From 34e3b52216d3da0d1dc6d9c699aa486216e3d9f7 Mon Sep 17 00:00:00 2001 From: Ziyang Hu Date: Sun, 30 Apr 2023 11:35:43 +0800 Subject: [PATCH] make LSH work --- cozo-core/src/data/program.rs | 56 ++++++++-------------- cozo-core/src/fts/mod.rs | 2 +- cozo-core/src/parse/sys.rs | 2 +- cozo-core/src/runtime/minhash_lsh.rs | 49 +++++++++----------- cozo-core/src/runtime/relation.rs | 18 ++++---- cozo-core/src/runtime/tests.rs | 69 ++++++++++++++++++++++++++++ 6 files changed, 118 insertions(+), 78 deletions(-) diff --git a/cozo-core/src/data/program.rs b/cozo-core/src/data/program.rs index 493f3eb1..c998762b 100644 --- a/cozo-core/src/data/program.rs +++ b/cozo-core/src/data/program.rs @@ -1109,43 +1109,17 @@ impl SearchInput { let filter = self.parameters.remove("filter"); - let bind_similarity = match self.parameters.remove("bind_similarity") { - None => None, - Some(Expr::Binding { var, .. }) => Some(var), - Some(expr) => { - let span = expr.span(); - let kw = gen.next(span); - let unif = NormalFormAtom::Unification(Unification { - binding: kw.clone(), - expr, - one_many_unif: false, - span, - }); - conj.push(unif); - Some(kw) - } - }; - - let min_similarity = match self.parameters.remove("min_similarity") { - None => manifest.threshold, - Some(expr) => { - let min_similarity = expr.eval_to_const()?; - let min_similarity = min_similarity - .get_float() - .ok_or(ExpectedFloatForMinSimilarity(self.span))?; + #[derive(Debug, Error, Diagnostic)] + #[error("Extra parameters for LSH search: {0:?}")] + #[diagnostic(code(parser::extra_parameters_for_lsh_search))] + struct ExtraParametersForLshSearch(Vec, #[label] SourceSpan); - #[derive(Debug, Error, Diagnostic)] - #[error("Expected float for `min_similarity`")] - #[diagnostic(code(parser::expected_float_for_min_similarity))] - struct ExpectedFloatForMinSimilarity(#[label] SourceSpan); - - ensure!( - (0.0..=1.0).contains(&min_similarity), - ExpectedFloatForMinSimilarity(self.span) - ); - min_similarity - } - }; + if !self.parameters.is_empty() { + bail!(ExtraParametersForLshSearch( + self.parameters.keys().map(|s| s.to_string()).collect(), + self.span + )); + } conj.push(NormalFormAtom::LshSearch(LshSearch { base_handle, @@ -1153,10 +1127,8 @@ impl SearchInput { manifest, bindings, k, - bind_similarity, query, span: self.span, - min_similarity, filter, })); @@ -1303,6 +1275,10 @@ impl SearchInput { } }; + if !self.parameters.is_empty() { + bail!("Unknown parameters for FTS: {:?}", self.parameters.keys()); + } + conj.push(NormalFormAtom::FtsSearch(FtsSearch { base_handle, idx_handle, @@ -1527,6 +1503,10 @@ impl SearchInput { } }; + if !self.parameters.is_empty() { + bail!("Unexpected parameters for HNSW: {:?}", self.parameters); + } + conj.push(NormalFormAtom::HnswSearch(HnswSearch { base_handle, idx_handle, diff --git a/cozo-core/src/fts/mod.rs b/cozo-core/src/fts/mod.rs index f52ca08b..1dd031c3 100644 --- a/cozo-core/src/fts/mod.rs +++ b/cozo-core/src/fts/mod.rs @@ -82,7 +82,7 @@ impl TokenizerConfig { let min_gram = self .args .get(0) - .ok_or_else(|| miette!("Missing first argument `min_gram`"))? + .unwrap_or(&DataValue::from(1)) .get_int() .ok_or_else(|| miette!("First argument `min_gram` must be an integer"))?; let max_gram = self diff --git a/cozo-core/src/parse/sys.rs b/cozo-core/src/parse/sys.rs index 0cbaee45..fda2e2d6 100644 --- a/cozo-core/src/parse/sys.rs +++ b/cozo-core/src/parse/sys.rs @@ -332,7 +332,7 @@ pub(crate) fn parse_sys( _ => bail!("Filters must be a list of filters"), } } - _ => bail!("Unknown option {} for FTS index", opt_name.as_str()), + _ => bail!("Unknown option {} for LSH index", opt_name.as_str()), } } ensure!( diff --git a/cozo-core/src/runtime/minhash_lsh.rs b/cozo-core/src/runtime/minhash_lsh.rs index a96fb9c4..5a27360a 100644 --- a/cozo-core/src/runtime/minhash_lsh.rs +++ b/cozo-core/src/runtime/minhash_lsh.rs @@ -19,7 +19,7 @@ use itertools::Itertools; use miette::{bail, miette, Result}; use quadrature::integrate; use rand::{thread_rng, RngCore}; -use rustc_hash::FxHashMap; +use rustc_hash::FxHashSet; use smartstring::{LazyCompact, SmartString}; use std::cmp::min; use std::hash::{Hash, Hasher}; @@ -105,11 +105,13 @@ impl<'a> SessionTx<'a> { let bytes = min_hash.get_bytes(); let chunk_size = manifest.r * std::mem::size_of::(); - let chunks = (0..manifest.b).map(|i| { - let mut byte_range = bytes[i * chunk_size..(i + 1) * chunk_size].to_vec(); - byte_range.extend_from_slice(&(i as u16).to_le_bytes()); - byte_range - }).collect_vec(); + let chunks = (0..manifest.b) + .map(|i| { + let mut byte_range = bytes[i * chunk_size..(i + 1) * chunk_size].to_vec(); + byte_range.extend_from_slice(&(i as u16).to_le_bytes()); + byte_range + }) + .collect_vec(); let inv_key_part = &tuple[..rel_handle.metadata.keys.len()]; @@ -123,13 +125,14 @@ impl<'a> SessionTx<'a> { self.store_tx.put(&key_bytes, &[])?; } - let inv_val_part = vec![DataValue::List(chunks.into_iter().map(DataValue::Bytes).collect_vec())]; + let inv_val_part = vec![DataValue::List( + chunks.into_iter().map(DataValue::Bytes).collect_vec(), + )]; let inv_key = inv_idx_handle.encode_key_for_store(inv_key_part, Default::default())?; let inv_val = inv_idx_handle.encode_val_only_for_store(&inv_val_part, Default::default())?; self.store_tx.put(&inv_key, &inv_val)?; - Ok(()) } pub(crate) fn lsh_search( @@ -153,26 +156,22 @@ impl<'a> SessionTx<'a> { _ => bail!("Cannot search for value {:?} in a LSH index", q), }; let chunk_size = config.manifest.r * std::mem::size_of::(); - let mut key_prefix = Vec::with_capacity(2); - let mut found_tuples: FxHashMap<_, usize> = FxHashMap::default(); + let mut key_prefix = Vec::with_capacity(1); + let mut found_tuples: FxHashSet<_> = FxHashSet::default(); for (i, chunk) in bytes.chunks_exact(chunk_size).enumerate() { key_prefix.clear(); - key_prefix.push(DataValue::from(i as i64)); - key_prefix.push(DataValue::Bytes(chunk.to_vec())); + let mut chunk = chunk.to_vec(); + chunk.extend_from_slice(&(i as u16).to_le_bytes()); + key_prefix.push(DataValue::Bytes(chunk)); for ks in config.idx_handle.scan_prefix(self, &key_prefix) { let ks = ks?; - let key_part = &ks[2..]; - let found = found_tuples.entry(key_part.to_vec()).or_default(); - *found += 1; + let key_part = &ks[1..]; + found_tuples.insert(key_part.to_vec()); } } let mut ret = vec![]; - for (key, count) in found_tuples { - let similarity = count as f64 / config.manifest.r as f64; - if similarity < config.min_similarity { - continue; - } - let mut orig_tuple = config + for key in found_tuples { + let orig_tuple = config .base_handle .get(self, &key)? .ok_or_else(|| miette!("Tuple not found in base LSH relation"))?; @@ -181,9 +180,6 @@ impl<'a> SessionTx<'a> { continue; } } - if config.bind_similarity.is_some() { - orig_tuple.push(DataValue::from(similarity)); - } ret.push(orig_tuple); if let Some(k) = config.k { if ret.len() >= k { @@ -203,16 +199,13 @@ pub(crate) struct LshSearch { pub(crate) bindings: Vec, pub(crate) k: Option, pub(crate) query: Symbol, - pub(crate) bind_similarity: Option, - pub(crate) min_similarity: f64, - // pub(crate) lax_mode: bool, pub(crate) filter: Option, pub(crate) span: SourceSpan, } impl LshSearch { pub(crate) fn all_bindings(&self) -> impl Iterator { - self.bindings.iter().chain(self.bind_similarity.iter()) + self.bindings.iter() } } diff --git a/cozo-core/src/runtime/relation.rs b/cozo-core/src/runtime/relation.rs index 807a92aa..105ea7ba 100644 --- a/cozo-core/src/runtime/relation.rs +++ b/cozo-core/src/runtime/relation.rs @@ -742,16 +742,14 @@ impl<'a> SessionTx<'a> { default_gen: None, }]; - let mut idx_keys = vec![ - ColumnDef { - name: SmartString::from("hash"), - typing: NullableColType { - coltype: ColType::Bytes, - nullable: false, - }, - default_gen: None, + let mut idx_keys = vec![ColumnDef { + name: SmartString::from("hash"), + typing: NullableColType { + coltype: ColType::Bytes, + nullable: false, }, - ]; + default_gen: None, + }]; for k in rel_handle.metadata.keys.iter() { idx_keys.push(ColumnDef { name: format!("src_{}", k.name).into(), @@ -770,7 +768,7 @@ impl<'a> SessionTx<'a> { let inv_idx_handle = self.write_idx_relation( &config.base_relation, - &config.index_name, + &format!("{}:inv", config.index_name), inv_idx_keys, inv_idx_vals, )?; diff --git a/cozo-core/src/runtime/tests.rs b/cozo-core/src/runtime/tests.rs index baa54c5c..68f84e0e 100644 --- a/cozo-core/src/runtime/tests.rs +++ b/cozo-core/src/runtime/tests.rs @@ -948,6 +948,75 @@ fn test_fts_indexing() { } } + +#[test] +fn test_lsh_indexing() { + let db = DbInstance::new("mem", "", "").unwrap(); + db.run_script(r":create a {k: String => v: String}", Default::default()) + .unwrap(); + db.run_script( + r"?[k, v] <- [['a', 'hello world!'], ['b', 'the world is round']] :put a {k => v}", + Default::default(), + ) + .unwrap(); + db.run_script( + r"::lsh create a:lsh {extractor: v, tokenizer: NGram, n_gram: 3, target_threshold: 0.5 }", + Default::default(), + ) + .unwrap(); + db.run_script( + r"?[k, v] <- [ + ['b', 'the world is square!'], + ['c', 'see you at the end of the world!'], + ['d', 'the world is the world and makes the world go around'], + ['e', 'the world is the world and makes the world not go around'] + ] :put a {k => v}", + Default::default(), + ) + .unwrap(); + let res = db.run_script("::columns a:lsh", Default::default()).unwrap(); + for row in res.into_json()["rows"].as_array().unwrap() { + println!("{}", row); + } + let _res = db + .run_script( + r" + ?[hash, src_k] := + *a:lsh{src_k, hash} + ", + Default::default(), + ) + .unwrap(); + // for row in res.into_json()["rows"].as_array().unwrap() { + // println!("{}", row); + // } + let _res = db + .run_script( + r" + ?[k, minhash] := + *a:lsh:inv{k, minhash} + ", + Default::default(), + ) + .unwrap(); + // for row in res.into_json()["rows"].as_array().unwrap() { + // println!("{}", row); + // } + let res = db + .run_script( + r" + ?[k, v] := ~a:lsh{k, v | + query: 'see him at the end of the world', + } + ", + Default::default(), + ) + .unwrap(); + for row in res.into_json()["rows"].as_array().unwrap() { + println!("{}", row); + } +} + #[test] fn test_insertions() { let db = DbInstance::new("mem", "", "").unwrap();