complete LSH

main
Ziyang Hu 1 year ago
parent 683b66c0a7
commit ba98a5c137

19
Cargo.lock generated

@ -724,6 +724,7 @@ dependencies = [
"pest", "pest",
"pest_derive", "pest_derive",
"priority-queue", "priority-queue",
"quadrature",
"rand 0.8.5", "rand 0.8.5",
"rayon", "rayon",
"regex", "regex",
@ -746,6 +747,7 @@ dependencies = [
"tikv-client", "tikv-client",
"tikv-jemallocator-global", "tikv-jemallocator-global",
"tokio", "tokio",
"twox-hash",
"unicode-normalization", "unicode-normalization",
"uuid", "uuid",
] ]
@ -2945,6 +2947,12 @@ dependencies = [
"syn 1.0.109", "syn 1.0.109",
] ]
[[package]]
name = "quadrature"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2054ccb02f454fcb2bc81e343aa0a171636a6331003fd5ec24c47a10966634b7"
[[package]] [[package]]
name = "quote" name = "quote"
version = "1.0.26" version = "1.0.26"
@ -4113,6 +4121,17 @@ version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3528ecfd12c466c6f163363caf2d02a71161dd5e1cc6ae7b34207ea2d42d81ed" checksum = "3528ecfd12c466c6f163363caf2d02a71161dd5e1cc6ae7b34207ea2d42d81ed"
[[package]]
name = "twox-hash"
version = "1.6.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "97fee6b57c6a41524a810daee9286c02d7752c4253064d0b05472833a438f675"
dependencies = [
"cfg-if 1.0.0",
"rand 0.8.5",
"static_assertions",
]
[[package]] [[package]]
name = "typenum" name = "typenum"
version = "1.16.0" version = "1.16.0"

@ -130,6 +130,8 @@ crossbeam = "0.8.2"
ndarray = { version = "0.15.6", features = ["serde"] } ndarray = { version = "0.15.6", features = ["serde"] }
sha2 = "0.10.6" sha2 = "0.10.6"
rustc-hash = "1.1.0" rustc-hash = "1.1.0"
twox-hash = "1.6.3"
quadrature = "0.1.2"
# For the FTS feature # For the FTS feature
jieba-rs = "0.6.7" jieba-rs = "0.6.7"
aho-corasick = "1.0.1" aho-corasick = "1.0.1"

@ -56,7 +56,7 @@ underscore_ident = @{("_" | XID_START) ~ ("_" | XID_CONTINUE)*}
relation_ident = @{"*" ~ (compound_or_index_ident | underscore_ident)} relation_ident = @{"*" ~ (compound_or_index_ident | underscore_ident)}
search_index_ident = _{"~" ~ compound_or_index_ident} search_index_ident = _{"~" ~ compound_or_index_ident}
compound_ident = @{ident ~ ("." ~ ident)*} compound_ident = @{ident ~ ("." ~ ident)*}
compound_or_index_ident = @{ident ~ ("." ~ ident)* ~ (":" ~ ident)?} compound_or_index_ident = @{ident ~ ("." ~ ident)* ~ (":" ~ ident)*}
rule = {rule_head ~ ":=" ~ rule_body ~ ";"?} rule = {rule_head ~ ":=" ~ rule_body ~ ";"?}
const_rule = {rule_head ~ "<-" ~ expr ~ ";"?} const_rule = {rule_head ~ "<-" ~ expr ~ ";"?}

@ -18,6 +18,7 @@ use thiserror::Error;
use crate::data::aggr::Aggregation; use crate::data::aggr::Aggregation;
use crate::data::expr::Expr; use crate::data::expr::Expr;
use crate::data::functions::OP_LIST;
use crate::data::relation::StoredRelationMetadata; use crate::data::relation::StoredRelationMetadata;
use crate::data::symb::{Symbol, PROG_ENTRY}; use crate::data::symb::{Symbol, PROG_ENTRY};
use crate::data::value::{DataValue, ValidityTs}; use crate::data::value::{DataValue, ValidityTs};
@ -26,6 +27,7 @@ use crate::fts::FtsIndexManifest;
use crate::parse::SourceSpan; use crate::parse::SourceSpan;
use crate::query::logical::{Disjunction, NamedFieldNotFound}; use crate::query::logical::{Disjunction, NamedFieldNotFound};
use crate::runtime::hnsw::HnswIndexManifest; use crate::runtime::hnsw::HnswIndexManifest;
use crate::runtime::minhash_lsh::{LshSearch, MinHashLshIndexManifest};
use crate::runtime::relation::{ use crate::runtime::relation::{
AccessLevel, InputRelationHandle, InsufficientAccessLevel, RelationHandle, AccessLevel, InputRelationHandle, InsufficientAccessLevel, RelationHandle,
}; };
@ -950,8 +952,8 @@ pub(crate) struct HnswSearch {
#[derive(Copy, Clone, Debug, PartialEq, Eq)] #[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub(crate) enum FtsScoreKind { pub(crate) enum FtsScoreKind {
TFIDF, TfIdf,
TF, Tf,
} }
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
@ -989,6 +991,212 @@ impl FtsSearch {
} }
impl SearchInput { impl SearchInput {
fn normalize_lsh(
mut self,
base_handle: RelationHandle,
idx_handle: RelationHandle,
inv_idx_handle: RelationHandle,
manifest: MinHashLshIndexManifest,
gen: &mut TempSymbGen,
) -> Result<Disjunction> {
let mut conj = Vec::with_capacity(self.bindings.len() + 8);
let mut bindings = Vec::with_capacity(self.bindings.len());
let mut seen_variables = BTreeSet::new();
for col in base_handle
.metadata
.keys
.iter()
.chain(base_handle.metadata.non_keys.iter())
{
if let Some(arg) = self.bindings.remove(&col.name) {
match arg {
Expr::Binding { var, .. } => {
if var.is_ignored_symbol() {
bindings.push(gen.next_ignored(var.span));
} else if seen_variables.insert(var.clone()) {
bindings.push(var);
} else {
let span = var.span;
let dup = gen.next(span);
let unif = NormalFormAtom::Unification(Unification {
binding: dup.clone(),
expr: Expr::Binding {
var,
tuple_pos: None,
},
one_many_unif: false,
span,
});
conj.push(unif);
bindings.push(dup);
}
}
expr => {
let span = expr.span();
let kw = gen.next(span);
bindings.push(kw.clone());
let unif = NormalFormAtom::Unification(Unification {
binding: kw,
expr,
one_many_unif: false,
span,
});
conj.push(unif)
}
}
} else {
bindings.push(gen.next_ignored(self.span));
}
}
if let Some((name, _)) = self.bindings.pop_first() {
bail!(NamedFieldNotFound(
self.relation.name.to_string(),
name.to_string(),
self.span
));
}
#[derive(Debug, Error, Diagnostic)]
#[error("Field `{0}` is required for LSH search")]
#[diagnostic(code(parser::hnsw_query_required))]
struct LshRequiredMissing(String, #[label] SourceSpan);
#[derive(Debug, Error, Diagnostic)]
#[error("Expected a list of keys for LSH search")]
#[diagnostic(code(parser::expected_list_for_lsh_keys))]
struct ExpectedListForLshKeys(#[label] SourceSpan);
#[derive(Debug, Error, Diagnostic)]
#[error("Wrong arity for LSH keys, expected {1}, got {2}")]
#[diagnostic(code(parser::wrong_arity_for_lsh_keys))]
struct WrongArityForKeys(#[label] SourceSpan, usize, usize);
let query = match self.parameters.remove("keys") {
None => match self.parameters.remove("key") {
None => {
bail!(LshRequiredMissing("keys".to_string(), self.span))
}
Some(expr) => {
ensure!(
base_handle.indices.keys().len() == 1,
LshRequiredMissing("keys".to_string(), self.span)
);
let span = expr.span();
let kw = gen.next(span);
let unif = NormalFormAtom::Unification(Unification {
binding: kw.clone(),
expr: Expr::Apply {
op: &OP_LIST,
args: [expr].into(),
span,
},
one_many_unif: false,
span,
});
conj.push(unif);
kw
}
},
Some(mut expr) => {
expr.partial_eval()?;
match expr {
Expr::Apply { op, args, span } => {
ensure!(op.name == OP_LIST.name, ExpectedListForLshKeys(span));
ensure!(
args.len() == base_handle.indices.keys().len(),
WrongArityForKeys(span, base_handle.indices.keys().len(), args.len())
);
let kw = gen.next(span);
let unif = NormalFormAtom::Unification(Unification {
binding: kw.clone(),
expr: Expr::Apply { op, args, span },
one_many_unif: false,
span,
});
conj.push(unif);
kw
}
_ => {
bail!(ExpectedListForLshKeys(self.span))
}
}
}
};
let k = match self.parameters.remove("k") {
None => None,
Some(k_expr) => {
let k = k_expr.eval_to_const()?;
let k = k.get_int().ok_or(ExpectedPosIntForFtsK(self.span))?;
#[derive(Debug, Error, Diagnostic)]
#[error("Expected positive integer for `k`")]
#[diagnostic(code(parser::expected_int_for_hnsw_k))]
struct ExpectedPosIntForFtsK(#[label] SourceSpan);
ensure!(k > 0, ExpectedPosIntForFtsK(self.span));
Some(k as usize)
}
};
let filter = self.parameters.remove("filter");
let bind_similarity = match self.parameters.remove("bind_similarity") {
None => None,
Some(Expr::Binding { var, .. }) => Some(var),
Some(expr) => {
let span = expr.span();
let kw = gen.next(span);
let unif = NormalFormAtom::Unification(Unification {
binding: kw.clone(),
expr,
one_many_unif: false,
span,
});
conj.push(unif);
Some(kw)
}
};
let min_similarity = match self.parameters.remove("min_similarity") {
None => manifest.threshold,
Some(expr) => {
let min_similarity = expr.eval_to_const()?;
let min_similarity = min_similarity
.get_float()
.ok_or(ExpectedFloatForMinSimilarity(self.span))?;
#[derive(Debug, Error, Diagnostic)]
#[error("Expected float for `min_similarity`")]
#[diagnostic(code(parser::expected_float_for_min_similarity))]
struct ExpectedFloatForMinSimilarity(#[label] SourceSpan);
ensure!(
(0.0..=1.0).contains(&min_similarity),
ExpectedFloatForMinSimilarity(self.span)
);
min_similarity
}
};
conj.push(NormalFormAtom::LshSearch(LshSearch {
base_handle,
idx_handle,
inv_idx_handle,
manifest,
bindings,
k,
bind_similarity,
query,
span: self.span,
min_similarity,
filter,
}));
Ok(Disjunction::conj(conj))
}
fn normalize_fts( fn normalize_fts(
mut self, mut self,
base_handle: RelationHandle, base_handle: RelationHandle,
@ -1094,46 +1302,6 @@ impl SearchInput {
ensure!(k > 0, ExpectedPosIntForFtsK(self.span)); ensure!(k > 0, ExpectedPosIntForFtsK(self.span));
// let k1 = {
// match self.parameters.remove("k1") {
// None => 1.2,
// Some(expr) => {
// let r = expr.eval_to_const()?;
// let r = r
// .get_float()
// .ok_or_else(|| miette!("k1 for FTS must be a float"))?;
//
// #[derive(Debug, Error, Diagnostic)]
// #[error("Expected positive float for `k1`")]
// #[diagnostic(code(parser::expected_float_for_hnsw_k1))]
// struct ExpectedPosFloatForFtsK1(#[label] SourceSpan);
//
// ensure!(r > 0.0, ExpectedPosFloatForFtsK1(self.span));
// r
// }
// }
// };
//
// let b = {
// match self.parameters.remove("b") {
// None => 0.75,
// Some(expr) => {
// let r = expr.eval_to_const()?;
// let r = r
// .get_float()
// .ok_or_else(|| miette!("b for FTS must be a float"))?;
//
// #[derive(Debug, Error, Diagnostic)]
// #[error("Expected positive float for `b`")]
// #[diagnostic(code(parser::expected_float_for_hnsw_b))]
// struct ExpectedPosFloatForFtsB(#[label] SourceSpan);
//
// ensure!(r > 0.0, ExpectedPosFloatForFtsB(self.span));
// r
// }
// }
// };
let score_kind_expr = self.parameters.remove("score_kind"); let score_kind_expr = self.parameters.remove("score_kind");
let score_kind = match score_kind_expr { let score_kind = match score_kind_expr {
Some(expr) => { Some(expr) => {
@ -1143,26 +1311,14 @@ impl SearchInput {
.ok_or_else(|| miette!("Score kind for FTS must be a string"))?; .ok_or_else(|| miette!("Score kind for FTS must be a string"))?;
match r { match r {
"tf_idf" => FtsScoreKind::TFIDF, "tf_idf" => FtsScoreKind::TfIdf,
"tf" => FtsScoreKind::TF, "tf" => FtsScoreKind::Tf,
s => bail!("Unknown score kind for FTS: {}", s), s => bail!("Unknown score kind for FTS: {}", s),
} }
} }
None => FtsScoreKind::TFIDF, None => FtsScoreKind::TfIdf,
}; };
// let lax_mode_expr = self.parameters.remove("lax_mode");
// let lax_mode = match lax_mode_expr {
// Some(expr) => {
// let r = expr.eval_to_const()?;
// let r = r
// .get_bool()
// .ok_or_else(|| miette!("Lax mode for FTS must be a boolean"))?;
// r
// }
// None => true,
// };
let filter = self.parameters.remove("filter"); let filter = self.parameters.remove("filter");
let bind_score = match self.parameters.remove("bind_score") { let bind_score = match self.parameters.remove("bind_score") {
@ -1447,6 +1603,11 @@ impl SearchInput {
{ {
return self.normalize_fts(base_handle, idx_handle, manifest, gen); return self.normalize_fts(base_handle, idx_handle, manifest, gen);
} }
if let Some((idx_handle, inv_idx_handle, manifest)) =
base_handle.lsh_indices.get(&self.index.name).cloned()
{
return self.normalize_lsh(base_handle, idx_handle, inv_idx_handle, manifest, gen);
}
#[derive(Debug, Error, Diagnostic)] #[derive(Debug, Error, Diagnostic)]
#[error("Index {name} not found on relation {relation}")] #[error("Index {name} not found on relation {relation}")]
#[diagnostic(code(eval::hnsw_index_not_found))] #[diagnostic(code(eval::hnsw_index_not_found))]
@ -1586,6 +1747,7 @@ pub(crate) enum NormalFormAtom {
Unification(Unification), Unification(Unification),
HnswSearch(HnswSearch), HnswSearch(HnswSearch),
FtsSearch(FtsSearch), FtsSearch(FtsSearch),
LshSearch(LshSearch),
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -1598,6 +1760,7 @@ pub(crate) enum MagicAtom {
Unification(Unification), Unification(Unification),
HnswSearch(HnswSearch), HnswSearch(HnswSearch),
FtsSearch(FtsSearch), FtsSearch(FtsSearch),
LshSearch(LshSearch),
} }
#[derive(Clone, Debug)] #[derive(Clone, Debug)]

@ -619,6 +619,13 @@ impl Display for DataValue {
} }
impl DataValue { impl DataValue {
/// Returns a slice of bytes if this one is a Bytes
pub fn get_bytes(&self) -> Option<&[u8]> {
match self {
DataValue::Bytes(b) => Some(b),
_ => None,
}
}
/// Returns a slice of DataValues if this one is a List /// Returns a slice of DataValues if this one is a List
pub fn get_slice(&self) -> Option<&[DataValue]> { pub fn get_slice(&self) -> Option<&[DataValue]> {
match self { match self {

@ -80,11 +80,10 @@ impl<'a> SessionTx<'a> {
if !found_str_key.starts_with(start_key_str) { if !found_str_key.starts_with(start_key_str) {
break; break;
} }
} else { } else if found_str_key != start_key_str {
if found_str_key != start_key_str { break;
break;
}
} }
let vals: Vec<DataValue> = rmp_serde::from_slice(&vvec[ENCODED_KEY_MIN_LEN..]).unwrap(); let vals: Vec<DataValue> = rmp_serde::from_slice(&vvec[ENCODED_KEY_MIN_LEN..]).unwrap();
let froms = vals[0].get_slice().unwrap(); let froms = vals[0].get_slice().unwrap();
let tos = vals[1].get_slice().unwrap(); let tos = vals[1].get_slice().unwrap();
@ -112,8 +111,6 @@ impl<'a> SessionTx<'a> {
&self, &self,
ast: &FtsExpr, ast: &FtsExpr,
config: &FtsSearch, config: &FtsSearch,
filter_code: &Option<(Vec<Bytecode>, SourceSpan)>,
tokenizer: &TextAnalyzer,
n: usize, n: usize,
) -> Result<FxHashMap<Tuple, f64>> { ) -> Result<FxHashMap<Tuple, f64>> {
Ok(match ast { Ok(match ast {
@ -138,21 +135,13 @@ impl<'a> SessionTx<'a> {
let mut res = self.fts_search_impl( let mut res = self.fts_search_impl(
l_iter.next().unwrap(), l_iter.next().unwrap(),
config, config,
filter_code,
tokenizer,
n, n,
)?; )?;
for nxt in l_iter { for nxt in l_iter {
let nxt_res = self.fts_search_impl(nxt, config, filter_code, tokenizer, n)?; let nxt_res = self.fts_search_impl(nxt, config, n)?;
res = res res = res
.into_iter() .into_iter()
.filter_map(|(k, v)| { .filter_map(|(k, v)| nxt_res.get(&k).map(|nxt_v| (k, v + nxt_v)))
if let Some(nxt_v) = nxt_res.get(&k) {
Some((k, v + nxt_v))
} else {
None
}
})
.collect(); .collect();
} }
res res
@ -160,7 +149,7 @@ impl<'a> SessionTx<'a> {
FtsExpr::Or(ls) => { FtsExpr::Or(ls) => {
let mut res: FxHashMap<Tuple, f64> = FxHashMap::default(); let mut res: FxHashMap<Tuple, f64> = FxHashMap::default();
for nxt in ls { for nxt in ls {
let nxt_res = self.fts_search_impl(nxt, config, filter_code, tokenizer, n)?; let nxt_res = self.fts_search_impl(nxt, config, n)?;
for (k, v) in nxt_res { for (k, v) in nxt_res {
if let Some(old_v) = res.get_mut(&k) { if let Some(old_v) = res.get_mut(&k) {
*old_v = (*old_v).max(v); *old_v = (*old_v).max(v);
@ -199,10 +188,8 @@ impl<'a> SessionTx<'a> {
if cur - p <= *distance { if cur - p <= *distance {
inner_coll.insert(p); inner_coll.insert(p);
} }
} else { } else if p - cur <= *distance {
if p - cur <= *distance { inner_coll.insert(cur);
inner_coll.insert(cur);
}
} }
} }
} }
@ -230,9 +217,9 @@ impl<'a> SessionTx<'a> {
.collect() .collect()
} }
FtsExpr::Not(fst, snd) => { FtsExpr::Not(fst, snd) => {
let mut res = self.fts_search_impl(fst, config, filter_code, tokenizer, n)?; let mut res = self.fts_search_impl(fst, config, n)?;
for el in self for el in self
.fts_search_impl(snd, config, filter_code, tokenizer, n)? .fts_search_impl(snd, config, n)?
.keys() .keys()
{ {
res.remove(el); res.remove(el);
@ -250,8 +237,8 @@ impl<'a> SessionTx<'a> {
) -> f64 { ) -> f64 {
let tf = tf as f64; let tf = tf as f64;
match config.score_kind { match config.score_kind {
FtsScoreKind::TF => tf * booster, FtsScoreKind::Tf => tf * booster,
FtsScoreKind::TFIDF => { FtsScoreKind::TfIdf => {
let n_found_docs = n_found_docs as f64; let n_found_docs = n_found_docs as f64;
let idf = (1.0 + (n_total as f64 - n_found_docs + 0.5) / (n_found_docs + 0.5)).ln(); let idf = (1.0 + (n_total as f64 - n_found_docs + 0.5) / (n_found_docs + 0.5)).ln();
tf * idf * booster tf * idf * booster
@ -271,13 +258,13 @@ impl<'a> SessionTx<'a> {
if ast.is_empty() { if ast.is_empty() {
return Ok(vec![]); return Ok(vec![]);
} }
let n = if config.score_kind == FtsScoreKind::TFIDF { let n = if config.score_kind == FtsScoreKind::TfIdf {
cache.get_n_for_relation(&config.base_handle, self)? cache.get_n_for_relation(&config.base_handle, self)?
} else { } else {
0 0
}; };
let mut result: Vec<_> = self let mut result: Vec<_> = self
.fts_search_impl(&ast, config, filter_code, tokenizer, n)? .fts_search_impl(&ast, config, n)?
.into_iter() .into_iter()
.collect(); .collect();
result.sort_by_key(|(_, score)| Reverse(OrderedFloat(*score))); result.sort_by_key(|(_, score)| Reverse(OrderedFloat(*score)));

@ -1,7 +1,10 @@
use smartstring::{LazyCompact, SmartString};
/// The tokenizer module contains all of the tools used to process /// The tokenizer module contains all of the tools used to process
/// text in `tantivy`. /// text in `tantivy`.
use std::borrow::{Borrow, BorrowMut}; use std::borrow::{Borrow, BorrowMut};
use std::iter;
use std::ops::{Deref, DerefMut}; use std::ops::{Deref, DerefMut};
use rustc_hash::FxHashSet;
use crate::fts::tokenizer::empty_tokenizer::EmptyTokenizer; use crate::fts::tokenizer::empty_tokenizer::EmptyTokenizer;
@ -60,7 +63,10 @@ impl TextAnalyzer {
/// ///
/// When creating a `TextAnalyzer` from a `Tokenizer` alone, prefer using /// When creating a `TextAnalyzer` from a `Tokenizer` alone, prefer using
/// `TextAnalyzer::from(tokenizer)`. /// `TextAnalyzer::from(tokenizer)`.
pub(crate) fn new<T: Tokenizer>(tokenizer: T, token_filters: Vec<BoxTokenFilter>) -> TextAnalyzer { pub(crate) fn new<T: Tokenizer>(
tokenizer: T,
token_filters: Vec<BoxTokenFilter>,
) -> TextAnalyzer {
TextAnalyzer { TextAnalyzer {
tokenizer: Box::new(tokenizer), tokenizer: Box::new(tokenizer),
token_filters, token_filters,
@ -96,6 +102,25 @@ impl TextAnalyzer {
} }
token_stream token_stream
} }
pub(crate) fn unique_ngrams(&self, text: &str, n: usize) -> FxHashSet<Vec<SmartString<LazyCompact>>> {
let mut token_steam = self.token_stream(text);
let mut coll: Vec<SmartString<LazyCompact>> = vec![];
while let Some(token) = token_steam.next() {
coll.push(SmartString::from(token.text.as_str()));
}
if n == 1 {
coll.iter().map(|x| vec![x.clone()]).collect()
} else if n >= coll.len() {
iter::once(coll).collect()
} else {
let mut ret = FxHashSet::default();
for chunk in coll.windows(n) {
ret.insert(chunk.to_vec());
}
ret
}
}
} }
impl Clone for TextAnalyzer { impl Clone for TextAnalyzer {

@ -21,7 +21,7 @@ pub(crate) fn parse_fts_query(q: &str) -> Result<FtsExpr> {
let pairs = pairs.next().unwrap().into_inner(); let pairs = pairs.next().unwrap().into_inner();
let pairs: Vec<_> = pairs let pairs: Vec<_> = pairs
.filter(|r| r.as_rule() != Rule::EOI) .filter(|r| r.as_rule() != Rule::EOI)
.map(|r| parse_fts_expr(r)) .map(parse_fts_expr)
.try_collect()?; .try_collect()?;
Ok(if pairs.len() == 1 { Ok(if pairs.len() == 1 {
pairs.into_iter().next().unwrap() pairs.into_iter().next().unwrap()
@ -157,7 +157,7 @@ mod tests {
assert!(matches!(res, FtsExpr::Not(_, _))); assert!(matches!(res, FtsExpr::Not(_, _)));
let src = " NEAR(abc def \"ghi\"^22.8) "; let src = " NEAR(abc def \"ghi\"^22.8) ";
let res = parse_fts_query(src).unwrap().flatten(); let res = parse_fts_query(src).unwrap().flatten();
assert!(matches!(res, FtsExpr::Near(FtsNear{distance: 10, ..}))); assert!(matches!(res, FtsExpr::Near(FtsNear { distance: 10, .. })));
println!("{:#?}", res); println!("{:#?}", res);
} }
} }

@ -11,6 +11,7 @@ use std::sync::Arc;
use itertools::Itertools; use itertools::Itertools;
use miette::{bail, ensure, miette, Diagnostic, Result}; use miette::{bail, ensure, miette, Diagnostic, Result};
use ordered_float::OrderedFloat;
use smartstring::{LazyCompact, SmartString}; use smartstring::{LazyCompact, SmartString};
use thiserror::Error; use thiserror::Error;
@ -41,8 +42,10 @@ pub(crate) enum SysOp {
CreateIndex(Symbol, Symbol, Vec<Symbol>), CreateIndex(Symbol, Symbol, Vec<Symbol>),
CreateVectorIndex(HnswIndexConfig), CreateVectorIndex(HnswIndexConfig),
CreateFtsIndex(FtsIndexConfig), CreateFtsIndex(FtsIndexConfig),
CreateMinHashLshIndex(MinHashLshConfig),
RemoveIndex(Symbol, Symbol), RemoveIndex(Symbol, Symbol),
} }
#[derive(Debug, Clone, PartialEq, Eq, Hash)] #[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub(crate) struct FtsIndexConfig { pub(crate) struct FtsIndexConfig {
pub(crate) base_relation: SmartString<LazyCompact>, pub(crate) base_relation: SmartString<LazyCompact>,
@ -52,6 +55,20 @@ pub(crate) struct FtsIndexConfig {
pub(crate) filters: Vec<TokenizerConfig>, pub(crate) filters: Vec<TokenizerConfig>,
} }
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub(crate) struct MinHashLshConfig {
pub(crate) base_relation: SmartString<LazyCompact>,
pub(crate) index_name: SmartString<LazyCompact>,
pub(crate) extractor: String,
pub(crate) tokenizer: TokenizerConfig,
pub(crate) filters: Vec<TokenizerConfig>,
pub(crate) n_gram: usize,
pub(crate) n_perm: usize,
pub(crate) false_positive_weight: OrderedFloat<f64>,
pub(crate) false_negative_weight: OrderedFloat<f64>,
pub(crate) target_threshold: OrderedFloat<f64>,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)] #[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub(crate) struct HnswIndexConfig { pub(crate) struct HnswIndexConfig {
pub(crate) base_relation: SmartString<LazyCompact>, pub(crate) base_relation: SmartString<LazyCompact>,
@ -186,7 +203,175 @@ pub(crate) fn parse_sys(
SysOp::SetTriggers(rel, puts, rms, replaces) SysOp::SetTriggers(rel, puts, rms, replaces)
} }
Rule::lsh_idx_op => { Rule::lsh_idx_op => {
todo!() let inner = inner.into_inner().next().unwrap();
match inner.as_rule() {
Rule::index_create_adv => {
let mut inner = inner.into_inner();
let rel = inner.next().unwrap();
let name = inner.next().unwrap();
let mut filters = vec![];
let mut tokenizer = TokenizerConfig {
name: Default::default(),
args: Default::default(),
};
let mut extractor = "".to_string();
let mut n_gram = 1;
let mut n_perm = 200;
let mut target_threshold = 0.9;
let mut false_positive_weight = 1.0;
let mut false_negative_weight = 1.0;
for opt_pair in inner {
let mut opt_inner = opt_pair.into_inner();
let opt_name = opt_inner.next().unwrap();
let opt_val = opt_inner.next().unwrap();
match opt_name.as_str() {
"false_positive_weight" => {
let mut expr = build_expr(opt_val, param_pool)?;
expr.partial_eval()?;
let v = expr.eval_to_const()?;
false_positive_weight = v.get_float().ok_or_else(|| {
miette!("false_positive_weight must be a float")
})?;
}
"false_negative_weight" => {
let mut expr = build_expr(opt_val, param_pool)?;
expr.partial_eval()?;
let v = expr.eval_to_const()?;
false_negative_weight = v.get_float().ok_or_else(|| {
miette!("false_negative_weight must be a float")
})?;
}
"n_gram" => {
let mut expr = build_expr(opt_val, param_pool)?;
expr.partial_eval()?;
let v = expr.eval_to_const()?;
n_gram = v
.get_int()
.ok_or_else(|| miette!("n_gram must be an integer"))?
as usize;
}
"n_perm" => {
let mut expr = build_expr(opt_val, param_pool)?;
expr.partial_eval()?;
let v = expr.eval_to_const()?;
n_perm = v
.get_int()
.ok_or_else(|| miette!("n_perm must be an integer"))?
as usize;
}
"target_threshold" => {
let mut expr = build_expr(opt_val, param_pool)?;
expr.partial_eval()?;
let v = expr.eval_to_const()?;
target_threshold = v
.get_float()
.ok_or_else(|| miette!("target_threshold must be a float"))?;
}
"extractor" => {
let mut ex = build_expr(opt_val, param_pool)?;
ex.partial_eval()?;
extractor = ex.to_string();
}
"tokenizer" => {
let mut expr = build_expr(opt_val, param_pool)?;
expr.partial_eval()?;
match expr {
Expr::UnboundApply { op, args, .. } => {
let mut targs = vec![];
for arg in args.iter() {
let v = arg.clone().eval_to_const()?;
targs.push(v);
}
tokenizer.name = op;
tokenizer.args = targs;
}
Expr::Binding { var, .. } => {
tokenizer.name = var.name;
tokenizer.args = vec![];
}
_ => bail!("Tokenizer must be a symbol or a call for an existing tokenizer"),
}
}
"filters" => {
let mut expr = build_expr(opt_val, param_pool)?;
expr.partial_eval()?;
match expr {
Expr::Apply { op, args, .. } => {
if op.name != "LIST" {
bail!("Filters must be a list of filters");
}
for arg in args.iter() {
match arg {
Expr::UnboundApply { op, args, .. } => {
let mut targs = vec![];
for arg in args.iter() {
let v = arg.clone().eval_to_const()?;
targs.push(v);
}
filters.push(TokenizerConfig {
name: op.clone(),
args: targs,
})
}
Expr::Binding { var, .. } => {
filters.push(TokenizerConfig {
name: var.name.clone(),
args: vec![],
})
}
_ => bail!("Tokenizer must be a symbol or a call for an existing tokenizer"),
}
}
}
_ => bail!("Filters must be a list of filters"),
}
}
_ => bail!("Unknown option {} for FTS index", opt_name.as_str()),
}
}
ensure!(
false_positive_weight > 0.,
"false_positive_weight must be positive"
);
ensure!(
false_negative_weight > 0.,
"false_negative_weight must be positive"
);
ensure!(n_gram > 0, "n_gram must be positive");
ensure!(n_perm > 0, "n_perm must be positive");
ensure!(
target_threshold > 0. && target_threshold < 1.,
"target_threshold must be between 0 and 1"
);
let total_weights = false_positive_weight + false_negative_weight;
false_positive_weight /= total_weights;
false_negative_weight /= total_weights;
let config = MinHashLshConfig {
base_relation: SmartString::from(rel.as_str()),
index_name: SmartString::from(name.as_str()),
extractor,
tokenizer,
filters,
n_gram,
n_perm,
false_positive_weight: false_positive_weight.into(),
false_negative_weight: false_negative_weight.into(),
target_threshold: target_threshold.into(),
};
SysOp::CreateMinHashLshIndex(config)
}
Rule::index_drop => {
let mut inner = inner.into_inner();
let rel = inner.next().unwrap();
let name = inner.next().unwrap();
SysOp::RemoveIndex(
Symbol::new(rel.as_str(), rel.extract_span()),
Symbol::new(name.as_str(), name.extract_span()),
)
}
r => unreachable!("{:?}", r),
}
} }
Rule::fts_idx_op => { Rule::fts_idx_op => {
let inner = inner.into_inner().next().unwrap(); let inner = inner.into_inner().next().unwrap();

@ -539,6 +539,40 @@ impl<'a> SessionTx<'a> {
ret = ret.filter(Expr::build_and(post_filters, s.span))?; ret = ret.filter(Expr::build_and(post_filters, s.span))?;
} }
} }
MagicAtom::LshSearch(s) => {
debug_assert!(
seen_variables.contains(&s.query),
"FTS search query must be bound"
);
let mut own_bindings = vec![];
let mut post_filters = vec![];
for var in s.all_bindings() {
if seen_variables.contains(var) {
let rk = gen_symb(var.span);
post_filters.push(Expr::build_equate(
vec![
Expr::Binding {
var: var.clone(),
tuple_pos: None,
},
Expr::Binding {
var: rk.clone(),
tuple_pos: None,
},
],
var.span,
));
own_bindings.push(rk);
} else {
seen_variables.insert(var.clone());
own_bindings.push(var.clone());
}
}
ret = ret.lsh_search(s.clone(), own_bindings)?;
if !post_filters.is_empty() {
ret = ret.filter(Expr::build_and(post_filters, s.span))?;
}
}
MagicAtom::Unification(u) => { MagicAtom::Unification(u) => {
if seen_variables.contains(&u.binding) { if seen_variables.contains(&u.binding) {
let expr = if u.one_many_unif { let expr = if u.one_many_unif {

@ -180,6 +180,10 @@ fn magic_rewrite_ruleset(
seen_bindings.extend(s.all_bindings().cloned()); seen_bindings.extend(s.all_bindings().cloned());
collected_atoms.push(MagicAtom::FtsSearch(s)); collected_atoms.push(MagicAtom::FtsSearch(s));
} }
MagicAtom::LshSearch(s) => {
seen_bindings.extend(s.all_bindings().cloned());
collected_atoms.push(MagicAtom::LshSearch(s));
}
MagicAtom::Rule(r_app) => { MagicAtom::Rule(r_app) => {
if r_app.name.has_bound_adornment() { if r_app.name.has_bound_adornment() {
// we are guaranteed to have a magic rule application // we are guaranteed to have a magic rule application
@ -539,6 +543,14 @@ impl NormalFormAtom {
} }
MagicAtom::FtsSearch(s.clone()) MagicAtom::FtsSearch(s.clone())
} }
NormalFormAtom::LshSearch(s) => {
for arg in s.all_bindings() {
if !seen_bindings.contains(arg) {
seen_bindings.insert(arg.clone());
}
}
MagicAtom::LshSearch(s.clone())
}
NormalFormAtom::Predicate(p) => { NormalFormAtom::Predicate(p) => {
// predicate cannot introduce new bindings // predicate cannot introduce new bindings

@ -23,6 +23,7 @@ use crate::data::symb::Symbol;
use crate::data::tuple::{Tuple, TupleIter}; use crate::data::tuple::{Tuple, TupleIter};
use crate::data::value::{DataValue, ValidityTs}; use crate::data::value::{DataValue, ValidityTs};
use crate::parse::SourceSpan; use crate::parse::SourceSpan;
use crate::runtime::minhash_lsh::LshSearch;
use crate::runtime::relation::RelationHandle; use crate::runtime::relation::RelationHandle;
use crate::runtime::temp_store::EpochStore; use crate::runtime::temp_store::EpochStore;
use crate::runtime::transact::SessionTx; use crate::runtime::transact::SessionTx;
@ -40,6 +41,7 @@ pub(crate) enum RelAlgebra {
Unification(UnificationRA), Unification(UnificationRA),
HnswSearch(HnswSearchRA), HnswSearch(HnswSearchRA),
FtsSearch(FtsSearchRA), FtsSearch(FtsSearchRA),
LshSearch(LshSearchRA),
} }
impl RelAlgebra { impl RelAlgebra {
@ -56,6 +58,7 @@ impl RelAlgebra {
RelAlgebra::StoredWithValidity(i) => i.span, RelAlgebra::StoredWithValidity(i) => i.span,
RelAlgebra::HnswSearch(i) => i.hnsw_search.span, RelAlgebra::HnswSearch(i) => i.hnsw_search.span,
RelAlgebra::FtsSearch(i) => i.fts_search.span, RelAlgebra::FtsSearch(i) => i.fts_search.span,
RelAlgebra::LshSearch(i) => i.lsh_search.span,
} }
} }
} }
@ -290,6 +293,11 @@ impl Debug for RelAlgebra {
.field(&bindings) .field(&bindings)
.field(&s.fts_search.idx_handle.name) .field(&s.fts_search.idx_handle.name)
.finish(), .finish(),
RelAlgebra::LshSearch(s) => f
.debug_tuple("LshSearch")
.field(&bindings)
.field(&s.lsh_search.idx_handle.name)
.finish(),
RelAlgebra::StoredWithValidity(r) => f RelAlgebra::StoredWithValidity(r) => f
.debug_tuple("StoredWithValidity") .debug_tuple("StoredWithValidity")
.field(&bindings) .field(&bindings)
@ -362,6 +370,9 @@ impl RelAlgebra {
RelAlgebra::FtsSearch(s) => { RelAlgebra::FtsSearch(s) => {
s.fill_binding_indices_and_compile()?; s.fill_binding_indices_and_compile()?;
} }
RelAlgebra::LshSearch(s) => {
s.fill_binding_indices_and_compile()?;
}
RelAlgebra::StoredWithValidity(v) => { RelAlgebra::StoredWithValidity(v) => {
v.fill_binding_indices_and_compile()?; v.fill_binding_indices_and_compile()?;
} }
@ -459,7 +470,8 @@ impl RelAlgebra {
| RelAlgebra::NegJoin(_) | RelAlgebra::NegJoin(_)
| RelAlgebra::Unification(_) | RelAlgebra::Unification(_)
| RelAlgebra::HnswSearch(_) | RelAlgebra::HnswSearch(_)
| RelAlgebra::FtsSearch(_)) => { | RelAlgebra::FtsSearch(_)
| RelAlgebra::LshSearch(_)) => {
let span = filter.span(); let span = filter.span();
RelAlgebra::Filter(FilteredRA { RelAlgebra::Filter(FilteredRA {
parent: Box::new(s), parent: Box::new(s),
@ -624,6 +636,18 @@ impl RelAlgebra {
own_bindings, own_bindings,
})) }))
} }
pub(crate) fn lsh_search(
self,
fts_search: LshSearch,
own_bindings: Vec<Symbol>,
) -> Result<Self> {
Ok(Self::LshSearch(LshSearchRA {
parent: Box::new(self),
lsh_search: fts_search,
filter_bytecode: None,
own_bindings,
}))
}
pub(crate) fn join( pub(crate) fn join(
self, self,
right: RelAlgebra, right: RelAlgebra,
@ -875,6 +899,78 @@ pub(crate) struct HnswSearchRA {
pub(crate) own_bindings: Vec<Symbol>, pub(crate) own_bindings: Vec<Symbol>,
} }
#[derive(Debug)]
pub(crate) struct LshSearchRA {
pub(crate) parent: Box<RelAlgebra>,
pub(crate) lsh_search: LshSearch,
pub(crate) filter_bytecode: Option<(Vec<Bytecode>, SourceSpan)>,
pub(crate) own_bindings: Vec<Symbol>,
}
impl LshSearchRA {
fn fill_binding_indices_and_compile(&mut self) -> Result<()> {
self.parent.fill_binding_indices_and_compile()?;
if self.lsh_search.filter.is_some() {
let bindings: BTreeMap<_, _> = self
.own_bindings
.iter()
.cloned()
.enumerate()
.map(|(a, b)| (b, a))
.collect();
let filter = self.lsh_search.filter.as_mut().unwrap();
filter.fill_binding_indices(&bindings)?;
self.filter_bytecode = Some((filter.compile()?, filter.span()));
}
Ok(())
}
fn iter<'a>(
&'a self,
tx: &'a SessionTx<'_>,
delta_rule: Option<&MagicSymbol>,
stores: &'a BTreeMap<MagicSymbol, EpochStore>,
) -> Result<TupleIter<'a>> {
let bindings = self.parent.bindings_after_eliminate();
let mut bind_idx = usize::MAX;
for (i, b) in bindings.iter().enumerate() {
if *b == self.lsh_search.query {
bind_idx = i;
break;
}
}
let config = self.lsh_search.clone();
let filter_code = self.filter_bytecode.clone();
let mut stack = vec![];
let it = self
.parent
.iter(tx, delta_rule, stores)?
.map_ok(move |tuple| -> Result<_> {
let q = match tuple[bind_idx].clone() {
DataValue::List(l) => l,
d => bail!("Expected list for LSH search, got {:?}", d),
};
let res = tx.lsh_search(
&q,
&config,
&mut stack,
&filter_code,
)?;
Ok(res.into_iter().map(move |t| {
let mut r = tuple.clone();
r.extend(t);
r
}))
})
.map(flatten_err)
.flatten_ok();
Ok(Box::new(it))
}
}
#[derive(Debug)] #[derive(Debug)]
pub(crate) struct FtsSearchRA { pub(crate) struct FtsSearchRA {
pub(crate) parent: Box<RelAlgebra>, pub(crate) parent: Box<RelAlgebra>,
@ -1745,6 +1841,7 @@ impl RelAlgebra {
RelAlgebra::Unification(r) => r.do_eliminate_temp_vars(used), RelAlgebra::Unification(r) => r.do_eliminate_temp_vars(used),
RelAlgebra::HnswSearch(_) => Ok(()), RelAlgebra::HnswSearch(_) => Ok(()),
RelAlgebra::FtsSearch(_) => Ok(()), RelAlgebra::FtsSearch(_) => Ok(()),
RelAlgebra::LshSearch(_) => Ok(()),
} }
} }
@ -1761,6 +1858,7 @@ impl RelAlgebra {
RelAlgebra::Unification(u) => Some(&u.to_eliminate), RelAlgebra::Unification(u) => Some(&u.to_eliminate),
RelAlgebra::HnswSearch(_) => None, RelAlgebra::HnswSearch(_) => None,
RelAlgebra::FtsSearch(_) => None, RelAlgebra::FtsSearch(_) => None,
RelAlgebra::LshSearch(_) => None,
} }
} }
@ -1800,6 +1898,11 @@ impl RelAlgebra {
bindings.extend_from_slice(&s.own_bindings); bindings.extend_from_slice(&s.own_bindings);
bindings bindings
} }
RelAlgebra::LshSearch(s) => {
let mut bindings = s.parent.bindings_after_eliminate();
bindings.extend_from_slice(&s.own_bindings);
bindings
}
} }
} }
pub(crate) fn iter<'a>( pub(crate) fn iter<'a>(
@ -1820,6 +1923,7 @@ impl RelAlgebra {
RelAlgebra::Unification(r) => r.iter(tx, delta_rule, stores), RelAlgebra::Unification(r) => r.iter(tx, delta_rule, stores),
RelAlgebra::HnswSearch(r) => r.iter(tx, delta_rule, stores), RelAlgebra::HnswSearch(r) => r.iter(tx, delta_rule, stores),
RelAlgebra::FtsSearch(r) => r.iter(tx, delta_rule, stores), RelAlgebra::FtsSearch(r) => r.iter(tx, delta_rule, stores),
RelAlgebra::LshSearch(r) => r.iter(tx, delta_rule, stores),
} }
} }
} }
@ -2001,6 +2105,7 @@ impl InnerJoin {
} }
RelAlgebra::HnswSearch(_) => "hnsw_search_join", RelAlgebra::HnswSearch(_) => "hnsw_search_join",
RelAlgebra::FtsSearch(_) => "fts_search_join", RelAlgebra::FtsSearch(_) => "fts_search_join",
RelAlgebra::LshSearch(_) => "lsh_search_join",
RelAlgebra::StoredWithValidity(_) => { RelAlgebra::StoredWithValidity(_) => {
let join_indices = self let join_indices = self
.joiner .joiner
@ -2113,7 +2218,8 @@ impl InnerJoin {
| RelAlgebra::Filter(_) | RelAlgebra::Filter(_)
| RelAlgebra::Unification(_) | RelAlgebra::Unification(_)
| RelAlgebra::HnswSearch(_) | RelAlgebra::HnswSearch(_)
| RelAlgebra::FtsSearch(_) => { | RelAlgebra::FtsSearch(_)
| RelAlgebra::LshSearch(_) => {
self.materialized_join(tx, eliminate_indices, delta_rule, stores) self.materialized_join(tx, eliminate_indices, delta_rule, stores)
} }
RelAlgebra::Reorder(_) => { RelAlgebra::Reorder(_) => {

@ -88,6 +88,14 @@ impl NormalFormInlineRule {
pending.push(NormalFormAtom::FtsSearch(s)); pending.push(NormalFormAtom::FtsSearch(s));
} }
} }
NormalFormAtom::LshSearch(s) => {
if seen_variables.contains(&s.query) {
seen_variables.extend(s.all_bindings().cloned());
round_1_collected.push(NormalFormAtom::LshSearch(s));
} else {
pending.push(NormalFormAtom::LshSearch(s));
}
}
} }
} }
@ -124,6 +132,10 @@ impl NormalFormInlineRule {
seen_variables.extend(s.all_bindings().cloned()); seen_variables.extend(s.all_bindings().cloned());
collected.push(NormalFormAtom::FtsSearch(s)); collected.push(NormalFormAtom::FtsSearch(s));
} }
NormalFormAtom::LshSearch(s) => {
seen_variables.extend(s.all_bindings().cloned());
collected.push(NormalFormAtom::LshSearch(s));
}
} }
for atom in last_pending.iter() { for atom in last_pending.iter() {
match atom { match atom {
@ -158,6 +170,14 @@ impl NormalFormInlineRule {
pending.push(NormalFormAtom::FtsSearch(s.clone())); pending.push(NormalFormAtom::FtsSearch(s.clone()));
} }
} }
NormalFormAtom::LshSearch(s) => {
if seen_variables.contains(&s.query) {
seen_variables.extend(s.all_bindings().cloned());
collected.push(NormalFormAtom::LshSearch(s.clone()));
} else {
pending.push(NormalFormAtom::LshSearch(s.clone()));
}
}
NormalFormAtom::Predicate(p) => { NormalFormAtom::Predicate(p) => {
if p.bindings()?.is_subset(&seen_variables) { if p.bindings()?.is_subset(&seen_variables) {
collected.push(NormalFormAtom::Predicate(p.clone())); collected.push(NormalFormAtom::Predicate(p.clone()));
@ -206,6 +226,9 @@ impl NormalFormInlineRule {
NormalFormAtom::FtsSearch(s) => { NormalFormAtom::FtsSearch(s) => {
bail!(UnboundVariable(s.span)) bail!(UnboundVariable(s.span))
} }
NormalFormAtom::LshSearch(s) => {
bail!(UnboundVariable(s.span))
}
} }
} }
} }

@ -31,7 +31,8 @@ impl NormalFormAtom {
| NormalFormAtom::Predicate(_) | NormalFormAtom::Predicate(_)
| NormalFormAtom::Unification(_) | NormalFormAtom::Unification(_)
| NormalFormAtom::HnswSearch(_) | NormalFormAtom::HnswSearch(_)
| NormalFormAtom::FtsSearch(_) => Default::default(), | NormalFormAtom::FtsSearch(_)
| NormalFormAtom::LshSearch(_) => Default::default(),
NormalFormAtom::Rule(r) => BTreeMap::from([(&r.name, false)]), NormalFormAtom::Rule(r) => BTreeMap::from([(&r.name, false)]),
NormalFormAtom::NegatedRule(r) => BTreeMap::from([(&r.name, true)]), NormalFormAtom::NegatedRule(r) => BTreeMap::from([(&r.name, true)]),
} }

@ -43,7 +43,7 @@ use crate::fts::TokenizerCache;
use crate::parse::sys::SysOp; use crate::parse::sys::SysOp;
use crate::parse::{parse_script, CozoScript, SourceSpan}; use crate::parse::{parse_script, CozoScript, SourceSpan};
use crate::query::compile::{CompiledProgram, CompiledRule, CompiledRuleSet}; use crate::query::compile::{CompiledProgram, CompiledRule, CompiledRuleSet};
use crate::query::ra::{FilteredRA, FtsSearchRA, HnswSearchRA, InnerJoin, NegJoin, RelAlgebra, ReorderRA, StoredRA, StoredWithValidityRA, TempStoreRA, UnificationRA}; use crate::query::ra::{FilteredRA, FtsSearchRA, HnswSearchRA, InnerJoin, LshSearchRA, NegJoin, RelAlgebra, ReorderRA, StoredRA, StoredWithValidityRA, TempStoreRA, UnificationRA};
#[allow(unused_imports)] #[allow(unused_imports)]
use crate::runtime::callback::{ use crate::runtime::callback::{
CallbackCollector, CallbackDeclaration, CallbackOp, EventCallbackRegistry, CallbackCollector, CallbackDeclaration, CallbackOp, EventCallbackRegistry,
@ -1082,6 +1082,18 @@ impl<'s, S: Storage<'s>> Db<S> {
.map(|f| f.to_string()) .map(|f| f.to_string())
.collect_vec()), .collect_vec()),
), ),
RelAlgebra::LshSearch(LshSearchRA {
lsh_search, ..
}) => (
"lsh_index",
json!(format!(":{}", lsh_search.query.name)),
json!(lsh_search.query.name),
json!(lsh_search
.filter
.iter()
.map(|f| f.to_string())
.collect_vec()),
),
}; };
ret_for_relation.push(json!({ ret_for_relation.push(json!({
STRATUM: stratum, STRATUM: stratum,
@ -1217,6 +1229,20 @@ impl<'s, S: Storage<'s>> Db<S> {
vec![vec![DataValue::from(OK_STR)]], vec![vec![DataValue::from(OK_STR)]],
)) ))
} }
SysOp::CreateMinHashLshIndex(config) => {
let lock = self
.obtain_relation_locks(iter::once(&config.base_relation))
.pop()
.unwrap();
let _guard = lock.write().unwrap();
let mut tx = self.transact_write()?;
tx.create_minhash_lsh_index(config)?;
tx.commit_tx()?;
Ok(NamedRows::new(
vec![STATUS_STR.to_string()],
vec![vec![DataValue::from(OK_STR)]],
))
}
SysOp::RemoveIndex(rel_name, idx_name) => { SysOp::RemoveIndex(rel_name, idx_name) => {
let lock = self let lock = self
.obtain_relation_locks(iter::once(&rel_name.name)) .obtain_relation_locks(iter::once(&rel_name.name))

@ -0,0 +1,374 @@
/*
* Copyright 2023, The Cozo Project Authors.
*
* This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0.
* If a copy of the MPL was not distributed with this file,
* You can obtain one at https://mozilla.org/MPL/2.0/.
*/
// Some ideas are from https://github.com/schelterlabs/rust-minhash
use crate::data::expr::{eval_bytecode, Bytecode, eval_bytecode_pred};
use crate::data::tuple::Tuple;
use crate::fts::tokenizer::TextAnalyzer;
use crate::fts::TokenizerConfig;
use crate::runtime::relation::RelationHandle;
use crate::runtime::transact::SessionTx;
use crate::{DataValue, Expr, SourceSpan, Symbol};
use miette::{bail, miette, Result};
use quadrature::integrate;
use rand::{thread_rng, RngCore};
use rustc_hash::FxHashMap;
use smartstring::{LazyCompact, SmartString};
use std::cmp::min;
use std::hash::{Hash, Hasher};
use twox_hash::XxHash32;
impl<'a> SessionTx<'a> {
pub(crate) fn del_lsh_index_item(
&mut self,
tuple: &[DataValue],
bytes: Option<Vec<u8>>,
idx_handle: &RelationHandle,
inv_idx_handle: &RelationHandle,
manifest: &MinHashLshIndexManifest,
) -> Result<()> {
let bytes = match bytes {
None => {
if let Some(mut found) = inv_idx_handle.get_val_only(self, tuple)? {
let inv_key = inv_idx_handle.encode_key_for_store(tuple, Default::default())?;
self.store_tx.del(&inv_key)?;
match found.pop() {
Some(DataValue::Bytes(b)) => b,
_ => unreachable!(),
}
} else {
return Ok(());
}
}
Some(b) => b,
};
let mut key = Vec::with_capacity(bytes.len() + 2);
key.push(DataValue::Bot);
key.push(DataValue::Bot);
key.extend_from_slice(tuple);
for (i, chunk) in bytes
.chunks_exact(manifest.r * std::mem::size_of::<u32>())
.enumerate()
{
key[0] = DataValue::from(i as i64);
key[1] = DataValue::Bytes(chunk.to_vec());
let key_bytes = idx_handle.encode_key_for_store(&key, Default::default())?;
self.store_tx.del(&key_bytes)?;
}
Ok(())
}
pub(crate) fn put_lsh_index_item(
&mut self,
tuple: &[DataValue],
extractor: &[Bytecode],
stack: &mut Vec<DataValue>,
tokenizer: &TextAnalyzer,
rel_handle: &RelationHandle,
idx_handle: &RelationHandle,
inv_idx_handle: &RelationHandle,
manifest: &MinHashLshIndexManifest,
hash_perms: &HashPermutations,
) -> Result<()> {
if let Some(mut found) =
inv_idx_handle.get_val_only(self, &tuple[..rel_handle.metadata.keys.len()])?
{
let bytes = match found.pop() {
Some(DataValue::Bytes(b)) => b,
_ => unreachable!(),
};
self.del_lsh_index_item(tuple, Some(bytes), idx_handle, inv_idx_handle, manifest)?;
}
let to_index = eval_bytecode(extractor, tuple, stack)?;
let min_hash = match to_index {
DataValue::Null => return Ok(()),
DataValue::List(l) => HashValues::new(l.iter(), hash_perms),
DataValue::Str(s) => {
let n_grams = tokenizer.unique_ngrams(&s, manifest.n_gram);
HashValues::new(n_grams.iter(), hash_perms)
}
_ => bail!("Cannot put value {:?} into a LSH index", to_index),
};
let bytes = min_hash.get_bytes();
let inv_key_part = &tuple[..rel_handle.metadata.keys.len()];
let inv_val_part = vec![DataValue::Bytes(bytes.to_vec())];
let inv_key = inv_idx_handle.encode_key_for_store(inv_key_part, Default::default())?;
let inv_val =
inv_idx_handle.encode_val_only_for_store(&inv_val_part, Default::default())?;
self.store_tx.put(&inv_key, &inv_val)?;
let mut key = Vec::with_capacity(bytes.len() + 2);
key.push(DataValue::Bot);
key.push(DataValue::Bot);
key.extend_from_slice(inv_key_part);
let chunk_size = manifest.r * std::mem::size_of::<u32>();
for i in 0..manifest.b {
let byte_range = &bytes[i * chunk_size..(i + 1) * chunk_size];
key[0] = DataValue::from(i as i64);
key[1] = DataValue::Bytes(byte_range.to_vec());
let key_bytes = idx_handle.encode_key_for_store(&key, Default::default())?;
self.store_tx.put(&key_bytes, &[])?;
}
Ok(())
}
pub(crate) fn lsh_search(
&self,
tuple: &[DataValue],
config: &LshSearch,
stack: &mut Vec<DataValue>,
filter_code: &Option<(Vec<Bytecode>, SourceSpan)>,
) -> Result<Vec<Tuple>> {
let bytes = if let Some(mut found) = config
.inv_idx_handle
.get_val_only(self, &tuple[..config.base_handle.metadata.keys.len()])?
{
match found.pop() {
Some(DataValue::Bytes(b)) => b,
_ => unreachable!(),
}
} else {
return Ok(vec![]);
};
let chunk_size = config.manifest.r * std::mem::size_of::<u32>();
let mut key_prefix = Vec::with_capacity(2);
let mut found_tuples: FxHashMap<_, usize> = FxHashMap::default();
for (i, chunk) in bytes.chunks_exact(chunk_size).enumerate() {
key_prefix.clear();
key_prefix.push(DataValue::from(i as i64));
key_prefix.push(DataValue::Bytes(chunk.to_vec()));
for ks in config.idx_handle.scan_prefix(self, &key_prefix) {
let ks = ks?;
let key_part = &ks[2..];
if key_part != tuple {
let found = found_tuples.entry(key_part.to_vec()).or_default();
*found += 1;
}
}
}
let mut ret = vec![];
for (key, count) in found_tuples {
let similarity = count as f64 / config.manifest.r as f64;
if similarity < config.min_similarity {
continue;
}
let mut orig_tuple = config
.base_handle
.get(self, &key)?
.ok_or_else(|| miette!("Tuple not found in base LSH relation"))?;
if let Some((filter_code, span)) = filter_code {
if !eval_bytecode_pred(filter_code, &orig_tuple, stack, *span)? {
continue;
}
}
if config.bind_similarity.is_some() {
orig_tuple.push(DataValue::from(similarity));
}
ret.push(orig_tuple);
if let Some(k) = config.k {
if ret.len() >= k {
break;
}
}
}
Ok(ret)
}
}
#[derive(Clone, Debug)]
pub(crate) struct LshSearch {
pub(crate) base_handle: RelationHandle,
pub(crate) idx_handle: RelationHandle,
pub(crate) inv_idx_handle: RelationHandle,
pub(crate) manifest: MinHashLshIndexManifest,
pub(crate) bindings: Vec<Symbol>,
pub(crate) k: Option<usize>,
pub(crate) query: Symbol,
pub(crate) bind_similarity: Option<Symbol>,
pub(crate) min_similarity: f64,
// pub(crate) lax_mode: bool,
pub(crate) filter: Option<Expr>,
pub(crate) span: SourceSpan,
}
impl LshSearch {
pub(crate) fn all_bindings(&self) -> impl Iterator<Item = &Symbol> {
self.bindings.iter().chain(self.bind_similarity.iter())
}
}
pub(crate) struct HashValues(pub(crate) Vec<u32>);
pub(crate) struct HashPermutations(pub(crate) Vec<u32>);
#[derive(Debug, Clone, PartialEq, serde_derive::Serialize, serde_derive::Deserialize)]
pub(crate) struct MinHashLshIndexManifest {
pub(crate) base_relation: SmartString<LazyCompact>,
pub(crate) index_name: SmartString<LazyCompact>,
pub(crate) extractor: String,
pub(crate) n_gram: usize,
pub(crate) tokenizer: TokenizerConfig,
pub(crate) filters: Vec<TokenizerConfig>,
pub(crate) num_perm: usize,
pub(crate) b: usize,
pub(crate) r: usize,
pub(crate) threshold: f64,
pub(crate) perms: Vec<u8>,
}
impl MinHashLshIndexManifest {
pub(crate) fn get_hash_perms(&self) -> HashPermutations {
HashPermutations::from_bytes(&self.perms)
}
}
#[derive(Clone, Debug)]
pub(crate) struct LshParams {
pub b: usize,
pub r: usize,
}
#[derive(Clone)]
pub(crate) struct Weights(pub(crate) f64, pub(crate) f64);
const _ALLOWED_INTEGRATE_ERR: f64 = 0.001;
// code is mostly from https://github.com/schelterlabs/rust-minhash/blob/81ea3fec24fd888a330a71b6932623643346b591/src/minhash_lsh.rs
impl LshParams {
pub fn find_optimal_params(threshold: f64, num_perm: usize, weights: &Weights) -> LshParams {
let Weights(false_positive_weight, false_negative_weight) = weights;
let mut min_error = f64::INFINITY;
let mut opt = LshParams { b: 0, r: 0 };
for b in 1..num_perm + 1 {
let max_r = num_perm / b;
for r in 1..max_r + 1 {
let false_pos = LshParams::false_positive_probability(threshold, b, r);
let false_neg = LshParams::false_negative_probability(threshold, b, r);
let error = false_pos * false_positive_weight + false_neg * false_negative_weight;
if error < min_error {
min_error = error;
opt = LshParams { b, r };
}
}
}
opt
}
fn false_positive_probability(threshold: f64, b: usize, r: usize) -> f64 {
let _probability =
|s| -> f64 { 1. - f64::powf(1. - f64::powi(s, r as i32), b as f64) };
integrate(_probability, 0.0, threshold, _ALLOWED_INTEGRATE_ERR).integral
}
fn false_negative_probability(threshold: f64, b: usize, r: usize) -> f64 {
let _probability =
|s| -> f64 { 1. - (1. - f64::powf(1. - f64::powi(s, r as i32), b as f64)) };
integrate(_probability, threshold, 1.0, _ALLOWED_INTEGRATE_ERR).integral
}
}
impl HashPermutations {
pub(crate) fn new(n_perms: usize) -> Self {
let mut rng = thread_rng();
let mut perms = Vec::with_capacity(n_perms);
for _ in 0..n_perms {
perms.push(rng.next_u32());
}
Self(perms)
}
pub(crate) fn as_bytes(&self) -> &[u8] {
unsafe {
std::slice::from_raw_parts(
self.0.as_ptr() as *const u8,
self.0.len() * std::mem::size_of::<u32>(),
)
}
}
// this is the inverse of `as_bytes`
pub(crate) fn from_bytes(bytes: &[u8]) -> Self {
unsafe {
let ptr = bytes.as_ptr() as *const u32;
let len = bytes.len() / std::mem::size_of::<u32>();
let perms = std::slice::from_raw_parts(ptr, len);
Self(perms.to_vec())
}
}
}
impl HashValues {
pub(crate) fn new<T: Hash>(values: impl Iterator<Item = T>, perms: &HashPermutations) -> Self {
let mut ret = Self::init(perms);
ret.update(values, perms);
ret
}
pub(crate) fn init(perms: &HashPermutations) -> Self {
Self(vec![u32::MAX; perms.0.len()])
}
pub(crate) fn update<T: Hash>(
&mut self,
values: impl Iterator<Item = T>,
perms: &HashPermutations,
) {
for v in values {
for (i, seed) in perms.0.iter().enumerate() {
let mut hasher = XxHash32::with_seed(*seed);
v.hash(&mut hasher);
let hash = hasher.finish() as u32;
self.0[i] = min(self.0[i], hash);
}
}
}
#[cfg(test)]
pub(crate) fn jaccard(&self, other_minhash: &Self) -> f32 {
let matches = self
.0
.iter()
.zip_eq(&other_minhash.0)
.filter(|(left, right)| left == right)
.count();
let result = matches as f32 / self.0.len() as f32;
result
}
pub(crate) fn get_bytes(&self) -> &[u8] {
unsafe {
std::slice::from_raw_parts(
self.0.as_ptr() as *const u8,
self.0.len() * std::mem::size_of::<u32>(),
)
}
}
// pub(crate) fn get_byte_chunks(&self, n_chunks: usize) -> impl Iterator<Item = &[u8]> {
// let chunk_size = self.0.len() * std::mem::size_of::<u32>() / n_chunks;
// self.get_bytes().chunks_exact(chunk_size)
// }
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_minhash() {
let perms = HashPermutations::new(20000);
let mut m1 = HashValues::new([1, 2, 3, 4, 5, 6].iter(), &perms);
let mut m2 = HashValues::new([4, 3, 2, 1, 5, 6].iter(), &perms);
assert_eq!(m1.0, m2.0);
// println!("{:?}", &m1.0);
// println!("{:?}", &m2.0);
assert_eq!(m1.jaccard(&m2), 1.0);
m1.update([7, 8, 9].iter(), &perms);
assert!(m1.jaccard(&m2) < 1.0);
println!("{:?}", m1.jaccard(&m2));
m2.update([17, 18, 19].iter(), &perms);
assert!(m1.jaccard(&m2) < 1.0);
println!("{:?}", m1.jaccard(&m2));
// println!("{:?}", m2.get_byte_chunks(2).collect_vec());
assert_eq!(perms.0, HashPermutations::from_bytes(perms.as_bytes()).0);
}
}

@ -13,5 +13,6 @@ pub(crate) mod relation;
pub(crate) mod temp_store; pub(crate) mod temp_store;
pub(crate) mod transact; pub(crate) mod transact;
pub(crate) mod hnsw; pub(crate) mod hnsw;
pub(crate) mod minhash_lsh;
#[cfg(test)] #[cfg(test)]
mod tests; mod tests;

@ -26,10 +26,11 @@ use crate::data::tuple::{decode_tuple_from_key, Tuple, TupleT, ENCODED_KEY_MIN_L
use crate::data::value::{DataValue, ValidityTs}; use crate::data::value::{DataValue, ValidityTs};
use crate::fts::FtsIndexManifest; use crate::fts::FtsIndexManifest;
use crate::parse::expr::build_expr; use crate::parse::expr::build_expr;
use crate::parse::sys::{FtsIndexConfig, HnswIndexConfig}; use crate::parse::sys::{FtsIndexConfig, HnswIndexConfig, MinHashLshConfig};
use crate::parse::{CozoScriptParser, Rule, SourceSpan}; use crate::parse::{CozoScriptParser, Rule, SourceSpan};
use crate::query::compile::IndexPositionUse; use crate::query::compile::IndexPositionUse;
use crate::runtime::hnsw::HnswIndexManifest; use crate::runtime::hnsw::HnswIndexManifest;
use crate::runtime::minhash_lsh::{HashPermutations, LshParams, MinHashLshIndexManifest, Weights};
use crate::runtime::transact::SessionTx; use crate::runtime::transact::SessionTx;
use crate::{NamedRows, StoreTx}; use crate::{NamedRows, StoreTx};
@ -83,6 +84,10 @@ pub(crate) struct RelationHandle {
pub(crate) hnsw_indices: pub(crate) hnsw_indices:
BTreeMap<SmartString<LazyCompact>, (RelationHandle, HnswIndexManifest)>, BTreeMap<SmartString<LazyCompact>, (RelationHandle, HnswIndexManifest)>,
pub(crate) fts_indices: BTreeMap<SmartString<LazyCompact>, (RelationHandle, FtsIndexManifest)>, pub(crate) fts_indices: BTreeMap<SmartString<LazyCompact>, (RelationHandle, FtsIndexManifest)>,
pub(crate) lsh_indices: BTreeMap<
SmartString<LazyCompact>,
(RelationHandle, RelationHandle, MinHashLshIndexManifest),
>,
} }
impl RelationHandle { impl RelationHandle {
@ -90,9 +95,13 @@ impl RelationHandle {
self.indices.contains_key(index_name) self.indices.contains_key(index_name)
|| self.hnsw_indices.contains_key(index_name) || self.hnsw_indices.contains_key(index_name)
|| self.fts_indices.contains_key(index_name) || self.fts_indices.contains_key(index_name)
|| self.lsh_indices.contains_key(index_name)
} }
pub(crate) fn has_no_index(&self) -> bool { pub(crate) fn has_no_index(&self) -> bool {
self.indices.is_empty() && self.hnsw_indices.is_empty() && self.fts_indices.is_empty() self.indices.is_empty()
&& self.hnsw_indices.is_empty()
&& self.fts_indices.is_empty()
&& self.lsh_indices.is_empty()
} }
} }
@ -254,10 +263,7 @@ impl RelationHandle {
} }
Ok(ret) Ok(ret)
} }
pub(crate) fn encode_partial_key_for_store( pub(crate) fn encode_partial_key_for_store(&self, tuple: &[DataValue]) -> Vec<u8> {
&self,
tuple: &[DataValue],
) -> Vec<u8> {
let mut ret = self.encode_key_prefix(tuple.len()); let mut ret = self.encode_key_prefix(tuple.len());
for val in tuple { for val in tuple {
ret.encode_datavalue(val); ret.encode_datavalue(val);
@ -392,6 +398,25 @@ impl RelationHandle {
} }
} }
pub(crate) fn get_val_only(
&self,
tx: &SessionTx<'_>,
key: &[DataValue],
) -> Result<Option<Tuple>> {
let key_data = key.encode_as_key(self.id);
if self.is_temp {
Ok(tx
.temp_store_tx
.get(&key_data, false)?
.map(|val_data| rmp_serde::from_slice(&val_data[ENCODED_KEY_MIN_LEN..]).unwrap()))
} else {
Ok(tx
.store_tx
.get(&key_data, false)?
.map(|val_data| rmp_serde::from_slice(&val_data[ENCODED_KEY_MIN_LEN..]).unwrap()))
}
}
pub(crate) fn exists(&self, tx: &SessionTx<'_>, key: &[DataValue]) -> Result<bool> { pub(crate) fn exists(&self, tx: &SessionTx<'_>, key: &[DataValue]) -> Result<bool> {
let key_data = key.encode_as_key(self.id); let key_data = key.encode_as_key(self.id);
if self.is_temp { if self.is_temp {
@ -594,6 +619,7 @@ impl<'a> SessionTx<'a> {
indices: Default::default(), indices: Default::default(),
hnsw_indices: Default::default(), hnsw_indices: Default::default(),
fts_indices: Default::default(), fts_indices: Default::default(),
lsh_indices: Default::default(),
}; };
let name_key = vec![DataValue::Str(meta.name.clone())].encode_as_key(RelationId::SYSTEM); let name_key = vec![DataValue::Str(meta.name.clone())].encode_as_key(RelationId::SYSTEM);
@ -694,6 +720,141 @@ impl<'a> SessionTx<'a> {
Ok(()) Ok(())
} }
pub(crate) fn create_minhash_lsh_index(&mut self, config: MinHashLshConfig) -> Result<()> {
// Get relation handle
let mut rel_handle = self.get_relation(&config.base_relation, true)?;
// Check if index already exists
if rel_handle.has_index(&config.index_name) {
bail!(IndexAlreadyExists(
config.index_name.to_string(),
config.index_name.to_string()
));
}
let inv_idx_keys = rel_handle.metadata.keys.clone();
let inv_idx_vals = vec![ColumnDef {
name: SmartString::from("minhash"),
typing: NullableColType {
coltype: ColType::Bytes,
nullable: false,
},
default_gen: None,
}];
let mut idx_keys = vec![
ColumnDef {
name: SmartString::from("perm"),
typing: NullableColType {
coltype: ColType::Int,
nullable: false,
},
default_gen: None,
},
ColumnDef {
name: SmartString::from("hash"),
typing: NullableColType {
coltype: ColType::Bytes,
nullable: false,
},
default_gen: None,
},
];
for k in rel_handle.metadata.keys.iter() {
idx_keys.push(ColumnDef {
name: format!("src_{}", k.name).into(),
typing: k.typing.clone(),
default_gen: None,
});
}
let idx_vals = vec![];
let idx_handle = self.write_idx_relation(
&config.base_relation,
&config.index_name,
idx_keys,
idx_vals,
)?;
let inv_idx_handle = self.write_idx_relation(
&config.base_relation,
&config.index_name,
inv_idx_keys,
inv_idx_vals,
)?;
// add index to relation
let params = LshParams::find_optimal_params(
config.target_threshold.0,
config.n_perm,
&Weights(
config.false_positive_weight.0,
config.false_negative_weight.0,
),
);
let perms = HashPermutations::new(config.n_perm);
let manifest = MinHashLshIndexManifest {
base_relation: config.base_relation,
index_name: config.index_name,
extractor: config.extractor,
n_gram: config.n_gram,
tokenizer: config.tokenizer,
filters: config.filters,
num_perm: config.n_perm,
b: params.b,
r: params.r,
threshold: config.target_threshold.0,
perms: perms.as_bytes().to_vec(),
};
// populate index
let tokenizer =
self.tokenizers
.get(&idx_handle.name, &manifest.tokenizer, &manifest.filters)?;
let parsed = CozoScriptParser::parse(Rule::expr, &manifest.extractor)
.into_diagnostic()?
.next()
.unwrap();
let mut code_expr = build_expr(parsed, &Default::default())?;
let binding_map = rel_handle.raw_binding_map();
code_expr.fill_binding_indices(&binding_map)?;
let extractor = code_expr.compile()?;
let mut stack = vec![];
let existing: Vec<_> = rel_handle.scan_all(self).try_collect()?;
let hash_perms = manifest.get_hash_perms();
for tuple in existing {
self.put_lsh_index_item(
&tuple,
&extractor,
&mut stack,
&tokenizer,
&rel_handle,
&idx_handle,
&inv_idx_handle,
&manifest,
&hash_perms,
)?;
}
rel_handle.lsh_indices.insert(
manifest.index_name.clone(),
(idx_handle, inv_idx_handle, manifest),
);
// update relation metadata
let new_encoded =
vec![DataValue::from(&rel_handle.name as &str)].encode_as_key(RelationId::SYSTEM);
let mut meta_val = vec![];
rel_handle
.serialize(&mut Serializer::new(&mut meta_val))
.unwrap();
self.store_tx.put(&new_encoded, &meta_val)?;
Ok(())
}
pub(crate) fn create_fts_index(&mut self, config: FtsIndexConfig) -> Result<()> { pub(crate) fn create_fts_index(&mut self, config: FtsIndexConfig) -> Result<()> {
// Get relation handle // Get relation handle
let mut rel_handle = self.get_relation(&config.base_relation, true)?; let mut rel_handle = self.get_relation(&config.base_relation, true)?;
@ -748,7 +909,7 @@ impl<'a> SessionTx<'a> {
}, },
ColumnDef { ColumnDef {
name: SmartString::from("position"), name: SmartString::from("position"),
typing: col_type.clone(), typing: col_type,
default_gen: None, default_gen: None,
}, },
ColumnDef { ColumnDef {
@ -1192,8 +1353,10 @@ impl<'a> SessionTx<'a> {
idx_name: &Symbol, idx_name: &Symbol,
) -> Result<Vec<(Vec<u8>, Vec<u8>)>> { ) -> Result<Vec<(Vec<u8>, Vec<u8>)>> {
let mut rel = self.get_relation(rel_name, true)?; let mut rel = self.get_relation(rel_name, true)?;
let is_lsh = rel.lsh_indices.contains_key(&idx_name.name);
if rel.indices.remove(&idx_name.name).is_none() if rel.indices.remove(&idx_name.name).is_none()
&& rel.hnsw_indices.remove(&idx_name.name).is_none() && rel.hnsw_indices.remove(&idx_name.name).is_none()
&& rel.lsh_indices.remove(&idx_name.name).is_none()
{ {
#[derive(Debug, Error, Diagnostic)] #[derive(Debug, Error, Diagnostic)]
#[error("index {0} for relation {1} not found")] #[error("index {0} for relation {1} not found")]
@ -1203,7 +1366,13 @@ impl<'a> SessionTx<'a> {
bail!(IndexNotFound(idx_name.to_string(), rel_name.to_string())); bail!(IndexNotFound(idx_name.to_string(), rel_name.to_string()));
} }
let to_clean = self.destroy_relation(&format!("{}:{}", rel_name.name, idx_name.name))?; let mut to_clean =
self.destroy_relation(&format!("{}:{}", rel_name.name, idx_name.name))?;
if is_lsh {
to_clean.extend(
self.destroy_relation(&format!("{}:{}:inv", rel_name.name, idx_name.name))?,
);
}
let new_encoded = let new_encoded =
vec![DataValue::from(&rel_name.name as &str)].encode_as_key(RelationId::SYSTEM); vec![DataValue::from(&rel_name.name as &str)].encode_as_key(RelationId::SYSTEM);

@ -62,6 +62,13 @@ pub trait StoreTx<'s>: Sync {
/// the key has not been modified outside the transaction. /// the key has not been modified outside the transaction.
fn get(&self, key: &[u8], for_update: bool) -> Result<Option<Vec<u8>>>; fn get(&self, key: &[u8], for_update: bool) -> Result<Option<Vec<u8>>>;
/// Get multiple keys. If `for_update` is `true` (only possible in a write transaction),
/// then the database needs to guarantee that `commit()` can only succeed if
/// the keys have not been modified outside the transaction.
fn multi_get(&self, keys: &[Vec<u8>], for_update: bool) -> Result<Vec<Option<Vec<u8>>>> {
keys.iter().map(|k| self.get(k, for_update)).collect()
}
/// Put a key-value pair into the storage. In case of existing key, /// Put a key-value pair into the storage. In case of existing key,
/// the storage engine needs to overwrite the old value. /// the storage engine needs to overwrite the old value.
fn put(&mut self, key: &[u8], val: &[u8]) -> Result<()>; fn put(&mut self, key: &[u8], val: &[u8]) -> Result<()>;

Loading…
Cancel
Save