HNSW through magic rewrite

main
Ziyang Hu 1 year ago
parent bf479c3a64
commit 1b7d96f93f

@ -23,7 +23,9 @@ use crate::data::symb::{Symbol, PROG_ENTRY};
use crate::data::value::{DataValue, ValidityTs};
use crate::fixed_rule::{FixedRule, FixedRuleHandle};
use crate::parse::SourceSpan;
use crate::runtime::relation::InputRelationHandle;
use crate::query::logical::Disjunction;
use crate::runtime::hnsw::HnswIndexManifest;
use crate::runtime::relation::{InputRelationHandle, RelationHandle};
use crate::runtime::temp_store::EpochStore;
use crate::runtime::transact::SessionTx;
@ -908,12 +910,12 @@ pub(crate) enum InputAtom {
inner: Unification,
},
HnswSearch {
inner: HnswSearch,
inner: HnswSearchInput,
},
}
#[derive(Clone)]
pub(crate) struct HnswSearch {
pub(crate) struct HnswSearchInput {
pub(crate) relation: Symbol,
pub(crate) index: Symbol,
pub(crate) bindings: BTreeMap<SmartString<LazyCompact>, Expr>,
@ -929,6 +931,236 @@ pub(crate) struct HnswSearch {
pub(crate) span: SourceSpan,
}
#[derive(Clone, Debug)]
pub(crate) struct HnswSearch {
pub(crate) base_handle: RelationHandle,
pub(crate) idx_handle: RelationHandle,
pub(crate) manifest: HnswIndexManifest,
pub(crate) bindings: Vec<Symbol>,
pub(crate) k: usize,
pub(crate) ef: usize,
pub(crate) query: Symbol,
pub(crate) bind_field: Option<Symbol>,
pub(crate) bind_field_idx: Option<Symbol>,
pub(crate) bind_distance: Option<Symbol>,
pub(crate) bind_vector: Option<Symbol>,
pub(crate) radius: Option<f64>,
pub(crate) filter: Option<Expr>,
pub(crate) span: SourceSpan,
}
impl HnswSearch {
pub(crate) fn all_bindings(&self) -> impl Iterator<Item = &Symbol> {
self.bindings
.iter()
.chain(self.bind_field.iter())
.chain(self.bind_distance.iter())
.chain(self.bind_field_idx.iter())
.chain(self.bind_vector.iter())
}
}
impl HnswSearchInput {
pub(crate) fn normalize(
mut self,
gen: &mut TempSymbGen,
tx: &SessionTx<'_>,
) -> Result<Disjunction> {
let base_handle = tx.get_relation(&self.relation, false)?;
let (idx_handle, manifest) = base_handle
.hnsw_indices
.get(&self.index.name)
.ok_or_else(|| {
#[derive(Debug, Error, Diagnostic)]
#[error("hnsw index {name} not found on relation {relation}")]
#[diagnostic(code(eval::hnsw_index_not_found))]
struct HnswIndexNotFound {
relation: String,
name: String,
#[label]
span: SourceSpan,
}
HnswIndexNotFound {
relation: self.relation.name.to_string(),
name: self.index.name.to_string(),
span: self.index.span,
}
})?
.clone();
let mut conj = Vec::with_capacity(self.bindings.len() + 8);
let mut bindings = Vec::with_capacity(self.bindings.len());
let mut seen_variables = BTreeSet::new();
for col in base_handle
.metadata
.keys
.iter()
.chain(base_handle.metadata.non_keys.iter())
{
if let Some(arg) = self.bindings.remove(&col.name) {
match arg {
Expr::Binding { var, .. } => {
if var.is_ignored_symbol() {
bindings.push(gen.next_ignored(var.span));
} else if seen_variables.insert(var.clone()) {
bindings.push(var);
} else {
let span = var.span;
let dup = gen.next(span);
let unif = NormalFormAtom::Unification(Unification {
binding: dup.clone(),
expr: Expr::Binding {
var,
tuple_pos: None,
},
one_many_unif: false,
span,
});
conj.push(unif);
bindings.push(dup);
}
}
expr => {
let span = expr.span();
let kw = gen.next(span);
bindings.push(kw.clone());
let unif = NormalFormAtom::Unification(Unification {
binding: kw,
expr,
one_many_unif: false,
span,
});
conj.push(unif)
}
}
} else {
bindings.push(gen.next_ignored(self.span));
}
}
let query = match self.query {
Expr::Binding { var, .. } => var,
expr => {
let span = expr.span();
let kw = gen.next(span);
let unif = NormalFormAtom::Unification(Unification {
binding: kw.clone(),
expr,
one_many_unif: false,
span,
});
conj.push(unif);
kw
}
};
let bind_field = match self.bind_field {
None => None,
Some(Expr::Binding { var, .. }) => Some(var),
Some(expr) => {
let span = expr.span();
let kw = gen.next(span);
let unif = NormalFormAtom::Unification(Unification {
binding: kw.clone(),
expr,
one_many_unif: false,
span,
});
conj.push(unif);
Some(kw)
}
};
let bind_field_idx = match self.bind_field_idx {
None => None,
Some(Expr::Binding { var, .. }) => Some(var),
Some(expr) => {
let span = expr.span();
let kw = gen.next(span);
let unif = NormalFormAtom::Unification(Unification {
binding: kw.clone(),
expr,
one_many_unif: false,
span,
});
conj.push(unif);
Some(kw)
}
};
let bind_distance = match self.bind_distance {
None => None,
Some(Expr::Binding { var, .. }) => Some(var),
Some(expr) => {
let span = expr.span();
let kw = gen.next(span);
let unif = NormalFormAtom::Unification(Unification {
binding: kw.clone(),
expr,
one_many_unif: false,
span,
});
conj.push(unif);
Some(kw)
}
};
let bind_vector = match self.bind_vector {
None => None,
Some(Expr::Binding { var, .. }) => Some(var),
Some(expr) => {
let span = expr.span();
let kw = gen.next(span);
let unif = NormalFormAtom::Unification(Unification {
binding: kw.clone(),
expr,
one_many_unif: false,
span,
});
conj.push(unif);
Some(kw)
}
};
conj.push(NormalFormAtom::HnswSearch(HnswSearch {
base_handle,
idx_handle,
manifest,
bindings,
k: self.k,
ef: self.ef,
query,
bind_field,
bind_field_idx,
bind_distance,
bind_vector,
radius: self.radius,
filter: self.filter,
span: self.span,
}));
// ret.push(if is_negated {
// NormalFormAtom::NegatedRelation(NormalFormRelationApplyAtom {
// name: self.name,
// args,
// valid_at: self.valid_at,
// span: self.span,
// })
// } else {
// NormalFormAtom::Relation(NormalFormRelationApplyAtom {
// name: self.name,
// args,
// valid_at: self.valid_at,
// span: self.span,
// })
// });
// Disjunction::conj(ret)
Ok(Disjunction::conj(conj))
}
}
impl Debug for InputAtom {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{self}")
@ -1043,6 +1275,7 @@ pub(crate) enum NormalFormAtom {
NegatedRelation(NormalFormRelationApplyAtom),
Predicate(Expr),
Unification(Unification),
HnswSearch(HnswSearch),
}
#[derive(Debug, Clone)]
@ -1053,6 +1286,7 @@ pub(crate) enum MagicAtom {
NegatedRule(MagicRuleApplyAtom),
NegatedRelation(MagicRelationApplyAtom),
Unification(Unification),
HnswSearch(HnswSearch),
}
#[derive(Clone, Debug)]

@ -23,7 +23,7 @@ use crate::data::aggr::{parse_aggr, Aggregation};
use crate::data::expr::Expr;
use crate::data::functions::{str2vld, MAX_VALIDITY_TS};
use crate::data::program::{
FixedRuleApply, FixedRuleArg, HnswSearch, InputAtom, InputInlineRule, InputInlineRulesOrFixed,
FixedRuleApply, FixedRuleArg, HnswSearchInput, InputAtom, InputInlineRule, InputInlineRulesOrFixed,
InputNamedFieldRelationApplyAtom, InputProgram, InputRelationApplyAtom, InputRuleApplyAtom,
QueryAssertion, QueryOutOptions, RelationOp, SortDir, Unification,
};
@ -642,7 +642,7 @@ fn parse_atom(
.map(|arg| extract_named_apply_arg(arg, param_pool))
.try_collect()?;
let mut opts = HnswSearch {
let mut opts = HnswSearchInput {
relation,
index,
bindings,

@ -210,6 +210,9 @@ impl<'a> SessionTx<'a> {
debug_assert_eq!(prev_joiner_vars.len(), right_joiner_vars.len());
ret = ret.join(right, prev_joiner_vars, right_joiner_vars, rule_app.span);
}
MagicAtom::HnswSearch(s) => {
todo!("HNSW search")
}
MagicAtom::Relation(rel_app) => {
let store = self.get_relation(&rel_app.name, false)?;
if store.access_level < AccessLevel::ReadOnly {

@ -47,7 +47,7 @@ impl Disjunction {
inner: vec![Conjunction(vec![atom])],
}
}
fn conj(atoms: Vec<NormalFormAtom>) -> Self {
pub(crate) fn conj(atoms: Vec<NormalFormAtom>) -> Self {
Disjunction {
inner: vec![Conjunction(atoms)],
}
@ -121,9 +121,11 @@ impl InputAtom {
InputAtom::Unification { inner } => {
bail!(UnsafeNegation(inner.span))
}
InputAtom::HnswSearch { .. } => todo!(),
InputAtom::HnswSearch { inner } => {
bail!(UnsafeNegation(inner.span))
}
},
InputAtom::HnswSearch { .. } => todo!(),
InputAtom::HnswSearch { inner } => InputAtom::HnswSearch { inner },
})
}
@ -227,9 +229,7 @@ impl InputAtom {
InputAtom::Unification { inner: u } => {
Disjunction::singlet(NormalFormAtom::Unification(u))
}
InputAtom::HnswSearch { .. } => {
todo!()
}
InputAtom::HnswSearch { inner } => inner.normalize(gen, tx)?,
})
}
}

@ -172,6 +172,10 @@ fn magic_rewrite_ruleset(
seen_bindings.insert(u.binding.clone());
collected_atoms.push(MagicAtom::Unification(u));
}
MagicAtom::HnswSearch(s) => {
seen_bindings.extend(s.all_bindings().cloned());
collected_atoms.push(MagicAtom::HnswSearch(s));
}
MagicAtom::Rule(r_app) => {
if r_app.name.has_bound_adornment() {
// we are guaranteed to have a magic rule application
@ -515,6 +519,14 @@ impl NormalFormAtom {
}
MagicAtom::Relation(v)
}
NormalFormAtom::HnswSearch(s) => {
for arg in s.all_bindings() {
if !seen_bindings.contains(arg) {
seen_bindings.insert(arg.clone());
}
}
MagicAtom::HnswSearch(s.clone())
}
NormalFormAtom::Predicate(p) => {
// predicate cannot introduce new bindings
MagicAtom::Predicate(p.clone())

@ -36,6 +36,7 @@ impl NormalFormInlineRule {
let mut round_1_collected = vec![];
let mut pending = vec![];
// first round: collect all unifications that are completely bounded
for atom in self.body {
match atom {
NormalFormAtom::Unification(u) => {
@ -58,8 +59,8 @@ impl NormalFormInlineRule {
}
round_1_collected.push(NormalFormAtom::Rule(r))
}
NormalFormAtom::Relation(mut v) => {
for arg in &mut v.args {
NormalFormAtom::Relation(v) => {
for arg in &v.args {
seen_variables.insert(arg.clone());
}
round_1_collected.push(NormalFormAtom::Relation(v))
@ -71,12 +72,21 @@ impl NormalFormInlineRule {
NormalFormAtom::Predicate(p) => {
pending.push(NormalFormAtom::Predicate(p));
}
NormalFormAtom::HnswSearch(s) => {
if seen_variables.contains(&s.query) {
seen_variables.extend(s.all_bindings().cloned());
round_1_collected.push(NormalFormAtom::HnswSearch(s));
} else {
pending.push(NormalFormAtom::HnswSearch(s));
}
}
}
}
let mut collected = vec![];
seen_variables.clear();
let mut last_pending = vec![];
// second round: insert pending where possible
for atom in round_1_collected {
mem::swap(&mut last_pending, &mut pending);
pending.clear();
@ -98,6 +108,10 @@ impl NormalFormInlineRule {
seen_variables.insert(u.binding.clone());
collected.push(NormalFormAtom::Unification(u));
}
NormalFormAtom::HnswSearch(s) => {
seen_variables.extend(s.all_bindings().cloned());
collected.push(NormalFormAtom::HnswSearch(s));
}
}
for atom in last_pending.iter() {
match atom {
@ -116,6 +130,14 @@ impl NormalFormInlineRule {
pending.push(NormalFormAtom::NegatedRelation(v.clone()));
}
}
NormalFormAtom::HnswSearch(s) => {
if seen_variables.contains(&s.query) {
seen_variables.extend(s.all_bindings().cloned());
collected.push(NormalFormAtom::HnswSearch(s.clone()));
} else {
pending.push(NormalFormAtom::HnswSearch(s.clone()));
}
}
NormalFormAtom::Predicate(p) => {
if p.bindings().is_subset(&seen_variables) {
collected.push(NormalFormAtom::Predicate(p.clone()));
@ -158,6 +180,9 @@ impl NormalFormInlineRule {
NormalFormAtom::Unification(u) => {
bail!(UnboundVariable(u.span))
}
NormalFormAtom::HnswSearch(s) => {
bail!(UnboundVariable(s.span))
}
}
}
}

@ -29,7 +29,8 @@ impl NormalFormAtom {
NormalFormAtom::Relation(_)
| NormalFormAtom::NegatedRelation(_)
| NormalFormAtom::Predicate(_)
| NormalFormAtom::Unification(_) => Default::default(),
| NormalFormAtom::Unification(_)
| NormalFormAtom::HnswSearch(_) => Default::default(),
NormalFormAtom::Rule(r) => BTreeMap::from([(&r.name, false)]),
NormalFormAtom::NegatedRule(r) => BTreeMap::from([(&r.name, true)]),
}

Loading…
Cancel
Save