hnsw preliminary

main
Ziyang Hu 1 year ago
parent a8f946a719
commit 6ccdf71892

@ -42,7 +42,7 @@ pub(crate) enum SysOp {
RemoveIndex(Symbol, Symbol), RemoveIndex(Symbol, Symbol),
} }
#[derive(Debug, Clone, PartialEq, Eq, Hash, serde_derive::Serialize, serde_derive::Deserialize)] #[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub(crate) struct HnswIndexConfig { pub(crate) struct HnswIndexConfig {
pub(crate) base_relation: SmartString<LazyCompact>, pub(crate) base_relation: SmartString<LazyCompact>,
pub(crate) index_name: SmartString<LazyCompact>, pub(crate) index_name: SmartString<LazyCompact>,
@ -56,6 +56,21 @@ pub(crate) struct HnswIndexConfig {
pub(crate) index_filter: Option<String>, pub(crate) index_filter: Option<String>,
} }
#[derive(Debug, Clone, PartialEq, Eq, Hash, serde_derive::Serialize, serde_derive::Deserialize)]
pub(crate) struct HnswIndexManifest {
pub(crate) base_relation: SmartString<LazyCompact>,
pub(crate) index_name: SmartString<LazyCompact>,
pub(crate) vec_dim: usize,
pub(crate) dtype: VecElementType,
pub(crate) vec_fields: Vec<usize>,
pub(crate) tag_fields: Vec<usize>,
pub(crate) distance: HnswDistance,
pub(crate) ef_construction: usize,
pub(crate) max_elements: usize,
pub(crate) index_filter: Option<String>,
}
#[derive( #[derive(
Debug, Clone, Copy, PartialEq, Eq, Hash, serde_derive::Serialize, serde_derive::Deserialize, Debug, Clone, Copy, PartialEq, Eq, Hash, serde_derive::Serialize, serde_derive::Deserialize,
)] )]

@ -0,0 +1,90 @@
/*
* Copyright 2023, The Cozo Project Authors.
*
* This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0.
* If a copy of the MPL was not distributed with this file,
* You can obtain one at https://mozilla.org/MPL/2.0/.
*/
use crate::data::expr::{eval_bytecode_pred, Bytecode};
use crate::data::tuple::Tuple;
use crate::data::value::Vector;
use crate::parse::sys::HnswIndexManifest;
use crate::runtime::relation::RelationHandle;
use crate::runtime::transact::SessionTx;
use crate::DataValue;
use miette::Result;
use smartstring::{LazyCompact, SmartString};
impl<'a> SessionTx<'a> {
fn hnsw_put_vector(
&mut self,
vec: &Vector,
idx: usize,
subidx: i32,
orig_table: &RelationHandle,
idx_table: &RelationHandle,
tags: &[SmartString<LazyCompact>]
) -> Result<()> {
todo!()
}
pub(crate) fn hnsw_put(
&mut self,
config: &HnswIndexManifest,
orig_table: &RelationHandle,
idx_table: &RelationHandle,
filter: Option<(&[Bytecode], &mut Vec<DataValue>)>,
tuple: &Tuple,
) -> Result<bool> {
if let Some((code, stack)) = filter {
if !eval_bytecode_pred(code, tuple, stack, Default::default())? {
return Ok(false);
}
}
let mut extracted_vectors = vec![];
for idx in &config.vec_fields {
let val = tuple.get(*idx).unwrap();
if let DataValue::Vec(v) = val {
extracted_vectors.push((v, *idx, -1 as i32));
} else if let DataValue::List(l) = val {
for (sidx, v) in l.iter().enumerate() {
if let DataValue::Vec(v) = v {
extracted_vectors.push((v, *idx, sidx as i32));
}
}
}
}
if extracted_vectors.is_empty() {
return Ok(false);
}
let mut extracted_tags: Vec<SmartString<LazyCompact>> = vec![];
for tag_idx in &config.tag_fields {
let tag_field = tuple.get(*tag_idx).unwrap();
if let Some(s) = tag_field.get_str() {
extracted_tags.push(SmartString::from(s));
} else if let DataValue::List(l) = tag_field {
for tag in l {
if let Some(s) = tag.get_str() {
extracted_tags.push(SmartString::from(s));
}
}
}
}
for (vec, idx, sub) in extracted_vectors {
self.hnsw_put_vector(vec, idx, sub, orig_table, idx_table, &extracted_tags)?;
}
Ok(true)
}
pub(crate) fn hnsw_remove(
&mut self,
config: &HnswIndexManifest,
orig_table: &RelationHandle,
idx_table: &RelationHandle,
tuple: &Tuple,
) -> Result<()> {
todo!()
}
pub(crate) fn hnsw_knn(&self, node: u64, k: usize) -> Vec<(u64, f32)> {
todo!()
}
}

@ -11,6 +11,7 @@ pub(crate) mod db;
pub(crate) mod imperative; pub(crate) mod imperative;
pub(crate) mod relation; pub(crate) mod relation;
pub(crate) mod temp_store; pub(crate) mod temp_store;
pub(crate) mod transact;
pub(crate) mod hnsw;
#[cfg(test)] #[cfg(test)]
mod tests; mod tests;
pub(crate) mod transact;

@ -23,7 +23,7 @@ use crate::data::relation::{ColType, ColumnDef, NullableColType, StoredRelationM
use crate::data::symb::Symbol; use crate::data::symb::Symbol;
use crate::data::tuple::{decode_tuple_from_key, Tuple, TupleT, ENCODED_KEY_MIN_LEN}; use crate::data::tuple::{decode_tuple_from_key, Tuple, TupleT, ENCODED_KEY_MIN_LEN};
use crate::data::value::{DataValue, ValidityTs}; use crate::data::value::{DataValue, ValidityTs};
use crate::parse::sys::HnswIndexConfig; use crate::parse::sys::{HnswIndexConfig, HnswIndexManifest};
use crate::parse::SourceSpan; use crate::parse::SourceSpan;
use crate::query::compile::IndexPositionUse; use crate::query::compile::IndexPositionUse;
use crate::runtime::transact::SessionTx; use crate::runtime::transact::SessionTx;
@ -76,7 +76,8 @@ pub(crate) struct RelationHandle {
pub(crate) access_level: AccessLevel, pub(crate) access_level: AccessLevel,
pub(crate) is_temp: bool, pub(crate) is_temp: bool,
pub(crate) indices: BTreeMap<SmartString<LazyCompact>, (RelationHandle, Vec<usize>)>, pub(crate) indices: BTreeMap<SmartString<LazyCompact>, (RelationHandle, Vec<usize>)>,
pub(crate) hnsw_indices: BTreeMap<SmartString<LazyCompact>, (RelationHandle, HnswIndexConfig)>, pub(crate) hnsw_indices:
BTreeMap<SmartString<LazyCompact>, (RelationHandle, HnswIndexManifest)>,
} }
#[derive( #[derive(
@ -661,13 +662,15 @@ impl<'a> SessionTx<'a> {
if config.vec_fields.is_empty() { if config.vec_fields.is_empty() {
bail!("Cannot create HNSW index without vector fields"); bail!("Cannot create HNSW index without vector fields");
} }
let mut vec_field_indices = vec![];
for field in config.vec_fields.iter() { for field in config.vec_fields.iter() {
let mut found = false; let mut found = false;
for col in rel_handle for (i, col) in rel_handle
.metadata .metadata
.non_keys .keys
.iter() .iter()
.chain(rel_handle.metadata.keys.iter()) .chain(rel_handle.metadata.non_keys.iter())
.enumerate()
{ {
if col.name == *field { if col.name == *field {
let mut col_type = col.typing.coltype.clone(); let mut col_type = col.typing.coltype.clone();
@ -687,6 +690,7 @@ impl<'a> SessionTx<'a> {
} }
found = true; found = true;
vec_field_indices.push(i);
break; break;
} }
} }
@ -696,12 +700,13 @@ impl<'a> SessionTx<'a> {
} }
// We only allow string tags // We only allow string tags
let mut tag_field_indices = vec![];
for field in config.tag_fields.iter() { for field in config.tag_fields.iter() {
for col in rel_handle for (i, col) in rel_handle
.metadata .metadata
.non_keys .keys
.iter() .iter()
.chain(rel_handle.metadata.keys.iter()) .chain(rel_handle.metadata.non_keys.iter()).enumerate()
{ {
if col.name == *field { if col.name == *field {
let mut col_type = col.typing.coltype.clone(); let mut col_type = col.typing.coltype.clone();
@ -716,6 +721,7 @@ impl<'a> SessionTx<'a> {
col_type col_type
); );
} }
tag_field_indices.push(i);
break; break;
} }
} }
@ -811,15 +817,25 @@ impl<'a> SessionTx<'a> {
// TODO // TODO
// add index to relation // add index to relation
let base_name = DataValue::from(&config.base_relation as &str); let manifest = HnswIndexManifest {
let idx_name = config.index_name.clone(); base_relation: config.base_relation.clone(),
index_name: config.index_name.clone(),
vec_dim: config.vec_dim,
dtype: config.dtype,
vec_fields: vec_field_indices,
tag_fields: tag_field_indices,
distance: config.distance,
ef_construction: config.ef_construction,
max_elements: config.max_elements,
index_filter: config.index_filter,
};
rel_handle rel_handle
.hnsw_indices .hnsw_indices
.insert(idx_name, (idx_handle, config)); .insert(config.index_name.clone(), (idx_handle, manifest));
// update relation metadata // update relation metadata
let new_encoded = let new_encoded =
vec![base_name].encode_as_key(RelationId::SYSTEM); vec![DataValue::from(&config.base_relation as &str)].encode_as_key(RelationId::SYSTEM);
let mut meta_val = vec![]; let mut meta_val = vec![];
rel_handle rel_handle
.serialize(&mut Serializer::new(&mut meta_val)) .serialize(&mut Serializer::new(&mut meta_val))
@ -977,7 +993,9 @@ impl<'a> SessionTx<'a> {
idx_name: &Symbol, idx_name: &Symbol,
) -> Result<Vec<(Vec<u8>, Vec<u8>)>> { ) -> Result<Vec<(Vec<u8>, Vec<u8>)>> {
let mut rel = self.get_relation(rel_name, true)?; let mut rel = self.get_relation(rel_name, true)?;
if rel.indices.remove(&idx_name.name).is_none() && rel.hnsw_indices.remove(&idx_name.name).is_none() { if rel.indices.remove(&idx_name.name).is_none()
&& rel.hnsw_indices.remove(&idx_name.name).is_none()
{
#[derive(Debug, Error, Diagnostic)] #[derive(Debug, Error, Diagnostic)]
#[error("index {0} for relation {1} not found")] #[error("index {0} for relation {1} not found")]
#[diagnostic(code(tx::idx_not_found))] #[diagnostic(code(tx::idx_not_found))]

Loading…
Cancel
Save