correct heuristic

main
Ziyang Hu 1 year ago
parent d8e6ee5887
commit 4f839bbab1

@ -14,7 +14,7 @@ use crate::data::value::Vector;
use crate::parse::sys::HnswDistance; use crate::parse::sys::HnswDistance;
use crate::runtime::relation::RelationHandle; use crate::runtime::relation::RelationHandle;
use crate::runtime::transact::SessionTx; use crate::runtime::transact::SessionTx;
use crate::{decode_tuple_from_kv, DataValue, SourceSpan}; use crate::{DataValue, SourceSpan};
use itertools::Itertools; use itertools::Itertools;
use miette::{bail, miette, Result}; use miette::{bail, miette, Result};
use ordered_float::OrderedFloat; use ordered_float::OrderedFloat;
@ -22,7 +22,7 @@ use priority_queue::PriorityQueue;
use rand::Rng; use rand::Rng;
use smartstring::{LazyCompact, SmartString}; use smartstring::{LazyCompact, SmartString};
use std::cmp::{max, Reverse}; use std::cmp::{max, Reverse};
use std::collections::BTreeSet; use std::collections::{BTreeMap, BTreeSet};
#[derive(Debug, Clone, PartialEq, serde_derive::Serialize, serde_derive::Deserialize)] #[derive(Debug, Clone, PartialEq, serde_derive::Serialize, serde_derive::Deserialize)]
pub(crate) struct HnswIndexManifest { pub(crate) struct HnswIndexManifest {
@ -50,49 +50,19 @@ impl HnswIndexManifest {
// the level is the largest integer smaller than r // the level is the largest integer smaller than r
-(r.floor() as i64) -(r.floor() as i64)
} }
fn get_vector(&self, tuple: &Tuple, idx: usize, sub_idx: i32) -> Result<Vector> { }
let field = tuple.get(idx).unwrap();
if sub_idx >= 0 { type CompoundKey = (Tuple, usize, i32);
match field {
DataValue::List(l) => match l.get(sub_idx as usize) { struct VectorCache {
Some(DataValue::Vec(v)) => Ok(v.clone()), cache: BTreeMap<CompoundKey, Vector>,
_ => bail!( distance: HnswDistance,
"Cannot extract vector from {} for sub index {}", }
field,
sub_idx impl VectorCache {
), fn dist(&self, v1: &Vector, v2: &Vector) -> f64 {
}, match self.distance {
_ => bail!("Cannot interpret {} as list", field), HnswDistance::L2 => match (v1, v2) {
}
} else {
match field {
DataValue::Vec(v) => Ok(v.clone()),
_ => bail!("Cannot interpret {} as vector", field),
}
}
}
fn get_distance(&self, q: &Vector, tuple: &Tuple, idx: usize, sub_idx: i32) -> Result<f64> {
let field = tuple.get(idx).unwrap();
let target = if sub_idx >= 0 {
match field {
DataValue::List(l) => match l.get(sub_idx as usize) {
Some(DataValue::Vec(v)) => v,
_ => bail!(
"Cannot extract vector from {} for sub index {}",
field,
sub_idx
),
},
_ => bail!("Cannot interpret {} as list", field),
}
} else {
match field {
DataValue::Vec(v) => v,
_ => bail!("Cannot interpret {} as vector", field),
}
};
Ok(match self.distance {
HnswDistance::L2 => match (q, target) {
(Vector::F32(a), Vector::F32(b)) => { (Vector::F32(a), Vector::F32(b)) => {
let diff = a - b; let diff = a - b;
diff.dot(&diff) as f64 diff.dot(&diff) as f64
@ -101,13 +71,9 @@ impl HnswIndexManifest {
let diff = a - b; let diff = a - b;
diff.dot(&diff) diff.dot(&diff)
} }
_ => bail!( _ => panic!("Cannot compute L2 distance between {:?} and {:?}", v1, v2),
"Cannot compute L2 distance between {:?} and {:?}",
q,
target
),
}, },
HnswDistance::Cosine => match (q, target) { HnswDistance::Cosine => match (v1, v2) {
(Vector::F32(a), Vector::F32(b)) => { (Vector::F32(a), Vector::F32(b)) => {
let a_norm = a.dot(a) as f64; let a_norm = a.dot(a) as f64;
let b_norm = b.dot(b) as f64; let b_norm = b.dot(b) as f64;
@ -120,13 +86,12 @@ impl HnswIndexManifest {
let dot = a.dot(b); let dot = a.dot(b);
1.0 - dot / (a_norm * b_norm).sqrt() 1.0 - dot / (a_norm * b_norm).sqrt()
} }
_ => bail!( _ => panic!(
"Cannot compute cosine distance between {:?} and {:?}", "Cannot compute cosine distance between {:?} and {:?}",
q, v1, v2
target
), ),
}, },
HnswDistance::InnerProduct => match (q, target) { HnswDistance::InnerProduct => match (v1, v2) {
(Vector::F32(a), Vector::F32(b)) => { (Vector::F32(a), Vector::F32(b)) => {
let dot = a.dot(b); let dot = a.dot(b);
1. - dot as f64 1. - dot as f64
@ -135,13 +100,51 @@ impl HnswIndexManifest {
let dot = a.dot(b); let dot = a.dot(b);
1. - dot as f64 1. - dot as f64
} }
_ => bail!( _ => panic!("Cannot compute inner product between {:?} and {:?}", v1, v2),
"Cannot compute inner product between {:?} and {:?}",
q,
target
),
}, },
}) }
}
fn v_dist(&self, v: &Vector, key: &CompoundKey) -> f64 {
let v2 = self.cache.get(key).unwrap();
self.dist(v, v2)
}
fn k_dist(&self, k1: &CompoundKey, k2: &CompoundKey) -> f64 {
let v1 = self.cache.get(k1).unwrap();
let v2 = self.cache.get(k2).unwrap();
self.dist(v1, v2)
}
fn get_key(&self, key: &CompoundKey) -> &Vector {
self.cache.get(key).unwrap()
}
fn ensure_key(
&mut self,
key: &CompoundKey,
handle: &RelationHandle,
tx: &SessionTx<'_>,
) -> Result<()> {
if !self.cache.contains_key(key) {
match handle.get(tx, &key.0)? {
Some(tuple) => {
let mut field = &tuple[key.1];
if key.2 >= 0 {
match field {
DataValue::List(l) => {
field = &l[key.2 as usize];
}
_ => bail!("Cannot interpret {} as list", field),
}
}
match field {
DataValue::Vec(v) => {
self.cache.insert(key.clone(), v.clone());
}
_ => bail!("Cannot interpret {} as vector", field),
}
}
None => bail!("Cannot find tuple {:?}", key.0),
}
}
Ok(())
} }
} }
@ -156,7 +159,12 @@ impl<'a> SessionTx<'a> {
orig_table: &RelationHandle, orig_table: &RelationHandle,
idx_table: &RelationHandle, idx_table: &RelationHandle,
) -> Result<()> { ) -> Result<()> {
let mut vec_cache = VectorCache {
cache: BTreeMap::new(),
distance: manifest.distance,
};
let tuple_key = &tuple[..orig_table.metadata.keys.len()]; let tuple_key = &tuple[..orig_table.metadata.keys.len()];
vec_cache.cache.insert((tuple_key.to_vec(), idx, subidx), q.clone());
let hash = q.get_hash(); let hash = q.get_hash();
let mut canary_tuple = vec![DataValue::from(0)]; let mut canary_tuple = vec![DataValue::from(0)];
for _ in 0..2 { for _ in 0..2 {
@ -185,13 +193,15 @@ impl<'a> SessionTx<'a> {
let ep = ep?; let ep = ep?;
// bottom level since we are going up // bottom level since we are going up
let bottom_level = ep[0].get_int().unwrap(); let bottom_level = ep[0].get_int().unwrap();
let ep_key = ep[1..orig_table.metadata.keys.len() + 1].to_vec(); let ep_t_key = ep[1..orig_table.metadata.keys.len() + 1].to_vec();
let ep_idx = ep[orig_table.metadata.keys.len() + 1].get_int().unwrap() as usize; let ep_idx = ep[orig_table.metadata.keys.len() + 1].get_int().unwrap() as usize;
let ep_subidx = ep[orig_table.metadata.keys.len() + 2].get_int().unwrap() as i32; let ep_subidx = ep[orig_table.metadata.keys.len() + 2].get_int().unwrap() as i32;
let ep_distance = let ep_key = (ep_t_key, ep_idx, ep_subidx);
self.hnsw_compare_vector(q, &ep_key, idx, subidx, manifest, orig_table)?; vec_cache.ensure_key(&ep_key, orig_table, self)?;
let ep_distance = vec_cache.v_dist(q, &ep_key);
// max queue
let mut found_nn = PriorityQueue::new(); let mut found_nn = PriorityQueue::new();
found_nn.push((ep_key, ep_idx, ep_subidx), OrderedFloat(ep_distance)); found_nn.push(ep_key, OrderedFloat(ep_distance));
let target_level = manifest.get_random_level(); let target_level = manifest.get_random_level();
if target_level < bottom_level { if target_level < bottom_level {
// this becomes the entry point // this becomes the entry point
@ -211,10 +221,10 @@ impl<'a> SessionTx<'a> {
q, q,
1, 1,
current_level, current_level,
manifest,
orig_table, orig_table,
idx_table, idx_table,
&mut found_nn, &mut found_nn,
&mut vec_cache,
)?; )?;
} }
let mut self_tuple_key = Vec::with_capacity(orig_table.metadata.keys.len() * 2 + 5); let mut self_tuple_key = Vec::with_capacity(orig_table.metadata.keys.len() * 2 + 5);
@ -239,10 +249,10 @@ impl<'a> SessionTx<'a> {
q, q,
manifest.ef_construction, manifest.ef_construction,
current_level, current_level,
manifest,
orig_table, orig_table,
idx_table, idx_table,
&mut found_nn, &mut found_nn,
&mut vec_cache,
)?; )?;
// add bidirectional links to the nearest neighbors // add bidirectional links to the nearest neighbors
let neighbours = self.hnsw_select_neighbours_heuristic( let neighbours = self.hnsw_select_neighbours_heuristic(
@ -253,6 +263,7 @@ impl<'a> SessionTx<'a> {
manifest, manifest,
idx_table, idx_table,
orig_table, orig_table,
&mut vec_cache,
)?; )?;
// add self-link // add self-link
self_tuple_key[0] = DataValue::from(current_level); self_tuple_key[0] = DataValue::from(current_level);
@ -336,6 +347,7 @@ impl<'a> SessionTx<'a> {
manifest, manifest,
idx_table, idx_table,
orig_table, orig_table,
&mut vec_cache,
)?; )?;
} }
// update degree // update degree
@ -373,24 +385,16 @@ impl<'a> SessionTx<'a> {
manifest: &HnswIndexManifest, manifest: &HnswIndexManifest,
idx_table: &RelationHandle, idx_table: &RelationHandle,
orig_table: &RelationHandle, orig_table: &RelationHandle,
vec_cache: &mut VectorCache,
) -> Result<usize> { ) -> Result<usize> {
let orig_key = orig_table.encode_key_for_store(target, Default::default())?; let c_key = (target.to_vec(), idx, sub_idx);
let orig_val = match self.store_tx.get(&orig_key, false)? { vec_cache.ensure_key(&c_key, orig_table, self)?;
Some(bytes) => bytes, let vec = vec_cache.get_key(&c_key).clone();
None => {
bail!("Indexed vector not found, this signifies a bug in the index implementation")
}
};
let orig_tuple = decode_tuple_from_kv(&orig_key, &orig_val);
let vec = manifest.get_vector(&orig_tuple, idx, sub_idx)?;
let mut candidates = PriorityQueue::new(); let mut candidates = PriorityQueue::new();
for neighbour in for (neighbour_key, neighbour_dist) in
self.hnsw_get_neighbours(target.to_vec(), idx, sub_idx, level, idx_table, false)? self.hnsw_get_neighbours(target.to_vec(), idx, sub_idx, level, idx_table, false)?
{ {
candidates.push( candidates.push(neighbour_key, OrderedFloat(neighbour_dist));
(neighbour.0, neighbour.1, neighbour.2),
OrderedFloat(neighbour.3),
);
} }
let new_candidates = self.hnsw_select_neighbours_heuristic( let new_candidates = self.hnsw_select_neighbours_heuristic(
&vec, &vec,
@ -400,6 +404,7 @@ impl<'a> SessionTx<'a> {
manifest, manifest,
idx_table, idx_table,
orig_table, orig_table,
vec_cache,
)?; )?;
let mut old_candidate_set = BTreeSet::new(); let mut old_candidate_set = BTreeSet::new();
for (old, _) in &candidates { for (old, _) in &candidates {
@ -467,35 +472,30 @@ impl<'a> SessionTx<'a> {
Ok(new_degree) Ok(new_degree)
} }
fn hnsw_compare_vector(
&self,
q: &Vector,
target_key: &[DataValue],
target_idx: usize,
target_subidx: i32,
manifest: &HnswIndexManifest,
orig_table: &RelationHandle,
) -> Result<f64> {
let target_key_bytes = orig_table.encode_key_for_store(target_key, Default::default())?;
let bytes = match self.store_tx.get(&target_key_bytes, false)? {
Some(bytes) => bytes,
None => bail!("Indexed data not found, this signifies a bug in the index."),
};
let target_tuple = decode_tuple_from_kv(&target_key_bytes, &bytes);
manifest.get_distance(q, &target_tuple, target_idx, target_subidx)
}
fn hnsw_select_neighbours_heuristic( fn hnsw_select_neighbours_heuristic(
&self, &self,
q: &Vector, q: &Vector,
found: &PriorityQueue<(Tuple, usize, i32), OrderedFloat<f64>>, found: &PriorityQueue<CompoundKey, OrderedFloat<f64>>,
m: usize, m: usize,
level: i64, level: i64,
manifest: &HnswIndexManifest, manifest: &HnswIndexManifest,
idx_table: &RelationHandle, idx_table: &RelationHandle,
orig_table: &RelationHandle, orig_table: &RelationHandle,
) -> Result<PriorityQueue<(Tuple, usize, i32), Reverse<OrderedFloat<f64>>>> { vec_cache: &mut VectorCache,
) -> Result<PriorityQueue<CompoundKey, Reverse<OrderedFloat<f64>>>> {
let mut candidates = PriorityQueue::new(); let mut candidates = PriorityQueue::new();
let mut ret: PriorityQueue<_, Reverse<OrderedFloat<_>>> = PriorityQueue::new(); // Simple non-heuristic selection
// let mut temp = found.clone();
// while temp.len() > m {
// temp.pop();
// }
// for (item, dist) in temp.iter() {
// candidates.push(item.clone(), Reverse(*dist));
// }
// return Ok(candidates);
// End of simple non-heuristic selection
let mut ret: PriorityQueue<CompoundKey, Reverse<OrderedFloat<_>>> = PriorityQueue::new();
let mut discarded: PriorityQueue<_, Reverse<OrderedFloat<_>>> = PriorityQueue::new(); let mut discarded: PriorityQueue<_, Reverse<OrderedFloat<_>>> = PriorityQueue::new();
for (item, dist) in found.iter() { for (item, dist) in found.iter() {
// Add to candidates // Add to candidates
@ -504,7 +504,7 @@ impl<'a> SessionTx<'a> {
if manifest.extend_candidates { if manifest.extend_candidates {
for (item, _) in found.iter() { for (item, _) in found.iter() {
// Extend by neighbours // Extend by neighbours
for neighbour in self.hnsw_get_neighbours( for (neighbour_key, _) in self.hnsw_get_neighbours(
item.0.clone(), item.0.clone(),
item.1, item.1,
item.2, item.2,
@ -512,35 +512,32 @@ impl<'a> SessionTx<'a> {
idx_table, idx_table,
false, false,
)? { )? {
let dist = self.hnsw_compare_vector( vec_cache.ensure_key(&neighbour_key, orig_table, self)?;
q, let dist = vec_cache.v_dist(q, &neighbour_key);
&neighbour.0,
neighbour.1,
neighbour.2,
manifest,
orig_table,
)?;
candidates.push( candidates.push(
(neighbour.0, neighbour.1, neighbour.2), (neighbour_key.0, neighbour_key.1, neighbour_key.2),
Reverse(OrderedFloat(dist)), Reverse(OrderedFloat(dist)),
); );
} }
} }
} }
while !candidates.is_empty() && ret.len() < m { while !candidates.is_empty() && ret.len() < m {
let (nearest_triple, Reverse(OrderedFloat(nearest_dist))) = candidates.pop().unwrap(); let (cand_key, Reverse(OrderedFloat(cand_dist_to_q))) = candidates.pop().unwrap();
match ret.peek() { let mut should_add = true;
Some((_, Reverse(OrderedFloat(dist)))) => { for (existing, _) in ret.iter() {
if nearest_dist < *dist { vec_cache.ensure_key(&cand_key, orig_table, self)?;
ret.push(nearest_triple, Reverse(OrderedFloat(nearest_dist))); vec_cache.ensure_key(&existing, orig_table, self)?;
} else if manifest.keep_pruned_connections { let dist_to_existing = vec_cache.k_dist(existing, &cand_key);
discarded.push(nearest_triple, Reverse(OrderedFloat(nearest_dist))); if dist_to_existing < cand_dist_to_q {
} should_add = false;
} break;
None => {
ret.push(nearest_triple, Reverse(OrderedFloat(nearest_dist)));
} }
} }
if should_add {
ret.push(cand_key, Reverse(OrderedFloat(cand_dist_to_q)));
} else if manifest.keep_pruned_connections {
discarded.push(cand_key, Reverse(OrderedFloat(cand_dist_to_q)));
}
} }
if manifest.keep_pruned_connections { if manifest.keep_pruned_connections {
while !discarded.is_empty() && ret.len() < m { while !discarded.is_empty() && ret.len() < m {
@ -556,13 +553,14 @@ impl<'a> SessionTx<'a> {
q: &Vector, q: &Vector,
ef: usize, ef: usize,
cur_level: i64, cur_level: i64,
manifest: &HnswIndexManifest,
orig_table: &RelationHandle, orig_table: &RelationHandle,
idx_table: &RelationHandle, idx_table: &RelationHandle,
found_nn: &mut PriorityQueue<(Tuple, usize, i32), OrderedFloat<f64>>, found_nn: &mut PriorityQueue<CompoundKey, OrderedFloat<f64>>,
vec_cache: &mut VectorCache,
) -> Result<()> { ) -> Result<()> {
let mut visited: BTreeSet<(Tuple, usize, i32)> = BTreeSet::new(); let mut visited: BTreeSet<CompoundKey> = BTreeSet::new();
let mut candidates: PriorityQueue<(Tuple, usize, i32), Reverse<OrderedFloat<f64>>> = // min queue
let mut candidates: PriorityQueue<CompoundKey, Reverse<OrderedFloat<f64>>> =
PriorityQueue::new(); PriorityQueue::new();
for item in found_nn.iter() { for item in found_nn.iter() {
@ -577,7 +575,7 @@ impl<'a> SessionTx<'a> {
break; break;
} }
// loop over each of the candidate's neighbors // loop over each of the candidate's neighbors
for neighbour_tetra in self.hnsw_get_neighbours( for (neighbour_key, _) in self.hnsw_get_neighbours(
candidate.0, candidate.0,
candidate.1, candidate.1,
candidate.2, candidate.2,
@ -585,30 +583,20 @@ impl<'a> SessionTx<'a> {
idx_table, idx_table,
false, false,
)? { )? {
let neighbour_triple = (neighbour_tetra.0, neighbour_tetra.1, neighbour_tetra.2); if visited.contains(&neighbour_key) {
if visited.contains(&neighbour_triple) {
continue; continue;
} }
let neighbour_dist = self.hnsw_compare_vector( vec_cache.ensure_key(&neighbour_key, orig_table, self)?;
q, let neighbour_dist = vec_cache.v_dist(q, &neighbour_key);
&neighbour_triple.0,
neighbour_triple.1,
neighbour_triple.2,
manifest,
orig_table,
)?;
let (_, OrderedFloat(cand_furtherest_dist)) = found_nn.peek().unwrap(); let (_, OrderedFloat(cand_furtherest_dist)) = found_nn.peek().unwrap();
if found_nn.len() < ef || neighbour_dist < *cand_furtherest_dist { if found_nn.len() < ef || neighbour_dist < *cand_furtherest_dist {
candidates.push( candidates.push(neighbour_key.clone(), Reverse(OrderedFloat(neighbour_dist)));
neighbour_triple.clone(), found_nn.push(neighbour_key.clone(), OrderedFloat(neighbour_dist));
Reverse(OrderedFloat(neighbour_dist)),
);
found_nn.push(neighbour_triple.clone(), OrderedFloat(neighbour_dist));
if found_nn.len() > ef { if found_nn.len() > ef {
found_nn.pop(); found_nn.pop();
} }
} }
visited.insert(neighbour_triple); visited.insert(neighbour_key);
} }
} }
@ -622,7 +610,7 @@ impl<'a> SessionTx<'a> {
level: i64, level: i64,
idx_handle: &RelationHandle, idx_handle: &RelationHandle,
include_deleted: bool, include_deleted: bool,
) -> Result<impl Iterator<Item = (Tuple, usize, i32, f64)> + 'b> { ) -> Result<impl Iterator<Item = (CompoundKey, f64)> + 'b> {
let mut start_tuple = Vec::with_capacity(cand_key.len() + 3); let mut start_tuple = Vec::with_capacity(cand_key.len() + 3);
start_tuple.push(DataValue::from(level)); start_tuple.push(DataValue::from(level));
start_tuple.extend_from_slice(&cand_key); start_tuple.extend_from_slice(&cand_key);
@ -642,9 +630,7 @@ impl<'a> SessionTx<'a> {
} else { } else {
if include_deleted { if include_deleted {
return Some(( return Some((
key_slice, (key_slice, key_idx, key_subidx),
key_idx,
key_subidx,
tuple[2 * key_len + 5].get_float().unwrap(), tuple[2 * key_len + 5].get_float().unwrap(),
)); ));
} }
@ -653,9 +639,7 @@ impl<'a> SessionTx<'a> {
None None
} else { } else {
Some(( Some((
key_slice, (key_slice, key_idx, key_subidx),
key_idx,
key_subidx,
tuple[2 * key_len + 5].get_float().unwrap(), tuple[2 * key_len + 5].get_float().unwrap(),
)) ))
} }
@ -801,22 +785,22 @@ impl<'a> SessionTx<'a> {
.hnsw_get_neighbours(tuple_key.to_vec(), idx, subidx, layer, idx_table, true)? .hnsw_get_neighbours(tuple_key.to_vec(), idx, subidx, layer, idx_table, true)?
.collect_vec(); .collect_vec();
encountered_singletons |= neigbours.is_empty(); encountered_singletons |= neigbours.is_empty();
for (neighbour_key, neighbour_idx, neighbour_subidx, _) in neigbours { for (neighbour_key, _) in neigbours {
// REMARK: this still has some probability of disconnecting the graph. // REMARK: this still has some probability of disconnecting the graph.
// Should we accept that as a consequence of the probabilistic nature of the algorithm? // Should we accept that as a consequence of the probabilistic nature of the algorithm?
let mut out_key = vec![DataValue::from(layer)]; let mut out_key = vec![DataValue::from(layer)];
out_key.extend_from_slice(tuple_key); out_key.extend_from_slice(tuple_key);
out_key.push(DataValue::from(idx as i64)); out_key.push(DataValue::from(idx as i64));
out_key.push(DataValue::from(subidx as i64)); out_key.push(DataValue::from(subidx as i64));
out_key.extend_from_slice(&neighbour_key); out_key.extend_from_slice(&neighbour_key.0);
out_key.push(DataValue::from(neighbour_idx as i64)); out_key.push(DataValue::from(neighbour_key.1 as i64));
out_key.push(DataValue::from(neighbour_subidx as i64)); out_key.push(DataValue::from(neighbour_key.2 as i64));
let out_key_bytes = idx_table.encode_key_for_store(&out_key, Default::default())?; let out_key_bytes = idx_table.encode_key_for_store(&out_key, Default::default())?;
self.store_tx.del(&out_key_bytes)?; self.store_tx.del(&out_key_bytes)?;
let mut in_key = vec![DataValue::from(layer)]; let mut in_key = vec![DataValue::from(layer)];
in_key.extend_from_slice(&neighbour_key); in_key.extend_from_slice(&neighbour_key.0);
in_key.push(DataValue::from(neighbour_idx as i64)); in_key.push(DataValue::from(neighbour_key.1 as i64));
in_key.push(DataValue::from(neighbour_subidx as i64)); in_key.push(DataValue::from(neighbour_key.2 as i64));
in_key.extend_from_slice(tuple_key); in_key.extend_from_slice(tuple_key);
in_key.push(DataValue::from(idx as i64)); in_key.push(DataValue::from(idx as i64));
in_key.push(DataValue::from(subidx as i64)); in_key.push(DataValue::from(subidx as i64));
@ -824,9 +808,9 @@ impl<'a> SessionTx<'a> {
self.store_tx.del(&in_key_bytes)?; self.store_tx.del(&in_key_bytes)?;
let mut neighbour_self_key = vec![DataValue::from(layer)]; let mut neighbour_self_key = vec![DataValue::from(layer)];
for _ in 0..2 { for _ in 0..2 {
neighbour_self_key.extend_from_slice(&neighbour_key); neighbour_self_key.extend_from_slice(&neighbour_key.0);
neighbour_self_key.push(DataValue::from(neighbour_idx as i64)); neighbour_self_key.push(DataValue::from(neighbour_key.1 as i64));
neighbour_self_key.push(DataValue::from(neighbour_subidx as i64)); neighbour_self_key.push(DataValue::from(neighbour_key.2 as i64));
} }
let neighbour_val_bytes = self let neighbour_val_bytes = self
.store_tx .store_tx
@ -903,6 +887,11 @@ impl<'a> SessionTx<'a> {
(Vector::F64(v), VecElementType::F32) => Vector::F32(v.mapv(|x| x as f32)), (Vector::F64(v), VecElementType::F32) => Vector::F32(v.mapv(|x| x as f32)),
}; };
let mut vec_cache = VectorCache {
cache: Default::default(),
distance: config.manifest.distance,
};
let ep_res = config let ep_res = config
.idx_handle .idx_handle
.scan_bounded_prefix( .scan_bounded_prefix(
@ -915,42 +904,37 @@ impl<'a> SessionTx<'a> {
if let Some(ep) = ep_res { if let Some(ep) = ep_res {
let ep = ep?; let ep = ep?;
let bottom_level = ep[0].get_int().unwrap(); let bottom_level = ep[0].get_int().unwrap();
let ep_key = ep[1..config.base_handle.metadata.keys.len() + 1].to_vec(); let ep_t_key = ep[1..config.base_handle.metadata.keys.len() + 1].to_vec();
let ep_idx = ep[config.base_handle.metadata.keys.len() + 1] let ep_idx = ep[config.base_handle.metadata.keys.len() + 1]
.get_int() .get_int()
.unwrap() as usize; .unwrap() as usize;
let ep_subidx = ep[config.base_handle.metadata.keys.len() + 2] let ep_subidx = ep[config.base_handle.metadata.keys.len() + 2]
.get_int() .get_int()
.unwrap() as i32; .unwrap() as i32;
let ep_distance = self.hnsw_compare_vector( let ep_key = (ep_t_key, ep_idx, ep_subidx);
&q, vec_cache.ensure_key(&ep_key, &config.base_handle, self)?;
&ep_key, let ep_distance = vec_cache.v_dist(&q, &ep_key);
ep_idx,
ep_subidx,
&config.manifest,
&config.base_handle,
)?;
let mut found_nn = PriorityQueue::new(); let mut found_nn = PriorityQueue::new();
found_nn.push((ep_key, ep_idx, ep_subidx), OrderedFloat(ep_distance)); found_nn.push(ep_key, OrderedFloat(ep_distance));
for current_level in bottom_level..0 { for current_level in bottom_level..0 {
self.hnsw_search_level( self.hnsw_search_level(
&q, &q,
1, 1,
current_level, current_level,
&config.manifest,
&config.base_handle, &config.base_handle,
&config.idx_handle, &config.idx_handle,
&mut found_nn, &mut found_nn,
&mut vec_cache,
)?; )?;
} }
self.hnsw_search_level( self.hnsw_search_level(
&q, &q,
config.ef, config.ef,
0, 0,
&config.manifest,
&config.base_handle, &config.base_handle,
&config.idx_handle, &config.idx_handle,
&mut found_nn, &mut found_nn,
&mut vec_cache,
)?; )?;
if found_nn.is_empty() { if found_nn.is_empty() {
return Ok(vec![]); return Ok(vec![]);
@ -964,9 +948,7 @@ impl<'a> SessionTx<'a> {
let mut ret = vec![]; let mut ret = vec![];
while let Some(((cand_tuple, cand_idx, cand_subidx), OrderedFloat(distance))) = while let Some((cand_key, OrderedFloat(distance))) = found_nn.pop() {
found_nn.pop()
{
if let Some(r) = config.radius { if let Some(r) = config.radius {
if distance > r { if distance > r {
continue; continue;
@ -975,40 +957,40 @@ impl<'a> SessionTx<'a> {
let mut cand_tuple = config let mut cand_tuple = config
.base_handle .base_handle
.get(self, &cand_tuple)? .get(self, &cand_key.0)?
.ok_or_else(|| miette!("corrupted index"))?; .ok_or_else(|| miette!("corrupted index"))?;
if config.bind_field.is_some() { if config.bind_field.is_some() {
let field = if cand_idx as usize >= config.base_handle.metadata.keys.len() { let field = if cand_key.1 as usize >= config.base_handle.metadata.keys.len() {
config.base_handle.metadata.keys[cand_idx as usize] config.base_handle.metadata.keys[cand_key.1 as usize]
.name .name
.clone() .clone()
} else { } else {
config.base_handle.metadata.non_keys config.base_handle.metadata.non_keys
[cand_idx as usize - config.base_handle.metadata.keys.len()] [cand_key.1 as usize - config.base_handle.metadata.keys.len()]
.name .name
.clone() .clone()
}; };
cand_tuple.push(DataValue::Str(field)); cand_tuple.push(DataValue::Str(field));
} }
if config.bind_field_idx.is_some() { if config.bind_field_idx.is_some() {
cand_tuple.push(if cand_subidx < 0 { cand_tuple.push(if cand_key.2 < 0 {
DataValue::Null DataValue::Null
} else { } else {
DataValue::from(cand_subidx as i64) DataValue::from(cand_key.2 as i64)
}); });
} }
if config.bind_distance.is_some() { if config.bind_distance.is_some() {
cand_tuple.push(DataValue::from(distance)); cand_tuple.push(DataValue::from(distance));
} }
if config.bind_vector.is_some() { if config.bind_vector.is_some() {
let vec = if cand_subidx < 0 { let vec = if cand_key.2 < 0 {
match &cand_tuple[cand_idx] { match &cand_tuple[cand_key.1] {
DataValue::List(v) => v[cand_subidx as usize].clone(), DataValue::List(v) => v[cand_key.2 as usize].clone(),
_ => bail!("corrupted index"), _ => bail!("corrupted index"),
} }
} else { } else {
cand_tuple[cand_idx].clone() cand_tuple[cand_key.1].clone()
}; };
cand_tuple.push(vec); cand_tuple.push(vec);
} }

@ -863,7 +863,7 @@ fn test_insertions() {
db.run_script(r"?[k, v] := *a{k, v}", Default::default()) db.run_script(r"?[k, v] := *a{k, v}", Default::default())
.unwrap(); .unwrap();
db.run_script( db.run_script(
r"::hnsw create a:i {fields: [v], dim: 100, ef: 200, m: 50}", r"::hnsw create a:i {fields: [v], dim: 100, ef: 16, m: 32}",
Default::default(), Default::default(),
) )
.unwrap(); .unwrap();
@ -871,10 +871,10 @@ fn test_insertions() {
.unwrap(); .unwrap();
db.run_script(r"?[k] <- [[1]] :put a {k}", Default::default()) db.run_script(r"?[k] <- [[1]] :put a {k}", Default::default())
.unwrap(); .unwrap();
db.run_script(r"?[k] := k in int_range(1000) :put a {k}", Default::default()).unwrap(); db.run_script(r"?[k] := k in int_range(10000) :put a {k}", Default::default()).unwrap();
let res = db let res = db
.run_script( .run_script(
r"?[dist, k] := ~a:i{k | query: v, bind_distance: dist, k:10, ef: 200}, *a{k: 333, v}", r"?[dist, k] := ~a:i{k | query: v, bind_distance: dist, k:10, ef: 5}, *a{k: 8888, v}",
Default::default(), Default::default(),
) )
.unwrap(); .unwrap();

Loading…
Cancel
Save