diff --git a/src/data/aggr.rs b/src/data/aggr.rs index 7f125ff8..dc45fe0e 100644 --- a/src/data/aggr.rs +++ b/src/data/aggr.rs @@ -55,6 +55,88 @@ macro_rules! define_aggr { }; } +define_aggr!(AGGR_AND, true); + +struct AggrAnd { + accum: bool, +} + +impl Default for AggrAnd { + fn default() -> Self { + Self { accum: true } + } +} + +impl NormalAggrObj for AggrAnd { + fn set(&mut self, value: &DataValue) -> Result<()> { + match value { + DataValue::Bool(v) => self.accum &= *v, + v => bail!("cannot compute 'and' for {:?}", v), + } + Ok(()) + } + + fn get(&self) -> Result { + Ok(DataValue::Bool(self.accum)) + } +} + +struct MeetAggrAnd; + +impl MeetAggrObj for MeetAggrAnd { + fn update(&self, left: &mut DataValue, right: &DataValue) -> Result { + match (left, right) { + (DataValue::Bool(l), DataValue::Bool(r)) => { + let old = *l; + *l &= *r; + Ok(old == *l) + } + (u, v) => bail!("cannot compute 'and' for {:?} and {:?}", u, v), + } + } +} + +define_aggr!(AGGR_OR, true); + +struct AggrOr { + accum: bool, +} + +impl Default for AggrAnd { + fn default() -> Self { + Self { accum: false } + } +} + +impl NormalAggrObj for AggrOr { + fn set(&mut self, value: &DataValue) -> Result<()> { + match value { + DataValue::Bool(v) => self.accum |= *v, + v => bail!("cannot compute 'or' for {:?}", v), + } + Ok(()) + } + + fn get(&self) -> Result { + Ok(DataValue::Bool(self.accum)) + } +} + +struct MeetAggrOr; + +impl MeetAggrObj for MeetAggrOr { + fn update(&self, left: &mut DataValue, right: &DataValue) -> Result { + match (left, right) { + (DataValue::Bool(l), DataValue::Bool(r)) => { + let old = *l; + *l |= *r; + Ok(old == *l) + } + (u, v) => bail!("cannot compute 'or' for {:?} and {:?}", u, v), + } + } +} + define_aggr!(AGGR_UNIQUE, false); #[derive(Default)] @@ -501,6 +583,73 @@ impl MeetAggrObj for MeetAggrMinCost { } } +define_aggr!(AGGR_MAX_COST, true); + +struct AggrMaxCost { + found: DataValue, + cost: DataValue, +} + +impl Default for AggrMaxCost { + fn default() -> Self { + Self { + found: DataValue::Null, + cost: DataValue::Null, + } + } +} + +impl NormalAggrObj for AggrMaxCost { + fn set(&mut self, value: &DataValue) -> Result<()> { + match value { + DataValue::List(l) => { + ensure!( + l.len() == 2, + "'max_cost' requires a list of exactly two items as argument" + ); + let c = &l[1]; + if *c > self.cost { + self.cost = c.clone(); + self.found = l[0].clone(); + } + Ok(()) + } + v => bail!("cannot compute 'max_cost' on {:?}", v), + } + } + + fn get(&self) -> Result { + Ok(DataValue::List(vec![self.found.clone(), self.cost.clone()])) + } +} + +struct MeetAggrMaxCost; + +impl MeetAggrObj for MeetAggrMaxCost { + fn update(&self, left: &mut DataValue, right: &DataValue) -> Result { + Ok(match (left, right) { + (DataValue::List(prev), DataValue::List(l)) => { + ensure!( + l.len() == 2 && prev.len() == 2, + "'max_cost' requires a list of length 2 as argument, got {:?}, {:?}", + prev, + l + ); + let cur_cost = l.get(1).unwrap(); + let prev_cost = prev.get(1).unwrap(); + + if prev_cost >= cur_cost { + false + } else { + *prev = l.clone(); + true + } + } + (u, v) => bail!("cannot compute 'max_cost' on {:?}, {:?}", u, v), + }) + } +} + define_aggr!(AGGR_SHORTEST, true); #[derive(Default)] @@ -614,6 +763,8 @@ pub(crate) fn get_aggr(name: &str) -> Option<&'static Aggregation> { impl Aggregation { pub(crate) fn meet_init(&mut self, _args: &[DataValue]) -> Result<()> { self.meet_op.replace(match self.name { + name if name == AGGR_AND.name => Box::new(MeetAggrAnd), + name if name == AGGR_OR.name => Box::new(MeetAggrOr), name if name == AGGR_MIN.name => Box::new(MeetAggrMin), name if name == AGGR_MAX.name => Box::new(MeetAggrMax), name if name == AGGR_CHOICE.name => Box::new(MeetAggrChoice), @@ -621,6 +772,7 @@ impl Aggregation { name if name == AGGR_INTERSECTION.name => Box::new(MeetAggrIntersection), name if name == AGGR_SHORTEST.name => Box::new(MeetAggrShortest), name if name == AGGR_MIN_COST.name => Box::new(MeetAggrMinCost), + name if name == AGGR_MAX_COST.name => Box::new(MeetAggrMaxCost), name if name == AGGR_COALESCE.name => Box::new(MeetAggrCoalesce), _ => unreachable!(), }); @@ -628,6 +780,8 @@ impl Aggregation { } pub(crate) fn normal_init(&mut self, args: &[DataValue]) -> Result<()> { self.normal_op.replace(match self.name { + name if name == AGGR_AND.name => Box::new(AggrAnd::default()), + name if name == AGGR_OR.name => Box::new(AggrOr::default()), name if name == AGGR_COUNT.name => Box::new(AggrCount::default()), name if name == AGGR_GROUP_COUNT.name => Box::new(AggrGroupCount::default()), name if name == AGGR_COUNT_UNIQUE.name => Box::new(AggrCountUnique::default()), @@ -641,6 +795,7 @@ impl Aggregation { name if name == AGGR_INTERSECTION.name => Box::new(AggrIntersection::default()), name if name == AGGR_SHORTEST.name => Box::new(AggrShortest::default()), name if name == AGGR_MIN_COST.name => Box::new(AggrMinCost::default()), + name if name == AGGR_MAX_COST.name => Box::new(AggrMaxCost::default()), name if name == AGGR_COALESCE.name => Box::new(AggrCoalesce::default()), name if name == AGGR_COLLECT.name => Box::new({ if args.len() == 0 {