HNSW compilation redux

main
Ziyang Hu 1 year ago
parent 2141731cab
commit 84163915a1

@ -276,6 +276,13 @@ impl Expr {
span,
}
}
pub(crate) fn build_and(exprs: Vec<Expr>, span: SourceSpan) -> Self {
Expr::Apply {
op: &OP_AND,
args: exprs.into(),
span,
}
}
pub(crate) fn build_is_in(exprs: Vec<Expr>, span: SourceSpan) -> Self {
Expr::Apply {
op: &OP_IS_IN,

@ -210,42 +210,6 @@ impl<'a> SessionTx<'a> {
debug_assert_eq!(prev_joiner_vars.len(), right_joiner_vars.len());
ret = ret.join(right, prev_joiner_vars, right_joiner_vars, rule_app.span);
}
MagicAtom::HnswSearch(s) => {
// already existing vars
let mut prev_joiner_vars = vec![];
// vars introduced by right and joined
let mut right_joiner_vars = vec![];
// used to split in case we need to join again
let mut right_joiner_vars_pos = vec![];
// vars introduced by right, regardless of joining
let mut right_vars = vec![];
// used for choosing indices
let mut join_indices = vec![];
for (i, var) in s.all_bindings().enumerate() {
if seen_variables.contains(var) {
prev_joiner_vars.push(var.clone());
let rk = gen_symb(var.span);
right_vars.push(rk.clone());
right_joiner_vars.push(rk);
right_joiner_vars_pos.push(i);
join_indices.push(IndexPositionUse::Join)
} else {
seen_variables.insert(var.clone());
right_vars.push(var.clone());
if var.is_generated_ignored_symbol() {
join_indices.push(IndexPositionUse::Ignored)
} else {
join_indices.push(IndexPositionUse::BindForLater)
}
}
}
// scan original relation
let right = RelAlgebra::hnsw_search(right_vars, s.clone())?;
debug_assert_eq!(prev_joiner_vars.len(), right_joiner_vars.len());
ret = ret.join(right, prev_joiner_vars, right_joiner_vars, s.span);
}
MagicAtom::Relation(rel_app) => {
let store = self.get_relation(&rel_app.name, false)?;
if store.access_level < AccessLevel::ReadOnly {
@ -507,6 +471,40 @@ impl<'a> SessionTx<'a> {
MagicAtom::Predicate(p) => {
ret = ret.filter(p.clone());
}
MagicAtom::HnswSearch(s) => {
debug_assert!(
seen_variables.contains(&s.query),
"HNSW search query must be bound"
);
let mut own_bindings = vec![];
let mut post_filters = vec![];
for var in s.all_bindings() {
if seen_variables.contains(var) {
let rk = gen_symb(var.span);
post_filters.push(Expr::build_equate(
vec![
Expr::Binding {
var: var.clone(),
tuple_pos: None,
},
Expr::Binding {
var: rk.clone(),
tuple_pos: None,
},
],
var.span,
));
own_bindings.push(rk);
} else {
seen_variables.insert(var.clone());
own_bindings.push(var.clone());
}
}
ret = ret.hnsw_search(s.clone(), own_bindings)?;
if post_filters.len() > 0 {
ret = ret.filter(Expr::build_and(post_filters, s.span));
}
}
MagicAtom::Unification(u) => {
if seen_variables.contains(&u.binding) {
let expr = if u.one_many_unif {

@ -349,8 +349,8 @@ impl RelAlgebra {
RelAlgebra::Stored(v) => {
v.fill_binding_indices_and_compile()?;
}
RelAlgebra::HnswSearch(_) => {
todo!("fill_binding_indices_and_compile for HnswSearch")
RelAlgebra::HnswSearch(s) => {
s.fill_binding_indices_and_compile()?;
}
RelAlgebra::StoredWithValidity(v) => {
v.fill_binding_indices_and_compile()?;
@ -402,12 +402,6 @@ impl RelAlgebra {
span,
})
}
pub(crate) fn hnsw_search(bindings: Vec<Symbol>, search: HnswSearch) -> Result<Self> {
Ok(Self::HnswSearch(HnswSearchRA {
bindings,
hnsw_search: search,
}))
}
pub(crate) fn relation(
bindings: Vec<Symbol>,
storage: RelationHandle,
@ -453,7 +447,8 @@ impl RelAlgebra {
s @ (RelAlgebra::Fixed(_)
| RelAlgebra::Reorder(_)
| RelAlgebra::NegJoin(_)
| RelAlgebra::Unification(_)) => {
| RelAlgebra::Unification(_)
| RelAlgebra::HnswSearch(_)) => {
let span = filter.span();
RelAlgebra::Filter(FilteredRA {
parent: Box::new(s),
@ -575,9 +570,6 @@ impl RelAlgebra {
}
joined
}
RelAlgebra::HnswSearch(_) => {
todo!("filter on HnswSearch")
}
}
}
pub(crate) fn unify(
@ -597,6 +589,18 @@ impl RelAlgebra {
span,
})
}
pub(crate) fn hnsw_search(
self,
hnsw_search: HnswSearch,
own_bindings: Vec<Symbol>,
) -> Result<Self> {
Ok(Self::HnswSearch(HnswSearchRA {
parent: Box::new(self),
hnsw_search,
filter_bytecode: None,
own_bindings,
}))
}
pub(crate) fn join(
self,
right: RelAlgebra,
@ -842,8 +846,28 @@ pub(crate) struct StoredRA {
#[derive(Debug)]
pub(crate) struct HnswSearchRA {
pub(crate) bindings: Vec<Symbol>,
pub(crate) parent: Box<RelAlgebra>,
pub(crate) hnsw_search: HnswSearch,
pub(crate) filter_bytecode: Option<(Vec<Bytecode>, SourceSpan)>,
pub(crate) own_bindings: Vec<Symbol>,
}
impl HnswSearchRA {
fn fill_binding_indices_and_compile(&mut self) -> Result<()> {
if self.hnsw_search.filter.is_some() {
let bindings: BTreeMap<_, _> = self
.own_bindings
.iter()
.cloned()
.enumerate()
.map(|(a, b)| (b, a))
.collect();
let filter = self.hnsw_search.filter.as_mut().unwrap();
filter.fill_binding_indices(&bindings)?;
self.filter_bytecode = Some((filter.compile(), filter.span()));
}
Ok(())
}
}
#[derive(Debug)]
@ -1626,7 +1650,11 @@ impl RelAlgebra {
bindings.push(u.binding.clone());
bindings
}
RelAlgebra::HnswSearch(s) => s.bindings.clone(),
RelAlgebra::HnswSearch(s) => {
let mut bindings = s.parent.bindings_after_eliminate();
bindings.extend_from_slice(&s.own_bindings);
bindings
}
}
}
pub(crate) fn iter<'a>(
@ -1645,7 +1673,7 @@ impl RelAlgebra {
RelAlgebra::Filter(r) => r.iter(tx, delta_rule, stores),
RelAlgebra::NegJoin(r) => r.iter(tx, delta_rule, stores),
RelAlgebra::Unification(r) => r.iter(tx, delta_rule, stores),
RelAlgebra::HnswSearch(_) => {
RelAlgebra::HnswSearch(r) => {
todo!()
}
}
@ -1936,7 +1964,10 @@ impl InnerJoin {
self.materialized_join(tx, eliminate_indices, delta_rule, stores)
}
}
RelAlgebra::Join(_) | RelAlgebra::Filter(_) | RelAlgebra::Unification(_) => {
RelAlgebra::Join(_)
| RelAlgebra::Filter(_)
| RelAlgebra::Unification(_)
| RelAlgebra::HnswSearch(_) => {
self.materialized_join(tx, eliminate_indices, delta_rule, stores)
}
RelAlgebra::Reorder(_) => {
@ -1945,9 +1976,6 @@ impl InnerJoin {
RelAlgebra::NegJoin(_) => {
panic!("joining on NegJoin")
}
RelAlgebra::HnswSearch(_) => {
todo!("joining on HnswSearch")
}
}
}
fn materialized_join<'a>(

@ -43,8 +43,8 @@ use crate::parse::sys::SysOp;
use crate::parse::{parse_script, CozoScript, SourceSpan};
use crate::query::compile::{CompiledProgram, CompiledRule, CompiledRuleSet};
use crate::query::ra::{
FilteredRA, InnerJoin, NegJoin, RelAlgebra, ReorderRA, StoredRA, StoredWithValidityRA,
TempStoreRA, UnificationRA,
FilteredRA, HnswSearchRA, InnerJoin, NegJoin, RelAlgebra, ReorderRA, StoredRA,
StoredWithValidityRA, TempStoreRA, UnificationRA,
};
#[allow(unused_imports)]
use crate::runtime::callback::{
@ -1055,9 +1055,18 @@ impl<'s, S: Storage<'s>> Db<S> {
json!(expr.to_string()),
)
}
RelAlgebra::HnswSearch(_) => {
todo!("HnswSearch")
}
RelAlgebra::HnswSearch(HnswSearchRA {
hnsw_search,
..
}) => (
"hnsw_index",
json!(format!(":{}", hnsw_search.query.name)),
json!(hnsw_search.query.name),
json!(hnsw_search.filter
.iter()
.map(|f| f.to_string())
.collect_vec()),
),
};
ret_for_relation.push(json!({
STRATUM: stratum,

@ -822,12 +822,16 @@ fn test_vec_index() {
Default::default(),
)
.unwrap();
db.run_script(
r"
?[v] := ~a:vec{k: 'a', v | query: [1,1,1,1,1,1,1,1]}
let res = db
.run_script(
r"
#::explain {
?[v] := ~a:vec{k: 'a', v | query: q}, q = vec([1,1,1,1,1,1,1,1])
#}
",
Default::default(),
)
.unwrap();
println!("{:#?}", db.export_relations(["a", "a:vec"].iter()));
Default::default(),
)
.unwrap();
println!("{:#?}", res.into_json()["rows"]);
// println!("{:#?}", db.export_relations(["a", "a:vec"].iter()));
}

Loading…
Cancel
Save