diff --git a/cozo-core/src/query/compile.rs b/cozo-core/src/query/compile.rs index 46afa49d..52e01bed 100644 --- a/cozo-core/src/query/compile.rs +++ b/cozo-core/src/query/compile.rs @@ -22,7 +22,6 @@ use crate::data::symb::Symbol; use crate::data::value::DataValue; use crate::parse::SourceSpan; use crate::query::ra::RelAlgebra; -use crate::runtime::in_mem::InMemRelation; use crate::runtime::relation::{AccessLevel, InsufficientAccessLevel}; use crate::runtime::transact::SessionTx; @@ -42,6 +41,12 @@ pub(crate) enum AggrKind { } impl CompiledRuleSet { + pub(crate) fn arity(&self) -> usize { + match self { + CompiledRuleSet::Rules(rs) => rs[0].aggr.len(), + CompiledRuleSet::Algo(algo) => algo.arity, + } + } pub(crate) fn aggr_kind(&self) -> AggrKind { match self { CompiledRuleSet::Rules(rules) => { @@ -94,15 +99,17 @@ impl<'a> SessionTx<'a> { pub(crate) fn stratified_magic_compile( &mut self, prog: &StratifiedMagicProgram, - ) -> Result<(Vec, BTreeMap)> { - let mut stores: BTreeMap = Default::default(); + ) -> Result> { + // let mut stores: BTreeMap = Default::default(); + let mut store_arities: BTreeMap<&MagicSymbol, usize> = Default::default(); for stratum in prog.0.iter() { for (name, ruleset) in &stratum.prog { - stores.insert( - name.clone(), - self.new_rule_store(ruleset.arity()?), - ); + // stores.insert( + // name.clone(), + // self.new_rule_store(ruleset.arity()?), + // ); + store_arities.insert(name, ruleset.arity()?); } } @@ -121,7 +128,7 @@ impl<'a> SessionTx<'a> { for rule in body.iter() { let header = &rule.head; let mut relation = - self.compile_magic_rule_body(rule, k, &stores, header)?; + self.compile_magic_rule_body(rule, k, &store_arities, header)?; relation.fill_binding_indices().with_context(|| { format!( "error encountered when filling binding indices for {:#?}", @@ -145,13 +152,13 @@ impl<'a> SessionTx<'a> { .try_collect() }) .try_collect()?; - Ok((compiled, stores)) + Ok(compiled) } pub(crate) fn compile_magic_rule_body( &mut self, rule: &MagicInlineRule, rule_name: &MagicSymbol, - stores: &BTreeMap, + store_arities: &BTreeMap<&MagicSymbol, usize>, ret_vars: &[Symbol], ) -> Result { let mut ret = RelAlgebra::unit(rule_name.symbol().span); @@ -165,7 +172,7 @@ impl<'a> SessionTx<'a> { for atom in &rule.body { match atom { MagicAtom::Rule(rule_app) => { - let store = stores.get(&rule_app.name).ok_or_else(|| { + let store_arity = store_arities.get(&rule_app.name).ok_or_else(|| { RuleNotFound( rule_app.name.symbol().to_string(), rule_app.name.symbol().span, @@ -173,10 +180,10 @@ impl<'a> SessionTx<'a> { })?; ensure!( - store.arity == rule_app.args.len(), + *store_arity == rule_app.args.len(), ArityMismatch( rule_app.name.symbol().to_string(), - store.arity, + *store_arity, rule_app.args.len(), rule_app.span ) @@ -241,19 +248,17 @@ impl<'a> SessionTx<'a> { ret = ret.join(right, prev_joiner_vars, right_joiner_vars, rel_app.span); } MagicAtom::NegatedRule(rule_app) => { - let store = stores - .get(&rule_app.name) - .ok_or_else(|| { - RuleNotFound( - rule_app.name.symbol().to_string(), - rule_app.name.symbol().span, - ) - })?; + let store_arity = store_arities.get(&rule_app.name).ok_or_else(|| { + RuleNotFound( + rule_app.name.symbol().to_string(), + rule_app.name.symbol().span, + ) + })?; ensure!( - store.arity == rule_app.args.len(), + *store_arity == rule_app.args.len(), ArityMismatch( rule_app.name.symbol().to_string(), - store.arity, + *store_arity, rule_app.args.len(), rule_app.span ) @@ -274,7 +279,8 @@ impl<'a> SessionTx<'a> { } } - let right = RelAlgebra::derived(right_vars, rule_app.name.clone(), rule_app.span); + let right = + RelAlgebra::derived(right_vars, rule_app.name.clone(), rule_app.span); debug_assert_eq!(prev_joiner_vars.len(), right_joiner_vars.len()); ret = ret.neg_join(right, prev_joiner_vars, right_joiner_vars, rule_app.span); } diff --git a/cozo-core/src/query/eval.rs b/cozo-core/src/query/eval.rs index 17bdc256..c3197c89 100644 --- a/cozo-core/src/query/eval.rs +++ b/cozo-core/src/query/eval.rs @@ -52,28 +52,32 @@ impl<'a> SessionTx<'a> { pub(crate) fn stratified_magic_evaluate( &self, strata: &[CompiledProgram], - stores: &BTreeMap, + store_lifetimes: BTreeMap, total_num_to_take: Option, num_to_skip: Option, poison: Poison, ) -> Result<(InMemRelation, bool)> { + let mut stores = BTreeMap::new(); let mut early_return = false; - for (idx, cur_prog) in strata.iter().enumerate() { - debug!("stratum {}", idx); + for (stratum, cur_prog) in strata.iter().enumerate() { + debug!("stratum {}", stratum); + for (rule_name, rule_set) in cur_prog { + stores.insert(rule_name.clone(), self.new_rule_store(rule_set.arity())); + } + early_return = self.semi_naive_magic_evaluate( cur_prog, - stores, + &stores, total_num_to_take, num_to_skip, poison.clone(), )?; } let ret_area = stores - .get(&MagicSymbol::Muggle { + .remove(&MagicSymbol::Muggle { inner: Symbol::new(PROG_ENTRY, SourceSpan(0, 0)), }) - .ok_or(NoEntryError)? - .clone(); + .ok_or(NoEntryError)?; Ok((ret_area, early_return)) } /// returns true if early return is activated diff --git a/cozo-core/src/runtime/db.rs b/cozo-core/src/runtime/db.rs index 2029556a..7bcc4b2f 100644 --- a/cozo-core/src/runtime/db.rs +++ b/cozo-core/src/runtime/db.rs @@ -650,7 +650,7 @@ impl<'s, S: Storage<'s>> Db { let mut tx = self.transact()?; let (stratified_program, _) = prog.to_normalized_program(&tx)?.stratify()?; let program = stratified_program.magic_sets_rewrite(&tx)?; - let (compiled, _) = tx.stratified_magic_compile(&program)?; + let compiled = tx.stratified_magic_compile(&program)?; tx.commit_tx()?; self.explain_compiled(&compiled) } @@ -789,10 +789,10 @@ impl<'s, S: Storage<'s>> Db { }; // query compilation - let (stratified_program, _store_lifetimes) = + let (stratified_program, store_lifetimes) = input_program.to_normalized_program(tx)?.stratify()?; let program = stratified_program.magic_sets_rewrite(tx)?; - let (compiled, stores) = tx.stratified_magic_compile(&program)?; + let compiled = tx.stratified_magic_compile(&program)?; // poison is used to terminate queries early let poison = Poison::default(); @@ -841,7 +841,7 @@ impl<'s, S: Storage<'s>> Db { // the real evaluation let (result, early_return) = tx.stratified_magic_evaluate( &compiled, - &stores, + store_lifetimes, total_num_to_take, num_to_skip, poison,