implemented callbacks

main
Ziyang Hu 2 years ago
parent 53bf0e53d1
commit 798facf94d

25
Cargo.lock generated

@ -551,6 +551,7 @@ dependencies = [
"chrono", "chrono",
"chrono-tz", "chrono-tz",
"cozorocks", "cozorocks",
"crossbeam",
"csv", "csv",
"document-features", "document-features",
"either", "either",
@ -699,6 +700,20 @@ dependencies = [
"cfg-if 1.0.0", "cfg-if 1.0.0",
] ]
[[package]]
name = "crossbeam"
version = "0.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2801af0d36612ae591caa9568261fddce32ce6e08a7275ea334a06a4ad021a2c"
dependencies = [
"cfg-if 1.0.0",
"crossbeam-channel",
"crossbeam-deque",
"crossbeam-epoch",
"crossbeam-queue",
"crossbeam-utils",
]
[[package]] [[package]]
name = "crossbeam-channel" name = "crossbeam-channel"
version = "0.5.6" version = "0.5.6"
@ -733,6 +748,16 @@ dependencies = [
"scopeguard", "scopeguard",
] ]
[[package]]
name = "crossbeam-queue"
version = "0.3.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d1cfb3ea8a53f37c40dea2c7bedcbd88bdfae54f5e2175d6ecaff1c988353add"
dependencies = [
"cfg-if 1.0.0",
"crossbeam-utils",
]
[[package]] [[package]]
name = "crossbeam-utils" name = "crossbeam-utils"
version = "0.8.14" version = "0.8.14"

@ -126,5 +126,6 @@ sqlite = { version = "0.30.1", optional = true }
sqlite3-src = { version = "0.4.0", optional = true, features = ["bundled"] } sqlite3-src = { version = "0.4.0", optional = true, features = ["bundled"] }
js-sys = { version = "0.3.60", optional = true } js-sys = { version = "0.3.60", optional = true }
graph = { version = "0.3.0", optional = true } graph = { version = "0.3.0", optional = true }
crossbeam = "0.8.2"
#redb = "0.9.0" #redb = "0.9.0"
#ouroboros = "0.15.5" #ouroboros = "0.15.5"

@ -6,12 +6,12 @@
* You can obtain one at https://mozilla.org/MPL/2.0/. * You can obtain one at https://mozilla.org/MPL/2.0/.
*/ */
use std::collections::BTreeMap; use std::collections::{BTreeMap, BTreeSet};
use std::sync::Arc; use std::sync::Arc;
use itertools::Itertools; use itertools::Itertools;
use miette::{bail, Diagnostic, Result, WrapErr}; use miette::{bail, Diagnostic, Result, WrapErr};
use smartstring::SmartString; use smartstring::{LazyCompact, SmartString};
use thiserror::Error; use thiserror::Error;
use crate::data::expr::Expr; use crate::data::expr::Expr;
@ -23,10 +23,11 @@ use crate::data::value::{DataValue, ValidityTs};
use crate::fixed_rule::utilities::constant::Constant; use crate::fixed_rule::utilities::constant::Constant;
use crate::fixed_rule::FixedRuleHandle; use crate::fixed_rule::FixedRuleHandle;
use crate::parse::parse_script; use crate::parse::parse_script;
use crate::runtime::db::CallbackOp;
use crate::runtime::relation::{AccessLevel, InputRelationHandle, InsufficientAccessLevel}; use crate::runtime::relation::{AccessLevel, InputRelationHandle, InsufficientAccessLevel};
use crate::runtime::transact::SessionTx; use crate::runtime::transact::SessionTx;
use crate::storage::Storage; use crate::storage::Storage;
use crate::{Db, StoreTx}; use crate::{Db, NamedRows, StoreTx};
#[derive(Debug, Error, Diagnostic)] #[derive(Debug, Error, Diagnostic)]
#[error("attempting to write into relation {0} of arity {1} with data of arity {2}")] #[error("attempting to write into relation {0} of arity {1} with data of arity {2}")]
@ -42,6 +43,11 @@ impl<'a> SessionTx<'a> {
meta: &InputRelationHandle, meta: &InputRelationHandle,
headers: &[Symbol], headers: &[Symbol],
cur_vld: ValidityTs, cur_vld: ValidityTs,
callback_targets: &BTreeSet<SmartString<LazyCompact>>,
callback_collector: &mut BTreeMap<
SmartString<LazyCompact>,
Vec<(CallbackOp, NamedRows, NamedRows)>,
>,
) -> Result<Vec<(Vec<u8>, Vec<u8>)>> { ) -> Result<Vec<(Vec<u8>, Vec<u8>)>> {
let mut to_clear = vec![]; let mut to_clear = vec![];
let mut replaced_old_triggers = None; let mut replaced_old_triggers = None;
@ -62,7 +68,9 @@ impl<'a> SessionTx<'a> {
parse_script(trigger, &Default::default(), &db.algorithms, cur_vld)? parse_script(trigger, &Default::default(), &db.algorithms, cur_vld)?
.get_single_program()?; .get_single_program()?;
let (_, cleanups) = db.run_query(self, program, cur_vld).map_err(|err| { let (_, cleanups) = db
.run_query(self, program, cur_vld, callback_targets, callback_collector)
.map_err(|err| {
if err.source_code().is_some() { if err.source_code().is_some() {
err err
} else { } else {
@ -92,6 +100,8 @@ impl<'a> SessionTx<'a> {
.. ..
} = meta; } = meta;
let is_callback_target = callback_targets.contains(&relation_store.name);
match op { match op {
RelationOp::Rm => { RelationOp::Rm => {
if relation_store.access_level < AccessLevel::Protected { if relation_store.access_level < AccessLevel::Protected {
@ -108,8 +118,8 @@ impl<'a> SessionTx<'a> {
headers, headers,
)?; )?;
let has_triggers = let need_to_collect = !relation_store.is_temp
!relation_store.is_temp && !relation_store.rm_triggers.is_empty(); && (is_callback_target || !relation_store.rm_triggers.is_empty());
let mut new_tuples: Vec<DataValue> = vec![]; let mut new_tuples: Vec<DataValue> = vec![];
let mut old_tuples: Vec<DataValue> = vec![]; let mut old_tuples: Vec<DataValue> = vec![];
@ -119,7 +129,7 @@ impl<'a> SessionTx<'a> {
.map(|ex| ex.extract_data(&tuple, cur_vld)) .map(|ex| ex.extract_data(&tuple, cur_vld))
.try_collect()?; .try_collect()?;
let key = relation_store.encode_key_for_store(&extracted, *span)?; let key = relation_store.encode_key_for_store(&extracted, *span)?;
if has_triggers { if need_to_collect {
if let Some(existing) = self.store_tx.get(&key, false)? { if let Some(existing) = self.store_tx.get(&key, false)? {
let mut tup = extracted.clone(); let mut tup = extracted.clone();
if !existing.is_empty() { if !existing.is_empty() {
@ -141,32 +151,45 @@ impl<'a> SessionTx<'a> {
} }
} }
if has_triggers && !new_tuples.is_empty() { if need_to_collect && !new_tuples.is_empty() {
for trigger in &relation_store.rm_triggers { let k_bindings = relation_store
let mut program =
parse_script(trigger, &Default::default(), &db.algorithms, cur_vld)?
.get_single_program()?;
let mut bindings = relation_store
.metadata .metadata
.keys .keys
.iter() .iter()
.map(|k| Symbol::new(k.name.clone(), Default::default())) .map(|k| Symbol::new(k.name.clone(), Default::default()))
.collect_vec(); .collect_vec();
make_const_rule(&mut program, "_new", bindings.clone(), new_tuples.clone());
let v_bindings = relation_store let v_bindings = relation_store
.metadata .metadata
.non_keys .non_keys
.iter() .iter()
.map(|k| Symbol::new(k.name.clone(), Default::default())); .map(|k| Symbol::new(k.name.clone(), Default::default()));
bindings.extend(v_bindings); let mut kv_bindings = k_bindings.clone();
kv_bindings.extend(v_bindings);
let kv_bindings = kv_bindings;
make_const_rule(&mut program, "_old", bindings, old_tuples.clone()); for trigger in &relation_store.rm_triggers {
let mut program =
parse_script(trigger, &Default::default(), &db.algorithms, cur_vld)?
.get_single_program()?;
make_const_rule(
&mut program,
"_new",
k_bindings.clone(),
new_tuples.clone(),
);
let (_, cleanups) = make_const_rule(
db.run_query(self, program, cur_vld).map_err(|err| { &mut program,
"_old",
kv_bindings.clone(),
old_tuples.clone(),
);
let (_, cleanups) = db
.run_query(self, program, cur_vld, callback_targets, callback_collector)
.map_err(|err| {
if err.source_code().is_some() { if err.source_code().is_some() {
err err
} else { } else {
@ -175,6 +198,41 @@ impl<'a> SessionTx<'a> {
})?; })?;
to_clear.extend(cleanups); to_clear.extend(cleanups);
} }
if is_callback_target {
let target_collector = callback_collector
.entry(relation_store.name.clone())
.or_default();
target_collector.push((
CallbackOp::Rm,
NamedRows {
headers: k_bindings
.into_iter()
.map(|k| k.name.to_string())
.collect_vec(),
rows: new_tuples
.into_iter()
.map(|v| match v {
DataValue::List(l) => l,
_ => unreachable!(),
})
.collect_vec(),
},
NamedRows {
headers: kv_bindings
.into_iter()
.map(|k| k.name.to_string())
.collect_vec(),
rows: old_tuples
.into_iter()
.map(|v| match v {
DataValue::List(l) => l,
_ => unreachable!(),
})
.collect_vec(),
},
))
}
} }
} }
RelationOp::Ensure => { RelationOp::Ensure => {
@ -288,8 +346,8 @@ impl<'a> SessionTx<'a> {
headers, headers,
)?; )?;
let has_triggers = let need_to_collect = !relation_store.is_temp
!relation_store.is_temp && !relation_store.put_triggers.is_empty(); && (is_callback_target || !relation_store.put_triggers.is_empty());
let mut new_tuples: Vec<DataValue> = vec![]; let mut new_tuples: Vec<DataValue> = vec![];
let mut old_tuples: Vec<DataValue> = vec![]; let mut old_tuples: Vec<DataValue> = vec![];
@ -310,7 +368,7 @@ impl<'a> SessionTx<'a> {
let key = relation_store.encode_key_for_store(&extracted, *span)?; let key = relation_store.encode_key_for_store(&extracted, *span)?;
let val = relation_store.encode_val_for_store(&extracted, *span)?; let val = relation_store.encode_val_for_store(&extracted, *span)?;
if has_triggers { if need_to_collect {
if let Some(existing) = self.store_tx.get(&key, false)? { if let Some(existing) = self.store_tx.get(&key, false)? {
let mut tup = extracted.clone(); let mut tup = extracted.clone();
let mut remaining = &existing[ENCODED_KEY_MIN_LEN..]; let mut remaining = &existing[ENCODED_KEY_MIN_LEN..];
@ -332,12 +390,7 @@ impl<'a> SessionTx<'a> {
} }
} }
if has_triggers && !new_tuples.is_empty() { if need_to_collect && !new_tuples.is_empty() {
for trigger in &relation_store.put_triggers {
let mut program =
parse_script(trigger, &Default::default(), &db.algorithms, cur_vld)?
.get_single_program()?;
let mut bindings = relation_store let mut bindings = relation_store
.metadata .metadata
.keys .keys
@ -351,11 +404,28 @@ impl<'a> SessionTx<'a> {
.map(|k| Symbol::new(k.name.clone(), Default::default())); .map(|k| Symbol::new(k.name.clone(), Default::default()));
bindings.extend(v_bindings); bindings.extend(v_bindings);
make_const_rule(&mut program, "_new", bindings.clone(), new_tuples.clone()); let kv_bindings = bindings;
make_const_rule(&mut program, "_old", bindings, old_tuples.clone()); for trigger in &relation_store.put_triggers {
let mut program =
parse_script(trigger, &Default::default(), &db.algorithms, cur_vld)?
.get_single_program()?;
make_const_rule(
&mut program,
"_new",
kv_bindings.clone(),
new_tuples.clone(),
);
make_const_rule(
&mut program,
"_old",
kv_bindings.clone(),
old_tuples.clone(),
);
let (_, cleanups) = let (_, cleanups) = db
db.run_query(self, program, cur_vld).map_err(|err| { .run_query(self, program, cur_vld, callback_targets, callback_collector)
.map_err(|err| {
if err.source_code().is_some() { if err.source_code().is_some() {
err err
} else { } else {
@ -364,6 +434,39 @@ impl<'a> SessionTx<'a> {
})?; })?;
to_clear.extend(cleanups); to_clear.extend(cleanups);
} }
if is_callback_target {
let target_collector = callback_collector
.entry(relation_store.name.clone())
.or_default();
let headers = kv_bindings
.into_iter()
.map(|k| k.name.to_string())
.collect_vec();
target_collector.push((
CallbackOp::Rm,
NamedRows {
headers: headers.clone(),
rows: new_tuples
.into_iter()
.map(|v| match v {
DataValue::List(l) => l,
_ => unreachable!(),
})
.collect_vec(),
},
NamedRows {
headers,
rows: old_tuples
.into_iter()
.map(|v| match v {
DataValue::List(l) => l,
_ => unreachable!(),
})
.collect_vec(),
},
))
}
} }
} }
}; };

@ -7,17 +7,18 @@
*/ */
use std::collections::btree_map::Entry; use std::collections::btree_map::Entry;
use std::collections::BTreeMap; use std::collections::{BTreeMap, BTreeSet};
use std::default::Default; use std::default::Default;
use std::fmt::{Debug, Formatter}; use std::fmt::{Debug, Formatter};
use std::path::Path; use std::path::Path;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering};
use std::sync::{Arc, Mutex, RwLock}; use std::sync::{Arc, Mutex, RwLock};
#[allow(unused_imports)] #[allow(unused_imports)]
use std::thread; use std::thread;
#[allow(unused_imports)] #[allow(unused_imports)]
use std::time::{Duration, SystemTime, UNIX_EPOCH}; use std::time::{Duration, SystemTime, UNIX_EPOCH};
use crossbeam::channel::{unbounded, Sender};
use either::{Either, Left, Right}; use either::{Either, Left, Right};
use itertools::Itertools; use itertools::Itertools;
#[allow(unused_imports)] #[allow(unused_imports)]
@ -81,7 +82,10 @@ pub enum CallbackOp {
Rm, Rm,
} }
pub type TxCallback = Box<dyn FnMut(CallbackOp, Tuple) + Send + Sync>; pub struct CallbackDeclaration {
dependent: SmartString<LazyCompact>,
callback: Box<dyn Fn(CallbackOp, NamedRows, NamedRows) + Send + Sync>,
}
/// The database object of Cozo. /// The database object of Cozo.
#[derive(Clone)] #[derive(Clone)]
@ -92,8 +96,14 @@ pub struct Db<S> {
queries_count: Arc<AtomicU64>, queries_count: Arc<AtomicU64>,
running_queries: Arc<Mutex<BTreeMap<u64, RunningQueryHandle>>>, running_queries: Arc<Mutex<BTreeMap<u64, RunningQueryHandle>>>,
pub(crate) algorithms: Arc<BTreeMap<String, Arc<Box<dyn FixedRule>>>>, pub(crate) algorithms: Arc<BTreeMap<String, Arc<Box<dyn FixedRule>>>>,
callback_count: Arc<AtomicU64>, callback_count: Arc<AtomicU32>,
event_callbacks: Arc<RwLock<BTreeMap<String, BTreeMap<u64, TxCallback>>>>, callback_sender: Sender<(SmartString<LazyCompact>, CallbackOp, NamedRows, NamedRows)>,
event_callbacks: Arc<
RwLock<(
BTreeMap<u32, CallbackDeclaration>,
BTreeMap<SmartString<LazyCompact>, BTreeSet<u32>>,
)>,
>,
} }
impl<S> Debug for Db<S> { impl<S> Debug for Db<S> {
@ -149,6 +159,7 @@ impl<'s, S: Storage<'s>> Db<S> {
/// You must call [`initialize`](Self::initialize) immediately after creation. /// You must call [`initialize`](Self::initialize) immediately after creation.
/// Due to lifetime restrictions we are not able to call that for you automatically. /// Due to lifetime restrictions we are not able to call that for you automatically.
pub fn new(storage: S) -> Result<Self> { pub fn new(storage: S) -> Result<Self> {
let (sender, receiver) = unbounded();
let ret = Self { let ret = Self {
db: storage, db: storage,
temp_db: Default::default(), temp_db: Default::default(),
@ -157,8 +168,23 @@ impl<'s, S: Storage<'s>> Db<S> {
running_queries: Arc::new(Mutex::new(Default::default())), running_queries: Arc::new(Mutex::new(Default::default())),
algorithms: DEFAULT_FIXED_RULES.clone(), algorithms: DEFAULT_FIXED_RULES.clone(),
callback_count: Arc::new(Default::default()), callback_count: Arc::new(Default::default()),
callback_sender: sender,
// callback_receiver: Arc::new(receiver),
event_callbacks: Arc::new(Default::default()), event_callbacks: Arc::new(Default::default()),
}; };
let callbacks = ret.event_callbacks.clone();
thread::spawn(move || {
while let Ok((table, op, new, old)) = receiver.recv() {
let (cbs, cb_dir) = &*callbacks.read().unwrap();
if let Some(cb_ids) = cb_dir.get(&table) {
for cb_id in cb_ids {
if let Some(cb) = cbs.get(cb_id) {
(cb.callback)(op, new.clone(), old.clone())
}
}
}
}
});
Ok(ret) Ok(ret)
} }
@ -392,7 +418,11 @@ impl<'s, S: Storage<'s>> Db<S> {
/// Note that triggers are _not_ run for the relations, if any exists. /// Note that triggers are _not_ run for the relations, if any exists.
/// If you need to activate triggers, use queries with parameters. /// If you need to activate triggers, use queries with parameters.
#[allow(unused_variables)] #[allow(unused_variables)]
pub fn import_from_backup(&'s self, in_file: impl AsRef<Path>, relations: &[String]) -> Result<()> { pub fn import_from_backup(
&'s self,
in_file: impl AsRef<Path>,
relations: &[String],
) -> Result<()> {
#[cfg(not(feature = "storage-sqlite"))] #[cfg(not(feature = "storage-sqlite"))]
bail!("backup requires the 'storage-sqlite' feature to be enabled"); bail!("backup requires the 'storage-sqlite' feature to be enabled");
@ -435,48 +465,64 @@ impl<'s, S: Storage<'s>> Db<S> {
dst_tx.commit_tx() dst_tx.commit_tx()
} }
} }
/// Register a custom fixed rule implementation /// Register a custom fixed rule implementation.
///
/// You must register fixed rules BEFORE you clone the database,
/// otherwise already cloned instances will not get the new fixed rule.
pub fn register_fixed_rule( pub fn register_fixed_rule(
&mut self, &mut self,
name: String, name: String,
rule_impl: Box<dyn FixedRule>, rule_impl: Box<dyn FixedRule>,
) -> Result<()> { ) -> Result<()> {
let inner = Arc::make_mut(&mut self.algorithms); let new = Arc::make_mut(&mut self.algorithms);
match inner.entry(name) { match new.entry(name) {
Entry::Vacant(ent) => { Entry::Vacant(ent) => {
ent.insert(Arc::new(rule_impl)); ent.insert(Arc::new(rule_impl));
Ok(()) Ok(())
} }
Entry::Occupied(ent) => { Entry::Occupied(ent) => {
bail!("A fixed rule with the name {} is already loaded", ent.key()) bail!("A fixed rule with the name {} is already registered", ent.key())
} }
} }
} }
/// Register callbacks to run when changes to relations are committed. /// Register callbacks to run when changes to relations are committed.
/// The returned ID can be used to unregister the callbacks. /// The returned ID can be used to unregister the callbacks.
/// It is OK to register callbacks for relations that do not exist (yet).
/// TODO: not yet implemented
#[allow(dead_code)] #[allow(dead_code)]
pub(crate) fn register_callback(&self, relation: String, cb: TxCallback) -> u64 { pub(crate) fn register_callback<CB>(&self, callback: CB, dependent: &str) -> Result<u32>
let id = self.callback_count.fetch_add(1, Ordering::AcqRel); where
CB: Fn(CallbackOp, NamedRows, NamedRows) + Send + Sync + 'static,
{
let cb = CallbackDeclaration {
dependent: SmartString::from(dependent),
callback: Box::new(callback),
};
let mut guard = self.event_callbacks.write().unwrap(); let mut guard = self.event_callbacks.write().unwrap();
let entries = guard.entry(relation).or_default(); let new_id = self.callback_count.fetch_add(1, Ordering::SeqCst);
entries.insert(id, cb); guard
id .1
.entry(SmartString::from(dependent))
.or_default()
.insert(new_id);
guard.0.insert(new_id, cb);
Ok(new_id)
} }
/// Unregister callbacks to run when changes to relations are committed. /// Unregister callbacks to run when changes to relations are committed.
#[allow(dead_code)] #[allow(dead_code)]
pub(crate) fn unregister_callback(&self, relation: String, id: u64) -> bool { pub(crate) fn unregister_callback(&self, id: u32) -> bool {
let mut guard = self.event_callbacks.write().unwrap(); let mut guard = self.event_callbacks.write().unwrap();
match guard.entry(relation) { let ret = guard.0.remove(&id);
Entry::Vacant(_) => false, if let Some(cb) = &ret {
Entry::Occupied(mut ent) => { guard.1.get_mut(&cb.dependent).unwrap().remove(&id);
let entries = ent.get_mut();
entries.remove(&id).is_some() if guard.1.get(&cb.dependent).unwrap().is_empty() {
guard.1.remove(&cb.dependent);
} }
} }
ret.is_some()
} }
fn compact_relation(&'s self) -> Result<()> { fn compact_relation(&'s self) -> Result<()> {
@ -518,13 +564,25 @@ impl<'s, S: Storage<'s>> Db<S> {
cleanups: &mut Vec<(Vec<u8>, Vec<u8>)>, cleanups: &mut Vec<(Vec<u8>, Vec<u8>)>,
cur_vld: ValidityTs, cur_vld: ValidityTs,
span: SourceSpan, span: SourceSpan,
callback_targets: &BTreeSet<SmartString<LazyCompact>>,
callback_collector: &mut BTreeMap<
SmartString<LazyCompact>,
Vec<(CallbackOp, NamedRows, NamedRows)>,
>,
) -> Result<bool> { ) -> Result<bool> {
let res = match p { let res = match p {
Left(rel) => { Left(rel) => {
let relation = tx.get_relation(rel, false)?; let relation = tx.get_relation(rel, false)?;
relation.as_named_rows(tx)? relation.as_named_rows(tx)?
} }
Right(p) => self.execute_single_program(p.clone(), tx, cleanups, cur_vld)?, Right(p) => self.execute_single_program(
p.clone(),
tx,
cleanups,
cur_vld,
callback_targets,
callback_collector,
)?,
}; };
Ok(match res.rows.first() { Ok(match res.rows.first() {
None => false, None => false,
@ -545,10 +603,16 @@ impl<'s, S: Storage<'s>> Db<S> {
tx: &mut SessionTx<'_>, tx: &mut SessionTx<'_>,
cleanups: &mut Vec<(Vec<u8>, Vec<u8>)>, cleanups: &mut Vec<(Vec<u8>, Vec<u8>)>,
cur_vld: ValidityTs, cur_vld: ValidityTs,
callback_targets: &BTreeSet<SmartString<LazyCompact>>,
callback_collector: &mut BTreeMap<
SmartString<LazyCompact>,
Vec<(CallbackOp, NamedRows, NamedRows)>,
>,
) -> Result<NamedRows> { ) -> Result<NamedRows> {
#[allow(unused_variables)] #[allow(unused_variables)]
let sleep_opt = p.out_opts.sleep; let sleep_opt = p.out_opts.sleep;
let (q_res, q_cleanups) = self.run_query(tx, p, cur_vld)?; let (q_res, q_cleanups) =
self.run_query(tx, p, cur_vld, callback_targets, callback_collector)?;
cleanups.extend(q_cleanups); cleanups.extend(q_cleanups);
#[cfg(not(target_arch = "wasm32"))] #[cfg(not(target_arch = "wasm32"))]
if let Some(secs) = sleep_opt { if let Some(secs) = sleep_opt {
@ -556,15 +620,42 @@ impl<'s, S: Storage<'s>> Db<S> {
} }
Ok(q_res) Ok(q_res)
} }
fn current_callback_targets(&self) -> BTreeSet<SmartString<LazyCompact>> {
self.event_callbacks
.read()
.unwrap()
.1
.keys()
.cloned()
.collect()
}
fn send_callbacks(
&'s self,
collector: BTreeMap<SmartString<LazyCompact>, Vec<(CallbackOp, NamedRows, NamedRows)>>,
) {
for (k, vals) in collector {
for (op, new, old) in vals {
self.callback_sender
.send((k.clone(), op, new, old))
.expect("sending to callback processor failed");
}
}
}
fn do_run_script( fn do_run_script(
&'s self, &'s self,
payload: &str, payload: &str,
param_pool: &BTreeMap<String, DataValue>, param_pool: &BTreeMap<String, DataValue>,
cur_vld: ValidityTs, cur_vld: ValidityTs,
) -> Result<NamedRows> { ) -> Result<NamedRows> {
let mut callback_collector = BTreeMap::new();
match parse_script(payload, param_pool, &self.algorithms, cur_vld)? { match parse_script(payload, param_pool, &self.algorithms, cur_vld)? {
CozoScript::Single(p) => { CozoScript::Single(p) => {
let is_write = p.needs_write_tx(); let is_write = p.needs_write_tx();
let callback_targets = if is_write {
self.current_callback_targets()
} else {
Default::default()
};
let mut cleanups = vec![]; let mut cleanups = vec![];
let res; let res;
{ {
@ -574,7 +665,14 @@ impl<'s, S: Storage<'s>> Db<S> {
self.transact()? self.transact()?
}; };
res = self.execute_single_program(p, &mut tx, &mut cleanups, cur_vld)?; res = self.execute_single_program(
p,
&mut tx,
&mut cleanups,
cur_vld,
&callback_targets,
&mut callback_collector,
)?;
if is_write { if is_write {
tx.commit_tx()?; tx.commit_tx()?;
@ -591,6 +689,11 @@ impl<'s, S: Storage<'s>> Db<S> {
} }
CozoScript::Imperative(ps) => { CozoScript::Imperative(ps) => {
let is_write = ps.iter().any(|p| p.needs_write_tx()); let is_write = ps.iter().any(|p| p.needs_write_tx());
let callback_targets = if is_write {
self.current_callback_targets()
} else {
Default::default()
};
let mut cleanups: Vec<(Vec<u8>, Vec<u8>)> = vec![]; let mut cleanups: Vec<(Vec<u8>, Vec<u8>)> = vec![];
let ret; let ret;
{ {
@ -599,7 +702,14 @@ impl<'s, S: Storage<'s>> Db<S> {
} else { } else {
self.transact()? self.transact()?
}; };
match self.execute_imperative_stmts(&ps, &mut tx, &mut cleanups, cur_vld)? { match self.execute_imperative_stmts(
&ps,
&mut tx,
&mut cleanups,
cur_vld,
&callback_targets,
&mut callback_collector,
)? {
Left(res) => ret = res, Left(res) => ret = res,
Right(ctrl) => match ctrl { Right(ctrl) => match ctrl {
ControlCode::Termination(res) => { ControlCode::Termination(res) => {
@ -623,6 +733,10 @@ impl<'s, S: Storage<'s>> Db<S> {
assert!(cleanups.is_empty(), "non-empty cleanups on read-only tx"); assert!(cleanups.is_empty(), "non-empty cleanups on read-only tx");
} }
} }
if !callback_collector.is_empty() {
self.send_callbacks(callback_collector)
}
for (lower, upper) in cleanups { for (lower, upper) in cleanups {
self.db.del_range(&lower, &upper)?; self.db.del_range(&lower, &upper)?;
} }
@ -637,6 +751,11 @@ impl<'s, S: Storage<'s>> Db<S> {
tx: &mut SessionTx<'_>, tx: &mut SessionTx<'_>,
cleanups: &mut Vec<(Vec<u8>, Vec<u8>)>, cleanups: &mut Vec<(Vec<u8>, Vec<u8>)>,
cur_vld: ValidityTs, cur_vld: ValidityTs,
callback_targets: &BTreeSet<SmartString<LazyCompact>>,
callback_collector: &mut BTreeMap<
SmartString<LazyCompact>,
Vec<(CallbackOp, NamedRows, NamedRows)>,
>,
) -> Result<Either<NamedRows, ControlCode>> { ) -> Result<Either<NamedRows, ControlCode>> {
let mut ret = NamedRows::default(); let mut ret = NamedRows::default();
for p in ps { for p in ps {
@ -651,7 +770,14 @@ impl<'s, S: Storage<'s>> Db<S> {
return Ok(Right(ControlCode::Termination(NamedRows::default()))) return Ok(Right(ControlCode::Termination(NamedRows::default())))
} }
ImperativeStmt::ReturnProgram { prog, .. } => { ImperativeStmt::ReturnProgram { prog, .. } => {
ret = self.execute_single_program(prog.clone(), tx, cleanups, cur_vld)?; ret = self.execute_single_program(
prog.clone(),
tx,
cleanups,
cur_vld,
callback_targets,
callback_collector,
)?;
return Ok(Right(ControlCode::Termination(ret))); return Ok(Right(ControlCode::Termination(ret)));
} }
ImperativeStmt::ReturnTemp { rel, .. } => { ImperativeStmt::ReturnTemp { rel, .. } => {
@ -664,10 +790,24 @@ impl<'s, S: Storage<'s>> Db<S> {
ret = NamedRows::default(); ret = NamedRows::default();
} }
ImperativeStmt::Program { prog, .. } => { ImperativeStmt::Program { prog, .. } => {
ret = self.execute_single_program(prog.clone(), tx, cleanups, cur_vld)?; ret = self.execute_single_program(
prog.clone(),
tx,
cleanups,
cur_vld,
callback_targets,
callback_collector,
)?;
} }
ImperativeStmt::IgnoreErrorProgram { prog, .. } => { ImperativeStmt::IgnoreErrorProgram { prog, .. } => {
match self.execute_single_program(prog.clone(), tx, cleanups, cur_vld) { match self.execute_single_program(
prog.clone(),
tx,
cleanups,
cur_vld,
callback_targets,
callback_collector,
) {
Ok(res) => ret = res, Ok(res) => ret = res,
Err(_) => { Err(_) => {
ret = NamedRows { ret = NamedRows {
@ -684,11 +824,25 @@ impl<'s, S: Storage<'s>> Db<S> {
span, span,
negated, negated,
} => { } => {
let cond_val = let cond_val = self.execute_imperative_condition(
self.execute_imperative_condition(condition, tx, cleanups, cur_vld, *span)?; condition,
tx,
cleanups,
cur_vld,
*span,
callback_targets,
callback_collector,
)?;
let cond_val = if *negated { !cond_val } else { cond_val }; let cond_val = if *negated { !cond_val } else { cond_val };
let to_execute = if cond_val { then_branch } else { else_branch }; let to_execute = if cond_val { then_branch } else { else_branch };
match self.execute_imperative_stmts(to_execute, tx, cleanups, cur_vld)? { match self.execute_imperative_stmts(
to_execute,
tx,
cleanups,
cur_vld,
callback_targets,
callback_collector,
)? {
Left(rows) => { Left(rows) => {
ret = rows; ret = rows;
} }
@ -698,7 +852,14 @@ impl<'s, S: Storage<'s>> Db<S> {
ImperativeStmt::Loop { label, body, .. } => { ImperativeStmt::Loop { label, body, .. } => {
ret = Default::default(); ret = Default::default();
loop { loop {
match self.execute_imperative_stmts(body, tx, cleanups, cur_vld)? { match self.execute_imperative_stmts(
body,
tx,
cleanups,
cur_vld,
callback_targets,
callback_collector,
)? {
Left(_) => {} Left(_) => {}
Right(ctrl) => match ctrl { Right(ctrl) => match ctrl {
ControlCode::Termination(ret) => { ControlCode::Termination(ret) => {
@ -1055,6 +1216,11 @@ impl<'s, S: Storage<'s>> Db<S> {
tx: &mut SessionTx<'_>, tx: &mut SessionTx<'_>,
input_program: InputProgram, input_program: InputProgram,
cur_vld: ValidityTs, cur_vld: ValidityTs,
callback_targets: &BTreeSet<SmartString<LazyCompact>>,
callback_collector: &mut BTreeMap<
SmartString<LazyCompact>,
Vec<(CallbackOp, NamedRows, NamedRows)>,
>,
) -> Result<(NamedRows, Vec<(Vec<u8>, Vec<u8>)>)> { ) -> Result<(NamedRows, Vec<(Vec<u8>, Vec<u8>)>)> {
// cleanups contain stored relations that should be deleted at the end of query // cleanups contain stored relations that should be deleted at the end of query
let mut clean_ups = vec![]; let mut clean_ups = vec![];
@ -1197,6 +1363,8 @@ impl<'s, S: Storage<'s>> Db<S> {
meta, meta,
&entry_head_or_default, &entry_head_or_default,
cur_vld, cur_vld,
callback_targets,
callback_collector,
) )
.wrap_err_with(|| format!("when executing against relation '{}'", meta.name))?; .wrap_err_with(|| format!("when executing against relation '{}'", meta.name))?;
clean_ups.extend(to_clear); clean_ups.extend(to_clear);
@ -1249,6 +1417,8 @@ impl<'s, S: Storage<'s>> Db<S> {
meta, meta,
&entry_head_or_default, &entry_head_or_default,
cur_vld, cur_vld,
callback_targets,
callback_collector,
) )
.wrap_err_with(|| format!("when executing against relation '{}'", meta.name))?; .wrap_err_with(|| format!("when executing against relation '{}'", meta.name))?;
clean_ups.extend(to_clear); clean_ups.extend(to_clear);

Loading…
Cancel
Save