non-grouping aggregations now always return exactly one row

main
Ziyang Hu 2 years ago
parent 233acd290f
commit 4ef4404d87

@ -38,6 +38,7 @@ pub(crate) trait NormalAggrObj {
}
pub(crate) trait MeetAggrObj {
fn init_val(&self) -> DataValue;
fn update(&self, left: &mut DataValue, right: &DataValue) -> Result<bool>;
}
@ -93,6 +94,10 @@ impl NormalAggrObj for AggrAnd {
pub(crate) struct MeetAggrAnd;
impl MeetAggrObj for MeetAggrAnd {
fn init_val(&self) -> DataValue {
DataValue::Bool(true)
}
fn update(&self, left: &mut DataValue, right: &DataValue) -> Result<bool> {
match (left, right) {
(DataValue::Bool(l), DataValue::Bool(r)) => {
@ -129,6 +134,10 @@ impl NormalAggrObj for AggrOr {
pub(crate) struct MeetAggrOr;
impl MeetAggrObj for MeetAggrOr {
fn init_val(&self) -> DataValue {
DataValue::Bool(false)
}
fn update(&self, left: &mut DataValue, right: &DataValue) -> Result<bool> {
match (left, right) {
(DataValue::Bool(l), DataValue::Bool(r)) => {
@ -229,6 +238,10 @@ impl NormalAggrObj for AggrUnion {
pub(crate) struct MeetAggrUnion;
impl MeetAggrObj for MeetAggrUnion {
fn init_val(&self) -> DataValue {
DataValue::Set(BTreeSet::new())
}
fn update(&self, left: &mut DataValue, right: &DataValue) -> Result<bool> {
loop {
if let DataValue::List(l) = left {
@ -294,7 +307,17 @@ impl NormalAggrObj for AggrIntersection {
pub(crate) struct MeetAggrIntersection;
impl MeetAggrObj for MeetAggrIntersection {
fn init_val(&self) -> DataValue {
DataValue::Null
}
fn update(&self, left: &mut DataValue, right: &DataValue) -> Result<bool> {
if *left == DataValue::Null && *right != DataValue::Null {
*left = right.clone();
return Ok(true);
} else if *right == DataValue::Null {
return Ok(false);
}
loop {
if let DataValue::List(l) = left {
let s = l.iter().cloned().collect();
@ -558,14 +581,28 @@ pub(crate) struct AggrMin {
impl Default for AggrMin {
fn default() -> Self {
Self {
found: DataValue::Bot,
found: DataValue::Null,
}
}
}
impl NormalAggrObj for AggrMin {
fn set(&mut self, value: &DataValue) -> Result<()> {
if *value < self.found {
if *value == DataValue::Null {
return Ok(());
}
if self.found == DataValue::Null {
self.found = value.clone();
return Ok(());
}
let f1 = self
.found
.get_float()
.ok_or_else(|| miette!("'min' applied to non-numerical values"))?;
let f2 = value
.get_float()
.ok_or_else(|| miette!("'min' applied to non-numerical values"))?;
if f1 > f2 {
self.found = value.clone();
}
Ok(())
@ -579,8 +616,26 @@ impl NormalAggrObj for AggrMin {
pub(crate) struct MeetAggrMin;
impl MeetAggrObj for MeetAggrMin {
fn init_val(&self) -> DataValue {
DataValue::Null
}
fn update(&self, left: &mut DataValue, right: &DataValue) -> Result<bool> {
Ok(if *right < *left {
if *right == DataValue::Null {
return Ok(false);
}
if *left == DataValue::Null {
*left = right.clone();
return Ok(true);
}
let f1 = left
.get_float()
.ok_or_else(|| miette!("'min' applied to non-numerical values"))?;
let f2 = right
.get_float()
.ok_or_else(|| miette!("'min' applied to non-numerical values"))?;
Ok(if f1 > f2 {
*left = right.clone();
true
} else {
@ -605,7 +660,21 @@ impl Default for AggrMax {
impl NormalAggrObj for AggrMax {
fn set(&mut self, value: &DataValue) -> Result<()> {
if *value > self.found {
if *value == DataValue::Null {
return Ok(());
}
if self.found == DataValue::Null {
self.found = value.clone();
return Ok(());
}
let f1 = self
.found
.get_float()
.ok_or_else(|| miette!("'min' applied to non-numerical values"))?;
let f2 = value
.get_float()
.ok_or_else(|| miette!("'min' applied to non-numerical values"))?;
if f1 < f2 {
self.found = value.clone();
}
Ok(())
@ -619,8 +688,26 @@ impl NormalAggrObj for AggrMax {
pub(crate) struct MeetAggrMax;
impl MeetAggrObj for MeetAggrMax {
fn init_val(&self) -> DataValue {
DataValue::Null
}
fn update(&self, left: &mut DataValue, right: &DataValue) -> Result<bool> {
Ok(if *right > *left {
if *right == DataValue::Null {
return Ok(false);
}
if *left == DataValue::Null {
*left = right.clone();
return Ok(true);
}
let f1 = left
.get_float()
.ok_or_else(|| miette!("'min' applied to non-numerical values"))?;
let f2 = right
.get_float()
.ok_or_else(|| miette!("'min' applied to non-numerical values"))?;
Ok(if f1 < f2 {
*left = right.clone();
true
} else {
@ -629,34 +716,6 @@ impl MeetAggrObj for MeetAggrMax {
}
}
define_aggr!(AGGR_CHOICE, true);
#[derive(Default)]
pub(crate) struct AggrChoice {
found: Option<DataValue>,
}
impl NormalAggrObj for AggrChoice {
fn set(&mut self, value: &DataValue) -> Result<()> {
if self.found.is_none() {
self.found = Some(value.clone());
}
Ok(())
}
fn get(&self) -> Result<DataValue> {
self.found.clone().ok_or_else(|| miette!("empty choice"))
}
}
pub(crate) struct MeetAggrChoice;
impl MeetAggrObj for MeetAggrChoice {
fn update(&self, _left: &mut DataValue, _right: &DataValue) -> Result<bool> {
Ok(false)
}
}
define_aggr!(AGGR_CHOICE_LAST, true);
pub(crate) struct AggrChoiceLast {
@ -685,7 +744,14 @@ impl NormalAggrObj for AggrChoiceLast {
pub(crate) struct MeetAggrChoiceLast;
impl MeetAggrObj for MeetAggrChoiceLast {
fn init_val(&self) -> DataValue {
DataValue::Null
}
fn update(&self, left: &mut DataValue, right: &DataValue) -> Result<bool> {
if *right == DataValue::Null {
return Ok(false);
}
Ok(if *left == *right {
false
} else {
@ -739,14 +805,14 @@ define_aggr!(AGGR_MIN_COST, true);
pub(crate) struct AggrMinCost {
found: DataValue,
cost: DataValue,
cost: f64,
}
impl Default for AggrMinCost {
fn default() -> Self {
Self {
found: DataValue::Null,
cost: DataValue::Bot,
cost: f64::INFINITY,
}
}
}
@ -760,8 +826,11 @@ impl NormalAggrObj for AggrMinCost {
"'min_cost' requires a list of exactly two items as argument"
);
let c = &l[1];
if *c < self.cost {
self.cost = c.clone();
let cost = c
.get_float()
.ok_or_else(|| miette!("Cost must be numeric"))?;
if cost < self.cost {
self.cost = cost;
self.found = l[0].clone();
}
Ok(())
@ -771,13 +840,20 @@ impl NormalAggrObj for AggrMinCost {
}
fn get(&self) -> Result<DataValue> {
Ok(DataValue::List(vec![self.found.clone(), self.cost.clone()]))
Ok(DataValue::List(vec![
self.found.clone(),
DataValue::from(self.cost),
]))
}
}
pub(crate) struct MeetAggrMinCost;
impl MeetAggrObj for MeetAggrMinCost {
fn init_val(&self) -> DataValue {
DataValue::List(vec![DataValue::Null, DataValue::from(f64::INFINITY)])
}
fn update(&self, left: &mut DataValue, right: &DataValue) -> Result<bool> {
Ok(match (left, right) {
(DataValue::List(prev), DataValue::List(l)) => {
@ -788,7 +864,13 @@ impl MeetAggrObj for MeetAggrMinCost {
l
);
let cur_cost = l.get(1).unwrap();
let cur_cost = cur_cost
.get_float()
.ok_or_else(|| miette!("'min_cost' must have numerical costs"))?;
let prev_cost = prev.get(1).unwrap();
let prev_cost = prev_cost
.get_float()
.ok_or_else(|| miette!("'prev_cost' must have numerical costs"))?;
if prev_cost <= cur_cost {
false
@ -838,7 +920,17 @@ impl NormalAggrObj for AggrShortest {
pub(crate) struct MeetAggrShortest;
impl MeetAggrObj for MeetAggrShortest {
fn init_val(&self) -> DataValue {
DataValue::Null
}
fn update(&self, left: &mut DataValue, right: &DataValue) -> Result<bool> {
if *left == DataValue::Null && *right != DataValue::Null {
*left = right.clone();
return Ok(true);
} else if *right == DataValue::Null {
return Ok(false);
}
match (left, right) {
(DataValue::List(l), DataValue::List(r)) => Ok(if r.len() < l.len() {
*l = r.clone();
@ -851,13 +943,13 @@ impl MeetAggrObj for MeetAggrShortest {
}
}
define_aggr!(AGGR_COALESCE, true);
define_aggr!(AGGR_CHOICE, true);
pub(crate) struct AggrCoalesce {
pub(crate) struct AggrChoice {
found: DataValue,
}
impl Default for AggrCoalesce {
impl Default for AggrChoice {
fn default() -> Self {
Self {
found: DataValue::Null,
@ -865,7 +957,7 @@ impl Default for AggrCoalesce {
}
}
impl NormalAggrObj for AggrCoalesce {
impl NormalAggrObj for AggrChoice {
fn set(&mut self, value: &DataValue) -> Result<()> {
if self.found == DataValue::Null {
self.found = value.clone();
@ -878,9 +970,13 @@ impl NormalAggrObj for AggrCoalesce {
}
}
pub(crate) struct MeetAggrCoalesce;
pub(crate) struct MeetAggrChoice;
impl MeetAggrObj for MeetAggrChoice {
fn init_val(&self) -> DataValue {
DataValue::Null
}
impl MeetAggrObj for MeetAggrCoalesce {
fn update(&self, left: &mut DataValue, right: &DataValue) -> Result<bool> {
Ok(if *left == DataValue::Null && *right != DataValue::Null {
*left = right.clone();
@ -929,12 +1025,20 @@ impl NormalAggrObj for AggrBitAnd {
pub(crate) struct MeetAggrBitAnd;
impl MeetAggrObj for MeetAggrBitAnd {
fn init_val(&self) -> DataValue {
DataValue::Bytes(vec![])
}
fn update(&self, left: &mut DataValue, right: &DataValue) -> Result<bool> {
match (left, right) {
(DataValue::Bytes(left), DataValue::Bytes(right)) => {
if left == right {
return Ok(false);
}
if left.len() == 0 {
*left = right.clone();
return Ok(true);
}
ensure!(
left.len() == right.len(),
"operands of 'bit_and' must have the same lengths, got {:x?} and {:x?}",
@ -990,12 +1094,20 @@ impl NormalAggrObj for AggrBitOr {
pub(crate) struct MeetAggrBitOr;
impl MeetAggrObj for MeetAggrBitOr {
fn init_val(&self) -> DataValue {
DataValue::Bytes(vec![])
}
fn update(&self, left: &mut DataValue, right: &DataValue) -> Result<bool> {
match (left, right) {
(DataValue::Bytes(left), DataValue::Bytes(right)) => {
if left == right {
return Ok(false);
}
if left.len() == 0 {
*left = right.clone();
return Ok(true);
}
ensure!(
left.len() == right.len(),
"operands of 'bit_or' must have the same lengths, got {:x?} and {:x?}",
@ -1070,7 +1182,6 @@ pub(crate) fn parse_aggr(name: &str) -> Option<&'static Aggregation> {
"collect" => &AGGR_COLLECT,
"shortest" => &AGGR_SHORTEST,
"min_cost" => &AGGR_MIN_COST,
"coalesce" => &AGGR_COALESCE,
"bit_and" => &AGGR_BIT_AND,
"bit_or" => &AGGR_BIT_OR,
"bit_xor" => &AGGR_BIT_XOR,
@ -1095,7 +1206,6 @@ impl Aggregation {
name if name == AGGR_INTERSECTION.name => Box::new(MeetAggrIntersection),
name if name == AGGR_SHORTEST.name => Box::new(MeetAggrShortest),
name if name == AGGR_MIN_COST.name => Box::new(MeetAggrMinCost),
name if name == AGGR_COALESCE.name => Box::new(MeetAggrCoalesce),
_ => unreachable!(),
});
Ok(())
@ -1125,7 +1235,6 @@ impl Aggregation {
name if name == AGGR_SHORTEST.name => Box::new(AggrShortest::default()),
name if name == AGGR_MIN_COST.name => Box::new(AggrMinCost::default()),
name if name == AGGR_LATEST_BY.name => Box::new(AggrLatestBy::default()),
name if name == AGGR_COALESCE.name => Box::new(AggrCoalesce::default()),
name if name == AGGR_CHOICE_RAND.name => Box::new(AggrChoiceRand::default()),
name if name == AGGR_COLLECT.name => Box::new({
if args.is_empty() {

@ -361,26 +361,6 @@ fn test_max() {
assert_eq!(v, DataValue::from(10));
}
#[test]
fn test_choice() {
let mut aggr = parse_aggr("choice").unwrap().clone();
aggr.normal_init(&[]).unwrap();
aggr.meet_init(&[]).unwrap();
let mut choice_aggr = aggr.normal_op.unwrap();
choice_aggr.set(&DataValue::from(1)).unwrap();
choice_aggr.set(&DataValue::from(2)).unwrap();
choice_aggr.set(&DataValue::from(3)).unwrap();
assert_eq!(choice_aggr.get().unwrap(), DataValue::from(1));
let m_choice_aggr = aggr.meet_op.unwrap();
let mut v = DataValue::from(5);
m_choice_aggr.update(&mut v, &DataValue::from(1)).unwrap();
m_choice_aggr.update(&mut v, &DataValue::from(2)).unwrap();
m_choice_aggr.update(&mut v, &DataValue::from(3)).unwrap();
assert_eq!(v, DataValue::from(5));
}
#[test]
fn test_choice_last() {
let mut aggr = parse_aggr("choice_last").unwrap().clone();
@ -438,7 +418,7 @@ fn test_min_cost() {
.unwrap();
assert_eq!(
min_cost_aggr.get().unwrap(),
DataValue::List(vec![DataValue::Bool(true), DataValue::from(1)])
DataValue::List(vec![DataValue::Bool(true), DataValue::from(1.)])
);
let m_min_cost_aggr = aggr.meet_op.unwrap();
@ -533,16 +513,16 @@ fn test_shortest() {
}
#[test]
fn test_coalesce() {
let mut aggr = parse_aggr("coalesce").unwrap().clone();
fn test_choice() {
let mut aggr = parse_aggr("choice").unwrap().clone();
aggr.normal_init(&[]).unwrap();
aggr.meet_init(&[]).unwrap();
let mut coalesce_aggr = aggr.normal_op.unwrap();
coalesce_aggr.set(&DataValue::Null).unwrap();
coalesce_aggr.set(&DataValue::from(1)).unwrap();
coalesce_aggr.set(&DataValue::from(2)).unwrap();
assert_eq!(coalesce_aggr.get().unwrap(), DataValue::from(1));
let mut choice_aggr = aggr.normal_op.unwrap();
choice_aggr.set(&DataValue::Null).unwrap();
choice_aggr.set(&DataValue::from(1)).unwrap();
choice_aggr.set(&DataValue::from(2)).unwrap();
assert_eq!(choice_aggr.get().unwrap(), DataValue::from(1));
let m_coalesce_aggr = aggr.meet_op.unwrap();
let mut v = DataValue::Null;

@ -150,7 +150,7 @@ impl<'a> SessionTx<'a> {
CompiledRuleSet::Algo(_) => {
// no need to do anything, algos are only calculated once
},
}
}
}
}
@ -215,6 +215,21 @@ impl<'a> SessionTx<'a> {
}
poison.check()?;
}
if is_meet && store.is_empty() && ruleset[0].aggr.iter().all(|a| a.is_some()) {
let mut aggr = ruleset[0].aggr.clone();
for (aggr, args) in aggr.iter_mut().flatten() {
aggr.meet_init(args)?;
}
let value: Vec<_> = aggr
.iter()
.map(|a| -> Result<DataValue> {
let (aggr, args) = a.as_ref().unwrap();
let op = aggr.meet_op.as_ref().unwrap();
Ok(op.init_val())
})
.try_collect()?;
store.aggr_meet_put(&value, &mut aggr, 0)?;
}
}
AggrKind::Normal => {
let mut aggr_work: BTreeMap<Vec<DataValue>, Vec<Aggregation>> = BTreeMap::new();
@ -294,6 +309,21 @@ impl<'a> SessionTx<'a> {
}
}
if aggr_work.is_empty() && ruleset[0].aggr.iter().all(|v| v.is_some()) {
let empty_result: Vec<_> = ruleset[0]
.aggr
.iter()
.map(|a| {
let (aggr, args) = a.as_ref().unwrap();
let mut aggr = aggr.clone();
aggr.normal_init(args)?;
let op = aggr.normal_op.unwrap();
op.get()
})
.try_collect()?;
store.put(empty_result, 0);
}
for (keys, aggrs) in aggr_work {
let tuple_data: Vec<_> = inv_indices
.iter()

@ -290,8 +290,7 @@ impl<'s, S: Storage<'s>> Db<S> {
col.typing.coerce(DataValue::from(v))
})
.try_collect()?;
let v_store =
handle.encode_val_only_for_store(&vals, Default::default())?;
let v_store = handle.encode_val_only_for_store(&vals, Default::default())?;
tx.tx.put(&k_store, &v_store)?;
}
}
@ -1004,8 +1003,7 @@ impl<'s, S: Storage<'s>> Db<S> {
})
}
fn list_relations(&'s self) -> Result<NamedRows> {
let lower =
vec![DataValue::Str(SmartString::from(""))].encode_as_key(RelationId::SYSTEM);
let lower = vec![DataValue::Str(SmartString::from(""))].encode_as_key(RelationId::SYSTEM);
let upper = vec![DataValue::Str(SmartString::from(String::from(
LARGEST_UTF_CHAR,
)))]
@ -1087,6 +1085,7 @@ mod tests {
use itertools::Itertools;
use serde_json::json;
use crate::data::value::DataValue;
use crate::new_cozo_mem;
#[test]
@ -1101,7 +1100,10 @@ mod tests {
.collect_vec();
assert_eq!(json!(res), json!([3, 5]));
let res = db
.run_script("?[a] := a in [5,3,1,2,4] :limit 2 :offset 1", Default::default())
.run_script(
"?[a] := a in [5,3,1,2,4] :limit 2 :offset 1",
Default::default(),
)
.unwrap()
.rows
.into_iter()
@ -1109,7 +1111,10 @@ mod tests {
.collect_vec();
assert_eq!(json!(res), json!([1, 3]));
let res = db
.run_script("?[a] := a in [5,3,1,2,4] :limit 2 :offset 4", Default::default())
.run_script(
"?[a] := a in [5,3,1,2,4] :limit 2 :offset 4",
Default::default(),
)
.unwrap()
.rows
.into_iter()
@ -1117,7 +1122,10 @@ mod tests {
.collect_vec();
assert_eq!(json!(res), json!([4]));
let res = db
.run_script("?[a] := a in [5,3,1,2,4] :limit 2 :offset 5", Default::default())
.run_script(
"?[a] := a in [5,3,1,2,4] :limit 2 :offset 5",
Default::default(),
)
.unwrap()
.rows
.into_iter()
@ -1125,4 +1133,28 @@ mod tests {
.collect_vec();
assert_eq!(json!(res), json!([]));
}
#[test]
fn test_normal_aggr_empty() {
let db = new_cozo_mem().unwrap();
let res = db
.run_script("?[count(a)] := a in []", Default::default())
.unwrap()
.rows;
assert_eq!(res, vec![vec![json!(0)]]);
}
#[test]
fn test_meet_aggr_empty() {
let db = new_cozo_mem().unwrap();
let res = db
.run_script("?[min(a)] := a in []", Default::default())
.unwrap()
.rows;
assert_eq!(res, vec![vec![json!(null)]]);
let res = db
.run_script("?[min(a), count(a)] := a in []", Default::default())
.unwrap()
.rows;
assert_eq!(res, vec![vec![json!(null), json!(0)]]);
}
}

@ -163,6 +163,14 @@ impl InMemRelation {
epoch_map.insert(tuple, Tuple::default());
}
}
pub(crate) fn is_empty(&self) -> bool {
let mem_db: &RefCell<_> = self.mem_db.borrow();
let epoch_maps = mem_db.borrow();
let epoch_map = epoch_maps.get(0).unwrap();
let epoch_map: &RefCell<BTreeMap<_, _>> = epoch_map.borrow();
let epoch_map = epoch_map.borrow();
epoch_map.is_empty()
}
pub(crate) fn exists(&self, tuple: &Tuple, epoch: u32) -> bool {
let mem_db: &RefCell<_> = self.mem_db.borrow();
let epoch_maps = mem_db.borrow();

Loading…
Cancel
Save