diff --git a/README.md b/README.md index 1e3001bc..cd04ea8d 100644 --- a/README.md +++ b/README.md @@ -6,8 +6,8 @@ * [x] stratum * [x] magic sets * [x] unification +* [x] aggregation * [ ] duplicate symbols in rule heads -* [ ] aggregation * [ ] range scan * [ ] public API * [ ] sorting diff --git a/cozopy/air_routes.py b/cozopy/air_routes.py index 0c8deb48..4e4cad12 100644 --- a/cozopy/air_routes.py +++ b/cozopy/air_routes.py @@ -107,11 +107,13 @@ def insert_data(destroy_on_exit): if __name__ == '__main__': db = insert_data(False) start_time = time.time() - res = db.run([Q(['?c', '?code', '?desc'], - Disj(T.country.code('?c', 'CU'), - Unify('?c', 10000239)), - T.country.code('?c', '?code'), - T.country.desc('?c', '?desc'))]) + # res = db.run([Q(['?c', '?code', '?desc'], + # Disj(T.country.code('?c', 'CU'), + # Unify('?c', 10000239)), + # T.country.code('?c', '?code'), + # T.country.desc('?c', '?desc'))]) + res = db.run([Q([Count('?a')], + T.route.distance('?a', '?n'))]) end_time = time.time() print(json.dumps(res, indent=2)) print(f'{len(res)} results fetched in {(end_time - start_time) * 1000:.3f}ms') diff --git a/cozopy/cozo.py b/cozopy/cozo.py index 7912a6be..db636097 100644 --- a/cozopy/cozo.py +++ b/cozopy/cozo.py @@ -158,34 +158,55 @@ R = RuleClass(None) Q = RuleClass('?') -class PredicateClass: - def __init__(self, pred_name): - self._pred_name = pred_name +class OpClass: + def __init__(self, op_name): + self._op_name = op_name def __getattr__(self, name): - if self._pred_name is None: + if self._op_name is None: return self.__class__(name) else: - raise Exception("cannot nest predicate name") + raise Exception("cannot nest op name") def __call__(self, *args): - if self._pred_name is None: + if self._op_name is None: + raise Exception("you need to set the op name first") + ret = {'op': self._op_name, 'args': list(args)} + return ret + + +Gt = OpClass('Gt') +Lt = OpClass('Lt') +Ge = OpClass('Ge') +Le = OpClass('Le') +Eq = OpClass('Eq') +Neq = OpClass('Neq') +Add = OpClass('Add') +Sub = OpClass('Sub') +Mul = OpClass('Mul') +Div = OpClass('Div') +StrCat = OpClass('StrCat') + + +class AggrClass: + def __init__(self, aggr_name): + self._aggr_name = aggr_name + + def __getattr__(self, name): + if self._aggr_name is None: + return self.__class__(name) + else: + raise Exception("cannot nest aggr name") + + def __call__(self, symb): + if self._aggr_name is None: raise Exception("you need to set the predicate name first") - ret = {'op': self._pred_name, 'args': list(args)} + ret = {'aggr': self._aggr_name, 'symb': symb} return ret -Gt = PredicateClass('Gt') -Lt = PredicateClass('Lt') -Ge = PredicateClass('Ge') -Le = PredicateClass('Le') -Eq = PredicateClass('Eq') -Neq = PredicateClass('Neq') -Add = PredicateClass('Add') -Sub = PredicateClass('Sub') -Mul = PredicateClass('Mul') -Div = PredicateClass('Div') -StrCat = PredicateClass('StrCat') +Count = AggrClass('Count') +Min = AggrClass('Min') def Const(item): @@ -210,4 +231,4 @@ def Unify(binding, expr): __all__ = ['Gt', 'Lt', 'Ge', 'Le', 'Eq', 'Neq', 'Add', 'Sub', 'Mul', 'Div', 'Q', 'T', 'R', 'Const', 'Conj', 'Disj', 'NotExists', 'CozoDb', 'Typing', 'Cardinality', 'Indexing', 'PutAttr', 'RetractAttr', 'Attribute', 'Put', - 'Retract', 'Pull', 'StrCat', 'Unify', 'DefAttrs'] + 'Retract', 'Pull', 'StrCat', 'Unify', 'DefAttrs', 'Count', 'Min'] diff --git a/src/data/aggr.rs b/src/data/aggr.rs index e3ac6a59..3c021e73 100644 --- a/src/data/aggr.rs +++ b/src/data/aggr.rs @@ -1,14 +1,22 @@ +use std::cmp::min; use std::fmt::{Debug, Formatter}; -use anyhow::Result; +use anyhow::{bail, Result}; +use ordered_float::Float; use crate::data::value::DataValue; #[derive(Clone)] pub(crate) struct Aggregation { pub(crate) name: &'static str, - pub(crate) init_state: fn() -> DataValue, pub(crate) combine: fn(&DataValue, &DataValue) -> Result, + pub(crate) is_meet: bool, +} + +impl PartialEq for Aggregation { + fn eq(&self, other: &Self) -> bool { + self.name == other.name + } } impl Debug for Aggregation { @@ -18,23 +26,74 @@ impl Debug for Aggregation { } macro_rules! define_aggr { - ($name:ident, $init:ident) => { + ($name:ident, $is_meet:expr) => { const $name: Aggregation = Aggregation { name: stringify!($name), - init_state: $init, combine: ::casey::lower!($name), + is_meet: $is_meet, }; }; } -fn init_zero() -> DataValue { - DataValue::Int(0) +define_aggr!(AGGR_COUNT, false); +fn aggr_count(accum: &DataValue, current: &DataValue) -> Result { + match (accum, current) { + (DataValue::Bottom, DataValue::Bottom) => Ok(DataValue::Int(0)), + (DataValue::Bottom, _) => Ok(DataValue::Int(1)), + (DataValue::Int(i), DataValue::Bottom) => Ok(DataValue::Int(*i)), + (DataValue::Int(i), _) => Ok(DataValue::Int(*i + 1)), + _ => unreachable!(), + } } -define_aggr!(AGGR_COUNT, init_zero); -fn aggr_count(existing: &DataValue, _: &DataValue) -> Result { - match existing { - DataValue::Int(i) => Ok(DataValue::Int(*i + 1)), - _ => unreachable!(), +define_aggr!(AGGR_SUM, false); +fn aggr_sum(accum: &DataValue, current: &DataValue) -> Result { + match (accum, current) { + (DataValue::Bottom, DataValue::Bottom) => Ok(DataValue::Int(0)), + (DataValue::Bottom, DataValue::Int(i)) => Ok(DataValue::Int(*i)), + (DataValue::Bottom, DataValue::Float(f)) => Ok(DataValue::Float(f.0.into())), + (DataValue::Int(i), DataValue::Bottom) => Ok(DataValue::Int(*i)), + (DataValue::Float(f), DataValue::Bottom) => Ok(DataValue::Float(f.0.into())), + (DataValue::Int(i), DataValue::Int(j)) => Ok(DataValue::Int(*i + *j)), + (DataValue::Int(j), DataValue::Float(i)) | (DataValue::Float(i), DataValue::Int(j)) => { + Ok(DataValue::Float((i.0 + (*j as f64)).into())) + } + (DataValue::Float(i), DataValue::Float(j)) => Ok(DataValue::Float((i.0 + j.0).into())), + (i, j) => bail!( + "cannot compute min: encountered value {:?} for aggregate {:?}", + j, + i + ), } } + +define_aggr!(AGGR_MIN, false); +fn aggr_min(accum: &DataValue, current: &DataValue) -> Result { + match (accum, current) { + (DataValue::Bottom, DataValue::Bottom) => Ok(DataValue::Float(f64::infinity().into())), + (DataValue::Bottom, DataValue::Int(i)) => Ok(DataValue::Int(*i)), + (DataValue::Bottom, DataValue::Float(f)) => Ok(DataValue::Float(f.0.into())), + (DataValue::Int(i), DataValue::Bottom) => Ok(DataValue::Int(*i)), + (DataValue::Float(f), DataValue::Bottom) => Ok(DataValue::Float(f.0.into())), + (DataValue::Int(i), DataValue::Int(j)) => Ok(DataValue::Int(min(*i, *j))), + (DataValue::Int(j), DataValue::Float(i)) | (DataValue::Float(i), DataValue::Int(j)) => { + Ok(DataValue::Float(min(i.clone(), (*j as f64).into()))) + } + (DataValue::Float(i), DataValue::Float(j)) => { + Ok(DataValue::Float(min(i.clone(), j.clone()))) + } + (i, j) => bail!( + "cannot compute min: encountered value {:?} for aggregate {:?}", + j, + i + ), + } +} + +pub(crate) fn get_aggr(name: &str) -> Option<&'static Aggregation> { + Some(match name { + "Count" => &AGGR_COUNT, + "Min" => &AGGR_MIN, + _ => return None, + }) +} diff --git a/src/data/program.rs b/src/data/program.rs index f1037d0d..e06e3851 100644 --- a/src/data/program.rs +++ b/src/data/program.rs @@ -3,8 +3,8 @@ use std::fmt::{Debug, Formatter}; use anyhow::Result; use smallvec::SmallVec; -use crate::data::aggr::Aggregation; +use crate::data::aggr::Aggregation; use crate::data::attr::Attribute; use crate::data::expr::Expr; use crate::data::id::{EntityId, Validity}; @@ -63,9 +63,14 @@ pub(crate) struct NormalFormProgram { #[derive(Debug, Clone)] pub(crate) struct StratifiedMagicProgram(pub(crate) Vec); +#[derive(Debug, Clone, Default)] +pub(crate) struct MagicRuleSet { + pub(crate) rules: Vec, +} + #[derive(Debug, Clone)] pub(crate) struct MagicProgram { - pub(crate) prog: BTreeMap>, + pub(crate) prog: BTreeMap, } #[derive(Clone, Ord, PartialOrd, Eq, PartialEq)] diff --git a/src/data/symb.rs b/src/data/symb.rs index 5af756e6..7b740da6 100644 --- a/src/data/symb.rs +++ b/src/data/symb.rs @@ -51,6 +51,14 @@ impl Symbol { ); Ok(()) } + pub(crate) fn validate_query_var(&self) -> Result<()> { + ensure!( + self.is_query_var(), + "query var must start with '?': {}", + self.0 + ); + Ok(()) + } pub(crate) fn is_prog_entry(&self) -> bool { self.0 == "?" } diff --git a/src/data/tuple.rs b/src/data/tuple.rs index ba19fb0d..424613d6 100644 --- a/src/data/tuple.rs +++ b/src/data/tuple.rs @@ -12,6 +12,7 @@ use crate::runtime::temp_store::TempStoreId; pub(crate) const SCRATCH_DB_KEY_PREFIX_LEN: usize = 6; +#[derive(Clone, Ord, PartialOrd, Eq, PartialEq)] pub(crate) struct Tuple(pub(crate) Vec); impl Debug for Tuple { diff --git a/src/parse/query.rs b/src/parse/query.rs index 90bfa35d..41637435 100644 --- a/src/parse/query.rs +++ b/src/parse/query.rs @@ -5,6 +5,7 @@ use anyhow::{anyhow, bail, ensure, Result}; use itertools::Itertools; use serde_json::{json, Map}; +use crate::data::aggr::get_aggr; use crate::data::attr::Attribute; use crate::data::expr::{get_op, Expr}; use crate::data::id::{EntityId, Validity}; @@ -334,10 +335,35 @@ impl SessionTx { let mut rule_aggr = vec![]; for head_item in rule_head_vec { if let Some(s) = head_item.as_str() { - rule_head.push(Symbol::from(s)); + let symbol = Symbol::from(s); + symbol.validate_query_var()?; + rule_head.push(symbol); rule_aggr.push(None); + } else if let Some(m) = head_item.as_object() { + let s = m + .get("symb") + .ok_or_else(|| anyhow!("expect field 'symb' in rule head map"))? + .as_str() + .ok_or_else(|| { + anyhow!("expect field 'symb' in rule head map to be a symbol") + })?; + let symbol = Symbol::from(s); + symbol.validate_query_var()?; + + let aggr = m + .get("aggr") + .ok_or_else(|| anyhow!("expect field 'aggr' in rule head map"))? + .as_str() + .ok_or_else(|| { + anyhow!("expect field 'aggr' in rule head map to be a symbol") + })?; + let aggr = get_aggr(aggr) + .ok_or_else(|| anyhow!("aggregation {} not found", aggr))? + .clone(); + rule_head.push(symbol); + rule_aggr.push(Some(aggr)); } else { - todo!() + bail!("cannot parse {} as rule head", head_item); } } let rule_body: Vec = args diff --git a/src/query/compile.rs b/src/query/compile.rs index 7945524e..6c4245c7 100644 --- a/src/query/compile.rs +++ b/src/query/compile.rs @@ -3,32 +3,83 @@ use std::collections::{BTreeMap, BTreeSet}; use anyhow::{anyhow, ensure, Result}; use itertools::Itertools; +use crate::data::aggr::Aggregation; use crate::data::expr::Expr; +use crate::data::program::{MagicAtom, MagicRule, MagicSymbol, StratifiedMagicProgram}; use crate::data::symb::Symbol; -use crate::data::program::{MagicAtom, MagicSymbol, MagicRule, StratifiedMagicProgram}; use crate::query::relation::Relation; use crate::runtime::temp_store::TempStore; use crate::runtime::transact::SessionTx; -pub(crate) type CompiledProgram = - BTreeMap, BTreeSet, Relation)>>; +pub(crate) type CompiledProgram = BTreeMap; + +#[derive(Debug)] +pub(crate) struct CompiledRuleSet { + pub(crate) rules: Vec, +} + +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub(crate) enum AggrKind { + None, + Normal, + Meet, +} + +impl CompiledRuleSet { + pub(crate) fn aggr_kind(&self) -> AggrKind { + let mut is_aggr = false; + for rule in &self.rules { + for aggr in &rule.aggr { + if aggr.is_some() { + is_aggr = true; + break; + } + } + } + if !is_aggr { + return AggrKind::None; + } + if !self.rules.iter().map(|r| &r.aggr).all_equal() { + return AggrKind::Normal; + } + for aggr in self.rules[0].aggr.iter() { + if let Some(aggr) = aggr { + if !aggr.is_meet { + return AggrKind::Normal; + } + } + } + return AggrKind::Meet; + } +} + +#[derive(Debug)] +pub(crate) struct CompiledRule { + pub(crate) aggr: Vec>, + pub(crate) relation: Relation, + pub(crate) contained_rules: BTreeSet, +} + +impl CompiledRule { + pub(crate) fn is_aggr(&self) -> bool { + self.aggr.iter().any(|a| a.is_some()) + } +} impl SessionTx { pub(crate) fn stratified_magic_compile( &mut self, prog: &StratifiedMagicProgram, ) -> Result<(Vec, BTreeMap)> { - let stores = prog - .0 - .iter() - .flat_map(|p| p.prog.iter()) - .map(|(k, s)| { - ( - k.clone(), - (self.new_throwaway(s[0].head.len(), 0, k.clone())), - ) - }) - .collect::>(); + let mut stores: BTreeMap = Default::default(); + for stratum in prog.0.iter() { + for (name, ruleset) in &stratum.prog { + stores.insert( + name.clone(), + self.new_rule_store(name.clone(), ruleset.rules[0].head.len()), + ); + } + } let compiled: Vec<_> = prog .0 @@ -38,27 +89,21 @@ impl SessionTx { cur_prog .prog .iter() - .map( - |(k, body)| -> Result<( - MagicSymbol, - Vec<(Vec, BTreeSet, Relation)>, - )> { - let mut collected = Vec::with_capacity(body.len()); - for (rule_idx, rule) in body.iter().enumerate() { - let header = &rule.head; - let mut relation = self.compile_magic_rule_body( - &rule, k, rule_idx, &stores, &header, - )?; - relation.fill_predicate_binding_indices(); - collected.push(( - rule.head.clone(), - rule.contained_rules(), - relation, - )); - } - Ok((k.clone(), collected)) - }, - ) + .map(|(k, body)| -> Result<(MagicSymbol, CompiledRuleSet)> { + let mut collected = Vec::with_capacity(body.rules.len()); + for (rule_idx, rule) in body.rules.iter().enumerate() { + let header = &rule.head; + let mut relation = + self.compile_magic_rule_body(&rule, k, rule_idx, &stores, &header)?; + relation.fill_predicate_binding_indices(); + collected.push(CompiledRule { + aggr: rule.aggr.clone(), + relation, + contained_rules: rule.contained_rules(), + }) + } + Ok((k.clone(), CompiledRuleSet { rules: collected })) + }) .try_collect() }) .try_collect()?; @@ -117,10 +162,10 @@ impl SessionTx { .ok_or_else(|| anyhow!("undefined rule {:?} encountered", rule_app.name))? .clone(); ensure!( - store.key_size == rule_app.args.len(), + store.arity == rule_app.args.len(), "arity mismatch in rule application {:?}, expect {}, found {}", rule_app.name, - store.key_size, + store.arity, rule_app.args.len() ); let mut prev_joiner_vars = vec![]; @@ -188,10 +233,10 @@ impl SessionTx { .ok_or_else(|| anyhow!("undefined rule encountered: {:?}", rule_app.name))? .clone(); ensure!( - store.key_size == rule_app.args.len(), + store.arity == rule_app.args.len(), "arity mismatch for {:?}, expect {}, got {}", rule_app.name, - store.key_size, + store.arity, rule_app.args.len() ); diff --git a/src/query/eval.rs b/src/query/eval.rs index f50b5a8c..ffc6b06f 100644 --- a/src/query/eval.rs +++ b/src/query/eval.rs @@ -4,9 +4,9 @@ use std::mem; use anyhow::{anyhow, Result}; use log::{debug, log_enabled, trace, Level}; -use crate::data::symb::PROG_ENTRY; use crate::data::program::MagicSymbol; -use crate::query::compile::CompiledProgram; +use crate::data::symb::PROG_ENTRY; +use crate::query::compile::{AggrKind, CompiledProgram}; use crate::runtime::temp_store::TempStore; use crate::runtime::transact::SessionTx; @@ -36,8 +36,8 @@ impl SessionTx { ) -> Result<()> { if log_enabled!(Level::Debug) { for (k, vs) in prog.iter() { - for (i, (binding, _, rel)) in vs.iter().enumerate() { - debug!("{:?}.{} {:?}: {:#?}", k, i, binding, rel) + for (i, compiled) in vs.rules.iter().enumerate() { + debug!("{:?}.{} {:?}", k, i, compiled) } } } @@ -48,16 +48,50 @@ impl SessionTx { for epoch in 0u32.. { debug!("epoch {}", epoch); if epoch == 0 { - for (k, rules) in prog.iter() { + for (k, ruleset) in prog.iter() { + let aggr_kind = ruleset.aggr_kind(); let store = stores.get(k).unwrap(); let use_delta = BTreeSet::default(); - for (rule_n, (_head, _deriving_rules, relation)) in rules.iter().enumerate() { - debug!("initial calculation for rule {:?}.{}", k, rule_n); - for item_res in relation.iter(self, Some(0), &use_delta) { - let item = item_res?; - trace!("item for {:?}.{}: {:?} at {}", k, rule_n, item, epoch); - store.put(&item, 0)?; - *changed.get_mut(k).unwrap() = true; + match aggr_kind { + AggrKind::None | AggrKind::Meet => { + let is_meet = aggr_kind == AggrKind::Meet; + for (rule_n, rule) in ruleset.rules.iter().enumerate() { + debug!("initial calculation for rule {:?}.{}", k, rule_n); + for item_res in rule.relation.iter(self, Some(0), &use_delta) { + let item = item_res?; + trace!("item for {:?}.{}: {:?} at {}", k, rule_n, item, epoch); + if is_meet { + store.aggr_meet_put(&item, &rule.aggr, 0)?; + } else { + store.put(&item, 0)?; + } + *changed.get_mut(k).unwrap() = true; + } + } + } + AggrKind::Normal => { + for (rule_n, rule) in ruleset.rules.iter().enumerate() { + debug!("Calculation for normal aggr rule {:?}.{}", k, rule_n); + let rule_is_aggr = rule.is_aggr(); + let store_to_use = if rule_is_aggr { + self.new_temp_store() + } else { + store.clone() + }; + for item_res in rule.relation.iter(self, Some(0), &use_delta) { + let item = item_res?; + trace!("item for {:?}.{}: {:?} at {}", k, rule_n, item, epoch); + if rule_is_aggr { + store_to_use.normal_aggr_put(&item, &rule.aggr)?; + } else { + store_to_use.put(&item, 0)?; + } + *changed.get_mut(k).unwrap() = true; + } + if rule_is_aggr { + store_to_use.normal_aggr_scan_and_put(&rule.aggr, store)?; + } + } } } } @@ -67,11 +101,11 @@ impl SessionTx { *v = false; } - for (k, rules) in prog.iter() { + for (k, ruleset) in prog.iter() { let store = stores.get(k).unwrap(); - for (rule_n, (_head, deriving_rules, relation)) in rules.iter().enumerate() { + for (rule_n, rule) in ruleset.rules.iter().enumerate() { let mut should_do_calculation = false; - for d_rule in deriving_rules { + for d_rule in &rule.contained_rules { if let Some(changed) = prev_changed.get(d_rule) { if *changed { should_do_calculation = true; @@ -84,27 +118,47 @@ impl SessionTx { continue; } for (delta_key, delta_store) in stores.iter() { - if !deriving_rules.contains(delta_key) { + if !rule.contained_rules.contains(delta_key) { continue; } + let is_meet_aggr = match ruleset.aggr_kind() { + AggrKind::None => false, + AggrKind::Normal => unreachable!(), + AggrKind::Meet => true, + }; + debug!("with delta {:?} for rule {:?}.{}", delta_key, k, rule_n); let use_delta = BTreeSet::from([delta_store.id]); - for item_res in relation.iter(self, Some(epoch), &use_delta) { + for item_res in rule.relation.iter(self, Some(epoch), &use_delta) { let item = item_res?; // improvement: the clauses can actually be evaluated in parallel - if store.exists(&item, 0)? { - trace!( - "item for {:?}.{}: {:?} at {}, rederived", - k, - rule_n, - item, - epoch - ); + if is_meet_aggr { + let aggr_changed = + store.aggr_meet_put(&item, &rule.aggr, epoch)?; + if aggr_changed { + *changed.get_mut(k).unwrap() = true; + } } else { - trace!("item for {:?}.{}: {:?} at {}", k, rule_n, item, epoch); - *changed.get_mut(k).unwrap() = true; - store.put(&item, epoch)?; - store.put(&item, 0)?; + if store.exists(&item, 0)? { + trace!( + "item for {:?}.{}: {:?} at {}, rederived", + k, + rule_n, + item, + epoch + ); + } else { + trace!( + "item for {:?}.{}: {:?} at {}", + k, + rule_n, + item, + epoch + ); + *changed.get_mut(k).unwrap() = true; + store.put(&item, epoch)?; + store.put(&item, 0)?; + } } } } diff --git a/src/query/magic.rs b/src/query/magic.rs index a37137b7..45fa9e8d 100644 --- a/src/query/magic.rs +++ b/src/query/magic.rs @@ -4,21 +4,37 @@ use std::mem; use itertools::Itertools; use smallvec::SmallVec; -use crate::data::symb::{Symbol, PROG_ENTRY}; use crate::data::program::{ - MagicAtom, MagicAttrTripleAtom, MagicSymbol, MagicProgram, MagicRule, MagicRuleApplyAtom, - NormalFormAtom, NormalFormProgram, NormalFormRule, StratifiedMagicProgram, + MagicAtom, MagicAttrTripleAtom, MagicProgram, MagicRule, MagicRuleApplyAtom, MagicRuleSet, + MagicSymbol, NormalFormAtom, NormalFormProgram, NormalFormRule, StratifiedMagicProgram, StratifiedNormalFormProgram, }; +use crate::data::symb::{Symbol, PROG_ENTRY}; + +impl NormalFormProgram { + pub(crate) fn exempt_aggr_rules(&self, exempt_rules: &mut BTreeSet) { + for (name, rule_set) in self.prog.iter() { + 'outer: for rule in rule_set.iter() { + for aggr in rule.aggr.iter() { + if aggr.is_some() { + exempt_rules.insert(name.clone()); + continue 'outer; + } + } + } + } + } +} impl StratifiedNormalFormProgram { pub(crate) fn magic_sets_rewrite(self) -> StratifiedMagicProgram { - let mut upstream_rules = BTreeSet::from([PROG_ENTRY.clone()]); + let mut exempt_rules = BTreeSet::from([PROG_ENTRY.clone()]); let mut collected = vec![]; for prog in self.0 { - let adorned = prog.adorn(&upstream_rules); + prog.exempt_aggr_rules(&mut exempt_rules); + let adorned = prog.adorn(&exempt_rules); collected.push(adorned.magic_rewrite()); - upstream_rules.extend(prog.get_downstream_rules()); + exempt_rules.extend(prog.get_downstream_rules()); } StratifiedMagicProgram(collected) } @@ -29,7 +45,7 @@ impl MagicProgram { let mut ret_prog = MagicProgram { prog: Default::default(), }; - for (rule_head, rules) in self.prog { + for (rule_head, ruleset) in self.prog { // at this point, rule_head must be Muggle or Magic, the remaining options are impossible let rule_name = rule_head.as_plain_symbol(); let adornment = rule_head.magic_adornment(); @@ -37,7 +53,7 @@ impl MagicProgram { // can only be true if rule is magic and args are not all free let rule_has_bound_args = rule_head.has_bound_adornment(); - for (rule_idx, rule) in rules.into_iter().enumerate() { + for (rule_idx, rule) in ruleset.rules.into_iter().enumerate() { let mut sup_idx = 0; let mut make_sup_kw = || { let ret = MagicSymbol::Sup { @@ -75,12 +91,14 @@ impl MagicProgram { ret_prog.prog.insert( sup_kw.clone(), - vec![MagicRule { - head: sup_args.clone(), - aggr: sup_aggr, - body: sup_body, - vld: rule.vld, - }], + MagicRuleSet { + rules: vec![MagicRule { + head: sup_args.clone(), + aggr: sup_aggr, + body: sup_body, + vld: rule.vld, + }], + }, ); seen_bindings.extend(sup_args.iter().cloned()); @@ -117,7 +135,7 @@ impl MagicProgram { mem::swap(&mut sup_rule_atoms, &mut collected_atoms); // add the sup rule to the program, this clears all collected atoms - sup_rule_entry.push(MagicRule { + sup_rule_entry.rules.push(MagicRule { head: args.clone(), aggr: vec![None; args.len()], body: sup_rule_atoms, @@ -152,7 +170,7 @@ impl MagicProgram { ) .collect_vec(); let inp_aggr = vec![None; inp_args.len()]; - inp_entry.push(MagicRule { + inp_entry.rules.push(MagicRule { head: inp_args, aggr: inp_aggr, body: vec![sup_rule_app], @@ -166,7 +184,7 @@ impl MagicProgram { } let entry = ret_prog.prog.entry(rule_head.clone()).or_default(); - entry.push(MagicRule { + entry.rules.push(MagicRule { head: rule.head, aggr: rule.aggr, body: collected_atoms, @@ -229,7 +247,9 @@ impl NormalFormProgram { MagicSymbol::Muggle { inner: rule_name.clone(), }, - adorned_rules, + MagicRuleSet { + rules: adorned_rules, + }, ); } @@ -251,7 +271,12 @@ impl NormalFormProgram { rule.adorn(&mut pending_adornment, &rules_to_rewrite, seen_bindings); adorned_rules.push(adorned_rule); } - adorned_prog.prog.insert(head, adorned_rules); + adorned_prog.prog.insert( + head, + MagicRuleSet { + rules: adorned_rules, + }, + ); } adorned_prog } diff --git a/src/query/relation.rs b/src/query/relation.rs index 38f6e656..32bd69f6 100644 --- a/src/query/relation.rs +++ b/src/query/relation.rs @@ -989,7 +989,7 @@ impl TripleRelation { eliminate_indices: BTreeSet, ) -> TupleIter<'a> { // [f, b] where b is not indexed - let throwaway = tx.temp_area(); + let throwaway = tx.new_temp_store(); for item in tx.triple_a_before_scan(self.attr.id, self.vld) { match item { Err(e) => return Box::new([Err(e)].into_iter()), @@ -1528,7 +1528,7 @@ impl InnerJoin { .sorted_by_key(|(_, b)| **b) .map(|(a, _)| a) .collect_vec(); - let throwaway = tx.temp_area(); + let throwaway = tx.new_temp_store(); for item in self.right.iter(tx, epoch, use_delta) { match item { Ok(tuple) => { diff --git a/src/query/stratify.rs b/src/query/stratify.rs index 56cf9951..7c084376 100644 --- a/src/query/stratify.rs +++ b/src/query/stratify.rs @@ -4,8 +4,8 @@ use std::collections::{BTreeMap, BTreeSet}; use anyhow::{ensure, Result}; use itertools::Itertools; -use crate::data::symb::{Symbol, PROG_ENTRY}; use crate::data::program::{NormalFormAtom, NormalFormProgram, StratifiedNormalFormProgram}; +use crate::data::symb::{Symbol, PROG_ENTRY}; use crate::query::graph::{ generalized_kahn, reachable_components, strongly_connected_components, Graph, StratifiedGraph, }; @@ -31,17 +31,45 @@ fn convert_normal_form_program_to_graph( .iter() .map(|(k, ruleset)| { let mut ret: BTreeMap<&Symbol, bool> = BTreeMap::default(); + let has_aggr = ruleset + .iter() + .any(|rule| rule.aggr.iter().any(|a| a.is_some())); + let is_meet = has_aggr + && ruleset.iter().map(|rule| &rule.aggr).all_equal() + && ruleset.iter().all(|rule| { + rule.aggr.iter().all(|v| match v { + None => true, + Some(v) => v.is_meet, + }) + }); for rule in ruleset { for atom in &rule.body { let contained = atom.contained_rules(); - for (found_key, negated) in contained { + for (found_key, is_negated) in contained { match ret.entry(found_key) { Entry::Vacant(e) => { - e.insert(negated); + if has_aggr { + if is_meet && k == found_key { + e.insert(is_negated); + } else { + e.insert(true); + } + } else { + e.insert(is_negated); + } } Entry::Occupied(mut e) => { let old = *e.get(); - e.insert(old || negated); + let new_val = if has_aggr { + if is_meet && k == found_key { + is_negated + } else { + true + } + } else { + is_negated + }; + e.insert(old || new_val); } } } diff --git a/src/runtime/temp_store.rs b/src/runtime/temp_store.rs index 88c8b82a..2b063196 100644 --- a/src/runtime/temp_store.rs +++ b/src/runtime/temp_store.rs @@ -1,12 +1,16 @@ use std::fmt::{Debug, Formatter}; +use anyhow::Result; +use itertools::Itertools; use log::error; use cozorocks::{DbIter, RawRocksDb, RocksDbStatus}; +use crate::data::aggr::Aggregation; use crate::data::program::MagicSymbol; use crate::data::tuple::{EncodedTuple, Tuple}; use crate::data::value::DataValue; +use crate::utils::swap_result_option; #[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq)] pub(crate) struct TempStoreId(pub(crate) u32); @@ -21,9 +25,8 @@ impl Debug for TempStoreId { pub(crate) struct TempStore { pub(crate) db: RawRocksDb, pub(crate) id: TempStoreId, - pub(crate) key_size: usize, - pub(crate) val_size: usize, pub(crate) rule_name: MagicSymbol, + pub(crate) arity: usize, } impl Debug for TempStore { @@ -33,17 +36,169 @@ impl Debug for TempStore { } impl TempStore { + pub(crate) fn aggr_meet_put( + &self, + tuple: &Tuple, + aggrs: &[Option], + epoch: u32, + ) -> anyhow::Result { + let key = Tuple( + aggrs + .iter() + .enumerate() + .map(|(i, ma)| { + if ma.is_none() { + tuple.0[i].clone() + } else { + DataValue::Bottom + } + }) + .collect_vec(), + ); + let key_encoded = key.encode_as_key_for_epoch(self.id, 0); + let prev_aggr = swap_result_option( + self.db + .get(&key_encoded)? + .map(|slice| EncodedTuple(&slice).decode()), + )?; + + if let Some(prev_aggr) = prev_aggr { + let tuple_to_store = Tuple( + aggrs + .iter() + .enumerate() + .map(|(i, aggr)| { + if let Some(aggr_op) = aggr { + let op = aggr_op.combine; + op(&prev_aggr.0[i], &tuple.0[i]) + } else { + Ok(DataValue::Bottom) + } + }) + .try_collect()?, + ); + + if prev_aggr == tuple_to_store { + Ok(false) + } else { + let tuple_data = tuple_to_store.encode_as_key_for_epoch(self.id, 0); + self.db.put(&key_encoded, &tuple_data)?; + if epoch != 0 { + let key_encoded = key.encode_as_key_for_epoch(self.id, epoch); + self.db.put(&key_encoded, &tuple_data)?; + } + Ok(true) + } + } else { + let tuple_to_store = Tuple( + aggrs + .iter() + .enumerate() + .map(|(i, aggr)| { + if let Some(aggr_op) = aggr { + let op = aggr_op.combine; + op(&DataValue::Bottom, &tuple.0[i]) + } else { + Ok(DataValue::Bottom) + } + }) + .try_collect()?, + ); + let tuple_data = tuple_to_store.encode_as_key_for_epoch(self.id, 0); + self.db.put(&key_encoded, &tuple_data)?; + if epoch != 0 { + let key_encoded = key.encode_as_key_for_epoch(self.id, epoch); + self.db.put(&key_encoded, &tuple_data)?; + } + Ok(true) + } + } pub(crate) fn put(&self, tuple: &Tuple, epoch: u32) -> Result<(), RocksDbStatus> { let key_encoded = tuple.encode_as_key_for_epoch(self.id, epoch); self.db.put(&key_encoded, &[]) } + pub(crate) fn normal_aggr_put( + &self, + tuple: &Tuple, + aggrs: &[Option], + ) -> Result<(), RocksDbStatus> { + let mut vals = vec![]; + for (idx, agg) in aggrs.iter().enumerate() { + if agg.is_none() { + vals.push(tuple.0[idx].clone()); + } + } + for (idx, agg) in aggrs.iter().enumerate() { + if agg.is_some() { + vals.push(tuple.0[idx].clone()); + } + } + self.db + .put(&Tuple(vals).encode_as_key_for_epoch(self.id, 0), &[]) + } pub(crate) fn exists(&self, tuple: &Tuple, epoch: u32) -> Result { let key_encoded = tuple.encode_as_key_for_epoch(self.id, epoch); self.db.exists(&key_encoded) } - pub(crate) fn scan_all(&self) -> impl Iterator> { - self.scan_all_for_epoch(0) + + pub(crate) fn normal_aggr_scan_and_put( + &self, + aggrs: &[Option], + store: &TempStore, + ) -> Result<()> { + let (lower, upper) = EncodedTuple::bounds_for_prefix_and_epoch(self.id, 0); + let mut it = self + .db + .iterator() + .upper_bound(&upper) + .prefix_same_as_start(true) + .start(); + it.seek(&lower); + let it = TempStoreIter { it, started: false }; + let aggrs = aggrs.to_vec(); + let key_indices = aggrs + .iter() + .enumerate() + .filter_map(|(i, aggr)| if aggr.is_none() { Some(i) } else { None }) + .collect_vec(); + let grouped = it.group_by(move |t_res| { + if let Ok(tuple) = t_res { + Some( + key_indices + .iter() + .map(|i| tuple.0[*i].clone()) + .collect_vec(), + ) + } else { + None + } + }); + for (key, group) in grouped.into_iter() { + if key.is_some() { + let mut aggr_res = vec![DataValue::Bottom; aggrs.len()]; + for tup_res in group.into_iter() { + let tuple = tup_res.unwrap().0; + for (i, val) in tuple.into_iter().enumerate() { + if let Some(aggr_op) = &aggrs[i] { + aggr_res[i] = (aggr_op.combine)(&aggr_res[i], &val)?; + } else { + aggr_res[i] = val; + } + } + } + for (i, aggr) in aggrs.iter().enumerate() { + if let Some(aggr_op) = aggr { + aggr_res[i] = (aggr_op.combine)(&aggr_res[i], &DataValue::Bottom)?; + } + } + store.put(&Tuple(aggr_res), 0)?; + } else { + return group.into_iter().next().unwrap().map(|_| ()); + } + } + Ok(()) } + pub(crate) fn scan_all_for_epoch( &self, epoch: u32, @@ -56,7 +211,10 @@ impl TempStore { .prefix_same_as_start(true) .start(); it.seek(&lower); - ThrowawayIter { it, started: false } + TempStoreIter { it, started: false } + } + pub(crate) fn scan_all(&self) -> impl Iterator> { + self.scan_all_for_epoch(0) } pub(crate) fn scan_prefix( &self, @@ -81,16 +239,16 @@ impl TempStore { .prefix_same_as_start(true) .start(); it.seek(&lower); - ThrowawayIter { it, started: false } + TempStoreIter { it, started: false } } } -struct ThrowawayIter { +struct TempStoreIter { it: DbIter, started: bool, } -impl Iterator for ThrowawayIter { +impl Iterator for TempStoreIter { type Item = anyhow::Result; fn next(&mut self) -> Option { @@ -102,9 +260,26 @@ impl Iterator for ThrowawayIter { match self.it.pair() { Err(e) => Some(Err(e.into())), Ok(None) => None, - Ok(Some((k_slice, _v_slice))) => match EncodedTuple(k_slice).decode() { + Ok(Some((k_slice, v_slice))) => match EncodedTuple(k_slice).decode() { Err(e) => Some(Err(e)), - Ok(t) => Some(Ok(t)), + Ok(t) => { + if v_slice.len() == 0 { + Some(Ok(t)) + } else { + match EncodedTuple(v_slice).decode() { + Err(e) => Some(Err(e)), + Ok(vt) => Some(Ok(Tuple( + t.0.into_iter() + .zip(vt.0) + .map(|(kv, vv)| match kv { + DataValue::Bottom => vv, + kv => kv, + }) + .collect_vec(), + ))), + } + } + } }, } } diff --git a/src/runtime/transact.rs b/src/runtime/transact.rs index b943d9c9..2e93b66f 100644 --- a/src/runtime/transact.rs +++ b/src/runtime/transact.rs @@ -15,8 +15,8 @@ use crate::data::encode::{ encode_sentinel_attr_by_id, encode_sentinel_entity_attr, encode_tx, EncodedVec, }; use crate::data::id::{AttrId, EntityId, TxId, Validity}; -use crate::data::symb::Symbol; use crate::data::program::MagicSymbol; +use crate::data::symb::Symbol; use crate::data::value::DataValue; use crate::runtime::temp_store::{TempStore, TempStoreId}; @@ -67,31 +67,24 @@ impl TxLog { } impl SessionTx { - pub(crate) fn new_throwaway( - &self, - key_size: usize, - val_size: usize, - rule_name: MagicSymbol, - ) -> TempStore { + pub(crate) fn new_rule_store(&self, rule_name: MagicSymbol, arity: usize) -> TempStore { let old_count = self.temp_store_id.fetch_add(1, Ordering::AcqRel); let old_count = old_count & 0x00ff_ffffu32; TempStore { db: self.temp_store.clone(), id: TempStoreId(old_count), - key_size, - val_size, + arity, rule_name, } } - pub(crate) fn temp_area(&self) -> TempStore { + pub(crate) fn new_temp_store(&self) -> TempStore { let old_count = self.temp_store_id.fetch_add(1, Ordering::AcqRel); let old_count = old_count & 0x00ff_ffffu32; TempStore { db: self.temp_store.clone(), id: TempStoreId(old_count), - key_size: 0, - val_size: 0, + arity: 0, rule_name: MagicSymbol::Muggle { inner: Symbol::from(""), },