diff --git a/cozo-core/src/data/program.rs b/cozo-core/src/data/program.rs index c3412101..f1128e0b 100644 --- a/cozo-core/src/data/program.rs +++ b/cozo-core/src/data/program.rs @@ -950,8 +950,8 @@ pub(crate) struct HnswSearch { #[derive(Copy, Clone, Debug)] pub(crate) enum FtsScoreKind { - BM25, TFIDF, + TF, } #[derive(Clone, Debug)] @@ -961,11 +961,12 @@ pub(crate) struct FtsSearch { pub(crate) manifest: FtsIndexManifest, pub(crate) bindings: Vec, pub(crate) k: usize, + pub(crate) k1: f64, + pub(crate) b: f64, pub(crate) query: Symbol, pub(crate) score_kind: FtsScoreKind, pub(crate) bind_score: Option, pub(crate) lax_mode: bool, - pub(crate) global_idf: bool, pub(crate) filter: Option, pub(crate) span: SourceSpan, } @@ -1093,37 +1094,70 @@ impl SearchInput { ensure!(k > 0, ExpectedPosIntForFtsK(self.span)); + let k1 = { + match self.parameters.remove("k1") { + None => 1.2, + Some(expr) => { + let r = expr.eval_to_const()?; + let r = r + .get_float() + .ok_or_else(|| miette!("k1 for FTS must be a float"))?; + + #[derive(Debug, Error, Diagnostic)] + #[error("Expected positive float for `k1`")] + #[diagnostic(code(parser::expected_float_for_hnsw_k1))] + struct ExpectedPosFloatForFtsK1(#[label] SourceSpan); + + ensure!(r > 0.0, ExpectedPosFloatForFtsK1(self.span)); + r + } + } + }; + + let b = { + match self.parameters.remove("b") { + None => 0.75, + Some(expr) => { + let r = expr.eval_to_const()?; + let r = r + .get_float() + .ok_or_else(|| miette!("b for FTS must be a float"))?; + + #[derive(Debug, Error, Diagnostic)] + #[error("Expected positive float for `b`")] + #[diagnostic(code(parser::expected_float_for_hnsw_b))] + struct ExpectedPosFloatForFtsB(#[label] SourceSpan); + + ensure!(r > 0.0, ExpectedPosFloatForFtsB(self.span)); + r + } + } + }; let score_kind_expr = self.parameters.remove("score_kind"); let score_kind = match score_kind_expr { Some(expr) => { let r = expr.eval_to_const()?; - let r = r.get_str().ok_or_else(|| miette!("Score kind for FTS must be a string"))?; + let r = r + .get_str() + .ok_or_else(|| miette!("Score kind for FTS must be a string"))?; match r { - "bm25" => FtsScoreKind::BM25, "tf_idf" => FtsScoreKind::TFIDF, - s => bail!("Unknown score kind for FTS: {}", s) + "tf" => FtsScoreKind::TF, + s => bail!("Unknown score kind for FTS: {}", s), } } - None => FtsScoreKind::BM25, + None => FtsScoreKind::TFIDF, }; let lax_mode_expr = self.parameters.remove("lax_mode"); let lax_mode = match lax_mode_expr { Some(expr) => { let r = expr.eval_to_const()?; - let r = r.get_bool().ok_or_else(|| miette!("Lax mode for FTS must be a boolean"))?; - r - } - None => true, - }; - - let global_idf_expr = self.parameters.remove("global_idf"); - let global_idf = match global_idf_expr { - Some(expr) => { - let r = expr.eval_to_const()?; - let r = r.get_bool().ok_or_else(|| miette!("Lax mode for FTS must be a boolean"))?; + let r = r + .get_bool() + .ok_or_else(|| miette!("Lax mode for FTS must be a boolean"))?; r } None => true, @@ -1158,7 +1192,8 @@ impl SearchInput { score_kind, bind_score, lax_mode, - global_idf, + k1, + b, filter, span: self.span, })); @@ -1408,8 +1443,7 @@ impl SearchInput { { return self.normalize_hnsw(base_handle, idx_handle, manifest, gen); } - if let Some((idx_handle, manifest)) = - base_handle.fts_indices.get(&self.index.name).cloned() + if let Some((idx_handle, manifest)) = base_handle.fts_indices.get(&self.index.name).cloned() { return self.normalize_fts(base_handle, idx_handle, manifest, gen); } diff --git a/cozo-core/src/fts/ast.rs b/cozo-core/src/fts/ast.rs index cc5c1ef2..ae7468ee 100644 --- a/cozo-core/src/fts/ast.rs +++ b/cozo-core/src/fts/ast.rs @@ -51,15 +51,15 @@ pub(crate) enum FtsExpr { } impl FtsExpr { - pub(crate) fn needs_idf(&self) -> bool { - match self { - FtsExpr::Literal(_) => false, - FtsExpr::Near(_) => false, - FtsExpr::And(exprs) => exprs.iter().any(|e| e.needs_idf()), - FtsExpr::Or(_) => true, - FtsExpr::Not(lhs, _) => lhs.needs_idf(), - } - } + // pub(crate) fn needs_idf(&self) -> bool { + // match self { + // FtsExpr::Literal(_) => false, + // FtsExpr::Near(_) => false, + // FtsExpr::And(exprs) => exprs.iter().any(|e| e.needs_idf()), + // FtsExpr::Or(_) => true, + // FtsExpr::Not(lhs, _) => lhs.needs_idf(), + // } + // } pub(crate) fn tokenize(self, tokenizer: &TextAnalyzer) -> Self { self.do_tokenize(tokenizer).flatten() diff --git a/cozo-core/src/fts/indexing.rs b/cozo-core/src/fts/indexing.rs index 7a9e594f..c77d3001 100644 --- a/cozo-core/src/fts/indexing.rs +++ b/cozo-core/src/fts/indexing.rs @@ -7,10 +7,10 @@ */ use crate::data::expr::{eval_bytecode, Bytecode}; -use crate::data::program::FtsSearch; +use crate::data::program::{FtsScoreKind, FtsSearch}; use crate::data::tuple::{decode_tuple_from_key, Tuple, ENCODED_KEY_MIN_LEN}; use crate::data::value::LARGEST_UTF_CHAR; -use crate::fts::ast::{FtsExpr, FtsLiteral}; +use crate::fts::ast::{FtsExpr, FtsLiteral, FtsNear}; use crate::fts::tokenizer::TextAnalyzer; use crate::parse::fts::parse_fts_query; use crate::runtime::relation::RelationHandle; @@ -18,6 +18,7 @@ use crate::runtime::transact::SessionTx; use crate::{decode_tuple_from_kv, DataValue, SourceSpan}; use itertools::Itertools; use miette::{bail, Diagnostic, Result}; +use num_traits::real::Real; use rustc_hash::{FxHashMap, FxHashSet}; use smartstring::{LazyCompact, SmartString}; use std::collections::hash_map::Entry; @@ -58,7 +59,7 @@ struct LiteralStats { } impl<'a> SessionTx<'a> { - fn search_literal( + fn fts_search_literal( &self, literal: &FtsLiteral, idx_handle: &RelationHandle, @@ -107,6 +108,156 @@ impl<'a> SessionTx<'a> { } Ok(results) } + fn fts_search_impl( + &self, + ast: &FtsExpr, + config: &FtsSearch, + filter_code: &Option<(Vec, SourceSpan)>, + tokenizer: &TextAnalyzer, + n: usize, + ) -> Result> { + Ok(match ast { + FtsExpr::Literal(l) => { + let mut res = FxHashMap::default(); + for el in self.fts_search_literal(l, &config.idx_handle)? { + let score = Self::fts_compute_score( + el.position_info.len(), + n, + el.doc_len, + l.booster.0, + config, + ); + res.insert(el.key, score); + } + res + } + FtsExpr::And(ls) => { + let mut l_iter = ls.iter(); + let mut res = self.fts_search_impl( + l_iter.next().unwrap(), + config, + filter_code, + tokenizer, + n, + )?; + for nxt in l_iter { + let nxt_res = self.fts_search_impl(nxt, config, filter_code, tokenizer, n)?; + res = res + .into_iter() + .filter_map(|(k, v)| { + if let Some(nxt_v) = nxt_res.get(&k) { + Some((k, v + nxt_v)) + } else { + None + } + }) + .collect(); + } + res + } + FtsExpr::Or(ls) => { + let mut res: FxHashMap = FxHashMap::default(); + for nxt in ls { + let nxt_res = self.fts_search_impl(nxt, config, filter_code, tokenizer, n)?; + for (k, v) in nxt_res { + if let Some(old_v) = res.get_mut(&k) { + *old_v = (*old_v).max(v); + } else { + res.insert(k, v); + } + } + } + res + } + FtsExpr::Near(FtsNear { literals, distance }) => { + let mut l_it = literals.iter(); + let mut coll: FxHashMap<_, _> = FxHashMap::default(); + for first_el in self.fts_search_literal(l_it.next().unwrap(), &config.idx_handle)? { + coll.insert( + first_el.key, + ( + first_el + .position_info + .into_iter() + .map(|el| el.position) + .collect_vec(), + first_el.doc_len, + ), + ); + } + for lit_nxt in literals { + let el_res = self.fts_search_literal(lit_nxt, &config.idx_handle)?; + coll = el_res + .into_iter() + .filter_map(|x| match coll.remove(&x.key) { + None => None, + Some((prev_pos, doc_len)) => { + let mut inner_coll = FxHashSet::default(); + for p in prev_pos { + for pi in x.position_info.iter() { + let cur = pi.position; + if cur > p { + if cur - p <= *distance { + inner_coll.insert(p); + } + } else { + if p - cur <= *distance { + inner_coll.insert(cur); + } + } + } + } + if inner_coll.is_empty() { + None + } else { + Some((x.key, (inner_coll.into_iter().collect_vec(), doc_len))) + } + } + }) + .collect(); + } + let mut booster = 0.0; + for lit in literals { + booster += lit.booster.0; + } + coll.into_iter() + .map(|(k, (cands, len))| { + ( + k, + Self::fts_compute_score(cands.len(), n, len, booster, config), + ) + }) + .collect() + } + FtsExpr::Not(fst, snd) => { + let mut res = self.fts_search_impl(fst, config, filter_code, tokenizer, n)?; + for el in self + .fts_search_impl(snd, config, filter_code, tokenizer, n)? + .keys() + { + res.remove(el); + } + res + } + }) + } + fn fts_compute_score( + tf: usize, + n: usize, + doc_len: u32, + booster: f64, + config: &FtsSearch, + ) -> f64 { + let tf = tf as f64; + match config.score_kind { + FtsScoreKind::TF => tf * booster, + FtsScoreKind::TFIDF => { + let doc_len = doc_len as f64; + let idf = ((n as f64 - doc_len + 0.5) / (doc_len + 0.5)).ln(); + tf * idf * booster + } + } + } pub(crate) fn fts_search( &self, q: &str, @@ -120,14 +271,8 @@ impl<'a> SessionTx<'a> { if ast.is_empty() { return Ok(vec![]); } - match cache.results_cache.entry(ast) { - Entry::Occupied(_) => { - todo!() - } - Entry::Vacant(_) => { - todo!() - } - } + let result = self.fts_search_impl(&ast, config, filter_code, tokenizer, 0)?; + todo!() } pub(crate) fn put_fts_index_item( &mut self,