|
|
|
@ -13,6 +13,7 @@ use std::sync::Arc;
|
|
|
|
|
use either::{Left, Right};
|
|
|
|
|
#[cfg(feature = "graph-algo")]
|
|
|
|
|
use graph::prelude::{CsrLayout, DirectedCsrGraph, GraphBuilder};
|
|
|
|
|
use itertools::Itertools;
|
|
|
|
|
use lazy_static::lazy_static;
|
|
|
|
|
#[allow(unused_imports)]
|
|
|
|
|
use miette::{bail, ensure, Diagnostic, Report, Result};
|
|
|
|
@ -20,6 +21,7 @@ use smartstring::{LazyCompact, SmartString};
|
|
|
|
|
use thiserror::Error;
|
|
|
|
|
|
|
|
|
|
use crate::data::expr::Expr;
|
|
|
|
|
use crate::data::json::JsonValue;
|
|
|
|
|
use crate::data::program::{
|
|
|
|
|
FixedRuleOptionNotFoundError, MagicFixedRuleApply, MagicFixedRuleRuleArg, MagicSymbol,
|
|
|
|
|
WrongFixedRuleOptionError,
|
|
|
|
@ -34,6 +36,7 @@ use crate::parse::SourceSpan;
|
|
|
|
|
use crate::runtime::db::Poison;
|
|
|
|
|
use crate::runtime::temp_store::{EpochStore, RegularTempStore};
|
|
|
|
|
use crate::runtime::transact::SessionTx;
|
|
|
|
|
use crate::NamedRows;
|
|
|
|
|
|
|
|
|
|
#[cfg(feature = "graph-algo")]
|
|
|
|
|
pub(crate) mod algos;
|
|
|
|
@ -558,6 +561,94 @@ pub trait FixedRule: Send + Sync {
|
|
|
|
|
) -> Result<()>;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Simple wrapper for custom fixed rule. You have less control than implementing [FixedRule] directly,
|
|
|
|
|
/// but implementation is simpler.
|
|
|
|
|
pub struct SimpleFixedRule {
|
|
|
|
|
return_arity: usize,
|
|
|
|
|
rule: Box<dyn Fn(Vec<NamedRows>, JsonValue) -> Result<NamedRows> + Send + Sync + 'static>,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl SimpleFixedRule {
|
|
|
|
|
/// Construct a SimpleFixedRule.
|
|
|
|
|
///
|
|
|
|
|
/// * `return_arity`: The return arity of this rule.
|
|
|
|
|
/// * `rule`: The rule implementation as a closure.
|
|
|
|
|
// The first argument is a vector of input relations, realized into NamedRows,
|
|
|
|
|
// and the second argument is a JSON object of passed in options.
|
|
|
|
|
// The returned NamedRows is the return relation of the application of this rule.
|
|
|
|
|
// Every row of the returned relation must have length equal to `return_arity`.
|
|
|
|
|
pub fn new<R>(return_arity: usize, rule: R) -> Self
|
|
|
|
|
where
|
|
|
|
|
R: Fn(Vec<NamedRows>, JsonValue) -> Result<NamedRows> + Send + Sync + 'static,
|
|
|
|
|
{
|
|
|
|
|
Self {
|
|
|
|
|
return_arity,
|
|
|
|
|
rule: Box::new(rule),
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl FixedRule for SimpleFixedRule {
|
|
|
|
|
fn arity(
|
|
|
|
|
&self,
|
|
|
|
|
_options: &BTreeMap<SmartString<LazyCompact>, Expr>,
|
|
|
|
|
_rule_head: &[Symbol],
|
|
|
|
|
_span: SourceSpan,
|
|
|
|
|
) -> Result<usize> {
|
|
|
|
|
Ok(self.return_arity)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn run(
|
|
|
|
|
&self,
|
|
|
|
|
payload: FixedRulePayload<'_, '_>,
|
|
|
|
|
out: &'_ mut RegularTempStore,
|
|
|
|
|
_poison: Poison,
|
|
|
|
|
) -> Result<()> {
|
|
|
|
|
let options: JsonValue = payload
|
|
|
|
|
.manifest
|
|
|
|
|
.options
|
|
|
|
|
.iter()
|
|
|
|
|
.map(|(k, v)| -> Result<_> {
|
|
|
|
|
let val = v.clone().eval_to_const()?;
|
|
|
|
|
Ok((k.to_string(), JsonValue::from(val)))
|
|
|
|
|
})
|
|
|
|
|
.try_collect()?;
|
|
|
|
|
let input_arity = payload.manifest.rule_args.len();
|
|
|
|
|
let inputs: Vec<_> = (0..input_arity)
|
|
|
|
|
.map(|i| -> Result<_> {
|
|
|
|
|
let input = payload.get_input(i).unwrap();
|
|
|
|
|
let rows: Vec<_> = input.iter()?.try_collect()?;
|
|
|
|
|
let mut headers = input
|
|
|
|
|
.arg_manifest
|
|
|
|
|
.bindings()
|
|
|
|
|
.iter()
|
|
|
|
|
.map(|s| s.name.to_string())
|
|
|
|
|
.collect_vec();
|
|
|
|
|
let l = headers.len();
|
|
|
|
|
let m = input.arg_manifest.arity(&payload.tx, &payload.stores)?;
|
|
|
|
|
for i in l..m {
|
|
|
|
|
headers.push(format!("_{i}"));
|
|
|
|
|
}
|
|
|
|
|
Ok(NamedRows::new(headers, rows))
|
|
|
|
|
})
|
|
|
|
|
.try_collect()?;
|
|
|
|
|
let results: NamedRows = (self.rule)(inputs, options)?;
|
|
|
|
|
for row in results.rows {
|
|
|
|
|
#[derive(Debug, Error, Diagnostic)]
|
|
|
|
|
#[error("arity mismatch: expect {0}, got {1}")]
|
|
|
|
|
#[diagnostic(code(parser::simple_fixed_rule_arity_mismatch))]
|
|
|
|
|
struct ArityMismatch(#[label] SourceSpan, usize, usize);
|
|
|
|
|
|
|
|
|
|
ensure!(
|
|
|
|
|
row.len() == self.return_arity,
|
|
|
|
|
ArityMismatch(payload.span(), self.return_arity, row.len())
|
|
|
|
|
);
|
|
|
|
|
out.put(row);
|
|
|
|
|
}
|
|
|
|
|
Ok(())
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[derive(Debug, Error, Diagnostic)]
|
|
|
|
|
#[error("Cannot determine arity for algo {0} since {1}")]
|
|
|
|
|
#[diagnostic(code(parser::no_algo_arity))]
|
|
|
|
|