Python callbacks and fixed rules

main
Ziyang Hu 2 years ago
parent a5ad6b5633
commit 5bb68486f5

1
Cargo.lock generated

@ -643,6 +643,7 @@ name = "cozo_py"
version = "0.5.0"
dependencies = [
"cozo",
"miette",
"pyo3",
]

@ -9,12 +9,14 @@
use std::collections::BTreeMap;
use std::sync::Arc;
use crossbeam::channel::{bounded, Receiver, Sender};
#[allow(unused_imports)]
use either::{Left, Right};
#[cfg(feature = "graph-algo")]
use graph::prelude::{CsrLayout, DirectedCsrGraph, GraphBuilder};
use itertools::Itertools;
use lazy_static::lazy_static;
use miette::IntoDiagnostic;
#[allow(unused_imports)]
use miette::{bail, ensure, Diagnostic, Report, Result};
use smartstring::{LazyCompact, SmartString};
@ -586,6 +588,28 @@ impl SimpleFixedRule {
rule: Box::new(rule),
}
}
/// Construct a SimpleFixedRule that uses channels for communication.
pub fn rule_with_channel(
return_arity: usize,
) -> (
Self,
Receiver<(Vec<NamedRows>, JsonValue)>,
Sender<Result<NamedRows>>,
) {
let (db2app_sender, db2app_receiver) = bounded(0);
let (app2db_sender, app2db_receiver) = bounded(0);
(
Self {
return_arity,
rule: Box::new(move |inputs, options| -> Result<NamedRows> {
db2app_sender.send((inputs, options)).into_diagnostic()?;
app2db_receiver.recv().into_diagnostic()?
}),
},
db2app_receiver,
app2db_sender,
)
}
}
impl FixedRule for SimpleFixedRule {
@ -635,7 +659,7 @@ impl FixedRule for SimpleFixedRule {
let results: NamedRows = (self.rule)(inputs, options)?;
for row in results.rows {
#[derive(Debug, Error, Diagnostic)]
#[error("arity mismatch: expect {0}, got {1}")]
#[error("arity mismatch: expect {1}, got {2}")]
#[diagnostic(code(parser::simple_fixed_rule_arity_mismatch))]
struct ArityMismatch(#[label] SourceSpan, usize, usize);

@ -36,6 +36,7 @@ use std::path::Path;
#[allow(unused_imports)]
use std::time::Instant;
use crossbeam::channel::Receiver;
use lazy_static::lazy_static;
pub use miette::Error;
use miette::Report;
@ -375,22 +376,24 @@ impl DbInstance {
self.import_from_backup(&json_payload.path, &json_payload.relations)
}
/// Dispatcher method. See [crate::Db::register_callback].
#[cfg(not(target_arch = "wasm32"))]
pub fn register_callback<CB>(&self, relation: &str, callback: CB) -> Result<u32>
where
CB: Fn(CallbackOp, NamedRows, NamedRows) + Send + Sync + 'static,
{
pub fn register_callback(
&self,
relation: &str,
capacity: Option<usize>,
) -> (u32, Receiver<(CallbackOp, NamedRows, NamedRows)>) {
match self {
DbInstance::Mem(db) => db.register_callback(relation, callback),
DbInstance::Mem(db) => db.register_callback(relation, capacity),
#[cfg(feature = "storage-sqlite")]
DbInstance::Sqlite(db) => db.register_callback(relation, callback),
DbInstance::Sqlite(db) => db.register_callback(relation, capacity),
#[cfg(feature = "storage-rocksdb")]
DbInstance::RocksDb(db) => db.register_callback(relation, callback),
DbInstance::RocksDb(db) => db.register_callback(relation, capacity),
#[cfg(feature = "storage-sled")]
DbInstance::Sled(db) => db.register_callback(relation, callback),
DbInstance::Sled(db) => db.register_callback(relation, capacity),
#[cfg(feature = "storage-tikv")]
DbInstance::TiKv(db) => db.register_callback(relation, callback),
DbInstance::TiKv(db) => db.register_callback(relation, capacity),
}
}
@ -410,11 +413,7 @@ impl DbInstance {
}
}
/// Dispatcher method. See [crate::Db::register_fixed_rule].
pub fn register_fixed_rule(
&self,
name: String,
rule_impl: Box<dyn FixedRule>,
) -> Result<()> {
pub fn register_fixed_rule<R>(&self, name: String, rule_impl: R) -> Result<()> where R: FixedRule + 'static {
match self {
DbInstance::Mem(db) => db.register_fixed_rule(name, rule_impl),
#[cfg(feature = "storage-sqlite")]

@ -9,6 +9,7 @@
use std::collections::{BTreeMap, BTreeSet};
use std::fmt::{Display, Formatter};
use crossbeam::channel::Sender;
use smartstring::{LazyCompact, SmartString};
use crate::{Db, NamedRows, Storage};
@ -31,10 +32,20 @@ impl Display for CallbackOp {
}
}
impl CallbackOp {
/// Get the string representation
pub fn as_str(&self) -> &'static str {
match self {
CallbackOp::Put => "Put",
CallbackOp::Rm => "Rm"
}
}
}
#[cfg(not(target_arch = "wasm32"))]
pub struct CallbackDeclaration {
pub(crate) dependent: SmartString<LazyCompact>,
pub(crate) callback: Box<dyn Fn(CallbackOp, NamedRows, NamedRows) + Send + Sync>,
pub(crate) sender: Sender<(CallbackOp, NamedRows, NamedRows)>,
}
pub(crate) type CallbackCollector =
@ -66,11 +77,40 @@ impl<'s, S: Storage<'s>> Db<S> {
}
#[cfg(not(target_arch = "wasm32"))]
pub(crate) fn send_callbacks(&'s self, collector: CallbackCollector) {
for (k, vals) in collector {
let mut to_remove = vec![];
for (table, vals) in collector {
for (op, new, old) in vals {
self.callback_sender
.send((k.clone(), op, new, old))
.expect("sending to callback processor failed");
let (cbs, cb_dir) = &*self.event_callbacks.read().unwrap();
if let Some(cb_ids) = cb_dir.get(&table) {
let mut it = cb_ids.iter();
if let Some(fst) = it.next() {
for cb_id in it {
if let Some(cb) = cbs.get(cb_id) {
if cb.sender.send((op, new.clone(), old.clone())).is_err() {
to_remove.push(*cb_id)
}
}
}
if let Some(cb) = cbs.get(fst) {
if cb.sender.send((op, new, old)).is_err() {
to_remove.push(*fst)
}
}
}
}
}
}
if !to_remove.is_empty() {
let (cbs, cb_dir) = &mut *self.event_callbacks.write().unwrap();
for removing_id in &to_remove {
if let Some(removed) = cbs.remove(removing_id) {
if let Some(set) = cb_dir.get_mut(&removed.dependent) {
set.remove(removing_id);
}
}
}
}
}

@ -20,8 +20,7 @@ use std::thread;
#[allow(unused_imports)]
use std::time::{Duration, SystemTime, UNIX_EPOCH};
#[cfg(not(target_arch = "wasm32"))]
use crossbeam::channel::{unbounded, Sender};
use crossbeam::channel::{bounded, unbounded, Receiver};
use crossbeam::sync::ShardedLock;
use either::{Left, Right};
use itertools::Itertools;
@ -92,10 +91,7 @@ pub struct Db<S> {
#[cfg(not(target_arch = "wasm32"))]
callback_count: AtomicU32,
#[cfg(not(target_arch = "wasm32"))]
pub(crate) callback_sender:
Sender<(SmartString<LazyCompact>, CallbackOp, NamedRows, NamedRows)>,
#[cfg(not(target_arch = "wasm32"))]
pub(crate) event_callbacks: Arc<ShardedLock<EventCallbackRegistry>>,
pub(crate) event_callbacks: ShardedLock<EventCallbackRegistry>,
relation_locks: ShardedLock<BTreeMap<SmartString<LazyCompact>, Arc<ShardedLock<()>>>>,
}
@ -183,8 +179,6 @@ impl<'s, S: Storage<'s>> Db<S> {
/// You must call [`initialize`](Self::initialize) immediately after creation.
/// Due to lifetime restrictions we are not able to call that for you automatically.
pub fn new(storage: S) -> Result<Self> {
#[cfg(not(target_arch = "wasm32"))]
let (sender, receiver) = unbounded();
let ret = Self {
db: storage,
temp_db: Default::default(),
@ -194,29 +188,11 @@ impl<'s, S: Storage<'s>> Db<S> {
fixed_rules: ShardedLock::new(DEFAULT_FIXED_RULES.clone()),
#[cfg(not(target_arch = "wasm32"))]
callback_count: Default::default(),
#[cfg(not(target_arch = "wasm32"))]
callback_sender: sender,
// callback_receiver: Arc::new(receiver),
#[cfg(not(target_arch = "wasm32"))]
event_callbacks: Default::default(),
relation_locks: Default::default(),
};
#[cfg(not(target_arch = "wasm32"))]
{
let callbacks = ret.event_callbacks.clone();
thread::spawn(move || {
while let Ok((table, op, new, old)) = receiver.recv() {
let (cbs, cb_dir) = &*callbacks.read().unwrap();
if let Some(cb_ids) = cb_dir.get(&table) {
for cb_id in cb_ids {
if let Some(cb) = cbs.get(cb_id) {
(cb.callback)(op, new.clone(), old.clone())
}
}
}
}
});
}
Ok(ret)
}
@ -548,10 +524,13 @@ impl<'s, S: Storage<'s>> Db<S> {
}
}
/// Register a custom fixed rule implementation.
pub fn register_fixed_rule(&self, name: String, rule_impl: Box<dyn FixedRule>) -> Result<()> {
pub fn register_fixed_rule<R>(&self, name: String, rule_impl: R) -> Result<()>
where
R: FixedRule + 'static,
{
match self.fixed_rules.write().unwrap().entry(name) {
Entry::Vacant(ent) => {
ent.insert(Arc::new(rule_impl));
ent.insert(Arc::new(Box::new(rule_impl)));
Ok(())
}
Entry::Occupied(ent) => {
@ -571,31 +550,22 @@ impl<'s, S: Storage<'s>> Db<S> {
Ok(self.fixed_rules.write().unwrap().remove(name).is_some())
}
/// Register callbacks to run when changes to the requested relation are successfully committed.
/// The returned ID can be used to unregister the callbacks.
///
/// When any mutation is made against the registered relation,
/// the callback will be called with three arguments:
///
/// * The type of the mutation (`Put` or `Rm`)
/// * The incoming rows
/// * For `Put`, the rows inserted / upserted
/// * For `Rm`, the keys of the rows to be removed. These keys do not necessarily existed in the database.
/// * The evicted rows
/// * For `Put`, old rows that were updated
/// * For `Rm`, the rows actually removed.
///
/// The database has a single dedicated thread for executing callbacks no matter how many callbacks
/// are registered. You should not perform blocking operations in the callback provided to this function.
/// If blocking operations are required, send the data to some other thread and perform the operations there.
/// Register callback channel to receive changes when the requested relation are successfully committed.
/// The returned ID can be used to unregister the callback channel.
#[cfg(not(target_arch = "wasm32"))]
pub fn register_callback<CB>(&self, relation: &str, callback: CB) -> Result<u32>
where
CB: Fn(CallbackOp, NamedRows, NamedRows) + Send + Sync + 'static,
{
pub fn register_callback(
&self,
relation: &str,
capacity: Option<usize>,
) -> (u32, Receiver<(CallbackOp, NamedRows, NamedRows)>) {
let (sender, receiver) = if let Some(c) = capacity {
bounded(c)
} else {
unbounded()
};
let cb = CallbackDeclaration {
dependent: SmartString::from(relation),
callback: Box::new(callback),
sender: sender,
};
let mut guard = self.event_callbacks.write().unwrap();
@ -607,10 +577,10 @@ impl<'s, S: Storage<'s>> Db<S> {
.insert(new_id);
guard.0.insert(new_id, cb);
Ok(new_id)
(new_id, receiver)
}
/// Unregister callbacks to run when changes to relations are committed.
/// Unregister callbacks/channels to run when changes to relations are committed.
#[cfg(not(target_arch = "wasm32"))]
pub fn unregister_callback(&self, id: u32) -> bool {
let mut guard = self.event_callbacks.write().unwrap();

@ -8,7 +8,6 @@
*/
use std::collections::BTreeMap;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use itertools::Itertools;
@ -444,12 +443,8 @@ fn test_trigger() {
#[test]
fn test_callback() {
let db = new_cozo_mem().unwrap();
let collected = Arc::new(Mutex::new(vec![]));
let copy = collected.clone();
db.register_callback("friends", move |op, new, old| {
copy.lock().unwrap().push((op, new, old))
})
.unwrap();
let mut collected = vec![];
let (_id, receiver) = db.register_callback("friends", None).unwrap();
db.run_script(
":create friends {fr: Int, to: Int => data: Any}",
Default::default(),
@ -471,7 +466,10 @@ fn test_callback() {
)
.unwrap();
std::thread::sleep(Duration::from_secs_f64(0.01));
let collected = collected.lock().unwrap().clone();
while let Ok(d) = receiver.try_recv() {
collected.push(d);
}
let collected = collected;
assert_eq!(collected[0].0, CallbackOp::Put);
assert_eq!(collected[0].1.rows.len(), 2);
assert_eq!(collected[0].1.rows[0].len(), 3);
@ -578,7 +576,7 @@ fn test_index() {
#[test]
fn test_custom_rules() {
let mut db = new_cozo_mem().unwrap();
let db = new_cozo_mem().unwrap();
struct Custom;
impl FixedRule for Custom {

@ -43,3 +43,4 @@ io-uring = ["cozo/io-uring"]
[dependencies]
cozo = { version = "0.5.0", path = "../cozo-core", default-features = false }
pyo3 = { version = "0.17.1", features = ["extension-module", "abi3", "abi3-py37"] }
miette = "5.5.0"

@ -6,13 +6,26 @@
* You can obtain one at https://mozilla.org/MPL/2.0/.
*/
use std::collections::BTreeMap;
use std::thread;
use miette::{IntoDiagnostic, Result};
use pyo3::exceptions::PyException;
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList};
use std::collections::BTreeMap;
use pyo3::types::{PyDict, PyList, PyString, PyTuple};
use cozo::*;
fn py_to_named_rows(ob: &PyAny) -> PyResult<NamedRows> {
let rows = ob.extract::<Vec<Vec<&PyAny>>>()?;
let res: Vec<Vec<DataValue>> = rows
.into_iter()
.map(|row| row.into_iter().map(py_to_value).collect::<PyResult<_>>())
.collect::<PyResult<_>>()?;
Ok(NamedRows::new(vec![], res))
}
fn py_to_value(ob: &PyAny) -> PyResult<DataValue> {
Ok(if ob.is_none() {
DataValue::Null
@ -40,7 +53,9 @@ fn py_to_value(ob: &PyAny) -> PyResult<DataValue> {
}
DataValue::List(coll)
} else {
return Err(PyException::new_err(format!("Cannot convert {ob} into Cozo value")));
return Err(PyException::new_err(format!(
"Cannot convert {ob} into Cozo value"
)));
})
}
@ -81,10 +96,8 @@ fn value_to_py(val: DataValue, py: Python<'_>) -> PyObject {
}
}
fn named_rows_to_py(named_rows: NamedRows, py: Python<'_>) -> PyObject {
let rows = named_rows
.rows
.into_iter()
fn rows_to_py_rows(rows: Vec<Vec<DataValue>>, py: Python<'_>) -> PyObject {
rows.into_iter()
.map(|row| {
row.into_iter()
.map(|val| value_to_py(val, py))
@ -92,7 +105,11 @@ fn named_rows_to_py(named_rows: NamedRows, py: Python<'_>) -> PyObject {
.into_py(py)
})
.collect::<Vec<_>>()
.into_py(py);
.into_py(py)
}
fn named_rows_to_py(named_rows: NamedRows, py: Python<'_>) -> PyObject {
let rows = rows_to_py_rows(named_rows.rows, py);
let headers = named_rows.headers.into_py(py);
BTreeMap::from([("rows", rows), ("headers", headers)]).into_py(py)
}
@ -130,22 +147,86 @@ impl CozoDbPy {
pub fn register_callback(&self, rel: &str, callback: &PyAny) -> PyResult<u32> {
if let Some(db) = &self.db {
let cb: Py<PyAny> = callback.into();
match db.register_callback(rel, move |op, new, old| {
Python::with_gil(|py| {
let callable = cb.as_ref(py);
let _ = callable.call0();
})
}) {
Ok(id) => Ok(id),
Err(err) => {
let reports = format_error_as_json(err, None).to_string();
Err(PyException::new_err(reports))
let (id, ch) = db.register_callback(rel, None);
thread::spawn(move || {
for (op, new, old) in ch {
Python::with_gil(|py| {
let op = PyString::new(py, op.as_str()).into();
let new_py = rows_to_py_rows(new.rows, py);
let old_py = rows_to_py_rows(old.rows, py);
let args = PyTuple::new(py, [op, new_py, old_py]);
let callable = cb.as_ref(py);
if let Err(err) = callable.call1(args) {
eprintln!("{}", err);
}
})
}
});
Ok(id)
} else {
Err(PyException::new_err(DB_CLOSED_MSG))
}
}
pub fn register_fixed_rule(
&self,
name: String,
arity: usize,
callback: &PyAny,
) -> PyResult<()> {
if let Some(db) = &self.db {
let cb: Py<PyAny> = callback.into();
let (rule_impl, receiver, sender) = SimpleFixedRule::rule_with_channel(arity);
match db.register_fixed_rule(name, rule_impl) {
Ok(_) => {
thread::spawn(move || {
for (inputs, options) in receiver {
let res = Python::with_gil(|py| -> Result<NamedRows> {
let json_convert =
PyModule::import(py, "json").into_diagnostic()?;
let json_convert: Py<PyAny> =
json_convert.getattr("loads").into_diagnostic()?.into();
let py_inputs = PyList::new(
py,
inputs.into_iter().map(|nr| rows_to_py_rows(nr.rows, py)),
);
let opts_str = PyString::new(py, &options.to_string());
let args = PyTuple::new(py, vec![opts_str]);
let py_opts = json_convert.call1(py, args).into_diagnostic()?;
let args =
PyTuple::new(py, vec![PyObject::from(py_inputs), py_opts]);
let res = cb.as_ref(py).call1(args).into_diagnostic()?;
py_to_named_rows(res).into_diagnostic()
});
if sender.send(res).is_err() {
break;
}
}
});
Ok(())
}
Err(err) => Err(PyException::new_err(err.to_string())),
}
} else {
Err(PyException::new_err(DB_CLOSED_MSG))
}
}
pub fn unregister_callback(&self, id: u32) -> bool {
if let Some(db) = &self.db {
db.unregister_callback(id)
} else {
false
}
}
pub fn unregister_fixed_rule(&self, name: &str) -> PyResult<bool> {
if let Some(db) = &self.db {
match db.unregister_fixed_rule(name) {
Ok(b) => Ok(b),
Err(err) => Err(PyException::new_err(err.to_string())),
}
} else {
Ok(false)
}
}
pub fn run_query(&self, py: Python<'_>, query: &str, params: &str) -> String {
if let Some(db) = &self.db {
py.allow_threads(|| db.run_script_str(query, params))

Loading…
Cancel
Save