query runs

main
Ziyang Hu 1 year ago
parent 84163915a1
commit a03af029db

@ -757,6 +757,24 @@ fn parse_atom(
bail!(HnswQueryRequired(span))
}
if opts.k == 0 {
#[derive(Debug, Error, Diagnostic)]
#[error("Field `k > 0` is required for HNSW search")]
#[diagnostic(code(parser::hnsw_k_required))]
struct HnswKRequired(#[label] SourceSpan);
bail!(HnswKRequired(span))
}
if opts.ef == 0 {
#[derive(Debug, Error, Diagnostic)]
#[error("Field `ef > 0` is required for HNSW search")]
#[diagnostic(code(parser::hnsw_ef_required))]
struct HnswEfRequired(#[label] SourceSpan);
bail!(HnswEfRequired(span))
}
InputAtom::HnswSearch { inner: opts }
}
Rule::relation_named_apply => {

@ -854,6 +854,7 @@ pub(crate) struct HnswSearchRA {
impl HnswSearchRA {
fn fill_binding_indices_and_compile(&mut self) -> Result<()> {
self.parent.fill_binding_indices_and_compile()?;
if self.hnsw_search.filter.is_some() {
let bindings: BTreeMap<_, _> = self
.own_bindings
@ -868,6 +869,43 @@ impl HnswSearchRA {
}
Ok(())
}
fn iter<'a>(
&'a self,
tx: &'a SessionTx<'_>,
delta_rule: Option<&MagicSymbol>,
stores: &'a BTreeMap<MagicSymbol, EpochStore>,
) -> Result<TupleIter<'a>> {
let bindings = self.parent.bindings_after_eliminate();
let mut bind_idx = usize::MAX;
for (i, b) in bindings.iter().enumerate() {
if *b == self.hnsw_search.query {
bind_idx = i;
break;
}
}
let config = self.hnsw_search.clone();
let filter_code = self.filter_bytecode.clone();
let mut stack = vec![];
let it = self
.parent
.iter(tx, delta_rule, stores)?
.map_ok(move |tuple| -> Result<_> {
let v = match tuple[bind_idx].clone() {
DataValue::Vec(v) => v,
d => bail!("Expected vector, got {:?}", d)
};
let res = tx.hnsw_knn(v, &config, &filter_code, &mut stack)?;
Ok(res.into_iter().map(move |t| {
let mut r = tuple.clone();
r.extend(t);
r
}))
})
.map(flatten_err)
.flatten_ok();
Ok(Box::new(it))
}
}
#[derive(Debug)]
@ -1673,9 +1711,7 @@ impl RelAlgebra {
RelAlgebra::Filter(r) => r.iter(tx, delta_rule, stores),
RelAlgebra::NegJoin(r) => r.iter(tx, delta_rule, stores),
RelAlgebra::Unification(r) => r.iter(tx, delta_rule, stores),
RelAlgebra::HnswSearch(r) => {
todo!()
}
RelAlgebra::HnswSearch(r) => r.iter(tx, delta_rule, stores),
}
}
}

@ -7,13 +7,14 @@
*/
use crate::data::expr::{eval_bytecode_pred, Bytecode};
use crate::data::program::HnswSearch;
use crate::data::relation::VecElementType;
use crate::data::tuple::{Tuple, ENCODED_KEY_MIN_LEN};
use crate::data::value::Vector;
use crate::parse::sys::HnswDistance;
use crate::runtime::relation::RelationHandle;
use crate::runtime::transact::SessionTx;
use crate::{decode_tuple_from_kv, DataValue};
use crate::{decode_tuple_from_kv, DataValue, SourceSpan};
use itertools::Itertools;
use miette::{bail, miette, Result};
use ordered_float::OrderedFloat;
@ -41,37 +42,37 @@ pub(crate) struct HnswIndexManifest {
pub(crate) keep_pruned_connections: bool,
}
#[derive(Clone)]
pub(crate) struct HnswKnnQueryOptions {
k: usize,
ef: usize,
bind_field: bool,
bind_field_idx: bool,
bind_distance: bool,
bind_vector: bool,
radius: Option<f64>,
orig_table_binding: Vec<usize>,
filter: Option<Vec<Bytecode>>,
}
// #[derive(Clone)]
// pub(crate) struct HnswKnnQueryOptions {
// k: usize,
// ef: usize,
// bind_field: bool,
// bind_field_idx: bool,
// bind_distance: bool,
// bind_vector: bool,
// radius: Option<f64>,
// orig_table_binding: Vec<usize>,
// filter: Option<Vec<Bytecode>>,
// }
impl HnswKnnQueryOptions {
fn return_tuple_len(&self) -> usize {
let mut ret_tuple_len = self.orig_table_binding.len();
if self.bind_field {
ret_tuple_len += 1;
}
if self.bind_field_idx {
ret_tuple_len += 1;
}
if self.bind_distance {
ret_tuple_len += 1;
}
if self.bind_vector {
ret_tuple_len += 1;
}
ret_tuple_len
}
}
// impl HnswKnnQueryOptions {
// fn return_tuple_len(&self) -> usize {
// let mut ret_tuple_len = self.orig_table_binding.len();
// if self.bind_field {
// ret_tuple_len += 1;
// }
// if self.bind_field_idx {
// ret_tuple_len += 1;
// }
// if self.bind_distance {
// ret_tuple_len += 1;
// }
// if self.bind_vector {
// ret_tuple_len += 1;
// }
// ret_tuple_len
// }
// }
impl HnswIndexManifest {
fn get_random_level(&self) -> i64 {
@ -890,22 +891,23 @@ impl<'a> SessionTx<'a> {
pub(crate) fn hnsw_knn(
&self,
q: Vector,
config: &HnswKnnQueryOptions,
idx_table: &RelationHandle,
orig_table: &RelationHandle,
manifest: &HnswIndexManifest,
config: &HnswSearch,
filter_bytecode: &Option<(Vec<Bytecode>, SourceSpan)>,
stack: &mut Vec<DataValue>,
) -> Result<Vec<Tuple>> {
if q.len() != manifest.vec_dim {
println!("hnsw_knn for {:?} on {:#?}", q, config);
if q.len() != config.manifest.vec_dim {
bail!("query vector dimension mismatch");
}
let q = match (q, manifest.dtype) {
let q = match (q, config.manifest.dtype) {
(v @ Vector::F32(_), VecElementType::F32) => v,
(v @ Vector::F64(_), VecElementType::F64) => v,
(Vector::F32(v), VecElementType::F64) => Vector::F64(v.mapv(|x| x as f64)),
(Vector::F64(v), VecElementType::F32) => Vector::F32(v.mapv(|x| x as f32)),
};
let ep_res = idx_table
let ep_res = config
.idx_handle
.scan_bounded_prefix(
self,
&[],
@ -914,13 +916,24 @@ impl<'a> SessionTx<'a> {
)
.next();
if let Some(ep) = ep_res {
println!("found ep");
let ep = ep?;
let bottom_level = ep[0].get_int().unwrap();
let ep_key = ep[1..orig_table.metadata.keys.len() + 1].to_vec();
let ep_idx = ep[orig_table.metadata.keys.len() + 1].get_int().unwrap() as usize;
let ep_subidx = ep[orig_table.metadata.keys.len() + 2].get_int().unwrap() as i32;
let ep_distance =
self.hnsw_compare_vector(&q, &ep_key, ep_idx, ep_subidx, manifest, orig_table)?;
let ep_key = ep[1..config.base_handle.metadata.keys.len() + 1].to_vec();
let ep_idx = ep[config.base_handle.metadata.keys.len() + 1]
.get_int()
.unwrap() as usize;
let ep_subidx = ep[config.base_handle.metadata.keys.len() + 2]
.get_int()
.unwrap() as i32;
let ep_distance = self.hnsw_compare_vector(
&q,
&ep_key,
ep_idx,
ep_subidx,
&config.manifest,
&config.base_handle,
)?;
let mut found_nn = PriorityQueue::new();
found_nn.push((ep_key, ep_idx, ep_subidx), OrderedFloat(ep_distance));
for current_level in bottom_level..0 {
@ -928,9 +941,9 @@ impl<'a> SessionTx<'a> {
&q,
1,
current_level,
manifest,
orig_table,
idx_table,
&config.manifest,
&config.base_handle,
&config.idx_handle,
&mut found_nn,
)?;
}
@ -938,12 +951,13 @@ impl<'a> SessionTx<'a> {
&q,
config.ef,
0,
manifest,
orig_table,
idx_table,
&config.manifest,
&config.base_handle,
&config.idx_handle,
&mut found_nn,
)?;
if found_nn.is_empty() {
println!("no candidates found");
return Ok(vec![]);
}
@ -953,61 +967,47 @@ impl<'a> SessionTx<'a> {
}
}
let max_binding = *config.orig_table_binding.iter().max().unwrap();
// FIXME this is wasteful
let needs_query_original_table = config.bind_vector
|| config.filter.is_some()
|| max_binding >= orig_table.metadata.keys.len();
let mut ret = vec![];
let mut stack = vec![];
let ret_tuple_len = config.return_tuple_len();
while let Some(((mut cand_tuple, cand_idx, cand_subidx), OrderedFloat(distance))) =
while let Some(((cand_tuple, cand_idx, cand_subidx), OrderedFloat(distance))) =
found_nn.pop()
{
println!("found candidate {:?} at distance {}", cand_tuple, distance);
if let Some(r) = config.radius {
if distance > r {
continue;
}
}
if needs_query_original_table {
cand_tuple = orig_table
.get(self, &cand_tuple)?
.ok_or_else(|| miette!("corrupted index"))?;
}
if let Some(code) = &config.filter {
if !eval_bytecode_pred(code, &cand_tuple, &mut stack, Default::default())? {
continue;
}
}
let mut cur = Vec::with_capacity(ret_tuple_len);
for i in &config.orig_table_binding {
cur.push(cand_tuple[*i].clone());
}
if config.bind_field {
let field = if cand_idx as usize >= orig_table.metadata.keys.len() {
orig_table.metadata.keys[cand_idx as usize].name.clone()
let mut cand_tuple = config
.base_handle
.get(self, &cand_tuple)?
.ok_or_else(|| miette!("corrupted index"))?;
if config.bind_field.is_some() {
let field = if cand_idx as usize >= config.base_handle.metadata.keys.len() {
config.base_handle.metadata.keys[cand_idx as usize]
.name
.clone()
} else {
orig_table.metadata.non_keys
[cand_idx as usize - orig_table.metadata.keys.len()]
config.base_handle.metadata.non_keys
[cand_idx as usize - config.base_handle.metadata.keys.len()]
.name
.clone()
};
cur.push(DataValue::Str(field));
cand_tuple.push(DataValue::Str(field));
}
if config.bind_field_idx {
cur.push(if cand_subidx < 0 {
if config.bind_field_idx.is_some() {
cand_tuple.push(if cand_subidx < 0 {
DataValue::Null
} else {
DataValue::from(cand_subidx as i64)
});
}
if config.bind_distance {
cur.push(DataValue::from(distance));
if config.bind_distance.is_some() {
cand_tuple.push(DataValue::from(distance));
}
if config.bind_vector {
if config.bind_vector.is_some() {
let vec = if cand_subidx < 0 {
match &cand_tuple[cand_idx] {
DataValue::List(v) => v[cand_subidx as usize].clone(),
@ -1016,10 +1016,16 @@ impl<'a> SessionTx<'a> {
} else {
cand_tuple[cand_idx].clone()
};
cur.push(vec);
cand_tuple.push(vec);
}
if let Some((code, span)) = filter_bytecode {
if !eval_bytecode_pred(code, &cand_tuple, stack, *span)? {
continue;
}
}
ret.push(cur);
ret.push(cand_tuple);
}
Ok(ret)
} else {

@ -826,12 +826,12 @@ fn test_vec_index() {
.run_script(
r"
#::explain {
?[v] := ~a:vec{k: 'a', v | query: q}, q = vec([1,1,1,1,1,1,1,1])
?[k, dist, v] := ~a:vec{k, v | query: q, k: 10, ef: 20, bind_distance: dist}, q = vec([1,1,1,1,1,1,1,1])
#}
",
Default::default(),
)
.unwrap();
println!("{:#?}", res.into_json()["rows"]);
println!("res: {:#?}", res.into_json()["rows"]);
// println!("{:#?}", db.export_relations(["a", "a:vec"].iter()));
}

Loading…
Cancel
Save