From 5379c76a70f91ad078b4c7adabd1b63485b99f68 Mon Sep 17 00:00:00 2001 From: Ziyang Hu Date: Wed, 28 Sep 2022 22:05:18 +0800 Subject: [PATCH] algo supports named bindings --- src/cozoscript.pest | 4 ++- src/data/program.rs | 5 ++++ src/parse/query.rs | 26 ++++++++++++++++- src/query/logical.rs | 31 +++++++++++++++++--- src/query/magic.rs | 67 ++++++++++++++++++++++++++++++++++--------- src/query/stratify.rs | 8 ++++-- src/runtime/db.rs | 2 +- 7 files changed, 120 insertions(+), 23 deletions(-) diff --git a/src/cozoscript.pest b/src/cozoscript.pest index 778e34ac..d2d7974f 100644 --- a/src/cozoscript.pest +++ b/src/cozoscript.pest @@ -46,9 +46,11 @@ head_arg = {aggr_arg | var} aggr_arg = {ident ~ "(" ~ var ~ ("," ~ expr)* ~ ")"} algo_arg = _{algo_rel | algo_opt_pair} algo_opt_pair = {ident ~ ":" ~ expr} -algo_rel = {algo_rule_rel | algo_relation_rel } +algo_rel = {algo_rule_rel | algo_relation_rel | algo_named_relation_rel } algo_rule_rel = {ident ~ "[" ~ (var ~ ",")* ~ var? ~ "]"} algo_relation_rel = {relation_ident ~ "[" ~ (var ~ ",")* ~ var? ~ "]"} +algo_named_relation_rel = {relation_ident ~ "{" ~ (algo_named_relation_arg_pair ~ ",")* ~ algo_named_relation_arg_pair? ~ "}"} +algo_named_relation_arg_pair = {ident ~ (":" ~ ident)?} rule_body = {(disjunction ~ ",")* ~ disjunction?} rule_apply = {underscore_ident ~ "[" ~ apply_args ~ "]"} diff --git a/src/data/program.rs b/src/data/program.rs index 2380b2e3..3aa710f8 100644 --- a/src/data/program.rs +++ b/src/data/program.rs @@ -419,6 +419,11 @@ pub(crate) enum AlgoRuleArg { bindings: Vec, span: SourceSpan, }, + NamedStored { + name: Symbol, + bindings: BTreeMap, Symbol>, + span: SourceSpan, + } } #[derive(Debug, Clone)] diff --git a/src/parse/query.rs b/src/parse/query.rs index 4aafe991..fe9ef6d5 100644 --- a/src/parse/query.rs +++ b/src/parse/query.rs @@ -670,7 +670,6 @@ fn parse_atom(src: Pair<'_>, param_pool: &BTreeMap) -> Result }) } - fn parse_rule_head( src: Pair<'_>, param_pool: &BTreeMap, @@ -781,6 +780,31 @@ fn parse_algo_rule( span, }) } + Rule::algo_named_relation_rel => { + let mut els = inner.into_inner(); + let name = els.next().unwrap(); + let bindings = els + .map(|v| { + let mut vs = v.into_inner(); + let kp = vs.next().unwrap(); + let k = SmartString::from(kp.as_str()); + let v = match vs.next() { + Some(vp) => Symbol::new(vp.as_str(), vp.extract_span()), + None => Symbol::new(k.clone(), kp.extract_span()), + }; + (k, v) + }) + .collect(); + + rule_args.push(AlgoRuleArg::NamedStored { + name: Symbol::new( + name.as_str().strip_prefix(':').unwrap(), + name.extract_span(), + ), + bindings, + span, + }) + } _ => unreachable!(), } } diff --git a/src/query/logical.rs b/src/query/logical.rs index d3208cd1..b717696f 100644 --- a/src/query/logical.rs +++ b/src/query/logical.rs @@ -1,14 +1,15 @@ use std::collections::BTreeSet; use itertools::Itertools; -use miette::{bail, Result}; +use miette::{bail, ensure, Diagnostic, Result}; +use thiserror::Error; use crate::data::expr::Expr; use crate::data::program::{ InputAtom, InputNamedFieldRelationApplyAtom, InputRelationApplyAtom, InputRuleApplyAtom, - NormalFormAtom, NormalFormRelationApplyAtom, NormalFormRuleApplyAtom, TempSymbGen, - Unification, + NormalFormAtom, NormalFormRelationApplyAtom, NormalFormRuleApplyAtom, TempSymbGen, Unification, }; +use crate::parse::SourceSpan; use crate::query::reorder::UnsafeNegation; use crate::runtime::transact::SessionTx; @@ -132,6 +133,19 @@ impl InputAtom { tx: &SessionTx, ) -> Result { let stored = tx.get_relation(&name, false)?; + let fields: BTreeSet<_> = stored + .metadata + .keys + .iter() + .chain(stored.metadata.non_keys.iter()) + .map(|col| &col.name) + .collect(); + for k in args.keys() { + ensure!( + fields.contains(k), + NamedFieldNotFound(name.to_string(), k.to_string(), span) + ); + } let mut new_args = vec![]; for col_def in stored .metadata @@ -269,7 +283,7 @@ impl InputRelationApplyAtom { let mut seen_variables = BTreeSet::new(); for arg in self.args { match arg { - Expr::Binding {var, ..} => { + Expr::Binding { var, .. } => { if seen_variables.insert(var.clone()) { args.push(var); } else { @@ -319,3 +333,12 @@ impl InputRelationApplyAtom { Disjunction::conj(ret) } } + +#[derive(Debug, Error, Diagnostic)] +#[error("stored relation '{0}' does not have field '{1}'")] +#[diagnostic(code(eval::named_field_not_found))] +pub(crate) struct NamedFieldNotFound( + pub(crate) String, + pub(crate) String, + #[label] pub(crate) SourceSpan, +); diff --git a/src/query/magic.rs b/src/query/magic.rs index acb77b90..5530b8ae 100644 --- a/src/query/magic.rs +++ b/src/query/magic.rs @@ -2,17 +2,20 @@ use std::collections::BTreeSet; use std::mem; use itertools::Itertools; -use miette::{Result}; +use miette::{ensure, Result}; use smallvec::SmallVec; +use smartstring::SmartString; use crate::data::program::{ - AlgoRuleArg, MagicAlgoApply, MagicAlgoRuleArg, MagicAtom, MagicProgram, - MagicRelationApplyAtom, MagicRule, MagicRuleApplyAtom, MagicRulesOrAlgo, MagicSymbol, - NormalFormAlgoOrRules, NormalFormAtom, NormalFormProgram, NormalFormRule, - StratifiedMagicProgram, StratifiedNormalFormProgram, + AlgoRuleArg, MagicAlgoApply, MagicAlgoRuleArg, MagicAtom, MagicProgram, MagicRelationApplyAtom, + MagicRule, MagicRuleApplyAtom, MagicRulesOrAlgo, MagicSymbol, NormalFormAlgoOrRules, + NormalFormAtom, NormalFormProgram, NormalFormRule, StratifiedMagicProgram, + StratifiedNormalFormProgram, }; use crate::data::symb::{Symbol, PROG_ENTRY}; use crate::parse::SourceSpan; +use crate::query::logical::NamedFieldNotFound; +use crate::runtime::transact::SessionTx; impl NormalFormProgram { pub(crate) fn exempt_aggr_rules_for_magic_sets(&self, exempt_rules: &mut BTreeSet) { @@ -35,14 +38,12 @@ impl NormalFormProgram { } impl StratifiedNormalFormProgram { - pub(crate) fn magic_sets_rewrite( - self, - ) -> Result { + pub(crate) fn magic_sets_rewrite(self, tx: &SessionTx) -> Result { let mut exempt_rules = BTreeSet::from([Symbol::new(PROG_ENTRY, SourceSpan(0, 0))]); let mut collected = vec![]; for prog in self.0 { prog.exempt_aggr_rules_for_magic_sets(&mut exempt_rules); - let adorned = prog.adorn(&exempt_rules)?; + let adorned = prog.adorn(&exempt_rules, tx)?; collected.push(adorned.magic_rewrite()); exempt_rules.extend(prog.get_downstream_rules()); } @@ -273,10 +274,7 @@ impl NormalFormProgram { } downstream_rules } - fn adorn( - &self, - upstream_rules: &BTreeSet, - ) -> Result { + fn adorn(&self, upstream_rules: &BTreeSet, tx: &SessionTx) -> Result { let rules_to_rewrite: BTreeSet<_> = self .prog .keys() @@ -329,6 +327,49 @@ impl NormalFormProgram { bindings: bindings.clone(), span: *span, }, + AlgoRuleArg::NamedStored { + name, + bindings, + span, + } => { + let relation = tx.get_relation(&name, false)?; + let fields: BTreeSet<_> = relation + .metadata + .keys + .iter() + .chain(relation.metadata.non_keys.iter()) + .map(|col| &col.name) + .collect(); + for k in bindings.keys() { + ensure!( + fields.contains(&k), + NamedFieldNotFound( + name.to_string(), + k.to_string(), + *span + ) + ); + } + let new_bindings = relation + .metadata + .keys + .iter() + .chain(relation.metadata.non_keys.iter()) + .enumerate() + .map(|(i, col)| match bindings.get(&col.name) { + None => Symbol::new( + SmartString::from(format!("{}", i)), + Default::default(), + ), + Some(k) => k.clone(), + }) + .collect_vec(); + MagicAlgoRuleArg::Stored { + name: name.clone(), + bindings: new_bindings, + span: *span, + } + } }) }) .try_collect()?, diff --git a/src/query/stratify.rs b/src/query/stratify.rs index 1da40271..c0ca2e0a 100644 --- a/src/query/stratify.rs +++ b/src/query/stratify.rs @@ -126,7 +126,7 @@ fn convert_normal_form_program_to_graph( AlgoRuleArg::InMem { name, .. } => { ret.insert(name, true); } - AlgoRuleArg::Stored { .. } => {} + AlgoRuleArg::Stored { .. } | AlgoRuleArg::NamedStored { .. } => {} } } (k, ret) @@ -149,9 +149,11 @@ fn verify_no_cycle(g: &StratifiedGraph<&'_ Symbol>, sccs: &[BTreeSet<&Symbol>]) #[derive(Debug, Error, Diagnostic)] #[error("Query is unstratifiable")] #[diagnostic(code(eval::unstratifiable))] - #[diagnostic(help("The rule '{0}' is in the strongly connected component {1:?},\n\ + #[diagnostic(help( + "The rule '{0}' is in the strongly connected component {1:?},\n\ and is involved in at least one forbidden dependency \n\ - (negation, non-meet aggregation, or algorithm-application)."))] + (negation, non-meet aggregation, or algorithm-application)." + ))] struct UnStratifiableProgram(String, Vec); ensure!( diff --git a/src/runtime/db.rs b/src/runtime/db.rs index 3a43e341..a3157746 100644 --- a/src/runtime/db.rs +++ b/src/runtime/db.rs @@ -332,7 +332,7 @@ impl Db { let program = input_program .to_normalized_program(tx)? .stratify()? - .magic_sets_rewrite()?; + .magic_sets_rewrite(tx)?; debug!("{:#?}", program); let (compiled, stores) = tx.stratified_magic_compile(&program, &input_program.const_rules)?;