From 2b50e2aaf497df3e46d7fedb9d736aa28df499e9 Mon Sep 17 00:00:00 2001 From: Ziyang Hu Date: Mon, 1 Aug 2022 19:35:00 +0800 Subject: [PATCH] well-ordered rules --- src/data/program.rs | 11 ++- src/query/logical.rs | 63 ++++++++++---- src/query/mod.rs | 1 + src/query/reorder.rs | 198 +++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 256 insertions(+), 17 deletions(-) create mode 100644 src/query/reorder.rs diff --git a/src/data/program.rs b/src/data/program.rs index b4a627a5..131adefe 100644 --- a/src/data/program.rs +++ b/src/data/program.rs @@ -49,7 +49,7 @@ impl InputProgram { body: conj.0, vld: rule.vld, }; - collected_rules.push(normalized_rule); + collected_rules.push(normalized_rule.convert_to_well_ordered_rule()?); } } prog.insert(k, collected_rules); @@ -224,3 +224,12 @@ pub struct Unification { pub(crate) binding: Keyword, pub(crate) expr: Expr, } + +impl Unification { + pub(crate) fn is_const(&self) -> bool { + matches!(self, Expr::Const(_)) + } + pub(crate) fn bindings_in_expr(&self) -> BTreeSet { + self.expr.bindings() + } +} diff --git a/src/query/logical.rs b/src/query/logical.rs index 5b0c1e45..39a6652b 100644 --- a/src/query/logical.rs +++ b/src/query/logical.rs @@ -1,3 +1,5 @@ +use std::collections::BTreeSet; + use anyhow::{bail, Result}; use itertools::Itertools; @@ -103,15 +105,15 @@ impl InputAtom { } result } - InputAtom::AttrTriple(a) => Disjunction::conj(a.normalize(false, gen)), - InputAtom::Rule(r) => Disjunction::conj(r.normalize(false, gen)), + InputAtom::AttrTriple(a) => a.normalize(false, gen), + InputAtom::Rule(r) => r.normalize(false, gen), InputAtom::Predicate(mut p) => { p.partial_eval()?; Disjunction::singlet(NormalFormAtom::Predicate(p)) } InputAtom::Negation(n) => match *n { - InputAtom::Rule(r) => Disjunction::conj(r.normalize(true, gen)), - InputAtom::AttrTriple(r) => Disjunction::conj(r.normalize(true, gen)), + InputAtom::Rule(r) => r.normalize(true, gen), + InputAtom::AttrTriple(r) => r.normalize(true, gen), _ => unreachable!(), }, InputAtom::Unification(u) => Disjunction::singlet(NormalFormAtom::Unification(u)), @@ -120,12 +122,25 @@ impl InputAtom { } impl InputRuleApplyAtom { - fn normalize(mut self, is_negated: bool, gen: &mut TempKwGen) -> Vec { + fn normalize(mut self, is_negated: bool, gen: &mut TempKwGen) -> Disjunction { let mut ret = Vec::with_capacity(self.args.len() + 1); let mut args = Vec::with_capacity(self.args.len()); + let mut seen_variables = BTreeSet::new(); for arg in self.args { match arg { - InputTerm::Var(kw) => args.push(kw), + InputTerm::Var(kw) => { + if seen_variables.insert(kw.clone()) { + args.push(kw); + } else { + let dup = gen.next(); + let unif = NormalFormAtom::Unification(Unification { + binding: dup.clone(), + expr: Expr::Binding(kw, None), + }); + ret.push(unif); + args.push(dup); + } + } InputTerm::Const(val) => { let kw = gen.next(); args.push(kw.clone()); @@ -149,12 +164,12 @@ impl InputRuleApplyAtom { args, }) }); - ret + Disjunction::conj(ret) } } impl InputAttrTripleAtom { - fn normalize(mut self, is_negated: bool, gen: &mut TempKwGen) -> Vec { + fn normalize(mut self, is_negated: bool, gen: &mut TempKwGen) -> Disjunction { let wrap = |atom| { if is_negated { NormalFormAtom::NegatedAttrTriple(atom) @@ -162,7 +177,7 @@ impl InputAttrTripleAtom { NormalFormAtom::AttrTriple(atom) } }; - match (self.entity, self.value) { + Disjunction::conj(match (self.entity, self.value) { (InputTerm::Const(eid), InputTerm::Const(val)) => { let ekw = gen.next(); let vkw = gen.next(); @@ -211,14 +226,30 @@ impl InputAttrTripleAtom { vec![ue, ret] } (InputTerm::Var(ekw), InputTerm::Var(vkw)) => { - let ret = wrap(NormalFormAttrTripleAtom { - attr: self.attr, - entity: ekw, - value: vkw, - }); - vec![ret] + if ekw == vkw { + let dup = gen.next(); + let atom = NormalFormAttrTripleAtom { + attr: self.attr, + entity: ekw, + value: dup.clone(), + }; + vec![ + NormalFormAtom::Unification(Unification { + binding: dup, + expr: Expr::Binding(vkw, None), + }), + wrap(atom), + ] + } else { + let ret = wrap(NormalFormAttrTripleAtom { + attr: self.attr, + entity: ekw, + value: vkw, + }); + vec![ret] + } } - } + }) } } diff --git a/src/query/mod.rs b/src/query/mod.rs index b33e386a..ca624ba8 100644 --- a/src/query/mod.rs +++ b/src/query/mod.rs @@ -21,6 +21,7 @@ pub(crate) mod magic; pub(crate) mod pull; pub(crate) mod relation; pub(crate) mod stratify; +pub(crate) mod reorder; impl SessionTx { pub fn run_query(&mut self, payload: &JsonValue) -> Result> { diff --git a/src/query/reorder.rs b/src/query/reorder.rs new file mode 100644 index 00000000..27195f46 --- /dev/null +++ b/src/query/reorder.rs @@ -0,0 +1,198 @@ +use std::collections::BTreeSet; +use std::mem; + +use anyhow::{ensure, Result}; + +use crate::data::program::{NormalFormAtom, NormalFormRule}; + +impl NormalFormRule { + pub(crate) fn convert_to_well_ordered_rule(self) -> Result { + let mut seen_variables: BTreeSet<_> = self.head.iter().cloned().collect(); + let mut round_1_collected = vec![]; + let mut pending = vec![]; + + for atom in self.body { + match atom { + a @ NormalFormAtom::Unification(ref u) => { + if u.is_const() { + seen_variables.insert(u.binding.clone()); + round_1_collected.push(a); + } else { + let unif_vars = u.bindings_in_expr(); + if unif_vars.is_subset(&seen_variables) { + seen_variables.insert(u.binding.clone()); + round_1_collected.push(a); + } else { + pending.push(a); + } + } + } + a @ NormalFormAtom::AttrTriple(ref t) => { + seen_variables.insert(t.value.clone()); + seen_variables.insert(t.entity.clone()); + round_1_collected.push(a); + } + a @ NormalFormAtom::Rule(ref r) => { + for arg in &r.args { + seen_variables.insert(arg.clone()); + } + round_1_collected.push(a) + } + a @ (NormalFormAtom::NegatedAttrTriple(_) + | NormalFormAtom::NegatedRule(_) + | NormalFormAtom::Predicate(_)) => { + pending.push(a); + } + } + } + + let mut collected = vec![]; + seen_variables = self.head.iter().cloned().collect(); + let mut last_pending = pending; + let mut pending = vec![]; + for atom in round_1_collected { + mem::swap(&mut last_pending, &mut pending); + pending.clear(); + match atom { + a @ NormalFormAtom::AttrTriple(ref t) => { + seen_variables.insert(t.value.clone()); + seen_variables.insert(t.entity.clone()); + collected.push(a) + } + a @ NormalFormAtom::Rule(ref r) => { + seen_variables.extend(r.args.iter().cloned()); + collected.push(a) + } + a @ (NormalFormAtom::NegatedAttrTriple(_) + | NormalFormAtom::NegatedRule(_) + | NormalFormAtom::Predicate(_)) => { + unreachable!() + } + a @ NormalFormAtom::Unification(ref u) => { + seen_variables.insert(u.binding.clone()); + collected.push(a); + } + } + for atom in last_pending { + match atom { + NormalFormAtom::AttrTriple(_) | NormalFormAtom::Rule(_) => unreachable!(), + a @ NormalFormAtom::NegatedAttrTriple(ref t) => { + if seen_variables.contains(&t.value) && seen_variables.contains(&t.entity) { + collected.push(a); + } else { + pending.push(a); + } + } + a @ NormalFormAtom::NegatedRule(ref r) => { + if r.args.iter().map(|a| seen_variables.contains(a)).all() { + collected.push(a); + } else { + pending.push(a); + } + } + a @ NormalFormAtom::Predicate(ref p) => { + if p.bindings().is_subset(&seen_variables) { + collected.push(a); + } else { + pending.push(a); + } + } + a @ NormalFormAtom::Unification(ref u) => { + if u.bindings_in_expr().is_subset(&seen_variables) { + collected.push(a); + } else { + pending.push(a); + } + } + } + } + } + + ensure!( + pending.is_empty(), + "found unsafe atoms in rule: {:?}", + pending + ); + + Ok(NormalFormRule { + head: self.head, + aggr: self.aggr, + body: collected, + vld: self.vld, + }) + } +} + +// fn reorder_rule_body_for_negations(clauses: Vec) -> Result> { +// let (negations, others): (Vec<_>, _) = clauses.into_iter().partition(|a| a.is_negation()); +// let mut seen_bindings = BTreeSet::new(); +// for a in &others { +// a.collect_bindings(&mut seen_bindings); +// } +// let mut negations_with_meta = negations +// .into_iter() +// .map(|p| { +// let p = p.into_negated().unwrap(); +// let mut bindings = Default::default(); +// p.collect_bindings(&mut bindings); +// let valid_bindings: BTreeSet<_> = +// bindings.intersection(&seen_bindings).cloned().collect(); +// (Some(p), valid_bindings) +// }) +// .collect_vec(); +// let mut ret = vec![]; +// seen_bindings.clear(); +// for a in others { +// a.collect_bindings(&mut seen_bindings); +// ret.push(a); +// for (negated, pred_bindings) in negations_with_meta.iter_mut() { +// if negated.is_none() { +// continue; +// } +// if seen_bindings.is_superset(pred_bindings) { +// let negated = negated.take().unwrap(); +// ret.push(Atom::Negation(Box::new(negated))); +// } +// } +// } +// Ok(ret) +// } +// +// fn reorder_rule_body_for_predicates(clauses: Vec) -> Result> { +// let (predicates, others): (Vec<_>, _) = clauses.into_iter().partition(|a| a.is_predicate()); +// let mut predicates_with_meta = predicates +// .into_iter() +// .map(|p| { +// let p = p.into_predicate().unwrap(); +// let bindings = p.bindings(); +// (Some(p), bindings) +// }) +// .collect_vec(); +// let mut seen_bindings = BTreeSet::new(); +// let mut ret = vec![]; +// for a in others { +// a.collect_bindings(&mut seen_bindings); +// ret.push(a); +// for (pred, pred_bindings) in predicates_with_meta.iter_mut() { +// if pred.is_none() { +// continue; +// } +// if seen_bindings.is_superset(pred_bindings) { +// let pred = pred.take().unwrap(); +// ret.push(Atom::Predicate(pred)); +// } +// } +// } +// for (p, bindings) in predicates_with_meta { +// ensure!( +// p.is_none(), +// "unsafe bindings {:?} found in predicate {:?}", +// bindings +// .difference(&seen_bindings) +// .cloned() +// .collect::>(), +// p.unwrap() +// ); +// } +// Ok(ret) +// }