diff --git a/cozo-core/src/data/expr.rs b/cozo-core/src/data/expr.rs index 2e0b4bb8..4700add4 100644 --- a/cozo-core/src/data/expr.rs +++ b/cozo-core/src/data/expr.rs @@ -606,6 +606,34 @@ impl Expr { } }) } + pub(crate) fn get_variables(&self) -> Result> { + let mut ret = BTreeSet::new(); + self.do_get_variables(&mut ret)?; + Ok(ret) + } + fn do_get_variables(&self, coll: &mut BTreeSet) -> Result<()> { + match self { + Expr::Binding { var, .. } => { + coll.insert(var.to_string()); + } + Expr::Const { .. } => {} + Expr::Apply { args, .. } => { + for arg in args.iter() { + arg.do_get_variables(coll)?; + } + } + Expr::Cond { clauses, .. } => { + for (cond, act) in clauses.iter() { + cond.do_get_variables(coll)?; + act.do_get_variables(coll)?; + } + } + Expr::UnboundApply { op, span, .. } => { + bail!(NoImplementationError(*span, op.to_string())); + } + } + Ok(()) + } pub(crate) fn to_var_list(&self) -> Result>> { match self { Expr::Apply { op, args, .. } => { diff --git a/cozo-core/src/lib.rs b/cozo-core/src/lib.rs index 55992690..44dbb4e8 100644 --- a/cozo-core/src/lib.rs +++ b/cozo-core/src/lib.rs @@ -75,6 +75,7 @@ pub use crate::runtime::callback::CallbackOp; pub use crate::runtime::db::Poison; pub use crate::runtime::db::TransactionPayload; pub use crate::runtime::db::evaluate_expressions; +pub use crate::runtime::db::get_variables; pub(crate) mod data; pub(crate) mod fixed_rule; diff --git a/cozo-core/src/runtime/db.rs b/cozo-core/src/runtime/db.rs index 064272be..a73cf661 100644 --- a/cozo-core/src/runtime/db.rs +++ b/cozo-core/src/runtime/db.rs @@ -1775,6 +1775,17 @@ pub fn evaluate_expressions( }) } +/// Get the variables referenced in a string expression +pub fn get_variables(src: &str, params: &BTreeMap) -> Result> { + _get_variables(src, params).map_err(|err| { + if err.source().is_none() { + err.with_source_code(format!("{src} ")) + } else { + err + } + }) +} + fn _evaluate_expressions( src: &str, params: &BTreeMap, @@ -1795,6 +1806,15 @@ fn _evaluate_expressions( Ok(ret) } +fn _get_variables(src: &str, params: &BTreeMap) -> Result> { + let mut exprs = parse_expressions(src, params)?; + let mut ret = BTreeSet::new(); + for expr in exprs.iter_mut() { + ret.extend(expr.get_variables()?); + } + Ok(ret) +} + /// Used for user-initiated termination of running queries #[derive(Clone, Default)] pub struct Poison(pub(crate) Arc); diff --git a/cozo-lib-python/src/lib.rs b/cozo-lib-python/src/lib.rs index 68d98c37..8e243030 100644 --- a/cozo-lib-python/src/lib.rs +++ b/cozo-lib-python/src/lib.rs @@ -6,7 +6,7 @@ * You can obtain one at https://mozilla.org/MPL/2.0/. */ -use std::collections::BTreeMap; +use std::collections::{BTreeMap, BTreeSet}; use std::thread; use miette::{IntoDiagnostic, Report, Result}; @@ -417,10 +417,27 @@ fn eval_expressions( } } +#[pyfunction] +fn variables(py: Python<'_>, query: &str, params: &PyDict) -> PyResult> { + let params = convert_params(params).unwrap(); + match get_variables(query, ¶ms) { + Ok(rows) => Ok(rows), + Err(err) => { + let reports = format_error_as_json(err, Some(query)).to_string(); + let json_mod = py.import("json")?; + let loads_fn = json_mod.getattr("loads")?; + let args = PyTuple::new(py, [PyString::new(py, &reports)]); + let msg = loads_fn.call1(args)?; + Err(PyException::new_err(PyObject::from(msg))) + } + } +} + #[pymodule] fn cozo_embedded(_py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_function(wrap_pyfunction!(eval_expressions, m)?)?; + m.add_function(wrap_pyfunction!(variables, m)?)?; Ok(()) }