API for counting

main
Ziyang Hu 1 year ago
parent e983b72b40
commit 4f56ebe505

@ -40,23 +40,27 @@ impl<'a> SessionTx<'a> {
}
};
let mut token_stream = tokenizer.token_stream(&to_index);
let mut collector: HashMap<_, (Vec<_>, Vec<_>), _> = FxHashMap::default();
let mut collector: HashMap<_, (Vec<_>, Vec<_>, Vec<_>), _> = FxHashMap::default();
let mut count = 0i64;
while let Some(token) = token_stream.next() {
let text = SmartString::<LazyCompact>::from(&token.text);
let (fr, to) = collector.entry(text).or_default();
let (fr, to, position) = collector.entry(text).or_default();
fr.push(DataValue::from(token.offset_from as i64));
to.push(DataValue::from(token.offset_to as i64));
position.push(DataValue::from(token.position as i64));
count += 1;
}
let mut key = Vec::with_capacity(1 + rel_handle.metadata.keys.len());
key.push(DataValue::Bot);
for k in &tuple[..rel_handle.metadata.keys.len()] {
key.push(k.clone());
}
let mut val = vec![DataValue::Bot, DataValue::Bot];
for (text, (from, to)) in collector {
let mut val = vec![DataValue::Bot, DataValue::Bot, DataValue::Bot, DataValue::from(count)];
for (text, (from, to, position)) in collector {
key[0] = DataValue::Str(text);
val[0] = DataValue::List(from);
val[1] = DataValue::List(to);
val[2] = DataValue::List(position);
let key_bytes = idx_handle.encode_key_for_store(&key, Default::default())?;
let val_bytes = idx_handle.encode_val_only_for_store(&val, Default::default())?;
self.store_tx.put(&key_bytes, &val_bytes)?;

@ -714,17 +714,35 @@ impl<'a> SessionTx<'a> {
});
}
let non_idx_keys: Vec<ColumnDef> = vec![
ColumnDef {
name: SmartString::from("offset_from"),
typing: NullableColType {
let col_type = NullableColType {
coltype: ColType::List {
eltype: Box::new(NullableColType {
coltype: ColType::Int,
nullable: false,
}),
len: None,
},
nullable: false,
};
let non_idx_keys: Vec<ColumnDef> = vec![
ColumnDef {
name: SmartString::from("offset_from"),
typing: col_type.clone(),
default_gen: None,
},
ColumnDef {
name: SmartString::from("offset_to"),
typing: col_type.clone(),
default_gen: None,
},
ColumnDef {
name: SmartString::from("position"),
typing: col_type.clone(),
default_gen: None,
},
ColumnDef {
name: SmartString::from("total_length"),
typing: NullableColType {
coltype: ColType::Int,
nullable: false,

@ -919,7 +919,8 @@ fn test_fts_indexing() {
db.run_script(
r"?[k, v] <- [
['b', 'the world is square!'],
['c', 'see you at the end of the world!']
['c', 'see you at the end of the world!'],
['d', 'the world is the world and makes the world go around']
] :put a {k => v}",
Default::default(),
)
@ -927,7 +928,8 @@ fn test_fts_indexing() {
let res = db
.run_script(
r"
?[word, src_k, offset_from, offset_to] := *a:fts{word, src_k, offset_from, offset_to}
?[word, src_k, offset_from, offset_to, position, total_length] :=
*a:fts{word, src_k, offset_from, offset_to, position, total_length}
",
Default::default(),
)

@ -254,6 +254,22 @@ impl<'s> StoreTx<'s> for MemTx<'s> {
}
}
fn range_count<'a>(&'a self, lower: &[u8], upper: &[u8]) -> Result<usize>
where
's: 'a,
{
Ok(match self {
MemTx::Reader(rdr) => rdr.range(lower.to_vec()..upper.to_vec()).count(),
MemTx::Writer(wtr, cache) => (CacheIterRaw {
change_iter: cache.range(lower.to_vec()..upper.to_vec()).fuse(),
db_iter: wtr.range(lower.to_vec()..upper.to_vec()).fuse(),
change_cache: None,
db_cache: None,
})
.count(),
})
}
fn total_scan<'a>(&'a self) -> Box<dyn Iterator<Item = Result<(Vec<u8>, Vec<u8>)>> + 'a>
where
's: 'a,
@ -452,7 +468,8 @@ impl<'a> Iterator for SkipIterator<'a> {
match nxt {
None => return None,
Some((candidate_key, candidate_val)) => {
let (ret, nxt_bound) = check_key_for_validity(candidate_key, self.valid_at, self.size_hint);
let (ret, nxt_bound) =
check_key_for_validity(candidate_key, self.valid_at, self.size_hint);
self.next_bound = nxt_bound;
if let Some(mut nk) = ret {
extend_tuple_from_v(&mut nk, candidate_val);

@ -137,6 +137,11 @@ pub trait StoreTx<'s>: Sync {
where
's: 'a;
/// Return the number of rows in the range.
fn range_count<'a>(&'a self, lower: &[u8], upper: &[u8]) -> Result<usize>
where
's: 'a;
/// Scan for all rows. The rows are required to be in ascending order.
fn total_scan<'a>(&'a self) -> Box<dyn Iterator<Item = Result<(Vec<u8>, Vec<u8>)>> + 'a>
where

@ -242,6 +242,20 @@ impl<'s> StoreTx<'s> for RocksDbTx {
})
}
fn range_count<'a>(&'a self, lower: &[u8], upper: &[u8]) -> Result<usize> where 's: 'a {
let mut inner = self.db_tx.iterator().upper_bound(upper).start();
inner.seek(lower);
let mut count = 0;
while let Some(k) = inner.key()? {
if k >= upper {
break;
}
count += 1;
inner.next();
}
Ok(count)
}
fn total_scan<'a>(&'a self) -> Box<dyn Iterator<Item = Result<(Vec<u8>, Vec<u8>)>> + 'a>
where
's: 'a,

@ -250,6 +250,25 @@ impl<'s> StoreTx<'s> for SledTx {
}
}
fn range_count<'a>(&'a self, lower: &[u8], upper: &[u8]) -> Result<usize>
where
's: 'a,
{
Ok(if let Some(changes) = &self.changes {
let change_iter = changes.range(lower.to_vec()..upper.to_vec()).fuse();
let db_iter = self.db.range(lower.to_vec()..upper.to_vec()).fuse();
(SledIterRaw {
change_iter,
db_iter,
change_cache: None,
db_cache: None,
})
.count()
} else {
self.db.range(lower.to_vec()..upper.to_vec()).count()
})
}
fn total_scan<'a>(&'a self) -> Box<dyn Iterator<Item = Result<(Vec<u8>, Vec<u8>)>> + 'a>
where
's: 'a,

@ -146,7 +146,7 @@ pub struct SqliteTx<'a> {
unsafe impl Sync for SqliteTx<'_> {}
const N_QUERIES: usize = 6;
const N_QUERIES: usize = 7;
const N_CACHED_QUERIES: usize = 4;
const QUERIES: [&str; N_QUERIES] = [
"select v from cozo where k = ?;",
@ -155,6 +155,7 @@ const QUERIES: [&str; N_QUERIES] = [
"select 1 from cozo where k = ?;",
"select k, v from cozo where k >= ? and k < ? order by k;",
"select k, v from cozo where k >= ? and k < ? order by k limit 1;",
"select count(*) from cozo where k >= ? and k < ?;"
];
const GET_QUERY: usize = 0;
@ -163,6 +164,7 @@ const DEL_QUERY: usize = 2;
const EXISTS_QUERY: usize = 3;
const RANGE_QUERY: usize = 4;
const SKIP_RANGE_QUERY: usize = 5;
const COUNT_RANGE_QUERY: usize = 6;
impl Drop for SqliteTx<'_> {
fn drop(&mut self) {
@ -319,6 +321,21 @@ impl<'s> StoreTx<'s> for SqliteTx<'s> {
Box::new(RawIter(statement))
}
fn range_count<'a>(&'a self, lower: &[u8], upper: &[u8]) -> Result<usize> where 's: 'a {
let query = QUERIES[COUNT_RANGE_QUERY];
let mut statement = self.conn.as_ref().unwrap().prepare(query).unwrap();
statement.bind((1, lower)).unwrap();
statement.bind((2, upper)).unwrap();
match statement.next() {
Ok(State::Done) => bail!("range count query returned no rows"),
Ok(State::Row) => {
let k = statement.read::<i64, _>(0).unwrap();
Ok(k as usize)
}
Err(err) => bail!(err),
}
}
fn total_scan<'a>(&'a self) -> Box<dyn Iterator<Item = Result<(Vec<u8>, Vec<u8>)>> + 'a>
where
's: 'a,

@ -132,6 +132,10 @@ impl<'s> StoreTx<'s> for TempTx {
)
}
fn range_count<'a>(&'a self, lower: &[u8], upper: &[u8]) -> Result<usize> where 's: 'a {
Ok(self.store.range(lower.to_vec()..upper.to_vec()).count())
}
fn total_scan<'a>(&'a self) -> Box<dyn Iterator<Item = Result<(Vec<u8>, Vec<u8>)>> + 'a>
where
's: 'a,

@ -191,6 +191,13 @@ impl<'s> StoreTx<'s> for TiKvTx {
Box::new(BatchScannerRaw::new(self.tx.clone(), lower, upper))
}
fn range_count<'a>(&'a self, lower: &[u8], upper: &[u8]) -> Result<usize>
where
's: 'a,
{
Ok(BatchScannerRaw::new(self.tx.clone(), lower, upper).count())
}
fn total_scan<'a>(&'a self) -> Box<dyn Iterator<Item = Result<(Vec<u8>, Vec<u8>)>> + 'a>
where
's: 'a,

Loading…
Cancel
Save