|
|
|
@ -14,7 +14,7 @@ use crate::data::value::Vector;
|
|
|
|
|
use crate::parse::sys::HnswDistance;
|
|
|
|
|
use crate::runtime::relation::RelationHandle;
|
|
|
|
|
use crate::runtime::transact::SessionTx;
|
|
|
|
|
use crate::{decode_tuple_from_kv, DataValue, SourceSpan};
|
|
|
|
|
use crate::{DataValue, SourceSpan};
|
|
|
|
|
use itertools::Itertools;
|
|
|
|
|
use miette::{bail, miette, Result};
|
|
|
|
|
use ordered_float::OrderedFloat;
|
|
|
|
@ -22,7 +22,7 @@ use priority_queue::PriorityQueue;
|
|
|
|
|
use rand::Rng;
|
|
|
|
|
use smartstring::{LazyCompact, SmartString};
|
|
|
|
|
use std::cmp::{max, Reverse};
|
|
|
|
|
use std::collections::BTreeSet;
|
|
|
|
|
use std::collections::{BTreeMap, BTreeSet};
|
|
|
|
|
|
|
|
|
|
#[derive(Debug, Clone, PartialEq, serde_derive::Serialize, serde_derive::Deserialize)]
|
|
|
|
|
pub(crate) struct HnswIndexManifest {
|
|
|
|
@ -50,49 +50,19 @@ impl HnswIndexManifest {
|
|
|
|
|
// the level is the largest integer smaller than r
|
|
|
|
|
-(r.floor() as i64)
|
|
|
|
|
}
|
|
|
|
|
fn get_vector(&self, tuple: &Tuple, idx: usize, sub_idx: i32) -> Result<Vector> {
|
|
|
|
|
let field = tuple.get(idx).unwrap();
|
|
|
|
|
if sub_idx >= 0 {
|
|
|
|
|
match field {
|
|
|
|
|
DataValue::List(l) => match l.get(sub_idx as usize) {
|
|
|
|
|
Some(DataValue::Vec(v)) => Ok(v.clone()),
|
|
|
|
|
_ => bail!(
|
|
|
|
|
"Cannot extract vector from {} for sub index {}",
|
|
|
|
|
field,
|
|
|
|
|
sub_idx
|
|
|
|
|
),
|
|
|
|
|
},
|
|
|
|
|
_ => bail!("Cannot interpret {} as list", field),
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
match field {
|
|
|
|
|
DataValue::Vec(v) => Ok(v.clone()),
|
|
|
|
|
_ => bail!("Cannot interpret {} as vector", field),
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
fn get_distance(&self, q: &Vector, tuple: &Tuple, idx: usize, sub_idx: i32) -> Result<f64> {
|
|
|
|
|
let field = tuple.get(idx).unwrap();
|
|
|
|
|
let target = if sub_idx >= 0 {
|
|
|
|
|
match field {
|
|
|
|
|
DataValue::List(l) => match l.get(sub_idx as usize) {
|
|
|
|
|
Some(DataValue::Vec(v)) => v,
|
|
|
|
|
_ => bail!(
|
|
|
|
|
"Cannot extract vector from {} for sub index {}",
|
|
|
|
|
field,
|
|
|
|
|
sub_idx
|
|
|
|
|
),
|
|
|
|
|
},
|
|
|
|
|
_ => bail!("Cannot interpret {} as list", field),
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
match field {
|
|
|
|
|
DataValue::Vec(v) => v,
|
|
|
|
|
_ => bail!("Cannot interpret {} as vector", field),
|
|
|
|
|
|
|
|
|
|
type CompoundKey = (Tuple, usize, i32);
|
|
|
|
|
|
|
|
|
|
struct VectorCache {
|
|
|
|
|
cache: BTreeMap<CompoundKey, Vector>,
|
|
|
|
|
distance: HnswDistance,
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
Ok(match self.distance {
|
|
|
|
|
HnswDistance::L2 => match (q, target) {
|
|
|
|
|
|
|
|
|
|
impl VectorCache {
|
|
|
|
|
fn dist(&self, v1: &Vector, v2: &Vector) -> f64 {
|
|
|
|
|
match self.distance {
|
|
|
|
|
HnswDistance::L2 => match (v1, v2) {
|
|
|
|
|
(Vector::F32(a), Vector::F32(b)) => {
|
|
|
|
|
let diff = a - b;
|
|
|
|
|
diff.dot(&diff) as f64
|
|
|
|
@ -101,13 +71,9 @@ impl HnswIndexManifest {
|
|
|
|
|
let diff = a - b;
|
|
|
|
|
diff.dot(&diff)
|
|
|
|
|
}
|
|
|
|
|
_ => bail!(
|
|
|
|
|
"Cannot compute L2 distance between {:?} and {:?}",
|
|
|
|
|
q,
|
|
|
|
|
target
|
|
|
|
|
),
|
|
|
|
|
_ => panic!("Cannot compute L2 distance between {:?} and {:?}", v1, v2),
|
|
|
|
|
},
|
|
|
|
|
HnswDistance::Cosine => match (q, target) {
|
|
|
|
|
HnswDistance::Cosine => match (v1, v2) {
|
|
|
|
|
(Vector::F32(a), Vector::F32(b)) => {
|
|
|
|
|
let a_norm = a.dot(a) as f64;
|
|
|
|
|
let b_norm = b.dot(b) as f64;
|
|
|
|
@ -120,13 +86,12 @@ impl HnswIndexManifest {
|
|
|
|
|
let dot = a.dot(b);
|
|
|
|
|
1.0 - dot / (a_norm * b_norm).sqrt()
|
|
|
|
|
}
|
|
|
|
|
_ => bail!(
|
|
|
|
|
_ => panic!(
|
|
|
|
|
"Cannot compute cosine distance between {:?} and {:?}",
|
|
|
|
|
q,
|
|
|
|
|
target
|
|
|
|
|
v1, v2
|
|
|
|
|
),
|
|
|
|
|
},
|
|
|
|
|
HnswDistance::InnerProduct => match (q, target) {
|
|
|
|
|
HnswDistance::InnerProduct => match (v1, v2) {
|
|
|
|
|
(Vector::F32(a), Vector::F32(b)) => {
|
|
|
|
|
let dot = a.dot(b);
|
|
|
|
|
1. - dot as f64
|
|
|
|
@ -135,13 +100,51 @@ impl HnswIndexManifest {
|
|
|
|
|
let dot = a.dot(b);
|
|
|
|
|
1. - dot as f64
|
|
|
|
|
}
|
|
|
|
|
_ => bail!(
|
|
|
|
|
"Cannot compute inner product between {:?} and {:?}",
|
|
|
|
|
q,
|
|
|
|
|
target
|
|
|
|
|
),
|
|
|
|
|
_ => panic!("Cannot compute inner product between {:?} and {:?}", v1, v2),
|
|
|
|
|
},
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
fn v_dist(&self, v: &Vector, key: &CompoundKey) -> f64 {
|
|
|
|
|
let v2 = self.cache.get(key).unwrap();
|
|
|
|
|
self.dist(v, v2)
|
|
|
|
|
}
|
|
|
|
|
fn k_dist(&self, k1: &CompoundKey, k2: &CompoundKey) -> f64 {
|
|
|
|
|
let v1 = self.cache.get(k1).unwrap();
|
|
|
|
|
let v2 = self.cache.get(k2).unwrap();
|
|
|
|
|
self.dist(v1, v2)
|
|
|
|
|
}
|
|
|
|
|
fn get_key(&self, key: &CompoundKey) -> &Vector {
|
|
|
|
|
self.cache.get(key).unwrap()
|
|
|
|
|
}
|
|
|
|
|
fn ensure_key(
|
|
|
|
|
&mut self,
|
|
|
|
|
key: &CompoundKey,
|
|
|
|
|
handle: &RelationHandle,
|
|
|
|
|
tx: &SessionTx<'_>,
|
|
|
|
|
) -> Result<()> {
|
|
|
|
|
if !self.cache.contains_key(key) {
|
|
|
|
|
match handle.get(tx, &key.0)? {
|
|
|
|
|
Some(tuple) => {
|
|
|
|
|
let mut field = &tuple[key.1];
|
|
|
|
|
if key.2 >= 0 {
|
|
|
|
|
match field {
|
|
|
|
|
DataValue::List(l) => {
|
|
|
|
|
field = &l[key.2 as usize];
|
|
|
|
|
}
|
|
|
|
|
_ => bail!("Cannot interpret {} as list", field),
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
match field {
|
|
|
|
|
DataValue::Vec(v) => {
|
|
|
|
|
self.cache.insert(key.clone(), v.clone());
|
|
|
|
|
}
|
|
|
|
|
_ => bail!("Cannot interpret {} as vector", field),
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
None => bail!("Cannot find tuple {:?}", key.0),
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
Ok(())
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -156,7 +159,12 @@ impl<'a> SessionTx<'a> {
|
|
|
|
|
orig_table: &RelationHandle,
|
|
|
|
|
idx_table: &RelationHandle,
|
|
|
|
|
) -> Result<()> {
|
|
|
|
|
let mut vec_cache = VectorCache {
|
|
|
|
|
cache: BTreeMap::new(),
|
|
|
|
|
distance: manifest.distance,
|
|
|
|
|
};
|
|
|
|
|
let tuple_key = &tuple[..orig_table.metadata.keys.len()];
|
|
|
|
|
vec_cache.cache.insert((tuple_key.to_vec(), idx, subidx), q.clone());
|
|
|
|
|
let hash = q.get_hash();
|
|
|
|
|
let mut canary_tuple = vec![DataValue::from(0)];
|
|
|
|
|
for _ in 0..2 {
|
|
|
|
@ -185,13 +193,15 @@ impl<'a> SessionTx<'a> {
|
|
|
|
|
let ep = ep?;
|
|
|
|
|
// bottom level since we are going up
|
|
|
|
|
let bottom_level = ep[0].get_int().unwrap();
|
|
|
|
|
let ep_key = ep[1..orig_table.metadata.keys.len() + 1].to_vec();
|
|
|
|
|
let ep_t_key = ep[1..orig_table.metadata.keys.len() + 1].to_vec();
|
|
|
|
|
let ep_idx = ep[orig_table.metadata.keys.len() + 1].get_int().unwrap() as usize;
|
|
|
|
|
let ep_subidx = ep[orig_table.metadata.keys.len() + 2].get_int().unwrap() as i32;
|
|
|
|
|
let ep_distance =
|
|
|
|
|
self.hnsw_compare_vector(q, &ep_key, idx, subidx, manifest, orig_table)?;
|
|
|
|
|
let ep_key = (ep_t_key, ep_idx, ep_subidx);
|
|
|
|
|
vec_cache.ensure_key(&ep_key, orig_table, self)?;
|
|
|
|
|
let ep_distance = vec_cache.v_dist(q, &ep_key);
|
|
|
|
|
// max queue
|
|
|
|
|
let mut found_nn = PriorityQueue::new();
|
|
|
|
|
found_nn.push((ep_key, ep_idx, ep_subidx), OrderedFloat(ep_distance));
|
|
|
|
|
found_nn.push(ep_key, OrderedFloat(ep_distance));
|
|
|
|
|
let target_level = manifest.get_random_level();
|
|
|
|
|
if target_level < bottom_level {
|
|
|
|
|
// this becomes the entry point
|
|
|
|
@ -211,10 +221,10 @@ impl<'a> SessionTx<'a> {
|
|
|
|
|
q,
|
|
|
|
|
1,
|
|
|
|
|
current_level,
|
|
|
|
|
manifest,
|
|
|
|
|
orig_table,
|
|
|
|
|
idx_table,
|
|
|
|
|
&mut found_nn,
|
|
|
|
|
&mut vec_cache,
|
|
|
|
|
)?;
|
|
|
|
|
}
|
|
|
|
|
let mut self_tuple_key = Vec::with_capacity(orig_table.metadata.keys.len() * 2 + 5);
|
|
|
|
@ -239,10 +249,10 @@ impl<'a> SessionTx<'a> {
|
|
|
|
|
q,
|
|
|
|
|
manifest.ef_construction,
|
|
|
|
|
current_level,
|
|
|
|
|
manifest,
|
|
|
|
|
orig_table,
|
|
|
|
|
idx_table,
|
|
|
|
|
&mut found_nn,
|
|
|
|
|
&mut vec_cache,
|
|
|
|
|
)?;
|
|
|
|
|
// add bidirectional links to the nearest neighbors
|
|
|
|
|
let neighbours = self.hnsw_select_neighbours_heuristic(
|
|
|
|
@ -253,6 +263,7 @@ impl<'a> SessionTx<'a> {
|
|
|
|
|
manifest,
|
|
|
|
|
idx_table,
|
|
|
|
|
orig_table,
|
|
|
|
|
&mut vec_cache,
|
|
|
|
|
)?;
|
|
|
|
|
// add self-link
|
|
|
|
|
self_tuple_key[0] = DataValue::from(current_level);
|
|
|
|
@ -336,6 +347,7 @@ impl<'a> SessionTx<'a> {
|
|
|
|
|
manifest,
|
|
|
|
|
idx_table,
|
|
|
|
|
orig_table,
|
|
|
|
|
&mut vec_cache,
|
|
|
|
|
)?;
|
|
|
|
|
}
|
|
|
|
|
// update degree
|
|
|
|
@ -373,24 +385,16 @@ impl<'a> SessionTx<'a> {
|
|
|
|
|
manifest: &HnswIndexManifest,
|
|
|
|
|
idx_table: &RelationHandle,
|
|
|
|
|
orig_table: &RelationHandle,
|
|
|
|
|
vec_cache: &mut VectorCache,
|
|
|
|
|
) -> Result<usize> {
|
|
|
|
|
let orig_key = orig_table.encode_key_for_store(target, Default::default())?;
|
|
|
|
|
let orig_val = match self.store_tx.get(&orig_key, false)? {
|
|
|
|
|
Some(bytes) => bytes,
|
|
|
|
|
None => {
|
|
|
|
|
bail!("Indexed vector not found, this signifies a bug in the index implementation")
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
let orig_tuple = decode_tuple_from_kv(&orig_key, &orig_val);
|
|
|
|
|
let vec = manifest.get_vector(&orig_tuple, idx, sub_idx)?;
|
|
|
|
|
let c_key = (target.to_vec(), idx, sub_idx);
|
|
|
|
|
vec_cache.ensure_key(&c_key, orig_table, self)?;
|
|
|
|
|
let vec = vec_cache.get_key(&c_key).clone();
|
|
|
|
|
let mut candidates = PriorityQueue::new();
|
|
|
|
|
for neighbour in
|
|
|
|
|
for (neighbour_key, neighbour_dist) in
|
|
|
|
|
self.hnsw_get_neighbours(target.to_vec(), idx, sub_idx, level, idx_table, false)?
|
|
|
|
|
{
|
|
|
|
|
candidates.push(
|
|
|
|
|
(neighbour.0, neighbour.1, neighbour.2),
|
|
|
|
|
OrderedFloat(neighbour.3),
|
|
|
|
|
);
|
|
|
|
|
candidates.push(neighbour_key, OrderedFloat(neighbour_dist));
|
|
|
|
|
}
|
|
|
|
|
let new_candidates = self.hnsw_select_neighbours_heuristic(
|
|
|
|
|
&vec,
|
|
|
|
@ -400,6 +404,7 @@ impl<'a> SessionTx<'a> {
|
|
|
|
|
manifest,
|
|
|
|
|
idx_table,
|
|
|
|
|
orig_table,
|
|
|
|
|
vec_cache,
|
|
|
|
|
)?;
|
|
|
|
|
let mut old_candidate_set = BTreeSet::new();
|
|
|
|
|
for (old, _) in &candidates {
|
|
|
|
@ -467,35 +472,30 @@ impl<'a> SessionTx<'a> {
|
|
|
|
|
|
|
|
|
|
Ok(new_degree)
|
|
|
|
|
}
|
|
|
|
|
fn hnsw_compare_vector(
|
|
|
|
|
&self,
|
|
|
|
|
q: &Vector,
|
|
|
|
|
target_key: &[DataValue],
|
|
|
|
|
target_idx: usize,
|
|
|
|
|
target_subidx: i32,
|
|
|
|
|
manifest: &HnswIndexManifest,
|
|
|
|
|
orig_table: &RelationHandle,
|
|
|
|
|
) -> Result<f64> {
|
|
|
|
|
let target_key_bytes = orig_table.encode_key_for_store(target_key, Default::default())?;
|
|
|
|
|
let bytes = match self.store_tx.get(&target_key_bytes, false)? {
|
|
|
|
|
Some(bytes) => bytes,
|
|
|
|
|
None => bail!("Indexed data not found, this signifies a bug in the index."),
|
|
|
|
|
};
|
|
|
|
|
let target_tuple = decode_tuple_from_kv(&target_key_bytes, &bytes);
|
|
|
|
|
manifest.get_distance(q, &target_tuple, target_idx, target_subidx)
|
|
|
|
|
}
|
|
|
|
|
fn hnsw_select_neighbours_heuristic(
|
|
|
|
|
&self,
|
|
|
|
|
q: &Vector,
|
|
|
|
|
found: &PriorityQueue<(Tuple, usize, i32), OrderedFloat<f64>>,
|
|
|
|
|
found: &PriorityQueue<CompoundKey, OrderedFloat<f64>>,
|
|
|
|
|
m: usize,
|
|
|
|
|
level: i64,
|
|
|
|
|
manifest: &HnswIndexManifest,
|
|
|
|
|
idx_table: &RelationHandle,
|
|
|
|
|
orig_table: &RelationHandle,
|
|
|
|
|
) -> Result<PriorityQueue<(Tuple, usize, i32), Reverse<OrderedFloat<f64>>>> {
|
|
|
|
|
vec_cache: &mut VectorCache,
|
|
|
|
|
) -> Result<PriorityQueue<CompoundKey, Reverse<OrderedFloat<f64>>>> {
|
|
|
|
|
let mut candidates = PriorityQueue::new();
|
|
|
|
|
let mut ret: PriorityQueue<_, Reverse<OrderedFloat<_>>> = PriorityQueue::new();
|
|
|
|
|
// Simple non-heuristic selection
|
|
|
|
|
// let mut temp = found.clone();
|
|
|
|
|
// while temp.len() > m {
|
|
|
|
|
// temp.pop();
|
|
|
|
|
// }
|
|
|
|
|
// for (item, dist) in temp.iter() {
|
|
|
|
|
// candidates.push(item.clone(), Reverse(*dist));
|
|
|
|
|
// }
|
|
|
|
|
// return Ok(candidates);
|
|
|
|
|
// End of simple non-heuristic selection
|
|
|
|
|
|
|
|
|
|
let mut ret: PriorityQueue<CompoundKey, Reverse<OrderedFloat<_>>> = PriorityQueue::new();
|
|
|
|
|
let mut discarded: PriorityQueue<_, Reverse<OrderedFloat<_>>> = PriorityQueue::new();
|
|
|
|
|
for (item, dist) in found.iter() {
|
|
|
|
|
// Add to candidates
|
|
|
|
@ -504,7 +504,7 @@ impl<'a> SessionTx<'a> {
|
|
|
|
|
if manifest.extend_candidates {
|
|
|
|
|
for (item, _) in found.iter() {
|
|
|
|
|
// Extend by neighbours
|
|
|
|
|
for neighbour in self.hnsw_get_neighbours(
|
|
|
|
|
for (neighbour_key, _) in self.hnsw_get_neighbours(
|
|
|
|
|
item.0.clone(),
|
|
|
|
|
item.1,
|
|
|
|
|
item.2,
|
|
|
|
@ -512,34 +512,31 @@ impl<'a> SessionTx<'a> {
|
|
|
|
|
idx_table,
|
|
|
|
|
false,
|
|
|
|
|
)? {
|
|
|
|
|
let dist = self.hnsw_compare_vector(
|
|
|
|
|
q,
|
|
|
|
|
&neighbour.0,
|
|
|
|
|
neighbour.1,
|
|
|
|
|
neighbour.2,
|
|
|
|
|
manifest,
|
|
|
|
|
orig_table,
|
|
|
|
|
)?;
|
|
|
|
|
vec_cache.ensure_key(&neighbour_key, orig_table, self)?;
|
|
|
|
|
let dist = vec_cache.v_dist(q, &neighbour_key);
|
|
|
|
|
candidates.push(
|
|
|
|
|
(neighbour.0, neighbour.1, neighbour.2),
|
|
|
|
|
(neighbour_key.0, neighbour_key.1, neighbour_key.2),
|
|
|
|
|
Reverse(OrderedFloat(dist)),
|
|
|
|
|
);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
while !candidates.is_empty() && ret.len() < m {
|
|
|
|
|
let (nearest_triple, Reverse(OrderedFloat(nearest_dist))) = candidates.pop().unwrap();
|
|
|
|
|
match ret.peek() {
|
|
|
|
|
Some((_, Reverse(OrderedFloat(dist)))) => {
|
|
|
|
|
if nearest_dist < *dist {
|
|
|
|
|
ret.push(nearest_triple, Reverse(OrderedFloat(nearest_dist)));
|
|
|
|
|
} else if manifest.keep_pruned_connections {
|
|
|
|
|
discarded.push(nearest_triple, Reverse(OrderedFloat(nearest_dist)));
|
|
|
|
|
}
|
|
|
|
|
let (cand_key, Reverse(OrderedFloat(cand_dist_to_q))) = candidates.pop().unwrap();
|
|
|
|
|
let mut should_add = true;
|
|
|
|
|
for (existing, _) in ret.iter() {
|
|
|
|
|
vec_cache.ensure_key(&cand_key, orig_table, self)?;
|
|
|
|
|
vec_cache.ensure_key(&existing, orig_table, self)?;
|
|
|
|
|
let dist_to_existing = vec_cache.k_dist(existing, &cand_key);
|
|
|
|
|
if dist_to_existing < cand_dist_to_q {
|
|
|
|
|
should_add = false;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
None => {
|
|
|
|
|
ret.push(nearest_triple, Reverse(OrderedFloat(nearest_dist)));
|
|
|
|
|
}
|
|
|
|
|
if should_add {
|
|
|
|
|
ret.push(cand_key, Reverse(OrderedFloat(cand_dist_to_q)));
|
|
|
|
|
} else if manifest.keep_pruned_connections {
|
|
|
|
|
discarded.push(cand_key, Reverse(OrderedFloat(cand_dist_to_q)));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if manifest.keep_pruned_connections {
|
|
|
|
@ -556,13 +553,14 @@ impl<'a> SessionTx<'a> {
|
|
|
|
|
q: &Vector,
|
|
|
|
|
ef: usize,
|
|
|
|
|
cur_level: i64,
|
|
|
|
|
manifest: &HnswIndexManifest,
|
|
|
|
|
orig_table: &RelationHandle,
|
|
|
|
|
idx_table: &RelationHandle,
|
|
|
|
|
found_nn: &mut PriorityQueue<(Tuple, usize, i32), OrderedFloat<f64>>,
|
|
|
|
|
found_nn: &mut PriorityQueue<CompoundKey, OrderedFloat<f64>>,
|
|
|
|
|
vec_cache: &mut VectorCache,
|
|
|
|
|
) -> Result<()> {
|
|
|
|
|
let mut visited: BTreeSet<(Tuple, usize, i32)> = BTreeSet::new();
|
|
|
|
|
let mut candidates: PriorityQueue<(Tuple, usize, i32), Reverse<OrderedFloat<f64>>> =
|
|
|
|
|
let mut visited: BTreeSet<CompoundKey> = BTreeSet::new();
|
|
|
|
|
// min queue
|
|
|
|
|
let mut candidates: PriorityQueue<CompoundKey, Reverse<OrderedFloat<f64>>> =
|
|
|
|
|
PriorityQueue::new();
|
|
|
|
|
|
|
|
|
|
for item in found_nn.iter() {
|
|
|
|
@ -577,7 +575,7 @@ impl<'a> SessionTx<'a> {
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
// loop over each of the candidate's neighbors
|
|
|
|
|
for neighbour_tetra in self.hnsw_get_neighbours(
|
|
|
|
|
for (neighbour_key, _) in self.hnsw_get_neighbours(
|
|
|
|
|
candidate.0,
|
|
|
|
|
candidate.1,
|
|
|
|
|
candidate.2,
|
|
|
|
@ -585,30 +583,20 @@ impl<'a> SessionTx<'a> {
|
|
|
|
|
idx_table,
|
|
|
|
|
false,
|
|
|
|
|
)? {
|
|
|
|
|
let neighbour_triple = (neighbour_tetra.0, neighbour_tetra.1, neighbour_tetra.2);
|
|
|
|
|
if visited.contains(&neighbour_triple) {
|
|
|
|
|
if visited.contains(&neighbour_key) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
let neighbour_dist = self.hnsw_compare_vector(
|
|
|
|
|
q,
|
|
|
|
|
&neighbour_triple.0,
|
|
|
|
|
neighbour_triple.1,
|
|
|
|
|
neighbour_triple.2,
|
|
|
|
|
manifest,
|
|
|
|
|
orig_table,
|
|
|
|
|
)?;
|
|
|
|
|
vec_cache.ensure_key(&neighbour_key, orig_table, self)?;
|
|
|
|
|
let neighbour_dist = vec_cache.v_dist(q, &neighbour_key);
|
|
|
|
|
let (_, OrderedFloat(cand_furtherest_dist)) = found_nn.peek().unwrap();
|
|
|
|
|
if found_nn.len() < ef || neighbour_dist < *cand_furtherest_dist {
|
|
|
|
|
candidates.push(
|
|
|
|
|
neighbour_triple.clone(),
|
|
|
|
|
Reverse(OrderedFloat(neighbour_dist)),
|
|
|
|
|
);
|
|
|
|
|
found_nn.push(neighbour_triple.clone(), OrderedFloat(neighbour_dist));
|
|
|
|
|
candidates.push(neighbour_key.clone(), Reverse(OrderedFloat(neighbour_dist)));
|
|
|
|
|
found_nn.push(neighbour_key.clone(), OrderedFloat(neighbour_dist));
|
|
|
|
|
if found_nn.len() > ef {
|
|
|
|
|
found_nn.pop();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
visited.insert(neighbour_triple);
|
|
|
|
|
visited.insert(neighbour_key);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -622,7 +610,7 @@ impl<'a> SessionTx<'a> {
|
|
|
|
|
level: i64,
|
|
|
|
|
idx_handle: &RelationHandle,
|
|
|
|
|
include_deleted: bool,
|
|
|
|
|
) -> Result<impl Iterator<Item = (Tuple, usize, i32, f64)> + 'b> {
|
|
|
|
|
) -> Result<impl Iterator<Item = (CompoundKey, f64)> + 'b> {
|
|
|
|
|
let mut start_tuple = Vec::with_capacity(cand_key.len() + 3);
|
|
|
|
|
start_tuple.push(DataValue::from(level));
|
|
|
|
|
start_tuple.extend_from_slice(&cand_key);
|
|
|
|
@ -642,9 +630,7 @@ impl<'a> SessionTx<'a> {
|
|
|
|
|
} else {
|
|
|
|
|
if include_deleted {
|
|
|
|
|
return Some((
|
|
|
|
|
key_slice,
|
|
|
|
|
key_idx,
|
|
|
|
|
key_subidx,
|
|
|
|
|
(key_slice, key_idx, key_subidx),
|
|
|
|
|
tuple[2 * key_len + 5].get_float().unwrap(),
|
|
|
|
|
));
|
|
|
|
|
}
|
|
|
|
@ -653,9 +639,7 @@ impl<'a> SessionTx<'a> {
|
|
|
|
|
None
|
|
|
|
|
} else {
|
|
|
|
|
Some((
|
|
|
|
|
key_slice,
|
|
|
|
|
key_idx,
|
|
|
|
|
key_subidx,
|
|
|
|
|
(key_slice, key_idx, key_subidx),
|
|
|
|
|
tuple[2 * key_len + 5].get_float().unwrap(),
|
|
|
|
|
))
|
|
|
|
|
}
|
|
|
|
@ -801,22 +785,22 @@ impl<'a> SessionTx<'a> {
|
|
|
|
|
.hnsw_get_neighbours(tuple_key.to_vec(), idx, subidx, layer, idx_table, true)?
|
|
|
|
|
.collect_vec();
|
|
|
|
|
encountered_singletons |= neigbours.is_empty();
|
|
|
|
|
for (neighbour_key, neighbour_idx, neighbour_subidx, _) in neigbours {
|
|
|
|
|
for (neighbour_key, _) in neigbours {
|
|
|
|
|
// REMARK: this still has some probability of disconnecting the graph.
|
|
|
|
|
// Should we accept that as a consequence of the probabilistic nature of the algorithm?
|
|
|
|
|
let mut out_key = vec![DataValue::from(layer)];
|
|
|
|
|
out_key.extend_from_slice(tuple_key);
|
|
|
|
|
out_key.push(DataValue::from(idx as i64));
|
|
|
|
|
out_key.push(DataValue::from(subidx as i64));
|
|
|
|
|
out_key.extend_from_slice(&neighbour_key);
|
|
|
|
|
out_key.push(DataValue::from(neighbour_idx as i64));
|
|
|
|
|
out_key.push(DataValue::from(neighbour_subidx as i64));
|
|
|
|
|
out_key.extend_from_slice(&neighbour_key.0);
|
|
|
|
|
out_key.push(DataValue::from(neighbour_key.1 as i64));
|
|
|
|
|
out_key.push(DataValue::from(neighbour_key.2 as i64));
|
|
|
|
|
let out_key_bytes = idx_table.encode_key_for_store(&out_key, Default::default())?;
|
|
|
|
|
self.store_tx.del(&out_key_bytes)?;
|
|
|
|
|
let mut in_key = vec![DataValue::from(layer)];
|
|
|
|
|
in_key.extend_from_slice(&neighbour_key);
|
|
|
|
|
in_key.push(DataValue::from(neighbour_idx as i64));
|
|
|
|
|
in_key.push(DataValue::from(neighbour_subidx as i64));
|
|
|
|
|
in_key.extend_from_slice(&neighbour_key.0);
|
|
|
|
|
in_key.push(DataValue::from(neighbour_key.1 as i64));
|
|
|
|
|
in_key.push(DataValue::from(neighbour_key.2 as i64));
|
|
|
|
|
in_key.extend_from_slice(tuple_key);
|
|
|
|
|
in_key.push(DataValue::from(idx as i64));
|
|
|
|
|
in_key.push(DataValue::from(subidx as i64));
|
|
|
|
@ -824,9 +808,9 @@ impl<'a> SessionTx<'a> {
|
|
|
|
|
self.store_tx.del(&in_key_bytes)?;
|
|
|
|
|
let mut neighbour_self_key = vec![DataValue::from(layer)];
|
|
|
|
|
for _ in 0..2 {
|
|
|
|
|
neighbour_self_key.extend_from_slice(&neighbour_key);
|
|
|
|
|
neighbour_self_key.push(DataValue::from(neighbour_idx as i64));
|
|
|
|
|
neighbour_self_key.push(DataValue::from(neighbour_subidx as i64));
|
|
|
|
|
neighbour_self_key.extend_from_slice(&neighbour_key.0);
|
|
|
|
|
neighbour_self_key.push(DataValue::from(neighbour_key.1 as i64));
|
|
|
|
|
neighbour_self_key.push(DataValue::from(neighbour_key.2 as i64));
|
|
|
|
|
}
|
|
|
|
|
let neighbour_val_bytes = self
|
|
|
|
|
.store_tx
|
|
|
|
@ -903,6 +887,11 @@ impl<'a> SessionTx<'a> {
|
|
|
|
|
(Vector::F64(v), VecElementType::F32) => Vector::F32(v.mapv(|x| x as f32)),
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
let mut vec_cache = VectorCache {
|
|
|
|
|
cache: Default::default(),
|
|
|
|
|
distance: config.manifest.distance,
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
let ep_res = config
|
|
|
|
|
.idx_handle
|
|
|
|
|
.scan_bounded_prefix(
|
|
|
|
@ -915,42 +904,37 @@ impl<'a> SessionTx<'a> {
|
|
|
|
|
if let Some(ep) = ep_res {
|
|
|
|
|
let ep = ep?;
|
|
|
|
|
let bottom_level = ep[0].get_int().unwrap();
|
|
|
|
|
let ep_key = ep[1..config.base_handle.metadata.keys.len() + 1].to_vec();
|
|
|
|
|
let ep_t_key = ep[1..config.base_handle.metadata.keys.len() + 1].to_vec();
|
|
|
|
|
let ep_idx = ep[config.base_handle.metadata.keys.len() + 1]
|
|
|
|
|
.get_int()
|
|
|
|
|
.unwrap() as usize;
|
|
|
|
|
let ep_subidx = ep[config.base_handle.metadata.keys.len() + 2]
|
|
|
|
|
.get_int()
|
|
|
|
|
.unwrap() as i32;
|
|
|
|
|
let ep_distance = self.hnsw_compare_vector(
|
|
|
|
|
&q,
|
|
|
|
|
&ep_key,
|
|
|
|
|
ep_idx,
|
|
|
|
|
ep_subidx,
|
|
|
|
|
&config.manifest,
|
|
|
|
|
&config.base_handle,
|
|
|
|
|
)?;
|
|
|
|
|
let ep_key = (ep_t_key, ep_idx, ep_subidx);
|
|
|
|
|
vec_cache.ensure_key(&ep_key, &config.base_handle, self)?;
|
|
|
|
|
let ep_distance = vec_cache.v_dist(&q, &ep_key);
|
|
|
|
|
let mut found_nn = PriorityQueue::new();
|
|
|
|
|
found_nn.push((ep_key, ep_idx, ep_subidx), OrderedFloat(ep_distance));
|
|
|
|
|
found_nn.push(ep_key, OrderedFloat(ep_distance));
|
|
|
|
|
for current_level in bottom_level..0 {
|
|
|
|
|
self.hnsw_search_level(
|
|
|
|
|
&q,
|
|
|
|
|
1,
|
|
|
|
|
current_level,
|
|
|
|
|
&config.manifest,
|
|
|
|
|
&config.base_handle,
|
|
|
|
|
&config.idx_handle,
|
|
|
|
|
&mut found_nn,
|
|
|
|
|
&mut vec_cache,
|
|
|
|
|
)?;
|
|
|
|
|
}
|
|
|
|
|
self.hnsw_search_level(
|
|
|
|
|
&q,
|
|
|
|
|
config.ef,
|
|
|
|
|
0,
|
|
|
|
|
&config.manifest,
|
|
|
|
|
&config.base_handle,
|
|
|
|
|
&config.idx_handle,
|
|
|
|
|
&mut found_nn,
|
|
|
|
|
&mut vec_cache,
|
|
|
|
|
)?;
|
|
|
|
|
if found_nn.is_empty() {
|
|
|
|
|
return Ok(vec![]);
|
|
|
|
@ -964,9 +948,7 @@ impl<'a> SessionTx<'a> {
|
|
|
|
|
|
|
|
|
|
let mut ret = vec![];
|
|
|
|
|
|
|
|
|
|
while let Some(((cand_tuple, cand_idx, cand_subidx), OrderedFloat(distance))) =
|
|
|
|
|
found_nn.pop()
|
|
|
|
|
{
|
|
|
|
|
while let Some((cand_key, OrderedFloat(distance))) = found_nn.pop() {
|
|
|
|
|
if let Some(r) = config.radius {
|
|
|
|
|
if distance > r {
|
|
|
|
|
continue;
|
|
|
|
@ -975,40 +957,40 @@ impl<'a> SessionTx<'a> {
|
|
|
|
|
|
|
|
|
|
let mut cand_tuple = config
|
|
|
|
|
.base_handle
|
|
|
|
|
.get(self, &cand_tuple)?
|
|
|
|
|
.get(self, &cand_key.0)?
|
|
|
|
|
.ok_or_else(|| miette!("corrupted index"))?;
|
|
|
|
|
|
|
|
|
|
if config.bind_field.is_some() {
|
|
|
|
|
let field = if cand_idx as usize >= config.base_handle.metadata.keys.len() {
|
|
|
|
|
config.base_handle.metadata.keys[cand_idx as usize]
|
|
|
|
|
let field = if cand_key.1 as usize >= config.base_handle.metadata.keys.len() {
|
|
|
|
|
config.base_handle.metadata.keys[cand_key.1 as usize]
|
|
|
|
|
.name
|
|
|
|
|
.clone()
|
|
|
|
|
} else {
|
|
|
|
|
config.base_handle.metadata.non_keys
|
|
|
|
|
[cand_idx as usize - config.base_handle.metadata.keys.len()]
|
|
|
|
|
[cand_key.1 as usize - config.base_handle.metadata.keys.len()]
|
|
|
|
|
.name
|
|
|
|
|
.clone()
|
|
|
|
|
};
|
|
|
|
|
cand_tuple.push(DataValue::Str(field));
|
|
|
|
|
}
|
|
|
|
|
if config.bind_field_idx.is_some() {
|
|
|
|
|
cand_tuple.push(if cand_subidx < 0 {
|
|
|
|
|
cand_tuple.push(if cand_key.2 < 0 {
|
|
|
|
|
DataValue::Null
|
|
|
|
|
} else {
|
|
|
|
|
DataValue::from(cand_subidx as i64)
|
|
|
|
|
DataValue::from(cand_key.2 as i64)
|
|
|
|
|
});
|
|
|
|
|
}
|
|
|
|
|
if config.bind_distance.is_some() {
|
|
|
|
|
cand_tuple.push(DataValue::from(distance));
|
|
|
|
|
}
|
|
|
|
|
if config.bind_vector.is_some() {
|
|
|
|
|
let vec = if cand_subidx < 0 {
|
|
|
|
|
match &cand_tuple[cand_idx] {
|
|
|
|
|
DataValue::List(v) => v[cand_subidx as usize].clone(),
|
|
|
|
|
let vec = if cand_key.2 < 0 {
|
|
|
|
|
match &cand_tuple[cand_key.1] {
|
|
|
|
|
DataValue::List(v) => v[cand_key.2 as usize].clone(),
|
|
|
|
|
_ => bail!("corrupted index"),
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
cand_tuple[cand_idx].clone()
|
|
|
|
|
cand_tuple[cand_key.1].clone()
|
|
|
|
|
};
|
|
|
|
|
cand_tuple.push(vec);
|
|
|
|
|
}
|
|
|
|
|