From e726141c1dbc6bd81b2b7e51b5ee1cf13b758686 Mon Sep 17 00:00:00 2001 From: Ziyang Hu Date: Sun, 13 Nov 2022 18:20:25 +0800 Subject: [PATCH] safer sqlite --- cozo-core/src/storage/sqlite.rs | 95 +++++++++++++++++++++++---------- 1 file changed, 68 insertions(+), 27 deletions(-) diff --git a/cozo-core/src/storage/sqlite.rs b/cozo-core/src/storage/sqlite.rs index 5cc1751d..58acd5d9 100644 --- a/cozo-core/src/storage/sqlite.rs +++ b/cozo-core/src/storage/sqlite.rs @@ -2,8 +2,11 @@ * Copyright 2022, The Cozo Project Authors. Licensed under MPL-2.0. */ +use std::sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard}; + use ::sqlite::Connection; -use miette::{miette, IntoDiagnostic, Result, WrapErr}; +use either::{Either, Left, Right}; +use miette::{bail, miette, IntoDiagnostic, Result, WrapErr}; use sqlite::{State, Statement}; use crate::data::tuple::Tuple; @@ -12,12 +15,13 @@ use crate::storage::{Storage, StoreTx}; /// The Sqlite storage engine pub struct SqliteStorage { - path: String, + lock: Arc>, + name: String, } /// create a sqlite backed database. `:memory:` is not OK. pub fn new_cozo_sqlite(path: String) -> Result> { - let connection = sqlite::open(&path).into_diagnostic()?; + let conn = sqlite::open(&path).into_diagnostic()?; let query = r#" create table if not exists cozo ( @@ -25,41 +29,55 @@ pub fn new_cozo_sqlite(path: String) -> Result> { v BLOB ); "#; - let mut statement = connection.prepare(query).unwrap(); + let mut statement = conn.prepare(query).unwrap(); while statement.next().into_diagnostic()? != State::Done {} - let ret = crate::Db::new(SqliteStorage { path })?; + let ret = crate::Db::new(SqliteStorage { + lock: Arc::new(Default::default()), + name: path, + })?; ret.initialize()?; Ok(ret) } -impl Storage<'_> for SqliteStorage { - type Tx = SqliteTx; - - fn transact(&'_ self, _write: bool) -> Result { - let conn = sqlite::open(&self.path).into_diagnostic()?; - { - let query = r#"begin;"#; - let mut statement = conn.prepare(query).unwrap(); - while statement.next().unwrap() != State::Done {} +impl<'s> Storage<'s> for SqliteStorage { + type Tx = SqliteTx<'s>; + + fn transact(&'s self, write: bool) -> Result { + let conn = sqlite::open(&self.name).into_diagnostic()?; + let lock = if write { + Right(self.lock.write().unwrap()) + } else { + Left(self.lock.read().unwrap()) + }; + if write { + let mut stmt = conn.prepare("begin;").into_diagnostic()?; + while stmt.next().into_diagnostic()? != State::Done {} } - - Ok(SqliteTx { conn }) + Ok(SqliteTx { + lock, + conn, + committed: false, + }) } fn del_range(&'_ self, lower: &[u8], upper: &[u8]) -> Result<()> { let lower_b = lower.to_vec(); let upper_b = upper.to_vec(); - let path = self.path.clone(); - let connection = sqlite::open(path).unwrap(); let query = r#" delete from cozo where k >= ? and k < ?; "#; - let mut statement = connection.prepare(query).unwrap(); - statement.bind((1, &lower_b as &[u8])).unwrap(); - statement.bind((2, &upper_b as &[u8])).unwrap(); - while statement.next().unwrap() != State::Done {} + let lock = self.lock.clone(); + let name = self.name.clone(); + std::thread::spawn(move || { + let _locked = lock.write().unwrap(); + let conn = sqlite::open(&name).unwrap(); + let mut statement = conn.prepare(query).unwrap(); + statement.bind((1, &lower_b as &[u8])).unwrap(); + statement.bind((2, &upper_b as &[u8])).unwrap(); + while statement.next().unwrap() != State::Done {} + }); Ok(()) } @@ -68,15 +86,29 @@ impl Storage<'_> for SqliteStorage { } } -pub struct SqliteTx { +pub struct SqliteTx<'a> { + lock: Either, RwLockWriteGuard<'a, ()>>, conn: Connection, + committed: bool, } -impl<'s> StoreTx<'s> for SqliteTx { +impl Drop for SqliteTx<'_> { + fn drop(&mut self) { + if let Right(RwLockWriteGuard { .. }) = self.lock { + if !self.committed { + let query = r#"rollback;"#; + let _ = self.conn.execute(query); + } + } + } +} + +impl<'s> StoreTx<'s> for SqliteTx<'s> { fn get(&self, key: &[u8], _for_update: bool) -> Result>> { let query = r#" select v from cozo where k = ?; "#; + let mut statement = self.conn.prepare(query).unwrap(); statement.bind((1, key)).unwrap(); Ok(match statement.next().into_diagnostic()? { @@ -93,6 +125,7 @@ impl<'s> StoreTx<'s> for SqliteTx { insert into cozo(k, v) values (?, ?) on conflict(k) do update set v=excluded.v; "#; + let mut statement = self.conn.prepare(query).unwrap(); statement.bind((1, key)).unwrap(); statement.bind((2, val)).unwrap(); @@ -112,6 +145,7 @@ impl<'s> StoreTx<'s> for SqliteTx { let mut statement = self.conn.prepare(query).unwrap(); statement.bind((1, key)).unwrap(); while statement.next().into_diagnostic()? != State::Done {} + Ok(()) } @@ -128,9 +162,16 @@ impl<'s> StoreTx<'s> for SqliteTx { } fn commit(&mut self) -> Result<()> { - let query = r#"commit;"#; - let mut statement = self.conn.prepare(query).unwrap(); - while statement.next().unwrap() != State::Done {} + if let Right(RwLockWriteGuard { .. }) = self.lock { + if !self.committed { + let query = r#"commit;"#; + let mut statement = self.conn.prepare(query).unwrap(); + while statement.next().into_diagnostic()? != State::Done {} + self.committed = true; + } else { + bail!("multiple commits") + } + } Ok(()) }