diff --git a/src/algo/bfs.rs b/src/algo/bfs.rs index 48431a74..d0ef7173 100644 --- a/src/algo/bfs.rs +++ b/src/algo/bfs.rs @@ -14,16 +14,8 @@ use crate::runtime::transact::SessionTx; pub(crate) struct Bfs; impl AlgoImpl for Bfs { - fn name(&self) -> Symbol { - Symbol::from("bfs") - } - - fn arity(&self) -> usize { - 1 - } - fn run( - &self, + &mut self, tx: &mut SessionTx, rels: &[MagicAlgoRuleArg], opts: &BTreeMap, diff --git a/src/algo/connected_components.rs b/src/algo/connected_components.rs new file mode 100644 index 00000000..3ac7b34e --- /dev/null +++ b/src/algo/connected_components.rs @@ -0,0 +1,127 @@ +use std::collections::BTreeMap; + +use anyhow::{anyhow, bail, Result}; + +use crate::algo::AlgoImpl; +use crate::data::expr::Expr; +use crate::data::program::{MagicAlgoRuleArg, MagicSymbol}; +use crate::data::symb::Symbol; +use crate::data::tuple::Tuple; +use crate::data::value::DataValue; +use crate::runtime::derived::DerivedRelStore; +use crate::runtime::transact::SessionTx; + +#[derive(Default)] +struct ConnectedComponents { + union_find: BTreeMap, +} + +impl ConnectedComponents { + pub(crate) fn find(&mut self, data: &DataValue) -> DataValue { + if !self.union_find.contains_key(data) { + self.union_find.insert(data.clone(), data.clone()); + data.clone() + } else { + let mut root = data; + while let Some(new_found) = self.union_find.get(root) { + root = new_found; + } + + let root = root.clone(); + + let mut current = data.clone(); + while current != root { + let found = self.union_find.get_mut(¤t).unwrap(); + let next = found.clone(); + *found = root.clone(); + current = next; + } + + root + } + } +} + +impl AlgoImpl for ConnectedComponents { + fn run( + &mut self, + tx: &mut SessionTx, + rels: &[MagicAlgoRuleArg], + opts: &BTreeMap, + stores: &BTreeMap, + out: &DerivedRelStore, + ) -> Result<()> { + let edges = rels + .get(0) + .ok_or_else(|| anyhow!("'connected_components' missing edges relation"))?; + + let reverse_mode = match opts.get(&Symbol::from("mode")) { + None => false, + Some(Expr::Const(DataValue::String(s))) => match s as &str { + "group_first" => true, + "key_first" => false, + v => bail!("unexpected option 'mode' for 'connected_components': {}", v), + }, + Some(v) => bail!( + "unexpected option 'mode' for 'connected_components': {:?}", + v + ), + }; + + for tuple in edges.iter(tx, stores)? { + let mut tuple = tuple?.0.into_iter(); + let from = tuple + .next() + .ok_or_else(|| anyhow!("edges relation for 'connected_components' too short"))?; + let to = tuple + .next() + .ok_or_else(|| anyhow!("edges relation for 'connected_components' too short"))?; + + let to_root = self.find(&to); + let from_root = self.find(&from); + if to_root != from_root { + let from_target = self.union_find.get_mut(&from_root).unwrap(); + *from_target = to_root + } + } + + let mut counter = 0i64; + let mut seen: BTreeMap<&DataValue, i64> = Default::default(); + + for (k, grp) in self.union_find.iter() { + let grp_id = if let Some(grp_id) = seen.get(grp) { + *grp_id + } else { + let old = counter; + seen.insert(grp, old); + counter += 1; + old + }; + let tuple = if reverse_mode { + Tuple(vec![DataValue::from(grp_id), k.clone()]) + } else { + Tuple(vec![k.clone(), DataValue::from(grp_id)]) + }; + out.put(tuple, 0); + } + + if let Some(nodes) = rels.get(1) { + for tuple in nodes.iter(tx, stores)? { + let tuple = tuple?; + let node = tuple.0.into_iter().next().ok_or_else(|| { + anyhow!("nodes relation for 'connected_components' too short") + })?; + if !seen.contains_key(&node) { + let tuple = if reverse_mode { + Tuple(vec![DataValue::from(counter), node]) + } else { + Tuple(vec![node, DataValue::from(counter)]) + }; + out.put(tuple, 0); + counter += 1; + } + } + } + Ok(()) + } +} diff --git a/src/algo/degree_centrality.rs b/src/algo/degree_centrality.rs index e6d53c95..9edd5edd 100644 --- a/src/algo/degree_centrality.rs +++ b/src/algo/degree_centrality.rs @@ -14,16 +14,8 @@ use crate::runtime::transact::SessionTx; pub(crate) struct DegreeCentrality; impl AlgoImpl for DegreeCentrality { - fn name(&self) -> Symbol { - Symbol::from("degree_centrality") - } - - fn arity(&self) -> usize { - 4 - } - fn run( - &self, + &mut self, tx: &mut SessionTx, rels: &[MagicAlgoRuleArg], _opts: &BTreeMap, diff --git a/src/algo/dfs.rs b/src/algo/dfs.rs index 2e064237..8f1f6edd 100644 --- a/src/algo/dfs.rs +++ b/src/algo/dfs.rs @@ -14,16 +14,8 @@ use crate::runtime::transact::SessionTx; pub(crate) struct Dfs; impl AlgoImpl for Dfs { - fn name(&self) -> Symbol { - Symbol::from("dfs") - } - - fn arity(&self) -> usize { - 1 - } - fn run( - &self, + &mut self, tx: &mut SessionTx, rels: &[MagicAlgoRuleArg], opts: &BTreeMap, diff --git a/src/algo/mod.rs b/src/algo/mod.rs index ce37ab94..bc3e7cd0 100644 --- a/src/algo/mod.rs +++ b/src/algo/mod.rs @@ -1,5 +1,4 @@ use std::collections::BTreeMap; -use std::sync::Arc; use anyhow::{anyhow, bail, ensure, Result}; use itertools::Itertools; @@ -7,6 +6,7 @@ use itertools::Itertools; use crate::algo::bfs::Bfs; use crate::algo::degree_centrality::DegreeCentrality; use crate::algo::dfs::Dfs; +use crate::algo::top_sort::TopSort; use crate::data::expr::Expr; use crate::data::id::{EntityId, Validity}; use crate::data::program::{MagicAlgoRuleArg, MagicSymbol, TripleDir}; @@ -17,15 +17,15 @@ use crate::runtime::derived::DerivedRelStore; use crate::runtime::transact::SessionTx; pub(crate) mod bfs; -mod degree_centrality; +pub(crate) mod connected_components; +pub(crate) mod degree_centrality; pub(crate) mod dfs; pub(crate) mod page_rank; +pub(crate) mod top_sort; pub(crate) trait AlgoImpl { - fn name(&self) -> Symbol; - fn arity(&self) -> usize; fn run( - &self, + &mut self, tx: &mut SessionTx, rels: &[MagicAlgoRuleArg], opts: &BTreeMap, @@ -34,14 +34,42 @@ pub(crate) trait AlgoImpl { ) -> Result<()>; } -pub(crate) fn get_algo(name: &str) -> Result> { - Ok(match name { - "degree_centrality" => Arc::new(DegreeCentrality), - "dfs" => Arc::new(Dfs), - "bfs" => Arc::new(Bfs), - "page_rank" => todo!(), - name => bail!("algorithm '{}' not found", name), - }) +#[derive(Clone, Debug)] +pub(crate) struct AlgoHandle { + pub(crate) name: Symbol, +} + +impl AlgoHandle { + pub(crate) fn new(name: &str) -> Self { + AlgoHandle { + name: Symbol::from(name), + } + } + pub(crate) fn arity(&self) -> Result { + Ok(match &self.name.0 as &str { + "degree_centrality" => 4, + "depth_first_search" | "dfs" => 1, + "breadth_first_search" | "bfs" => 1, + "top_sort" => 2, + "connected_components" => 2, + "strongly_connected_components" | "scc" => 2, + "page_rank" => todo!(), + name => bail!("algorithm '{}' not found", name), + }) + } + + pub(crate) fn get_impl(&self) -> Result> { + Ok(match &self.name.0 as &str { + "degree_centrality" => Box::new(DegreeCentrality), + "depth_first_search" | "dfs" => Box::new(Dfs), + "breadth_first_search" | "bfs" => Box::new(Bfs), + "top_sort" => Box::new(TopSort), + "connected_components" => todo!(), + "strongly_connected_components" | "scc" => todo!(), + "page_rank" => todo!(), + name => bail!("algorithm '{}' not found", name), + }) + } } impl MagicAlgoRuleArg { diff --git a/src/algo/top_sort.rs b/src/algo/top_sort.rs new file mode 100644 index 00000000..0aed09ae --- /dev/null +++ b/src/algo/top_sort.rs @@ -0,0 +1,25 @@ +use std::collections::BTreeMap; + +use anyhow::Result; + +use crate::algo::AlgoImpl; +use crate::data::expr::Expr; +use crate::data::program::{MagicAlgoRuleArg, MagicSymbol}; +use crate::data::symb::Symbol; +use crate::runtime::derived::DerivedRelStore; +use crate::runtime::transact::SessionTx; + +pub(crate) struct TopSort; + +impl AlgoImpl for TopSort { + fn run( + &mut self, + tx: &mut SessionTx, + rels: &[MagicAlgoRuleArg], + opts: &BTreeMap, + stores: &BTreeMap, + out: &DerivedRelStore, + ) -> Result<()> { + todo!() + } +} diff --git a/src/cozoscript.pest b/src/cozoscript.pest index 55ac3e26..5e785a8d 100644 --- a/src/cozoscript.pest +++ b/src/cozoscript.pest @@ -26,8 +26,8 @@ view_ident = @{":" ~ ident} compound_ident = {ident ~ ("." ~ ident)?} rule = {rule_head ~ ("@" ~ expr)? ~ ":=" ~ rule_body ~ ";"} -const_rule = {ident ~ "<-" ~ expr ~ ";" } -algo_rule = {ident ~ "<-" ~ algo_ident ~ "(" ~ (algo_arg ~ ",")* ~ algo_arg? ~ ")" ~ ";" } +const_rule = {(prog_entry | ident) ~ "<-" ~ expr ~ ";" } +algo_rule = {(prog_entry | ident) ~ "<-" ~ algo_ident ~ "(" ~ (algo_arg ~ ",")* ~ algo_arg? ~ ")" ~ ";" } rule_head = {(prog_entry | ident) ~ "[" ~ (head_arg ~ ",")* ~ head_arg? ~ "]"} head_arg = {aggr_arg | var} diff --git a/src/data/program.rs b/src/data/program.rs index 6faf001f..961f5ffa 100644 --- a/src/data/program.rs +++ b/src/data/program.rs @@ -1,13 +1,12 @@ use std::collections::btree_map::Entry; use std::collections::{BTreeMap, BTreeSet}; use std::fmt::{Debug, Formatter}; -use std::sync::Arc; use anyhow::{anyhow, bail, ensure, Result}; use itertools::Itertools; use smallvec::SmallVec; -use crate::algo::AlgoImpl; +use crate::algo::AlgoHandle; use crate::data::aggr::Aggregation; use crate::data::attr::Attribute; use crate::data::expr::Expr; @@ -35,7 +34,7 @@ pub(crate) enum InputRulesOrAlgo { #[derive(Clone)] pub(crate) struct AlgoApply { - pub(crate) algo: Arc, + pub(crate) algo: AlgoHandle, pub(crate) rule_args: Vec, pub(crate) options: BTreeMap, } @@ -43,7 +42,7 @@ pub(crate) struct AlgoApply { impl Debug for AlgoApply { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { f.debug_struct("AlgoApply") - .field("algo", &self.algo.name()) + .field("algo", &self.algo.name) .field("rules", &self.rule_args) .field("options", &self.options) .finish() @@ -52,7 +51,7 @@ impl Debug for AlgoApply { #[derive(Clone)] pub(crate) struct MagicAlgoApply { - pub(crate) algo: Arc, + pub(crate) algo: AlgoHandle, pub(crate) rule_args: Vec, pub(crate) options: BTreeMap, } @@ -60,7 +59,7 @@ pub(crate) struct MagicAlgoApply { impl Debug for MagicAlgoApply { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { f.debug_struct("AlgoApply") - .field("algo", &self.algo.name()) + .field("algo", &self.algo.name) .field("rules", &self.rule_args) .field("options", &self.options) .finish() @@ -143,7 +142,7 @@ impl InputProgram { .ok_or_else(|| anyhow!("program entry point not found"))? { InputRulesOrAlgo::Rules(rules) => rules[0].head.len(), - InputRulesOrAlgo::Algo(algo_apply) => algo_apply.algo.arity(), + InputRulesOrAlgo::Algo(algo_apply) => algo_apply.algo.arity()?, }, ) } @@ -258,11 +257,11 @@ impl Default for MagicRulesOrAlgo { } impl MagicRulesOrAlgo { - pub(crate) fn arity(&self) -> usize { - match self { + pub(crate) fn arity(&self) -> Result { + Ok(match self { MagicRulesOrAlgo::Rules(r) => r.first().unwrap().head.len(), - MagicRulesOrAlgo::Algo(algo) => algo.algo.arity(), - } + MagicRulesOrAlgo::Algo(algo) => algo.algo.arity()?, + }) } pub(crate) fn mut_rules(&mut self) -> Option<&mut Vec> { match self { diff --git a/src/parse/query.rs b/src/parse/query.rs index d8bafa89..0285e74d 100644 --- a/src/parse/query.rs +++ b/src/parse/query.rs @@ -5,7 +5,7 @@ use anyhow::{anyhow, bail, ensure, Result}; use itertools::Itertools; use serde_json::{json, Map}; -use crate::algo::get_algo; +use crate::algo::AlgoHandle; use crate::data::aggr::get_aggr; use crate::data::attr::Attribute; use crate::data::expr::{get_op, Expr, OP_LIST}; @@ -99,8 +99,11 @@ impl SessionTx { let rules_payload = q .as_array() .ok_or_else(|| anyhow!("expect field 'q' to be an array in query {}", payload))?; - ensure!(!rules_payload.is_empty(), "no rules in {}", payload); - let mut input_prog = if rules_payload.first().unwrap().is_array() { + let mut input_prog = if rules_payload.is_empty() { + InputProgram { + prog: Default::default(), + } + } else if rules_payload.first().unwrap().is_array() { let q = json!([{"rule": "?", "args": rules_payload}]); self.parse_input_rule_sets(&q, vld, params_pool)? } else { @@ -196,11 +199,13 @@ impl SessionTx { } } let out_name = Symbol::from(out_symbol); - out_name.validate_not_reserved()?; + if !out_name.is_prog_entry() { + out_name.validate_not_reserved()?; + } match input_prog.prog.entry(out_name) { Entry::Vacant(v) => { v.insert(InputRulesOrAlgo::Algo(AlgoApply { - algo: get_algo(name_symbol)?, + algo: AlgoHandle::new(name_symbol), rule_args: relations, options, })); diff --git a/src/query/compile.rs b/src/query/compile.rs index 9078aa70..bbde1239 100644 --- a/src/query/compile.rs +++ b/src/query/compile.rs @@ -92,7 +92,7 @@ impl SessionTx { ); stores.insert( name.clone(), - self.new_rule_store(name.clone(), ruleset.arity()), + self.new_rule_store(name.clone(), ruleset.arity()?), ); } } diff --git a/src/query/eval.rs b/src/query/eval.rs index a09622cd..f6d064f5 100644 --- a/src/query/eval.rs +++ b/src/query/eval.rs @@ -137,7 +137,7 @@ impl SessionTx { algo_apply: &MagicAlgoApply, stores: &BTreeMap, ) -> Result<()> { - let algo_impl = &algo_apply.algo; + let mut algo_impl = algo_apply.algo.get_impl()?; let out = stores .get(rule_symb) .ok_or_else(|| anyhow!("cannot find algo store {:?}", rule_symb))?; diff --git a/tests/air_routes.rs b/tests/air_routes.rs index c084c9e8..0dbedf28 100644 --- a/tests/air_routes.rs +++ b/tests/air_routes.rs @@ -100,8 +100,7 @@ fn air_routes() -> Result<()> { let dfs_time = Instant::now(); let res = db.run_script(r#" starting <- [['PEK']]; - res <- dfs!(:flies_to_code[], [?id Result<()> { let bfs_time = Instant::now(); let res = db.run_script(r#" starting <- [['PEK']]; - res <- bfs!(:flies_to_code[], [?id