diff --git a/README.md b/README.md index 303cfbba..c92f26af 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ ## TODO -* [ ] predicates +* [x] predicates * [ ] negation * [ ] aggregation * [ ] stratum diff --git a/src/data/expr.rs b/src/data/expr.rs index 98a82225..85d2daba 100644 --- a/src/data/expr.rs +++ b/src/data/expr.rs @@ -1,4 +1,4 @@ -use std::collections::BTreeSet; +use std::collections::{BTreeMap, BTreeSet}; use std::fmt::{Debug, Formatter}; use anyhow::Result; @@ -19,13 +19,26 @@ pub enum ExprError { #[derive(Debug, Clone)] pub enum Expr { - Binding(Keyword), - BoundIdx(usize), + Binding(Keyword, Option), Const(DataValue), Apply(&'static Op, Box<[Expr]>), } impl Expr { + pub(crate) fn fill_binding_indices(&mut self, binding_map: &BTreeMap) { + match self { + Expr::Binding(k, idx) => { + let found_idx = *binding_map.get(k).unwrap(); + *idx = Some(found_idx) + } + Expr::Const(_) => {} + Expr::Apply(_, args) => { + for arg in args.iter_mut() { + arg.fill_binding_indices(binding_map); + } + } + } + } pub(crate) fn bindings(&self) -> BTreeSet { let mut ret = BTreeSet::new(); self.collect_bindings(&mut ret); @@ -33,10 +46,9 @@ impl Expr { } pub(crate) fn collect_bindings(&self, coll: &mut BTreeSet) { match self { - Expr::Binding(b) => { + Expr::Binding(b, _) => { coll.insert(b.clone()); } - Expr::BoundIdx(_) => {} Expr::Const(_) => {} Expr::Apply(_, args) => { for arg in args.iter() { @@ -47,10 +59,7 @@ impl Expr { } pub(crate) fn eval(&self, bindings: &Tuple) -> Result { match self { - Expr::Binding(_) => { - unreachable!() - } - Expr::BoundIdx(i) => Ok(bindings.0[*i].clone()), + Expr::Binding(_, i) => Ok(bindings.0[i.unwrap()].clone()), Expr::Const(d) => Ok(d.clone()), Expr::Apply(op, args) => { let args: Box<[DataValue]> = args.iter().map(|v| v.eval(bindings)).try_collect()?; @@ -77,7 +86,7 @@ pub struct Op { impl Debug for Op { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - f.debug_tuple(self.name).field(&self.min_arity).finish() + write!(f, "{}", self.name) } } diff --git a/src/parse/query.rs b/src/parse/query.rs index ee8dba7a..1102a822 100644 --- a/src/parse/query.rs +++ b/src/parse/query.rs @@ -163,7 +163,7 @@ impl SessionTx { JsonValue::String(s) => { let kw = Keyword::from(s as &str); if kw.is_reserved() { - Ok(Expr::Binding(kw)) + Ok(Expr::Binding(kw, None)) } else { Ok(Expr::Const(DataValue::String(s.into()))) } diff --git a/src/query/compile.rs b/src/query/compile.rs index 992ca900..b48fff54 100644 --- a/src/query/compile.rs +++ b/src/query/compile.rs @@ -64,7 +64,7 @@ pub enum QueryCompilationError { #[error("op {0} is not a predicate")] NotAPredicate(&'static str), #[error("unsafe bindings in expression {0:?}: {1:?}")] - UnsafeBindingInPredicate(Expr, BTreeSet) + UnsafeBindingInPredicate(Expr, BTreeSet), } #[derive(Clone, Debug)] @@ -140,7 +140,7 @@ impl Atom { pub(crate) fn into_predicate(self) -> Option { match self { Atom::Predicate(e) => Some(e), - _ => None + _ => None, } } pub(crate) fn collect_bindings(&self, coll: &mut BTreeSet) { @@ -417,9 +417,7 @@ impl SessionTx { debug_assert_eq!(prev_joiner_vars.len(), right_joiner_vars.len()); ret = ret.join(right, prev_joiner_vars, right_joiner_vars); } - Atom::Predicate(_) => { - todo!() - } + Atom::Predicate(p) => ret = ret.filter(p.clone()), Atom::Logical(_) => { todo!() } diff --git a/src/query/eval.rs b/src/query/eval.rs index a376f2bf..9acd07f8 100644 --- a/src/query/eval.rs +++ b/src/query/eval.rs @@ -34,8 +34,9 @@ impl SessionTx { let mut collected = Vec::with_capacity(body.rules.len()); for rule in &body.rules { let header = rule.head.iter().map(|t| &t.name).cloned().collect_vec(); - let relation = + let mut relation = self.compile_rule_body(&rule.body, rule.vld, &stores, &header)?; + relation.fill_predicate_binding_indices(); collected.push((rule.head.clone(), rule.contained_rules(), relation)); } Ok((k.clone(), collected)) diff --git a/src/query/relation.rs b/src/query/relation.rs index 8c8902c0..27f273a1 100644 --- a/src/query/relation.rs +++ b/src/query/relation.rs @@ -6,6 +6,7 @@ use anyhow::Result; use itertools::Itertools; use crate::data::attr::Attribute; +use crate::data::expr::Expr; use crate::data::keyword::Keyword; use crate::data::tuple::{Tuple, TupleIter}; use crate::data::value::DataValue; @@ -20,6 +21,44 @@ pub enum Relation { Derived(StoredDerivedRelation), Join(Box), Reorder(ReorderRelation), + Filter(FilteredRelation), +} + +pub struct FilteredRelation { + parent: Box, + pred: Expr, +} + +impl FilteredRelation { + fn fill_binding_indices(&mut self) { + let parent_bindings: BTreeMap<_, _> = self + .parent + .bindings() + .into_iter() + .enumerate() + .map(|(a, b)| (b, a)) + .collect(); + self.pred.fill_binding_indices(&parent_bindings); + } + fn iter<'a>( + &'a self, + tx: &'a SessionTx, + epoch: Option, + use_delta: &BTreeSet, + ) -> TupleIter { + Box::new( + self.parent + .iter(tx, epoch, use_delta) + .filter_map(|tuple| match tuple { + Ok(t) => match self.pred.eval_pred(&t) { + Ok(true) => Some(Ok(t)), + Ok(false) => None, + Err(e) => Some(Err(e)), + }, + Err(e) => Some(Err(e)), + }), + ) + } } struct BindingFormatter<'a>(&'a [Keyword]); @@ -63,8 +102,7 @@ impl Debug for Relation { if r.left.is_unit() { r.right.fmt(f) } else { - f - .debug_tuple("Join") + f.debug_tuple("Join") .field(&BindingFormatter(&r.bindings())) .field(&r.joiner) .field(&r.left) @@ -77,11 +115,34 @@ impl Debug for Relation { .field(&r.new_order) .field(&r.relation) .finish(), + Relation::Filter(r) => f + .debug_tuple("Filter") + .field(&r.pred) + .field(&r.parent) + .finish(), } } } impl Relation { + pub(crate) fn fill_predicate_binding_indices(&mut self) { + match self { + Relation::Fixed(_) => {} + Relation::Triple(_) => {} + Relation::Derived(_) => {} + Relation::Join(r) => { + r.left.fill_predicate_binding_indices(); + r.right.fill_predicate_binding_indices(); + } + Relation::Reorder(r) => { + r.relation.fill_predicate_binding_indices(); + } + Relation::Filter(f) => { + f.parent.fill_predicate_binding_indices(); + f.fill_binding_indices() + } + } + } pub(crate) fn unit() -> Self { Self::Fixed(InlineFixedRelation::unit()) } @@ -126,6 +187,12 @@ impl Relation { new_order, }) } + pub(crate) fn filter(self, filter: Expr) -> Self { + Relation::Filter(FilteredRelation { + parent: Box::new(self), + pred: filter, + }) + } pub(crate) fn join( self, right: Relation, @@ -786,6 +853,7 @@ impl Relation { Relation::Derived(_r) => Ok(()), Relation::Join(r) => r.do_eliminate_temp_vars(used), Relation::Reorder(r) => r.relation.eliminate_temp_vars(used), + Relation::Filter(_) => Ok(()), } } @@ -796,6 +864,7 @@ impl Relation { Relation::Derived(_) => None, Relation::Join(r) => Some(&r.to_eliminate), Relation::Reorder(_) => None, + Relation::Filter(_) => None, } } @@ -817,6 +886,7 @@ impl Relation { Relation::Derived(d) => d.bindings.clone(), Relation::Join(j) => j.bindings(), Relation::Reorder(r) => r.bindings(), + Relation::Filter(r) => r.parent.bindings(), } } pub fn iter<'a>( @@ -834,6 +904,7 @@ impl Relation { Relation::Derived(r) => r.iter(epoch, use_delta), Relation::Join(j) => j.iter(tx, epoch, use_delta), Relation::Reorder(r) => r.iter(tx, epoch, use_delta), + Relation::Filter(r) => r.iter(tx, epoch, use_delta), } } } @@ -909,7 +980,9 @@ impl InnerJoin { self.materialized_join(tx, eliminate_indices, epoch, use_delta) } } - Relation::Join(_) => self.materialized_join(tx, eliminate_indices, epoch, use_delta), + Relation::Join(_) | Relation::Filter(_) => { + self.materialized_join(tx, eliminate_indices, epoch, use_delta) + } Relation::Reorder(_) => { panic!("joining on reordered") } diff --git a/tests/creation.rs b/tests/creation.rs index 28dbd49c..f3faa614 100644 --- a/tests/creation.rs +++ b/tests/creation.rs @@ -141,10 +141,10 @@ fn creation() { }, { "rule": "?", - "args": [["?n", "?a"], {"rule": "ff", "args": [{"person/id": "alice_amorist"}, "?a"]}, ["?a", "person/first_name", "?n"]] + "args": [["?n", "?a"], {"pred": "Neq", "args": ["?n", "Alice"]}, {"rule": "ff", "args": [{"person/id": "alice_amorist"}, "?a"]}, ["?a", "person/first_name", "?n"]] } ], - "out": {"friend": {"pull": "?a", "spec": ["*"]}} + // "out": {"friend": {"pull": "?a", "spec": ["*"]}} }); let mut tx = db.transact().unwrap(); let ret = tx.run_query(&query).unwrap();