From 7965dfe67f79af8ac6f69c633605f381b9b68637 Mon Sep 17 00:00:00 2001 From: Ziyang Hu Date: Mon, 29 May 2023 22:28:34 +0800 Subject: [PATCH] fix `:update` on relations with default val cols --- cozo-core/src/query/stored.rs | 84 +++++++++++++++++++--------------- cozo-core/src/runtime/tests.rs | 23 ++++++++++ 2 files changed, 69 insertions(+), 38 deletions(-) diff --git a/cozo-core/src/query/stored.rs b/cozo-core/src/query/stored.rs index 9f2beb2d..ce346a11 100644 --- a/cozo-core/src/query/stored.rs +++ b/cozo-core/src/query/stored.rs @@ -44,7 +44,7 @@ impl<'a> SessionTx<'a> { pub(crate) fn execute_relation<'s, S: Storage<'s>>( &mut self, db: &Db, - res_iter: impl Iterator, + res_iter: impl Iterator, op: RelationOp, meta: &InputRelationHandle, headers: &[Symbol], @@ -89,7 +89,7 @@ impl<'a> SessionTx<'a> { &db.fixed_rules.read().unwrap(), cur_vld, )? - .get_single_program()?; + .get_single_program()?; let (_, cleanups) = db .run_query( @@ -104,7 +104,7 @@ impl<'a> SessionTx<'a> { if err.source_code().is_some() { err } else { - err.with_source_code(format!("{trigger}" )) + err.with_source_code(format!("{trigger}")) } })?; to_clear.extend(cleanups); @@ -208,7 +208,7 @@ impl<'a> SessionTx<'a> { fn put_into_relation<'s, S: Storage<'s>>( &mut self, db: &Db, - res_iter: impl Iterator, + res_iter: impl Iterator, headers: &[Symbol], cur_vld: ValidityTs, callback_targets: &BTreeSet>, @@ -223,7 +223,8 @@ impl<'a> SessionTx<'a> { force_collect: &str, span: SourceSpan, ) -> Result<()> { - let is_callback_target = callback_targets.contains(&relation_store.name) || force_collect == &relation_store.name; + let is_callback_target = callback_targets.contains(&relation_store.name) + || force_collect == &relation_store.name; if relation_store.access_level < AccessLevel::Protected { bail!(InsufficientAccessLevel( @@ -240,9 +241,10 @@ impl<'a> SessionTx<'a> { headers, )?; - let need_to_collect = !force_collect.is_empty() || (!relation_store.is_temp - && (is_callback_target - || (propagate_triggers && !relation_store.put_triggers.is_empty()))); + let need_to_collect = !force_collect.is_empty() + || (!relation_store.is_temp + && (is_callback_target + || (propagate_triggers && !relation_store.put_triggers.is_empty()))); let has_indices = !relation_store.indices.is_empty(); let has_hnsw_indices = !relation_store.hnsw_indices.is_empty(); let has_fts_indices = !relation_store.fts_indices.is_empty(); @@ -464,11 +466,9 @@ impl<'a> SessionTx<'a> { ) -> Result, (Arc, Vec)>> { let mut processors = BTreeMap::new(); for (name, (_, manifest)) in relation_store.fts_indices.iter() { - let tokenizer = self.tokenizers.get( - &name, - &manifest.tokenizer, - &manifest.filters, - )?; + let tokenizer = self + .tokenizers + .get(&name, &manifest.tokenizer, &manifest.filters)?; let parsed = CozoScriptParser::parse(Rule::expr, &manifest.extractor) .into_diagnostic()? @@ -481,11 +481,9 @@ impl<'a> SessionTx<'a> { processors.insert(name.clone(), (tokenizer, extractor)); } for (name, (_, _, manifest)) in relation_store.lsh_indices.iter() { - let tokenizer = self.tokenizers.get( - &name, - &manifest.tokenizer, - &manifest.filters, - )?; + let tokenizer = self + .tokenizers + .get(&name, &manifest.tokenizer, &manifest.filters)?; let parsed = CozoScriptParser::parse(Rule::expr, &manifest.extractor) .into_diagnostic()? @@ -522,7 +520,7 @@ impl<'a> SessionTx<'a> { fn update_in_relation<'s, S: Storage<'s>>( &mut self, db: &Db, - res_iter: impl Iterator, + res_iter: impl Iterator, headers: &[Symbol], cur_vld: ValidityTs, callback_targets: &BTreeSet>, @@ -535,7 +533,8 @@ impl<'a> SessionTx<'a> { force_collect: &str, span: SourceSpan, ) -> Result<()> { - let is_callback_target = callback_targets.contains(&relation_store.name) || force_collect == &relation_store.name; + let is_callback_target = callback_targets.contains(&relation_store.name) + || force_collect == &relation_store.name; if relation_store.access_level < AccessLevel::Protected { bail!(InsufficientAccessLevel( @@ -552,9 +551,10 @@ impl<'a> SessionTx<'a> { headers, )?; - let need_to_collect = !force_collect.is_empty() || (!relation_store.is_temp - && (is_callback_target - || (propagate_triggers && !relation_store.put_triggers.is_empty()))); + let need_to_collect = !force_collect.is_empty() + || (!relation_store.is_temp + && (is_callback_target + || (propagate_triggers && !relation_store.put_triggers.is_empty()))); let has_indices = !relation_store.indices.is_empty(); let has_hnsw_indices = !relation_store.hnsw_indices.is_empty(); let has_fts_indices = !relation_store.fts_indices.is_empty(); @@ -567,7 +567,7 @@ impl<'a> SessionTx<'a> { &metadata.keys, key_bindings, headers, - ); + )?; let mut stack = vec![]; let hnsw_filters = Self::make_hnsw_filters(relation_store)?; @@ -701,7 +701,7 @@ impl<'a> SessionTx<'a> { &db.fixed_rules.read().unwrap(), cur_vld, )? - .get_single_program()?; + .get_single_program()?; make_const_rule( &mut program, @@ -797,7 +797,7 @@ impl<'a> SessionTx<'a> { fn ensure_not_in_relation( &mut self, - res_iter: impl Iterator, + res_iter: impl Iterator, headers: &[Symbol], cur_vld: ValidityTs, relation_store: &RelationHandle, @@ -844,7 +844,7 @@ impl<'a> SessionTx<'a> { fn ensure_in_relation( &mut self, - res_iter: impl Iterator, + res_iter: impl Iterator, headers: &[Symbol], cur_vld: ValidityTs, relation_store: &RelationHandle, @@ -914,7 +914,7 @@ impl<'a> SessionTx<'a> { fn remove_from_relation<'s, S: Storage<'s>>( &mut self, db: &Db, - res_iter: impl Iterator, + res_iter: impl Iterator, headers: &[Symbol], cur_vld: ValidityTs, callback_targets: &BTreeSet>, @@ -928,7 +928,8 @@ impl<'a> SessionTx<'a> { force_collect: &str, span: SourceSpan, ) -> Result<()> { - let is_callback_target = callback_targets.contains(&relation_store.name) || force_collect == relation_store.name; + let is_callback_target = + callback_targets.contains(&relation_store.name) || force_collect == relation_store.name; if relation_store.access_level < AccessLevel::Protected { bail!(InsufficientAccessLevel( @@ -944,9 +945,10 @@ impl<'a> SessionTx<'a> { headers, )?; - let need_to_collect = !force_collect.is_empty() || (!relation_store.is_temp - && (is_callback_target - || (propagate_triggers && !relation_store.rm_triggers.is_empty()))); + let need_to_collect = !force_collect.is_empty() + || (!relation_store.is_temp + && (is_callback_target + || (propagate_triggers && !relation_store.rm_triggers.is_empty()))); let has_indices = !relation_store.indices.is_empty(); let has_hnsw_indices = !relation_store.hnsw_indices.is_empty(); let has_fts_indices = !relation_store.fts_indices.is_empty(); @@ -1035,7 +1037,7 @@ impl<'a> SessionTx<'a> { &db.fixed_rules.read().unwrap(), cur_vld, )? - .get_single_program()?; + .get_single_program()?; make_const_rule(&mut program, "_new", k_bindings.clone(), new_tuples.clone()); @@ -1149,11 +1151,17 @@ fn make_update_extractors( input: &[ColumnDef], bindings: &[Symbol], tuple_headers: &[Symbol], -) -> Vec> { - stored - .iter() - .map(|s| make_extractor(s, input, bindings, tuple_headers).ok()) - .collect_vec() +) -> Result>> { + let input_keys: BTreeSet<_> = input.iter().map(|b| &b.name).collect(); + let mut extractors = Vec::with_capacity(stored.len()); + for col in stored.iter() { + if input_keys.contains(&col.name) { + extractors.push(Some(make_extractor(col, input, bindings, tuple_headers)?)); + } else { + extractors.push(None); + } + } + Ok(extractors) } fn make_extractor( diff --git a/cozo-core/src/runtime/tests.rs b/cozo-core/src/runtime/tests.rs index ea417607..afcdb1ab 100644 --- a/cozo-core/src/runtime/tests.rs +++ b/cozo-core/src/runtime/tests.rs @@ -1203,3 +1203,26 @@ fn as_store_in_imperative_script() { println!("{}", row); } } + +#[test] +fn update_shall_not_destroy_values() { + let db = DbInstance::default(); + db.run_default(r"?[x, y] <- [[1, 2]] :create z {x => y default 0}").unwrap(); + let r = db.run_default(r"?[x, y] := *z {x, y}").unwrap(); + assert_eq!(r.into_json()["rows"], json!([[1, 2]])); + db.run_default(r"?[x] <- [[1]] :update z {x}").unwrap(); + let r = db.run_default(r"?[x, y] := *z {x, y}").unwrap(); + assert_eq!(r.into_json()["rows"], json!([[1, 2]])); +} + + +#[test] +fn update_shall_work() { + let db = DbInstance::default(); + db.run_default(r"?[x, y, z] <- [[1, 2, 3]] :create z {x => y, z}").unwrap(); + let r = db.run_default(r"?[x, y, z] := *z {x, y, z}").unwrap(); + assert_eq!(r.into_json()["rows"], json!([[1, 2, 3]])); + db.run_default(r"?[x, y] <- [[1, 4]] :update z {x, y}").unwrap(); + let r = db.run_default(r"?[x, y, z] := *z {x, y, z}").unwrap(); + assert_eq!(r.into_json()["rows"], json!([[1, 4, 3]])); +} \ No newline at end of file