diff --git a/cozo-core/src/cozoscript.pest b/cozo-core/src/cozoscript.pest index 524e1bdf..77321231 100644 --- a/cozo-core/src/cozoscript.pest +++ b/cozo-core/src/cozoscript.pest @@ -127,10 +127,11 @@ limit_option = {":limit" ~ expr} offset_option = {":offset" ~ expr} sort_option = {(":sort" | ":order") ~ (sort_arg ~ ",")* ~ sort_arg } relation_option = {relation_op ~ (compound_ident | underscore_ident) ~ table_schema?} -relation_op = _{relation_create | relation_replace | relation_put | relation_rm | relation_ensure | relation_ensure_not} +relation_op = _{relation_create | relation_replace | relation_put | relation_update | relation_rm | relation_ensure | relation_ensure_not} relation_create = {":create"} relation_replace = {":replace"} relation_put = {":put"} +relation_update = {":update"} relation_rm = {":rm"} relation_ensure = {":ensure"} relation_ensure_not = {":ensure_not"} diff --git a/cozo-core/src/data/program.rs b/cozo-core/src/data/program.rs index 1ab60633..b79d33b5 100644 --- a/cozo-core/src/data/program.rs +++ b/cozo-core/src/data/program.rs @@ -93,6 +93,9 @@ impl Display for QueryOutOptions { RelationOp::Put => { write!(f, ":put ")?; } + RelationOp::Update => { + write!(f, ":update ")?; + } RelationOp::Rm => { write!(f, ":rm ")?; } @@ -172,6 +175,7 @@ pub(crate) enum RelationOp { Create, Replace, Put, + Update, Rm, Ensure, EnsureNot, diff --git a/cozo-core/src/parse/query.rs b/cozo-core/src/parse/query.rs index 17952089..9e9368d0 100644 --- a/cozo-core/src/parse/query.rs +++ b/cozo-core/src/parse/query.rs @@ -320,6 +320,7 @@ pub(crate) fn parse_query( Rule::relation_create => RelationOp::Create, Rule::relation_replace => RelationOp::Replace, Rule::relation_put => RelationOp::Put, + Rule::relation_update => RelationOp::Update, Rule::relation_rm => RelationOp::Rm, Rule::relation_ensure => RelationOp::Ensure, Rule::relation_ensure_not => RelationOp::EnsureNot, diff --git a/cozo-core/src/query/stored.rs b/cozo-core/src/query/stored.rs index 2f918581..2ca870ef 100644 --- a/cozo-core/src/query/stored.rs +++ b/cozo-core/src/query/stored.rs @@ -19,7 +19,7 @@ use crate::data::expr::Expr; use crate::data::program::{FixedRuleApply, InputInlineRulesOrFixed, InputProgram, RelationOp}; use crate::data::relation::{ColumnDef, NullableColType}; use crate::data::symb::Symbol; -use crate::data::tuple::Tuple; +use crate::data::tuple::{Tuple, ENCODED_KEY_MIN_LEN}; use crate::data::value::{DataValue, ValidityTs}; use crate::fixed_rule::utilities::constant::Constant; use crate::fixed_rule::FixedRuleHandle; @@ -175,9 +175,7 @@ impl<'a> SessionTx<'a> { } } if has_hnsw_indices { - for (idx_handle, _) in - relation_store.hnsw_indices.values() - { + for (idx_handle, _) in relation_store.hnsw_indices.values() { self.hnsw_remove(&relation_store, idx_handle, &extracted)?; } } @@ -389,6 +387,238 @@ impl<'a> SessionTx<'a> { } } } + RelationOp::Update => { + if relation_store.access_level < AccessLevel::Protected { + bail!(InsufficientAccessLevel( + relation_store.name.to_string(), + "row update".to_string(), + relation_store.access_level + )); + } + + let key_extractors = make_extractors( + &relation_store.metadata.keys, + &metadata.keys, + key_bindings, + headers, + )?; + + let need_to_collect = !relation_store.is_temp + && (is_callback_target + || (propagate_triggers && !relation_store.put_triggers.is_empty())); + let has_indices = !relation_store.indices.is_empty(); + let has_hnsw_indices = !relation_store.hnsw_indices.is_empty(); + let mut new_tuples: Vec = vec![]; + let mut old_tuples: Vec = vec![]; + + let val_extractors = make_update_extractors( + &relation_store.metadata.non_keys, + &metadata.non_keys, + dep_bindings, + headers, + ); + + let mut stack = vec![]; + let mut hnsw_filters = BTreeMap::new(); + if has_hnsw_indices { + for (name, (_, manifest)) in relation_store.hnsw_indices.iter() { + if let Some(f_code) = &manifest.index_filter { + let parsed = CozoScriptParser::parse(Rule::expr, f_code) + .into_diagnostic()? + .next() + .unwrap(); + let mut code_expr = build_expr(parsed, &Default::default())?; + let binding_map = relation_store.raw_binding_map(); + code_expr.fill_binding_indices(&binding_map)?; + hnsw_filters.insert(name.clone(), code_expr.compile()); + } + } + } + + for tuple in res_iter { + let mut new_kv: Vec = key_extractors + .iter() + .map(|ex| ex.extract_data(&tuple, cur_vld)) + .try_collect()?; + + let key = relation_store.encode_key_for_store(&new_kv, *span)?; + let original_val_bytes = if relation_store.is_temp { + self.temp_store_tx.get(&key, true)? + } else { + self.store_tx.get(&key, true)? + }; + let original_val: Tuple = match original_val_bytes { + None => { + bail!(TransactAssertionFailure { + relation: relation_store.name.to_string(), + key: new_kv, + notice: "key to update does not exist".to_string() + }) + } + Some(v) => rmp_serde::from_slice(&v[ENCODED_KEY_MIN_LEN..]).unwrap(), + }; + let mut old_kv = Vec::with_capacity(relation_store.arity()); + old_kv.extend_from_slice(&new_kv); + old_kv.extend_from_slice(&original_val); + new_kv.reserve_exact(relation_store.arity()); + for (i, extractor) in val_extractors.iter().enumerate() { + match extractor { + None => { + new_kv.push(original_val[i].clone()); + } + Some(ex) => { + let val = ex.extract_data(&tuple, cur_vld)?; + new_kv.push(val); + } + } + } + let new_val = relation_store.encode_val_for_store(&new_kv, *span)?; + + if need_to_collect || has_indices || has_hnsw_indices { + if has_indices { + for (idx_rel, idx_extractor) in relation_store.indices.values() { + let idx_tup_old = idx_extractor + .iter() + .map(|i| old_kv[*i].clone()) + .collect_vec(); + let encoded_old = idx_rel + .encode_key_for_store(&idx_tup_old, Default::default())?; + self.store_tx.del(&encoded_old)?; + + let idx_tup_new = idx_extractor + .iter() + .map(|i| new_kv[*i].clone()) + .collect_vec(); + let encoded_new = idx_rel + .encode_key_for_store(&idx_tup_new, Default::default())?; + self.store_tx.put(&encoded_new, &[])?; + } + } + + if need_to_collect { + old_tuples.push(DataValue::List(old_kv)); + } + + if has_hnsw_indices { + for (name, (idx_handle, idx_manifest)) in + relation_store.hnsw_indices.iter() + { + let filter = hnsw_filters.get(name); + self.hnsw_put( + idx_manifest, + &relation_store, + idx_handle, + filter, + &mut stack, + &new_kv, + )?; + } + } + + if need_to_collect { + new_tuples.push(DataValue::List(new_kv)); + } + } + + if relation_store.is_temp { + self.temp_store_tx.put(&key, &new_val)?; + } else { + self.store_tx.put(&key, &new_val)?; + } + } + + if need_to_collect && !new_tuples.is_empty() { + let mut bindings = relation_store + .metadata + .keys + .iter() + .map(|k| Symbol::new(k.name.clone(), Default::default())) + .collect_vec(); + let v_bindings = relation_store + .metadata + .non_keys + .iter() + .map(|k| Symbol::new(k.name.clone(), Default::default())); + bindings.extend(v_bindings); + + let kv_bindings = bindings; + if propagate_triggers { + for trigger in &relation_store.put_triggers { + let mut program = parse_script( + trigger, + &Default::default(), + &db.fixed_rules.read().unwrap(), + cur_vld, + )? + .get_single_program()?; + + make_const_rule( + &mut program, + "_new", + kv_bindings.clone(), + new_tuples.clone(), + ); + make_const_rule( + &mut program, + "_old", + kv_bindings.clone(), + old_tuples.clone(), + ); + + let (_, cleanups) = db + .run_query( + self, + program, + cur_vld, + callback_targets, + callback_collector, + false, + ) + .map_err(|err| { + if err.source_code().is_some() { + err + } else { + err.with_source_code(trigger.to_string()) + } + })?; + to_clear.extend(cleanups); + } + } + + if is_callback_target { + let target_collector = callback_collector + .entry(relation_store.name.clone()) + .or_default(); + let headers = kv_bindings + .into_iter() + .map(|k| k.name.to_string()) + .collect_vec(); + target_collector.push(( + CallbackOp::Put, + NamedRows::new( + headers.clone(), + new_tuples + .into_iter() + .map(|v| match v { + DataValue::List(l) => l, + _ => unreachable!(), + }) + .collect_vec(), + ), + NamedRows::new( + headers, + old_tuples + .into_iter() + .map(|v| match v { + DataValue::List(l) => l, + _ => unreachable!(), + }) + .collect_vec(), + ), + )) + } + } + } RelationOp::Create | RelationOp::Replace | RelationOp::Put => { if relation_store.access_level < AccessLevel::Protected { bail!(InsufficientAccessLevel( @@ -485,7 +715,7 @@ impl<'a> SessionTx<'a> { if has_hnsw_indices { for (name, (idx_handle, idx_manifest)) in - relation_store.hnsw_indices.iter() + relation_store.hnsw_indices.iter() { let filter = hnsw_filters.get(name); self.hnsw_put( @@ -647,6 +877,18 @@ fn make_extractors( .try_collect() } +fn make_update_extractors( + stored: &[ColumnDef], + input: &[ColumnDef], + bindings: &[Symbol], + tuple_headers: &[Symbol], +) -> Vec> { + stored + .iter() + .map(|s| make_extractor(s, input, bindings, tuple_headers).ok()) + .collect_vec() +} + fn make_extractor( stored: &ColumnDef, input: &[ColumnDef], diff --git a/cozo-core/src/runtime/db.rs b/cozo-core/src/runtime/db.rs index ab22e3b8..1b05a786 100644 --- a/cozo-core/src/runtime/db.rs +++ b/cozo-core/src/runtime/db.rs @@ -1057,13 +1057,13 @@ impl<'s, S: Storage<'s>> Db { ) } RelAlgebra::HnswSearch(HnswSearchRA { - hnsw_search, - .. + hnsw_search, .. }) => ( "hnsw_index", json!(format!(":{}", hnsw_search.query.name)), json!(hnsw_search.query.name), - json!(hnsw_search.filter + json!(hnsw_search + .filter .iter() .map(|f| f.to_string()) .collect_vec()), @@ -1321,7 +1321,8 @@ impl<'s, S: Storage<'s>> Db { StoreRelationNotFoundError(meta.name.to_string()) ); - existing.ensure_compatible(meta, *op == RelationOp::Rm)?; + existing + .ensure_compatible(meta, *op == RelationOp::Rm || *op == RelationOp::Update)?; } }; diff --git a/cozo-core/src/runtime/relation.rs b/cozo-core/src/runtime/relation.rs index 8869fa09..71e1052f 100644 --- a/cozo-core/src/runtime/relation.rs +++ b/cozo-core/src/runtime/relation.rs @@ -266,7 +266,7 @@ impl RelationHandle { pub(crate) fn ensure_compatible( &self, inp: &InputRelationHandle, - is_remove: bool, + is_remove_or_update: bool, ) -> Result<()> { let InputRelationHandle { metadata, .. } = inp; // check that every given key is found and compatible @@ -280,7 +280,7 @@ impl RelationHandle { for col in &self.metadata.keys { metadata.satisfied_by_required_col(col, true)?; } - if !is_remove { + if !is_remove_or_update { for col in &self.metadata.non_keys { metadata.satisfied_by_required_col(col, false)?; } diff --git a/cozo-core/src/runtime/tests.rs b/cozo-core/src/runtime/tests.rs index 8fa6781f..a8d5ccd2 100644 --- a/cozo-core/src/runtime/tests.rs +++ b/cozo-core/src/runtime/tests.rs @@ -489,6 +489,42 @@ fn test_callback() { assert_eq!(collected[2].2.rows[0].len(), 3); } +#[test] +fn test_update() { + let db = new_cozo_mem().unwrap(); + db.run_script( + ":create friends {fr: Int, to: Int => a: Any, b: Any, c: Any}", + Default::default(), + ) + .unwrap(); + db.run_script( + "?[fr, to, a, b, c] <- [[1,2,3,4,5]] :put friends {fr, to => a, b, c}", + Default::default(), + ) + .unwrap(); + let res = db + .run_script( + "?[fr, to, a, b, c] := *friends{fr, to, a, b, c}", + Default::default(), + ) + .unwrap() + .into_json(); + assert_eq!(res["rows"][0], json!([1, 2, 3, 4, 5])); + db.run_script( + "?[fr, to, b] <- [[1, 2, 100]] :update friends {fr, to => b}", + Default::default(), + ) + .unwrap(); + let res = db + .run_script( + "?[fr, to, a, b, c] := *friends{fr, to, a, b, c}", + Default::default(), + ) + .unwrap() + .into_json(); + assert_eq!(res["rows"][0], json!([1, 2, 3, 100, 5])); +} + #[test] fn test_index() { let db = new_cozo_mem().unwrap(); @@ -871,7 +907,11 @@ fn test_insertions() { .unwrap(); db.run_script(r"?[k] <- [[1]] :put a {k}", Default::default()) .unwrap(); - db.run_script(r"?[k] := k in int_range(300) :put a {k}", Default::default()).unwrap(); + db.run_script( + r"?[k] := k in int_range(300) :put a {k}", + Default::default(), + ) + .unwrap(); let res = db .run_script( r"?[dist, k] := ~a:i{k | query: v, bind_distance: dist, k:10, ef: 50, filter: k % 2 == 0, radius: 245}, *a{k: 96, v}",