diff --git a/Cargo.lock b/Cargo.lock index c013ace2..5853fb3b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -643,6 +643,7 @@ name = "cozo_py" version = "0.5.0" dependencies = [ "cozo", + "miette", "pyo3", ] diff --git a/cozo-core/src/fixed_rule/mod.rs b/cozo-core/src/fixed_rule/mod.rs index feca8c99..b41b9f6f 100644 --- a/cozo-core/src/fixed_rule/mod.rs +++ b/cozo-core/src/fixed_rule/mod.rs @@ -9,12 +9,14 @@ use std::collections::BTreeMap; use std::sync::Arc; +use crossbeam::channel::{bounded, Receiver, Sender}; #[allow(unused_imports)] use either::{Left, Right}; #[cfg(feature = "graph-algo")] use graph::prelude::{CsrLayout, DirectedCsrGraph, GraphBuilder}; use itertools::Itertools; use lazy_static::lazy_static; +use miette::IntoDiagnostic; #[allow(unused_imports)] use miette::{bail, ensure, Diagnostic, Report, Result}; use smartstring::{LazyCompact, SmartString}; @@ -586,6 +588,28 @@ impl SimpleFixedRule { rule: Box::new(rule), } } + /// Construct a SimpleFixedRule that uses channels for communication. + pub fn rule_with_channel( + return_arity: usize, + ) -> ( + Self, + Receiver<(Vec, JsonValue)>, + Sender>, + ) { + let (db2app_sender, db2app_receiver) = bounded(0); + let (app2db_sender, app2db_receiver) = bounded(0); + ( + Self { + return_arity, + rule: Box::new(move |inputs, options| -> Result { + db2app_sender.send((inputs, options)).into_diagnostic()?; + app2db_receiver.recv().into_diagnostic()? + }), + }, + db2app_receiver, + app2db_sender, + ) + } } impl FixedRule for SimpleFixedRule { @@ -635,7 +659,7 @@ impl FixedRule for SimpleFixedRule { let results: NamedRows = (self.rule)(inputs, options)?; for row in results.rows { #[derive(Debug, Error, Diagnostic)] - #[error("arity mismatch: expect {0}, got {1}")] + #[error("arity mismatch: expect {1}, got {2}")] #[diagnostic(code(parser::simple_fixed_rule_arity_mismatch))] struct ArityMismatch(#[label] SourceSpan, usize, usize); diff --git a/cozo-core/src/lib.rs b/cozo-core/src/lib.rs index e160a21a..a0d56975 100644 --- a/cozo-core/src/lib.rs +++ b/cozo-core/src/lib.rs @@ -36,6 +36,7 @@ use std::path::Path; #[allow(unused_imports)] use std::time::Instant; +use crossbeam::channel::Receiver; use lazy_static::lazy_static; pub use miette::Error; use miette::Report; @@ -375,22 +376,24 @@ impl DbInstance { self.import_from_backup(&json_payload.path, &json_payload.relations) } + /// Dispatcher method. See [crate::Db::register_callback]. #[cfg(not(target_arch = "wasm32"))] - pub fn register_callback(&self, relation: &str, callback: CB) -> Result - where - CB: Fn(CallbackOp, NamedRows, NamedRows) + Send + Sync + 'static, - { + pub fn register_callback( + &self, + relation: &str, + capacity: Option, + ) -> (u32, Receiver<(CallbackOp, NamedRows, NamedRows)>) { match self { - DbInstance::Mem(db) => db.register_callback(relation, callback), + DbInstance::Mem(db) => db.register_callback(relation, capacity), #[cfg(feature = "storage-sqlite")] - DbInstance::Sqlite(db) => db.register_callback(relation, callback), + DbInstance::Sqlite(db) => db.register_callback(relation, capacity), #[cfg(feature = "storage-rocksdb")] - DbInstance::RocksDb(db) => db.register_callback(relation, callback), + DbInstance::RocksDb(db) => db.register_callback(relation, capacity), #[cfg(feature = "storage-sled")] - DbInstance::Sled(db) => db.register_callback(relation, callback), + DbInstance::Sled(db) => db.register_callback(relation, capacity), #[cfg(feature = "storage-tikv")] - DbInstance::TiKv(db) => db.register_callback(relation, callback), + DbInstance::TiKv(db) => db.register_callback(relation, capacity), } } @@ -410,11 +413,7 @@ impl DbInstance { } } /// Dispatcher method. See [crate::Db::register_fixed_rule]. - pub fn register_fixed_rule( - &self, - name: String, - rule_impl: Box, - ) -> Result<()> { + pub fn register_fixed_rule(&self, name: String, rule_impl: R) -> Result<()> where R: FixedRule + 'static { match self { DbInstance::Mem(db) => db.register_fixed_rule(name, rule_impl), #[cfg(feature = "storage-sqlite")] diff --git a/cozo-core/src/runtime/callback.rs b/cozo-core/src/runtime/callback.rs index 075574be..1f7fdcfc 100644 --- a/cozo-core/src/runtime/callback.rs +++ b/cozo-core/src/runtime/callback.rs @@ -9,6 +9,7 @@ use std::collections::{BTreeMap, BTreeSet}; use std::fmt::{Display, Formatter}; +use crossbeam::channel::Sender; use smartstring::{LazyCompact, SmartString}; use crate::{Db, NamedRows, Storage}; @@ -31,10 +32,20 @@ impl Display for CallbackOp { } } +impl CallbackOp { + /// Get the string representation + pub fn as_str(&self) -> &'static str { + match self { + CallbackOp::Put => "Put", + CallbackOp::Rm => "Rm" + } + } +} + #[cfg(not(target_arch = "wasm32"))] pub struct CallbackDeclaration { pub(crate) dependent: SmartString, - pub(crate) callback: Box, + pub(crate) sender: Sender<(CallbackOp, NamedRows, NamedRows)>, } pub(crate) type CallbackCollector = @@ -66,11 +77,40 @@ impl<'s, S: Storage<'s>> Db { } #[cfg(not(target_arch = "wasm32"))] pub(crate) fn send_callbacks(&'s self, collector: CallbackCollector) { - for (k, vals) in collector { + let mut to_remove = vec![]; + + for (table, vals) in collector { for (op, new, old) in vals { - self.callback_sender - .send((k.clone(), op, new, old)) - .expect("sending to callback processor failed"); + let (cbs, cb_dir) = &*self.event_callbacks.read().unwrap(); + if let Some(cb_ids) = cb_dir.get(&table) { + let mut it = cb_ids.iter(); + if let Some(fst) = it.next() { + for cb_id in it { + if let Some(cb) = cbs.get(cb_id) { + if cb.sender.send((op, new.clone(), old.clone())).is_err() { + to_remove.push(*cb_id) + } + } + } + + if let Some(cb) = cbs.get(fst) { + if cb.sender.send((op, new, old)).is_err() { + to_remove.push(*fst) + } + } + } + } + } + } + + if !to_remove.is_empty() { + let (cbs, cb_dir) = &mut *self.event_callbacks.write().unwrap(); + for removing_id in &to_remove { + if let Some(removed) = cbs.remove(removing_id) { + if let Some(set) = cb_dir.get_mut(&removed.dependent) { + set.remove(removing_id); + } + } } } } diff --git a/cozo-core/src/runtime/db.rs b/cozo-core/src/runtime/db.rs index cc08f51f..6f86a45a 100644 --- a/cozo-core/src/runtime/db.rs +++ b/cozo-core/src/runtime/db.rs @@ -20,8 +20,7 @@ use std::thread; #[allow(unused_imports)] use std::time::{Duration, SystemTime, UNIX_EPOCH}; -#[cfg(not(target_arch = "wasm32"))] -use crossbeam::channel::{unbounded, Sender}; +use crossbeam::channel::{bounded, unbounded, Receiver}; use crossbeam::sync::ShardedLock; use either::{Left, Right}; use itertools::Itertools; @@ -92,10 +91,7 @@ pub struct Db { #[cfg(not(target_arch = "wasm32"))] callback_count: AtomicU32, #[cfg(not(target_arch = "wasm32"))] - pub(crate) callback_sender: - Sender<(SmartString, CallbackOp, NamedRows, NamedRows)>, - #[cfg(not(target_arch = "wasm32"))] - pub(crate) event_callbacks: Arc>, + pub(crate) event_callbacks: ShardedLock, relation_locks: ShardedLock, Arc>>>, } @@ -183,8 +179,6 @@ impl<'s, S: Storage<'s>> Db { /// You must call [`initialize`](Self::initialize) immediately after creation. /// Due to lifetime restrictions we are not able to call that for you automatically. pub fn new(storage: S) -> Result { - #[cfg(not(target_arch = "wasm32"))] - let (sender, receiver) = unbounded(); let ret = Self { db: storage, temp_db: Default::default(), @@ -194,29 +188,11 @@ impl<'s, S: Storage<'s>> Db { fixed_rules: ShardedLock::new(DEFAULT_FIXED_RULES.clone()), #[cfg(not(target_arch = "wasm32"))] callback_count: Default::default(), - #[cfg(not(target_arch = "wasm32"))] - callback_sender: sender, // callback_receiver: Arc::new(receiver), #[cfg(not(target_arch = "wasm32"))] event_callbacks: Default::default(), relation_locks: Default::default(), }; - #[cfg(not(target_arch = "wasm32"))] - { - let callbacks = ret.event_callbacks.clone(); - thread::spawn(move || { - while let Ok((table, op, new, old)) = receiver.recv() { - let (cbs, cb_dir) = &*callbacks.read().unwrap(); - if let Some(cb_ids) = cb_dir.get(&table) { - for cb_id in cb_ids { - if let Some(cb) = cbs.get(cb_id) { - (cb.callback)(op, new.clone(), old.clone()) - } - } - } - } - }); - } Ok(ret) } @@ -548,10 +524,13 @@ impl<'s, S: Storage<'s>> Db { } } /// Register a custom fixed rule implementation. - pub fn register_fixed_rule(&self, name: String, rule_impl: Box) -> Result<()> { + pub fn register_fixed_rule(&self, name: String, rule_impl: R) -> Result<()> + where + R: FixedRule + 'static, + { match self.fixed_rules.write().unwrap().entry(name) { Entry::Vacant(ent) => { - ent.insert(Arc::new(rule_impl)); + ent.insert(Arc::new(Box::new(rule_impl))); Ok(()) } Entry::Occupied(ent) => { @@ -571,31 +550,22 @@ impl<'s, S: Storage<'s>> Db { Ok(self.fixed_rules.write().unwrap().remove(name).is_some()) } - /// Register callbacks to run when changes to the requested relation are successfully committed. - /// The returned ID can be used to unregister the callbacks. - /// - /// When any mutation is made against the registered relation, - /// the callback will be called with three arguments: - /// - /// * The type of the mutation (`Put` or `Rm`) - /// * The incoming rows - /// * For `Put`, the rows inserted / upserted - /// * For `Rm`, the keys of the rows to be removed. These keys do not necessarily existed in the database. - /// * The evicted rows - /// * For `Put`, old rows that were updated - /// * For `Rm`, the rows actually removed. - /// - /// The database has a single dedicated thread for executing callbacks no matter how many callbacks - /// are registered. You should not perform blocking operations in the callback provided to this function. - /// If blocking operations are required, send the data to some other thread and perform the operations there. + /// Register callback channel to receive changes when the requested relation are successfully committed. + /// The returned ID can be used to unregister the callback channel. #[cfg(not(target_arch = "wasm32"))] - pub fn register_callback(&self, relation: &str, callback: CB) -> Result - where - CB: Fn(CallbackOp, NamedRows, NamedRows) + Send + Sync + 'static, - { + pub fn register_callback( + &self, + relation: &str, + capacity: Option, + ) -> (u32, Receiver<(CallbackOp, NamedRows, NamedRows)>) { + let (sender, receiver) = if let Some(c) = capacity { + bounded(c) + } else { + unbounded() + }; let cb = CallbackDeclaration { dependent: SmartString::from(relation), - callback: Box::new(callback), + sender: sender, }; let mut guard = self.event_callbacks.write().unwrap(); @@ -607,10 +577,10 @@ impl<'s, S: Storage<'s>> Db { .insert(new_id); guard.0.insert(new_id, cb); - Ok(new_id) + (new_id, receiver) } - /// Unregister callbacks to run when changes to relations are committed. + /// Unregister callbacks/channels to run when changes to relations are committed. #[cfg(not(target_arch = "wasm32"))] pub fn unregister_callback(&self, id: u32) -> bool { let mut guard = self.event_callbacks.write().unwrap(); diff --git a/cozo-core/src/runtime/tests.rs b/cozo-core/src/runtime/tests.rs index c47d0097..5068fbe1 100644 --- a/cozo-core/src/runtime/tests.rs +++ b/cozo-core/src/runtime/tests.rs @@ -8,7 +8,6 @@ */ use std::collections::BTreeMap; -use std::sync::{Arc, Mutex}; use std::time::Duration; use itertools::Itertools; @@ -444,12 +443,8 @@ fn test_trigger() { #[test] fn test_callback() { let db = new_cozo_mem().unwrap(); - let collected = Arc::new(Mutex::new(vec![])); - let copy = collected.clone(); - db.register_callback("friends", move |op, new, old| { - copy.lock().unwrap().push((op, new, old)) - }) - .unwrap(); + let mut collected = vec![]; + let (_id, receiver) = db.register_callback("friends", None).unwrap(); db.run_script( ":create friends {fr: Int, to: Int => data: Any}", Default::default(), @@ -471,7 +466,10 @@ fn test_callback() { ) .unwrap(); std::thread::sleep(Duration::from_secs_f64(0.01)); - let collected = collected.lock().unwrap().clone(); + while let Ok(d) = receiver.try_recv() { + collected.push(d); + } + let collected = collected; assert_eq!(collected[0].0, CallbackOp::Put); assert_eq!(collected[0].1.rows.len(), 2); assert_eq!(collected[0].1.rows[0].len(), 3); @@ -578,7 +576,7 @@ fn test_index() { #[test] fn test_custom_rules() { - let mut db = new_cozo_mem().unwrap(); + let db = new_cozo_mem().unwrap(); struct Custom; impl FixedRule for Custom { diff --git a/cozo-lib-python/Cargo.toml b/cozo-lib-python/Cargo.toml index 551dec67..dda9216a 100644 --- a/cozo-lib-python/Cargo.toml +++ b/cozo-lib-python/Cargo.toml @@ -43,3 +43,4 @@ io-uring = ["cozo/io-uring"] [dependencies] cozo = { version = "0.5.0", path = "../cozo-core", default-features = false } pyo3 = { version = "0.17.1", features = ["extension-module", "abi3", "abi3-py37"] } +miette = "5.5.0" \ No newline at end of file diff --git a/cozo-lib-python/src/lib.rs b/cozo-lib-python/src/lib.rs index aedfa8a2..077b6ef3 100644 --- a/cozo-lib-python/src/lib.rs +++ b/cozo-lib-python/src/lib.rs @@ -6,13 +6,26 @@ * You can obtain one at https://mozilla.org/MPL/2.0/. */ +use std::collections::BTreeMap; +use std::thread; + +use miette::{IntoDiagnostic, Result}; use pyo3::exceptions::PyException; use pyo3::prelude::*; -use pyo3::types::{PyDict, PyList}; -use std::collections::BTreeMap; +use pyo3::types::{PyDict, PyList, PyString, PyTuple}; use cozo::*; +fn py_to_named_rows(ob: &PyAny) -> PyResult { + let rows = ob.extract::>>()?; + let res: Vec> = rows + .into_iter() + .map(|row| row.into_iter().map(py_to_value).collect::>()) + .collect::>()?; + + Ok(NamedRows::new(vec![], res)) +} + fn py_to_value(ob: &PyAny) -> PyResult { Ok(if ob.is_none() { DataValue::Null @@ -40,7 +53,9 @@ fn py_to_value(ob: &PyAny) -> PyResult { } DataValue::List(coll) } else { - return Err(PyException::new_err(format!("Cannot convert {ob} into Cozo value"))); + return Err(PyException::new_err(format!( + "Cannot convert {ob} into Cozo value" + ))); }) } @@ -81,10 +96,8 @@ fn value_to_py(val: DataValue, py: Python<'_>) -> PyObject { } } -fn named_rows_to_py(named_rows: NamedRows, py: Python<'_>) -> PyObject { - let rows = named_rows - .rows - .into_iter() +fn rows_to_py_rows(rows: Vec>, py: Python<'_>) -> PyObject { + rows.into_iter() .map(|row| { row.into_iter() .map(|val| value_to_py(val, py)) @@ -92,7 +105,11 @@ fn named_rows_to_py(named_rows: NamedRows, py: Python<'_>) -> PyObject { .into_py(py) }) .collect::>() - .into_py(py); + .into_py(py) +} + +fn named_rows_to_py(named_rows: NamedRows, py: Python<'_>) -> PyObject { + let rows = rows_to_py_rows(named_rows.rows, py); let headers = named_rows.headers.into_py(py); BTreeMap::from([("rows", rows), ("headers", headers)]).into_py(py) } @@ -130,22 +147,86 @@ impl CozoDbPy { pub fn register_callback(&self, rel: &str, callback: &PyAny) -> PyResult { if let Some(db) = &self.db { let cb: Py = callback.into(); - match db.register_callback(rel, move |op, new, old| { - Python::with_gil(|py| { - let callable = cb.as_ref(py); - let _ = callable.call0(); - }) - }) { - Ok(id) => Ok(id), - Err(err) => { - let reports = format_error_as_json(err, None).to_string(); - Err(PyException::new_err(reports)) + let (id, ch) = db.register_callback(rel, None); + thread::spawn(move || { + for (op, new, old) in ch { + Python::with_gil(|py| { + let op = PyString::new(py, op.as_str()).into(); + let new_py = rows_to_py_rows(new.rows, py); + let old_py = rows_to_py_rows(old.rows, py); + let args = PyTuple::new(py, [op, new_py, old_py]); + let callable = cb.as_ref(py); + if let Err(err) = callable.call1(args) { + eprintln!("{}", err); + } + }) + } + }); + Ok(id) + } else { + Err(PyException::new_err(DB_CLOSED_MSG)) + } + } + pub fn register_fixed_rule( + &self, + name: String, + arity: usize, + callback: &PyAny, + ) -> PyResult<()> { + if let Some(db) = &self.db { + let cb: Py = callback.into(); + let (rule_impl, receiver, sender) = SimpleFixedRule::rule_with_channel(arity); + match db.register_fixed_rule(name, rule_impl) { + Ok(_) => { + thread::spawn(move || { + for (inputs, options) in receiver { + let res = Python::with_gil(|py| -> Result { + let json_convert = + PyModule::import(py, "json").into_diagnostic()?; + let json_convert: Py = + json_convert.getattr("loads").into_diagnostic()?.into(); + let py_inputs = PyList::new( + py, + inputs.into_iter().map(|nr| rows_to_py_rows(nr.rows, py)), + ); + let opts_str = PyString::new(py, &options.to_string()); + let args = PyTuple::new(py, vec![opts_str]); + let py_opts = json_convert.call1(py, args).into_diagnostic()?; + let args = + PyTuple::new(py, vec![PyObject::from(py_inputs), py_opts]); + let res = cb.as_ref(py).call1(args).into_diagnostic()?; + py_to_named_rows(res).into_diagnostic() + }); + if sender.send(res).is_err() { + break; + } + } + }); + Ok(()) } + Err(err) => Err(PyException::new_err(err.to_string())), } } else { Err(PyException::new_err(DB_CLOSED_MSG)) } } + pub fn unregister_callback(&self, id: u32) -> bool { + if let Some(db) = &self.db { + db.unregister_callback(id) + } else { + false + } + } + pub fn unregister_fixed_rule(&self, name: &str) -> PyResult { + if let Some(db) = &self.db { + match db.unregister_fixed_rule(name) { + Ok(b) => Ok(b), + Err(err) => Err(PyException::new_err(err.to_string())), + } + } else { + Ok(false) + } + } pub fn run_query(&self, py: Python<'_>, query: &str, params: &str) -> String { if let Some(db) = &self.db { py.allow_threads(|| db.run_script_str(query, params))