diff --git a/cozo-core/src/runtime/db.rs b/cozo-core/src/runtime/db.rs index 3c0b0950..fd932c9a 100644 --- a/cozo-core/src/runtime/db.rs +++ b/cozo-core/src/runtime/db.rs @@ -51,7 +51,7 @@ use crate::query::ra::{ use crate::runtime::callback::{ CallbackCollector, CallbackDeclaration, CallbackOp, EventCallbackRegistry, }; -use crate::runtime::relation::{extend_tuple_from_v, AccessLevel, InsufficientAccessLevel, RelationHandle, RelationId, InputRelationHandle}; +use crate::runtime::relation::{extend_tuple_from_v, AccessLevel, InsufficientAccessLevel, RelationHandle, RelationId}; use crate::runtime::transact::SessionTx; use crate::storage::temp::TempStorage; use crate::storage::{Storage, StoreTx}; @@ -1493,7 +1493,7 @@ impl<'s, S: Storage<'s>> Db { ) .wrap_err_with(|| format!("when executing against relation '{}'", meta.name))?; clean_ups.extend(to_clear); - let returned_rows = Self::get_returning_rows(callback_collector, meta, returning); + let returned_rows = tx.get_returning_rows(callback_collector, &meta.name, returning)?; Ok((returned_rows, clean_ups)) } else { // not sorting outputs @@ -1548,7 +1548,7 @@ impl<'s, S: Storage<'s>> Db { ) .wrap_err_with(|| format!("when executing against relation '{}'", meta.name))?; clean_ups.extend(to_clear); - let returned_rows = Self::get_returning_rows(callback_collector, meta, returning); + let returned_rows = tx.get_returning_rows(callback_collector, &meta.name, returning)?; Ok((returned_rows, clean_ups)) } else { @@ -1567,59 +1567,6 @@ impl<'s, S: Storage<'s>> Db { } } } - - fn get_returning_rows(callback_collector: &mut CallbackCollector, meta: &InputRelationHandle, returning: &ReturnMutation) -> NamedRows { - let returned_rows = { - match returning { - ReturnMutation::NotReturning => { - NamedRows::new( - vec![STATUS_STR.to_string()], - vec![vec![DataValue::from(OK_STR)]], - ) - } - ReturnMutation::Returning => { - let target_len = meta.metadata.keys.len() + meta.metadata.non_keys.len(); - let mut returned_rows = Vec::new(); - if let Some(collected) = callback_collector.get(&meta.name.name) { - for (kind, insertions, deletions) in collected { - let (pos_key, neg_key) = match kind { - CallbackOp::Put => { ("inserted", "replaced") } - CallbackOp::Rm => { ("requested", "deleted") } - }; - for row in &insertions.rows { - let mut v = Vec::with_capacity(target_len + 1); - v.push(DataValue::from(pos_key)); - v.extend_from_slice(&row[..target_len]); - while v.len() <= target_len { - v.push(DataValue::Null); - } - returned_rows.push(v); - } - for row in &deletions.rows { - let mut v = Vec::with_capacity(target_len + 1); - v.push(DataValue::from(neg_key)); - v.extend_from_slice(&row[..target_len]); - while v.len() <= target_len { - v.push(DataValue::Null); - } - returned_rows.push(v); - } - } - } - let mut header = vec!["_kind".to_string()]; - header.extend(meta.metadata.keys - .iter() - .chain(meta.metadata.non_keys.iter()) - .map(|s| s.name.to_string())); - NamedRows::new( - header, - returned_rows, - ) - } - } - }; - returned_rows - } pub(crate) fn list_running(&self) -> Result { let rows = self .running_queries diff --git a/cozo-core/src/runtime/tests.rs b/cozo-core/src/runtime/tests.rs index f037c3f7..991954c0 100644 --- a/cozo-core/src/runtime/tests.rs +++ b/cozo-core/src/runtime/tests.rs @@ -1256,7 +1256,23 @@ fn returning() { // for row in res.into_json()["rows"].as_array().unwrap() { // println!("{}", row); // } - assert_eq!(res.into_json()["rows"], json!([["requested", 1], ["requested", 4], ["deleted", 1]])); + assert_eq!(res.into_json()["rows"], json!([["requested", 1, null], ["requested", 4, null], ["deleted", 1, 3]])); + db.run_script(r":create todo{id:Uuid default rand_uuid_v1() => label: String, done: Bool}", Default::default()) + .unwrap(); + let res = db + .run_script( + r"?[label,done] <- [['milk',false]] :put todo{label,done} :returning", + Default::default(), + ) + .unwrap(); + assert_eq!(res.rows[0].len(), 4); + for title in res.headers.iter() { + print!("{} ", title); + } + println!(); + for row in res.into_json()["rows"].as_array().unwrap() { + println!("{}", row); + } } #[test] diff --git a/cozo-core/src/runtime/transact.rs b/cozo-core/src/runtime/transact.rs index ad7f052a..38f7f9c0 100644 --- a/cozo-core/src/runtime/transact.rs +++ b/cozo-core/src/runtime/transact.rs @@ -10,10 +10,13 @@ use std::sync::atomic::{AtomicU32, AtomicU64}; use std::sync::Arc; use miette::{bail, Result}; +use crate::data::program::ReturnMutation; use crate::data::tuple::TupleT; use crate::data::value::DataValue; use crate::fts::TokenizerCache; +use crate::{CallbackOp, NamedRows}; +use crate::runtime::callback::CallbackCollector; use crate::runtime::relation::RelationId; use crate::storage::temp::TempTx; use crate::storage::StoreTx; @@ -33,7 +36,64 @@ fn storage_version_key() -> Vec { storage_version_tuple.encode_as_key(RelationId::SYSTEM) } +const STATUS_STR: &str = "status"; +const OK_STR: &str = "OK"; + impl<'a> SessionTx<'a> { + pub(crate) fn get_returning_rows(&self, callback_collector: &mut CallbackCollector, rel: &str, returning: &ReturnMutation) -> Result { + let returned_rows = { + match returning { + ReturnMutation::NotReturning => { + NamedRows::new( + vec![STATUS_STR.to_string()], + vec![vec![DataValue::from(OK_STR)]], + ) + } + ReturnMutation::Returning => { + let meta = self.get_relation(rel, false)?; + let target_len = meta.metadata.keys.len() + meta.metadata.non_keys.len(); + let mut returned_rows = Vec::new(); + if let Some(collected) = callback_collector.get(&meta.name) { + for (kind, insertions, deletions) in collected { + let (pos_key, neg_key) = match kind { + CallbackOp::Put => { ("inserted", "replaced") } + CallbackOp::Rm => { ("requested", "deleted") } + }; + for row in &insertions.rows { + let mut v = Vec::with_capacity(target_len + 1); + v.push(DataValue::from(pos_key)); + v.extend_from_slice(row); + while v.len() <= target_len { + v.push(DataValue::Null); + } + returned_rows.push(v); + } + for row in &deletions.rows { + let mut v = Vec::with_capacity(target_len + 1); + v.push(DataValue::from(neg_key)); + v.extend_from_slice(row); + while v.len() <= target_len { + v.push(DataValue::Null); + } + returned_rows.push(v); + } + } + } + let mut header = vec!["_kind".to_string()]; + header.extend(meta.metadata.keys + .iter() + .chain(meta.metadata.non_keys.iter()) + .map(|s| s.name.to_string())); + NamedRows::new( + header, + returned_rows, + ) + } + } + }; + Ok(returned_rows) + } + pub(crate) fn init_storage(&mut self) -> Result { let tuple = vec![DataValue::Null]; let t_encoded = tuple.encode_as_key(RelationId::SYSTEM);