tests for custom fixed rules

main
Ziyang Hu 2 years ago
parent bf48a44030
commit 758f7392f0

@ -362,64 +362,23 @@ impl<'a, 'b> FixedRulePayload<'a, 'b> {
}
}
pub fn pos_integer_option(&self, name: &str, default: Option<usize>) -> Result<usize> {
pub fn option_span(&self, name: &str) -> Result<SourceSpan> {
match self.manifest.options.get(name) {
Some(v) => match v.clone().eval_to_const() {
Ok(DataValue::Num(n)) => match n.get_int() {
Some(i) => {
ensure!(
i > 0,
WrongFixedRuleOptionError {
name: name.to_string(),
span: v.span(),
rule_name: self.manifest.fixed_handle.name.to_string(),
help: "a positive integer is required".to_string(),
}
);
Ok(i as usize)
}
None => Err(FixedRuleOptionNotFoundError {
name: name.to_string(),
span: self.span(),
rule_name: self.manifest.fixed_handle.name.to_string(),
}
.into()),
},
_ => Err(WrongFixedRuleOptionError {
name: name.to_string(),
span: v.span(),
rule_name: self.manifest.fixed_handle.name.to_string(),
help: "a positive integer is required".to_string(),
}
.into()),
},
None => match default {
Some(v) => Ok(v),
None => Err(FixedRuleOptionNotFoundError {
name: name.to_string(),
span: self.manifest.span,
rule_name: self.manifest.fixed_handle.name.to_string(),
}
.into()),
},
None => Err(FixedRuleOptionNotFoundError {
name: name.to_string(),
span: self.manifest.span,
rule_name: self.manifest.fixed_handle.name.to_string(),
}
.into()),
Some(v) => Ok(v.span()),
}
}
pub fn non_neg_integer_option(&self, name: &str, default: Option<usize>) -> Result<usize> {
pub fn integer_option(&self, name: &str, default: Option<i64>) -> Result<i64> {
match self.manifest.options.get(name) {
Some(v) => match v.clone().eval_to_const() {
Ok(DataValue::Num(n)) => match n.get_int() {
Some(i) => {
ensure!(
i >= 0,
WrongFixedRuleOptionError {
name: name.to_string(),
span: v.span(),
rule_name: self.manifest.fixed_handle.name.to_string(),
help: "a non-negative integer is required".to_string(),
}
);
Ok(i as usize)
}
Some(i) => Ok(i),
None => Err(FixedRuleOptionNotFoundError {
name: name.to_string(),
span: self.manifest.span,
@ -431,7 +390,7 @@ impl<'a, 'b> FixedRulePayload<'a, 'b> {
name: name.to_string(),
span: v.span(),
rule_name: self.manifest.fixed_handle.name.to_string(),
help: "a non-negative integer is required".to_string(),
help: "an integer is required".to_string(),
}
.into()),
},
@ -446,27 +405,45 @@ impl<'a, 'b> FixedRulePayload<'a, 'b> {
},
}
}
pub fn unit_interval_option(&self, name: &str, default: Option<f64>) -> Result<f64> {
pub fn pos_integer_option(&self, name: &str, default: Option<usize>) -> Result<usize> {
let i = self.integer_option(name, default.map(|i| i as i64))?;
ensure!(
i > 0,
WrongFixedRuleOptionError {
name: name.to_string(),
span: self.option_span(name)?,
rule_name: self.manifest.fixed_handle.name.to_string(),
help: "a positive integer is required".to_string(),
}
);
Ok(i as usize)
}
pub fn non_neg_integer_option(&self, name: &str, default: Option<usize>) -> Result<usize> {
let i = self.integer_option(name, default.map(|i| i as i64))?;
ensure!(
i >= 0,
WrongFixedRuleOptionError {
name: name.to_string(),
span: self.option_span(name)?,
rule_name: self.manifest.fixed_handle.name.to_string(),
help: "a non-negative integer is required".to_string(),
}
);
Ok(i as usize)
}
pub fn float_option(&self, name: &str, default: Option<f64>) -> Result<f64> {
match self.manifest.options.get(name) {
Some(v) => match v.clone().eval_to_const() {
Ok(DataValue::Num(n)) => {
let f = n.get_float();
ensure!(
(0. ..=1.).contains(&f),
WrongFixedRuleOptionError {
name: name.to_string(),
span: v.span(),
rule_name: self.manifest.fixed_handle.name.to_string(),
help: "a number between 0. and 1. is required".to_string(),
}
);
Ok(f)
}
_ => Err(WrongFixedRuleOptionError {
name: name.to_string(),
span: v.span(),
rule_name: self.manifest.fixed_handle.name.to_string(),
help: "a number between 0. and 1. is required".to_string(),
help: "a floating number is required".to_string(),
}
.into()),
},
@ -481,7 +458,20 @@ impl<'a, 'b> FixedRulePayload<'a, 'b> {
},
}
}
pub(crate) fn bool_option(&self, name: &str, default: Option<bool>) -> Result<bool> {
pub fn unit_interval_option(&self, name: &str, default: Option<f64>) -> Result<f64> {
let f = self.float_option(name, default)?;
ensure!(
(0. ..=1.).contains(&f),
WrongFixedRuleOptionError {
name: name.to_string(),
span: self.option_span(name)?,
rule_name: self.manifest.fixed_handle.name.to_string(),
help: "a number between 0. and 1. is required".to_string(),
}
);
Ok(f)
}
pub fn bool_option(&self, name: &str, default: Option<bool>) -> Result<bool> {
match self.manifest.options.get(name) {
Some(v) => match v.clone().eval_to_const() {
Ok(DataValue::Bool(b)) => Ok(b),

@ -7,16 +7,22 @@
*
*/
use std::collections::BTreeMap;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use itertools::Itertools;
use log::debug;
use serde_json::json;
use smartstring::{LazyCompact, SmartString};
use crate::data::expr::Expr;
use crate::data::symb::Symbol;
use crate::data::value::DataValue;
use crate::new_cozo_mem;
use crate::runtime::db::CallbackOp;
use crate::fixed_rule::FixedRulePayload;
use crate::parse::SourceSpan;
use crate::runtime::db::{CallbackOp, Poison};
use crate::{new_cozo_mem, FixedRule, RegularTempStore};
#[test]
fn test_limit_offset() {
@ -570,6 +576,57 @@ fn test_index() {
assert!(joins.contains(&json!(":friends:rev")));
}
#[test]
fn test_custom_rules() {
let mut db = new_cozo_mem().unwrap();
struct Custom;
impl FixedRule for Custom {
fn arity(
&self,
_options: &BTreeMap<SmartString<LazyCompact>, Expr>,
_rule_head: &[Symbol],
_span: SourceSpan,
) -> miette::Result<usize> {
Ok(1)
}
fn run(
&self,
payload: FixedRulePayload<'_, '_>,
out: &'_ mut RegularTempStore,
_poison: Poison,
) -> miette::Result<()> {
let rel = payload.get_input(0)?;
let mult = payload.integer_option("mult", Some(2))?;
for maybe_row in rel.iter()? {
let row = maybe_row?;
let mut sum = 0;
for col in row {
let d = col.get_int().unwrap_or(0);
sum += d;
}
sum *= mult;
out.put(vec![DataValue::from(sum)])
}
Ok(())
}
}
db.register_fixed_rule("SumCols".to_string(), Box::new(Custom))
.unwrap();
let res = db
.run_script(
r#"
rel[] <- [[1,2,3,4],[5,6,7,8]]
?[x] <~ SumCols(rel[], mult: 100)
"#,
Default::default(),
)
.unwrap();
assert_eq!(res.into_json()["rows"], json!([[1000], [2600]]));
}
#[test]
fn test_index_short() {
let db = new_cozo_mem().unwrap();

Loading…
Cancel
Save