diff --git a/cozo-core/src/data/program.rs b/cozo-core/src/data/program.rs index f1128e0b..fc860056 100644 --- a/cozo-core/src/data/program.rs +++ b/cozo-core/src/data/program.rs @@ -948,7 +948,7 @@ pub(crate) struct HnswSearch { pub(crate) span: SourceSpan, } -#[derive(Copy, Clone, Debug)] +#[derive(Copy, Clone, Debug, PartialEq, Eq)] pub(crate) enum FtsScoreKind { TFIDF, TF, @@ -961,12 +961,12 @@ pub(crate) struct FtsSearch { pub(crate) manifest: FtsIndexManifest, pub(crate) bindings: Vec, pub(crate) k: usize, - pub(crate) k1: f64, - pub(crate) b: f64, + // pub(crate) k1: f64, + // pub(crate) b: f64, pub(crate) query: Symbol, pub(crate) score_kind: FtsScoreKind, pub(crate) bind_score: Option, - pub(crate) lax_mode: bool, + // pub(crate) lax_mode: bool, pub(crate) filter: Option, pub(crate) span: SourceSpan, } @@ -1094,45 +1094,45 @@ impl SearchInput { ensure!(k > 0, ExpectedPosIntForFtsK(self.span)); - let k1 = { - match self.parameters.remove("k1") { - None => 1.2, - Some(expr) => { - let r = expr.eval_to_const()?; - let r = r - .get_float() - .ok_or_else(|| miette!("k1 for FTS must be a float"))?; - - #[derive(Debug, Error, Diagnostic)] - #[error("Expected positive float for `k1`")] - #[diagnostic(code(parser::expected_float_for_hnsw_k1))] - struct ExpectedPosFloatForFtsK1(#[label] SourceSpan); - - ensure!(r > 0.0, ExpectedPosFloatForFtsK1(self.span)); - r - } - } - }; - - let b = { - match self.parameters.remove("b") { - None => 0.75, - Some(expr) => { - let r = expr.eval_to_const()?; - let r = r - .get_float() - .ok_or_else(|| miette!("b for FTS must be a float"))?; - - #[derive(Debug, Error, Diagnostic)] - #[error("Expected positive float for `b`")] - #[diagnostic(code(parser::expected_float_for_hnsw_b))] - struct ExpectedPosFloatForFtsB(#[label] SourceSpan); - - ensure!(r > 0.0, ExpectedPosFloatForFtsB(self.span)); - r - } - } - }; + // let k1 = { + // match self.parameters.remove("k1") { + // None => 1.2, + // Some(expr) => { + // let r = expr.eval_to_const()?; + // let r = r + // .get_float() + // .ok_or_else(|| miette!("k1 for FTS must be a float"))?; + // + // #[derive(Debug, Error, Diagnostic)] + // #[error("Expected positive float for `k1`")] + // #[diagnostic(code(parser::expected_float_for_hnsw_k1))] + // struct ExpectedPosFloatForFtsK1(#[label] SourceSpan); + // + // ensure!(r > 0.0, ExpectedPosFloatForFtsK1(self.span)); + // r + // } + // } + // }; + // + // let b = { + // match self.parameters.remove("b") { + // None => 0.75, + // Some(expr) => { + // let r = expr.eval_to_const()?; + // let r = r + // .get_float() + // .ok_or_else(|| miette!("b for FTS must be a float"))?; + // + // #[derive(Debug, Error, Diagnostic)] + // #[error("Expected positive float for `b`")] + // #[diagnostic(code(parser::expected_float_for_hnsw_b))] + // struct ExpectedPosFloatForFtsB(#[label] SourceSpan); + // + // ensure!(r > 0.0, ExpectedPosFloatForFtsB(self.span)); + // r + // } + // } + // }; let score_kind_expr = self.parameters.remove("score_kind"); let score_kind = match score_kind_expr { @@ -1151,17 +1151,17 @@ impl SearchInput { None => FtsScoreKind::TFIDF, }; - let lax_mode_expr = self.parameters.remove("lax_mode"); - let lax_mode = match lax_mode_expr { - Some(expr) => { - let r = expr.eval_to_const()?; - let r = r - .get_bool() - .ok_or_else(|| miette!("Lax mode for FTS must be a boolean"))?; - r - } - None => true, - }; + // let lax_mode_expr = self.parameters.remove("lax_mode"); + // let lax_mode = match lax_mode_expr { + // Some(expr) => { + // let r = expr.eval_to_const()?; + // let r = r + // .get_bool() + // .ok_or_else(|| miette!("Lax mode for FTS must be a boolean"))?; + // r + // } + // None => true, + // }; let filter = self.parameters.remove("filter"); @@ -1191,9 +1191,9 @@ impl SearchInput { query, score_kind, bind_score, - lax_mode, - k1, - b, + // lax_mode, + // k1, + // b, filter, span: self.span, })); diff --git a/cozo-core/src/fts/indexing.rs b/cozo-core/src/fts/indexing.rs index c77d3001..26379935 100644 --- a/cozo-core/src/fts/indexing.rs +++ b/cozo-core/src/fts/indexing.rs @@ -6,7 +6,7 @@ * You can obtain one at https://mozilla.org/MPL/2.0/. */ -use crate::data::expr::{eval_bytecode, Bytecode}; +use crate::data::expr::{eval_bytecode, eval_bytecode_pred, Bytecode}; use crate::data::program::{FtsScoreKind, FtsSearch}; use crate::data::tuple::{decode_tuple_from_key, Tuple, ENCODED_KEY_MIN_LEN}; use crate::data::value::LARGEST_UTF_CHAR; @@ -15,12 +15,13 @@ use crate::fts::tokenizer::TextAnalyzer; use crate::parse::fts::parse_fts_query; use crate::runtime::relation::RelationHandle; use crate::runtime::transact::SessionTx; -use crate::{decode_tuple_from_kv, DataValue, SourceSpan}; +use crate::{DataValue, SourceSpan}; use itertools::Itertools; -use miette::{bail, Diagnostic, Result}; -use num_traits::real::Real; +use miette::{bail, miette, Diagnostic, Result}; +use ordered_float::OrderedFloat; use rustc_hash::{FxHashMap, FxHashSet}; use smartstring::{LazyCompact, SmartString}; +use std::cmp::Reverse; use std::collections::hash_map::Entry; use std::collections::HashMap; use thiserror::Error; @@ -28,15 +29,14 @@ use thiserror::Error; #[derive(Default)] pub(crate) struct FtsCache { total_n_cache: FxHashMap, usize>, - results_cache: FxHashMap>, } impl FtsCache { fn get_n_for_relation(&mut self, rel: &RelationHandle, tx: &SessionTx<'_>) -> Result { Ok(match self.total_n_cache.entry(rel.name.clone()) { Entry::Vacant(v) => { - let start = rel.encode_key_for_store(&[], Default::default())?; - let end = rel.encode_key_for_store(&[DataValue::Bot], Default::default())?; + let start = rel.encode_partial_key_for_store(&[]); + let end = rel.encode_partial_key_for_store(&[DataValue::Bot]); let val = tx.store_tx.range_count(&start, &end)?; v.insert(val); val @@ -47,15 +47,15 @@ impl FtsCache { } struct PositionInfo { - from: u32, - to: u32, + // from: u32, + // to: u32, position: u32, } struct LiteralStats { key: Tuple, position_info: Vec, - doc_len: u32, + // doc_len: u32, } impl<'a> SessionTx<'a> { @@ -69,8 +69,8 @@ impl<'a> SessionTx<'a> { let mut end_key_str = literal.value.clone(); end_key_str.push(LARGEST_UTF_CHAR); let end_key = vec![DataValue::Str(end_key_str)]; - let start_key_bytes = idx_handle.encode_key_for_store(&start_key, Default::default())?; - let end_key_bytes = idx_handle.encode_key_for_store(&end_key, Default::default())?; + let start_key_bytes = idx_handle.encode_partial_key_for_store(&start_key); + let end_key_bytes = idx_handle.encode_partial_key_for_store(&end_key); let mut results = vec![]; for item in self.store_tx.range_scan(&start_key_bytes, &end_key_bytes) { let (kvec, vvec) = item?; @@ -89,21 +89,21 @@ impl<'a> SessionTx<'a> { let froms = vals[0].get_slice().unwrap(); let tos = vals[1].get_slice().unwrap(); let positions = vals[2].get_slice().unwrap(); - let total_length = vals[3].get_int().unwrap(); + // let total_length = vals[3].get_int().unwrap(); let position_info = froms .iter() .zip(tos.iter()) .zip(positions.iter()) - .map(|((f, t), p)| PositionInfo { - from: f.get_int().unwrap() as u32, - to: t.get_int().unwrap() as u32, + .map(|(_, p)| PositionInfo { + // from: f.get_int().unwrap() as u32, + // to: t.get_int().unwrap() as u32, position: p.get_int().unwrap() as u32, }) .collect_vec(); results.push(LiteralStats { key: key_tuple[1..].to_vec(), position_info, - doc_len: total_length as u32, + // doc_len: total_length as u32, }); } Ok(results) @@ -119,11 +119,13 @@ impl<'a> SessionTx<'a> { Ok(match ast { FtsExpr::Literal(l) => { let mut res = FxHashMap::default(); - for el in self.fts_search_literal(l, &config.idx_handle)? { + let found_docs = self.fts_search_literal(l, &config.idx_handle)?; + let found_docs_len = found_docs.len(); + for el in found_docs { let score = Self::fts_compute_score( el.position_info.len(), + found_docs_len, n, - el.doc_len, l.booster.0, config, ); @@ -175,14 +177,11 @@ impl<'a> SessionTx<'a> { for first_el in self.fts_search_literal(l_it.next().unwrap(), &config.idx_handle)? { coll.insert( first_el.key, - ( - first_el - .position_info - .into_iter() - .map(|el| el.position) - .collect_vec(), - first_el.doc_len, - ), + first_el + .position_info + .into_iter() + .map(|el| el.position) + .collect_vec(), ); } for lit_nxt in literals { @@ -191,7 +190,7 @@ impl<'a> SessionTx<'a> { .into_iter() .filter_map(|x| match coll.remove(&x.key) { None => None, - Some((prev_pos, doc_len)) => { + Some(prev_pos) => { let mut inner_coll = FxHashSet::default(); for p in prev_pos { for pi in x.position_info.iter() { @@ -210,7 +209,7 @@ impl<'a> SessionTx<'a> { if inner_coll.is_empty() { None } else { - Some((x.key, (inner_coll.into_iter().collect_vec(), doc_len))) + Some((x.key, inner_coll.into_iter().collect_vec())) } } }) @@ -220,11 +219,12 @@ impl<'a> SessionTx<'a> { for lit in literals { booster += lit.booster.0; } + let coll_len = coll.len(); coll.into_iter() - .map(|(k, (cands, len))| { + .map(|(k, cands)| { ( k, - Self::fts_compute_score(cands.len(), n, len, booster, config), + Self::fts_compute_score(cands.len(), coll_len, n, booster, config), ) }) .collect() @@ -243,8 +243,8 @@ impl<'a> SessionTx<'a> { } fn fts_compute_score( tf: usize, - n: usize, - doc_len: u32, + n_found_docs: usize, + n_total: usize, booster: f64, config: &FtsSearch, ) -> f64 { @@ -252,8 +252,8 @@ impl<'a> SessionTx<'a> { match config.score_kind { FtsScoreKind::TF => tf * booster, FtsScoreKind::TFIDF => { - let doc_len = doc_len as f64; - let idf = ((n as f64 - doc_len + 0.5) / (doc_len + 0.5)).ln(); + let n_found_docs = n_found_docs as f64; + let idf = (1.0 + (n_total as f64 - n_found_docs + 0.5) / (n_found_docs + 0.5)).ln(); tf * idf * booster } } @@ -271,8 +271,43 @@ impl<'a> SessionTx<'a> { if ast.is_empty() { return Ok(vec![]); } - let result = self.fts_search_impl(&ast, config, filter_code, tokenizer, 0)?; - todo!() + let n = if config.score_kind == FtsScoreKind::TFIDF { + cache.get_n_for_relation(&config.base_handle, self)? + } else { + 0 + }; + let mut result: Vec<_> = self + .fts_search_impl(&ast, config, filter_code, tokenizer, n)? + .into_iter() + .collect(); + result.sort_by_key(|(_, score)| Reverse(OrderedFloat(*score))); + if config.filter.is_none() { + result.truncate(config.k); + } + + let mut ret = Vec::with_capacity(config.k); + for (found_key, score) in result { + let mut cand_tuple = config + .base_handle + .get(self, &found_key)? + .ok_or_else(|| miette!("corrupted index"))?; + + if config.bind_score.is_some() { + cand_tuple.push(DataValue::from(score)); + } + + if let Some((code, span)) = filter_code { + if !eval_bytecode_pred(code, &cand_tuple, stack, *span)? { + continue; + } + } + + ret.push(cand_tuple); + if ret.len() >= config.k { + break; + } + } + Ok(ret) } pub(crate) fn put_fts_index_item( &mut self, diff --git a/cozo-core/src/runtime/relation.rs b/cozo-core/src/runtime/relation.rs index 131be264..6d94f21d 100644 --- a/cozo-core/src/runtime/relation.rs +++ b/cozo-core/src/runtime/relation.rs @@ -254,6 +254,16 @@ impl RelationHandle { } Ok(ret) } + pub(crate) fn encode_partial_key_for_store( + &self, + tuple: &[DataValue], + ) -> Vec { + let mut ret = self.encode_key_prefix(tuple.len()); + for val in tuple { + ret.encode_datavalue(val); + } + ret + } pub(crate) fn encode_val_for_store( &self, tuple: &[DataValue], diff --git a/cozo-core/src/runtime/tests.rs b/cozo-core/src/runtime/tests.rs index 6fa52f18..70580656 100644 --- a/cozo-core/src/runtime/tests.rs +++ b/cozo-core/src/runtime/tests.rs @@ -937,6 +937,15 @@ fn test_fts_indexing() { for row in res.into_json()["rows"].as_array().unwrap() { println!("{}", row); } + let res = db + .run_script( + r"?[k, v, s] := ~a:fts{k, v | query: 'world', k: 2, bind_score: s}", + Default::default(), + ) + .unwrap(); + for row in res.into_json()["rows"].as_array().unwrap() { + println!("{}", row); + } } #[test]