underscores are ignored for all unification; stricter checks for fixed rule bindings

main
Ziyang Hu 2 years ago
parent 6ac4ec33c2
commit 6b246d8f27

@ -81,6 +81,9 @@ impl Symbol {
pub(crate) fn is_prog_entry(&self) -> bool { pub(crate) fn is_prog_entry(&self) -> bool {
self.name == "?" self.name == "?"
} }
pub(crate) fn is_ignored_symbol(&self) -> bool {
self.name == "_"
}
pub(crate) fn ensure_valid_field(&self) -> Result<()> { pub(crate) fn ensure_valid_field(&self) -> Result<()> {
if self.name.contains('(') || self.name.contains(')') { if self.name.contains('(') || self.name.contains(')') {
#[derive(Debug, Error, Diagnostic)] #[derive(Debug, Error, Diagnostic)]

@ -8,7 +8,7 @@
use std::cmp::Reverse; use std::cmp::Reverse;
use std::collections::btree_map::Entry; use std::collections::btree_map::Entry;
use std::collections::BTreeMap; use std::collections::{BTreeMap, BTreeSet};
use std::error::Error; use std::error::Error;
use std::fmt::{Display, Formatter}; use std::fmt::{Display, Formatter};
use std::rc::Rc; use std::rc::Rc;
@ -96,7 +96,7 @@ fn merge_spans(symbs: &[Symbol]) -> SourceSpan {
pub(crate) fn parse_query( pub(crate) fn parse_query(
src: Pairs<'_>, src: Pairs<'_>,
param_pool: &BTreeMap<String, DataValue>, param_pool: &BTreeMap<String, DataValue>,
fixedrithms: &BTreeMap<String, Arc<Box<dyn FixedRule>>>, fixed_rules: &BTreeMap<String, Arc<Box<dyn FixedRule>>>,
cur_vld: ValidityTs, cur_vld: ValidityTs,
) -> Result<InputProgram> { ) -> Result<InputProgram> {
let mut progs: BTreeMap<Symbol, InputInlineRulesOrFixed> = Default::default(); let mut progs: BTreeMap<Symbol, InputInlineRulesOrFixed> = Default::default();
@ -150,7 +150,7 @@ pub(crate) fn parse_query(
} }
Rule::fixed_rule => { Rule::fixed_rule => {
let rule_span = pair.extract_span(); let rule_span = pair.extract_span();
let (name, apply) = parse_fixed_rule(pair, param_pool, fixedrithms, cur_vld)?; let (name, apply) = parse_fixed_rule(pair, param_pool, fixed_rules, cur_vld)?;
match progs.entry(name) { match progs.entry(name) {
Entry::Vacant(e) => { Entry::Vacant(e) => {
@ -451,8 +451,9 @@ fn parse_rule(
ensure!(!head.is_empty(), EmptyRuleHead(head_span)); ensure!(!head.is_empty(), EmptyRuleHead(head_span));
let body = src.next().unwrap(); let body = src.next().unwrap();
let mut body_clauses = vec![]; let mut body_clauses = vec![];
let mut ignored_counter = 0;
for atom_src in body.into_inner() { for atom_src in body.into_inner() {
body_clauses.push(parse_disjunction(atom_src, param_pool, cur_vld)?) body_clauses.push(parse_disjunction(atom_src, param_pool, cur_vld, &mut ignored_counter)?)
} }
Ok(( Ok((
@ -470,11 +471,12 @@ fn parse_disjunction(
pair: Pair<'_>, pair: Pair<'_>,
param_pool: &BTreeMap<String, DataValue>, param_pool: &BTreeMap<String, DataValue>,
cur_vld: ValidityTs, cur_vld: ValidityTs,
ignored_counter: &mut u32,
) -> Result<InputAtom> { ) -> Result<InputAtom> {
let span = pair.extract_span(); let span = pair.extract_span();
let res: Vec<_> = pair let res: Vec<_> = pair
.into_inner() .into_inner()
.map(|v| parse_atom(v, param_pool, cur_vld)) .map(|v| parse_atom(v, param_pool, cur_vld, ignored_counter))
.try_collect()?; .try_collect()?;
Ok(if res.len() == 1 { Ok(if res.len() == 1 {
res.into_iter().next().unwrap() res.into_iter().next().unwrap()
@ -487,23 +489,24 @@ fn parse_atom(
src: Pair<'_>, src: Pair<'_>,
param_pool: &BTreeMap<String, DataValue>, param_pool: &BTreeMap<String, DataValue>,
cur_vld: ValidityTs, cur_vld: ValidityTs,
ignored_counter: &mut u32
) -> Result<InputAtom> { ) -> Result<InputAtom> {
Ok(match src.as_rule() { Ok(match src.as_rule() {
Rule::rule_body => { Rule::rule_body => {
let span = src.extract_span(); let span = src.extract_span();
let grouped: Vec<_> = src let grouped: Vec<_> = src
.into_inner() .into_inner()
.map(|v| parse_disjunction(v, param_pool, cur_vld)) .map(|v| parse_disjunction(v, param_pool, cur_vld, ignored_counter))
.try_collect()?; .try_collect()?;
InputAtom::Conjunction { InputAtom::Conjunction {
inner: grouped, inner: grouped,
span, span,
} }
} }
Rule::disjunction => parse_disjunction(src, param_pool, cur_vld)?, Rule::disjunction => parse_disjunction(src, param_pool, cur_vld, ignored_counter)?,
Rule::negation => { Rule::negation => {
let span = src.extract_span(); let span = src.extract_span();
let inner = parse_atom(src.into_inner().next().unwrap(), param_pool, cur_vld)?; let inner = parse_atom(src.into_inner().next().unwrap(), param_pool, cur_vld, ignored_counter)?;
InputAtom::Negation { InputAtom::Negation {
inner: inner.into(), inner: inner.into(),
span, span,
@ -517,10 +520,15 @@ fn parse_atom(
let span = src.extract_span(); let span = src.extract_span();
let mut src = src.into_inner(); let mut src = src.into_inner();
let var = src.next().unwrap(); let var = src.next().unwrap();
let mut symb = Symbol::new(var.as_str(), var.extract_span());
if symb.is_ignored_symbol() {
symb.name = format!("*^*{}", *ignored_counter).into();
*ignored_counter += 1;
}
let expr = build_expr(src.next().unwrap(), param_pool)?; let expr = build_expr(src.next().unwrap(), param_pool)?;
InputAtom::Unification { InputAtom::Unification {
inner: Unification { inner: Unification {
binding: Symbol::new(var.as_str(), var.extract_span()), binding: symb,
expr, expr,
one_many_unif: false, one_many_unif: false,
span, span,
@ -531,10 +539,15 @@ fn parse_atom(
let span = src.extract_span(); let span = src.extract_span();
let mut src = src.into_inner(); let mut src = src.into_inner();
let var = src.next().unwrap(); let var = src.next().unwrap();
let mut symb = Symbol::new(var.as_str(), var.extract_span());
if symb.is_ignored_symbol() {
symb.name = format!("*^*{}", *ignored_counter).into();
*ignored_counter += 1;
}
let expr = build_expr(src.next().unwrap(), param_pool)?; let expr = build_expr(src.next().unwrap(), param_pool)?;
InputAtom::Unification { InputAtom::Unification {
inner: Unification { inner: Unification {
binding: Symbol::new(var.as_str(), var.extract_span()), binding: symb,
expr, expr,
one_many_unif: true, one_many_unif: true,
span, span,
@ -690,7 +703,7 @@ struct BadValiditySpecification(#[label] SourceSpan);
fn parse_fixed_rule( fn parse_fixed_rule(
src: Pair<'_>, src: Pair<'_>,
param_pool: &BTreeMap<String, DataValue>, param_pool: &BTreeMap<String, DataValue>,
fixedrithms: &BTreeMap<String, Arc<Box<dyn FixedRule>>>, fixed_rules: &BTreeMap<String, Arc<Box<dyn FixedRule>>>,
cur_vld: ValidityTs, cur_vld: ValidityTs,
) -> Result<(Symbol, FixedRuleApply)> { ) -> Result<(Symbol, FixedRuleApply)> {
let mut src = src.into_inner(); let mut src = src.into_inner();
@ -701,10 +714,18 @@ fn parse_fixed_rule(
#[diagnostic(code(parser::fixed_aggr_conflict))] #[diagnostic(code(parser::fixed_aggr_conflict))]
struct AggrInfixedError(#[label] SourceSpan); struct AggrInfixedError(#[label] SourceSpan);
#[derive(Debug, Error, Diagnostic)]
#[error("fixed rule cannot have duplicate bindings")]
#[diagnostic(code(parser::duplicate_bindings_for_fixed_rule))]
struct DuplicateBindingError(#[label] SourceSpan);
for (a, v) in aggr.iter().zip(head.iter()) { for (a, v) in aggr.iter().zip(head.iter()) {
ensure!(a.is_none(), AggrInfixedError(v.span)) ensure!(a.is_none(), AggrInfixedError(v.span))
} }
let mut seen_bindings = BTreeSet::new();
let mut binding_gen_id = 0;
let name_pair = src.next().unwrap(); let name_pair = src.next().unwrap();
let fixed_name = &name_pair.as_str(); let fixed_name = &name_pair.as_str();
let mut rule_args: Vec<FixedRuleArg> = vec![]; let mut rule_args: Vec<FixedRuleArg> = vec![];
@ -721,9 +742,21 @@ fn parse_fixed_rule(
Rule::fixed_rule_rel => { Rule::fixed_rule_rel => {
let mut els = inner.into_inner(); let mut els = inner.into_inner();
let name = els.next().unwrap(); let name = els.next().unwrap();
let bindings = els let mut bindings = Vec::with_capacity(els.size_hint().1.unwrap_or(4));
.map(|v| Symbol::new(v.as_str(), v.extract_span())) for v in els {
.collect_vec(); let s = v.as_str();
if s == "_" {
let symb = Symbol::new(format!("*_*{}", binding_gen_id) ,v.extract_span());
binding_gen_id += 1;
bindings.push(symb);
} else {
if !seen_bindings.insert(s) {
bail!(DuplicateBindingError(v.extract_span()))
}
let symb = Symbol::new(s, v.extract_span());
bindings.push(symb);
}
}
rule_args.push(FixedRuleArg::InMem { rule_args.push(FixedRuleArg::InMem {
name: Symbol::new(name.as_str(), name.extract_span()), name: Symbol::new(name.as_str(), name.extract_span()),
bindings, bindings,
@ -738,7 +771,17 @@ fn parse_fixed_rule(
for v in els { for v in els {
match v.as_rule() { match v.as_rule() {
Rule::var => { Rule::var => {
bindings.push(Symbol::new(v.as_str(), v.extract_span())) let s = v.as_str();
if s == "_" {
let symb = Symbol::new(format!("*_*{}", binding_gen_id) ,v.extract_span());
binding_gen_id += 1;
bindings.push(symb);
} else {
if !seen_bindings.insert(s) {
bail!(DuplicateBindingError(v.extract_span()))
}
bindings.push(Symbol::new(v.as_str(), v.extract_span()))
}
} }
Rule::validity_clause => { Rule::validity_clause => {
let vld_inner = v.into_inner().next().unwrap(); let vld_inner = v.into_inner().next().unwrap();
@ -770,8 +813,18 @@ fn parse_fixed_rule(
let kp = vs.next().unwrap(); let kp = vs.next().unwrap();
let k = SmartString::from(kp.as_str()); let k = SmartString::from(kp.as_str());
let v = match vs.next() { let v = match vs.next() {
Some(vp) => Symbol::new(vp.as_str(), vp.extract_span()), Some(vp) => {
None => Symbol::new(k.clone(), kp.extract_span()), if !seen_bindings.insert(vp.as_str()) {
bail!(DuplicateBindingError(vp.extract_span()))
}
Symbol::new(vp.as_str(), vp.extract_span())
}
None => {
if !seen_bindings.insert(kp.as_str()) {
bail!(DuplicateBindingError(kp.extract_span()))
}
Symbol::new(k.clone(), kp.extract_span())
}
}; };
bindings.insert(k, v); bindings.insert(k, v);
} }
@ -810,7 +863,7 @@ fn parse_fixed_rule(
let fixed = FixedRuleHandle::new(fixed_name, name_pair.extract_span()); let fixed = FixedRuleHandle::new(fixed_name, name_pair.extract_span());
let fixed_impl = fixedrithms let fixed_impl = fixed_rules
.get(&fixed.name as &str) .get(&fixed.name as &str)
.ok_or_else(|| FixedRuleNotFoundError(fixed.name.to_string(), name_pair.extract_span()))?; .ok_or_else(|| FixedRuleNotFoundError(fixed.name.to_string(), name_pair.extract_span()))?;
fixed_impl.init_options(&mut options, args_list_span)?; fixed_impl.init_options(&mut options, args_list_span)?;

@ -237,7 +237,10 @@ impl InputRuleApplyAtom {
for arg in self.args { for arg in self.args {
match arg { match arg {
Expr::Binding { var, .. } => { Expr::Binding { var, .. } => {
if seen_variables.insert(var.clone()) { if var.is_ignored_symbol() {
let dup = gen.next(var.span);
args.push(dup);
} else if seen_variables.insert(var.clone()) {
args.push(var); args.push(var);
} else { } else {
let dup = gen.next(var.span); let dup = gen.next(var.span);
@ -294,7 +297,9 @@ impl InputRelationApplyAtom {
for arg in self.args { for arg in self.args {
match arg { match arg {
Expr::Binding { var, .. } => { Expr::Binding { var, .. } => {
if seen_variables.insert(var.clone()) { if var.is_ignored_symbol() {
args.push(gen.next(var.span));
} else if seen_variables.insert(var.clone()) {
args.push(var); args.push(var);
} else { } else {
let span = var.span; let span = var.span;

@ -1391,4 +1391,75 @@ grandparent[gcld, gp] := parent[gcld, p], parent[p, gp]
.run_script("?[uid] <- [[1]] :rm status {uid}", Default::default()) .run_script("?[uid] <- [[1]] :rm status {uid}", Default::default())
.is_ok()); .is_ok());
} }
#[test]
fn strict_checks_for_fixed_rules_args() {
let db = new_cozo_mem().unwrap();
let res = db.run_script(
r#"
r[] <- [[1, 2]]
?[] <~ PageRank(r[_, _])
"#,
Default::default(),
);
assert!(res.is_ok());
let db = new_cozo_mem().unwrap();
let res = db.run_script(
r#"
r[] <- [[1, 2]]
?[] <~ PageRank(r[a, b])
"#,
Default::default(),
);
assert!(res.is_ok());
let db = new_cozo_mem().unwrap();
let res = db.run_script(
r#"
r[] <- [[1, 2]]
?[] <~ PageRank(r[a, a])
"#,
Default::default(),
);
assert!(res.is_err());
}
#[test]
fn do_not_unify_underscore() {
let db = new_cozo_mem().unwrap();
let res = db
.run_script(
r#"
r1[] <- [[1, 'a'], [2, 'b']]
r2[] <- [[2, 'B'], [3, 'C']]
?[l1, l2] := r1[_ , l1], r2[_ , l2]
"#,
Default::default(),
)
.unwrap()
.rows;
assert_eq!(res.len(), 4);
let res = db.run_script(
r#"
?[_] := _ = 1
"#,
Default::default(),
);
assert!(res.is_err());
let res = db
.run_script(
r#"
?[x] := x = 1, _ = 1, _ = 2
"#,
Default::default(),
)
.unwrap()
.rows;
assert_eq!(res.len(), 1);
}
} }

Loading…
Cancel
Save