|
|
@ -20,9 +20,9 @@ use miette::{bail, miette, Result};
|
|
|
|
use ordered_float::OrderedFloat;
|
|
|
|
use ordered_float::OrderedFloat;
|
|
|
|
use priority_queue::PriorityQueue;
|
|
|
|
use priority_queue::PriorityQueue;
|
|
|
|
use rand::Rng;
|
|
|
|
use rand::Rng;
|
|
|
|
|
|
|
|
use rustc_hash::{FxHashMap, FxHashSet};
|
|
|
|
use smartstring::{LazyCompact, SmartString};
|
|
|
|
use smartstring::{LazyCompact, SmartString};
|
|
|
|
use std::cmp::{max, Reverse};
|
|
|
|
use std::cmp::{max, Reverse};
|
|
|
|
use rustc_hash::{FxHashMap, FxHashSet};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#[derive(Debug, Clone, PartialEq, serde_derive::Serialize, serde_derive::Deserialize)]
|
|
|
|
#[derive(Debug, Clone, PartialEq, serde_derive::Serialize, serde_derive::Deserialize)]
|
|
|
|
pub(crate) struct HnswIndexManifest {
|
|
|
|
pub(crate) struct HnswIndexManifest {
|
|
|
@ -81,8 +81,8 @@ impl VectorCache {
|
|
|
|
1.0 - dot / (a_norm * b_norm).sqrt()
|
|
|
|
1.0 - dot / (a_norm * b_norm).sqrt()
|
|
|
|
}
|
|
|
|
}
|
|
|
|
(Vector::F64(a), Vector::F64(b)) => {
|
|
|
|
(Vector::F64(a), Vector::F64(b)) => {
|
|
|
|
let a_norm = a.dot(a) as f64;
|
|
|
|
let a_norm = a.dot(a);
|
|
|
|
let b_norm = b.dot(b) as f64;
|
|
|
|
let b_norm = b.dot(b);
|
|
|
|
let dot = a.dot(b);
|
|
|
|
let dot = a.dot(b);
|
|
|
|
1.0 - dot / (a_norm * b_norm).sqrt()
|
|
|
|
1.0 - dot / (a_norm * b_norm).sqrt()
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -98,7 +98,7 @@ impl VectorCache {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
(Vector::F64(a), Vector::F64(b)) => {
|
|
|
|
(Vector::F64(a), Vector::F64(b)) => {
|
|
|
|
let dot = a.dot(b);
|
|
|
|
let dot = a.dot(b);
|
|
|
|
1. - dot as f64
|
|
|
|
1. - dot
|
|
|
|
}
|
|
|
|
}
|
|
|
|
_ => panic!("Cannot compute inner product between {:?} and {:?}", v1, v2),
|
|
|
|
_ => panic!("Cannot compute inner product between {:?} and {:?}", v1, v2),
|
|
|
|
},
|
|
|
|
},
|
|
|
@ -164,7 +164,9 @@ impl<'a> SessionTx<'a> {
|
|
|
|
distance: manifest.distance,
|
|
|
|
distance: manifest.distance,
|
|
|
|
};
|
|
|
|
};
|
|
|
|
let tuple_key = &tuple[..orig_table.metadata.keys.len()];
|
|
|
|
let tuple_key = &tuple[..orig_table.metadata.keys.len()];
|
|
|
|
vec_cache.cache.insert((tuple_key.to_vec(), idx, subidx), q.clone());
|
|
|
|
vec_cache
|
|
|
|
|
|
|
|
.cache
|
|
|
|
|
|
|
|
.insert((tuple_key.to_vec(), idx, subidx), q.clone());
|
|
|
|
let hash = q.get_hash();
|
|
|
|
let hash = q.get_hash();
|
|
|
|
let mut canary_tuple = vec![DataValue::from(0)];
|
|
|
|
let mut canary_tuple = vec![DataValue::from(0)];
|
|
|
|
for _ in 0..2 {
|
|
|
|
for _ in 0..2 {
|
|
|
@ -339,7 +341,7 @@ impl<'a> SessionTx<'a> {
|
|
|
|
if target_degree > m_max {
|
|
|
|
if target_degree > m_max {
|
|
|
|
// shrink links
|
|
|
|
// shrink links
|
|
|
|
target_degree = self.hnsw_shrink_neighbour(
|
|
|
|
target_degree = self.hnsw_shrink_neighbour(
|
|
|
|
&neighbour,
|
|
|
|
neighbour,
|
|
|
|
m_max,
|
|
|
|
m_max,
|
|
|
|
current_level,
|
|
|
|
current_level,
|
|
|
|
manifest,
|
|
|
|
manifest,
|
|
|
@ -383,8 +385,8 @@ impl<'a> SessionTx<'a> {
|
|
|
|
orig_table: &RelationHandle,
|
|
|
|
orig_table: &RelationHandle,
|
|
|
|
vec_cache: &mut VectorCache,
|
|
|
|
vec_cache: &mut VectorCache,
|
|
|
|
) -> Result<usize> {
|
|
|
|
) -> Result<usize> {
|
|
|
|
vec_cache.ensure_key(&target_key, orig_table, self)?;
|
|
|
|
vec_cache.ensure_key(target_key, orig_table, self)?;
|
|
|
|
let vec = vec_cache.get_key(&target_key).clone();
|
|
|
|
let vec = vec_cache.get_key(target_key).clone();
|
|
|
|
let mut candidates = PriorityQueue::new();
|
|
|
|
let mut candidates = PriorityQueue::new();
|
|
|
|
for (neighbour_key, neighbour_dist) in
|
|
|
|
for (neighbour_key, neighbour_dist) in
|
|
|
|
self.hnsw_get_neighbours(target_key, level, idx_table, false)?
|
|
|
|
self.hnsw_get_neighbours(target_key, level, idx_table, false)?
|
|
|
@ -499,12 +501,7 @@ impl<'a> SessionTx<'a> {
|
|
|
|
if manifest.extend_candidates {
|
|
|
|
if manifest.extend_candidates {
|
|
|
|
for (item, _) in found.iter() {
|
|
|
|
for (item, _) in found.iter() {
|
|
|
|
// Extend by neighbours
|
|
|
|
// Extend by neighbours
|
|
|
|
for (neighbour_key, _) in self.hnsw_get_neighbours(
|
|
|
|
for (neighbour_key, _) in self.hnsw_get_neighbours(item, level, idx_table, false)? {
|
|
|
|
&item,
|
|
|
|
|
|
|
|
level,
|
|
|
|
|
|
|
|
idx_table,
|
|
|
|
|
|
|
|
false,
|
|
|
|
|
|
|
|
)? {
|
|
|
|
|
|
|
|
vec_cache.ensure_key(&neighbour_key, orig_table, self)?;
|
|
|
|
vec_cache.ensure_key(&neighbour_key, orig_table, self)?;
|
|
|
|
let dist = vec_cache.v_dist(q, &neighbour_key);
|
|
|
|
let dist = vec_cache.v_dist(q, &neighbour_key);
|
|
|
|
candidates.push(
|
|
|
|
candidates.push(
|
|
|
@ -519,7 +516,7 @@ impl<'a> SessionTx<'a> {
|
|
|
|
let mut should_add = true;
|
|
|
|
let mut should_add = true;
|
|
|
|
for (existing, _) in ret.iter() {
|
|
|
|
for (existing, _) in ret.iter() {
|
|
|
|
vec_cache.ensure_key(&cand_key, orig_table, self)?;
|
|
|
|
vec_cache.ensure_key(&cand_key, orig_table, self)?;
|
|
|
|
vec_cache.ensure_key(&existing, orig_table, self)?;
|
|
|
|
vec_cache.ensure_key(existing, orig_table, self)?;
|
|
|
|
let dist_to_existing = vec_cache.k_dist(existing, &cand_key);
|
|
|
|
let dist_to_existing = vec_cache.k_dist(existing, &cand_key);
|
|
|
|
if dist_to_existing < cand_dist_to_q {
|
|
|
|
if dist_to_existing < cand_dist_to_q {
|
|
|
|
should_add = false;
|
|
|
|
should_add = false;
|
|
|
@ -568,12 +565,9 @@ impl<'a> SessionTx<'a> {
|
|
|
|
break;
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// loop over each of the candidate's neighbors
|
|
|
|
// loop over each of the candidate's neighbors
|
|
|
|
for (neighbour_key, _) in self.hnsw_get_neighbours(
|
|
|
|
for (neighbour_key, _) in
|
|
|
|
&candidate,
|
|
|
|
self.hnsw_get_neighbours(&candidate, cur_level, idx_table, false)?
|
|
|
|
cur_level,
|
|
|
|
{
|
|
|
|
idx_table,
|
|
|
|
|
|
|
|
false,
|
|
|
|
|
|
|
|
)? {
|
|
|
|
|
|
|
|
if visited.contains(&neighbour_key) {
|
|
|
|
if visited.contains(&neighbour_key) {
|
|
|
|
continue;
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -702,7 +696,7 @@ impl<'a> SessionTx<'a> {
|
|
|
|
for idx in &manifest.vec_fields {
|
|
|
|
for idx in &manifest.vec_fields {
|
|
|
|
let val = tuple.get(*idx).unwrap();
|
|
|
|
let val = tuple.get(*idx).unwrap();
|
|
|
|
if let DataValue::Vec(v) = val {
|
|
|
|
if let DataValue::Vec(v) = val {
|
|
|
|
extracted_vectors.push((v, *idx, -1 as i32));
|
|
|
|
extracted_vectors.push((v, *idx, -1));
|
|
|
|
} else if let DataValue::List(l) = val {
|
|
|
|
} else if let DataValue::List(l) = val {
|
|
|
|
for (sidx, v) in l.iter().enumerate() {
|
|
|
|
for (sidx, v) in l.iter().enumerate() {
|
|
|
|
if let DataValue::Vec(v) = v {
|
|
|
|
if let DataValue::Vec(v) = v {
|
|
|
@ -715,7 +709,7 @@ impl<'a> SessionTx<'a> {
|
|
|
|
return Ok(false);
|
|
|
|
return Ok(false);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
for (vec, idx, sub) in extracted_vectors {
|
|
|
|
for (vec, idx, sub) in extracted_vectors {
|
|
|
|
self.hnsw_put_vector(&tuple, vec, idx, sub, manifest, orig_table, idx_table)?;
|
|
|
|
self.hnsw_put_vector(tuple, vec, idx, sub, manifest, orig_table, idx_table)?;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
Ok(true)
|
|
|
|
Ok(true)
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -951,13 +945,11 @@ impl<'a> SessionTx<'a> {
|
|
|
|
.ok_or_else(|| miette!("corrupted index"))?;
|
|
|
|
.ok_or_else(|| miette!("corrupted index"))?;
|
|
|
|
|
|
|
|
|
|
|
|
if config.bind_field.is_some() {
|
|
|
|
if config.bind_field.is_some() {
|
|
|
|
let field = if cand_key.1 as usize >= config.base_handle.metadata.keys.len() {
|
|
|
|
let field = if cand_key.1 >= config.base_handle.metadata.keys.len() {
|
|
|
|
config.base_handle.metadata.keys[cand_key.1 as usize]
|
|
|
|
config.base_handle.metadata.keys[cand_key.1].name.clone()
|
|
|
|
.name
|
|
|
|
|
|
|
|
.clone()
|
|
|
|
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
config.base_handle.metadata.non_keys
|
|
|
|
config.base_handle.metadata.non_keys
|
|
|
|
[cand_key.1 as usize - config.base_handle.metadata.keys.len()]
|
|
|
|
[cand_key.1 - config.base_handle.metadata.keys.len()]
|
|
|
|
.name
|
|
|
|
.name
|
|
|
|
.clone()
|
|
|
|
.clone()
|
|
|
|
};
|
|
|
|
};
|
|
|
|