diff --git a/cozo-core/src/parse/sys.rs b/cozo-core/src/parse/sys.rs index 83f3c219..0fcf5946 100644 --- a/cozo-core/src/parse/sys.rs +++ b/cozo-core/src/parse/sys.rs @@ -42,7 +42,7 @@ pub(crate) enum SysOp { RemoveIndex(Symbol, Symbol), } -#[derive(Debug, Clone, PartialEq, Eq, Hash, serde_derive::Serialize, serde_derive::Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub(crate) struct HnswIndexConfig { pub(crate) base_relation: SmartString, pub(crate) index_name: SmartString, @@ -56,6 +56,21 @@ pub(crate) struct HnswIndexConfig { pub(crate) index_filter: Option, } +#[derive(Debug, Clone, PartialEq, Eq, Hash, serde_derive::Serialize, serde_derive::Deserialize)] +pub(crate) struct HnswIndexManifest { + pub(crate) base_relation: SmartString, + pub(crate) index_name: SmartString, + pub(crate) vec_dim: usize, + pub(crate) dtype: VecElementType, + pub(crate) vec_fields: Vec, + pub(crate) tag_fields: Vec, + pub(crate) distance: HnswDistance, + pub(crate) ef_construction: usize, + pub(crate) max_elements: usize, + pub(crate) index_filter: Option, +} + + #[derive( Debug, Clone, Copy, PartialEq, Eq, Hash, serde_derive::Serialize, serde_derive::Deserialize, )] diff --git a/cozo-core/src/runtime/hnsw.rs b/cozo-core/src/runtime/hnsw.rs new file mode 100644 index 00000000..cbda0519 --- /dev/null +++ b/cozo-core/src/runtime/hnsw.rs @@ -0,0 +1,90 @@ +/* + * Copyright 2023, The Cozo Project Authors. + * + * This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. + * If a copy of the MPL was not distributed with this file, + * You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +use crate::data::expr::{eval_bytecode_pred, Bytecode}; +use crate::data::tuple::Tuple; +use crate::data::value::Vector; +use crate::parse::sys::HnswIndexManifest; +use crate::runtime::relation::RelationHandle; +use crate::runtime::transact::SessionTx; +use crate::DataValue; +use miette::Result; +use smartstring::{LazyCompact, SmartString}; + +impl<'a> SessionTx<'a> { + fn hnsw_put_vector( + &mut self, + vec: &Vector, + idx: usize, + subidx: i32, + orig_table: &RelationHandle, + idx_table: &RelationHandle, + tags: &[SmartString] + ) -> Result<()> { + todo!() + } + pub(crate) fn hnsw_put( + &mut self, + config: &HnswIndexManifest, + orig_table: &RelationHandle, + idx_table: &RelationHandle, + filter: Option<(&[Bytecode], &mut Vec)>, + tuple: &Tuple, + ) -> Result { + if let Some((code, stack)) = filter { + if !eval_bytecode_pred(code, tuple, stack, Default::default())? { + return Ok(false); + } + } + let mut extracted_vectors = vec![]; + for idx in &config.vec_fields { + let val = tuple.get(*idx).unwrap(); + if let DataValue::Vec(v) = val { + extracted_vectors.push((v, *idx, -1 as i32)); + } else if let DataValue::List(l) = val { + for (sidx, v) in l.iter().enumerate() { + if let DataValue::Vec(v) = v { + extracted_vectors.push((v, *idx, sidx as i32)); + } + } + } + } + if extracted_vectors.is_empty() { + return Ok(false); + } + let mut extracted_tags: Vec> = vec![]; + for tag_idx in &config.tag_fields { + let tag_field = tuple.get(*tag_idx).unwrap(); + if let Some(s) = tag_field.get_str() { + extracted_tags.push(SmartString::from(s)); + } else if let DataValue::List(l) = tag_field { + for tag in l { + if let Some(s) = tag.get_str() { + extracted_tags.push(SmartString::from(s)); + } + } + } + } + for (vec, idx, sub) in extracted_vectors { + self.hnsw_put_vector(vec, idx, sub, orig_table, idx_table, &extracted_tags)?; + } + Ok(true) + } + pub(crate) fn hnsw_remove( + &mut self, + config: &HnswIndexManifest, + orig_table: &RelationHandle, + idx_table: &RelationHandle, + tuple: &Tuple, + ) -> Result<()> { + todo!() + } + pub(crate) fn hnsw_knn(&self, node: u64, k: usize) -> Vec<(u64, f32)> { + todo!() + } +} diff --git a/cozo-core/src/runtime/mod.rs b/cozo-core/src/runtime/mod.rs index c110ece3..91dde669 100644 --- a/cozo-core/src/runtime/mod.rs +++ b/cozo-core/src/runtime/mod.rs @@ -11,6 +11,7 @@ pub(crate) mod db; pub(crate) mod imperative; pub(crate) mod relation; pub(crate) mod temp_store; +pub(crate) mod transact; +pub(crate) mod hnsw; #[cfg(test)] mod tests; -pub(crate) mod transact; diff --git a/cozo-core/src/runtime/relation.rs b/cozo-core/src/runtime/relation.rs index 2a67bbd0..ace0d0e8 100644 --- a/cozo-core/src/runtime/relation.rs +++ b/cozo-core/src/runtime/relation.rs @@ -23,7 +23,7 @@ use crate::data::relation::{ColType, ColumnDef, NullableColType, StoredRelationM use crate::data::symb::Symbol; use crate::data::tuple::{decode_tuple_from_key, Tuple, TupleT, ENCODED_KEY_MIN_LEN}; use crate::data::value::{DataValue, ValidityTs}; -use crate::parse::sys::HnswIndexConfig; +use crate::parse::sys::{HnswIndexConfig, HnswIndexManifest}; use crate::parse::SourceSpan; use crate::query::compile::IndexPositionUse; use crate::runtime::transact::SessionTx; @@ -76,7 +76,8 @@ pub(crate) struct RelationHandle { pub(crate) access_level: AccessLevel, pub(crate) is_temp: bool, pub(crate) indices: BTreeMap, (RelationHandle, Vec)>, - pub(crate) hnsw_indices: BTreeMap, (RelationHandle, HnswIndexConfig)>, + pub(crate) hnsw_indices: + BTreeMap, (RelationHandle, HnswIndexManifest)>, } #[derive( @@ -661,13 +662,15 @@ impl<'a> SessionTx<'a> { if config.vec_fields.is_empty() { bail!("Cannot create HNSW index without vector fields"); } + let mut vec_field_indices = vec![]; for field in config.vec_fields.iter() { let mut found = false; - for col in rel_handle + for (i, col) in rel_handle .metadata - .non_keys + .keys .iter() - .chain(rel_handle.metadata.keys.iter()) + .chain(rel_handle.metadata.non_keys.iter()) + .enumerate() { if col.name == *field { let mut col_type = col.typing.coltype.clone(); @@ -687,6 +690,7 @@ impl<'a> SessionTx<'a> { } found = true; + vec_field_indices.push(i); break; } } @@ -696,12 +700,13 @@ impl<'a> SessionTx<'a> { } // We only allow string tags + let mut tag_field_indices = vec![]; for field in config.tag_fields.iter() { - for col in rel_handle + for (i, col) in rel_handle .metadata - .non_keys + .keys .iter() - .chain(rel_handle.metadata.keys.iter()) + .chain(rel_handle.metadata.non_keys.iter()).enumerate() { if col.name == *field { let mut col_type = col.typing.coltype.clone(); @@ -716,6 +721,7 @@ impl<'a> SessionTx<'a> { col_type ); } + tag_field_indices.push(i); break; } } @@ -811,15 +817,25 @@ impl<'a> SessionTx<'a> { // TODO // add index to relation - let base_name = DataValue::from(&config.base_relation as &str); - let idx_name = config.index_name.clone(); + let manifest = HnswIndexManifest { + base_relation: config.base_relation.clone(), + index_name: config.index_name.clone(), + vec_dim: config.vec_dim, + dtype: config.dtype, + vec_fields: vec_field_indices, + tag_fields: tag_field_indices, + distance: config.distance, + ef_construction: config.ef_construction, + max_elements: config.max_elements, + index_filter: config.index_filter, + }; rel_handle .hnsw_indices - .insert(idx_name, (idx_handle, config)); + .insert(config.index_name.clone(), (idx_handle, manifest)); // update relation metadata let new_encoded = - vec![base_name].encode_as_key(RelationId::SYSTEM); + vec![DataValue::from(&config.base_relation as &str)].encode_as_key(RelationId::SYSTEM); let mut meta_val = vec![]; rel_handle .serialize(&mut Serializer::new(&mut meta_val)) @@ -977,7 +993,9 @@ impl<'a> SessionTx<'a> { idx_name: &Symbol, ) -> Result, Vec)>> { let mut rel = self.get_relation(rel_name, true)?; - if rel.indices.remove(&idx_name.name).is_none() && rel.hnsw_indices.remove(&idx_name.name).is_none() { + if rel.indices.remove(&idx_name.name).is_none() + && rel.hnsw_indices.remove(&idx_name.name).is_none() + { #[derive(Debug, Error, Diagnostic)] #[error("index {0} for relation {1} not found")] #[diagnostic(code(tx::idx_not_found))]