diff --git a/cozo-core/src/fixed_rule/mod.rs b/cozo-core/src/fixed_rule/mod.rs index 0efafa8c..a08dbb78 100644 --- a/cozo-core/src/fixed_rule/mod.rs +++ b/cozo-core/src/fixed_rule/mod.rs @@ -362,64 +362,23 @@ impl<'a, 'b> FixedRulePayload<'a, 'b> { } } - pub fn pos_integer_option(&self, name: &str, default: Option) -> Result { + pub fn option_span(&self, name: &str) -> Result { match self.manifest.options.get(name) { - Some(v) => match v.clone().eval_to_const() { - Ok(DataValue::Num(n)) => match n.get_int() { - Some(i) => { - ensure!( - i > 0, - WrongFixedRuleOptionError { - name: name.to_string(), - span: v.span(), - rule_name: self.manifest.fixed_handle.name.to_string(), - help: "a positive integer is required".to_string(), - } - ); - Ok(i as usize) - } - None => Err(FixedRuleOptionNotFoundError { - name: name.to_string(), - span: self.span(), - rule_name: self.manifest.fixed_handle.name.to_string(), - } - .into()), - }, - _ => Err(WrongFixedRuleOptionError { - name: name.to_string(), - span: v.span(), - rule_name: self.manifest.fixed_handle.name.to_string(), - help: "a positive integer is required".to_string(), - } - .into()), - }, - None => match default { - Some(v) => Ok(v), - None => Err(FixedRuleOptionNotFoundError { - name: name.to_string(), - span: self.manifest.span, - rule_name: self.manifest.fixed_handle.name.to_string(), - } - .into()), - }, + None => Err(FixedRuleOptionNotFoundError { + name: name.to_string(), + span: self.manifest.span, + rule_name: self.manifest.fixed_handle.name.to_string(), + } + .into()), + Some(v) => Ok(v.span()), } } - pub fn non_neg_integer_option(&self, name: &str, default: Option) -> Result { + + pub fn integer_option(&self, name: &str, default: Option) -> Result { match self.manifest.options.get(name) { Some(v) => match v.clone().eval_to_const() { Ok(DataValue::Num(n)) => match n.get_int() { - Some(i) => { - ensure!( - i >= 0, - WrongFixedRuleOptionError { - name: name.to_string(), - span: v.span(), - rule_name: self.manifest.fixed_handle.name.to_string(), - help: "a non-negative integer is required".to_string(), - } - ); - Ok(i as usize) - } + Some(i) => Ok(i), None => Err(FixedRuleOptionNotFoundError { name: name.to_string(), span: self.manifest.span, @@ -431,7 +390,7 @@ impl<'a, 'b> FixedRulePayload<'a, 'b> { name: name.to_string(), span: v.span(), rule_name: self.manifest.fixed_handle.name.to_string(), - help: "a non-negative integer is required".to_string(), + help: "an integer is required".to_string(), } .into()), }, @@ -446,27 +405,45 @@ impl<'a, 'b> FixedRulePayload<'a, 'b> { }, } } - pub fn unit_interval_option(&self, name: &str, default: Option) -> Result { + + pub fn pos_integer_option(&self, name: &str, default: Option) -> Result { + let i = self.integer_option(name, default.map(|i| i as i64))?; + ensure!( + i > 0, + WrongFixedRuleOptionError { + name: name.to_string(), + span: self.option_span(name)?, + rule_name: self.manifest.fixed_handle.name.to_string(), + help: "a positive integer is required".to_string(), + } + ); + Ok(i as usize) + } + pub fn non_neg_integer_option(&self, name: &str, default: Option) -> Result { + let i = self.integer_option(name, default.map(|i| i as i64))?; + ensure!( + i >= 0, + WrongFixedRuleOptionError { + name: name.to_string(), + span: self.option_span(name)?, + rule_name: self.manifest.fixed_handle.name.to_string(), + help: "a non-negative integer is required".to_string(), + } + ); + Ok(i as usize) + } + pub fn float_option(&self, name: &str, default: Option) -> Result { match self.manifest.options.get(name) { Some(v) => match v.clone().eval_to_const() { Ok(DataValue::Num(n)) => { let f = n.get_float(); - ensure!( - (0. ..=1.).contains(&f), - WrongFixedRuleOptionError { - name: name.to_string(), - span: v.span(), - rule_name: self.manifest.fixed_handle.name.to_string(), - help: "a number between 0. and 1. is required".to_string(), - } - ); Ok(f) } _ => Err(WrongFixedRuleOptionError { name: name.to_string(), span: v.span(), rule_name: self.manifest.fixed_handle.name.to_string(), - help: "a number between 0. and 1. is required".to_string(), + help: "a floating number is required".to_string(), } .into()), }, @@ -481,7 +458,20 @@ impl<'a, 'b> FixedRulePayload<'a, 'b> { }, } } - pub(crate) fn bool_option(&self, name: &str, default: Option) -> Result { + pub fn unit_interval_option(&self, name: &str, default: Option) -> Result { + let f = self.float_option(name, default)?; + ensure!( + (0. ..=1.).contains(&f), + WrongFixedRuleOptionError { + name: name.to_string(), + span: self.option_span(name)?, + rule_name: self.manifest.fixed_handle.name.to_string(), + help: "a number between 0. and 1. is required".to_string(), + } + ); + Ok(f) + } + pub fn bool_option(&self, name: &str, default: Option) -> Result { match self.manifest.options.get(name) { Some(v) => match v.clone().eval_to_const() { Ok(DataValue::Bool(b)) => Ok(b), diff --git a/cozo-core/src/runtime/tests.rs b/cozo-core/src/runtime/tests.rs index 25ba5b16..26d3266b 100644 --- a/cozo-core/src/runtime/tests.rs +++ b/cozo-core/src/runtime/tests.rs @@ -7,16 +7,22 @@ * */ +use std::collections::BTreeMap; use std::sync::{Arc, Mutex}; use std::time::Duration; use itertools::Itertools; use log::debug; use serde_json::json; +use smartstring::{LazyCompact, SmartString}; +use crate::data::expr::Expr; +use crate::data::symb::Symbol; use crate::data::value::DataValue; -use crate::new_cozo_mem; -use crate::runtime::db::CallbackOp; +use crate::fixed_rule::FixedRulePayload; +use crate::parse::SourceSpan; +use crate::runtime::db::{CallbackOp, Poison}; +use crate::{new_cozo_mem, FixedRule, RegularTempStore}; #[test] fn test_limit_offset() { @@ -570,6 +576,57 @@ fn test_index() { assert!(joins.contains(&json!(":friends:rev"))); } +#[test] +fn test_custom_rules() { + let mut db = new_cozo_mem().unwrap(); + struct Custom; + + impl FixedRule for Custom { + fn arity( + &self, + _options: &BTreeMap, Expr>, + _rule_head: &[Symbol], + _span: SourceSpan, + ) -> miette::Result { + Ok(1) + } + + fn run( + &self, + payload: FixedRulePayload<'_, '_>, + out: &'_ mut RegularTempStore, + _poison: Poison, + ) -> miette::Result<()> { + let rel = payload.get_input(0)?; + let mult = payload.integer_option("mult", Some(2))?; + for maybe_row in rel.iter()? { + let row = maybe_row?; + let mut sum = 0; + for col in row { + let d = col.get_int().unwrap_or(0); + sum += d; + } + sum *= mult; + out.put(vec![DataValue::from(sum)]) + } + Ok(()) + } + } + + db.register_fixed_rule("SumCols".to_string(), Box::new(Custom)) + .unwrap(); + let res = db + .run_script( + r#" + rel[] <- [[1,2,3,4],[5,6,7,8]] + ?[x] <~ SumCols(rel[], mult: 100) + "#, + Default::default(), + ) + .unwrap(); + assert_eq!(res.into_json()["rows"], json!([[1000], [2600]])); +} + #[test] fn test_index_short() { let db = new_cozo_mem().unwrap();