From 8c8840957abfd1dec8654ab51ff46d185feb6165 Mon Sep 17 00:00:00 2001 From: Ziyang Hu Date: Mon, 9 Jan 2023 18:38:27 +0800 Subject: [PATCH] implemented imperative script --- cozo-core/src/cozoscript.pest | 28 +-- cozo-core/src/data/expr.rs | 2 +- cozo-core/src/data/program.rs | 17 +- cozo-core/src/parse/imperative.rs | 224 ++++++++++++++++++++++ cozo-core/src/parse/mod.rs | 137 ++++++++++---- cozo-core/src/runtime/db.rs | 301 +++++++++++++++++++++++++++--- cozo-core/src/runtime/relation.rs | 50 ++++- 7 files changed, 676 insertions(+), 83 deletions(-) create mode 100644 cozo-core/src/parse/imperative.rs diff --git a/cozo-core/src/cozoscript.pest b/cozo-core/src/cozoscript.pest index bd408cf2..de0b954f 100644 --- a/cozo-core/src/cozoscript.pest +++ b/cozo-core/src/cozoscript.pest @@ -6,10 +6,10 @@ * You can obtain one at https://mozilla.org/MPL/2.0/. */ -script = _{sys_script | multi_script | query_script} +script = _{sys_script | imperative_script | query_script} query_script = {SOI ~ (option | rule | const_rule | fixed_rule)+ ~ EOI} query_script_inner = {"{" ~ (option | rule | const_rule | fixed_rule)+ ~ "}"} -multi_script = {SOI ~ query_script_inner+ ~ EOI} +imperative_script = {SOI ~ imperative_stmt+ ~ EOI} sys_script = {SOI ~ "::" ~ (compact_op | list_relations_op | list_relation_op | remove_relations_op | trigger_relation_op | trigger_relation_show_op | rename_relations_op | running_op | kill_op | explain_op | access_level_op) ~ EOI} @@ -199,22 +199,22 @@ validity_type = {"Validity"} list_type = {"[" ~ col_type ~ (";" ~ expr)? ~ "]"} tuple_type = {"(" ~ (col_type ~ ",")* ~ col_type? ~ ")"} -imperative_script = { - break_stmt | continue_stmt | return_stmt | mark_stmt | goto_stmt | - query_script_inner | if_chain | while_block | do_while_block | temp_swap +imperative_stmt = { + break_stmt | continue_stmt | return_stmt | + query_script_inner | ignore_error_script | if_chain | while_block | do_while_block | temp_swap | remove_stmt } -if_chain = {"%if" ~ query_script_inner ~ "%then" ~ imperative_block ~ elif_block* ~ else_block? ~ "%end" } -elif_block = {"%elif" ~ query_script_inner ~ "%then" ~ imperative_block} -else_block = {"%else" ~ imperative_block} -imperative_block = {imperative_script+} +if_chain = {"%if" ~ (ident | query_script_inner) + ~ "%then" ~ imperative_block + ~ ("%else" ~ imperative_block)? ~ "%end" } +imperative_block = {imperative_stmt+} break_stmt = {"%break" ~ ident?} +ignore_error_script = {"%ignore_error" ~ query_script_inner} continue_stmt = {"%continue" ~ ident?} -mark_stmt = {"%mark" ~ ident} -goto_stmt = {"%goto" ~ ident} -return_stmt = {"%return" ~ query_script_inner?} -while_block = {"%while" ~ query_script_inner ~ "%loop" ~ imperative_block ~ "%end"} -do_while_block = {"%loop" ~ imperative_block ~ "%while" ~ query_script_inner ~ "%end"} +return_stmt = {"%return" ~ (ident | query_script_inner)?} +while_block = {("%mark" ~ ident)? ~ "%while" ~ (ident | query_script_inner) ~ "%loop" ~ imperative_block ~ "%end"} +do_while_block = {("%mark" ~ ident)? ~ "%loop" ~ imperative_block ~ "%while" ~ (query_script_inner | ident) ~ "%end"} temp_swap = {"%swap" ~ underscore_ident ~ underscore_ident} +remove_stmt = {"%remove" ~ underscore_ident} /* diff --git a/cozo-core/src/data/expr.rs b/cozo-core/src/data/expr.rs index ef167fb0..312d415e 100644 --- a/cozo-core/src/data/expr.rs +++ b/cozo-core/src/data/expr.rs @@ -220,7 +220,7 @@ impl Display for Expr { #[derive(Debug, Error, Diagnostic)] #[error("Found value {1:?} where a boolean value is expected")] #[diagnostic(code(eval::predicate_not_bool))] -struct PredicateTypeError(#[label] SourceSpan, DataValue); +pub(crate) struct PredicateTypeError(#[label] pub(crate) SourceSpan, pub(crate) DataValue); #[derive(Debug, Error, Diagnostic)] #[error("Cannot build entity ID from {0:?}")] diff --git a/cozo-core/src/data/program.rs b/cozo-core/src/data/program.rs index 9a163a8a..31b0bfc9 100644 --- a/cozo-core/src/data/program.rs +++ b/cozo-core/src/data/program.rs @@ -185,7 +185,7 @@ impl TempSymbGen { } } -#[derive(Debug)] +#[derive(Debug, Clone)] pub(crate) enum InputInlineRulesOrFixed { Rules { rules: Vec }, Fixed { fixed: FixedRuleApply }, @@ -215,6 +215,7 @@ impl InputInlineRulesOrFixed { // } } +#[derive(Clone)] pub(crate) struct FixedRuleApply { pub(crate) fixed_handle: FixedRuleHandle, pub(crate) rule_args: Vec, @@ -329,6 +330,7 @@ impl Debug for MagicFixedRuleApply { } } +#[derive(Clone)] pub(crate) enum FixedRuleArg { InMem { name: Symbol, @@ -422,7 +424,7 @@ impl MagicFixedRuleRuleArg { } } -#[derive(Debug)] +#[derive(Debug, Clone)] pub(crate) struct InputProgram { pub(crate) prog: BTreeMap, pub(crate) out_opts: QueryOutOptions, @@ -520,6 +522,14 @@ impl InputProgram { // self.prog.values().any(|rule| rule.used_rule(rule_name)) // } + pub(crate) fn needs_write_tx(&self) -> bool { + if let Some((h, _)) = &self.out_opts.store_relation { + !h.name.name.starts_with('_') + } else { + false + } + } + pub(crate) fn get_entry_arity(&self) -> Result { if let Some(entry) = self.prog.get(&Symbol::new(PROG_ENTRY, SourceSpan(0, 0))) { return match entry { @@ -824,7 +834,7 @@ impl MagicSymbol { } } -#[derive(Debug)] +#[derive(Debug, Clone)] pub(crate) struct InputInlineRule { pub(crate) head: Vec, pub(crate) aggr: Vec)>>, @@ -861,6 +871,7 @@ impl MagicInlineRule { } } +#[derive(Clone)] pub(crate) enum InputAtom { Rule { inner: InputRuleApplyAtom, diff --git a/cozo-core/src/parse/imperative.rs b/cozo-core/src/parse/imperative.rs new file mode 100644 index 00000000..6b3cf2ae --- /dev/null +++ b/cozo-core/src/parse/imperative.rs @@ -0,0 +1,224 @@ +/* + * Copyright 2023, The Cozo Project Authors. + * + * This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. + * If a copy of the MPL was not distributed with this file, + * You can obtain one at https://mozilla.org/MPL/2.0/. + * + */ + +use crate::parse::query::parse_query; +use crate::parse::{ExtractSpan, ImperativeProgram, ImperativeStmt, Pair, Pairs, Rule, SourceSpan}; +use crate::{DataValue, FixedRule, ValidityTs}; +use either::{Left, Right}; +use itertools::Itertools; +use miette::{Diagnostic, Result}; +use smartstring::SmartString; +use std::collections::BTreeMap; +use std::sync::Arc; +use thiserror::Error; + +pub(crate) fn parse_imperative( + src: Pairs<'_>, + param_pool: &BTreeMap, + fixed_rules: &BTreeMap>>, + cur_vld: ValidityTs, +) -> Result { + let mut collected = vec![]; + + for pair in src { + if pair.as_rule() != Rule::EOI { + collected.push(parse_imperative_stmt( + pair, + param_pool, + fixed_rules, + cur_vld, + )?); + } + } + + Ok(collected) +} + +#[derive(Debug, Error, Diagnostic)] +#[error("cannot manipulate permanent relation in imperative script")] +#[diagnostic(code(parser::manipulate_perm_rel_in_script))] +struct CannotManipulatePermRel(#[label] SourceSpan); + +#[derive(Debug, Error, Diagnostic)] +#[error("duplicate marker found")] +#[diagnostic(code(parser::dup_marker))] +struct DuplicateMarker(#[label] SourceSpan); + +fn parse_imperative_stmt( + pair: Pair<'_>, + param_pool: &BTreeMap, + fixed_rules: &BTreeMap>>, + cur_vld: ValidityTs, +) -> Result { + Ok(match pair.as_rule() { + Rule::break_stmt => { + let span = pair.extract_span(); + let target = pair + .into_inner() + .next() + .map(|p| SmartString::from(p.as_str())); + ImperativeStmt::Break { target, span } + } + Rule::continue_stmt => { + let span = pair.extract_span(); + let target = pair + .into_inner() + .next() + .map(|p| SmartString::from(p.as_str())); + ImperativeStmt::Continue { target, span } + } + Rule::return_stmt => { + // let span = pair.extract_span(); + match pair.into_inner().next() { + None => ImperativeStmt::ReturnNil, + Some(p) => match p.as_rule() { + Rule::ident => { + let rel = SmartString::from(p.as_str()); + ImperativeStmt::ReturnTemp { rel } + } + Rule::query_script_inner => { + let prog = parse_query(p.into_inner(), param_pool, fixed_rules, cur_vld)?; + ImperativeStmt::ReturnProgram { prog } + } + _ => unreachable!(), + }, + } + } + Rule::if_chain => { + let span = pair.extract_span(); + let mut inner = pair.into_inner(); + let condition = inner.next().unwrap(); + let cond = match condition.as_rule() { + Rule::ident => Left(SmartString::from(condition.as_str())), + Rule::query_script_inner => Right(parse_query( + condition.into_inner(), + param_pool, + fixed_rules, + cur_vld, + )?), + _ => unreachable!(), + }; + let body = inner + .next() + .unwrap() + .into_inner() + .map(|p| parse_imperative_stmt(p, param_pool, fixed_rules, cur_vld)) + .try_collect()?; + let else_body = match inner.next() { + None => vec![], + Some(rest) => rest + .into_inner() + .map(|p| parse_imperative_stmt(p, param_pool, fixed_rules, cur_vld)) + .try_collect()?, + }; + ImperativeStmt::If { + condition: cond, + then_branch: body, + else_branch: else_body, + span, + } + } + Rule::while_block => { + let span = pair.extract_span(); + let mut inner = pair.into_inner(); + let mut mark = None; + let mut nxt = inner.next().unwrap(); + if nxt.as_rule() == Rule::ident { + mark = Some(SmartString::from(nxt.as_str())); + nxt = inner.next().unwrap(); + } + let cond = match nxt.as_rule() { + Rule::ident => Left(SmartString::from(nxt.as_str())), + Rule::query_script_inner => Right(parse_query( + nxt.into_inner(), + param_pool, + fixed_rules, + cur_vld, + )?), + _ => unreachable!(), + }; + let body = parse_imperative( + inner.next().unwrap().into_inner(), + param_pool, + fixed_rules, + cur_vld, + )?; + ImperativeStmt::While { + label: mark, + condition: cond, + body, + span, + } + } + Rule::do_while_block => { + let span = pair.extract_span(); + let mut inner = pair.into_inner(); + let mut mark = None; + let mut nxt = inner.next().unwrap(); + if nxt.as_rule() == Rule::ident { + mark = Some(SmartString::from(nxt.as_str())); + nxt = inner.next().unwrap(); + } + let body = parse_imperative( + inner.next().unwrap().into_inner(), + param_pool, + fixed_rules, + cur_vld, + )?; + let cond = match nxt.as_rule() { + Rule::ident => Left(SmartString::from(nxt.as_str())), + Rule::query_script_inner => Right(parse_query( + nxt.into_inner(), + param_pool, + fixed_rules, + cur_vld, + )?), + _ => unreachable!(), + }; + ImperativeStmt::DoWhile { + label: mark, + body, + condition: cond, + span, + } + } + Rule::temp_swap => { + // let span = pair.extract_span(); + let mut pairs = pair.into_inner(); + let left = pairs.next().unwrap(); + let left_name = left.as_str(); + let right = pairs.next().unwrap(); + let right_name = right.as_str(); + + ImperativeStmt::TempSwap { + left: SmartString::from(left_name), + right: SmartString::from(right_name), + } + } + Rule::remove_stmt => { + // let span = pair.extract_span(); + let name_p = pair.into_inner().next().unwrap(); + let name = name_p.as_str(); + + ImperativeStmt::TempRemove { + temp: SmartString::from(name), + } + } + Rule::query_script_inner => { + let prog = parse_query(pair.into_inner(), param_pool, fixed_rules, cur_vld)?; + ImperativeStmt::Program { prog } + } + Rule::ignore_error_script => { + let pair = pair.into_inner().next().unwrap(); + let prog = parse_query(pair.into_inner(), param_pool, fixed_rules, cur_vld)?; + ImperativeStmt::IgnoreErrorProgram { prog } + } + _ => unreachable!(), + }) +} diff --git a/cozo-core/src/parse/mod.rs b/cozo-core/src/parse/mod.rs index 4eb72ec2..f14c5194 100644 --- a/cozo-core/src/parse/mod.rs +++ b/cozo-core/src/parse/mod.rs @@ -6,11 +6,12 @@ * You can obtain one at https://mozilla.org/MPL/2.0/. */ +use either::Either; use std::cmp::{max, min}; use std::collections::BTreeMap; use std::sync::Arc; -use miette::{bail, ensure, Diagnostic, IntoDiagnostic, Result}; +use miette::{bail, Diagnostic, IntoDiagnostic, Result}; use pest::error::InputLocation; use pest::Parser; use smartstring::{LazyCompact, SmartString}; @@ -19,12 +20,14 @@ use thiserror::Error; use crate::data::program::InputProgram; use crate::data::relation::NullableColType; use crate::data::value::{DataValue, ValidityTs}; +use crate::parse::imperative::parse_imperative; use crate::parse::query::parse_query; use crate::parse::schema::parse_nullable_type; use crate::parse::sys::{parse_sys, SysOp}; use crate::FixedRule; pub(crate) mod expr; +pub(crate) mod imperative; pub(crate) mod query; pub(crate) mod schema; pub(crate) mod sys; @@ -37,31 +40,105 @@ pub(crate) type Pair<'a> = pest::iterators::Pair<'a, Rule>; pub(crate) type Pairs<'a> = pest::iterators::Pairs<'a, Rule>; pub(crate) enum CozoScript { - Multi(Vec), + Single(InputProgram), + Imperative(ImperativeProgram), Sys(SysOp), } -pub(crate) enum ImperativeElement { - JumpIfNot { - goto: usize, +pub(crate) enum ImperativeStmt { + Break { + target: Option>, + span: SourceSpan, }, - Goto { - goto: usize, + Continue { + target: Option>, + span: SourceSpan, }, - Prog { - prog: Box, - ignore_error: bool, + ReturnNil, + ReturnProgram { + prog: InputProgram, + // span: SourceSpan, }, - Swap { + ReturnTemp { + rel: SmartString + }, + Program { + prog: InputProgram + }, + IgnoreErrorProgram { + prog: InputProgram + }, + If { + condition: ImperativeCondition, + then_branch: ImperativeProgram, + else_branch: ImperativeProgram, + span: SourceSpan, + }, + While { + label: Option>, + condition: ImperativeCondition, + body: ImperativeProgram, + span: SourceSpan, + }, + DoWhile { + label: Option>, + body: ImperativeProgram, + condition: ImperativeCondition, + span: SourceSpan, + }, + TempSwap { left: SmartString, right: SmartString, + // span: SourceSpan, }, - Remove { - name: SmartString, + TempRemove { + temp: SmartString, + // span: SourceSpan, }, } -pub(crate) struct ImperativeProgram {} +pub(crate) type ImperativeCondition = Either, InputProgram>; + +pub(crate) type ImperativeProgram = Vec; + +impl ImperativeStmt { + pub(crate) fn needs_write_tx(&self) -> bool { + match self { + ImperativeStmt::ReturnProgram { prog, .. } + | ImperativeStmt::Program { prog, .. } + | ImperativeStmt::IgnoreErrorProgram { prog, .. } => prog.needs_write_tx(), + ImperativeStmt::If { + condition, + then_branch, + else_branch, + .. + } => { + (match condition { + ImperativeCondition::Left(_) => false, + ImperativeCondition::Right(prog) => prog.needs_write_tx(), + }) || then_branch.iter().any(|p| p.needs_write_tx()) + || else_branch.iter().any(|p| p.needs_write_tx()) + } + ImperativeStmt::While { + condition, body, .. + } + | ImperativeStmt::DoWhile { + body, condition, .. + } => { + (match condition { + ImperativeCondition::Left(_) => false, + ImperativeCondition::Right(prog) => prog.needs_write_tx(), + }) || body.iter().any(|p| p.needs_write_tx()) + } + ImperativeStmt::ReturnTemp { .. } + | ImperativeStmt::Break { .. } + | ImperativeStmt::Continue { .. } + | ImperativeStmt::ReturnNil { .. } + | ImperativeStmt::TempSwap { .. } + | ImperativeStmt::TempRemove { .. } => false, + } + } +} impl CozoScript { pub(crate) fn get_single_program(self) -> Result { @@ -70,11 +147,8 @@ impl CozoScript { #[diagnostic(code(parser::expect_singleton))] struct ExpectSingleProgram; match self { - CozoScript::Multi(v) => { - ensure!(v.len() == 1, ExpectSingleProgram); - Ok(v.into_iter().next().unwrap()) - } - CozoScript::Sys(_) => { + CozoScript::Single(s) => Ok(s), + CozoScript::Imperative(_) | CozoScript::Sys(_) => { bail!(ExpectSingleProgram) } } @@ -129,7 +203,7 @@ pub(crate) fn parse_type(src: &str) -> Result { pub(crate) fn parse_script( src: &str, param_pool: &BTreeMap, - algorithms: &BTreeMap>>, + fixed_rules: &BTreeMap>>, cur_vld: ValidityTs, ) -> Result { let parsed = CozoScriptParser::parse(Rule::script, src) @@ -144,27 +218,18 @@ pub(crate) fn parse_script( .unwrap(); Ok(match parsed.as_rule() { Rule::query_script => { - let q = parse_query(parsed.into_inner(), param_pool, algorithms, cur_vld)?; - CozoScript::Multi(vec![q]) + let q = parse_query(parsed.into_inner(), param_pool, fixed_rules, cur_vld)?; + CozoScript::Single(q) } - Rule::multi_script => { - let mut qs = vec![]; - for pair in parsed.into_inner() { - if pair.as_rule() != Rule::EOI { - qs.push(parse_query( - pair.into_inner(), - param_pool, - algorithms, - cur_vld, - )?); - } - } - CozoScript::Multi(qs) + Rule::imperative_script => { + let p = parse_imperative(parsed.into_inner(), param_pool, fixed_rules, cur_vld)?; + CozoScript::Imperative(p) } + Rule::sys_script => CozoScript::Sys(parse_sys( parsed.into_inner(), param_pool, - algorithms, + fixed_rules, cur_vld, )?), _ => unreachable!(), diff --git a/cozo-core/src/runtime/db.rs b/cozo-core/src/runtime/db.rs index 0356ecd9..5150a2ba 100644 --- a/cozo-core/src/runtime/db.rs +++ b/cozo-core/src/runtime/db.rs @@ -17,22 +17,27 @@ use std::thread; #[allow(unused_imports)] use std::time::{Duration, SystemTime, UNIX_EPOCH}; -use either::{Left, Right}; +use either::{Either, Left, Right}; use itertools::Itertools; #[allow(unused_imports)] use miette::{bail, ensure, miette, Diagnostic, IntoDiagnostic, Result, WrapErr}; use serde_json::json; +use smartstring::{LazyCompact, SmartString}; use thiserror::Error; -use crate::data::functions::current_validity; +use crate::data::expr::PredicateTypeError; +use crate::data::functions::{current_validity, op_to_bool}; use crate::data::json::JsonValue; use crate::data::program::{InputProgram, QueryAssertion, RelationOp}; use crate::data::relation::ColumnDef; +use crate::data::symb::Symbol; use crate::data::tuple::{Tuple, TupleT}; use crate::data::value::{DataValue, ValidityTs, LARGEST_UTF_CHAR}; use crate::fixed_rule::DEFAULT_FIXED_RULES; use crate::parse::sys::SysOp; -use crate::parse::{parse_script, CozoScript, SourceSpan}; +use crate::parse::{ + parse_script, CozoScript, ImperativeCondition, ImperativeProgram, ImperativeStmt, SourceSpan, +}; use crate::query::compile::{CompiledProgram, CompiledRule, CompiledRuleSet}; use crate::query::ra::{ FilteredRA, InnerJoin, NegJoin, RelAlgebra, ReorderRA, StoredRA, StoredWithValidityRA, @@ -101,7 +106,7 @@ impl Debug for Db { #[diagnostic(code(db::init))] pub(crate) struct BadDbInit(#[help] pub(crate) String); -#[derive(serde_derive::Serialize, serde_derive::Deserialize, Debug, Clone)] +#[derive(serde_derive::Serialize, serde_derive::Deserialize, Debug, Clone, Default)] /// Rows in a relation, together with headers for the fields. pub struct NamedRows { /// The headers @@ -132,6 +137,12 @@ impl NamedRows { const STATUS_STR: &str = "status"; const OK_STR: &str = "OK"; +enum ControlCode { + Termination(NamedRows), + Break(Option>, SourceSpan), + Continue(Option>, SourceSpan), +} + impl<'s, S: Storage<'s>> Db { /// Create a new database object with the given storage. /// You must call [`initialize`](Self::initialize) immediately after creation. @@ -499,6 +510,51 @@ impl<'s, S: Storage<'s>> Db { }; Ok(ret) } + fn execute_imperative_condition( + &'s self, + p: &ImperativeCondition, + tx: &mut SessionTx<'_>, + cleanups: &mut Vec<(Vec, Vec)>, + cur_vld: ValidityTs, + span: SourceSpan, + ) -> Result { + let res = match p { + Left(rel) => { + let relation = tx.get_relation(rel, false)?; + relation.as_named_rows(tx)? + } + Right(p) => self.execute_single_program(p.clone(), tx, cleanups, cur_vld)?, + }; + Ok(match res.rows.first() { + None => false, + Some(row) => { + if row.is_empty() { + false + } else { + op_to_bool(&row[row.len() - 2..])? + .get_bool() + .ok_or_else(|| PredicateTypeError(span, row.last().cloned().unwrap()))? + } + } + }) + } + fn execute_single_program( + &'s self, + p: InputProgram, + tx: &mut SessionTx<'_>, + cleanups: &mut Vec<(Vec, Vec)>, + cur_vld: ValidityTs, + ) -> Result { + #[allow(unused_variables)] + let sleep_opt = p.out_opts.sleep; + let (q_res, q_cleanups) = self.run_query(tx, p, cur_vld)?; + cleanups.extend(q_cleanups); + #[cfg(not(target_arch = "wasm32"))] + if let Some(secs) = sleep_opt { + thread::sleep(Duration::from_micros((secs * 1000000.) as u64)); + } + Ok(q_res) + } fn do_run_script( &'s self, payload: &str, @@ -506,19 +562,10 @@ impl<'s, S: Storage<'s>> Db { cur_vld: ValidityTs, ) -> Result { match parse_script(payload, param_pool, &self.algorithms, cur_vld)? { - CozoScript::Multi(ps) => { - let is_write = ps.iter().any(|p| { - if let Some((h, _)) = &p.out_opts.store_relation { - !h.name.name.starts_with('_') - } else { - false - } - }); + CozoScript::Single(p) => { + let is_write = p.needs_write_tx(); let mut cleanups = vec![]; - let mut res = NamedRows { - headers: vec![], - rows: vec![], - }; + let res; { let mut tx = if is_write { self.transact_write()? @@ -526,17 +573,8 @@ impl<'s, S: Storage<'s>> Db { self.transact()? }; - for p in ps.into_iter() { - #[allow(unused_variables)] - let sleep_opt = p.out_opts.sleep; - let (q_res, q_cleanups) = self.run_query(&mut tx, p, cur_vld)?; - res = q_res; - cleanups.extend(q_cleanups); - #[cfg(not(target_arch = "wasm32"))] - if let Some(secs) = sleep_opt { - thread::sleep(Duration::from_micros((secs * 1000000.) as u64)); - } - } + res = self.execute_single_program(p, &mut tx, &mut cleanups, cur_vld)?; + if is_write { tx.commit_tx()?; } else { @@ -550,9 +588,212 @@ impl<'s, S: Storage<'s>> Db { } Ok(res) } + CozoScript::Imperative(ps) => { + let is_write = ps.iter().any(|p| p.needs_write_tx()); + let mut cleanups: Vec<(Vec, Vec)> = vec![]; + let ret; + { + let mut tx = if is_write { + self.transact_write()? + } else { + self.transact()? + }; + match self.execute_imperative_stmts(&ps, &mut tx, &mut cleanups, cur_vld)? { + Left(res) => ret = res, + Right(ctrl) => match ctrl { + ControlCode::Termination(res) => { + ret = res; + } + ControlCode::Break(_, span) | ControlCode::Continue(_, span) => { + #[derive(Debug, Error, Diagnostic)] + #[error("control flow has nowhere to go")] + #[diagnostic(code(eval::dangling_ctrl_flow))] + struct DanglingControlFlow(#[label] SourceSpan); + + bail!(DanglingControlFlow(span)) + } + }, + } + + if is_write { + tx.commit_tx()?; + } else { + tx.commit_tx()?; + assert!(cleanups.is_empty(), "non-empty cleanups on read-only tx"); + } + } + for (lower, upper) in cleanups { + self.db.del_range(&lower, &upper)?; + } + Ok(ret) + } CozoScript::Sys(op) => self.run_sys_op(op), } } + fn execute_imperative_stmts( + &'s self, + ps: &ImperativeProgram, + tx: &mut SessionTx<'_>, + cleanups: &mut Vec<(Vec, Vec)>, + cur_vld: ValidityTs, + ) -> Result> { + let mut ret = NamedRows::default(); + for p in ps { + match p { + ImperativeStmt::Break { target, span, .. } => { + return Ok(Right(ControlCode::Break(target.clone(), *span))); + } + ImperativeStmt::Continue { target, span, .. } => { + return Ok(Right(ControlCode::Continue(target.clone(), *span))); + } + ImperativeStmt::ReturnNil { .. } => { + return Ok(Right(ControlCode::Termination(NamedRows::default()))) + } + ImperativeStmt::ReturnProgram { prog, .. } => { + ret = self.execute_single_program(prog.clone(), tx, cleanups, cur_vld)?; + return Ok(Right(ControlCode::Termination(ret))); + } + ImperativeStmt::ReturnTemp { rel, .. } => { + let relation = tx.get_relation(rel, false)?; + return Ok(Right(ControlCode::Termination(relation.as_named_rows(tx)?))); + } + ImperativeStmt::Program { prog, .. } => { + ret = self.execute_single_program(prog.clone(), tx, cleanups, cur_vld)?; + } + ImperativeStmt::IgnoreErrorProgram { prog, .. } => { + match self.execute_single_program(prog.clone(), tx, cleanups, cur_vld) { + Ok(res) => ret = res, + Err(_) => { + ret = NamedRows { + headers: vec!["status".to_string()], + rows: vec![vec![DataValue::from("FAILED")]], + } + } + } + } + ImperativeStmt::If { + condition, + then_branch, + else_branch, + span, + } => { + let cond_val = + self.execute_imperative_condition(condition, tx, cleanups, cur_vld, *span)?; + let to_execute = if cond_val { then_branch } else { else_branch }; + match self.execute_imperative_stmts(to_execute, tx, cleanups, cur_vld)? { + Left(rows) => { + ret = rows; + } + Right(ctrl) => return Ok(Right(ctrl)), + } + } + ImperativeStmt::While { + label, + condition, + body, + span, + } => { + ret = Default::default(); + loop { + let cond_val = self.execute_imperative_condition( + condition, tx, cleanups, cur_vld, *span, + )?; + if cond_val { + match self.execute_imperative_stmts(body, tx, cleanups, cur_vld)? { + Left(_) => {} + Right(ctrl) => match ctrl { + ControlCode::Termination(ret) => { + return Ok(Right(ControlCode::Termination(ret))) + } + ControlCode::Break(break_label, span) => { + if break_label.is_none() || break_label == *label { + break; + } else { + return Ok(Right(ControlCode::Break( + break_label, + span, + ))); + } + } + ControlCode::Continue(cont_label, span) => { + if cont_label.is_none() || cont_label == *label { + continue; + } else { + return Ok(Right(ControlCode::Continue( + cont_label, span, + ))); + } + } + }, + } + } else { + ret = NamedRows::default(); + break; + } + } + } + ImperativeStmt::DoWhile { + label, + body, + condition, + span, + } => { + ret = Default::default(); + loop { + match self.execute_imperative_stmts(body, tx, cleanups, cur_vld)? { + Left(_) => {} + Right(ctrl) => match ctrl { + ControlCode::Termination(ret) => { + return Ok(Right(ControlCode::Termination(ret))) + } + ControlCode::Break(break_label, span) => { + if break_label.is_none() || break_label == *label { + break; + } else { + return Ok(Right(ControlCode::Break(break_label, span))); + } + } + ControlCode::Continue(cont_label, span) => { + if cont_label.is_none() || cont_label == *label { + continue; + } else { + return Ok(Right(ControlCode::Continue(cont_label, span))); + } + } + }, + } + } + let cond_val = + self.execute_imperative_condition(condition, tx, cleanups, cur_vld, *span)?; + if !cond_val { + ret = NamedRows::default(); + break; + } + } + ImperativeStmt::TempSwap { left, right, .. } => { + tx.rename_temp_relation( + Symbol::new(left.clone(), Default::default()), + Symbol::new(SmartString::from("*temp*"), Default::default()), + )?; + tx.rename_temp_relation( + Symbol::new(right.clone(), Default::default()), + Symbol::new(left.clone(), Default::default()), + )?; + tx.rename_temp_relation( + Symbol::new(SmartString::from("*temp*"), Default::default()), + Symbol::new(right.clone(), Default::default()), + )?; + ret = NamedRows::default(); + break; + } + ImperativeStmt::TempRemove { temp, .. } => { + tx.destroy_temp_relation(temp)?; + ret = NamedRows::default(); + } + } + } + return Ok(Left(ret)); + } fn explain_compiled(&self, strata: &[CompiledProgram]) -> Result { let mut ret: Vec = vec![]; const STRATUM: &str = "stratum"; @@ -676,7 +917,11 @@ impl<'s, S: Storage<'s>> Db { rel_stack.push(relation); ("reorder", json!(null), json!(null), json!(null)) } - RelAlgebra::Filter(FilteredRA { parent, filters: pred, .. }) => { + RelAlgebra::Filter(FilteredRA { + parent, + filters: pred, + .. + }) => { rel_stack.push(parent); ( "filter", diff --git a/cozo-core/src/runtime/relation.rs b/cozo-core/src/runtime/relation.rs index 71284266..db10e785 100644 --- a/cozo-core/src/runtime/relation.rs +++ b/cozo-core/src/runtime/relation.rs @@ -9,6 +9,7 @@ use std::fmt::{Debug, Display, Formatter}; use std::sync::atomic::Ordering; +use itertools::Itertools; use log::error; use miette::{bail, ensure, Diagnostic, Result}; use rmp_serde::Serializer; @@ -23,7 +24,7 @@ use crate::data::tuple::{decode_tuple_from_key, Tuple, TupleT, ENCODED_KEY_MIN_L use crate::data::value::{DataValue, ValidityTs}; use crate::parse::SourceSpan; use crate::runtime::transact::SessionTx; -use crate::StoreTx; +use crate::{NamedRows, StoreTx}; #[derive( Copy, @@ -126,6 +127,22 @@ impl RelationHandle { ret.extend(prefix_bytes); ret } + pub(crate) fn as_named_rows(&self, tx: &SessionTx<'_>) -> Result { + let rows: Vec<_> = self.scan_all(tx).try_collect()?; + let mut headers = self + .metadata + .keys + .iter() + .map(|col| col.name.to_string()) + .collect_vec(); + headers.extend( + self.metadata + .non_keys + .iter() + .map(|col| col.name.to_string()), + ); + Ok(NamedRows { headers, rows }) + } #[allow(dead_code)] pub(crate) fn amend_key_prefix(&self, data: &mut [u8]) { let prefix_bytes = self.id.0.to_be_bytes(); @@ -521,6 +538,16 @@ impl<'a> SessionTx<'a> { let upper_bound = Tuple::default().encode_as_key(store.id.next()); Ok((lower_bound, upper_bound)) } + pub(crate) fn destroy_temp_relation(&mut self, name: &str) -> Result<()> { + self.get_relation(name, true)?; + let key = DataValue::from(name); + let encoded = vec![key].encode_as_key(RelationId::SYSTEM); + self.temp_store_tx.del(&encoded)?; + // TODO: at the moment this doesn't actually remove anything from the store + // let lower_bound = Tuple::default().encode_as_key(store.id); + // let upper_bound = Tuple::default().encode_as_key(store.id.next()); + Ok(()) + } pub(crate) fn set_access_level(&mut self, rel: Symbol, level: AccessLevel) -> Result<()> { let mut meta = self.get_relation(&rel, true)?; meta.access_level = level; @@ -563,6 +590,27 @@ impl<'a> SessionTx<'a> { self.store_tx.del(&old_encoded)?; self.store_tx.put(&new_encoded, &meta_val)?; + Ok(()) + } + pub(crate) fn rename_temp_relation(&mut self, old: Symbol, new: Symbol) -> Result<()> { + let new_key = DataValue::Str(new.name.clone()); + let new_encoded = vec![new_key].encode_as_key(RelationId::SYSTEM); + + if self.temp_store_tx.exists(&new_encoded, true)? { + bail!(RelNameConflictError(new.name.to_string())) + }; + + let old_key = DataValue::Str(old.name.clone()); + let old_encoded = vec![old_key].encode_as_key(RelationId::SYSTEM); + + let mut rel = self.get_relation(&old, true)?; + rel.name = new.name; + + let mut meta_val = vec![]; + rel.serialize(&mut Serializer::new(&mut meta_val)).unwrap(); + self.temp_store_tx.del(&old_encoded)?; + self.temp_store_tx.put(&new_encoded, &meta_val)?; + Ok(()) } }