diff --git a/cozo-core/src/data/symb.rs b/cozo-core/src/data/symb.rs index 5359d9d1..a49f6710 100644 --- a/cozo-core/src/data/symb.rs +++ b/cozo-core/src/data/symb.rs @@ -81,6 +81,9 @@ impl Symbol { pub(crate) fn is_prog_entry(&self) -> bool { self.name == "?" } + pub(crate) fn is_ignored_symbol(&self) -> bool { + self.name == "_" + } pub(crate) fn ensure_valid_field(&self) -> Result<()> { if self.name.contains('(') || self.name.contains(')') { #[derive(Debug, Error, Diagnostic)] diff --git a/cozo-core/src/parse/query.rs b/cozo-core/src/parse/query.rs index 95608edb..b562d9a5 100644 --- a/cozo-core/src/parse/query.rs +++ b/cozo-core/src/parse/query.rs @@ -8,7 +8,7 @@ use std::cmp::Reverse; use std::collections::btree_map::Entry; -use std::collections::BTreeMap; +use std::collections::{BTreeMap, BTreeSet}; use std::error::Error; use std::fmt::{Display, Formatter}; use std::rc::Rc; @@ -96,7 +96,7 @@ fn merge_spans(symbs: &[Symbol]) -> SourceSpan { pub(crate) fn parse_query( src: Pairs<'_>, param_pool: &BTreeMap, - fixedrithms: &BTreeMap>>, + fixed_rules: &BTreeMap>>, cur_vld: ValidityTs, ) -> Result { let mut progs: BTreeMap = Default::default(); @@ -150,7 +150,7 @@ pub(crate) fn parse_query( } Rule::fixed_rule => { let rule_span = pair.extract_span(); - let (name, apply) = parse_fixed_rule(pair, param_pool, fixedrithms, cur_vld)?; + let (name, apply) = parse_fixed_rule(pair, param_pool, fixed_rules, cur_vld)?; match progs.entry(name) { Entry::Vacant(e) => { @@ -451,8 +451,9 @@ fn parse_rule( ensure!(!head.is_empty(), EmptyRuleHead(head_span)); let body = src.next().unwrap(); let mut body_clauses = vec![]; + let mut ignored_counter = 0; for atom_src in body.into_inner() { - body_clauses.push(parse_disjunction(atom_src, param_pool, cur_vld)?) + body_clauses.push(parse_disjunction(atom_src, param_pool, cur_vld, &mut ignored_counter)?) } Ok(( @@ -470,11 +471,12 @@ fn parse_disjunction( pair: Pair<'_>, param_pool: &BTreeMap, cur_vld: ValidityTs, + ignored_counter: &mut u32, ) -> Result { let span = pair.extract_span(); let res: Vec<_> = pair .into_inner() - .map(|v| parse_atom(v, param_pool, cur_vld)) + .map(|v| parse_atom(v, param_pool, cur_vld, ignored_counter)) .try_collect()?; Ok(if res.len() == 1 { res.into_iter().next().unwrap() @@ -487,23 +489,24 @@ fn parse_atom( src: Pair<'_>, param_pool: &BTreeMap, cur_vld: ValidityTs, + ignored_counter: &mut u32 ) -> Result { Ok(match src.as_rule() { Rule::rule_body => { let span = src.extract_span(); let grouped: Vec<_> = src .into_inner() - .map(|v| parse_disjunction(v, param_pool, cur_vld)) + .map(|v| parse_disjunction(v, param_pool, cur_vld, ignored_counter)) .try_collect()?; InputAtom::Conjunction { inner: grouped, span, } } - Rule::disjunction => parse_disjunction(src, param_pool, cur_vld)?, + Rule::disjunction => parse_disjunction(src, param_pool, cur_vld, ignored_counter)?, Rule::negation => { let span = src.extract_span(); - let inner = parse_atom(src.into_inner().next().unwrap(), param_pool, cur_vld)?; + let inner = parse_atom(src.into_inner().next().unwrap(), param_pool, cur_vld, ignored_counter)?; InputAtom::Negation { inner: inner.into(), span, @@ -517,10 +520,15 @@ fn parse_atom( let span = src.extract_span(); let mut src = src.into_inner(); let var = src.next().unwrap(); + let mut symb = Symbol::new(var.as_str(), var.extract_span()); + if symb.is_ignored_symbol() { + symb.name = format!("*^*{}", *ignored_counter).into(); + *ignored_counter += 1; + } let expr = build_expr(src.next().unwrap(), param_pool)?; InputAtom::Unification { inner: Unification { - binding: Symbol::new(var.as_str(), var.extract_span()), + binding: symb, expr, one_many_unif: false, span, @@ -531,10 +539,15 @@ fn parse_atom( let span = src.extract_span(); let mut src = src.into_inner(); let var = src.next().unwrap(); + let mut symb = Symbol::new(var.as_str(), var.extract_span()); + if symb.is_ignored_symbol() { + symb.name = format!("*^*{}", *ignored_counter).into(); + *ignored_counter += 1; + } let expr = build_expr(src.next().unwrap(), param_pool)?; InputAtom::Unification { inner: Unification { - binding: Symbol::new(var.as_str(), var.extract_span()), + binding: symb, expr, one_many_unif: true, span, @@ -690,7 +703,7 @@ struct BadValiditySpecification(#[label] SourceSpan); fn parse_fixed_rule( src: Pair<'_>, param_pool: &BTreeMap, - fixedrithms: &BTreeMap>>, + fixed_rules: &BTreeMap>>, cur_vld: ValidityTs, ) -> Result<(Symbol, FixedRuleApply)> { let mut src = src.into_inner(); @@ -701,10 +714,18 @@ fn parse_fixed_rule( #[diagnostic(code(parser::fixed_aggr_conflict))] struct AggrInfixedError(#[label] SourceSpan); + #[derive(Debug, Error, Diagnostic)] + #[error("fixed rule cannot have duplicate bindings")] + #[diagnostic(code(parser::duplicate_bindings_for_fixed_rule))] + struct DuplicateBindingError(#[label] SourceSpan); + for (a, v) in aggr.iter().zip(head.iter()) { ensure!(a.is_none(), AggrInfixedError(v.span)) } + let mut seen_bindings = BTreeSet::new(); + let mut binding_gen_id = 0; + let name_pair = src.next().unwrap(); let fixed_name = &name_pair.as_str(); let mut rule_args: Vec = vec![]; @@ -721,9 +742,21 @@ fn parse_fixed_rule( Rule::fixed_rule_rel => { let mut els = inner.into_inner(); let name = els.next().unwrap(); - let bindings = els - .map(|v| Symbol::new(v.as_str(), v.extract_span())) - .collect_vec(); + let mut bindings = Vec::with_capacity(els.size_hint().1.unwrap_or(4)); + for v in els { + let s = v.as_str(); + if s == "_" { + let symb = Symbol::new(format!("*_*{}", binding_gen_id) ,v.extract_span()); + binding_gen_id += 1; + bindings.push(symb); + } else { + if !seen_bindings.insert(s) { + bail!(DuplicateBindingError(v.extract_span())) + } + let symb = Symbol::new(s, v.extract_span()); + bindings.push(symb); + } + } rule_args.push(FixedRuleArg::InMem { name: Symbol::new(name.as_str(), name.extract_span()), bindings, @@ -738,7 +771,17 @@ fn parse_fixed_rule( for v in els { match v.as_rule() { Rule::var => { - bindings.push(Symbol::new(v.as_str(), v.extract_span())) + let s = v.as_str(); + if s == "_" { + let symb = Symbol::new(format!("*_*{}", binding_gen_id) ,v.extract_span()); + binding_gen_id += 1; + bindings.push(symb); + } else { + if !seen_bindings.insert(s) { + bail!(DuplicateBindingError(v.extract_span())) + } + bindings.push(Symbol::new(v.as_str(), v.extract_span())) + } } Rule::validity_clause => { let vld_inner = v.into_inner().next().unwrap(); @@ -770,8 +813,18 @@ fn parse_fixed_rule( let kp = vs.next().unwrap(); let k = SmartString::from(kp.as_str()); let v = match vs.next() { - Some(vp) => Symbol::new(vp.as_str(), vp.extract_span()), - None => Symbol::new(k.clone(), kp.extract_span()), + Some(vp) => { + if !seen_bindings.insert(vp.as_str()) { + bail!(DuplicateBindingError(vp.extract_span())) + } + Symbol::new(vp.as_str(), vp.extract_span()) + } + None => { + if !seen_bindings.insert(kp.as_str()) { + bail!(DuplicateBindingError(kp.extract_span())) + } + Symbol::new(k.clone(), kp.extract_span()) + } }; bindings.insert(k, v); } @@ -810,7 +863,7 @@ fn parse_fixed_rule( let fixed = FixedRuleHandle::new(fixed_name, name_pair.extract_span()); - let fixed_impl = fixedrithms + let fixed_impl = fixed_rules .get(&fixed.name as &str) .ok_or_else(|| FixedRuleNotFoundError(fixed.name.to_string(), name_pair.extract_span()))?; fixed_impl.init_options(&mut options, args_list_span)?; diff --git a/cozo-core/src/query/logical.rs b/cozo-core/src/query/logical.rs index 86388a1c..6802ecc1 100644 --- a/cozo-core/src/query/logical.rs +++ b/cozo-core/src/query/logical.rs @@ -237,7 +237,10 @@ impl InputRuleApplyAtom { for arg in self.args { match arg { Expr::Binding { var, .. } => { - if seen_variables.insert(var.clone()) { + if var.is_ignored_symbol() { + let dup = gen.next(var.span); + args.push(dup); + } else if seen_variables.insert(var.clone()) { args.push(var); } else { let dup = gen.next(var.span); @@ -294,7 +297,9 @@ impl InputRelationApplyAtom { for arg in self.args { match arg { Expr::Binding { var, .. } => { - if seen_variables.insert(var.clone()) { + if var.is_ignored_symbol() { + args.push(gen.next(var.span)); + } else if seen_variables.insert(var.clone()) { args.push(var); } else { let span = var.span; diff --git a/cozo-core/src/runtime/db.rs b/cozo-core/src/runtime/db.rs index 4afc952a..25e259e9 100644 --- a/cozo-core/src/runtime/db.rs +++ b/cozo-core/src/runtime/db.rs @@ -1391,4 +1391,75 @@ grandparent[gcld, gp] := parent[gcld, p], parent[p, gp] .run_script("?[uid] <- [[1]] :rm status {uid}", Default::default()) .is_ok()); } + + #[test] + fn strict_checks_for_fixed_rules_args() { + let db = new_cozo_mem().unwrap(); + let res = db.run_script( + r#" + r[] <- [[1, 2]] + ?[] <~ PageRank(r[_, _]) + "#, + Default::default(), + ); + assert!(res.is_ok()); + + let db = new_cozo_mem().unwrap(); + let res = db.run_script( + r#" + r[] <- [[1, 2]] + ?[] <~ PageRank(r[a, b]) + "#, + Default::default(), + ); + assert!(res.is_ok()); + + let db = new_cozo_mem().unwrap(); + let res = db.run_script( + r#" + r[] <- [[1, 2]] + ?[] <~ PageRank(r[a, a]) + "#, + Default::default(), + ); + assert!(res.is_err()); + } + + #[test] + fn do_not_unify_underscore() { + let db = new_cozo_mem().unwrap(); + let res = db + .run_script( + r#" + r1[] <- [[1, 'a'], [2, 'b']] + r2[] <- [[2, 'B'], [3, 'C']] + + ?[l1, l2] := r1[_ , l1], r2[_ , l2] + "#, + Default::default(), + ) + .unwrap() + .rows; + assert_eq!(res.len(), 4); + + let res = db.run_script( + r#" + ?[_] := _ = 1 + "#, + Default::default(), + ); + assert!(res.is_err()); + + let res = db + .run_script( + r#" + ?[x] := x = 1, _ = 1, _ = 2 + "#, + Default::default(), + ) + .unwrap() + .rows; + + assert_eq!(res.len(), 1); + } }