diff --git a/src/transact/pull.rs b/src/transact/pull.rs index 7e880f78..cec78c5c 100644 --- a/src/transact/pull.rs +++ b/src/transact/pull.rs @@ -10,14 +10,17 @@ use crate::data::value::{StaticValue, Value}; use crate::runtime::transact::SessionTx; use anyhow::Result; use serde_json::{json, Map}; +use std::collections::HashSet; pub(crate) type PullSpecs = Vec; +#[derive(Debug, Clone)] pub(crate) enum PullSpec { PullAll, Attr(AttrPullSpec), } +#[derive(Debug, Clone)] pub(crate) struct AttrPullSpec { pub(crate) attr: Attribute, pub(crate) reverse: bool, @@ -25,11 +28,8 @@ pub(crate) struct AttrPullSpec { pub(crate) cardinality: AttributeCardinality, pub(crate) take: Option, pub(crate) nested: PullSpecs, -} - -pub(crate) struct RecursePullSpec { - pub(crate) parent: Keyword, - pub(crate) max_depth: Option, + pub(crate) recursive: bool, + pub(crate) recursion_limit: Option, } impl SessionTx { @@ -39,14 +39,15 @@ impl SessionTx { vld: Validity, spec: &PullSpec, collector: &mut Map, + recursive_seen: &mut Option>, ) -> Result<()> { match spec { PullSpec::PullAll => self.pull_all(eid, vld, collector), PullSpec::Attr(a_spec) => { if a_spec.reverse { - self.pull_attr_rev(eid, vld, a_spec, collector) + self.pull_attr_rev(eid, vld, a_spec, collector, recursive_seen) } else { - self.pull_attr(eid, vld, a_spec, collector) + self.pull_attr(eid, vld, a_spec, collector, recursive_seen) } } } @@ -57,13 +58,14 @@ impl SessionTx { vld: Validity, spec: &AttrPullSpec, collector: &mut Map, + recursive_seen: &mut Option>, ) -> Result<()> { if spec.cardinality.is_one() { if let Some(found) = self.triple_ea_before_scan(eid, spec.attr.id, vld).next() { let (_, _, value) = found?; - self.pull_attr_collect(spec, value, vld, collector)?; + self.pull_attr_collect(spec, value, vld, collector, recursive_seen)?; } else { - self.pull_attr_collect(spec, Value::Null, vld, collector)?; + self.pull_attr_collect(spec, Value::Null, vld, collector, recursive_seen)?; } } else { let mut collection: Vec = vec![]; @@ -77,7 +79,7 @@ impl SessionTx { } } } - self.pull_attr_collect_many(spec, collection, vld, collector)?; + self.pull_attr_collect_many(spec, collection, vld, collector, recursive_seen)?; } Ok(()) } @@ -87,14 +89,55 @@ impl SessionTx { value: StaticValue, vld: Validity, collector: &mut Map, + recursive_seen: &mut Option>, ) -> Result<()> { - if spec.nested.is_empty() { + if spec.recursive { + let mut nested = spec.nested.clone(); + let mut to_add = spec.clone(); + if let Some(n) = to_add.recursion_limit { + if n == 0 { + return Ok(()); + } else { + to_add.recursion_limit = Some(n - 1); + } + } + + let eid = value.get_entity_id()?; + + if let Some(inner) = recursive_seen { + if !inner.insert(eid) { + collector.insert(spec.name.to_string_no_prefix(), value.into()); + return Ok(()); + } + } + + nested.push(PullSpec::Attr(to_add)); + + let mut sub_collector = Map::default(); + if recursive_seen.is_some() { + for sub_spec in &spec.nested { + self.pull(eid, vld, sub_spec, &mut sub_collector, recursive_seen)?; + } + } else { + let mut recursive_seen_inner = Some(Default::default()); + for sub_spec in &spec.nested { + self.pull( + eid, + vld, + sub_spec, + &mut sub_collector, + &mut recursive_seen_inner, + )?; + } + } + collector.insert(spec.name.to_string_no_prefix(), sub_collector.into()); + } else if spec.nested.is_empty() { collector.insert(spec.name.to_string_no_prefix(), value.into()); } else { let eid = value.get_entity_id()?; let mut sub_collector = Map::default(); for sub_spec in &spec.nested { - self.pull(eid, vld, sub_spec, &mut sub_collector)?; + self.pull(eid, vld, sub_spec, &mut sub_collector, recursive_seen)?; } collector.insert(spec.name.to_string_no_prefix(), sub_collector.into()); } @@ -106,8 +149,52 @@ impl SessionTx { values: Vec, vld: Validity, collector: &mut Map, + recursive_seen: &mut Option>, ) -> Result<()> { - if spec.nested.is_empty() { + if spec.recursive { + if recursive_seen.is_none() { + let mut new_recursive_seen = Some(Default::default()); + return self.pull_attr_collect_many( + spec, + values, + vld, + collector, + &mut new_recursive_seen, + ); + } + + let mut nested = spec.nested.clone(); + let mut to_add = spec.clone(); + if let Some(n) = to_add.recursion_limit { + if n == 0 { + return Ok(()); + } else { + to_add.recursion_limit = Some(n - 1); + } + } + + nested.push(PullSpec::Attr(to_add)); + + let mut sub_collectors = vec![]; + + for value in values { + let eid = value.get_entity_id()?; + + if let Some(inner) = recursive_seen { + if !inner.insert(eid) { + collector.insert(spec.name.to_string_no_prefix(), value.into()); + return Ok(()); + } + } + + let mut sub_collector = Map::default(); + for sub_spec in &spec.nested { + self.pull(eid, vld, sub_spec, &mut sub_collector, recursive_seen)?; + } + sub_collectors.push(sub_collector); + } + collector.insert(spec.name.to_string_no_prefix(), sub_collectors.into()); + } else if spec.nested.is_empty() { collector.insert(spec.name.to_string_no_prefix(), values.into()); } else { let mut sub_collectors = vec![]; @@ -115,7 +202,7 @@ impl SessionTx { let eid = value.get_entity_id()?; let mut sub_collector = Map::default(); for sub_spec in &spec.nested { - self.pull(eid, vld, sub_spec, &mut sub_collector)?; + self.pull(eid, vld, sub_spec, &mut sub_collector, recursive_seen)?; } sub_collectors.push(sub_collector); } @@ -129,6 +216,7 @@ impl SessionTx { vld: Validity, spec: &AttrPullSpec, collector: &mut Map, + recursive_seen: &mut Option>, ) -> Result<()> { if spec.cardinality.is_one() { if let Some(found) = self @@ -136,9 +224,9 @@ impl SessionTx { .next() { let (_, _, value) = found?; - self.pull_attr_collect(spec, Value::EnId(value), vld, collector)?; + self.pull_attr_collect(spec, Value::EnId(value), vld, collector, recursive_seen)?; } else { - self.pull_attr_collect(spec, Value::Null, vld, collector)?; + self.pull_attr_collect(spec, Value::Null, vld, collector, recursive_seen)?; } } else { let mut collection: Vec = vec![]; @@ -152,7 +240,7 @@ impl SessionTx { } } } - self.pull_attr_collect_many(spec, collection.into(), vld, collector)?; + self.pull_attr_collect_many(spec, collection.into(), vld, collector, recursive_seen)?; } Ok(()) }