diff --git a/cozo-core/src/cozoscript.pest b/cozo-core/src/cozoscript.pest index 493d82b3..36df7164 100644 --- a/cozo-core/src/cozoscript.pest +++ b/cozo-core/src/cozoscript.pest @@ -54,7 +54,7 @@ param = @{"$" ~ (XID_CONTINUE | "_")*} ident = @{XID_START ~ ("_" | XID_CONTINUE)*} underscore_ident = @{("_" | XID_START) ~ ("_" | XID_CONTINUE)*} relation_ident = @{"*" ~ (compound_or_index_ident | underscore_ident)} -vector_index_ident = _{"~" ~ compound_or_index_ident} +search_index_ident = _{"~" ~ compound_or_index_ident} compound_ident = @{ident ~ ("." ~ ident)*} compound_or_index_ident = @{ident ~ ("." ~ ident)* ~ (":" ~ ident)?} @@ -80,10 +80,10 @@ rule_body = {(disjunction ~ ",")* ~ disjunction?} rule_apply = {underscore_ident ~ "[" ~ apply_args ~ "]"} relation_named_apply = {relation_ident ~ "{" ~ named_apply_args ~ validity_clause? ~ "}"} relation_apply = {relation_ident ~ "[" ~ apply_args ~ validity_clause? ~ "]"} -vector_search = {vector_index_ident ~ "{" ~ named_apply_args ~ "|" ~ (index_opt_field ~ ",")* ~ index_opt_field? ~ "}"} +search_apply = {search_index_ident ~ "{" ~ named_apply_args ~ "|" ~ (index_opt_field ~ ",")* ~ index_opt_field? ~ "}"} disjunction = {(atom ~ "or" )* ~ atom} -atom = _{ negation | relation_named_apply | relation_apply | vector_search | rule_apply | unify_multi | unify | expr | grouped} +atom = _{ negation | relation_named_apply | relation_apply | search_apply | rule_apply | unify_multi | unify | expr | grouped} unify = {var ~ "=" ~ expr} unify_multi = {var ~ "in" ~ expr} negation = {"not" ~ atom} diff --git a/cozo-core/src/data/program.rs b/cozo-core/src/data/program.rs index b79d33b5..ed5e80f0 100644 --- a/cozo-core/src/data/program.rs +++ b/cozo-core/src/data/program.rs @@ -11,7 +11,7 @@ use std::collections::{BTreeMap, BTreeSet}; use std::fmt::{Debug, Display, Formatter}; use std::sync::Arc; -use miette::{bail, ensure, Diagnostic, Result}; +use miette::{bail, ensure, miette, Diagnostic, Result}; use smallvec::SmallVec; use smartstring::{LazyCompact, SmartString}; use thiserror::Error; @@ -915,25 +915,17 @@ pub(crate) enum InputAtom { Unification { inner: Unification, }, - HnswSearch { - inner: HnswSearchInput, + Search { + inner: SearchInput, }, } #[derive(Clone)] -pub(crate) struct HnswSearchInput { +pub(crate) struct SearchInput { pub(crate) relation: Symbol, pub(crate) index: Symbol, pub(crate) bindings: BTreeMap, Expr>, - pub(crate) k: usize, - pub(crate) ef: usize, - pub(crate) query: Expr, - pub(crate) bind_field: Option, - pub(crate) bind_field_idx: Option, - pub(crate) bind_distance: Option, - pub(crate) bind_vector: Option, - pub(crate) radius: Option, - pub(crate) filter: Option, + pub(crate) parameters: BTreeMap, Expr>, pub(crate) span: SourceSpan, } @@ -966,41 +958,14 @@ impl HnswSearch { } } -impl HnswSearchInput { - pub(crate) fn normalize( +impl SearchInput { + fn normalize_hnsw( mut self, + base_handle: RelationHandle, + idx_handle: RelationHandle, + manifest: HnswIndexManifest, gen: &mut TempSymbGen, - tx: &SessionTx<'_>, ) -> Result { - let base_handle = tx.get_relation(&self.relation, false)?; - if base_handle.access_level < AccessLevel::ReadOnly { - bail!(InsufficientAccessLevel( - base_handle.name.to_string(), - "reading rows".to_string(), - base_handle.access_level - )); - } - let (idx_handle, manifest) = base_handle - .hnsw_indices - .get(&self.index.name) - .ok_or_else(|| { - #[derive(Debug, Error, Diagnostic)] - #[error("hnsw index {name} not found on relation {relation}")] - #[diagnostic(code(eval::hnsw_index_not_found))] - struct HnswIndexNotFound { - relation: String, - name: String, - #[label] - span: SourceSpan, - } - HnswIndexNotFound { - relation: self.relation.name.to_string(), - name: self.index.name.to_string(), - span: self.index.span, - } - })? - .clone(); - let mut conj = Vec::with_capacity(self.bindings.len() + 8); let mut bindings = Vec::with_capacity(self.bindings.len()); let mut seen_variables = BTreeSet::new(); @@ -1060,7 +1025,16 @@ impl HnswSearchInput { )); } - let query = match self.query { + #[derive(Debug, Error, Diagnostic)] + #[error("Field `{0}` is required for HNSW search")] + #[diagnostic(code(parser::hnsw_query_required))] + struct HnswRequiredMissing(String, #[label] SourceSpan); + + let query = match self + .parameters + .remove("query") + .ok_or_else(|| miette!(HnswRequiredMissing("query".to_string(), self.span)))? + { Expr::Binding { var, .. } => var, expr => { let span = expr.span(); @@ -1076,7 +1050,54 @@ impl HnswSearchInput { } }; - let bind_field = match self.bind_field { + let k_expr = self + .parameters + .remove("k") + .ok_or_else(|| miette!(HnswRequiredMissing("k".to_string(), self.span)))?; + let k = k_expr.eval_to_const()?; + let k = k.get_int().ok_or(ExpectedPosIntForHnswK(self.span))?; + + #[derive(Debug, Error, Diagnostic)] + #[error("Expected positive integer for `k`")] + #[diagnostic(code(parser::expected_int_for_hnsw_k))] + struct ExpectedPosIntForHnswK(#[label] SourceSpan); + + ensure!(k > 0, ExpectedPosIntForHnswK(self.span)); + + let ef_expr = self + .parameters + .remove("ef") + .ok_or_else(|| miette!(HnswRequiredMissing("ef".to_string(), self.span)))?; + let ef = ef_expr.eval_to_const()?; + let ef = ef.get_int().ok_or(ExpectedPosIntForHnswEf(self.span))?; + + #[derive(Debug, Error, Diagnostic)] + #[error("Expected positive integer for `ef`")] + #[diagnostic(code(parser::expected_int_for_hnsw_ef))] + struct ExpectedPosIntForHnswEf(#[label] SourceSpan); + + ensure!(ef > 0, ExpectedPosIntForHnswEf(self.span)); + + let radius_expr = self.parameters.remove("radius"); + let radius = match radius_expr { + Some(expr) => { + let r = expr.eval_to_const()?; + let r = r.get_float().ok_or(ExpectedFloatForHnswRadius(self.span))?; + + #[derive(Debug, Error, Diagnostic)] + #[error("Expected positive float for `radius`")] + #[diagnostic(code(parser::expected_float_for_hnsw_radius))] + struct ExpectedFloatForHnswRadius(#[label] SourceSpan); + + ensure!(r > 0.0, ExpectedFloatForHnswRadius(self.span)); + Some(r) + } + None => None, + }; + + let filter = self.parameters.remove("filter"); + + let bind_field = match self.parameters.remove("bind_field") { None => None, Some(Expr::Binding { var, .. }) => Some(var), Some(expr) => { @@ -1093,7 +1114,7 @@ impl HnswSearchInput { } }; - let bind_field_idx = match self.bind_field_idx { + let bind_field_idx = match self.parameters.remove("bind_field_idx") { None => None, Some(Expr::Binding { var, .. }) => Some(var), Some(expr) => { @@ -1110,7 +1131,7 @@ impl HnswSearchInput { } }; - let bind_distance = match self.bind_distance { + let bind_distance = match self.parameters.remove("bind_distance") { None => None, Some(Expr::Binding { var, .. }) => Some(var), Some(expr) => { @@ -1127,7 +1148,7 @@ impl HnswSearchInput { } }; - let bind_vector = match self.bind_vector { + let bind_vector = match self.parameters.remove("bind_vector") { None => None, Some(Expr::Binding { var, .. }) => Some(var), Some(expr) => { @@ -1149,37 +1170,53 @@ impl HnswSearchInput { idx_handle, manifest, bindings, - k: self.k, - ef: self.ef, + k: k as usize, + ef: ef as usize, query, bind_field, bind_field_idx, bind_distance, bind_vector, - radius: self.radius, - filter: self.filter, + radius, + filter, span: self.span, })); - // ret.push(if is_negated { - // NormalFormAtom::NegatedRelation(NormalFormRelationApplyAtom { - // name: self.name, - // args, - // valid_at: self.valid_at, - // span: self.span, - // }) - // } else { - // NormalFormAtom::Relation(NormalFormRelationApplyAtom { - // name: self.name, - // args, - // valid_at: self.valid_at, - // span: self.span, - // }) - // }); - // Disjunction::conj(ret) - Ok(Disjunction::conj(conj)) } + pub(crate) fn normalize( + self, + gen: &mut TempSymbGen, + tx: &SessionTx<'_>, + ) -> Result { + let base_handle = tx.get_relation(&self.relation, false)?; + if base_handle.access_level < AccessLevel::ReadOnly { + bail!(InsufficientAccessLevel( + base_handle.name.to_string(), + "reading rows".to_string(), + base_handle.access_level + )); + } + if let Some((idx_handle, manifest)) = + base_handle.hnsw_indices.get(&self.index.name).cloned() + { + return self.normalize_hnsw(base_handle, idx_handle, manifest, gen); + } + #[derive(Debug, Error, Diagnostic)] + #[error("Index {name} not found on relation {relation}")] + #[diagnostic(code(eval::hnsw_index_not_found))] + struct IndexNotFound { + relation: String, + name: String, + #[label] + span: SourceSpan, + } + bail!(IndexNotFound { + relation: self.relation.to_string(), + name: self.index.to_string(), + span: self.span, + }) + } } impl Debug for InputAtom { @@ -1213,32 +1250,14 @@ impl Display for InputAtom { write!(f, ":{name}")?; f.debug_list().entries(args).finish()?; } - InputAtom::HnswSearch { inner } => { + InputAtom::Search { inner } => { write!(f, "~{}:{}{{", inner.relation, inner.index)?; for (binding, expr) in &inner.bindings { write!(f, "{binding}: {expr}, ")?; } write!(f, "| ")?; - write!(f, " query: {}, ", inner.query)?; - write!(f, " k: {}, ", inner.k)?; - write!(f, " ef: {}, ", inner.ef)?; - if let Some(radius) = &inner.radius { - write!(f, " radius: {}, ", radius)?; - } - if let Some(filter) = &inner.filter { - write!(f, " filter: {}, ", filter)?; - } - if let Some(bind_distance) = &inner.bind_distance { - write!(f, " bind_distance: {}, ", bind_distance)?; - } - if let Some(bind_field) = &inner.bind_field { - write!(f, " bind_field: {}, ", bind_field)?; - } - if let Some(bind_field_idx) = &inner.bind_field_idx { - write!(f, " bind_field_idx: {}, ", bind_field_idx)?; - } - if let Some(bind_vector) = &inner.bind_vector { - write!(f, " bind_vector: {}, ", bind_vector)?; + for (k, v) in inner.parameters.iter() { + write!(f, "{k}: {v}, ")?; } write!(f, "}}")?; } @@ -1307,7 +1326,7 @@ impl InputAtom { InputAtom::Relation { inner, .. } => inner.span, InputAtom::Predicate { inner, .. } => inner.span(), InputAtom::Unification { inner, .. } => inner.span, - InputAtom::HnswSearch { inner, .. } => inner.span, + InputAtom::Search { inner, .. } => inner.span, } } } diff --git a/cozo-core/src/fts/mod.rs b/cozo-core/src/fts/mod.rs index ae99d51d..bc73ff0a 100644 --- a/cozo-core/src/fts/mod.rs +++ b/cozo-core/src/fts/mod.rs @@ -26,12 +26,12 @@ pub(crate) mod cangjie; pub(crate) mod tokenizer; #[derive(Debug, Clone, PartialEq, Eq, Hash, serde_derive::Serialize, serde_derive::Deserialize)] -pub(crate) struct TokenizerFilterConfig { +pub(crate) struct TokenizerConfig { pub(crate) name: SmartString, pub(crate) args: Vec, } -impl TokenizerFilterConfig { +impl TokenizerConfig { // use sha256::digest; pub(crate) fn config_hash(&self, filters: &[Self]) -> impl AsRef<[u8]> { let mut hasher = Sha256::new(); @@ -225,8 +225,8 @@ pub(crate) struct FtsIndexConfig { base_relation: SmartString, index_name: SmartString, fts_fields: Vec>, - tokenizer: TokenizerFilterConfig, - filters: Vec, + tokenizer: TokenizerConfig, + filters: Vec, } #[derive(Default)] @@ -239,8 +239,8 @@ impl TokenizerCache { pub(crate) fn get( &self, tokenizer_name: &str, - tokenizer: &TokenizerFilterConfig, - filters: &[TokenizerFilterConfig], + tokenizer: &TokenizerConfig, + filters: &[TokenizerConfig], ) -> Result> { { let idx_cache = self.named_cache.read().unwrap(); diff --git a/cozo-core/src/parse/query.rs b/cozo-core/src/parse/query.rs index 9e9368d0..80a322a3 100644 --- a/cozo-core/src/parse/query.rs +++ b/cozo-core/src/parse/query.rs @@ -22,12 +22,7 @@ use thiserror::Error; use crate::data::aggr::{parse_aggr, Aggregation}; use crate::data::expr::Expr; use crate::data::functions::{str2vld, MAX_VALIDITY_TS}; -use crate::data::program::{ - FixedRuleApply, FixedRuleArg, HnswSearchInput, InputAtom, InputInlineRule, - InputInlineRulesOrFixed, InputNamedFieldRelationApplyAtom, InputProgram, - InputRelationApplyAtom, InputRuleApplyAtom, QueryAssertion, QueryOutOptions, RelationOp, - SortDir, Unification, -}; +use crate::data::program::{FixedRuleApply, FixedRuleArg, InputAtom, InputInlineRule, InputInlineRulesOrFixed, InputNamedFieldRelationApplyAtom, InputProgram, InputRelationApplyAtom, InputRuleApplyAtom, QueryAssertion, QueryOutOptions, RelationOp, SearchInput, SortDir, Unification}; use crate::data::relation::{ColType, ColumnDef, NullableColType, StoredRelationMetadata}; use crate::data::symb::{Symbol, PROG_ENTRY}; use crate::data::value::{DataValue, ValidityTs}; @@ -626,7 +621,7 @@ fn parse_atom( }, } } - Rule::vector_search => { + Rule::search_apply => { let span = src.extract_span(); let mut src = src.into_inner(); let name_p = src.next().unwrap(); @@ -649,141 +644,19 @@ fn parse_atom( .into_inner() .map(|arg| extract_named_apply_arg(arg, param_pool)) .try_collect()?; + let parameters: BTreeMap, Expr> = src + .map(|arg| extract_named_apply_arg(arg, param_pool)) + .try_collect()?; - let mut opts = HnswSearchInput { + let opts = SearchInput { relation, index, bindings, - k: 0, - ef: 0, - query: Expr::Const { - val: DataValue::Bot, - span: SourceSpan::default(), - }, - bind_field: None, - bind_field_idx: None, - bind_distance: None, - bind_vector: None, - radius: None, - filter: None, span, + parameters, }; - for field in src { - let mut fs = field.into_inner(); - let fname = fs.next().unwrap(); - let fval = fs.next().unwrap(); - match fname.as_str() { - "query" => { - opts.query = build_expr(fval, param_pool)?; - } - "filter" => { - opts.filter = Some(build_expr(fval, param_pool)?); - } - "bind_field" => { - opts.bind_field = Some(build_expr(fval, param_pool)?); - } - "bind_field_idx" => { - opts.bind_field_idx = Some(build_expr(fval, param_pool)?); - } - "bind_distance" => { - opts.bind_distance = Some(build_expr(fval, param_pool)?); - } - "bind_vector" => { - opts.bind_vector = Some(build_expr(fval, param_pool)?); - } - "k" => { - #[derive(Debug, Error, Diagnostic)] - #[error("Expected positive integer for `k`")] - #[diagnostic(code(parser::expected_int_for_hnsw_k))] - struct ExpectedPosIntForHnswK(#[label] SourceSpan); - - let k = build_expr(fval, param_pool)?; - let k = k.eval_to_const()?; - let k = k - .get_int() - .ok_or_else(|| ExpectedPosIntForHnswK(fname.extract_span()))?; - ensure!(k > 0, ExpectedPosIntForHnswK(fname.extract_span())); - opts.k = k as usize; - } - "ef" => { - #[derive(Debug, Error, Diagnostic)] - #[error("Expected positive integer for `ef`")] - #[diagnostic(code(parser::expected_int_for_hnsw_ef))] - struct ExpectedPosIntForHnswEf(#[label] SourceSpan); - - let ef = build_expr(fval, param_pool)?; - let ef = ef.eval_to_const()?; - let ef = ef - .get_int() - .ok_or_else(|| ExpectedPosIntForHnswEf(fname.extract_span()))?; - ensure!(ef > 0, ExpectedPosIntForHnswEf(fname.extract_span())); - opts.ef = ef as usize; - } - "radius" => { - #[derive(Debug, Error, Diagnostic)] - #[error("Expected positive float for `radius`")] - #[diagnostic(code(parser::expected_float_for_hnsw_radius))] - struct ExpectedPosFloatForHnswRadius(#[label] SourceSpan); - - let radius = build_expr(fval, param_pool)?; - let radius = radius.eval_to_const()?; - let radius = radius - .get_float() - .ok_or_else(|| ExpectedPosFloatForHnswRadius(fname.extract_span()))?; - ensure!( - radius > 0.0, - ExpectedPosFloatForHnswRadius(fname.extract_span()) - ); - opts.radius = Some(radius); - } - _ => { - #[derive(Debug, Error, Diagnostic)] - #[error("Unexpected HNSW option `{0}`")] - #[diagnostic(code(parser::unexpected_hnsw_option))] - struct UnexpectedHnswOption(String, #[label] SourceSpan); - bail!(UnexpectedHnswOption( - fname.as_str().into(), - fname.extract_span() - )) - } - } - } - - if matches!( - opts.query, - Expr::Const { - val: DataValue::Bot, - .. - } - ) { - #[derive(Debug, Error, Diagnostic)] - #[error("Field `query` is required for HNSW search")] - #[diagnostic(code(parser::hnsw_query_required))] - struct HnswQueryRequired(#[label] SourceSpan); - - bail!(HnswQueryRequired(span)) - } - - if opts.k == 0 { - #[derive(Debug, Error, Diagnostic)] - #[error("Field `k > 0` is required for HNSW search")] - #[diagnostic(code(parser::hnsw_k_required))] - struct HnswKRequired(#[label] SourceSpan); - - bail!(HnswKRequired(span)) - } - - if opts.ef == 0 { - #[derive(Debug, Error, Diagnostic)] - #[error("Field `ef > 0` is required for HNSW search")] - #[diagnostic(code(parser::hnsw_ef_required))] - struct HnswEfRequired(#[label] SourceSpan); - - bail!(HnswEfRequired(span)) - } - - InputAtom::HnswSearch { inner: opts } + InputAtom::Search { inner: opts } } Rule::relation_named_apply => { let span = src.extract_span(); diff --git a/cozo-core/src/parse/sys.rs b/cozo-core/src/parse/sys.rs index 00723825..7b812b41 100644 --- a/cozo-core/src/parse/sys.rs +++ b/cozo-core/src/parse/sys.rs @@ -18,6 +18,7 @@ use crate::data::program::InputProgram; use crate::data::relation::VecElementType; use crate::data::symb::Symbol; use crate::data::value::{DataValue, ValidityTs}; +use crate::fts::TokenizerConfig; use crate::parse::expr::build_expr; use crate::parse::query::parse_query; use crate::parse::{ExtractSpan, Pairs, Rule, SourceSpan}; @@ -39,8 +40,17 @@ pub(crate) enum SysOp { SetAccessLevel(Vec, AccessLevel), CreateIndex(Symbol, Symbol, Vec), CreateVectorIndex(HnswIndexConfig), + CreateFtsIndex(FtsIndexConfig), RemoveIndex(Symbol, Symbol), } +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub(crate) struct FtsIndexConfig { + pub(crate) base_relation: SmartString, + pub(crate) index_name: SmartString, + pub(crate) extractor: String, + pub(crate) tokenizer: TokenizerConfig, + pub(crate) filters: Vec, +} #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub(crate) struct HnswIndexConfig { @@ -227,7 +237,9 @@ pub(crate) fn parse_sys( let v = build_expr(opt_val, param_pool)? .eval_to_const()? .get_int() - .ok_or_else(|| miette!("Invalid m_neighbours: {}", opt_val_str))?; + .ok_or_else(|| { + miette!("Invalid m_neighbours: {}", opt_val_str) + })?; ensure!(v > 0, "Invalid m_neighbours: {}", v); m_neighbours = v as usize; } diff --git a/cozo-core/src/query/logical.rs b/cozo-core/src/query/logical.rs index 5bc6a2c5..6efb5856 100644 --- a/cozo-core/src/query/logical.rs +++ b/cozo-core/src/query/logical.rs @@ -121,11 +121,11 @@ impl InputAtom { InputAtom::Unification { inner } => { bail!(UnsafeNegation(inner.span)) } - InputAtom::HnswSearch { inner } => { + InputAtom::Search { inner } => { bail!(UnsafeNegation(inner.span)) } }, - InputAtom::HnswSearch { inner } => InputAtom::HnswSearch { inner }, + InputAtom::Search { inner } => InputAtom::Search { inner }, }) } @@ -229,7 +229,7 @@ impl InputAtom { InputAtom::Unification { inner: u } => { Disjunction::singlet(NormalFormAtom::Unification(u)) } - InputAtom::HnswSearch { inner } => inner.normalize(gen, tx)?, + InputAtom::Search { inner } => inner.normalize(gen, tx)?, }) } } diff --git a/cozo-core/src/runtime/db.rs b/cozo-core/src/runtime/db.rs index 2c614c45..c3402e76 100644 --- a/cozo-core/src/runtime/db.rs +++ b/cozo-core/src/runtime/db.rs @@ -1192,6 +1192,9 @@ impl<'s, S: Storage<'s>> Db { vec![vec![DataValue::from(OK_STR)]], )) } + SysOp::CreateFtsIndex(_) => { + todo!("FTS index creation is not yet implemented") + } SysOp::RemoveIndex(rel_name, idx_name) => { let lock = self .obtain_relation_locks(iter::once(&rel_name.name)) diff --git a/cozo-core/src/runtime/tests.rs b/cozo-core/src/runtime/tests.rs index 6544e0f3..172cda3f 100644 --- a/cozo-core/src/runtime/tests.rs +++ b/cozo-core/src/runtime/tests.rs @@ -19,7 +19,7 @@ use crate::data::expr::Expr; use crate::data::symb::Symbol; use crate::data::value::DataValue; use crate::fixed_rule::FixedRulePayload; -use crate::fts::{TokenizerCache, TokenizerFilterConfig}; +use crate::fts::{TokenizerCache, TokenizerConfig}; use crate::parse::SourceSpan; use crate::runtime::callback::CallbackOp; use crate::runtime::db::Poison; @@ -948,7 +948,7 @@ fn tentivy_tokenizers() { let tokenizer = tokenizers .get( "simple", - &TokenizerFilterConfig { + &TokenizerConfig { name: "Simple".into(), args: vec![], }, @@ -970,7 +970,7 @@ fn tentivy_tokenizers() { let tokenizer = tokenizers .get( "cangjie", - &TokenizerFilterConfig { + &TokenizerConfig { name: "Cangjie".into(), args: vec![], },