diff --git a/Cargo.lock b/Cargo.lock index 3add23c7..4b406b28 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -335,7 +335,7 @@ checksum = "5827cebf4670468b8772dd191856768aedcb1b0278a04f989f7766351917b9dc" [[package]] name = "cozo" -version = "0.1.0" +version = "0.1.1" dependencies = [ "approx", "base64", diff --git a/Cargo.toml b/Cargo.toml index 0a1f02a3..757bba22 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "cozo" -version = "0.1.0" +version = "0.1.1" edition = "2021" description = "A general-purpose, transactional, relational database that uses Datalog and focuses on graph data" authors = ["Ziyang Hu"] diff --git a/docs/source/queries.rst b/docs/source/queries.rst index 1e8b7416..0bee5908 100644 --- a/docs/source/queries.rst +++ b/docs/source/queries.rst @@ -309,5 +309,6 @@ The rest of the query options are explained in the following. The query returns nothing if the output relation contains at least one row, otherwise, execution aborts with an error. - Implies ``:limit 1`` to ensure early termination if possible. - Useful for transactions and triggers. \ No newline at end of file + Useful for transactions and triggers. + You should consider adding ``:limit 1`` to the query to ensure early termination + if you do not need to check all return tuples. \ No newline at end of file diff --git a/src/query/eval.rs b/src/query/eval.rs index f17d8461..e71fba88 100644 --- a/src/query/eval.rs +++ b/src/query/eval.rs @@ -9,7 +9,7 @@ use log::{debug, trace}; use miette::Result; use crate::data::program::{MagicAlgoApply, MagicSymbol, NoEntryError}; -use crate::data::symb::{PROG_ENTRY, Symbol}; +use crate::data::symb::{Symbol, PROG_ENTRY}; use crate::parse::SourceSpan; use crate::query::compile::{AggrKind, CompiledProgram, CompiledRule, CompiledRuleSet}; use crate::runtime::db::Poison; @@ -17,19 +17,29 @@ use crate::runtime::in_mem::InMemRelation; use crate::runtime::transact::SessionTx; pub(crate) struct QueryLimiter { - limit: Option, + total: Option, + skip: Option, counter: usize, } impl QueryLimiter { - pub(crate) fn incr(&mut self) -> bool { - if let Some(limit) = self.limit { + pub(crate) fn is_used(&self) -> bool { + self.total.is_some() || self.skip.is_some() + } + pub(crate) fn incr_and_should_stop(&mut self) -> bool { + if let Some(limit) = self.total { self.counter += 1; self.counter >= limit } else { false } } + pub(crate) fn should_skip_next(&self) -> bool { + match self.skip { + None => false, + Some(i) => i > self.counter, + } + } } impl SessionTx { @@ -37,33 +47,42 @@ impl SessionTx { &self, strata: &[CompiledProgram], stores: &BTreeMap, - num_to_take: Option, + total_num_to_take: Option, + num_to_skip: Option, poison: Poison, - ) -> Result { + ) -> Result<(InMemRelation, bool)> { let ret_area = stores .get(&MagicSymbol::Muggle { inner: Symbol::new(PROG_ENTRY, SourceSpan(0, 0)), }) .ok_or(NoEntryError)? .clone(); - + let mut early_return = false; for (idx, cur_prog) in strata.iter().enumerate() { debug!("stratum {}", idx); - self.semi_naive_magic_evaluate(cur_prog, stores, num_to_take, poison.clone())?; + early_return = self.semi_naive_magic_evaluate( + cur_prog, + stores, + total_num_to_take, + num_to_skip, + poison.clone(), + )?; } - Ok(ret_area) + Ok((ret_area, early_return)) } fn semi_naive_magic_evaluate( &self, prog: &CompiledProgram, stores: &BTreeMap, - num_to_take: Option, + total_num_to_take: Option, + num_to_skip: Option, poison: Poison, - ) -> Result<()> { + ) -> Result { let mut changed: BTreeMap<_, _> = prog.keys().map(|k| (k, false)).collect(); let mut prev_changed = changed.clone(); let mut limiter = QueryLimiter { - limit: num_to_take, + total: total_num_to_take, + skip: num_to_skip, counter: 0, }; @@ -74,7 +93,7 @@ impl SessionTx { match compiled_ruleset { CompiledRuleSet::Rules(ruleset) => { let aggr_kind = compiled_ruleset.aggr_kind(); - self.initial_rule_eval( + if self.initial_rule_eval( k, ruleset, aggr_kind, @@ -82,7 +101,9 @@ impl SessionTx { &mut changed, &mut limiter, poison.clone(), - )?; + )? { + return Ok(true); + }; } CompiledRuleSet::Algo(algo_apply) => { self.algo_application_eval(k, algo_apply, stores, poison.clone())?; @@ -103,7 +124,7 @@ impl SessionTx { AggrKind::Normal => false, AggrKind::Meet => true, }; - self.incremental_rule_eval( + if self.incremental_rule_eval( k, ruleset, epoch, @@ -113,7 +134,9 @@ impl SessionTx { &mut changed, &mut limiter, poison.clone(), - )?; + )? { + return Ok(true); + }; } CompiledRuleSet::Algo(_) => unreachable!(), @@ -124,7 +147,7 @@ impl SessionTx { break; } } - Ok(()) + Ok(limiter.is_used()) } fn algo_application_eval( &self, @@ -146,10 +169,10 @@ impl SessionTx { changed: &mut BTreeMap<&MagicSymbol, bool>, limiter: &mut QueryLimiter, poison: Poison, - ) -> Result<()> { + ) -> Result { let store = stores.get(rule_symb).unwrap(); let use_delta = BTreeSet::default(); - let should_check_limit = limiter.limit.is_some() && rule_symb.is_prog_entry(); + let should_check_limit = limiter.total.is_some() && rule_symb.is_prog_entry(); match aggr_kind { AggrKind::None | AggrKind::Meet => { let is_meet = aggr_kind == AggrKind::Meet; @@ -166,10 +189,10 @@ impl SessionTx { store.aggr_meet_put(&item, &mut aggr, 0)?; } else if should_check_limit { if !store.exists(&item, 0) { - store.put(item, 0); - if limiter.incr() { + store.put_with_skip(item, limiter.should_skip_next()); + if limiter.incr_and_should_stop() { trace!("early stopping due to result count limit exceeded"); - return Ok(()); + return Ok(true); } } } else { @@ -188,7 +211,7 @@ impl SessionTx { rule_symb, rule_n ); for (serial, item_res) in - rule.relation.iter(self, Some(0), &use_delta)?.enumerate() + rule.relation.iter(self, Some(0), &use_delta)?.enumerate() { let item = item_res?; trace!("item for {:?}.{}: {:?} at {}", rule_symb, rule_n, item, 0); @@ -207,11 +230,11 @@ impl SessionTx { }, poison, )? { - return Ok(()); + return Ok(true); } } } - Ok(()) + Ok(false) } fn incremental_rule_eval( &self, @@ -224,9 +247,9 @@ impl SessionTx { changed: &mut BTreeMap<&MagicSymbol, bool>, limiter: &mut QueryLimiter, poison: Poison, - ) -> Result<()> { + ) -> Result { let store = stores.get(rule_symb).unwrap(); - let should_check_limit = limiter.limit.is_some() && rule_symb.is_prog_entry(); + let should_check_limit = limiter.total.is_some() && rule_symb.is_prog_entry(); for (rule_n, rule) in ruleset.iter().enumerate() { let mut should_do_calculation = false; for d_rule in &rule.contained_rules { @@ -281,16 +304,16 @@ impl SessionTx { ); *changed.get_mut(rule_symb).unwrap() = true; store.put(item.clone(), epoch); - store.put(item, 0); - if should_check_limit && limiter.incr() { + store.put_with_skip(item, limiter.should_skip_next()); + if should_check_limit && limiter.incr_and_should_stop() { trace!("early stopping due to result count limit exceeded"); - return Ok(()); + return Ok(true); } } poison.check()?; } } } - Ok(()) + Ok(false) } } diff --git a/src/runtime/db.rs b/src/runtime/db.rs index 3d4c9698..9e9ca559 100644 --- a/src/runtime/db.rs +++ b/src/runtime/db.rs @@ -534,7 +534,7 @@ impl Db { running_queries: self.running_queries.clone(), }; - let result = tx.stratified_magic_evaluate( + let (result, early_return) = tx.stratified_magic_evaluate( &compiled, &stores, if input_program.out_opts.sorters.is_empty() { @@ -542,6 +542,11 @@ impl Db { } else { None }, + if input_program.out_opts.sorters.is_empty() { + input_program.out_opts.offset + } else { + None + }, poison, )?; if let Some(assertion) = &input_program.out_opts.assertion { @@ -613,12 +618,15 @@ impl Db { Ok((json!({ "rows": ret, "headers": json_headers }), clean_ups)) } } else { - let scan = if input_program.out_opts.limit.is_some() + let scan = if early_return { + let limit = input_program.out_opts.limit.unwrap_or(usize::MAX); + Right(Left(result.scan_early_returned().take(limit))) + } else if input_program.out_opts.limit.is_some() || input_program.out_opts.offset.is_some() { let limit = input_program.out_opts.limit.unwrap_or(usize::MAX); let offset = input_program.out_opts.offset.unwrap_or(0); - Right(result.scan_all().skip(offset).take(limit)) + Right(Right(result.scan_all().skip(offset).take(limit))) } else { Left(result.scan_all()) }; diff --git a/src/runtime/in_mem.rs b/src/runtime/in_mem.rs index d466d243..926ac5bc 100644 --- a/src/runtime/in_mem.rs +++ b/src/runtime/in_mem.rs @@ -7,8 +7,8 @@ use std::collections::BTreeMap; use std::fmt::{Debug, Formatter}; use std::iter; use std::ops::Bound::Included; -use std::sync::{Arc, RwLock}; use std::sync::atomic::{AtomicU32, Ordering}; +use std::sync::{Arc, RwLock}; use either::{Left, Right}; use itertools::Itertools; @@ -16,7 +16,7 @@ use miette::Result; use crate::data::aggr::Aggregation; use crate::data::program::MagicSymbol; -use crate::data::tuple::{Tuple}; +use crate::data::tuple::Tuple; use crate::data::value::DataValue; use crate::query::eval::QueryLimiter; use crate::runtime::db::Poison; @@ -46,11 +46,7 @@ impl Debug for InMemRelation { } impl InMemRelation { - pub(crate) fn new( - id: StoredRelationId, - rule_name: MagicSymbol, - arity: usize, - ) -> InMemRelation { + pub(crate) fn new(id: StoredRelationId, rule_name: MagicSymbol, arity: usize) -> InMemRelation { Self { epoch_size: Default::default(), mem_db: Default::default(), @@ -139,6 +135,16 @@ impl InMemRelation { let mut target = db.get(epoch as usize).unwrap().try_write().unwrap(); target.insert(tuple, Tuple::default()); } + pub(crate) fn put_with_skip(&self, tuple: Tuple, should_skip: bool) { + self.ensure_mem_db_for_epoch(0); + let db = self.mem_db.try_read().unwrap(); + let mut target = db.get(0).unwrap().try_write().unwrap(); + if should_skip { + target.insert(tuple, Tuple(vec![DataValue::Guard])); + } else { + target.insert(tuple, Tuple::default()); + } + } pub(crate) fn normal_aggr_put( &self, tuple: &Tuple, @@ -180,9 +186,7 @@ impl InMemRelation { let db_target = self.mem_db.try_read().unwrap(); let target = db_target.get(0); let it = match target { - None => { - Left(iter::empty()) - } + None => Left(iter::empty()), Some(target) => { let target = target.try_read().unwrap(); Right(target.clone().into_iter().map(|(k, v)| { @@ -260,8 +264,8 @@ impl InMemRelation { let res_tpl = Tuple(aggr_res); if let Some(lmt) = limiter.borrow_mut() { if !store.exists(&res_tpl, 0) { - store.put(res_tpl, 0); - if lmt.incr() { + store.put_with_skip(res_tpl, lmt.should_skip_next()); + if lmt.incr_and_should_stop() { return Ok(true); } } @@ -272,7 +276,7 @@ impl InMemRelation { Ok(false) } - pub(crate) fn scan_all_for_epoch(&self, epoch: u32) -> impl Iterator> { + pub(crate) fn scan_all_for_epoch(&self, epoch: u32) -> impl Iterator> { self.ensure_mem_db_for_epoch(epoch); let db = self .mem_db @@ -303,17 +307,50 @@ impl InMemRelation { } }) } - pub(crate) fn scan_all(&self) -> impl Iterator> { + pub(crate) fn scan_all(&self) -> impl Iterator> { self.scan_all_for_epoch(0) } - pub(crate) fn scan_prefix(&self, prefix: &Tuple) -> impl Iterator> { + pub(crate) fn scan_early_returned(&self) -> impl Iterator> { + self.ensure_mem_db_for_epoch(0); + let db = self + .mem_db + .try_read() + .unwrap() + .get(0) + .unwrap() + .clone() + .try_read() + .unwrap() + .clone(); + db.into_iter().filter_map(|(k, v)| { + if v.0.is_empty() { + Some(Ok(k)) + } else if v.0.last() == Some(&DataValue::Guard) { + None + } else { + let combined = + k.0.iter() + .zip(v.0.iter()) + .map(|(kel, vel)| { + if matches!(kel, DataValue::Guard) { + vel.clone() + } else { + kel.clone() + } + }) + .collect_vec(); + Some(Ok(Tuple(combined))) + } + }) + } + pub(crate) fn scan_prefix(&self, prefix: &Tuple) -> impl Iterator> { self.scan_prefix_for_epoch(prefix, 0) } pub(crate) fn scan_prefix_for_epoch( &self, prefix: &Tuple, epoch: u32, - ) -> impl Iterator> { + ) -> impl Iterator> { let mut upper = prefix.0.clone(); upper.push(DataValue::Bot); let upper = Tuple(upper); @@ -349,7 +386,7 @@ impl InMemRelation { lower: &[DataValue], upper: &[DataValue], epoch: u32, - ) -> impl Iterator> { + ) -> impl Iterator> { self.ensure_mem_db_for_epoch(epoch); let mut prefix_bound = prefix.clone(); prefix_bound.0.extend_from_slice(lower); diff --git a/tests/air_routes.rs b/tests/air_routes.rs index 7520e515..5e45b78e 100644 --- a/tests/air_routes.rs +++ b/tests/air_routes.rs @@ -6,6 +6,7 @@ use std::str::FromStr; use std::time::Instant; use approx::AbsDiffEq; +use env_logger::Env; use lazy_static::lazy_static; use serde_json::json; @@ -1460,6 +1461,7 @@ fn longest_routes() { #[test] fn longest_routes_from_each_airports() { + env_logger::Builder::from_env(Env::default().default_filter_or("info")).init(); check_db(); let longest_routes_from_each_airports = Instant::now(); @@ -1851,3 +1853,67 @@ fn furthest_from_lhr() { ); dbg!(furthest_from_lhr.elapsed()); } + +#[test] +fn skip_limit() { + check_db(); + let res = TEST_DB + .run_script( + r#" + ?[a] := a in [9, 9, 8, 9, 8, 7, 7, 6, 5, 9, 4, 4, 3] + "#, + &Default::default(), + ) + .unwrap(); + let rows = res.get("rows").unwrap(); + assert_eq!(*rows, json!([[3], [4], [5], [6], [7], [8], [9]])); + + let res = TEST_DB + .run_script( + r#" + ?[a] := a in [9, 9, 8, 9, 8, 7, 7, 6, 5, 9, 4, 4, 3] + "#, + &Default::default(), + ) + .unwrap(); + let rows = res.get("rows").unwrap(); + assert_eq!(*rows, json!([[3], [4], [5], [6], [7], [8], [9]])); + + let res = TEST_DB + .run_script( + r#" + ?[a] := a in [9, 9, 8, 9, 8, 7, 7, 6, 5, 9, 4, 4, 3] + :limit 2 + "#, + &Default::default(), + ) + .unwrap(); + let rows = res.get("rows").unwrap(); + assert_eq!(*rows, json!([[8], [9]])); + + let res = TEST_DB + .run_script( + r#" + ?[a] := a in [9, 9, 8, 9, 8, 7, 7, 6, 5, 9, 4, 4, 3] + :limit 2 + :offset 1 + "#, + &Default::default(), + ) + .unwrap(); + let rows = res.get("rows").unwrap(); + assert_eq!(*rows, json!([[7], [8]])); + + let res = TEST_DB + .run_script( + r#" + ?[a] := a in [9, 9, 8, 9, 8, 7, 7, 6, 5, 9, 4, 4, 3] + :limit 100 + :offset 1 + "#, + &Default::default(), + ) + .unwrap(); + let rows = res.get("rows").unwrap(); + assert_eq!(*rows, json!([[3], [4], [5], [6], [7], [8]])); +}