diff --git a/python/cozo_helper.py b/python/cozo_helper.py index ced01236..cbd62059 100644 --- a/python/cozo_helper.py +++ b/python/cozo_helper.py @@ -70,13 +70,17 @@ Ge = PredicateClass('Ge') Le = PredicateClass('Le') Eq = PredicateClass('Eq') Neq = PredicateClass('Neq') +Add = PredicateClass('Add') +Sub = PredicateClass('Sub') +Mul = PredicateClass('Mul') +Div = PredicateClass('Div') def Const(item): return {'const': item} -__all__ = ['Gt', 'Lt', 'Ge', 'Le', 'Eq', 'Neq', 'Q', 'T', 'R', 'Const'] +__all__ = ['Gt', 'Lt', 'Ge', 'Le', 'Eq', 'Neq', 'Add', 'Sub', 'Mul', 'Div', 'Q', 'T', 'R', 'Const'] if __name__ == '__main__': import json diff --git a/src/data/expr.rs b/src/data/expr.rs index c45f7599..568f4b06 100644 --- a/src/data/expr.rs +++ b/src/data/expr.rs @@ -14,18 +14,18 @@ pub enum ExprError { } #[derive(Debug, Clone)] -pub(crate) enum Expr { +pub enum Expr { Const(Term), Apply(&'static Op, Box<[Expr]>), } #[derive(Clone)] -pub(crate) struct Op { - name: &'static str, - min_arity: usize, - vararg: bool, - is_predicate: bool, - inner: fn(&[DataValue]) -> Result, +pub struct Op { + pub(crate) name: &'static str, + pub(crate) min_arity: usize, + pub(crate) vararg: bool, + pub(crate) is_predicate: bool, + pub(crate) inner: fn(&[DataValue]) -> Result, } impl Debug for Op { @@ -56,6 +56,26 @@ fn op_neq(args: &[DataValue]) -> Result { Ok(DataValue::Bool(!args.iter().all_equal())) } +define_op!(OP_GT, 2, false, true); +fn op_gt(args: &[DataValue]) -> Result { + Ok(DataValue::Bool(args[0] > args[1])) +} + +define_op!(OP_GE, 2, false, true); +fn op_ge(args: &[DataValue]) -> Result { + Ok(DataValue::Bool(args[0] >= args[1])) +} + +define_op!(OP_LT, 2, false, true); +fn op_lt(args: &[DataValue]) -> Result { + Ok(DataValue::Bool(args[0] < args[1])) +} + +define_op!(OP_LE, 2, false, true); +fn op_le(args: &[DataValue]) -> Result { + Ok(DataValue::Bool(args[0] <= args[1])) +} + define_op!(OP_ADD, 0, true, false); fn op_add(args: &[DataValue]) -> Result { let mut i_accum = 0i64; diff --git a/src/parse/query.rs b/src/parse/query.rs index e9959deb..76d5e0d4 100644 --- a/src/parse/query.rs +++ b/src/parse/query.rs @@ -6,6 +6,9 @@ use itertools::Itertools; use serde_json::Map; use crate::data::attr::Attribute; +use crate::data::expr::{ + Expr, OP_ADD, OP_DIV, OP_EQ, OP_GE, OP_GT, OP_LE, OP_LT, OP_MUL, OP_NEQ, OP_SUB, +}; use crate::data::json::JsonValue; use crate::data::keyword::{Keyword, PROG_ENTRY}; use crate::data::value::DataValue; @@ -72,6 +75,111 @@ impl SessionTx { } } } + fn parse_predicate_atom(payload: &Map) -> Result { + let pred = Self::parse_expr(payload)?; + if let Expr::Apply(op, _) = &pred { + if !op.is_predicate { + return Err(QueryCompilationError::NotAPredicate(op.name).into()); + } + } + Ok(Atom::Predicate(pred)) + } + fn parse_expr(payload: &Map) -> Result { + let name = payload + .get("pred") + .ok_or_else(|| { + QueryCompilationError::UnexpectedForm( + JsonValue::Object(payload.clone()), + "expect key 'pred'".to_string(), + ) + })? + .as_str() + .ok_or_else(|| { + QueryCompilationError::UnexpectedForm( + JsonValue::Object(payload.clone()), + "expect key 'pred' to be the name of a predicate".to_string(), + ) + })?; + + let op = match name { + "Add" => &OP_ADD, + "Sub" => &OP_SUB, + "Mul" => &OP_MUL, + "Div" => &OP_DIV, + "Eq" => &OP_EQ, + "Neq" => &OP_NEQ, + "Gt" => &OP_GT, + "Ge" => &OP_GE, + "Lt" => &OP_LT, + "Le" => &OP_LE, + s => return Err(QueryCompilationError::UnknownOperator(s.to_string()).into()), + }; + + let args: Box<[Expr]> = payload + .get("args") + .ok_or_else(|| { + QueryCompilationError::UnexpectedForm( + JsonValue::Object(payload.clone()), + "expect key 'args'".to_string(), + ) + })? + .as_array() + .ok_or_else(|| { + QueryCompilationError::UnexpectedForm( + JsonValue::Object(payload.clone()), + "expect key 'args' to be an array".to_string(), + ) + })? + .iter() + .map(Self::parse_expr_arg) + .try_collect()?; + + if op.vararg { + if args.len() < op.min_arity { + return Err(QueryCompilationError::PredicateArityMismatch( + op.name, + op.min_arity, + args.len(), + ) + .into()); + } + } else if args.len() != op.min_arity { + return Err(QueryCompilationError::PredicateArityMismatch( + op.name, + op.min_arity, + args.len(), + ) + .into()); + } + + Ok(Expr::Apply(op, args)) + } + fn parse_expr_arg(payload: &JsonValue) -> Result { + match payload { + JsonValue::String(s) => { + let kw = Keyword::from(s as &str); + if kw.is_reserved() { + Ok(Expr::Const(Term::Var(kw))) + } else { + Ok(Expr::Const(Term::Const(DataValue::String(s.into())))) + } + } + JsonValue::Object(map) => { + if let Some(v) = map.get("const") { + Ok(Expr::Const(Term::Const(v.into()))) + } else if map.contains_key("pred") { + Self::parse_expr(map) + } else { + Err(QueryCompilationError::UnexpectedForm( + JsonValue::Object(map.clone()), + "must contain either 'const' or 'pred' key".to_string(), + ) + .into()) + } + } + v => Ok(Expr::Const(Term::Const(v.into()))), + } + } fn parse_rule_atom(&mut self, payload: &Map, vld: Validity) -> Result { let rule_name = payload .get("rule") @@ -232,7 +340,9 @@ impl SessionTx { if map.contains_key("rule") { self.parse_rule_atom(map, vld) } else if map.contains_key("pred") { - dbg!(map); + Self::parse_predicate_atom(map) + } else if map.contains_key("logical") { + dbg!("x"); todo!() } else { todo!() diff --git a/src/query/compile.rs b/src/query/compile.rs index efeb7f96..742dad06 100644 --- a/src/query/compile.rs +++ b/src/query/compile.rs @@ -5,7 +5,6 @@ use std::ops::Sub; use anyhow::Result; use itertools::Itertools; -use crate::{EntityId, Validity}; use crate::data::attr::Attribute; use crate::data::json::JsonValue; use crate::data::keyword::Keyword; @@ -13,6 +12,8 @@ use crate::data::value::DataValue; use crate::query::relation::Relation; use crate::runtime::temp_store::TempStore; use crate::runtime::transact::SessionTx; +use crate::{EntityId, Validity}; +use crate::data::expr::Expr; /// example ruleset in python and javascript /// ```python @@ -56,10 +57,16 @@ pub enum QueryCompilationError { EntryHeadsNotIdentical, #[error("required binding not found: {0}")] BindingNotFound(Keyword), + #[error("unknown operator '{0}'")] + UnknownOperator(String), + #[error("op {0} arity mismatch: expected {1} arguments, found {2}")] + PredicateArityMismatch(&'static str, usize, usize), + #[error("op {0} is not a predicate")] + NotAPredicate(&'static str), } #[derive(Clone, Debug)] -pub(crate) enum Term { +pub enum Term { Var(Keyword), Const(T), } @@ -93,16 +100,27 @@ pub struct RuleApplyAtom { } #[derive(Clone, Debug)] -pub struct PredicateAtom { - pub(crate) left: Term, - pub(crate) right: Term, +pub enum LogicalAtom { + AttrTriple(AttrTripleAtom), + Rule(RuleApplyAtom), + Negation(Box), + Conjunction(Vec), + Disjunction(Vec), +} + +#[derive(Clone, Debug)] +pub struct BindUnification { + left: Term, + right: Expr, } #[derive(Clone, Debug)] pub enum Atom { AttrTriple(AttrTripleAtom), Rule(RuleApplyAtom), - Predicate(PredicateAtom), + Predicate(Expr), + Logical(LogicalAtom), + BindUnify(BindUnification), } #[derive(Clone, Debug)] @@ -358,6 +376,12 @@ impl SessionTx { Atom::Predicate(_) => { todo!() } + Atom::Logical(_) => { + todo!() + } + Atom::BindUnify(_) => { + todo!() + } } }