diff --git a/cozo-core/src/storage/rocks.rs b/cozo-core/src/storage/rocks.rs index d5414600..756e8594 100644 --- a/cozo-core/src/storage/rocks.rs +++ b/cozo-core/src/storage/rocks.rs @@ -154,6 +154,8 @@ pub struct RocksDbTx { db_tx: Tx, } +unsafe impl Sync for RocksDbTx {} + impl<'s> StoreTx<'s> for RocksDbTx { #[inline] fn get(&self, key: &[u8], for_update: bool) -> Result>> { diff --git a/cozo-core/src/storage/sqlite.rs b/cozo-core/src/storage/sqlite.rs index 83d779b8..22472051 100644 --- a/cozo-core/src/storage/sqlite.rs +++ b/cozo-core/src/storage/sqlite.rs @@ -6,13 +6,12 @@ * You can obtain one at https://mozilla.org/MPL/2.0/. */ -use std::cell::RefCell; use std::sync::{Arc, Mutex, RwLock, RwLockReadGuard, RwLockWriteGuard}; use ::sqlite::Connection; use either::{Either, Left, Right}; use miette::{bail, miette, IntoDiagnostic, Result}; -use sqlite::{State, Statement}; +use sqlite::{ConnectionWithFullMutex, State, Statement}; use crate::data::tuple::{check_key_for_validity, Tuple}; use crate::data::value::ValidityTs; @@ -25,7 +24,7 @@ use crate::utils::swap_option_result; pub struct SqliteStorage { lock: Arc>, name: String, - pool: Arc>>, + pool: Arc>>, } /// Create a sqlite backed database. @@ -40,7 +39,7 @@ pub fn new_cozo_sqlite(path: String) -> Result> { if path.is_empty() { bail!("empty path for sqlite storage") } - let conn = sqlite::open(&path).into_diagnostic()?; + let conn = Connection::open_with_full_mutex(&path).into_diagnostic()?; let query = r#" create table if not exists cozo ( @@ -67,7 +66,7 @@ impl<'s> Storage<'s> for SqliteStorage { fn transact(&'s self, write: bool) -> Result { let conn = { match self.pool.lock().unwrap().pop() { - None => sqlite::open(&self.name).into_diagnostic()?, + None => Connection::open_with_full_mutex(&self.name).into_diagnostic()?, Some(conn) => conn, } }; @@ -85,10 +84,10 @@ impl<'s> Storage<'s> for SqliteStorage { storage: self, conn: Some(conn), stmts: [ - RefCell::new(None), - RefCell::new(None), - RefCell::new(None), - RefCell::new(None), + Mutex::new(None), + Mutex::new(None), + Mutex::new(None), + Mutex::new(None), ], committed: false, }) @@ -141,11 +140,13 @@ impl<'s> Storage<'s> for SqliteStorage { pub struct SqliteTx<'a> { lock: Either, RwLockWriteGuard<'a, ()>>, storage: &'a SqliteStorage, - conn: Option, - stmts: [RefCell>>; N_CACHED_QUERIES], + conn: Option, + stmts: [Mutex>>; N_CACHED_QUERIES], committed: bool, } +unsafe impl Sync for SqliteTx<'_> {} + const N_QUERIES: usize = 6; const N_CACHED_QUERIES: usize = 4; const QUERIES: [&str; N_QUERIES] = [ @@ -180,7 +181,7 @@ impl Drop for SqliteTx<'_> { impl<'s> SqliteTx<'s> { fn ensure_stmt(&self, idx: usize) { - let mut stmt = self.stmts[idx].borrow_mut(); + let mut stmt = self.stmts[idx].lock().unwrap(); if stmt.is_none() { let query = QUERIES[idx]; let prepared = self.conn.as_ref().unwrap().prepare(query).unwrap(); @@ -198,7 +199,7 @@ impl<'s> SqliteTx<'s> { impl<'s> StoreTx<'s> for SqliteTx<'s> { fn get(&self, key: &[u8], _for_update: bool) -> Result>> { self.ensure_stmt(GET_QUERY); - let mut statement = self.stmts[GET_QUERY].borrow_mut(); + let mut statement = self.stmts[GET_QUERY].lock().unwrap(); let statement = statement.as_mut().unwrap(); statement.reset().unwrap(); @@ -214,7 +215,7 @@ impl<'s> StoreTx<'s> for SqliteTx<'s> { fn put(&mut self, key: &[u8], val: &[u8]) -> Result<()> { self.ensure_stmt(PUT_QUERY); - let mut statement = self.stmts[PUT_QUERY].borrow_mut(); + let mut statement = self.stmts[PUT_QUERY].lock().unwrap(); let statement = statement.as_mut().unwrap(); statement.reset().unwrap(); @@ -226,7 +227,7 @@ impl<'s> StoreTx<'s> for SqliteTx<'s> { fn del(&mut self, key: &[u8]) -> Result<()> { self.ensure_stmt(DEL_QUERY); - let mut statement = self.stmts[DEL_QUERY].borrow_mut(); + let mut statement = self.stmts[DEL_QUERY].lock().unwrap(); let statement = statement.as_mut().unwrap(); statement.reset().unwrap(); @@ -238,7 +239,7 @@ impl<'s> StoreTx<'s> for SqliteTx<'s> { fn exists(&self, key: &[u8], _for_update: bool) -> Result { self.ensure_stmt(EXISTS_QUERY); - let mut statement = self.stmts[EXISTS_QUERY].borrow_mut(); + let mut statement = self.stmts[EXISTS_QUERY].lock().unwrap(); let statement = statement.as_mut().unwrap(); statement.reset().unwrap(); @@ -385,7 +386,7 @@ impl<'l> SkipIter<'l> { if let Some(mut tup) = ret { let v = self.stmt.read::, _>(1).unwrap(); extend_tuple_from_v(&mut tup, &v); - return Ok(Some(tup)) + return Ok(Some(tup)); } } } diff --git a/cozo-core/tests/air_routes.rs b/cozo-core/tests/air_routes.rs index ed2c9d8e..5b5b0c8f 100644 --- a/cozo-core/tests/air_routes.rs +++ b/cozo-core/tests/air_routes.rs @@ -193,6 +193,30 @@ fn empty() { assert!(res.is_err()); } +#[test] +fn parallel_counts() { + initialize(&TEST_DB); + let res = TEST_DB + .run_script( + r#" + a[count(fr)] := *route{fr} + b[count(fr)] := *route{fr} + c[count(fr)] := *route{fr} + d[count(fr)] := *route{fr} + e[count(fr)] := *route{fr} + ?[x] := a[a], b[b], c[c], d[d], e[e], x = a + b + c + d + e + "#, + Default::default(), + ) + .unwrap() + .rows + .remove(0) + .remove(0) + .as_i64() + .unwrap(); + assert_eq!(res, 50637 * 5); +} + #[test] fn bfs() { initialize(&TEST_DB);