diff --git a/Cargo.lock b/Cargo.lock index 2b800839..601807a6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -724,6 +724,7 @@ dependencies = [ "pest", "pest_derive", "priority-queue", + "quadrature", "rand 0.8.5", "rayon", "regex", @@ -746,6 +747,7 @@ dependencies = [ "tikv-client", "tikv-jemallocator-global", "tokio", + "twox-hash", "unicode-normalization", "uuid", ] @@ -2945,6 +2947,12 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "quadrature" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2054ccb02f454fcb2bc81e343aa0a171636a6331003fd5ec24c47a10966634b7" + [[package]] name = "quote" version = "1.0.26" @@ -4113,6 +4121,17 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3528ecfd12c466c6f163363caf2d02a71161dd5e1cc6ae7b34207ea2d42d81ed" +[[package]] +name = "twox-hash" +version = "1.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97fee6b57c6a41524a810daee9286c02d7752c4253064d0b05472833a438f675" +dependencies = [ + "cfg-if 1.0.0", + "rand 0.8.5", + "static_assertions", +] + [[package]] name = "typenum" version = "1.16.0" diff --git a/cozo-core/Cargo.toml b/cozo-core/Cargo.toml index 08d3d7fc..8dedafd3 100644 --- a/cozo-core/Cargo.toml +++ b/cozo-core/Cargo.toml @@ -130,6 +130,8 @@ crossbeam = "0.8.2" ndarray = { version = "0.15.6", features = ["serde"] } sha2 = "0.10.6" rustc-hash = "1.1.0" +twox-hash = "1.6.3" +quadrature = "0.1.2" # For the FTS feature jieba-rs = "0.6.7" aho-corasick = "1.0.1" diff --git a/cozo-core/src/cozoscript.pest b/cozo-core/src/cozoscript.pest index 2fa0be52..968994ac 100644 --- a/cozo-core/src/cozoscript.pest +++ b/cozo-core/src/cozoscript.pest @@ -56,7 +56,7 @@ underscore_ident = @{("_" | XID_START) ~ ("_" | XID_CONTINUE)*} relation_ident = @{"*" ~ (compound_or_index_ident | underscore_ident)} search_index_ident = _{"~" ~ compound_or_index_ident} compound_ident = @{ident ~ ("." ~ ident)*} -compound_or_index_ident = @{ident ~ ("." ~ ident)* ~ (":" ~ ident)?} +compound_or_index_ident = @{ident ~ ("." ~ ident)* ~ (":" ~ ident)*} rule = {rule_head ~ ":=" ~ rule_body ~ ";"?} const_rule = {rule_head ~ "<-" ~ expr ~ ";"?} diff --git a/cozo-core/src/data/program.rs b/cozo-core/src/data/program.rs index fc860056..306fd494 100644 --- a/cozo-core/src/data/program.rs +++ b/cozo-core/src/data/program.rs @@ -18,6 +18,7 @@ use thiserror::Error; use crate::data::aggr::Aggregation; use crate::data::expr::Expr; +use crate::data::functions::OP_LIST; use crate::data::relation::StoredRelationMetadata; use crate::data::symb::{Symbol, PROG_ENTRY}; use crate::data::value::{DataValue, ValidityTs}; @@ -26,6 +27,7 @@ use crate::fts::FtsIndexManifest; use crate::parse::SourceSpan; use crate::query::logical::{Disjunction, NamedFieldNotFound}; use crate::runtime::hnsw::HnswIndexManifest; +use crate::runtime::minhash_lsh::{LshSearch, MinHashLshIndexManifest}; use crate::runtime::relation::{ AccessLevel, InputRelationHandle, InsufficientAccessLevel, RelationHandle, }; @@ -950,8 +952,8 @@ pub(crate) struct HnswSearch { #[derive(Copy, Clone, Debug, PartialEq, Eq)] pub(crate) enum FtsScoreKind { - TFIDF, - TF, + TfIdf, + Tf, } #[derive(Clone, Debug)] @@ -989,6 +991,212 @@ impl FtsSearch { } impl SearchInput { + fn normalize_lsh( + mut self, + base_handle: RelationHandle, + idx_handle: RelationHandle, + inv_idx_handle: RelationHandle, + manifest: MinHashLshIndexManifest, + gen: &mut TempSymbGen, + ) -> Result { + let mut conj = Vec::with_capacity(self.bindings.len() + 8); + let mut bindings = Vec::with_capacity(self.bindings.len()); + let mut seen_variables = BTreeSet::new(); + + for col in base_handle + .metadata + .keys + .iter() + .chain(base_handle.metadata.non_keys.iter()) + { + if let Some(arg) = self.bindings.remove(&col.name) { + match arg { + Expr::Binding { var, .. } => { + if var.is_ignored_symbol() { + bindings.push(gen.next_ignored(var.span)); + } else if seen_variables.insert(var.clone()) { + bindings.push(var); + } else { + let span = var.span; + let dup = gen.next(span); + let unif = NormalFormAtom::Unification(Unification { + binding: dup.clone(), + expr: Expr::Binding { + var, + tuple_pos: None, + }, + one_many_unif: false, + span, + }); + conj.push(unif); + bindings.push(dup); + } + } + expr => { + let span = expr.span(); + let kw = gen.next(span); + bindings.push(kw.clone()); + let unif = NormalFormAtom::Unification(Unification { + binding: kw, + expr, + one_many_unif: false, + span, + }); + conj.push(unif) + } + } + } else { + bindings.push(gen.next_ignored(self.span)); + } + } + + if let Some((name, _)) = self.bindings.pop_first() { + bail!(NamedFieldNotFound( + self.relation.name.to_string(), + name.to_string(), + self.span + )); + } + + #[derive(Debug, Error, Diagnostic)] + #[error("Field `{0}` is required for LSH search")] + #[diagnostic(code(parser::hnsw_query_required))] + struct LshRequiredMissing(String, #[label] SourceSpan); + + #[derive(Debug, Error, Diagnostic)] + #[error("Expected a list of keys for LSH search")] + #[diagnostic(code(parser::expected_list_for_lsh_keys))] + struct ExpectedListForLshKeys(#[label] SourceSpan); + + #[derive(Debug, Error, Diagnostic)] + #[error("Wrong arity for LSH keys, expected {1}, got {2}")] + #[diagnostic(code(parser::wrong_arity_for_lsh_keys))] + struct WrongArityForKeys(#[label] SourceSpan, usize, usize); + + let query = match self.parameters.remove("keys") { + None => match self.parameters.remove("key") { + None => { + bail!(LshRequiredMissing("keys".to_string(), self.span)) + } + Some(expr) => { + ensure!( + base_handle.indices.keys().len() == 1, + LshRequiredMissing("keys".to_string(), self.span) + ); + let span = expr.span(); + let kw = gen.next(span); + let unif = NormalFormAtom::Unification(Unification { + binding: kw.clone(), + expr: Expr::Apply { + op: &OP_LIST, + args: [expr].into(), + span, + }, + one_many_unif: false, + span, + }); + conj.push(unif); + kw + } + }, + Some(mut expr) => { + expr.partial_eval()?; + match expr { + Expr::Apply { op, args, span } => { + ensure!(op.name == OP_LIST.name, ExpectedListForLshKeys(span)); + ensure!( + args.len() == base_handle.indices.keys().len(), + WrongArityForKeys(span, base_handle.indices.keys().len(), args.len()) + ); + let kw = gen.next(span); + let unif = NormalFormAtom::Unification(Unification { + binding: kw.clone(), + expr: Expr::Apply { op, args, span }, + one_many_unif: false, + span, + }); + conj.push(unif); + kw + } + _ => { + bail!(ExpectedListForLshKeys(self.span)) + } + } + } + }; + + let k = match self.parameters.remove("k") { + None => None, + Some(k_expr) => { + let k = k_expr.eval_to_const()?; + let k = k.get_int().ok_or(ExpectedPosIntForFtsK(self.span))?; + + #[derive(Debug, Error, Diagnostic)] + #[error("Expected positive integer for `k`")] + #[diagnostic(code(parser::expected_int_for_hnsw_k))] + struct ExpectedPosIntForFtsK(#[label] SourceSpan); + + ensure!(k > 0, ExpectedPosIntForFtsK(self.span)); + Some(k as usize) + } + }; + + let filter = self.parameters.remove("filter"); + + let bind_similarity = match self.parameters.remove("bind_similarity") { + None => None, + Some(Expr::Binding { var, .. }) => Some(var), + Some(expr) => { + let span = expr.span(); + let kw = gen.next(span); + let unif = NormalFormAtom::Unification(Unification { + binding: kw.clone(), + expr, + one_many_unif: false, + span, + }); + conj.push(unif); + Some(kw) + } + }; + + let min_similarity = match self.parameters.remove("min_similarity") { + None => manifest.threshold, + Some(expr) => { + let min_similarity = expr.eval_to_const()?; + let min_similarity = min_similarity + .get_float() + .ok_or(ExpectedFloatForMinSimilarity(self.span))?; + + #[derive(Debug, Error, Diagnostic)] + #[error("Expected float for `min_similarity`")] + #[diagnostic(code(parser::expected_float_for_min_similarity))] + struct ExpectedFloatForMinSimilarity(#[label] SourceSpan); + + ensure!( + (0.0..=1.0).contains(&min_similarity), + ExpectedFloatForMinSimilarity(self.span) + ); + min_similarity + } + }; + + conj.push(NormalFormAtom::LshSearch(LshSearch { + base_handle, + idx_handle, + inv_idx_handle, + manifest, + bindings, + k, + bind_similarity, + query, + span: self.span, + min_similarity, + filter, + })); + + Ok(Disjunction::conj(conj)) + } fn normalize_fts( mut self, base_handle: RelationHandle, @@ -1094,46 +1302,6 @@ impl SearchInput { ensure!(k > 0, ExpectedPosIntForFtsK(self.span)); - // let k1 = { - // match self.parameters.remove("k1") { - // None => 1.2, - // Some(expr) => { - // let r = expr.eval_to_const()?; - // let r = r - // .get_float() - // .ok_or_else(|| miette!("k1 for FTS must be a float"))?; - // - // #[derive(Debug, Error, Diagnostic)] - // #[error("Expected positive float for `k1`")] - // #[diagnostic(code(parser::expected_float_for_hnsw_k1))] - // struct ExpectedPosFloatForFtsK1(#[label] SourceSpan); - // - // ensure!(r > 0.0, ExpectedPosFloatForFtsK1(self.span)); - // r - // } - // } - // }; - // - // let b = { - // match self.parameters.remove("b") { - // None => 0.75, - // Some(expr) => { - // let r = expr.eval_to_const()?; - // let r = r - // .get_float() - // .ok_or_else(|| miette!("b for FTS must be a float"))?; - // - // #[derive(Debug, Error, Diagnostic)] - // #[error("Expected positive float for `b`")] - // #[diagnostic(code(parser::expected_float_for_hnsw_b))] - // struct ExpectedPosFloatForFtsB(#[label] SourceSpan); - // - // ensure!(r > 0.0, ExpectedPosFloatForFtsB(self.span)); - // r - // } - // } - // }; - let score_kind_expr = self.parameters.remove("score_kind"); let score_kind = match score_kind_expr { Some(expr) => { @@ -1143,26 +1311,14 @@ impl SearchInput { .ok_or_else(|| miette!("Score kind for FTS must be a string"))?; match r { - "tf_idf" => FtsScoreKind::TFIDF, - "tf" => FtsScoreKind::TF, + "tf_idf" => FtsScoreKind::TfIdf, + "tf" => FtsScoreKind::Tf, s => bail!("Unknown score kind for FTS: {}", s), } } - None => FtsScoreKind::TFIDF, + None => FtsScoreKind::TfIdf, }; - // let lax_mode_expr = self.parameters.remove("lax_mode"); - // let lax_mode = match lax_mode_expr { - // Some(expr) => { - // let r = expr.eval_to_const()?; - // let r = r - // .get_bool() - // .ok_or_else(|| miette!("Lax mode for FTS must be a boolean"))?; - // r - // } - // None => true, - // }; - let filter = self.parameters.remove("filter"); let bind_score = match self.parameters.remove("bind_score") { @@ -1447,6 +1603,11 @@ impl SearchInput { { return self.normalize_fts(base_handle, idx_handle, manifest, gen); } + if let Some((idx_handle, inv_idx_handle, manifest)) = + base_handle.lsh_indices.get(&self.index.name).cloned() + { + return self.normalize_lsh(base_handle, idx_handle, inv_idx_handle, manifest, gen); + } #[derive(Debug, Error, Diagnostic)] #[error("Index {name} not found on relation {relation}")] #[diagnostic(code(eval::hnsw_index_not_found))] @@ -1586,6 +1747,7 @@ pub(crate) enum NormalFormAtom { Unification(Unification), HnswSearch(HnswSearch), FtsSearch(FtsSearch), + LshSearch(LshSearch), } #[derive(Debug, Clone)] @@ -1598,6 +1760,7 @@ pub(crate) enum MagicAtom { Unification(Unification), HnswSearch(HnswSearch), FtsSearch(FtsSearch), + LshSearch(LshSearch), } #[derive(Clone, Debug)] diff --git a/cozo-core/src/data/value.rs b/cozo-core/src/data/value.rs index 888f3899..123f2b98 100644 --- a/cozo-core/src/data/value.rs +++ b/cozo-core/src/data/value.rs @@ -619,6 +619,13 @@ impl Display for DataValue { } impl DataValue { + /// Returns a slice of bytes if this one is a Bytes + pub fn get_bytes(&self) -> Option<&[u8]> { + match self { + DataValue::Bytes(b) => Some(b), + _ => None, + } + } /// Returns a slice of DataValues if this one is a List pub fn get_slice(&self) -> Option<&[DataValue]> { match self { diff --git a/cozo-core/src/fts/indexing.rs b/cozo-core/src/fts/indexing.rs index 26379935..0ed01083 100644 --- a/cozo-core/src/fts/indexing.rs +++ b/cozo-core/src/fts/indexing.rs @@ -80,11 +80,10 @@ impl<'a> SessionTx<'a> { if !found_str_key.starts_with(start_key_str) { break; } - } else { - if found_str_key != start_key_str { - break; - } + } else if found_str_key != start_key_str { + break; } + let vals: Vec = rmp_serde::from_slice(&vvec[ENCODED_KEY_MIN_LEN..]).unwrap(); let froms = vals[0].get_slice().unwrap(); let tos = vals[1].get_slice().unwrap(); @@ -112,8 +111,6 @@ impl<'a> SessionTx<'a> { &self, ast: &FtsExpr, config: &FtsSearch, - filter_code: &Option<(Vec, SourceSpan)>, - tokenizer: &TextAnalyzer, n: usize, ) -> Result> { Ok(match ast { @@ -138,21 +135,13 @@ impl<'a> SessionTx<'a> { let mut res = self.fts_search_impl( l_iter.next().unwrap(), config, - filter_code, - tokenizer, n, )?; for nxt in l_iter { - let nxt_res = self.fts_search_impl(nxt, config, filter_code, tokenizer, n)?; + let nxt_res = self.fts_search_impl(nxt, config, n)?; res = res .into_iter() - .filter_map(|(k, v)| { - if let Some(nxt_v) = nxt_res.get(&k) { - Some((k, v + nxt_v)) - } else { - None - } - }) + .filter_map(|(k, v)| nxt_res.get(&k).map(|nxt_v| (k, v + nxt_v))) .collect(); } res @@ -160,7 +149,7 @@ impl<'a> SessionTx<'a> { FtsExpr::Or(ls) => { let mut res: FxHashMap = FxHashMap::default(); for nxt in ls { - let nxt_res = self.fts_search_impl(nxt, config, filter_code, tokenizer, n)?; + let nxt_res = self.fts_search_impl(nxt, config, n)?; for (k, v) in nxt_res { if let Some(old_v) = res.get_mut(&k) { *old_v = (*old_v).max(v); @@ -199,10 +188,8 @@ impl<'a> SessionTx<'a> { if cur - p <= *distance { inner_coll.insert(p); } - } else { - if p - cur <= *distance { - inner_coll.insert(cur); - } + } else if p - cur <= *distance { + inner_coll.insert(cur); } } } @@ -230,9 +217,9 @@ impl<'a> SessionTx<'a> { .collect() } FtsExpr::Not(fst, snd) => { - let mut res = self.fts_search_impl(fst, config, filter_code, tokenizer, n)?; + let mut res = self.fts_search_impl(fst, config, n)?; for el in self - .fts_search_impl(snd, config, filter_code, tokenizer, n)? + .fts_search_impl(snd, config, n)? .keys() { res.remove(el); @@ -250,8 +237,8 @@ impl<'a> SessionTx<'a> { ) -> f64 { let tf = tf as f64; match config.score_kind { - FtsScoreKind::TF => tf * booster, - FtsScoreKind::TFIDF => { + FtsScoreKind::Tf => tf * booster, + FtsScoreKind::TfIdf => { let n_found_docs = n_found_docs as f64; let idf = (1.0 + (n_total as f64 - n_found_docs + 0.5) / (n_found_docs + 0.5)).ln(); tf * idf * booster @@ -271,13 +258,13 @@ impl<'a> SessionTx<'a> { if ast.is_empty() { return Ok(vec![]); } - let n = if config.score_kind == FtsScoreKind::TFIDF { + let n = if config.score_kind == FtsScoreKind::TfIdf { cache.get_n_for_relation(&config.base_handle, self)? } else { 0 }; let mut result: Vec<_> = self - .fts_search_impl(&ast, config, filter_code, tokenizer, n)? + .fts_search_impl(&ast, config, n)? .into_iter() .collect(); result.sort_by_key(|(_, score)| Reverse(OrderedFloat(*score))); diff --git a/cozo-core/src/fts/tokenizer/tokenizer_impl.rs b/cozo-core/src/fts/tokenizer/tokenizer_impl.rs index b8711c51..915015c1 100644 --- a/cozo-core/src/fts/tokenizer/tokenizer_impl.rs +++ b/cozo-core/src/fts/tokenizer/tokenizer_impl.rs @@ -1,7 +1,10 @@ +use smartstring::{LazyCompact, SmartString}; /// The tokenizer module contains all of the tools used to process /// text in `tantivy`. use std::borrow::{Borrow, BorrowMut}; +use std::iter; use std::ops::{Deref, DerefMut}; +use rustc_hash::FxHashSet; use crate::fts::tokenizer::empty_tokenizer::EmptyTokenizer; @@ -60,7 +63,10 @@ impl TextAnalyzer { /// /// When creating a `TextAnalyzer` from a `Tokenizer` alone, prefer using /// `TextAnalyzer::from(tokenizer)`. - pub(crate) fn new(tokenizer: T, token_filters: Vec) -> TextAnalyzer { + pub(crate) fn new( + tokenizer: T, + token_filters: Vec, + ) -> TextAnalyzer { TextAnalyzer { tokenizer: Box::new(tokenizer), token_filters, @@ -96,6 +102,25 @@ impl TextAnalyzer { } token_stream } + pub(crate) fn unique_ngrams(&self, text: &str, n: usize) -> FxHashSet>> { + let mut token_steam = self.token_stream(text); + let mut coll: Vec> = vec![]; + while let Some(token) = token_steam.next() { + coll.push(SmartString::from(token.text.as_str())); + } + + if n == 1 { + coll.iter().map(|x| vec![x.clone()]).collect() + } else if n >= coll.len() { + iter::once(coll).collect() + } else { + let mut ret = FxHashSet::default(); + for chunk in coll.windows(n) { + ret.insert(chunk.to_vec()); + } + ret + } + } } impl Clone for TextAnalyzer { diff --git a/cozo-core/src/parse/fts.rs b/cozo-core/src/parse/fts.rs index 58a65f47..d8863b6e 100644 --- a/cozo-core/src/parse/fts.rs +++ b/cozo-core/src/parse/fts.rs @@ -21,7 +21,7 @@ pub(crate) fn parse_fts_query(q: &str) -> Result { let pairs = pairs.next().unwrap().into_inner(); let pairs: Vec<_> = pairs .filter(|r| r.as_rule() != Rule::EOI) - .map(|r| parse_fts_expr(r)) + .map(parse_fts_expr) .try_collect()?; Ok(if pairs.len() == 1 { pairs.into_iter().next().unwrap() @@ -157,7 +157,7 @@ mod tests { assert!(matches!(res, FtsExpr::Not(_, _))); let src = " NEAR(abc def \"ghi\"^22.8) "; let res = parse_fts_query(src).unwrap().flatten(); - assert!(matches!(res, FtsExpr::Near(FtsNear{distance: 10, ..}))); + assert!(matches!(res, FtsExpr::Near(FtsNear { distance: 10, .. }))); println!("{:#?}", res); } } diff --git a/cozo-core/src/parse/sys.rs b/cozo-core/src/parse/sys.rs index e4b67693..b1d78ce4 100644 --- a/cozo-core/src/parse/sys.rs +++ b/cozo-core/src/parse/sys.rs @@ -11,6 +11,7 @@ use std::sync::Arc; use itertools::Itertools; use miette::{bail, ensure, miette, Diagnostic, Result}; +use ordered_float::OrderedFloat; use smartstring::{LazyCompact, SmartString}; use thiserror::Error; @@ -41,8 +42,10 @@ pub(crate) enum SysOp { CreateIndex(Symbol, Symbol, Vec), CreateVectorIndex(HnswIndexConfig), CreateFtsIndex(FtsIndexConfig), + CreateMinHashLshIndex(MinHashLshConfig), RemoveIndex(Symbol, Symbol), } + #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub(crate) struct FtsIndexConfig { pub(crate) base_relation: SmartString, @@ -52,6 +55,20 @@ pub(crate) struct FtsIndexConfig { pub(crate) filters: Vec, } +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub(crate) struct MinHashLshConfig { + pub(crate) base_relation: SmartString, + pub(crate) index_name: SmartString, + pub(crate) extractor: String, + pub(crate) tokenizer: TokenizerConfig, + pub(crate) filters: Vec, + pub(crate) n_gram: usize, + pub(crate) n_perm: usize, + pub(crate) false_positive_weight: OrderedFloat, + pub(crate) false_negative_weight: OrderedFloat, + pub(crate) target_threshold: OrderedFloat, +} + #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub(crate) struct HnswIndexConfig { pub(crate) base_relation: SmartString, @@ -186,7 +203,175 @@ pub(crate) fn parse_sys( SysOp::SetTriggers(rel, puts, rms, replaces) } Rule::lsh_idx_op => { - todo!() + let inner = inner.into_inner().next().unwrap(); + match inner.as_rule() { + Rule::index_create_adv => { + let mut inner = inner.into_inner(); + let rel = inner.next().unwrap(); + let name = inner.next().unwrap(); + let mut filters = vec![]; + let mut tokenizer = TokenizerConfig { + name: Default::default(), + args: Default::default(), + }; + let mut extractor = "".to_string(); + let mut n_gram = 1; + let mut n_perm = 200; + let mut target_threshold = 0.9; + let mut false_positive_weight = 1.0; + let mut false_negative_weight = 1.0; + for opt_pair in inner { + let mut opt_inner = opt_pair.into_inner(); + let opt_name = opt_inner.next().unwrap(); + let opt_val = opt_inner.next().unwrap(); + match opt_name.as_str() { + "false_positive_weight" => { + let mut expr = build_expr(opt_val, param_pool)?; + expr.partial_eval()?; + let v = expr.eval_to_const()?; + false_positive_weight = v.get_float().ok_or_else(|| { + miette!("false_positive_weight must be a float") + })?; + } + "false_negative_weight" => { + let mut expr = build_expr(opt_val, param_pool)?; + expr.partial_eval()?; + let v = expr.eval_to_const()?; + false_negative_weight = v.get_float().ok_or_else(|| { + miette!("false_negative_weight must be a float") + })?; + } + "n_gram" => { + let mut expr = build_expr(opt_val, param_pool)?; + expr.partial_eval()?; + let v = expr.eval_to_const()?; + n_gram = v + .get_int() + .ok_or_else(|| miette!("n_gram must be an integer"))? + as usize; + } + "n_perm" => { + let mut expr = build_expr(opt_val, param_pool)?; + expr.partial_eval()?; + let v = expr.eval_to_const()?; + n_perm = v + .get_int() + .ok_or_else(|| miette!("n_perm must be an integer"))? + as usize; + } + "target_threshold" => { + let mut expr = build_expr(opt_val, param_pool)?; + expr.partial_eval()?; + let v = expr.eval_to_const()?; + target_threshold = v + .get_float() + .ok_or_else(|| miette!("target_threshold must be a float"))?; + } + "extractor" => { + let mut ex = build_expr(opt_val, param_pool)?; + ex.partial_eval()?; + extractor = ex.to_string(); + } + "tokenizer" => { + let mut expr = build_expr(opt_val, param_pool)?; + expr.partial_eval()?; + match expr { + Expr::UnboundApply { op, args, .. } => { + let mut targs = vec![]; + for arg in args.iter() { + let v = arg.clone().eval_to_const()?; + targs.push(v); + } + tokenizer.name = op; + tokenizer.args = targs; + } + Expr::Binding { var, .. } => { + tokenizer.name = var.name; + tokenizer.args = vec![]; + } + _ => bail!("Tokenizer must be a symbol or a call for an existing tokenizer"), + } + } + "filters" => { + let mut expr = build_expr(opt_val, param_pool)?; + expr.partial_eval()?; + match expr { + Expr::Apply { op, args, .. } => { + if op.name != "LIST" { + bail!("Filters must be a list of filters"); + } + for arg in args.iter() { + match arg { + Expr::UnboundApply { op, args, .. } => { + let mut targs = vec![]; + for arg in args.iter() { + let v = arg.clone().eval_to_const()?; + targs.push(v); + } + filters.push(TokenizerConfig { + name: op.clone(), + args: targs, + }) + } + Expr::Binding { var, .. } => { + filters.push(TokenizerConfig { + name: var.name.clone(), + args: vec![], + }) + } + _ => bail!("Tokenizer must be a symbol or a call for an existing tokenizer"), + } + } + } + _ => bail!("Filters must be a list of filters"), + } + } + _ => bail!("Unknown option {} for FTS index", opt_name.as_str()), + } + } + ensure!( + false_positive_weight > 0., + "false_positive_weight must be positive" + ); + ensure!( + false_negative_weight > 0., + "false_negative_weight must be positive" + ); + ensure!(n_gram > 0, "n_gram must be positive"); + ensure!(n_perm > 0, "n_perm must be positive"); + ensure!( + target_threshold > 0. && target_threshold < 1., + "target_threshold must be between 0 and 1" + ); + let total_weights = false_positive_weight + false_negative_weight; + false_positive_weight /= total_weights; + false_negative_weight /= total_weights; + + let config = MinHashLshConfig { + base_relation: SmartString::from(rel.as_str()), + index_name: SmartString::from(name.as_str()), + extractor, + tokenizer, + filters, + n_gram, + n_perm, + false_positive_weight: false_positive_weight.into(), + false_negative_weight: false_negative_weight.into(), + target_threshold: target_threshold.into(), + }; + SysOp::CreateMinHashLshIndex(config) + } + Rule::index_drop => { + let mut inner = inner.into_inner(); + let rel = inner.next().unwrap(); + let name = inner.next().unwrap(); + SysOp::RemoveIndex( + Symbol::new(rel.as_str(), rel.extract_span()), + Symbol::new(name.as_str(), name.extract_span()), + ) + } + r => unreachable!("{:?}", r), + } } Rule::fts_idx_op => { let inner = inner.into_inner().next().unwrap(); diff --git a/cozo-core/src/query/compile.rs b/cozo-core/src/query/compile.rs index 68a91e62..f7961cc1 100644 --- a/cozo-core/src/query/compile.rs +++ b/cozo-core/src/query/compile.rs @@ -539,6 +539,40 @@ impl<'a> SessionTx<'a> { ret = ret.filter(Expr::build_and(post_filters, s.span))?; } } + MagicAtom::LshSearch(s) => { + debug_assert!( + seen_variables.contains(&s.query), + "FTS search query must be bound" + ); + let mut own_bindings = vec![]; + let mut post_filters = vec![]; + for var in s.all_bindings() { + if seen_variables.contains(var) { + let rk = gen_symb(var.span); + post_filters.push(Expr::build_equate( + vec![ + Expr::Binding { + var: var.clone(), + tuple_pos: None, + }, + Expr::Binding { + var: rk.clone(), + tuple_pos: None, + }, + ], + var.span, + )); + own_bindings.push(rk); + } else { + seen_variables.insert(var.clone()); + own_bindings.push(var.clone()); + } + } + ret = ret.lsh_search(s.clone(), own_bindings)?; + if !post_filters.is_empty() { + ret = ret.filter(Expr::build_and(post_filters, s.span))?; + } + } MagicAtom::Unification(u) => { if seen_variables.contains(&u.binding) { let expr = if u.one_many_unif { diff --git a/cozo-core/src/query/magic.rs b/cozo-core/src/query/magic.rs index b0996cec..ec9cff90 100644 --- a/cozo-core/src/query/magic.rs +++ b/cozo-core/src/query/magic.rs @@ -180,6 +180,10 @@ fn magic_rewrite_ruleset( seen_bindings.extend(s.all_bindings().cloned()); collected_atoms.push(MagicAtom::FtsSearch(s)); } + MagicAtom::LshSearch(s) => { + seen_bindings.extend(s.all_bindings().cloned()); + collected_atoms.push(MagicAtom::LshSearch(s)); + } MagicAtom::Rule(r_app) => { if r_app.name.has_bound_adornment() { // we are guaranteed to have a magic rule application @@ -539,6 +543,14 @@ impl NormalFormAtom { } MagicAtom::FtsSearch(s.clone()) } + NormalFormAtom::LshSearch(s) => { + for arg in s.all_bindings() { + if !seen_bindings.contains(arg) { + seen_bindings.insert(arg.clone()); + } + } + MagicAtom::LshSearch(s.clone()) + } NormalFormAtom::Predicate(p) => { // predicate cannot introduce new bindings diff --git a/cozo-core/src/query/ra.rs b/cozo-core/src/query/ra.rs index 9f6fe90a..aa699975 100644 --- a/cozo-core/src/query/ra.rs +++ b/cozo-core/src/query/ra.rs @@ -23,6 +23,7 @@ use crate::data::symb::Symbol; use crate::data::tuple::{Tuple, TupleIter}; use crate::data::value::{DataValue, ValidityTs}; use crate::parse::SourceSpan; +use crate::runtime::minhash_lsh::LshSearch; use crate::runtime::relation::RelationHandle; use crate::runtime::temp_store::EpochStore; use crate::runtime::transact::SessionTx; @@ -40,6 +41,7 @@ pub(crate) enum RelAlgebra { Unification(UnificationRA), HnswSearch(HnswSearchRA), FtsSearch(FtsSearchRA), + LshSearch(LshSearchRA), } impl RelAlgebra { @@ -56,6 +58,7 @@ impl RelAlgebra { RelAlgebra::StoredWithValidity(i) => i.span, RelAlgebra::HnswSearch(i) => i.hnsw_search.span, RelAlgebra::FtsSearch(i) => i.fts_search.span, + RelAlgebra::LshSearch(i) => i.lsh_search.span, } } } @@ -290,6 +293,11 @@ impl Debug for RelAlgebra { .field(&bindings) .field(&s.fts_search.idx_handle.name) .finish(), + RelAlgebra::LshSearch(s) => f + .debug_tuple("LshSearch") + .field(&bindings) + .field(&s.lsh_search.idx_handle.name) + .finish(), RelAlgebra::StoredWithValidity(r) => f .debug_tuple("StoredWithValidity") .field(&bindings) @@ -362,6 +370,9 @@ impl RelAlgebra { RelAlgebra::FtsSearch(s) => { s.fill_binding_indices_and_compile()?; } + RelAlgebra::LshSearch(s) => { + s.fill_binding_indices_and_compile()?; + } RelAlgebra::StoredWithValidity(v) => { v.fill_binding_indices_and_compile()?; } @@ -459,7 +470,8 @@ impl RelAlgebra { | RelAlgebra::NegJoin(_) | RelAlgebra::Unification(_) | RelAlgebra::HnswSearch(_) - | RelAlgebra::FtsSearch(_)) => { + | RelAlgebra::FtsSearch(_) + | RelAlgebra::LshSearch(_)) => { let span = filter.span(); RelAlgebra::Filter(FilteredRA { parent: Box::new(s), @@ -624,6 +636,18 @@ impl RelAlgebra { own_bindings, })) } + pub(crate) fn lsh_search( + self, + fts_search: LshSearch, + own_bindings: Vec, + ) -> Result { + Ok(Self::LshSearch(LshSearchRA { + parent: Box::new(self), + lsh_search: fts_search, + filter_bytecode: None, + own_bindings, + })) + } pub(crate) fn join( self, right: RelAlgebra, @@ -875,6 +899,78 @@ pub(crate) struct HnswSearchRA { pub(crate) own_bindings: Vec, } +#[derive(Debug)] +pub(crate) struct LshSearchRA { + pub(crate) parent: Box, + pub(crate) lsh_search: LshSearch, + pub(crate) filter_bytecode: Option<(Vec, SourceSpan)>, + pub(crate) own_bindings: Vec, +} + + +impl LshSearchRA { + fn fill_binding_indices_and_compile(&mut self) -> Result<()> { + self.parent.fill_binding_indices_and_compile()?; + if self.lsh_search.filter.is_some() { + let bindings: BTreeMap<_, _> = self + .own_bindings + .iter() + .cloned() + .enumerate() + .map(|(a, b)| (b, a)) + .collect(); + let filter = self.lsh_search.filter.as_mut().unwrap(); + filter.fill_binding_indices(&bindings)?; + self.filter_bytecode = Some((filter.compile()?, filter.span())); + } + Ok(()) + } + fn iter<'a>( + &'a self, + tx: &'a SessionTx<'_>, + delta_rule: Option<&MagicSymbol>, + stores: &'a BTreeMap, + ) -> Result> { + let bindings = self.parent.bindings_after_eliminate(); + let mut bind_idx = usize::MAX; + for (i, b) in bindings.iter().enumerate() { + if *b == self.lsh_search.query { + bind_idx = i; + break; + } + } + let config = self.lsh_search.clone(); + let filter_code = self.filter_bytecode.clone(); + let mut stack = vec![]; + + let it = self + .parent + .iter(tx, delta_rule, stores)? + .map_ok(move |tuple| -> Result<_> { + let q = match tuple[bind_idx].clone() { + DataValue::List(l) => l, + d => bail!("Expected list for LSH search, got {:?}", d), + }; + + let res = tx.lsh_search( + &q, + &config, + &mut stack, + &filter_code, + )?; + Ok(res.into_iter().map(move |t| { + let mut r = tuple.clone(); + r.extend(t); + r + })) + }) + .map(flatten_err) + .flatten_ok(); + Ok(Box::new(it)) + } +} + + #[derive(Debug)] pub(crate) struct FtsSearchRA { pub(crate) parent: Box, @@ -1745,6 +1841,7 @@ impl RelAlgebra { RelAlgebra::Unification(r) => r.do_eliminate_temp_vars(used), RelAlgebra::HnswSearch(_) => Ok(()), RelAlgebra::FtsSearch(_) => Ok(()), + RelAlgebra::LshSearch(_) => Ok(()), } } @@ -1761,6 +1858,7 @@ impl RelAlgebra { RelAlgebra::Unification(u) => Some(&u.to_eliminate), RelAlgebra::HnswSearch(_) => None, RelAlgebra::FtsSearch(_) => None, + RelAlgebra::LshSearch(_) => None, } } @@ -1800,6 +1898,11 @@ impl RelAlgebra { bindings.extend_from_slice(&s.own_bindings); bindings } + RelAlgebra::LshSearch(s) => { + let mut bindings = s.parent.bindings_after_eliminate(); + bindings.extend_from_slice(&s.own_bindings); + bindings + } } } pub(crate) fn iter<'a>( @@ -1820,6 +1923,7 @@ impl RelAlgebra { RelAlgebra::Unification(r) => r.iter(tx, delta_rule, stores), RelAlgebra::HnswSearch(r) => r.iter(tx, delta_rule, stores), RelAlgebra::FtsSearch(r) => r.iter(tx, delta_rule, stores), + RelAlgebra::LshSearch(r) => r.iter(tx, delta_rule, stores), } } } @@ -2001,6 +2105,7 @@ impl InnerJoin { } RelAlgebra::HnswSearch(_) => "hnsw_search_join", RelAlgebra::FtsSearch(_) => "fts_search_join", + RelAlgebra::LshSearch(_) => "lsh_search_join", RelAlgebra::StoredWithValidity(_) => { let join_indices = self .joiner @@ -2113,7 +2218,8 @@ impl InnerJoin { | RelAlgebra::Filter(_) | RelAlgebra::Unification(_) | RelAlgebra::HnswSearch(_) - | RelAlgebra::FtsSearch(_) => { + | RelAlgebra::FtsSearch(_) + | RelAlgebra::LshSearch(_) => { self.materialized_join(tx, eliminate_indices, delta_rule, stores) } RelAlgebra::Reorder(_) => { diff --git a/cozo-core/src/query/reorder.rs b/cozo-core/src/query/reorder.rs index f0c05ecd..cf805943 100644 --- a/cozo-core/src/query/reorder.rs +++ b/cozo-core/src/query/reorder.rs @@ -88,6 +88,14 @@ impl NormalFormInlineRule { pending.push(NormalFormAtom::FtsSearch(s)); } } + NormalFormAtom::LshSearch(s) => { + if seen_variables.contains(&s.query) { + seen_variables.extend(s.all_bindings().cloned()); + round_1_collected.push(NormalFormAtom::LshSearch(s)); + } else { + pending.push(NormalFormAtom::LshSearch(s)); + } + } } } @@ -124,6 +132,10 @@ impl NormalFormInlineRule { seen_variables.extend(s.all_bindings().cloned()); collected.push(NormalFormAtom::FtsSearch(s)); } + NormalFormAtom::LshSearch(s) => { + seen_variables.extend(s.all_bindings().cloned()); + collected.push(NormalFormAtom::LshSearch(s)); + } } for atom in last_pending.iter() { match atom { @@ -158,6 +170,14 @@ impl NormalFormInlineRule { pending.push(NormalFormAtom::FtsSearch(s.clone())); } } + NormalFormAtom::LshSearch(s) => { + if seen_variables.contains(&s.query) { + seen_variables.extend(s.all_bindings().cloned()); + collected.push(NormalFormAtom::LshSearch(s.clone())); + } else { + pending.push(NormalFormAtom::LshSearch(s.clone())); + } + } NormalFormAtom::Predicate(p) => { if p.bindings()?.is_subset(&seen_variables) { collected.push(NormalFormAtom::Predicate(p.clone())); @@ -206,6 +226,9 @@ impl NormalFormInlineRule { NormalFormAtom::FtsSearch(s) => { bail!(UnboundVariable(s.span)) } + NormalFormAtom::LshSearch(s) => { + bail!(UnboundVariable(s.span)) + } } } } diff --git a/cozo-core/src/query/stratify.rs b/cozo-core/src/query/stratify.rs index 53f2eeb1..5e7e2c90 100644 --- a/cozo-core/src/query/stratify.rs +++ b/cozo-core/src/query/stratify.rs @@ -31,7 +31,8 @@ impl NormalFormAtom { | NormalFormAtom::Predicate(_) | NormalFormAtom::Unification(_) | NormalFormAtom::HnswSearch(_) - | NormalFormAtom::FtsSearch(_) => Default::default(), + | NormalFormAtom::FtsSearch(_) + | NormalFormAtom::LshSearch(_) => Default::default(), NormalFormAtom::Rule(r) => BTreeMap::from([(&r.name, false)]), NormalFormAtom::NegatedRule(r) => BTreeMap::from([(&r.name, true)]), } diff --git a/cozo-core/src/runtime/db.rs b/cozo-core/src/runtime/db.rs index 196ede06..1ddf6a11 100644 --- a/cozo-core/src/runtime/db.rs +++ b/cozo-core/src/runtime/db.rs @@ -43,7 +43,7 @@ use crate::fts::TokenizerCache; use crate::parse::sys::SysOp; use crate::parse::{parse_script, CozoScript, SourceSpan}; use crate::query::compile::{CompiledProgram, CompiledRule, CompiledRuleSet}; -use crate::query::ra::{FilteredRA, FtsSearchRA, HnswSearchRA, InnerJoin, NegJoin, RelAlgebra, ReorderRA, StoredRA, StoredWithValidityRA, TempStoreRA, UnificationRA}; +use crate::query::ra::{FilteredRA, FtsSearchRA, HnswSearchRA, InnerJoin, LshSearchRA, NegJoin, RelAlgebra, ReorderRA, StoredRA, StoredWithValidityRA, TempStoreRA, UnificationRA}; #[allow(unused_imports)] use crate::runtime::callback::{ CallbackCollector, CallbackDeclaration, CallbackOp, EventCallbackRegistry, @@ -1082,6 +1082,18 @@ impl<'s, S: Storage<'s>> Db { .map(|f| f.to_string()) .collect_vec()), ), + RelAlgebra::LshSearch(LshSearchRA { + lsh_search, .. + }) => ( + "lsh_index", + json!(format!(":{}", lsh_search.query.name)), + json!(lsh_search.query.name), + json!(lsh_search + .filter + .iter() + .map(|f| f.to_string()) + .collect_vec()), + ), }; ret_for_relation.push(json!({ STRATUM: stratum, @@ -1217,6 +1229,20 @@ impl<'s, S: Storage<'s>> Db { vec![vec![DataValue::from(OK_STR)]], )) } + SysOp::CreateMinHashLshIndex(config) => { + let lock = self + .obtain_relation_locks(iter::once(&config.base_relation)) + .pop() + .unwrap(); + let _guard = lock.write().unwrap(); + let mut tx = self.transact_write()?; + tx.create_minhash_lsh_index(config)?; + tx.commit_tx()?; + Ok(NamedRows::new( + vec![STATUS_STR.to_string()], + vec![vec![DataValue::from(OK_STR)]], + )) + } SysOp::RemoveIndex(rel_name, idx_name) => { let lock = self .obtain_relation_locks(iter::once(&rel_name.name)) diff --git a/cozo-core/src/runtime/minhash_lsh.rs b/cozo-core/src/runtime/minhash_lsh.rs new file mode 100644 index 00000000..8e9bfcf1 --- /dev/null +++ b/cozo-core/src/runtime/minhash_lsh.rs @@ -0,0 +1,374 @@ +/* + * Copyright 2023, The Cozo Project Authors. + * + * This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. + * If a copy of the MPL was not distributed with this file, + * You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +// Some ideas are from https://github.com/schelterlabs/rust-minhash + +use crate::data::expr::{eval_bytecode, Bytecode, eval_bytecode_pred}; +use crate::data::tuple::Tuple; +use crate::fts::tokenizer::TextAnalyzer; +use crate::fts::TokenizerConfig; +use crate::runtime::relation::RelationHandle; +use crate::runtime::transact::SessionTx; +use crate::{DataValue, Expr, SourceSpan, Symbol}; +use miette::{bail, miette, Result}; +use quadrature::integrate; +use rand::{thread_rng, RngCore}; +use rustc_hash::FxHashMap; +use smartstring::{LazyCompact, SmartString}; +use std::cmp::min; +use std::hash::{Hash, Hasher}; +use twox_hash::XxHash32; + +impl<'a> SessionTx<'a> { + pub(crate) fn del_lsh_index_item( + &mut self, + tuple: &[DataValue], + bytes: Option>, + idx_handle: &RelationHandle, + inv_idx_handle: &RelationHandle, + manifest: &MinHashLshIndexManifest, + ) -> Result<()> { + let bytes = match bytes { + None => { + if let Some(mut found) = inv_idx_handle.get_val_only(self, tuple)? { + let inv_key = inv_idx_handle.encode_key_for_store(tuple, Default::default())?; + self.store_tx.del(&inv_key)?; + match found.pop() { + Some(DataValue::Bytes(b)) => b, + _ => unreachable!(), + } + } else { + return Ok(()); + } + } + Some(b) => b, + }; + + let mut key = Vec::with_capacity(bytes.len() + 2); + key.push(DataValue::Bot); + key.push(DataValue::Bot); + key.extend_from_slice(tuple); + for (i, chunk) in bytes + .chunks_exact(manifest.r * std::mem::size_of::()) + .enumerate() + { + key[0] = DataValue::from(i as i64); + key[1] = DataValue::Bytes(chunk.to_vec()); + let key_bytes = idx_handle.encode_key_for_store(&key, Default::default())?; + self.store_tx.del(&key_bytes)?; + } + Ok(()) + } + pub(crate) fn put_lsh_index_item( + &mut self, + tuple: &[DataValue], + extractor: &[Bytecode], + stack: &mut Vec, + tokenizer: &TextAnalyzer, + rel_handle: &RelationHandle, + idx_handle: &RelationHandle, + inv_idx_handle: &RelationHandle, + manifest: &MinHashLshIndexManifest, + hash_perms: &HashPermutations, + ) -> Result<()> { + if let Some(mut found) = + inv_idx_handle.get_val_only(self, &tuple[..rel_handle.metadata.keys.len()])? + { + let bytes = match found.pop() { + Some(DataValue::Bytes(b)) => b, + _ => unreachable!(), + }; + self.del_lsh_index_item(tuple, Some(bytes), idx_handle, inv_idx_handle, manifest)?; + } + let to_index = eval_bytecode(extractor, tuple, stack)?; + let min_hash = match to_index { + DataValue::Null => return Ok(()), + DataValue::List(l) => HashValues::new(l.iter(), hash_perms), + DataValue::Str(s) => { + let n_grams = tokenizer.unique_ngrams(&s, manifest.n_gram); + HashValues::new(n_grams.iter(), hash_perms) + } + _ => bail!("Cannot put value {:?} into a LSH index", to_index), + }; + let bytes = min_hash.get_bytes(); + let inv_key_part = &tuple[..rel_handle.metadata.keys.len()]; + let inv_val_part = vec![DataValue::Bytes(bytes.to_vec())]; + let inv_key = inv_idx_handle.encode_key_for_store(inv_key_part, Default::default())?; + let inv_val = + inv_idx_handle.encode_val_only_for_store(&inv_val_part, Default::default())?; + self.store_tx.put(&inv_key, &inv_val)?; + + let mut key = Vec::with_capacity(bytes.len() + 2); + key.push(DataValue::Bot); + key.push(DataValue::Bot); + key.extend_from_slice(inv_key_part); + let chunk_size = manifest.r * std::mem::size_of::(); + for i in 0..manifest.b { + let byte_range = &bytes[i * chunk_size..(i + 1) * chunk_size]; + key[0] = DataValue::from(i as i64); + key[1] = DataValue::Bytes(byte_range.to_vec()); + let key_bytes = idx_handle.encode_key_for_store(&key, Default::default())?; + self.store_tx.put(&key_bytes, &[])?; + } + + Ok(()) + } + pub(crate) fn lsh_search( + &self, + tuple: &[DataValue], + config: &LshSearch, + stack: &mut Vec, + filter_code: &Option<(Vec, SourceSpan)>, + ) -> Result> { + let bytes = if let Some(mut found) = config + .inv_idx_handle + .get_val_only(self, &tuple[..config.base_handle.metadata.keys.len()])? + { + match found.pop() { + Some(DataValue::Bytes(b)) => b, + _ => unreachable!(), + } + } else { + return Ok(vec![]); + }; + let chunk_size = config.manifest.r * std::mem::size_of::(); + let mut key_prefix = Vec::with_capacity(2); + let mut found_tuples: FxHashMap<_, usize> = FxHashMap::default(); + for (i, chunk) in bytes.chunks_exact(chunk_size).enumerate() { + key_prefix.clear(); + key_prefix.push(DataValue::from(i as i64)); + key_prefix.push(DataValue::Bytes(chunk.to_vec())); + for ks in config.idx_handle.scan_prefix(self, &key_prefix) { + let ks = ks?; + let key_part = &ks[2..]; + if key_part != tuple { + let found = found_tuples.entry(key_part.to_vec()).or_default(); + *found += 1; + } + } + } + let mut ret = vec![]; + for (key, count) in found_tuples { + let similarity = count as f64 / config.manifest.r as f64; + if similarity < config.min_similarity { + continue; + } + let mut orig_tuple = config + .base_handle + .get(self, &key)? + .ok_or_else(|| miette!("Tuple not found in base LSH relation"))?; + if let Some((filter_code, span)) = filter_code { + if !eval_bytecode_pred(filter_code, &orig_tuple, stack, *span)? { + continue; + } + } + if config.bind_similarity.is_some() { + orig_tuple.push(DataValue::from(similarity)); + } + ret.push(orig_tuple); + if let Some(k) = config.k { + if ret.len() >= k { + break; + } + } + } + Ok(ret) + } +} + +#[derive(Clone, Debug)] +pub(crate) struct LshSearch { + pub(crate) base_handle: RelationHandle, + pub(crate) idx_handle: RelationHandle, + pub(crate) inv_idx_handle: RelationHandle, + pub(crate) manifest: MinHashLshIndexManifest, + pub(crate) bindings: Vec, + pub(crate) k: Option, + pub(crate) query: Symbol, + pub(crate) bind_similarity: Option, + pub(crate) min_similarity: f64, + // pub(crate) lax_mode: bool, + pub(crate) filter: Option, + pub(crate) span: SourceSpan, +} + +impl LshSearch { + pub(crate) fn all_bindings(&self) -> impl Iterator { + self.bindings.iter().chain(self.bind_similarity.iter()) + } +} + +pub(crate) struct HashValues(pub(crate) Vec); +pub(crate) struct HashPermutations(pub(crate) Vec); + +#[derive(Debug, Clone, PartialEq, serde_derive::Serialize, serde_derive::Deserialize)] +pub(crate) struct MinHashLshIndexManifest { + pub(crate) base_relation: SmartString, + pub(crate) index_name: SmartString, + pub(crate) extractor: String, + pub(crate) n_gram: usize, + pub(crate) tokenizer: TokenizerConfig, + pub(crate) filters: Vec, + + pub(crate) num_perm: usize, + pub(crate) b: usize, + pub(crate) r: usize, + pub(crate) threshold: f64, + pub(crate) perms: Vec, +} + +impl MinHashLshIndexManifest { + pub(crate) fn get_hash_perms(&self) -> HashPermutations { + HashPermutations::from_bytes(&self.perms) + } +} + +#[derive(Clone, Debug)] +pub(crate) struct LshParams { + pub b: usize, + pub r: usize, +} + +#[derive(Clone)] +pub(crate) struct Weights(pub(crate) f64, pub(crate) f64); + +const _ALLOWED_INTEGRATE_ERR: f64 = 0.001; + +// code is mostly from https://github.com/schelterlabs/rust-minhash/blob/81ea3fec24fd888a330a71b6932623643346b591/src/minhash_lsh.rs +impl LshParams { + pub fn find_optimal_params(threshold: f64, num_perm: usize, weights: &Weights) -> LshParams { + let Weights(false_positive_weight, false_negative_weight) = weights; + let mut min_error = f64::INFINITY; + let mut opt = LshParams { b: 0, r: 0 }; + for b in 1..num_perm + 1 { + let max_r = num_perm / b; + for r in 1..max_r + 1 { + let false_pos = LshParams::false_positive_probability(threshold, b, r); + let false_neg = LshParams::false_negative_probability(threshold, b, r); + let error = false_pos * false_positive_weight + false_neg * false_negative_weight; + if error < min_error { + min_error = error; + opt = LshParams { b, r }; + } + } + } + opt + } + + fn false_positive_probability(threshold: f64, b: usize, r: usize) -> f64 { + let _probability = + |s| -> f64 { 1. - f64::powf(1. - f64::powi(s, r as i32), b as f64) }; + integrate(_probability, 0.0, threshold, _ALLOWED_INTEGRATE_ERR).integral + } + + fn false_negative_probability(threshold: f64, b: usize, r: usize) -> f64 { + let _probability = + |s| -> f64 { 1. - (1. - f64::powf(1. - f64::powi(s, r as i32), b as f64)) }; + integrate(_probability, threshold, 1.0, _ALLOWED_INTEGRATE_ERR).integral + } +} + +impl HashPermutations { + pub(crate) fn new(n_perms: usize) -> Self { + let mut rng = thread_rng(); + let mut perms = Vec::with_capacity(n_perms); + for _ in 0..n_perms { + perms.push(rng.next_u32()); + } + Self(perms) + } + pub(crate) fn as_bytes(&self) -> &[u8] { + unsafe { + std::slice::from_raw_parts( + self.0.as_ptr() as *const u8, + self.0.len() * std::mem::size_of::(), + ) + } + } + // this is the inverse of `as_bytes` + pub(crate) fn from_bytes(bytes: &[u8]) -> Self { + unsafe { + let ptr = bytes.as_ptr() as *const u32; + let len = bytes.len() / std::mem::size_of::(); + let perms = std::slice::from_raw_parts(ptr, len); + Self(perms.to_vec()) + } + } +} + +impl HashValues { + pub(crate) fn new(values: impl Iterator, perms: &HashPermutations) -> Self { + let mut ret = Self::init(perms); + ret.update(values, perms); + ret + } + pub(crate) fn init(perms: &HashPermutations) -> Self { + Self(vec![u32::MAX; perms.0.len()]) + } + pub(crate) fn update( + &mut self, + values: impl Iterator, + perms: &HashPermutations, + ) { + for v in values { + for (i, seed) in perms.0.iter().enumerate() { + let mut hasher = XxHash32::with_seed(*seed); + v.hash(&mut hasher); + let hash = hasher.finish() as u32; + self.0[i] = min(self.0[i], hash); + } + } + } + #[cfg(test)] + pub(crate) fn jaccard(&self, other_minhash: &Self) -> f32 { + let matches = self + .0 + .iter() + .zip_eq(&other_minhash.0) + .filter(|(left, right)| left == right) + .count(); + let result = matches as f32 / self.0.len() as f32; + result + } + pub(crate) fn get_bytes(&self) -> &[u8] { + unsafe { + std::slice::from_raw_parts( + self.0.as_ptr() as *const u8, + self.0.len() * std::mem::size_of::(), + ) + } + } + // pub(crate) fn get_byte_chunks(&self, n_chunks: usize) -> impl Iterator { + // let chunk_size = self.0.len() * std::mem::size_of::() / n_chunks; + // self.get_bytes().chunks_exact(chunk_size) + // } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_minhash() { + let perms = HashPermutations::new(20000); + let mut m1 = HashValues::new([1, 2, 3, 4, 5, 6].iter(), &perms); + let mut m2 = HashValues::new([4, 3, 2, 1, 5, 6].iter(), &perms); + assert_eq!(m1.0, m2.0); + // println!("{:?}", &m1.0); + // println!("{:?}", &m2.0); + assert_eq!(m1.jaccard(&m2), 1.0); + m1.update([7, 8, 9].iter(), &perms); + assert!(m1.jaccard(&m2) < 1.0); + println!("{:?}", m1.jaccard(&m2)); + m2.update([17, 18, 19].iter(), &perms); + assert!(m1.jaccard(&m2) < 1.0); + println!("{:?}", m1.jaccard(&m2)); + // println!("{:?}", m2.get_byte_chunks(2).collect_vec()); + assert_eq!(perms.0, HashPermutations::from_bytes(perms.as_bytes()).0); + } +} diff --git a/cozo-core/src/runtime/mod.rs b/cozo-core/src/runtime/mod.rs index 91dde669..e8e02eac 100644 --- a/cozo-core/src/runtime/mod.rs +++ b/cozo-core/src/runtime/mod.rs @@ -13,5 +13,6 @@ pub(crate) mod relation; pub(crate) mod temp_store; pub(crate) mod transact; pub(crate) mod hnsw; +pub(crate) mod minhash_lsh; #[cfg(test)] mod tests; diff --git a/cozo-core/src/runtime/relation.rs b/cozo-core/src/runtime/relation.rs index 6d94f21d..1bd6dd87 100644 --- a/cozo-core/src/runtime/relation.rs +++ b/cozo-core/src/runtime/relation.rs @@ -26,10 +26,11 @@ use crate::data::tuple::{decode_tuple_from_key, Tuple, TupleT, ENCODED_KEY_MIN_L use crate::data::value::{DataValue, ValidityTs}; use crate::fts::FtsIndexManifest; use crate::parse::expr::build_expr; -use crate::parse::sys::{FtsIndexConfig, HnswIndexConfig}; +use crate::parse::sys::{FtsIndexConfig, HnswIndexConfig, MinHashLshConfig}; use crate::parse::{CozoScriptParser, Rule, SourceSpan}; use crate::query::compile::IndexPositionUse; use crate::runtime::hnsw::HnswIndexManifest; +use crate::runtime::minhash_lsh::{HashPermutations, LshParams, MinHashLshIndexManifest, Weights}; use crate::runtime::transact::SessionTx; use crate::{NamedRows, StoreTx}; @@ -83,6 +84,10 @@ pub(crate) struct RelationHandle { pub(crate) hnsw_indices: BTreeMap, (RelationHandle, HnswIndexManifest)>, pub(crate) fts_indices: BTreeMap, (RelationHandle, FtsIndexManifest)>, + pub(crate) lsh_indices: BTreeMap< + SmartString, + (RelationHandle, RelationHandle, MinHashLshIndexManifest), + >, } impl RelationHandle { @@ -90,9 +95,13 @@ impl RelationHandle { self.indices.contains_key(index_name) || self.hnsw_indices.contains_key(index_name) || self.fts_indices.contains_key(index_name) + || self.lsh_indices.contains_key(index_name) } pub(crate) fn has_no_index(&self) -> bool { - self.indices.is_empty() && self.hnsw_indices.is_empty() && self.fts_indices.is_empty() + self.indices.is_empty() + && self.hnsw_indices.is_empty() + && self.fts_indices.is_empty() + && self.lsh_indices.is_empty() } } @@ -254,10 +263,7 @@ impl RelationHandle { } Ok(ret) } - pub(crate) fn encode_partial_key_for_store( - &self, - tuple: &[DataValue], - ) -> Vec { + pub(crate) fn encode_partial_key_for_store(&self, tuple: &[DataValue]) -> Vec { let mut ret = self.encode_key_prefix(tuple.len()); for val in tuple { ret.encode_datavalue(val); @@ -392,6 +398,25 @@ impl RelationHandle { } } + pub(crate) fn get_val_only( + &self, + tx: &SessionTx<'_>, + key: &[DataValue], + ) -> Result> { + let key_data = key.encode_as_key(self.id); + if self.is_temp { + Ok(tx + .temp_store_tx + .get(&key_data, false)? + .map(|val_data| rmp_serde::from_slice(&val_data[ENCODED_KEY_MIN_LEN..]).unwrap())) + } else { + Ok(tx + .store_tx + .get(&key_data, false)? + .map(|val_data| rmp_serde::from_slice(&val_data[ENCODED_KEY_MIN_LEN..]).unwrap())) + } + } + pub(crate) fn exists(&self, tx: &SessionTx<'_>, key: &[DataValue]) -> Result { let key_data = key.encode_as_key(self.id); if self.is_temp { @@ -594,6 +619,7 @@ impl<'a> SessionTx<'a> { indices: Default::default(), hnsw_indices: Default::default(), fts_indices: Default::default(), + lsh_indices: Default::default(), }; let name_key = vec![DataValue::Str(meta.name.clone())].encode_as_key(RelationId::SYSTEM); @@ -694,6 +720,141 @@ impl<'a> SessionTx<'a> { Ok(()) } + pub(crate) fn create_minhash_lsh_index(&mut self, config: MinHashLshConfig) -> Result<()> { + // Get relation handle + let mut rel_handle = self.get_relation(&config.base_relation, true)?; + + // Check if index already exists + if rel_handle.has_index(&config.index_name) { + bail!(IndexAlreadyExists( + config.index_name.to_string(), + config.index_name.to_string() + )); + } + + let inv_idx_keys = rel_handle.metadata.keys.clone(); + let inv_idx_vals = vec![ColumnDef { + name: SmartString::from("minhash"), + typing: NullableColType { + coltype: ColType::Bytes, + nullable: false, + }, + default_gen: None, + }]; + + let mut idx_keys = vec![ + ColumnDef { + name: SmartString::from("perm"), + typing: NullableColType { + coltype: ColType::Int, + nullable: false, + }, + default_gen: None, + }, + ColumnDef { + name: SmartString::from("hash"), + typing: NullableColType { + coltype: ColType::Bytes, + nullable: false, + }, + default_gen: None, + }, + ]; + for k in rel_handle.metadata.keys.iter() { + idx_keys.push(ColumnDef { + name: format!("src_{}", k.name).into(), + typing: k.typing.clone(), + default_gen: None, + }); + } + let idx_vals = vec![]; + + let idx_handle = self.write_idx_relation( + &config.base_relation, + &config.index_name, + idx_keys, + idx_vals, + )?; + + let inv_idx_handle = self.write_idx_relation( + &config.base_relation, + &config.index_name, + inv_idx_keys, + inv_idx_vals, + )?; + + // add index to relation + let params = LshParams::find_optimal_params( + config.target_threshold.0, + config.n_perm, + &Weights( + config.false_positive_weight.0, + config.false_negative_weight.0, + ), + ); + let perms = HashPermutations::new(config.n_perm); + let manifest = MinHashLshIndexManifest { + base_relation: config.base_relation, + index_name: config.index_name, + extractor: config.extractor, + n_gram: config.n_gram, + tokenizer: config.tokenizer, + filters: config.filters, + num_perm: config.n_perm, + b: params.b, + r: params.r, + threshold: config.target_threshold.0, + perms: perms.as_bytes().to_vec(), + }; + + // populate index + let tokenizer = + self.tokenizers + .get(&idx_handle.name, &manifest.tokenizer, &manifest.filters)?; + let parsed = CozoScriptParser::parse(Rule::expr, &manifest.extractor) + .into_diagnostic()? + .next() + .unwrap(); + let mut code_expr = build_expr(parsed, &Default::default())?; + let binding_map = rel_handle.raw_binding_map(); + code_expr.fill_binding_indices(&binding_map)?; + let extractor = code_expr.compile()?; + + let mut stack = vec![]; + + let existing: Vec<_> = rel_handle.scan_all(self).try_collect()?; + let hash_perms = manifest.get_hash_perms(); + for tuple in existing { + self.put_lsh_index_item( + &tuple, + &extractor, + &mut stack, + &tokenizer, + &rel_handle, + &idx_handle, + &inv_idx_handle, + &manifest, + &hash_perms, + )?; + } + + rel_handle.lsh_indices.insert( + manifest.index_name.clone(), + (idx_handle, inv_idx_handle, manifest), + ); + + // update relation metadata + let new_encoded = + vec![DataValue::from(&rel_handle.name as &str)].encode_as_key(RelationId::SYSTEM); + let mut meta_val = vec![]; + rel_handle + .serialize(&mut Serializer::new(&mut meta_val)) + .unwrap(); + self.store_tx.put(&new_encoded, &meta_val)?; + + Ok(()) + } + pub(crate) fn create_fts_index(&mut self, config: FtsIndexConfig) -> Result<()> { // Get relation handle let mut rel_handle = self.get_relation(&config.base_relation, true)?; @@ -748,7 +909,7 @@ impl<'a> SessionTx<'a> { }, ColumnDef { name: SmartString::from("position"), - typing: col_type.clone(), + typing: col_type, default_gen: None, }, ColumnDef { @@ -1192,8 +1353,10 @@ impl<'a> SessionTx<'a> { idx_name: &Symbol, ) -> Result, Vec)>> { let mut rel = self.get_relation(rel_name, true)?; + let is_lsh = rel.lsh_indices.contains_key(&idx_name.name); if rel.indices.remove(&idx_name.name).is_none() && rel.hnsw_indices.remove(&idx_name.name).is_none() + && rel.lsh_indices.remove(&idx_name.name).is_none() { #[derive(Debug, Error, Diagnostic)] #[error("index {0} for relation {1} not found")] @@ -1203,7 +1366,13 @@ impl<'a> SessionTx<'a> { bail!(IndexNotFound(idx_name.to_string(), rel_name.to_string())); } - let to_clean = self.destroy_relation(&format!("{}:{}", rel_name.name, idx_name.name))?; + let mut to_clean = + self.destroy_relation(&format!("{}:{}", rel_name.name, idx_name.name))?; + if is_lsh { + to_clean.extend( + self.destroy_relation(&format!("{}:{}:inv", rel_name.name, idx_name.name))?, + ); + } let new_encoded = vec![DataValue::from(&rel_name.name as &str)].encode_as_key(RelationId::SYSTEM); diff --git a/cozo-core/src/storage/mod.rs b/cozo-core/src/storage/mod.rs index 1840eda0..4c53e3df 100644 --- a/cozo-core/src/storage/mod.rs +++ b/cozo-core/src/storage/mod.rs @@ -62,6 +62,13 @@ pub trait StoreTx<'s>: Sync { /// the key has not been modified outside the transaction. fn get(&self, key: &[u8], for_update: bool) -> Result>>; + /// Get multiple keys. If `for_update` is `true` (only possible in a write transaction), + /// then the database needs to guarantee that `commit()` can only succeed if + /// the keys have not been modified outside the transaction. + fn multi_get(&self, keys: &[Vec], for_update: bool) -> Result>>> { + keys.iter().map(|k| self.get(k, for_update)).collect() + } + /// Put a key-value pair into the storage. In case of existing key, /// the storage engine needs to overwrite the old value. fn put(&mut self, key: &[u8], val: &[u8]) -> Result<()>;