HNSW fixes; string slicing

main
Ziyang Hu 1 year ago
parent deab88a5d6
commit d779e0cc13

@ -885,6 +885,7 @@ pub(crate) fn get_op(name: &str) -> Option<&'static Op> {
"get" => &OP_GET, "get" => &OP_GET,
"maybe_get" => &OP_MAYBE_GET, "maybe_get" => &OP_MAYBE_GET,
"chars" => &OP_CHARS, "chars" => &OP_CHARS,
"slice_string" => &OP_SLICE_STRING,
"from_substrings" => &OP_FROM_SUBSTRINGS, "from_substrings" => &OP_FROM_SUBSTRINGS,
"slice" => &OP_SLICE, "slice" => &OP_SLICE,
"regex_matches" => &OP_REGEX_MATCHES, "regex_matches" => &OP_REGEX_MATCHES,

@ -1372,9 +1372,7 @@ define_op!(OP_ENDS_WITH, 2, false);
pub(crate) fn op_ends_with(args: &[DataValue]) -> Result<DataValue> { pub(crate) fn op_ends_with(args: &[DataValue]) -> Result<DataValue> {
match (&args[0], &args[1]) { match (&args[0], &args[1]) {
(DataValue::Str(l), DataValue::Str(r)) => Ok(DataValue::from(l.ends_with(r as &str))), (DataValue::Str(l), DataValue::Str(r)) => Ok(DataValue::from(l.ends_with(r as &str))),
(DataValue::Bytes(l), DataValue::Bytes(r)) => { (DataValue::Bytes(l), DataValue::Bytes(r)) => Ok(DataValue::from(l.ends_with(r as &[u8]))),
Ok(DataValue::from(l.ends_with(r as &[u8])))
}
_ => bail!("'ends_with' requires strings or bytes"), _ => bail!("'ends_with' requires strings or bytes"),
} }
} }
@ -1846,6 +1844,27 @@ pub(crate) fn op_chars(args: &[DataValue]) -> Result<DataValue> {
)) ))
} }
define_op!(OP_SLICE_STRING, 3, false);
pub(crate) fn op_slice_string(args: &[DataValue]) -> Result<DataValue> {
let s = args[0]
.get_str()
.ok_or_else(|| miette!("first argument to 'slice_string' mut be a string"))?;
let m = args[1]
.get_int()
.ok_or_else(|| miette!("second argument to 'slice_string' mut be an integer"))?;
ensure!(
m >= 0,
"second argument to 'slice_string' mut be a positive integer"
);
let n = args[2]
.get_int()
.ok_or_else(|| miette!("third argument to 'slice_string' mut be an integer"))?;
ensure!(n >= m, "third argument to 'slice_string' mut be a positive integer greater than the second argument");
Ok(DataValue::Str(
s.chars().skip(m as usize).take((n - m) as usize).collect(),
))
}
define_op!(OP_FROM_SUBSTRINGS, 1, false); define_op!(OP_FROM_SUBSTRINGS, 1, false);
pub(crate) fn op_from_substrings(args: &[DataValue]) -> Result<DataValue> { pub(crate) fn op_from_substrings(args: &[DataValue]) -> Result<DataValue> {
let mut ret = String::new(); let mut ret = String::new();

@ -1011,8 +1011,8 @@ impl HnswSearch {
self.bindings self.bindings
.iter() .iter()
.chain(self.bind_field.iter()) .chain(self.bind_field.iter())
.chain(self.bind_distance.iter())
.chain(self.bind_field_idx.iter()) .chain(self.bind_field_idx.iter())
.chain(self.bind_distance.iter())
.chain(self.bind_vector.iter()) .chain(self.bind_vector.iter())
} }
} }

@ -687,6 +687,7 @@ impl<'a> SessionTx<'a> {
) -> Result<bool> { ) -> Result<bool> {
if let Some(code) = filter { if let Some(code) = filter {
if !eval_bytecode_pred(code, tuple, stack, Default::default())? { if !eval_bytecode_pred(code, tuple, stack, Default::default())? {
self.hnsw_remove(orig_table, idx_table, tuple)?;
return Ok(false); return Ok(false);
} }
} }
@ -955,6 +956,7 @@ impl<'a> SessionTx<'a> {
.get(self, &cand_key.0)? .get(self, &cand_key.0)?
.ok_or_else(|| miette!("corrupted index"))?; .ok_or_else(|| miette!("corrupted index"))?;
// make sure the order is the same as in all_bindings()!!!
if config.bind_field.is_some() { if config.bind_field.is_some() {
let field = if cand_key.1 < config.base_handle.metadata.keys.len() { let field = if cand_key.1 < config.base_handle.metadata.keys.len() {
config.base_handle.metadata.keys[cand_key.1].name.clone() config.base_handle.metadata.keys[cand_key.1].name.clone()

@ -697,6 +697,42 @@ fn test_vec_types() {
println!("{}", res.into_json()); println!("{}", res.into_json());
} }
#[test]
fn test_vec_index_insertion() {
let db = DbInstance::new("mem", "", "").unwrap();
db.run_default(
r"
?[k, v, m] <- [['a', [1,2], true],
['b', [2,3], false]]
:create a {k: String => v: <F32; 2>, m: Bool}
",
)
.unwrap();
db.run_default(
r"
::hnsw create a:vec {
dim: 2,
m: 50,
dtype: F32,
fields: [v],
distance: L2,
ef_construction: 20,
filter: m,
#extend_candidates: true,
#keep_pruned_connections: true,
}",
)
.unwrap();
let res = db.run_default("?[k] := *a:vec{layer: 0, fr_k, to_k}, k = fr_k or k = to_k").unwrap();
assert_eq!(res.rows.len(), 1);
println!("update!");
db.run_default(r#"?[k, m] <- [["a", false]] :update a {}"#).unwrap();
let res = db.run_default("?[k] := *a:vec{layer: 0, fr_k, to_k}, k = fr_k or k = to_k").unwrap();
assert_eq!(res.rows.len(), 0);
println!("{}", res.into_json());
}
#[test] #[test]
fn test_vec_index() { fn test_vec_index() {
let db = DbInstance::new("mem", "", "").unwrap(); let db = DbInstance::new("mem", "", "").unwrap();

Loading…
Cancel
Save