From aa7c852beefe5bcf4b196a0b132cf01612ed7136 Mon Sep 17 00:00:00 2001 From: Ziyang Hu Date: Mon, 5 Dec 2022 13:28:23 +0800 Subject: [PATCH] first layer rwlock -> refcell --- cozo-core/src/query/compile.rs | 2 - cozo-core/src/runtime/in_mem.rs | 124 +++++++++++++++++--------------- 2 files changed, 67 insertions(+), 59 deletions(-) diff --git a/cozo-core/src/query/compile.rs b/cozo-core/src/query/compile.rs index 9d0612a0..ce290851 100644 --- a/cozo-core/src/query/compile.rs +++ b/cozo-core/src/query/compile.rs @@ -34,8 +34,6 @@ pub(crate) enum CompiledRuleSet { Algo(MagicAlgoApply), } -unsafe impl Send for CompiledRuleSet {} - #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub(crate) enum AggrKind { None, diff --git a/cozo-core/src/runtime/in_mem.rs b/cozo-core/src/runtime/in_mem.rs index 41184a4f..a2e87c0d 100644 --- a/cozo-core/src/runtime/in_mem.rs +++ b/cozo-core/src/runtime/in_mem.rs @@ -6,11 +6,13 @@ * You can obtain one at https://mozilla.org/MPL/2.0/. */ -use std::borrow::BorrowMut; +use std::borrow::{Borrow, BorrowMut}; +use std::cell::RefCell; use std::collections::BTreeMap; use std::fmt::{Debug, Formatter}; use std::iter; use std::ops::Bound::Included; +use std::rc::Rc; use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::{Arc, RwLock}; @@ -36,7 +38,7 @@ impl Debug for StoredRelationId { #[derive(Clone)] pub(crate) struct InMemRelation { - mem_db: Arc>>>>>, + mem_db: Rc>>>>>, epoch_size: Arc, pub(crate) id: StoredRelationId, pub(crate) rule_name: MagicSymbol, @@ -63,11 +65,14 @@ impl InMemRelation { if self.epoch_size.load(Ordering::Relaxed) > epoch { return; } - let l = self.mem_db.try_read().unwrap().len() as i32; + + let mem_db: &RefCell<_> = self.mem_db.borrow(); + + let l = mem_db.borrow().len() as i32; let want = (epoch + 1) as i32; let diff = want - l; if diff > 0 { - let mut db = self.mem_db.try_write().unwrap(); + let mut db = mem_db.borrow_mut(); for _ in 0..diff { db.push(Default::default()); } @@ -81,8 +86,11 @@ impl InMemRelation { epoch: u32, ) -> Result { self.ensure_mem_db_for_epoch(epoch); - let db_target = self.mem_db.try_read().unwrap(); - let mut zero_target = db_target.get(0).unwrap().try_write().unwrap(); + + let mem_db: &RefCell<_> = self.mem_db.borrow(); + let zero_map = mem_db.borrow().get(0).unwrap().clone(); + + let mut zero_target = zero_map.try_write().unwrap(); let key = Tuple( aggrs .iter() @@ -107,8 +115,9 @@ impl InMemRelation { } } if changed && epoch != 0 { - let mut epoch_target = db_target.get(epoch as usize).unwrap().try_write().unwrap(); - epoch_target.insert(key, prev_aggr.clone()); + let epoch_map = mem_db.borrow().get(epoch as usize).unwrap().clone(); + let mut epoch_map = epoch_map.try_write().unwrap(); + epoch_map.insert(key, prev_aggr.clone()); } Ok(changed) } else { @@ -127,26 +136,29 @@ impl InMemRelation { ); zero_target.insert(key.clone(), tuple_to_store.clone()); if epoch != 0 { - let mut zero = db_target.get(epoch as usize).unwrap().try_write().unwrap(); - zero.insert(key, tuple_to_store); + let epoch_map = mem_db.borrow().get(epoch as usize).unwrap().clone(); + let mut epoch_map = epoch_map.try_write().unwrap(); + epoch_map.insert(key, tuple_to_store); } Ok(true) } } pub(crate) fn put(&self, tuple: Tuple, epoch: u32) { self.ensure_mem_db_for_epoch(epoch); - let db = self.mem_db.try_read().unwrap(); - let mut target = db.get(epoch as usize).unwrap().try_write().unwrap(); - target.insert(tuple, Tuple::default()); + let mem_db: &RefCell<_> = self.mem_db.borrow(); + let epoch_map = mem_db.borrow().get(epoch as usize).unwrap().clone(); + let mut epoch_map = epoch_map.try_write().unwrap(); + epoch_map.insert(tuple, Tuple::default()); } pub(crate) fn put_with_skip(&self, tuple: Tuple, should_skip: bool) { self.ensure_mem_db_for_epoch(0); - let db = self.mem_db.try_read().unwrap(); - let mut target = db.get(0).unwrap().try_write().unwrap(); + let mem_db: &RefCell<_> = self.mem_db.borrow(); + let epoch_map = mem_db.borrow().get(0).unwrap().clone(); + let mut epoch_map = epoch_map.try_write().unwrap(); if should_skip { - target.insert(tuple, Tuple(vec![DataValue::Guard])); + epoch_map.insert(tuple, Tuple(vec![DataValue::Guard])); } else { - target.insert(tuple, Tuple::default()); + epoch_map.insert(tuple, Tuple::default()); } } pub(crate) fn normal_aggr_put( @@ -169,15 +181,17 @@ impl InMemRelation { } vals.push(DataValue::from(serial as i64)); - let target = self.mem_db.try_read().unwrap(); - let mut target = target.get(0).unwrap().try_write().unwrap(); - target.insert(Tuple(vals), Tuple::default()); + let mem_db: &RefCell<_> = self.mem_db.borrow(); + let epoch_map = mem_db.borrow().get(0).unwrap().clone(); + let mut epoch_map = epoch_map.try_write().unwrap(); + epoch_map.insert(Tuple(vals), Tuple::default()); } pub(crate) fn exists(&self, tuple: &Tuple, epoch: u32) -> bool { self.ensure_mem_db_for_epoch(epoch); - let target = self.mem_db.try_read().unwrap(); - let target = target.get(epoch as usize).unwrap().try_read().unwrap(); - target.contains_key(tuple) + let mem_db: &RefCell<_> = self.mem_db.borrow(); + let epoch_map = mem_db.borrow().get(epoch as usize).unwrap().clone(); + let epoch_map = epoch_map.try_read().unwrap(); + epoch_map.contains_key(tuple) } pub(crate) fn normal_aggr_scan_and_put( @@ -187,9 +201,10 @@ impl InMemRelation { mut limiter: Option<&mut QueryLimiter>, poison: Poison, ) -> Result { - let db_target = self.mem_db.try_read().unwrap(); - let target = db_target.get(0); - let it = match target { + let mem_db: &RefCell<_> = self.mem_db.borrow(); + let epoch_map = mem_db.borrow(); + let epoch_map = epoch_map.get(0); + let it = match epoch_map { None => Left(iter::empty()), Some(target) => { let target = target.try_read().unwrap(); @@ -280,19 +295,16 @@ impl InMemRelation { Ok(false) } - pub(crate) fn scan_all_for_epoch(&self, epoch: u32) -> impl Iterator> { + pub(crate) fn scan_all_for_epoch<'a>( + &'a self, + epoch: u32, + ) -> impl Iterator> + 'a { self.ensure_mem_db_for_epoch(epoch); - let db = self - .mem_db - .try_read() - .unwrap() - .get(epoch as usize) - .unwrap() - .clone() - .try_read() - .unwrap() - .clone(); - db.into_iter().map(|(k, v)| { + let mem_db: &RefCell<_> = self.mem_db.borrow(); + let epoch_map = mem_db.borrow().get(epoch as usize).unwrap().clone(); + let epoch_map = epoch_map.try_read().unwrap(); + + epoch_map.clone().into_iter().map(|(k, v)| { if v.0.is_empty() { Ok(k) } else { @@ -311,22 +323,16 @@ impl InMemRelation { } }) } - pub(crate) fn scan_all(&self) -> impl Iterator> { + pub(crate) fn scan_all<'a>(&'a self) -> impl Iterator> + 'a { self.scan_all_for_epoch(0) } - pub(crate) fn scan_early_returned(&self) -> impl Iterator> { + pub(crate) fn scan_early_returned<'a>(&'a self) -> impl Iterator> + 'a { self.ensure_mem_db_for_epoch(0); - let db = self - .mem_db - .try_read() - .unwrap() - .get(0) - .unwrap() - .clone() - .try_read() - .unwrap() - .clone(); - db.into_iter().filter_map(|(k, v)| { + let mem_db: &RefCell<_> = self.mem_db.borrow(); + let epoch_map = mem_db.borrow().get(0).unwrap().clone(); + let epoch_map = epoch_map.try_read().unwrap(); + + epoch_map.clone().into_iter().filter_map(|(k, v)| { if v.0.is_empty() { Some(Ok(k)) } else if v.0.last() == Some(&DataValue::Guard) { @@ -359,9 +365,10 @@ impl InMemRelation { upper.push(DataValue::Bot); let upper = Tuple(upper); self.ensure_mem_db_for_epoch(epoch); - let target = self.mem_db.try_read().unwrap(); - let target = target.get(epoch as usize).unwrap().try_read().unwrap(); - let res = target + let mem_db: &RefCell<_> = self.mem_db.borrow(); + let epoch_map = mem_db.borrow().get(epoch as usize).unwrap().clone(); + let epoch_map = epoch_map.try_read().unwrap(); + let res = epoch_map .range((Included(prefix), Included(&upper))) .map(|(k, v)| { if v.0.is_empty() { @@ -396,9 +403,12 @@ impl InMemRelation { prefix_bound.0.extend_from_slice(lower); let mut upper_bound = prefix.clone(); upper_bound.0.extend_from_slice(upper); - let target = self.mem_db.try_read().unwrap(); - let target = target.get(epoch as usize).unwrap().try_read().unwrap(); - let res = target + + let mem_db: &RefCell<_> = self.mem_db.borrow(); + let epoch_map = mem_db.borrow().get(epoch as usize).unwrap().clone(); + let epoch_map = epoch_map.try_read().unwrap(); + + let res = epoch_map .range((Included(&prefix_bound), Included(&upper_bound))) .map(|(k, _v)| Ok(k.clone())) .collect_vec();