fix aggregation problems

main
Ziyang Hu 2 years ago
parent e6ab334caf
commit 582f8213b1

@ -112,7 +112,7 @@ if __name__ == '__main__':
# Unify('?c', 10000239)),
# T.country.code('?c', '?code'),
# T.country.desc('?c', '?desc'))])
res = db.run([Q([Count('?a')],
res = db.run([Q([Max('?n')],
T.route.distance('?a', '?n'))])
end_time = time.time()
print(json.dumps(res, indent=2))

@ -207,6 +207,7 @@ class AggrClass:
Count = AggrClass('Count')
Min = AggrClass('Min')
Max = AggrClass('Max')
def Const(item):
@ -231,4 +232,4 @@ def Unify(binding, expr):
__all__ = ['Gt', 'Lt', 'Ge', 'Le', 'Eq', 'Neq', 'Add', 'Sub', 'Mul', 'Div', 'Q', 'T', 'R', 'Const', 'Conj', 'Disj',
'NotExists', 'CozoDb', 'Typing', 'Cardinality', 'Indexing', 'PutAttr', 'RetractAttr', 'Attribute', 'Put',
'Retract', 'Pull', 'StrCat', 'Unify', 'DefAttrs', 'Count', 'Min']
'Retract', 'Pull', 'StrCat', 'Unify', 'DefAttrs', 'Count', 'Min', 'Max']

@ -1,4 +1,4 @@
use std::cmp::min;
use std::cmp::{max, min};
use std::fmt::{Debug, Formatter};
use anyhow::{bail, Result};
@ -67,14 +67,11 @@ fn aggr_sum(accum: &DataValue, current: &DataValue) -> Result<DataValue> {
}
}
define_aggr!(AGGR_MIN, false);
define_aggr!(AGGR_MIN, true);
fn aggr_min(accum: &DataValue, current: &DataValue) -> Result<DataValue> {
match (accum, current) {
(DataValue::Bottom, DataValue::Bottom) => Ok(DataValue::Float(f64::infinity().into())),
(DataValue::Bottom, DataValue::Int(i)) => Ok(DataValue::Int(*i)),
(DataValue::Bottom, DataValue::Float(f)) => Ok(DataValue::Float(f.0.into())),
(DataValue::Int(i), DataValue::Bottom) => Ok(DataValue::Int(*i)),
(DataValue::Float(f), DataValue::Bottom) => Ok(DataValue::Float(f.0.into())),
(DataValue::Int(i), DataValue::Int(j)) => Ok(DataValue::Int(min(*i, *j))),
(DataValue::Int(j), DataValue::Float(i)) | (DataValue::Float(i), DataValue::Int(j)) => {
Ok(DataValue::Float(min(i.clone(), (*j as f64).into())))
@ -90,10 +87,33 @@ fn aggr_min(accum: &DataValue, current: &DataValue) -> Result<DataValue> {
}
}
define_aggr!(AGGR_MAX, true);
fn aggr_max(accum: &DataValue, current: &DataValue) -> Result<DataValue> {
match (accum, current) {
(DataValue::Bottom, DataValue::Int(i)) => Ok(DataValue::Int(*i)),
(DataValue::Bottom, DataValue::Float(f)) => Ok(DataValue::Float(f.0.into())),
(DataValue::Float(f), DataValue::Bottom) => Ok(DataValue::Float(f.0.into())),
(DataValue::Int(i), DataValue::Int(j)) => Ok(DataValue::Int(max(*i, *j))),
(DataValue::Int(j), DataValue::Float(i)) | (DataValue::Float(i), DataValue::Int(j)) => {
Ok(DataValue::Float(max(i.clone(), (*j as f64).into())))
}
(DataValue::Float(i), DataValue::Float(j)) => {
Ok(DataValue::Float(max(i.clone(), j.clone())))
}
(i, j) => bail!(
"cannot compute min: encountered value {:?} for aggregate {:?}",
j,
i
),
}
}
pub(crate) fn get_aggr(name: &str) -> Option<&'static Aggregation> {
Some(match name {
"Count" => &AGGR_COUNT,
"Sum" => &AGGR_SUM,
"Min" => &AGGR_MIN,
"Max" => &AGGR_MAX,
_ => return None,
})
}

@ -78,11 +78,13 @@ impl SessionTx {
} else {
store.clone()
};
for item_res in rule.relation.iter(self, Some(0), &use_delta) {
for (serial, item_res) in
rule.relation.iter(self, Some(0), &use_delta).enumerate()
{
let item = item_res?;
trace!("item for {:?}.{}: {:?} at {}", k, rule_n, item, epoch);
if rule_is_aggr {
store_to_use.normal_aggr_put(&item, &rule.aggr)?;
store_to_use.normal_aggr_put(&item, &rule.aggr, serial)?;
} else {
store_to_use.put(&item, 0)?;
}

@ -121,6 +121,7 @@ impl TempStore {
&self,
tuple: &Tuple,
aggrs: &[Option<Aggregation>],
serial: usize,
) -> Result<(), RocksDbStatus> {
let mut vals = vec![];
for (idx, agg) in aggrs.iter().enumerate() {
@ -133,6 +134,7 @@ impl TempStore {
vals.push(tuple.0[idx].clone());
}
}
vals.push(DataValue::Int(serial as i64));
self.db
.put(&Tuple(vals).encode_as_key_for_epoch(self.id, 0), &[])
}
@ -173,22 +175,43 @@ impl TempStore {
None
}
});
let mut invert_indices = vec![];
let mut idx = 0;
for aggr in aggrs.iter() {
if aggr.is_none() {
invert_indices.push(idx);
idx += 1;
}
}
for aggr in aggrs.iter() {
if aggr.is_some() {
invert_indices.push(idx);
idx += 1;
}
}
let invert_indices = invert_indices
.into_iter()
.enumerate()
.sorted_by_key(|(_a, b)| *b)
.map(|(a, _b)| a)
.collect_vec();
for (key, group) in grouped.into_iter() {
if key.is_some() {
let mut aggr_res = vec![DataValue::Bottom; aggrs.len()];
for tup_res in group.into_iter() {
let tuple = tup_res.unwrap().0;
for (i, val) in tuple.into_iter().enumerate() {
if let Some(aggr_op) = &aggrs[i] {
aggr_res[i] = (aggr_op.combine)(&aggr_res[i], &val)?;
for tuple in group.into_iter() {
let tuple = tuple?;
for (idx, aggr) in aggrs.iter().enumerate() {
let val = &tuple.0[invert_indices[idx]];
if let Some(aggr_op) = aggr {
aggr_res[idx] = (aggr_op.combine)(&aggr_res[idx], val)?;
} else {
aggr_res[i] = val;
aggr_res[idx] = val.clone();
}
}
}
for (i, aggr) in aggrs.iter().enumerate() {
if let Some(aggr_op) = aggr {
aggr_res[i] = (aggr_op.combine)(&aggr_res[i], &DataValue::Bottom)?;
for (i, aggr) in aggrs.iter().enumerate() {
if let Some(aggr_op) = aggr {
aggr_res[i] = (aggr_op.combine)(&aggr_res[i], &DataValue::Bottom)?;
}
}
}
store.put(&Tuple(aggr_res), 0)?;

Loading…
Cancel
Save