diff --git a/cozo-core/src/data/expr.rs b/cozo-core/src/data/expr.rs index 9e7fac71..0a3da6f4 100644 --- a/cozo-core/src/data/expr.rs +++ b/cozo-core/src/data/expr.rs @@ -789,6 +789,7 @@ pub(crate) fn get_op(name: &str) -> Option<&'static Op> { "is_infinite" => &OP_IS_INFINITE, "is_nan" => &OP_IS_NAN, "is_uuid" => &OP_IS_UUID, + "is_vec" => &OP_IS_VEC, "length" => &OP_LENGTH, "sorted" => &OP_SORTED, "reverse" => &OP_REVERSE, diff --git a/cozo-core/src/lib.rs b/cozo-core/src/lib.rs index 72371ad6..1a9c56a8 100644 --- a/cozo-core/src/lib.rs +++ b/cozo-core/src/lib.rs @@ -140,7 +140,6 @@ impl DbInstance { #[derive(serde_derive::Deserialize)] struct TiKvOpts { end_points: Vec, - #[serde(default = "Default::default")] optimistic: bool, } let opts: TiKvOpts = serde_json::from_str(options).into_diagnostic()?; diff --git a/cozo-core/src/parse/sys.rs b/cozo-core/src/parse/sys.rs index dcda0cb2..83f3c219 100644 --- a/cozo-core/src/parse/sys.rs +++ b/cozo-core/src/parse/sys.rs @@ -14,7 +14,6 @@ use miette::{ensure, miette, Diagnostic, Result}; use smartstring::{LazyCompact, SmartString}; use thiserror::Error; -use crate::data::functions::OP_LIST; use crate::data::program::InputProgram; use crate::data::relation::VecElementType; use crate::data::symb::Symbol; @@ -23,7 +22,7 @@ use crate::parse::expr::build_expr; use crate::parse::query::parse_query; use crate::parse::{ExtractSpan, Pairs, Rule, SourceSpan}; use crate::runtime::relation::AccessLevel; -use crate::{Expr, FixedRule}; +use crate::FixedRule; pub(crate) enum SysOp { Compact, @@ -41,10 +40,9 @@ pub(crate) enum SysOp { CreateIndex(Symbol, Symbol, Vec), CreateVectorIndex(HnswIndexConfig), RemoveIndex(Symbol, Symbol), - RemoveVectorIndex(Symbol, Symbol), } -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, serde_derive::Serialize, serde_derive::Deserialize)] pub(crate) struct HnswIndexConfig { pub(crate) base_relation: SmartString, pub(crate) index_name: SmartString, @@ -58,7 +56,9 @@ pub(crate) struct HnswIndexConfig { pub(crate) index_filter: Option, } -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[derive( + Debug, Clone, Copy, PartialEq, Eq, Hash, serde_derive::Serialize, serde_derive::Deserialize, +)] pub(crate) enum HnswDistance { L2, InnerProduct, @@ -189,7 +189,6 @@ pub(crate) fn parse_sys( let inner = inner.into_inner().next().unwrap(); match inner.as_rule() { Rule::index_create_hnsw => { - let span = inner.extract_span(); let mut inner = inner.into_inner(); let rel = inner.next().unwrap(); let name = inner.next().unwrap(); @@ -280,7 +279,7 @@ pub(crate) fn parse_sys( let mut inner = inner.into_inner(); let rel = inner.next().unwrap(); let name = inner.next().unwrap(); - SysOp::RemoveVectorIndex( + SysOp::RemoveIndex( Symbol::new(rel.as_str(), rel.extract_span()), Symbol::new(name.as_str(), name.extract_span()), ) diff --git a/cozo-core/src/runtime/db.rs b/cozo-core/src/runtime/db.rs index 27751be2..9c0c0c3c 100644 --- a/cozo-core/src/runtime/db.rs +++ b/cozo-core/src/runtime/db.rs @@ -6,42 +6,41 @@ * You can obtain one at https://mozilla.org/MPL/2.0/. */ -use std::collections::{BTreeMap, BTreeSet}; use std::collections::btree_map::Entry; +use std::collections::{BTreeMap, BTreeSet}; use std::default::Default; use std::fmt::{Debug, Formatter}; use std::iter; use std::path::Path; -use std::sync::{Arc, Mutex}; #[allow(unused_imports)] use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering}; +use std::sync::{Arc, Mutex}; #[allow(unused_imports)] use std::thread; #[allow(unused_imports)] use std::time::{Duration, SystemTime, UNIX_EPOCH}; #[allow(unused_imports)] -use crossbeam::channel::{bounded, Receiver, Sender, unbounded}; +use crossbeam::channel::{bounded, unbounded, Receiver, Sender}; use crossbeam::sync::ShardedLock; use either::{Left, Right}; use itertools::Itertools; -#[allow(unused_imports)] -use miette::{bail, Diagnostic, ensure, IntoDiagnostic, miette, Result, WrapErr}; use miette::Report; +#[allow(unused_imports)] +use miette::{bail, ensure, miette, Diagnostic, IntoDiagnostic, Result, WrapErr}; use serde_json::json; use smartstring::{LazyCompact, SmartString}; use thiserror::Error; -use crate::{decode_tuple_from_kv, FixedRule}; use crate::data::functions::current_validity; use crate::data::json::JsonValue; use crate::data::program::{InputProgram, QueryAssertion, RelationOp}; use crate::data::relation::ColumnDef; use crate::data::tuple::{Tuple, TupleT}; -use crate::data::value::{DataValue, LARGEST_UTF_CHAR, ValidityTs}; +use crate::data::value::{DataValue, ValidityTs, LARGEST_UTF_CHAR}; use crate::fixed_rule::DEFAULT_FIXED_RULES; -use crate::parse::{CozoScript, parse_script, SourceSpan}; use crate::parse::sys::SysOp; +use crate::parse::{parse_script, CozoScript, SourceSpan}; use crate::query::compile::{CompiledProgram, CompiledRule, CompiledRuleSet}; use crate::query::ra::{ FilteredRA, InnerJoin, NegJoin, RelAlgebra, ReorderRA, StoredRA, StoredWithValidityRA, @@ -52,11 +51,12 @@ use crate::runtime::callback::{ CallbackCollector, CallbackDeclaration, CallbackOp, EventCallbackRegistry, }; use crate::runtime::relation::{ - AccessLevel, extend_tuple_from_v, InsufficientAccessLevel, RelationHandle, RelationId, + extend_tuple_from_v, AccessLevel, InsufficientAccessLevel, RelationHandle, RelationId, }; use crate::runtime::transact::SessionTx; -use crate::storage::{Storage, StoreTx}; use crate::storage::temp::TempStorage; +use crate::storage::{Storage, StoreTx}; +use crate::{decode_tuple_from_kv, FixedRule}; pub(crate) struct RunningQueryHandle { pub(crate) started_at: f64, @@ -181,10 +181,15 @@ impl NamedRows { let headers = headers .as_array() .ok_or_else(|| miette!("'headers' field must be an array"))?; - let headers = headers.iter().map(|h| -> Result { - let h = h.as_str().ok_or_else(|| miette!("'headers' field must be an array of strings"))?; - Ok(h.to_string()) - }).try_collect()?; + let headers = headers + .iter() + .map(|h| -> Result { + let h = h + .as_str() + .ok_or_else(|| miette!("'headers' field must be an array of strings"))?; + Ok(h.to_string()) + }) + .try_collect()?; let rows = value .get("rows") .ok_or_else(|| miette!("NamedRows requires 'rows' field"))?; @@ -1158,11 +1163,18 @@ impl<'s, S: Storage<'s>> Db { )) } SysOp::CreateVectorIndex(config) => { - dbg!(&config); - todo!() - } - SysOp::RemoveVectorIndex(rel_name, idx_name) => { - todo!("remove vector index") + let lock = self + .obtain_relation_locks(iter::once(&config.base_relation)) + .pop() + .unwrap(); + let _guard = lock.write().unwrap(); + let mut tx = self.transact_write()?; + tx.create_hnsw_index(config)?; + tx.commit_tx()?; + Ok(NamedRows::new( + vec![STATUS_STR.to_string()], + vec![vec![DataValue::from(OK_STR)]], + )) } SysOp::RemoveIndex(rel_name, idx_name) => { let lock = self @@ -1631,13 +1643,13 @@ impl Poison { pub(crate) fn seconds_since_the_epoch() -> Result { #[cfg(not(target_arch = "wasm32"))] - let now = SystemTime::now(); + let now = SystemTime::now(); #[cfg(not(target_arch = "wasm32"))] - return Ok(now + return Ok(now .duration_since(UNIX_EPOCH) .into_diagnostic()? .as_secs_f64()); #[cfg(target_arch = "wasm32")] - Ok(js_sys::Date::now()) -} \ No newline at end of file + Ok(js_sys::Date::now()) +} diff --git a/cozo-core/src/runtime/relation.rs b/cozo-core/src/runtime/relation.rs index b550da5d..02954279 100644 --- a/cozo-core/src/runtime/relation.rs +++ b/cozo-core/src/runtime/relation.rs @@ -19,10 +19,11 @@ use smartstring::{LazyCompact, SmartString}; use thiserror::Error; use crate::data::memcmp::MemCmpEncoder; -use crate::data::relation::StoredRelationMetadata; +use crate::data::relation::{ColType, ColumnDef, NullableColType, StoredRelationMetadata}; use crate::data::symb::Symbol; use crate::data::tuple::{decode_tuple_from_key, Tuple, TupleT, ENCODED_KEY_MIN_LEN}; use crate::data::value::{DataValue, ValidityTs}; +use crate::parse::sys::HnswIndexConfig; use crate::parse::SourceSpan; use crate::query::compile::IndexPositionUse; use crate::runtime::transact::SessionTx; @@ -73,10 +74,9 @@ pub(crate) struct RelationHandle { pub(crate) rm_triggers: Vec, pub(crate) replace_triggers: Vec, pub(crate) access_level: AccessLevel, - #[serde(default)] pub(crate) is_temp: bool, - #[serde(default)] pub(crate) indices: BTreeMap, (RelationHandle, Vec)>, + pub(crate) hnsw_indices: BTreeMap, (RelationHandle, HnswIndexConfig)>, } #[derive( @@ -537,6 +537,7 @@ impl<'a> SessionTx<'a> { access_level: AccessLevel::Normal, is_temp, indices: Default::default(), + hnsw_indices: Default::default(), }; let name_key = vec![DataValue::Str(meta.name.clone())].encode_as_key(RelationId::SYSTEM); @@ -587,8 +588,11 @@ impl<'a> SessionTx<'a> { // bail!("Cannot destroy temp relation"); // } let store = self.get_relation(name, true)?; - if !store.indices.is_empty() { - bail!("Cannot remove stored relation `{}` with indices attached.", name); + if !store.indices.is_empty() || !store.hnsw_indices.is_empty() { + bail!( + "Cannot remove stored relation `{}` with indices attached.", + name + ); } if store.access_level < AccessLevel::Normal { bail!(InsufficientAccessLevel( @@ -603,6 +607,11 @@ impl<'a> SessionTx<'a> { to_clean.extend(more_to_clean); } + for k in store.hnsw_indices.keys() { + let more_to_clean = self.destroy_relation(&format!("{name}:{k}"))?; + to_clean.extend(more_to_clean); + } + let key = DataValue::from(name); let encoded = vec![key].encode_as_key(RelationId::SYSTEM); if is_temp { @@ -629,14 +638,206 @@ impl<'a> SessionTx<'a> { Ok(()) } + pub(crate) fn create_hnsw_index(&mut self, config: HnswIndexConfig) -> Result<()> { + // Get relation handle + let mut rel_handle = self.get_relation(&config.base_relation, true)?; + + // Check if index already exists + if rel_handle.indices.contains_key(&config.index_name) + || rel_handle.hnsw_indices.contains_key(&config.index_name) + { + #[derive(Debug, Error, Diagnostic)] + #[error("index {0} for relation {1} already exists")] + #[diagnostic(code(tx::index_already_exists))] + pub(crate) struct IndexAlreadyExists(String, String); + + bail!(IndexAlreadyExists( + config.index_name.to_string(), + config.index_name.to_string() + )); + } + + // Check that what we are indexing are really vectors + if config.vec_fields.is_empty() { + bail!("Cannot create HNSW index without vector fields"); + } + for field in config.vec_fields.iter() { + let mut found = false; + for col in rel_handle + .metadata + .non_keys + .iter() + .chain(rel_handle.metadata.keys.iter()) + { + if col.name == *field { + let mut col_type = col.typing.coltype.clone(); + if let ColType::List { eltype, .. } = &col_type { + col_type = eltype.coltype.clone(); + } + + if let ColType::Vec { eltype, len } = col_type { + if eltype != config.dtype { + bail!("Cannot create HNSW index with field {} of type {:?} (expected {:?})", field, eltype, config.dtype); + } + if len != config.vec_dim { + bail!("Cannot create HNSW index with field {} of dimension {} (expected {})", field, len, config.vec_dim); + } + } else { + bail!("Cannot create HNSW index with non-vector field {}", field) + } + + found = true; + break; + } + } + if !found { + bail!("Cannot create HNSW index with non-existent field {}", field); + } + } + + // We only allow string tags + for field in config.tag_fields.iter() { + for col in rel_handle + .metadata + .non_keys + .iter() + .chain(rel_handle.metadata.keys.iter()) + { + if col.name == *field { + let mut col_type = col.typing.coltype.clone(); + if let ColType::List { eltype, .. } = &col_type { + col_type = eltype.coltype.clone(); + } + + if col_type != ColType::String { + bail!( + "Cannot create HNSW index with field {} of type {:?} (expected Str)", + field, + col_type + ); + } + break; + } + } + } + + // Build key columns definitions + let mut idx_keys: Vec = vec![ColumnDef { + name: SmartString::from("layer"), + typing: NullableColType { + coltype: ColType::Int, + nullable: false, + }, + default_gen: None, + }]; + for prefix in ["fr", "to"] { + for col in rel_handle.metadata.keys.iter() { + let mut col = col.clone(); + col.name = SmartString::from(format!("{}_{}", prefix, config.index_name)); + idx_keys.push(col); + } + idx_keys.push(ColumnDef { + name: SmartString::from(format!("{}__field", prefix)), + typing: NullableColType { + coltype: ColType::Int, + nullable: false, + }, + default_gen: None, + }); + idx_keys.push(ColumnDef { + name: SmartString::from(format!("{}__sub_idx", prefix)), + typing: NullableColType { + coltype: ColType::Int, + nullable: false, + }, + default_gen: None, + }); + } + + // Build non-key columns definitions + let non_idx_keys = vec![ + ColumnDef { + name: SmartString::from("dist"), + typing: NullableColType { + coltype: ColType::Float, + nullable: false, + }, + default_gen: None, + }, + ColumnDef { + name: SmartString::from("tags"), + typing: NullableColType { + coltype: ColType::List { + eltype: Box::new(NullableColType { + coltype: ColType::String, + nullable: false, + }), + len: None, + }, + nullable: false, + }, + default_gen: None, + }, + ]; + // create index relation + let key_bindings = idx_keys + .iter() + .map(|col| Symbol::new(col.name.clone(), Default::default())) + .collect(); + let dep_bindings = non_idx_keys + .iter() + .map(|col| Symbol::new(col.name.clone(), Default::default())) + .collect(); + let idx_handle = InputRelationHandle { + name: Symbol::new( + format!("{}:{}", config.base_relation, config.index_name), + Default::default(), + ), + metadata: StoredRelationMetadata { + keys: idx_keys, + non_keys: non_idx_keys, + }, + key_bindings, + dep_bindings, + span: Default::default(), + }; + let idx_handle = self.create_relation(idx_handle)?; + + // populate index + // TODO + + // add index to relation + let base_name = DataValue::from(&config.base_relation as &str); + let idx_name = config.index_name.clone(); + rel_handle + .hnsw_indices + .insert(idx_name, (idx_handle, config)); + + // update relation metadata + let new_encoded = + vec![base_name].encode_as_key(RelationId::SYSTEM); + let mut meta_val = vec![]; + rel_handle + .serialize(&mut Serializer::new(&mut meta_val)) + .unwrap(); + self.store_tx.put(&new_encoded, &meta_val)?; + + Ok(()) + } + pub(crate) fn create_index( &mut self, rel_name: &Symbol, idx_name: &Symbol, cols: Vec, ) -> Result<()> { + // Get relation handle let mut rel_handle = self.get_relation(rel_name, true)?; - if rel_handle.indices.contains_key(&idx_name.name) { + + // Check if index already exists + if rel_handle.indices.contains_key(&idx_name.name) + || rel_handle.hnsw_indices.contains_key(&idx_name.name) + { #[derive(Debug, Error, Diagnostic)] #[error("index {0} for relation {1} already exists")] #[diagnostic(code(tx::index_already_exists))] @@ -648,6 +849,7 @@ impl<'a> SessionTx<'a> { )); } + // Build column definitions let mut col_defs = vec![]; 'outer: for col in cols.iter() { for orig_col in rel_handle @@ -692,6 +894,7 @@ impl<'a> SessionTx<'a> { non_keys: vec![], }; + // create index relation let idx_handle = InputRelationHandle { name: Symbol::new( format!("{}:{}", rel_name.name, idx_name.name), @@ -747,10 +950,12 @@ impl<'a> SessionTx<'a> { } } + // add index to relation rel_handle .indices .insert(idx_name.name.clone(), (idx_handle, extraction_indices)); + // update relation metadata let new_encoded = vec![DataValue::from(&rel_name.name as &str)].encode_as_key(RelationId::SYSTEM); let mut meta_val = vec![]; @@ -762,9 +967,13 @@ impl<'a> SessionTx<'a> { Ok(()) } - pub(crate) fn remove_index(&mut self, rel_name: &Symbol, idx_name: &Symbol) -> Result, Vec)>> { + pub(crate) fn remove_index( + &mut self, + rel_name: &Symbol, + idx_name: &Symbol, + ) -> Result, Vec)>> { let mut rel = self.get_relation(rel_name, true)?; - if rel.indices.remove(&idx_name.name).is_none() { + if rel.indices.remove(&idx_name.name).is_none() && rel.hnsw_indices.remove(&idx_name.name).is_none() { #[derive(Debug, Error, Diagnostic)] #[error("index {0} for relation {1} not found")] #[diagnostic(code(tx::idx_not_found))] diff --git a/cozo-core/src/runtime/tests.rs b/cozo-core/src/runtime/tests.rs index 7770830b..6948b543 100644 --- a/cozo-core/src/runtime/tests.rs +++ b/cozo-core/src/runtime/tests.rs @@ -786,7 +786,7 @@ fn test_vec_index() { .unwrap(); db.run_script(r" ::hnsw create a:vec { - dim: 128, + dim: 8, dtype: F32, fields: [v], tags: tags, diff --git a/cozo-core/src/storage/rocks.rs b/cozo-core/src/storage/rocks.rs index 198669f2..2eb7c0b8 100644 --- a/cozo-core/src/storage/rocks.rs +++ b/cozo-core/src/storage/rocks.rs @@ -23,7 +23,7 @@ use crate::utils::swap_option_result; use crate::Db; const KEY_PREFIX_LEN: usize = 9; -const CURRENT_STORAGE_VERSION: u64 = 1; +const CURRENT_STORAGE_VERSION: u64 = 2; /// Creates a RocksDB database object. /// This is currently the fastest persistent storage and it can