insertion now compiles

main
Ziyang Hu 1 year ago
parent 9462409b90
commit 02c3358362

@ -8,7 +8,7 @@
use crate::data::expr::{eval_bytecode_pred, Bytecode};
use crate::data::relation::VecElementType;
use crate::data::tuple::{decode_tuple_from_key, Tuple};
use crate::data::tuple::{Tuple, ENCODED_KEY_MIN_LEN};
use crate::data::value::Vector;
use crate::parse::sys::HnswDistance;
use crate::runtime::relation::RelationHandle;
@ -18,6 +18,8 @@ use miette::{bail, Result};
use ordered_float::OrderedFloat;
use priority_queue::PriorityQueue;
use rand::Rng;
use sha2::digest::FixedOutput;
use sha2::{Digest, Sha256};
use smartstring::{LazyCompact, SmartString};
use std::cmp::{max, Reverse};
use std::collections::BTreeSet;
@ -165,23 +167,25 @@ impl<'a> SessionTx<'a> {
orig_table: &RelationHandle,
idx_table: &RelationHandle,
) -> Result<()> {
let start_tuple =
idx_table.encode_key_for_store(&vec![DataValue::from(i64::MIN)], Default::default())?;
let end_tuple =
idx_table.encode_key_for_store(&vec![DataValue::from(1)], Default::default())?;
let ep_res = self.store_tx.range_scan(&start_tuple, &end_tuple).next();
// TODO check if this is an update!
let tuple_key = &tuple[..orig_table.metadata.keys.len()];
let ep_res = idx_table
.scan_bounded_prefix(
self,
&[],
&[DataValue::from(i64::MIN)],
&[DataValue::from(1)],
)
.next();
if let Some(ep) = ep_res {
let (ep_key_bytes, _) = ep?;
let ep_key_tuple = decode_tuple_from_key(&ep_key_bytes);
let ep = ep?;
// bottom level since we are going up
let bottom_level = ep_key_tuple[0].get_int().unwrap();
let ep_key = ep_key_tuple[1..orig_table.metadata.keys.len() + 1].to_vec();
let ep_idx = ep_key_tuple[orig_table.metadata.keys.len() + 1]
.get_int()
.unwrap() as usize;
let ep_subidx = ep_key_tuple[orig_table.metadata.keys.len() + 2]
.get_int()
.unwrap() as i32;
let bottom_level = ep[0].get_int().unwrap();
let ep_key = ep[1..orig_table.metadata.keys.len() + 1].to_vec();
let ep_idx = ep[orig_table.metadata.keys.len() + 1].get_int().unwrap() as usize;
let ep_subidx = ep[orig_table.metadata.keys.len() + 2].get_int().unwrap() as i32;
let ep_distance =
self.hnsw_compare_vector(q, &ep_key, idx, subidx, manifest, orig_table)?;
let mut found_nn = PriorityQueue::new();
@ -190,7 +194,7 @@ impl<'a> SessionTx<'a> {
if target_level < bottom_level {
// this becomes the entry point
self.hnsw_put_fresh_at_levels(
tuple,
tuple_key,
idx,
subidx,
orig_table,
@ -210,7 +214,24 @@ impl<'a> SessionTx<'a> {
&mut found_nn,
)?;
}
let mut self_tuple_key = Vec::with_capacity(orig_table.metadata.keys.len() * 2 + 5);
self_tuple_key.push(DataValue::from(0));
for _ in 0..2 {
self_tuple_key.extend_from_slice(tuple_key);
self_tuple_key.push(DataValue::from(idx as i64));
self_tuple_key.push(DataValue::from(subidx as i64));
}
let mut self_tuple_val = vec![
DataValue::from(0.0),
DataValue::Null,
DataValue::from(false),
];
for current_level in max(target_level, bottom_level)..=0 {
let m_max = if current_level == 0 {
manifest.m_max0
} else {
manifest.m_max
};
self.hnsw_search_level(
q,
manifest.ef_construction,
@ -221,15 +242,220 @@ impl<'a> SessionTx<'a> {
&mut found_nn,
)?;
// add bidirectional links to the nearest neighbors
todo!();
// shrink links if necessary
todo!();
let neighbours = self.hnsw_select_neighbours_heuristic(
q,
&found_nn,
m_max,
current_level,
manifest,
idx_table,
orig_table,
)?;
// add self-link
self_tuple_key[0] = DataValue::from(current_level);
self_tuple_val[0] = DataValue::from(neighbours.len() as f64);
// save hash in self-loops
let mut hasher = Sha256::new();
for (_, Reverse(OrderedFloat(dist))) in neighbours.iter() {
let dist_bs = dist.to_be_bytes();
Digest::update(&mut hasher, &dist_bs);
}
let hash = hasher.finalize_fixed();
self_tuple_val[1] = DataValue::Bytes(hash.to_vec());
let self_tuple_key_bytes =
idx_table.encode_key_for_store(&self_tuple_key, Default::default())?;
let self_tuple_val_bytes =
idx_table.encode_val_only_for_store(&self_tuple_val, Default::default())?;
self.store_tx
.put(&self_tuple_key_bytes, &self_tuple_val_bytes)?;
// add bidirectional links
for (neighbour, Reverse(OrderedFloat(dist))) in neighbours.iter() {
let mut out_key = Vec::with_capacity(orig_table.metadata.keys.len() * 2 + 5);
let out_val = vec![
DataValue::from(*dist),
DataValue::Null,
DataValue::from(false),
];
out_key.push(DataValue::from(current_level));
out_key.extend_from_slice(tuple_key);
out_key.push(DataValue::from(idx as i64));
out_key.push(DataValue::from(subidx as i64));
out_key.extend_from_slice(&neighbour.0);
out_key.push(DataValue::from(neighbour.1 as i64));
out_key.push(DataValue::from(neighbour.2 as i64));
let out_key_bytes =
idx_table.encode_key_for_store(&out_key, Default::default())?;
let out_val_bytes =
idx_table.encode_val_only_for_store(&out_val, Default::default())?;
// println!("tuple: {:?}", tuple_key);
// println!("out_key: {:?}", out_key);
self.store_tx.put(&out_key_bytes, &out_val_bytes)?;
let mut in_key = Vec::with_capacity(orig_table.metadata.keys.len() * 2 + 5);
let in_val = vec![
DataValue::from(*dist),
DataValue::Null,
DataValue::from(false),
];
in_key.push(DataValue::from(current_level));
in_key.extend_from_slice(&neighbour.0);
in_key.push(DataValue::from(neighbour.1 as i64));
in_key.push(DataValue::from(neighbour.2 as i64));
in_key.extend_from_slice(tuple_key);
in_key.push(DataValue::from(idx as i64));
in_key.push(DataValue::from(subidx as i64));
// println!("in_key: {:?}", in_key);
let in_key_bytes =
idx_table.encode_key_for_store(&in_key, Default::default())?;
let in_val_bytes =
idx_table.encode_val_only_for_store(&in_val, Default::default())?;
self.store_tx.put(&in_key_bytes, &in_val_bytes)?;
// shrink links if necessary
let mut target_self_key =
Vec::with_capacity(orig_table.metadata.keys.len() * 2 + 5);
target_self_key.push(DataValue::from(current_level));
for _ in 0..2 {
target_self_key.extend_from_slice(&neighbour.0);
target_self_key.push(DataValue::from(neighbour.1 as i64));
target_self_key.push(DataValue::from(neighbour.2 as i64));
}
let target_self_key_bytes =
idx_table.encode_key_for_store(&target_self_key, Default::default())?;
let target_self_val_bytes = match self.store_tx.get(&target_self_key_bytes, false)? {
Some(bytes) => bytes,
None => bail!("Indexed vector not found, this signifies a bug in the index implementation"),
};
let target_self_val: Vec<DataValue> =
rmp_serde::from_slice(&target_self_val_bytes[ENCODED_KEY_MIN_LEN..])
.unwrap();
let target_degree = target_self_val[0].get_float().unwrap() as usize;
if target_degree > m_max {
// shrink links
self.hnsw_shrink_neighbour(
&neighbour.0,
neighbour.1,
neighbour.2,
m_max,
current_level,
manifest,
idx_table,
orig_table,
)?;
}
}
}
} else {
// This is the first vector in the index.
let level = manifest.get_random_level();
self.hnsw_put_fresh_at_levels(tuple, idx, subidx, orig_table, idx_table, level, 0)?;
self.hnsw_put_fresh_at_levels(tuple_key, idx, subidx, orig_table, idx_table, level, 0)?;
}
Ok(())
}
fn hnsw_shrink_neighbour(
&mut self,
target: &[DataValue],
idx: usize,
sub_idx: i32,
m: usize,
level: i64,
manifest: &HnswIndexManifest,
idx_table: &RelationHandle,
orig_table: &RelationHandle,
) -> Result<()> {
let orig_key = orig_table.encode_key_for_store(target, Default::default())?;
let orig_val = match self.store_tx.get(&orig_key, false)? {
Some(bytes) => bytes,
None => {
bail!("Indexed vector not found, this signifies a bug in the index implementation")
}
};
let orig_tuple = decode_tuple_from_kv(&orig_key, &orig_val);
let vec = manifest.get_vector(&orig_tuple, idx, sub_idx)?;
let mut candidates = PriorityQueue::new();
for neighbour in
self.hnsw_get_neighbours(target.to_vec(), idx, sub_idx, level, idx_table)?
{
candidates.push(
(neighbour.0, neighbour.1, neighbour.2),
OrderedFloat(neighbour.3),
);
}
let new_candidates = self.hnsw_select_neighbours_heuristic(
&vec,
&candidates,
m,
level,
manifest,
idx_table,
orig_table,
)?;
let mut old_candidate_set = BTreeSet::new();
for (old, _) in &candidates {
old_candidate_set.insert(old.clone());
}
let mut new_candidate_set = BTreeSet::new();
for (new, _) in &new_candidates {
new_candidate_set.insert(new.clone());
}
for (new, Reverse(OrderedFloat(new_dist))) in new_candidates {
if !old_candidate_set.contains(&new) {
let mut new_key = Vec::with_capacity(orig_table.metadata.keys.len() * 2 + 5);
let new_val = vec![
DataValue::from(new_dist),
DataValue::Null,
DataValue::from(false),
];
new_key.push(DataValue::from(level));
new_key.extend_from_slice(target);
new_key.push(DataValue::from(idx as i64));
new_key.push(DataValue::from(sub_idx as i64));
new_key.extend_from_slice(&new.0);
new_key.push(DataValue::from(new.1 as i64));
new_key.push(DataValue::from(new.2 as i64));
let new_key_bytes = idx_table.encode_key_for_store(&new_key, Default::default())?;
let new_val_bytes =
idx_table.encode_val_only_for_store(&new_val, Default::default())?;
self.store_tx.put(&new_key_bytes, &new_val_bytes)?;
}
}
for (old, OrderedFloat(old_dist)) in candidates {
if !new_candidate_set.contains(&old) {
let mut old_key = Vec::with_capacity(orig_table.metadata.keys.len() * 2 + 5);
old_key.push(DataValue::from(level));
old_key.extend_from_slice(target);
old_key.push(DataValue::from(idx as i64));
old_key.push(DataValue::from(sub_idx as i64));
old_key.extend_from_slice(&old.0);
old_key.push(DataValue::from(old.1 as i64));
old_key.push(DataValue::from(old.2 as i64));
let old_key_bytes = idx_table.encode_key_for_store(&old_key, Default::default())?;
let old_existing_val = match self.store_tx.get(&old_key_bytes, false)? {
Some(bytes) => bytes,
None => {
bail!("Indexed vector not found, this signifies a bug in the index implementation")
}
};
let old_existing_val: Vec<DataValue> =
rmp_serde::from_slice(&old_existing_val[ENCODED_KEY_MIN_LEN..]).unwrap();
if old_existing_val[2].get_bool().unwrap() {
self.store_tx.del(&old_key_bytes)?;
} else {
let old_val = vec![
DataValue::from(old_dist),
DataValue::Null,
DataValue::from(true),
];
let old_val_bytes =
idx_table.encode_val_only_for_store(&old_val, Default::default())?;
self.store_tx.put(&old_key_bytes, &old_val_bytes)?;
}
}
}
Ok(())
}
fn hnsw_compare_vector(
@ -249,8 +475,67 @@ impl<'a> SessionTx<'a> {
let target_tuple = decode_tuple_from_kv(&target_key_bytes, &bytes);
manifest.get_distance(q, &target_tuple, target_idx, target_subidx)
}
fn hnsw_select_neighbours_heuristic(&self) -> Result<()> {
todo!()
fn hnsw_select_neighbours_heuristic(
&self,
q: &Vector,
found: &PriorityQueue<(Tuple, usize, i32), OrderedFloat<f64>>,
m: usize,
level: i64,
manifest: &HnswIndexManifest,
idx_table: &RelationHandle,
orig_table: &RelationHandle,
) -> Result<PriorityQueue<(Tuple, usize, i32), Reverse<OrderedFloat<f64>>>> {
let mut candidates = PriorityQueue::new();
let mut ret: PriorityQueue<_, Reverse<OrderedFloat<_>>> = PriorityQueue::new();
let mut discarded: PriorityQueue<_, Reverse<OrderedFloat<_>>> = PriorityQueue::new();
for (item, dist) in found.iter() {
// Add to candidates
candidates.push(item.clone(), Reverse(*dist));
}
if manifest.extend_candidates {
for (item, _) in found.iter() {
// Extend by neighbours
for neighbour in
self.hnsw_get_neighbours(item.0.clone(), item.1, item.2, level, idx_table)?
{
let dist = self.hnsw_compare_vector(
q,
&neighbour.0,
neighbour.1,
neighbour.2,
manifest,
orig_table,
)?;
candidates.push(
(neighbour.0, neighbour.1, neighbour.2),
Reverse(OrderedFloat(dist)),
);
}
}
}
while !candidates.is_empty() && ret.len() < m {
let (nearest_triple, Reverse(OrderedFloat(nearest_dist))) = candidates.pop().unwrap();
match ret.peek() {
Some((_, Reverse(OrderedFloat(dist)))) => {
if nearest_dist < *dist {
ret.push(nearest_triple, Reverse(OrderedFloat(nearest_dist)));
} else if manifest.keep_pruned_connections {
discarded.push(nearest_triple, Reverse(OrderedFloat(nearest_dist)));
}
}
None => {
ret.push(nearest_triple, Reverse(OrderedFloat(nearest_dist)));
}
}
}
if manifest.keep_pruned_connections {
while !discarded.is_empty() && ret.len() < m {
let (nearest_triple, Reverse(OrderedFloat(nearest_dist))) =
discarded.pop().unwrap();
ret.push(nearest_triple, Reverse(OrderedFloat(nearest_dist)));
}
}
Ok(ret)
}
fn hnsw_search_level(
&self,
@ -278,13 +563,14 @@ impl<'a> SessionTx<'a> {
break;
}
// loop over each of the candidate's neighbors
for neighbour_triple in self.hnsw_get_neighbours(
for neighbour_tetra in self.hnsw_get_neighbours(
candidate.0,
candidate.1,
candidate.2,
cur_level,
idx_table,
)? {
let neighbour_triple = (neighbour_tetra.0, neighbour_tetra.1, neighbour_tetra.2);
if visited.contains(&neighbour_triple) {
continue;
}
@ -320,36 +606,42 @@ impl<'a> SessionTx<'a> {
cand_sub_idx: i32,
level: i64,
idx_handle: &RelationHandle,
) -> Result<impl Iterator<Item = (Tuple, usize, i32)> + 'b> {
) -> Result<impl Iterator<Item = (Tuple, usize, i32, f64)> + 'b> {
let mut start_tuple = Vec::with_capacity(cand_key.len() + 3);
start_tuple.push(DataValue::from(level));
start_tuple.extend_from_slice(&cand_key);
start_tuple.push(DataValue::from(cand_idx as i64));
start_tuple.push(DataValue::from(cand_sub_idx as i64));
let mut end_tuple = start_tuple.clone();
end_tuple.push(DataValue::Bot);
let start_bytes = idx_handle.encode_key_for_store(&start_tuple, Default::default())?;
let end_bytes = idx_handle.encode_key_for_store(&end_tuple, Default::default())?;
Ok(self
.store_tx
.range_scan(&start_bytes, &end_bytes)
let key_len = cand_key.len();
Ok(idx_handle
.scan_prefix(self, &start_tuple)
.filter_map(move |res| {
let (key, _value) = res.unwrap();
let key_tuple = decode_tuple_from_key(&key);
let key_total_len = key_tuple.len();
let key_idx = key_tuple[key_total_len - 2].get_int().unwrap() as usize;
let key_subidx = key_tuple[key_total_len - 1].get_int().unwrap() as i32;
let key_slice = key_tuple[cand_key.len() + 3..key_total_len - 2].to_vec();
let tuple = res.unwrap();
// println!("tuple: {:?}", tuple);
// println!("key_len: {}", key_len);
let key_idx = tuple[2 * key_len + 3].get_int().unwrap() as usize;
let key_subidx = tuple[2 * key_len + 4].get_int().unwrap() as i32;
let key_slice = tuple[key_len + 3..2 * key_len + 3].to_vec();
if key_slice == cand_key {
None
} else {
Some((key_slice, key_idx, key_subidx))
let is_deleted = tuple[2 * key_len + 7].get_bool().unwrap();
if is_deleted {
None
} else {
Some((
key_slice,
key_idx,
key_subidx,
tuple[2 * key_len + 5].get_float().unwrap(),
))
}
}
}))
}
fn hnsw_put_fresh_at_levels(
&mut self,
tuple: &Tuple,
tuple: &[DataValue],
idx: usize,
subidx: i32,
orig_table: &RelationHandle,
@ -369,29 +661,34 @@ impl<'a> SessionTx<'a> {
canary_key.push(DataValue::Null);
canary_key.push(DataValue::Null);
}
let target_value = [DataValue::from(0.0), DataValue::Null];
let target_value = [
DataValue::from(0.0),
DataValue::Null,
DataValue::from(false),
];
let target_key_bytes = idx_table.encode_key_for_store(&target_key, Default::default())?;
// canary value is for conflict detection: prevent the scenario of disconnected graphs at all levels
let canary_value = [
DataValue::from(bottom_level),
DataValue::Bytes(target_key_bytes),
DataValue::from(false),
];
let canary_key_bytes = idx_table.encode_key_for_store(&canary_key, Default::default())?;
let canary_value_bytes =
idx_table.encode_val_for_store(&canary_value, Default::default())?;
idx_table.encode_val_only_for_store(&canary_value, Default::default())?;
self.store_tx.put(&canary_key_bytes, &canary_value_bytes)?;
for cur_level in bottom_level..=top_level {
target_key[0] = DataValue::from(cur_level);
let key = idx_table.encode_key_for_store(&target_key, Default::default())?;
let val = idx_table.encode_val_for_store(&target_value, Default::default())?;
let val = idx_table.encode_val_only_for_store(&target_value, Default::default())?;
self.store_tx.put(&key, &val)?;
}
Ok(())
}
pub(crate) fn hnsw_put(
&'a mut self,
&mut self,
manifest: &HnswIndexManifest,
orig_table: &RelationHandle,
idx_table: &RelationHandle,

@ -241,7 +241,7 @@ impl RelationHandle {
}
pub(crate) fn encode_val_only_for_store(
&self,
tuple: &Tuple,
tuple: &[DataValue],
_span: SourceSpan,
) -> Result<Vec<u8>> {
let mut ret = self.encode_key_prefix(tuple.len());
@ -407,13 +407,13 @@ impl RelationHandle {
pub(crate) fn scan_bounded_prefix<'a>(
&self,
tx: &'a SessionTx<'_>,
prefix: &Tuple,
prefix: &[DataValue],
lower: &[DataValue],
upper: &[DataValue],
) -> impl Iterator<Item = Result<Tuple>> + 'a {
let mut lower_t = prefix.clone();
let mut lower_t = prefix.to_vec();
lower_t.extend_from_slice(lower);
let mut upper_t = prefix.clone();
let mut upper_t = prefix.to_vec();
upper_t.extend_from_slice(upper);
upper_t.push(DataValue::Bot);
let lower_encoded = lower_t.encode_as_key(self.id);
@ -763,6 +763,14 @@ impl<'a> SessionTx<'a> {
},
default_gen: None,
},
ColumnDef {
name: SmartString::from("ignore_link"),
typing: NullableColType {
coltype: ColType::Bool,
nullable: false,
},
default_gen: None,
}
];
// create index relation
let key_bindings = idx_keys
@ -788,9 +796,6 @@ impl<'a> SessionTx<'a> {
};
let idx_handle = self.create_relation(idx_handle)?;
// populate index
// TODO
// add index to relation
let manifest = HnswIndexManifest {
base_relation: config.base_relation.clone(),
@ -808,6 +813,13 @@ impl<'a> SessionTx<'a> {
extend_candidates: config.extend_candidates,
keep_pruned_connections: config.keep_pruned_connections,
};
// populate index
let all_tuples = rel_handle.scan_all(self).collect::<Result<Vec<_>>>()?;
for tuple in all_tuples {
self.hnsw_put(&manifest, &rel_handle, &idx_handle, None, &tuple)?;
}
rel_handle
.hnsw_indices
.insert(config.index_name.clone(), (idx_handle, manifest));

@ -782,7 +782,14 @@ fn test_vec_types() {
#[test]
fn test_vec_index() {
let db = DbInstance::new("mem", "", "").unwrap();
db.run_script(":create a {k: String => tags: [String], v: <F32; 8>}", Default::default())
db.run_script(r"
?[k, v] <- [['a', [1,2,3,4,5,6,7,8]],
['b', [2,3,4,5,6,7,8,9]],
['bb', [2,3,4,5,6,7,8,9]],
['c', [2,3,4,5,6,7,8,19]]]
:create a {k: String => v: <F32; 8>}
", Default::default())
.unwrap();
db.run_script(r"
::hnsw create a:vec {

Loading…
Cancel
Save