diff --git a/src/algo/all_pairs_shortest_path.rs b/src/algo/all_pairs_shortest_path.rs index 820b8d63..1dfdfd00 100644 --- a/src/algo/all_pairs_shortest_path.rs +++ b/src/algo/all_pairs_shortest_path.rs @@ -6,12 +6,16 @@ use miette::Result; use ordered_float::OrderedFloat; use priority_queue::PriorityQueue; use rayon::prelude::*; +use smartstring::{LazyCompact, SmartString}; -use crate::algo::AlgoImpl; use crate::algo::shortest_path_dijkstra::dijkstra_keep_ties; +use crate::algo::AlgoImpl; +use crate::data::expr::Expr; use crate::data::program::{MagicAlgoApply, MagicSymbol}; +use crate::data::symb::Symbol; use crate::data::tuple::Tuple; use crate::data::value::DataValue; +use crate::parse::SourceSpan; use crate::runtime::db::Poison; use crate::runtime::in_mem::InMemRelation; use crate::runtime::transact::SessionTx; @@ -75,6 +79,15 @@ impl AlgoImpl for BetweennessCentrality { Ok(()) } + + fn arity( + &self, + _options: &BTreeMap, Expr>, + _rule_head: &[Symbol], + _span: SourceSpan, + ) -> Result { + Ok(2) + } } pub(crate) struct ClosenessCentrality; @@ -116,6 +129,15 @@ impl AlgoImpl for ClosenessCentrality { } Ok(()) } + + fn arity( + &self, + _options: &BTreeMap, Expr>, + _rule_head: &[Symbol], + _span: SourceSpan, + ) -> Result { + Ok(2) + } } pub(crate) fn dijkstra_cost_only( diff --git a/src/algo/astar.rs b/src/algo/astar.rs index 1484cbbd..b13ed6b5 100644 --- a/src/algo/astar.rs +++ b/src/algo/astar.rs @@ -4,12 +4,15 @@ use std::collections::BTreeMap; use miette::{ensure, Result}; use ordered_float::OrderedFloat; use priority_queue::PriorityQueue; +use smartstring::{LazyCompact, SmartString}; use crate::algo::{AlgoImpl, BadExprValueError, NodeNotFoundError}; use crate::data::expr::Expr; use crate::data::program::{MagicAlgoApply, MagicAlgoRuleArg, MagicSymbol}; +use crate::data::symb::Symbol; use crate::data::tuple::Tuple; use crate::data::value::DataValue; +use crate::parse::SourceSpan; use crate::runtime::db::Poison; use crate::runtime::in_mem::InMemRelation; use crate::runtime::transact::SessionTx; @@ -63,6 +66,15 @@ impl AlgoImpl for ShortestPathAStar { Ok(()) } + + fn arity( + &self, + _options: &BTreeMap, Expr>, + _rule_head: &[Symbol], + _span: SourceSpan, + ) -> Result { + Ok(4) + } } fn astar( diff --git a/src/algo/bfs.rs b/src/algo/bfs.rs index 55f5a4cb..a389f5ef 100644 --- a/src/algo/bfs.rs +++ b/src/algo/bfs.rs @@ -1,11 +1,15 @@ use std::collections::{BTreeMap, BTreeSet, VecDeque}; -use miette::Result; +use miette::{Result}; +use smartstring::{LazyCompact, SmartString}; use crate::algo::{AlgoImpl, NodeNotFoundError}; +use crate::data::expr::Expr; use crate::data::program::{MagicAlgoApply, MagicSymbol}; +use crate::data::symb::Symbol; use crate::data::tuple::Tuple; use crate::data::value::DataValue; +use crate::parse::SourceSpan; use crate::runtime::db::Poison; use crate::runtime::in_mem::InMemRelation; use crate::runtime::transact::SessionTx; @@ -96,4 +100,13 @@ impl AlgoImpl for Bfs { } Ok(()) } + + fn arity( + &self, + _options: &BTreeMap, Expr>, + _rule_head: &[Symbol], + _span: SourceSpan, + ) -> Result { + Ok(1) + } } diff --git a/src/algo/csv.rs b/src/algo/csv.rs index c89d8a3e..a79a2a9f 100644 --- a/src/algo/csv.rs +++ b/src/algo/csv.rs @@ -2,15 +2,19 @@ use std::collections::BTreeMap; use csv::StringRecord; use miette::{bail, ensure, IntoDiagnostic, Result}; -use smartstring::SmartString; +use smartstring::{LazyCompact, SmartString}; -use crate::algo::AlgoImpl; +use crate::algo::{AlgoImpl, CannotDetermineArity}; +use crate::data::expr::Expr; use crate::data::functions::{op_to_float, op_to_uuid}; -use crate::data::program::{MagicAlgoApply, MagicSymbol, WrongAlgoOptionError}; +use crate::data::program::{ + AlgoOptionNotFoundError, MagicAlgoApply, MagicSymbol, WrongAlgoOptionError, +}; use crate::data::relation::{ColType, NullableColType}; +use crate::data::symb::Symbol; use crate::data::tuple::Tuple; use crate::data::value::DataValue; -use crate::parse::parse_type; +use crate::parse::{parse_type, SourceSpan}; use crate::runtime::db::Poison; use crate::runtime::in_mem::InMemRelation; use crate::runtime::transact::SessionTx; @@ -149,9 +153,7 @@ impl AlgoImpl for CsvReader { } } None => { - let content = minreq::get(&url as &str) - .send() - .into_diagnostic()?; + let content = minreq::get(&url as &str).send().into_diagnostic()?; let mut rdr = rdr_builder.from_reader(content.as_bytes()); for record in rdr.records() { let record = record.into_diagnostic()?; @@ -161,4 +163,44 @@ impl AlgoImpl for CsvReader { } Ok(()) } + + fn arity( + &self, + options: &BTreeMap, Expr>, + _rule_head: &[Symbol], + span: SourceSpan, + ) -> Result { + let with_row_num = match options.get("prepend_index") { + None => 0, + Some(Expr::Const { + val: DataValue::Bool(true), + .. + }) => 1, + Some(Expr::Const { + val: DataValue::Bool(false), + .. + }) => 0, + _ => bail!(CannotDetermineArity( + "CsvReader".to_string(), + "invalid option 'prepend_index' given, expect a boolean".to_string(), + span + )), + }; + let columns = options + .get("types") + .ok_or_else(|| AlgoOptionNotFoundError { + name: "types".to_string(), + span, + algo_name: "CsvReader".to_string(), + })?; + let columns = columns.clone().eval_to_const()?; + if let Some(l) = columns.get_list() { + return Ok(l.len() + with_row_num); + } + bail!(CannotDetermineArity( + "CsvReader".to_string(), + "invalid option 'types' given, expect positive number or list".to_string(), + span + )) + } } diff --git a/src/algo/degree_centrality.rs b/src/algo/degree_centrality.rs index 0adfa358..47e882ce 100644 --- a/src/algo/degree_centrality.rs +++ b/src/algo/degree_centrality.rs @@ -1,11 +1,15 @@ use std::collections::BTreeMap; use miette::Result; +use smartstring::{LazyCompact, SmartString}; use crate::algo::AlgoImpl; +use crate::data::expr::Expr; use crate::data::program::{MagicAlgoApply, MagicSymbol}; +use crate::data::symb::Symbol; use crate::data::tuple::Tuple; use crate::data::value::DataValue; +use crate::parse::SourceSpan; use crate::runtime::db::Poison; use crate::runtime::in_mem::InMemRelation; use crate::runtime::transact::SessionTx; @@ -60,4 +64,13 @@ impl AlgoImpl for DegreeCentrality { } Ok(()) } + + fn arity( + &self, + _options: &BTreeMap, Expr>, + _rule_head: &[Symbol], + _span: SourceSpan, + ) -> Result { + Ok(4) + } } diff --git a/src/algo/dfs.rs b/src/algo/dfs.rs index 1da3ba7b..4625f4f1 100644 --- a/src/algo/dfs.rs +++ b/src/algo/dfs.rs @@ -1,11 +1,15 @@ use std::collections::{BTreeMap, BTreeSet}; use miette::Result; +use smartstring::{LazyCompact, SmartString}; use crate::algo::{AlgoImpl, NodeNotFoundError}; +use crate::data::expr::Expr; use crate::data::program::{MagicAlgoApply, MagicSymbol}; +use crate::data::symb::Symbol; use crate::data::tuple::Tuple; use crate::data::value::DataValue; +use crate::parse::SourceSpan; use crate::runtime::db::Poison; use crate::runtime::in_mem::InMemRelation; use crate::runtime::transact::SessionTx; @@ -99,4 +103,13 @@ impl AlgoImpl for Dfs { } Ok(()) } + + fn arity( + &self, + _options: &BTreeMap, Expr>, + _rule_head: &[Symbol], + _span: SourceSpan, + ) -> Result { + Ok(1) + } } diff --git a/src/algo/jlines.rs b/src/algo/jlines.rs index 5211e109..587614ca 100644 --- a/src/algo/jlines.rs +++ b/src/algo/jlines.rs @@ -5,11 +5,14 @@ use std::{fs, io}; use itertools::Itertools; use miette::{bail, miette, Diagnostic, IntoDiagnostic, Result}; +use smartstring::{LazyCompact, SmartString}; use thiserror::Error; -use crate::algo::AlgoImpl; +use crate::algo::{AlgoImpl, CannotDetermineArity}; +use crate::data::expr::Expr; use crate::data::json::JsonValue; use crate::data::program::{MagicAlgoApply, MagicSymbol}; +use crate::data::symb::Symbol; use crate::data::tuple::Tuple; use crate::data::value::DataValue; use crate::parse::SourceSpan; @@ -121,4 +124,43 @@ impl AlgoImpl for JsonReader { } Ok(()) } + + fn arity( + &self, + opts: &BTreeMap, Expr>, + _rule_head: &[Symbol], + span: SourceSpan, + ) -> Result { + let with_row_num = match opts.get("prepend_index") { + None => 0, + Some(Expr::Const { + val: DataValue::Bool(true), + .. + }) => 1, + Some(Expr::Const { + val: DataValue::Bool(false), + .. + }) => 0, + _ => bail!(CannotDetermineArity( + "JsonReader".to_string(), + "invalid option 'prepend_index' given, expect a boolean".to_string(), + span + )), + }; + let fields = opts.get("fields").ok_or_else(|| { + CannotDetermineArity( + "JsonReader".to_string(), + "option 'fields' not provided".to_string(), + span, + ) + })?; + Ok(match fields.clone().eval_to_const()? { + DataValue::List(l) => l.len() + with_row_num, + _ => bail!(CannotDetermineArity( + "JsonReader".to_string(), + "invalid option 'fields' given, expect a list".to_string(), + span + )), + }) + } } diff --git a/src/algo/kruskal.rs b/src/algo/kruskal.rs index 77a7f1ca..d6001624 100644 --- a/src/algo/kruskal.rs +++ b/src/algo/kruskal.rs @@ -5,11 +5,15 @@ use itertools::Itertools; use miette::Result; use ordered_float::OrderedFloat; use priority_queue::PriorityQueue; +use smartstring::{LazyCompact, SmartString}; use crate::algo::AlgoImpl; +use crate::data::expr::Expr; use crate::data::program::{MagicAlgoApply, MagicSymbol}; +use crate::data::symb::Symbol; use crate::data::tuple::Tuple; use crate::data::value::DataValue; +use crate::parse::SourceSpan; use crate::runtime::db::Poison; use crate::runtime::in_mem::InMemRelation; use crate::runtime::transact::SessionTx; @@ -45,6 +49,15 @@ impl AlgoImpl for MinimumSpanningForestKruskal { Ok(()) } + + fn arity( + &self, + _options: &BTreeMap, Expr>, + _rule_head: &[Symbol], + _span: SourceSpan, + ) -> Result { + Ok(3) + } } fn kruskal(edges: &[Vec<(usize, f64)>], poison: Poison) -> Result> { diff --git a/src/algo/label_propagation.rs b/src/algo/label_propagation.rs index 1bb41cbd..492a8b3c 100644 --- a/src/algo/label_propagation.rs +++ b/src/algo/label_propagation.rs @@ -3,11 +3,15 @@ use std::collections::BTreeMap; use itertools::Itertools; use miette::Result; use rand::prelude::*; +use smartstring::{LazyCompact, SmartString}; use crate::algo::AlgoImpl; +use crate::data::expr::Expr; use crate::data::program::{MagicAlgoApply, MagicSymbol}; +use crate::data::symb::Symbol; use crate::data::tuple::Tuple; use crate::data::value::DataValue; +use crate::parse::SourceSpan; use crate::runtime::db::Poison; use crate::runtime::in_mem::InMemRelation; use crate::runtime::transact::SessionTx; @@ -35,6 +39,15 @@ impl AlgoImpl for LabelPropagation { } Ok(()) } + + fn arity( + &self, + _options: &BTreeMap, Expr>, + _rule_head: &[Symbol], + _span: SourceSpan, + ) -> Result { + Ok(2) + } } fn label_propagation( diff --git a/src/algo/louvain.rs b/src/algo/louvain.rs index 83ea960d..b969b0ff 100644 --- a/src/algo/louvain.rs +++ b/src/algo/louvain.rs @@ -3,11 +3,15 @@ use std::collections::{BTreeMap, BTreeSet}; use itertools::Itertools; use log::debug; use miette::Result; +use smartstring::{LazyCompact, SmartString}; use crate::algo::AlgoImpl; +use crate::data::expr::Expr; use crate::data::program::{MagicAlgoApply, MagicSymbol}; +use crate::data::symb::Symbol; use crate::data::tuple::Tuple; use crate::data::value::DataValue; +use crate::parse::SourceSpan; use crate::runtime::db::Poison; use crate::runtime::in_mem::InMemRelation; use crate::runtime::transact::SessionTx; @@ -59,6 +63,15 @@ impl AlgoImpl for CommunityDetectionLouvain { Ok(()) } + + fn arity( + &self, + _options: &BTreeMap, Expr>, + _rule_head: &[Symbol], + _span: SourceSpan, + ) -> Result { + Ok(2) + } } fn louvain( diff --git a/src/algo/mod.rs b/src/algo/mod.rs index 73af7650..1396085b 100644 --- a/src/algo/mod.rs +++ b/src/algo/mod.rs @@ -1,6 +1,5 @@ use std::collections::BTreeMap; -use either::Either; use miette::{bail, ensure, Diagnostic, Result}; use smartstring::{LazyCompact, SmartString}; use thiserror::Error; @@ -25,10 +24,7 @@ use crate::algo::top_sort::TopSort; use crate::algo::triangles::ClusteringCoefficients; use crate::algo::yen::KShortestPathYen; use crate::data::expr::Expr; -use crate::data::functions::OP_LIST; -use crate::data::program::{ - AlgoOptionNotFoundError, AlgoRuleArg, MagicAlgoApply, MagicAlgoRuleArg, MagicSymbol, -}; +use crate::data::program::{MagicAlgoApply, MagicAlgoRuleArg, MagicSymbol}; use crate::data::symb::Symbol; use crate::data::tuple::{Tuple, TupleIter}; use crate::data::value::DataValue; @@ -66,8 +62,23 @@ pub(crate) trait AlgoImpl { out: &InMemRelation, poison: Poison, ) -> Result<()>; + fn arity( + &self, + _options: &BTreeMap, Expr>, + _rule_head: &[Symbol], + _span: SourceSpan, + ) -> Result; } +#[derive(Debug, Error, Diagnostic)] +#[error("Cannot determine arity for algo {0} since {1}")] +#[diagnostic(code(parser::no_algo_arity))] +pub(crate) struct CannotDetermineArity( + pub(crate) String, + pub(crate) String, + #[label] pub(crate) SourceSpan, +); + #[derive(Clone, Debug)] pub(crate) struct AlgoHandle { pub(crate) name: Symbol, @@ -79,124 +90,6 @@ impl AlgoHandle { name: Symbol::new(name, span), } } - pub(crate) fn arity( - &self, - _args: Either<&[AlgoRuleArg], &[MagicAlgoRuleArg]>, - opts: &BTreeMap, Expr>, - ) -> Result { - #[derive(Debug, Error, Diagnostic)] - #[error("Cannot determine arity for algo {0} since {1}")] - #[diagnostic(code(parser::no_algo_arity))] - struct CannotDetermineArity(String, String, #[label] SourceSpan); - - Ok(match &self.name.name as &str { - "ClusteringCoefficients" => 4, - "DegreeCentrality" => 4, - "ClosenessCentrality" => 2, - "BetweennessCentrality" => 2, - "DepthFirstSearch" | "DFS" => 1, - "BreadthFirstSearch" | "BFS" => 1, - "ShortestPathDijkstra" => 4, - "ShortestPathAStar" => 4, - "KShortestPathYen" => 4, - "MinimumSpanningTreePrim" => 3, - "MinimumSpanningTreeKruskal" => 3, - "TopSort" => 2, - "ConnectedComponents" => 2, - "StronglyConnectedComponents" | "SCC" => 2, - "PageRank" => 2, - "CommunityDetectionLouvain" => 2, - "LabelPropagation" => 2, - "RandomWalk" => 3, - n @ "ReorderSort" => { - let out_opts = opts.get("out").ok_or_else(|| { - CannotDetermineArity( - n.to_string(), - "option 'out' not provided".to_string(), - self.name.span, - ) - })?; - match out_opts { - Expr::Const { - val: DataValue::List(l), - .. - } => l.len() + 1, - Expr::Apply { op, args, .. } if **op == OP_LIST => args.len() + 1, - _ => bail!(CannotDetermineArity( - n.to_string(), - "invalid option 'out' given, expect a list".to_string(), - self.name.span - )), - } - } - n @ "CsvReader" => { - let with_row_num = match opts.get("prepend_index") { - None => 0, - Some(Expr::Const { - val: DataValue::Bool(true), - .. - }) => 1, - Some(Expr::Const { - val: DataValue::Bool(false), - .. - }) => 0, - _ => bail!(CannotDetermineArity( - n.to_string(), - "invalid option 'prepend_index' given, expect a boolean".to_string(), - self.name.span - )), - }; - let columns = opts.get("types").ok_or_else(|| AlgoOptionNotFoundError { - name: "types".to_string(), - span: self.name.span, - algo_name: n.to_string(), - })?; - let columns = columns.clone().eval_to_const()?; - if let Some(l) = columns.get_list() { - return Ok(l.len() + with_row_num); - } - bail!(CannotDetermineArity( - n.to_string(), - "invalid option 'types' given, expect positive number or list".to_string(), - self.name.span - )) - } - n @ "JsonReader" => { - let with_row_num = match opts.get("prepend_index") { - None => 0, - Some(Expr::Const { - val: DataValue::Bool(true), - .. - }) => 1, - Some(Expr::Const { - val: DataValue::Bool(false), - .. - }) => 0, - _ => bail!(CannotDetermineArity( - n.to_string(), - "invalid option 'prepend_index' given, expect a boolean".to_string(), - self.name.span - )), - }; - let fields = opts.get("fields").ok_or_else(|| { - CannotDetermineArity( - n.to_string(), - "option 'fields' not provided".to_string(), - self.name.span, - ) - })?; - match fields.clone().eval_to_const()? { - DataValue::List(l) => l.len() + with_row_num, - _ => bail!(CannotDetermineArity( - n.to_string(), - "invalid option 'fields' given, expect a list".to_string(), - self.name.span - )), - } - } - n => bail!(AlgoNotFoundError(n.to_string(), self.name.span)), - }) - } pub(crate) fn get_impl(&self) -> Result> { Ok(match &self.name.name as &str { diff --git a/src/algo/pagerank.rs b/src/algo/pagerank.rs index 84751f41..de49412b 100644 --- a/src/algo/pagerank.rs +++ b/src/algo/pagerank.rs @@ -4,11 +4,15 @@ use std::mem; use approx::AbsDiffEq; use miette::Result; use nalgebra::{Dynamic, OMatrix, U1}; +use smartstring::{LazyCompact, SmartString}; use crate::algo::AlgoImpl; +use crate::data::expr::Expr; use crate::data::program::{MagicAlgoApply, MagicSymbol}; +use crate::data::symb::Symbol; use crate::data::tuple::Tuple; use crate::data::value::DataValue; +use crate::parse::SourceSpan; use crate::runtime::db::Poison; use crate::runtime::in_mem::InMemRelation; use crate::runtime::transact::SessionTx; @@ -39,6 +43,15 @@ impl AlgoImpl for PageRank { } Ok(()) } + + fn arity( + &self, + _options: &BTreeMap, Expr>, + _rule_head: &[Symbol], + _span: SourceSpan, + ) -> Result { + Ok(2) + } } fn pagerank( diff --git a/src/algo/prim.rs b/src/algo/prim.rs index d7d6ab17..1d5c8d23 100644 --- a/src/algo/prim.rs +++ b/src/algo/prim.rs @@ -5,10 +5,13 @@ use miette::Diagnostic; use miette::Result; use ordered_float::OrderedFloat; use priority_queue::PriorityQueue; +use smartstring::{LazyCompact, SmartString}; use thiserror::Error; use crate::algo::AlgoImpl; +use crate::data::expr::Expr; use crate::data::program::{MagicAlgoApply, MagicSymbol}; +use crate::data::symb::Symbol; use crate::data::tuple::Tuple; use crate::data::value::DataValue; use crate::parse::SourceSpan; @@ -68,6 +71,15 @@ impl AlgoImpl for MinimumSpanningTreePrim { } Ok(()) } + + fn arity( + &self, + _options: &BTreeMap, Expr>, + _rule_head: &[Symbol], + _span: SourceSpan, + ) -> Result { + Ok(3) + } } fn prim( diff --git a/src/algo/random_walk.rs b/src/algo/random_walk.rs index 206cac13..46f96701 100644 --- a/src/algo/random_walk.rs +++ b/src/algo/random_walk.rs @@ -4,11 +4,15 @@ use itertools::Itertools; use miette::{bail, ensure, Result}; use rand::distributions::WeightedIndex; use rand::prelude::*; +use smartstring::{LazyCompact, SmartString}; use crate::algo::{AlgoImpl, BadExprValueError, NodeNotFoundError}; +use crate::data::expr::Expr; use crate::data::program::{MagicAlgoApply, MagicSymbol}; +use crate::data::symb::Symbol; use crate::data::tuple::Tuple; use crate::data::value::DataValue; +use crate::parse::SourceSpan; use crate::runtime::db::Poison; use crate::runtime::in_mem::InMemRelation; use crate::runtime::transact::SessionTx; @@ -119,4 +123,13 @@ impl AlgoImpl for RandomWalk { } Ok(()) } + + fn arity( + &self, + _options: &BTreeMap, Expr>, + _rule_head: &[Symbol], + _span: SourceSpan, + ) -> Result { + Ok(3) + } } diff --git a/src/algo/reorder_sort.rs b/src/algo/reorder_sort.rs index c668968a..204eab6a 100644 --- a/src/algo/reorder_sort.rs +++ b/src/algo/reorder_sort.rs @@ -2,11 +2,13 @@ use std::collections::BTreeMap; use itertools::Itertools; use miette::{bail, Result}; +use smartstring::{LazyCompact, SmartString}; -use crate::algo::AlgoImpl; +use crate::algo::{AlgoImpl, CannotDetermineArity}; use crate::data::expr::Expr; use crate::data::functions::OP_LIST; use crate::data::program::{MagicAlgoApply, MagicSymbol, WrongAlgoOptionError}; +use crate::data::symb::Symbol; use crate::data::tuple::Tuple; use crate::data::value::DataValue; use crate::parse::SourceSpan; @@ -112,4 +114,31 @@ impl AlgoImpl for ReorderSort { } Ok(()) } + + fn arity( + &self, + opts: &BTreeMap, Expr>, + _rule_head: &[Symbol], + span: SourceSpan, + ) -> Result { + let out_opts = opts.get("out").ok_or_else(|| { + CannotDetermineArity( + "ReorderSort".to_string(), + "option 'out' not provided".to_string(), + span, + ) + })?; + Ok(match out_opts { + Expr::Const { + val: DataValue::List(l), + .. + } => l.len() + 1, + Expr::Apply { op, args, .. } if **op == OP_LIST => args.len() + 1, + _ => bail!(CannotDetermineArity( + "ReorderSort".to_string(), + "invalid option 'out' given, expect a list".to_string(), + span + )), + }) + } } diff --git a/src/algo/shortest_path_dijkstra.rs b/src/algo/shortest_path_dijkstra.rs index 5f9700fe..fa098ba2 100644 --- a/src/algo/shortest_path_dijkstra.rs +++ b/src/algo/shortest_path_dijkstra.rs @@ -8,11 +8,15 @@ use ordered_float::OrderedFloat; use priority_queue::PriorityQueue; use rayon::prelude::*; use smallvec::{smallvec, SmallVec}; +use smartstring::{LazyCompact, SmartString}; use crate::algo::AlgoImpl; +use crate::data::expr::Expr; use crate::data::program::{MagicAlgoApply, MagicSymbol}; +use crate::data::symb::Symbol; use crate::data::tuple::Tuple; use crate::data::value::DataValue; +use crate::parse::SourceSpan; use crate::runtime::db::Poison; use crate::runtime::in_mem::InMemRelation; use crate::runtime::transact::SessionTx; @@ -139,6 +143,15 @@ impl AlgoImpl for ShortestPathDijkstra { Ok(()) } + + fn arity( + &self, + _options: &BTreeMap, Expr>, + _rule_head: &[Symbol], + _span: SourceSpan, + ) -> Result { + Ok(4) + } } #[derive(PartialEq)] diff --git a/src/algo/strongly_connected_components.rs b/src/algo/strongly_connected_components.rs index 86139365..0364bc77 100644 --- a/src/algo/strongly_connected_components.rs +++ b/src/algo/strongly_connected_components.rs @@ -3,11 +3,15 @@ use std::collections::BTreeMap; use itertools::Itertools; use miette::Result; +use smartstring::{LazyCompact, SmartString}; use crate::algo::AlgoImpl; +use crate::data::expr::Expr; use crate::data::program::{MagicAlgoApply, MagicSymbol}; +use crate::data::symb::Symbol; use crate::data::tuple::Tuple; use crate::data::value::DataValue; +use crate::parse::SourceSpan; use crate::runtime::db::Poison; use crate::runtime::in_mem::InMemRelation; use crate::runtime::transact::SessionTx; @@ -62,6 +66,15 @@ impl AlgoImpl for StronglyConnectedComponent { Ok(()) } + + fn arity( + &self, + _options: &BTreeMap, Expr>, + _rule_head: &[Symbol], + _span: SourceSpan, + ) -> Result { + Ok(2) + } } pub(crate) struct TarjanScc<'a> { diff --git a/src/algo/top_sort.rs b/src/algo/top_sort.rs index c58d459b..e81a00fc 100644 --- a/src/algo/top_sort.rs +++ b/src/algo/top_sort.rs @@ -1,11 +1,15 @@ use std::collections::BTreeMap; use miette::Result; +use smartstring::{LazyCompact, SmartString}; use crate::algo::AlgoImpl; +use crate::data::expr::Expr; use crate::data::program::{MagicAlgoApply, MagicSymbol}; +use crate::data::symb::Symbol; use crate::data::tuple::Tuple; use crate::data::value::DataValue; +use crate::parse::SourceSpan; use crate::runtime::db::Poison; use crate::runtime::in_mem::InMemRelation; use crate::runtime::transact::SessionTx; @@ -35,6 +39,15 @@ impl AlgoImpl for TopSort { Ok(()) } + + fn arity( + &self, + _options: &BTreeMap, Expr>, + _rule_head: &[Symbol], + _span: SourceSpan, + ) -> Result { + Ok(2) + } } pub(crate) fn kahn(graph: &[Vec], poison: Poison) -> Result> { diff --git a/src/algo/triangles.rs b/src/algo/triangles.rs index de4811cb..39b32b7c 100644 --- a/src/algo/triangles.rs +++ b/src/algo/triangles.rs @@ -2,11 +2,15 @@ use std::collections::{BTreeMap, BTreeSet}; use miette::Result; use rayon::prelude::*; +use smartstring::{LazyCompact, SmartString}; use crate::algo::AlgoImpl; +use crate::data::expr::Expr; use crate::data::program::{MagicAlgoApply, MagicSymbol}; +use crate::data::symb::Symbol; use crate::data::tuple::Tuple; use crate::data::value::DataValue; +use crate::parse::SourceSpan; use crate::runtime::db::Poison; use crate::runtime::in_mem::InMemRelation; use crate::runtime::transact::SessionTx; @@ -41,6 +45,15 @@ impl AlgoImpl for ClusteringCoefficients { Ok(()) } + + fn arity( + &self, + _options: &BTreeMap, Expr>, + _rule_head: &[Symbol], + _span: SourceSpan, + ) -> Result { + Ok(4) + } } fn clustering_coefficients( diff --git a/src/algo/yen.rs b/src/algo/yen.rs index 8ce16d6c..5f713593 100644 --- a/src/algo/yen.rs +++ b/src/algo/yen.rs @@ -3,12 +3,16 @@ use std::collections::{BTreeMap, BTreeSet}; use itertools::Itertools; use miette::Result; use rayon::prelude::*; +use smartstring::{LazyCompact, SmartString}; use crate::algo::shortest_path_dijkstra::dijkstra; use crate::algo::AlgoImpl; +use crate::data::expr::Expr; use crate::data::program::{MagicAlgoApply, MagicSymbol}; +use crate::data::symb::Symbol; use crate::data::tuple::Tuple; use crate::data::value::DataValue; +use crate::parse::SourceSpan; use crate::runtime::db::Poison; use crate::runtime::in_mem::InMemRelation; use crate::runtime::transact::SessionTx; @@ -96,6 +100,15 @@ impl AlgoImpl for KShortestPathYen { } Ok(()) } + + fn arity( + &self, + _options: &BTreeMap, Expr>, + _rule_head: &[Symbol], + _span: SourceSpan, + ) -> Result { + Ok(4) + } } fn k_shortest_path_yen( diff --git a/src/data/program.rs b/src/data/program.rs index f95700fe..f7d4b985 100644 --- a/src/data/program.rs +++ b/src/data/program.rs @@ -2,13 +2,12 @@ use std::collections::btree_map::Entry; use std::collections::{BTreeMap, BTreeSet}; use std::fmt::{Debug, Display, Formatter}; -use either::{Left, Right}; use miette::{ensure, Diagnostic, Result}; use smallvec::SmallVec; use smartstring::{LazyCompact, SmartString}; use thiserror::Error; -use crate::algo::AlgoHandle; +use crate::algo::{AlgoHandle, AlgoImpl}; use crate::data::aggr::Aggregation; use crate::data::expr::Expr; use crate::data::relation::StoredRelationMetadata; @@ -193,18 +192,33 @@ impl InputRulesOrAlgo { } } -#[derive(Clone)] pub(crate) struct AlgoApply { pub(crate) algo: AlgoHandle, pub(crate) rule_args: Vec, pub(crate) options: BTreeMap, Expr>, pub(crate) head: Vec, + pub(crate) arity: usize, pub(crate) span: SourceSpan, + pub(crate) algo_impl: Box, +} + +impl Clone for AlgoApply { + fn clone(&self) -> Self { + Self { + algo: self.algo.clone(), + rule_args: self.rule_args.clone(), + options: self.options.clone(), + head: self.head.clone(), + arity: self.arity, + span: self.span.clone(), + algo_impl: self.algo.get_impl().unwrap(), + } + } } impl AlgoApply { pub(crate) fn arity(&self) -> Result { - self.algo.arity(Left(&self.rule_args), &self.options) + self.algo_impl.arity(&self.options, &self.head, self.span) } } @@ -224,6 +238,7 @@ pub(crate) struct MagicAlgoApply { pub(crate) rule_args: Vec, pub(crate) options: BTreeMap, Expr>, pub(crate) span: SourceSpan, + pub(crate) arity: usize, } #[derive(Error, Diagnostic, Debug)] @@ -249,9 +264,6 @@ pub(crate) struct WrongAlgoOptionError { } impl MagicAlgoApply { - pub(crate) fn arity(&self) -> Result { - self.algo.arity(Right(&self.rule_args), &self.options) - } pub(crate) fn relation_with_min_len( &self, idx: usize, @@ -879,7 +891,7 @@ impl MagicRulesOrAlgo { pub(crate) fn arity(&self) -> Result { Ok(match self { MagicRulesOrAlgo::Rules { rules } => rules.first().unwrap().head.len(), - MagicRulesOrAlgo::Algo { algo } => algo.arity()?, + MagicRulesOrAlgo::Algo { algo } => algo.arity, }) } pub(crate) fn mut_rules(&mut self) -> Option<&mut Vec> { diff --git a/src/parse/query.rs b/src/parse/query.rs index fe9ef6d5..60218d25 100644 --- a/src/parse/query.rs +++ b/src/parse/query.rs @@ -820,7 +820,6 @@ fn parse_algo_rule( } let algo = AlgoHandle::new(algo_name, name_pair.extract_span()); - let algo_arity = algo.arity(Left(&rule_args), &options)?; #[derive(Debug, Error, Diagnostic)] #[error("Algorithm rule head arity mismatch")] @@ -828,9 +827,12 @@ fn parse_algo_rule( #[diagnostic(help("Expected arity: {0}, number of arguments given: {1}"))] struct AlgoRuleHeadArityMismatch(usize, usize, #[label] SourceSpan); + let algo_impl = algo.get_impl()?; + let arity = algo_impl.arity(&options, &head, args_list_span)?; + ensure!( - head.is_empty() || algo_arity == head.len(), - AlgoRuleHeadArityMismatch(algo_arity, head.len(), args_list_span) + head.is_empty() || arity == head.len(), + AlgoRuleHeadArityMismatch(arity, head.len(), args_list_span) ); Ok(( @@ -840,7 +842,9 @@ fn parse_algo_rule( rule_args, options, head, + arity, span: args_list_span, + algo_impl, }, )) } diff --git a/src/query/magic.rs b/src/query/magic.rs index 2b61ab54..5d5a14d1 100644 --- a/src/query/magic.rs +++ b/src/query/magic.rs @@ -374,6 +374,7 @@ impl NormalFormProgram { }) .try_collect()?, options: algo_apply.options.clone(), + arity: algo_apply.arity }, }, );