hash stuff

main
Ziyang Hu 1 year ago
parent 02c3358362
commit 9e84392798

@ -95,7 +95,7 @@ pub enum ColType {
Validity,
}
#[derive(Debug, Clone, Eq, PartialEq, Hash, serde_derive::Deserialize, serde_derive::Serialize)]
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, serde_derive::Deserialize, serde_derive::Serialize)]
pub enum VecElementType {
F32,
F64,

@ -18,6 +18,8 @@ use crate::data::relation::VecElementType;
use ordered_float::OrderedFloat;
use regex::Regex;
use serde::{Deserialize, Deserializer, Serialize};
use sha2::digest::FixedOutput;
use sha2::{Digest, Sha256};
use smartstring::{LazyCompact, SmartString};
use uuid::Uuid;
@ -182,6 +184,22 @@ impl Vector {
Vector::F64(_) => VecElementType::F64,
}
}
pub(crate) fn get_hash(&self) -> impl AsRef<[u8]> {
let mut hasher = Sha256::new();
match self {
Vector::F32(v) => {
for e in v.iter() {
hasher.update(&e.to_le_bytes());
}
}
Vector::F64(v) => {
for e in v.iter() {
hasher.update(&e.to_le_bytes());
}
}
}
hasher.finalize_fixed()
}
}
impl PartialEq<Self> for Vector {

@ -13,13 +13,11 @@ use crate::data::value::Vector;
use crate::parse::sys::HnswDistance;
use crate::runtime::relation::RelationHandle;
use crate::runtime::transact::SessionTx;
use crate::{decode_tuple_from_kv, DataValue, Symbol};
use miette::{bail, Result};
use crate::{decode_tuple_from_kv, DataValue};
use miette::{bail, miette, Result};
use ordered_float::OrderedFloat;
use priority_queue::PriorityQueue;
use rand::Rng;
use sha2::digest::FixedOutput;
use sha2::{Digest, Sha256};
use smartstring::{LazyCompact, SmartString};
use std::cmp::{max, Reverse};
use std::collections::BTreeSet;
@ -45,12 +43,32 @@ pub(crate) struct HnswIndexManifest {
pub(crate) struct HnswKnnQueryOptions {
k: usize,
ef: usize,
max_distance: f64,
min_margin: f64,
auto_margin_factor: Option<f64>,
bind_field: Option<Symbol>,
bind_distance: Option<Symbol>,
bind_vector: Option<Symbol>,
bind_field: bool,
bind_field_idx: bool,
bind_distance: bool,
bind_vector: bool,
radius: Option<f64>,
orig_table_binding: Vec<usize>,
filter: Option<Vec<Bytecode>>,
}
impl HnswKnnQueryOptions {
fn return_tuple_len(&self) -> usize {
let mut ret_tuple_len = self.orig_table_binding.len();
if self.bind_field {
ret_tuple_len += 1;
}
if self.bind_field_idx {
ret_tuple_len += 1;
}
if self.bind_distance {
ret_tuple_len += 1;
}
if self.bind_vector {
ret_tuple_len += 1;
}
ret_tuple_len
}
}
impl HnswIndexManifest {
@ -167,9 +185,23 @@ impl<'a> SessionTx<'a> {
orig_table: &RelationHandle,
idx_table: &RelationHandle,
) -> Result<()> {
// TODO check if this is an update!
let tuple_key = &tuple[..orig_table.metadata.keys.len()];
let hash = q.get_hash();
let mut canary_tuple = vec![DataValue::from(0)];
for _ in 0..2 {
canary_tuple.extend_from_slice(tuple_key);
canary_tuple.push(DataValue::from(idx as i64));
canary_tuple.push(DataValue::from(subidx as i64));
}
if let Some(v) = idx_table.get(self, &canary_tuple)? {
if let DataValue::Bytes(b) = &v[tuple_key.len() * 2 + 6] {
if b == hash.as_ref() {
return Ok(());
}
}
// TODO
self.hnsw_remove_vec()?;
}
let ep_res = idx_table
.scan_bounded_prefix(
@ -194,6 +226,7 @@ impl<'a> SessionTx<'a> {
if target_level < bottom_level {
// this becomes the entry point
self.hnsw_put_fresh_at_levels(
hash.as_ref(),
tuple_key,
idx,
subidx,
@ -223,7 +256,7 @@ impl<'a> SessionTx<'a> {
}
let mut self_tuple_val = vec![
DataValue::from(0.0),
DataValue::Null,
DataValue::Bytes(hash.as_ref().to_vec()),
DataValue::from(false),
];
for current_level in max(target_level, bottom_level)..=0 {
@ -255,15 +288,6 @@ impl<'a> SessionTx<'a> {
self_tuple_key[0] = DataValue::from(current_level);
self_tuple_val[0] = DataValue::from(neighbours.len() as f64);
// save hash in self-loops
let mut hasher = Sha256::new();
for (_, Reverse(OrderedFloat(dist))) in neighbours.iter() {
let dist_bs = dist.to_be_bytes();
Digest::update(&mut hasher, &dist_bs);
}
let hash = hasher.finalize_fixed();
self_tuple_val[1] = DataValue::Bytes(hash.to_vec());
let self_tuple_key_bytes =
idx_table.encode_key_for_store(&self_tuple_key, Default::default())?;
let self_tuple_val_bytes =
@ -351,7 +375,7 @@ impl<'a> SessionTx<'a> {
} else {
// This is the first vector in the index.
let level = manifest.get_random_level();
self.hnsw_put_fresh_at_levels(tuple_key, idx, subidx, orig_table, idx_table, level, 0)?;
self.hnsw_put_fresh_at_levels(hash.as_ref(), tuple_key, idx, subidx, orig_table, idx_table, level, 0)?;
}
Ok(())
}
@ -641,6 +665,7 @@ impl<'a> SessionTx<'a> {
}
fn hnsw_put_fresh_at_levels(
&mut self,
hash: &[u8],
tuple: &[DataValue],
idx: usize,
subidx: i32,
@ -663,7 +688,7 @@ impl<'a> SessionTx<'a> {
}
let target_value = [
DataValue::from(0.0),
DataValue::Null,
DataValue::Bytes(hash.to_vec()),
DataValue::from(false),
];
let target_key_bytes = idx_table.encode_key_for_store(&target_key, Default::default())?;
@ -692,10 +717,11 @@ impl<'a> SessionTx<'a> {
manifest: &HnswIndexManifest,
orig_table: &RelationHandle,
idx_table: &RelationHandle,
filter: Option<(&[Bytecode], &mut Vec<DataValue>)>,
filter: &Option<Vec<Bytecode>>,
stack: &mut Vec<DataValue>,
tuple: &Tuple,
) -> Result<bool> {
if let Some((code, stack)) = filter {
if let Some(code) = filter {
if !eval_bytecode_pred(code, tuple, stack, Default::default())? {
return Ok(false);
}
@ -724,9 +750,136 @@ impl<'a> SessionTx<'a> {
pub(crate) fn hnsw_remove(&mut self) -> Result<()> {
todo!()
}
pub(crate) fn hnsw_knn(&self) -> Result<()> {
pub(crate) fn hnsw_remove_vec(&mut self) -> Result<()> {
todo!()
}
pub(crate) fn hnsw_knn(
&self,
q: Vector,
config: &HnswKnnQueryOptions,
idx_table: &RelationHandle,
orig_table: &RelationHandle,
manifest: &HnswIndexManifest,
) -> Result<Vec<Tuple>> {
if q.len() != manifest.vec_dim {
bail!("query vector dimension mismatch");
}
let q = match (q, manifest.dtype) {
(v @ Vector::F32(_), VecElementType::F32) => v,
(v @ Vector::F64(_), VecElementType::F64) => v,
(Vector::F32(v), VecElementType::F64) => Vector::F64(v.mapv(|x| x as f64)),
(Vector::F64(v), VecElementType::F32) => Vector::F32(v.mapv(|x| x as f32)),
};
let ep_res = idx_table
.scan_bounded_prefix(
self,
&[],
&[DataValue::from(i64::MIN)],
&[DataValue::from(1)],
)
.next();
if let Some(ep) = ep_res {
let ep = ep?;
let bottom_level = ep[0].get_int().unwrap();
let ep_key = ep[1..orig_table.metadata.keys.len() + 1].to_vec();
let ep_idx = ep[orig_table.metadata.keys.len() + 1].get_int().unwrap() as usize;
let ep_subidx = ep[orig_table.metadata.keys.len() + 2].get_int().unwrap() as i32;
let ep_distance =
self.hnsw_compare_vector(&q, &ep_key, ep_idx, ep_subidx, manifest, orig_table)?;
let mut found_nn = PriorityQueue::new();
found_nn.push((ep_key, ep_idx, ep_subidx), OrderedFloat(ep_distance));
for current_level in bottom_level..0 {
self.hnsw_search_level(
&q,
1,
current_level,
manifest,
orig_table,
idx_table,
&mut found_nn,
)?;
}
self.hnsw_search_level(
&q,
config.ef,
0,
manifest,
orig_table,
idx_table,
&mut found_nn,
)?;
if found_nn.is_empty() {
return Ok(vec![]);
}
if config.filter.is_none() {
while found_nn.len() > config.k {
found_nn.pop();
}
}
let max_binding = *config.orig_table_binding.iter().max().unwrap();
// FIXME this is wasteful
let needs_query_original_table = config.bind_vector
|| config.filter.is_some()
|| max_binding >= orig_table.metadata.keys.len();
let mut ret = vec![];
let mut stack = vec![];
let ret_tuple_len = config.return_tuple_len();
while let Some(((mut cand_tuple, cand_idx, cand_subidx), OrderedFloat(distance))) =
found_nn.pop()
{
if let Some(r) = config.radius {
if distance > r {
continue;
}
}
if needs_query_original_table {
cand_tuple = orig_table
.get(self, &cand_tuple)?
.ok_or_else(|| miette!("corrupted index"))?;
}
if let Some(code) = &config.filter {
if !eval_bytecode_pred(code, &cand_tuple, &mut stack, Default::default())? {
continue;
}
}
let mut cur = Vec::with_capacity(ret_tuple_len);
for i in &config.orig_table_binding {
cur.push(cand_tuple[*i].clone());
}
if config.bind_field {
cur.push(DataValue::from(cand_idx as i64));
}
if config.bind_field_idx {
cur.push(DataValue::from(cand_subidx as i64));
}
if config.bind_distance {
cur.push(DataValue::from(distance));
}
if config.bind_vector {
let vec = if cand_subidx < 0 {
match &cand_tuple[cand_idx] {
DataValue::List(v) => v[cand_subidx as usize].clone(),
_ => bail!("corrupted index"),
}
} else {
cand_tuple[cand_idx].clone()
};
cur.push(vec);
}
ret.push(cur);
}
Ok(ret)
} else {
Ok(vec![])
}
}
}
#[cfg(test)]

@ -12,7 +12,8 @@ use std::sync::atomic::Ordering;
use itertools::Itertools;
use log::error;
use miette::{bail, ensure, Diagnostic, Result};
use miette::{bail, ensure, Diagnostic, Result, IntoDiagnostic};
use pest::Parser;
use rmp_serde::Serializer;
use serde::Serialize;
use smartstring::{LazyCompact, SmartString};
@ -24,10 +25,11 @@ use crate::data::symb::Symbol;
use crate::data::tuple::{decode_tuple_from_key, Tuple, TupleT, ENCODED_KEY_MIN_LEN};
use crate::data::value::{DataValue, ValidityTs};
use crate::parse::sys::HnswIndexConfig;
use crate::parse::SourceSpan;
use crate::parse::{CozoScriptParser, Rule, SourceSpan};
use crate::query::compile::IndexPositionUse;
use crate::runtime::transact::SessionTx;
use crate::{NamedRows, StoreTx};
use crate::parse::expr::build_expr;
use crate::runtime::hnsw::HnswIndexManifest;
#[derive(
@ -124,6 +126,16 @@ struct StoredRelArityMismatch {
}
impl RelationHandle {
pub(crate) fn raw_binding_map(&self) -> BTreeMap<Symbol, usize> {
let mut ret = BTreeMap::new();
for (i, col) in self.metadata.keys.iter().enumerate() {
ret.insert(Symbol::new(col.name.clone(), Default::default()), i);
}
for (i, col) in self.metadata.non_keys.iter().enumerate() {
ret.insert(Symbol::new(col.name.clone(), Default::default()), i + self.metadata.keys.len());
}
ret
}
pub(crate) fn has_triggers(&self) -> bool {
!self.put_triggers.is_empty() || !self.rm_triggers.is_empty()
}
@ -816,8 +828,18 @@ impl<'a> SessionTx<'a> {
// populate index
let all_tuples = rel_handle.scan_all(self).collect::<Result<Vec<_>>>()?;
let filter = if let Some(f_code) = &manifest.index_filter {
let parsed = CozoScriptParser::parse(Rule::expr, f_code).into_diagnostic()?.next().unwrap();
let mut code_expr = build_expr(parsed, &Default::default())?;
let binding_map = rel_handle.raw_binding_map();
code_expr.fill_binding_indices(&binding_map)?;
Some(code_expr.compile())
} else {
None
};
let mut stack = vec![];
for tuple in all_tuples {
self.hnsw_put(&manifest, &rel_handle, &idx_handle, None, &tuple)?;
self.hnsw_put(&manifest, &rel_handle, &idx_handle, &filter, &mut stack, &tuple)?;
}
rel_handle

Loading…
Cancel
Save