From 84163915a1b859fcf382bc1c62aac2e808932b66 Mon Sep 17 00:00:00 2001 From: Ziyang Hu Date: Tue, 18 Apr 2023 14:41:21 +0800 Subject: [PATCH] HNSW compilation redux --- cozo-core/src/data/expr.rs | 7 ++++ cozo-core/src/query/compile.rs | 70 +++++++++++++++++----------------- cozo-core/src/query/ra.rs | 66 +++++++++++++++++++++++--------- cozo-core/src/runtime/db.rs | 19 ++++++--- cozo-core/src/runtime/tests.rs | 18 +++++---- 5 files changed, 113 insertions(+), 67 deletions(-) diff --git a/cozo-core/src/data/expr.rs b/cozo-core/src/data/expr.rs index 060c1baf..db84be36 100644 --- a/cozo-core/src/data/expr.rs +++ b/cozo-core/src/data/expr.rs @@ -276,6 +276,13 @@ impl Expr { span, } } + pub(crate) fn build_and(exprs: Vec, span: SourceSpan) -> Self { + Expr::Apply { + op: &OP_AND, + args: exprs.into(), + span, + } + } pub(crate) fn build_is_in(exprs: Vec, span: SourceSpan) -> Self { Expr::Apply { op: &OP_IS_IN, diff --git a/cozo-core/src/query/compile.rs b/cozo-core/src/query/compile.rs index fdd86a69..d0110b67 100644 --- a/cozo-core/src/query/compile.rs +++ b/cozo-core/src/query/compile.rs @@ -210,42 +210,6 @@ impl<'a> SessionTx<'a> { debug_assert_eq!(prev_joiner_vars.len(), right_joiner_vars.len()); ret = ret.join(right, prev_joiner_vars, right_joiner_vars, rule_app.span); } - MagicAtom::HnswSearch(s) => { - // already existing vars - let mut prev_joiner_vars = vec![]; - // vars introduced by right and joined - let mut right_joiner_vars = vec![]; - // used to split in case we need to join again - let mut right_joiner_vars_pos = vec![]; - // vars introduced by right, regardless of joining - let mut right_vars = vec![]; - // used for choosing indices - let mut join_indices = vec![]; - - for (i, var) in s.all_bindings().enumerate() { - if seen_variables.contains(var) { - prev_joiner_vars.push(var.clone()); - let rk = gen_symb(var.span); - right_vars.push(rk.clone()); - right_joiner_vars.push(rk); - right_joiner_vars_pos.push(i); - join_indices.push(IndexPositionUse::Join) - } else { - seen_variables.insert(var.clone()); - right_vars.push(var.clone()); - if var.is_generated_ignored_symbol() { - join_indices.push(IndexPositionUse::Ignored) - } else { - join_indices.push(IndexPositionUse::BindForLater) - } - } - } - - // scan original relation - let right = RelAlgebra::hnsw_search(right_vars, s.clone())?; - debug_assert_eq!(prev_joiner_vars.len(), right_joiner_vars.len()); - ret = ret.join(right, prev_joiner_vars, right_joiner_vars, s.span); - } MagicAtom::Relation(rel_app) => { let store = self.get_relation(&rel_app.name, false)?; if store.access_level < AccessLevel::ReadOnly { @@ -507,6 +471,40 @@ impl<'a> SessionTx<'a> { MagicAtom::Predicate(p) => { ret = ret.filter(p.clone()); } + MagicAtom::HnswSearch(s) => { + debug_assert!( + seen_variables.contains(&s.query), + "HNSW search query must be bound" + ); + let mut own_bindings = vec![]; + let mut post_filters = vec![]; + for var in s.all_bindings() { + if seen_variables.contains(var) { + let rk = gen_symb(var.span); + post_filters.push(Expr::build_equate( + vec![ + Expr::Binding { + var: var.clone(), + tuple_pos: None, + }, + Expr::Binding { + var: rk.clone(), + tuple_pos: None, + }, + ], + var.span, + )); + own_bindings.push(rk); + } else { + seen_variables.insert(var.clone()); + own_bindings.push(var.clone()); + } + } + ret = ret.hnsw_search(s.clone(), own_bindings)?; + if post_filters.len() > 0 { + ret = ret.filter(Expr::build_and(post_filters, s.span)); + } + } MagicAtom::Unification(u) => { if seen_variables.contains(&u.binding) { let expr = if u.one_many_unif { diff --git a/cozo-core/src/query/ra.rs b/cozo-core/src/query/ra.rs index f7bd12bd..68c2af94 100644 --- a/cozo-core/src/query/ra.rs +++ b/cozo-core/src/query/ra.rs @@ -349,8 +349,8 @@ impl RelAlgebra { RelAlgebra::Stored(v) => { v.fill_binding_indices_and_compile()?; } - RelAlgebra::HnswSearch(_) => { - todo!("fill_binding_indices_and_compile for HnswSearch") + RelAlgebra::HnswSearch(s) => { + s.fill_binding_indices_and_compile()?; } RelAlgebra::StoredWithValidity(v) => { v.fill_binding_indices_and_compile()?; @@ -402,12 +402,6 @@ impl RelAlgebra { span, }) } - pub(crate) fn hnsw_search(bindings: Vec, search: HnswSearch) -> Result { - Ok(Self::HnswSearch(HnswSearchRA { - bindings, - hnsw_search: search, - })) - } pub(crate) fn relation( bindings: Vec, storage: RelationHandle, @@ -453,7 +447,8 @@ impl RelAlgebra { s @ (RelAlgebra::Fixed(_) | RelAlgebra::Reorder(_) | RelAlgebra::NegJoin(_) - | RelAlgebra::Unification(_)) => { + | RelAlgebra::Unification(_) + | RelAlgebra::HnswSearch(_)) => { let span = filter.span(); RelAlgebra::Filter(FilteredRA { parent: Box::new(s), @@ -575,9 +570,6 @@ impl RelAlgebra { } joined } - RelAlgebra::HnswSearch(_) => { - todo!("filter on HnswSearch") - } } } pub(crate) fn unify( @@ -597,6 +589,18 @@ impl RelAlgebra { span, }) } + pub(crate) fn hnsw_search( + self, + hnsw_search: HnswSearch, + own_bindings: Vec, + ) -> Result { + Ok(Self::HnswSearch(HnswSearchRA { + parent: Box::new(self), + hnsw_search, + filter_bytecode: None, + own_bindings, + })) + } pub(crate) fn join( self, right: RelAlgebra, @@ -842,8 +846,28 @@ pub(crate) struct StoredRA { #[derive(Debug)] pub(crate) struct HnswSearchRA { - pub(crate) bindings: Vec, + pub(crate) parent: Box, pub(crate) hnsw_search: HnswSearch, + pub(crate) filter_bytecode: Option<(Vec, SourceSpan)>, + pub(crate) own_bindings: Vec, +} + +impl HnswSearchRA { + fn fill_binding_indices_and_compile(&mut self) -> Result<()> { + if self.hnsw_search.filter.is_some() { + let bindings: BTreeMap<_, _> = self + .own_bindings + .iter() + .cloned() + .enumerate() + .map(|(a, b)| (b, a)) + .collect(); + let filter = self.hnsw_search.filter.as_mut().unwrap(); + filter.fill_binding_indices(&bindings)?; + self.filter_bytecode = Some((filter.compile(), filter.span())); + } + Ok(()) + } } #[derive(Debug)] @@ -1626,7 +1650,11 @@ impl RelAlgebra { bindings.push(u.binding.clone()); bindings } - RelAlgebra::HnswSearch(s) => s.bindings.clone(), + RelAlgebra::HnswSearch(s) => { + let mut bindings = s.parent.bindings_after_eliminate(); + bindings.extend_from_slice(&s.own_bindings); + bindings + } } } pub(crate) fn iter<'a>( @@ -1645,7 +1673,7 @@ impl RelAlgebra { RelAlgebra::Filter(r) => r.iter(tx, delta_rule, stores), RelAlgebra::NegJoin(r) => r.iter(tx, delta_rule, stores), RelAlgebra::Unification(r) => r.iter(tx, delta_rule, stores), - RelAlgebra::HnswSearch(_) => { + RelAlgebra::HnswSearch(r) => { todo!() } } @@ -1936,7 +1964,10 @@ impl InnerJoin { self.materialized_join(tx, eliminate_indices, delta_rule, stores) } } - RelAlgebra::Join(_) | RelAlgebra::Filter(_) | RelAlgebra::Unification(_) => { + RelAlgebra::Join(_) + | RelAlgebra::Filter(_) + | RelAlgebra::Unification(_) + | RelAlgebra::HnswSearch(_) => { self.materialized_join(tx, eliminate_indices, delta_rule, stores) } RelAlgebra::Reorder(_) => { @@ -1945,9 +1976,6 @@ impl InnerJoin { RelAlgebra::NegJoin(_) => { panic!("joining on NegJoin") } - RelAlgebra::HnswSearch(_) => { - todo!("joining on HnswSearch") - } } } fn materialized_join<'a>( diff --git a/cozo-core/src/runtime/db.rs b/cozo-core/src/runtime/db.rs index 69e25034..35f5d5e3 100644 --- a/cozo-core/src/runtime/db.rs +++ b/cozo-core/src/runtime/db.rs @@ -43,8 +43,8 @@ use crate::parse::sys::SysOp; use crate::parse::{parse_script, CozoScript, SourceSpan}; use crate::query::compile::{CompiledProgram, CompiledRule, CompiledRuleSet}; use crate::query::ra::{ - FilteredRA, InnerJoin, NegJoin, RelAlgebra, ReorderRA, StoredRA, StoredWithValidityRA, - TempStoreRA, UnificationRA, + FilteredRA, HnswSearchRA, InnerJoin, NegJoin, RelAlgebra, ReorderRA, StoredRA, + StoredWithValidityRA, TempStoreRA, UnificationRA, }; #[allow(unused_imports)] use crate::runtime::callback::{ @@ -1055,9 +1055,18 @@ impl<'s, S: Storage<'s>> Db { json!(expr.to_string()), ) } - RelAlgebra::HnswSearch(_) => { - todo!("HnswSearch") - } + RelAlgebra::HnswSearch(HnswSearchRA { + hnsw_search, + .. + }) => ( + "hnsw_index", + json!(format!(":{}", hnsw_search.query.name)), + json!(hnsw_search.query.name), + json!(hnsw_search.filter + .iter() + .map(|f| f.to_string()) + .collect_vec()), + ), }; ret_for_relation.push(json!({ STRATUM: stratum, diff --git a/cozo-core/src/runtime/tests.rs b/cozo-core/src/runtime/tests.rs index 90d65fda..6d5d5f12 100644 --- a/cozo-core/src/runtime/tests.rs +++ b/cozo-core/src/runtime/tests.rs @@ -822,12 +822,16 @@ fn test_vec_index() { Default::default(), ) .unwrap(); - db.run_script( - r" - ?[v] := ~a:vec{k: 'a', v | query: [1,1,1,1,1,1,1,1]} + let res = db + .run_script( + r" + #::explain { + ?[v] := ~a:vec{k: 'a', v | query: q}, q = vec([1,1,1,1,1,1,1,1]) + #} ", - Default::default(), - ) - .unwrap(); - println!("{:#?}", db.export_relations(["a", "a:vec"].iter())); + Default::default(), + ) + .unwrap(); + println!("{:#?}", res.into_json()["rows"]); + // println!("{:#?}", db.export_relations(["a", "a:vec"].iter())); }