diff --git a/cozo-core/src/cozoscript.pest b/cozo-core/src/cozoscript.pest index f407c166..d40a3d81 100644 --- a/cozo-core/src/cozoscript.pest +++ b/cozo-core/src/cozoscript.pest @@ -128,8 +128,9 @@ list = { "[" ~ (expr ~ ",")* ~ expr? ~ "]" } grouping = { "(" ~ expr ~ ")" } option = _{(limit_option|offset_option|sort_option|relation_option|timeout_option|sleep_option| - assert_none_option|assert_some_option) ~ ";"?} + assert_none_option|assert_some_option|disable_magic_rewrite_option) ~ ";"?} out_arg = @{var ~ ("(" ~ var ~ ")")?} +disable_magic_rewrite_option = {":disable_magic_rewrite" ~ expr} limit_option = {":limit" ~ expr} offset_option = {":offset" ~ expr} sort_option = {(":sort" | ":order") ~ (sort_arg ~ ",")* ~ sort_arg } diff --git a/cozo-core/src/data/program.rs b/cozo-core/src/data/program.rs index c998762b..a2684c37 100644 --- a/cozo-core/src/data/program.rs +++ b/cozo-core/src/data/program.rs @@ -24,6 +24,7 @@ use crate::data::value::{DataValue, ValidityTs}; use crate::fixed_rule::{FixedRule, FixedRuleHandle}; use crate::fts::FtsIndexManifest; use crate::parse::SourceSpan; +use crate::query::compile::ContainedRuleMultiplicity; use crate::query::logical::{Disjunction, NamedFieldNotFound}; use crate::runtime::hnsw::HnswIndexManifest; use crate::runtime::minhash_lsh::{LshSearch, MinHashLshIndexManifest}; @@ -445,6 +446,7 @@ impl MagicFixedRuleRuleArg { pub(crate) struct InputProgram { pub(crate) prog: BTreeMap, pub(crate) out_opts: QueryOutOptions, + pub(crate) disable_magic_rewrite: bool, } impl Display for InputProgram { @@ -676,7 +678,13 @@ impl InputProgram { } } } - Ok((NormalFormProgram { prog }, self.out_opts)) + Ok(( + NormalFormProgram { + prog, + disable_magic_rewrite: self.disable_magic_rewrite, + }, + self.out_opts, + )) } } @@ -701,6 +709,7 @@ impl NormalFormRulesOrFixed { #[derive(Debug, Default)] pub(crate) struct NormalFormProgram { pub(crate) prog: BTreeMap, + pub(crate) disable_magic_rewrite: bool, } #[derive(Debug)] @@ -874,12 +883,19 @@ pub(crate) struct MagicInlineRule { } impl MagicInlineRule { - pub(crate) fn contained_rules(&self) -> BTreeSet { - let mut coll = BTreeSet::new(); + pub(crate) fn contained_rules(&self) -> BTreeMap { + let mut coll = BTreeMap::new(); for atom in self.body.iter() { match atom { MagicAtom::Rule(rule) | MagicAtom::NegatedRule(rule) => { - coll.insert(rule.name.clone()); + match coll.entry(rule.name.clone()) { + Entry::Vacant(ent) => { + ent.insert(ContainedRuleMultiplicity::One); + } + Entry::Occupied(mut ent) => { + *ent.get_mut() = ContainedRuleMultiplicity::Many; + } + } } _ => {} } diff --git a/cozo-core/src/parse/query.rs b/cozo-core/src/parse/query.rs index 44a28f0c..62588b52 100644 --- a/cozo-core/src/parse/query.rs +++ b/cozo-core/src/parse/query.rs @@ -53,6 +53,11 @@ struct OptionNotNonNegIntError(&'static str, #[label] SourceSpan); #[diagnostic(code(parser::option_not_pos))] struct OptionNotPosIntError(&'static str, #[label] SourceSpan); +#[derive(Error, Diagnostic, Debug)] +#[error("Query option {0} requires a boolean")] +#[diagnostic(code(parser::option_not_bool))] +struct OptionNotBoolError(&'static str, #[label] SourceSpan); + #[derive(Debug)] struct MultipleRuleDefinitionError(String, Vec); @@ -105,6 +110,7 @@ pub(crate) fn parse_query( ) -> Result { let mut progs: BTreeMap = Default::default(); let mut out_opts: QueryOutOptions = Default::default(); + let mut disable_magic_rewrite = false; let mut stored_relation = None; @@ -358,6 +364,16 @@ pub(crate) fn parse_query( ); out_opts.assertion = Some(QueryAssertion::AssertSome(pair.extract_span())) } + Rule::disable_magic_rewrite_option => { + let pair = pair.into_inner().next().unwrap(); + let span = pair.extract_span(); + let val = build_expr(pair, param_pool)? + .eval_to_const() + .map_err(|err| OptionNotConstantError("disable_magic_rewrite", span, [err]))? + .get_bool() + .ok_or(OptionNotBoolError("disable_magic_rewrite", span))?; + disable_magic_rewrite = val; + } Rule::EOI => break, r => unreachable!("{:?}", r), } @@ -366,6 +382,7 @@ pub(crate) fn parse_query( let mut prog = InputProgram { prog: progs, out_opts, + disable_magic_rewrite, }; if prog.prog.is_empty() { @@ -532,12 +549,7 @@ fn parse_atom( let span = src.extract_span(); let mut src = src.into_inner(); src.next().unwrap(); - let inner = parse_atom( - src.next().unwrap(), - param_pool, - cur_vld, - ignored_counter, - )?; + let inner = parse_atom(src.next().unwrap(), param_pool, cur_vld, ignored_counter)?; InputAtom::Negation { inner: inner.into(), span, diff --git a/cozo-core/src/query/compile.rs b/cozo-core/src/query/compile.rs index f7961cc1..47025f46 100644 --- a/cozo-core/src/query/compile.rs +++ b/cozo-core/src/query/compile.rs @@ -77,11 +77,17 @@ impl CompiledRuleSet { } } +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub(crate) enum ContainedRuleMultiplicity { + One, + Many, +} + #[derive(Debug)] pub(crate) struct CompiledRule { pub(crate) aggr: Vec)>>, pub(crate) relation: RelAlgebra, - pub(crate) contained_rules: BTreeSet, + pub(crate) contained_rules: BTreeMap, } #[derive(Debug, Error, Diagnostic)] diff --git a/cozo-core/src/query/eval.rs b/cozo-core/src/query/eval.rs index 746f6546..015e0ccd 100644 --- a/cozo-core/src/query/eval.rs +++ b/cozo-core/src/query/eval.rs @@ -23,7 +23,9 @@ use crate::data::tuple::Tuple; use crate::data::value::DataValue; use crate::fixed_rule::FixedRulePayload; use crate::parse::SourceSpan; -use crate::query::compile::{AggrKind, CompiledProgram, CompiledRule, CompiledRuleSet}; +use crate::query::compile::{ + AggrKind, CompiledProgram, CompiledRule, CompiledRuleSet, ContainedRuleMultiplicity, +}; use crate::runtime::db::Poison; use crate::runtime::temp_store::{EpochStore, MeetAggrStore, RegularTempStore}; use crate::runtime::transact::SessionTx; @@ -212,6 +214,7 @@ impl<'a> SessionTx<'a> { } } } else { + // Follow up epoch > 0 #[allow(clippy::needless_borrow)] let execution = |(k, compiled_ruleset): (_, &CompiledRuleSet)| -> Result<_> { let new_store = match compiled_ruleset { @@ -510,28 +513,31 @@ impl<'a> SessionTx<'a> { limiter: &QueryLimiter, poison: Poison, ) -> Result<(bool, RegularTempStore)> { + // TODO: handle the case where self-join is involved let prev_store = stores.get(rule_symb).unwrap(); let mut out_store = RegularTempStore::default(); let should_check_limit = limiter.total.is_some() && rule_symb.is_prog_entry(); for (rule_n, rule) in ruleset.iter().enumerate() { - let dependencies_changed = rule - .contained_rules - .iter() - .map(|symb| stores.get(symb).unwrap().has_delta()) - .any(|v| v); + let mut need_complete_run = false; + let mut dependencies_changed = false; + + for (symb, multiplicity) in rule.contained_rules.iter() { + if stores.get(symb).unwrap().has_delta() { + dependencies_changed = true; + if *multiplicity == ContainedRuleMultiplicity::Many { + need_complete_run = true; + break; + } + } + } + if !dependencies_changed { continue; } - for (delta_key, _) in stores.iter() { - if !rule.contained_rules.contains(delta_key) { - continue; - } - debug!( - "with delta {:?} for rule {:?}.{}", - delta_key, rule_symb, rule_n - ); - for item_res in rule.relation.iter(self, Some(delta_key), stores)? { + if need_complete_run { + debug!("complete rule for rule {:?}.{}", rule_symb, rule_n); + for item_res in rule.relation.iter(self, None, stores)? { let item = item_res?; // improvement: the clauses can actually be evaluated in parallel if prev_store.exists(&item) { @@ -562,6 +568,47 @@ impl<'a> SessionTx<'a> { } } poison.check()?; + } else { + for (delta_key, _) in stores.iter() { + if !rule.contained_rules.contains_key(delta_key) { + continue; + } + debug!( + "with delta {:?} for rule {:?}.{}", + delta_key, rule_symb, rule_n + ); + for item_res in rule.relation.iter(self, Some(delta_key), stores)? { + let item = item_res?; + // improvement: the clauses can actually be evaluated in parallel + if prev_store.exists(&item) { + trace!( + "item for {:?}.{}: {:?} at {}, rederived", + rule_symb, + rule_n, + item, + epoch + ); + } else { + trace!( + "item for {:?}.{}: {:?} at {}", + rule_symb, + rule_n, + item, + epoch + ); + if limiter.should_skip_next() { + out_store.put_with_skip(item); + } else { + out_store.put(item); + } + if should_check_limit && limiter.incr_and_should_stop() { + trace!("early stopping due to result count limit exceeded"); + return Ok((true, out_store)); + } + } + } + poison.check()?; + } } } Ok((should_check_limit, out_store)) @@ -573,13 +620,22 @@ impl<'a> SessionTx<'a> { stores: &BTreeMap, poison: Poison, ) -> Result { + // TODO handle the case where self-joins are involved let mut out_store = MeetAggrStore::new(ruleset[0].aggr.clone())?; for (rule_n, rule) in ruleset.iter().enumerate() { - let dependencies_changed = rule - .contained_rules - .iter() - .map(|symb| stores.get(symb).unwrap().has_delta()) - .any(|v| v); + let mut need_complete_run = false; + let mut dependencies_changed = false; + + for (symb, multiplicity) in rule.contained_rules.iter() { + if stores.get(symb).unwrap().has_delta() { + dependencies_changed = true; + if *multiplicity == ContainedRuleMultiplicity::Many { + need_complete_run = true; + break; + } + } + } + if !dependencies_changed { continue; } @@ -589,18 +645,26 @@ impl<'a> SessionTx<'a> { aggr.meet_init(args)?; } - for (delta_key, _) in stores.iter() { - if !rule.contained_rules.contains(delta_key) { - continue; - } - debug!( - "with delta {:?} for rule {:?}.{}", - delta_key, rule_symb, rule_n - ); - for item_res in rule.relation.iter(self, Some(delta_key), stores)? { + if need_complete_run { + debug!("complete run for rule {:?}.{}", rule_symb, rule_n); + for item_res in rule.relation.iter(self, None, stores)? { out_store.meet_put(item_res?)?; } poison.check()?; + } else { + for (delta_key, _) in stores.iter() { + if !rule.contained_rules.contains_key(delta_key) { + continue; + } + debug!( + "with delta {:?} for rule {:?}.{}", + delta_key, rule_symb, rule_n + ); + for item_res in rule.relation.iter(self, Some(delta_key), stores)? { + out_store.meet_put(item_res?)?; + } + poison.check()?; + } } } Ok(out_store) diff --git a/cozo-core/src/query/magic.rs b/cozo-core/src/query/magic.rs index ec9cff90..0ab0480f 100644 --- a/cozo-core/src/query/magic.rs +++ b/cozo-core/src/query/magic.rs @@ -30,6 +30,10 @@ use crate::runtime::transact::SessionTx; impl NormalFormProgram { pub(crate) fn exempt_aggr_rules_for_magic_sets(&self, exempt_rules: &mut BTreeSet) { for (name, rule_set) in self.prog.iter() { + if self.disable_magic_rewrite { + exempt_rules.insert(name.clone()); + continue; + } match rule_set { NormalFormRulesOrFixed::Rules { rules: rule_set } => { 'outer: for rule in rule_set.iter() { @@ -629,3 +633,30 @@ impl NormalFormInlineRule { } } } + +#[cfg(test)] +mod tests { + use crate::DbInstance; + use serde_json::json; + + #[test] + fn strange_case() { + let db = DbInstance::new("mem", "", "").unwrap(); + + let query = r#" + x[A] := A = 1 + y[A, A] := A = 1 + y[A, B] := A = 0, B = 1, x[B] + + ?[C] := y[A, _], y[C, A] + + :disable_magic_rewrite true + "#; + + let res = db + .run_script(query, Default::default()) + .unwrap() + .into_json(); + assert_eq!(res["rows"], json!([[0], [1]])); + } +} diff --git a/cozo-core/src/query/stratify.rs b/cozo-core/src/query/stratify.rs index 5e7e2c90..246677e2 100644 --- a/cozo-core/src/query/stratify.rs +++ b/cozo-core/src/query/stratify.rs @@ -263,8 +263,12 @@ impl NormalFormProgram { .flat_map(|(stratum, indices)| indices.into_iter().map(move |idx| (idx, stratum))) .collect::>(); // 7. translate the stratification into datalog program - let mut ret: Vec = - (0..n_strata).map(|_| Default::default()).collect_vec(); + let mut ret: Vec = (0..n_strata) + .map(|_| NormalFormProgram { + prog: BTreeMap::new(), + disable_magic_rewrite: self.disable_magic_rewrite, + }) + .collect_vec(); let mut store_lifetimes = BTreeMap::new(); for (fr, tos) in &stratified_graph {