group count; map type; fix normal aggr;

main
Ziyang Hu 2 years ago
parent b5d3545160
commit 43d46ecb5d

@ -1,8 +1,9 @@
use std::collections::BTreeSet;
use std::collections::{BTreeMap, BTreeSet};
use std::fmt::{Debug, Formatter};
use std::ops::Sub;
use anyhow::{anyhow, bail, ensure, Result};
use itertools::Itertools;
use crate::data::value::{DataValue, Number};
@ -52,6 +53,42 @@ fn aggr_unique(accum: &mut DataValue, current: &DataValue, _args: &[DataValue])
})
}
define_aggr!(AGGR_GROUP_COUNT, false);
fn aggr_group_count(
accum: &mut DataValue,
current: &DataValue,
_args: &[DataValue],
) -> Result<bool> {
dbg!(&current);
Ok(match (accum, current) {
(accum @ DataValue::Guard, DataValue::Guard) => {
*accum = DataValue::List(vec![]);
true
}
(accum @ DataValue::Guard, val) => {
*accum = DataValue::Map(BTreeMap::from([(val.clone(), DataValue::from(1))]));
true
}
(accum, DataValue::Guard) => {
*accum = DataValue::List(
accum
.get_map()
.unwrap()
.iter()
.map(|(k, v)| DataValue::List(vec![k.clone(), v.clone()]))
.collect_vec(),
);
true
}
(DataValue::Map(l), val) => {
let entry = l.entry(val.clone()).or_insert_with(|| DataValue::from(0));
*entry = DataValue::from(entry.get_int().unwrap() + 1);
true
}
_ => unreachable!(),
})
}
define_aggr!(AGGR_COUNT_UNIQUE, false);
fn aggr_count_unique(
accum: &mut DataValue,
@ -348,6 +385,7 @@ fn aggr_choice(accum: &mut DataValue, current: &DataValue, _args: &[DataValue])
pub(crate) fn get_aggr(name: &str) -> Option<&'static Aggregation> {
Some(match name {
"count" => &AGGR_COUNT,
"group_count" => &AGGR_GROUP_COUNT,
"count_unique" => &AGGR_COUNT_UNIQUE,
"sum" => &AGGR_SUM,
"min" => &AGGR_MIN,

@ -80,6 +80,9 @@ impl From<DataValue> for JsonValue {
DataValue::Regex(r) => {
json!(r.0.as_str())
}
DataValue::Map(m) => {
JsonValue::Array(m.into_iter().map(|(k, v)| json!([k, v])).collect())
}
}
}
}

@ -1,5 +1,5 @@
use std::cmp::{Ordering, Reverse};
use std::collections::BTreeSet;
use std::collections::{BTreeMap, BTreeSet};
use std::fmt::{Debug, Display, Formatter};
use anyhow::{bail, Result};
@ -19,13 +19,19 @@ use crate::data::triple::StoreOp;
pub(crate) struct RegexWrapper(pub(crate) Regex);
impl Serialize for RegexWrapper {
fn serialize<S>(&self, _serializer: S) -> std::result::Result<S::Ok, S::Error> where S: serde::Serializer {
fn serialize<S>(&self, _serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
panic!("serializing regex");
}
}
impl<'de> Deserialize<'de> for RegexWrapper {
fn deserialize<D>(_deserializer: D) -> std::result::Result<Self, D::Error> where D: Deserializer<'de> {
fn deserialize<D>(_deserializer: D) -> std::result::Result<Self, D::Error>
where
D: Deserializer<'de>,
{
panic!("deserializing regex");
}
}
@ -72,6 +78,8 @@ pub(crate) enum DataValue {
List(Vec<DataValue>),
#[serde(rename = "y")]
Set(BTreeSet<DataValue>),
#[serde(rename = "w")]
Map(BTreeMap<DataValue, DataValue>),
#[serde(rename = "g")]
Guard,
#[serde(rename = "o")]
@ -201,6 +209,7 @@ impl Debug for DataValue {
}
DataValue::List(t) => f.debug_list().entries(t.iter()).finish(),
DataValue::Set(t) => f.debug_list().entries(t.iter()).finish(),
DataValue::Map(m) => f.debug_map().entries(m.iter()).finish(),
DataValue::DescVal(v) => {
write!(f, "desc<{:?}>", v)
}
@ -247,6 +256,12 @@ impl DataValue {
_ => None,
}
}
pub(crate) fn get_map(&self) -> Option<&BTreeMap<DataValue, DataValue>> {
match self {
DataValue::Map(m) => Some(m),
_ => None
}
}
pub(crate) fn get_int(&self) -> Option<i64> {
match self {
DataValue::Number(n) => n.get_int(),

@ -153,35 +153,23 @@ impl TempStore {
it.seek(&lower);
let it = TempStoreIter { it, started: false };
let aggrs = aggrs.to_vec();
let key_indices = aggrs
.iter()
.enumerate()
.filter_map(|(i, aggr)| if aggr.is_none() { Some(i) } else { None })
.collect_vec();
let n_keys = aggrs.iter().filter(|aggr| aggr.is_none()).count();
let grouped = it.group_by(move |t_res| {
if let Ok(tuple) = t_res {
Some(
key_indices
.iter()
.map(|i| tuple.0[*i].clone())
.collect_vec(),
)
Some(tuple.0[..n_keys].to_vec())
} else {
None
}
});
let mut invert_indices = vec![];
let mut idx = 0;
for aggr in aggrs.iter() {
for (idx, aggr) in aggrs.iter().enumerate() {
if aggr.is_none() {
invert_indices.push(idx);
idx += 1;
}
}
for aggr in aggrs.iter() {
for (idx, aggr) in aggrs.iter().enumerate() {
if aggr.is_some() {
invert_indices.push(idx);
idx += 1;
}
}
let invert_indices = invert_indices
@ -200,7 +188,7 @@ impl TempStore {
if let Some((aggr_op, aggr_args)) = aggr {
(aggr_op.combine)(&mut aggr_res[idx], val, aggr_args)?;
} else {
aggr_res[idx] = first_tuple.0[idx].clone();
aggr_res[idx] = first_tuple.0[invert_indices[idx]].clone();
}
}
for tuple in it {

@ -155,6 +155,29 @@ fn air_routes() -> Result<()> {
.unwrap()
);
let most_out_routes_time_inv = Instant::now();
let res = db.run_script(
r#"
route_count[count(?r), ?a, ?x] := [?r route.src ?a], ?x is 1;
?[?code, ?n] := route_count[?n, ?a, ?_], ?n > 180, [?a airport.iata ?code];
:sort -?n;
"#,
)?;
dbg!(most_out_routes_time_inv.elapsed());
assert_eq!(
res,
serde_json::Value::from_str(
r#"[
["FRA",307],["IST",307],["CDG",293],["AMS",282],["MUC",270],["ORD",264],["DFW",251],
["PEK",248],["DXB",247],["ATL",242],["DME",232],["LGW",232],["LHR",221],["DEN",216],
["MAN",216],["LAX",213],["PVG",212],["STN",211],["MAD",206],["VIE",206],["BCN",203],
["BER",202],["FCO",201],["JFK",201],["DUS",199],["IAH",199],["EWR",197],["MIA",195],
["YYZ",195],["BRU",194],["CPH",194],["DOH",186],["DUB",185],["CLT",184],["SVO",181]
]"#
)
.unwrap()
);
let most_routes_time = Instant::now();
let res = db.run_script(
r#"
@ -482,5 +505,36 @@ fn air_routes() -> Result<()> {
.unwrap()
);
let len_of_names_count_time = Instant::now();
let res = db.run_script(
r#"
?[sum(?n)] := [?a airport.iata 'AUS'],
[?r route.src ?a],
[?r route.dst ?a2],
[?a2 airport.city ?city_name],
?n is length(?city_name);
"#,
)?;
dbg!(len_of_names_count_time.elapsed());
assert_eq!(res, json!([[866]]));
let group_count_by_out_time = Instant::now();
let res = db.run_script(
r#"
route_count[count(?r), ?a] := [?r route.src ?a];
?[?n, count(?a)] := route_count[?n, ?a];
:order -?a;
:limit 10;
"#,
)?;
dbg!(group_count_by_out_time.elapsed());
assert_eq!(
res,
serde_json::Value::from_str(
r#"[[1,777],[2,649],[3,359],[4,232],[5,150],[6,139],[7,100],[8,74],[9,63],[10,59]]"#
)
.unwrap()
);
Ok(())
}

Loading…
Cancel
Save