hnsw preliminary

main
Ziyang Hu 1 year ago
parent a8f946a719
commit 6ccdf71892

@ -42,7 +42,7 @@ pub(crate) enum SysOp {
RemoveIndex(Symbol, Symbol),
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, serde_derive::Serialize, serde_derive::Deserialize)]
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub(crate) struct HnswIndexConfig {
pub(crate) base_relation: SmartString<LazyCompact>,
pub(crate) index_name: SmartString<LazyCompact>,
@ -56,6 +56,21 @@ pub(crate) struct HnswIndexConfig {
pub(crate) index_filter: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, serde_derive::Serialize, serde_derive::Deserialize)]
pub(crate) struct HnswIndexManifest {
pub(crate) base_relation: SmartString<LazyCompact>,
pub(crate) index_name: SmartString<LazyCompact>,
pub(crate) vec_dim: usize,
pub(crate) dtype: VecElementType,
pub(crate) vec_fields: Vec<usize>,
pub(crate) tag_fields: Vec<usize>,
pub(crate) distance: HnswDistance,
pub(crate) ef_construction: usize,
pub(crate) max_elements: usize,
pub(crate) index_filter: Option<String>,
}
#[derive(
Debug, Clone, Copy, PartialEq, Eq, Hash, serde_derive::Serialize, serde_derive::Deserialize,
)]

@ -0,0 +1,90 @@
/*
* Copyright 2023, The Cozo Project Authors.
*
* This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0.
* If a copy of the MPL was not distributed with this file,
* You can obtain one at https://mozilla.org/MPL/2.0/.
*/
use crate::data::expr::{eval_bytecode_pred, Bytecode};
use crate::data::tuple::Tuple;
use crate::data::value::Vector;
use crate::parse::sys::HnswIndexManifest;
use crate::runtime::relation::RelationHandle;
use crate::runtime::transact::SessionTx;
use crate::DataValue;
use miette::Result;
use smartstring::{LazyCompact, SmartString};
impl<'a> SessionTx<'a> {
fn hnsw_put_vector(
&mut self,
vec: &Vector,
idx: usize,
subidx: i32,
orig_table: &RelationHandle,
idx_table: &RelationHandle,
tags: &[SmartString<LazyCompact>]
) -> Result<()> {
todo!()
}
pub(crate) fn hnsw_put(
&mut self,
config: &HnswIndexManifest,
orig_table: &RelationHandle,
idx_table: &RelationHandle,
filter: Option<(&[Bytecode], &mut Vec<DataValue>)>,
tuple: &Tuple,
) -> Result<bool> {
if let Some((code, stack)) = filter {
if !eval_bytecode_pred(code, tuple, stack, Default::default())? {
return Ok(false);
}
}
let mut extracted_vectors = vec![];
for idx in &config.vec_fields {
let val = tuple.get(*idx).unwrap();
if let DataValue::Vec(v) = val {
extracted_vectors.push((v, *idx, -1 as i32));
} else if let DataValue::List(l) = val {
for (sidx, v) in l.iter().enumerate() {
if let DataValue::Vec(v) = v {
extracted_vectors.push((v, *idx, sidx as i32));
}
}
}
}
if extracted_vectors.is_empty() {
return Ok(false);
}
let mut extracted_tags: Vec<SmartString<LazyCompact>> = vec![];
for tag_idx in &config.tag_fields {
let tag_field = tuple.get(*tag_idx).unwrap();
if let Some(s) = tag_field.get_str() {
extracted_tags.push(SmartString::from(s));
} else if let DataValue::List(l) = tag_field {
for tag in l {
if let Some(s) = tag.get_str() {
extracted_tags.push(SmartString::from(s));
}
}
}
}
for (vec, idx, sub) in extracted_vectors {
self.hnsw_put_vector(vec, idx, sub, orig_table, idx_table, &extracted_tags)?;
}
Ok(true)
}
pub(crate) fn hnsw_remove(
&mut self,
config: &HnswIndexManifest,
orig_table: &RelationHandle,
idx_table: &RelationHandle,
tuple: &Tuple,
) -> Result<()> {
todo!()
}
pub(crate) fn hnsw_knn(&self, node: u64, k: usize) -> Vec<(u64, f32)> {
todo!()
}
}

@ -11,6 +11,7 @@ pub(crate) mod db;
pub(crate) mod imperative;
pub(crate) mod relation;
pub(crate) mod temp_store;
pub(crate) mod transact;
pub(crate) mod hnsw;
#[cfg(test)]
mod tests;
pub(crate) mod transact;

@ -23,7 +23,7 @@ use crate::data::relation::{ColType, ColumnDef, NullableColType, StoredRelationM
use crate::data::symb::Symbol;
use crate::data::tuple::{decode_tuple_from_key, Tuple, TupleT, ENCODED_KEY_MIN_LEN};
use crate::data::value::{DataValue, ValidityTs};
use crate::parse::sys::HnswIndexConfig;
use crate::parse::sys::{HnswIndexConfig, HnswIndexManifest};
use crate::parse::SourceSpan;
use crate::query::compile::IndexPositionUse;
use crate::runtime::transact::SessionTx;
@ -76,7 +76,8 @@ pub(crate) struct RelationHandle {
pub(crate) access_level: AccessLevel,
pub(crate) is_temp: bool,
pub(crate) indices: BTreeMap<SmartString<LazyCompact>, (RelationHandle, Vec<usize>)>,
pub(crate) hnsw_indices: BTreeMap<SmartString<LazyCompact>, (RelationHandle, HnswIndexConfig)>,
pub(crate) hnsw_indices:
BTreeMap<SmartString<LazyCompact>, (RelationHandle, HnswIndexManifest)>,
}
#[derive(
@ -661,13 +662,15 @@ impl<'a> SessionTx<'a> {
if config.vec_fields.is_empty() {
bail!("Cannot create HNSW index without vector fields");
}
let mut vec_field_indices = vec![];
for field in config.vec_fields.iter() {
let mut found = false;
for col in rel_handle
for (i, col) in rel_handle
.metadata
.non_keys
.keys
.iter()
.chain(rel_handle.metadata.keys.iter())
.chain(rel_handle.metadata.non_keys.iter())
.enumerate()
{
if col.name == *field {
let mut col_type = col.typing.coltype.clone();
@ -687,6 +690,7 @@ impl<'a> SessionTx<'a> {
}
found = true;
vec_field_indices.push(i);
break;
}
}
@ -696,12 +700,13 @@ impl<'a> SessionTx<'a> {
}
// We only allow string tags
let mut tag_field_indices = vec![];
for field in config.tag_fields.iter() {
for col in rel_handle
for (i, col) in rel_handle
.metadata
.non_keys
.keys
.iter()
.chain(rel_handle.metadata.keys.iter())
.chain(rel_handle.metadata.non_keys.iter()).enumerate()
{
if col.name == *field {
let mut col_type = col.typing.coltype.clone();
@ -716,6 +721,7 @@ impl<'a> SessionTx<'a> {
col_type
);
}
tag_field_indices.push(i);
break;
}
}
@ -811,15 +817,25 @@ impl<'a> SessionTx<'a> {
// TODO
// add index to relation
let base_name = DataValue::from(&config.base_relation as &str);
let idx_name = config.index_name.clone();
let manifest = HnswIndexManifest {
base_relation: config.base_relation.clone(),
index_name: config.index_name.clone(),
vec_dim: config.vec_dim,
dtype: config.dtype,
vec_fields: vec_field_indices,
tag_fields: tag_field_indices,
distance: config.distance,
ef_construction: config.ef_construction,
max_elements: config.max_elements,
index_filter: config.index_filter,
};
rel_handle
.hnsw_indices
.insert(idx_name, (idx_handle, config));
.insert(config.index_name.clone(), (idx_handle, manifest));
// update relation metadata
let new_encoded =
vec![base_name].encode_as_key(RelationId::SYSTEM);
vec![DataValue::from(&config.base_relation as &str)].encode_as_key(RelationId::SYSTEM);
let mut meta_val = vec![];
rel_handle
.serialize(&mut Serializer::new(&mut meta_val))
@ -977,7 +993,9 @@ impl<'a> SessionTx<'a> {
idx_name: &Symbol,
) -> Result<Vec<(Vec<u8>, Vec<u8>)>> {
let mut rel = self.get_relation(rel_name, true)?;
if rel.indices.remove(&idx_name.name).is_none() && rel.hnsw_indices.remove(&idx_name.name).is_none() {
if rel.indices.remove(&idx_name.name).is_none()
&& rel.hnsw_indices.remove(&idx_name.name).is_none()
{
#[derive(Debug, Error, Diagnostic)]
#[error("index {0} for relation {1} not found")]
#[diagnostic(code(tx::idx_not_found))]

Loading…
Cancel
Save