diff --git a/cozo-core/src/data/value.rs b/cozo-core/src/data/value.rs index a883199e..cb531454 100644 --- a/cozo-core/src/data/value.rs +++ b/cozo-core/src/data/value.rs @@ -17,8 +17,9 @@ use serde::{Deserialize, Deserializer, Serialize}; use smartstring::{LazyCompact, SmartString}; use uuid::Uuid; +/// UUID value in the database #[derive(Clone, Hash, Eq, PartialEq, serde_derive::Deserialize, serde_derive::Serialize)] -pub struct UuidWrapper(pub(crate) Uuid); +pub struct UuidWrapper(pub Uuid); impl PartialOrd for UuidWrapper { fn partial_cmp(&self, other: &Self) -> Option { @@ -37,8 +38,9 @@ impl Ord for UuidWrapper { } } +/// A Regex in the database. Used internally in functions. #[derive(Clone)] -pub struct RegexWrapper(pub(crate) Regex); +pub struct RegexWrapper(pub Regex); impl Hash for RegexWrapper { fn hash(&self, state: &mut H) { @@ -84,6 +86,7 @@ impl PartialOrd for RegexWrapper { } } +/// Timestamp part of validity #[derive( Copy, Clone, @@ -98,6 +101,7 @@ impl PartialOrd for RegexWrapper { )] pub struct ValidityTs(pub Reverse); +/// Validity for time travel #[derive( Copy, Clone, @@ -110,25 +114,39 @@ pub struct ValidityTs(pub Reverse); Hash, )] pub struct Validity { - pub(crate) timestamp: ValidityTs, - pub(crate) is_assert: Reverse, + /// Timestamp, sorted descendingly + pub timestamp: ValidityTs, + /// Whether this validity is an assertion, sorted descendingly + pub is_assert: Reverse, } +/// A Value in the database #[derive( Clone, PartialEq, Eq, PartialOrd, Ord, serde_derive::Deserialize, serde_derive::Serialize, Hash, )] pub enum DataValue { + /// null Null, + /// boolean Bool(bool), + /// number, may be int or float Num(Num), + /// string Str(SmartString), + /// bytes #[serde(with = "serde_bytes")] Bytes(Vec), + /// UUID Uuid(UuidWrapper), + /// Regex, used internally only Regex(RegexWrapper), + /// list List(Vec), + /// set, used internally only Set(BTreeSet), + /// validity Validity(Validity), + /// bottom type, used internally only Bot, } @@ -162,9 +180,12 @@ impl From for DataValue { } } +/// Representing a number #[derive(Copy, Clone, serde_derive::Deserialize, serde_derive::Serialize)] pub enum Num { + /// intger number Int(i64), + /// float number Float(f64), } @@ -340,6 +361,7 @@ impl DataValue { _ => None, } } + /// Returns bool if this one is. pub fn get_bool(&self) -> Option { match self { DataValue::Bool(b) => Some(*b), diff --git a/cozo-core/src/lib.rs b/cozo-core/src/lib.rs index 0e6da4e4..4e5cde12 100644 --- a/cozo-core/src/lib.rs +++ b/cozo-core/src/lib.rs @@ -45,6 +45,7 @@ use miette::{ }; use serde_json::json; +pub use data::value::{DataValue, Num, RegexWrapper, UuidWrapper, Validity, ValidityTs}; pub use fixed_rule::FixedRule; pub use runtime::db::Db; pub use runtime::db::NamedRows; @@ -153,7 +154,7 @@ impl DbInstance { pub fn run_script( &self, payload: &str, - params: BTreeMap, + params: BTreeMap, ) -> Result { match self { DbInstance::Mem(db) => db.run_script(payload, params), @@ -172,7 +173,7 @@ impl DbInstance { pub fn run_script_fold_err( &self, payload: &str, - params: BTreeMap, + params: BTreeMap, ) -> JsonValue { #[cfg(not(target_arch = "wasm32"))] let start = Instant::now(); @@ -198,7 +199,10 @@ impl DbInstance { BTreeMap::default() } else { match serde_json::from_str::>(params) { - Ok(map) => map, + Ok(map) => map + .into_iter() + .map(|(k, v)| (k, DataValue::from(v))) + .collect(), Err(_) => { return json!({"ok": false, "message": "params argument is not a JSON map"}) .to_string() diff --git a/cozo-core/src/runtime/db.rs b/cozo-core/src/runtime/db.rs index 8423d0cc..30a58637 100644 --- a/cozo-core/src/runtime/db.rs +++ b/cozo-core/src/runtime/db.rs @@ -163,13 +163,9 @@ impl<'s, S: Storage<'s>> Db { pub fn run_script( &'s self, payload: &str, - params: BTreeMap, + params: BTreeMap, ) -> Result { let cur_vld = current_validity(); - let params = params - .into_iter() - .map(|(k, v)| (k, DataValue::from(v))) - .collect(); self.do_run_script(payload, ¶ms, cur_vld) } /// Export relations to JSON data. diff --git a/cozo-lib-python/src/lib.rs b/cozo-lib-python/src/lib.rs index f263dbdc..685fe715 100644 --- a/cozo-lib-python/src/lib.rs +++ b/cozo-lib-python/src/lib.rs @@ -8,9 +8,95 @@ use pyo3::exceptions::PyException; use pyo3::prelude::*; +use pyo3::types::{PyDict, PyList}; +use std::collections::BTreeMap; use cozo::*; +fn py_to_value(ob: &PyAny) -> PyResult { + Ok(if ob.is_none() { + DataValue::Null + } else if let Ok(i) = ob.extract::() { + DataValue::from(i) + } else if let Ok(f) = ob.extract::() { + DataValue::from(f) + } else if let Ok(s) = ob.extract::() { + DataValue::from(s) + } else if let Ok(b) = ob.extract::>() { + DataValue::Bytes(b) + } else if let Ok(l) = ob.downcast::() { + let mut coll = Vec::with_capacity(l.len()); + for el in l { + let el = py_to_value(el)?; + coll.push(el) + } + DataValue::List(coll) + } else if let Ok(d) = ob.downcast::() { + let mut coll = Vec::with_capacity(d.len()); + for (k, v) in d { + let k = py_to_value(k)?; + let v = py_to_value(v)?; + coll.push(DataValue::List(vec![k, v])) + } + DataValue::List(coll) + } else { + return Err(PyException::new_err(format!("Cannot convert {} into Cozo value", ob)).into()); + }) +} + +fn convert_params(ob: &PyDict) -> PyResult> { + let mut ret = BTreeMap::new(); + for (k, v) in ob { + let k: String = k.extract()?; + let v = py_to_value(v)?; + ret.insert(k, v); + } + Ok(ret) +} + +fn value_to_py(val: DataValue, py: Python<'_>) -> PyObject { + match val { + DataValue::Null => py.None(), + DataValue::Bool(b) => b.into_py(py), + DataValue::Num(num) => match num { + Num::Int(i) => i.into_py(py), + Num::Float(f) => f.into_py(py), + }, + DataValue::Str(s) => s.as_str().into_py(py), + DataValue::Bytes(b) => b.into_py(py), + DataValue::Uuid(uuid) => uuid.0.to_string().into_py(py), + DataValue::Regex(rx) => rx.0.as_str().into_py(py), + DataValue::List(l) => { + let vs: Vec<_> = l.into_iter().map(|v| value_to_py(v, py)).collect(); + vs.into_py(py) + } + DataValue::Set(l) => { + let vs: Vec<_> = l.into_iter().map(|v| value_to_py(v, py)).collect(); + vs.into_py(py) + } + DataValue::Validity(vld) => { + [vld.timestamp.0 .0.into_py(py), vld.is_assert.0.into_py(py)].into_py(py) + } + DataValue::Bot => py.None(), + } +} + +fn named_rows_to_py(named_rows: NamedRows, py: Python<'_>) -> PyObject { + let rows = named_rows + .rows + .into_iter() + .map(|row| { + row.into_iter() + .map(|val| value_to_py(val, py)) + .collect::>() + .into_py(py) + }) + .collect::>() + .into_py(py); + let headers = named_rows.headers.into_py(py); + BTreeMap::from([("rows", rows), ("headers", headers)]).into_py(py) +} + #[pyclass] struct CozoDbPy { db: Option, @@ -27,6 +113,20 @@ impl CozoDbPy { Err(err) => Err(PyException::new_err(format!("{err:?}"))), } } + pub fn run_script(&self, py: Python<'_>, query: &str, params: &PyDict) -> PyResult { + if let Some(db) = &self.db { + let params = convert_params(params)?; + match py.allow_threads(|| db.run_script(query, params)) { + Ok(rows) => Ok(named_rows_to_py(rows, py)), + Err(err) => { + let reports = format_error_as_json(err, Some(query)).to_string(); + Err(PyException::new_err(reports)) + } + } + } else { + Err(PyException::new_err(DB_CLOSED_MSG)) + } + } pub fn run_query(&self, py: Python<'_>, query: &str, params: &str) -> String { if let Some(db) = &self.db { py.allow_threads(|| db.run_script_str(query, params))