pre-insertions for HNSW

main
Ziyang Hu 1 year ago
parent 6ccdf71892
commit 9462409b90

44
Cargo.lock generated

@ -353,6 +353,15 @@ version = "1.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
[[package]]
name = "block-buffer"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4152116fd6e9dadb291ae18fc1ec3575ed6d84c29642d97890f4b4a3417297e4"
dependencies = [
"generic-array",
]
[[package]]
name = "block-buffer"
version = "0.10.4"
@ -705,6 +714,7 @@ dependencies = [
"serde_bytes",
"serde_derive",
"serde_json",
"sha2 0.9.9",
"sled",
"smallvec",
"smartstring",
@ -1034,13 +1044,22 @@ dependencies = [
"syn 1.0.109",
]
[[package]]
name = "digest"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d3dd60d1080a57a05ab032377049e0591415d2b31afd7028356dbf3cc6dcb066"
dependencies = [
"generic-array",
]
[[package]]
name = "digest"
version = "0.10.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8168378f4e5023e7218c89c891c0fd8ecdb5e5e4f18cb78f38cf245dd021e76f"
dependencies = [
"block-buffer",
"block-buffer 0.10.4",
"crypto-common",
]
@ -2346,6 +2365,12 @@ version = "1.17.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b7e5500299e16ebb147ae15a00a942af264cf3688f47923b8fc2cd5858f23ad3"
[[package]]
name = "opaque-debug"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5"
[[package]]
name = "openssl"
version = "0.10.50"
@ -2547,7 +2572,7 @@ checksum = "6733073c7cff3d8459fda0e42f13a047870242aed8b509fe98000928975f359e"
dependencies = [
"once_cell",
"pest",
"sha2",
"sha2 0.10.6",
]
[[package]]
@ -3413,6 +3438,19 @@ dependencies = [
"serde",
]
[[package]]
name = "sha2"
version = "0.9.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4d58a1e1bf39749807d89cf2d98ac2dfa0ff1cb3faa38fbb64dd88ac8013d800"
dependencies = [
"block-buffer 0.9.0",
"cfg-if 1.0.0",
"cpufeatures",
"digest 0.9.0",
"opaque-debug",
]
[[package]]
name = "sha2"
version = "0.10.6"
@ -3421,7 +3459,7 @@ checksum = "82e6b795fe2e3b1e845bafcb27aa35405c4d47cdfc92af5fc8d3002f76cebdc0"
dependencies = [
"cfg-if 1.0.0",
"cpufeatures",
"digest",
"digest 0.10.6",
]
[[package]]

@ -128,3 +128,4 @@ js-sys = { version = "0.3.60", optional = true }
graph = { version = "0.3.0", optional = true }
crossbeam = "0.8.2"
ndarray = { version = "0.15.6", features = ["serde"] }
sha2 = "0.9.8"

@ -156,7 +156,7 @@ pub enum DataValue {
Bot,
}
#[derive(Clone, serde_derive::Serialize, serde_derive::Deserialize)]
#[derive(Debug, Clone, serde_derive::Serialize, serde_derive::Deserialize)]
pub enum Vector {
F32(Array1<f32>),
F64(Array1<f64>),

@ -49,28 +49,14 @@ pub(crate) struct HnswIndexConfig {
pub(crate) vec_dim: usize,
pub(crate) dtype: VecElementType,
pub(crate) vec_fields: Vec<SmartString<LazyCompact>>,
pub(crate) tag_fields: Vec<SmartString<LazyCompact>>,
pub(crate) distance: HnswDistance,
pub(crate) ef_construction: usize,
pub(crate) max_elements: usize,
pub(crate) m_neighbours: usize,
pub(crate) index_filter: Option<String>,
pub(crate) extend_candidates: bool,
pub(crate) keep_pruned_connections: bool
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, serde_derive::Serialize, serde_derive::Deserialize)]
pub(crate) struct HnswIndexManifest {
pub(crate) base_relation: SmartString<LazyCompact>,
pub(crate) index_name: SmartString<LazyCompact>,
pub(crate) vec_dim: usize,
pub(crate) dtype: VecElementType,
pub(crate) vec_fields: Vec<usize>,
pub(crate) tag_fields: Vec<usize>,
pub(crate) distance: HnswDistance,
pub(crate) ef_construction: usize,
pub(crate) max_elements: usize,
pub(crate) index_filter: Option<String>,
}
#[derive(
Debug, Clone, Copy, PartialEq, Eq, Hash, serde_derive::Serialize, serde_derive::Deserialize,
)]
@ -80,17 +66,6 @@ pub(crate) enum HnswDistance {
Cosine,
}
pub(crate) struct HnswKnnQueryOptions {
k: usize,
ef: usize,
max_distance: f64,
min_margin: f64,
auto_margin_factor: Option<f64>,
bind_field: Option<Symbol>,
bind_distance: Option<Symbol>,
bind_vector: Option<Symbol>,
}
#[derive(Debug, Diagnostic, Error)]
#[error("Cannot interpret {0} as process ID")]
#[diagnostic(code(parser::not_proc_id))]
@ -211,13 +186,13 @@ pub(crate) fn parse_sys(
let mut vec_dim = 0;
let mut dtype = VecElementType::F32;
let mut vec_fields = vec![];
let mut tag_fields = vec![];
let mut distance = HnswDistance::L2;
let mut ef_construction = 0;
let mut max_elements = 0;
let mut index_filter = None;
let mut extend_candidates = false;
let mut keep_pruned_connections = false;
// TODO this is a bit of a mess
for opt_pair in inner {
let mut opt_inner = opt_pair.into_inner();
let opt_name = opt_inner.next().unwrap();
@ -242,11 +217,7 @@ pub(crate) fn parse_sys(
let fields = build_expr(opt_val, &Default::default())?;
vec_fields = fields.to_var_list()?;
}
"tags" => {
let fields = build_expr(opt_val, &Default::default())?;
tag_fields = fields.to_var_list()?;
}
"distance" => {
"distance" | "dist" => {
distance = match opt_val.as_str() {
"L2" => HnswDistance::L2,
"IP" => HnswDistance::InnerProduct,
@ -259,13 +230,13 @@ pub(crate) fn parse_sys(
}
}
}
"ef_construction" => {
"ef_construction" | "ef" => {
ef_construction = opt_val
.as_str()
.parse()
.map_err(|e| miette!("Invalid ef_construction: {}", e))?;
}
"max_elements" => {
"m_neighbours" | "m" | "M" => {
max_elements = opt_val
.as_str()
.parse()
@ -274,6 +245,12 @@ pub(crate) fn parse_sys(
"filter" => {
index_filter = Some(opt_val.as_str().to_string());
}
"extend_candidates" => {
extend_candidates = opt_val.as_str() == "true";
}
"keep_pruned_connections" => {
keep_pruned_connections = opt_val.as_str() == "true";
}
_ => return Err(miette!("Invalid option: {}", opt_name.as_str())),
}
}
@ -283,11 +260,12 @@ pub(crate) fn parse_sys(
vec_dim,
dtype,
vec_fields,
tag_fields,
distance,
ef_construction,
max_elements,
m_neighbours: max_elements,
index_filter,
extend_candidates,
keep_pruned_connections,
})
}
Rule::index_drop => {

@ -153,7 +153,7 @@ impl<'a> SessionTx<'a> {
let mut old_tuples: Vec<DataValue> = vec![];
for tuple in res_iter {
let extracted = key_extractors
let extracted: Vec<DataValue> = key_extractors
.iter()
.map(|ex| ex.extract_data(&tuple, cur_vld))
.try_collect()?;
@ -309,7 +309,7 @@ impl<'a> SessionTx<'a> {
key_extractors.extend(val_extractors);
for tuple in res_iter {
let extracted = key_extractors
let extracted: Vec<DataValue> = key_extractors
.iter()
.map(|ex| ex.extract_data(&tuple, cur_vld))
.try_collect()?;
@ -360,7 +360,7 @@ impl<'a> SessionTx<'a> {
)?;
for tuple in res_iter {
let extracted = key_extractors
let extracted: Vec<DataValue> = key_extractors
.iter()
.map(|ex| ex.extract_data(&tuple, cur_vld))
.try_collect()?;
@ -411,7 +411,7 @@ impl<'a> SessionTx<'a> {
key_extractors.extend(val_extractors);
for tuple in res_iter {
let extracted = key_extractors
let extracted: Vec<DataValue> = key_extractors
.iter()
.map(|ex| ex.extract_data(&tuple, cur_vld))
.try_collect()?;

@ -7,30 +7,392 @@
*/
use crate::data::expr::{eval_bytecode_pred, Bytecode};
use crate::data::tuple::Tuple;
use crate::data::relation::VecElementType;
use crate::data::tuple::{decode_tuple_from_key, Tuple};
use crate::data::value::Vector;
use crate::parse::sys::HnswIndexManifest;
use crate::parse::sys::HnswDistance;
use crate::runtime::relation::RelationHandle;
use crate::runtime::transact::SessionTx;
use crate::DataValue;
use miette::Result;
use crate::{decode_tuple_from_kv, DataValue, Symbol};
use miette::{bail, Result};
use ordered_float::OrderedFloat;
use priority_queue::PriorityQueue;
use rand::Rng;
use smartstring::{LazyCompact, SmartString};
use std::cmp::{max, Reverse};
use std::collections::BTreeSet;
#[derive(Debug, Clone, PartialEq, serde_derive::Serialize, serde_derive::Deserialize)]
pub(crate) struct HnswIndexManifest {
pub(crate) base_relation: SmartString<LazyCompact>,
pub(crate) index_name: SmartString<LazyCompact>,
pub(crate) vec_dim: usize,
pub(crate) dtype: VecElementType,
pub(crate) vec_fields: Vec<usize>,
pub(crate) distance: HnswDistance,
pub(crate) ef_construction: usize,
pub(crate) m_neighbours: usize,
pub(crate) m_max: usize,
pub(crate) m_max0: usize,
pub(crate) level_multiplier: f64,
pub(crate) index_filter: Option<String>,
pub(crate) extend_candidates: bool,
pub(crate) keep_pruned_connections: bool,
}
pub(crate) struct HnswKnnQueryOptions {
k: usize,
ef: usize,
max_distance: f64,
min_margin: f64,
auto_margin_factor: Option<f64>,
bind_field: Option<Symbol>,
bind_distance: Option<Symbol>,
bind_vector: Option<Symbol>,
}
impl HnswIndexManifest {
fn get_random_level(&self) -> i64 {
let mut rng = rand::thread_rng();
let uniform_num: f64 = rng.gen_range(0.0..1.0);
let r = -uniform_num.ln() * self.level_multiplier;
// the level is the largest integer smaller than r
-(r.floor() as i64)
}
fn get_vector(&self, tuple: &Tuple, idx: usize, sub_idx: i32) -> Result<Vector> {
let field = tuple.get(idx).unwrap();
if sub_idx >= 0 {
match field {
DataValue::List(l) => match l.get(sub_idx as usize) {
Some(DataValue::Vec(v)) => Ok(v.clone()),
_ => bail!(
"Cannot extract vector from {} for sub index {}",
field,
sub_idx
),
},
_ => bail!("Cannot interpret {} as list", field),
}
} else {
match field {
DataValue::Vec(v) => Ok(v.clone()),
_ => bail!("Cannot interpret {} as vector", field),
}
}
}
fn get_distance(&self, q: &Vector, tuple: &Tuple, idx: usize, sub_idx: i32) -> Result<f64> {
let field = tuple.get(idx).unwrap();
let target = if sub_idx >= 0 {
match field {
DataValue::List(l) => match l.get(sub_idx as usize) {
Some(DataValue::Vec(v)) => v,
_ => bail!(
"Cannot extract vector from {} for sub index {}",
field,
sub_idx
),
},
_ => bail!("Cannot interpret {} as list", field),
}
} else {
match field {
DataValue::Vec(v) => v,
_ => bail!("Cannot interpret {} as vector", field),
}
};
Ok(match self.distance {
HnswDistance::L2 => match (q, target) {
(Vector::F32(a), Vector::F32(b)) => {
let diff = a - b;
diff.dot(&diff) as f64
}
(Vector::F64(a), Vector::F64(b)) => {
let diff = a - b;
diff.dot(&diff)
}
_ => bail!(
"Cannot compute L2 distance between {:?} and {:?}",
q,
target
),
},
HnswDistance::Cosine => match (q, target) {
(Vector::F32(a), Vector::F32(b)) => {
let a_norm = a.dot(a) as f64;
let b_norm = b.dot(b) as f64;
let dot = a.dot(b) as f64;
1.0 - dot / (a_norm * b_norm).sqrt()
}
(Vector::F64(a), Vector::F64(b)) => {
let a_norm = a.dot(a) as f64;
let b_norm = b.dot(b) as f64;
let dot = a.dot(b);
1.0 - dot / (a_norm * b_norm).sqrt()
}
_ => bail!(
"Cannot compute cosine distance between {:?} and {:?}",
q,
target
),
},
HnswDistance::InnerProduct => match (q, target) {
(Vector::F32(a), Vector::F32(b)) => {
let dot = a.dot(b);
1. - dot as f64
}
(Vector::F64(a), Vector::F64(b)) => {
let dot = a.dot(b);
1. - dot as f64
}
_ => bail!(
"Cannot compute inner product between {:?} and {:?}",
q,
target
),
},
})
}
}
impl<'a> SessionTx<'a> {
fn hnsw_put_vector(
&mut self,
vec: &Vector,
tuple: &Tuple,
q: &Vector,
idx: usize,
subidx: i32,
manifest: &HnswIndexManifest,
orig_table: &RelationHandle,
idx_table: &RelationHandle,
tags: &[SmartString<LazyCompact>]
) -> Result<()> {
let start_tuple =
idx_table.encode_key_for_store(&vec![DataValue::from(i64::MIN)], Default::default())?;
let end_tuple =
idx_table.encode_key_for_store(&vec![DataValue::from(1)], Default::default())?;
let ep_res = self.store_tx.range_scan(&start_tuple, &end_tuple).next();
if let Some(ep) = ep_res {
let (ep_key_bytes, _) = ep?;
let ep_key_tuple = decode_tuple_from_key(&ep_key_bytes);
// bottom level since we are going up
let bottom_level = ep_key_tuple[0].get_int().unwrap();
let ep_key = ep_key_tuple[1..orig_table.metadata.keys.len() + 1].to_vec();
let ep_idx = ep_key_tuple[orig_table.metadata.keys.len() + 1]
.get_int()
.unwrap() as usize;
let ep_subidx = ep_key_tuple[orig_table.metadata.keys.len() + 2]
.get_int()
.unwrap() as i32;
let ep_distance =
self.hnsw_compare_vector(q, &ep_key, idx, subidx, manifest, orig_table)?;
let mut found_nn = PriorityQueue::new();
found_nn.push((ep_key, ep_idx, ep_subidx), OrderedFloat(ep_distance));
let target_level = manifest.get_random_level();
if target_level < bottom_level {
// this becomes the entry point
self.hnsw_put_fresh_at_levels(
tuple,
idx,
subidx,
orig_table,
idx_table,
target_level,
bottom_level - 1,
)?;
}
for current_level in bottom_level..target_level {
self.hnsw_search_level(
q,
1,
current_level,
manifest,
orig_table,
idx_table,
&mut found_nn,
)?;
}
for current_level in max(target_level, bottom_level)..=0 {
self.hnsw_search_level(
q,
manifest.ef_construction,
current_level,
manifest,
orig_table,
idx_table,
&mut found_nn,
)?;
// add bidirectional links to the nearest neighbors
todo!();
// shrink links if necessary
todo!();
}
} else {
// This is the first vector in the index.
let level = manifest.get_random_level();
self.hnsw_put_fresh_at_levels(tuple, idx, subidx, orig_table, idx_table, level, 0)?;
}
Ok(())
}
fn hnsw_compare_vector(
&self,
q: &Vector,
target_key: &[DataValue],
target_idx: usize,
target_subidx: i32,
manifest: &HnswIndexManifest,
orig_table: &RelationHandle,
) -> Result<f64> {
let target_key_bytes = orig_table.encode_key_for_store(target_key, Default::default())?;
let bytes = match self.store_tx.get(&target_key_bytes, false)? {
Some(bytes) => bytes,
None => bail!("Indexed data not found, this signifies a bug in the index."),
};
let target_tuple = decode_tuple_from_kv(&target_key_bytes, &bytes);
manifest.get_distance(q, &target_tuple, target_idx, target_subidx)
}
fn hnsw_select_neighbours_heuristic(&self) -> Result<()> {
todo!()
}
pub(crate) fn hnsw_put(
fn hnsw_search_level(
&self,
q: &Vector,
ef: usize,
cur_level: i64,
manifest: &HnswIndexManifest,
orig_table: &RelationHandle,
idx_table: &RelationHandle,
found_nn: &mut PriorityQueue<(Tuple, usize, i32), OrderedFloat<f64>>,
) -> Result<()> {
let mut visited: BTreeSet<(Tuple, usize, i32)> = BTreeSet::new();
let mut candidates: PriorityQueue<(Tuple, usize, i32), Reverse<OrderedFloat<f64>>> =
PriorityQueue::new();
for item in found_nn.iter() {
visited.insert(item.0.clone());
candidates.push(item.0.clone(), Reverse(*item.1));
}
while let Some((candidate, Reverse(OrderedFloat(candidate_dist)))) = candidates.pop() {
let (_, OrderedFloat(furtherest_dist)) = found_nn.peek().unwrap();
let furtherest_dist = *furtherest_dist;
if candidate_dist > furtherest_dist {
break;
}
// loop over each of the candidate's neighbors
for neighbour_triple in self.hnsw_get_neighbours(
candidate.0,
candidate.1,
candidate.2,
cur_level,
idx_table,
)? {
if visited.contains(&neighbour_triple) {
continue;
}
let neighbour_dist = self.hnsw_compare_vector(
q,
&neighbour_triple.0,
neighbour_triple.1,
neighbour_triple.2,
manifest,
orig_table,
)?;
let (_, OrderedFloat(cand_furtherest_dist)) = found_nn.peek().unwrap();
if found_nn.len() < ef || neighbour_dist < *cand_furtherest_dist {
candidates.push(
neighbour_triple.clone(),
Reverse(OrderedFloat(neighbour_dist)),
);
found_nn.push(neighbour_triple.clone(), OrderedFloat(neighbour_dist));
if found_nn.len() > ef {
found_nn.pop();
}
}
visited.insert(neighbour_triple);
}
}
Ok(())
}
fn hnsw_get_neighbours<'b>(
&'b self,
cand_key: Vec<DataValue>,
cand_idx: usize,
cand_sub_idx: i32,
level: i64,
idx_handle: &RelationHandle,
) -> Result<impl Iterator<Item = (Tuple, usize, i32)> + 'b> {
let mut start_tuple = Vec::with_capacity(cand_key.len() + 3);
start_tuple.push(DataValue::from(level));
start_tuple.extend_from_slice(&cand_key);
start_tuple.push(DataValue::from(cand_idx as i64));
start_tuple.push(DataValue::from(cand_sub_idx as i64));
let mut end_tuple = start_tuple.clone();
end_tuple.push(DataValue::Bot);
let start_bytes = idx_handle.encode_key_for_store(&start_tuple, Default::default())?;
let end_bytes = idx_handle.encode_key_for_store(&end_tuple, Default::default())?;
Ok(self
.store_tx
.range_scan(&start_bytes, &end_bytes)
.filter_map(move |res| {
let (key, _value) = res.unwrap();
let key_tuple = decode_tuple_from_key(&key);
let key_total_len = key_tuple.len();
let key_idx = key_tuple[key_total_len - 2].get_int().unwrap() as usize;
let key_subidx = key_tuple[key_total_len - 1].get_int().unwrap() as i32;
let key_slice = key_tuple[cand_key.len() + 3..key_total_len - 2].to_vec();
if key_slice == cand_key {
None
} else {
Some((key_slice, key_idx, key_subidx))
}
}))
}
fn hnsw_put_fresh_at_levels(
&mut self,
config: &HnswIndexManifest,
tuple: &Tuple,
idx: usize,
subidx: i32,
orig_table: &RelationHandle,
idx_table: &RelationHandle,
bottom_level: i64,
top_level: i64,
) -> Result<()> {
let mut target_key = vec![DataValue::Null];
let mut canary_key = vec![DataValue::from(1)];
for _ in 0..2 {
for i in 0..orig_table.metadata.keys.len() {
target_key.push(tuple.get(i).unwrap().clone());
canary_key.push(DataValue::Null);
}
target_key.push(DataValue::from(idx as i64));
target_key.push(DataValue::from(subidx as i64));
canary_key.push(DataValue::Null);
canary_key.push(DataValue::Null);
}
let target_value = [DataValue::from(0.0), DataValue::Null];
let target_key_bytes = idx_table.encode_key_for_store(&target_key, Default::default())?;
// canary value is for conflict detection: prevent the scenario of disconnected graphs at all levels
let canary_value = [
DataValue::from(bottom_level),
DataValue::Bytes(target_key_bytes),
];
let canary_key_bytes = idx_table.encode_key_for_store(&canary_key, Default::default())?;
let canary_value_bytes =
idx_table.encode_val_for_store(&canary_value, Default::default())?;
self.store_tx.put(&canary_key_bytes, &canary_value_bytes)?;
for cur_level in bottom_level..=top_level {
target_key[0] = DataValue::from(cur_level);
let key = idx_table.encode_key_for_store(&target_key, Default::default())?;
let val = idx_table.encode_val_for_store(&target_value, Default::default())?;
self.store_tx.put(&key, &val)?;
}
Ok(())
}
pub(crate) fn hnsw_put(
&'a mut self,
manifest: &HnswIndexManifest,
orig_table: &RelationHandle,
idx_table: &RelationHandle,
filter: Option<(&[Bytecode], &mut Vec<DataValue>)>,
@ -42,7 +404,7 @@ impl<'a> SessionTx<'a> {
}
}
let mut extracted_vectors = vec![];
for idx in &config.vec_fields {
for idx in &manifest.vec_fields {
let val = tuple.get(*idx).unwrap();
if let DataValue::Vec(v) = val {
extracted_vectors.push((v, *idx, -1 as i32));
@ -57,34 +419,37 @@ impl<'a> SessionTx<'a> {
if extracted_vectors.is_empty() {
return Ok(false);
}
let mut extracted_tags: Vec<SmartString<LazyCompact>> = vec![];
for tag_idx in &config.tag_fields {
let tag_field = tuple.get(*tag_idx).unwrap();
if let Some(s) = tag_field.get_str() {
extracted_tags.push(SmartString::from(s));
} else if let DataValue::List(l) = tag_field {
for tag in l {
if let Some(s) = tag.get_str() {
extracted_tags.push(SmartString::from(s));
}
}
}
}
for (vec, idx, sub) in extracted_vectors {
self.hnsw_put_vector(vec, idx, sub, orig_table, idx_table, &extracted_tags)?;
self.hnsw_put_vector(&tuple, vec, idx, sub, manifest, orig_table, idx_table)?;
}
Ok(true)
}
pub(crate) fn hnsw_remove(
&mut self,
config: &HnswIndexManifest,
orig_table: &RelationHandle,
idx_table: &RelationHandle,
tuple: &Tuple,
) -> Result<()> {
pub(crate) fn hnsw_remove(&mut self) -> Result<()> {
todo!()
}
pub(crate) fn hnsw_knn(&self, node: u64, k: usize) -> Vec<(u64, f32)> {
pub(crate) fn hnsw_knn(&self) -> Result<()> {
todo!()
}
}
#[cfg(test)]
mod tests {
use rand::Rng;
use std::collections::BTreeMap;
#[test]
fn test_random_level() {
let m = 20;
let mult = 1. / (m as f64).ln();
let mut rng = rand::thread_rng();
let mut collected = BTreeMap::new();
for _ in 0..10000 {
let uniform_num: f64 = rng.gen_range(0.0..1.0);
let r = -uniform_num.ln() * mult;
// the level is the largest integer smaller than r
let level = -(r.floor() as i64);
collected.entry(level).and_modify(|x| *x += 1).or_insert(1);
}
println!("{:?}", collected);
}
}

@ -23,11 +23,12 @@ use crate::data::relation::{ColType, ColumnDef, NullableColType, StoredRelationM
use crate::data::symb::Symbol;
use crate::data::tuple::{decode_tuple_from_key, Tuple, TupleT, ENCODED_KEY_MIN_LEN};
use crate::data::value::{DataValue, ValidityTs};
use crate::parse::sys::{HnswIndexConfig, HnswIndexManifest};
use crate::parse::sys::HnswIndexConfig;
use crate::parse::SourceSpan;
use crate::query::compile::IndexPositionUse;
use crate::runtime::transact::SessionTx;
use crate::{NamedRows, StoreTx};
use crate::runtime::hnsw::HnswIndexManifest;
#[derive(
Copy,
@ -65,7 +66,7 @@ impl RelationId {
}
}
#[derive(Clone, Eq, PartialEq, serde_derive::Serialize, serde_derive::Deserialize)]
#[derive(Clone, PartialEq, serde_derive::Serialize, serde_derive::Deserialize)]
pub(crate) struct RelationHandle {
pub(crate) name: SmartString<LazyCompact>,
pub(crate) id: RelationId,
@ -204,7 +205,11 @@ impl RelationHandle {
}
chosen
}
pub(crate) fn encode_key_for_store(&self, tuple: &Tuple, span: SourceSpan) -> Result<Vec<u8>> {
pub(crate) fn encode_key_for_store(
&self,
tuple: &[DataValue],
span: SourceSpan,
) -> Result<Vec<u8>> {
let len = self.metadata.keys.len();
ensure!(
tuple.len() >= len,
@ -221,7 +226,11 @@ impl RelationHandle {
}
Ok(ret)
}
pub(crate) fn encode_val_for_store(&self, tuple: &Tuple, _span: SourceSpan) -> Result<Vec<u8>> {
pub(crate) fn encode_val_for_store(
&self,
tuple: &[DataValue],
_span: SourceSpan,
) -> Result<Vec<u8>> {
let start = self.metadata.keys.len();
let len = self.metadata.non_keys.len();
let mut ret = self.encode_key_prefix(len);
@ -699,34 +708,6 @@ impl<'a> SessionTx<'a> {
}
}
// We only allow string tags
let mut tag_field_indices = vec![];
for field in config.tag_fields.iter() {
for (i, col) in rel_handle
.metadata
.keys
.iter()
.chain(rel_handle.metadata.non_keys.iter()).enumerate()
{
if col.name == *field {
let mut col_type = col.typing.coltype.clone();
if let ColType::List { eltype, .. } = &col_type {
col_type = eltype.coltype.clone();
}
if col_type != ColType::String {
bail!(
"Cannot create HNSW index with field {} of type {:?} (expected Str)",
field,
col_type
);
}
tag_field_indices.push(i);
break;
}
}
}
// Build key columns definitions
let mut idx_keys: Vec<ColumnDef> = vec![ColumnDef {
// layer -1 stores the self-loops
@ -775,16 +756,10 @@ impl<'a> SessionTx<'a> {
},
// For self-loops, stores a hash of the neighbours, for conflict detection
ColumnDef {
name: SmartString::from("tags"),
name: SmartString::from("hash"),
typing: NullableColType {
coltype: ColType::List {
eltype: Box::new(NullableColType {
coltype: ColType::String,
nullable: false,
}),
len: None,
},
nullable: false,
coltype: ColType::Bytes,
nullable: true,
},
default_gen: None,
},
@ -823,11 +798,15 @@ impl<'a> SessionTx<'a> {
vec_dim: config.vec_dim,
dtype: config.dtype,
vec_fields: vec_field_indices,
tag_fields: tag_field_indices,
distance: config.distance,
ef_construction: config.ef_construction,
max_elements: config.max_elements,
m_neighbours: config.m_neighbours,
m_max: config.m_neighbours,
m_max0: config.m_neighbours * 2,
level_multiplier: 1. / (config.m_neighbours as f64).ln(),
index_filter: config.index_filter,
extend_candidates: config.extend_candidates,
keep_pruned_connections: config.keep_pruned_connections,
};
rel_handle
.hnsw_indices

@ -787,12 +787,11 @@ fn test_vec_index() {
db.run_script(r"
::hnsw create a:vec {
dim: 8,
m: 50,
dtype: F32,
fields: [v],
tags: tags,
distance: Cosine,
ef_construction: 20,
max_elements: 50,
filter: k != 'k1'
}", Default::default())
.unwrap();

Loading…
Cancel
Save