|
|
@ -7,13 +7,14 @@
|
|
|
|
*/
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
|
|
use crate::data::expr::{eval_bytecode_pred, Bytecode};
|
|
|
|
use crate::data::expr::{eval_bytecode_pred, Bytecode};
|
|
|
|
|
|
|
|
use crate::data::program::HnswSearch;
|
|
|
|
use crate::data::relation::VecElementType;
|
|
|
|
use crate::data::relation::VecElementType;
|
|
|
|
use crate::data::tuple::{Tuple, ENCODED_KEY_MIN_LEN};
|
|
|
|
use crate::data::tuple::{Tuple, ENCODED_KEY_MIN_LEN};
|
|
|
|
use crate::data::value::Vector;
|
|
|
|
use crate::data::value::Vector;
|
|
|
|
use crate::parse::sys::HnswDistance;
|
|
|
|
use crate::parse::sys::HnswDistance;
|
|
|
|
use crate::runtime::relation::RelationHandle;
|
|
|
|
use crate::runtime::relation::RelationHandle;
|
|
|
|
use crate::runtime::transact::SessionTx;
|
|
|
|
use crate::runtime::transact::SessionTx;
|
|
|
|
use crate::{decode_tuple_from_kv, DataValue};
|
|
|
|
use crate::{decode_tuple_from_kv, DataValue, SourceSpan};
|
|
|
|
use itertools::Itertools;
|
|
|
|
use itertools::Itertools;
|
|
|
|
use miette::{bail, miette, Result};
|
|
|
|
use miette::{bail, miette, Result};
|
|
|
|
use ordered_float::OrderedFloat;
|
|
|
|
use ordered_float::OrderedFloat;
|
|
|
@ -41,37 +42,37 @@ pub(crate) struct HnswIndexManifest {
|
|
|
|
pub(crate) keep_pruned_connections: bool,
|
|
|
|
pub(crate) keep_pruned_connections: bool,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
#[derive(Clone)]
|
|
|
|
// #[derive(Clone)]
|
|
|
|
pub(crate) struct HnswKnnQueryOptions {
|
|
|
|
// pub(crate) struct HnswKnnQueryOptions {
|
|
|
|
k: usize,
|
|
|
|
// k: usize,
|
|
|
|
ef: usize,
|
|
|
|
// ef: usize,
|
|
|
|
bind_field: bool,
|
|
|
|
// bind_field: bool,
|
|
|
|
bind_field_idx: bool,
|
|
|
|
// bind_field_idx: bool,
|
|
|
|
bind_distance: bool,
|
|
|
|
// bind_distance: bool,
|
|
|
|
bind_vector: bool,
|
|
|
|
// bind_vector: bool,
|
|
|
|
radius: Option<f64>,
|
|
|
|
// radius: Option<f64>,
|
|
|
|
orig_table_binding: Vec<usize>,
|
|
|
|
// orig_table_binding: Vec<usize>,
|
|
|
|
filter: Option<Vec<Bytecode>>,
|
|
|
|
// filter: Option<Vec<Bytecode>>,
|
|
|
|
}
|
|
|
|
// }
|
|
|
|
|
|
|
|
|
|
|
|
impl HnswKnnQueryOptions {
|
|
|
|
// impl HnswKnnQueryOptions {
|
|
|
|
fn return_tuple_len(&self) -> usize {
|
|
|
|
// fn return_tuple_len(&self) -> usize {
|
|
|
|
let mut ret_tuple_len = self.orig_table_binding.len();
|
|
|
|
// let mut ret_tuple_len = self.orig_table_binding.len();
|
|
|
|
if self.bind_field {
|
|
|
|
// if self.bind_field {
|
|
|
|
ret_tuple_len += 1;
|
|
|
|
// ret_tuple_len += 1;
|
|
|
|
}
|
|
|
|
// }
|
|
|
|
if self.bind_field_idx {
|
|
|
|
// if self.bind_field_idx {
|
|
|
|
ret_tuple_len += 1;
|
|
|
|
// ret_tuple_len += 1;
|
|
|
|
}
|
|
|
|
// }
|
|
|
|
if self.bind_distance {
|
|
|
|
// if self.bind_distance {
|
|
|
|
ret_tuple_len += 1;
|
|
|
|
// ret_tuple_len += 1;
|
|
|
|
}
|
|
|
|
// }
|
|
|
|
if self.bind_vector {
|
|
|
|
// if self.bind_vector {
|
|
|
|
ret_tuple_len += 1;
|
|
|
|
// ret_tuple_len += 1;
|
|
|
|
}
|
|
|
|
// }
|
|
|
|
ret_tuple_len
|
|
|
|
// ret_tuple_len
|
|
|
|
}
|
|
|
|
// }
|
|
|
|
}
|
|
|
|
// }
|
|
|
|
|
|
|
|
|
|
|
|
impl HnswIndexManifest {
|
|
|
|
impl HnswIndexManifest {
|
|
|
|
fn get_random_level(&self) -> i64 {
|
|
|
|
fn get_random_level(&self) -> i64 {
|
|
|
@ -890,22 +891,23 @@ impl<'a> SessionTx<'a> {
|
|
|
|
pub(crate) fn hnsw_knn(
|
|
|
|
pub(crate) fn hnsw_knn(
|
|
|
|
&self,
|
|
|
|
&self,
|
|
|
|
q: Vector,
|
|
|
|
q: Vector,
|
|
|
|
config: &HnswKnnQueryOptions,
|
|
|
|
config: &HnswSearch,
|
|
|
|
idx_table: &RelationHandle,
|
|
|
|
filter_bytecode: &Option<(Vec<Bytecode>, SourceSpan)>,
|
|
|
|
orig_table: &RelationHandle,
|
|
|
|
stack: &mut Vec<DataValue>,
|
|
|
|
manifest: &HnswIndexManifest,
|
|
|
|
|
|
|
|
) -> Result<Vec<Tuple>> {
|
|
|
|
) -> Result<Vec<Tuple>> {
|
|
|
|
if q.len() != manifest.vec_dim {
|
|
|
|
println!("hnsw_knn for {:?} on {:#?}", q, config);
|
|
|
|
|
|
|
|
if q.len() != config.manifest.vec_dim {
|
|
|
|
bail!("query vector dimension mismatch");
|
|
|
|
bail!("query vector dimension mismatch");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
let q = match (q, manifest.dtype) {
|
|
|
|
let q = match (q, config.manifest.dtype) {
|
|
|
|
(v @ Vector::F32(_), VecElementType::F32) => v,
|
|
|
|
(v @ Vector::F32(_), VecElementType::F32) => v,
|
|
|
|
(v @ Vector::F64(_), VecElementType::F64) => v,
|
|
|
|
(v @ Vector::F64(_), VecElementType::F64) => v,
|
|
|
|
(Vector::F32(v), VecElementType::F64) => Vector::F64(v.mapv(|x| x as f64)),
|
|
|
|
(Vector::F32(v), VecElementType::F64) => Vector::F64(v.mapv(|x| x as f64)),
|
|
|
|
(Vector::F64(v), VecElementType::F32) => Vector::F32(v.mapv(|x| x as f32)),
|
|
|
|
(Vector::F64(v), VecElementType::F32) => Vector::F32(v.mapv(|x| x as f32)),
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
let ep_res = idx_table
|
|
|
|
let ep_res = config
|
|
|
|
|
|
|
|
.idx_handle
|
|
|
|
.scan_bounded_prefix(
|
|
|
|
.scan_bounded_prefix(
|
|
|
|
self,
|
|
|
|
self,
|
|
|
|
&[],
|
|
|
|
&[],
|
|
|
@ -914,13 +916,24 @@ impl<'a> SessionTx<'a> {
|
|
|
|
)
|
|
|
|
)
|
|
|
|
.next();
|
|
|
|
.next();
|
|
|
|
if let Some(ep) = ep_res {
|
|
|
|
if let Some(ep) = ep_res {
|
|
|
|
|
|
|
|
println!("found ep");
|
|
|
|
let ep = ep?;
|
|
|
|
let ep = ep?;
|
|
|
|
let bottom_level = ep[0].get_int().unwrap();
|
|
|
|
let bottom_level = ep[0].get_int().unwrap();
|
|
|
|
let ep_key = ep[1..orig_table.metadata.keys.len() + 1].to_vec();
|
|
|
|
let ep_key = ep[1..config.base_handle.metadata.keys.len() + 1].to_vec();
|
|
|
|
let ep_idx = ep[orig_table.metadata.keys.len() + 1].get_int().unwrap() as usize;
|
|
|
|
let ep_idx = ep[config.base_handle.metadata.keys.len() + 1]
|
|
|
|
let ep_subidx = ep[orig_table.metadata.keys.len() + 2].get_int().unwrap() as i32;
|
|
|
|
.get_int()
|
|
|
|
let ep_distance =
|
|
|
|
.unwrap() as usize;
|
|
|
|
self.hnsw_compare_vector(&q, &ep_key, ep_idx, ep_subidx, manifest, orig_table)?;
|
|
|
|
let ep_subidx = ep[config.base_handle.metadata.keys.len() + 2]
|
|
|
|
|
|
|
|
.get_int()
|
|
|
|
|
|
|
|
.unwrap() as i32;
|
|
|
|
|
|
|
|
let ep_distance = self.hnsw_compare_vector(
|
|
|
|
|
|
|
|
&q,
|
|
|
|
|
|
|
|
&ep_key,
|
|
|
|
|
|
|
|
ep_idx,
|
|
|
|
|
|
|
|
ep_subidx,
|
|
|
|
|
|
|
|
&config.manifest,
|
|
|
|
|
|
|
|
&config.base_handle,
|
|
|
|
|
|
|
|
)?;
|
|
|
|
let mut found_nn = PriorityQueue::new();
|
|
|
|
let mut found_nn = PriorityQueue::new();
|
|
|
|
found_nn.push((ep_key, ep_idx, ep_subidx), OrderedFloat(ep_distance));
|
|
|
|
found_nn.push((ep_key, ep_idx, ep_subidx), OrderedFloat(ep_distance));
|
|
|
|
for current_level in bottom_level..0 {
|
|
|
|
for current_level in bottom_level..0 {
|
|
|
@ -928,9 +941,9 @@ impl<'a> SessionTx<'a> {
|
|
|
|
&q,
|
|
|
|
&q,
|
|
|
|
1,
|
|
|
|
1,
|
|
|
|
current_level,
|
|
|
|
current_level,
|
|
|
|
manifest,
|
|
|
|
&config.manifest,
|
|
|
|
orig_table,
|
|
|
|
&config.base_handle,
|
|
|
|
idx_table,
|
|
|
|
&config.idx_handle,
|
|
|
|
&mut found_nn,
|
|
|
|
&mut found_nn,
|
|
|
|
)?;
|
|
|
|
)?;
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -938,12 +951,13 @@ impl<'a> SessionTx<'a> {
|
|
|
|
&q,
|
|
|
|
&q,
|
|
|
|
config.ef,
|
|
|
|
config.ef,
|
|
|
|
0,
|
|
|
|
0,
|
|
|
|
manifest,
|
|
|
|
&config.manifest,
|
|
|
|
orig_table,
|
|
|
|
&config.base_handle,
|
|
|
|
idx_table,
|
|
|
|
&config.idx_handle,
|
|
|
|
&mut found_nn,
|
|
|
|
&mut found_nn,
|
|
|
|
)?;
|
|
|
|
)?;
|
|
|
|
if found_nn.is_empty() {
|
|
|
|
if found_nn.is_empty() {
|
|
|
|
|
|
|
|
println!("no candidates found");
|
|
|
|
return Ok(vec![]);
|
|
|
|
return Ok(vec![]);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -953,61 +967,47 @@ impl<'a> SessionTx<'a> {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
let max_binding = *config.orig_table_binding.iter().max().unwrap();
|
|
|
|
|
|
|
|
// FIXME this is wasteful
|
|
|
|
|
|
|
|
let needs_query_original_table = config.bind_vector
|
|
|
|
|
|
|
|
|| config.filter.is_some()
|
|
|
|
|
|
|
|
|| max_binding >= orig_table.metadata.keys.len();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let mut ret = vec![];
|
|
|
|
let mut ret = vec![];
|
|
|
|
let mut stack = vec![];
|
|
|
|
|
|
|
|
let ret_tuple_len = config.return_tuple_len();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
while let Some(((mut cand_tuple, cand_idx, cand_subidx), OrderedFloat(distance))) =
|
|
|
|
while let Some(((cand_tuple, cand_idx, cand_subidx), OrderedFloat(distance))) =
|
|
|
|
found_nn.pop()
|
|
|
|
found_nn.pop()
|
|
|
|
{
|
|
|
|
{
|
|
|
|
|
|
|
|
println!("found candidate {:?} at distance {}", cand_tuple, distance);
|
|
|
|
if let Some(r) = config.radius {
|
|
|
|
if let Some(r) = config.radius {
|
|
|
|
if distance > r {
|
|
|
|
if distance > r {
|
|
|
|
continue;
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if needs_query_original_table {
|
|
|
|
|
|
|
|
cand_tuple = orig_table
|
|
|
|
|
|
|
|
.get(self, &cand_tuple)?
|
|
|
|
|
|
|
|
.ok_or_else(|| miette!("corrupted index"))?;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if let Some(code) = &config.filter {
|
|
|
|
|
|
|
|
if !eval_bytecode_pred(code, &cand_tuple, &mut stack, Default::default())? {
|
|
|
|
|
|
|
|
continue;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
let mut cur = Vec::with_capacity(ret_tuple_len);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for i in &config.orig_table_binding {
|
|
|
|
let mut cand_tuple = config
|
|
|
|
cur.push(cand_tuple[*i].clone());
|
|
|
|
.base_handle
|
|
|
|
}
|
|
|
|
.get(self, &cand_tuple)?
|
|
|
|
if config.bind_field {
|
|
|
|
.ok_or_else(|| miette!("corrupted index"))?;
|
|
|
|
let field = if cand_idx as usize >= orig_table.metadata.keys.len() {
|
|
|
|
|
|
|
|
orig_table.metadata.keys[cand_idx as usize].name.clone()
|
|
|
|
if config.bind_field.is_some() {
|
|
|
|
|
|
|
|
let field = if cand_idx as usize >= config.base_handle.metadata.keys.len() {
|
|
|
|
|
|
|
|
config.base_handle.metadata.keys[cand_idx as usize]
|
|
|
|
|
|
|
|
.name
|
|
|
|
|
|
|
|
.clone()
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
orig_table.metadata.non_keys
|
|
|
|
config.base_handle.metadata.non_keys
|
|
|
|
[cand_idx as usize - orig_table.metadata.keys.len()]
|
|
|
|
[cand_idx as usize - config.base_handle.metadata.keys.len()]
|
|
|
|
.name
|
|
|
|
.name
|
|
|
|
.clone()
|
|
|
|
.clone()
|
|
|
|
};
|
|
|
|
};
|
|
|
|
cur.push(DataValue::Str(field));
|
|
|
|
cand_tuple.push(DataValue::Str(field));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if config.bind_field_idx {
|
|
|
|
if config.bind_field_idx.is_some() {
|
|
|
|
cur.push(if cand_subidx < 0 {
|
|
|
|
cand_tuple.push(if cand_subidx < 0 {
|
|
|
|
DataValue::Null
|
|
|
|
DataValue::Null
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
DataValue::from(cand_subidx as i64)
|
|
|
|
DataValue::from(cand_subidx as i64)
|
|
|
|
});
|
|
|
|
});
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if config.bind_distance {
|
|
|
|
if config.bind_distance.is_some() {
|
|
|
|
cur.push(DataValue::from(distance));
|
|
|
|
cand_tuple.push(DataValue::from(distance));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if config.bind_vector {
|
|
|
|
if config.bind_vector.is_some() {
|
|
|
|
let vec = if cand_subidx < 0 {
|
|
|
|
let vec = if cand_subidx < 0 {
|
|
|
|
match &cand_tuple[cand_idx] {
|
|
|
|
match &cand_tuple[cand_idx] {
|
|
|
|
DataValue::List(v) => v[cand_subidx as usize].clone(),
|
|
|
|
DataValue::List(v) => v[cand_subidx as usize].clone(),
|
|
|
@ -1016,10 +1016,16 @@ impl<'a> SessionTx<'a> {
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
cand_tuple[cand_idx].clone()
|
|
|
|
cand_tuple[cand_idx].clone()
|
|
|
|
};
|
|
|
|
};
|
|
|
|
cur.push(vec);
|
|
|
|
cand_tuple.push(vec);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if let Some((code, span)) = filter_bytecode {
|
|
|
|
|
|
|
|
if !eval_bytecode_pred(code, &cand_tuple, stack, *span)? {
|
|
|
|
|
|
|
|
continue;
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
ret.push(cur);
|
|
|
|
ret.push(cand_tuple);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
Ok(ret)
|
|
|
|
Ok(ret)
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|