complete LSH

main
Ziyang Hu 1 year ago
parent 683b66c0a7
commit ba98a5c137

19
Cargo.lock generated

@ -724,6 +724,7 @@ dependencies = [
"pest",
"pest_derive",
"priority-queue",
"quadrature",
"rand 0.8.5",
"rayon",
"regex",
@ -746,6 +747,7 @@ dependencies = [
"tikv-client",
"tikv-jemallocator-global",
"tokio",
"twox-hash",
"unicode-normalization",
"uuid",
]
@ -2945,6 +2947,12 @@ dependencies = [
"syn 1.0.109",
]
[[package]]
name = "quadrature"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2054ccb02f454fcb2bc81e343aa0a171636a6331003fd5ec24c47a10966634b7"
[[package]]
name = "quote"
version = "1.0.26"
@ -4113,6 +4121,17 @@ version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3528ecfd12c466c6f163363caf2d02a71161dd5e1cc6ae7b34207ea2d42d81ed"
[[package]]
name = "twox-hash"
version = "1.6.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "97fee6b57c6a41524a810daee9286c02d7752c4253064d0b05472833a438f675"
dependencies = [
"cfg-if 1.0.0",
"rand 0.8.5",
"static_assertions",
]
[[package]]
name = "typenum"
version = "1.16.0"

@ -130,6 +130,8 @@ crossbeam = "0.8.2"
ndarray = { version = "0.15.6", features = ["serde"] }
sha2 = "0.10.6"
rustc-hash = "1.1.0"
twox-hash = "1.6.3"
quadrature = "0.1.2"
# For the FTS feature
jieba-rs = "0.6.7"
aho-corasick = "1.0.1"

@ -56,7 +56,7 @@ underscore_ident = @{("_" | XID_START) ~ ("_" | XID_CONTINUE)*}
relation_ident = @{"*" ~ (compound_or_index_ident | underscore_ident)}
search_index_ident = _{"~" ~ compound_or_index_ident}
compound_ident = @{ident ~ ("." ~ ident)*}
compound_or_index_ident = @{ident ~ ("." ~ ident)* ~ (":" ~ ident)?}
compound_or_index_ident = @{ident ~ ("." ~ ident)* ~ (":" ~ ident)*}
rule = {rule_head ~ ":=" ~ rule_body ~ ";"?}
const_rule = {rule_head ~ "<-" ~ expr ~ ";"?}

@ -18,6 +18,7 @@ use thiserror::Error;
use crate::data::aggr::Aggregation;
use crate::data::expr::Expr;
use crate::data::functions::OP_LIST;
use crate::data::relation::StoredRelationMetadata;
use crate::data::symb::{Symbol, PROG_ENTRY};
use crate::data::value::{DataValue, ValidityTs};
@ -26,6 +27,7 @@ use crate::fts::FtsIndexManifest;
use crate::parse::SourceSpan;
use crate::query::logical::{Disjunction, NamedFieldNotFound};
use crate::runtime::hnsw::HnswIndexManifest;
use crate::runtime::minhash_lsh::{LshSearch, MinHashLshIndexManifest};
use crate::runtime::relation::{
AccessLevel, InputRelationHandle, InsufficientAccessLevel, RelationHandle,
};
@ -950,8 +952,8 @@ pub(crate) struct HnswSearch {
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub(crate) enum FtsScoreKind {
TFIDF,
TF,
TfIdf,
Tf,
}
#[derive(Clone, Debug)]
@ -989,6 +991,212 @@ impl FtsSearch {
}
impl SearchInput {
fn normalize_lsh(
mut self,
base_handle: RelationHandle,
idx_handle: RelationHandle,
inv_idx_handle: RelationHandle,
manifest: MinHashLshIndexManifest,
gen: &mut TempSymbGen,
) -> Result<Disjunction> {
let mut conj = Vec::with_capacity(self.bindings.len() + 8);
let mut bindings = Vec::with_capacity(self.bindings.len());
let mut seen_variables = BTreeSet::new();
for col in base_handle
.metadata
.keys
.iter()
.chain(base_handle.metadata.non_keys.iter())
{
if let Some(arg) = self.bindings.remove(&col.name) {
match arg {
Expr::Binding { var, .. } => {
if var.is_ignored_symbol() {
bindings.push(gen.next_ignored(var.span));
} else if seen_variables.insert(var.clone()) {
bindings.push(var);
} else {
let span = var.span;
let dup = gen.next(span);
let unif = NormalFormAtom::Unification(Unification {
binding: dup.clone(),
expr: Expr::Binding {
var,
tuple_pos: None,
},
one_many_unif: false,
span,
});
conj.push(unif);
bindings.push(dup);
}
}
expr => {
let span = expr.span();
let kw = gen.next(span);
bindings.push(kw.clone());
let unif = NormalFormAtom::Unification(Unification {
binding: kw,
expr,
one_many_unif: false,
span,
});
conj.push(unif)
}
}
} else {
bindings.push(gen.next_ignored(self.span));
}
}
if let Some((name, _)) = self.bindings.pop_first() {
bail!(NamedFieldNotFound(
self.relation.name.to_string(),
name.to_string(),
self.span
));
}
#[derive(Debug, Error, Diagnostic)]
#[error("Field `{0}` is required for LSH search")]
#[diagnostic(code(parser::hnsw_query_required))]
struct LshRequiredMissing(String, #[label] SourceSpan);
#[derive(Debug, Error, Diagnostic)]
#[error("Expected a list of keys for LSH search")]
#[diagnostic(code(parser::expected_list_for_lsh_keys))]
struct ExpectedListForLshKeys(#[label] SourceSpan);
#[derive(Debug, Error, Diagnostic)]
#[error("Wrong arity for LSH keys, expected {1}, got {2}")]
#[diagnostic(code(parser::wrong_arity_for_lsh_keys))]
struct WrongArityForKeys(#[label] SourceSpan, usize, usize);
let query = match self.parameters.remove("keys") {
None => match self.parameters.remove("key") {
None => {
bail!(LshRequiredMissing("keys".to_string(), self.span))
}
Some(expr) => {
ensure!(
base_handle.indices.keys().len() == 1,
LshRequiredMissing("keys".to_string(), self.span)
);
let span = expr.span();
let kw = gen.next(span);
let unif = NormalFormAtom::Unification(Unification {
binding: kw.clone(),
expr: Expr::Apply {
op: &OP_LIST,
args: [expr].into(),
span,
},
one_many_unif: false,
span,
});
conj.push(unif);
kw
}
},
Some(mut expr) => {
expr.partial_eval()?;
match expr {
Expr::Apply { op, args, span } => {
ensure!(op.name == OP_LIST.name, ExpectedListForLshKeys(span));
ensure!(
args.len() == base_handle.indices.keys().len(),
WrongArityForKeys(span, base_handle.indices.keys().len(), args.len())
);
let kw = gen.next(span);
let unif = NormalFormAtom::Unification(Unification {
binding: kw.clone(),
expr: Expr::Apply { op, args, span },
one_many_unif: false,
span,
});
conj.push(unif);
kw
}
_ => {
bail!(ExpectedListForLshKeys(self.span))
}
}
}
};
let k = match self.parameters.remove("k") {
None => None,
Some(k_expr) => {
let k = k_expr.eval_to_const()?;
let k = k.get_int().ok_or(ExpectedPosIntForFtsK(self.span))?;
#[derive(Debug, Error, Diagnostic)]
#[error("Expected positive integer for `k`")]
#[diagnostic(code(parser::expected_int_for_hnsw_k))]
struct ExpectedPosIntForFtsK(#[label] SourceSpan);
ensure!(k > 0, ExpectedPosIntForFtsK(self.span));
Some(k as usize)
}
};
let filter = self.parameters.remove("filter");
let bind_similarity = match self.parameters.remove("bind_similarity") {
None => None,
Some(Expr::Binding { var, .. }) => Some(var),
Some(expr) => {
let span = expr.span();
let kw = gen.next(span);
let unif = NormalFormAtom::Unification(Unification {
binding: kw.clone(),
expr,
one_many_unif: false,
span,
});
conj.push(unif);
Some(kw)
}
};
let min_similarity = match self.parameters.remove("min_similarity") {
None => manifest.threshold,
Some(expr) => {
let min_similarity = expr.eval_to_const()?;
let min_similarity = min_similarity
.get_float()
.ok_or(ExpectedFloatForMinSimilarity(self.span))?;
#[derive(Debug, Error, Diagnostic)]
#[error("Expected float for `min_similarity`")]
#[diagnostic(code(parser::expected_float_for_min_similarity))]
struct ExpectedFloatForMinSimilarity(#[label] SourceSpan);
ensure!(
(0.0..=1.0).contains(&min_similarity),
ExpectedFloatForMinSimilarity(self.span)
);
min_similarity
}
};
conj.push(NormalFormAtom::LshSearch(LshSearch {
base_handle,
idx_handle,
inv_idx_handle,
manifest,
bindings,
k,
bind_similarity,
query,
span: self.span,
min_similarity,
filter,
}));
Ok(Disjunction::conj(conj))
}
fn normalize_fts(
mut self,
base_handle: RelationHandle,
@ -1094,46 +1302,6 @@ impl SearchInput {
ensure!(k > 0, ExpectedPosIntForFtsK(self.span));
// let k1 = {
// match self.parameters.remove("k1") {
// None => 1.2,
// Some(expr) => {
// let r = expr.eval_to_const()?;
// let r = r
// .get_float()
// .ok_or_else(|| miette!("k1 for FTS must be a float"))?;
//
// #[derive(Debug, Error, Diagnostic)]
// #[error("Expected positive float for `k1`")]
// #[diagnostic(code(parser::expected_float_for_hnsw_k1))]
// struct ExpectedPosFloatForFtsK1(#[label] SourceSpan);
//
// ensure!(r > 0.0, ExpectedPosFloatForFtsK1(self.span));
// r
// }
// }
// };
//
// let b = {
// match self.parameters.remove("b") {
// None => 0.75,
// Some(expr) => {
// let r = expr.eval_to_const()?;
// let r = r
// .get_float()
// .ok_or_else(|| miette!("b for FTS must be a float"))?;
//
// #[derive(Debug, Error, Diagnostic)]
// #[error("Expected positive float for `b`")]
// #[diagnostic(code(parser::expected_float_for_hnsw_b))]
// struct ExpectedPosFloatForFtsB(#[label] SourceSpan);
//
// ensure!(r > 0.0, ExpectedPosFloatForFtsB(self.span));
// r
// }
// }
// };
let score_kind_expr = self.parameters.remove("score_kind");
let score_kind = match score_kind_expr {
Some(expr) => {
@ -1143,26 +1311,14 @@ impl SearchInput {
.ok_or_else(|| miette!("Score kind for FTS must be a string"))?;
match r {
"tf_idf" => FtsScoreKind::TFIDF,
"tf" => FtsScoreKind::TF,
"tf_idf" => FtsScoreKind::TfIdf,
"tf" => FtsScoreKind::Tf,
s => bail!("Unknown score kind for FTS: {}", s),
}
}
None => FtsScoreKind::TFIDF,
None => FtsScoreKind::TfIdf,
};
// let lax_mode_expr = self.parameters.remove("lax_mode");
// let lax_mode = match lax_mode_expr {
// Some(expr) => {
// let r = expr.eval_to_const()?;
// let r = r
// .get_bool()
// .ok_or_else(|| miette!("Lax mode for FTS must be a boolean"))?;
// r
// }
// None => true,
// };
let filter = self.parameters.remove("filter");
let bind_score = match self.parameters.remove("bind_score") {
@ -1447,6 +1603,11 @@ impl SearchInput {
{
return self.normalize_fts(base_handle, idx_handle, manifest, gen);
}
if let Some((idx_handle, inv_idx_handle, manifest)) =
base_handle.lsh_indices.get(&self.index.name).cloned()
{
return self.normalize_lsh(base_handle, idx_handle, inv_idx_handle, manifest, gen);
}
#[derive(Debug, Error, Diagnostic)]
#[error("Index {name} not found on relation {relation}")]
#[diagnostic(code(eval::hnsw_index_not_found))]
@ -1586,6 +1747,7 @@ pub(crate) enum NormalFormAtom {
Unification(Unification),
HnswSearch(HnswSearch),
FtsSearch(FtsSearch),
LshSearch(LshSearch),
}
#[derive(Debug, Clone)]
@ -1598,6 +1760,7 @@ pub(crate) enum MagicAtom {
Unification(Unification),
HnswSearch(HnswSearch),
FtsSearch(FtsSearch),
LshSearch(LshSearch),
}
#[derive(Clone, Debug)]

@ -619,6 +619,13 @@ impl Display for DataValue {
}
impl DataValue {
/// Returns a slice of bytes if this one is a Bytes
pub fn get_bytes(&self) -> Option<&[u8]> {
match self {
DataValue::Bytes(b) => Some(b),
_ => None,
}
}
/// Returns a slice of DataValues if this one is a List
pub fn get_slice(&self) -> Option<&[DataValue]> {
match self {

@ -80,11 +80,10 @@ impl<'a> SessionTx<'a> {
if !found_str_key.starts_with(start_key_str) {
break;
}
} else {
if found_str_key != start_key_str {
} else if found_str_key != start_key_str {
break;
}
}
let vals: Vec<DataValue> = rmp_serde::from_slice(&vvec[ENCODED_KEY_MIN_LEN..]).unwrap();
let froms = vals[0].get_slice().unwrap();
let tos = vals[1].get_slice().unwrap();
@ -112,8 +111,6 @@ impl<'a> SessionTx<'a> {
&self,
ast: &FtsExpr,
config: &FtsSearch,
filter_code: &Option<(Vec<Bytecode>, SourceSpan)>,
tokenizer: &TextAnalyzer,
n: usize,
) -> Result<FxHashMap<Tuple, f64>> {
Ok(match ast {
@ -138,21 +135,13 @@ impl<'a> SessionTx<'a> {
let mut res = self.fts_search_impl(
l_iter.next().unwrap(),
config,
filter_code,
tokenizer,
n,
)?;
for nxt in l_iter {
let nxt_res = self.fts_search_impl(nxt, config, filter_code, tokenizer, n)?;
let nxt_res = self.fts_search_impl(nxt, config, n)?;
res = res
.into_iter()
.filter_map(|(k, v)| {
if let Some(nxt_v) = nxt_res.get(&k) {
Some((k, v + nxt_v))
} else {
None
}
})
.filter_map(|(k, v)| nxt_res.get(&k).map(|nxt_v| (k, v + nxt_v)))
.collect();
}
res
@ -160,7 +149,7 @@ impl<'a> SessionTx<'a> {
FtsExpr::Or(ls) => {
let mut res: FxHashMap<Tuple, f64> = FxHashMap::default();
for nxt in ls {
let nxt_res = self.fts_search_impl(nxt, config, filter_code, tokenizer, n)?;
let nxt_res = self.fts_search_impl(nxt, config, n)?;
for (k, v) in nxt_res {
if let Some(old_v) = res.get_mut(&k) {
*old_v = (*old_v).max(v);
@ -199,13 +188,11 @@ impl<'a> SessionTx<'a> {
if cur - p <= *distance {
inner_coll.insert(p);
}
} else {
if p - cur <= *distance {
} else if p - cur <= *distance {
inner_coll.insert(cur);
}
}
}
}
if inner_coll.is_empty() {
None
} else {
@ -230,9 +217,9 @@ impl<'a> SessionTx<'a> {
.collect()
}
FtsExpr::Not(fst, snd) => {
let mut res = self.fts_search_impl(fst, config, filter_code, tokenizer, n)?;
let mut res = self.fts_search_impl(fst, config, n)?;
for el in self
.fts_search_impl(snd, config, filter_code, tokenizer, n)?
.fts_search_impl(snd, config, n)?
.keys()
{
res.remove(el);
@ -250,8 +237,8 @@ impl<'a> SessionTx<'a> {
) -> f64 {
let tf = tf as f64;
match config.score_kind {
FtsScoreKind::TF => tf * booster,
FtsScoreKind::TFIDF => {
FtsScoreKind::Tf => tf * booster,
FtsScoreKind::TfIdf => {
let n_found_docs = n_found_docs as f64;
let idf = (1.0 + (n_total as f64 - n_found_docs + 0.5) / (n_found_docs + 0.5)).ln();
tf * idf * booster
@ -271,13 +258,13 @@ impl<'a> SessionTx<'a> {
if ast.is_empty() {
return Ok(vec![]);
}
let n = if config.score_kind == FtsScoreKind::TFIDF {
let n = if config.score_kind == FtsScoreKind::TfIdf {
cache.get_n_for_relation(&config.base_handle, self)?
} else {
0
};
let mut result: Vec<_> = self
.fts_search_impl(&ast, config, filter_code, tokenizer, n)?
.fts_search_impl(&ast, config, n)?
.into_iter()
.collect();
result.sort_by_key(|(_, score)| Reverse(OrderedFloat(*score)));

@ -1,7 +1,10 @@
use smartstring::{LazyCompact, SmartString};
/// The tokenizer module contains all of the tools used to process
/// text in `tantivy`.
use std::borrow::{Borrow, BorrowMut};
use std::iter;
use std::ops::{Deref, DerefMut};
use rustc_hash::FxHashSet;
use crate::fts::tokenizer::empty_tokenizer::EmptyTokenizer;
@ -60,7 +63,10 @@ impl TextAnalyzer {
///
/// When creating a `TextAnalyzer` from a `Tokenizer` alone, prefer using
/// `TextAnalyzer::from(tokenizer)`.
pub(crate) fn new<T: Tokenizer>(tokenizer: T, token_filters: Vec<BoxTokenFilter>) -> TextAnalyzer {
pub(crate) fn new<T: Tokenizer>(
tokenizer: T,
token_filters: Vec<BoxTokenFilter>,
) -> TextAnalyzer {
TextAnalyzer {
tokenizer: Box::new(tokenizer),
token_filters,
@ -96,6 +102,25 @@ impl TextAnalyzer {
}
token_stream
}
pub(crate) fn unique_ngrams(&self, text: &str, n: usize) -> FxHashSet<Vec<SmartString<LazyCompact>>> {
let mut token_steam = self.token_stream(text);
let mut coll: Vec<SmartString<LazyCompact>> = vec![];
while let Some(token) = token_steam.next() {
coll.push(SmartString::from(token.text.as_str()));
}
if n == 1 {
coll.iter().map(|x| vec![x.clone()]).collect()
} else if n >= coll.len() {
iter::once(coll).collect()
} else {
let mut ret = FxHashSet::default();
for chunk in coll.windows(n) {
ret.insert(chunk.to_vec());
}
ret
}
}
}
impl Clone for TextAnalyzer {

@ -21,7 +21,7 @@ pub(crate) fn parse_fts_query(q: &str) -> Result<FtsExpr> {
let pairs = pairs.next().unwrap().into_inner();
let pairs: Vec<_> = pairs
.filter(|r| r.as_rule() != Rule::EOI)
.map(|r| parse_fts_expr(r))
.map(parse_fts_expr)
.try_collect()?;
Ok(if pairs.len() == 1 {
pairs.into_iter().next().unwrap()
@ -157,7 +157,7 @@ mod tests {
assert!(matches!(res, FtsExpr::Not(_, _)));
let src = " NEAR(abc def \"ghi\"^22.8) ";
let res = parse_fts_query(src).unwrap().flatten();
assert!(matches!(res, FtsExpr::Near(FtsNear{distance: 10, ..})));
assert!(matches!(res, FtsExpr::Near(FtsNear { distance: 10, .. })));
println!("{:#?}", res);
}
}

@ -11,6 +11,7 @@ use std::sync::Arc;
use itertools::Itertools;
use miette::{bail, ensure, miette, Diagnostic, Result};
use ordered_float::OrderedFloat;
use smartstring::{LazyCompact, SmartString};
use thiserror::Error;
@ -41,8 +42,10 @@ pub(crate) enum SysOp {
CreateIndex(Symbol, Symbol, Vec<Symbol>),
CreateVectorIndex(HnswIndexConfig),
CreateFtsIndex(FtsIndexConfig),
CreateMinHashLshIndex(MinHashLshConfig),
RemoveIndex(Symbol, Symbol),
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub(crate) struct FtsIndexConfig {
pub(crate) base_relation: SmartString<LazyCompact>,
@ -52,6 +55,20 @@ pub(crate) struct FtsIndexConfig {
pub(crate) filters: Vec<TokenizerConfig>,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub(crate) struct MinHashLshConfig {
pub(crate) base_relation: SmartString<LazyCompact>,
pub(crate) index_name: SmartString<LazyCompact>,
pub(crate) extractor: String,
pub(crate) tokenizer: TokenizerConfig,
pub(crate) filters: Vec<TokenizerConfig>,
pub(crate) n_gram: usize,
pub(crate) n_perm: usize,
pub(crate) false_positive_weight: OrderedFloat<f64>,
pub(crate) false_negative_weight: OrderedFloat<f64>,
pub(crate) target_threshold: OrderedFloat<f64>,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub(crate) struct HnswIndexConfig {
pub(crate) base_relation: SmartString<LazyCompact>,
@ -186,7 +203,175 @@ pub(crate) fn parse_sys(
SysOp::SetTriggers(rel, puts, rms, replaces)
}
Rule::lsh_idx_op => {
todo!()
let inner = inner.into_inner().next().unwrap();
match inner.as_rule() {
Rule::index_create_adv => {
let mut inner = inner.into_inner();
let rel = inner.next().unwrap();
let name = inner.next().unwrap();
let mut filters = vec![];
let mut tokenizer = TokenizerConfig {
name: Default::default(),
args: Default::default(),
};
let mut extractor = "".to_string();
let mut n_gram = 1;
let mut n_perm = 200;
let mut target_threshold = 0.9;
let mut false_positive_weight = 1.0;
let mut false_negative_weight = 1.0;
for opt_pair in inner {
let mut opt_inner = opt_pair.into_inner();
let opt_name = opt_inner.next().unwrap();
let opt_val = opt_inner.next().unwrap();
match opt_name.as_str() {
"false_positive_weight" => {
let mut expr = build_expr(opt_val, param_pool)?;
expr.partial_eval()?;
let v = expr.eval_to_const()?;
false_positive_weight = v.get_float().ok_or_else(|| {
miette!("false_positive_weight must be a float")
})?;
}
"false_negative_weight" => {
let mut expr = build_expr(opt_val, param_pool)?;
expr.partial_eval()?;
let v = expr.eval_to_const()?;
false_negative_weight = v.get_float().ok_or_else(|| {
miette!("false_negative_weight must be a float")
})?;
}
"n_gram" => {
let mut expr = build_expr(opt_val, param_pool)?;
expr.partial_eval()?;
let v = expr.eval_to_const()?;
n_gram = v
.get_int()
.ok_or_else(|| miette!("n_gram must be an integer"))?
as usize;
}
"n_perm" => {
let mut expr = build_expr(opt_val, param_pool)?;
expr.partial_eval()?;
let v = expr.eval_to_const()?;
n_perm = v
.get_int()
.ok_or_else(|| miette!("n_perm must be an integer"))?
as usize;
}
"target_threshold" => {
let mut expr = build_expr(opt_val, param_pool)?;
expr.partial_eval()?;
let v = expr.eval_to_const()?;
target_threshold = v
.get_float()
.ok_or_else(|| miette!("target_threshold must be a float"))?;
}
"extractor" => {
let mut ex = build_expr(opt_val, param_pool)?;
ex.partial_eval()?;
extractor = ex.to_string();
}
"tokenizer" => {
let mut expr = build_expr(opt_val, param_pool)?;
expr.partial_eval()?;
match expr {
Expr::UnboundApply { op, args, .. } => {
let mut targs = vec![];
for arg in args.iter() {
let v = arg.clone().eval_to_const()?;
targs.push(v);
}
tokenizer.name = op;
tokenizer.args = targs;
}
Expr::Binding { var, .. } => {
tokenizer.name = var.name;
tokenizer.args = vec![];
}
_ => bail!("Tokenizer must be a symbol or a call for an existing tokenizer"),
}
}
"filters" => {
let mut expr = build_expr(opt_val, param_pool)?;
expr.partial_eval()?;
match expr {
Expr::Apply { op, args, .. } => {
if op.name != "LIST" {
bail!("Filters must be a list of filters");
}
for arg in args.iter() {
match arg {
Expr::UnboundApply { op, args, .. } => {
let mut targs = vec![];
for arg in args.iter() {
let v = arg.clone().eval_to_const()?;
targs.push(v);
}
filters.push(TokenizerConfig {
name: op.clone(),
args: targs,
})
}
Expr::Binding { var, .. } => {
filters.push(TokenizerConfig {
name: var.name.clone(),
args: vec![],
})
}
_ => bail!("Tokenizer must be a symbol or a call for an existing tokenizer"),
}
}
}
_ => bail!("Filters must be a list of filters"),
}
}
_ => bail!("Unknown option {} for FTS index", opt_name.as_str()),
}
}
ensure!(
false_positive_weight > 0.,
"false_positive_weight must be positive"
);
ensure!(
false_negative_weight > 0.,
"false_negative_weight must be positive"
);
ensure!(n_gram > 0, "n_gram must be positive");
ensure!(n_perm > 0, "n_perm must be positive");
ensure!(
target_threshold > 0. && target_threshold < 1.,
"target_threshold must be between 0 and 1"
);
let total_weights = false_positive_weight + false_negative_weight;
false_positive_weight /= total_weights;
false_negative_weight /= total_weights;
let config = MinHashLshConfig {
base_relation: SmartString::from(rel.as_str()),
index_name: SmartString::from(name.as_str()),
extractor,
tokenizer,
filters,
n_gram,
n_perm,
false_positive_weight: false_positive_weight.into(),
false_negative_weight: false_negative_weight.into(),
target_threshold: target_threshold.into(),
};
SysOp::CreateMinHashLshIndex(config)
}
Rule::index_drop => {
let mut inner = inner.into_inner();
let rel = inner.next().unwrap();
let name = inner.next().unwrap();
SysOp::RemoveIndex(
Symbol::new(rel.as_str(), rel.extract_span()),
Symbol::new(name.as_str(), name.extract_span()),
)
}
r => unreachable!("{:?}", r),
}
}
Rule::fts_idx_op => {
let inner = inner.into_inner().next().unwrap();

@ -539,6 +539,40 @@ impl<'a> SessionTx<'a> {
ret = ret.filter(Expr::build_and(post_filters, s.span))?;
}
}
MagicAtom::LshSearch(s) => {
debug_assert!(
seen_variables.contains(&s.query),
"FTS search query must be bound"
);
let mut own_bindings = vec![];
let mut post_filters = vec![];
for var in s.all_bindings() {
if seen_variables.contains(var) {
let rk = gen_symb(var.span);
post_filters.push(Expr::build_equate(
vec![
Expr::Binding {
var: var.clone(),
tuple_pos: None,
},
Expr::Binding {
var: rk.clone(),
tuple_pos: None,
},
],
var.span,
));
own_bindings.push(rk);
} else {
seen_variables.insert(var.clone());
own_bindings.push(var.clone());
}
}
ret = ret.lsh_search(s.clone(), own_bindings)?;
if !post_filters.is_empty() {
ret = ret.filter(Expr::build_and(post_filters, s.span))?;
}
}
MagicAtom::Unification(u) => {
if seen_variables.contains(&u.binding) {
let expr = if u.one_many_unif {

@ -180,6 +180,10 @@ fn magic_rewrite_ruleset(
seen_bindings.extend(s.all_bindings().cloned());
collected_atoms.push(MagicAtom::FtsSearch(s));
}
MagicAtom::LshSearch(s) => {
seen_bindings.extend(s.all_bindings().cloned());
collected_atoms.push(MagicAtom::LshSearch(s));
}
MagicAtom::Rule(r_app) => {
if r_app.name.has_bound_adornment() {
// we are guaranteed to have a magic rule application
@ -539,6 +543,14 @@ impl NormalFormAtom {
}
MagicAtom::FtsSearch(s.clone())
}
NormalFormAtom::LshSearch(s) => {
for arg in s.all_bindings() {
if !seen_bindings.contains(arg) {
seen_bindings.insert(arg.clone());
}
}
MagicAtom::LshSearch(s.clone())
}
NormalFormAtom::Predicate(p) => {
// predicate cannot introduce new bindings

@ -23,6 +23,7 @@ use crate::data::symb::Symbol;
use crate::data::tuple::{Tuple, TupleIter};
use crate::data::value::{DataValue, ValidityTs};
use crate::parse::SourceSpan;
use crate::runtime::minhash_lsh::LshSearch;
use crate::runtime::relation::RelationHandle;
use crate::runtime::temp_store::EpochStore;
use crate::runtime::transact::SessionTx;
@ -40,6 +41,7 @@ pub(crate) enum RelAlgebra {
Unification(UnificationRA),
HnswSearch(HnswSearchRA),
FtsSearch(FtsSearchRA),
LshSearch(LshSearchRA),
}
impl RelAlgebra {
@ -56,6 +58,7 @@ impl RelAlgebra {
RelAlgebra::StoredWithValidity(i) => i.span,
RelAlgebra::HnswSearch(i) => i.hnsw_search.span,
RelAlgebra::FtsSearch(i) => i.fts_search.span,
RelAlgebra::LshSearch(i) => i.lsh_search.span,
}
}
}
@ -290,6 +293,11 @@ impl Debug for RelAlgebra {
.field(&bindings)
.field(&s.fts_search.idx_handle.name)
.finish(),
RelAlgebra::LshSearch(s) => f
.debug_tuple("LshSearch")
.field(&bindings)
.field(&s.lsh_search.idx_handle.name)
.finish(),
RelAlgebra::StoredWithValidity(r) => f
.debug_tuple("StoredWithValidity")
.field(&bindings)
@ -362,6 +370,9 @@ impl RelAlgebra {
RelAlgebra::FtsSearch(s) => {
s.fill_binding_indices_and_compile()?;
}
RelAlgebra::LshSearch(s) => {
s.fill_binding_indices_and_compile()?;
}
RelAlgebra::StoredWithValidity(v) => {
v.fill_binding_indices_and_compile()?;
}
@ -459,7 +470,8 @@ impl RelAlgebra {
| RelAlgebra::NegJoin(_)
| RelAlgebra::Unification(_)
| RelAlgebra::HnswSearch(_)
| RelAlgebra::FtsSearch(_)) => {
| RelAlgebra::FtsSearch(_)
| RelAlgebra::LshSearch(_)) => {
let span = filter.span();
RelAlgebra::Filter(FilteredRA {
parent: Box::new(s),
@ -624,6 +636,18 @@ impl RelAlgebra {
own_bindings,
}))
}
pub(crate) fn lsh_search(
self,
fts_search: LshSearch,
own_bindings: Vec<Symbol>,
) -> Result<Self> {
Ok(Self::LshSearch(LshSearchRA {
parent: Box::new(self),
lsh_search: fts_search,
filter_bytecode: None,
own_bindings,
}))
}
pub(crate) fn join(
self,
right: RelAlgebra,
@ -875,6 +899,78 @@ pub(crate) struct HnswSearchRA {
pub(crate) own_bindings: Vec<Symbol>,
}
#[derive(Debug)]
pub(crate) struct LshSearchRA {
pub(crate) parent: Box<RelAlgebra>,
pub(crate) lsh_search: LshSearch,
pub(crate) filter_bytecode: Option<(Vec<Bytecode>, SourceSpan)>,
pub(crate) own_bindings: Vec<Symbol>,
}
impl LshSearchRA {
fn fill_binding_indices_and_compile(&mut self) -> Result<()> {
self.parent.fill_binding_indices_and_compile()?;
if self.lsh_search.filter.is_some() {
let bindings: BTreeMap<_, _> = self
.own_bindings
.iter()
.cloned()
.enumerate()
.map(|(a, b)| (b, a))
.collect();
let filter = self.lsh_search.filter.as_mut().unwrap();
filter.fill_binding_indices(&bindings)?;
self.filter_bytecode = Some((filter.compile()?, filter.span()));
}
Ok(())
}
fn iter<'a>(
&'a self,
tx: &'a SessionTx<'_>,
delta_rule: Option<&MagicSymbol>,
stores: &'a BTreeMap<MagicSymbol, EpochStore>,
) -> Result<TupleIter<'a>> {
let bindings = self.parent.bindings_after_eliminate();
let mut bind_idx = usize::MAX;
for (i, b) in bindings.iter().enumerate() {
if *b == self.lsh_search.query {
bind_idx = i;
break;
}
}
let config = self.lsh_search.clone();
let filter_code = self.filter_bytecode.clone();
let mut stack = vec![];
let it = self
.parent
.iter(tx, delta_rule, stores)?
.map_ok(move |tuple| -> Result<_> {
let q = match tuple[bind_idx].clone() {
DataValue::List(l) => l,
d => bail!("Expected list for LSH search, got {:?}", d),
};
let res = tx.lsh_search(
&q,
&config,
&mut stack,
&filter_code,
)?;
Ok(res.into_iter().map(move |t| {
let mut r = tuple.clone();
r.extend(t);
r
}))
})
.map(flatten_err)
.flatten_ok();
Ok(Box::new(it))
}
}
#[derive(Debug)]
pub(crate) struct FtsSearchRA {
pub(crate) parent: Box<RelAlgebra>,
@ -1745,6 +1841,7 @@ impl RelAlgebra {
RelAlgebra::Unification(r) => r.do_eliminate_temp_vars(used),
RelAlgebra::HnswSearch(_) => Ok(()),
RelAlgebra::FtsSearch(_) => Ok(()),
RelAlgebra::LshSearch(_) => Ok(()),
}
}
@ -1761,6 +1858,7 @@ impl RelAlgebra {
RelAlgebra::Unification(u) => Some(&u.to_eliminate),
RelAlgebra::HnswSearch(_) => None,
RelAlgebra::FtsSearch(_) => None,
RelAlgebra::LshSearch(_) => None,
}
}
@ -1800,6 +1898,11 @@ impl RelAlgebra {
bindings.extend_from_slice(&s.own_bindings);
bindings
}
RelAlgebra::LshSearch(s) => {
let mut bindings = s.parent.bindings_after_eliminate();
bindings.extend_from_slice(&s.own_bindings);
bindings
}
}
}
pub(crate) fn iter<'a>(
@ -1820,6 +1923,7 @@ impl RelAlgebra {
RelAlgebra::Unification(r) => r.iter(tx, delta_rule, stores),
RelAlgebra::HnswSearch(r) => r.iter(tx, delta_rule, stores),
RelAlgebra::FtsSearch(r) => r.iter(tx, delta_rule, stores),
RelAlgebra::LshSearch(r) => r.iter(tx, delta_rule, stores),
}
}
}
@ -2001,6 +2105,7 @@ impl InnerJoin {
}
RelAlgebra::HnswSearch(_) => "hnsw_search_join",
RelAlgebra::FtsSearch(_) => "fts_search_join",
RelAlgebra::LshSearch(_) => "lsh_search_join",
RelAlgebra::StoredWithValidity(_) => {
let join_indices = self
.joiner
@ -2113,7 +2218,8 @@ impl InnerJoin {
| RelAlgebra::Filter(_)
| RelAlgebra::Unification(_)
| RelAlgebra::HnswSearch(_)
| RelAlgebra::FtsSearch(_) => {
| RelAlgebra::FtsSearch(_)
| RelAlgebra::LshSearch(_) => {
self.materialized_join(tx, eliminate_indices, delta_rule, stores)
}
RelAlgebra::Reorder(_) => {

@ -88,6 +88,14 @@ impl NormalFormInlineRule {
pending.push(NormalFormAtom::FtsSearch(s));
}
}
NormalFormAtom::LshSearch(s) => {
if seen_variables.contains(&s.query) {
seen_variables.extend(s.all_bindings().cloned());
round_1_collected.push(NormalFormAtom::LshSearch(s));
} else {
pending.push(NormalFormAtom::LshSearch(s));
}
}
}
}
@ -124,6 +132,10 @@ impl NormalFormInlineRule {
seen_variables.extend(s.all_bindings().cloned());
collected.push(NormalFormAtom::FtsSearch(s));
}
NormalFormAtom::LshSearch(s) => {
seen_variables.extend(s.all_bindings().cloned());
collected.push(NormalFormAtom::LshSearch(s));
}
}
for atom in last_pending.iter() {
match atom {
@ -158,6 +170,14 @@ impl NormalFormInlineRule {
pending.push(NormalFormAtom::FtsSearch(s.clone()));
}
}
NormalFormAtom::LshSearch(s) => {
if seen_variables.contains(&s.query) {
seen_variables.extend(s.all_bindings().cloned());
collected.push(NormalFormAtom::LshSearch(s.clone()));
} else {
pending.push(NormalFormAtom::LshSearch(s.clone()));
}
}
NormalFormAtom::Predicate(p) => {
if p.bindings()?.is_subset(&seen_variables) {
collected.push(NormalFormAtom::Predicate(p.clone()));
@ -206,6 +226,9 @@ impl NormalFormInlineRule {
NormalFormAtom::FtsSearch(s) => {
bail!(UnboundVariable(s.span))
}
NormalFormAtom::LshSearch(s) => {
bail!(UnboundVariable(s.span))
}
}
}
}

@ -31,7 +31,8 @@ impl NormalFormAtom {
| NormalFormAtom::Predicate(_)
| NormalFormAtom::Unification(_)
| NormalFormAtom::HnswSearch(_)
| NormalFormAtom::FtsSearch(_) => Default::default(),
| NormalFormAtom::FtsSearch(_)
| NormalFormAtom::LshSearch(_) => Default::default(),
NormalFormAtom::Rule(r) => BTreeMap::from([(&r.name, false)]),
NormalFormAtom::NegatedRule(r) => BTreeMap::from([(&r.name, true)]),
}

@ -43,7 +43,7 @@ use crate::fts::TokenizerCache;
use crate::parse::sys::SysOp;
use crate::parse::{parse_script, CozoScript, SourceSpan};
use crate::query::compile::{CompiledProgram, CompiledRule, CompiledRuleSet};
use crate::query::ra::{FilteredRA, FtsSearchRA, HnswSearchRA, InnerJoin, NegJoin, RelAlgebra, ReorderRA, StoredRA, StoredWithValidityRA, TempStoreRA, UnificationRA};
use crate::query::ra::{FilteredRA, FtsSearchRA, HnswSearchRA, InnerJoin, LshSearchRA, NegJoin, RelAlgebra, ReorderRA, StoredRA, StoredWithValidityRA, TempStoreRA, UnificationRA};
#[allow(unused_imports)]
use crate::runtime::callback::{
CallbackCollector, CallbackDeclaration, CallbackOp, EventCallbackRegistry,
@ -1082,6 +1082,18 @@ impl<'s, S: Storage<'s>> Db<S> {
.map(|f| f.to_string())
.collect_vec()),
),
RelAlgebra::LshSearch(LshSearchRA {
lsh_search, ..
}) => (
"lsh_index",
json!(format!(":{}", lsh_search.query.name)),
json!(lsh_search.query.name),
json!(lsh_search
.filter
.iter()
.map(|f| f.to_string())
.collect_vec()),
),
};
ret_for_relation.push(json!({
STRATUM: stratum,
@ -1217,6 +1229,20 @@ impl<'s, S: Storage<'s>> Db<S> {
vec![vec![DataValue::from(OK_STR)]],
))
}
SysOp::CreateMinHashLshIndex(config) => {
let lock = self
.obtain_relation_locks(iter::once(&config.base_relation))
.pop()
.unwrap();
let _guard = lock.write().unwrap();
let mut tx = self.transact_write()?;
tx.create_minhash_lsh_index(config)?;
tx.commit_tx()?;
Ok(NamedRows::new(
vec![STATUS_STR.to_string()],
vec![vec![DataValue::from(OK_STR)]],
))
}
SysOp::RemoveIndex(rel_name, idx_name) => {
let lock = self
.obtain_relation_locks(iter::once(&rel_name.name))

@ -0,0 +1,374 @@
/*
* Copyright 2023, The Cozo Project Authors.
*
* This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0.
* If a copy of the MPL was not distributed with this file,
* You can obtain one at https://mozilla.org/MPL/2.0/.
*/
// Some ideas are from https://github.com/schelterlabs/rust-minhash
use crate::data::expr::{eval_bytecode, Bytecode, eval_bytecode_pred};
use crate::data::tuple::Tuple;
use crate::fts::tokenizer::TextAnalyzer;
use crate::fts::TokenizerConfig;
use crate::runtime::relation::RelationHandle;
use crate::runtime::transact::SessionTx;
use crate::{DataValue, Expr, SourceSpan, Symbol};
use miette::{bail, miette, Result};
use quadrature::integrate;
use rand::{thread_rng, RngCore};
use rustc_hash::FxHashMap;
use smartstring::{LazyCompact, SmartString};
use std::cmp::min;
use std::hash::{Hash, Hasher};
use twox_hash::XxHash32;
impl<'a> SessionTx<'a> {
pub(crate) fn del_lsh_index_item(
&mut self,
tuple: &[DataValue],
bytes: Option<Vec<u8>>,
idx_handle: &RelationHandle,
inv_idx_handle: &RelationHandle,
manifest: &MinHashLshIndexManifest,
) -> Result<()> {
let bytes = match bytes {
None => {
if let Some(mut found) = inv_idx_handle.get_val_only(self, tuple)? {
let inv_key = inv_idx_handle.encode_key_for_store(tuple, Default::default())?;
self.store_tx.del(&inv_key)?;
match found.pop() {
Some(DataValue::Bytes(b)) => b,
_ => unreachable!(),
}
} else {
return Ok(());
}
}
Some(b) => b,
};
let mut key = Vec::with_capacity(bytes.len() + 2);
key.push(DataValue::Bot);
key.push(DataValue::Bot);
key.extend_from_slice(tuple);
for (i, chunk) in bytes
.chunks_exact(manifest.r * std::mem::size_of::<u32>())
.enumerate()
{
key[0] = DataValue::from(i as i64);
key[1] = DataValue::Bytes(chunk.to_vec());
let key_bytes = idx_handle.encode_key_for_store(&key, Default::default())?;
self.store_tx.del(&key_bytes)?;
}
Ok(())
}
pub(crate) fn put_lsh_index_item(
&mut self,
tuple: &[DataValue],
extractor: &[Bytecode],
stack: &mut Vec<DataValue>,
tokenizer: &TextAnalyzer,
rel_handle: &RelationHandle,
idx_handle: &RelationHandle,
inv_idx_handle: &RelationHandle,
manifest: &MinHashLshIndexManifest,
hash_perms: &HashPermutations,
) -> Result<()> {
if let Some(mut found) =
inv_idx_handle.get_val_only(self, &tuple[..rel_handle.metadata.keys.len()])?
{
let bytes = match found.pop() {
Some(DataValue::Bytes(b)) => b,
_ => unreachable!(),
};
self.del_lsh_index_item(tuple, Some(bytes), idx_handle, inv_idx_handle, manifest)?;
}
let to_index = eval_bytecode(extractor, tuple, stack)?;
let min_hash = match to_index {
DataValue::Null => return Ok(()),
DataValue::List(l) => HashValues::new(l.iter(), hash_perms),
DataValue::Str(s) => {
let n_grams = tokenizer.unique_ngrams(&s, manifest.n_gram);
HashValues::new(n_grams.iter(), hash_perms)
}
_ => bail!("Cannot put value {:?} into a LSH index", to_index),
};
let bytes = min_hash.get_bytes();
let inv_key_part = &tuple[..rel_handle.metadata.keys.len()];
let inv_val_part = vec![DataValue::Bytes(bytes.to_vec())];
let inv_key = inv_idx_handle.encode_key_for_store(inv_key_part, Default::default())?;
let inv_val =
inv_idx_handle.encode_val_only_for_store(&inv_val_part, Default::default())?;
self.store_tx.put(&inv_key, &inv_val)?;
let mut key = Vec::with_capacity(bytes.len() + 2);
key.push(DataValue::Bot);
key.push(DataValue::Bot);
key.extend_from_slice(inv_key_part);
let chunk_size = manifest.r * std::mem::size_of::<u32>();
for i in 0..manifest.b {
let byte_range = &bytes[i * chunk_size..(i + 1) * chunk_size];
key[0] = DataValue::from(i as i64);
key[1] = DataValue::Bytes(byte_range.to_vec());
let key_bytes = idx_handle.encode_key_for_store(&key, Default::default())?;
self.store_tx.put(&key_bytes, &[])?;
}
Ok(())
}
pub(crate) fn lsh_search(
&self,
tuple: &[DataValue],
config: &LshSearch,
stack: &mut Vec<DataValue>,
filter_code: &Option<(Vec<Bytecode>, SourceSpan)>,
) -> Result<Vec<Tuple>> {
let bytes = if let Some(mut found) = config
.inv_idx_handle
.get_val_only(self, &tuple[..config.base_handle.metadata.keys.len()])?
{
match found.pop() {
Some(DataValue::Bytes(b)) => b,
_ => unreachable!(),
}
} else {
return Ok(vec![]);
};
let chunk_size = config.manifest.r * std::mem::size_of::<u32>();
let mut key_prefix = Vec::with_capacity(2);
let mut found_tuples: FxHashMap<_, usize> = FxHashMap::default();
for (i, chunk) in bytes.chunks_exact(chunk_size).enumerate() {
key_prefix.clear();
key_prefix.push(DataValue::from(i as i64));
key_prefix.push(DataValue::Bytes(chunk.to_vec()));
for ks in config.idx_handle.scan_prefix(self, &key_prefix) {
let ks = ks?;
let key_part = &ks[2..];
if key_part != tuple {
let found = found_tuples.entry(key_part.to_vec()).or_default();
*found += 1;
}
}
}
let mut ret = vec![];
for (key, count) in found_tuples {
let similarity = count as f64 / config.manifest.r as f64;
if similarity < config.min_similarity {
continue;
}
let mut orig_tuple = config
.base_handle
.get(self, &key)?
.ok_or_else(|| miette!("Tuple not found in base LSH relation"))?;
if let Some((filter_code, span)) = filter_code {
if !eval_bytecode_pred(filter_code, &orig_tuple, stack, *span)? {
continue;
}
}
if config.bind_similarity.is_some() {
orig_tuple.push(DataValue::from(similarity));
}
ret.push(orig_tuple);
if let Some(k) = config.k {
if ret.len() >= k {
break;
}
}
}
Ok(ret)
}
}
#[derive(Clone, Debug)]
pub(crate) struct LshSearch {
pub(crate) base_handle: RelationHandle,
pub(crate) idx_handle: RelationHandle,
pub(crate) inv_idx_handle: RelationHandle,
pub(crate) manifest: MinHashLshIndexManifest,
pub(crate) bindings: Vec<Symbol>,
pub(crate) k: Option<usize>,
pub(crate) query: Symbol,
pub(crate) bind_similarity: Option<Symbol>,
pub(crate) min_similarity: f64,
// pub(crate) lax_mode: bool,
pub(crate) filter: Option<Expr>,
pub(crate) span: SourceSpan,
}
impl LshSearch {
pub(crate) fn all_bindings(&self) -> impl Iterator<Item = &Symbol> {
self.bindings.iter().chain(self.bind_similarity.iter())
}
}
pub(crate) struct HashValues(pub(crate) Vec<u32>);
pub(crate) struct HashPermutations(pub(crate) Vec<u32>);
#[derive(Debug, Clone, PartialEq, serde_derive::Serialize, serde_derive::Deserialize)]
pub(crate) struct MinHashLshIndexManifest {
pub(crate) base_relation: SmartString<LazyCompact>,
pub(crate) index_name: SmartString<LazyCompact>,
pub(crate) extractor: String,
pub(crate) n_gram: usize,
pub(crate) tokenizer: TokenizerConfig,
pub(crate) filters: Vec<TokenizerConfig>,
pub(crate) num_perm: usize,
pub(crate) b: usize,
pub(crate) r: usize,
pub(crate) threshold: f64,
pub(crate) perms: Vec<u8>,
}
impl MinHashLshIndexManifest {
pub(crate) fn get_hash_perms(&self) -> HashPermutations {
HashPermutations::from_bytes(&self.perms)
}
}
#[derive(Clone, Debug)]
pub(crate) struct LshParams {
pub b: usize,
pub r: usize,
}
#[derive(Clone)]
pub(crate) struct Weights(pub(crate) f64, pub(crate) f64);
const _ALLOWED_INTEGRATE_ERR: f64 = 0.001;
// code is mostly from https://github.com/schelterlabs/rust-minhash/blob/81ea3fec24fd888a330a71b6932623643346b591/src/minhash_lsh.rs
impl LshParams {
pub fn find_optimal_params(threshold: f64, num_perm: usize, weights: &Weights) -> LshParams {
let Weights(false_positive_weight, false_negative_weight) = weights;
let mut min_error = f64::INFINITY;
let mut opt = LshParams { b: 0, r: 0 };
for b in 1..num_perm + 1 {
let max_r = num_perm / b;
for r in 1..max_r + 1 {
let false_pos = LshParams::false_positive_probability(threshold, b, r);
let false_neg = LshParams::false_negative_probability(threshold, b, r);
let error = false_pos * false_positive_weight + false_neg * false_negative_weight;
if error < min_error {
min_error = error;
opt = LshParams { b, r };
}
}
}
opt
}
fn false_positive_probability(threshold: f64, b: usize, r: usize) -> f64 {
let _probability =
|s| -> f64 { 1. - f64::powf(1. - f64::powi(s, r as i32), b as f64) };
integrate(_probability, 0.0, threshold, _ALLOWED_INTEGRATE_ERR).integral
}
fn false_negative_probability(threshold: f64, b: usize, r: usize) -> f64 {
let _probability =
|s| -> f64 { 1. - (1. - f64::powf(1. - f64::powi(s, r as i32), b as f64)) };
integrate(_probability, threshold, 1.0, _ALLOWED_INTEGRATE_ERR).integral
}
}
impl HashPermutations {
pub(crate) fn new(n_perms: usize) -> Self {
let mut rng = thread_rng();
let mut perms = Vec::with_capacity(n_perms);
for _ in 0..n_perms {
perms.push(rng.next_u32());
}
Self(perms)
}
pub(crate) fn as_bytes(&self) -> &[u8] {
unsafe {
std::slice::from_raw_parts(
self.0.as_ptr() as *const u8,
self.0.len() * std::mem::size_of::<u32>(),
)
}
}
// this is the inverse of `as_bytes`
pub(crate) fn from_bytes(bytes: &[u8]) -> Self {
unsafe {
let ptr = bytes.as_ptr() as *const u32;
let len = bytes.len() / std::mem::size_of::<u32>();
let perms = std::slice::from_raw_parts(ptr, len);
Self(perms.to_vec())
}
}
}
impl HashValues {
pub(crate) fn new<T: Hash>(values: impl Iterator<Item = T>, perms: &HashPermutations) -> Self {
let mut ret = Self::init(perms);
ret.update(values, perms);
ret
}
pub(crate) fn init(perms: &HashPermutations) -> Self {
Self(vec![u32::MAX; perms.0.len()])
}
pub(crate) fn update<T: Hash>(
&mut self,
values: impl Iterator<Item = T>,
perms: &HashPermutations,
) {
for v in values {
for (i, seed) in perms.0.iter().enumerate() {
let mut hasher = XxHash32::with_seed(*seed);
v.hash(&mut hasher);
let hash = hasher.finish() as u32;
self.0[i] = min(self.0[i], hash);
}
}
}
#[cfg(test)]
pub(crate) fn jaccard(&self, other_minhash: &Self) -> f32 {
let matches = self
.0
.iter()
.zip_eq(&other_minhash.0)
.filter(|(left, right)| left == right)
.count();
let result = matches as f32 / self.0.len() as f32;
result
}
pub(crate) fn get_bytes(&self) -> &[u8] {
unsafe {
std::slice::from_raw_parts(
self.0.as_ptr() as *const u8,
self.0.len() * std::mem::size_of::<u32>(),
)
}
}
// pub(crate) fn get_byte_chunks(&self, n_chunks: usize) -> impl Iterator<Item = &[u8]> {
// let chunk_size = self.0.len() * std::mem::size_of::<u32>() / n_chunks;
// self.get_bytes().chunks_exact(chunk_size)
// }
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_minhash() {
let perms = HashPermutations::new(20000);
let mut m1 = HashValues::new([1, 2, 3, 4, 5, 6].iter(), &perms);
let mut m2 = HashValues::new([4, 3, 2, 1, 5, 6].iter(), &perms);
assert_eq!(m1.0, m2.0);
// println!("{:?}", &m1.0);
// println!("{:?}", &m2.0);
assert_eq!(m1.jaccard(&m2), 1.0);
m1.update([7, 8, 9].iter(), &perms);
assert!(m1.jaccard(&m2) < 1.0);
println!("{:?}", m1.jaccard(&m2));
m2.update([17, 18, 19].iter(), &perms);
assert!(m1.jaccard(&m2) < 1.0);
println!("{:?}", m1.jaccard(&m2));
// println!("{:?}", m2.get_byte_chunks(2).collect_vec());
assert_eq!(perms.0, HashPermutations::from_bytes(perms.as_bytes()).0);
}
}

@ -13,5 +13,6 @@ pub(crate) mod relation;
pub(crate) mod temp_store;
pub(crate) mod transact;
pub(crate) mod hnsw;
pub(crate) mod minhash_lsh;
#[cfg(test)]
mod tests;

@ -26,10 +26,11 @@ use crate::data::tuple::{decode_tuple_from_key, Tuple, TupleT, ENCODED_KEY_MIN_L
use crate::data::value::{DataValue, ValidityTs};
use crate::fts::FtsIndexManifest;
use crate::parse::expr::build_expr;
use crate::parse::sys::{FtsIndexConfig, HnswIndexConfig};
use crate::parse::sys::{FtsIndexConfig, HnswIndexConfig, MinHashLshConfig};
use crate::parse::{CozoScriptParser, Rule, SourceSpan};
use crate::query::compile::IndexPositionUse;
use crate::runtime::hnsw::HnswIndexManifest;
use crate::runtime::minhash_lsh::{HashPermutations, LshParams, MinHashLshIndexManifest, Weights};
use crate::runtime::transact::SessionTx;
use crate::{NamedRows, StoreTx};
@ -83,6 +84,10 @@ pub(crate) struct RelationHandle {
pub(crate) hnsw_indices:
BTreeMap<SmartString<LazyCompact>, (RelationHandle, HnswIndexManifest)>,
pub(crate) fts_indices: BTreeMap<SmartString<LazyCompact>, (RelationHandle, FtsIndexManifest)>,
pub(crate) lsh_indices: BTreeMap<
SmartString<LazyCompact>,
(RelationHandle, RelationHandle, MinHashLshIndexManifest),
>,
}
impl RelationHandle {
@ -90,9 +95,13 @@ impl RelationHandle {
self.indices.contains_key(index_name)
|| self.hnsw_indices.contains_key(index_name)
|| self.fts_indices.contains_key(index_name)
|| self.lsh_indices.contains_key(index_name)
}
pub(crate) fn has_no_index(&self) -> bool {
self.indices.is_empty() && self.hnsw_indices.is_empty() && self.fts_indices.is_empty()
self.indices.is_empty()
&& self.hnsw_indices.is_empty()
&& self.fts_indices.is_empty()
&& self.lsh_indices.is_empty()
}
}
@ -254,10 +263,7 @@ impl RelationHandle {
}
Ok(ret)
}
pub(crate) fn encode_partial_key_for_store(
&self,
tuple: &[DataValue],
) -> Vec<u8> {
pub(crate) fn encode_partial_key_for_store(&self, tuple: &[DataValue]) -> Vec<u8> {
let mut ret = self.encode_key_prefix(tuple.len());
for val in tuple {
ret.encode_datavalue(val);
@ -392,6 +398,25 @@ impl RelationHandle {
}
}
pub(crate) fn get_val_only(
&self,
tx: &SessionTx<'_>,
key: &[DataValue],
) -> Result<Option<Tuple>> {
let key_data = key.encode_as_key(self.id);
if self.is_temp {
Ok(tx
.temp_store_tx
.get(&key_data, false)?
.map(|val_data| rmp_serde::from_slice(&val_data[ENCODED_KEY_MIN_LEN..]).unwrap()))
} else {
Ok(tx
.store_tx
.get(&key_data, false)?
.map(|val_data| rmp_serde::from_slice(&val_data[ENCODED_KEY_MIN_LEN..]).unwrap()))
}
}
pub(crate) fn exists(&self, tx: &SessionTx<'_>, key: &[DataValue]) -> Result<bool> {
let key_data = key.encode_as_key(self.id);
if self.is_temp {
@ -594,6 +619,7 @@ impl<'a> SessionTx<'a> {
indices: Default::default(),
hnsw_indices: Default::default(),
fts_indices: Default::default(),
lsh_indices: Default::default(),
};
let name_key = vec![DataValue::Str(meta.name.clone())].encode_as_key(RelationId::SYSTEM);
@ -694,6 +720,141 @@ impl<'a> SessionTx<'a> {
Ok(())
}
pub(crate) fn create_minhash_lsh_index(&mut self, config: MinHashLshConfig) -> Result<()> {
// Get relation handle
let mut rel_handle = self.get_relation(&config.base_relation, true)?;
// Check if index already exists
if rel_handle.has_index(&config.index_name) {
bail!(IndexAlreadyExists(
config.index_name.to_string(),
config.index_name.to_string()
));
}
let inv_idx_keys = rel_handle.metadata.keys.clone();
let inv_idx_vals = vec![ColumnDef {
name: SmartString::from("minhash"),
typing: NullableColType {
coltype: ColType::Bytes,
nullable: false,
},
default_gen: None,
}];
let mut idx_keys = vec![
ColumnDef {
name: SmartString::from("perm"),
typing: NullableColType {
coltype: ColType::Int,
nullable: false,
},
default_gen: None,
},
ColumnDef {
name: SmartString::from("hash"),
typing: NullableColType {
coltype: ColType::Bytes,
nullable: false,
},
default_gen: None,
},
];
for k in rel_handle.metadata.keys.iter() {
idx_keys.push(ColumnDef {
name: format!("src_{}", k.name).into(),
typing: k.typing.clone(),
default_gen: None,
});
}
let idx_vals = vec![];
let idx_handle = self.write_idx_relation(
&config.base_relation,
&config.index_name,
idx_keys,
idx_vals,
)?;
let inv_idx_handle = self.write_idx_relation(
&config.base_relation,
&config.index_name,
inv_idx_keys,
inv_idx_vals,
)?;
// add index to relation
let params = LshParams::find_optimal_params(
config.target_threshold.0,
config.n_perm,
&Weights(
config.false_positive_weight.0,
config.false_negative_weight.0,
),
);
let perms = HashPermutations::new(config.n_perm);
let manifest = MinHashLshIndexManifest {
base_relation: config.base_relation,
index_name: config.index_name,
extractor: config.extractor,
n_gram: config.n_gram,
tokenizer: config.tokenizer,
filters: config.filters,
num_perm: config.n_perm,
b: params.b,
r: params.r,
threshold: config.target_threshold.0,
perms: perms.as_bytes().to_vec(),
};
// populate index
let tokenizer =
self.tokenizers
.get(&idx_handle.name, &manifest.tokenizer, &manifest.filters)?;
let parsed = CozoScriptParser::parse(Rule::expr, &manifest.extractor)
.into_diagnostic()?
.next()
.unwrap();
let mut code_expr = build_expr(parsed, &Default::default())?;
let binding_map = rel_handle.raw_binding_map();
code_expr.fill_binding_indices(&binding_map)?;
let extractor = code_expr.compile()?;
let mut stack = vec![];
let existing: Vec<_> = rel_handle.scan_all(self).try_collect()?;
let hash_perms = manifest.get_hash_perms();
for tuple in existing {
self.put_lsh_index_item(
&tuple,
&extractor,
&mut stack,
&tokenizer,
&rel_handle,
&idx_handle,
&inv_idx_handle,
&manifest,
&hash_perms,
)?;
}
rel_handle.lsh_indices.insert(
manifest.index_name.clone(),
(idx_handle, inv_idx_handle, manifest),
);
// update relation metadata
let new_encoded =
vec![DataValue::from(&rel_handle.name as &str)].encode_as_key(RelationId::SYSTEM);
let mut meta_val = vec![];
rel_handle
.serialize(&mut Serializer::new(&mut meta_val))
.unwrap();
self.store_tx.put(&new_encoded, &meta_val)?;
Ok(())
}
pub(crate) fn create_fts_index(&mut self, config: FtsIndexConfig) -> Result<()> {
// Get relation handle
let mut rel_handle = self.get_relation(&config.base_relation, true)?;
@ -748,7 +909,7 @@ impl<'a> SessionTx<'a> {
},
ColumnDef {
name: SmartString::from("position"),
typing: col_type.clone(),
typing: col_type,
default_gen: None,
},
ColumnDef {
@ -1192,8 +1353,10 @@ impl<'a> SessionTx<'a> {
idx_name: &Symbol,
) -> Result<Vec<(Vec<u8>, Vec<u8>)>> {
let mut rel = self.get_relation(rel_name, true)?;
let is_lsh = rel.lsh_indices.contains_key(&idx_name.name);
if rel.indices.remove(&idx_name.name).is_none()
&& rel.hnsw_indices.remove(&idx_name.name).is_none()
&& rel.lsh_indices.remove(&idx_name.name).is_none()
{
#[derive(Debug, Error, Diagnostic)]
#[error("index {0} for relation {1} not found")]
@ -1203,7 +1366,13 @@ impl<'a> SessionTx<'a> {
bail!(IndexNotFound(idx_name.to_string(), rel_name.to_string()));
}
let to_clean = self.destroy_relation(&format!("{}:{}", rel_name.name, idx_name.name))?;
let mut to_clean =
self.destroy_relation(&format!("{}:{}", rel_name.name, idx_name.name))?;
if is_lsh {
to_clean.extend(
self.destroy_relation(&format!("{}:{}:inv", rel_name.name, idx_name.name))?,
);
}
let new_encoded =
vec![DataValue::from(&rel_name.name as &str)].encode_as_key(RelationId::SYSTEM);

@ -62,6 +62,13 @@ pub trait StoreTx<'s>: Sync {
/// the key has not been modified outside the transaction.
fn get(&self, key: &[u8], for_update: bool) -> Result<Option<Vec<u8>>>;
/// Get multiple keys. If `for_update` is `true` (only possible in a write transaction),
/// then the database needs to guarantee that `commit()` can only succeed if
/// the keys have not been modified outside the transaction.
fn multi_get(&self, keys: &[Vec<u8>], for_update: bool) -> Result<Vec<Option<Vec<u8>>>> {
keys.iter().map(|k| self.get(k, for_update)).collect()
}
/// Put a key-value pair into the storage. In case of existing key,
/// the storage engine needs to overwrite the old value.
fn put(&mut self, key: &[u8], val: &[u8]) -> Result<()>;

Loading…
Cancel
Save