From 582de562e989f6c89b3b936422a1b1d750cada17 Mon Sep 17 00:00:00 2001 From: Ziyang Hu Date: Sun, 28 Aug 2022 18:22:58 +0800 Subject: [PATCH] dijkstra's algorithm --- README.md | 4 ++-- src/algo/mod.rs | 13 ++++++++---- src/algo/shortest_path_dijkstra.rs | 33 ++++++++++++++++++++++-------- src/data/symb.rs | 3 --- src/query/reorder.rs | 32 +++++------------------------ tests/air_routes.rs | 13 ++++++++++++ 6 files changed, 53 insertions(+), 45 deletions(-) diff --git a/README.md b/README.md index 33d3a241..b22d3942 100644 --- a/README.md +++ b/README.md @@ -5,11 +5,11 @@ * [ ] graph algorithms * [x] bfs * [x] dfs - * [ ] shortest path, SSSP + * [x] shortest path * [ ] A* * [ ] Yen's k-shortest * [ ] all-pairs shortest path - * [ ] single-source shortest path + * [x] single-source shortest path * [ ] minimum spanning tree * [ ] random walking * [x] degree centrality diff --git a/src/algo/mod.rs b/src/algo/mod.rs index fd3d8ac8..95f07ce0 100644 --- a/src/algo/mod.rs +++ b/src/algo/mod.rs @@ -83,17 +83,19 @@ impl MagicAlgoRuleArg { pub(crate) fn convert_edge_to_weighted_graph( &self, undirected: bool, - allow_negative_edge: bool, + allow_negative_edges: bool, tx: &SessionTx, stores: &BTreeMap, ) -> Result<( Vec>, Vec, BTreeMap, + bool )> { let mut graph: Vec> = vec![]; let mut indices: Vec = vec![]; let mut inv_indices: BTreeMap = Default::default(); + let mut has_neg_edge = false; for tuple in self.iter(tx, stores)? { let mut tuple = tuple?.0.into_iter(); @@ -108,8 +110,11 @@ impl MagicAlgoRuleArg { Some(d) => match d.get_float() { Some(f) => { ensure!(f.is_finite(), "edge weight must be finite, got {}", f); - if !allow_negative_edge { - ensure!(f >= 0., "edge weight must be non-negative, got {}", f); + if f < 0. { + if !allow_negative_edges { + bail!("edge weight must be non-negative, got {}", f); + } + has_neg_edge = true; } f } @@ -139,7 +144,7 @@ impl MagicAlgoRuleArg { to_target.push((from_idx, weight)); } } - Ok((graph, indices, inv_indices)) + Ok((graph, indices, inv_indices, has_neg_edge)) } pub(crate) fn convert_edge_to_graph( &self, diff --git a/src/algo/shortest_path_dijkstra.rs b/src/algo/shortest_path_dijkstra.rs index 8573b097..5462404b 100644 --- a/src/algo/shortest_path_dijkstra.rs +++ b/src/algo/shortest_path_dijkstra.rs @@ -32,11 +32,6 @@ impl AlgoImpl for ShortestPathDijkstra { anyhow!("'shortest_path_dijkstra' requires starting relation as second argument") })?; let termination = rels.get(2); - let allow_negative_edges = match opts.get(&Symbol::from("allow_negative_edges")) { - None => false, - Some(Expr::Const(DataValue::Bool(b))) => *b, - Some(v) => bail!("option 'allow_negative_edges' for 'shortest_path_dijkstra' requires a boolean, got {:?}", v) - }; let undirected = match opts.get(&Symbol::from("undirected")) { None => false, Some(Expr::Const(DataValue::Bool(b))) => *b, @@ -46,8 +41,8 @@ impl AlgoImpl for ShortestPathDijkstra { ), }; - let (graph, indices, inv_indices) = - edges.convert_edge_to_weighted_graph(undirected, allow_negative_edges, tx, stores)?; + let (graph, indices, inv_indices, _) = + edges.convert_edge_to_weighted_graph(undirected, false, tx, stores)?; let mut starting_nodes = BTreeSet::new(); for tuple in starting.iter(tx, stores)? { @@ -79,7 +74,7 @@ impl AlgoImpl for ShortestPathDijkstra { }; for start in starting_nodes { - let res = dijkstra(&graph, start, &termination_nodes); + let res = dijkstra(&graph, start, &termination_nodes, ()); for (target, cost, path) in res { let t = vec![ indices[start].clone(), @@ -118,10 +113,27 @@ impl Ord for HeapState { impl Eq for HeapState {} -fn dijkstra( +trait ForbiddenEdge { + fn is_forbidden(&self, src: usize, dst: usize) -> bool; +} + +impl ForbiddenEdge for () { + fn is_forbidden(&self, _src: usize, _dst: usize) -> bool { + false + } +} + +impl ForbiddenEdge for BTreeSet<(usize, usize)> { + fn is_forbidden(&self, src: usize, dst: usize) -> bool { + self.contains(&(src, dst)) + } +} + +fn dijkstra( edges: &[Vec<(usize, f64)>], start: usize, maybe_goals: &Option>, + forbidden: T, ) -> Vec<(usize, f64, Vec)> { let mut distance = vec![f64::INFINITY; edges.len()]; let mut heap = BinaryHeap::new(); @@ -139,6 +151,9 @@ fn dijkstra( } for (nxt_node, path_weight) in &edges[state.node] { + if forbidden.is_forbidden(state.node, *nxt_node) { + continue; + } let nxt_cost = state.cost + *path_weight; if nxt_cost < distance[*nxt_node] { heap.push(HeapState { diff --git a/src/data/symb.rs b/src/data/symb.rs index 6e3b2769..7b740da6 100644 --- a/src/data/symb.rs +++ b/src/data/symb.rs @@ -43,9 +43,6 @@ impl Symbol { pub(crate) fn is_query_var(&self) -> bool { self.0.starts_with('?') && self.0.len() > 1 } - pub(crate) fn is_ignored_var(&self) -> bool { - self.0 == "?_" - } pub(crate) fn validate_not_reserved(&self) -> Result<()> { ensure!( !self.is_reserved(), diff --git a/src/query/reorder.rs b/src/query/reorder.rs index 8d6f3bc7..1acfc935 100644 --- a/src/query/reorder.rs +++ b/src/query/reorder.rs @@ -4,26 +4,16 @@ use std::mem; use anyhow::{bail, Result}; use crate::data::program::{NormalFormAtom, NormalFormRule}; -use crate::data::symb::Symbol; impl NormalFormRule { pub(crate) fn convert_to_well_ordered_rule(self) -> Result { let mut seen_variables = BTreeSet::default(); let mut round_1_collected = vec![]; let mut pending = vec![]; - let mut symb_count = 0; - let mut process_ignored_symbol = |symb: &mut Symbol| { - if symb.is_ignored_var() { - symb_count += 1; - let mut new_symb = Symbol::from(&format!("_{}", symb_count) as &str); - mem::swap(&mut new_symb, symb); - } - }; for atom in self.body { match atom { - NormalFormAtom::Unification(mut u) => { - process_ignored_symbol(&mut u.binding); + NormalFormAtom::Unification(u) => { if u.is_const() { seen_variables.insert(u.binding.clone()); round_1_collected.push(NormalFormAtom::Unification(u)); @@ -37,42 +27,30 @@ impl NormalFormRule { } } } - NormalFormAtom::AttrTriple(mut t) => { - process_ignored_symbol(&mut t.value); - process_ignored_symbol(&mut t.entity); + NormalFormAtom::AttrTriple(t) => { seen_variables.insert(t.value.clone()); seen_variables.insert(t.entity.clone()); round_1_collected.push(NormalFormAtom::AttrTriple(t)); } NormalFormAtom::Rule(mut r) => { for arg in &mut r.args { - process_ignored_symbol(arg); seen_variables.insert(arg.clone()); } round_1_collected.push(NormalFormAtom::Rule(r)) } NormalFormAtom::View(mut v) => { for arg in &mut v.args { - process_ignored_symbol(arg); seen_variables.insert(arg.clone()); } round_1_collected.push(NormalFormAtom::View(v)) } - NormalFormAtom::NegatedAttrTriple(mut t) => { - process_ignored_symbol(&mut t.value); - process_ignored_symbol(&mut t.entity); + NormalFormAtom::NegatedAttrTriple(t) => { pending.push(NormalFormAtom::NegatedAttrTriple(t)) } - NormalFormAtom::NegatedRule(mut r) => { - for arg in &mut r.args { - process_ignored_symbol(arg); - } + NormalFormAtom::NegatedRule(r) => { pending.push(NormalFormAtom::NegatedRule(r)) } - NormalFormAtom::NegatedView(mut v) => { - for arg in &mut v.args { - process_ignored_symbol(arg); - } + NormalFormAtom::NegatedView(v) => { pending.push(NormalFormAtom::NegatedView(v)) } NormalFormAtom::Predicate(p) => { diff --git a/tests/air_routes.rs b/tests/air_routes.rs index 2ac241a7..45ac6e79 100644 --- a/tests/air_routes.rs +++ b/tests/air_routes.rs @@ -173,6 +173,19 @@ fn air_routes() -> Result<()> { )? ); + let dijkstra_time = Instant::now(); + let res = db.run_script( + r#" + starting <- [['JFK']]; + ending <- [['KUL']]; + res <- shortest_path_dijkstra!(:flies_to_code[], starting[], ending[]); + ?[?path] := res[?src, ?dst, ?cost, ?path]; + "#, + )?; + + dbg!(dijkstra_time.elapsed()); + assert_eq!(*res.get("rows").unwrap(), json!([[["JFK", "CTU", "KUL"]]])); + let starts_with_time = Instant::now(); let res = db.run_script( r#"