diff --git a/cozo-core/src/query/stored.rs b/cozo-core/src/query/stored.rs index 55047368..6ffb98f6 100644 --- a/cozo-core/src/query/stored.rs +++ b/cozo-core/src/query/stored.rs @@ -444,7 +444,7 @@ fn make_const_rule( ); let bindings_arity = bindings.len(); program.prog.insert( - rule_symbol.clone(), + rule_symbol, InputInlineRulesOrAlgo::Algo { algo: AlgoApply { algo: AlgoHandle { diff --git a/cozo-core/src/runtime/db.rs b/cozo-core/src/runtime/db.rs index d3db8a0b..33cc20f2 100644 --- a/cozo-core/src/runtime/db.rs +++ b/cozo-core/src/runtime/db.rs @@ -180,11 +180,7 @@ impl<'s, S: Storage<'s>> Db { for data in tx.tx.range_scan(&start, &end) { let (k, v) = data?; let tuple = decode_tuple_from_kv(&k, &v); - let row = tuple - .0 - .into_iter() - .map(|dv| JsonValue::from(dv)) - .collect_vec(); + let row = tuple.0.into_iter().map(JsonValue::from).collect_vec(); rows.push(row); } let headers = cols.iter().map(|col| col.to_string()).collect_vec(); diff --git a/cozo-core/src/storage/sqlite.rs b/cozo-core/src/storage/sqlite.rs index b1e342ff..8ba67f70 100644 --- a/cozo-core/src/storage/sqlite.rs +++ b/cozo-core/src/storage/sqlite.rs @@ -6,7 +6,8 @@ * You can obtain one at https://mozilla.org/MPL/2.0/. */ -use std::sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard}; +use std::cell::RefCell; +use std::sync::{Arc, Mutex, RwLock, RwLockReadGuard, RwLockWriteGuard}; use ::sqlite::Connection; use either::{Either, Left, Right}; @@ -22,6 +23,7 @@ use crate::storage::{Storage, StoreTx}; pub struct SqliteStorage { lock: Arc>, name: String, + pool: Arc>>, } /// Create a sqlite backed database. @@ -50,6 +52,7 @@ pub fn new_cozo_sqlite(path: String) -> Result> { let ret = crate::Db::new(SqliteStorage { lock: Arc::new(Default::default()), name: path, + pool: Arc::new(Mutex::new(vec![])), })?; ret.initialize()?; @@ -60,7 +63,12 @@ impl<'s> Storage<'s> for SqliteStorage { type Tx = SqliteTx<'s>; fn transact(&'s self, write: bool) -> Result { - let conn = sqlite::open(&self.name).into_diagnostic()?; + let conn = { + match self.pool.lock().unwrap().pop() { + None => sqlite::open(&self.name).into_diagnostic()?, + Some(conn) => conn, + } + }; let lock = if write { Right(self.lock.write().unwrap()) } else { @@ -72,7 +80,14 @@ impl<'s> Storage<'s> for SqliteStorage { } Ok(SqliteTx { lock, - conn, + storage: self, + conn: Some(conn), + stmts: [ + RefCell::new(None), + RefCell::new(None), + RefCell::new(None), + RefCell::new(None), + ], committed: false, }) } @@ -101,34 +116,74 @@ impl<'s> Storage<'s> for SqliteStorage { } fn range_compact(&'_ self, _lower: &[u8], _upper: &[u8]) -> Result<()> { + let mut pool = self.pool.lock().unwrap(); + while pool.pop().is_some() {} Ok(()) } } pub struct SqliteTx<'a> { lock: Either, RwLockWriteGuard<'a, ()>>, - conn: Connection, + storage: &'a SqliteStorage, + conn: Option, + stmts: [RefCell>>; N_CACHED_QUERIES], committed: bool, } +const N_QUERIES: usize = 5; +const N_CACHED_QUERIES: usize = 4; +const QUERIES: [&str; N_QUERIES] = [ + "select v from cozo where k = ?;", + "insert into cozo(k, v) values (?, ?) on conflict(k) do update set v=excluded.v;", + "delete from cozo where k = ?;", + "select 1 from cozo where k = ?;", + "select k, v from cozo where k >= ? and k < ? order by k;", +]; + +const GET_QUERY: usize = 0; +const PUT_QUERY: usize = 1; +const DEL_QUERY: usize = 2; +const EXISTS_QUERY: usize = 3; +const RANGE_QUERY: usize = 4; + impl Drop for SqliteTx<'_> { fn drop(&mut self) { if let Right(RwLockWriteGuard { .. }) = self.lock { if !self.committed { let query = r#"rollback;"#; - let _ = self.conn.execute(query); + let _ = self.conn.as_ref().unwrap().execute(query); } } + let mut pool = self.storage.pool.lock().unwrap(); + let conn = self.conn.take().unwrap(); + pool.push(conn) + } +} + +impl<'s> SqliteTx<'s> { + fn ensure_stmt(&self, idx: usize) { + let mut stmt = self.stmts[idx].borrow_mut(); + if stmt.is_none() { + let query = QUERIES[idx]; + let prepared = self.conn.as_ref().unwrap().prepare(query).unwrap(); + + // Casting away the lifetime! + // This is OK because we are abiding by the contract of the underlying C pointer, + // as required by Sqlite's implementation + let prepared = unsafe { std::mem::transmute(prepared) }; + + *stmt = Some(prepared) + } } } impl<'s> StoreTx<'s> for SqliteTx<'s> { fn get(&self, key: &[u8], _for_update: bool) -> Result>> { - let query = r#" - select v from cozo where k = ?; - "#; + self.ensure_stmt(GET_QUERY); + let mut statement = self.stmts[GET_QUERY].borrow_mut(); + let statement = statement.as_mut().unwrap(); + statement.reset().unwrap(); - let mut statement = self.conn.prepare(query).unwrap(); statement.bind((1, key)).unwrap(); Ok(match statement.next().into_diagnostic()? { State::Row => { @@ -140,12 +195,11 @@ impl<'s> StoreTx<'s> for SqliteTx<'s> { } fn put(&mut self, key: &[u8], val: &[u8]) -> Result<()> { - let query = r#" - insert into cozo(k, v) values (?, ?) - on conflict(k) do update set v=excluded.v; - "#; + self.ensure_stmt(PUT_QUERY); + let mut statement = self.stmts[PUT_QUERY].borrow_mut(); + let statement = statement.as_mut().unwrap(); + statement.reset().unwrap(); - let mut statement = self.conn.prepare(query).unwrap(); statement.bind((1, key)).unwrap(); statement.bind((2, val)).unwrap(); while statement.next().into_diagnostic()? != State::Done {} @@ -153,10 +207,11 @@ impl<'s> StoreTx<'s> for SqliteTx<'s> { } fn del(&mut self, key: &[u8]) -> Result<()> { - let query = r#" - delete from cozo where k = ?; - "#; - let mut statement = self.conn.prepare(query).unwrap(); + self.ensure_stmt(DEL_QUERY); + let mut statement = self.stmts[DEL_QUERY].borrow_mut(); + let statement = statement.as_mut().unwrap(); + statement.reset().unwrap(); + statement.bind((1, key)).unwrap(); while statement.next().into_diagnostic()? != State::Done {} @@ -164,10 +219,11 @@ impl<'s> StoreTx<'s> for SqliteTx<'s> { } fn exists(&self, key: &[u8], _for_update: bool) -> Result { - let query = r#" - select 1 from cozo where k = ?; - "#; - let mut statement = self.conn.prepare(query).unwrap(); + self.ensure_stmt(EXISTS_QUERY); + let mut statement = self.stmts[EXISTS_QUERY].borrow_mut(); + let statement = statement.as_mut().unwrap(); + statement.reset().unwrap(); + statement.bind((1, key)).unwrap(); Ok(match statement.next().into_diagnostic()? { State::Row => true, @@ -179,7 +235,7 @@ impl<'s> StoreTx<'s> for SqliteTx<'s> { if let Right(RwLockWriteGuard { .. }) = self.lock { if !self.committed { let query = r#"commit;"#; - let mut statement = self.conn.prepare(query).unwrap(); + let mut statement = self.conn.as_ref().unwrap().prepare(query).unwrap(); while statement.next().into_diagnostic()? != State::Done {} self.committed = true; } else { @@ -197,11 +253,10 @@ impl<'s> StoreTx<'s> for SqliteTx<'s> { where 's: 'a, { - let query = r#" - select k, v from cozo where k >= ? and k < ? - order by k; - "#; - let mut statement = self.conn.prepare(query).unwrap(); + // Range scans cannot use cached prepared statements, as several of them + // can be used at the same time. + let query = QUERIES[RANGE_QUERY]; + let mut statement = self.conn.as_ref().unwrap().prepare(query).unwrap(); statement.bind((1, lower)).unwrap(); statement.bind((2, upper)).unwrap(); Box::new(TupleIter(statement)) @@ -215,11 +270,8 @@ impl<'s> StoreTx<'s> for SqliteTx<'s> { where 's: 'a, { - let query = r#" - select k, v from cozo where k >= ? and k < ? - order by k; - "#; - let mut statement = self.conn.prepare(query).unwrap(); + let query = QUERIES[RANGE_QUERY]; + let mut statement = self.conn.as_ref().unwrap().prepare(query).unwrap(); statement.bind((1, lower)).unwrap(); statement.bind((2, upper)).unwrap(); Box::new(RawIter(statement)) @@ -232,18 +284,17 @@ impl<'s> StoreTx<'s> for SqliteTx<'s> { where 's: 'a, { - let query = r#" - insert into cozo(k, v) values (?, ?) - on conflict(k) do update set v=excluded.v; - "#; + self.ensure_stmt(PUT_QUERY); + let mut statement = self.stmts[PUT_QUERY].borrow_mut(); + let statement = statement.as_mut().unwrap(); + statement.reset().unwrap(); - let mut statement = self.conn.prepare(query).unwrap(); for pair in data { let (key, val) = pair?; statement.bind((1, key.as_slice())).unwrap(); statement.bind((2, val.as_slice())).unwrap(); while statement.next().into_diagnostic()? != State::Done {} - statement.reset().into_diagnostic()?; + statement.reset().unwrap(); } Ok(()) } diff --git a/cozo-core/tests/air_routes.rs b/cozo-core/tests/air_routes.rs index 14f65ccc..caade582 100644 --- a/cozo-core/tests/air_routes.rs +++ b/cozo-core/tests/air_routes.rs @@ -26,7 +26,7 @@ lazy_static! { _ = std::fs::remove_file(path); _ = std::fs::remove_dir_all(path); let db_kind = env::var("COZO_TEST_DB_ENGINE").unwrap_or("mem".to_string()); - + println!("Using {} engine", db_kind); let db = DbInstance::new(&db_kind, path, Default::default()).unwrap(); dbg!(creation.elapsed()); diff --git a/cozo-lib-c/cozo_c.h b/cozo-lib-c/cozo_c.h index 0827bfaa..f7f09a6f 100644 --- a/cozo-lib-c/cozo_c.h +++ b/cozo-lib-c/cozo_c.h @@ -64,6 +64,9 @@ char *cozo_run_query(int32_t db_id, const char *script_raw, const char *params_r /** * Import data into relations * + * Note that triggers are _not_ run for the relations, if any exists. + * If you need to activate triggers, use queries with parameters. + * * `db_id`: the ID representing the database. * `json_payload`: a UTF-8 encoded JSON payload, in the same form as returned by exporting relations. * @@ -108,6 +111,9 @@ char *cozo_restore(int32_t db_id, /** * Import data into relations from a backup * + * Note that triggers are _not_ run for the relations, if any exists. + * If you need to activate triggers, use queries with parameters. + * * `db_id`: the ID representing the database. * `json_payload`: a UTF-8 encoded JSON payload: `{"path": ..., "relations": [...]}` * diff --git a/cozoserver/src/main.rs b/cozoserver/src/main.rs index 25c5c90f..c1152ebc 100644 --- a/cozoserver/src/main.rs +++ b/cozoserver/src/main.rs @@ -69,7 +69,12 @@ fn main() { eprintln!("{}", SECURITY_WARNING); } - let db = DbInstance::new(args.engine.as_str(), args.path.as_str(), &args.config.clone()).unwrap(); + let db = DbInstance::new( + args.engine.as_str(), + args.path.as_str(), + &args.config.clone(), + ) + .unwrap(); if let Some(restore_path) = &args.restore { db.restore_backup(restore_path).unwrap(); @@ -140,13 +145,7 @@ fn main() { check_auth!(request, auth_guard); } - let relations = relations.split(",").filter_map(|t| { - if t.is_empty() { - None - } else { - Some(t) - } - }); + let relations = relations.split(',').filter(|t| !t.is_empty()); let result = db.export_relations(relations); match result { Ok(s) => { @@ -190,7 +189,7 @@ fn main() { let payload: BackupPayload = try_or_400!(rouille::input::json_input(request)); - let result = db.backup_db(payload.path.clone()); + let result = db.backup_db(payload.path); match result { Ok(()) => {