`update` functionality

main
Ziyang Hu 1 year ago
parent d984304322
commit 39a0946920

@ -127,10 +127,11 @@ limit_option = {":limit" ~ expr}
offset_option = {":offset" ~ expr} offset_option = {":offset" ~ expr}
sort_option = {(":sort" | ":order") ~ (sort_arg ~ ",")* ~ sort_arg } sort_option = {(":sort" | ":order") ~ (sort_arg ~ ",")* ~ sort_arg }
relation_option = {relation_op ~ (compound_ident | underscore_ident) ~ table_schema?} relation_option = {relation_op ~ (compound_ident | underscore_ident) ~ table_schema?}
relation_op = _{relation_create | relation_replace | relation_put | relation_rm | relation_ensure | relation_ensure_not} relation_op = _{relation_create | relation_replace | relation_put | relation_update | relation_rm | relation_ensure | relation_ensure_not}
relation_create = {":create"} relation_create = {":create"}
relation_replace = {":replace"} relation_replace = {":replace"}
relation_put = {":put"} relation_put = {":put"}
relation_update = {":update"}
relation_rm = {":rm"} relation_rm = {":rm"}
relation_ensure = {":ensure"} relation_ensure = {":ensure"}
relation_ensure_not = {":ensure_not"} relation_ensure_not = {":ensure_not"}

@ -93,6 +93,9 @@ impl Display for QueryOutOptions {
RelationOp::Put => { RelationOp::Put => {
write!(f, ":put ")?; write!(f, ":put ")?;
} }
RelationOp::Update => {
write!(f, ":update ")?;
}
RelationOp::Rm => { RelationOp::Rm => {
write!(f, ":rm ")?; write!(f, ":rm ")?;
} }
@ -172,6 +175,7 @@ pub(crate) enum RelationOp {
Create, Create,
Replace, Replace,
Put, Put,
Update,
Rm, Rm,
Ensure, Ensure,
EnsureNot, EnsureNot,

@ -320,6 +320,7 @@ pub(crate) fn parse_query(
Rule::relation_create => RelationOp::Create, Rule::relation_create => RelationOp::Create,
Rule::relation_replace => RelationOp::Replace, Rule::relation_replace => RelationOp::Replace,
Rule::relation_put => RelationOp::Put, Rule::relation_put => RelationOp::Put,
Rule::relation_update => RelationOp::Update,
Rule::relation_rm => RelationOp::Rm, Rule::relation_rm => RelationOp::Rm,
Rule::relation_ensure => RelationOp::Ensure, Rule::relation_ensure => RelationOp::Ensure,
Rule::relation_ensure_not => RelationOp::EnsureNot, Rule::relation_ensure_not => RelationOp::EnsureNot,

@ -19,7 +19,7 @@ use crate::data::expr::Expr;
use crate::data::program::{FixedRuleApply, InputInlineRulesOrFixed, InputProgram, RelationOp}; use crate::data::program::{FixedRuleApply, InputInlineRulesOrFixed, InputProgram, RelationOp};
use crate::data::relation::{ColumnDef, NullableColType}; use crate::data::relation::{ColumnDef, NullableColType};
use crate::data::symb::Symbol; use crate::data::symb::Symbol;
use crate::data::tuple::Tuple; use crate::data::tuple::{Tuple, ENCODED_KEY_MIN_LEN};
use crate::data::value::{DataValue, ValidityTs}; use crate::data::value::{DataValue, ValidityTs};
use crate::fixed_rule::utilities::constant::Constant; use crate::fixed_rule::utilities::constant::Constant;
use crate::fixed_rule::FixedRuleHandle; use crate::fixed_rule::FixedRuleHandle;
@ -175,9 +175,7 @@ impl<'a> SessionTx<'a> {
} }
} }
if has_hnsw_indices { if has_hnsw_indices {
for (idx_handle, _) in for (idx_handle, _) in relation_store.hnsw_indices.values() {
relation_store.hnsw_indices.values()
{
self.hnsw_remove(&relation_store, idx_handle, &extracted)?; self.hnsw_remove(&relation_store, idx_handle, &extracted)?;
} }
} }
@ -389,6 +387,238 @@ impl<'a> SessionTx<'a> {
} }
} }
} }
RelationOp::Update => {
if relation_store.access_level < AccessLevel::Protected {
bail!(InsufficientAccessLevel(
relation_store.name.to_string(),
"row update".to_string(),
relation_store.access_level
));
}
let key_extractors = make_extractors(
&relation_store.metadata.keys,
&metadata.keys,
key_bindings,
headers,
)?;
let need_to_collect = !relation_store.is_temp
&& (is_callback_target
|| (propagate_triggers && !relation_store.put_triggers.is_empty()));
let has_indices = !relation_store.indices.is_empty();
let has_hnsw_indices = !relation_store.hnsw_indices.is_empty();
let mut new_tuples: Vec<DataValue> = vec![];
let mut old_tuples: Vec<DataValue> = vec![];
let val_extractors = make_update_extractors(
&relation_store.metadata.non_keys,
&metadata.non_keys,
dep_bindings,
headers,
);
let mut stack = vec![];
let mut hnsw_filters = BTreeMap::new();
if has_hnsw_indices {
for (name, (_, manifest)) in relation_store.hnsw_indices.iter() {
if let Some(f_code) = &manifest.index_filter {
let parsed = CozoScriptParser::parse(Rule::expr, f_code)
.into_diagnostic()?
.next()
.unwrap();
let mut code_expr = build_expr(parsed, &Default::default())?;
let binding_map = relation_store.raw_binding_map();
code_expr.fill_binding_indices(&binding_map)?;
hnsw_filters.insert(name.clone(), code_expr.compile());
}
}
}
for tuple in res_iter {
let mut new_kv: Vec<DataValue> = key_extractors
.iter()
.map(|ex| ex.extract_data(&tuple, cur_vld))
.try_collect()?;
let key = relation_store.encode_key_for_store(&new_kv, *span)?;
let original_val_bytes = if relation_store.is_temp {
self.temp_store_tx.get(&key, true)?
} else {
self.store_tx.get(&key, true)?
};
let original_val: Tuple = match original_val_bytes {
None => {
bail!(TransactAssertionFailure {
relation: relation_store.name.to_string(),
key: new_kv,
notice: "key to update does not exist".to_string()
})
}
Some(v) => rmp_serde::from_slice(&v[ENCODED_KEY_MIN_LEN..]).unwrap(),
};
let mut old_kv = Vec::with_capacity(relation_store.arity());
old_kv.extend_from_slice(&new_kv);
old_kv.extend_from_slice(&original_val);
new_kv.reserve_exact(relation_store.arity());
for (i, extractor) in val_extractors.iter().enumerate() {
match extractor {
None => {
new_kv.push(original_val[i].clone());
}
Some(ex) => {
let val = ex.extract_data(&tuple, cur_vld)?;
new_kv.push(val);
}
}
}
let new_val = relation_store.encode_val_for_store(&new_kv, *span)?;
if need_to_collect || has_indices || has_hnsw_indices {
if has_indices {
for (idx_rel, idx_extractor) in relation_store.indices.values() {
let idx_tup_old = idx_extractor
.iter()
.map(|i| old_kv[*i].clone())
.collect_vec();
let encoded_old = idx_rel
.encode_key_for_store(&idx_tup_old, Default::default())?;
self.store_tx.del(&encoded_old)?;
let idx_tup_new = idx_extractor
.iter()
.map(|i| new_kv[*i].clone())
.collect_vec();
let encoded_new = idx_rel
.encode_key_for_store(&idx_tup_new, Default::default())?;
self.store_tx.put(&encoded_new, &[])?;
}
}
if need_to_collect {
old_tuples.push(DataValue::List(old_kv));
}
if has_hnsw_indices {
for (name, (idx_handle, idx_manifest)) in
relation_store.hnsw_indices.iter()
{
let filter = hnsw_filters.get(name);
self.hnsw_put(
idx_manifest,
&relation_store,
idx_handle,
filter,
&mut stack,
&new_kv,
)?;
}
}
if need_to_collect {
new_tuples.push(DataValue::List(new_kv));
}
}
if relation_store.is_temp {
self.temp_store_tx.put(&key, &new_val)?;
} else {
self.store_tx.put(&key, &new_val)?;
}
}
if need_to_collect && !new_tuples.is_empty() {
let mut bindings = relation_store
.metadata
.keys
.iter()
.map(|k| Symbol::new(k.name.clone(), Default::default()))
.collect_vec();
let v_bindings = relation_store
.metadata
.non_keys
.iter()
.map(|k| Symbol::new(k.name.clone(), Default::default()));
bindings.extend(v_bindings);
let kv_bindings = bindings;
if propagate_triggers {
for trigger in &relation_store.put_triggers {
let mut program = parse_script(
trigger,
&Default::default(),
&db.fixed_rules.read().unwrap(),
cur_vld,
)?
.get_single_program()?;
make_const_rule(
&mut program,
"_new",
kv_bindings.clone(),
new_tuples.clone(),
);
make_const_rule(
&mut program,
"_old",
kv_bindings.clone(),
old_tuples.clone(),
);
let (_, cleanups) = db
.run_query(
self,
program,
cur_vld,
callback_targets,
callback_collector,
false,
)
.map_err(|err| {
if err.source_code().is_some() {
err
} else {
err.with_source_code(trigger.to_string())
}
})?;
to_clear.extend(cleanups);
}
}
if is_callback_target {
let target_collector = callback_collector
.entry(relation_store.name.clone())
.or_default();
let headers = kv_bindings
.into_iter()
.map(|k| k.name.to_string())
.collect_vec();
target_collector.push((
CallbackOp::Put,
NamedRows::new(
headers.clone(),
new_tuples
.into_iter()
.map(|v| match v {
DataValue::List(l) => l,
_ => unreachable!(),
})
.collect_vec(),
),
NamedRows::new(
headers,
old_tuples
.into_iter()
.map(|v| match v {
DataValue::List(l) => l,
_ => unreachable!(),
})
.collect_vec(),
),
))
}
}
}
RelationOp::Create | RelationOp::Replace | RelationOp::Put => { RelationOp::Create | RelationOp::Replace | RelationOp::Put => {
if relation_store.access_level < AccessLevel::Protected { if relation_store.access_level < AccessLevel::Protected {
bail!(InsufficientAccessLevel( bail!(InsufficientAccessLevel(
@ -485,7 +715,7 @@ impl<'a> SessionTx<'a> {
if has_hnsw_indices { if has_hnsw_indices {
for (name, (idx_handle, idx_manifest)) in for (name, (idx_handle, idx_manifest)) in
relation_store.hnsw_indices.iter() relation_store.hnsw_indices.iter()
{ {
let filter = hnsw_filters.get(name); let filter = hnsw_filters.get(name);
self.hnsw_put( self.hnsw_put(
@ -647,6 +877,18 @@ fn make_extractors(
.try_collect() .try_collect()
} }
fn make_update_extractors(
stored: &[ColumnDef],
input: &[ColumnDef],
bindings: &[Symbol],
tuple_headers: &[Symbol],
) -> Vec<Option<DataExtractor>> {
stored
.iter()
.map(|s| make_extractor(s, input, bindings, tuple_headers).ok())
.collect_vec()
}
fn make_extractor( fn make_extractor(
stored: &ColumnDef, stored: &ColumnDef,
input: &[ColumnDef], input: &[ColumnDef],

@ -1057,13 +1057,13 @@ impl<'s, S: Storage<'s>> Db<S> {
) )
} }
RelAlgebra::HnswSearch(HnswSearchRA { RelAlgebra::HnswSearch(HnswSearchRA {
hnsw_search, hnsw_search, ..
..
}) => ( }) => (
"hnsw_index", "hnsw_index",
json!(format!(":{}", hnsw_search.query.name)), json!(format!(":{}", hnsw_search.query.name)),
json!(hnsw_search.query.name), json!(hnsw_search.query.name),
json!(hnsw_search.filter json!(hnsw_search
.filter
.iter() .iter()
.map(|f| f.to_string()) .map(|f| f.to_string())
.collect_vec()), .collect_vec()),
@ -1321,7 +1321,8 @@ impl<'s, S: Storage<'s>> Db<S> {
StoreRelationNotFoundError(meta.name.to_string()) StoreRelationNotFoundError(meta.name.to_string())
); );
existing.ensure_compatible(meta, *op == RelationOp::Rm)?; existing
.ensure_compatible(meta, *op == RelationOp::Rm || *op == RelationOp::Update)?;
} }
}; };

@ -266,7 +266,7 @@ impl RelationHandle {
pub(crate) fn ensure_compatible( pub(crate) fn ensure_compatible(
&self, &self,
inp: &InputRelationHandle, inp: &InputRelationHandle,
is_remove: bool, is_remove_or_update: bool,
) -> Result<()> { ) -> Result<()> {
let InputRelationHandle { metadata, .. } = inp; let InputRelationHandle { metadata, .. } = inp;
// check that every given key is found and compatible // check that every given key is found and compatible
@ -280,7 +280,7 @@ impl RelationHandle {
for col in &self.metadata.keys { for col in &self.metadata.keys {
metadata.satisfied_by_required_col(col, true)?; metadata.satisfied_by_required_col(col, true)?;
} }
if !is_remove { if !is_remove_or_update {
for col in &self.metadata.non_keys { for col in &self.metadata.non_keys {
metadata.satisfied_by_required_col(col, false)?; metadata.satisfied_by_required_col(col, false)?;
} }

@ -489,6 +489,42 @@ fn test_callback() {
assert_eq!(collected[2].2.rows[0].len(), 3); assert_eq!(collected[2].2.rows[0].len(), 3);
} }
#[test]
fn test_update() {
let db = new_cozo_mem().unwrap();
db.run_script(
":create friends {fr: Int, to: Int => a: Any, b: Any, c: Any}",
Default::default(),
)
.unwrap();
db.run_script(
"?[fr, to, a, b, c] <- [[1,2,3,4,5]] :put friends {fr, to => a, b, c}",
Default::default(),
)
.unwrap();
let res = db
.run_script(
"?[fr, to, a, b, c] := *friends{fr, to, a, b, c}",
Default::default(),
)
.unwrap()
.into_json();
assert_eq!(res["rows"][0], json!([1, 2, 3, 4, 5]));
db.run_script(
"?[fr, to, b] <- [[1, 2, 100]] :update friends {fr, to => b}",
Default::default(),
)
.unwrap();
let res = db
.run_script(
"?[fr, to, a, b, c] := *friends{fr, to, a, b, c}",
Default::default(),
)
.unwrap()
.into_json();
assert_eq!(res["rows"][0], json!([1, 2, 3, 100, 5]));
}
#[test] #[test]
fn test_index() { fn test_index() {
let db = new_cozo_mem().unwrap(); let db = new_cozo_mem().unwrap();
@ -871,7 +907,11 @@ fn test_insertions() {
.unwrap(); .unwrap();
db.run_script(r"?[k] <- [[1]] :put a {k}", Default::default()) db.run_script(r"?[k] <- [[1]] :put a {k}", Default::default())
.unwrap(); .unwrap();
db.run_script(r"?[k] := k in int_range(300) :put a {k}", Default::default()).unwrap(); db.run_script(
r"?[k] := k in int_range(300) :put a {k}",
Default::default(),
)
.unwrap();
let res = db let res = db
.run_script( .run_script(
r"?[dist, k] := ~a:i{k | query: v, bind_distance: dist, k:10, ef: 50, filter: k % 2 == 0, radius: 245}, *a{k: 96, v}", r"?[dist, k] := ~a:i{k | query: v, bind_distance: dist, k:10, ef: 50, filter: k % 2 == 0, radius: 245}, *a{k: 96, v}",

Loading…
Cancel
Save