avoid copying programs

main
Ziyang Hu 2 years ago
parent 6beec31ada
commit 847ac55cfc

@ -184,7 +184,7 @@ impl TempSymbGen {
}
}
#[derive(Debug, Clone)]
#[derive(Debug)]
pub(crate) enum InputInlineRulesOrAlgo {
Rules { rules: Vec<InputInlineRule> },
Algo { algo: AlgoApply },
@ -209,20 +209,6 @@ pub(crate) struct AlgoApply {
pub(crate) algo_impl: Box<dyn AlgoImpl>,
}
impl Clone for AlgoApply {
fn clone(&self) -> Self {
Self {
algo: self.algo.clone(),
rule_args: self.rule_args.clone(),
options: self.options.clone(),
head: self.head.clone(),
arity: self.arity,
span: self.span,
algo_impl: self.algo.get_impl().unwrap(),
}
}
}
impl AlgoApply {
pub(crate) fn arity(&self) -> Result<usize> {
self.algo_impl.arity(&self.options, &self.head, self.span)
@ -239,7 +225,6 @@ impl Debug for AlgoApply {
}
}
#[derive(Clone)]
pub(crate) struct MagicAlgoApply {
pub(crate) algo: AlgoHandle,
pub(crate) rule_args: Vec<MagicAlgoRuleArg>,
@ -325,7 +310,6 @@ impl Debug for MagicAlgoApply {
}
}
#[derive(Clone)]
pub(crate) enum AlgoRuleArg {
InMem {
name: Symbol,
@ -374,7 +358,7 @@ impl Display for AlgoRuleArg {
}
}
#[derive(Debug, Clone)]
#[derive(Debug)]
pub(crate) enum MagicAlgoRuleArg {
InMem {
name: MagicSymbol,
@ -415,7 +399,7 @@ impl MagicAlgoRuleArg {
}
}
#[derive(Debug, Clone)]
#[derive(Debug)]
pub(crate) struct InputProgram {
pub(crate) prog: BTreeMap<Symbol, InputInlineRulesOrAlgo>,
pub(crate) out_opts: QueryOutOptions,
@ -568,9 +552,12 @@ impl InputProgram {
Err(NoEntryError.into())
}
pub(crate) fn to_normalized_program(&self, tx: &SessionTx<'_>) -> Result<NormalFormProgram> {
pub(crate) fn into_normalized_program(
self,
tx: &SessionTx<'_>,
) -> Result<(NormalFormProgram, QueryOutOptions)> {
let mut prog: BTreeMap<Symbol, _> = Default::default();
for (k, rules_or_algo) in &self.prog {
for (k, rules_or_algo) in self.prog {
match rules_or_algo {
InputInlineRulesOrAlgo::Rules { rules } => {
let mut collected_rules = vec![];
@ -581,7 +568,7 @@ impl InputProgram {
Symbol::new(&format!("***{}", counter) as &str, span)
};
let normalized_body = InputAtom::Conjunction {
inner: rule.body.clone(),
inner: rule.body,
span: rule.span,
}
.disjunctive_normal_form(tx)?;
@ -631,23 +618,18 @@ impl InputProgram {
);
}
InputInlineRulesOrAlgo::Algo { algo: algo_apply } => {
prog.insert(
k.clone(),
NormalFormAlgoOrRules::Algo {
algo: algo_apply.clone(),
},
);
prog.insert(k.clone(), NormalFormAlgoOrRules::Algo { algo: algo_apply });
}
}
}
Ok(NormalFormProgram { prog })
Ok((NormalFormProgram { prog }, self.out_opts))
}
}
#[derive(Debug, Clone)]
#[derive(Debug)]
pub(crate) struct StratifiedNormalFormProgram(pub(crate) Vec<NormalFormProgram>);
#[derive(Debug, Clone)]
#[derive(Debug)]
pub(crate) enum NormalFormAlgoOrRules {
Rules { rules: Vec<NormalFormInlineRule> },
Algo { algo: AlgoApply },
@ -662,15 +644,15 @@ impl NormalFormAlgoOrRules {
}
}
#[derive(Debug, Clone, Default)]
#[derive(Debug, Default)]
pub(crate) struct NormalFormProgram {
pub(crate) prog: BTreeMap<Symbol, NormalFormAlgoOrRules>,
}
#[derive(Debug, Clone)]
#[derive(Debug)]
pub(crate) struct StratifiedMagicProgram(pub(crate) Vec<MagicProgram>);
#[derive(Debug, Clone)]
#[derive(Debug)]
pub(crate) enum MagicRulesOrAlgo {
Rules { rules: Vec<MagicInlineRule> },
Algo { algo: MagicAlgoApply },
@ -697,7 +679,7 @@ impl MagicRulesOrAlgo {
}
}
#[derive(Debug, Clone)]
#[derive(Debug)]
pub(crate) struct MagicProgram {
pub(crate) prog: BTreeMap<MagicSymbol, MagicRulesOrAlgo>,
}
@ -815,7 +797,7 @@ impl MagicSymbol {
}
}
#[derive(Debug, Clone)]
#[derive(Debug)]
pub(crate) struct InputInlineRule {
pub(crate) head: Vec<Symbol>,
pub(crate) aggr: Vec<Option<(Aggregation, Vec<DataValue>)>>,
@ -823,14 +805,14 @@ pub(crate) struct InputInlineRule {
pub(crate) span: SourceSpan,
}
#[derive(Debug, Clone)]
#[derive(Debug)]
pub(crate) struct NormalFormInlineRule {
pub(crate) head: Vec<Symbol>,
pub(crate) aggr: Vec<Option<(Aggregation, Vec<DataValue>)>>,
pub(crate) body: Vec<NormalFormAtom>,
}
#[derive(Debug, Clone)]
#[derive(Debug)]
pub(crate) struct MagicInlineRule {
pub(crate) head: Vec<Symbol>,
pub(crate) aggr: Vec<Option<(Aggregation, Vec<DataValue>)>>,
@ -852,7 +834,6 @@ impl MagicInlineRule {
}
}
#[derive(Clone)]
pub(crate) enum InputAtom {
Rule {
inner: InputRuleApplyAtom,

@ -98,29 +98,24 @@ struct ArityMismatch(String, usize, usize, #[label] SourceSpan);
impl<'a> SessionTx<'a> {
pub(crate) fn stratified_magic_compile(
&mut self,
prog: &StratifiedMagicProgram,
prog: StratifiedMagicProgram,
) -> Result<Vec<CompiledProgram>> {
// let mut stores: BTreeMap<MagicSymbol, InMemRelation> = Default::default();
let mut store_arities: BTreeMap<&MagicSymbol, usize> = Default::default();
let mut store_arities: BTreeMap<MagicSymbol, usize> = Default::default();
for stratum in prog.0.iter() {
for (name, ruleset) in &stratum.prog {
// stores.insert(
// name.clone(),
// self.new_rule_store(ruleset.arity()?),
// );
store_arities.insert(name, ruleset.arity()?);
store_arities.insert(name.clone(), ruleset.arity()?);
}
}
let compiled: Vec<_> = prog
.0
.iter()
.into_iter()
.rev()
.map(|cur_prog| -> Result<CompiledProgram> {
cur_prog
.prog
.iter()
.into_iter()
.map(|(k, body)| -> Result<(MagicSymbol, CompiledRuleSet)> {
match body {
MagicRulesOrAlgo::Rules { rules: body } => {
@ -128,7 +123,7 @@ impl<'a> SessionTx<'a> {
for rule in body.iter() {
let header = &rule.head;
let mut relation =
self.compile_magic_rule_body(rule, k, &store_arities, header)?;
self.compile_magic_rule_body(rule, &k, &store_arities, header)?;
relation.fill_binding_indices().with_context(|| {
format!(
"error encountered when filling binding indices for {:#?}",
@ -141,11 +136,11 @@ impl<'a> SessionTx<'a> {
contained_rules: rule.contained_rules(),
})
}
Ok((k.clone(), CompiledRuleSet::Rules(collected)))
Ok((k, CompiledRuleSet::Rules(collected)))
}
MagicRulesOrAlgo::Algo { algo: algo_apply } => {
Ok((k.clone(), CompiledRuleSet::Algo(algo_apply.clone())))
Ok((k, CompiledRuleSet::Algo(algo_apply)))
}
}
})
@ -158,7 +153,7 @@ impl<'a> SessionTx<'a> {
&mut self,
rule: &MagicInlineRule,
rule_name: &MagicSymbol,
store_arities: &BTreeMap<&MagicSymbol, usize>,
store_arities: &BTreeMap<MagicSymbol, usize>,
ret_vars: &[Symbol],
) -> Result<RelAlgebra> {
let mut ret = RelAlgebra::unit(rule_name.symbol().span);

@ -215,7 +215,7 @@ fn make_scc_reduced_graph<'a>(
impl NormalFormProgram {
/// returns the stratified program and the store lifetimes of the intermediate relations
pub(crate) fn stratify(
pub(crate) fn into_stratified_program(
self,
) -> Result<(StratifiedNormalFormProgram, BTreeMap<MagicSymbol, usize>)> {
// prerequisite: the program is already in disjunctive normal form
@ -256,7 +256,8 @@ impl NormalFormProgram {
.flat_map(|(stratum, indices)| indices.into_iter().map(move |idx| (idx, stratum)))
.collect::<BTreeMap<_, _>>();
// 7. translate the stratification into datalog program
let mut ret: Vec<NormalFormProgram> = vec![Default::default(); n_strata];
let mut ret: Vec<NormalFormProgram> =
(0..n_strata).map(|_| Default::default()).collect_vec();
let mut store_lifetimes = BTreeMap::new();
for (fr, tos) in &stratified_graph {

@ -650,9 +650,10 @@ impl<'s, S: Storage<'s>> Db<S> {
match op {
SysOp::Explain(prog) => {
let mut tx = self.transact()?;
let (stratified_program, _) = prog.to_normalized_program(&tx)?.stratify()?;
let (normalized_program, _) = prog.into_normalized_program(&tx)?;
let (stratified_program, _) = normalized_program.into_stratified_program()?;
let program = stratified_program.magic_sets_rewrite(&tx)?;
let compiled = tx.stratified_magic_compile(&program)?;
let compiled = tx.stratified_magic_compile(program)?;
tx.commit_tx()?;
self.explain_compiled(&compiled)
}
@ -791,14 +792,15 @@ impl<'s, S: Storage<'s>> Db<S> {
};
// query compilation
let (stratified_program, store_lifetimes) =
input_program.to_normalized_program(tx)?.stratify()?;
let entry_head_or_default = input_program.get_entry_out_head_or_default()?;
let (normalized_program, out_opts) = input_program.into_normalized_program(tx)?;
let (stratified_program, store_lifetimes) = normalized_program.into_stratified_program()?;
let program = stratified_program.magic_sets_rewrite(tx)?;
let compiled = tx.stratified_magic_compile(&program)?;
let compiled = tx.stratified_magic_compile(program)?;
// poison is used to terminate queries early
let poison = Poison::default();
if let Some(secs) = input_program.out_opts.timeout {
if let Some(secs) = out_opts.timeout {
poison.set_timeout(secs)?;
}
// give the query an ID and store it so that it can be queried and cancelled
@ -828,14 +830,14 @@ impl<'s, S: Storage<'s>> Db<S> {
running_queries: self.running_queries.clone(),
};
let total_num_to_take = if input_program.out_opts.sorters.is_empty() {
input_program.out_opts.num_to_take()
let total_num_to_take = if out_opts.sorters.is_empty() {
out_opts.num_to_take()
} else {
None
};
let num_to_skip = if input_program.out_opts.sorters.is_empty() {
input_program.out_opts.offset
let num_to_skip = if out_opts.sorters.is_empty() {
out_opts.offset
} else {
None
};
@ -850,7 +852,7 @@ impl<'s, S: Storage<'s>> Db<S> {
)?;
// deal with assertions
if let Some(assertion) = &input_program.out_opts.assertion {
if let Some(assertion) = &out_opts.assertion {
match assertion {
QueryAssertion::AssertNone(span) => {
if let Some(tuple) = result_store.all_iter().next() {
@ -875,29 +877,28 @@ impl<'s, S: Storage<'s>> Db<S> {
}
}
if !input_program.out_opts.sorters.is_empty() {
if !out_opts.sorters.is_empty() {
// sort outputs if required
let entry_head = input_program.get_entry_out_head()?;
let sorted_result =
tx.sort_and_collect(result_store, &input_program.out_opts.sorters, &entry_head)?;
let sorted_iter = if let Some(offset) = input_program.out_opts.offset {
tx.sort_and_collect(result_store, &out_opts.sorters, &entry_head_or_default)?;
let sorted_iter = if let Some(offset) = out_opts.offset {
Left(sorted_result.into_iter().skip(offset))
} else {
Right(sorted_result.into_iter())
};
let sorted_iter = if let Some(limit) = input_program.out_opts.limit {
let sorted_iter = if let Some(limit) = out_opts.limit {
Left(sorted_iter.take(limit))
} else {
Right(sorted_iter)
};
if let Some((meta, relation_op)) = &input_program.out_opts.store_relation {
if let Some((meta, relation_op)) = &out_opts.store_relation {
let to_clear = tx
.execute_relation(
self,
sorted_iter,
*relation_op,
meta,
&input_program.get_entry_out_head_or_default()?,
&entry_head_or_default,
)
.wrap_err_with(|| format!("when executing against relation '{}'", meta.name))?;
clean_ups.extend(to_clear);
@ -915,25 +916,25 @@ impl<'s, S: Storage<'s>> Db<S> {
tuple.into_iter().map(JsonValue::from).collect()
})
.collect_vec();
let headers: Vec<String> = match input_program.get_entry_out_head() {
Ok(headers) => headers.into_iter().map(|v| v.name.to_string()).collect(),
Err(_) => match rows.get(0) {
None => vec![],
Some(row) => (0..row.len()).map(|i| format!("_{}", i)).collect_vec(),
Ok((
NamedRows {
headers: entry_head_or_default
.iter()
.map(|s| s.to_string())
.collect_vec(),
rows,
},
};
Ok((NamedRows { headers, rows }, clean_ups))
clean_ups,
))
}
} else {
let scan = if early_return {
Right(Left(
result_store.early_returned_iter().map(|t| t.into_tuple()),
))
} else if input_program.out_opts.limit.is_some()
|| input_program.out_opts.offset.is_some()
{
let limit = input_program.out_opts.limit.unwrap_or(usize::MAX);
let offset = input_program.out_opts.offset.unwrap_or(0);
} else if out_opts.limit.is_some() || out_opts.offset.is_some() {
let limit = out_opts.limit.unwrap_or(usize::MAX);
let offset = out_opts.offset.unwrap_or(0);
Right(Right(
result_store
.all_iter()
@ -945,14 +946,14 @@ impl<'s, S: Storage<'s>> Db<S> {
Left(result_store.all_iter().map(|t| t.into_tuple()))
};
if let Some((meta, relation_op)) = &input_program.out_opts.store_relation {
if let Some((meta, relation_op)) = &out_opts.store_relation {
let to_clear = tx
.execute_relation(
self,
scan,
*relation_op,
meta,
&input_program.get_entry_out_head_or_default()?,
&entry_head_or_default,
)
.wrap_err_with(|| format!("when executing against relation '{}'", meta.name))?;
clean_ups.extend(to_clear);
@ -970,15 +971,16 @@ impl<'s, S: Storage<'s>> Db<S> {
})
.collect_vec();
let headers: Vec<String> = match input_program.get_entry_out_head() {
Ok(headers) => headers.into_iter().map(|v| v.name.to_string()).collect(),
Err(_) => match rows.get(0) {
None => vec![],
Some(row) => (0..row.len()).map(|i| format!("_{}", i)).collect_vec(),
Ok((
NamedRows {
headers: entry_head_or_default
.iter()
.map(|s| s.to_string())
.collect_vec(),
rows,
},
};
Ok((NamedRows { headers, rows }, clean_ups))
clean_ups,
))
}
}
}
@ -1192,19 +1194,26 @@ mod tests {
let _ = env_logger::builder().is_test(true).try_init();
let db = new_cozo_mem().unwrap();
let res = db.run_script(r#"
let res = db
.run_script(
r#"
y[a] := a in [1,2,3]
x[sum(a)] := y[a]
x[sum(a)] := a in [4,5,6]
?[sum(a)] := x[a]
"#, Default::default()).unwrap().rows;
"#,
Default::default(),
)
.unwrap()
.rows;
assert_eq!(res[0][0], json!(21.))
}
#[test]
fn test_conditions() {
let _ = env_logger::builder().is_test(true).try_init();
let db = new_cozo_mem().unwrap();
db.run_script(r#"
db.run_script(
r#"
{
?[code] <- [['a'],['b'],['c']]
:create airport {code}
@ -1213,12 +1222,21 @@ mod tests {
?[fr, to, dist] <- [['a', 'b', 1.1], ['a', 'c', 0.5], ['b', 'c', 9.1]]
:create route {fr, to => dist}
}
"#, Default::default()).unwrap();
"#,
Default::default(),
)
.unwrap();
debug!("real test begins");
let res = db.run_script(r#"
let res = db
.run_script(
r#"
r[code, dist] := *airport{code}, *route{fr: code, dist};
?[dist] := r['a', dist], dist > 0.5, dist <= 1.1;
"#, Default::default()).unwrap().rows;
"#,
Default::default(),
)
.unwrap()
.rows;
assert_eq!(res[0][0], json!(1.1))
}
}

@ -6,6 +6,7 @@
* You can obtain one at https://mozilla.org/MPL/2.0/.
*/
#![warn(rust_2018_idioms, future_incompatible)]
#![allow(clippy::missing_safety_doc)]
use std::collections::BTreeMap;
use std::ffi::{c_char, CStr, CString};

@ -30,6 +30,7 @@ pub struct CozoDb {
#[wasm_bindgen]
impl CozoDb {
#[allow(clippy::new_without_default)]
pub fn new() -> Self {
utils::set_panic_hook();
let db = DbInstance::new("mem", "", "").unwrap();

Loading…
Cancel
Save