diff --git a/cozo-core/src/runtime/hnsw.rs b/cozo-core/src/runtime/hnsw.rs index 74a1f91b..822ca0de 100644 --- a/cozo-core/src/runtime/hnsw.rs +++ b/cozo-core/src/runtime/hnsw.rs @@ -339,9 +339,7 @@ impl<'a> SessionTx<'a> { if target_degree > m_max { // shrink links target_degree = self.hnsw_shrink_neighbour( - &neighbour.0, - neighbour.1, - neighbour.2, + &neighbour, m_max, current_level, manifest, @@ -377,9 +375,7 @@ impl<'a> SessionTx<'a> { } fn hnsw_shrink_neighbour( &mut self, - target: &[DataValue], - idx: usize, - sub_idx: i32, + target_key: &CompoundKey, m: usize, level: i64, manifest: &HnswIndexManifest, @@ -387,12 +383,11 @@ impl<'a> SessionTx<'a> { orig_table: &RelationHandle, vec_cache: &mut VectorCache, ) -> Result { - let c_key = (target.to_vec(), idx, sub_idx); - vec_cache.ensure_key(&c_key, orig_table, self)?; - let vec = vec_cache.get_key(&c_key).clone(); + vec_cache.ensure_key(&target_key, orig_table, self)?; + let vec = vec_cache.get_key(&target_key).clone(); let mut candidates = PriorityQueue::new(); for (neighbour_key, neighbour_dist) in - self.hnsw_get_neighbours(target.to_vec(), idx, sub_idx, level, idx_table, false)? + self.hnsw_get_neighbours(target_key, level, idx_table, false)? { candidates.push(neighbour_key, OrderedFloat(neighbour_dist)); } @@ -424,9 +419,9 @@ impl<'a> SessionTx<'a> { DataValue::from(false), ]; new_key.push(DataValue::from(level)); - new_key.extend_from_slice(target); - new_key.push(DataValue::from(idx as i64)); - new_key.push(DataValue::from(sub_idx as i64)); + new_key.extend_from_slice(&target_key.0); + new_key.push(DataValue::from(target_key.1 as i64)); + new_key.push(DataValue::from(target_key.2 as i64)); new_key.extend_from_slice(&new.0); new_key.push(DataValue::from(new.1 as i64)); new_key.push(DataValue::from(new.2 as i64)); @@ -440,9 +435,9 @@ impl<'a> SessionTx<'a> { if !new_candidate_set.contains(&old) { let mut old_key = Vec::with_capacity(orig_table.metadata.keys.len() * 2 + 5); old_key.push(DataValue::from(level)); - old_key.extend_from_slice(target); - old_key.push(DataValue::from(idx as i64)); - old_key.push(DataValue::from(sub_idx as i64)); + old_key.extend_from_slice(&target_key.0); + old_key.push(DataValue::from(target_key.1 as i64)); + old_key.push(DataValue::from(target_key.2 as i64)); old_key.extend_from_slice(&old.0); old_key.push(DataValue::from(old.1 as i64)); old_key.push(DataValue::from(old.2 as i64)); @@ -505,9 +500,7 @@ impl<'a> SessionTx<'a> { for (item, _) in found.iter() { // Extend by neighbours for (neighbour_key, _) in self.hnsw_get_neighbours( - item.0.clone(), - item.1, - item.2, + &item, level, idx_table, false, @@ -576,9 +569,7 @@ impl<'a> SessionTx<'a> { } // loop over each of the candidate's neighbors for (neighbour_key, _) in self.hnsw_get_neighbours( - candidate.0, - candidate.1, - candidate.2, + &candidate, cur_level, idx_table, false, @@ -604,19 +595,17 @@ impl<'a> SessionTx<'a> { } fn hnsw_get_neighbours<'b>( &'b self, - cand_key: Vec, - cand_idx: usize, - cand_sub_idx: i32, + cand_key: &'b CompoundKey, level: i64, idx_handle: &RelationHandle, include_deleted: bool, ) -> Result + 'b> { - let mut start_tuple = Vec::with_capacity(cand_key.len() + 3); + let mut start_tuple = Vec::with_capacity(cand_key.0.len() + 3); start_tuple.push(DataValue::from(level)); - start_tuple.extend_from_slice(&cand_key); - start_tuple.push(DataValue::from(cand_idx as i64)); - start_tuple.push(DataValue::from(cand_sub_idx as i64)); - let key_len = cand_key.len(); + start_tuple.extend_from_slice(&cand_key.0); + start_tuple.push(DataValue::from(cand_key.1 as i64)); + start_tuple.push(DataValue::from(cand_key.2 as i64)); + let key_len = cand_key.0.len(); Ok(idx_handle .scan_prefix(self, &start_tuple) .filter_map(move |res| { @@ -624,13 +613,13 @@ impl<'a> SessionTx<'a> { let key_idx = tuple[2 * key_len + 3].get_int().unwrap() as usize; let key_subidx = tuple[2 * key_len + 4].get_int().unwrap() as i32; - let key_slice = tuple[key_len + 3..2 * key_len + 3].to_vec(); - if key_slice == cand_key { + let key_tup = tuple[key_len + 3..2 * key_len + 3].to_vec(); + if key_tup == cand_key.0 { None } else { if include_deleted { return Some(( - (key_slice, key_idx, key_subidx), + (key_tup, key_idx, key_subidx), tuple[2 * key_len + 5].get_float().unwrap(), )); } @@ -639,7 +628,7 @@ impl<'a> SessionTx<'a> { None } else { Some(( - (key_slice, key_idx, key_subidx), + (key_tup, key_idx, key_subidx), tuple[2 * key_len + 5].get_float().unwrap(), )) } @@ -764,6 +753,7 @@ impl<'a> SessionTx<'a> { orig_table: &RelationHandle, idx_table: &RelationHandle, ) -> Result<()> { + let compound_key = (tuple_key.to_vec(), idx, subidx); // Go down the layers and remove all the links let mut encountered_singletons = false; for neg_layer in 0i64.. { @@ -782,7 +772,7 @@ impl<'a> SessionTx<'a> { } let neigbours = self - .hnsw_get_neighbours(tuple_key.to_vec(), idx, subidx, layer, idx_table, true)? + .hnsw_get_neighbours(&compound_key, layer, idx_table, true)? .collect_vec(); encountered_singletons |= neigbours.is_empty(); for (neighbour_key, _) in neigbours { diff --git a/cozo-core/src/runtime/tests.rs b/cozo-core/src/runtime/tests.rs index 5aef9997..8fa6781f 100644 --- a/cozo-core/src/runtime/tests.rs +++ b/cozo-core/src/runtime/tests.rs @@ -863,7 +863,7 @@ fn test_insertions() { db.run_script(r"?[k, v] := *a{k, v}", Default::default()) .unwrap(); db.run_script( - r"::hnsw create a:i {fields: [v], dim: 1536, ef: 16, m: 32}", + r"::hnsw create a:i {fields: [v], dim: 1536, ef: 16, m: 32, filter: k % 3 == 0}", Default::default(), ) .unwrap(); @@ -871,10 +871,10 @@ fn test_insertions() { .unwrap(); db.run_script(r"?[k] <- [[1]] :put a {k}", Default::default()) .unwrap(); - db.run_script(r"?[k] := k in int_range(100000) :put a {k}", Default::default()).unwrap(); + db.run_script(r"?[k] := k in int_range(300) :put a {k}", Default::default()).unwrap(); let res = db .run_script( - r"?[dist, k] := ~a:i{k | query: v, bind_distance: dist, k:10, ef: 5}, *a{k: 8888, v}", + r"?[dist, k] := ~a:i{k | query: v, bind_distance: dist, k:10, ef: 50, filter: k % 2 == 0, radius: 245}, *a{k: 96, v}", Default::default(), ) .unwrap();