diff --git a/cozo-core/src/lib.rs b/cozo-core/src/lib.rs index b504a933..926d8e45 100644 --- a/cozo-core/src/lib.rs +++ b/cozo-core/src/lib.rs @@ -33,10 +33,11 @@ use std::collections::BTreeMap; use std::path::Path; +use std::thread; #[allow(unused_imports)] use std::time::Instant; -use crossbeam::channel::Receiver; +use crossbeam::channel::{bounded, Receiver, Sender}; use lazy_static::lazy_static; pub use miette::Error; use miette::Report; @@ -71,6 +72,7 @@ pub use crate::fixed_rule::SimpleFixedRule; pub use crate::parse::SourceSpan; pub use crate::runtime::callback::CallbackOp; pub use crate::runtime::db::Poison; +pub use crate::runtime::db::TransactionPayload; #[cfg(not(target_arch = "wasm32"))] pub(crate) mod data; @@ -90,6 +92,7 @@ pub(crate) mod utils; /// Other methods are wrappers simplifying signatures to deal with only strings. /// These methods made code for interop with other languages much easier, /// but are not desirable if you are using Rust. +#[derive(Clone)] pub enum DbInstance { /// In memory storage (not persistent) Mem(Db), @@ -444,6 +447,87 @@ impl DbInstance { DbInstance::TiKv(db) => db.unregister_fixed_rule(name), } } + + /// Dispatcher method. See [crate::Db::run_multi_transaction] + pub fn run_multi_transaction( + &self, + write: bool, + payloads: Receiver, + results: Sender>, + ) { + match self { + DbInstance::Mem(db) => db.run_multi_transaction(write, payloads, results), + #[cfg(feature = "storage-sqlite")] + DbInstance::Sqlite(db) => db.run_multi_transaction(write, payloads, results), + #[cfg(feature = "storage-rocksdb")] + DbInstance::RocksDb(db) => db.run_multi_transaction(write, payloads, results), + #[cfg(feature = "storage-sled")] + DbInstance::Sled(db) => db.run_multi_transaction(write, payloads, results), + #[cfg(feature = "storage-tikv")] + DbInstance::TiKv(db) => db.run_multi_transaction(write, payloads, results), + } + } + /// A higher-level, blocking wrapper for [crate::Db::run_multi_transaction]. Runs the transaction on a dedicated thread. + /// Write transactions _may_ block other reads, but we guarantee that this does not happen for the RocksDB backend. + pub fn multi_transaction(&self, write: bool) -> MultiTransaction { + let (app2db_send, app2db_recv) = bounded(1); + let (db2app_send, db2app_recv) = bounded(1); + let db = self.clone(); + thread::spawn(move || db.run_multi_transaction(write, app2db_recv, db2app_send)); + MultiTransaction { + sender: app2db_send, + receiver: db2app_recv, + } + } +} + +/// A multi-transaction handle. +/// You should use either the fields directly, or the associated functions. +pub struct MultiTransaction { + /// Commands can be sent into the transaction through this channel + pub sender: Sender, + /// Results can be retrieved from the transaction from this channel + pub receiver: Receiver>, +} + +impl MultiTransaction { + /// Runs a single script in the transaction. + pub fn run_script( + &self, + payload: &str, + params: BTreeMap, + ) -> Result { + if let Err(err) = self + .sender + .send(TransactionPayload::Query((payload.to_string(), params))) + { + bail!(err); + } + match self.receiver.recv() { + Ok(r) => r, + Err(err) => bail!(err), + } + } + /// Commits the multi-transaction + pub fn commit(&self) -> Result<()> { + if let Err(err) = self.sender.send(TransactionPayload::Commit) { + bail!(err); + } + match self.receiver.recv() { + Ok(_) => Ok(()), + Err(err) => bail!(err), + } + } + /// Aborts the multi-transaction + pub fn abort(&self) -> Result<()> { + if let Err(err) = self.sender.send(TransactionPayload::Abort) { + bail!(err); + } + match self.receiver.recv() { + Ok(_) => Ok(()), + Err(err) => bail!(err), + } + } } /// Convert error raised by the database into friendly JSON format diff --git a/cozo-core/src/runtime/db.rs b/cozo-core/src/runtime/db.rs index af428157..83467122 100644 --- a/cozo-core/src/runtime/db.rs +++ b/cozo-core/src/runtime/db.rs @@ -20,7 +20,7 @@ use std::thread; #[allow(unused_imports)] use std::time::{Duration, SystemTime, UNIX_EPOCH}; -use crossbeam::channel::{bounded, unbounded, Receiver}; +use crossbeam::channel::{bounded, unbounded, Receiver, Sender}; use crossbeam::sync::ShardedLock; use either::{Left, Right}; use itertools::Itertools; @@ -176,6 +176,17 @@ impl NamedRows { const STATUS_STR: &str = "status"; const OK_STR: &str = "OK"; +/// Commands to be sent to a multi-transaction +#[derive(Eq, PartialEq, Debug)] +pub enum TransactionPayload { + /// Commit the current transaction + Commit, + /// Abort the current transaction + Abort, + /// Run a query inside the transaction + Query((String, BTreeMap)), +} + impl<'s, S: Storage<'s>> Db { /// Create a new database object with the given storage. /// You must call [`initialize`](Self::initialize) immediately after creation. @@ -204,6 +215,112 @@ impl<'s, S: Storage<'s>> Db { Ok(()) } + /// Run a multi-transaction. A command should be sent to `payloads`, and the result should be + /// retrieved from `results`. A transaction ends when it receives a `Commit` or `Abort`, + /// or when a query is not successful. After a transaction ends, sending / receiving from + /// the channels will fail. + /// + /// Write transactions _may_ block other reads, but we guarantee that this does not happen + /// for the RocksDB backend. + pub fn run_multi_transaction( + &'s self, + is_write: bool, + payloads: Receiver, + results: Sender>, + ) { + let tx = if is_write { + self.transact_write() + } else { + self.transact() + }; + let mut cleanups: Vec<(Vec, Vec)> = vec![]; + let mut tx = match tx { + Ok(tx) => tx, + Err(err) => { + let _ = results.send(Err(err)); + return; + } + }; + + let ts = current_validity(); + let callback_targets = self.current_callback_targets(); + let mut callback_collector = BTreeMap::new(); + let mut write_locks = BTreeMap::new(); + + for payload in payloads { + match payload { + TransactionPayload::Commit => { + let _ = results.send(tx.commit_tx().map(|_| NamedRows::default())); + #[cfg(not(target_arch = "wasm32"))] + if !callback_collector.is_empty() { + self.send_callbacks(callback_collector) + } + + for (lower, upper) in cleanups { + if let Err(err) = self.db.del_range(&lower, &upper) { + eprintln!("{err:?}") + } + } + + break; + } + TransactionPayload::Abort => { + let _ = results.send(Ok(NamedRows::default())); + break; + } + TransactionPayload::Query((script, params)) => { + let p = + match parse_script(&script, ¶ms, &self.fixed_rules.read().unwrap(), ts) + { + Ok(p) => p, + Err(err) => { + if results.send(Err(err)).is_err() { + break; + } else { + continue; + } + } + }; + + let p = match p.get_single_program() { + Ok(p) => p, + Err(err) => { + if results.send(Err(err)).is_err() { + break; + } else { + continue; + } + } + }; + if let Some(write_lock_name) = p.needs_write_lock() { + match write_locks.entry(write_lock_name) { + Entry::Vacant(e) => { + let lock = self + .obtain_relation_locks(iter::once(e.key())) + .pop() + .unwrap(); + e.insert(lock); + } + Entry::Occupied(_) => {} + } + } + + let res = self.execute_single_program( + p, + &mut tx, + &mut cleanups, + ts, + &callback_targets, + &mut callback_collector, + ); + if results.send(res).is_err() { + break; + } + } + } + } + } + /// Run the CozoScript passed in. The `params` argument is a map of parameters. pub fn run_script( &'s self, diff --git a/cozo-core/src/runtime/tests.rs b/cozo-core/src/runtime/tests.rs index 5068fbe1..8b75066e 100644 --- a/cozo-core/src/runtime/tests.rs +++ b/cozo-core/src/runtime/tests.rs @@ -22,7 +22,7 @@ use crate::fixed_rule::FixedRulePayload; use crate::parse::SourceSpan; use crate::runtime::callback::CallbackOp; use crate::runtime::db::Poison; -use crate::{new_cozo_mem, FixedRule, RegularTempStore}; +use crate::{new_cozo_mem, DbInstance, FixedRule, RegularTempStore}; #[test] fn test_limit_offset() { @@ -444,7 +444,7 @@ fn test_trigger() { fn test_callback() { let db = new_cozo_mem().unwrap(); let mut collected = vec![]; - let (_id, receiver) = db.register_callback("friends", None).unwrap(); + let (_id, receiver) = db.register_callback("friends", None); db.run_script( ":create friends {fr: Int, to: Int => data: Any}", Default::default(), @@ -611,7 +611,7 @@ fn test_custom_rules() { } } - db.register_fixed_rule("SumCols".to_string(), Box::new(Custom)) + db.register_fixed_rule("SumCols".to_string(), Custom) .unwrap(); let res = db .run_script( @@ -704,3 +704,37 @@ fn test_index_short() { .unwrap(); assert_eq!(res.into_json()["rows"], json!([[1, 5]])); } + +#[test] +fn test_multi_tx() { + let db = DbInstance::new("mem", "", "").unwrap(); + let tx = db.multi_transaction(true); + tx.run_script(":create a {a}", Default::default()).unwrap(); + tx.run_script("?[a] <- [[1]] :put a {a}", Default::default()) + .unwrap(); + assert!(tx.run_script(":create a {a}", Default::default()).is_err()); + tx.run_script("?[a] <- [[2]] :put a {a}", Default::default()) + .unwrap(); + tx.run_script("?[a] <- [[3]] :put a {a}", Default::default()) + .unwrap(); + tx.commit().unwrap(); + assert_eq!( + db.run_script("?[a] := *a[a]", Default::default()) + .unwrap() + .into_json()["rows"], + json!([[1], [2], [3]]) + ); + + let db = DbInstance::new("mem", "", "").unwrap(); + let tx = db.multi_transaction(true); + tx.run_script(":create a {a}", Default::default()).unwrap(); + tx.run_script("?[a] <- [[1]] :put a {a}", Default::default()) + .unwrap(); + assert!(tx.run_script(":create a {a}", Default::default()).is_err()); + tx.run_script("?[a] <- [[2]] :put a {a}", Default::default()) + .unwrap(); + tx.run_script("?[a] <- [[3]] :put a {a}", Default::default()) + .unwrap(); + tx.abort().unwrap(); + assert!(db.run_script("?[a] := *a[a]", Default::default()).is_err()); +} diff --git a/cozo-core/src/storage/mod.rs b/cozo-core/src/storage/mod.rs index bec4544f..9097677e 100644 --- a/cozo-core/src/storage/mod.rs +++ b/cozo-core/src/storage/mod.rs @@ -26,7 +26,7 @@ pub(crate) mod tikv; // pub(crate) mod re; /// Swappable storage trait for Cozo's storage engine -pub trait Storage<'s>: Sync + Clone { +pub trait Storage<'s>: Send + Sync + Clone { /// The associated transaction type used by this engine type Tx: StoreTx<'s>;