simple fixed rule

main
Ziyang Hu 2 years ago
parent e98942aa78
commit 1f5949e4a3

@ -13,6 +13,7 @@ use std::sync::Arc;
use either::{Left, Right};
#[cfg(feature = "graph-algo")]
use graph::prelude::{CsrLayout, DirectedCsrGraph, GraphBuilder};
use itertools::Itertools;
use lazy_static::lazy_static;
#[allow(unused_imports)]
use miette::{bail, ensure, Diagnostic, Report, Result};
@ -20,6 +21,7 @@ use smartstring::{LazyCompact, SmartString};
use thiserror::Error;
use crate::data::expr::Expr;
use crate::data::json::JsonValue;
use crate::data::program::{
FixedRuleOptionNotFoundError, MagicFixedRuleApply, MagicFixedRuleRuleArg, MagicSymbol,
WrongFixedRuleOptionError,
@ -34,6 +36,7 @@ use crate::parse::SourceSpan;
use crate::runtime::db::Poison;
use crate::runtime::temp_store::{EpochStore, RegularTempStore};
use crate::runtime::transact::SessionTx;
use crate::NamedRows;
#[cfg(feature = "graph-algo")]
pub(crate) mod algos;
@ -558,6 +561,94 @@ pub trait FixedRule: Send + Sync {
) -> Result<()>;
}
/// Simple wrapper for custom fixed rule. You have less control than implementing [FixedRule] directly,
/// but implementation is simpler.
pub struct SimpleFixedRule {
return_arity: usize,
rule: Box<dyn Fn(Vec<NamedRows>, JsonValue) -> Result<NamedRows> + Send + Sync + 'static>,
}
impl SimpleFixedRule {
/// Construct a SimpleFixedRule.
///
/// * `return_arity`: The return arity of this rule.
/// * `rule`: The rule implementation as a closure.
// The first argument is a vector of input relations, realized into NamedRows,
// and the second argument is a JSON object of passed in options.
// The returned NamedRows is the return relation of the application of this rule.
// Every row of the returned relation must have length equal to `return_arity`.
pub fn new<R>(return_arity: usize, rule: R) -> Self
where
R: Fn(Vec<NamedRows>, JsonValue) -> Result<NamedRows> + Send + Sync + 'static,
{
Self {
return_arity,
rule: Box::new(rule),
}
}
}
impl FixedRule for SimpleFixedRule {
fn arity(
&self,
_options: &BTreeMap<SmartString<LazyCompact>, Expr>,
_rule_head: &[Symbol],
_span: SourceSpan,
) -> Result<usize> {
Ok(self.return_arity)
}
fn run(
&self,
payload: FixedRulePayload<'_, '_>,
out: &'_ mut RegularTempStore,
_poison: Poison,
) -> Result<()> {
let options: JsonValue = payload
.manifest
.options
.iter()
.map(|(k, v)| -> Result<_> {
let val = v.clone().eval_to_const()?;
Ok((k.to_string(), JsonValue::from(val)))
})
.try_collect()?;
let input_arity = payload.manifest.rule_args.len();
let inputs: Vec<_> = (0..input_arity)
.map(|i| -> Result<_> {
let input = payload.get_input(i).unwrap();
let rows: Vec<_> = input.iter()?.try_collect()?;
let mut headers = input
.arg_manifest
.bindings()
.iter()
.map(|s| s.name.to_string())
.collect_vec();
let l = headers.len();
let m = input.arg_manifest.arity(&payload.tx, &payload.stores)?;
for i in l..m {
headers.push(format!("_{i}"));
}
Ok(NamedRows::new(headers, rows))
})
.try_collect()?;
let results: NamedRows = (self.rule)(inputs, options)?;
for row in results.rows {
#[derive(Debug, Error, Diagnostic)]
#[error("arity mismatch: expect {0}, got {1}")]
#[diagnostic(code(parser::simple_fixed_rule_arity_mismatch))]
struct ArityMismatch(#[label] SourceSpan, usize, usize);
ensure!(
row.len() == self.return_arity,
ArityMismatch(payload.span(), self.return_arity, row.len())
);
out.put(row);
}
Ok(())
}
}
#[derive(Debug, Error, Diagnostic)]
#[error("Cannot determine arity for algo {0} since {1}")]
#[diagnostic(code(parser::no_algo_arity))]

@ -66,6 +66,7 @@ pub use storage::{Storage, StoreTx};
pub use crate::data::expr::Expr;
use crate::data::json::JsonValue;
pub use crate::data::symb::Symbol;
pub use crate::fixed_rule::SimpleFixedRule;
pub use crate::parse::SourceSpan;
pub use crate::runtime::callback::CallbackOp;
pub use crate::runtime::db::Poison;

Loading…
Cancel
Save