From 6411e42e9ce991beddef99ba299d12ed5f35b95a Mon Sep 17 00:00:00 2001 From: Ziyang Hu Date: Sat, 14 Jan 2023 15:23:02 +0800 Subject: [PATCH] improve python API --- cozo-core/src/lib.rs | 14 ++++-- cozo-core/src/runtime/db.rs | 16 ++++--- cozo-lib-python/src/lib.rs | 93 +++++++++++++++++++++++++------------ 3 files changed, 81 insertions(+), 42 deletions(-) diff --git a/cozo-core/src/lib.rs b/cozo-core/src/lib.rs index a0d56975..b504a933 100644 --- a/cozo-core/src/lib.rs +++ b/cozo-core/src/lib.rs @@ -222,10 +222,11 @@ impl DbInstance { self.run_script_fold_err(payload, params_json).to_string() } /// Dispatcher method. See [crate::Db::export_relations]. - pub fn export_relations<'a>( - &self, - relations: impl Iterator, - ) -> Result> { + pub fn export_relations<'a, I, T>(&self, relations: I) -> Result> + where + T: AsRef, + I: Iterator, + { match self { DbInstance::Mem(db) => db.export_relations(relations), #[cfg(feature = "storage-sqlite")] @@ -413,7 +414,10 @@ impl DbInstance { } } /// Dispatcher method. See [crate::Db::register_fixed_rule]. - pub fn register_fixed_rule(&self, name: String, rule_impl: R) -> Result<()> where R: FixedRule + 'static { + pub fn register_fixed_rule(&self, name: String, rule_impl: R) -> Result<()> + where + R: FixedRule + 'static, + { match self { DbInstance::Mem(db) => db.register_fixed_rule(name, rule_impl), #[cfg(feature = "storage-sqlite")] diff --git a/cozo-core/src/runtime/db.rs b/cozo-core/src/runtime/db.rs index 6f86a45a..dcf50cc9 100644 --- a/cozo-core/src/runtime/db.rs +++ b/cozo-core/src/runtime/db.rs @@ -118,7 +118,8 @@ pub struct NamedRows { pub headers: Vec, /// The rows pub rows: Vec, - pub(crate) next: Option>, + /// Contains the next named rows, if exists + pub next: Option>, } impl NamedRows { @@ -214,14 +215,15 @@ impl<'s, S: Storage<'s>> Db { /// Export relations to JSON data. /// /// `relations` contains names of the stored relations to export. - pub fn export_relations<'a>( - &'s self, - relations: impl Iterator, - ) -> Result> { + pub fn export_relations<'a, I, T>(&'s self, relations: I) -> Result> + where + T: AsRef, + I: Iterator, + { let tx = self.transact()?; let mut ret: BTreeMap = BTreeMap::new(); for rel in relations { - let handle = tx.get_relation(rel, false)?; + let handle = tx.get_relation(rel.as_ref(), false)?; if handle.access_level < AccessLevel::ReadOnly { bail!(InsufficientAccessLevel( @@ -256,7 +258,7 @@ impl<'s, S: Storage<'s>> Db { rows.push(tuple); } let headers = cols.iter().map(|col| col.to_string()).collect_vec(); - ret.insert(rel.to_string(), NamedRows::new(headers, rows)); + ret.insert(rel.as_ref().to_string(), NamedRows::new(headers, rows)); } Ok(ret) } diff --git a/cozo-lib-python/src/lib.rs b/cozo-lib-python/src/lib.rs index f768bd71..a754c102 100644 --- a/cozo-lib-python/src/lib.rs +++ b/cozo-lib-python/src/lib.rs @@ -9,21 +9,37 @@ use std::collections::BTreeMap; use std::thread; -use miette::{IntoDiagnostic, Result}; +use miette::{IntoDiagnostic, Report, Result}; use pyo3::exceptions::PyException; use pyo3::prelude::*; use pyo3::types::{PyDict, PyList, PyString, PyTuple}; use cozo::*; -fn py_to_named_rows(ob: &PyAny) -> PyResult { +fn py_to_rows(ob: &PyAny) -> PyResult>> { let rows = ob.extract::>>()?; let res: Vec> = rows .into_iter() .map(|row| row.into_iter().map(py_to_value).collect::>()) .collect::>()?; + Ok(res) +} - Ok(NamedRows::new(vec![], res)) +fn report2py(r: Report) -> PyErr { + PyException::new_err(r.to_string()) +} + +fn py_to_named_rows(ob: &PyAny) -> PyResult { + let d = ob.downcast::()?; + let rows = d + .get_item("rows") + .ok_or_else(|| PyException::new_err("named rows must contain 'rows'"))?; + let rows = py_to_rows(rows)?; + let headers = d + .get_item("headers") + .ok_or_else(|| PyException::new_err("named rows must contain 'headers'"))?; + let headers = headers.extract::>()?; + Ok(NamedRows::new(headers, rows)) } fn py_to_value(ob: &PyAny) -> PyResult { @@ -76,7 +92,6 @@ fn options_to_py(opts: BTreeMap, py: Python<'_>) -> PyResult< let val = value_to_py(v, py); ret.set_item(k, val)?; } - Ok(ret.into()) } @@ -122,7 +137,11 @@ fn rows_to_py_rows(rows: Vec>, py: Python<'_>) -> PyObject { fn named_rows_to_py(named_rows: NamedRows, py: Python<'_>) -> PyObject { let rows = rows_to_py_rows(named_rows.rows, py); let headers = named_rows.headers.into_py(py); - BTreeMap::from([("rows", rows), ("headers", headers)]).into_py(py) + let next = match named_rows.next { + None => py.None(), + Some(nxt) => named_rows_to_py(*nxt, py), + }; + BTreeMap::from([("rows", rows), ("headers", headers), ("next", next)]).into_py(py) } #[pyclass] @@ -195,11 +214,10 @@ impl CozoDbPy { let py_opts = options_to_py(options, py).into_diagnostic()?; let args = PyTuple::new(py, vec![PyObject::from(py_inputs), py_opts]); let res = cb.as_ref(py).call1(args).into_diagnostic()?; - py_to_named_rows(res).into_diagnostic() + Ok(NamedRows::new(vec![], py_to_rows(res).into_diagnostic()?)) }) }); - db.register_fixed_rule(name, rule_impl) - .map_err(|err| PyException::new_err(err.to_string())) + db.register_fixed_rule(name, rule_impl).map_err(report2py) } else { Err(PyException::new_err(DB_CLOSED_MSG)) } @@ -221,46 +239,61 @@ impl CozoDbPy { Ok(false) } } - pub fn run_query(&self, py: Python<'_>, query: &str, params: &str) -> String { - if let Some(db) = &self.db { - py.allow_threads(|| db.run_script_str(query, params)) - } else { - DB_CLOSED_MSG.to_string() - } - } - pub fn export_relations(&self, py: Python<'_>, data: &str) -> String { + pub fn export_relations(&self, py: Python<'_>, relations: Vec) -> PyResult { if let Some(db) = &self.db { - py.allow_threads(|| db.export_relations_str(data)) + let res = match py.allow_threads(|| db.export_relations(relations.iter())) { + Ok(res) => res, + Err(err) => return Err(PyException::new_err(err.to_string())), + }; + let ret = PyDict::new(py); + for (k, v) in res { + ret.set_item(k, named_rows_to_py(v, py))?; + } + Ok(ret.into()) } else { - DB_CLOSED_MSG.to_string() + Err(PyException::new_err(DB_CLOSED_MSG.to_string())) } } - pub fn import_relations(&self, py: Python<'_>, data: &str) -> String { + pub fn import_relations(&self, py: Python<'_>, data: &PyDict) -> PyResult<()> { if let Some(db) = &self.db { - py.allow_threads(|| db.import_relations_str(data)) + let mut arg = BTreeMap::new(); + for (k, v) in data.iter() { + let k = k.extract::()?; + let vals = py_to_named_rows(v)?; + arg.insert(k, vals); + } + py.allow_threads(|| db.import_relations(arg)) + .map_err(report2py) } else { - DB_CLOSED_MSG.to_string() + Err(PyException::new_err(DB_CLOSED_MSG.to_string())) } } - pub fn backup(&self, py: Python<'_>, path: &str) -> String { + pub fn backup(&self, py: Python<'_>, path: &str) -> PyResult<()> { if let Some(db) = &self.db { - py.allow_threads(|| db.backup_db_str(path)) + py.allow_threads(|| db.backup_db(path)).map_err(report2py) } else { - DB_CLOSED_MSG.to_string() + Err(PyException::new_err(DB_CLOSED_MSG.to_string())) } } - pub fn restore(&self, py: Python<'_>, path: &str) -> String { + pub fn restore(&self, py: Python<'_>, path: &str) -> PyResult<()> { if let Some(db) = &self.db { - py.allow_threads(|| db.restore_backup_str(path)) + py.allow_threads(|| db.restore_backup(path)) + .map_err(report2py) } else { - DB_CLOSED_MSG.to_string() + Err(PyException::new_err(DB_CLOSED_MSG.to_string())) } } - pub fn import_from_backup(&self, py: Python<'_>, data: &str) -> String { + pub fn import_from_backup( + &self, + py: Python<'_>, + in_file: &str, + relations: Vec, + ) -> PyResult<()> { if let Some(db) = &self.db { - py.allow_threads(|| db.import_from_backup_str(data)) + py.allow_threads(|| db.import_from_backup(in_file, &relations)) + .map_err(report2py) } else { - DB_CLOSED_MSG.to_string() + Err(PyException::new_err(DB_CLOSED_MSG.to_string())) } } pub fn close(&mut self) -> bool {