API for counting

main
Ziyang Hu 1 year ago
parent e983b72b40
commit 4f56ebe505

@ -40,23 +40,27 @@ impl<'a> SessionTx<'a> {
} }
}; };
let mut token_stream = tokenizer.token_stream(&to_index); let mut token_stream = tokenizer.token_stream(&to_index);
let mut collector: HashMap<_, (Vec<_>, Vec<_>), _> = FxHashMap::default(); let mut collector: HashMap<_, (Vec<_>, Vec<_>, Vec<_>), _> = FxHashMap::default();
let mut count = 0i64;
while let Some(token) = token_stream.next() { while let Some(token) = token_stream.next() {
let text = SmartString::<LazyCompact>::from(&token.text); let text = SmartString::<LazyCompact>::from(&token.text);
let (fr, to) = collector.entry(text).or_default(); let (fr, to, position) = collector.entry(text).or_default();
fr.push(DataValue::from(token.offset_from as i64)); fr.push(DataValue::from(token.offset_from as i64));
to.push(DataValue::from(token.offset_to as i64)); to.push(DataValue::from(token.offset_to as i64));
position.push(DataValue::from(token.position as i64));
count += 1;
} }
let mut key = Vec::with_capacity(1 + rel_handle.metadata.keys.len()); let mut key = Vec::with_capacity(1 + rel_handle.metadata.keys.len());
key.push(DataValue::Bot); key.push(DataValue::Bot);
for k in &tuple[..rel_handle.metadata.keys.len()] { for k in &tuple[..rel_handle.metadata.keys.len()] {
key.push(k.clone()); key.push(k.clone());
} }
let mut val = vec![DataValue::Bot, DataValue::Bot]; let mut val = vec![DataValue::Bot, DataValue::Bot, DataValue::Bot, DataValue::from(count)];
for (text, (from, to)) in collector { for (text, (from, to, position)) in collector {
key[0] = DataValue::Str(text); key[0] = DataValue::Str(text);
val[0] = DataValue::List(from); val[0] = DataValue::List(from);
val[1] = DataValue::List(to); val[1] = DataValue::List(to);
val[2] = DataValue::List(position);
let key_bytes = idx_handle.encode_key_for_store(&key, Default::default())?; let key_bytes = idx_handle.encode_key_for_store(&key, Default::default())?;
let val_bytes = idx_handle.encode_val_only_for_store(&val, Default::default())?; let val_bytes = idx_handle.encode_val_only_for_store(&val, Default::default())?;
self.store_tx.put(&key_bytes, &val_bytes)?; self.store_tx.put(&key_bytes, &val_bytes)?;

@ -714,17 +714,35 @@ impl<'a> SessionTx<'a> {
}); });
} }
let col_type = NullableColType {
coltype: ColType::List {
eltype: Box::new(NullableColType {
coltype: ColType::Int,
nullable: false,
}),
len: None,
},
nullable: false,
};
let non_idx_keys: Vec<ColumnDef> = vec![ let non_idx_keys: Vec<ColumnDef> = vec![
ColumnDef { ColumnDef {
name: SmartString::from("offset_from"), name: SmartString::from("offset_from"),
typing: NullableColType { typing: col_type.clone(),
coltype: ColType::Int,
nullable: false,
},
default_gen: None, default_gen: None,
}, },
ColumnDef { ColumnDef {
name: SmartString::from("offset_to"), name: SmartString::from("offset_to"),
typing: col_type.clone(),
default_gen: None,
},
ColumnDef {
name: SmartString::from("position"),
typing: col_type.clone(),
default_gen: None,
},
ColumnDef {
name: SmartString::from("total_length"),
typing: NullableColType { typing: NullableColType {
coltype: ColType::Int, coltype: ColType::Int,
nullable: false, nullable: false,

@ -919,7 +919,8 @@ fn test_fts_indexing() {
db.run_script( db.run_script(
r"?[k, v] <- [ r"?[k, v] <- [
['b', 'the world is square!'], ['b', 'the world is square!'],
['c', 'see you at the end of the world!'] ['c', 'see you at the end of the world!'],
['d', 'the world is the world and makes the world go around']
] :put a {k => v}", ] :put a {k => v}",
Default::default(), Default::default(),
) )
@ -927,7 +928,8 @@ fn test_fts_indexing() {
let res = db let res = db
.run_script( .run_script(
r" r"
?[word, src_k, offset_from, offset_to] := *a:fts{word, src_k, offset_from, offset_to} ?[word, src_k, offset_from, offset_to, position, total_length] :=
*a:fts{word, src_k, offset_from, offset_to, position, total_length}
", ",
Default::default(), Default::default(),
) )

@ -254,6 +254,22 @@ impl<'s> StoreTx<'s> for MemTx<'s> {
} }
} }
fn range_count<'a>(&'a self, lower: &[u8], upper: &[u8]) -> Result<usize>
where
's: 'a,
{
Ok(match self {
MemTx::Reader(rdr) => rdr.range(lower.to_vec()..upper.to_vec()).count(),
MemTx::Writer(wtr, cache) => (CacheIterRaw {
change_iter: cache.range(lower.to_vec()..upper.to_vec()).fuse(),
db_iter: wtr.range(lower.to_vec()..upper.to_vec()).fuse(),
change_cache: None,
db_cache: None,
})
.count(),
})
}
fn total_scan<'a>(&'a self) -> Box<dyn Iterator<Item = Result<(Vec<u8>, Vec<u8>)>> + 'a> fn total_scan<'a>(&'a self) -> Box<dyn Iterator<Item = Result<(Vec<u8>, Vec<u8>)>> + 'a>
where where
's: 'a, 's: 'a,
@ -452,7 +468,8 @@ impl<'a> Iterator for SkipIterator<'a> {
match nxt { match nxt {
None => return None, None => return None,
Some((candidate_key, candidate_val)) => { Some((candidate_key, candidate_val)) => {
let (ret, nxt_bound) = check_key_for_validity(candidate_key, self.valid_at, self.size_hint); let (ret, nxt_bound) =
check_key_for_validity(candidate_key, self.valid_at, self.size_hint);
self.next_bound = nxt_bound; self.next_bound = nxt_bound;
if let Some(mut nk) = ret { if let Some(mut nk) = ret {
extend_tuple_from_v(&mut nk, candidate_val); extend_tuple_from_v(&mut nk, candidate_val);

@ -137,6 +137,11 @@ pub trait StoreTx<'s>: Sync {
where where
's: 'a; 's: 'a;
/// Return the number of rows in the range.
fn range_count<'a>(&'a self, lower: &[u8], upper: &[u8]) -> Result<usize>
where
's: 'a;
/// Scan for all rows. The rows are required to be in ascending order. /// Scan for all rows. The rows are required to be in ascending order.
fn total_scan<'a>(&'a self) -> Box<dyn Iterator<Item = Result<(Vec<u8>, Vec<u8>)>> + 'a> fn total_scan<'a>(&'a self) -> Box<dyn Iterator<Item = Result<(Vec<u8>, Vec<u8>)>> + 'a>
where where

@ -242,6 +242,20 @@ impl<'s> StoreTx<'s> for RocksDbTx {
}) })
} }
fn range_count<'a>(&'a self, lower: &[u8], upper: &[u8]) -> Result<usize> where 's: 'a {
let mut inner = self.db_tx.iterator().upper_bound(upper).start();
inner.seek(lower);
let mut count = 0;
while let Some(k) = inner.key()? {
if k >= upper {
break;
}
count += 1;
inner.next();
}
Ok(count)
}
fn total_scan<'a>(&'a self) -> Box<dyn Iterator<Item = Result<(Vec<u8>, Vec<u8>)>> + 'a> fn total_scan<'a>(&'a self) -> Box<dyn Iterator<Item = Result<(Vec<u8>, Vec<u8>)>> + 'a>
where where
's: 'a, 's: 'a,

@ -250,6 +250,25 @@ impl<'s> StoreTx<'s> for SledTx {
} }
} }
fn range_count<'a>(&'a self, lower: &[u8], upper: &[u8]) -> Result<usize>
where
's: 'a,
{
Ok(if let Some(changes) = &self.changes {
let change_iter = changes.range(lower.to_vec()..upper.to_vec()).fuse();
let db_iter = self.db.range(lower.to_vec()..upper.to_vec()).fuse();
(SledIterRaw {
change_iter,
db_iter,
change_cache: None,
db_cache: None,
})
.count()
} else {
self.db.range(lower.to_vec()..upper.to_vec()).count()
})
}
fn total_scan<'a>(&'a self) -> Box<dyn Iterator<Item = Result<(Vec<u8>, Vec<u8>)>> + 'a> fn total_scan<'a>(&'a self) -> Box<dyn Iterator<Item = Result<(Vec<u8>, Vec<u8>)>> + 'a>
where where
's: 'a, 's: 'a,

@ -146,7 +146,7 @@ pub struct SqliteTx<'a> {
unsafe impl Sync for SqliteTx<'_> {} unsafe impl Sync for SqliteTx<'_> {}
const N_QUERIES: usize = 6; const N_QUERIES: usize = 7;
const N_CACHED_QUERIES: usize = 4; const N_CACHED_QUERIES: usize = 4;
const QUERIES: [&str; N_QUERIES] = [ const QUERIES: [&str; N_QUERIES] = [
"select v from cozo where k = ?;", "select v from cozo where k = ?;",
@ -155,6 +155,7 @@ const QUERIES: [&str; N_QUERIES] = [
"select 1 from cozo where k = ?;", "select 1 from cozo where k = ?;",
"select k, v from cozo where k >= ? and k < ? order by k;", "select k, v from cozo where k >= ? and k < ? order by k;",
"select k, v from cozo where k >= ? and k < ? order by k limit 1;", "select k, v from cozo where k >= ? and k < ? order by k limit 1;",
"select count(*) from cozo where k >= ? and k < ?;"
]; ];
const GET_QUERY: usize = 0; const GET_QUERY: usize = 0;
@ -163,6 +164,7 @@ const DEL_QUERY: usize = 2;
const EXISTS_QUERY: usize = 3; const EXISTS_QUERY: usize = 3;
const RANGE_QUERY: usize = 4; const RANGE_QUERY: usize = 4;
const SKIP_RANGE_QUERY: usize = 5; const SKIP_RANGE_QUERY: usize = 5;
const COUNT_RANGE_QUERY: usize = 6;
impl Drop for SqliteTx<'_> { impl Drop for SqliteTx<'_> {
fn drop(&mut self) { fn drop(&mut self) {
@ -319,6 +321,21 @@ impl<'s> StoreTx<'s> for SqliteTx<'s> {
Box::new(RawIter(statement)) Box::new(RawIter(statement))
} }
fn range_count<'a>(&'a self, lower: &[u8], upper: &[u8]) -> Result<usize> where 's: 'a {
let query = QUERIES[COUNT_RANGE_QUERY];
let mut statement = self.conn.as_ref().unwrap().prepare(query).unwrap();
statement.bind((1, lower)).unwrap();
statement.bind((2, upper)).unwrap();
match statement.next() {
Ok(State::Done) => bail!("range count query returned no rows"),
Ok(State::Row) => {
let k = statement.read::<i64, _>(0).unwrap();
Ok(k as usize)
}
Err(err) => bail!(err),
}
}
fn total_scan<'a>(&'a self) -> Box<dyn Iterator<Item = Result<(Vec<u8>, Vec<u8>)>> + 'a> fn total_scan<'a>(&'a self) -> Box<dyn Iterator<Item = Result<(Vec<u8>, Vec<u8>)>> + 'a>
where where
's: 'a, 's: 'a,

@ -132,6 +132,10 @@ impl<'s> StoreTx<'s> for TempTx {
) )
} }
fn range_count<'a>(&'a self, lower: &[u8], upper: &[u8]) -> Result<usize> where 's: 'a {
Ok(self.store.range(lower.to_vec()..upper.to_vec()).count())
}
fn total_scan<'a>(&'a self) -> Box<dyn Iterator<Item = Result<(Vec<u8>, Vec<u8>)>> + 'a> fn total_scan<'a>(&'a self) -> Box<dyn Iterator<Item = Result<(Vec<u8>, Vec<u8>)>> + 'a>
where where
's: 'a, 's: 'a,

@ -191,6 +191,13 @@ impl<'s> StoreTx<'s> for TiKvTx {
Box::new(BatchScannerRaw::new(self.tx.clone(), lower, upper)) Box::new(BatchScannerRaw::new(self.tx.clone(), lower, upper))
} }
fn range_count<'a>(&'a self, lower: &[u8], upper: &[u8]) -> Result<usize>
where
's: 'a,
{
Ok(BatchScannerRaw::new(self.tx.clone(), lower, upper).count())
}
fn total_scan<'a>(&'a self) -> Box<dyn Iterator<Item = Result<(Vec<u8>, Vec<u8>)>> + 'a> fn total_scan<'a>(&'a self) -> Box<dyn Iterator<Item = Result<(Vec<u8>, Vec<u8>)>> + 'a>
where where
's: 'a, 's: 'a,

Loading…
Cancel
Save