From 6362083faf52d2e4b61fc107482061587fe3cac5 Mon Sep 17 00:00:00 2001 From: Ziyang Hu Date: Mon, 8 Aug 2022 22:09:47 +0800 Subject: [PATCH] operators in const rule --- src/parse/query.rs | 29 +++++++++++++++++++++++------ tests/air_routes.rs | 2 +- 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/src/parse/query.rs b/src/parse/query.rs index 8deb93c9..de2928b9 100644 --- a/src/parse/query.rs +++ b/src/parse/query.rs @@ -3,7 +3,7 @@ use std::collections::{BTreeMap, BTreeSet}; use anyhow::{anyhow, bail, ensure, Result}; use itertools::Itertools; -use serde_json::{json, Map}; +use serde_json::{json, Map, Value}; use crate::data::aggr::get_aggr; use crate::data::attr::Attribute; @@ -131,8 +131,8 @@ impl SessionTx { anyhow!("data in rule is expected to be an array, got {}", v) })? .iter() - .map(|v| DataValue::from(v)) - .collect(); + .map(|v| Self::parse_const_expr(v)) + .try_collect()?; Ok(Tuple(tuple)) }) .try_collect()?; @@ -326,7 +326,7 @@ impl SessionTx { } } fn parse_input_predicate_atom(payload: &Map) -> Result { - let mut pred = Self::parse_expr(payload)?; + let mut pred = Self::parse_apply_expr(payload)?; if let Expr::Apply(op, _) = &pred { ensure!( op.is_predicate, @@ -355,7 +355,7 @@ impl SessionTx { let expr = Self::parse_expr_arg(expr)?; Ok(InputAtom::Unification(Unification { binding, expr })) } - fn parse_expr(payload: &Map) -> Result { + fn parse_apply_expr(payload: &Map) -> Result { let name = payload .get("op") .ok_or_else(|| anyhow!("expect expression to have key 'pred'"))? @@ -393,6 +393,23 @@ impl SessionTx { Ok(Expr::Apply(op, args)) } + fn parse_const_expr(payload: &JsonValue) -> Result { + Ok(match payload { + Value::Array(arr) => { + let exprs: Vec<_> = arr.iter().map(Self::parse_const_expr).try_collect()?; + DataValue::List(exprs.into()) + } + Value::Object(m) => { + let mut ex = Self::parse_apply_expr(m)?; + ex.partial_eval()?; + match ex { + Expr::Const(c) => c, + v => bail!("failed to convert {:?} to constant", v), + } + } + v => DataValue::from(v), + }) + } fn parse_expr_arg(payload: &JsonValue) -> Result { match payload { JsonValue::String(s) => { @@ -407,7 +424,7 @@ impl SessionTx { if let Some(v) = map.get("const") { Ok(Expr::Const(v.into())) } else if map.contains_key("op") { - Self::parse_expr(map) + Self::parse_apply_expr(map) } else { bail!("expression object must contain either 'const' or 'pred' key"); } diff --git a/tests/air_routes.rs b/tests/air_routes.rs index 009750c1..9462f742 100644 --- a/tests/air_routes.rs +++ b/tests/air_routes.rs @@ -310,7 +310,7 @@ fn air_routes() -> Result<()> { routes_count[?a, count(?r)] := given[?code], [?a airport.iata ?code], [?r route.src ?a]; ?[?code, ?n] := routes_count[?a, ?n], [?a airport.iata ?code]; - given <- [['AUS'],['AMS'],['JFK'],['DUB'],['MEX']]; + given <- [['A' ++ 'U' ++ 'S'],['AMS'],['JFK'],['DUB'],['MEX']]; "#, )?; dbg!(routes_per_airport_time.elapsed());