make LSH work

main
Ziyang Hu 1 year ago
parent a07d4bf960
commit 34e3b52216

@ -1109,43 +1109,17 @@ impl SearchInput {
let filter = self.parameters.remove("filter"); let filter = self.parameters.remove("filter");
let bind_similarity = match self.parameters.remove("bind_similarity") { #[derive(Debug, Error, Diagnostic)]
None => None, #[error("Extra parameters for LSH search: {0:?}")]
Some(Expr::Binding { var, .. }) => Some(var), #[diagnostic(code(parser::extra_parameters_for_lsh_search))]
Some(expr) => { struct ExtraParametersForLshSearch(Vec<String>, #[label] SourceSpan);
let span = expr.span();
let kw = gen.next(span);
let unif = NormalFormAtom::Unification(Unification {
binding: kw.clone(),
expr,
one_many_unif: false,
span,
});
conj.push(unif);
Some(kw)
}
};
let min_similarity = match self.parameters.remove("min_similarity") {
None => manifest.threshold,
Some(expr) => {
let min_similarity = expr.eval_to_const()?;
let min_similarity = min_similarity
.get_float()
.ok_or(ExpectedFloatForMinSimilarity(self.span))?;
#[derive(Debug, Error, Diagnostic)] if !self.parameters.is_empty() {
#[error("Expected float for `min_similarity`")] bail!(ExtraParametersForLshSearch(
#[diagnostic(code(parser::expected_float_for_min_similarity))] self.parameters.keys().map(|s| s.to_string()).collect(),
struct ExpectedFloatForMinSimilarity(#[label] SourceSpan); self.span
));
ensure!( }
(0.0..=1.0).contains(&min_similarity),
ExpectedFloatForMinSimilarity(self.span)
);
min_similarity
}
};
conj.push(NormalFormAtom::LshSearch(LshSearch { conj.push(NormalFormAtom::LshSearch(LshSearch {
base_handle, base_handle,
@ -1153,10 +1127,8 @@ impl SearchInput {
manifest, manifest,
bindings, bindings,
k, k,
bind_similarity,
query, query,
span: self.span, span: self.span,
min_similarity,
filter, filter,
})); }));
@ -1303,6 +1275,10 @@ impl SearchInput {
} }
}; };
if !self.parameters.is_empty() {
bail!("Unknown parameters for FTS: {:?}", self.parameters.keys());
}
conj.push(NormalFormAtom::FtsSearch(FtsSearch { conj.push(NormalFormAtom::FtsSearch(FtsSearch {
base_handle, base_handle,
idx_handle, idx_handle,
@ -1527,6 +1503,10 @@ impl SearchInput {
} }
}; };
if !self.parameters.is_empty() {
bail!("Unexpected parameters for HNSW: {:?}", self.parameters);
}
conj.push(NormalFormAtom::HnswSearch(HnswSearch { conj.push(NormalFormAtom::HnswSearch(HnswSearch {
base_handle, base_handle,
idx_handle, idx_handle,

@ -82,7 +82,7 @@ impl TokenizerConfig {
let min_gram = self let min_gram = self
.args .args
.get(0) .get(0)
.ok_or_else(|| miette!("Missing first argument `min_gram`"))? .unwrap_or(&DataValue::from(1))
.get_int() .get_int()
.ok_or_else(|| miette!("First argument `min_gram` must be an integer"))?; .ok_or_else(|| miette!("First argument `min_gram` must be an integer"))?;
let max_gram = self let max_gram = self

@ -332,7 +332,7 @@ pub(crate) fn parse_sys(
_ => bail!("Filters must be a list of filters"), _ => bail!("Filters must be a list of filters"),
} }
} }
_ => bail!("Unknown option {} for FTS index", opt_name.as_str()), _ => bail!("Unknown option {} for LSH index", opt_name.as_str()),
} }
} }
ensure!( ensure!(

@ -19,7 +19,7 @@ use itertools::Itertools;
use miette::{bail, miette, Result}; use miette::{bail, miette, Result};
use quadrature::integrate; use quadrature::integrate;
use rand::{thread_rng, RngCore}; use rand::{thread_rng, RngCore};
use rustc_hash::FxHashMap; use rustc_hash::FxHashSet;
use smartstring::{LazyCompact, SmartString}; use smartstring::{LazyCompact, SmartString};
use std::cmp::min; use std::cmp::min;
use std::hash::{Hash, Hasher}; use std::hash::{Hash, Hasher};
@ -105,11 +105,13 @@ impl<'a> SessionTx<'a> {
let bytes = min_hash.get_bytes(); let bytes = min_hash.get_bytes();
let chunk_size = manifest.r * std::mem::size_of::<u32>(); let chunk_size = manifest.r * std::mem::size_of::<u32>();
let chunks = (0..manifest.b).map(|i| { let chunks = (0..manifest.b)
let mut byte_range = bytes[i * chunk_size..(i + 1) * chunk_size].to_vec(); .map(|i| {
byte_range.extend_from_slice(&(i as u16).to_le_bytes()); let mut byte_range = bytes[i * chunk_size..(i + 1) * chunk_size].to_vec();
byte_range byte_range.extend_from_slice(&(i as u16).to_le_bytes());
}).collect_vec(); byte_range
})
.collect_vec();
let inv_key_part = &tuple[..rel_handle.metadata.keys.len()]; let inv_key_part = &tuple[..rel_handle.metadata.keys.len()];
@ -123,13 +125,14 @@ impl<'a> SessionTx<'a> {
self.store_tx.put(&key_bytes, &[])?; self.store_tx.put(&key_bytes, &[])?;
} }
let inv_val_part = vec![DataValue::List(chunks.into_iter().map(DataValue::Bytes).collect_vec())]; let inv_val_part = vec![DataValue::List(
chunks.into_iter().map(DataValue::Bytes).collect_vec(),
)];
let inv_key = inv_idx_handle.encode_key_for_store(inv_key_part, Default::default())?; let inv_key = inv_idx_handle.encode_key_for_store(inv_key_part, Default::default())?;
let inv_val = let inv_val =
inv_idx_handle.encode_val_only_for_store(&inv_val_part, Default::default())?; inv_idx_handle.encode_val_only_for_store(&inv_val_part, Default::default())?;
self.store_tx.put(&inv_key, &inv_val)?; self.store_tx.put(&inv_key, &inv_val)?;
Ok(()) Ok(())
} }
pub(crate) fn lsh_search( pub(crate) fn lsh_search(
@ -153,26 +156,22 @@ impl<'a> SessionTx<'a> {
_ => bail!("Cannot search for value {:?} in a LSH index", q), _ => bail!("Cannot search for value {:?} in a LSH index", q),
}; };
let chunk_size = config.manifest.r * std::mem::size_of::<u32>(); let chunk_size = config.manifest.r * std::mem::size_of::<u32>();
let mut key_prefix = Vec::with_capacity(2); let mut key_prefix = Vec::with_capacity(1);
let mut found_tuples: FxHashMap<_, usize> = FxHashMap::default(); let mut found_tuples: FxHashSet<_> = FxHashSet::default();
for (i, chunk) in bytes.chunks_exact(chunk_size).enumerate() { for (i, chunk) in bytes.chunks_exact(chunk_size).enumerate() {
key_prefix.clear(); key_prefix.clear();
key_prefix.push(DataValue::from(i as i64)); let mut chunk = chunk.to_vec();
key_prefix.push(DataValue::Bytes(chunk.to_vec())); chunk.extend_from_slice(&(i as u16).to_le_bytes());
key_prefix.push(DataValue::Bytes(chunk));
for ks in config.idx_handle.scan_prefix(self, &key_prefix) { for ks in config.idx_handle.scan_prefix(self, &key_prefix) {
let ks = ks?; let ks = ks?;
let key_part = &ks[2..]; let key_part = &ks[1..];
let found = found_tuples.entry(key_part.to_vec()).or_default(); found_tuples.insert(key_part.to_vec());
*found += 1;
} }
} }
let mut ret = vec![]; let mut ret = vec![];
for (key, count) in found_tuples { for key in found_tuples {
let similarity = count as f64 / config.manifest.r as f64; let orig_tuple = config
if similarity < config.min_similarity {
continue;
}
let mut orig_tuple = config
.base_handle .base_handle
.get(self, &key)? .get(self, &key)?
.ok_or_else(|| miette!("Tuple not found in base LSH relation"))?; .ok_or_else(|| miette!("Tuple not found in base LSH relation"))?;
@ -181,9 +180,6 @@ impl<'a> SessionTx<'a> {
continue; continue;
} }
} }
if config.bind_similarity.is_some() {
orig_tuple.push(DataValue::from(similarity));
}
ret.push(orig_tuple); ret.push(orig_tuple);
if let Some(k) = config.k { if let Some(k) = config.k {
if ret.len() >= k { if ret.len() >= k {
@ -203,16 +199,13 @@ pub(crate) struct LshSearch {
pub(crate) bindings: Vec<Symbol>, pub(crate) bindings: Vec<Symbol>,
pub(crate) k: Option<usize>, pub(crate) k: Option<usize>,
pub(crate) query: Symbol, pub(crate) query: Symbol,
pub(crate) bind_similarity: Option<Symbol>,
pub(crate) min_similarity: f64,
// pub(crate) lax_mode: bool,
pub(crate) filter: Option<Expr>, pub(crate) filter: Option<Expr>,
pub(crate) span: SourceSpan, pub(crate) span: SourceSpan,
} }
impl LshSearch { impl LshSearch {
pub(crate) fn all_bindings(&self) -> impl Iterator<Item = &Symbol> { pub(crate) fn all_bindings(&self) -> impl Iterator<Item = &Symbol> {
self.bindings.iter().chain(self.bind_similarity.iter()) self.bindings.iter()
} }
} }

@ -742,16 +742,14 @@ impl<'a> SessionTx<'a> {
default_gen: None, default_gen: None,
}]; }];
let mut idx_keys = vec![ let mut idx_keys = vec![ColumnDef {
ColumnDef { name: SmartString::from("hash"),
name: SmartString::from("hash"), typing: NullableColType {
typing: NullableColType { coltype: ColType::Bytes,
coltype: ColType::Bytes, nullable: false,
nullable: false,
},
default_gen: None,
}, },
]; default_gen: None,
}];
for k in rel_handle.metadata.keys.iter() { for k in rel_handle.metadata.keys.iter() {
idx_keys.push(ColumnDef { idx_keys.push(ColumnDef {
name: format!("src_{}", k.name).into(), name: format!("src_{}", k.name).into(),
@ -770,7 +768,7 @@ impl<'a> SessionTx<'a> {
let inv_idx_handle = self.write_idx_relation( let inv_idx_handle = self.write_idx_relation(
&config.base_relation, &config.base_relation,
&config.index_name, &format!("{}:inv", config.index_name),
inv_idx_keys, inv_idx_keys,
inv_idx_vals, inv_idx_vals,
)?; )?;

@ -948,6 +948,75 @@ fn test_fts_indexing() {
} }
} }
#[test]
fn test_lsh_indexing() {
let db = DbInstance::new("mem", "", "").unwrap();
db.run_script(r":create a {k: String => v: String}", Default::default())
.unwrap();
db.run_script(
r"?[k, v] <- [['a', 'hello world!'], ['b', 'the world is round']] :put a {k => v}",
Default::default(),
)
.unwrap();
db.run_script(
r"::lsh create a:lsh {extractor: v, tokenizer: NGram, n_gram: 3, target_threshold: 0.5 }",
Default::default(),
)
.unwrap();
db.run_script(
r"?[k, v] <- [
['b', 'the world is square!'],
['c', 'see you at the end of the world!'],
['d', 'the world is the world and makes the world go around'],
['e', 'the world is the world and makes the world not go around']
] :put a {k => v}",
Default::default(),
)
.unwrap();
let res = db.run_script("::columns a:lsh", Default::default()).unwrap();
for row in res.into_json()["rows"].as_array().unwrap() {
println!("{}", row);
}
let _res = db
.run_script(
r"
?[hash, src_k] :=
*a:lsh{src_k, hash}
",
Default::default(),
)
.unwrap();
// for row in res.into_json()["rows"].as_array().unwrap() {
// println!("{}", row);
// }
let _res = db
.run_script(
r"
?[k, minhash] :=
*a:lsh:inv{k, minhash}
",
Default::default(),
)
.unwrap();
// for row in res.into_json()["rows"].as_array().unwrap() {
// println!("{}", row);
// }
let res = db
.run_script(
r"
?[k, v] := ~a:lsh{k, v |
query: 'see him at the end of the world',
}
",
Default::default(),
)
.unwrap();
for row in res.into_json()["rows"].as_array().unwrap() {
println!("{}", row);
}
}
#[test] #[test]
fn test_insertions() { fn test_insertions() {
let db = DbInstance::new("mem", "", "").unwrap(); let db = DbInstance::new("mem", "", "").unwrap();

Loading…
Cancel
Save