diff --git a/cozo-core/src/runtime/hnsw.rs b/cozo-core/src/runtime/hnsw.rs index e87d559f..fc7f129d 100644 --- a/cozo-core/src/runtime/hnsw.rs +++ b/cozo-core/src/runtime/hnsw.rs @@ -14,6 +14,7 @@ use crate::parse::sys::HnswDistance; use crate::runtime::relation::RelationHandle; use crate::runtime::transact::SessionTx; use crate::{decode_tuple_from_kv, DataValue}; +use itertools::Itertools; use miette::{bail, miette, Result}; use ordered_float::OrderedFloat; use priority_queue::PriorityQueue; @@ -199,8 +200,7 @@ impl<'a> SessionTx<'a> { return Ok(()); } } - // TODO - self.hnsw_remove_vec()?; + self.hnsw_remove_vec(tuple_key, idx, subidx, orig_table, idx_table)?; } let ep_res = idx_table @@ -314,8 +314,6 @@ impl<'a> SessionTx<'a> { idx_table.encode_key_for_store(&out_key, Default::default())?; let out_val_bytes = idx_table.encode_val_only_for_store(&out_val, Default::default())?; - // println!("tuple: {:?}", tuple_key); - // println!("out_key: {:?}", out_key); self.store_tx.put(&out_key_bytes, &out_val_bytes)?; let mut in_key = Vec::with_capacity(orig_table.metadata.keys.len() * 2 + 5); @@ -331,7 +329,7 @@ impl<'a> SessionTx<'a> { in_key.extend_from_slice(tuple_key); in_key.push(DataValue::from(idx as i64)); in_key.push(DataValue::from(subidx as i64)); - // println!("in_key: {:?}", in_key); + let in_key_bytes = idx_table.encode_key_for_store(&in_key, Default::default())?; let in_val_bytes = @@ -375,7 +373,16 @@ impl<'a> SessionTx<'a> { } else { // This is the first vector in the index. let level = manifest.get_random_level(); - self.hnsw_put_fresh_at_levels(hash.as_ref(), tuple_key, idx, subidx, orig_table, idx_table, level, 0)?; + self.hnsw_put_fresh_at_levels( + hash.as_ref(), + tuple_key, + idx, + subidx, + orig_table, + idx_table, + level, + 0, + )?; } Ok(()) } @@ -401,7 +408,7 @@ impl<'a> SessionTx<'a> { let vec = manifest.get_vector(&orig_tuple, idx, sub_idx)?; let mut candidates = PriorityQueue::new(); for neighbour in - self.hnsw_get_neighbours(target.to_vec(), idx, sub_idx, level, idx_table)? + self.hnsw_get_neighbours(target.to_vec(), idx, sub_idx, level, idx_table, false)? { candidates.push( (neighbour.0, neighbour.1, neighbour.2), @@ -519,9 +526,14 @@ impl<'a> SessionTx<'a> { if manifest.extend_candidates { for (item, _) in found.iter() { // Extend by neighbours - for neighbour in - self.hnsw_get_neighbours(item.0.clone(), item.1, item.2, level, idx_table)? - { + for neighbour in self.hnsw_get_neighbours( + item.0.clone(), + item.1, + item.2, + level, + idx_table, + false, + )? { let dist = self.hnsw_compare_vector( q, &neighbour.0, @@ -593,6 +605,7 @@ impl<'a> SessionTx<'a> { candidate.2, cur_level, idx_table, + false, )? { let neighbour_triple = (neighbour_tetra.0, neighbour_tetra.1, neighbour_tetra.2); if visited.contains(&neighbour_triple) { @@ -630,6 +643,7 @@ impl<'a> SessionTx<'a> { cand_sub_idx: i32, level: i64, idx_handle: &RelationHandle, + include_deleted: bool, ) -> Result + 'b> { let mut start_tuple = Vec::with_capacity(cand_key.len() + 3); start_tuple.push(DataValue::from(level)); @@ -641,14 +655,21 @@ impl<'a> SessionTx<'a> { .scan_prefix(self, &start_tuple) .filter_map(move |res| { let tuple = res.unwrap(); - // println!("tuple: {:?}", tuple); - // println!("key_len: {}", key_len); + let key_idx = tuple[2 * key_len + 3].get_int().unwrap() as usize; let key_subidx = tuple[2 * key_len + 4].get_int().unwrap() as i32; let key_slice = tuple[key_len + 3..2 * key_len + 3].to_vec(); if key_slice == cand_key { None } else { + if include_deleted { + return Some(( + key_slice, + key_idx, + key_subidx, + tuple[2 * key_len + 5].get_float().unwrap(), + )); + } let is_deleted = tuple[2 * key_len + 7].get_bool().unwrap(); if is_deleted { None @@ -750,8 +771,95 @@ impl<'a> SessionTx<'a> { pub(crate) fn hnsw_remove(&mut self) -> Result<()> { todo!() } - pub(crate) fn hnsw_remove_vec(&mut self) -> Result<()> { - todo!() + fn hnsw_remove_vec( + &mut self, + tuple_key: &[DataValue], + idx: usize, + subidx: i32, + orig_table: &RelationHandle, + idx_table: &RelationHandle, + ) -> Result<()> { + // Go down the layers and remove all the links + let mut encountered_singletons = false; + for neg_layer in 0i64.. { + let layer = -neg_layer; + let mut self_key = vec![DataValue::from(layer)]; + for _ in 0..2 { + self_key.extend_from_slice(tuple_key); + self_key.push(DataValue::from(idx as i64)); + self_key.push(DataValue::from(subidx as i64)); + } + let self_key_bytes = idx_table.encode_key_for_store(&self_key, Default::default())?; + if self.store_tx.exists(&self_key_bytes, false)? { + self.store_tx.del(&self_key_bytes)?; + } else { + break; + } + + let neigbours = self + .hnsw_get_neighbours(tuple_key.to_vec(), idx, subidx, layer, idx_table, true)? + .collect_vec(); + encountered_singletons |= neigbours.is_empty(); + for (neighbour_key, neighbour_idx, neighbour_subidx, _) in neigbours { + let mut out_key = vec![DataValue::from(layer)]; + out_key.extend_from_slice(tuple_key); + out_key.push(DataValue::from(idx as i64)); + out_key.push(DataValue::from(subidx as i64)); + out_key.extend_from_slice(&neighbour_key); + out_key.push(DataValue::from(neighbour_idx as i64)); + out_key.push(DataValue::from(neighbour_subidx as i64)); + let out_key_bytes = idx_table.encode_key_for_store(&out_key, Default::default())?; + self.store_tx.del(&out_key_bytes)?; + let mut in_key = vec![DataValue::from(layer)]; + in_key.extend_from_slice(&neighbour_key); + in_key.push(DataValue::from(neighbour_idx as i64)); + in_key.push(DataValue::from(neighbour_subidx as i64)); + in_key.extend_from_slice(tuple_key); + in_key.push(DataValue::from(idx as i64)); + in_key.push(DataValue::from(subidx as i64)); + let in_key_bytes = idx_table.encode_key_for_store(&in_key, Default::default())?; + self.store_tx.del(&in_key_bytes)?; + } + } + + if encountered_singletons { + // the entry point is removed, we need to do something + let ep_res = idx_table + .scan_bounded_prefix( + self, + &[], + &[DataValue::from(i64::MIN)], + &[DataValue::from(1)], + ) + .next(); + let mut canary_key = vec![DataValue::from(1)]; + for _ in 0..2 { + for _ in 0..orig_table.metadata.keys.len() { + canary_key.push(DataValue::Null); + } + canary_key.push(DataValue::Null); + canary_key.push(DataValue::Null); + } + let canary_key_bytes = idx_table.encode_key_for_store(&canary_key, Default::default())?; + if let Some(ep) = ep_res { + let ep = ep?; + let target_key_bytes = idx_table.encode_key_for_store(&ep, Default::default())?; + let bottom_level = ep[0].get_int().unwrap(); + // canary value is for conflict detection: prevent the scenario of disconnected graphs at all levels + let canary_value = [ + DataValue::from(bottom_level), + DataValue::Bytes(target_key_bytes), + DataValue::from(false), + ]; + let canary_value_bytes = idx_table.encode_val_only_for_store(&canary_value, Default::default())?; + self.store_tx.put(&canary_key_bytes, &canary_value_bytes)?; + } else { + // HA! we have removed the last item in the index + self.store_tx.del(&canary_key_bytes)?; + } + } + + Ok(()) } pub(crate) fn hnsw_knn( &self, @@ -853,10 +961,22 @@ impl<'a> SessionTx<'a> { cur.push(cand_tuple[*i].clone()); } if config.bind_field { - cur.push(DataValue::from(cand_idx as i64)); + let field = if cand_idx as usize >= orig_table.metadata.keys.len() { + orig_table.metadata.keys[cand_idx as usize].name.clone() + } else { + orig_table.metadata.non_keys + [cand_idx as usize - orig_table.metadata.keys.len()] + .name + .clone() + }; + cur.push(DataValue::Str(field)); } if config.bind_field_idx { - cur.push(DataValue::from(cand_subidx as i64)); + cur.push(if cand_subidx < 0 { + DataValue::Null + } else { + DataValue::from(cand_subidx as i64) + }); } if config.bind_distance { cur.push(DataValue::from(distance)); diff --git a/cozo-core/src/runtime/tests.rs b/cozo-core/src/runtime/tests.rs index 3bd8bd02..5106060c 100644 --- a/cozo-core/src/runtime/tests.rs +++ b/cozo-core/src/runtime/tests.rs @@ -786,7 +786,9 @@ fn test_vec_index() { ?[k, v] <- [['a', [1,2,3,4,5,6,7,8]], ['b', [2,3,4,5,6,7,8,9]], ['bb', [2,3,4,5,6,7,8,9]], - ['c', [2,3,4,5,6,7,8,19]]] + ['c', [2,3,4,5,6,7,8,19]], + ['a', [2,3,4,5,6,7,8,9]], + ['b', [1,1,1,1,1,1,1,1]]] :create a {k: String => v: } ", Default::default())