From 43d46ecb5dc57ae94f8c28aa839b61866598e4e2 Mon Sep 17 00:00:00 2001 From: Ziyang Hu Date: Tue, 9 Aug 2022 20:59:28 +0800 Subject: [PATCH] group count; map type; fix normal aggr; --- src/data/aggr.rs | 40 ++++++++++++++++++++++++++++- src/data/json.rs | 3 +++ src/data/value.rs | 21 ++++++++++++--- src/runtime/temp_store.rs | 22 ++++------------ tests/air_routes.rs | 54 +++++++++++++++++++++++++++++++++++++++ 5 files changed, 119 insertions(+), 21 deletions(-) diff --git a/src/data/aggr.rs b/src/data/aggr.rs index 8a05e850..c265d530 100644 --- a/src/data/aggr.rs +++ b/src/data/aggr.rs @@ -1,8 +1,9 @@ -use std::collections::BTreeSet; +use std::collections::{BTreeMap, BTreeSet}; use std::fmt::{Debug, Formatter}; use std::ops::Sub; use anyhow::{anyhow, bail, ensure, Result}; +use itertools::Itertools; use crate::data::value::{DataValue, Number}; @@ -52,6 +53,42 @@ fn aggr_unique(accum: &mut DataValue, current: &DataValue, _args: &[DataValue]) }) } +define_aggr!(AGGR_GROUP_COUNT, false); +fn aggr_group_count( + accum: &mut DataValue, + current: &DataValue, + _args: &[DataValue], +) -> Result { + dbg!(¤t); + Ok(match (accum, current) { + (accum @ DataValue::Guard, DataValue::Guard) => { + *accum = DataValue::List(vec![]); + true + } + (accum @ DataValue::Guard, val) => { + *accum = DataValue::Map(BTreeMap::from([(val.clone(), DataValue::from(1))])); + true + } + (accum, DataValue::Guard) => { + *accum = DataValue::List( + accum + .get_map() + .unwrap() + .iter() + .map(|(k, v)| DataValue::List(vec![k.clone(), v.clone()])) + .collect_vec(), + ); + true + } + (DataValue::Map(l), val) => { + let entry = l.entry(val.clone()).or_insert_with(|| DataValue::from(0)); + *entry = DataValue::from(entry.get_int().unwrap() + 1); + true + } + _ => unreachable!(), + }) +} + define_aggr!(AGGR_COUNT_UNIQUE, false); fn aggr_count_unique( accum: &mut DataValue, @@ -348,6 +385,7 @@ fn aggr_choice(accum: &mut DataValue, current: &DataValue, _args: &[DataValue]) pub(crate) fn get_aggr(name: &str) -> Option<&'static Aggregation> { Some(match name { "count" => &AGGR_COUNT, + "group_count" => &AGGR_GROUP_COUNT, "count_unique" => &AGGR_COUNT_UNIQUE, "sum" => &AGGR_SUM, "min" => &AGGR_MIN, diff --git a/src/data/json.rs b/src/data/json.rs index a5a5c449..811e6386 100644 --- a/src/data/json.rs +++ b/src/data/json.rs @@ -80,6 +80,9 @@ impl From for JsonValue { DataValue::Regex(r) => { json!(r.0.as_str()) } + DataValue::Map(m) => { + JsonValue::Array(m.into_iter().map(|(k, v)| json!([k, v])).collect()) + } } } } diff --git a/src/data/value.rs b/src/data/value.rs index c6dfdf5d..4253d3b0 100644 --- a/src/data/value.rs +++ b/src/data/value.rs @@ -1,5 +1,5 @@ use std::cmp::{Ordering, Reverse}; -use std::collections::BTreeSet; +use std::collections::{BTreeMap, BTreeSet}; use std::fmt::{Debug, Display, Formatter}; use anyhow::{bail, Result}; @@ -19,13 +19,19 @@ use crate::data::triple::StoreOp; pub(crate) struct RegexWrapper(pub(crate) Regex); impl Serialize for RegexWrapper { - fn serialize(&self, _serializer: S) -> std::result::Result where S: serde::Serializer { + fn serialize(&self, _serializer: S) -> std::result::Result + where + S: serde::Serializer, + { panic!("serializing regex"); } } impl<'de> Deserialize<'de> for RegexWrapper { - fn deserialize(_deserializer: D) -> std::result::Result where D: Deserializer<'de> { + fn deserialize(_deserializer: D) -> std::result::Result + where + D: Deserializer<'de>, + { panic!("deserializing regex"); } } @@ -72,6 +78,8 @@ pub(crate) enum DataValue { List(Vec), #[serde(rename = "y")] Set(BTreeSet), + #[serde(rename = "w")] + Map(BTreeMap), #[serde(rename = "g")] Guard, #[serde(rename = "o")] @@ -201,6 +209,7 @@ impl Debug for DataValue { } DataValue::List(t) => f.debug_list().entries(t.iter()).finish(), DataValue::Set(t) => f.debug_list().entries(t.iter()).finish(), + DataValue::Map(m) => f.debug_map().entries(m.iter()).finish(), DataValue::DescVal(v) => { write!(f, "desc<{:?}>", v) } @@ -247,6 +256,12 @@ impl DataValue { _ => None, } } + pub(crate) fn get_map(&self) -> Option<&BTreeMap> { + match self { + DataValue::Map(m) => Some(m), + _ => None + } + } pub(crate) fn get_int(&self) -> Option { match self { DataValue::Number(n) => n.get_int(), diff --git a/src/runtime/temp_store.rs b/src/runtime/temp_store.rs index 159c770e..ee7e94b6 100644 --- a/src/runtime/temp_store.rs +++ b/src/runtime/temp_store.rs @@ -153,35 +153,23 @@ impl TempStore { it.seek(&lower); let it = TempStoreIter { it, started: false }; let aggrs = aggrs.to_vec(); - let key_indices = aggrs - .iter() - .enumerate() - .filter_map(|(i, aggr)| if aggr.is_none() { Some(i) } else { None }) - .collect_vec(); + let n_keys = aggrs.iter().filter(|aggr| aggr.is_none()).count(); let grouped = it.group_by(move |t_res| { if let Ok(tuple) = t_res { - Some( - key_indices - .iter() - .map(|i| tuple.0[*i].clone()) - .collect_vec(), - ) + Some(tuple.0[..n_keys].to_vec()) } else { None } }); let mut invert_indices = vec![]; - let mut idx = 0; - for aggr in aggrs.iter() { + for (idx, aggr) in aggrs.iter().enumerate() { if aggr.is_none() { invert_indices.push(idx); - idx += 1; } } - for aggr in aggrs.iter() { + for (idx, aggr) in aggrs.iter().enumerate() { if aggr.is_some() { invert_indices.push(idx); - idx += 1; } } let invert_indices = invert_indices @@ -200,7 +188,7 @@ impl TempStore { if let Some((aggr_op, aggr_args)) = aggr { (aggr_op.combine)(&mut aggr_res[idx], val, aggr_args)?; } else { - aggr_res[idx] = first_tuple.0[idx].clone(); + aggr_res[idx] = first_tuple.0[invert_indices[idx]].clone(); } } for tuple in it { diff --git a/tests/air_routes.rs b/tests/air_routes.rs index 4b825284..4129d4ea 100644 --- a/tests/air_routes.rs +++ b/tests/air_routes.rs @@ -155,6 +155,29 @@ fn air_routes() -> Result<()> { .unwrap() ); + let most_out_routes_time_inv = Instant::now(); + let res = db.run_script( + r#" + route_count[count(?r), ?a, ?x] := [?r route.src ?a], ?x is 1; + ?[?code, ?n] := route_count[?n, ?a, ?_], ?n > 180, [?a airport.iata ?code]; + :sort -?n; + "#, + )?; + dbg!(most_out_routes_time_inv.elapsed()); + assert_eq!( + res, + serde_json::Value::from_str( + r#"[ + ["FRA",307],["IST",307],["CDG",293],["AMS",282],["MUC",270],["ORD",264],["DFW",251], + ["PEK",248],["DXB",247],["ATL",242],["DME",232],["LGW",232],["LHR",221],["DEN",216], + ["MAN",216],["LAX",213],["PVG",212],["STN",211],["MAD",206],["VIE",206],["BCN",203], + ["BER",202],["FCO",201],["JFK",201],["DUS",199],["IAH",199],["EWR",197],["MIA",195], + ["YYZ",195],["BRU",194],["CPH",194],["DOH",186],["DUB",185],["CLT",184],["SVO",181] + ]"# + ) + .unwrap() + ); + let most_routes_time = Instant::now(); let res = db.run_script( r#" @@ -482,5 +505,36 @@ fn air_routes() -> Result<()> { .unwrap() ); + let len_of_names_count_time = Instant::now(); + let res = db.run_script( + r#" + ?[sum(?n)] := [?a airport.iata 'AUS'], + [?r route.src ?a], + [?r route.dst ?a2], + [?a2 airport.city ?city_name], + ?n is length(?city_name); + "#, + )?; + dbg!(len_of_names_count_time.elapsed()); + assert_eq!(res, json!([[866]])); + + let group_count_by_out_time = Instant::now(); + let res = db.run_script( + r#" + route_count[count(?r), ?a] := [?r route.src ?a]; + ?[?n, count(?a)] := route_count[?n, ?a]; + :order -?a; + :limit 10; + "#, + )?; + dbg!(group_count_by_out_time.elapsed()); + assert_eq!( + res, + serde_json::Value::from_str( + r#"[[1,777],[2,649],[3,359],[4,232],[5,150],[6,139],[7,100],[8,74],[9,63],[10,59]]"# + ) + .unwrap() + ); + Ok(()) }