From d779e0cc13f98183bcf29fb2eaf0507a7019c681 Mon Sep 17 00:00:00 2001 From: Ziyang Hu Date: Tue, 20 Jun 2023 18:31:17 +0800 Subject: [PATCH] HNSW fixes; string slicing --- cozo-core/src/data/expr.rs | 1 + cozo-core/src/data/functions.rs | 25 ++++++++++++++++++++--- cozo-core/src/data/program.rs | 2 +- cozo-core/src/runtime/hnsw.rs | 2 ++ cozo-core/src/runtime/tests.rs | 36 +++++++++++++++++++++++++++++++++ 5 files changed, 62 insertions(+), 4 deletions(-) diff --git a/cozo-core/src/data/expr.rs b/cozo-core/src/data/expr.rs index 4700add4..a8d25eac 100644 --- a/cozo-core/src/data/expr.rs +++ b/cozo-core/src/data/expr.rs @@ -885,6 +885,7 @@ pub(crate) fn get_op(name: &str) -> Option<&'static Op> { "get" => &OP_GET, "maybe_get" => &OP_MAYBE_GET, "chars" => &OP_CHARS, + "slice_string" => &OP_SLICE_STRING, "from_substrings" => &OP_FROM_SUBSTRINGS, "slice" => &OP_SLICE, "regex_matches" => &OP_REGEX_MATCHES, diff --git a/cozo-core/src/data/functions.rs b/cozo-core/src/data/functions.rs index c2ced5df..1e804f00 100644 --- a/cozo-core/src/data/functions.rs +++ b/cozo-core/src/data/functions.rs @@ -1372,9 +1372,7 @@ define_op!(OP_ENDS_WITH, 2, false); pub(crate) fn op_ends_with(args: &[DataValue]) -> Result { match (&args[0], &args[1]) { (DataValue::Str(l), DataValue::Str(r)) => Ok(DataValue::from(l.ends_with(r as &str))), - (DataValue::Bytes(l), DataValue::Bytes(r)) => { - Ok(DataValue::from(l.ends_with(r as &[u8]))) - } + (DataValue::Bytes(l), DataValue::Bytes(r)) => Ok(DataValue::from(l.ends_with(r as &[u8]))), _ => bail!("'ends_with' requires strings or bytes"), } } @@ -1846,6 +1844,27 @@ pub(crate) fn op_chars(args: &[DataValue]) -> Result { )) } +define_op!(OP_SLICE_STRING, 3, false); +pub(crate) fn op_slice_string(args: &[DataValue]) -> Result { + let s = args[0] + .get_str() + .ok_or_else(|| miette!("first argument to 'slice_string' mut be a string"))?; + let m = args[1] + .get_int() + .ok_or_else(|| miette!("second argument to 'slice_string' mut be an integer"))?; + ensure!( + m >= 0, + "second argument to 'slice_string' mut be a positive integer" + ); + let n = args[2] + .get_int() + .ok_or_else(|| miette!("third argument to 'slice_string' mut be an integer"))?; + ensure!(n >= m, "third argument to 'slice_string' mut be a positive integer greater than the second argument"); + Ok(DataValue::Str( + s.chars().skip(m as usize).take((n - m) as usize).collect(), + )) +} + define_op!(OP_FROM_SUBSTRINGS, 1, false); pub(crate) fn op_from_substrings(args: &[DataValue]) -> Result { let mut ret = String::new(); diff --git a/cozo-core/src/data/program.rs b/cozo-core/src/data/program.rs index c24a0276..d3c6b0af 100644 --- a/cozo-core/src/data/program.rs +++ b/cozo-core/src/data/program.rs @@ -1011,8 +1011,8 @@ impl HnswSearch { self.bindings .iter() .chain(self.bind_field.iter()) - .chain(self.bind_distance.iter()) .chain(self.bind_field_idx.iter()) + .chain(self.bind_distance.iter()) .chain(self.bind_vector.iter()) } } diff --git a/cozo-core/src/runtime/hnsw.rs b/cozo-core/src/runtime/hnsw.rs index 97eca778..47724072 100644 --- a/cozo-core/src/runtime/hnsw.rs +++ b/cozo-core/src/runtime/hnsw.rs @@ -687,6 +687,7 @@ impl<'a> SessionTx<'a> { ) -> Result { if let Some(code) = filter { if !eval_bytecode_pred(code, tuple, stack, Default::default())? { + self.hnsw_remove(orig_table, idx_table, tuple)?; return Ok(false); } } @@ -955,6 +956,7 @@ impl<'a> SessionTx<'a> { .get(self, &cand_key.0)? .ok_or_else(|| miette!("corrupted index"))?; + // make sure the order is the same as in all_bindings()!!! if config.bind_field.is_some() { let field = if cand_key.1 < config.base_handle.metadata.keys.len() { config.base_handle.metadata.keys[cand_key.1].name.clone() diff --git a/cozo-core/src/runtime/tests.rs b/cozo-core/src/runtime/tests.rs index b8e6d420..6528b88b 100644 --- a/cozo-core/src/runtime/tests.rs +++ b/cozo-core/src/runtime/tests.rs @@ -697,6 +697,42 @@ fn test_vec_types() { println!("{}", res.into_json()); } +#[test] +fn test_vec_index_insertion() { + let db = DbInstance::new("mem", "", "").unwrap(); + db.run_default( + r" + ?[k, v, m] <- [['a', [1,2], true], + ['b', [2,3], false]] + + :create a {k: String => v: , m: Bool} + ", + ) + .unwrap(); + db.run_default( + r" + ::hnsw create a:vec { + dim: 2, + m: 50, + dtype: F32, + fields: [v], + distance: L2, + ef_construction: 20, + filter: m, + #extend_candidates: true, + #keep_pruned_connections: true, + }", + ) + .unwrap(); + let res = db.run_default("?[k] := *a:vec{layer: 0, fr_k, to_k}, k = fr_k or k = to_k").unwrap(); + assert_eq!(res.rows.len(), 1); + println!("update!"); + db.run_default(r#"?[k, m] <- [["a", false]] :update a {}"#).unwrap(); + let res = db.run_default("?[k] := *a:vec{layer: 0, fr_k, to_k}, k = fr_k or k = to_k").unwrap(); + assert_eq!(res.rows.len(), 0); + println!("{}", res.into_json()); +} + #[test] fn test_vec_index() { let db = DbInstance::new("mem", "", "").unwrap();