diff --git a/cozo-core/src/fixed_rule/mod.rs b/cozo-core/src/fixed_rule/mod.rs index 7cf44433..01f140f6 100644 --- a/cozo-core/src/fixed_rule/mod.rs +++ b/cozo-core/src/fixed_rule/mod.rs @@ -13,6 +13,7 @@ use std::sync::Arc; use either::{Left, Right}; #[cfg(feature = "graph-algo")] use graph::prelude::{CsrLayout, DirectedCsrGraph, GraphBuilder}; +use itertools::Itertools; use lazy_static::lazy_static; #[allow(unused_imports)] use miette::{bail, ensure, Diagnostic, Report, Result}; @@ -20,6 +21,7 @@ use smartstring::{LazyCompact, SmartString}; use thiserror::Error; use crate::data::expr::Expr; +use crate::data::json::JsonValue; use crate::data::program::{ FixedRuleOptionNotFoundError, MagicFixedRuleApply, MagicFixedRuleRuleArg, MagicSymbol, WrongFixedRuleOptionError, @@ -34,6 +36,7 @@ use crate::parse::SourceSpan; use crate::runtime::db::Poison; use crate::runtime::temp_store::{EpochStore, RegularTempStore}; use crate::runtime::transact::SessionTx; +use crate::NamedRows; #[cfg(feature = "graph-algo")] pub(crate) mod algos; @@ -558,6 +561,94 @@ pub trait FixedRule: Send + Sync { ) -> Result<()>; } +/// Simple wrapper for custom fixed rule. You have less control than implementing [FixedRule] directly, +/// but implementation is simpler. +pub struct SimpleFixedRule { + return_arity: usize, + rule: Box, JsonValue) -> Result + Send + Sync + 'static>, +} + +impl SimpleFixedRule { + /// Construct a SimpleFixedRule. + /// + /// * `return_arity`: The return arity of this rule. + /// * `rule`: The rule implementation as a closure. + // The first argument is a vector of input relations, realized into NamedRows, + // and the second argument is a JSON object of passed in options. + // The returned NamedRows is the return relation of the application of this rule. + // Every row of the returned relation must have length equal to `return_arity`. + pub fn new(return_arity: usize, rule: R) -> Self + where + R: Fn(Vec, JsonValue) -> Result + Send + Sync + 'static, + { + Self { + return_arity, + rule: Box::new(rule), + } + } +} + +impl FixedRule for SimpleFixedRule { + fn arity( + &self, + _options: &BTreeMap, Expr>, + _rule_head: &[Symbol], + _span: SourceSpan, + ) -> Result { + Ok(self.return_arity) + } + + fn run( + &self, + payload: FixedRulePayload<'_, '_>, + out: &'_ mut RegularTempStore, + _poison: Poison, + ) -> Result<()> { + let options: JsonValue = payload + .manifest + .options + .iter() + .map(|(k, v)| -> Result<_> { + let val = v.clone().eval_to_const()?; + Ok((k.to_string(), JsonValue::from(val))) + }) + .try_collect()?; + let input_arity = payload.manifest.rule_args.len(); + let inputs: Vec<_> = (0..input_arity) + .map(|i| -> Result<_> { + let input = payload.get_input(i).unwrap(); + let rows: Vec<_> = input.iter()?.try_collect()?; + let mut headers = input + .arg_manifest + .bindings() + .iter() + .map(|s| s.name.to_string()) + .collect_vec(); + let l = headers.len(); + let m = input.arg_manifest.arity(&payload.tx, &payload.stores)?; + for i in l..m { + headers.push(format!("_{i}")); + } + Ok(NamedRows::new(headers, rows)) + }) + .try_collect()?; + let results: NamedRows = (self.rule)(inputs, options)?; + for row in results.rows { + #[derive(Debug, Error, Diagnostic)] + #[error("arity mismatch: expect {0}, got {1}")] + #[diagnostic(code(parser::simple_fixed_rule_arity_mismatch))] + struct ArityMismatch(#[label] SourceSpan, usize, usize); + + ensure!( + row.len() == self.return_arity, + ArityMismatch(payload.span(), self.return_arity, row.len()) + ); + out.put(row); + } + Ok(()) + } +} + #[derive(Debug, Error, Diagnostic)] #[error("Cannot determine arity for algo {0} since {1}")] #[diagnostic(code(parser::no_algo_arity))] diff --git a/cozo-core/src/lib.rs b/cozo-core/src/lib.rs index d160cb5a..f3274c5c 100644 --- a/cozo-core/src/lib.rs +++ b/cozo-core/src/lib.rs @@ -66,6 +66,7 @@ pub use storage::{Storage, StoreTx}; pub use crate::data::expr::Expr; use crate::data::json::JsonValue; pub use crate::data::symb::Symbol; +pub use crate::fixed_rule::SimpleFixedRule; pub use crate::parse::SourceSpan; pub use crate::runtime::callback::CallbackOp; pub use crate::runtime::db::Poison;