diff --git a/cozo-core/src/data/program.rs b/cozo-core/src/data/program.rs index 682bfac4..c3412101 100644 --- a/cozo-core/src/data/program.rs +++ b/cozo-core/src/data/program.rs @@ -22,6 +22,7 @@ use crate::data::relation::StoredRelationMetadata; use crate::data::symb::{Symbol, PROG_ENTRY}; use crate::data::value::{DataValue, ValidityTs}; use crate::fixed_rule::{FixedRule, FixedRuleHandle}; +use crate::fts::FtsIndexManifest; use crate::parse::SourceSpan; use crate::query::logical::{Disjunction, NamedFieldNotFound}; use crate::runtime::hnsw::HnswIndexManifest; @@ -947,6 +948,28 @@ pub(crate) struct HnswSearch { pub(crate) span: SourceSpan, } +#[derive(Copy, Clone, Debug)] +pub(crate) enum FtsScoreKind { + BM25, + TFIDF, +} + +#[derive(Clone, Debug)] +pub(crate) struct FtsSearch { + pub(crate) base_handle: RelationHandle, + pub(crate) idx_handle: RelationHandle, + pub(crate) manifest: FtsIndexManifest, + pub(crate) bindings: Vec, + pub(crate) k: usize, + pub(crate) query: Symbol, + pub(crate) score_kind: FtsScoreKind, + pub(crate) bind_score: Option, + pub(crate) lax_mode: bool, + pub(crate) global_idf: bool, + pub(crate) filter: Option, + pub(crate) span: SourceSpan, +} + impl HnswSearch { pub(crate) fn all_bindings(&self) -> impl Iterator { self.bindings @@ -958,7 +981,190 @@ impl HnswSearch { } } +impl FtsSearch { + pub(crate) fn all_bindings(&self) -> impl Iterator { + self.bindings.iter().chain(self.bind_score.iter()) + } +} + impl SearchInput { + fn normalize_fts( + mut self, + base_handle: RelationHandle, + idx_handle: RelationHandle, + manifest: FtsIndexManifest, + gen: &mut TempSymbGen, + ) -> Result { + let mut conj = Vec::with_capacity(self.bindings.len() + 8); + let mut bindings = Vec::with_capacity(self.bindings.len()); + let mut seen_variables = BTreeSet::new(); + + for col in base_handle + .metadata + .keys + .iter() + .chain(base_handle.metadata.non_keys.iter()) + { + if let Some(arg) = self.bindings.remove(&col.name) { + match arg { + Expr::Binding { var, .. } => { + if var.is_ignored_symbol() { + bindings.push(gen.next_ignored(var.span)); + } else if seen_variables.insert(var.clone()) { + bindings.push(var); + } else { + let span = var.span; + let dup = gen.next(span); + let unif = NormalFormAtom::Unification(Unification { + binding: dup.clone(), + expr: Expr::Binding { + var, + tuple_pos: None, + }, + one_many_unif: false, + span, + }); + conj.push(unif); + bindings.push(dup); + } + } + expr => { + let span = expr.span(); + let kw = gen.next(span); + bindings.push(kw.clone()); + let unif = NormalFormAtom::Unification(Unification { + binding: kw, + expr, + one_many_unif: false, + span, + }); + conj.push(unif) + } + } + } else { + bindings.push(gen.next_ignored(self.span)); + } + } + + if let Some((name, _)) = self.bindings.pop_first() { + bail!(NamedFieldNotFound( + self.relation.name.to_string(), + name.to_string(), + self.span + )); + } + + #[derive(Debug, Error, Diagnostic)] + #[error("Field `{0}` is required for HNSW search")] + #[diagnostic(code(parser::hnsw_query_required))] + struct HnswRequiredMissing(String, #[label] SourceSpan); + + let query = match self + .parameters + .remove("query") + .ok_or_else(|| miette!(HnswRequiredMissing("query".to_string(), self.span)))? + { + Expr::Binding { var, .. } => var, + expr => { + let span = expr.span(); + let kw = gen.next(span); + let unif = NormalFormAtom::Unification(Unification { + binding: kw.clone(), + expr, + one_many_unif: false, + span, + }); + conj.push(unif); + kw + } + }; + + let k_expr = self + .parameters + .remove("k") + .ok_or_else(|| miette!(HnswRequiredMissing("k".to_string(), self.span)))?; + let k = k_expr.eval_to_const()?; + let k = k.get_int().ok_or(ExpectedPosIntForFtsK(self.span))?; + + #[derive(Debug, Error, Diagnostic)] + #[error("Expected positive integer for `k`")] + #[diagnostic(code(parser::expected_int_for_hnsw_k))] + struct ExpectedPosIntForFtsK(#[label] SourceSpan); + + ensure!(k > 0, ExpectedPosIntForFtsK(self.span)); + + + let score_kind_expr = self.parameters.remove("score_kind"); + let score_kind = match score_kind_expr { + Some(expr) => { + let r = expr.eval_to_const()?; + let r = r.get_str().ok_or_else(|| miette!("Score kind for FTS must be a string"))?; + + match r { + "bm25" => FtsScoreKind::BM25, + "tf_idf" => FtsScoreKind::TFIDF, + s => bail!("Unknown score kind for FTS: {}", s) + } + } + None => FtsScoreKind::BM25, + }; + + let lax_mode_expr = self.parameters.remove("lax_mode"); + let lax_mode = match lax_mode_expr { + Some(expr) => { + let r = expr.eval_to_const()?; + let r = r.get_bool().ok_or_else(|| miette!("Lax mode for FTS must be a boolean"))?; + r + } + None => true, + }; + + let global_idf_expr = self.parameters.remove("global_idf"); + let global_idf = match global_idf_expr { + Some(expr) => { + let r = expr.eval_to_const()?; + let r = r.get_bool().ok_or_else(|| miette!("Lax mode for FTS must be a boolean"))?; + r + } + None => true, + }; + + let filter = self.parameters.remove("filter"); + + let bind_score = match self.parameters.remove("bind_score") { + None => None, + Some(Expr::Binding { var, .. }) => Some(var), + Some(expr) => { + let span = expr.span(); + let kw = gen.next(span); + let unif = NormalFormAtom::Unification(Unification { + binding: kw.clone(), + expr, + one_many_unif: false, + span, + }); + conj.push(unif); + Some(kw) + } + }; + + conj.push(NormalFormAtom::FtsSearch(FtsSearch { + base_handle, + idx_handle, + manifest, + bindings, + k: k as usize, + query, + score_kind, + bind_score, + lax_mode, + global_idf, + filter, + span: self.span, + })); + + Ok(Disjunction::conj(conj)) + } fn normalize_hnsw( mut self, base_handle: RelationHandle, @@ -1202,6 +1408,11 @@ impl SearchInput { { return self.normalize_hnsw(base_handle, idx_handle, manifest, gen); } + if let Some((idx_handle, manifest)) = + base_handle.fts_indices.get(&self.index.name).cloned() + { + return self.normalize_fts(base_handle, idx_handle, manifest, gen); + } #[derive(Debug, Error, Diagnostic)] #[error("Index {name} not found on relation {relation}")] #[diagnostic(code(eval::hnsw_index_not_found))] @@ -1340,6 +1551,7 @@ pub(crate) enum NormalFormAtom { Predicate(Expr), Unification(Unification), HnswSearch(HnswSearch), + FtsSearch(FtsSearch), } #[derive(Debug, Clone)] @@ -1351,6 +1563,7 @@ pub(crate) enum MagicAtom { NegatedRelation(MagicRelationApplyAtom), Unification(Unification), HnswSearch(HnswSearch), + FtsSearch(FtsSearch), } #[derive(Clone, Debug)] diff --git a/cozo-core/src/fts/ast.rs b/cozo-core/src/fts/ast.rs index c952d36b..cc5c1ef2 100644 --- a/cozo-core/src/fts/ast.rs +++ b/cozo-core/src/fts/ast.rs @@ -6,14 +6,15 @@ * You can obtain one at https://mozilla.org/MPL/2.0/. */ +use ordered_float::OrderedFloat; use crate::fts::tokenizer::TextAnalyzer; use smartstring::{LazyCompact, SmartString}; -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub(crate) struct FtsLiteral { pub(crate) value: SmartString, pub(crate) is_prefix: bool, - pub(crate) booster: f64, + pub(crate) booster: OrderedFloat, } impl FtsLiteral { @@ -34,13 +35,13 @@ impl FtsLiteral { } } -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub(crate) struct FtsNear { pub(crate) literals: Vec, pub(crate) distance: u32, } -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub(crate) enum FtsExpr { Literal(FtsLiteral), Near(FtsNear), diff --git a/cozo-core/src/fts/indexing.rs b/cozo-core/src/fts/indexing.rs index d077809a..7a9e594f 100644 --- a/cozo-core/src/fts/indexing.rs +++ b/cozo-core/src/fts/indexing.rs @@ -7,17 +7,128 @@ */ use crate::data::expr::{eval_bytecode, Bytecode}; +use crate::data::program::FtsSearch; +use crate::data::tuple::{decode_tuple_from_key, Tuple, ENCODED_KEY_MIN_LEN}; +use crate::data::value::LARGEST_UTF_CHAR; +use crate::fts::ast::{FtsExpr, FtsLiteral}; use crate::fts::tokenizer::TextAnalyzer; +use crate::parse::fts::parse_fts_query; use crate::runtime::relation::RelationHandle; use crate::runtime::transact::SessionTx; -use crate::DataValue; +use crate::{decode_tuple_from_kv, DataValue, SourceSpan}; +use itertools::Itertools; use miette::{bail, Diagnostic, Result}; use rustc_hash::{FxHashMap, FxHashSet}; use smartstring::{LazyCompact, SmartString}; +use std::collections::hash_map::Entry; use std::collections::HashMap; use thiserror::Error; +#[derive(Default)] +pub(crate) struct FtsCache { + total_n_cache: FxHashMap, usize>, + results_cache: FxHashMap>, +} + +impl FtsCache { + fn get_n_for_relation(&mut self, rel: &RelationHandle, tx: &SessionTx<'_>) -> Result { + Ok(match self.total_n_cache.entry(rel.name.clone()) { + Entry::Vacant(v) => { + let start = rel.encode_key_for_store(&[], Default::default())?; + let end = rel.encode_key_for_store(&[DataValue::Bot], Default::default())?; + let val = tx.store_tx.range_count(&start, &end)?; + v.insert(val); + val + } + Entry::Occupied(o) => *o.get(), + }) + } +} + +struct PositionInfo { + from: u32, + to: u32, + position: u32, +} + +struct LiteralStats { + key: Tuple, + position_info: Vec, + doc_len: u32, +} + impl<'a> SessionTx<'a> { + fn search_literal( + &self, + literal: &FtsLiteral, + idx_handle: &RelationHandle, + ) -> Result> { + let start_key_str = &literal.value as &str; + let start_key = vec![DataValue::Str(SmartString::from(start_key_str))]; + let mut end_key_str = literal.value.clone(); + end_key_str.push(LARGEST_UTF_CHAR); + let end_key = vec![DataValue::Str(end_key_str)]; + let start_key_bytes = idx_handle.encode_key_for_store(&start_key, Default::default())?; + let end_key_bytes = idx_handle.encode_key_for_store(&end_key, Default::default())?; + let mut results = vec![]; + for item in self.store_tx.range_scan(&start_key_bytes, &end_key_bytes) { + let (kvec, vvec) = item?; + let key_tuple = decode_tuple_from_key(&kvec, idx_handle.metadata.keys.len()); + let found_str_key = key_tuple[0].get_str().unwrap(); + if literal.is_prefix { + if !found_str_key.starts_with(start_key_str) { + break; + } + } else { + if found_str_key != start_key_str { + break; + } + } + let vals: Vec = rmp_serde::from_slice(&vvec[ENCODED_KEY_MIN_LEN..]).unwrap(); + let froms = vals[0].get_slice().unwrap(); + let tos = vals[1].get_slice().unwrap(); + let positions = vals[2].get_slice().unwrap(); + let total_length = vals[3].get_int().unwrap(); + let position_info = froms + .iter() + .zip(tos.iter()) + .zip(positions.iter()) + .map(|((f, t), p)| PositionInfo { + from: f.get_int().unwrap() as u32, + to: t.get_int().unwrap() as u32, + position: p.get_int().unwrap() as u32, + }) + .collect_vec(); + results.push(LiteralStats { + key: key_tuple[1..].to_vec(), + position_info, + doc_len: total_length as u32, + }); + } + Ok(results) + } + pub(crate) fn fts_search( + &self, + q: &str, + config: &FtsSearch, + filter_code: &Option<(Vec, SourceSpan)>, + tokenizer: &TextAnalyzer, + stack: &mut Vec, + cache: &mut FtsCache, + ) -> Result> { + let ast = parse_fts_query(q)?.tokenize(tokenizer); + if ast.is_empty() { + return Ok(vec![]); + } + match cache.results_cache.entry(ast) { + Entry::Occupied(_) => { + todo!() + } + Entry::Vacant(_) => { + todo!() + } + } + } pub(crate) fn put_fts_index_item( &mut self, tuple: &[DataValue], @@ -55,7 +166,12 @@ impl<'a> SessionTx<'a> { for k in &tuple[..rel_handle.metadata.keys.len()] { key.push(k.clone()); } - let mut val = vec![DataValue::Bot, DataValue::Bot, DataValue::Bot, DataValue::from(count)]; + let mut val = vec![ + DataValue::Bot, + DataValue::Bot, + DataValue::Bot, + DataValue::from(count), + ]; for (text, (from, to, position)) in collector { key[0] = DataValue::Str(text); val[0] = DataValue::List(from); diff --git a/cozo-core/src/parse/fts.rs b/cozo-core/src/parse/fts.rs index 513ad9e8..58a65f47 100644 --- a/cozo-core/src/parse/fts.rs +++ b/cozo-core/src/parse/fts.rs @@ -16,7 +16,7 @@ use pest::pratt_parser::{Op, PrattParser}; use pest::Parser; use smartstring::SmartString; -fn parse_fts_query(q: &str) -> Result { +pub(crate) fn parse_fts_query(q: &str) -> Result { let mut pairs = CozoScriptParser::parse(Rule::fts_doc, q).into_diagnostic()?; let pairs = pairs.next().unwrap().into_inner(); let pairs: Vec<_> = pairs @@ -124,7 +124,7 @@ fn build_phrase(pair: Pair<'_>) -> Result { Ok(FtsLiteral { value: core_text, is_prefix: is_quoted, - booster, + booster: booster.into(), }) } diff --git a/cozo-core/src/query/compile.rs b/cozo-core/src/query/compile.rs index 408087f4..68a91e62 100644 --- a/cozo-core/src/query/compile.rs +++ b/cozo-core/src/query/compile.rs @@ -505,6 +505,40 @@ impl<'a> SessionTx<'a> { ret = ret.filter(Expr::build_and(post_filters, s.span))?; } } + MagicAtom::FtsSearch(s) => { + debug_assert!( + seen_variables.contains(&s.query), + "FTS search query must be bound" + ); + let mut own_bindings = vec![]; + let mut post_filters = vec![]; + for var in s.all_bindings() { + if seen_variables.contains(var) { + let rk = gen_symb(var.span); + post_filters.push(Expr::build_equate( + vec![ + Expr::Binding { + var: var.clone(), + tuple_pos: None, + }, + Expr::Binding { + var: rk.clone(), + tuple_pos: None, + }, + ], + var.span, + )); + own_bindings.push(rk); + } else { + seen_variables.insert(var.clone()); + own_bindings.push(var.clone()); + } + } + ret = ret.fts_search(s.clone(), own_bindings)?; + if !post_filters.is_empty() { + ret = ret.filter(Expr::build_and(post_filters, s.span))?; + } + } MagicAtom::Unification(u) => { if seen_variables.contains(&u.binding) { let expr = if u.one_many_unif { diff --git a/cozo-core/src/query/magic.rs b/cozo-core/src/query/magic.rs index d414d63b..b0996cec 100644 --- a/cozo-core/src/query/magic.rs +++ b/cozo-core/src/query/magic.rs @@ -176,6 +176,10 @@ fn magic_rewrite_ruleset( seen_bindings.extend(s.all_bindings().cloned()); collected_atoms.push(MagicAtom::HnswSearch(s)); } + MagicAtom::FtsSearch(s) => { + seen_bindings.extend(s.all_bindings().cloned()); + collected_atoms.push(MagicAtom::FtsSearch(s)); + } MagicAtom::Rule(r_app) => { if r_app.name.has_bound_adornment() { // we are guaranteed to have a magic rule application @@ -527,6 +531,15 @@ impl NormalFormAtom { } MagicAtom::HnswSearch(s.clone()) } + NormalFormAtom::FtsSearch(s) => { + for arg in s.all_bindings() { + if !seen_bindings.contains(arg) { + seen_bindings.insert(arg.clone()); + } + } + MagicAtom::FtsSearch(s.clone()) + } + NormalFormAtom::Predicate(p) => { // predicate cannot introduce new bindings MagicAtom::Predicate(p.clone()) diff --git a/cozo-core/src/query/ra.rs b/cozo-core/src/query/ra.rs index 997bf4a6..9f6fe90a 100644 --- a/cozo-core/src/query/ra.rs +++ b/cozo-core/src/query/ra.rs @@ -17,7 +17,7 @@ use miette::{bail, Diagnostic, Result}; use thiserror::Error; use crate::data::expr::{compute_bounds, eval_bytecode, eval_bytecode_pred, Bytecode, Expr}; -use crate::data::program::{HnswSearch, MagicSymbol}; +use crate::data::program::{FtsSearch, HnswSearch, MagicSymbol}; use crate::data::relation::{ColType, NullableColType}; use crate::data::symb::Symbol; use crate::data::tuple::{Tuple, TupleIter}; @@ -39,6 +39,7 @@ pub(crate) enum RelAlgebra { Filter(FilteredRA), Unification(UnificationRA), HnswSearch(HnswSearchRA), + FtsSearch(FtsSearchRA), } impl RelAlgebra { @@ -54,6 +55,7 @@ impl RelAlgebra { RelAlgebra::Unification(i) => i.span, RelAlgebra::StoredWithValidity(i) => i.span, RelAlgebra::HnswSearch(i) => i.hnsw_search.span, + RelAlgebra::FtsSearch(i) => i.fts_search.span, } } } @@ -283,6 +285,11 @@ impl Debug for RelAlgebra { .field(&bindings) .field(&s.hnsw_search.idx_handle.name) .finish(), + RelAlgebra::FtsSearch(s) => f + .debug_tuple("FtsSearch") + .field(&bindings) + .field(&s.fts_search.idx_handle.name) + .finish(), RelAlgebra::StoredWithValidity(r) => f .debug_tuple("StoredWithValidity") .field(&bindings) @@ -352,6 +359,9 @@ impl RelAlgebra { RelAlgebra::HnswSearch(s) => { s.fill_binding_indices_and_compile()?; } + RelAlgebra::FtsSearch(s) => { + s.fill_binding_indices_and_compile()?; + } RelAlgebra::StoredWithValidity(v) => { v.fill_binding_indices_and_compile()?; } @@ -448,7 +458,8 @@ impl RelAlgebra { | RelAlgebra::Reorder(_) | RelAlgebra::NegJoin(_) | RelAlgebra::Unification(_) - | RelAlgebra::HnswSearch(_)) => { + | RelAlgebra::HnswSearch(_) + | RelAlgebra::FtsSearch(_)) => { let span = filter.span(); RelAlgebra::Filter(FilteredRA { parent: Box::new(s), @@ -601,6 +612,18 @@ impl RelAlgebra { own_bindings, })) } + pub(crate) fn fts_search( + self, + fts_search: FtsSearch, + own_bindings: Vec, + ) -> Result { + Ok(Self::FtsSearch(FtsSearchRA { + parent: Box::new(self), + fts_search, + filter_bytecode: None, + own_bindings, + })) + } pub(crate) fn join( self, right: RelAlgebra, @@ -852,6 +875,83 @@ pub(crate) struct HnswSearchRA { pub(crate) own_bindings: Vec, } +#[derive(Debug)] +pub(crate) struct FtsSearchRA { + pub(crate) parent: Box, + pub(crate) fts_search: FtsSearch, + pub(crate) filter_bytecode: Option<(Vec, SourceSpan)>, + pub(crate) own_bindings: Vec, +} + +impl FtsSearchRA { + fn fill_binding_indices_and_compile(&mut self) -> Result<()> { + self.parent.fill_binding_indices_and_compile()?; + if self.fts_search.filter.is_some() { + let bindings: BTreeMap<_, _> = self + .own_bindings + .iter() + .cloned() + .enumerate() + .map(|(a, b)| (b, a)) + .collect(); + let filter = self.fts_search.filter.as_mut().unwrap(); + filter.fill_binding_indices(&bindings)?; + self.filter_bytecode = Some((filter.compile()?, filter.span())); + } + Ok(()) + } + fn iter<'a>( + &'a self, + tx: &'a SessionTx<'_>, + delta_rule: Option<&MagicSymbol>, + stores: &'a BTreeMap, + ) -> Result> { + let bindings = self.parent.bindings_after_eliminate(); + let mut bind_idx = usize::MAX; + for (i, b) in bindings.iter().enumerate() { + if *b == self.fts_search.query { + bind_idx = i; + break; + } + } + let config = self.fts_search.clone(); + let filter_code = self.filter_bytecode.clone(); + let mut stack = vec![]; + let mut idf_cache = Default::default(); + let tokenizer = tx.tokenizers.get( + &config.idx_handle.name, + &config.manifest.tokenizer, + &config.manifest.filters, + )?; + let it = self + .parent + .iter(tx, delta_rule, stores)? + .map_ok(move |tuple| -> Result<_> { + let q = match tuple[bind_idx].clone() { + DataValue::Str(s) => s, + d => bail!("Expected string for FTS search, got {:?}", d), + }; + + let res = tx.fts_search( + &q, + &config, + &filter_code, + &tokenizer, + &mut stack, + &mut idf_cache, + )?; + Ok(res.into_iter().map(move |t| { + let mut r = tuple.clone(); + r.extend(t); + r + })) + }) + .map(flatten_err) + .flatten_ok(); + Ok(Box::new(it)) + } +} + impl HnswSearchRA { fn fill_binding_indices_and_compile(&mut self) -> Result<()> { self.parent.fill_binding_indices_and_compile()?; @@ -1644,6 +1744,7 @@ impl RelAlgebra { RelAlgebra::NegJoin(r) => r.do_eliminate_temp_vars(used), RelAlgebra::Unification(r) => r.do_eliminate_temp_vars(used), RelAlgebra::HnswSearch(_) => Ok(()), + RelAlgebra::FtsSearch(_) => Ok(()), } } @@ -1659,6 +1760,7 @@ impl RelAlgebra { RelAlgebra::NegJoin(r) => Some(&r.to_eliminate), RelAlgebra::Unification(u) => Some(&u.to_eliminate), RelAlgebra::HnswSearch(_) => None, + RelAlgebra::FtsSearch(_) => None, } } @@ -1693,6 +1795,11 @@ impl RelAlgebra { bindings.extend_from_slice(&s.own_bindings); bindings } + RelAlgebra::FtsSearch(s) => { + let mut bindings = s.parent.bindings_after_eliminate(); + bindings.extend_from_slice(&s.own_bindings); + bindings + } } } pub(crate) fn iter<'a>( @@ -1712,6 +1819,7 @@ impl RelAlgebra { RelAlgebra::NegJoin(r) => r.iter(tx, delta_rule, stores), RelAlgebra::Unification(r) => r.iter(tx, delta_rule, stores), RelAlgebra::HnswSearch(r) => r.iter(tx, delta_rule, stores), + RelAlgebra::FtsSearch(r) => r.iter(tx, delta_rule, stores), } } } @@ -1892,6 +2000,7 @@ impl InnerJoin { } } RelAlgebra::HnswSearch(_) => "hnsw_search_join", + RelAlgebra::FtsSearch(_) => "fts_search_join", RelAlgebra::StoredWithValidity(_) => { let join_indices = self .joiner @@ -2003,7 +2112,8 @@ impl InnerJoin { RelAlgebra::Join(_) | RelAlgebra::Filter(_) | RelAlgebra::Unification(_) - | RelAlgebra::HnswSearch(_) => { + | RelAlgebra::HnswSearch(_) + | RelAlgebra::FtsSearch(_) => { self.materialized_join(tx, eliminate_indices, delta_rule, stores) } RelAlgebra::Reorder(_) => { diff --git a/cozo-core/src/query/reorder.rs b/cozo-core/src/query/reorder.rs index 7d3e348d..f0c05ecd 100644 --- a/cozo-core/src/query/reorder.rs +++ b/cozo-core/src/query/reorder.rs @@ -80,6 +80,14 @@ impl NormalFormInlineRule { pending.push(NormalFormAtom::HnswSearch(s)); } } + NormalFormAtom::FtsSearch(s) => { + if seen_variables.contains(&s.query) { + seen_variables.extend(s.all_bindings().cloned()); + round_1_collected.push(NormalFormAtom::FtsSearch(s)); + } else { + pending.push(NormalFormAtom::FtsSearch(s)); + } + } } } @@ -112,6 +120,10 @@ impl NormalFormInlineRule { seen_variables.extend(s.all_bindings().cloned()); collected.push(NormalFormAtom::HnswSearch(s)); } + NormalFormAtom::FtsSearch(s) => { + seen_variables.extend(s.all_bindings().cloned()); + collected.push(NormalFormAtom::FtsSearch(s)); + } } for atom in last_pending.iter() { match atom { @@ -138,6 +150,14 @@ impl NormalFormInlineRule { pending.push(NormalFormAtom::HnswSearch(s.clone())); } } + NormalFormAtom::FtsSearch(s) => { + if seen_variables.contains(&s.query) { + seen_variables.extend(s.all_bindings().cloned()); + collected.push(NormalFormAtom::FtsSearch(s.clone())); + } else { + pending.push(NormalFormAtom::FtsSearch(s.clone())); + } + } NormalFormAtom::Predicate(p) => { if p.bindings()?.is_subset(&seen_variables) { collected.push(NormalFormAtom::Predicate(p.clone())); @@ -183,6 +203,9 @@ impl NormalFormInlineRule { NormalFormAtom::HnswSearch(s) => { bail!(UnboundVariable(s.span)) } + NormalFormAtom::FtsSearch(s) => { + bail!(UnboundVariable(s.span)) + } } } } diff --git a/cozo-core/src/query/stratify.rs b/cozo-core/src/query/stratify.rs index 6cb2eac4..53f2eeb1 100644 --- a/cozo-core/src/query/stratify.rs +++ b/cozo-core/src/query/stratify.rs @@ -30,7 +30,8 @@ impl NormalFormAtom { | NormalFormAtom::NegatedRelation(_) | NormalFormAtom::Predicate(_) | NormalFormAtom::Unification(_) - | NormalFormAtom::HnswSearch(_) => Default::default(), + | NormalFormAtom::HnswSearch(_) + | NormalFormAtom::FtsSearch(_) => Default::default(), NormalFormAtom::Rule(r) => BTreeMap::from([(&r.name, false)]), NormalFormAtom::NegatedRule(r) => BTreeMap::from([(&r.name, true)]), } diff --git a/cozo-core/src/runtime/db.rs b/cozo-core/src/runtime/db.rs index f909f228..196ede06 100644 --- a/cozo-core/src/runtime/db.rs +++ b/cozo-core/src/runtime/db.rs @@ -43,10 +43,7 @@ use crate::fts::TokenizerCache; use crate::parse::sys::SysOp; use crate::parse::{parse_script, CozoScript, SourceSpan}; use crate::query::compile::{CompiledProgram, CompiledRule, CompiledRuleSet}; -use crate::query::ra::{ - FilteredRA, HnswSearchRA, InnerJoin, NegJoin, RelAlgebra, ReorderRA, StoredRA, - StoredWithValidityRA, TempStoreRA, UnificationRA, -}; +use crate::query::ra::{FilteredRA, FtsSearchRA, HnswSearchRA, InnerJoin, NegJoin, RelAlgebra, ReorderRA, StoredRA, StoredWithValidityRA, TempStoreRA, UnificationRA}; #[allow(unused_imports)] use crate::runtime::callback::{ CallbackCollector, CallbackDeclaration, CallbackOp, EventCallbackRegistry, @@ -1073,6 +1070,18 @@ impl<'s, S: Storage<'s>> Db { .map(|f| f.to_string()) .collect_vec()), ), + RelAlgebra::FtsSearch(FtsSearchRA { + fts_search, .. + }) => ( + "fts_index", + json!(format!(":{}", fts_search.query.name)), + json!(fts_search.query.name), + json!(fts_search + .filter + .iter() + .map(|f| f.to_string()) + .collect_vec()), + ), }; ret_for_relation.push(json!({ STRATUM: stratum,