From 4f56ebe5054359579a4fb985a23720ab0c198017 Mon Sep 17 00:00:00 2001 From: Ziyang Hu Date: Wed, 26 Apr 2023 01:22:00 +0800 Subject: [PATCH] API for counting --- cozo-core/src/fts/indexing.rs | 12 ++++++++---- cozo-core/src/runtime/relation.rs | 26 ++++++++++++++++++++++---- cozo-core/src/runtime/tests.rs | 6 ++++-- cozo-core/src/storage/mem.rs | 19 ++++++++++++++++++- cozo-core/src/storage/mod.rs | 5 +++++ cozo-core/src/storage/rocks.rs | 14 ++++++++++++++ cozo-core/src/storage/sled.rs | 19 +++++++++++++++++++ cozo-core/src/storage/sqlite.rs | 19 ++++++++++++++++++- cozo-core/src/storage/temp.rs | 4 ++++ cozo-core/src/storage/tikv.rs | 7 +++++++ 10 files changed, 119 insertions(+), 12 deletions(-) diff --git a/cozo-core/src/fts/indexing.rs b/cozo-core/src/fts/indexing.rs index a3464ee3..d077809a 100644 --- a/cozo-core/src/fts/indexing.rs +++ b/cozo-core/src/fts/indexing.rs @@ -40,23 +40,27 @@ impl<'a> SessionTx<'a> { } }; let mut token_stream = tokenizer.token_stream(&to_index); - let mut collector: HashMap<_, (Vec<_>, Vec<_>), _> = FxHashMap::default(); + let mut collector: HashMap<_, (Vec<_>, Vec<_>, Vec<_>), _> = FxHashMap::default(); + let mut count = 0i64; while let Some(token) = token_stream.next() { let text = SmartString::::from(&token.text); - let (fr, to) = collector.entry(text).or_default(); + let (fr, to, position) = collector.entry(text).or_default(); fr.push(DataValue::from(token.offset_from as i64)); to.push(DataValue::from(token.offset_to as i64)); + position.push(DataValue::from(token.position as i64)); + count += 1; } let mut key = Vec::with_capacity(1 + rel_handle.metadata.keys.len()); key.push(DataValue::Bot); for k in &tuple[..rel_handle.metadata.keys.len()] { key.push(k.clone()); } - let mut val = vec![DataValue::Bot, DataValue::Bot]; - for (text, (from, to)) in collector { + let mut val = vec![DataValue::Bot, DataValue::Bot, DataValue::Bot, DataValue::from(count)]; + for (text, (from, to, position)) in collector { key[0] = DataValue::Str(text); val[0] = DataValue::List(from); val[1] = DataValue::List(to); + val[2] = DataValue::List(position); let key_bytes = idx_handle.encode_key_for_store(&key, Default::default())?; let val_bytes = idx_handle.encode_val_only_for_store(&val, Default::default())?; self.store_tx.put(&key_bytes, &val_bytes)?; diff --git a/cozo-core/src/runtime/relation.rs b/cozo-core/src/runtime/relation.rs index c92986b4..131be264 100644 --- a/cozo-core/src/runtime/relation.rs +++ b/cozo-core/src/runtime/relation.rs @@ -714,17 +714,35 @@ impl<'a> SessionTx<'a> { }); } + let col_type = NullableColType { + coltype: ColType::List { + eltype: Box::new(NullableColType { + coltype: ColType::Int, + nullable: false, + }), + len: None, + }, + nullable: false, + }; + let non_idx_keys: Vec = vec![ ColumnDef { name: SmartString::from("offset_from"), - typing: NullableColType { - coltype: ColType::Int, - nullable: false, - }, + typing: col_type.clone(), default_gen: None, }, ColumnDef { name: SmartString::from("offset_to"), + typing: col_type.clone(), + default_gen: None, + }, + ColumnDef { + name: SmartString::from("position"), + typing: col_type.clone(), + default_gen: None, + }, + ColumnDef { + name: SmartString::from("total_length"), typing: NullableColType { coltype: ColType::Int, nullable: false, diff --git a/cozo-core/src/runtime/tests.rs b/cozo-core/src/runtime/tests.rs index 3ee961af..6fa52f18 100644 --- a/cozo-core/src/runtime/tests.rs +++ b/cozo-core/src/runtime/tests.rs @@ -919,7 +919,8 @@ fn test_fts_indexing() { db.run_script( r"?[k, v] <- [ ['b', 'the world is square!'], - ['c', 'see you at the end of the world!'] + ['c', 'see you at the end of the world!'], + ['d', 'the world is the world and makes the world go around'] ] :put a {k => v}", Default::default(), ) @@ -927,7 +928,8 @@ fn test_fts_indexing() { let res = db .run_script( r" - ?[word, src_k, offset_from, offset_to] := *a:fts{word, src_k, offset_from, offset_to} + ?[word, src_k, offset_from, offset_to, position, total_length] := + *a:fts{word, src_k, offset_from, offset_to, position, total_length} ", Default::default(), ) diff --git a/cozo-core/src/storage/mem.rs b/cozo-core/src/storage/mem.rs index afa905d2..a9dad332 100644 --- a/cozo-core/src/storage/mem.rs +++ b/cozo-core/src/storage/mem.rs @@ -254,6 +254,22 @@ impl<'s> StoreTx<'s> for MemTx<'s> { } } + fn range_count<'a>(&'a self, lower: &[u8], upper: &[u8]) -> Result + where + 's: 'a, + { + Ok(match self { + MemTx::Reader(rdr) => rdr.range(lower.to_vec()..upper.to_vec()).count(), + MemTx::Writer(wtr, cache) => (CacheIterRaw { + change_iter: cache.range(lower.to_vec()..upper.to_vec()).fuse(), + db_iter: wtr.range(lower.to_vec()..upper.to_vec()).fuse(), + change_cache: None, + db_cache: None, + }) + .count(), + }) + } + fn total_scan<'a>(&'a self) -> Box, Vec)>> + 'a> where 's: 'a, @@ -452,7 +468,8 @@ impl<'a> Iterator for SkipIterator<'a> { match nxt { None => return None, Some((candidate_key, candidate_val)) => { - let (ret, nxt_bound) = check_key_for_validity(candidate_key, self.valid_at, self.size_hint); + let (ret, nxt_bound) = + check_key_for_validity(candidate_key, self.valid_at, self.size_hint); self.next_bound = nxt_bound; if let Some(mut nk) = ret { extend_tuple_from_v(&mut nk, candidate_val); diff --git a/cozo-core/src/storage/mod.rs b/cozo-core/src/storage/mod.rs index c5ff7bc9..1840eda0 100644 --- a/cozo-core/src/storage/mod.rs +++ b/cozo-core/src/storage/mod.rs @@ -137,6 +137,11 @@ pub trait StoreTx<'s>: Sync { where 's: 'a; + /// Return the number of rows in the range. + fn range_count<'a>(&'a self, lower: &[u8], upper: &[u8]) -> Result + where + 's: 'a; + /// Scan for all rows. The rows are required to be in ascending order. fn total_scan<'a>(&'a self) -> Box, Vec)>> + 'a> where diff --git a/cozo-core/src/storage/rocks.rs b/cozo-core/src/storage/rocks.rs index a885534a..4c0ca9ec 100644 --- a/cozo-core/src/storage/rocks.rs +++ b/cozo-core/src/storage/rocks.rs @@ -242,6 +242,20 @@ impl<'s> StoreTx<'s> for RocksDbTx { }) } + fn range_count<'a>(&'a self, lower: &[u8], upper: &[u8]) -> Result where 's: 'a { + let mut inner = self.db_tx.iterator().upper_bound(upper).start(); + inner.seek(lower); + let mut count = 0; + while let Some(k) = inner.key()? { + if k >= upper { + break; + } + count += 1; + inner.next(); + } + Ok(count) + } + fn total_scan<'a>(&'a self) -> Box, Vec)>> + 'a> where 's: 'a, diff --git a/cozo-core/src/storage/sled.rs b/cozo-core/src/storage/sled.rs index dba2de9f..37b882db 100644 --- a/cozo-core/src/storage/sled.rs +++ b/cozo-core/src/storage/sled.rs @@ -250,6 +250,25 @@ impl<'s> StoreTx<'s> for SledTx { } } + fn range_count<'a>(&'a self, lower: &[u8], upper: &[u8]) -> Result + where + 's: 'a, + { + Ok(if let Some(changes) = &self.changes { + let change_iter = changes.range(lower.to_vec()..upper.to_vec()).fuse(); + let db_iter = self.db.range(lower.to_vec()..upper.to_vec()).fuse(); + (SledIterRaw { + change_iter, + db_iter, + change_cache: None, + db_cache: None, + }) + .count() + } else { + self.db.range(lower.to_vec()..upper.to_vec()).count() + }) + } + fn total_scan<'a>(&'a self) -> Box, Vec)>> + 'a> where 's: 'a, diff --git a/cozo-core/src/storage/sqlite.rs b/cozo-core/src/storage/sqlite.rs index 5e4e2fb3..9daa7348 100644 --- a/cozo-core/src/storage/sqlite.rs +++ b/cozo-core/src/storage/sqlite.rs @@ -146,7 +146,7 @@ pub struct SqliteTx<'a> { unsafe impl Sync for SqliteTx<'_> {} -const N_QUERIES: usize = 6; +const N_QUERIES: usize = 7; const N_CACHED_QUERIES: usize = 4; const QUERIES: [&str; N_QUERIES] = [ "select v from cozo where k = ?;", @@ -155,6 +155,7 @@ const QUERIES: [&str; N_QUERIES] = [ "select 1 from cozo where k = ?;", "select k, v from cozo where k >= ? and k < ? order by k;", "select k, v from cozo where k >= ? and k < ? order by k limit 1;", + "select count(*) from cozo where k >= ? and k < ?;" ]; const GET_QUERY: usize = 0; @@ -163,6 +164,7 @@ const DEL_QUERY: usize = 2; const EXISTS_QUERY: usize = 3; const RANGE_QUERY: usize = 4; const SKIP_RANGE_QUERY: usize = 5; +const COUNT_RANGE_QUERY: usize = 6; impl Drop for SqliteTx<'_> { fn drop(&mut self) { @@ -319,6 +321,21 @@ impl<'s> StoreTx<'s> for SqliteTx<'s> { Box::new(RawIter(statement)) } + fn range_count<'a>(&'a self, lower: &[u8], upper: &[u8]) -> Result where 's: 'a { + let query = QUERIES[COUNT_RANGE_QUERY]; + let mut statement = self.conn.as_ref().unwrap().prepare(query).unwrap(); + statement.bind((1, lower)).unwrap(); + statement.bind((2, upper)).unwrap(); + match statement.next() { + Ok(State::Done) => bail!("range count query returned no rows"), + Ok(State::Row) => { + let k = statement.read::(0).unwrap(); + Ok(k as usize) + } + Err(err) => bail!(err), + } + } + fn total_scan<'a>(&'a self) -> Box, Vec)>> + 'a> where 's: 'a, diff --git a/cozo-core/src/storage/temp.rs b/cozo-core/src/storage/temp.rs index 0ce26f1b..15a5a001 100644 --- a/cozo-core/src/storage/temp.rs +++ b/cozo-core/src/storage/temp.rs @@ -132,6 +132,10 @@ impl<'s> StoreTx<'s> for TempTx { ) } + fn range_count<'a>(&'a self, lower: &[u8], upper: &[u8]) -> Result where 's: 'a { + Ok(self.store.range(lower.to_vec()..upper.to_vec()).count()) + } + fn total_scan<'a>(&'a self) -> Box, Vec)>> + 'a> where 's: 'a, diff --git a/cozo-core/src/storage/tikv.rs b/cozo-core/src/storage/tikv.rs index 30ca0012..d65e8ed8 100644 --- a/cozo-core/src/storage/tikv.rs +++ b/cozo-core/src/storage/tikv.rs @@ -191,6 +191,13 @@ impl<'s> StoreTx<'s> for TiKvTx { Box::new(BatchScannerRaw::new(self.tx.clone(), lower, upper)) } + fn range_count<'a>(&'a self, lower: &[u8], upper: &[u8]) -> Result + where + 's: 'a, + { + Ok(BatchScannerRaw::new(self.tx.clone(), lower, upper).count()) + } + fn total_scan<'a>(&'a self) -> Box, Vec)>> + 'a> where 's: 'a,