From 974dfc2e2a7d2ec86ec2e869e980a91e6843ecf7 Mon Sep 17 00:00:00 2001 From: Ziyang Hu Date: Tue, 16 May 2023 00:11:13 +0800 Subject: [PATCH] evaluate a list of expressions --- cozo-bin/src/repl.rs | 8 ++- cozo-core/src/cozoscript.pest | 4 +- cozo-core/src/data/expr.rs | 1 - cozo-core/src/lib.rs | 3 +- cozo-core/src/parse/mod.rs | 44 +++++++++++++- cozo-core/src/query/stored.rs | 6 +- cozo-core/src/runtime/db.rs | 111 +++++++++++++++++++++++----------- cozo-lib-python/src/lib.rs | 23 +++++++ 8 files changed, 155 insertions(+), 45 deletions(-) diff --git a/cozo-bin/src/repl.rs b/cozo-bin/src/repl.rs index b1eeebda..f51edda9 100644 --- a/cozo-bin/src/repl.rs +++ b/cozo-bin/src/repl.rs @@ -20,7 +20,7 @@ use rustyline::history::DefaultHistory; use rustyline::Changeset; use serde_json::{json, Value}; -use cozo::{DataValue, DbInstance, NamedRows}; +use cozo::{DataValue, DbInstance, evaluate_expressions, NamedRows}; struct Indented; @@ -208,6 +208,12 @@ fn process_line( .split_once(|c: char| c.is_whitespace()) .unwrap_or((remaining, "")); match op { + "eval" => { + let out = evaluate_expressions(payload, params, params)?; + for val in out { + println!("{val}"); + } + } "set" => { let (key, v_str) = payload .trim() diff --git a/cozo-core/src/cozoscript.pest b/cozo-core/src/cozoscript.pest index d4c54d69..c7f7d67b 100644 --- a/cozo-core/src/cozoscript.pest +++ b/cozo-core/src/cozoscript.pest @@ -266,4 +266,6 @@ fts_expr = {fts_term ~ (fts_op ~ fts_term)*} fts_op = _{fts_and | fts_or | fts_not} fts_and = {"AND"} fts_or = {"OR" | "," | ";"} -fts_not = {"NOT"} \ No newline at end of file +fts_not = {"NOT"} + +expression_script = {SOI ~ expr ~ ("," ~ expr)* ~ EOI} \ No newline at end of file diff --git a/cozo-core/src/data/expr.rs b/cozo-core/src/data/expr.rs index fb230b3b..2e0b4bb8 100644 --- a/cozo-core/src/data/expr.rs +++ b/cozo-core/src/data/expr.rs @@ -335,7 +335,6 @@ impl Expr { #[derive(Debug, Error, Diagnostic)] #[error("Cannot find binding {0}")] #[diagnostic(code(eval::bad_binding))] - #[diagnostic(help("This could indicate a system problem"))] struct BadBindingError(String, #[label] SourceSpan); let found_idx = *binding_map diff --git a/cozo-core/src/lib.rs b/cozo-core/src/lib.rs index bf95a421..55992690 100644 --- a/cozo-core/src/lib.rs +++ b/cozo-core/src/lib.rs @@ -74,6 +74,7 @@ pub use crate::parse::SourceSpan; pub use crate::runtime::callback::CallbackOp; pub use crate::runtime::db::Poison; pub use crate::runtime::db::TransactionPayload; +pub use crate::runtime::db::evaluate_expressions; pub(crate) mod data; pub(crate) mod fixed_rule; @@ -543,7 +544,7 @@ impl MultiTransaction { pub fn format_error_as_json(mut err: Report, source: Option<&str>) -> JsonValue { if err.source_code().is_none() { if let Some(src) = source { - err = err.with_source_code(src.to_string()); + err = err.with_source_code(format!("{src} ")); } } let mut text_err = String::new(); diff --git a/cozo-core/src/parse/mod.rs b/cozo-core/src/parse/mod.rs index 42190ac2..5b9651ce 100644 --- a/cozo-core/src/parse/mod.rs +++ b/cozo-core/src/parse/mod.rs @@ -21,18 +21,19 @@ use thiserror::Error; use crate::data::program::InputProgram; use crate::data::relation::NullableColType; use crate::data::value::{DataValue, ValidityTs}; +use crate::parse::expr::build_expr; use crate::parse::imperative::parse_imperative_block; use crate::parse::query::parse_query; use crate::parse::schema::parse_nullable_type; use crate::parse::sys::{parse_sys, SysOp}; -use crate::FixedRule; +use crate::{Expr, FixedRule}; pub(crate) mod expr; +pub(crate) mod fts; pub(crate) mod imperative; pub(crate) mod query; pub(crate) mod schema; pub(crate) mod sys; -pub(crate) mod fts; #[derive(pest_derive::Parser)] #[grammar = "cozoscript.pest"] @@ -209,6 +210,34 @@ pub(crate) fn parse_type(src: &str) -> Result { parse_nullable_type(parsed.into_inner().next().unwrap()) } +pub(crate) fn parse_expressions( + src: &str, + param_pool: &BTreeMap, +) -> Result> { + let mut ret = vec![]; + let parsed = CozoScriptParser::parse(Rule::expression_script, src) + .map_err(|err| { + let span = match err.location { + InputLocation::Pos(p) => SourceSpan(p, 0), + InputLocation::Span((start, end)) => SourceSpan(start, end - start), + }; + ParseError { span } + })? + .next() + .unwrap() + .into_inner(); + for rule in parsed { + match rule.as_rule() { + Rule::expr => { + ret.push(build_expr(rule, param_pool)?); + } + Rule::EOI => {} + _ => unreachable!(), + } + } + Ok(ret) +} + pub(crate) fn parse_script( src: &str, param_pool: &BTreeMap, @@ -257,3 +286,14 @@ impl ExtractSpan for Pair<'_> { SourceSpan(start, end - start) } } + +#[cfg(test)] +mod tests { + use crate::parse::parse_expressions; + + #[test] + fn test_expressions() { + let x = parse_expressions("null, 1, 2, 3, 5, 6 > 7", &Default::default()).unwrap(); + println!("{:?}", x); + } +} diff --git a/cozo-core/src/query/stored.rs b/cozo-core/src/query/stored.rs index 4ee2dd0a..a0807322 100644 --- a/cozo-core/src/query/stored.rs +++ b/cozo-core/src/query/stored.rs @@ -104,7 +104,7 @@ impl<'a> SessionTx<'a> { if err.source_code().is_some() { err } else { - err.with_source_code(trigger.to_string()) + err.with_source_code(format!("{trigger}" )) } })?; to_clear.extend(cleanups); @@ -729,7 +729,7 @@ impl<'a> SessionTx<'a> { if err.source_code().is_some() { err } else { - err.with_source_code(trigger.to_string()) + err.with_source_code(format!("{trigger} ")) } })?; to_clear.extend(cleanups); @@ -1059,7 +1059,7 @@ impl<'a> SessionTx<'a> { if err.source_code().is_some() { err } else { - err.with_source_code(trigger.to_string()) + err.with_source_code(format!("{trigger} ")) } })?; to_clear.extend(cleanups); diff --git a/cozo-core/src/runtime/db.rs b/cozo-core/src/runtime/db.rs index fd932c9a..064272be 100644 --- a/cozo-core/src/runtime/db.rs +++ b/cozo-core/src/runtime/db.rs @@ -41,7 +41,7 @@ use crate::data::value::{DataValue, ValidityTs, LARGEST_UTF_CHAR}; use crate::fixed_rule::DEFAULT_FIXED_RULES; use crate::fts::TokenizerCache; use crate::parse::sys::SysOp; -use crate::parse::{parse_script, CozoScript, SourceSpan}; +use crate::parse::{parse_expressions, parse_script, CozoScript, SourceSpan}; use crate::query::compile::{CompiledProgram, CompiledRule, CompiledRuleSet}; use crate::query::ra::{ FilteredRA, FtsSearchRA, HnswSearchRA, InnerJoin, LshSearchRA, NegJoin, RelAlgebra, ReorderRA, @@ -51,11 +51,13 @@ use crate::query::ra::{ use crate::runtime::callback::{ CallbackCollector, CallbackDeclaration, CallbackOp, EventCallbackRegistry, }; -use crate::runtime::relation::{extend_tuple_from_v, AccessLevel, InsufficientAccessLevel, RelationHandle, RelationId}; +use crate::runtime::relation::{ + extend_tuple_from_v, AccessLevel, InsufficientAccessLevel, RelationHandle, RelationId, +}; use crate::runtime::transact::SessionTx; use crate::storage::temp::TempStorage; use crate::storage::{Storage, StoreTx}; -use crate::{decode_tuple_from_kv, FixedRule}; +use crate::{decode_tuple_from_kv, FixedRule, Symbol}; pub(crate) struct RunningQueryHandle { pub(crate) started_at: f64, @@ -375,9 +377,9 @@ impl<'s, S: Storage<'s>> Db { /// /// `relations` contains names of the stored relations to export. pub fn export_relations(&'s self, relations: I) -> Result> - where - T: AsRef, - I: Iterator, + where + T: AsRef, + I: Iterator, { let tx = self.transact()?; let mut ret: BTreeMap = BTreeMap::new(); @@ -687,8 +689,8 @@ impl<'s, S: Storage<'s>> Db { } /// Register a custom fixed rule implementation. pub fn register_fixed_rule(&self, name: String, rule_impl: R) -> Result<()> - where - R: FixedRule + 'static, + where + R: FixedRule + 'static, { match self.fixed_rules.write().unwrap().entry(name) { Entry::Vacant(ent) => { @@ -757,7 +759,7 @@ impl<'s, S: Storage<'s>> Db { ret.is_some() } - pub(crate) fn obtain_relation_locks<'a, T: Iterator>>( + pub(crate) fn obtain_relation_locks<'a, T: Iterator>>( &'s self, rels: T, ) -> Vec>> { @@ -829,7 +831,7 @@ impl<'s, S: Storage<'s>> Db { callback_collector: &mut CallbackCollector, ) -> Result { #[allow(unused_variables)] - let sleep_opt = p.out_opts.sleep; + let sleep_opt = p.out_opts.sleep; let (q_res, q_cleanups) = self.run_query(tx, p, cur_vld, callback_targets, callback_collector, true)?; cleanups.extend(q_cleanups); @@ -968,28 +970,28 @@ impl<'s, S: Storage<'s>> Db { ("fixed", json!(null), json!(null), json!(null)) } RelAlgebra::TempStore(TempStoreRA { - storage_key, - filters, - .. - }) => ( + storage_key, + filters, + .. + }) => ( "load_mem", json!(storage_key.to_string()), json!(null), json!(filters.iter().map(|f| f.to_string()).collect_vec()), ), RelAlgebra::Stored(StoredRA { - storage, filters, .. - }) => ( + storage, filters, .. + }) => ( "load_stored", json!(format!(":{}", storage.name)), json!(null), json!(filters.iter().map(|f| f.to_string()).collect_vec()), ), RelAlgebra::StoredWithValidity(StoredWithValidityRA { - storage, - filters, - .. - }) => ( + storage, + filters, + .. + }) => ( "load_stored_with_validity", json!(format!(":{}", storage.name)), json!(null), @@ -1028,10 +1030,10 @@ impl<'s, S: Storage<'s>> Db { ("reorder", json!(null), json!(null), json!(null)) } RelAlgebra::Filter(FilteredRA { - parent, - filters: pred, - .. - }) => { + parent, + filters: pred, + .. + }) => { rel_stack.push(parent); ( "filter", @@ -1041,12 +1043,12 @@ impl<'s, S: Storage<'s>> Db { ) } RelAlgebra::Unification(UnificationRA { - parent, - binding, - expr, - is_multi, - .. - }) => { + parent, + binding, + expr, + is_multi, + .. + }) => { rel_stack.push(parent); ( if *is_multi { "multi-unify" } else { "unify" }, @@ -1056,8 +1058,8 @@ impl<'s, S: Storage<'s>> Db { ) } RelAlgebra::HnswSearch(HnswSearchRA { - hnsw_search, .. - }) => ( + hnsw_search, .. + }) => ( "hnsw_index", json!(format!(":{}", hnsw_search.query.name)), json!(hnsw_search.query.name), @@ -1440,7 +1442,7 @@ impl<'s, S: Storage<'s>> Db { if let Some(tuple) = result_store.all_iter().next() { #[derive(Debug, Error, Diagnostic)] #[error( - "The query is asserted to return no result, but a tuple {0:?} is found" + "The query is asserted to return no result, but a tuple {0:?} is found" )] #[diagnostic(code(eval::assert_none_failure))] struct AssertNoneFailure(Tuple, #[label] SourceSpan); @@ -1493,7 +1495,8 @@ impl<'s, S: Storage<'s>> Db { ) .wrap_err_with(|| format!("when executing against relation '{}'", meta.name))?; clean_ups.extend(to_clear); - let returned_rows = tx.get_returning_rows(callback_collector, &meta.name, returning)?; + let returned_rows = + tx.get_returning_rows(callback_collector, &meta.name, returning)?; Ok((returned_rows, clean_ups)) } else { // not sorting outputs @@ -1548,7 +1551,8 @@ impl<'s, S: Storage<'s>> Db { ) .wrap_err_with(|| format!("when executing against relation '{}'", meta.name))?; clean_ups.extend(to_clear); - let returned_rows = tx.get_returning_rows(callback_collector, &meta.name, returning)?; + let returned_rows = + tx.get_returning_rows(callback_collector, &meta.name, returning)?; Ok((returned_rows, clean_ups)) } else { @@ -1756,6 +1760,41 @@ impl<'s, S: Storage<'s>> Db { } } +/// Evaluate a string expression in the context of a set of parameters and variables +pub fn evaluate_expressions( + src: &str, + params: &BTreeMap, + vars: &BTreeMap, +) -> Result> { + _evaluate_expressions(src, params, vars).map_err(|err| { + if err.source().is_none() { + err.with_source_code(format!("{src} ")) + } else { + err + } + }) +} + +fn _evaluate_expressions( + src: &str, + params: &BTreeMap, + vars: &BTreeMap, +) -> Result> { + let mut exprs = parse_expressions(src, params)?; + let mut ctx = vec![]; + let mut binding_map = BTreeMap::new(); + for (i, (k, v)) in vars.iter().enumerate() { + ctx.push(v.clone()); + binding_map.insert(Symbol::new(k, Default::default()), i); + } + let mut ret = vec![]; + for expr in exprs.iter_mut() { + expr.fill_binding_indices(&binding_map)?; + ret.push(expr.eval(&ctx)?); + } + Ok(ret) +} + /// Used for user-initiated termination of running queries #[derive(Clone, Default)] pub struct Poison(pub(crate) Arc); @@ -1792,7 +1831,7 @@ impl Poison { pub(crate) fn seconds_since_the_epoch() -> Result { #[cfg(not(target_arch = "wasm32"))] - let now = SystemTime::now(); + let now = SystemTime::now(); #[cfg(not(target_arch = "wasm32"))] return Ok(now .duration_since(UNIX_EPOCH) diff --git a/cozo-lib-python/src/lib.rs b/cozo-lib-python/src/lib.rs index bd0381e6..68d98c37 100644 --- a/cozo-lib-python/src/lib.rs +++ b/cozo-lib-python/src/lib.rs @@ -395,9 +395,32 @@ impl CozoDbMulTx { } } +#[pyfunction] +fn eval_expressions( + py: Python<'_>, + query: &str, + params: &PyDict, + bindings: &PyDict, +) -> PyResult> { + let params = convert_params(params).unwrap(); + let bindings = convert_params(bindings).unwrap(); + match evaluate_expressions(query, ¶ms, &bindings) { + Ok(rows) => Ok(rows.into_iter().map(|r| value_to_py(r, py)).collect()), + Err(err) => { + let reports = format_error_as_json(err, Some(query)).to_string(); + let json_mod = py.import("json")?; + let loads_fn = json_mod.getattr("loads")?; + let args = PyTuple::new(py, [PyString::new(py, &reports)]); + let msg = loads_fn.call1(args)?; + Err(PyException::new_err(PyObject::from(msg))) + } + } +} + #[pymodule] fn cozo_embedded(_py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; + m.add_function(wrap_pyfunction!(eval_expressions, m)?)?; Ok(()) }