From a91cfc840d5365b65d719a33c84da2642bdfffb1 Mon Sep 17 00:00:00 2001 From: Ziyang Hu Date: Mon, 5 Dec 2022 16:42:46 +0800 Subject: [PATCH] refactor normal aggr --- cozo-core/src/query/eval.rs | 108 +++++++++++++++++--- cozo-core/src/runtime/in_mem.rs | 173 ++++++-------------------------- 2 files changed, 124 insertions(+), 157 deletions(-) diff --git a/cozo-core/src/query/eval.rs b/cozo-core/src/query/eval.rs index 6e9af2ff..73dc3987 100644 --- a/cozo-core/src/query/eval.rs +++ b/cozo-core/src/query/eval.rs @@ -6,14 +6,19 @@ * You can obtain one at https://mozilla.org/MPL/2.0/. */ +use std::collections::btree_map::Entry; use std::collections::{BTreeMap, BTreeSet}; use std::mem; +use itertools::Itertools; use log::{debug, trace}; use miette::Result; +use crate::data::aggr::Aggregation; use crate::data::program::{MagicAlgoApply, MagicSymbol, NoEntryError}; use crate::data::symb::{Symbol, PROG_ENTRY}; +use crate::data::tuple::Tuple; +use crate::data::value::DataValue; use crate::parse::SourceSpan; use crate::query::compile::{AggrKind, CompiledProgram, CompiledRule, CompiledRuleSet}; use crate::runtime::db::Poison; @@ -71,6 +76,7 @@ impl<'a> SessionTx<'a> { } Ok((ret_area, early_return)) } + /// returns true if early return is activated fn semi_naive_magic_evaluate( &self, prog: &CompiledProgram, @@ -159,6 +165,7 @@ impl<'a> SessionTx<'a> { let out = stores.get(rule_symb).unwrap(); algo_impl.run(self, algo_apply, stores, out, poison) } + /// returns true is early return is activated fn initial_rule_eval( &self, rule_symb: &MagicSymbol, @@ -204,33 +211,106 @@ impl<'a> SessionTx<'a> { } } AggrKind::Normal => { - let store_to_use = self.new_temp_store(rule_symb.symbol().span); + let mut aggr_work: BTreeMap, Vec> = BTreeMap::new(); + for (rule_n, rule) in ruleset.iter().enumerate() { debug!( "Calculation for normal aggr rule {:?}.{}", rule_symb, rule_n ); - for (serial, item_res) in - rule.relation.iter(self, Some(0), &use_delta)?.enumerate() - { + + let keys_indices = rule + .aggr + .iter() + .enumerate() + .filter_map(|(i, a)| if a.is_none() { Some(i) } else { None }) + .collect_vec(); + let extract_keys = |t: &Tuple| -> Vec { + keys_indices.iter().map(|i| t.0[*i].clone()).collect_vec() + }; + + let val_indices_and_aggrs = rule + .aggr + .iter() + .enumerate() + .filter_map(|(i, a)| match a { + None => None, + Some(aggr) => Some((i, aggr.clone())), + }) + .collect_vec(); + + for item_res in rule.relation.iter(self, Some(0), &use_delta)? { let item = item_res?; trace!("item for {:?}.{}: {:?} at {}", rule_symb, rule_n, item, 0); - store_to_use.normal_aggr_put(&item, &rule.aggr, serial); + + let keys = extract_keys(&item); + + match aggr_work.entry(keys) { + Entry::Occupied(mut ent) => { + let aggr_ops = ent.get_mut(); + for (aggr_idx, (tuple_idx, _)) in + val_indices_and_aggrs.iter().enumerate() + { + aggr_ops[aggr_idx] + .normal_op + .as_mut() + .unwrap() + .set(&item.0[*tuple_idx])?; + } + } + Entry::Vacant(ent) => { + let mut aggr_ops = Vec::with_capacity(val_indices_and_aggrs.len()); + for (i, (aggr, params)) in &val_indices_and_aggrs { + let mut cur_aggr = aggr.clone(); + cur_aggr.normal_init(params)?; + cur_aggr.normal_op.as_mut().unwrap().set(&item.0[*i])?; + aggr_ops.push(cur_aggr) + } + ent.insert(aggr_ops); + } + } + *changed.get_mut(rule_symb).unwrap() = true; } poison.check()?; } - if store_to_use.normal_aggr_scan_and_put( - &ruleset[0].aggr, - store, + + let mut inv_indices = Vec::with_capacity(ruleset[0].aggr.len()); + let mut seen_keys = 0usize; + let mut seen_aggrs = 0usize; + for aggr in ruleset[0].aggr.iter() { + if aggr.is_some() { + inv_indices.push((true, seen_aggrs)); + seen_aggrs += 1; + } else { + inv_indices.push((false, seen_keys)); + seen_keys += 1; + } + } + + for (keys, aggrs) in aggr_work { + let tuple_data: Vec<_> = inv_indices + .iter() + .map(|(is_aggr, idx)| { + if *is_aggr { + aggrs[*idx].normal_op.as_ref().unwrap().get() + } else { + Ok(keys[*idx].clone()) + } + }) + .try_collect()?; + let tuple = Tuple(tuple_data); if should_check_limit { - Some(limiter) + if !store.exists(&tuple, 0) { + store.put_with_skip(tuple, limiter.should_skip_next()); + if limiter.incr_and_should_stop() { + return Ok(true); + } + } + // else, do nothing } else { - None - }, - poison, - )? { - return Ok(true); + store.put(tuple, 0); + } } } } diff --git a/cozo-core/src/runtime/in_mem.rs b/cozo-core/src/runtime/in_mem.rs index a2e87c0d..4326fae4 100644 --- a/cozo-core/src/runtime/in_mem.rs +++ b/cozo-core/src/runtime/in_mem.rs @@ -6,17 +6,15 @@ * You can obtain one at https://mozilla.org/MPL/2.0/. */ -use std::borrow::{Borrow, BorrowMut}; +use std::borrow::Borrow; use std::cell::RefCell; use std::collections::BTreeMap; use std::fmt::{Debug, Formatter}; -use std::iter; use std::ops::Bound::Included; use std::rc::Rc; use std::sync::atomic::{AtomicU32, Ordering}; -use std::sync::{Arc, RwLock}; +use std::sync::Arc; -use either::{Left, Right}; use itertools::Itertools; use miette::Result; @@ -24,8 +22,6 @@ use crate::data::aggr::Aggregation; use crate::data::program::MagicSymbol; use crate::data::tuple::Tuple; use crate::data::value::DataValue; -use crate::query::eval::QueryLimiter; -use crate::runtime::db::Poison; #[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq)] pub(crate) struct StoredRelationId(pub(crate) u32); @@ -38,8 +34,11 @@ impl Debug for StoredRelationId { #[derive(Clone)] pub(crate) struct InMemRelation { - mem_db: Rc>>>>>, + mem_db: Rc>>>>>, epoch_size: Arc, + // total: Rc>>, + // current: Rc>>, + // prev: Rc>>, pub(crate) id: StoredRelationId, pub(crate) rule_name: MagicSymbol, pub(crate) arity: usize, @@ -90,7 +89,9 @@ impl InMemRelation { let mem_db: &RefCell<_> = self.mem_db.borrow(); let zero_map = mem_db.borrow().get(0).unwrap().clone(); - let mut zero_target = zero_map.try_write().unwrap(); + let zero_target: &RefCell> = zero_map.borrow(); + let mut zero_target = zero_target.borrow_mut(); + let key = Tuple( aggrs .iter() @@ -116,7 +117,8 @@ impl InMemRelation { } if changed && epoch != 0 { let epoch_map = mem_db.borrow().get(epoch as usize).unwrap().clone(); - let mut epoch_map = epoch_map.try_write().unwrap(); + let epoch_map: &RefCell> = epoch_map.borrow(); + let mut epoch_map = epoch_map.borrow_mut(); epoch_map.insert(key, prev_aggr.clone()); } Ok(changed) @@ -137,7 +139,8 @@ impl InMemRelation { zero_target.insert(key.clone(), tuple_to_store.clone()); if epoch != 0 { let epoch_map = mem_db.borrow().get(epoch as usize).unwrap().clone(); - let mut epoch_map = epoch_map.try_write().unwrap(); + let epoch_map: &RefCell> = epoch_map.borrow(); + let mut epoch_map = epoch_map.borrow_mut(); epoch_map.insert(key, tuple_to_store); } Ok(true) @@ -147,152 +150,31 @@ impl InMemRelation { self.ensure_mem_db_for_epoch(epoch); let mem_db: &RefCell<_> = self.mem_db.borrow(); let epoch_map = mem_db.borrow().get(epoch as usize).unwrap().clone(); - let mut epoch_map = epoch_map.try_write().unwrap(); + let epoch_map: &RefCell> = epoch_map.borrow(); + let mut epoch_map = epoch_map.borrow_mut(); epoch_map.insert(tuple, Tuple::default()); } pub(crate) fn put_with_skip(&self, tuple: Tuple, should_skip: bool) { self.ensure_mem_db_for_epoch(0); let mem_db: &RefCell<_> = self.mem_db.borrow(); let epoch_map = mem_db.borrow().get(0).unwrap().clone(); - let mut epoch_map = epoch_map.try_write().unwrap(); + let epoch_map: &RefCell> = epoch_map.borrow(); + let mut epoch_map = epoch_map.borrow_mut(); + if should_skip { epoch_map.insert(tuple, Tuple(vec![DataValue::Guard])); } else { epoch_map.insert(tuple, Tuple::default()); } } - pub(crate) fn normal_aggr_put( - &self, - tuple: &Tuple, - aggrs: &[Option<(Aggregation, Vec)>], - serial: usize, - ) { - self.ensure_mem_db_for_epoch(0); - let mut vals = vec![]; - for (idx, agg) in aggrs.iter().enumerate() { - if agg.is_none() { - vals.push(tuple.0[idx].clone()); - } - } - for (idx, agg) in aggrs.iter().enumerate() { - if agg.is_some() { - vals.push(tuple.0[idx].clone()); - } - } - vals.push(DataValue::from(serial as i64)); - - let mem_db: &RefCell<_> = self.mem_db.borrow(); - let epoch_map = mem_db.borrow().get(0).unwrap().clone(); - let mut epoch_map = epoch_map.try_write().unwrap(); - epoch_map.insert(Tuple(vals), Tuple::default()); - } pub(crate) fn exists(&self, tuple: &Tuple, epoch: u32) -> bool { self.ensure_mem_db_for_epoch(epoch); let mem_db: &RefCell<_> = self.mem_db.borrow(); let epoch_map = mem_db.borrow().get(epoch as usize).unwrap().clone(); - let epoch_map = epoch_map.try_read().unwrap(); - epoch_map.contains_key(tuple) - } + let epoch_map: &RefCell> = epoch_map.borrow(); + let epoch_map = epoch_map.borrow(); - pub(crate) fn normal_aggr_scan_and_put( - &self, - aggrs: &[Option<(Aggregation, Vec)>], - store: &InMemRelation, - mut limiter: Option<&mut QueryLimiter>, - poison: Poison, - ) -> Result { - let mem_db: &RefCell<_> = self.mem_db.borrow(); - let epoch_map = mem_db.borrow(); - let epoch_map = epoch_map.get(0); - let it = match epoch_map { - None => Left(iter::empty()), - Some(target) => { - let target = target.try_read().unwrap(); - Right(target.clone().into_iter().map(|(k, v)| { - if v.0.is_empty() { - k - } else { - let combined = - k.0.into_iter() - .zip(v.0.into_iter()) - .map(|(kel, vel)| { - if matches!(kel, DataValue::Guard) { - vel - } else { - kel - } - }) - .collect_vec(); - Tuple(combined) - } - })) - } - }; - - let mut aggrs = aggrs.to_vec(); - let n_keys = aggrs.iter().filter(|aggr| aggr.is_none()).count(); - let grouped = it.group_by(move |tuple| tuple.0[..n_keys].to_vec()); - let mut invert_indices = vec![]; - for (idx, aggr) in aggrs.iter().enumerate() { - if aggr.is_none() { - invert_indices.push(idx); - } - } - for (idx, aggr) in aggrs.iter().enumerate() { - if aggr.is_some() { - invert_indices.push(idx); - } - } - let invert_indices = invert_indices - .into_iter() - .enumerate() - .sorted_by_key(|(_a, b)| *b) - .map(|(a, _b)| a) - .collect_vec(); - for (_key, mut group_iter) in grouped.into_iter() { - for (aggr, args) in aggrs.iter_mut().flatten() { - aggr.normal_init(args)?; - } - let mut aggr_res = vec![DataValue::Guard; aggrs.len()]; - let first_tuple = group_iter.next().unwrap(); - for (idx, aggr) in aggrs.iter_mut().enumerate() { - let val = &first_tuple.0[invert_indices[idx]]; - if let Some((aggr_op, _aggr_args)) = aggr { - let op = aggr_op.normal_op.as_mut().unwrap(); - op.set(val)?; - } else { - aggr_res[idx] = first_tuple.0[invert_indices[idx]].clone(); - } - } - for tuple in group_iter { - for (idx, aggr) in aggrs.iter_mut().enumerate() { - let val = &tuple.0[invert_indices[idx]]; - if let Some((aggr_op, _aggr_args)) = aggr { - let op = aggr_op.normal_op.as_mut().unwrap(); - op.set(val)?; - } - } - } - poison.check()?; - for (i, aggr) in aggrs.iter().enumerate() { - if let Some((aggr_op, _aggr_args)) = aggr { - let op = aggr_op.normal_op.as_ref().unwrap(); - aggr_res[i] = op.get()?; - } - } - let res_tpl = Tuple(aggr_res); - if let Some(lmt) = limiter.borrow_mut() { - if !store.exists(&res_tpl, 0) { - store.put_with_skip(res_tpl, lmt.should_skip_next()); - if lmt.incr_and_should_stop() { - return Ok(true); - } - } - } else { - store.put(res_tpl, 0); - } - } - Ok(false) + epoch_map.contains_key(tuple) } pub(crate) fn scan_all_for_epoch<'a>( @@ -302,7 +184,8 @@ impl InMemRelation { self.ensure_mem_db_for_epoch(epoch); let mem_db: &RefCell<_> = self.mem_db.borrow(); let epoch_map = mem_db.borrow().get(epoch as usize).unwrap().clone(); - let epoch_map = epoch_map.try_read().unwrap(); + let epoch_map: &RefCell> = epoch_map.borrow(); + let epoch_map = epoch_map.borrow(); epoch_map.clone().into_iter().map(|(k, v)| { if v.0.is_empty() { @@ -330,7 +213,8 @@ impl InMemRelation { self.ensure_mem_db_for_epoch(0); let mem_db: &RefCell<_> = self.mem_db.borrow(); let epoch_map = mem_db.borrow().get(0).unwrap().clone(); - let epoch_map = epoch_map.try_read().unwrap(); + let epoch_map: &RefCell> = epoch_map.borrow(); + let epoch_map = epoch_map.borrow(); epoch_map.clone().into_iter().filter_map(|(k, v)| { if v.0.is_empty() { @@ -367,7 +251,9 @@ impl InMemRelation { self.ensure_mem_db_for_epoch(epoch); let mem_db: &RefCell<_> = self.mem_db.borrow(); let epoch_map = mem_db.borrow().get(epoch as usize).unwrap().clone(); - let epoch_map = epoch_map.try_read().unwrap(); + let epoch_map: &RefCell> = epoch_map.borrow(); + let epoch_map = epoch_map.borrow(); + let res = epoch_map .range((Included(prefix), Included(&upper))) .map(|(k, v)| { @@ -406,7 +292,8 @@ impl InMemRelation { let mem_db: &RefCell<_> = self.mem_db.borrow(); let epoch_map = mem_db.borrow().get(epoch as usize).unwrap().clone(); - let epoch_map = epoch_map.try_read().unwrap(); + let epoch_map: &RefCell> = epoch_map.borrow(); + let epoch_map = epoch_map.borrow(); let res = epoch_map .range((Included(&prefix_bound), Included(&upper_bound)))