complete LSH
parent
683b66c0a7
commit
ba98a5c137
@ -0,0 +1,374 @@
|
||||
/*
|
||||
* Copyright 2023, The Cozo Project Authors.
|
||||
*
|
||||
* This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0.
|
||||
* If a copy of the MPL was not distributed with this file,
|
||||
* You can obtain one at https://mozilla.org/MPL/2.0/.
|
||||
*/
|
||||
|
||||
// Some ideas are from https://github.com/schelterlabs/rust-minhash
|
||||
|
||||
use crate::data::expr::{eval_bytecode, Bytecode, eval_bytecode_pred};
|
||||
use crate::data::tuple::Tuple;
|
||||
use crate::fts::tokenizer::TextAnalyzer;
|
||||
use crate::fts::TokenizerConfig;
|
||||
use crate::runtime::relation::RelationHandle;
|
||||
use crate::runtime::transact::SessionTx;
|
||||
use crate::{DataValue, Expr, SourceSpan, Symbol};
|
||||
use miette::{bail, miette, Result};
|
||||
use quadrature::integrate;
|
||||
use rand::{thread_rng, RngCore};
|
||||
use rustc_hash::FxHashMap;
|
||||
use smartstring::{LazyCompact, SmartString};
|
||||
use std::cmp::min;
|
||||
use std::hash::{Hash, Hasher};
|
||||
use twox_hash::XxHash32;
|
||||
|
||||
impl<'a> SessionTx<'a> {
|
||||
pub(crate) fn del_lsh_index_item(
|
||||
&mut self,
|
||||
tuple: &[DataValue],
|
||||
bytes: Option<Vec<u8>>,
|
||||
idx_handle: &RelationHandle,
|
||||
inv_idx_handle: &RelationHandle,
|
||||
manifest: &MinHashLshIndexManifest,
|
||||
) -> Result<()> {
|
||||
let bytes = match bytes {
|
||||
None => {
|
||||
if let Some(mut found) = inv_idx_handle.get_val_only(self, tuple)? {
|
||||
let inv_key = inv_idx_handle.encode_key_for_store(tuple, Default::default())?;
|
||||
self.store_tx.del(&inv_key)?;
|
||||
match found.pop() {
|
||||
Some(DataValue::Bytes(b)) => b,
|
||||
_ => unreachable!(),
|
||||
}
|
||||
} else {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
Some(b) => b,
|
||||
};
|
||||
|
||||
let mut key = Vec::with_capacity(bytes.len() + 2);
|
||||
key.push(DataValue::Bot);
|
||||
key.push(DataValue::Bot);
|
||||
key.extend_from_slice(tuple);
|
||||
for (i, chunk) in bytes
|
||||
.chunks_exact(manifest.r * std::mem::size_of::<u32>())
|
||||
.enumerate()
|
||||
{
|
||||
key[0] = DataValue::from(i as i64);
|
||||
key[1] = DataValue::Bytes(chunk.to_vec());
|
||||
let key_bytes = idx_handle.encode_key_for_store(&key, Default::default())?;
|
||||
self.store_tx.del(&key_bytes)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
pub(crate) fn put_lsh_index_item(
|
||||
&mut self,
|
||||
tuple: &[DataValue],
|
||||
extractor: &[Bytecode],
|
||||
stack: &mut Vec<DataValue>,
|
||||
tokenizer: &TextAnalyzer,
|
||||
rel_handle: &RelationHandle,
|
||||
idx_handle: &RelationHandle,
|
||||
inv_idx_handle: &RelationHandle,
|
||||
manifest: &MinHashLshIndexManifest,
|
||||
hash_perms: &HashPermutations,
|
||||
) -> Result<()> {
|
||||
if let Some(mut found) =
|
||||
inv_idx_handle.get_val_only(self, &tuple[..rel_handle.metadata.keys.len()])?
|
||||
{
|
||||
let bytes = match found.pop() {
|
||||
Some(DataValue::Bytes(b)) => b,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
self.del_lsh_index_item(tuple, Some(bytes), idx_handle, inv_idx_handle, manifest)?;
|
||||
}
|
||||
let to_index = eval_bytecode(extractor, tuple, stack)?;
|
||||
let min_hash = match to_index {
|
||||
DataValue::Null => return Ok(()),
|
||||
DataValue::List(l) => HashValues::new(l.iter(), hash_perms),
|
||||
DataValue::Str(s) => {
|
||||
let n_grams = tokenizer.unique_ngrams(&s, manifest.n_gram);
|
||||
HashValues::new(n_grams.iter(), hash_perms)
|
||||
}
|
||||
_ => bail!("Cannot put value {:?} into a LSH index", to_index),
|
||||
};
|
||||
let bytes = min_hash.get_bytes();
|
||||
let inv_key_part = &tuple[..rel_handle.metadata.keys.len()];
|
||||
let inv_val_part = vec![DataValue::Bytes(bytes.to_vec())];
|
||||
let inv_key = inv_idx_handle.encode_key_for_store(inv_key_part, Default::default())?;
|
||||
let inv_val =
|
||||
inv_idx_handle.encode_val_only_for_store(&inv_val_part, Default::default())?;
|
||||
self.store_tx.put(&inv_key, &inv_val)?;
|
||||
|
||||
let mut key = Vec::with_capacity(bytes.len() + 2);
|
||||
key.push(DataValue::Bot);
|
||||
key.push(DataValue::Bot);
|
||||
key.extend_from_slice(inv_key_part);
|
||||
let chunk_size = manifest.r * std::mem::size_of::<u32>();
|
||||
for i in 0..manifest.b {
|
||||
let byte_range = &bytes[i * chunk_size..(i + 1) * chunk_size];
|
||||
key[0] = DataValue::from(i as i64);
|
||||
key[1] = DataValue::Bytes(byte_range.to_vec());
|
||||
let key_bytes = idx_handle.encode_key_for_store(&key, Default::default())?;
|
||||
self.store_tx.put(&key_bytes, &[])?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
pub(crate) fn lsh_search(
|
||||
&self,
|
||||
tuple: &[DataValue],
|
||||
config: &LshSearch,
|
||||
stack: &mut Vec<DataValue>,
|
||||
filter_code: &Option<(Vec<Bytecode>, SourceSpan)>,
|
||||
) -> Result<Vec<Tuple>> {
|
||||
let bytes = if let Some(mut found) = config
|
||||
.inv_idx_handle
|
||||
.get_val_only(self, &tuple[..config.base_handle.metadata.keys.len()])?
|
||||
{
|
||||
match found.pop() {
|
||||
Some(DataValue::Bytes(b)) => b,
|
||||
_ => unreachable!(),
|
||||
}
|
||||
} else {
|
||||
return Ok(vec![]);
|
||||
};
|
||||
let chunk_size = config.manifest.r * std::mem::size_of::<u32>();
|
||||
let mut key_prefix = Vec::with_capacity(2);
|
||||
let mut found_tuples: FxHashMap<_, usize> = FxHashMap::default();
|
||||
for (i, chunk) in bytes.chunks_exact(chunk_size).enumerate() {
|
||||
key_prefix.clear();
|
||||
key_prefix.push(DataValue::from(i as i64));
|
||||
key_prefix.push(DataValue::Bytes(chunk.to_vec()));
|
||||
for ks in config.idx_handle.scan_prefix(self, &key_prefix) {
|
||||
let ks = ks?;
|
||||
let key_part = &ks[2..];
|
||||
if key_part != tuple {
|
||||
let found = found_tuples.entry(key_part.to_vec()).or_default();
|
||||
*found += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
let mut ret = vec![];
|
||||
for (key, count) in found_tuples {
|
||||
let similarity = count as f64 / config.manifest.r as f64;
|
||||
if similarity < config.min_similarity {
|
||||
continue;
|
||||
}
|
||||
let mut orig_tuple = config
|
||||
.base_handle
|
||||
.get(self, &key)?
|
||||
.ok_or_else(|| miette!("Tuple not found in base LSH relation"))?;
|
||||
if let Some((filter_code, span)) = filter_code {
|
||||
if !eval_bytecode_pred(filter_code, &orig_tuple, stack, *span)? {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if config.bind_similarity.is_some() {
|
||||
orig_tuple.push(DataValue::from(similarity));
|
||||
}
|
||||
ret.push(orig_tuple);
|
||||
if let Some(k) = config.k {
|
||||
if ret.len() >= k {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(ret)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct LshSearch {
|
||||
pub(crate) base_handle: RelationHandle,
|
||||
pub(crate) idx_handle: RelationHandle,
|
||||
pub(crate) inv_idx_handle: RelationHandle,
|
||||
pub(crate) manifest: MinHashLshIndexManifest,
|
||||
pub(crate) bindings: Vec<Symbol>,
|
||||
pub(crate) k: Option<usize>,
|
||||
pub(crate) query: Symbol,
|
||||
pub(crate) bind_similarity: Option<Symbol>,
|
||||
pub(crate) min_similarity: f64,
|
||||
// pub(crate) lax_mode: bool,
|
||||
pub(crate) filter: Option<Expr>,
|
||||
pub(crate) span: SourceSpan,
|
||||
}
|
||||
|
||||
impl LshSearch {
|
||||
pub(crate) fn all_bindings(&self) -> impl Iterator<Item = &Symbol> {
|
||||
self.bindings.iter().chain(self.bind_similarity.iter())
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct HashValues(pub(crate) Vec<u32>);
|
||||
pub(crate) struct HashPermutations(pub(crate) Vec<u32>);
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, serde_derive::Serialize, serde_derive::Deserialize)]
|
||||
pub(crate) struct MinHashLshIndexManifest {
|
||||
pub(crate) base_relation: SmartString<LazyCompact>,
|
||||
pub(crate) index_name: SmartString<LazyCompact>,
|
||||
pub(crate) extractor: String,
|
||||
pub(crate) n_gram: usize,
|
||||
pub(crate) tokenizer: TokenizerConfig,
|
||||
pub(crate) filters: Vec<TokenizerConfig>,
|
||||
|
||||
pub(crate) num_perm: usize,
|
||||
pub(crate) b: usize,
|
||||
pub(crate) r: usize,
|
||||
pub(crate) threshold: f64,
|
||||
pub(crate) perms: Vec<u8>,
|
||||
}
|
||||
|
||||
impl MinHashLshIndexManifest {
|
||||
pub(crate) fn get_hash_perms(&self) -> HashPermutations {
|
||||
HashPermutations::from_bytes(&self.perms)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct LshParams {
|
||||
pub b: usize,
|
||||
pub r: usize,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct Weights(pub(crate) f64, pub(crate) f64);
|
||||
|
||||
const _ALLOWED_INTEGRATE_ERR: f64 = 0.001;
|
||||
|
||||
// code is mostly from https://github.com/schelterlabs/rust-minhash/blob/81ea3fec24fd888a330a71b6932623643346b591/src/minhash_lsh.rs
|
||||
impl LshParams {
|
||||
pub fn find_optimal_params(threshold: f64, num_perm: usize, weights: &Weights) -> LshParams {
|
||||
let Weights(false_positive_weight, false_negative_weight) = weights;
|
||||
let mut min_error = f64::INFINITY;
|
||||
let mut opt = LshParams { b: 0, r: 0 };
|
||||
for b in 1..num_perm + 1 {
|
||||
let max_r = num_perm / b;
|
||||
for r in 1..max_r + 1 {
|
||||
let false_pos = LshParams::false_positive_probability(threshold, b, r);
|
||||
let false_neg = LshParams::false_negative_probability(threshold, b, r);
|
||||
let error = false_pos * false_positive_weight + false_neg * false_negative_weight;
|
||||
if error < min_error {
|
||||
min_error = error;
|
||||
opt = LshParams { b, r };
|
||||
}
|
||||
}
|
||||
}
|
||||
opt
|
||||
}
|
||||
|
||||
fn false_positive_probability(threshold: f64, b: usize, r: usize) -> f64 {
|
||||
let _probability =
|
||||
|s| -> f64 { 1. - f64::powf(1. - f64::powi(s, r as i32), b as f64) };
|
||||
integrate(_probability, 0.0, threshold, _ALLOWED_INTEGRATE_ERR).integral
|
||||
}
|
||||
|
||||
fn false_negative_probability(threshold: f64, b: usize, r: usize) -> f64 {
|
||||
let _probability =
|
||||
|s| -> f64 { 1. - (1. - f64::powf(1. - f64::powi(s, r as i32), b as f64)) };
|
||||
integrate(_probability, threshold, 1.0, _ALLOWED_INTEGRATE_ERR).integral
|
||||
}
|
||||
}
|
||||
|
||||
impl HashPermutations {
|
||||
pub(crate) fn new(n_perms: usize) -> Self {
|
||||
let mut rng = thread_rng();
|
||||
let mut perms = Vec::with_capacity(n_perms);
|
||||
for _ in 0..n_perms {
|
||||
perms.push(rng.next_u32());
|
||||
}
|
||||
Self(perms)
|
||||
}
|
||||
pub(crate) fn as_bytes(&self) -> &[u8] {
|
||||
unsafe {
|
||||
std::slice::from_raw_parts(
|
||||
self.0.as_ptr() as *const u8,
|
||||
self.0.len() * std::mem::size_of::<u32>(),
|
||||
)
|
||||
}
|
||||
}
|
||||
// this is the inverse of `as_bytes`
|
||||
pub(crate) fn from_bytes(bytes: &[u8]) -> Self {
|
||||
unsafe {
|
||||
let ptr = bytes.as_ptr() as *const u32;
|
||||
let len = bytes.len() / std::mem::size_of::<u32>();
|
||||
let perms = std::slice::from_raw_parts(ptr, len);
|
||||
Self(perms.to_vec())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl HashValues {
|
||||
pub(crate) fn new<T: Hash>(values: impl Iterator<Item = T>, perms: &HashPermutations) -> Self {
|
||||
let mut ret = Self::init(perms);
|
||||
ret.update(values, perms);
|
||||
ret
|
||||
}
|
||||
pub(crate) fn init(perms: &HashPermutations) -> Self {
|
||||
Self(vec![u32::MAX; perms.0.len()])
|
||||
}
|
||||
pub(crate) fn update<T: Hash>(
|
||||
&mut self,
|
||||
values: impl Iterator<Item = T>,
|
||||
perms: &HashPermutations,
|
||||
) {
|
||||
for v in values {
|
||||
for (i, seed) in perms.0.iter().enumerate() {
|
||||
let mut hasher = XxHash32::with_seed(*seed);
|
||||
v.hash(&mut hasher);
|
||||
let hash = hasher.finish() as u32;
|
||||
self.0[i] = min(self.0[i], hash);
|
||||
}
|
||||
}
|
||||
}
|
||||
#[cfg(test)]
|
||||
pub(crate) fn jaccard(&self, other_minhash: &Self) -> f32 {
|
||||
let matches = self
|
||||
.0
|
||||
.iter()
|
||||
.zip_eq(&other_minhash.0)
|
||||
.filter(|(left, right)| left == right)
|
||||
.count();
|
||||
let result = matches as f32 / self.0.len() as f32;
|
||||
result
|
||||
}
|
||||
pub(crate) fn get_bytes(&self) -> &[u8] {
|
||||
unsafe {
|
||||
std::slice::from_raw_parts(
|
||||
self.0.as_ptr() as *const u8,
|
||||
self.0.len() * std::mem::size_of::<u32>(),
|
||||
)
|
||||
}
|
||||
}
|
||||
// pub(crate) fn get_byte_chunks(&self, n_chunks: usize) -> impl Iterator<Item = &[u8]> {
|
||||
// let chunk_size = self.0.len() * std::mem::size_of::<u32>() / n_chunks;
|
||||
// self.get_bytes().chunks_exact(chunk_size)
|
||||
// }
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_minhash() {
|
||||
let perms = HashPermutations::new(20000);
|
||||
let mut m1 = HashValues::new([1, 2, 3, 4, 5, 6].iter(), &perms);
|
||||
let mut m2 = HashValues::new([4, 3, 2, 1, 5, 6].iter(), &perms);
|
||||
assert_eq!(m1.0, m2.0);
|
||||
// println!("{:?}", &m1.0);
|
||||
// println!("{:?}", &m2.0);
|
||||
assert_eq!(m1.jaccard(&m2), 1.0);
|
||||
m1.update([7, 8, 9].iter(), &perms);
|
||||
assert!(m1.jaccard(&m2) < 1.0);
|
||||
println!("{:?}", m1.jaccard(&m2));
|
||||
m2.update([17, 18, 19].iter(), &perms);
|
||||
assert!(m1.jaccard(&m2) < 1.0);
|
||||
println!("{:?}", m1.jaccard(&m2));
|
||||
// println!("{:?}", m2.get_byte_chunks(2).collect_vec());
|
||||
assert_eq!(perms.0, HashPermutations::from_bytes(perms.as_bytes()).0);
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue