underscores are ignored for all unification; stricter checks for fixed rule bindings

main
Ziyang Hu 2 years ago
parent 6ac4ec33c2
commit 6b246d8f27

@ -81,6 +81,9 @@ impl Symbol {
pub(crate) fn is_prog_entry(&self) -> bool {
self.name == "?"
}
pub(crate) fn is_ignored_symbol(&self) -> bool {
self.name == "_"
}
pub(crate) fn ensure_valid_field(&self) -> Result<()> {
if self.name.contains('(') || self.name.contains(')') {
#[derive(Debug, Error, Diagnostic)]

@ -8,7 +8,7 @@
use std::cmp::Reverse;
use std::collections::btree_map::Entry;
use std::collections::BTreeMap;
use std::collections::{BTreeMap, BTreeSet};
use std::error::Error;
use std::fmt::{Display, Formatter};
use std::rc::Rc;
@ -96,7 +96,7 @@ fn merge_spans(symbs: &[Symbol]) -> SourceSpan {
pub(crate) fn parse_query(
src: Pairs<'_>,
param_pool: &BTreeMap<String, DataValue>,
fixedrithms: &BTreeMap<String, Arc<Box<dyn FixedRule>>>,
fixed_rules: &BTreeMap<String, Arc<Box<dyn FixedRule>>>,
cur_vld: ValidityTs,
) -> Result<InputProgram> {
let mut progs: BTreeMap<Symbol, InputInlineRulesOrFixed> = Default::default();
@ -150,7 +150,7 @@ pub(crate) fn parse_query(
}
Rule::fixed_rule => {
let rule_span = pair.extract_span();
let (name, apply) = parse_fixed_rule(pair, param_pool, fixedrithms, cur_vld)?;
let (name, apply) = parse_fixed_rule(pair, param_pool, fixed_rules, cur_vld)?;
match progs.entry(name) {
Entry::Vacant(e) => {
@ -451,8 +451,9 @@ fn parse_rule(
ensure!(!head.is_empty(), EmptyRuleHead(head_span));
let body = src.next().unwrap();
let mut body_clauses = vec![];
let mut ignored_counter = 0;
for atom_src in body.into_inner() {
body_clauses.push(parse_disjunction(atom_src, param_pool, cur_vld)?)
body_clauses.push(parse_disjunction(atom_src, param_pool, cur_vld, &mut ignored_counter)?)
}
Ok((
@ -470,11 +471,12 @@ fn parse_disjunction(
pair: Pair<'_>,
param_pool: &BTreeMap<String, DataValue>,
cur_vld: ValidityTs,
ignored_counter: &mut u32,
) -> Result<InputAtom> {
let span = pair.extract_span();
let res: Vec<_> = pair
.into_inner()
.map(|v| parse_atom(v, param_pool, cur_vld))
.map(|v| parse_atom(v, param_pool, cur_vld, ignored_counter))
.try_collect()?;
Ok(if res.len() == 1 {
res.into_iter().next().unwrap()
@ -487,23 +489,24 @@ fn parse_atom(
src: Pair<'_>,
param_pool: &BTreeMap<String, DataValue>,
cur_vld: ValidityTs,
ignored_counter: &mut u32
) -> Result<InputAtom> {
Ok(match src.as_rule() {
Rule::rule_body => {
let span = src.extract_span();
let grouped: Vec<_> = src
.into_inner()
.map(|v| parse_disjunction(v, param_pool, cur_vld))
.map(|v| parse_disjunction(v, param_pool, cur_vld, ignored_counter))
.try_collect()?;
InputAtom::Conjunction {
inner: grouped,
span,
}
}
Rule::disjunction => parse_disjunction(src, param_pool, cur_vld)?,
Rule::disjunction => parse_disjunction(src, param_pool, cur_vld, ignored_counter)?,
Rule::negation => {
let span = src.extract_span();
let inner = parse_atom(src.into_inner().next().unwrap(), param_pool, cur_vld)?;
let inner = parse_atom(src.into_inner().next().unwrap(), param_pool, cur_vld, ignored_counter)?;
InputAtom::Negation {
inner: inner.into(),
span,
@ -517,10 +520,15 @@ fn parse_atom(
let span = src.extract_span();
let mut src = src.into_inner();
let var = src.next().unwrap();
let mut symb = Symbol::new(var.as_str(), var.extract_span());
if symb.is_ignored_symbol() {
symb.name = format!("*^*{}", *ignored_counter).into();
*ignored_counter += 1;
}
let expr = build_expr(src.next().unwrap(), param_pool)?;
InputAtom::Unification {
inner: Unification {
binding: Symbol::new(var.as_str(), var.extract_span()),
binding: symb,
expr,
one_many_unif: false,
span,
@ -531,10 +539,15 @@ fn parse_atom(
let span = src.extract_span();
let mut src = src.into_inner();
let var = src.next().unwrap();
let mut symb = Symbol::new(var.as_str(), var.extract_span());
if symb.is_ignored_symbol() {
symb.name = format!("*^*{}", *ignored_counter).into();
*ignored_counter += 1;
}
let expr = build_expr(src.next().unwrap(), param_pool)?;
InputAtom::Unification {
inner: Unification {
binding: Symbol::new(var.as_str(), var.extract_span()),
binding: symb,
expr,
one_many_unif: true,
span,
@ -690,7 +703,7 @@ struct BadValiditySpecification(#[label] SourceSpan);
fn parse_fixed_rule(
src: Pair<'_>,
param_pool: &BTreeMap<String, DataValue>,
fixedrithms: &BTreeMap<String, Arc<Box<dyn FixedRule>>>,
fixed_rules: &BTreeMap<String, Arc<Box<dyn FixedRule>>>,
cur_vld: ValidityTs,
) -> Result<(Symbol, FixedRuleApply)> {
let mut src = src.into_inner();
@ -701,10 +714,18 @@ fn parse_fixed_rule(
#[diagnostic(code(parser::fixed_aggr_conflict))]
struct AggrInfixedError(#[label] SourceSpan);
#[derive(Debug, Error, Diagnostic)]
#[error("fixed rule cannot have duplicate bindings")]
#[diagnostic(code(parser::duplicate_bindings_for_fixed_rule))]
struct DuplicateBindingError(#[label] SourceSpan);
for (a, v) in aggr.iter().zip(head.iter()) {
ensure!(a.is_none(), AggrInfixedError(v.span))
}
let mut seen_bindings = BTreeSet::new();
let mut binding_gen_id = 0;
let name_pair = src.next().unwrap();
let fixed_name = &name_pair.as_str();
let mut rule_args: Vec<FixedRuleArg> = vec![];
@ -721,9 +742,21 @@ fn parse_fixed_rule(
Rule::fixed_rule_rel => {
let mut els = inner.into_inner();
let name = els.next().unwrap();
let bindings = els
.map(|v| Symbol::new(v.as_str(), v.extract_span()))
.collect_vec();
let mut bindings = Vec::with_capacity(els.size_hint().1.unwrap_or(4));
for v in els {
let s = v.as_str();
if s == "_" {
let symb = Symbol::new(format!("*_*{}", binding_gen_id) ,v.extract_span());
binding_gen_id += 1;
bindings.push(symb);
} else {
if !seen_bindings.insert(s) {
bail!(DuplicateBindingError(v.extract_span()))
}
let symb = Symbol::new(s, v.extract_span());
bindings.push(symb);
}
}
rule_args.push(FixedRuleArg::InMem {
name: Symbol::new(name.as_str(), name.extract_span()),
bindings,
@ -738,7 +771,17 @@ fn parse_fixed_rule(
for v in els {
match v.as_rule() {
Rule::var => {
bindings.push(Symbol::new(v.as_str(), v.extract_span()))
let s = v.as_str();
if s == "_" {
let symb = Symbol::new(format!("*_*{}", binding_gen_id) ,v.extract_span());
binding_gen_id += 1;
bindings.push(symb);
} else {
if !seen_bindings.insert(s) {
bail!(DuplicateBindingError(v.extract_span()))
}
bindings.push(Symbol::new(v.as_str(), v.extract_span()))
}
}
Rule::validity_clause => {
let vld_inner = v.into_inner().next().unwrap();
@ -770,8 +813,18 @@ fn parse_fixed_rule(
let kp = vs.next().unwrap();
let k = SmartString::from(kp.as_str());
let v = match vs.next() {
Some(vp) => Symbol::new(vp.as_str(), vp.extract_span()),
None => Symbol::new(k.clone(), kp.extract_span()),
Some(vp) => {
if !seen_bindings.insert(vp.as_str()) {
bail!(DuplicateBindingError(vp.extract_span()))
}
Symbol::new(vp.as_str(), vp.extract_span())
}
None => {
if !seen_bindings.insert(kp.as_str()) {
bail!(DuplicateBindingError(kp.extract_span()))
}
Symbol::new(k.clone(), kp.extract_span())
}
};
bindings.insert(k, v);
}
@ -810,7 +863,7 @@ fn parse_fixed_rule(
let fixed = FixedRuleHandle::new(fixed_name, name_pair.extract_span());
let fixed_impl = fixedrithms
let fixed_impl = fixed_rules
.get(&fixed.name as &str)
.ok_or_else(|| FixedRuleNotFoundError(fixed.name.to_string(), name_pair.extract_span()))?;
fixed_impl.init_options(&mut options, args_list_span)?;

@ -237,7 +237,10 @@ impl InputRuleApplyAtom {
for arg in self.args {
match arg {
Expr::Binding { var, .. } => {
if seen_variables.insert(var.clone()) {
if var.is_ignored_symbol() {
let dup = gen.next(var.span);
args.push(dup);
} else if seen_variables.insert(var.clone()) {
args.push(var);
} else {
let dup = gen.next(var.span);
@ -294,7 +297,9 @@ impl InputRelationApplyAtom {
for arg in self.args {
match arg {
Expr::Binding { var, .. } => {
if seen_variables.insert(var.clone()) {
if var.is_ignored_symbol() {
args.push(gen.next(var.span));
} else if seen_variables.insert(var.clone()) {
args.push(var);
} else {
let span = var.span;

@ -1391,4 +1391,75 @@ grandparent[gcld, gp] := parent[gcld, p], parent[p, gp]
.run_script("?[uid] <- [[1]] :rm status {uid}", Default::default())
.is_ok());
}
#[test]
fn strict_checks_for_fixed_rules_args() {
let db = new_cozo_mem().unwrap();
let res = db.run_script(
r#"
r[] <- [[1, 2]]
?[] <~ PageRank(r[_, _])
"#,
Default::default(),
);
assert!(res.is_ok());
let db = new_cozo_mem().unwrap();
let res = db.run_script(
r#"
r[] <- [[1, 2]]
?[] <~ PageRank(r[a, b])
"#,
Default::default(),
);
assert!(res.is_ok());
let db = new_cozo_mem().unwrap();
let res = db.run_script(
r#"
r[] <- [[1, 2]]
?[] <~ PageRank(r[a, a])
"#,
Default::default(),
);
assert!(res.is_err());
}
#[test]
fn do_not_unify_underscore() {
let db = new_cozo_mem().unwrap();
let res = db
.run_script(
r#"
r1[] <- [[1, 'a'], [2, 'b']]
r2[] <- [[2, 'B'], [3, 'C']]
?[l1, l2] := r1[_ , l1], r2[_ , l2]
"#,
Default::default(),
)
.unwrap()
.rows;
assert_eq!(res.len(), 4);
let res = db.run_script(
r#"
?[_] := _ = 1
"#,
Default::default(),
);
assert!(res.is_err());
let res = db
.run_script(
r#"
?[x] := x = 1, _ = 1, _ = 2
"#,
Default::default(),
)
.unwrap()
.rows;
assert_eq!(res.len(), 1);
}
}

Loading…
Cancel
Save