FTS search impl

main
Ziyang Hu 1 year ago
parent 5f90c6f9eb
commit db9c40ef07

@ -950,8 +950,8 @@ pub(crate) struct HnswSearch {
#[derive(Copy, Clone, Debug)]
pub(crate) enum FtsScoreKind {
BM25,
TFIDF,
TF,
}
#[derive(Clone, Debug)]
@ -961,11 +961,12 @@ pub(crate) struct FtsSearch {
pub(crate) manifest: FtsIndexManifest,
pub(crate) bindings: Vec<Symbol>,
pub(crate) k: usize,
pub(crate) k1: f64,
pub(crate) b: f64,
pub(crate) query: Symbol,
pub(crate) score_kind: FtsScoreKind,
pub(crate) bind_score: Option<Symbol>,
pub(crate) lax_mode: bool,
pub(crate) global_idf: bool,
pub(crate) filter: Option<Expr>,
pub(crate) span: SourceSpan,
}
@ -1093,37 +1094,70 @@ impl SearchInput {
ensure!(k > 0, ExpectedPosIntForFtsK(self.span));
let k1 = {
match self.parameters.remove("k1") {
None => 1.2,
Some(expr) => {
let r = expr.eval_to_const()?;
let r = r
.get_float()
.ok_or_else(|| miette!("k1 for FTS must be a float"))?;
#[derive(Debug, Error, Diagnostic)]
#[error("Expected positive float for `k1`")]
#[diagnostic(code(parser::expected_float_for_hnsw_k1))]
struct ExpectedPosFloatForFtsK1(#[label] SourceSpan);
ensure!(r > 0.0, ExpectedPosFloatForFtsK1(self.span));
r
}
}
};
let b = {
match self.parameters.remove("b") {
None => 0.75,
Some(expr) => {
let r = expr.eval_to_const()?;
let r = r
.get_float()
.ok_or_else(|| miette!("b for FTS must be a float"))?;
#[derive(Debug, Error, Diagnostic)]
#[error("Expected positive float for `b`")]
#[diagnostic(code(parser::expected_float_for_hnsw_b))]
struct ExpectedPosFloatForFtsB(#[label] SourceSpan);
ensure!(r > 0.0, ExpectedPosFloatForFtsB(self.span));
r
}
}
};
let score_kind_expr = self.parameters.remove("score_kind");
let score_kind = match score_kind_expr {
Some(expr) => {
let r = expr.eval_to_const()?;
let r = r.get_str().ok_or_else(|| miette!("Score kind for FTS must be a string"))?;
let r = r
.get_str()
.ok_or_else(|| miette!("Score kind for FTS must be a string"))?;
match r {
"bm25" => FtsScoreKind::BM25,
"tf_idf" => FtsScoreKind::TFIDF,
s => bail!("Unknown score kind for FTS: {}", s)
"tf" => FtsScoreKind::TF,
s => bail!("Unknown score kind for FTS: {}", s),
}
}
None => FtsScoreKind::BM25,
None => FtsScoreKind::TFIDF,
};
let lax_mode_expr = self.parameters.remove("lax_mode");
let lax_mode = match lax_mode_expr {
Some(expr) => {
let r = expr.eval_to_const()?;
let r = r.get_bool().ok_or_else(|| miette!("Lax mode for FTS must be a boolean"))?;
r
}
None => true,
};
let global_idf_expr = self.parameters.remove("global_idf");
let global_idf = match global_idf_expr {
Some(expr) => {
let r = expr.eval_to_const()?;
let r = r.get_bool().ok_or_else(|| miette!("Lax mode for FTS must be a boolean"))?;
let r = r
.get_bool()
.ok_or_else(|| miette!("Lax mode for FTS must be a boolean"))?;
r
}
None => true,
@ -1158,7 +1192,8 @@ impl SearchInput {
score_kind,
bind_score,
lax_mode,
global_idf,
k1,
b,
filter,
span: self.span,
}));
@ -1408,8 +1443,7 @@ impl SearchInput {
{
return self.normalize_hnsw(base_handle, idx_handle, manifest, gen);
}
if let Some((idx_handle, manifest)) =
base_handle.fts_indices.get(&self.index.name).cloned()
if let Some((idx_handle, manifest)) = base_handle.fts_indices.get(&self.index.name).cloned()
{
return self.normalize_fts(base_handle, idx_handle, manifest, gen);
}

@ -51,15 +51,15 @@ pub(crate) enum FtsExpr {
}
impl FtsExpr {
pub(crate) fn needs_idf(&self) -> bool {
match self {
FtsExpr::Literal(_) => false,
FtsExpr::Near(_) => false,
FtsExpr::And(exprs) => exprs.iter().any(|e| e.needs_idf()),
FtsExpr::Or(_) => true,
FtsExpr::Not(lhs, _) => lhs.needs_idf(),
}
}
// pub(crate) fn needs_idf(&self) -> bool {
// match self {
// FtsExpr::Literal(_) => false,
// FtsExpr::Near(_) => false,
// FtsExpr::And(exprs) => exprs.iter().any(|e| e.needs_idf()),
// FtsExpr::Or(_) => true,
// FtsExpr::Not(lhs, _) => lhs.needs_idf(),
// }
// }
pub(crate) fn tokenize(self, tokenizer: &TextAnalyzer) -> Self {
self.do_tokenize(tokenizer).flatten()

@ -7,10 +7,10 @@
*/
use crate::data::expr::{eval_bytecode, Bytecode};
use crate::data::program::FtsSearch;
use crate::data::program::{FtsScoreKind, FtsSearch};
use crate::data::tuple::{decode_tuple_from_key, Tuple, ENCODED_KEY_MIN_LEN};
use crate::data::value::LARGEST_UTF_CHAR;
use crate::fts::ast::{FtsExpr, FtsLiteral};
use crate::fts::ast::{FtsExpr, FtsLiteral, FtsNear};
use crate::fts::tokenizer::TextAnalyzer;
use crate::parse::fts::parse_fts_query;
use crate::runtime::relation::RelationHandle;
@ -18,6 +18,7 @@ use crate::runtime::transact::SessionTx;
use crate::{decode_tuple_from_kv, DataValue, SourceSpan};
use itertools::Itertools;
use miette::{bail, Diagnostic, Result};
use num_traits::real::Real;
use rustc_hash::{FxHashMap, FxHashSet};
use smartstring::{LazyCompact, SmartString};
use std::collections::hash_map::Entry;
@ -58,7 +59,7 @@ struct LiteralStats {
}
impl<'a> SessionTx<'a> {
fn search_literal(
fn fts_search_literal(
&self,
literal: &FtsLiteral,
idx_handle: &RelationHandle,
@ -107,6 +108,156 @@ impl<'a> SessionTx<'a> {
}
Ok(results)
}
fn fts_search_impl(
&self,
ast: &FtsExpr,
config: &FtsSearch,
filter_code: &Option<(Vec<Bytecode>, SourceSpan)>,
tokenizer: &TextAnalyzer,
n: usize,
) -> Result<FxHashMap<Tuple, f64>> {
Ok(match ast {
FtsExpr::Literal(l) => {
let mut res = FxHashMap::default();
for el in self.fts_search_literal(l, &config.idx_handle)? {
let score = Self::fts_compute_score(
el.position_info.len(),
n,
el.doc_len,
l.booster.0,
config,
);
res.insert(el.key, score);
}
res
}
FtsExpr::And(ls) => {
let mut l_iter = ls.iter();
let mut res = self.fts_search_impl(
l_iter.next().unwrap(),
config,
filter_code,
tokenizer,
n,
)?;
for nxt in l_iter {
let nxt_res = self.fts_search_impl(nxt, config, filter_code, tokenizer, n)?;
res = res
.into_iter()
.filter_map(|(k, v)| {
if let Some(nxt_v) = nxt_res.get(&k) {
Some((k, v + nxt_v))
} else {
None
}
})
.collect();
}
res
}
FtsExpr::Or(ls) => {
let mut res: FxHashMap<Tuple, f64> = FxHashMap::default();
for nxt in ls {
let nxt_res = self.fts_search_impl(nxt, config, filter_code, tokenizer, n)?;
for (k, v) in nxt_res {
if let Some(old_v) = res.get_mut(&k) {
*old_v = (*old_v).max(v);
} else {
res.insert(k, v);
}
}
}
res
}
FtsExpr::Near(FtsNear { literals, distance }) => {
let mut l_it = literals.iter();
let mut coll: FxHashMap<_, _> = FxHashMap::default();
for first_el in self.fts_search_literal(l_it.next().unwrap(), &config.idx_handle)? {
coll.insert(
first_el.key,
(
first_el
.position_info
.into_iter()
.map(|el| el.position)
.collect_vec(),
first_el.doc_len,
),
);
}
for lit_nxt in literals {
let el_res = self.fts_search_literal(lit_nxt, &config.idx_handle)?;
coll = el_res
.into_iter()
.filter_map(|x| match coll.remove(&x.key) {
None => None,
Some((prev_pos, doc_len)) => {
let mut inner_coll = FxHashSet::default();
for p in prev_pos {
for pi in x.position_info.iter() {
let cur = pi.position;
if cur > p {
if cur - p <= *distance {
inner_coll.insert(p);
}
} else {
if p - cur <= *distance {
inner_coll.insert(cur);
}
}
}
}
if inner_coll.is_empty() {
None
} else {
Some((x.key, (inner_coll.into_iter().collect_vec(), doc_len)))
}
}
})
.collect();
}
let mut booster = 0.0;
for lit in literals {
booster += lit.booster.0;
}
coll.into_iter()
.map(|(k, (cands, len))| {
(
k,
Self::fts_compute_score(cands.len(), n, len, booster, config),
)
})
.collect()
}
FtsExpr::Not(fst, snd) => {
let mut res = self.fts_search_impl(fst, config, filter_code, tokenizer, n)?;
for el in self
.fts_search_impl(snd, config, filter_code, tokenizer, n)?
.keys()
{
res.remove(el);
}
res
}
})
}
fn fts_compute_score(
tf: usize,
n: usize,
doc_len: u32,
booster: f64,
config: &FtsSearch,
) -> f64 {
let tf = tf as f64;
match config.score_kind {
FtsScoreKind::TF => tf * booster,
FtsScoreKind::TFIDF => {
let doc_len = doc_len as f64;
let idf = ((n as f64 - doc_len + 0.5) / (doc_len + 0.5)).ln();
tf * idf * booster
}
}
}
pub(crate) fn fts_search(
&self,
q: &str,
@ -120,15 +271,9 @@ impl<'a> SessionTx<'a> {
if ast.is_empty() {
return Ok(vec![]);
}
match cache.results_cache.entry(ast) {
Entry::Occupied(_) => {
todo!()
}
Entry::Vacant(_) => {
let result = self.fts_search_impl(&ast, config, filter_code, tokenizer, 0)?;
todo!()
}
}
}
pub(crate) fn put_fts_index_item(
&mut self,
tuple: &[DataValue],

Loading…
Cancel
Save