diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml deleted file mode 100644 index 81378f4f..00000000 --- a/.idea/inspectionProfiles/Project_Default.xml +++ /dev/null @@ -1,71 +0,0 @@ - - - - \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index bca968f4..d351d4b2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,7 @@ authors = ["Ziyang Hu"] [dependencies] casey = "0.3.3" +either = "1.7.0" uuid = { version = "1.1.2", features = ["v1", "v4", "serde"] } rand = "0.8.5" anyhow = "1.0" diff --git a/src/data/keyword.rs b/src/data/keyword.rs index 51cebb75..3f5eaa04 100644 --- a/src/data/keyword.rs +++ b/src/data/keyword.rs @@ -1,13 +1,10 @@ use std::fmt::{Debug, Display, Formatter}; -use std::str::Utf8Error; use anyhow::{ensure, Result}; use lazy_static::lazy_static; use serde_derive::{Deserialize, Serialize}; use smartstring::{LazyCompact, SmartString}; -use crate::data::json::JsonValue; - #[derive(Clone, PartialEq, PartialOrd, Eq, Ord, Deserialize, Serialize, Hash)] pub struct Keyword(pub(crate) SmartString); diff --git a/src/data/program.rs b/src/data/program.rs index d8531ed7..b4a627a5 100644 --- a/src/data/program.rs +++ b/src/data/program.rs @@ -1,5 +1,7 @@ -use std::collections::BTreeMap; +use std::collections::{BTreeMap, BTreeSet}; +use anyhow::Result; +use itertools::Itertools; use smallvec::SmallVec; use smartstring::{LazyCompact, SmartString}; @@ -9,29 +11,88 @@ use crate::data::keyword::Keyword; use crate::data::value::DataValue; use crate::{EntityId, Validity}; +#[derive(Default)] +pub(crate) struct TempKwGen { + last_id: u32, +} + +impl TempKwGen { + pub(crate) fn next(&mut self) -> Keyword { + self.last_id += 1; + Keyword::from(&format!("*{}", self.last_id) as &str) + } +} + #[derive(Clone, Debug, Default)] pub enum Aggregation { #[default] Todo, } +#[derive(Debug, Clone)] pub(crate) struct InputProgram { - prog: BTreeMap>, + pub(crate) prog: BTreeMap>, +} + +impl InputProgram { + pub(crate) fn to_normalized_program(self) -> Result { + let mut prog: BTreeMap<_, _> = Default::default(); + for (k, rules) in self.prog { + let mut collected_rules = vec![]; + for rule in rules { + let normalized_body = + InputAtom::Conjunction(rule.body).disjunctive_normal_form()?; + for conj in normalized_body.0 { + let normalized_rule = NormalFormRule { + head: rule.head.clone(), + aggr: rule.aggr.clone(), + body: conj.0, + vld: rule.vld, + }; + collected_rules.push(normalized_rule); + } + } + prog.insert(k, collected_rules); + } + Ok(NormalFormProgram { prog }) + } +} + +#[derive(Debug, Clone)] +pub(crate) struct StratifiedNormalFormProgram(pub(crate) Vec); + +impl StratifiedNormalFormProgram { + pub(crate) fn magic_sets_rewrite(self) -> Result { + Ok(StratifiedMagicProgram( + self.0 + .into_iter() + .map(|p| p.magic_sets_rewrite()) + .try_collect()?, + )) + } +} + +#[derive(Debug, Clone, Default)] +pub(crate) struct NormalFormProgram { + pub(crate) prog: BTreeMap>, } -pub(crate) struct StratifiedNormalFormProgram(Vec); - -pub(crate) struct NormalFormProgram { - prog: BTreeMap>, +impl NormalFormProgram { + pub(crate) fn magic_sets_rewrite(self) -> Result { + todo!() + } } +#[derive(Debug, Clone)] pub(crate) struct StratifiedMagicProgram(Vec); +#[derive(Debug, Clone)] pub(crate) struct MagicProgram { prog: BTreeMap>, keep_rules: Vec, } +#[derive(Clone, Debug)] enum MagicKeyword { Muggle { name: SmartString, @@ -51,27 +112,31 @@ enum MagicKeyword { }, } +#[derive(Debug, Clone)] pub(crate) struct InputRule { - head: Vec, - aggr: Vec>, - body: Vec, - vld: Validity, + pub(crate) head: Vec, + pub(crate) aggr: Vec>, + pub(crate) body: Vec, + pub(crate) vld: Validity, } +#[derive(Debug, Clone)] pub(crate) struct NormalFormRule { - head: Vec, - aggr: Vec>, - body: Vec, - vld: Validity, + pub(crate) head: Vec, + pub(crate) aggr: Vec>, + pub(crate) body: Vec, + pub(crate) vld: Validity, } +#[derive(Debug, Clone)] pub(crate) struct MagicRule { - head: Vec, - aggr: Vec>, - body: Vec, - vld: Validity, + pub(crate) head: Vec, + pub(crate) aggr: Vec>, + pub(crate) body: Vec, + pub(crate) vld: Validity, } +#[derive(Debug, Clone)] pub(crate) enum InputAtom { AttrTriple(InputAttrTripleAtom), Rule(InputRuleApplyAtom), @@ -82,14 +147,23 @@ pub(crate) enum InputAtom { Unification(Unification), } +impl InputAtom { + pub(crate) fn is_negation(&self) -> bool { + matches!(self, InputAtom::Negation(_)) + } +} + +#[derive(Debug, Clone)] pub(crate) enum NormalFormAtom { AttrTriple(NormalFormAttrTripleAtom), Rule(NormalFormRuleApplyAtom), + NegatedAttrTriple(NormalFormAttrTripleAtom), + NegatedRule(NormalFormRuleApplyAtom), Predicate(Expr), - Negation(Box), Unification(Unification), } +#[derive(Debug, Clone)] pub(crate) enum MagicAtom { AttrTriple(MagicAttrTripleAtom), Rule(MagicRuleApplyAtom), @@ -105,12 +179,14 @@ pub struct InputAttrTripleAtom { pub(crate) value: InputTerm, } +#[derive(Debug, Clone)] pub struct NormalFormAttrTripleAtom { - attr: Attribute, - entity: Keyword, - value: Keyword, + pub(crate) attr: Attribute, + pub(crate) entity: Keyword, + pub(crate) value: Keyword, } +#[derive(Debug, Clone)] pub(crate) struct MagicAttrTripleAtom { attr: Attribute, entity: Keyword, @@ -125,11 +201,13 @@ pub struct InputRuleApplyAtom { pub(crate) args: Vec>, } +#[derive(Clone, Debug)] pub struct NormalFormRuleApplyAtom { - name: Keyword, - args: Vec, + pub(crate) name: Keyword, + pub(crate) args: Vec, } +#[derive(Clone, Debug)] pub(crate) struct MagicRuleApplyAtom { name: MagicKeyword, args: Vec, @@ -141,7 +219,8 @@ pub enum InputTerm { Const(T), } +#[derive(Clone, Debug)] pub struct Unification { - binding: Keyword, - expr: Expr, + pub(crate) binding: Keyword, + pub(crate) expr: Expr, } diff --git a/src/parse/query.rs b/src/parse/query.rs index e679bad1..4ee2dac4 100644 --- a/src/parse/query.rs +++ b/src/parse/query.rs @@ -3,22 +3,172 @@ use std::collections::{BTreeMap, BTreeSet}; use anyhow::{anyhow, bail, ensure, Result}; use itertools::Itertools; -use serde_json::Map; +use serde_json::{json, Map}; use crate::data::attr::Attribute; use crate::data::expr::{get_op, Expr}; use crate::data::json::JsonValue; use crate::data::keyword::{Keyword, PROG_ENTRY}; +use crate::data::program::{InputAtom, InputAttrTripleAtom, InputProgram, InputRule, InputRuleApplyAtom, InputTerm, NormalFormProgram}; use crate::data::value::DataValue; use crate::query::compile::{ Atom, AttrTripleAtom, BindingHeadTerm, DatalogProgram, Rule, RuleApplyAtom, RuleSet, Term, }; use crate::query::magic::magic_sets_rewrite; +use crate::query::pull::PullSpecs; use crate::runtime::transact::SessionTx; +use crate::utils::{swap_option_result, swap_result_option}; use crate::{EntityId, Validity}; +pub(crate) type OutSpec = (Vec<(usize, Option)>, Option>); + impl SessionTx { - pub fn parse_rule_sets( + pub(crate) fn parse_query( + &mut self, + payload: &JsonValue, + ) -> Result<(InputProgram, Option)> { + let vld = match payload.get("since") { + None => Validity::current(), + Some(v) => Validity::try_from(v)?, + }; + let q = payload + .get("q") + .ok_or_else(|| anyhow!("expect field 'q' in query {}", payload))?; + let rules_payload = q + .as_array() + .ok_or_else(|| anyhow!("expect field 'q' to be an array in query {}", payload))?; + ensure!(!rules_payload.is_empty(), "no rules in {}", payload); + let input_prog = if rules_payload.first().unwrap().is_array() { + let q = json!([{"rule": "?", "args": rules_payload}]); + self.parse_input_rule_sets(&q, vld)? + } else { + self.parse_input_rule_sets(q, vld)? + }; + let entry_bindings = &input_prog + .prog + .get(&PROG_ENTRY) + .ok_or_else(|| anyhow!("program has no entry point"))? + .first() + .unwrap() + .head; + let out_spec = payload + .get("out") + .map(|spec| self.parse_query_out_spec(spec, entry_bindings)); + let out_spec = swap_result_option(out_spec)?; + Ok((input_prog, out_spec)) + } + fn parse_query_out_spec( + &mut self, + payload: &JsonValue, + entry_bindings: &[Keyword], + ) -> Result { + match payload { + JsonValue::Object(out_spec_map) => { + let out_spec = out_spec_map.values().cloned().collect_vec(); + let pull_specs = self.parse_pull_specs_for_query_spec(&out_spec, entry_bindings)?; + let map_keys = out_spec_map.keys().cloned().collect_vec(); + Ok((pull_specs, Some(map_keys))) + } + JsonValue::Array(out_spec) => { + let pull_specs = self.parse_pull_specs_for_query_spec(out_spec, entry_bindings)?; + Ok((pull_specs, None)) + } + v => bail!("out spec should be an array, found {}", v), + } + } + + pub(crate) fn parse_pull_specs_for_query_spec( + &mut self, + out_spec: &Vec, + entry_bindings: &[Keyword], + ) -> Result)>> { + let entry_bindings: BTreeMap<_, _> = entry_bindings + .iter() + .enumerate() + .map(|(i, h)| (h, i)) + .collect(); + out_spec + .iter() + .map(|spec| -> Result<(usize, Option)> { + match spec { + JsonValue::String(s) => { + let kw = Keyword::from(s as &str); + let idx = *entry_bindings + .get(&kw) + .ok_or_else(|| anyhow!("binding {} not found", kw))?; + Ok((idx, None)) + } + JsonValue::Object(m) => { + let kw = m + .get("pull") + .ok_or_else(|| anyhow!("expect field 'pull' in {:?}", m))? + .as_str() + .ok_or_else(|| anyhow!("expect 'pull' to be a binding in {:?}", m))?; + let kw = Keyword::from(kw); + let idx = *entry_bindings + .get(&kw) + .ok_or_else(|| anyhow!("binding {} not found", kw))?; + let spec = m + .get("spec") + .ok_or_else(|| anyhow!("expect field 'spec' in {:?}", m))?; + let specs = self.parse_pull(spec, 0)?; + Ok((idx, Some(specs))) + } + v => bail!("expect binding or map, got {:?}", v), + } + }) + .try_collect() + } + + pub(crate) fn parse_input_rule_sets( + &mut self, + payload: &JsonValue, + default_vld: Validity, + ) -> Result { + let rules = payload + .as_array() + .ok_or_else(|| anyhow!("expect array for rules, got {}", payload))? + .iter() + .map(|o| self.parse_input_rule_definition(o, default_vld)); + let mut collected: BTreeMap> = BTreeMap::new(); + for res in rules { + let (name, rule) = res?; + match collected.entry(name) { + Entry::Vacant(e) => { + e.insert(vec![rule]); + } + Entry::Occupied(mut e) => { + e.get_mut().push(rule); + } + } + } + let ret: BTreeMap> = collected + .into_iter() + .map(|(name, rules)| -> Result<(Keyword, Vec)> { + let mut arities = rules.iter().map(|r| r.head.len()); + let arity = arities.next().unwrap(); + for other in arities { + if other != arity { + bail!("arity mismatch for rules under the name of {}", name); + } + } + Ok((name, rules)) + }) + .try_collect()?; + + match ret.get(&PROG_ENTRY as &Keyword) { + None => bail!("no entry defined for datalog program"), + Some(ruleset) => { + if !ruleset.iter().map(|r| &r.head).all_equal() { + bail!("all heads for the entry query must be identical"); + } else { + Ok(InputProgram { prog: ret }) + } + } + } + } + + pub(crate) fn parse_rule_sets( &mut self, payload: &JsonValue, default_vld: Validity, @@ -71,6 +221,18 @@ impl SessionTx { } } } + fn parse_input_predicate_atom(payload: &Map) -> Result { + let mut pred = Self::parse_expr(payload)?; + if let Expr::Apply(op, _) = &pred { + ensure!( + op.is_predicate, + "non-predicate expression in predicate position: {}", + op.name + ); + } + pred.partial_eval()?; + Ok(InputAtom::Predicate(pred)) + } fn parse_predicate_atom(payload: &Map) -> Result { let mut pred = Self::parse_expr(payload)?; if let Expr::Apply(op, _) = &pred { @@ -143,6 +305,52 @@ impl SessionTx { v => Ok(Expr::Const(v.into())), } } + fn parse_input_rule_atom( + &mut self, + payload: &Map, + vld: Validity, + ) -> Result { + let rule_name = payload + .get("rule") + .ok_or_else(|| anyhow!("expect key 'rule' in rule atom"))? + .as_str() + .ok_or_else(|| anyhow!("expect value for key 'rule' to be a string"))? + .into(); + let args = payload + .get("args") + .ok_or_else(|| anyhow!("expect key 'args' in rule atom"))? + .as_array() + .ok_or_else(|| anyhow!("expect value for key 'args' to be an array"))? + .iter() + .map(|value_rep| -> Result> { + if let Some(s) = value_rep.as_str() { + let var = Keyword::from(s); + if s.starts_with(['?', '_']) { + return Ok(InputTerm::Var(var)); + } else { + ensure!( + !var.is_reserved(), + "{} is a reserved string value and must be quoted", + s + ) + } + } + if let Some(o) = value_rep.as_object() { + return if let Some(c) = o.get("const") { + Ok(InputTerm::Const(c.into())) + } else { + let eid = self.parse_eid_from_map(o, vld)?; + Ok(InputTerm::Const(DataValue::EnId(eid))) + }; + } + Ok(InputTerm::Const(value_rep.into())) + }) + .try_collect()?; + Ok(InputAtom::Rule(InputRuleApplyAtom { + name: rule_name, + args, + })) + } fn parse_rule_atom(&mut self, payload: &Map, vld: Validity) -> Result { let rule_name = payload .get("rule") @@ -186,6 +394,65 @@ impl SessionTx { adornment: None, })) } + fn parse_input_rule_definition( + &mut self, + payload: &JsonValue, + default_vld: Validity, + ) -> Result<(Keyword, InputRule)> { + let rule_name = payload + .get("rule") + .ok_or_else(|| anyhow!("expect key 'rule' in rule definition"))?; + let rule_name = Keyword::try_from(rule_name)?; + if !rule_name.is_prog_entry() { + rule_name.validate_not_reserved()?; + } + let vld = payload + .get("at") + .map(Validity::try_from) + .unwrap_or(Ok(default_vld))?; + let args = payload + .get("args") + .ok_or_else(|| anyhow!("expect key 'args' in rule definition"))? + .as_array() + .ok_or_else(|| anyhow!("expect value for key 'args' to be an array"))?; + let mut args = args.iter(); + let rule_head_payload = args + .next() + .ok_or_else(|| anyhow!("expect value for key 'args' to be a non-empty array"))?; + let rule_head_vec = rule_head_payload + .as_array() + .ok_or_else(|| anyhow!("expect rule head to be an array, got {}", rule_head_payload))?; + let mut rule_head = vec![]; + let mut rule_aggr = vec![]; + for head_item in rule_head_vec { + if let Some(s) = head_item.as_str() { + rule_head.push(Keyword::from(s)); + rule_aggr.push(None); + } else { + todo!() + } + } + let rule_body: Vec = args + .map(|el| self.parse_input_atom(el, default_vld)) + .try_collect()?; + + ensure!( + rule_head.len() == rule_head.iter().collect::>().len(), + "duplicate variables in rule head: {:?}", + rule_head + ); + + Ok(( + rule_name, + InputRule { + head: rule_head, + aggr: rule_aggr, + body: rule_body, + vld, + }, + )) + } + fn parse_rule_definition( &mut self, payload: &JsonValue, @@ -335,6 +602,37 @@ impl SessionTx { Ok(ret) } + fn parse_input_atom(&mut self, payload: &JsonValue, vld: Validity) -> Result { + match payload { + JsonValue::Array(arr) => match arr as &[JsonValue] { + [entity_rep, attr_rep, value_rep] => { + self.parse_input_triple_atom(entity_rep, attr_rep, value_rep, vld) + } + _ => unimplemented!(), + }, + JsonValue::Object(map) => { + if map.contains_key("rule") { + self.parse_input_rule_atom(map, vld) + } else if map.contains_key("pred") { + Self::parse_input_predicate_atom(map) + } else if map.contains_key("conj") + || map.contains_key("disj") + || map.contains_key("not_exists") + { + ensure!( + map.len() == 1, + "arity mismatch for atom definition {:?}: expect only one key", + map + ); + self.parse_input_logical_atom(map, vld) + } else { + bail!("unexpected atom definition {:?}", map); + } + } + v => bail!("expected atom definition {:?}", v), + } + } + fn parse_atom(&mut self, payload: &JsonValue, vld: Validity) -> Result { match payload { JsonValue::Array(arr) => match arr as &[JsonValue] { @@ -365,6 +663,29 @@ impl SessionTx { v => bail!("expected atom definition {:?}", v), } } + fn parse_input_logical_atom(&mut self, map: &Map, vld: Validity) -> Result { + let (k, v) = map.iter().next().unwrap(); + Ok(match k as &str { + "not_exists" => { + let arg = self.parse_input_atom(v, vld)?; + InputAtom::Negation(Box::new(arg)) + } + n @ ("conj" | "disj") => { + let args = v + .as_array() + .ok_or_else(|| anyhow!("expect array argument for atom {}", n))? + .iter() + .map(|a| self.parse_input_atom(a, vld)) + .try_collect()?; + if k == "conj" { + InputAtom::Conjunction(args) + } else { + InputAtom::Disjunction(args) + } + } + _ => unreachable!(), + }) + } fn parse_logical_atom(&mut self, map: &Map, vld: Validity) -> Result { let (k, v) = map.iter().next().unwrap(); Ok(match k as &str { @@ -388,6 +709,22 @@ impl SessionTx { _ => unreachable!(), }) } + fn parse_input_triple_atom( + &mut self, + entity_rep: &JsonValue, + attr_rep: &JsonValue, + value_rep: &JsonValue, + vld: Validity, + ) -> Result { + let entity = self.parse_input_triple_atom_entity(entity_rep, vld)?; + let attr = self.parse_triple_atom_attr(attr_rep)?; + let value = self.parse_input_triple_clause_value(value_rep, &attr, vld)?; + Ok(InputAtom::AttrTriple(InputAttrTripleAtom { + attr, + entity, + value, + })) + } fn parse_triple_atom( &mut self, entity_rep: &JsonValue, @@ -445,6 +782,30 @@ impl SessionTx { let value = attr.val_type.coerce_value(v.into())?; Ok(value) } + fn parse_input_triple_clause_value( + &mut self, + value_rep: &JsonValue, + attr: &Attribute, + vld: Validity, + ) -> Result> { + if let Some(s) = value_rep.as_str() { + let var = Keyword::from(s); + if s.starts_with(['?', '_']) { + return Ok(InputTerm::Var(var)); + } else { + ensure!(!var.is_reserved(), "reserved string {} must be quoted", s); + } + } + if let Some(o) = value_rep.as_object() { + return if attr.val_type.is_ref_type() { + let eid = self.parse_eid_from_map(o, vld)?; + Ok(InputTerm::Const(DataValue::EnId(eid))) + } else { + Ok(InputTerm::Const(self.parse_value_from_map(o, attr)?)) + }; + } + Ok(InputTerm::Const(attr.val_type.coerce_value(value_rep.into())?)) + } fn parse_triple_clause_value( &mut self, value_rep: &JsonValue, @@ -469,6 +830,28 @@ impl SessionTx { } Ok(Term::Const(attr.val_type.coerce_value(value_rep.into())?)) } + fn parse_input_triple_atom_entity( + &mut self, + entity_rep: &JsonValue, + vld: Validity, + ) -> Result> { + if let Some(s) = entity_rep.as_str() { + let var = Keyword::from(s); + if s.starts_with(['?', '_']) { + return Ok(InputTerm::Var(var)); + } else { + ensure!(!var.is_reserved(), "reserved string {} must be quoted", s); + } + } + if let Some(u) = entity_rep.as_u64() { + return Ok(InputTerm::Const(EntityId(u))); + } + if let Some(o) = entity_rep.as_object() { + let eid = self.parse_eid_from_map(o, vld)?; + return Ok(InputTerm::Const(eid)); + } + todo!() + } fn parse_triple_atom_entity( &mut self, entity_rep: &JsonValue, diff --git a/src/query/logical.rs b/src/query/logical.rs index a7ba88b8..5b0c1e45 100644 --- a/src/query/logical.rs +++ b/src/query/logical.rs @@ -1,6 +1,226 @@ +use anyhow::{bail, Result}; use itertools::Itertools; +use crate::data::expr::Expr; +use crate::data::program::{ + InputAtom, InputAttrTripleAtom, InputRuleApplyAtom, InputTerm, NormalFormAtom, + NormalFormAttrTripleAtom, NormalFormRuleApplyAtom, TempKwGen, Unification, +}; +use crate::data::value::DataValue; use crate::query::compile::Atom; +use crate::EntityId; + +pub(crate) struct Disjunction(pub(crate) Vec); + +impl Disjunction { + fn conjunctive_to_disjunctive_de_morgen(self, other: Self) -> Self { + // invariants: self and other are both already in disjunctive normal form, which are to be conjuncted together + // the return value must be in disjunctive normal form + let mut ret = vec![]; + let right_args = other.0.into_iter().map(|a| a.0).collect_vec(); + for left in self.0 { + let left = left.0; + for right in &right_args { + let mut current = left.clone(); + current.extend_from_slice(right); + ret.push(Conjunction(current)) + } + } + Disjunction(ret) + } + fn singlet(atom: NormalFormAtom) -> Self { + Disjunction(vec![Conjunction(vec![atom])]) + } + fn conj(atoms: Vec) -> Self { + Disjunction(vec![Conjunction(atoms)]) + } +} + +pub(crate) struct Conjunction(pub(crate) Vec); + +impl InputAtom { + pub(crate) fn negation_normal_form(self) -> Result { + Ok(match self { + a @ (InputAtom::AttrTriple(_) | InputAtom::Rule(_) | InputAtom::Predicate(_)) => a, + InputAtom::Conjunction(args) => InputAtom::Conjunction( + args.into_iter() + .map(|a| a.negation_normal_form()) + .try_collect()?, + ), + InputAtom::Disjunction(args) => InputAtom::Disjunction( + args.into_iter() + .map(|a| a.negation_normal_form()) + .try_collect()?, + ), + InputAtom::Unification(unif) => InputAtom::Unification(unif), + InputAtom::Negation(arg) => match *arg { + a @ (InputAtom::AttrTriple(_) | InputAtom::Rule(_)) => { + InputAtom::Negation(Box::new(a)) + } + InputAtom::Predicate(p) => InputAtom::Predicate(p.negate()), + InputAtom::Negation(inner) => inner.negation_normal_form()?, + InputAtom::Conjunction(args) => InputAtom::Disjunction( + args.into_iter() + .map(|a| InputAtom::Negation(Box::new(a)).negation_normal_form()) + .try_collect()?, + ), + InputAtom::Disjunction(args) => InputAtom::Conjunction( + args.into_iter() + .map(|a| InputAtom::Negation(Box::new(a)).negation_normal_form()) + .try_collect()?, + ), + InputAtom::Unification(unif) => { + bail!("unification not allowed in negation: {:?}", unif) + } + }, + }) + } + + pub(crate) fn disjunctive_normal_form(self) -> Result { + let mut gen = TempKwGen::default(); + self.negation_normal_form()? + .do_disjunctive_normal_form(&mut gen) + } + + fn do_disjunctive_normal_form(self, gen: &mut TempKwGen) -> Result { + // invariants: the input is already in negation normal form + // the return value is a disjunction of conjunctions, with no nesting + Ok(match self { + InputAtom::Disjunction(args) => { + let mut ret = vec![]; + for arg in args { + for a in arg.do_disjunctive_normal_form(gen)?.0 { + ret.push(a); + } + } + Disjunction(ret) + } + InputAtom::Conjunction(args) => { + let mut args = args.into_iter().map(|a| a.do_disjunctive_normal_form(gen)); + let mut result = args.next().unwrap()?; + for a in args { + result = result.conjunctive_to_disjunctive_de_morgen(a?) + } + result + } + InputAtom::AttrTriple(a) => Disjunction::conj(a.normalize(false, gen)), + InputAtom::Rule(r) => Disjunction::conj(r.normalize(false, gen)), + InputAtom::Predicate(mut p) => { + p.partial_eval()?; + Disjunction::singlet(NormalFormAtom::Predicate(p)) + } + InputAtom::Negation(n) => match *n { + InputAtom::Rule(r) => Disjunction::conj(r.normalize(true, gen)), + InputAtom::AttrTriple(r) => Disjunction::conj(r.normalize(true, gen)), + _ => unreachable!(), + }, + InputAtom::Unification(u) => Disjunction::singlet(NormalFormAtom::Unification(u)), + }) + } +} + +impl InputRuleApplyAtom { + fn normalize(mut self, is_negated: bool, gen: &mut TempKwGen) -> Vec { + let mut ret = Vec::with_capacity(self.args.len() + 1); + let mut args = Vec::with_capacity(self.args.len()); + for arg in self.args { + match arg { + InputTerm::Var(kw) => args.push(kw), + InputTerm::Const(val) => { + let kw = gen.next(); + args.push(kw.clone()); + let unif = NormalFormAtom::Unification(Unification { + binding: kw, + expr: Expr::Const(val), + }); + ret.push(unif) + } + } + } + + ret.push(if is_negated { + NormalFormAtom::NegatedRule(NormalFormRuleApplyAtom { + name: self.name, + args, + }) + } else { + NormalFormAtom::Rule(NormalFormRuleApplyAtom { + name: self.name, + args, + }) + }); + ret + } +} + +impl InputAttrTripleAtom { + fn normalize(mut self, is_negated: bool, gen: &mut TempKwGen) -> Vec { + let wrap = |atom| { + if is_negated { + NormalFormAtom::NegatedAttrTriple(atom) + } else { + NormalFormAtom::AttrTriple(atom) + } + }; + match (self.entity, self.value) { + (InputTerm::Const(eid), InputTerm::Const(val)) => { + let ekw = gen.next(); + let vkw = gen.next(); + let atom = NormalFormAttrTripleAtom { + attr: self.attr, + entity: ekw.clone(), + value: vkw.clone(), + }; + let ret = wrap(atom); + let ue = NormalFormAtom::Unification(Unification { + binding: ekw, + expr: Expr::Const(DataValue::EnId(eid)), + }); + let uv = NormalFormAtom::Unification(Unification { + binding: vkw, + expr: Expr::Const(val), + }); + vec![ue, uv, ret] + } + (InputTerm::Var(ekw), InputTerm::Const(val)) => { + let vkw = gen.next(); + let atom = NormalFormAttrTripleAtom { + attr: self.attr, + entity: ekw, + value: vkw.clone(), + }; + let ret = wrap(atom); + let uv = NormalFormAtom::Unification(Unification { + binding: vkw, + expr: Expr::Const(val), + }); + vec![uv, ret] + } + (InputTerm::Const(eid), InputTerm::Var(vkw)) => { + let ekw = gen.next(); + let atom = NormalFormAttrTripleAtom { + attr: self.attr, + entity: ekw.clone(), + value: vkw, + }; + let ret = wrap(atom); + let ue = NormalFormAtom::Unification(Unification { + binding: ekw, + expr: Expr::Const(DataValue::EnId(eid)), + }); + vec![ue, ret] + } + (InputTerm::Var(ekw), InputTerm::Var(vkw)) => { + let ret = wrap(NormalFormAttrTripleAtom { + attr: self.attr, + entity: ekw, + value: vkw, + }); + vec![ret] + } + } + } +} impl Atom { pub(crate) fn negation_normal_form(self) -> Self { diff --git a/src/query/mod.rs b/src/query/mod.rs index 7a74dbad..b33e386a 100644 --- a/src/query/mod.rs +++ b/src/query/mod.rs @@ -140,7 +140,7 @@ impl SessionTx { Some(v) => bail!("out spec should be an array, found {}", v), } } - fn parse_pull_specs_for_query( + pub(crate) fn parse_pull_specs_for_query( &mut self, out_spec: &Vec, prog: &DatalogProgram, diff --git a/src/query/stratify.rs b/src/query/stratify.rs index e1bd2679..c3512409 100644 --- a/src/query/stratify.rs +++ b/src/query/stratify.rs @@ -5,11 +5,25 @@ use anyhow::{ensure, Result}; use itertools::Itertools; use crate::data::keyword::{Keyword, PROG_ENTRY}; +use crate::data::program::{NormalFormAtom, NormalFormProgram, StratifiedNormalFormProgram}; use crate::query::compile::{Atom, DatalogProgram}; use crate::query::graph::{ generalized_kahn, reachable_components, strongly_connected_components, Graph, StratifiedGraph, }; +impl NormalFormAtom { + fn contained_rules(&self) -> BTreeMap<&Keyword, bool> { + match self { + NormalFormAtom::AttrTriple(_) + | NormalFormAtom::Predicate(_) + | NormalFormAtom::Unification(_) + | NormalFormAtom::NegatedAttrTriple(_) => Default::default(), + NormalFormAtom::Rule(r) => BTreeMap::from([(&r.name, false)]), + NormalFormAtom::NegatedRule(r) => BTreeMap::from([(&r.name, true)]), + } + } +} + impl Atom { fn contained_rules(&self) -> BTreeMap<&Keyword, bool> { match self { @@ -22,26 +36,40 @@ impl Atom { .collect(), Atom::Conjunction(_args) | Atom::Disjunction(_args) => { panic!("expect program in disjunctive normal form"); - // let mut ret: BTreeMap<&Keyword, bool> = Default::default(); - // for arg in args { - // for (k, v) in arg.contained_rules() { - // match ret.entry(k) { - // Entry::Vacant(e) => { - // e.insert(v); - // } - // Entry::Occupied(mut e) => { - // let old = *e.get(); - // e.insert(old || v); - // } - // } - // } - // } - // ret } } } } +fn convert_normal_form_program_to_graph( + nf_prog: &NormalFormProgram, +) -> StratifiedGraph<&'_ Keyword> { + nf_prog + .prog + .iter() + .map(|(k, ruleset)| { + let mut ret: BTreeMap<&Keyword, bool> = BTreeMap::default(); + for rule in ruleset { + for atom in &rule.body { + let contained = atom.contained_rules(); + for (found_key, negated) in contained { + match ret.entry(found_key) { + Entry::Vacant(e) => { + e.insert(negated); + } + Entry::Occupied(mut e) => { + let old = *e.get(); + e.insert(old || negated); + } + } + } + } + } + (k, ret) + }) + .collect() +} + fn convert_program_to_graph(prog: &DatalogProgram) -> StratifiedGraph<&'_ Keyword> { prog.iter() .map(|(k, ruleset)| { @@ -93,11 +121,11 @@ fn verify_no_cycle(g: &StratifiedGraph<&'_ Keyword>, sccs: &[BTreeSet<&Keyword>] fn make_scc_reduced_graph<'a>( sccs: &[BTreeSet<&'a Keyword>], graph: &StratifiedGraph<&Keyword>, -) -> (BTreeMap<&'a Keyword, usize>, StratifiedGraph) { +) -> (BTreeMap, StratifiedGraph) { let indices = sccs .iter() .enumerate() - .flat_map(|(idx, scc)| scc.iter().map(move |k| (*k, idx))) + .flat_map(|(idx, scc)| scc.iter().map(move |k| ((*k).clone(), idx))) .collect::>(); let mut ret: BTreeMap> = Default::default(); for (from, tos) in graph { @@ -122,6 +150,120 @@ fn make_scc_reduced_graph<'a>( (indices, ret) } +impl NormalFormProgram { + pub(crate) fn stratify(self) -> Result { + // prerequisite: the program is already in disjunctive normal form + // 0. build a graph of the program + let prog_entry: &Keyword = &PROG_ENTRY; + let stratified_graph = convert_normal_form_program_to_graph(&self); + let graph = reduce_to_graph(&stratified_graph); + ensure!( + graph.contains_key(prog_entry), + "program graph does not have an entry" + ); + + // 1. find reachable clauses starting from the query + let reachable: BTreeSet<_> = reachable_components(&graph, &prog_entry) + .into_iter() + .map(|k| (*k).clone()) + .collect(); + // 2. prune the graph of unreachable clauses + let stratified_graph: StratifiedGraph<_> = stratified_graph + .into_iter() + .filter(|(k, _)| reachable.contains(k)) + .collect(); + let graph: Graph<_> = graph + .into_iter() + .filter(|(k, _)| reachable.contains(k)) + .collect(); + // 3. find SCC of the clauses + let sccs: Vec> = strongly_connected_components(&graph) + .into_iter() + .map(|scc| scc.into_iter().cloned().collect()) + .collect_vec(); + // 4. for each SCC, verify that no neg/agg edges are present so that it is really stratifiable + verify_no_cycle(&stratified_graph, &sccs)?; + // 5. build a reduced graph for the SCC's + let (invert_indices, reduced_graph) = make_scc_reduced_graph(&sccs, &stratified_graph); + // 6. topological sort the reduced graph to get a stratification + let sort_result = generalized_kahn(&reduced_graph, stratified_graph.len()); + let n_strata = sort_result.len(); + let invert_sort_result = sort_result + .into_iter() + .enumerate() + .flat_map(|(stratum, indices)| indices.into_iter().map(move |idx| (idx, stratum))) + .collect::>(); + // 7. translate the stratification into datalog program + let mut ret: Vec = vec![Default::default(); n_strata]; + for (name, ruleset) in self.prog { + if let Some(scc_idx) = invert_indices.get(&name) { + if let Some(stratum_idx) = invert_sort_result.get(scc_idx) { + let target = ret.get_mut(*stratum_idx).unwrap(); + target.prog.insert(name, ruleset); + } + } + } + + Ok(StratifiedNormalFormProgram(ret)) + } +} + +pub(crate) fn convert_to_stratify_program(prog: &DatalogProgram) -> Result> { + // prerequisite: the program is already in disjunctive normal form + // 0. build a graph of the program + let prog_entry: &Keyword = &PROG_ENTRY; + let stratified_graph = convert_program_to_graph(&prog); + let graph = reduce_to_graph(&stratified_graph); + ensure!( + graph.contains_key(prog_entry), + "program graph does not have an entry" + ); + + // 1. find reachable clauses starting from the query + let reachable: BTreeSet<_> = reachable_components(&graph, &prog_entry) + .into_iter() + .map(|k| (*k).clone()) + .collect(); + // 2. prune the graph of unreachable clauses + let stratified_graph: StratifiedGraph<_> = stratified_graph + .into_iter() + .filter(|(k, _)| reachable.contains(k)) + .collect(); + let graph: Graph<_> = graph + .into_iter() + .filter(|(k, _)| reachable.contains(k)) + .collect(); + // 3. find SCC of the clauses + let sccs: Vec> = strongly_connected_components(&graph) + .into_iter() + .map(|scc| scc.into_iter().cloned().collect()) + .collect_vec(); + // 4. for each SCC, verify that no neg/agg edges are present so that it is really stratifiable + verify_no_cycle(&stratified_graph, &sccs)?; + // 5. build a reduced graph for the SCC's + let (invert_indices, reduced_graph) = make_scc_reduced_graph(&sccs, &stratified_graph); + // 6. topological sort the reduced graph to get a stratification + let sort_result = generalized_kahn(&reduced_graph, stratified_graph.len()); + let n_strata = sort_result.len(); + let invert_sort_result = sort_result + .into_iter() + .enumerate() + .flat_map(|(stratum, indices)| indices.into_iter().map(move |idx| (idx, stratum))) + .collect::>(); + // 7. translate the stratification into datalog program + let mut ret: Vec = vec![Default::default(); n_strata]; + for (name, ruleset) in prog { + if let Some(scc_idx) = invert_indices.get(&name) { + if let Some(stratum_idx) = invert_sort_result.get(scc_idx) { + let target = ret.get_mut(*stratum_idx).unwrap(); + target.insert(name.clone(), ruleset.clone()); + } + } + } + + Ok(ret) +} + pub(crate) fn stratify_program(prog: &DatalogProgram) -> Result> { // prerequisite: the program is already in disjunctive normal form // 0. build a graph of the program diff --git a/src/utils.rs b/src/utils.rs index 79dfbed7..bf3dcb6c 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -6,3 +6,12 @@ pub(crate) fn swap_option_result(d: Result, E>) -> Option Some(Err(e)), } } + +#[inline(always)] +pub(crate) fn swap_result_option(d: Option>) -> Result, E> { + match d { + None => Ok(None), + Some(Ok(v)) => Ok(Some(v)), + Some(Err(e)) => Err(e) + } +}