|
|
|
@ -13,13 +13,11 @@ use crate::data::value::Vector;
|
|
|
|
|
use crate::parse::sys::HnswDistance;
|
|
|
|
|
use crate::runtime::relation::RelationHandle;
|
|
|
|
|
use crate::runtime::transact::SessionTx;
|
|
|
|
|
use crate::{decode_tuple_from_kv, DataValue, Symbol};
|
|
|
|
|
use miette::{bail, Result};
|
|
|
|
|
use crate::{decode_tuple_from_kv, DataValue};
|
|
|
|
|
use miette::{bail, miette, Result};
|
|
|
|
|
use ordered_float::OrderedFloat;
|
|
|
|
|
use priority_queue::PriorityQueue;
|
|
|
|
|
use rand::Rng;
|
|
|
|
|
use sha2::digest::FixedOutput;
|
|
|
|
|
use sha2::{Digest, Sha256};
|
|
|
|
|
use smartstring::{LazyCompact, SmartString};
|
|
|
|
|
use std::cmp::{max, Reverse};
|
|
|
|
|
use std::collections::BTreeSet;
|
|
|
|
@ -45,12 +43,32 @@ pub(crate) struct HnswIndexManifest {
|
|
|
|
|
pub(crate) struct HnswKnnQueryOptions {
|
|
|
|
|
k: usize,
|
|
|
|
|
ef: usize,
|
|
|
|
|
max_distance: f64,
|
|
|
|
|
min_margin: f64,
|
|
|
|
|
auto_margin_factor: Option<f64>,
|
|
|
|
|
bind_field: Option<Symbol>,
|
|
|
|
|
bind_distance: Option<Symbol>,
|
|
|
|
|
bind_vector: Option<Symbol>,
|
|
|
|
|
bind_field: bool,
|
|
|
|
|
bind_field_idx: bool,
|
|
|
|
|
bind_distance: bool,
|
|
|
|
|
bind_vector: bool,
|
|
|
|
|
radius: Option<f64>,
|
|
|
|
|
orig_table_binding: Vec<usize>,
|
|
|
|
|
filter: Option<Vec<Bytecode>>,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl HnswKnnQueryOptions {
|
|
|
|
|
fn return_tuple_len(&self) -> usize {
|
|
|
|
|
let mut ret_tuple_len = self.orig_table_binding.len();
|
|
|
|
|
if self.bind_field {
|
|
|
|
|
ret_tuple_len += 1;
|
|
|
|
|
}
|
|
|
|
|
if self.bind_field_idx {
|
|
|
|
|
ret_tuple_len += 1;
|
|
|
|
|
}
|
|
|
|
|
if self.bind_distance {
|
|
|
|
|
ret_tuple_len += 1;
|
|
|
|
|
}
|
|
|
|
|
if self.bind_vector {
|
|
|
|
|
ret_tuple_len += 1;
|
|
|
|
|
}
|
|
|
|
|
ret_tuple_len
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl HnswIndexManifest {
|
|
|
|
@ -167,9 +185,23 @@ impl<'a> SessionTx<'a> {
|
|
|
|
|
orig_table: &RelationHandle,
|
|
|
|
|
idx_table: &RelationHandle,
|
|
|
|
|
) -> Result<()> {
|
|
|
|
|
// TODO check if this is an update!
|
|
|
|
|
|
|
|
|
|
let tuple_key = &tuple[..orig_table.metadata.keys.len()];
|
|
|
|
|
let hash = q.get_hash();
|
|
|
|
|
let mut canary_tuple = vec![DataValue::from(0)];
|
|
|
|
|
for _ in 0..2 {
|
|
|
|
|
canary_tuple.extend_from_slice(tuple_key);
|
|
|
|
|
canary_tuple.push(DataValue::from(idx as i64));
|
|
|
|
|
canary_tuple.push(DataValue::from(subidx as i64));
|
|
|
|
|
}
|
|
|
|
|
if let Some(v) = idx_table.get(self, &canary_tuple)? {
|
|
|
|
|
if let DataValue::Bytes(b) = &v[tuple_key.len() * 2 + 6] {
|
|
|
|
|
if b == hash.as_ref() {
|
|
|
|
|
return Ok(());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// TODO
|
|
|
|
|
self.hnsw_remove_vec()?;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
let ep_res = idx_table
|
|
|
|
|
.scan_bounded_prefix(
|
|
|
|
@ -194,6 +226,7 @@ impl<'a> SessionTx<'a> {
|
|
|
|
|
if target_level < bottom_level {
|
|
|
|
|
// this becomes the entry point
|
|
|
|
|
self.hnsw_put_fresh_at_levels(
|
|
|
|
|
hash.as_ref(),
|
|
|
|
|
tuple_key,
|
|
|
|
|
idx,
|
|
|
|
|
subidx,
|
|
|
|
@ -223,7 +256,7 @@ impl<'a> SessionTx<'a> {
|
|
|
|
|
}
|
|
|
|
|
let mut self_tuple_val = vec![
|
|
|
|
|
DataValue::from(0.0),
|
|
|
|
|
DataValue::Null,
|
|
|
|
|
DataValue::Bytes(hash.as_ref().to_vec()),
|
|
|
|
|
DataValue::from(false),
|
|
|
|
|
];
|
|
|
|
|
for current_level in max(target_level, bottom_level)..=0 {
|
|
|
|
@ -255,15 +288,6 @@ impl<'a> SessionTx<'a> {
|
|
|
|
|
self_tuple_key[0] = DataValue::from(current_level);
|
|
|
|
|
self_tuple_val[0] = DataValue::from(neighbours.len() as f64);
|
|
|
|
|
|
|
|
|
|
// save hash in self-loops
|
|
|
|
|
let mut hasher = Sha256::new();
|
|
|
|
|
for (_, Reverse(OrderedFloat(dist))) in neighbours.iter() {
|
|
|
|
|
let dist_bs = dist.to_be_bytes();
|
|
|
|
|
Digest::update(&mut hasher, &dist_bs);
|
|
|
|
|
}
|
|
|
|
|
let hash = hasher.finalize_fixed();
|
|
|
|
|
self_tuple_val[1] = DataValue::Bytes(hash.to_vec());
|
|
|
|
|
|
|
|
|
|
let self_tuple_key_bytes =
|
|
|
|
|
idx_table.encode_key_for_store(&self_tuple_key, Default::default())?;
|
|
|
|
|
let self_tuple_val_bytes =
|
|
|
|
@ -351,7 +375,7 @@ impl<'a> SessionTx<'a> {
|
|
|
|
|
} else {
|
|
|
|
|
// This is the first vector in the index.
|
|
|
|
|
let level = manifest.get_random_level();
|
|
|
|
|
self.hnsw_put_fresh_at_levels(tuple_key, idx, subidx, orig_table, idx_table, level, 0)?;
|
|
|
|
|
self.hnsw_put_fresh_at_levels(hash.as_ref(), tuple_key, idx, subidx, orig_table, idx_table, level, 0)?;
|
|
|
|
|
}
|
|
|
|
|
Ok(())
|
|
|
|
|
}
|
|
|
|
@ -641,6 +665,7 @@ impl<'a> SessionTx<'a> {
|
|
|
|
|
}
|
|
|
|
|
fn hnsw_put_fresh_at_levels(
|
|
|
|
|
&mut self,
|
|
|
|
|
hash: &[u8],
|
|
|
|
|
tuple: &[DataValue],
|
|
|
|
|
idx: usize,
|
|
|
|
|
subidx: i32,
|
|
|
|
@ -663,7 +688,7 @@ impl<'a> SessionTx<'a> {
|
|
|
|
|
}
|
|
|
|
|
let target_value = [
|
|
|
|
|
DataValue::from(0.0),
|
|
|
|
|
DataValue::Null,
|
|
|
|
|
DataValue::Bytes(hash.to_vec()),
|
|
|
|
|
DataValue::from(false),
|
|
|
|
|
];
|
|
|
|
|
let target_key_bytes = idx_table.encode_key_for_store(&target_key, Default::default())?;
|
|
|
|
@ -692,10 +717,11 @@ impl<'a> SessionTx<'a> {
|
|
|
|
|
manifest: &HnswIndexManifest,
|
|
|
|
|
orig_table: &RelationHandle,
|
|
|
|
|
idx_table: &RelationHandle,
|
|
|
|
|
filter: Option<(&[Bytecode], &mut Vec<DataValue>)>,
|
|
|
|
|
filter: &Option<Vec<Bytecode>>,
|
|
|
|
|
stack: &mut Vec<DataValue>,
|
|
|
|
|
tuple: &Tuple,
|
|
|
|
|
) -> Result<bool> {
|
|
|
|
|
if let Some((code, stack)) = filter {
|
|
|
|
|
if let Some(code) = filter {
|
|
|
|
|
if !eval_bytecode_pred(code, tuple, stack, Default::default())? {
|
|
|
|
|
return Ok(false);
|
|
|
|
|
}
|
|
|
|
@ -724,9 +750,136 @@ impl<'a> SessionTx<'a> {
|
|
|
|
|
pub(crate) fn hnsw_remove(&mut self) -> Result<()> {
|
|
|
|
|
todo!()
|
|
|
|
|
}
|
|
|
|
|
pub(crate) fn hnsw_knn(&self) -> Result<()> {
|
|
|
|
|
pub(crate) fn hnsw_remove_vec(&mut self) -> Result<()> {
|
|
|
|
|
todo!()
|
|
|
|
|
}
|
|
|
|
|
pub(crate) fn hnsw_knn(
|
|
|
|
|
&self,
|
|
|
|
|
q: Vector,
|
|
|
|
|
config: &HnswKnnQueryOptions,
|
|
|
|
|
idx_table: &RelationHandle,
|
|
|
|
|
orig_table: &RelationHandle,
|
|
|
|
|
manifest: &HnswIndexManifest,
|
|
|
|
|
) -> Result<Vec<Tuple>> {
|
|
|
|
|
if q.len() != manifest.vec_dim {
|
|
|
|
|
bail!("query vector dimension mismatch");
|
|
|
|
|
}
|
|
|
|
|
let q = match (q, manifest.dtype) {
|
|
|
|
|
(v @ Vector::F32(_), VecElementType::F32) => v,
|
|
|
|
|
(v @ Vector::F64(_), VecElementType::F64) => v,
|
|
|
|
|
(Vector::F32(v), VecElementType::F64) => Vector::F64(v.mapv(|x| x as f64)),
|
|
|
|
|
(Vector::F64(v), VecElementType::F32) => Vector::F32(v.mapv(|x| x as f32)),
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
let ep_res = idx_table
|
|
|
|
|
.scan_bounded_prefix(
|
|
|
|
|
self,
|
|
|
|
|
&[],
|
|
|
|
|
&[DataValue::from(i64::MIN)],
|
|
|
|
|
&[DataValue::from(1)],
|
|
|
|
|
)
|
|
|
|
|
.next();
|
|
|
|
|
if let Some(ep) = ep_res {
|
|
|
|
|
let ep = ep?;
|
|
|
|
|
let bottom_level = ep[0].get_int().unwrap();
|
|
|
|
|
let ep_key = ep[1..orig_table.metadata.keys.len() + 1].to_vec();
|
|
|
|
|
let ep_idx = ep[orig_table.metadata.keys.len() + 1].get_int().unwrap() as usize;
|
|
|
|
|
let ep_subidx = ep[orig_table.metadata.keys.len() + 2].get_int().unwrap() as i32;
|
|
|
|
|
let ep_distance =
|
|
|
|
|
self.hnsw_compare_vector(&q, &ep_key, ep_idx, ep_subidx, manifest, orig_table)?;
|
|
|
|
|
let mut found_nn = PriorityQueue::new();
|
|
|
|
|
found_nn.push((ep_key, ep_idx, ep_subidx), OrderedFloat(ep_distance));
|
|
|
|
|
for current_level in bottom_level..0 {
|
|
|
|
|
self.hnsw_search_level(
|
|
|
|
|
&q,
|
|
|
|
|
1,
|
|
|
|
|
current_level,
|
|
|
|
|
manifest,
|
|
|
|
|
orig_table,
|
|
|
|
|
idx_table,
|
|
|
|
|
&mut found_nn,
|
|
|
|
|
)?;
|
|
|
|
|
}
|
|
|
|
|
self.hnsw_search_level(
|
|
|
|
|
&q,
|
|
|
|
|
config.ef,
|
|
|
|
|
0,
|
|
|
|
|
manifest,
|
|
|
|
|
orig_table,
|
|
|
|
|
idx_table,
|
|
|
|
|
&mut found_nn,
|
|
|
|
|
)?;
|
|
|
|
|
if found_nn.is_empty() {
|
|
|
|
|
return Ok(vec![]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if config.filter.is_none() {
|
|
|
|
|
while found_nn.len() > config.k {
|
|
|
|
|
found_nn.pop();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
let max_binding = *config.orig_table_binding.iter().max().unwrap();
|
|
|
|
|
// FIXME this is wasteful
|
|
|
|
|
let needs_query_original_table = config.bind_vector
|
|
|
|
|
|| config.filter.is_some()
|
|
|
|
|
|| max_binding >= orig_table.metadata.keys.len();
|
|
|
|
|
|
|
|
|
|
let mut ret = vec![];
|
|
|
|
|
let mut stack = vec![];
|
|
|
|
|
let ret_tuple_len = config.return_tuple_len();
|
|
|
|
|
|
|
|
|
|
while let Some(((mut cand_tuple, cand_idx, cand_subidx), OrderedFloat(distance))) =
|
|
|
|
|
found_nn.pop()
|
|
|
|
|
{
|
|
|
|
|
if let Some(r) = config.radius {
|
|
|
|
|
if distance > r {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if needs_query_original_table {
|
|
|
|
|
cand_tuple = orig_table
|
|
|
|
|
.get(self, &cand_tuple)?
|
|
|
|
|
.ok_or_else(|| miette!("corrupted index"))?;
|
|
|
|
|
}
|
|
|
|
|
if let Some(code) = &config.filter {
|
|
|
|
|
if !eval_bytecode_pred(code, &cand_tuple, &mut stack, Default::default())? {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
let mut cur = Vec::with_capacity(ret_tuple_len);
|
|
|
|
|
|
|
|
|
|
for i in &config.orig_table_binding {
|
|
|
|
|
cur.push(cand_tuple[*i].clone());
|
|
|
|
|
}
|
|
|
|
|
if config.bind_field {
|
|
|
|
|
cur.push(DataValue::from(cand_idx as i64));
|
|
|
|
|
}
|
|
|
|
|
if config.bind_field_idx {
|
|
|
|
|
cur.push(DataValue::from(cand_subidx as i64));
|
|
|
|
|
}
|
|
|
|
|
if config.bind_distance {
|
|
|
|
|
cur.push(DataValue::from(distance));
|
|
|
|
|
}
|
|
|
|
|
if config.bind_vector {
|
|
|
|
|
let vec = if cand_subidx < 0 {
|
|
|
|
|
match &cand_tuple[cand_idx] {
|
|
|
|
|
DataValue::List(v) => v[cand_subidx as usize].clone(),
|
|
|
|
|
_ => bail!("corrupted index"),
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
cand_tuple[cand_idx].clone()
|
|
|
|
|
};
|
|
|
|
|
cur.push(vec);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ret.push(cur);
|
|
|
|
|
}
|
|
|
|
|
Ok(ret)
|
|
|
|
|
} else {
|
|
|
|
|
Ok(vec![])
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[cfg(test)]
|
|
|
|
|