fix `:update` on relations with default val cols

main
Ziyang Hu 1 year ago
parent a71f240a5c
commit 7965dfe67f

@ -44,7 +44,7 @@ impl<'a> SessionTx<'a> {
pub(crate) fn execute_relation<'s, S: Storage<'s>>( pub(crate) fn execute_relation<'s, S: Storage<'s>>(
&mut self, &mut self,
db: &Db<S>, db: &Db<S>,
res_iter: impl Iterator<Item=Tuple>, res_iter: impl Iterator<Item = Tuple>,
op: RelationOp, op: RelationOp,
meta: &InputRelationHandle, meta: &InputRelationHandle,
headers: &[Symbol], headers: &[Symbol],
@ -89,7 +89,7 @@ impl<'a> SessionTx<'a> {
&db.fixed_rules.read().unwrap(), &db.fixed_rules.read().unwrap(),
cur_vld, cur_vld,
)? )?
.get_single_program()?; .get_single_program()?;
let (_, cleanups) = db let (_, cleanups) = db
.run_query( .run_query(
@ -104,7 +104,7 @@ impl<'a> SessionTx<'a> {
if err.source_code().is_some() { if err.source_code().is_some() {
err err
} else { } else {
err.with_source_code(format!("{trigger}" )) err.with_source_code(format!("{trigger}"))
} }
})?; })?;
to_clear.extend(cleanups); to_clear.extend(cleanups);
@ -208,7 +208,7 @@ impl<'a> SessionTx<'a> {
fn put_into_relation<'s, S: Storage<'s>>( fn put_into_relation<'s, S: Storage<'s>>(
&mut self, &mut self,
db: &Db<S>, db: &Db<S>,
res_iter: impl Iterator<Item=Tuple>, res_iter: impl Iterator<Item = Tuple>,
headers: &[Symbol], headers: &[Symbol],
cur_vld: ValidityTs, cur_vld: ValidityTs,
callback_targets: &BTreeSet<SmartString<LazyCompact>>, callback_targets: &BTreeSet<SmartString<LazyCompact>>,
@ -223,7 +223,8 @@ impl<'a> SessionTx<'a> {
force_collect: &str, force_collect: &str,
span: SourceSpan, span: SourceSpan,
) -> Result<()> { ) -> Result<()> {
let is_callback_target = callback_targets.contains(&relation_store.name) || force_collect == &relation_store.name; let is_callback_target = callback_targets.contains(&relation_store.name)
|| force_collect == &relation_store.name;
if relation_store.access_level < AccessLevel::Protected { if relation_store.access_level < AccessLevel::Protected {
bail!(InsufficientAccessLevel( bail!(InsufficientAccessLevel(
@ -240,9 +241,10 @@ impl<'a> SessionTx<'a> {
headers, headers,
)?; )?;
let need_to_collect = !force_collect.is_empty() || (!relation_store.is_temp let need_to_collect = !force_collect.is_empty()
&& (is_callback_target || (!relation_store.is_temp
|| (propagate_triggers && !relation_store.put_triggers.is_empty()))); && (is_callback_target
|| (propagate_triggers && !relation_store.put_triggers.is_empty())));
let has_indices = !relation_store.indices.is_empty(); let has_indices = !relation_store.indices.is_empty();
let has_hnsw_indices = !relation_store.hnsw_indices.is_empty(); let has_hnsw_indices = !relation_store.hnsw_indices.is_empty();
let has_fts_indices = !relation_store.fts_indices.is_empty(); let has_fts_indices = !relation_store.fts_indices.is_empty();
@ -464,11 +466,9 @@ impl<'a> SessionTx<'a> {
) -> Result<BTreeMap<SmartString<LazyCompact>, (Arc<TextAnalyzer>, Vec<Bytecode>)>> { ) -> Result<BTreeMap<SmartString<LazyCompact>, (Arc<TextAnalyzer>, Vec<Bytecode>)>> {
let mut processors = BTreeMap::new(); let mut processors = BTreeMap::new();
for (name, (_, manifest)) in relation_store.fts_indices.iter() { for (name, (_, manifest)) in relation_store.fts_indices.iter() {
let tokenizer = self.tokenizers.get( let tokenizer = self
&name, .tokenizers
&manifest.tokenizer, .get(&name, &manifest.tokenizer, &manifest.filters)?;
&manifest.filters,
)?;
let parsed = CozoScriptParser::parse(Rule::expr, &manifest.extractor) let parsed = CozoScriptParser::parse(Rule::expr, &manifest.extractor)
.into_diagnostic()? .into_diagnostic()?
@ -481,11 +481,9 @@ impl<'a> SessionTx<'a> {
processors.insert(name.clone(), (tokenizer, extractor)); processors.insert(name.clone(), (tokenizer, extractor));
} }
for (name, (_, _, manifest)) in relation_store.lsh_indices.iter() { for (name, (_, _, manifest)) in relation_store.lsh_indices.iter() {
let tokenizer = self.tokenizers.get( let tokenizer = self
&name, .tokenizers
&manifest.tokenizer, .get(&name, &manifest.tokenizer, &manifest.filters)?;
&manifest.filters,
)?;
let parsed = CozoScriptParser::parse(Rule::expr, &manifest.extractor) let parsed = CozoScriptParser::parse(Rule::expr, &manifest.extractor)
.into_diagnostic()? .into_diagnostic()?
@ -522,7 +520,7 @@ impl<'a> SessionTx<'a> {
fn update_in_relation<'s, S: Storage<'s>>( fn update_in_relation<'s, S: Storage<'s>>(
&mut self, &mut self,
db: &Db<S>, db: &Db<S>,
res_iter: impl Iterator<Item=Tuple>, res_iter: impl Iterator<Item = Tuple>,
headers: &[Symbol], headers: &[Symbol],
cur_vld: ValidityTs, cur_vld: ValidityTs,
callback_targets: &BTreeSet<SmartString<LazyCompact>>, callback_targets: &BTreeSet<SmartString<LazyCompact>>,
@ -535,7 +533,8 @@ impl<'a> SessionTx<'a> {
force_collect: &str, force_collect: &str,
span: SourceSpan, span: SourceSpan,
) -> Result<()> { ) -> Result<()> {
let is_callback_target = callback_targets.contains(&relation_store.name) || force_collect == &relation_store.name; let is_callback_target = callback_targets.contains(&relation_store.name)
|| force_collect == &relation_store.name;
if relation_store.access_level < AccessLevel::Protected { if relation_store.access_level < AccessLevel::Protected {
bail!(InsufficientAccessLevel( bail!(InsufficientAccessLevel(
@ -552,9 +551,10 @@ impl<'a> SessionTx<'a> {
headers, headers,
)?; )?;
let need_to_collect = !force_collect.is_empty() || (!relation_store.is_temp let need_to_collect = !force_collect.is_empty()
&& (is_callback_target || (!relation_store.is_temp
|| (propagate_triggers && !relation_store.put_triggers.is_empty()))); && (is_callback_target
|| (propagate_triggers && !relation_store.put_triggers.is_empty())));
let has_indices = !relation_store.indices.is_empty(); let has_indices = !relation_store.indices.is_empty();
let has_hnsw_indices = !relation_store.hnsw_indices.is_empty(); let has_hnsw_indices = !relation_store.hnsw_indices.is_empty();
let has_fts_indices = !relation_store.fts_indices.is_empty(); let has_fts_indices = !relation_store.fts_indices.is_empty();
@ -567,7 +567,7 @@ impl<'a> SessionTx<'a> {
&metadata.keys, &metadata.keys,
key_bindings, key_bindings,
headers, headers,
); )?;
let mut stack = vec![]; let mut stack = vec![];
let hnsw_filters = Self::make_hnsw_filters(relation_store)?; let hnsw_filters = Self::make_hnsw_filters(relation_store)?;
@ -701,7 +701,7 @@ impl<'a> SessionTx<'a> {
&db.fixed_rules.read().unwrap(), &db.fixed_rules.read().unwrap(),
cur_vld, cur_vld,
)? )?
.get_single_program()?; .get_single_program()?;
make_const_rule( make_const_rule(
&mut program, &mut program,
@ -797,7 +797,7 @@ impl<'a> SessionTx<'a> {
fn ensure_not_in_relation( fn ensure_not_in_relation(
&mut self, &mut self,
res_iter: impl Iterator<Item=Tuple>, res_iter: impl Iterator<Item = Tuple>,
headers: &[Symbol], headers: &[Symbol],
cur_vld: ValidityTs, cur_vld: ValidityTs,
relation_store: &RelationHandle, relation_store: &RelationHandle,
@ -844,7 +844,7 @@ impl<'a> SessionTx<'a> {
fn ensure_in_relation( fn ensure_in_relation(
&mut self, &mut self,
res_iter: impl Iterator<Item=Tuple>, res_iter: impl Iterator<Item = Tuple>,
headers: &[Symbol], headers: &[Symbol],
cur_vld: ValidityTs, cur_vld: ValidityTs,
relation_store: &RelationHandle, relation_store: &RelationHandle,
@ -914,7 +914,7 @@ impl<'a> SessionTx<'a> {
fn remove_from_relation<'s, S: Storage<'s>>( fn remove_from_relation<'s, S: Storage<'s>>(
&mut self, &mut self,
db: &Db<S>, db: &Db<S>,
res_iter: impl Iterator<Item=Tuple>, res_iter: impl Iterator<Item = Tuple>,
headers: &[Symbol], headers: &[Symbol],
cur_vld: ValidityTs, cur_vld: ValidityTs,
callback_targets: &BTreeSet<SmartString<LazyCompact>>, callback_targets: &BTreeSet<SmartString<LazyCompact>>,
@ -928,7 +928,8 @@ impl<'a> SessionTx<'a> {
force_collect: &str, force_collect: &str,
span: SourceSpan, span: SourceSpan,
) -> Result<()> { ) -> Result<()> {
let is_callback_target = callback_targets.contains(&relation_store.name) || force_collect == relation_store.name; let is_callback_target =
callback_targets.contains(&relation_store.name) || force_collect == relation_store.name;
if relation_store.access_level < AccessLevel::Protected { if relation_store.access_level < AccessLevel::Protected {
bail!(InsufficientAccessLevel( bail!(InsufficientAccessLevel(
@ -944,9 +945,10 @@ impl<'a> SessionTx<'a> {
headers, headers,
)?; )?;
let need_to_collect = !force_collect.is_empty() || (!relation_store.is_temp let need_to_collect = !force_collect.is_empty()
&& (is_callback_target || (!relation_store.is_temp
|| (propagate_triggers && !relation_store.rm_triggers.is_empty()))); && (is_callback_target
|| (propagate_triggers && !relation_store.rm_triggers.is_empty())));
let has_indices = !relation_store.indices.is_empty(); let has_indices = !relation_store.indices.is_empty();
let has_hnsw_indices = !relation_store.hnsw_indices.is_empty(); let has_hnsw_indices = !relation_store.hnsw_indices.is_empty();
let has_fts_indices = !relation_store.fts_indices.is_empty(); let has_fts_indices = !relation_store.fts_indices.is_empty();
@ -1035,7 +1037,7 @@ impl<'a> SessionTx<'a> {
&db.fixed_rules.read().unwrap(), &db.fixed_rules.read().unwrap(),
cur_vld, cur_vld,
)? )?
.get_single_program()?; .get_single_program()?;
make_const_rule(&mut program, "_new", k_bindings.clone(), new_tuples.clone()); make_const_rule(&mut program, "_new", k_bindings.clone(), new_tuples.clone());
@ -1149,11 +1151,17 @@ fn make_update_extractors(
input: &[ColumnDef], input: &[ColumnDef],
bindings: &[Symbol], bindings: &[Symbol],
tuple_headers: &[Symbol], tuple_headers: &[Symbol],
) -> Vec<Option<DataExtractor>> { ) -> Result<Vec<Option<DataExtractor>>> {
stored let input_keys: BTreeSet<_> = input.iter().map(|b| &b.name).collect();
.iter() let mut extractors = Vec::with_capacity(stored.len());
.map(|s| make_extractor(s, input, bindings, tuple_headers).ok()) for col in stored.iter() {
.collect_vec() if input_keys.contains(&col.name) {
extractors.push(Some(make_extractor(col, input, bindings, tuple_headers)?));
} else {
extractors.push(None);
}
}
Ok(extractors)
} }
fn make_extractor( fn make_extractor(

@ -1203,3 +1203,26 @@ fn as_store_in_imperative_script() {
println!("{}", row); println!("{}", row);
} }
} }
#[test]
fn update_shall_not_destroy_values() {
let db = DbInstance::default();
db.run_default(r"?[x, y] <- [[1, 2]] :create z {x => y default 0}").unwrap();
let r = db.run_default(r"?[x, y] := *z {x, y}").unwrap();
assert_eq!(r.into_json()["rows"], json!([[1, 2]]));
db.run_default(r"?[x] <- [[1]] :update z {x}").unwrap();
let r = db.run_default(r"?[x, y] := *z {x, y}").unwrap();
assert_eq!(r.into_json()["rows"], json!([[1, 2]]));
}
#[test]
fn update_shall_work() {
let db = DbInstance::default();
db.run_default(r"?[x, y, z] <- [[1, 2, 3]] :create z {x => y, z}").unwrap();
let r = db.run_default(r"?[x, y, z] := *z {x, y, z}").unwrap();
assert_eq!(r.into_json()["rows"], json!([[1, 2, 3]]));
db.run_default(r"?[x, y] <- [[1, 4]] :update z {x, y}").unwrap();
let r = db.run_default(r"?[x, y, z] := *z {x, y, z}").unwrap();
assert_eq!(r.into_json()["rows"], json!([[1, 4, 3]]));
}
Loading…
Cancel
Save