fix limit/offset logic

main
Ziyang Hu 2 years ago
parent 5d69ff63eb
commit 3d8da482b8

2
Cargo.lock generated

@ -335,7 +335,7 @@ checksum = "5827cebf4670468b8772dd191856768aedcb1b0278a04f989f7766351917b9dc"
[[package]]
name = "cozo"
version = "0.1.0"
version = "0.1.1"
dependencies = [
"approx",
"base64",

@ -1,6 +1,6 @@
[package]
name = "cozo"
version = "0.1.0"
version = "0.1.1"
edition = "2021"
description = "A general-purpose, transactional, relational database that uses Datalog and focuses on graph data"
authors = ["Ziyang Hu"]

@ -309,5 +309,6 @@ The rest of the query options are explained in the following.
The query returns nothing if the output relation contains at least one row,
otherwise, execution aborts with an error.
Implies ``:limit 1`` to ensure early termination if possible.
Useful for transactions and triggers.
Useful for transactions and triggers.
You should consider adding ``:limit 1`` to the query to ensure early termination
if you do not need to check all return tuples.

@ -9,7 +9,7 @@ use log::{debug, trace};
use miette::Result;
use crate::data::program::{MagicAlgoApply, MagicSymbol, NoEntryError};
use crate::data::symb::{PROG_ENTRY, Symbol};
use crate::data::symb::{Symbol, PROG_ENTRY};
use crate::parse::SourceSpan;
use crate::query::compile::{AggrKind, CompiledProgram, CompiledRule, CompiledRuleSet};
use crate::runtime::db::Poison;
@ -17,19 +17,29 @@ use crate::runtime::in_mem::InMemRelation;
use crate::runtime::transact::SessionTx;
pub(crate) struct QueryLimiter {
limit: Option<usize>,
total: Option<usize>,
skip: Option<usize>,
counter: usize,
}
impl QueryLimiter {
pub(crate) fn incr(&mut self) -> bool {
if let Some(limit) = self.limit {
pub(crate) fn is_used(&self) -> bool {
self.total.is_some() || self.skip.is_some()
}
pub(crate) fn incr_and_should_stop(&mut self) -> bool {
if let Some(limit) = self.total {
self.counter += 1;
self.counter >= limit
} else {
false
}
}
pub(crate) fn should_skip_next(&self) -> bool {
match self.skip {
None => false,
Some(i) => i > self.counter,
}
}
}
impl SessionTx {
@ -37,33 +47,42 @@ impl SessionTx {
&self,
strata: &[CompiledProgram],
stores: &BTreeMap<MagicSymbol, InMemRelation>,
num_to_take: Option<usize>,
total_num_to_take: Option<usize>,
num_to_skip: Option<usize>,
poison: Poison,
) -> Result<InMemRelation> {
) -> Result<(InMemRelation, bool)> {
let ret_area = stores
.get(&MagicSymbol::Muggle {
inner: Symbol::new(PROG_ENTRY, SourceSpan(0, 0)),
})
.ok_or(NoEntryError)?
.clone();
let mut early_return = false;
for (idx, cur_prog) in strata.iter().enumerate() {
debug!("stratum {}", idx);
self.semi_naive_magic_evaluate(cur_prog, stores, num_to_take, poison.clone())?;
early_return = self.semi_naive_magic_evaluate(
cur_prog,
stores,
total_num_to_take,
num_to_skip,
poison.clone(),
)?;
}
Ok(ret_area)
Ok((ret_area, early_return))
}
fn semi_naive_magic_evaluate(
&self,
prog: &CompiledProgram,
stores: &BTreeMap<MagicSymbol, InMemRelation>,
num_to_take: Option<usize>,
total_num_to_take: Option<usize>,
num_to_skip: Option<usize>,
poison: Poison,
) -> Result<()> {
) -> Result<bool> {
let mut changed: BTreeMap<_, _> = prog.keys().map(|k| (k, false)).collect();
let mut prev_changed = changed.clone();
let mut limiter = QueryLimiter {
limit: num_to_take,
total: total_num_to_take,
skip: num_to_skip,
counter: 0,
};
@ -74,7 +93,7 @@ impl SessionTx {
match compiled_ruleset {
CompiledRuleSet::Rules(ruleset) => {
let aggr_kind = compiled_ruleset.aggr_kind();
self.initial_rule_eval(
if self.initial_rule_eval(
k,
ruleset,
aggr_kind,
@ -82,7 +101,9 @@ impl SessionTx {
&mut changed,
&mut limiter,
poison.clone(),
)?;
)? {
return Ok(true);
};
}
CompiledRuleSet::Algo(algo_apply) => {
self.algo_application_eval(k, algo_apply, stores, poison.clone())?;
@ -103,7 +124,7 @@ impl SessionTx {
AggrKind::Normal => false,
AggrKind::Meet => true,
};
self.incremental_rule_eval(
if self.incremental_rule_eval(
k,
ruleset,
epoch,
@ -113,7 +134,9 @@ impl SessionTx {
&mut changed,
&mut limiter,
poison.clone(),
)?;
)? {
return Ok(true);
};
}
CompiledRuleSet::Algo(_) => unreachable!(),
@ -124,7 +147,7 @@ impl SessionTx {
break;
}
}
Ok(())
Ok(limiter.is_used())
}
fn algo_application_eval(
&self,
@ -146,10 +169,10 @@ impl SessionTx {
changed: &mut BTreeMap<&MagicSymbol, bool>,
limiter: &mut QueryLimiter,
poison: Poison,
) -> Result<()> {
) -> Result<bool> {
let store = stores.get(rule_symb).unwrap();
let use_delta = BTreeSet::default();
let should_check_limit = limiter.limit.is_some() && rule_symb.is_prog_entry();
let should_check_limit = limiter.total.is_some() && rule_symb.is_prog_entry();
match aggr_kind {
AggrKind::None | AggrKind::Meet => {
let is_meet = aggr_kind == AggrKind::Meet;
@ -166,10 +189,10 @@ impl SessionTx {
store.aggr_meet_put(&item, &mut aggr, 0)?;
} else if should_check_limit {
if !store.exists(&item, 0) {
store.put(item, 0);
if limiter.incr() {
store.put_with_skip(item, limiter.should_skip_next());
if limiter.incr_and_should_stop() {
trace!("early stopping due to result count limit exceeded");
return Ok(());
return Ok(true);
}
}
} else {
@ -188,7 +211,7 @@ impl SessionTx {
rule_symb, rule_n
);
for (serial, item_res) in
rule.relation.iter(self, Some(0), &use_delta)?.enumerate()
rule.relation.iter(self, Some(0), &use_delta)?.enumerate()
{
let item = item_res?;
trace!("item for {:?}.{}: {:?} at {}", rule_symb, rule_n, item, 0);
@ -207,11 +230,11 @@ impl SessionTx {
},
poison,
)? {
return Ok(());
return Ok(true);
}
}
}
Ok(())
Ok(false)
}
fn incremental_rule_eval(
&self,
@ -224,9 +247,9 @@ impl SessionTx {
changed: &mut BTreeMap<&MagicSymbol, bool>,
limiter: &mut QueryLimiter,
poison: Poison,
) -> Result<()> {
) -> Result<bool> {
let store = stores.get(rule_symb).unwrap();
let should_check_limit = limiter.limit.is_some() && rule_symb.is_prog_entry();
let should_check_limit = limiter.total.is_some() && rule_symb.is_prog_entry();
for (rule_n, rule) in ruleset.iter().enumerate() {
let mut should_do_calculation = false;
for d_rule in &rule.contained_rules {
@ -281,16 +304,16 @@ impl SessionTx {
);
*changed.get_mut(rule_symb).unwrap() = true;
store.put(item.clone(), epoch);
store.put(item, 0);
if should_check_limit && limiter.incr() {
store.put_with_skip(item, limiter.should_skip_next());
if should_check_limit && limiter.incr_and_should_stop() {
trace!("early stopping due to result count limit exceeded");
return Ok(());
return Ok(true);
}
}
poison.check()?;
}
}
}
Ok(())
Ok(false)
}
}

@ -534,7 +534,7 @@ impl Db {
running_queries: self.running_queries.clone(),
};
let result = tx.stratified_magic_evaluate(
let (result, early_return) = tx.stratified_magic_evaluate(
&compiled,
&stores,
if input_program.out_opts.sorters.is_empty() {
@ -542,6 +542,11 @@ impl Db {
} else {
None
},
if input_program.out_opts.sorters.is_empty() {
input_program.out_opts.offset
} else {
None
},
poison,
)?;
if let Some(assertion) = &input_program.out_opts.assertion {
@ -613,12 +618,15 @@ impl Db {
Ok((json!({ "rows": ret, "headers": json_headers }), clean_ups))
}
} else {
let scan = if input_program.out_opts.limit.is_some()
let scan = if early_return {
let limit = input_program.out_opts.limit.unwrap_or(usize::MAX);
Right(Left(result.scan_early_returned().take(limit)))
} else if input_program.out_opts.limit.is_some()
|| input_program.out_opts.offset.is_some()
{
let limit = input_program.out_opts.limit.unwrap_or(usize::MAX);
let offset = input_program.out_opts.offset.unwrap_or(0);
Right(result.scan_all().skip(offset).take(limit))
Right(Right(result.scan_all().skip(offset).take(limit)))
} else {
Left(result.scan_all())
};

@ -7,8 +7,8 @@ use std::collections::BTreeMap;
use std::fmt::{Debug, Formatter};
use std::iter;
use std::ops::Bound::Included;
use std::sync::{Arc, RwLock};
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::{Arc, RwLock};
use either::{Left, Right};
use itertools::Itertools;
@ -16,7 +16,7 @@ use miette::Result;
use crate::data::aggr::Aggregation;
use crate::data::program::MagicSymbol;
use crate::data::tuple::{Tuple};
use crate::data::tuple::Tuple;
use crate::data::value::DataValue;
use crate::query::eval::QueryLimiter;
use crate::runtime::db::Poison;
@ -46,11 +46,7 @@ impl Debug for InMemRelation {
}
impl InMemRelation {
pub(crate) fn new(
id: StoredRelationId,
rule_name: MagicSymbol,
arity: usize,
) -> InMemRelation {
pub(crate) fn new(id: StoredRelationId, rule_name: MagicSymbol, arity: usize) -> InMemRelation {
Self {
epoch_size: Default::default(),
mem_db: Default::default(),
@ -139,6 +135,16 @@ impl InMemRelation {
let mut target = db.get(epoch as usize).unwrap().try_write().unwrap();
target.insert(tuple, Tuple::default());
}
pub(crate) fn put_with_skip(&self, tuple: Tuple, should_skip: bool) {
self.ensure_mem_db_for_epoch(0);
let db = self.mem_db.try_read().unwrap();
let mut target = db.get(0).unwrap().try_write().unwrap();
if should_skip {
target.insert(tuple, Tuple(vec![DataValue::Guard]));
} else {
target.insert(tuple, Tuple::default());
}
}
pub(crate) fn normal_aggr_put(
&self,
tuple: &Tuple,
@ -180,9 +186,7 @@ impl InMemRelation {
let db_target = self.mem_db.try_read().unwrap();
let target = db_target.get(0);
let it = match target {
None => {
Left(iter::empty())
}
None => Left(iter::empty()),
Some(target) => {
let target = target.try_read().unwrap();
Right(target.clone().into_iter().map(|(k, v)| {
@ -260,8 +264,8 @@ impl InMemRelation {
let res_tpl = Tuple(aggr_res);
if let Some(lmt) = limiter.borrow_mut() {
if !store.exists(&res_tpl, 0) {
store.put(res_tpl, 0);
if lmt.incr() {
store.put_with_skip(res_tpl, lmt.should_skip_next());
if lmt.incr_and_should_stop() {
return Ok(true);
}
}
@ -272,7 +276,7 @@ impl InMemRelation {
Ok(false)
}
pub(crate) fn scan_all_for_epoch(&self, epoch: u32) -> impl Iterator<Item=Result<Tuple>> {
pub(crate) fn scan_all_for_epoch(&self, epoch: u32) -> impl Iterator<Item = Result<Tuple>> {
self.ensure_mem_db_for_epoch(epoch);
let db = self
.mem_db
@ -303,17 +307,50 @@ impl InMemRelation {
}
})
}
pub(crate) fn scan_all(&self) -> impl Iterator<Item=Result<Tuple>> {
pub(crate) fn scan_all(&self) -> impl Iterator<Item = Result<Tuple>> {
self.scan_all_for_epoch(0)
}
pub(crate) fn scan_prefix(&self, prefix: &Tuple) -> impl Iterator<Item=Result<Tuple>> {
pub(crate) fn scan_early_returned(&self) -> impl Iterator<Item = Result<Tuple>> {
self.ensure_mem_db_for_epoch(0);
let db = self
.mem_db
.try_read()
.unwrap()
.get(0)
.unwrap()
.clone()
.try_read()
.unwrap()
.clone();
db.into_iter().filter_map(|(k, v)| {
if v.0.is_empty() {
Some(Ok(k))
} else if v.0.last() == Some(&DataValue::Guard) {
None
} else {
let combined =
k.0.iter()
.zip(v.0.iter())
.map(|(kel, vel)| {
if matches!(kel, DataValue::Guard) {
vel.clone()
} else {
kel.clone()
}
})
.collect_vec();
Some(Ok(Tuple(combined)))
}
})
}
pub(crate) fn scan_prefix(&self, prefix: &Tuple) -> impl Iterator<Item = Result<Tuple>> {
self.scan_prefix_for_epoch(prefix, 0)
}
pub(crate) fn scan_prefix_for_epoch(
&self,
prefix: &Tuple,
epoch: u32,
) -> impl Iterator<Item=Result<Tuple>> {
) -> impl Iterator<Item = Result<Tuple>> {
let mut upper = prefix.0.clone();
upper.push(DataValue::Bot);
let upper = Tuple(upper);
@ -349,7 +386,7 @@ impl InMemRelation {
lower: &[DataValue],
upper: &[DataValue],
epoch: u32,
) -> impl Iterator<Item=Result<Tuple>> {
) -> impl Iterator<Item = Result<Tuple>> {
self.ensure_mem_db_for_epoch(epoch);
let mut prefix_bound = prefix.clone();
prefix_bound.0.extend_from_slice(lower);

@ -6,6 +6,7 @@ use std::str::FromStr;
use std::time::Instant;
use approx::AbsDiffEq;
use env_logger::Env;
use lazy_static::lazy_static;
use serde_json::json;
@ -1460,6 +1461,7 @@ fn longest_routes() {
#[test]
fn longest_routes_from_each_airports() {
env_logger::Builder::from_env(Env::default().default_filter_or("info")).init();
check_db();
let longest_routes_from_each_airports = Instant::now();
@ -1851,3 +1853,67 @@ fn furthest_from_lhr() {
);
dbg!(furthest_from_lhr.elapsed());
}
#[test]
fn skip_limit() {
check_db();
let res = TEST_DB
.run_script(
r#"
?[a] := a in [9, 9, 8, 9, 8, 7, 7, 6, 5, 9, 4, 4, 3]
"#,
&Default::default(),
)
.unwrap();
let rows = res.get("rows").unwrap();
assert_eq!(*rows, json!([[3], [4], [5], [6], [7], [8], [9]]));
let res = TEST_DB
.run_script(
r#"
?[a] := a in [9, 9, 8, 9, 8, 7, 7, 6, 5, 9, 4, 4, 3]
"#,
&Default::default(),
)
.unwrap();
let rows = res.get("rows").unwrap();
assert_eq!(*rows, json!([[3], [4], [5], [6], [7], [8], [9]]));
let res = TEST_DB
.run_script(
r#"
?[a] := a in [9, 9, 8, 9, 8, 7, 7, 6, 5, 9, 4, 4, 3]
:limit 2
"#,
&Default::default(),
)
.unwrap();
let rows = res.get("rows").unwrap();
assert_eq!(*rows, json!([[8], [9]]));
let res = TEST_DB
.run_script(
r#"
?[a] := a in [9, 9, 8, 9, 8, 7, 7, 6, 5, 9, 4, 4, 3]
:limit 2
:offset 1
"#,
&Default::default(),
)
.unwrap();
let rows = res.get("rows").unwrap();
assert_eq!(*rows, json!([[7], [8]]));
let res = TEST_DB
.run_script(
r#"
?[a] := a in [9, 9, 8, 9, 8, 7, 7, 6, 5, 9, 4, 4, 3]
:limit 100
:offset 1
"#,
&Default::default(),
)
.unwrap();
let rows = res.get("rows").unwrap();
assert_eq!(*rows, json!([[3], [4], [5], [6], [7], [8]]));
}

Loading…
Cancel
Save