diff --git a/cozo-core/src/cozoscript.pest b/cozo-core/src/cozoscript.pest index 55bcac73..47313184 100644 --- a/cozo-core/src/cozoscript.pest +++ b/cozo-core/src/cozoscript.pest @@ -272,4 +272,5 @@ fts_and = {"AND"} fts_or = {"OR" | "," | ";"} fts_not = {"NOT"} -expression_script = {SOI ~ expr ~ EOI} \ No newline at end of file +expression_script = {SOI ~ expr ~ EOI} +param_list = {SOI ~ "[" ~ "[" ~ (param ~ ",")* ~ param? ~ "]" ~ "]" ~ EOI} \ No newline at end of file diff --git a/cozo-core/src/parse/query.rs b/cozo-core/src/parse/query.rs index e398e0ed..e11917af 100644 --- a/cozo-core/src/parse/query.rs +++ b/cozo-core/src/parse/query.rs @@ -16,6 +16,7 @@ use std::sync::Arc; use either::{Left, Right}; use itertools::Itertools; use miette::{bail, ensure, Diagnostic, LabeledSpan, Report, Result}; +use pest::Parser; use smartstring::{LazyCompact, SmartString}; use thiserror::Error; @@ -34,7 +35,7 @@ use crate::fixed_rule::utilities::constant::Constant; use crate::fixed_rule::{FixedRuleHandle, FixedRuleNotFoundError}; use crate::parse::expr::build_expr; use crate::parse::schema::parse_schema; -use crate::parse::{ExtractSpan, Pair, Pairs, Rule, SourceSpan}; +use crate::parse::{CozoScriptParser, ExtractSpan, Pair, Pairs, Rule, SourceSpan}; use crate::runtime::relation::InputRelationHandle; use crate::FixedRule; @@ -184,7 +185,7 @@ pub(crate) fn parse_query( Rule::const_rule => { let span = pair.extract_span(); let mut src = pair.into_inner(); - let (name, head, aggr) = parse_rule_head(src.next().unwrap(), param_pool)?; + let (name, mut head, aggr) = parse_rule_head(src.next().unwrap(), param_pool)?; if let Some(found) = progs.get(&name) { let mut found_span = match found { @@ -210,8 +211,9 @@ pub(crate) fn parse_query( for (a, v) in aggr.iter().zip(head.iter()) { ensure!(a.is_none(), AggrInConstRuleError(v.span)); } - - let data = build_expr(src.next().unwrap(), param_pool)?; + let data_part = src.next().unwrap(); + let data_part_str = data_part.as_str(); + let data = build_expr(data_part.clone(), param_pool)?; let mut options = BTreeMap::new(); options.insert(SmartString::from("data"), data); let handle = FixedRuleHandle { @@ -226,6 +228,20 @@ pub(crate) fn parse_query( head.is_empty() || arity == head.len(), FixedRuleHeadArityMismatch(arity, head.len(), span) ); + if head.is_empty() && name.is_prog_entry() { + if let Ok(mut datalist) = + CozoScriptParser::parse(Rule::param_list, data_part_str) + { + for s in datalist.next().unwrap().into_inner() { + if s.as_rule() == Rule::param { + head.push(Symbol::new( + s.as_str().strip_prefix("$").unwrap(), + Default::default(), + )); + } + } + } + } progs.insert( name, InputInlineRulesOrFixed::Fixed { diff --git a/cozo-core/src/runtime/tests.rs b/cozo-core/src/runtime/tests.rs index 0d2d209d..d7b42bc3 100644 --- a/cozo-core/src/runtime/tests.rs +++ b/cozo-core/src/runtime/tests.rs @@ -640,10 +640,13 @@ fn test_multi_tx() { let db = DbInstance::default(); let tx = db.multi_transaction(true); tx.run_script(":create a {a}", Default::default()).unwrap(); - tx.run_script("?[a] <- [[1]] :put a {a}", Default::default()).unwrap(); + tx.run_script("?[a] <- [[1]] :put a {a}", Default::default()) + .unwrap(); assert!(tx.run_script(":create a {a}", Default::default()).is_err()); - tx.run_script("?[a] <- [[2]] :put a {a}", Default::default()).unwrap(); - tx.run_script("?[a] <- [[3]] :put a {a}", Default::default()).unwrap(); + tx.run_script("?[a] <- [[2]] :put a {a}", Default::default()) + .unwrap(); + tx.run_script("?[a] <- [[3]] :put a {a}", Default::default()) + .unwrap(); tx.commit().unwrap(); assert_eq!( db.run_default("?[a] := *a[a]").unwrap().into_json()["rows"], @@ -653,10 +656,13 @@ fn test_multi_tx() { let db = DbInstance::default(); let tx = db.multi_transaction(true); tx.run_script(":create a {a}", Default::default()).unwrap(); - tx.run_script("?[a] <- [[1]] :put a {a}", Default::default()).unwrap(); + tx.run_script("?[a] <- [[1]] :put a {a}", Default::default()) + .unwrap(); assert!(tx.run_script(":create a {a}", Default::default()).is_err()); - tx.run_script("?[a] <- [[2]] :put a {a}", Default::default()).unwrap(); - tx.run_script("?[a] <- [[3]] :put a {a}", Default::default()).unwrap(); + tx.run_script("?[a] <- [[2]] :put a {a}", Default::default()) + .unwrap(); + tx.run_script("?[a] <- [[3]] :put a {a}", Default::default()) + .unwrap(); tx.abort().unwrap(); assert!(db.run_default("?[a] := *a[a]").is_err()); } @@ -1208,7 +1214,8 @@ fn as_store_in_imperative_script() { #[test] fn update_shall_not_destroy_values() { let db = DbInstance::default(); - db.run_default(r"?[x, y] <- [[1, 2]] :create z {x => y default 0}").unwrap(); + db.run_default(r"?[x, y] <- [[1, 2]] :create z {x => y default 0}") + .unwrap(); let r = db.run_default(r"?[x, y] := *z {x, y}").unwrap(); assert_eq!(r.into_json()["rows"], json!([[1, 2]])); db.run_default(r"?[x] <- [[1]] :update z {x}").unwrap(); @@ -1216,14 +1223,15 @@ fn update_shall_not_destroy_values() { assert_eq!(r.into_json()["rows"], json!([[1, 2]])); } - #[test] fn update_shall_work() { let db = DbInstance::default(); - db.run_default(r"?[x, y, z] <- [[1, 2, 3]] :create z {x => y, z}").unwrap(); + db.run_default(r"?[x, y, z] <- [[1, 2, 3]] :create z {x => y, z}") + .unwrap(); let r = db.run_default(r"?[x, y, z] := *z {x, y, z}").unwrap(); assert_eq!(r.into_json()["rows"], json!([[1, 2, 3]])); - db.run_default(r"?[x, y] <- [[1, 4]] :update z {x, y}").unwrap(); + db.run_default(r"?[x, y] <- [[1, 4]] :update z {x, y}") + .unwrap(); let r = db.run_default(r"?[x, y, z] := *z {x, y, z}").unwrap(); assert_eq!(r.into_json()["rows"], json!([[1, 4, 3]])); } @@ -1294,7 +1302,8 @@ fn sysop_in_imperatives() { #[test] fn puts() { let db = DbInstance::default(); - db.run_default(r" + db.run_default( + r" :create cm_txt { tid: String => aid: String, @@ -1308,21 +1317,47 @@ fn puts() { format: String default 'text', info_amount: Int, } - ").unwrap(); - db.run_default(r" + ", + ) + .unwrap(); + db.run_default( + r" ?[tid, aid, tag, text, info_amount, dup_for, seg_vecs, seg_pos] := dup_for = null, tid = 'x', aid = 'y', tag = 'z', text = 'w', info_amount = 12, follows_tid = null, for_qs = [], format = 'x', seg_vecs = [], seg_pos = [[0, 10]] :put cm_txt {tid, aid, tag, text, info_amount, seg_vecs, seg_pos, dup_for} - ").unwrap(); + ", + ) + .unwrap(); } #[test] fn short_hand() { let db = DbInstance::default(); db.run_default(r":create x {x => y, z}").unwrap(); - db.run_default(r"?[x, y, z] <- [[1, 2, 3]] :put x {}").unwrap(); + db.run_default(r"?[x, y, z] <- [[1, 2, 3]] :put x {}") + .unwrap(); let r = db.run_default(r"?[x, y, z] := *x {x, y, z}").unwrap(); assert_eq!(r.into_json()["rows"], json!([[1, 2, 3]])); -} \ No newline at end of file +} + +#[test] +fn param_shorthand() { + let db = DbInstance::default(); + db.run_script( + r" + ?[] <- [[$x, $y, $z]] + :create x {} + ", + BTreeMap::from([ + ("x".to_string(), DataValue::from(1)), + ("y".to_string(), DataValue::from(2)), + ("z".to_string(), DataValue::from(3)), + ]), + ScriptMutability::Mutable, + ) + .unwrap(); + let res = db.run_default(r"?[x, y, z] := *x {x, y, z}"); + assert_eq!(res.unwrap().into_json()["rows"], json!([[1, 2, 3]])); +}