make LSH work

main
Ziyang Hu 1 year ago
parent a07d4bf960
commit 34e3b52216

@ -1109,43 +1109,17 @@ impl SearchInput {
let filter = self.parameters.remove("filter");
let bind_similarity = match self.parameters.remove("bind_similarity") {
None => None,
Some(Expr::Binding { var, .. }) => Some(var),
Some(expr) => {
let span = expr.span();
let kw = gen.next(span);
let unif = NormalFormAtom::Unification(Unification {
binding: kw.clone(),
expr,
one_many_unif: false,
span,
});
conj.push(unif);
Some(kw)
}
};
let min_similarity = match self.parameters.remove("min_similarity") {
None => manifest.threshold,
Some(expr) => {
let min_similarity = expr.eval_to_const()?;
let min_similarity = min_similarity
.get_float()
.ok_or(ExpectedFloatForMinSimilarity(self.span))?;
#[derive(Debug, Error, Diagnostic)]
#[error("Extra parameters for LSH search: {0:?}")]
#[diagnostic(code(parser::extra_parameters_for_lsh_search))]
struct ExtraParametersForLshSearch(Vec<String>, #[label] SourceSpan);
#[derive(Debug, Error, Diagnostic)]
#[error("Expected float for `min_similarity`")]
#[diagnostic(code(parser::expected_float_for_min_similarity))]
struct ExpectedFloatForMinSimilarity(#[label] SourceSpan);
ensure!(
(0.0..=1.0).contains(&min_similarity),
ExpectedFloatForMinSimilarity(self.span)
);
min_similarity
}
};
if !self.parameters.is_empty() {
bail!(ExtraParametersForLshSearch(
self.parameters.keys().map(|s| s.to_string()).collect(),
self.span
));
}
conj.push(NormalFormAtom::LshSearch(LshSearch {
base_handle,
@ -1153,10 +1127,8 @@ impl SearchInput {
manifest,
bindings,
k,
bind_similarity,
query,
span: self.span,
min_similarity,
filter,
}));
@ -1303,6 +1275,10 @@ impl SearchInput {
}
};
if !self.parameters.is_empty() {
bail!("Unknown parameters for FTS: {:?}", self.parameters.keys());
}
conj.push(NormalFormAtom::FtsSearch(FtsSearch {
base_handle,
idx_handle,
@ -1527,6 +1503,10 @@ impl SearchInput {
}
};
if !self.parameters.is_empty() {
bail!("Unexpected parameters for HNSW: {:?}", self.parameters);
}
conj.push(NormalFormAtom::HnswSearch(HnswSearch {
base_handle,
idx_handle,

@ -82,7 +82,7 @@ impl TokenizerConfig {
let min_gram = self
.args
.get(0)
.ok_or_else(|| miette!("Missing first argument `min_gram`"))?
.unwrap_or(&DataValue::from(1))
.get_int()
.ok_or_else(|| miette!("First argument `min_gram` must be an integer"))?;
let max_gram = self

@ -332,7 +332,7 @@ pub(crate) fn parse_sys(
_ => bail!("Filters must be a list of filters"),
}
}
_ => bail!("Unknown option {} for FTS index", opt_name.as_str()),
_ => bail!("Unknown option {} for LSH index", opt_name.as_str()),
}
}
ensure!(

@ -19,7 +19,7 @@ use itertools::Itertools;
use miette::{bail, miette, Result};
use quadrature::integrate;
use rand::{thread_rng, RngCore};
use rustc_hash::FxHashMap;
use rustc_hash::FxHashSet;
use smartstring::{LazyCompact, SmartString};
use std::cmp::min;
use std::hash::{Hash, Hasher};
@ -105,11 +105,13 @@ impl<'a> SessionTx<'a> {
let bytes = min_hash.get_bytes();
let chunk_size = manifest.r * std::mem::size_of::<u32>();
let chunks = (0..manifest.b).map(|i| {
let mut byte_range = bytes[i * chunk_size..(i + 1) * chunk_size].to_vec();
byte_range.extend_from_slice(&(i as u16).to_le_bytes());
byte_range
}).collect_vec();
let chunks = (0..manifest.b)
.map(|i| {
let mut byte_range = bytes[i * chunk_size..(i + 1) * chunk_size].to_vec();
byte_range.extend_from_slice(&(i as u16).to_le_bytes());
byte_range
})
.collect_vec();
let inv_key_part = &tuple[..rel_handle.metadata.keys.len()];
@ -123,13 +125,14 @@ impl<'a> SessionTx<'a> {
self.store_tx.put(&key_bytes, &[])?;
}
let inv_val_part = vec![DataValue::List(chunks.into_iter().map(DataValue::Bytes).collect_vec())];
let inv_val_part = vec![DataValue::List(
chunks.into_iter().map(DataValue::Bytes).collect_vec(),
)];
let inv_key = inv_idx_handle.encode_key_for_store(inv_key_part, Default::default())?;
let inv_val =
inv_idx_handle.encode_val_only_for_store(&inv_val_part, Default::default())?;
self.store_tx.put(&inv_key, &inv_val)?;
Ok(())
}
pub(crate) fn lsh_search(
@ -153,26 +156,22 @@ impl<'a> SessionTx<'a> {
_ => bail!("Cannot search for value {:?} in a LSH index", q),
};
let chunk_size = config.manifest.r * std::mem::size_of::<u32>();
let mut key_prefix = Vec::with_capacity(2);
let mut found_tuples: FxHashMap<_, usize> = FxHashMap::default();
let mut key_prefix = Vec::with_capacity(1);
let mut found_tuples: FxHashSet<_> = FxHashSet::default();
for (i, chunk) in bytes.chunks_exact(chunk_size).enumerate() {
key_prefix.clear();
key_prefix.push(DataValue::from(i as i64));
key_prefix.push(DataValue::Bytes(chunk.to_vec()));
let mut chunk = chunk.to_vec();
chunk.extend_from_slice(&(i as u16).to_le_bytes());
key_prefix.push(DataValue::Bytes(chunk));
for ks in config.idx_handle.scan_prefix(self, &key_prefix) {
let ks = ks?;
let key_part = &ks[2..];
let found = found_tuples.entry(key_part.to_vec()).or_default();
*found += 1;
let key_part = &ks[1..];
found_tuples.insert(key_part.to_vec());
}
}
let mut ret = vec![];
for (key, count) in found_tuples {
let similarity = count as f64 / config.manifest.r as f64;
if similarity < config.min_similarity {
continue;
}
let mut orig_tuple = config
for key in found_tuples {
let orig_tuple = config
.base_handle
.get(self, &key)?
.ok_or_else(|| miette!("Tuple not found in base LSH relation"))?;
@ -181,9 +180,6 @@ impl<'a> SessionTx<'a> {
continue;
}
}
if config.bind_similarity.is_some() {
orig_tuple.push(DataValue::from(similarity));
}
ret.push(orig_tuple);
if let Some(k) = config.k {
if ret.len() >= k {
@ -203,16 +199,13 @@ pub(crate) struct LshSearch {
pub(crate) bindings: Vec<Symbol>,
pub(crate) k: Option<usize>,
pub(crate) query: Symbol,
pub(crate) bind_similarity: Option<Symbol>,
pub(crate) min_similarity: f64,
// pub(crate) lax_mode: bool,
pub(crate) filter: Option<Expr>,
pub(crate) span: SourceSpan,
}
impl LshSearch {
pub(crate) fn all_bindings(&self) -> impl Iterator<Item = &Symbol> {
self.bindings.iter().chain(self.bind_similarity.iter())
self.bindings.iter()
}
}

@ -742,16 +742,14 @@ impl<'a> SessionTx<'a> {
default_gen: None,
}];
let mut idx_keys = vec![
ColumnDef {
name: SmartString::from("hash"),
typing: NullableColType {
coltype: ColType::Bytes,
nullable: false,
},
default_gen: None,
let mut idx_keys = vec![ColumnDef {
name: SmartString::from("hash"),
typing: NullableColType {
coltype: ColType::Bytes,
nullable: false,
},
];
default_gen: None,
}];
for k in rel_handle.metadata.keys.iter() {
idx_keys.push(ColumnDef {
name: format!("src_{}", k.name).into(),
@ -770,7 +768,7 @@ impl<'a> SessionTx<'a> {
let inv_idx_handle = self.write_idx_relation(
&config.base_relation,
&config.index_name,
&format!("{}:inv", config.index_name),
inv_idx_keys,
inv_idx_vals,
)?;

@ -948,6 +948,75 @@ fn test_fts_indexing() {
}
}
#[test]
fn test_lsh_indexing() {
let db = DbInstance::new("mem", "", "").unwrap();
db.run_script(r":create a {k: String => v: String}", Default::default())
.unwrap();
db.run_script(
r"?[k, v] <- [['a', 'hello world!'], ['b', 'the world is round']] :put a {k => v}",
Default::default(),
)
.unwrap();
db.run_script(
r"::lsh create a:lsh {extractor: v, tokenizer: NGram, n_gram: 3, target_threshold: 0.5 }",
Default::default(),
)
.unwrap();
db.run_script(
r"?[k, v] <- [
['b', 'the world is square!'],
['c', 'see you at the end of the world!'],
['d', 'the world is the world and makes the world go around'],
['e', 'the world is the world and makes the world not go around']
] :put a {k => v}",
Default::default(),
)
.unwrap();
let res = db.run_script("::columns a:lsh", Default::default()).unwrap();
for row in res.into_json()["rows"].as_array().unwrap() {
println!("{}", row);
}
let _res = db
.run_script(
r"
?[hash, src_k] :=
*a:lsh{src_k, hash}
",
Default::default(),
)
.unwrap();
// for row in res.into_json()["rows"].as_array().unwrap() {
// println!("{}", row);
// }
let _res = db
.run_script(
r"
?[k, minhash] :=
*a:lsh:inv{k, minhash}
",
Default::default(),
)
.unwrap();
// for row in res.into_json()["rows"].as_array().unwrap() {
// println!("{}", row);
// }
let res = db
.run_script(
r"
?[k, v] := ~a:lsh{k, v |
query: 'see him at the end of the world',
}
",
Default::default(),
)
.unwrap();
for row in res.into_json()["rows"].as_array().unwrap() {
println!("{}", row);
}
}
#[test]
fn test_insertions() {
let db = DbInstance::new("mem", "", "").unwrap();

Loading…
Cancel
Save