operators in const rule

main
Ziyang Hu 2 years ago
parent db75893c5b
commit 6362083faf

@ -3,7 +3,7 @@ use std::collections::{BTreeMap, BTreeSet};
use anyhow::{anyhow, bail, ensure, Result};
use itertools::Itertools;
use serde_json::{json, Map};
use serde_json::{json, Map, Value};
use crate::data::aggr::get_aggr;
use crate::data::attr::Attribute;
@ -131,8 +131,8 @@ impl SessionTx {
anyhow!("data in rule is expected to be an array, got {}", v)
})?
.iter()
.map(|v| DataValue::from(v))
.collect();
.map(|v| Self::parse_const_expr(v))
.try_collect()?;
Ok(Tuple(tuple))
})
.try_collect()?;
@ -326,7 +326,7 @@ impl SessionTx {
}
}
fn parse_input_predicate_atom(payload: &Map<String, JsonValue>) -> Result<InputAtom> {
let mut pred = Self::parse_expr(payload)?;
let mut pred = Self::parse_apply_expr(payload)?;
if let Expr::Apply(op, _) = &pred {
ensure!(
op.is_predicate,
@ -355,7 +355,7 @@ impl SessionTx {
let expr = Self::parse_expr_arg(expr)?;
Ok(InputAtom::Unification(Unification { binding, expr }))
}
fn parse_expr(payload: &Map<String, JsonValue>) -> Result<Expr> {
fn parse_apply_expr(payload: &Map<String, JsonValue>) -> Result<Expr> {
let name = payload
.get("op")
.ok_or_else(|| anyhow!("expect expression to have key 'pred'"))?
@ -393,6 +393,23 @@ impl SessionTx {
Ok(Expr::Apply(op, args))
}
fn parse_const_expr(payload: &JsonValue) -> Result<DataValue> {
Ok(match payload {
Value::Array(arr) => {
let exprs: Vec<_> = arr.iter().map(Self::parse_const_expr).try_collect()?;
DataValue::List(exprs.into())
}
Value::Object(m) => {
let mut ex = Self::parse_apply_expr(m)?;
ex.partial_eval()?;
match ex {
Expr::Const(c) => c,
v => bail!("failed to convert {:?} to constant", v),
}
}
v => DataValue::from(v),
})
}
fn parse_expr_arg(payload: &JsonValue) -> Result<Expr> {
match payload {
JsonValue::String(s) => {
@ -407,7 +424,7 @@ impl SessionTx {
if let Some(v) = map.get("const") {
Ok(Expr::Const(v.into()))
} else if map.contains_key("op") {
Self::parse_expr(map)
Self::parse_apply_expr(map)
} else {
bail!("expression object must contain either 'const' or 'pred' key");
}

@ -310,7 +310,7 @@ fn air_routes() -> Result<()> {
routes_count[?a, count(?r)] := given[?code], [?a airport.iata ?code], [?r route.src ?a];
?[?code, ?n] := routes_count[?a, ?n], [?a airport.iata ?code];
given <- [['AUS'],['AMS'],['JFK'],['DUB'],['MEX']];
given <- [['A' ++ 'U' ++ 'S'],['AMS'],['JFK'],['DUB'],['MEX']];
"#,
)?;
dbg!(routes_per_airport_time.elapsed());

Loading…
Cancel
Save