finished preliminary implementation of HNSW

main
Ziyang Hu 1 year ago
parent a03af029db
commit 776e8e26a5

@ -1209,6 +1209,35 @@ impl Display for InputAtom {
write!(f, ":{name}")?;
f.debug_list().entries(args).finish()?;
}
InputAtom::HnswSearch { inner } => {
write!(f, "~{}:{}{{", inner.relation, inner.index)?;
for (binding, expr) in &inner.bindings {
write!(f, "{binding}: {expr}, ")?;
}
write!(f, "| ")?;
write!(f, " query: {}, ", inner.query)?;
write!(f, " k: {}, ", inner.k)?;
write!(f, " ef: {}, ", inner.ef)?;
if let Some(radius) = &inner.radius {
write!(f, " radius: {}, ", radius)?;
}
if let Some(filter) = &inner.filter {
write!(f, " filter: {}, ", filter)?;
}
if let Some(bind_distance) = &inner.bind_distance {
write!(f, " bind_distance: {}, ", bind_distance)?;
}
if let Some(bind_field) = &inner.bind_field {
write!(f, " bind_field: {}, ", bind_field)?;
}
if let Some(bind_field_idx) = &inner.bind_field_idx {
write!(f, " bind_field_idx: {}, ", bind_field_idx)?;
}
if let Some(bind_vector) = &inner.bind_vector {
write!(f, " bind_vector: {}, ", bind_vector)?;
}
write!(f, "}}")?;
}
InputAtom::Predicate { inner } => {
write!(f, "{inner}")?;
}
@ -1248,9 +1277,6 @@ impl Display for InputAtom {
}
write!(f, "{expr}")?;
}
InputAtom::HnswSearch { .. } => {
todo!()
}
}
Ok(())
}
@ -1277,9 +1303,7 @@ impl InputAtom {
InputAtom::Relation { inner, .. } => inner.span,
InputAtom::Predicate { inner, .. } => inner.span(),
InputAtom::Unification { inner, .. } => inner.span,
InputAtom::HnswSearch { .. } => {
todo!()
}
InputAtom::HnswSearch { inner, .. } => inner.span,
}
}
}

@ -42,38 +42,6 @@ pub(crate) struct HnswIndexManifest {
pub(crate) keep_pruned_connections: bool,
}
// #[derive(Clone)]
// pub(crate) struct HnswKnnQueryOptions {
// k: usize,
// ef: usize,
// bind_field: bool,
// bind_field_idx: bool,
// bind_distance: bool,
// bind_vector: bool,
// radius: Option<f64>,
// orig_table_binding: Vec<usize>,
// filter: Option<Vec<Bytecode>>,
// }
// impl HnswKnnQueryOptions {
// fn return_tuple_len(&self) -> usize {
// let mut ret_tuple_len = self.orig_table_binding.len();
// if self.bind_field {
// ret_tuple_len += 1;
// }
// if self.bind_field_idx {
// ret_tuple_len += 1;
// }
// if self.bind_distance {
// ret_tuple_len += 1;
// }
// if self.bind_vector {
// ret_tuple_len += 1;
// }
// ret_tuple_len
// }
// }
impl HnswIndexManifest {
fn get_random_level(&self) -> i64 {
let mut rng = rand::thread_rng();
@ -353,13 +321,13 @@ impl<'a> SessionTx<'a> {
Some(bytes) => bytes,
None => bail!("Indexed vector not found, this signifies a bug in the index implementation"),
};
let target_self_val: Vec<DataValue> =
let mut target_self_val: Vec<DataValue> =
rmp_serde::from_slice(&target_self_val_bytes[ENCODED_KEY_MIN_LEN..])
.unwrap();
let target_degree = target_self_val[0].get_float().unwrap() as usize;
let mut target_degree = target_self_val[0].get_float().unwrap() as usize + 1;
if target_degree > m_max {
// shrink links
self.hnsw_shrink_neighbour(
target_degree = self.hnsw_shrink_neighbour(
&neighbour.0,
neighbour.1,
neighbour.2,
@ -370,6 +338,13 @@ impl<'a> SessionTx<'a> {
orig_table,
)?;
}
// update degree
target_self_val[0] = DataValue::from(target_degree as f64);
self.store_tx.put(
&target_self_key_bytes,
&idx_table
.encode_val_only_for_store(&target_self_val, Default::default())?,
)?;
}
}
} else {
@ -398,7 +373,7 @@ impl<'a> SessionTx<'a> {
manifest: &HnswIndexManifest,
idx_table: &RelationHandle,
orig_table: &RelationHandle,
) -> Result<()> {
) -> Result<usize> {
let orig_key = orig_table.encode_key_for_store(target, Default::default())?;
let orig_val = match self.store_tx.get(&orig_key, false)? {
Some(bytes) => bytes,
@ -434,6 +409,7 @@ impl<'a> SessionTx<'a> {
for (new, _) in &new_candidates {
new_candidate_set.insert(new.clone());
}
let new_degree = new_candidates.len();
for (new, Reverse(OrderedFloat(new_dist))) in new_candidates {
if !old_candidate_set.contains(&new) {
let mut new_key = Vec::with_capacity(orig_table.metadata.keys.len() * 2 + 5);
@ -489,7 +465,7 @@ impl<'a> SessionTx<'a> {
}
}
Ok(())
Ok(new_degree)
}
fn hnsw_compare_vector(
&self,
@ -844,6 +820,26 @@ impl<'a> SessionTx<'a> {
in_key.push(DataValue::from(subidx as i64));
let in_key_bytes = idx_table.encode_key_for_store(&in_key, Default::default())?;
self.store_tx.del(&in_key_bytes)?;
let mut neighbour_self_key = vec![DataValue::from(layer)];
for _ in 0..2 {
neighbour_self_key.extend_from_slice(&neighbour_key);
neighbour_self_key.push(DataValue::from(neighbour_idx as i64));
neighbour_self_key.push(DataValue::from(neighbour_subidx as i64));
}
let neighbour_val_bytes = self
.store_tx
.get(
&idx_table.encode_key_for_store(&neighbour_self_key, Default::default())?,
false,
)?
.unwrap();
let mut neighbour_val: Vec<DataValue> =
rmp_serde::from_slice(&neighbour_val_bytes[ENCODED_KEY_MIN_LEN..]).unwrap();
neighbour_val[0] = DataValue::from(neighbour_val[0].get_float().unwrap() - 1.);
self.store_tx.put(
&idx_table.encode_key_for_store(&neighbour_self_key, Default::default())?,
&idx_table.encode_val_only_for_store(&neighbour_val, Default::default())?,
)?;
}
}
@ -895,7 +891,6 @@ impl<'a> SessionTx<'a> {
filter_bytecode: &Option<(Vec<Bytecode>, SourceSpan)>,
stack: &mut Vec<DataValue>,
) -> Result<Vec<Tuple>> {
println!("hnsw_knn for {:?} on {:#?}", q, config);
if q.len() != config.manifest.vec_dim {
bail!("query vector dimension mismatch");
}
@ -916,7 +911,6 @@ impl<'a> SessionTx<'a> {
)
.next();
if let Some(ep) = ep_res {
println!("found ep");
let ep = ep?;
let bottom_level = ep[0].get_int().unwrap();
let ep_key = ep[1..config.base_handle.metadata.keys.len() + 1].to_vec();
@ -957,7 +951,6 @@ impl<'a> SessionTx<'a> {
&mut found_nn,
)?;
if found_nn.is_empty() {
println!("no candidates found");
return Ok(vec![]);
}
@ -972,7 +965,6 @@ impl<'a> SessionTx<'a> {
while let Some(((cand_tuple, cand_idx, cand_subidx), OrderedFloat(distance))) =
found_nn.pop()
{
println!("found candidate {:?} at distance {}", cand_tuple, distance);
if let Some(r) = config.radius {
if distance > r {
continue;
@ -1027,6 +1019,9 @@ impl<'a> SessionTx<'a> {
ret.push(cand_tuple);
}
ret.reverse();
ret.truncate(config.k);
Ok(ret)
} else {
Ok(vec![])

@ -781,14 +781,15 @@ fn test_vec_index() {
let db = DbInstance::new("mem", "", "").unwrap();
db.run_script(
r"
?[k, v] <- [['a', [1,2,3,4,5,6,7,8]],
['b', [2,3,4,5,6,7,8,9]],
['bb', [2,3,4,5,6,7,8,9]],
['c', [2,3,4,5,6,7,8,19]],
['a', [2,3,4,5,6,7,8,9]],
['b', [1,1,1,1,1,1,1,1]]]
:create a {k: String => v: <F32; 8>}
?[k, v] <- [['a', [1,2]],
['b', [2,3]],
['bb', [2,3]],
['c', [3,4]],
['x', [0,0.1]],
['a', [112,0]],
['b', [1,1]]]
:create a {k: String => v: <F32; 2>}
",
Default::default(),
)
@ -796,13 +797,15 @@ fn test_vec_index() {
db.run_script(
r"
::hnsw create a:vec {
dim: 8,
dim: 2,
m: 50,
dtype: F32,
fields: [v],
distance: Cosine,
distance: L2,
ef_construction: 20,
filter: k != 'k1'
filter: k != 'k1',
#extend_candidates: true,
#keep_pruned_connections: true,
}",
Default::default(),
)
@ -810,28 +813,39 @@ fn test_vec_index() {
db.run_script(
r"
?[k, v] <- [
['a2', [1,2,3,4,5,6,7,8]],
['b2', [2,3,4,5,6,7,8,9]],
['bb2', [2,3,4,5,6,7,8,9]],
['c2', [2,3,4,5,6,7,8,19]],
['a2', [2,3,4,5,6,7,8,9]],
['b2', [1,1,1,1,1,1,1,1]]
['a2', [1,25]],
['b2', [2,34]],
['bb2', [2,33]],
['c2', [2,32]],
['a2', [2,31]],
['b2', [1,10]]
]
:put a {k => v}
",
Default::default(),
)
.unwrap();
println!("all links");
for (_, nrows) in db.export_relations(["a:vec"].iter()).unwrap() {
let nrows = nrows.rows;
for row in nrows {
println!("{} {} -> {} {}", row[0], row[1], row[4], row[7]);
}
}
let res = db
.run_script(
r"
#::explain {
?[k, dist, v] := ~a:vec{k, v | query: q, k: 10, ef: 20, bind_distance: dist}, q = vec([1,1,1,1,1,1,1,1])
?[dist, k, v] := ~a:vec{k, v | query: q, k: 2, ef: 20, bind_distance: dist}, q = vec([200, 34])
#}
",
Default::default(),
)
.unwrap();
println!("res: {:#?}", res.into_json()["rows"]);
// println!("{:#?}", db.export_relations(["a", "a:vec"].iter()));
println!("results");
for row in res.into_json()["rows"].as_array().unwrap() {
println!("{} {} {}", row[0], row[1], row[2]);
}
}

Loading…
Cancel
Save