complete FTS

main
Ziyang Hu 1 year ago
parent db9c40ef07
commit 683b66c0a7

@ -948,7 +948,7 @@ pub(crate) struct HnswSearch {
pub(crate) span: SourceSpan,
}
#[derive(Copy, Clone, Debug)]
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub(crate) enum FtsScoreKind {
TFIDF,
TF,
@ -961,12 +961,12 @@ pub(crate) struct FtsSearch {
pub(crate) manifest: FtsIndexManifest,
pub(crate) bindings: Vec<Symbol>,
pub(crate) k: usize,
pub(crate) k1: f64,
pub(crate) b: f64,
// pub(crate) k1: f64,
// pub(crate) b: f64,
pub(crate) query: Symbol,
pub(crate) score_kind: FtsScoreKind,
pub(crate) bind_score: Option<Symbol>,
pub(crate) lax_mode: bool,
// pub(crate) lax_mode: bool,
pub(crate) filter: Option<Expr>,
pub(crate) span: SourceSpan,
}
@ -1094,45 +1094,45 @@ impl SearchInput {
ensure!(k > 0, ExpectedPosIntForFtsK(self.span));
let k1 = {
match self.parameters.remove("k1") {
None => 1.2,
Some(expr) => {
let r = expr.eval_to_const()?;
let r = r
.get_float()
.ok_or_else(|| miette!("k1 for FTS must be a float"))?;
#[derive(Debug, Error, Diagnostic)]
#[error("Expected positive float for `k1`")]
#[diagnostic(code(parser::expected_float_for_hnsw_k1))]
struct ExpectedPosFloatForFtsK1(#[label] SourceSpan);
ensure!(r > 0.0, ExpectedPosFloatForFtsK1(self.span));
r
}
}
};
let b = {
match self.parameters.remove("b") {
None => 0.75,
Some(expr) => {
let r = expr.eval_to_const()?;
let r = r
.get_float()
.ok_or_else(|| miette!("b for FTS must be a float"))?;
#[derive(Debug, Error, Diagnostic)]
#[error("Expected positive float for `b`")]
#[diagnostic(code(parser::expected_float_for_hnsw_b))]
struct ExpectedPosFloatForFtsB(#[label] SourceSpan);
ensure!(r > 0.0, ExpectedPosFloatForFtsB(self.span));
r
}
}
};
// let k1 = {
// match self.parameters.remove("k1") {
// None => 1.2,
// Some(expr) => {
// let r = expr.eval_to_const()?;
// let r = r
// .get_float()
// .ok_or_else(|| miette!("k1 for FTS must be a float"))?;
//
// #[derive(Debug, Error, Diagnostic)]
// #[error("Expected positive float for `k1`")]
// #[diagnostic(code(parser::expected_float_for_hnsw_k1))]
// struct ExpectedPosFloatForFtsK1(#[label] SourceSpan);
//
// ensure!(r > 0.0, ExpectedPosFloatForFtsK1(self.span));
// r
// }
// }
// };
//
// let b = {
// match self.parameters.remove("b") {
// None => 0.75,
// Some(expr) => {
// let r = expr.eval_to_const()?;
// let r = r
// .get_float()
// .ok_or_else(|| miette!("b for FTS must be a float"))?;
//
// #[derive(Debug, Error, Diagnostic)]
// #[error("Expected positive float for `b`")]
// #[diagnostic(code(parser::expected_float_for_hnsw_b))]
// struct ExpectedPosFloatForFtsB(#[label] SourceSpan);
//
// ensure!(r > 0.0, ExpectedPosFloatForFtsB(self.span));
// r
// }
// }
// };
let score_kind_expr = self.parameters.remove("score_kind");
let score_kind = match score_kind_expr {
@ -1151,17 +1151,17 @@ impl SearchInput {
None => FtsScoreKind::TFIDF,
};
let lax_mode_expr = self.parameters.remove("lax_mode");
let lax_mode = match lax_mode_expr {
Some(expr) => {
let r = expr.eval_to_const()?;
let r = r
.get_bool()
.ok_or_else(|| miette!("Lax mode for FTS must be a boolean"))?;
r
}
None => true,
};
// let lax_mode_expr = self.parameters.remove("lax_mode");
// let lax_mode = match lax_mode_expr {
// Some(expr) => {
// let r = expr.eval_to_const()?;
// let r = r
// .get_bool()
// .ok_or_else(|| miette!("Lax mode for FTS must be a boolean"))?;
// r
// }
// None => true,
// };
let filter = self.parameters.remove("filter");
@ -1191,9 +1191,9 @@ impl SearchInput {
query,
score_kind,
bind_score,
lax_mode,
k1,
b,
// lax_mode,
// k1,
// b,
filter,
span: self.span,
}));

@ -6,7 +6,7 @@
* You can obtain one at https://mozilla.org/MPL/2.0/.
*/
use crate::data::expr::{eval_bytecode, Bytecode};
use crate::data::expr::{eval_bytecode, eval_bytecode_pred, Bytecode};
use crate::data::program::{FtsScoreKind, FtsSearch};
use crate::data::tuple::{decode_tuple_from_key, Tuple, ENCODED_KEY_MIN_LEN};
use crate::data::value::LARGEST_UTF_CHAR;
@ -15,12 +15,13 @@ use crate::fts::tokenizer::TextAnalyzer;
use crate::parse::fts::parse_fts_query;
use crate::runtime::relation::RelationHandle;
use crate::runtime::transact::SessionTx;
use crate::{decode_tuple_from_kv, DataValue, SourceSpan};
use crate::{DataValue, SourceSpan};
use itertools::Itertools;
use miette::{bail, Diagnostic, Result};
use num_traits::real::Real;
use miette::{bail, miette, Diagnostic, Result};
use ordered_float::OrderedFloat;
use rustc_hash::{FxHashMap, FxHashSet};
use smartstring::{LazyCompact, SmartString};
use std::cmp::Reverse;
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use thiserror::Error;
@ -28,15 +29,14 @@ use thiserror::Error;
#[derive(Default)]
pub(crate) struct FtsCache {
total_n_cache: FxHashMap<SmartString<LazyCompact>, usize>,
results_cache: FxHashMap<FtsExpr, Vec<(Tuple, f64)>>,
}
impl FtsCache {
fn get_n_for_relation(&mut self, rel: &RelationHandle, tx: &SessionTx<'_>) -> Result<usize> {
Ok(match self.total_n_cache.entry(rel.name.clone()) {
Entry::Vacant(v) => {
let start = rel.encode_key_for_store(&[], Default::default())?;
let end = rel.encode_key_for_store(&[DataValue::Bot], Default::default())?;
let start = rel.encode_partial_key_for_store(&[]);
let end = rel.encode_partial_key_for_store(&[DataValue::Bot]);
let val = tx.store_tx.range_count(&start, &end)?;
v.insert(val);
val
@ -47,15 +47,15 @@ impl FtsCache {
}
struct PositionInfo {
from: u32,
to: u32,
// from: u32,
// to: u32,
position: u32,
}
struct LiteralStats {
key: Tuple,
position_info: Vec<PositionInfo>,
doc_len: u32,
// doc_len: u32,
}
impl<'a> SessionTx<'a> {
@ -69,8 +69,8 @@ impl<'a> SessionTx<'a> {
let mut end_key_str = literal.value.clone();
end_key_str.push(LARGEST_UTF_CHAR);
let end_key = vec![DataValue::Str(end_key_str)];
let start_key_bytes = idx_handle.encode_key_for_store(&start_key, Default::default())?;
let end_key_bytes = idx_handle.encode_key_for_store(&end_key, Default::default())?;
let start_key_bytes = idx_handle.encode_partial_key_for_store(&start_key);
let end_key_bytes = idx_handle.encode_partial_key_for_store(&end_key);
let mut results = vec![];
for item in self.store_tx.range_scan(&start_key_bytes, &end_key_bytes) {
let (kvec, vvec) = item?;
@ -89,21 +89,21 @@ impl<'a> SessionTx<'a> {
let froms = vals[0].get_slice().unwrap();
let tos = vals[1].get_slice().unwrap();
let positions = vals[2].get_slice().unwrap();
let total_length = vals[3].get_int().unwrap();
// let total_length = vals[3].get_int().unwrap();
let position_info = froms
.iter()
.zip(tos.iter())
.zip(positions.iter())
.map(|((f, t), p)| PositionInfo {
from: f.get_int().unwrap() as u32,
to: t.get_int().unwrap() as u32,
.map(|(_, p)| PositionInfo {
// from: f.get_int().unwrap() as u32,
// to: t.get_int().unwrap() as u32,
position: p.get_int().unwrap() as u32,
})
.collect_vec();
results.push(LiteralStats {
key: key_tuple[1..].to_vec(),
position_info,
doc_len: total_length as u32,
// doc_len: total_length as u32,
});
}
Ok(results)
@ -119,11 +119,13 @@ impl<'a> SessionTx<'a> {
Ok(match ast {
FtsExpr::Literal(l) => {
let mut res = FxHashMap::default();
for el in self.fts_search_literal(l, &config.idx_handle)? {
let found_docs = self.fts_search_literal(l, &config.idx_handle)?;
let found_docs_len = found_docs.len();
for el in found_docs {
let score = Self::fts_compute_score(
el.position_info.len(),
found_docs_len,
n,
el.doc_len,
l.booster.0,
config,
);
@ -175,14 +177,11 @@ impl<'a> SessionTx<'a> {
for first_el in self.fts_search_literal(l_it.next().unwrap(), &config.idx_handle)? {
coll.insert(
first_el.key,
(
first_el
.position_info
.into_iter()
.map(|el| el.position)
.collect_vec(),
first_el.doc_len,
),
);
}
for lit_nxt in literals {
@ -191,7 +190,7 @@ impl<'a> SessionTx<'a> {
.into_iter()
.filter_map(|x| match coll.remove(&x.key) {
None => None,
Some((prev_pos, doc_len)) => {
Some(prev_pos) => {
let mut inner_coll = FxHashSet::default();
for p in prev_pos {
for pi in x.position_info.iter() {
@ -210,7 +209,7 @@ impl<'a> SessionTx<'a> {
if inner_coll.is_empty() {
None
} else {
Some((x.key, (inner_coll.into_iter().collect_vec(), doc_len)))
Some((x.key, inner_coll.into_iter().collect_vec()))
}
}
})
@ -220,11 +219,12 @@ impl<'a> SessionTx<'a> {
for lit in literals {
booster += lit.booster.0;
}
let coll_len = coll.len();
coll.into_iter()
.map(|(k, (cands, len))| {
.map(|(k, cands)| {
(
k,
Self::fts_compute_score(cands.len(), n, len, booster, config),
Self::fts_compute_score(cands.len(), coll_len, n, booster, config),
)
})
.collect()
@ -243,8 +243,8 @@ impl<'a> SessionTx<'a> {
}
fn fts_compute_score(
tf: usize,
n: usize,
doc_len: u32,
n_found_docs: usize,
n_total: usize,
booster: f64,
config: &FtsSearch,
) -> f64 {
@ -252,8 +252,8 @@ impl<'a> SessionTx<'a> {
match config.score_kind {
FtsScoreKind::TF => tf * booster,
FtsScoreKind::TFIDF => {
let doc_len = doc_len as f64;
let idf = ((n as f64 - doc_len + 0.5) / (doc_len + 0.5)).ln();
let n_found_docs = n_found_docs as f64;
let idf = (1.0 + (n_total as f64 - n_found_docs + 0.5) / (n_found_docs + 0.5)).ln();
tf * idf * booster
}
}
@ -271,8 +271,43 @@ impl<'a> SessionTx<'a> {
if ast.is_empty() {
return Ok(vec![]);
}
let result = self.fts_search_impl(&ast, config, filter_code, tokenizer, 0)?;
todo!()
let n = if config.score_kind == FtsScoreKind::TFIDF {
cache.get_n_for_relation(&config.base_handle, self)?
} else {
0
};
let mut result: Vec<_> = self
.fts_search_impl(&ast, config, filter_code, tokenizer, n)?
.into_iter()
.collect();
result.sort_by_key(|(_, score)| Reverse(OrderedFloat(*score)));
if config.filter.is_none() {
result.truncate(config.k);
}
let mut ret = Vec::with_capacity(config.k);
for (found_key, score) in result {
let mut cand_tuple = config
.base_handle
.get(self, &found_key)?
.ok_or_else(|| miette!("corrupted index"))?;
if config.bind_score.is_some() {
cand_tuple.push(DataValue::from(score));
}
if let Some((code, span)) = filter_code {
if !eval_bytecode_pred(code, &cand_tuple, stack, *span)? {
continue;
}
}
ret.push(cand_tuple);
if ret.len() >= config.k {
break;
}
}
Ok(ret)
}
pub(crate) fn put_fts_index_item(
&mut self,

@ -254,6 +254,16 @@ impl RelationHandle {
}
Ok(ret)
}
pub(crate) fn encode_partial_key_for_store(
&self,
tuple: &[DataValue],
) -> Vec<u8> {
let mut ret = self.encode_key_prefix(tuple.len());
for val in tuple {
ret.encode_datavalue(val);
}
ret
}
pub(crate) fn encode_val_for_store(
&self,
tuple: &[DataValue],

@ -937,6 +937,15 @@ fn test_fts_indexing() {
for row in res.into_json()["rows"].as_array().unwrap() {
println!("{}", row);
}
let res = db
.run_script(
r"?[k, v, s] := ~a:fts{k, v | query: 'world', k: 2, bind_score: s}",
Default::default(),
)
.unwrap();
for row in res.into_json()["rows"].as_array().unwrap() {
println!("{}", row);
}
}
#[test]

Loading…
Cancel
Save