diff --git a/cozo-core/src/cozoscript.pest b/cozo-core/src/cozoscript.pest index 77321231..493d82b3 100644 --- a/cozo-core/src/cozoscript.pest +++ b/cozo-core/src/cozoscript.pest @@ -13,11 +13,13 @@ query_script_inner_no_bracket = { (option | rule | const_rule | fixed_rule)+ } imperative_script = {SOI ~ imperative_stmt+ ~ EOI} sys_script = {SOI ~ "::" ~ (list_relations_op | list_relation_op | remove_relations_op | trigger_relation_op | trigger_relation_show_op | rename_relations_op | running_op | kill_op | explain_op | - access_level_op | index_op | vec_idx_op | compact_op | list_fixed_rules) ~ EOI} + access_level_op | index_op | vec_idx_op | fts_idx_op | lsh_idx_op | compact_op | list_fixed_rules) ~ EOI} index_op = {"index" ~ (index_create | index_drop)} -vec_idx_op = {"hnsw" ~ (index_create_hnsw | index_drop)} +vec_idx_op = {"hnsw" ~ (index_create_adv | index_drop)} +fts_idx_op = {"fts" ~ (index_create_adv | index_drop)} +lsh_idx_op = {"lsh" ~ (index_create_adv | index_drop)} index_create = {"create" ~ compound_ident ~ ":" ~ ident ~ "{" ~ (ident ~ ",")* ~ ident? ~ "}"} -index_create_hnsw = {"create" ~ compound_ident ~ ":" ~ ident ~ "{" ~ (index_opt_field ~ ",")* ~ index_opt_field? ~ "}"} +index_create_adv = {"create" ~ compound_ident ~ ":" ~ ident ~ "{" ~ (index_opt_field ~ ",")* ~ index_opt_field? ~ "}"} index_drop = {"drop" ~ compound_ident ~ ":" ~ ident } compact_op = {"compact"} list_fixed_rules = {"fixed_rules"} diff --git a/cozo-core/src/data/expr.rs b/cozo-core/src/data/expr.rs index 580f2120..c9cb3608 100644 --- a/cozo-core/src/data/expr.rs +++ b/cozo-core/src/data/expr.rs @@ -825,6 +825,7 @@ pub(crate) fn get_op(name: &str) -> Option<&'static Op> { "regex_replace_all" => &OP_REGEX_REPLACE_ALL, "regex_extract" => &OP_REGEX_EXTRACT, "regex_extract_first" => &OP_REGEX_EXTRACT_FIRST, + "t2s" => &OP_T2S, "encode_base64" => &OP_ENCODE_BASE64, "decode_base64" => &OP_DECODE_BASE64, "first" => &OP_FIRST, diff --git a/cozo-core/src/data/functions.rs b/cozo-core/src/data/functions.rs index d03d46eb..aca38beb 100644 --- a/cozo-core/src/data/functions.rs +++ b/cozo-core/src/data/functions.rs @@ -137,7 +137,7 @@ fn get_json_path<'a>( .get_int() .ok_or_else(|| miette!("json path must be a string or a number"))? as usize; - if arr.len() >= key + 1 { + if arr.len() <= key + 1 { arr.resize_with(key + 1, || JsonValue::Null); } @@ -1445,6 +1445,14 @@ pub(crate) fn op_regex_extract_first(args: &[DataValue]) -> Result { } } +define_op!(OP_T2S, 1, false); +fn op_t2s(args: &[DataValue]) -> Result { + Ok(match &args[0] { + DataValue::Str(s) => DataValue::Str(fast2s::convert(s).into()), + d => d.clone(), + }) +} + define_op!(OP_IS_NULL, 1, false); pub(crate) fn op_is_null(args: &[DataValue]) -> Result { Ok(DataValue::from(matches!(args[0], DataValue::Null))) @@ -1753,7 +1761,7 @@ fn get_impl(args: &[DataValue]) -> Result { .get_int() .ok_or_else(|| miette!("second argument to 'get' mut be an integer"))?; let idx = get_index(n, l.len())?; - return Ok(l[idx].clone()); + Ok(l[idx].clone()) } DataValue::Json(json) => { let res = match &args[1] { @@ -1770,8 +1778,7 @@ fn get_impl(args: &[DataValue]) -> Result { .clone() } DataValue::List(l) => { - let mut v = json.clone(); - get_json_path_immutable(&mut v, l)?.clone() + get_json_path_immutable(json, l)?.clone() } _ => bail!("second argument to 'get' mut be a string or integer"), }; diff --git a/cozo-core/src/fts/mod.rs b/cozo-core/src/fts/mod.rs index 00b0b22a..ae99d51d 100644 --- a/cozo-core/src/fts/mod.rs +++ b/cozo-core/src/fts/mod.rs @@ -263,7 +263,7 @@ impl TokenizerCache { hashed_cache.insert(hash.as_ref().to_vec(), analyzer.clone()); let mut idx_cache = self.named_cache.write().unwrap(); idx_cache.insert(tokenizer_name.into(), analyzer.clone()); - return Ok(analyzer); + Ok(analyzer) } } } diff --git a/cozo-core/src/parse/sys.rs b/cozo-core/src/parse/sys.rs index bb81b339..2edc6ae6 100644 --- a/cozo-core/src/parse/sys.rs +++ b/cozo-core/src/parse/sys.rs @@ -10,7 +10,7 @@ use std::collections::BTreeMap; use std::sync::Arc; use itertools::Itertools; -use miette::{ensure, miette, Diagnostic, Result, bail}; +use miette::{bail, ensure, miette, Diagnostic, Result}; use smartstring::{LazyCompact, SmartString}; use thiserror::Error; @@ -54,7 +54,7 @@ pub(crate) struct HnswIndexConfig { pub(crate) m_neighbours: usize, pub(crate) index_filter: Option, pub(crate) extend_candidates: bool, - pub(crate) keep_pruned_connections: bool + pub(crate) keep_pruned_connections: bool, } #[derive( @@ -175,10 +175,16 @@ pub(crate) fn parse_sys( } SysOp::SetTriggers(rel, puts, rms, replaces) } + Rule::fts_idx_op => { + todo!() + } + Rule::lsh_idx_op => { + todo!() + } Rule::vec_idx_op => { let inner = inner.into_inner().next().unwrap(); match inner.as_rule() { - Rule::index_create_hnsw => { + Rule::index_create_adv => { let mut inner = inner.into_inner(); let rel = inner.next().unwrap(); let name = inner.next().unwrap(); diff --git a/cozo-core/src/query/stored.rs b/cozo-core/src/query/stored.rs index 2ca870ef..298fbe62 100644 --- a/cozo-core/src/query/stored.rs +++ b/cozo-core/src/query/stored.rs @@ -15,9 +15,9 @@ use pest::Parser; use smartstring::{LazyCompact, SmartString}; use thiserror::Error; -use crate::data::expr::Expr; +use crate::data::expr::{Bytecode, Expr}; use crate::data::program::{FixedRuleApply, InputInlineRulesOrFixed, InputProgram, RelationOp}; -use crate::data::relation::{ColumnDef, NullableColType}; +use crate::data::relation::{ColumnDef, NullableColType, StoredRelationMetadata}; use crate::data::symb::Symbol; use crate::data::tuple::{Tuple, ENCODED_KEY_MIN_LEN}; use crate::data::value::{DataValue, ValidityTs}; @@ -27,11 +27,11 @@ use crate::parse::expr::build_expr; use crate::parse::{parse_script, CozoScriptParser, Rule}; use crate::runtime::callback::{CallbackCollector, CallbackOp}; use crate::runtime::relation::{ - extend_tuple_from_v, AccessLevel, InputRelationHandle, InsufficientAccessLevel, + extend_tuple_from_v, AccessLevel, InputRelationHandle, InsufficientAccessLevel, RelationHandle, }; use crate::runtime::transact::SessionTx; use crate::storage::Storage; -use crate::{Db, NamedRows, StoreTx}; +use crate::{Db, NamedRows, SourceSpan, StoreTx}; #[derive(Debug, Error, Diagnostic)] #[error("attempting to write into relation {0} of arity {1} with data of arity {2}")] @@ -129,713 +129,773 @@ impl<'a> SessionTx<'a> { .. } = meta; + match op { + RelationOp::Rm => self.remove_from_relation( + db, + res_iter, + headers, + cur_vld, + callback_targets, + callback_collector, + propagate_triggers, + &mut to_clear, + &mut relation_store, + metadata, + key_bindings, + *span, + )?, + RelationOp::Ensure => self.ensure_in_relation( + res_iter, + headers, + cur_vld, + &mut relation_store, + metadata, + key_bindings, + dep_bindings, + *span, + )?, + RelationOp::EnsureNot => self.ensure_not_in_relation( + res_iter, + headers, + cur_vld, + &mut relation_store, + metadata, + key_bindings, + *span, + )?, + RelationOp::Update => self.update_in_relation( + db, + res_iter, + headers, + cur_vld, + callback_targets, + callback_collector, + propagate_triggers, + &mut to_clear, + &mut relation_store, + metadata, + key_bindings, + dep_bindings, + *span, + )?, + RelationOp::Create | RelationOp::Replace | RelationOp::Put => self.put_into_relation( + db, + res_iter, + headers, + cur_vld, + callback_targets, + callback_collector, + propagate_triggers, + &mut to_clear, + &mut relation_store, + metadata, + key_bindings, + dep_bindings, + *span, + )?, + }; + + Ok(to_clear) + } + + fn put_into_relation<'s, S: Storage<'s>>( + &mut self, + db: &Db, + res_iter: impl Iterator, + headers: &[Symbol], + cur_vld: ValidityTs, + callback_targets: &BTreeSet>, + callback_collector: &mut CallbackCollector, + propagate_triggers: bool, + to_clear: &mut Vec<(Vec, Vec)>, + relation_store: &mut RelationHandle, + metadata: &StoredRelationMetadata, + key_bindings: &[Symbol], + dep_bindings: &[Symbol], + span: SourceSpan, + ) -> Result<()> { let is_callback_target = callback_targets.contains(&relation_store.name); - match op { - RelationOp::Rm => { - if relation_store.access_level < AccessLevel::Protected { - bail!(InsufficientAccessLevel( - relation_store.name.to_string(), - "row removal".to_string(), - relation_store.access_level - )); - } - let key_extractors = make_extractors( - &relation_store.metadata.keys, - &metadata.keys, - key_bindings, - headers, - )?; - - let need_to_collect = !relation_store.is_temp - && (is_callback_target - || (propagate_triggers && !relation_store.rm_triggers.is_empty())); - let has_indices = !relation_store.indices.is_empty(); - let has_hnsw_indices = !relation_store.hnsw_indices.is_empty(); - let mut new_tuples: Vec = vec![]; - let mut old_tuples: Vec = vec![]; - - for tuple in res_iter { - let extracted: Vec = key_extractors - .iter() - .map(|ex| ex.extract_data(&tuple, cur_vld)) - .try_collect()?; - let key = relation_store.encode_key_for_store(&extracted, *span)?; - if need_to_collect || has_indices || has_hnsw_indices { - if let Some(existing) = self.store_tx.get(&key, false)? { - let mut tup = extracted.clone(); - extend_tuple_from_v(&mut tup, &existing); - if has_indices { - for (idx_rel, extractor) in relation_store.indices.values() { - let idx_tup = - extractor.iter().map(|i| tup[*i].clone()).collect_vec(); - let encoded = idx_rel - .encode_key_for_store(&idx_tup, Default::default())?; - self.store_tx.del(&encoded)?; - } - } - if has_hnsw_indices { - for (idx_handle, _) in relation_store.hnsw_indices.values() { - self.hnsw_remove(&relation_store, idx_handle, &extracted)?; - } - } - if need_to_collect { - old_tuples.push(DataValue::List(tup)); - } - } - if need_to_collect { - new_tuples.push(DataValue::List(extracted.clone())); - } + if relation_store.access_level < AccessLevel::Protected { + bail!(InsufficientAccessLevel( + relation_store.name.to_string(), + "row insertion".to_string(), + relation_store.access_level + )); + } + + let mut key_extractors = make_extractors( + &relation_store.metadata.keys, + &metadata.keys, + key_bindings, + headers, + )?; + + let need_to_collect = !relation_store.is_temp + && (is_callback_target + || (propagate_triggers && !relation_store.put_triggers.is_empty())); + let has_indices = !relation_store.indices.is_empty(); + let has_hnsw_indices = !relation_store.hnsw_indices.is_empty(); + let mut new_tuples: Vec = vec![]; + let mut old_tuples: Vec = vec![]; + + let val_extractors = make_extractors( + &relation_store.metadata.non_keys, + &metadata.non_keys, + dep_bindings, + headers, + )?; + key_extractors.extend(val_extractors); + let mut stack = vec![]; + let hnsw_filters = Self::make_hnsw_filters(relation_store)?; + + for tuple in res_iter { + let extracted: Vec = key_extractors + .iter() + .map(|ex| ex.extract_data(&tuple, cur_vld)) + .try_collect()?; + + let key = relation_store.encode_key_for_store(&extracted, span)?; + let val = relation_store.encode_val_for_store(&extracted, span)?; + + if need_to_collect || has_indices || has_hnsw_indices { + if let Some(existing) = self.store_tx.get(&key, false)? { + let mut tup = extracted[0..relation_store.metadata.keys.len()].to_vec(); + extend_tuple_from_v(&mut tup, &existing); + if has_indices && extracted != tup { + self.update_in_index(relation_store, &extracted, &tup)?; + } + + if need_to_collect { + old_tuples.push(DataValue::List(tup)); } - if relation_store.is_temp { - self.temp_store_tx.del(&key)?; - } else { - self.store_tx.del(&key)?; + } else if has_indices { + for (idx_rel, extractor) in relation_store.indices.values() { + let idx_tup_new = extractor + .iter() + .map(|i| extracted[*i].clone()) + .collect_vec(); + let encoded_new = + idx_rel.encode_key_for_store(&idx_tup_new, Default::default())?; + self.store_tx.put(&encoded_new, &[])?; } } - // triggers and callbacks - if need_to_collect && !new_tuples.is_empty() { - let k_bindings = relation_store - .metadata - .keys - .iter() - .map(|k| Symbol::new(k.name.clone(), Default::default())) - .collect_vec(); - - let v_bindings = relation_store - .metadata - .non_keys - .iter() - .map(|k| Symbol::new(k.name.clone(), Default::default())); - let mut kv_bindings = k_bindings.clone(); - kv_bindings.extend(v_bindings); - let kv_bindings = kv_bindings; - - if propagate_triggers { - for trigger in &relation_store.rm_triggers { - let mut program = parse_script( - trigger, - &Default::default(), - &db.fixed_rules.read().unwrap(), - cur_vld, - )? - .get_single_program()?; - - make_const_rule( - &mut program, - "_new", - k_bindings.clone(), - new_tuples.clone(), - ); - - make_const_rule( - &mut program, - "_old", - kv_bindings.clone(), - old_tuples.clone(), - ); - - let (_, cleanups) = db - .run_query( - self, - program, - cur_vld, - callback_targets, - callback_collector, - false, - ) - .map_err(|err| { - if err.source_code().is_some() { - err - } else { - err.with_source_code(trigger.to_string()) - } - })?; - to_clear.extend(cleanups); - } - } + self.update_in_hnsw(relation_store, &mut stack, &hnsw_filters, &extracted)?; - if is_callback_target { - let target_collector = callback_collector - .entry(relation_store.name.clone()) - .or_default(); - target_collector.push(( - CallbackOp::Rm, - NamedRows::new( - k_bindings - .into_iter() - .map(|k| k.name.to_string()) - .collect_vec(), - new_tuples - .into_iter() - .map(|v| match v { - DataValue::List(l) => l, - _ => unreachable!(), - }) - .collect_vec(), - ), - NamedRows::new( - kv_bindings - .into_iter() - .map(|k| k.name.to_string()) - .collect_vec(), - old_tuples - .into_iter() - .map(|v| match v { - DataValue::List(l) => l, - _ => unreachable!(), - }) - .collect_vec(), - ), - )) - } + if need_to_collect { + new_tuples.push(DataValue::List(extracted)); } } - RelationOp::Ensure => { - if relation_store.access_level < AccessLevel::ReadOnly { - bail!(InsufficientAccessLevel( - relation_store.name.to_string(), - "row check".to_string(), - relation_store.access_level - )); - } - let mut key_extractors = make_extractors( - &relation_store.metadata.keys, - &metadata.keys, - key_bindings, - headers, - )?; + if relation_store.is_temp { + self.temp_store_tx.put(&key, &val)?; + } else { + self.store_tx.put(&key, &val)?; + } + } - let val_extractors = make_extractors( - &relation_store.metadata.non_keys, - &metadata.non_keys, - dep_bindings, - headers, - )?; - key_extractors.extend(val_extractors); - - for tuple in res_iter { - let extracted: Vec = key_extractors - .iter() - .map(|ex| ex.extract_data(&tuple, cur_vld)) - .try_collect()?; - - let key = relation_store.encode_key_for_store(&extracted, *span)?; - let val = relation_store.encode_val_for_store(&extracted, *span)?; - - let existing = if relation_store.is_temp { - self.temp_store_tx.get(&key, true)? - } else { - self.store_tx.get(&key, true)? - }; - match existing { - None => { - bail!(TransactAssertionFailure { - relation: relation_store.name.to_string(), - key: extracted, - notice: "key does not exist in database".to_string() - }) - } - Some(v) => { - if &v as &[u8] != &val as &[u8] { - bail!(TransactAssertionFailure { - relation: relation_store.name.to_string(), - key: extracted, - notice: "key exists in database, but value does not match" - .to_string() - }) - } - } + if need_to_collect && !new_tuples.is_empty() { + self.collect_mutations( + db, + cur_vld, + callback_targets, + callback_collector, + propagate_triggers, + to_clear, + relation_store, + is_callback_target, + new_tuples, + old_tuples, + )?; + } + Ok(()) + } + + fn update_in_hnsw( + &mut self, + relation_store: &RelationHandle, + stack: &mut Vec, + hnsw_filters: &BTreeMap, Vec>, + new_kv: &[DataValue], + ) -> Result<()> { + for (name, (idx_handle, idx_manifest)) in relation_store.hnsw_indices.iter() { + let filter = hnsw_filters.get(name); + self.hnsw_put( + idx_manifest, + relation_store, + idx_handle, + filter, + stack, + new_kv, + )?; + } + Ok(()) + } + + fn make_hnsw_filters( + relation_store: &RelationHandle, + ) -> Result, Vec>> { + let mut hnsw_filters = BTreeMap::new(); + for (name, (_, manifest)) in relation_store.hnsw_indices.iter() { + if let Some(f_code) = &manifest.index_filter { + let parsed = CozoScriptParser::parse(Rule::expr, f_code) + .into_diagnostic()? + .next() + .unwrap(); + let mut code_expr = build_expr(parsed, &Default::default())?; + let binding_map = relation_store.raw_binding_map(); + code_expr.fill_binding_indices(&binding_map)?; + hnsw_filters.insert(name.clone(), code_expr.compile()); + } + } + Ok(hnsw_filters) + } + + fn update_in_relation<'s, S: Storage<'s>>( + &mut self, + db: &Db, + res_iter: impl Iterator, + headers: &[Symbol], + cur_vld: ValidityTs, + callback_targets: &BTreeSet>, + callback_collector: &mut CallbackCollector, + propagate_triggers: bool, + to_clear: &mut Vec<(Vec, Vec)>, + relation_store: &mut RelationHandle, + metadata: &StoredRelationMetadata, + key_bindings: &[Symbol], + dep_bindings: &[Symbol], + span: SourceSpan, + ) -> Result<()> { + let is_callback_target = callback_targets.contains(&relation_store.name); + + if relation_store.access_level < AccessLevel::Protected { + bail!(InsufficientAccessLevel( + relation_store.name.to_string(), + "row update".to_string(), + relation_store.access_level + )); + } + + let key_extractors = make_extractors( + &relation_store.metadata.keys, + &metadata.keys, + key_bindings, + headers, + )?; + + let need_to_collect = !relation_store.is_temp + && (is_callback_target + || (propagate_triggers && !relation_store.put_triggers.is_empty())); + let has_indices = !relation_store.indices.is_empty(); + let has_hnsw_indices = !relation_store.hnsw_indices.is_empty(); + let mut new_tuples: Vec = vec![]; + let mut old_tuples: Vec = vec![]; + + let val_extractors = make_update_extractors( + &relation_store.metadata.non_keys, + &metadata.non_keys, + dep_bindings, + headers, + ); + + let mut stack = vec![]; + let hnsw_filters = Self::make_hnsw_filters(relation_store)?; + + for tuple in res_iter { + let mut new_kv: Vec = key_extractors + .iter() + .map(|ex| ex.extract_data(&tuple, cur_vld)) + .try_collect()?; + + let key = relation_store.encode_key_for_store(&new_kv, span)?; + let original_val_bytes = if relation_store.is_temp { + self.temp_store_tx.get(&key, true)? + } else { + self.store_tx.get(&key, true)? + }; + let original_val: Tuple = match original_val_bytes { + None => { + bail!(TransactAssertionFailure { + relation: relation_store.name.to_string(), + key: new_kv, + notice: "key to update does not exist".to_string() + }) + } + Some(v) => rmp_serde::from_slice(&v[ENCODED_KEY_MIN_LEN..]).unwrap(), + }; + let mut old_kv = Vec::with_capacity(relation_store.arity()); + old_kv.extend_from_slice(&new_kv); + old_kv.extend_from_slice(&original_val); + new_kv.reserve_exact(relation_store.arity()); + for (i, extractor) in val_extractors.iter().enumerate() { + match extractor { + None => { + new_kv.push(original_val[i].clone()); + } + Some(ex) => { + let val = ex.extract_data(&tuple, cur_vld)?; + new_kv.push(val); } } } - RelationOp::EnsureNot => { - if relation_store.access_level < AccessLevel::ReadOnly { - bail!(InsufficientAccessLevel( - relation_store.name.to_string(), - "row check".to_string(), - relation_store.access_level - )); + let new_val = relation_store.encode_val_for_store(&new_kv, span)?; + + if need_to_collect || has_indices || has_hnsw_indices { + self.update_in_index(relation_store, &new_kv, &old_kv)?; + + if need_to_collect { + old_tuples.push(DataValue::List(old_kv)); } - let key_extractors = make_extractors( - &relation_store.metadata.keys, - &metadata.keys, - key_bindings, - headers, - )?; - - for tuple in res_iter { - let extracted: Vec = key_extractors - .iter() - .map(|ex| ex.extract_data(&tuple, cur_vld)) - .try_collect()?; - let key = relation_store.encode_key_for_store(&extracted, *span)?; - let already_exists = if relation_store.is_temp { - self.temp_store_tx.exists(&key, true)? - } else { - self.store_tx.exists(&key, true)? - }; - if already_exists { - bail!(TransactAssertionFailure { - relation: relation_store.name.to_string(), - key: extracted, - notice: "key exists in database".to_string() - }) - } + self.update_in_hnsw(relation_store, &mut stack, &hnsw_filters, &new_kv)?; + + if need_to_collect { + new_tuples.push(DataValue::List(new_kv)); } } - RelationOp::Update => { - if relation_store.access_level < AccessLevel::Protected { - bail!(InsufficientAccessLevel( - relation_store.name.to_string(), - "row update".to_string(), - relation_store.access_level - )); - } - let key_extractors = make_extractors( - &relation_store.metadata.keys, - &metadata.keys, - key_bindings, - headers, - )?; - - let need_to_collect = !relation_store.is_temp - && (is_callback_target - || (propagate_triggers && !relation_store.put_triggers.is_empty())); - let has_indices = !relation_store.indices.is_empty(); - let has_hnsw_indices = !relation_store.hnsw_indices.is_empty(); - let mut new_tuples: Vec = vec![]; - let mut old_tuples: Vec = vec![]; - - let val_extractors = make_update_extractors( - &relation_store.metadata.non_keys, - &metadata.non_keys, - dep_bindings, - headers, + if relation_store.is_temp { + self.temp_store_tx.put(&key, &new_val)?; + } else { + self.store_tx.put(&key, &new_val)?; + } + } + + if need_to_collect && !new_tuples.is_empty() { + self.collect_mutations( + db, + cur_vld, + callback_targets, + callback_collector, + propagate_triggers, + to_clear, + relation_store, + is_callback_target, + new_tuples, + old_tuples, + )?; + } + Ok(()) + } + + fn collect_mutations<'s, S: Storage<'s>>( + &mut self, + db: &Db, + cur_vld: ValidityTs, + callback_targets: &BTreeSet>, + callback_collector: &mut CallbackCollector, + propagate_triggers: bool, + to_clear: &mut Vec<(Vec, Vec)>, + relation_store: &RelationHandle, + is_callback_target: bool, + new_tuples: Vec, + old_tuples: Vec, + ) -> Result<()> { + let mut bindings = relation_store + .metadata + .keys + .iter() + .map(|k| Symbol::new(k.name.clone(), Default::default())) + .collect_vec(); + let v_bindings = relation_store + .metadata + .non_keys + .iter() + .map(|k| Symbol::new(k.name.clone(), Default::default())); + bindings.extend(v_bindings); + + let kv_bindings = bindings; + if propagate_triggers { + for trigger in &relation_store.put_triggers { + let mut program = parse_script( + trigger, + &Default::default(), + &db.fixed_rules.read().unwrap(), + cur_vld, + )? + .get_single_program()?; + + make_const_rule( + &mut program, + "_new", + kv_bindings.clone(), + new_tuples.to_vec(), + ); + make_const_rule( + &mut program, + "_old", + kv_bindings.clone(), + old_tuples.to_vec(), ); - let mut stack = vec![]; - let mut hnsw_filters = BTreeMap::new(); - if has_hnsw_indices { - for (name, (_, manifest)) in relation_store.hnsw_indices.iter() { - if let Some(f_code) = &manifest.index_filter { - let parsed = CozoScriptParser::parse(Rule::expr, f_code) - .into_diagnostic()? - .next() - .unwrap(); - let mut code_expr = build_expr(parsed, &Default::default())?; - let binding_map = relation_store.raw_binding_map(); - code_expr.fill_binding_indices(&binding_map)?; - hnsw_filters.insert(name.clone(), code_expr.compile()); + let (_, cleanups) = db + .run_query( + self, + program, + cur_vld, + callback_targets, + callback_collector, + false, + ) + .map_err(|err| { + if err.source_code().is_some() { + err + } else { + err.with_source_code(trigger.to_string()) } - } - } + })?; + to_clear.extend(cleanups); + } + } - for tuple in res_iter { - let mut new_kv: Vec = key_extractors - .iter() - .map(|ex| ex.extract_data(&tuple, cur_vld)) - .try_collect()?; - - let key = relation_store.encode_key_for_store(&new_kv, *span)?; - let original_val_bytes = if relation_store.is_temp { - self.temp_store_tx.get(&key, true)? - } else { - self.store_tx.get(&key, true)? - }; - let original_val: Tuple = match original_val_bytes { - None => { - bail!(TransactAssertionFailure { - relation: relation_store.name.to_string(), - key: new_kv, - notice: "key to update does not exist".to_string() - }) - } - Some(v) => rmp_serde::from_slice(&v[ENCODED_KEY_MIN_LEN..]).unwrap(), - }; - let mut old_kv = Vec::with_capacity(relation_store.arity()); - old_kv.extend_from_slice(&new_kv); - old_kv.extend_from_slice(&original_val); - new_kv.reserve_exact(relation_store.arity()); - for (i, extractor) in val_extractors.iter().enumerate() { - match extractor { - None => { - new_kv.push(original_val[i].clone()); - } - Some(ex) => { - let val = ex.extract_data(&tuple, cur_vld)?; - new_kv.push(val); - } - } - } - let new_val = relation_store.encode_val_for_store(&new_kv, *span)?; - - if need_to_collect || has_indices || has_hnsw_indices { - if has_indices { - for (idx_rel, idx_extractor) in relation_store.indices.values() { - let idx_tup_old = idx_extractor - .iter() - .map(|i| old_kv[*i].clone()) - .collect_vec(); - let encoded_old = idx_rel - .encode_key_for_store(&idx_tup_old, Default::default())?; - self.store_tx.del(&encoded_old)?; - - let idx_tup_new = idx_extractor - .iter() - .map(|i| new_kv[*i].clone()) - .collect_vec(); - let encoded_new = idx_rel - .encode_key_for_store(&idx_tup_new, Default::default())?; - self.store_tx.put(&encoded_new, &[])?; - } - } + if is_callback_target { + let target_collector = callback_collector + .entry(relation_store.name.clone()) + .or_default(); + let headers = kv_bindings + .into_iter() + .map(|k| k.name.to_string()) + .collect_vec(); + target_collector.push(( + CallbackOp::Put, + NamedRows::new( + headers.clone(), + new_tuples + .into_iter() + .map(|v| match v { + DataValue::List(l) => l, + _ => unreachable!(), + }) + .collect_vec(), + ), + NamedRows::new( + headers, + old_tuples + .into_iter() + .map(|v| match v { + DataValue::List(l) => l, + _ => unreachable!(), + }) + .collect_vec(), + ), + )) + } + Ok(()) + } - if need_to_collect { - old_tuples.push(DataValue::List(old_kv)); - } + fn update_in_index( + &mut self, + relation_store: &RelationHandle, + new_kv: &[DataValue], + old_kv: &[DataValue], + ) -> Result<()> { + for (idx_rel, idx_extractor) in relation_store.indices.values() { + let idx_tup_old = idx_extractor + .iter() + .map(|i| old_kv[*i].clone()) + .collect_vec(); + let encoded_old = idx_rel.encode_key_for_store(&idx_tup_old, Default::default())?; + self.store_tx.del(&encoded_old)?; + + let idx_tup_new = idx_extractor + .iter() + .map(|i| new_kv[*i].clone()) + .collect_vec(); + let encoded_new = idx_rel.encode_key_for_store(&idx_tup_new, Default::default())?; + self.store_tx.put(&encoded_new, &[])?; + } + Ok(()) + } - if has_hnsw_indices { - for (name, (idx_handle, idx_manifest)) in - relation_store.hnsw_indices.iter() - { - let filter = hnsw_filters.get(name); - self.hnsw_put( - idx_manifest, - &relation_store, - idx_handle, - filter, - &mut stack, - &new_kv, - )?; - } - } + fn ensure_not_in_relation( + &mut self, + res_iter: impl Iterator, + headers: &[Symbol], + cur_vld: ValidityTs, + relation_store: &mut RelationHandle, + metadata: &StoredRelationMetadata, + key_bindings: &[Symbol], + span: SourceSpan, + ) -> Result<()> { + if relation_store.access_level < AccessLevel::ReadOnly { + bail!(InsufficientAccessLevel( + relation_store.name.to_string(), + "row check".to_string(), + relation_store.access_level + )); + } - if need_to_collect { - new_tuples.push(DataValue::List(new_kv)); - } - } + let key_extractors = make_extractors( + &relation_store.metadata.keys, + &metadata.keys, + key_bindings, + headers, + )?; + + for tuple in res_iter { + let extracted: Vec = key_extractors + .iter() + .map(|ex| ex.extract_data(&tuple, cur_vld)) + .try_collect()?; + let key = relation_store.encode_key_for_store(&extracted, span)?; + let already_exists = if relation_store.is_temp { + self.temp_store_tx.exists(&key, true)? + } else { + self.store_tx.exists(&key, true)? + }; + if already_exists { + bail!(TransactAssertionFailure { + relation: relation_store.name.to_string(), + key: extracted, + notice: "key exists in database".to_string() + }) + } + } + Ok(()) + } - if relation_store.is_temp { - self.temp_store_tx.put(&key, &new_val)?; - } else { - self.store_tx.put(&key, &new_val)?; - } - } + fn ensure_in_relation( + &mut self, + res_iter: impl Iterator, + headers: &[Symbol], + cur_vld: ValidityTs, + relation_store: &mut RelationHandle, + metadata: &StoredRelationMetadata, + key_bindings: &[Symbol], + dep_bindings: &[Symbol], + span: SourceSpan, + ) -> Result<()> { + if relation_store.access_level < AccessLevel::ReadOnly { + bail!(InsufficientAccessLevel( + relation_store.name.to_string(), + "row check".to_string(), + relation_store.access_level + )); + } - if need_to_collect && !new_tuples.is_empty() { - let mut bindings = relation_store - .metadata - .keys - .iter() - .map(|k| Symbol::new(k.name.clone(), Default::default())) - .collect_vec(); - let v_bindings = relation_store - .metadata - .non_keys - .iter() - .map(|k| Symbol::new(k.name.clone(), Default::default())); - bindings.extend(v_bindings); - - let kv_bindings = bindings; - if propagate_triggers { - for trigger in &relation_store.put_triggers { - let mut program = parse_script( - trigger, - &Default::default(), - &db.fixed_rules.read().unwrap(), - cur_vld, - )? - .get_single_program()?; - - make_const_rule( - &mut program, - "_new", - kv_bindings.clone(), - new_tuples.clone(), - ); - make_const_rule( - &mut program, - "_old", - kv_bindings.clone(), - old_tuples.clone(), - ); - - let (_, cleanups) = db - .run_query( - self, - program, - cur_vld, - callback_targets, - callback_collector, - false, - ) - .map_err(|err| { - if err.source_code().is_some() { - err - } else { - err.with_source_code(trigger.to_string()) - } - })?; - to_clear.extend(cleanups); - } - } + let mut key_extractors = make_extractors( + &relation_store.metadata.keys, + &metadata.keys, + key_bindings, + headers, + )?; - if is_callback_target { - let target_collector = callback_collector - .entry(relation_store.name.clone()) - .or_default(); - let headers = kv_bindings - .into_iter() - .map(|k| k.name.to_string()) - .collect_vec(); - target_collector.push(( - CallbackOp::Put, - NamedRows::new( - headers.clone(), - new_tuples - .into_iter() - .map(|v| match v { - DataValue::List(l) => l, - _ => unreachable!(), - }) - .collect_vec(), - ), - NamedRows::new( - headers, - old_tuples - .into_iter() - .map(|v| match v { - DataValue::List(l) => l, - _ => unreachable!(), - }) - .collect_vec(), - ), - )) + let val_extractors = make_extractors( + &relation_store.metadata.non_keys, + &metadata.non_keys, + dep_bindings, + headers, + )?; + key_extractors.extend(val_extractors); + + for tuple in res_iter { + let extracted: Vec = key_extractors + .iter() + .map(|ex| ex.extract_data(&tuple, cur_vld)) + .try_collect()?; + + let key = relation_store.encode_key_for_store(&extracted, span)?; + let val = relation_store.encode_val_for_store(&extracted, span)?; + + let existing = if relation_store.is_temp { + self.temp_store_tx.get(&key, true)? + } else { + self.store_tx.get(&key, true)? + }; + match existing { + None => { + bail!(TransactAssertionFailure { + relation: relation_store.name.to_string(), + key: extracted, + notice: "key does not exist in database".to_string() + }) + } + Some(v) => { + if &v as &[u8] != &val as &[u8] { + bail!(TransactAssertionFailure { + relation: relation_store.name.to_string(), + key: extracted, + notice: "key exists in database, but value does not match".to_string() + }) } } } - RelationOp::Create | RelationOp::Replace | RelationOp::Put => { - if relation_store.access_level < AccessLevel::Protected { - bail!(InsufficientAccessLevel( - relation_store.name.to_string(), - "row insertion".to_string(), - relation_store.access_level - )); - } + } + Ok(()) + } - let mut key_extractors = make_extractors( - &relation_store.metadata.keys, - &metadata.keys, - key_bindings, - headers, - )?; - - let need_to_collect = !relation_store.is_temp - && (is_callback_target - || (propagate_triggers && !relation_store.put_triggers.is_empty())); - let has_indices = !relation_store.indices.is_empty(); - let has_hnsw_indices = !relation_store.hnsw_indices.is_empty(); - let mut new_tuples: Vec = vec![]; - let mut old_tuples: Vec = vec![]; - - let val_extractors = make_extractors( - &relation_store.metadata.non_keys, - &metadata.non_keys, - dep_bindings, - headers, - )?; - key_extractors.extend(val_extractors); - let mut stack = vec![]; - let mut hnsw_filters = BTreeMap::new(); - if has_hnsw_indices { - for (name, (_, manifest)) in relation_store.hnsw_indices.iter() { - if let Some(f_code) = &manifest.index_filter { - let parsed = CozoScriptParser::parse(Rule::expr, f_code) - .into_diagnostic()? - .next() - .unwrap(); - let mut code_expr = build_expr(parsed, &Default::default())?; - let binding_map = relation_store.raw_binding_map(); - code_expr.fill_binding_indices(&binding_map)?; - hnsw_filters.insert(name.clone(), code_expr.compile()); + fn remove_from_relation<'s, S: Storage<'s>>( + &mut self, + db: &Db, + res_iter: impl Iterator, + headers: &[Symbol], + cur_vld: ValidityTs, + callback_targets: &BTreeSet>, + callback_collector: &mut CallbackCollector, + propagate_triggers: bool, + to_clear: &mut Vec<(Vec, Vec)>, + relation_store: &mut RelationHandle, + metadata: &StoredRelationMetadata, + key_bindings: &[Symbol], + span: SourceSpan, + ) -> Result<()> { + let is_callback_target = callback_targets.contains(&relation_store.name); + + if relation_store.access_level < AccessLevel::Protected { + bail!(InsufficientAccessLevel( + relation_store.name.to_string(), + "row removal".to_string(), + relation_store.access_level + )); + } + let key_extractors = make_extractors( + &relation_store.metadata.keys, + &metadata.keys, + key_bindings, + headers, + )?; + + let need_to_collect = !relation_store.is_temp + && (is_callback_target + || (propagate_triggers && !relation_store.rm_triggers.is_empty())); + let has_indices = !relation_store.indices.is_empty(); + let has_hnsw_indices = !relation_store.hnsw_indices.is_empty(); + let mut new_tuples: Vec = vec![]; + let mut old_tuples: Vec = vec![]; + + for tuple in res_iter { + let extracted: Vec = key_extractors + .iter() + .map(|ex| ex.extract_data(&tuple, cur_vld)) + .try_collect()?; + let key = relation_store.encode_key_for_store(&extracted, span)?; + if need_to_collect || has_indices || has_hnsw_indices { + if let Some(existing) = self.store_tx.get(&key, false)? { + let mut tup = extracted.clone(); + extend_tuple_from_v(&mut tup, &existing); + if has_indices { + for (idx_rel, extractor) in relation_store.indices.values() { + let idx_tup = extractor.iter().map(|i| tup[*i].clone()).collect_vec(); + let encoded = + idx_rel.encode_key_for_store(&idx_tup, Default::default())?; + self.store_tx.del(&encoded)?; } } + if has_hnsw_indices { + for (idx_handle, _) in relation_store.hnsw_indices.values() { + self.hnsw_remove(relation_store, idx_handle, &extracted)?; + } + } + if need_to_collect { + old_tuples.push(DataValue::List(tup)); + } } + if need_to_collect { + new_tuples.push(DataValue::List(extracted.clone())); + } + } + if relation_store.is_temp { + self.temp_store_tx.del(&key)?; + } else { + self.store_tx.del(&key)?; + } + } - for tuple in res_iter { - let extracted: Vec = key_extractors - .iter() - .map(|ex| ex.extract_data(&tuple, cur_vld)) - .try_collect()?; - - let key = relation_store.encode_key_for_store(&extracted, *span)?; - let val = relation_store.encode_val_for_store(&extracted, *span)?; - - if need_to_collect || has_indices || has_hnsw_indices { - if let Some(existing) = self.store_tx.get(&key, false)? { - let mut tup = extracted[0..relation_store.metadata.keys.len()].to_vec(); - extend_tuple_from_v(&mut tup, &existing); - if has_indices && extracted != tup { - for (idx_rel, extractor) in relation_store.indices.values() { - let idx_tup_old = - extractor.iter().map(|i| tup[*i].clone()).collect_vec(); - let encoded_old = idx_rel - .encode_key_for_store(&idx_tup_old, Default::default())?; - self.store_tx.del(&encoded_old)?; - - let idx_tup_new = extractor - .iter() - .map(|i| extracted[*i].clone()) - .collect_vec(); - let encoded_new = idx_rel - .encode_key_for_store(&idx_tup_new, Default::default())?; - self.store_tx.put(&encoded_new, &[])?; - } - } - - if need_to_collect { - old_tuples.push(DataValue::List(tup)); - } - } else if has_indices { - for (idx_rel, extractor) in relation_store.indices.values() { - let idx_tup_new = extractor - .iter() - .map(|i| extracted[*i].clone()) - .collect_vec(); - let encoded_new = idx_rel - .encode_key_for_store(&idx_tup_new, Default::default())?; - self.store_tx.put(&encoded_new, &[])?; - } - } + // triggers and callbacks + if need_to_collect && !new_tuples.is_empty() { + let k_bindings = relation_store + .metadata + .keys + .iter() + .map(|k| Symbol::new(k.name.clone(), Default::default())) + .collect_vec(); + + let v_bindings = relation_store + .metadata + .non_keys + .iter() + .map(|k| Symbol::new(k.name.clone(), Default::default())); + let mut kv_bindings = k_bindings.clone(); + kv_bindings.extend(v_bindings); + let kv_bindings = kv_bindings; + + if propagate_triggers { + for trigger in &relation_store.rm_triggers { + let mut program = parse_script( + trigger, + &Default::default(), + &db.fixed_rules.read().unwrap(), + cur_vld, + )? + .get_single_program()?; - if has_hnsw_indices { - for (name, (idx_handle, idx_manifest)) in - relation_store.hnsw_indices.iter() - { - let filter = hnsw_filters.get(name); - self.hnsw_put( - idx_manifest, - &relation_store, - idx_handle, - filter, - &mut stack, - &extracted, - )?; - } - } + make_const_rule(&mut program, "_new", k_bindings.clone(), new_tuples.clone()); - if need_to_collect { - new_tuples.push(DataValue::List(extracted)); - } - } + make_const_rule( + &mut program, + "_old", + kv_bindings.clone(), + old_tuples.clone(), + ); - if relation_store.is_temp { - self.temp_store_tx.put(&key, &val)?; - } else { - self.store_tx.put(&key, &val)?; - } + let (_, cleanups) = db + .run_query( + self, + program, + cur_vld, + callback_targets, + callback_collector, + false, + ) + .map_err(|err| { + if err.source_code().is_some() { + err + } else { + err.with_source_code(trigger.to_string()) + } + })?; + to_clear.extend(cleanups); } + } - if need_to_collect && !new_tuples.is_empty() { - let mut bindings = relation_store - .metadata - .keys - .iter() - .map(|k| Symbol::new(k.name.clone(), Default::default())) - .collect_vec(); - let v_bindings = relation_store - .metadata - .non_keys - .iter() - .map(|k| Symbol::new(k.name.clone(), Default::default())); - bindings.extend(v_bindings); - - let kv_bindings = bindings; - if propagate_triggers { - for trigger in &relation_store.put_triggers { - let mut program = parse_script( - trigger, - &Default::default(), - &db.fixed_rules.read().unwrap(), - cur_vld, - )? - .get_single_program()?; - - make_const_rule( - &mut program, - "_new", - kv_bindings.clone(), - new_tuples.clone(), - ); - make_const_rule( - &mut program, - "_old", - kv_bindings.clone(), - old_tuples.clone(), - ); - - let (_, cleanups) = db - .run_query( - self, - program, - cur_vld, - callback_targets, - callback_collector, - false, - ) - .map_err(|err| { - if err.source_code().is_some() { - err - } else { - err.with_source_code(trigger.to_string()) - } - })?; - to_clear.extend(cleanups); - } - } - - if is_callback_target { - let target_collector = callback_collector - .entry(relation_store.name.clone()) - .or_default(); - let headers = kv_bindings + if is_callback_target { + let target_collector = callback_collector + .entry(relation_store.name.clone()) + .or_default(); + target_collector.push(( + CallbackOp::Rm, + NamedRows::new( + k_bindings .into_iter() .map(|k| k.name.to_string()) - .collect_vec(); - target_collector.push(( - CallbackOp::Put, - NamedRows::new( - headers.clone(), - new_tuples - .into_iter() - .map(|v| match v { - DataValue::List(l) => l, - _ => unreachable!(), - }) - .collect_vec(), - ), - NamedRows::new( - headers, - old_tuples - .into_iter() - .map(|v| match v { - DataValue::List(l) => l, - _ => unreachable!(), - }) - .collect_vec(), - ), - )) - } - } + .collect_vec(), + new_tuples + .into_iter() + .map(|v| match v { + DataValue::List(l) => l, + _ => unreachable!(), + }) + .collect_vec(), + ), + NamedRows::new( + kv_bindings + .into_iter() + .map(|k| k.name.to_string()) + .collect_vec(), + old_tuples + .into_iter() + .map(|v| match v { + DataValue::List(l) => l, + _ => unreachable!(), + }) + .collect_vec(), + ), + )) } - }; - - Ok(to_clear) + } + Ok(()) } } diff --git a/cozo-core/src/runtime/hnsw.rs b/cozo-core/src/runtime/hnsw.rs index 3ff5d9a0..36a4ff8b 100644 --- a/cozo-core/src/runtime/hnsw.rs +++ b/cozo-core/src/runtime/hnsw.rs @@ -151,7 +151,7 @@ impl VectorCache { impl<'a> SessionTx<'a> { fn hnsw_put_vector( &mut self, - tuple: &Tuple, + tuple: &[DataValue], q: &Vector, idx: usize, subidx: i32, @@ -685,7 +685,7 @@ impl<'a> SessionTx<'a> { idx_table: &RelationHandle, filter: Option<&Vec>, stack: &mut Vec, - tuple: &Tuple, + tuple: &[DataValue], ) -> Result { if let Some(code) = filter { if !eval_bytecode_pred(code, tuple, stack, Default::default())? {