FTS infrastructure

main
Ziyang Hu 1 year ago
parent cc0025d514
commit 5f90c6f9eb

@ -22,6 +22,7 @@ use crate::data::relation::StoredRelationMetadata;
use crate::data::symb::{Symbol, PROG_ENTRY}; use crate::data::symb::{Symbol, PROG_ENTRY};
use crate::data::value::{DataValue, ValidityTs}; use crate::data::value::{DataValue, ValidityTs};
use crate::fixed_rule::{FixedRule, FixedRuleHandle}; use crate::fixed_rule::{FixedRule, FixedRuleHandle};
use crate::fts::FtsIndexManifest;
use crate::parse::SourceSpan; use crate::parse::SourceSpan;
use crate::query::logical::{Disjunction, NamedFieldNotFound}; use crate::query::logical::{Disjunction, NamedFieldNotFound};
use crate::runtime::hnsw::HnswIndexManifest; use crate::runtime::hnsw::HnswIndexManifest;
@ -947,6 +948,28 @@ pub(crate) struct HnswSearch {
pub(crate) span: SourceSpan, pub(crate) span: SourceSpan,
} }
#[derive(Copy, Clone, Debug)]
pub(crate) enum FtsScoreKind {
BM25,
TFIDF,
}
#[derive(Clone, Debug)]
pub(crate) struct FtsSearch {
pub(crate) base_handle: RelationHandle,
pub(crate) idx_handle: RelationHandle,
pub(crate) manifest: FtsIndexManifest,
pub(crate) bindings: Vec<Symbol>,
pub(crate) k: usize,
pub(crate) query: Symbol,
pub(crate) score_kind: FtsScoreKind,
pub(crate) bind_score: Option<Symbol>,
pub(crate) lax_mode: bool,
pub(crate) global_idf: bool,
pub(crate) filter: Option<Expr>,
pub(crate) span: SourceSpan,
}
impl HnswSearch { impl HnswSearch {
pub(crate) fn all_bindings(&self) -> impl Iterator<Item = &Symbol> { pub(crate) fn all_bindings(&self) -> impl Iterator<Item = &Symbol> {
self.bindings self.bindings
@ -958,7 +981,190 @@ impl HnswSearch {
} }
} }
impl FtsSearch {
pub(crate) fn all_bindings(&self) -> impl Iterator<Item = &Symbol> {
self.bindings.iter().chain(self.bind_score.iter())
}
}
impl SearchInput { impl SearchInput {
fn normalize_fts(
mut self,
base_handle: RelationHandle,
idx_handle: RelationHandle,
manifest: FtsIndexManifest,
gen: &mut TempSymbGen,
) -> Result<Disjunction> {
let mut conj = Vec::with_capacity(self.bindings.len() + 8);
let mut bindings = Vec::with_capacity(self.bindings.len());
let mut seen_variables = BTreeSet::new();
for col in base_handle
.metadata
.keys
.iter()
.chain(base_handle.metadata.non_keys.iter())
{
if let Some(arg) = self.bindings.remove(&col.name) {
match arg {
Expr::Binding { var, .. } => {
if var.is_ignored_symbol() {
bindings.push(gen.next_ignored(var.span));
} else if seen_variables.insert(var.clone()) {
bindings.push(var);
} else {
let span = var.span;
let dup = gen.next(span);
let unif = NormalFormAtom::Unification(Unification {
binding: dup.clone(),
expr: Expr::Binding {
var,
tuple_pos: None,
},
one_many_unif: false,
span,
});
conj.push(unif);
bindings.push(dup);
}
}
expr => {
let span = expr.span();
let kw = gen.next(span);
bindings.push(kw.clone());
let unif = NormalFormAtom::Unification(Unification {
binding: kw,
expr,
one_many_unif: false,
span,
});
conj.push(unif)
}
}
} else {
bindings.push(gen.next_ignored(self.span));
}
}
if let Some((name, _)) = self.bindings.pop_first() {
bail!(NamedFieldNotFound(
self.relation.name.to_string(),
name.to_string(),
self.span
));
}
#[derive(Debug, Error, Diagnostic)]
#[error("Field `{0}` is required for HNSW search")]
#[diagnostic(code(parser::hnsw_query_required))]
struct HnswRequiredMissing(String, #[label] SourceSpan);
let query = match self
.parameters
.remove("query")
.ok_or_else(|| miette!(HnswRequiredMissing("query".to_string(), self.span)))?
{
Expr::Binding { var, .. } => var,
expr => {
let span = expr.span();
let kw = gen.next(span);
let unif = NormalFormAtom::Unification(Unification {
binding: kw.clone(),
expr,
one_many_unif: false,
span,
});
conj.push(unif);
kw
}
};
let k_expr = self
.parameters
.remove("k")
.ok_or_else(|| miette!(HnswRequiredMissing("k".to_string(), self.span)))?;
let k = k_expr.eval_to_const()?;
let k = k.get_int().ok_or(ExpectedPosIntForFtsK(self.span))?;
#[derive(Debug, Error, Diagnostic)]
#[error("Expected positive integer for `k`")]
#[diagnostic(code(parser::expected_int_for_hnsw_k))]
struct ExpectedPosIntForFtsK(#[label] SourceSpan);
ensure!(k > 0, ExpectedPosIntForFtsK(self.span));
let score_kind_expr = self.parameters.remove("score_kind");
let score_kind = match score_kind_expr {
Some(expr) => {
let r = expr.eval_to_const()?;
let r = r.get_str().ok_or_else(|| miette!("Score kind for FTS must be a string"))?;
match r {
"bm25" => FtsScoreKind::BM25,
"tf_idf" => FtsScoreKind::TFIDF,
s => bail!("Unknown score kind for FTS: {}", s)
}
}
None => FtsScoreKind::BM25,
};
let lax_mode_expr = self.parameters.remove("lax_mode");
let lax_mode = match lax_mode_expr {
Some(expr) => {
let r = expr.eval_to_const()?;
let r = r.get_bool().ok_or_else(|| miette!("Lax mode for FTS must be a boolean"))?;
r
}
None => true,
};
let global_idf_expr = self.parameters.remove("global_idf");
let global_idf = match global_idf_expr {
Some(expr) => {
let r = expr.eval_to_const()?;
let r = r.get_bool().ok_or_else(|| miette!("Lax mode for FTS must be a boolean"))?;
r
}
None => true,
};
let filter = self.parameters.remove("filter");
let bind_score = match self.parameters.remove("bind_score") {
None => None,
Some(Expr::Binding { var, .. }) => Some(var),
Some(expr) => {
let span = expr.span();
let kw = gen.next(span);
let unif = NormalFormAtom::Unification(Unification {
binding: kw.clone(),
expr,
one_many_unif: false,
span,
});
conj.push(unif);
Some(kw)
}
};
conj.push(NormalFormAtom::FtsSearch(FtsSearch {
base_handle,
idx_handle,
manifest,
bindings,
k: k as usize,
query,
score_kind,
bind_score,
lax_mode,
global_idf,
filter,
span: self.span,
}));
Ok(Disjunction::conj(conj))
}
fn normalize_hnsw( fn normalize_hnsw(
mut self, mut self,
base_handle: RelationHandle, base_handle: RelationHandle,
@ -1202,6 +1408,11 @@ impl SearchInput {
{ {
return self.normalize_hnsw(base_handle, idx_handle, manifest, gen); return self.normalize_hnsw(base_handle, idx_handle, manifest, gen);
} }
if let Some((idx_handle, manifest)) =
base_handle.fts_indices.get(&self.index.name).cloned()
{
return self.normalize_fts(base_handle, idx_handle, manifest, gen);
}
#[derive(Debug, Error, Diagnostic)] #[derive(Debug, Error, Diagnostic)]
#[error("Index {name} not found on relation {relation}")] #[error("Index {name} not found on relation {relation}")]
#[diagnostic(code(eval::hnsw_index_not_found))] #[diagnostic(code(eval::hnsw_index_not_found))]
@ -1340,6 +1551,7 @@ pub(crate) enum NormalFormAtom {
Predicate(Expr), Predicate(Expr),
Unification(Unification), Unification(Unification),
HnswSearch(HnswSearch), HnswSearch(HnswSearch),
FtsSearch(FtsSearch),
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -1351,6 +1563,7 @@ pub(crate) enum MagicAtom {
NegatedRelation(MagicRelationApplyAtom), NegatedRelation(MagicRelationApplyAtom),
Unification(Unification), Unification(Unification),
HnswSearch(HnswSearch), HnswSearch(HnswSearch),
FtsSearch(FtsSearch),
} }
#[derive(Clone, Debug)] #[derive(Clone, Debug)]

@ -6,14 +6,15 @@
* You can obtain one at https://mozilla.org/MPL/2.0/. * You can obtain one at https://mozilla.org/MPL/2.0/.
*/ */
use ordered_float::OrderedFloat;
use crate::fts::tokenizer::TextAnalyzer; use crate::fts::tokenizer::TextAnalyzer;
use smartstring::{LazyCompact, SmartString}; use smartstring::{LazyCompact, SmartString};
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub(crate) struct FtsLiteral { pub(crate) struct FtsLiteral {
pub(crate) value: SmartString<LazyCompact>, pub(crate) value: SmartString<LazyCompact>,
pub(crate) is_prefix: bool, pub(crate) is_prefix: bool,
pub(crate) booster: f64, pub(crate) booster: OrderedFloat<f64>,
} }
impl FtsLiteral { impl FtsLiteral {
@ -34,13 +35,13 @@ impl FtsLiteral {
} }
} }
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub(crate) struct FtsNear { pub(crate) struct FtsNear {
pub(crate) literals: Vec<FtsLiteral>, pub(crate) literals: Vec<FtsLiteral>,
pub(crate) distance: u32, pub(crate) distance: u32,
} }
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub(crate) enum FtsExpr { pub(crate) enum FtsExpr {
Literal(FtsLiteral), Literal(FtsLiteral),
Near(FtsNear), Near(FtsNear),

@ -7,17 +7,128 @@
*/ */
use crate::data::expr::{eval_bytecode, Bytecode}; use crate::data::expr::{eval_bytecode, Bytecode};
use crate::data::program::FtsSearch;
use crate::data::tuple::{decode_tuple_from_key, Tuple, ENCODED_KEY_MIN_LEN};
use crate::data::value::LARGEST_UTF_CHAR;
use crate::fts::ast::{FtsExpr, FtsLiteral};
use crate::fts::tokenizer::TextAnalyzer; use crate::fts::tokenizer::TextAnalyzer;
use crate::parse::fts::parse_fts_query;
use crate::runtime::relation::RelationHandle; use crate::runtime::relation::RelationHandle;
use crate::runtime::transact::SessionTx; use crate::runtime::transact::SessionTx;
use crate::DataValue; use crate::{decode_tuple_from_kv, DataValue, SourceSpan};
use itertools::Itertools;
use miette::{bail, Diagnostic, Result}; use miette::{bail, Diagnostic, Result};
use rustc_hash::{FxHashMap, FxHashSet}; use rustc_hash::{FxHashMap, FxHashSet};
use smartstring::{LazyCompact, SmartString}; use smartstring::{LazyCompact, SmartString};
use std::collections::hash_map::Entry;
use std::collections::HashMap; use std::collections::HashMap;
use thiserror::Error; use thiserror::Error;
#[derive(Default)]
pub(crate) struct FtsCache {
total_n_cache: FxHashMap<SmartString<LazyCompact>, usize>,
results_cache: FxHashMap<FtsExpr, Vec<(Tuple, f64)>>,
}
impl FtsCache {
fn get_n_for_relation(&mut self, rel: &RelationHandle, tx: &SessionTx<'_>) -> Result<usize> {
Ok(match self.total_n_cache.entry(rel.name.clone()) {
Entry::Vacant(v) => {
let start = rel.encode_key_for_store(&[], Default::default())?;
let end = rel.encode_key_for_store(&[DataValue::Bot], Default::default())?;
let val = tx.store_tx.range_count(&start, &end)?;
v.insert(val);
val
}
Entry::Occupied(o) => *o.get(),
})
}
}
struct PositionInfo {
from: u32,
to: u32,
position: u32,
}
struct LiteralStats {
key: Tuple,
position_info: Vec<PositionInfo>,
doc_len: u32,
}
impl<'a> SessionTx<'a> { impl<'a> SessionTx<'a> {
fn search_literal(
&self,
literal: &FtsLiteral,
idx_handle: &RelationHandle,
) -> Result<Vec<LiteralStats>> {
let start_key_str = &literal.value as &str;
let start_key = vec![DataValue::Str(SmartString::from(start_key_str))];
let mut end_key_str = literal.value.clone();
end_key_str.push(LARGEST_UTF_CHAR);
let end_key = vec![DataValue::Str(end_key_str)];
let start_key_bytes = idx_handle.encode_key_for_store(&start_key, Default::default())?;
let end_key_bytes = idx_handle.encode_key_for_store(&end_key, Default::default())?;
let mut results = vec![];
for item in self.store_tx.range_scan(&start_key_bytes, &end_key_bytes) {
let (kvec, vvec) = item?;
let key_tuple = decode_tuple_from_key(&kvec, idx_handle.metadata.keys.len());
let found_str_key = key_tuple[0].get_str().unwrap();
if literal.is_prefix {
if !found_str_key.starts_with(start_key_str) {
break;
}
} else {
if found_str_key != start_key_str {
break;
}
}
let vals: Vec<DataValue> = rmp_serde::from_slice(&vvec[ENCODED_KEY_MIN_LEN..]).unwrap();
let froms = vals[0].get_slice().unwrap();
let tos = vals[1].get_slice().unwrap();
let positions = vals[2].get_slice().unwrap();
let total_length = vals[3].get_int().unwrap();
let position_info = froms
.iter()
.zip(tos.iter())
.zip(positions.iter())
.map(|((f, t), p)| PositionInfo {
from: f.get_int().unwrap() as u32,
to: t.get_int().unwrap() as u32,
position: p.get_int().unwrap() as u32,
})
.collect_vec();
results.push(LiteralStats {
key: key_tuple[1..].to_vec(),
position_info,
doc_len: total_length as u32,
});
}
Ok(results)
}
pub(crate) fn fts_search(
&self,
q: &str,
config: &FtsSearch,
filter_code: &Option<(Vec<Bytecode>, SourceSpan)>,
tokenizer: &TextAnalyzer,
stack: &mut Vec<DataValue>,
cache: &mut FtsCache,
) -> Result<Vec<Tuple>> {
let ast = parse_fts_query(q)?.tokenize(tokenizer);
if ast.is_empty() {
return Ok(vec![]);
}
match cache.results_cache.entry(ast) {
Entry::Occupied(_) => {
todo!()
}
Entry::Vacant(_) => {
todo!()
}
}
}
pub(crate) fn put_fts_index_item( pub(crate) fn put_fts_index_item(
&mut self, &mut self,
tuple: &[DataValue], tuple: &[DataValue],
@ -55,7 +166,12 @@ impl<'a> SessionTx<'a> {
for k in &tuple[..rel_handle.metadata.keys.len()] { for k in &tuple[..rel_handle.metadata.keys.len()] {
key.push(k.clone()); key.push(k.clone());
} }
let mut val = vec![DataValue::Bot, DataValue::Bot, DataValue::Bot, DataValue::from(count)]; let mut val = vec![
DataValue::Bot,
DataValue::Bot,
DataValue::Bot,
DataValue::from(count),
];
for (text, (from, to, position)) in collector { for (text, (from, to, position)) in collector {
key[0] = DataValue::Str(text); key[0] = DataValue::Str(text);
val[0] = DataValue::List(from); val[0] = DataValue::List(from);

@ -16,7 +16,7 @@ use pest::pratt_parser::{Op, PrattParser};
use pest::Parser; use pest::Parser;
use smartstring::SmartString; use smartstring::SmartString;
fn parse_fts_query(q: &str) -> Result<FtsExpr> { pub(crate) fn parse_fts_query(q: &str) -> Result<FtsExpr> {
let mut pairs = CozoScriptParser::parse(Rule::fts_doc, q).into_diagnostic()?; let mut pairs = CozoScriptParser::parse(Rule::fts_doc, q).into_diagnostic()?;
let pairs = pairs.next().unwrap().into_inner(); let pairs = pairs.next().unwrap().into_inner();
let pairs: Vec<_> = pairs let pairs: Vec<_> = pairs
@ -124,7 +124,7 @@ fn build_phrase(pair: Pair<'_>) -> Result<FtsLiteral> {
Ok(FtsLiteral { Ok(FtsLiteral {
value: core_text, value: core_text,
is_prefix: is_quoted, is_prefix: is_quoted,
booster, booster: booster.into(),
}) })
} }

@ -505,6 +505,40 @@ impl<'a> SessionTx<'a> {
ret = ret.filter(Expr::build_and(post_filters, s.span))?; ret = ret.filter(Expr::build_and(post_filters, s.span))?;
} }
} }
MagicAtom::FtsSearch(s) => {
debug_assert!(
seen_variables.contains(&s.query),
"FTS search query must be bound"
);
let mut own_bindings = vec![];
let mut post_filters = vec![];
for var in s.all_bindings() {
if seen_variables.contains(var) {
let rk = gen_symb(var.span);
post_filters.push(Expr::build_equate(
vec![
Expr::Binding {
var: var.clone(),
tuple_pos: None,
},
Expr::Binding {
var: rk.clone(),
tuple_pos: None,
},
],
var.span,
));
own_bindings.push(rk);
} else {
seen_variables.insert(var.clone());
own_bindings.push(var.clone());
}
}
ret = ret.fts_search(s.clone(), own_bindings)?;
if !post_filters.is_empty() {
ret = ret.filter(Expr::build_and(post_filters, s.span))?;
}
}
MagicAtom::Unification(u) => { MagicAtom::Unification(u) => {
if seen_variables.contains(&u.binding) { if seen_variables.contains(&u.binding) {
let expr = if u.one_many_unif { let expr = if u.one_many_unif {

@ -176,6 +176,10 @@ fn magic_rewrite_ruleset(
seen_bindings.extend(s.all_bindings().cloned()); seen_bindings.extend(s.all_bindings().cloned());
collected_atoms.push(MagicAtom::HnswSearch(s)); collected_atoms.push(MagicAtom::HnswSearch(s));
} }
MagicAtom::FtsSearch(s) => {
seen_bindings.extend(s.all_bindings().cloned());
collected_atoms.push(MagicAtom::FtsSearch(s));
}
MagicAtom::Rule(r_app) => { MagicAtom::Rule(r_app) => {
if r_app.name.has_bound_adornment() { if r_app.name.has_bound_adornment() {
// we are guaranteed to have a magic rule application // we are guaranteed to have a magic rule application
@ -527,6 +531,15 @@ impl NormalFormAtom {
} }
MagicAtom::HnswSearch(s.clone()) MagicAtom::HnswSearch(s.clone())
} }
NormalFormAtom::FtsSearch(s) => {
for arg in s.all_bindings() {
if !seen_bindings.contains(arg) {
seen_bindings.insert(arg.clone());
}
}
MagicAtom::FtsSearch(s.clone())
}
NormalFormAtom::Predicate(p) => { NormalFormAtom::Predicate(p) => {
// predicate cannot introduce new bindings // predicate cannot introduce new bindings
MagicAtom::Predicate(p.clone()) MagicAtom::Predicate(p.clone())

@ -17,7 +17,7 @@ use miette::{bail, Diagnostic, Result};
use thiserror::Error; use thiserror::Error;
use crate::data::expr::{compute_bounds, eval_bytecode, eval_bytecode_pred, Bytecode, Expr}; use crate::data::expr::{compute_bounds, eval_bytecode, eval_bytecode_pred, Bytecode, Expr};
use crate::data::program::{HnswSearch, MagicSymbol}; use crate::data::program::{FtsSearch, HnswSearch, MagicSymbol};
use crate::data::relation::{ColType, NullableColType}; use crate::data::relation::{ColType, NullableColType};
use crate::data::symb::Symbol; use crate::data::symb::Symbol;
use crate::data::tuple::{Tuple, TupleIter}; use crate::data::tuple::{Tuple, TupleIter};
@ -39,6 +39,7 @@ pub(crate) enum RelAlgebra {
Filter(FilteredRA), Filter(FilteredRA),
Unification(UnificationRA), Unification(UnificationRA),
HnswSearch(HnswSearchRA), HnswSearch(HnswSearchRA),
FtsSearch(FtsSearchRA),
} }
impl RelAlgebra { impl RelAlgebra {
@ -54,6 +55,7 @@ impl RelAlgebra {
RelAlgebra::Unification(i) => i.span, RelAlgebra::Unification(i) => i.span,
RelAlgebra::StoredWithValidity(i) => i.span, RelAlgebra::StoredWithValidity(i) => i.span,
RelAlgebra::HnswSearch(i) => i.hnsw_search.span, RelAlgebra::HnswSearch(i) => i.hnsw_search.span,
RelAlgebra::FtsSearch(i) => i.fts_search.span,
} }
} }
} }
@ -283,6 +285,11 @@ impl Debug for RelAlgebra {
.field(&bindings) .field(&bindings)
.field(&s.hnsw_search.idx_handle.name) .field(&s.hnsw_search.idx_handle.name)
.finish(), .finish(),
RelAlgebra::FtsSearch(s) => f
.debug_tuple("FtsSearch")
.field(&bindings)
.field(&s.fts_search.idx_handle.name)
.finish(),
RelAlgebra::StoredWithValidity(r) => f RelAlgebra::StoredWithValidity(r) => f
.debug_tuple("StoredWithValidity") .debug_tuple("StoredWithValidity")
.field(&bindings) .field(&bindings)
@ -352,6 +359,9 @@ impl RelAlgebra {
RelAlgebra::HnswSearch(s) => { RelAlgebra::HnswSearch(s) => {
s.fill_binding_indices_and_compile()?; s.fill_binding_indices_and_compile()?;
} }
RelAlgebra::FtsSearch(s) => {
s.fill_binding_indices_and_compile()?;
}
RelAlgebra::StoredWithValidity(v) => { RelAlgebra::StoredWithValidity(v) => {
v.fill_binding_indices_and_compile()?; v.fill_binding_indices_and_compile()?;
} }
@ -448,7 +458,8 @@ impl RelAlgebra {
| RelAlgebra::Reorder(_) | RelAlgebra::Reorder(_)
| RelAlgebra::NegJoin(_) | RelAlgebra::NegJoin(_)
| RelAlgebra::Unification(_) | RelAlgebra::Unification(_)
| RelAlgebra::HnswSearch(_)) => { | RelAlgebra::HnswSearch(_)
| RelAlgebra::FtsSearch(_)) => {
let span = filter.span(); let span = filter.span();
RelAlgebra::Filter(FilteredRA { RelAlgebra::Filter(FilteredRA {
parent: Box::new(s), parent: Box::new(s),
@ -601,6 +612,18 @@ impl RelAlgebra {
own_bindings, own_bindings,
})) }))
} }
pub(crate) fn fts_search(
self,
fts_search: FtsSearch,
own_bindings: Vec<Symbol>,
) -> Result<Self> {
Ok(Self::FtsSearch(FtsSearchRA {
parent: Box::new(self),
fts_search,
filter_bytecode: None,
own_bindings,
}))
}
pub(crate) fn join( pub(crate) fn join(
self, self,
right: RelAlgebra, right: RelAlgebra,
@ -852,6 +875,83 @@ pub(crate) struct HnswSearchRA {
pub(crate) own_bindings: Vec<Symbol>, pub(crate) own_bindings: Vec<Symbol>,
} }
#[derive(Debug)]
pub(crate) struct FtsSearchRA {
pub(crate) parent: Box<RelAlgebra>,
pub(crate) fts_search: FtsSearch,
pub(crate) filter_bytecode: Option<(Vec<Bytecode>, SourceSpan)>,
pub(crate) own_bindings: Vec<Symbol>,
}
impl FtsSearchRA {
fn fill_binding_indices_and_compile(&mut self) -> Result<()> {
self.parent.fill_binding_indices_and_compile()?;
if self.fts_search.filter.is_some() {
let bindings: BTreeMap<_, _> = self
.own_bindings
.iter()
.cloned()
.enumerate()
.map(|(a, b)| (b, a))
.collect();
let filter = self.fts_search.filter.as_mut().unwrap();
filter.fill_binding_indices(&bindings)?;
self.filter_bytecode = Some((filter.compile()?, filter.span()));
}
Ok(())
}
fn iter<'a>(
&'a self,
tx: &'a SessionTx<'_>,
delta_rule: Option<&MagicSymbol>,
stores: &'a BTreeMap<MagicSymbol, EpochStore>,
) -> Result<TupleIter<'a>> {
let bindings = self.parent.bindings_after_eliminate();
let mut bind_idx = usize::MAX;
for (i, b) in bindings.iter().enumerate() {
if *b == self.fts_search.query {
bind_idx = i;
break;
}
}
let config = self.fts_search.clone();
let filter_code = self.filter_bytecode.clone();
let mut stack = vec![];
let mut idf_cache = Default::default();
let tokenizer = tx.tokenizers.get(
&config.idx_handle.name,
&config.manifest.tokenizer,
&config.manifest.filters,
)?;
let it = self
.parent
.iter(tx, delta_rule, stores)?
.map_ok(move |tuple| -> Result<_> {
let q = match tuple[bind_idx].clone() {
DataValue::Str(s) => s,
d => bail!("Expected string for FTS search, got {:?}", d),
};
let res = tx.fts_search(
&q,
&config,
&filter_code,
&tokenizer,
&mut stack,
&mut idf_cache,
)?;
Ok(res.into_iter().map(move |t| {
let mut r = tuple.clone();
r.extend(t);
r
}))
})
.map(flatten_err)
.flatten_ok();
Ok(Box::new(it))
}
}
impl HnswSearchRA { impl HnswSearchRA {
fn fill_binding_indices_and_compile(&mut self) -> Result<()> { fn fill_binding_indices_and_compile(&mut self) -> Result<()> {
self.parent.fill_binding_indices_and_compile()?; self.parent.fill_binding_indices_and_compile()?;
@ -1644,6 +1744,7 @@ impl RelAlgebra {
RelAlgebra::NegJoin(r) => r.do_eliminate_temp_vars(used), RelAlgebra::NegJoin(r) => r.do_eliminate_temp_vars(used),
RelAlgebra::Unification(r) => r.do_eliminate_temp_vars(used), RelAlgebra::Unification(r) => r.do_eliminate_temp_vars(used),
RelAlgebra::HnswSearch(_) => Ok(()), RelAlgebra::HnswSearch(_) => Ok(()),
RelAlgebra::FtsSearch(_) => Ok(()),
} }
} }
@ -1659,6 +1760,7 @@ impl RelAlgebra {
RelAlgebra::NegJoin(r) => Some(&r.to_eliminate), RelAlgebra::NegJoin(r) => Some(&r.to_eliminate),
RelAlgebra::Unification(u) => Some(&u.to_eliminate), RelAlgebra::Unification(u) => Some(&u.to_eliminate),
RelAlgebra::HnswSearch(_) => None, RelAlgebra::HnswSearch(_) => None,
RelAlgebra::FtsSearch(_) => None,
} }
} }
@ -1693,6 +1795,11 @@ impl RelAlgebra {
bindings.extend_from_slice(&s.own_bindings); bindings.extend_from_slice(&s.own_bindings);
bindings bindings
} }
RelAlgebra::FtsSearch(s) => {
let mut bindings = s.parent.bindings_after_eliminate();
bindings.extend_from_slice(&s.own_bindings);
bindings
}
} }
} }
pub(crate) fn iter<'a>( pub(crate) fn iter<'a>(
@ -1712,6 +1819,7 @@ impl RelAlgebra {
RelAlgebra::NegJoin(r) => r.iter(tx, delta_rule, stores), RelAlgebra::NegJoin(r) => r.iter(tx, delta_rule, stores),
RelAlgebra::Unification(r) => r.iter(tx, delta_rule, stores), RelAlgebra::Unification(r) => r.iter(tx, delta_rule, stores),
RelAlgebra::HnswSearch(r) => r.iter(tx, delta_rule, stores), RelAlgebra::HnswSearch(r) => r.iter(tx, delta_rule, stores),
RelAlgebra::FtsSearch(r) => r.iter(tx, delta_rule, stores),
} }
} }
} }
@ -1892,6 +2000,7 @@ impl InnerJoin {
} }
} }
RelAlgebra::HnswSearch(_) => "hnsw_search_join", RelAlgebra::HnswSearch(_) => "hnsw_search_join",
RelAlgebra::FtsSearch(_) => "fts_search_join",
RelAlgebra::StoredWithValidity(_) => { RelAlgebra::StoredWithValidity(_) => {
let join_indices = self let join_indices = self
.joiner .joiner
@ -2003,7 +2112,8 @@ impl InnerJoin {
RelAlgebra::Join(_) RelAlgebra::Join(_)
| RelAlgebra::Filter(_) | RelAlgebra::Filter(_)
| RelAlgebra::Unification(_) | RelAlgebra::Unification(_)
| RelAlgebra::HnswSearch(_) => { | RelAlgebra::HnswSearch(_)
| RelAlgebra::FtsSearch(_) => {
self.materialized_join(tx, eliminate_indices, delta_rule, stores) self.materialized_join(tx, eliminate_indices, delta_rule, stores)
} }
RelAlgebra::Reorder(_) => { RelAlgebra::Reorder(_) => {

@ -80,6 +80,14 @@ impl NormalFormInlineRule {
pending.push(NormalFormAtom::HnswSearch(s)); pending.push(NormalFormAtom::HnswSearch(s));
} }
} }
NormalFormAtom::FtsSearch(s) => {
if seen_variables.contains(&s.query) {
seen_variables.extend(s.all_bindings().cloned());
round_1_collected.push(NormalFormAtom::FtsSearch(s));
} else {
pending.push(NormalFormAtom::FtsSearch(s));
}
}
} }
} }
@ -112,6 +120,10 @@ impl NormalFormInlineRule {
seen_variables.extend(s.all_bindings().cloned()); seen_variables.extend(s.all_bindings().cloned());
collected.push(NormalFormAtom::HnswSearch(s)); collected.push(NormalFormAtom::HnswSearch(s));
} }
NormalFormAtom::FtsSearch(s) => {
seen_variables.extend(s.all_bindings().cloned());
collected.push(NormalFormAtom::FtsSearch(s));
}
} }
for atom in last_pending.iter() { for atom in last_pending.iter() {
match atom { match atom {
@ -138,6 +150,14 @@ impl NormalFormInlineRule {
pending.push(NormalFormAtom::HnswSearch(s.clone())); pending.push(NormalFormAtom::HnswSearch(s.clone()));
} }
} }
NormalFormAtom::FtsSearch(s) => {
if seen_variables.contains(&s.query) {
seen_variables.extend(s.all_bindings().cloned());
collected.push(NormalFormAtom::FtsSearch(s.clone()));
} else {
pending.push(NormalFormAtom::FtsSearch(s.clone()));
}
}
NormalFormAtom::Predicate(p) => { NormalFormAtom::Predicate(p) => {
if p.bindings()?.is_subset(&seen_variables) { if p.bindings()?.is_subset(&seen_variables) {
collected.push(NormalFormAtom::Predicate(p.clone())); collected.push(NormalFormAtom::Predicate(p.clone()));
@ -183,6 +203,9 @@ impl NormalFormInlineRule {
NormalFormAtom::HnswSearch(s) => { NormalFormAtom::HnswSearch(s) => {
bail!(UnboundVariable(s.span)) bail!(UnboundVariable(s.span))
} }
NormalFormAtom::FtsSearch(s) => {
bail!(UnboundVariable(s.span))
}
} }
} }
} }

@ -30,7 +30,8 @@ impl NormalFormAtom {
| NormalFormAtom::NegatedRelation(_) | NormalFormAtom::NegatedRelation(_)
| NormalFormAtom::Predicate(_) | NormalFormAtom::Predicate(_)
| NormalFormAtom::Unification(_) | NormalFormAtom::Unification(_)
| NormalFormAtom::HnswSearch(_) => Default::default(), | NormalFormAtom::HnswSearch(_)
| NormalFormAtom::FtsSearch(_) => Default::default(),
NormalFormAtom::Rule(r) => BTreeMap::from([(&r.name, false)]), NormalFormAtom::Rule(r) => BTreeMap::from([(&r.name, false)]),
NormalFormAtom::NegatedRule(r) => BTreeMap::from([(&r.name, true)]), NormalFormAtom::NegatedRule(r) => BTreeMap::from([(&r.name, true)]),
} }

@ -43,10 +43,7 @@ use crate::fts::TokenizerCache;
use crate::parse::sys::SysOp; use crate::parse::sys::SysOp;
use crate::parse::{parse_script, CozoScript, SourceSpan}; use crate::parse::{parse_script, CozoScript, SourceSpan};
use crate::query::compile::{CompiledProgram, CompiledRule, CompiledRuleSet}; use crate::query::compile::{CompiledProgram, CompiledRule, CompiledRuleSet};
use crate::query::ra::{ use crate::query::ra::{FilteredRA, FtsSearchRA, HnswSearchRA, InnerJoin, NegJoin, RelAlgebra, ReorderRA, StoredRA, StoredWithValidityRA, TempStoreRA, UnificationRA};
FilteredRA, HnswSearchRA, InnerJoin, NegJoin, RelAlgebra, ReorderRA, StoredRA,
StoredWithValidityRA, TempStoreRA, UnificationRA,
};
#[allow(unused_imports)] #[allow(unused_imports)]
use crate::runtime::callback::{ use crate::runtime::callback::{
CallbackCollector, CallbackDeclaration, CallbackOp, EventCallbackRegistry, CallbackCollector, CallbackDeclaration, CallbackOp, EventCallbackRegistry,
@ -1073,6 +1070,18 @@ impl<'s, S: Storage<'s>> Db<S> {
.map(|f| f.to_string()) .map(|f| f.to_string())
.collect_vec()), .collect_vec()),
), ),
RelAlgebra::FtsSearch(FtsSearchRA {
fts_search, ..
}) => (
"fts_index",
json!(format!(":{}", fts_search.query.name)),
json!(fts_search.query.name),
json!(fts_search
.filter
.iter()
.map(|f| f.to_string())
.collect_vec()),
),
}; };
ret_for_relation.push(json!({ ret_for_relation.push(json!({
STRATUM: stratum, STRATUM: stratum,

Loading…
Cancel
Save