diff --git a/src/ast/mod.rs b/src/ast.rs similarity index 85% rename from src/ast/mod.rs rename to src/ast.rs index 8f4d2514..e1b9494f 100644 --- a/src/ast/mod.rs +++ b/src/ast.rs @@ -4,16 +4,37 @@ use pest::prec_climber::{Assoc, PrecClimber, Operator}; use crate::parser::Parser; use crate::parser::Rule; use lazy_static::lazy_static; -use crate::ast::eval_op::*; use crate::ast::Expr::{Apply, Const}; -use crate::ast::op::Op; use crate::error::CozoError; use crate::error::CozoError::ReservedIdent; -use crate::typing::{PrimitiveType, Typing}; +use crate::typing::{BaseType, Typing}; use crate::value::Value; -mod eval_op; -mod op; + + +#[derive(PartialEq, Debug)] +pub enum Op { + Add, + Sub, + Mul, + Div, + Eq, + Neq, + Gt, + Lt, + Ge, + Le, + Neg, + Minus, + Mod, + Or, + And, + Coalesce, + Pow, + Call, + IsNull, + NotNull +} lazy_static! { @@ -78,38 +99,11 @@ pub enum Expr<'a> { Const(Value<'a>), } -impl<'a> Expr<'a> { - pub fn eval(&self) -> Result, CozoError> { - match self { - Apply(op, args) => { - match op { - Op::Add => add_exprs(args), - Op::Sub => sub_exprs(args), - Op::Mul => mul_exprs(args), - Op::Div => div_exprs(args), - Op::Eq => eq_exprs(args), - Op::Neq => ne_exprs(args), - Op::Gt => gt_exprs(args), - Op::Lt => lt_exprs(args), - Op::Ge => ge_exprs(args), - Op::Le => le_exprs(args), - Op::Neg => negate_expr(args), - Op::Minus => minus_expr(args), - Op::Mod => mod_exprs(args), - Op::Or => or_expr(args), - Op::And => and_expr(args), - Op::Coalesce => coalesce_exprs(args), - Op::Pow => pow_exprs(args), - Op::IsNull => is_null_expr(args), - Op::NotNull => not_null_expr(args), - Op::Call => unimplemented!(), - } - } - Const(v) => Ok(Const(v.clone())) - } - } +pub trait ExprVisitor<'a, T> { + fn visit_expr(&mut self, ex: &Expr<'a>) -> T; } + fn build_expr_infix<'a>(lhs: Result, CozoError>, op: Pair, rhs: Result, CozoError>) -> Result, CozoError> { let lhs = lhs?; let rhs = rhs?; @@ -279,8 +273,8 @@ fn build_col_entry(pair: Pair) -> Result<(Col, bool), CozoError> { let (name, is_key) = parse_col_name(pairs.next().unwrap())?; Ok((Col { name, - typ: Typing::Primitive(PrimitiveType::Int), - default: None + typ: Typing::Base(BaseType::Int), + default: None, }, is_key)) } @@ -370,16 +364,6 @@ mod tests { assert_eq!(parse_expr_from_str(r#####"r###"x"yz"###"#####).unwrap(), Const(Value::RefString(r##"x"yz"##))); } - #[test] - fn operators() { - println!("{:#?}", parse_expr_from_str("1/10+(-2+3)*4^5").unwrap().eval().unwrap()); - println!("{:#?}", parse_expr_from_str("true && false").unwrap().eval().unwrap()); - println!("{:#?}", parse_expr_from_str("true || false").unwrap().eval().unwrap()); - println!("{:#?}", parse_expr_from_str("true || null").unwrap().eval().unwrap()); - println!("{:#?}", parse_expr_from_str("null || true").unwrap().eval().unwrap()); - println!("{:#?}", parse_expr_from_str("true && null").unwrap().eval().unwrap()); - } - #[test] fn definitions() { println!("{:#?}", build_statements_from_str(r#" diff --git a/src/ast/eval_op.rs b/src/ast/eval_op.rs deleted file mode 100644 index 4e47538c..00000000 --- a/src/ast/eval_op.rs +++ /dev/null @@ -1,451 +0,0 @@ -use crate::ast::Expr; -use crate::ast::Expr::*; -use crate::ast::op::Op; -use crate::error::CozoError; -use crate::error::CozoError::*; -use crate::value::Value::*; - -pub fn add_exprs<'a>(exprs: &[Expr<'a>]) -> Result, CozoError> { - match exprs { - [a, b] => { - let a = a.eval()?; - let b = b.eval()?; - if a == Const(Null) || b == Const(Null) { - return Ok(Const(Null)); - } - Ok(Const(match (a, b) { - (Const(a), Const(b)) => { - match (a, b) { - (Int(va), Int(vb)) => Int(va + vb), - (Float(va), Int(vb)) => Float(va + vb as f64), - (Int(va), Float(vb)) => Float(va as f64 + vb), - (Float(va), Float(vb)) => Float(va + vb), - (OwnString(va), OwnString(vb)) => OwnString(Box::new(*va + &*vb)), - (OwnString(va), RefString(vb)) => OwnString(Box::new(*va + &*vb)), - (RefString(va), OwnString(vb)) => OwnString(Box::new(va.to_string() + &*vb)), - (RefString(va), RefString(vb)) => OwnString(Box::new(va.to_string() + &*vb)), - (_, _) => return Err(CozoError::TypeError) - } - } - (a, b) => return Ok(Apply(Op::Add, vec![a, b])) - })) - } - _ => unreachable!() - } -} - -pub fn sub_exprs<'a>(exprs: &[Expr<'a>]) -> Result, CozoError> { - match exprs { - [a, b] => { - let a = a.eval()?; - let b = b.eval()?; - if a == Const(Null) || b == Const(Null) { - return Ok(Const(Null)); - } - Ok(Const(match (a, b) { - (Const(a), Const(b)) => { - match (a, b) { - (Int(va), Int(vb)) => Int(va - vb), - (Float(va), Int(vb)) => Float(va - vb as f64), - (Int(va), Float(vb)) => Float(va as f64 - vb), - (Float(va), Float(vb)) => Float(va - vb), - (_, _) => return Err(CozoError::TypeError) - } - } - (a, b) => return Ok(Apply(Op::Sub, vec![a, b])) - })) - } - _ => unreachable!() - } -} - -pub fn mul_exprs<'a>(exprs: &[Expr<'a>]) -> Result, CozoError> { - match exprs { - [a, b] => { - let a = a.eval()?; - let b = b.eval()?; - if a == Const(Null) || b == Const(Null) { - return Ok(Const(Null)); - } - Ok(Const(match (a, b) { - (Const(a), Const(b)) => { - match (a, b) { - (Int(va), Int(vb)) => Int(va * vb), - (Float(va), Int(vb)) => Float(va * vb as f64), - (Int(va), Float(vb)) => Float(va as f64 * vb), - (Float(va), Float(vb)) => Float(va * vb), - (_, _) => return Err(CozoError::TypeError) - } - } - (a, b) => return Ok(Apply(Op::Mul, vec![a, b])) - })) - } - _ => unreachable!() - } -} - - -pub fn div_exprs<'a>(exprs: &[Expr<'a>]) -> Result, CozoError> { - match exprs { - [a, b] => { - let a = a.eval()?; - let b = b.eval()?; - if a == Const(Null) || b == Const(Null) { - return Ok(Const(Null)); - } - Ok(Const(match (a, b) { - (Const(a), Const(b)) => { - match (a, b) { - (Int(va), Int(vb)) => Float(va as f64 / vb as f64), - (Float(va), Int(vb)) => Float(va / vb as f64), - (Int(va), Float(vb)) => Float(va as f64 / vb), - (Float(va), Float(vb)) => Float(va / vb), - (_, _) => return Err(CozoError::TypeError) - } - } - (a, b) => return Ok(Apply(Op::Div, vec![a, b])) - })) - } - _ => unreachable!() - } -} - -pub fn mod_exprs<'a>(exprs: &[Expr<'a>]) -> Result, CozoError> { - match exprs { - [a, b] => { - let a = a.eval()?; - let b = b.eval()?; - if a == Const(Null) || b == Const(Null) { - return Ok(Const(Null)); - } - Ok(Const(match (a, b) { - (Const(a), Const(b)) => { - match (a, b) { - (Int(a), Int(b)) => Int(a % b), - (_, _) => return Err(CozoError::TypeError) - } - } - (a, b) => return Ok(Apply(Op::Mod, vec![a, b])) - })) - } - _ => unreachable!() - } -} - - -pub fn eq_exprs<'a>(exprs: &[Expr<'a>]) -> Result, CozoError> { - match exprs { - [a, b] => { - let a = a.eval()?; - let b = b.eval()?; - if a == Const(Null) || b == Const(Null) { - return Ok(Const(Null)); - } - match (a, b) { - (Const(a), Const(b)) => Ok(Const(Bool(a == b))), - (a, b) => Ok(Apply(Op::Eq, vec![a, b])) - } - } - _ => unreachable!() - } -} - - -pub fn ne_exprs<'a>(exprs: &[Expr<'a>]) -> Result, CozoError> { - match exprs { - [a, b] => { - let a = a.eval()?; - let b = b.eval()?; - if a == Const(Null) || b == Const(Null) { - return Ok(Const(Null)); - } - match (a, b) { - (Const(a), Const(b)) => Ok(Const(Bool(a == b))), - (a, b) => Ok(Apply(Op::Neq, vec![a, b])) - } - } - _ => unreachable!() - } -} - - -pub fn gt_exprs<'a>(exprs: &[Expr<'a>]) -> Result, CozoError> { - match exprs { - [a, b] => { - let a = a.eval()?; - let b = b.eval()?; - if a == Const(Null) || b == Const(Null) { - return Ok(Const(Null)); - } - match (a, b) { - (Const(a), Const(b)) => { - match (a, b) { - (Int(a), Int(b)) => Ok(Const(Bool(a > b))), - (Float(a), Int(b)) => Ok(Const(Bool(a > b as f64))), - (Int(a), Float(b)) => Ok(Const(Bool(a as f64 > b))), - (Float(a), Float(b)) => Ok(Const(Bool(a > b))), - (_, _) => Err(CozoError::TypeError) - } - } - (a, b) => Ok(Apply(Op::Gt, vec![a, b])) - } - } - _ => unreachable!() - } -} - - -pub fn ge_exprs<'a>(exprs: &[Expr<'a>]) -> Result, CozoError> { - match exprs { - [a, b] => { - let a = a.eval()?; - let b = b.eval()?; - if a == Const(Null) || b == Const(Null) { - return Ok(Const(Null)); - } - match (a, b) { - (Const(a), Const(b)) => { - match (a, b) { - (Int(a), Int(b)) => Ok(Const(Bool(a >= b))), - (Float(a), Int(b)) => Ok(Const(Bool(a >= b as f64))), - (Int(a), Float(b)) => Ok(Const(Bool(a as f64 >= b))), - (Float(a), Float(b)) => Ok(Const(Bool(a >= b))), - (_, _) => Err(CozoError::TypeError) - } - } - (a, b) => Ok(Apply(Op::Ge, vec![a, b])) - } - } - _ => unreachable!() - } -} - - -pub fn lt_exprs<'a>(exprs: &[Expr<'a>]) -> Result, CozoError> { - match exprs { - [a, b] => { - let a = a.eval()?; - let b = b.eval()?; - if a == Const(Null) || b == Const(Null) { - return Ok(Const(Null)); - } - match (a, b) { - (Const(a), Const(b)) => { - match (a, b) { - (Int(a), Int(b)) => Ok(Const(Bool(a < b))), - (Float(a), Int(b)) => Ok(Const(Bool(a < b as f64))), - (Int(a), Float(b)) => Ok(Const(Bool((a as f64) < b))), - (Float(a), Float(b)) => Ok(Const(Bool(a < b))), - (_, _) => Err(CozoError::TypeError) - } - } - (a, b) => Ok(Apply(Op::Lt, vec![a, b])) - } - } - _ => unreachable!() - } -} - -pub fn le_exprs<'a>(exprs: &[Expr<'a>]) -> Result, CozoError> { - match exprs { - [a, b] => { - let a = a.eval()?; - let b = b.eval()?; - if a == Const(Null) || b == Const(Null) { - return Ok(Const(Null)); - } - match (a, b) { - (Const(a), Const(b)) => { - match (a, b) { - (Int(a), Int(b)) => Ok(Const(Bool(a <= b))), - (Float(a), Int(b)) => Ok(Const(Bool(a <= b as f64))), - (Int(a), Float(b)) => Ok(Const(Bool((a as f64) <= b))), - (Float(a), Float(b)) => Ok(Const(Bool(a <= b))), - (_, _) => Err(CozoError::TypeError) - } - } - (a, b) => Ok(Apply(Op::Le, vec![a, b])) - } - } - _ => unreachable!() - } -} - - -pub fn pow_exprs<'a>(exprs: &[Expr<'a>]) -> Result, CozoError> { - match exprs { - [a, b] => { - let a = a.eval()?; - let b = b.eval()?; - if a == Const(Null) || b == Const(Null) { - return Ok(Const(Null)); - } - match (a, b) { - (Const(a), Const(b)) => { - match (a, b) { - (Int(a), Int(b)) => Ok(Const(Float((a as f64).powf(b as f64)))), - (Float(a), Int(b)) => Ok(Const(Float(a.powi(b as i32)))), - (Int(a), Float(b)) => Ok(Const(Float((a as f64).powf(b)))), - (Float(a), Float(b)) => Ok(Const(Float(a.powf(b)))), - (_, _) => Err(CozoError::TypeError) - } - } - (a, b) => Ok(Apply(Op::Pow, vec![a, b])) - } - } - _ => unreachable!() - } -} - -pub fn coalesce_exprs<'a>(exprs: &[Expr<'a>]) -> Result, CozoError> { - match exprs { - [a, b] => { - let a = a.eval()?; - let b = b.eval()?; - if a == Const(Null) { - return Ok(b); - } - if b == Const(Null) { - return Ok(a); - } - if let a @ Const(_) = a { - return Ok(a); - } - return Ok(Apply(Op::Coalesce, vec![a, b])); - } - _ => unreachable!() - } -} - -pub fn negate_expr<'a>(exprs: &[Expr<'a>]) -> Result, CozoError> { - Ok(match exprs { - [a] => { - match a.eval()? { - Const(Null) => Const(Null), - Const(Bool(b)) => Const(Bool(!b)), - Const(_) => return Err(TypeError), - Apply(Op::Neg, v) => v.into_iter().next().unwrap(), - Apply(Op::IsNull, v) => Apply(Op::NotNull, v), - Apply(Op::NotNull, v) => Apply(Op::IsNull, v), - Apply(Op::Eq, v) => Apply(Op::Neq, v), - Apply(Op::Neq, v) => Apply(Op::Eq, v), - Apply(Op::Gt, v) => Apply(Op::Le, v), - Apply(Op::Ge, v) => Apply(Op::Lt, v), - Apply(Op::Le, v) => Apply(Op::Gt, v), - Apply(Op::Lt, v) => Apply(Op::Ge, v), - v => Apply(Op::Neg, vec![v]) - } - } - _ => unreachable!() - }) -} - - -pub fn minus_expr<'a>(exprs: &[Expr<'a>]) -> Result, CozoError> { - Ok(match exprs { - [a] => { - match a.eval()? { - Const(Null) => Const(Null), - Const(Int(i)) => Const(Int(-i)), - Const(Float(f)) => Const(Float(-f)), - Const(_) => return Err(TypeError), - Apply(Op::Minus, v) => v.into_iter().next().unwrap(), - v => Apply(Op::Minus, vec![v]) - } - } - _ => unreachable!() - }) -} - - -pub fn is_null_expr<'a>(exprs: &[Expr<'a>]) -> Result, CozoError> { - Ok(match exprs { - [a] => { - match a.eval()? { - Const(Null) => Const(Bool(true)), - Const(_) => Const(Bool(false)), - v => Apply(Op::IsNull, vec![v]) - } - } - _ => unreachable!() - }) -} - -pub fn not_null_expr<'a>(exprs: &[Expr<'a>]) -> Result, CozoError> { - Ok(match exprs { - [a] => { - match a.eval()? { - Const(Null) => Const(Bool(false)), - Const(_) => Const(Bool(true)), - v => Apply(Op::IsNull, vec![v]) - } - } - _ => unreachable!() - }) -} - -pub fn or_expr<'a>(exprs: &[Expr<'a>]) -> Result, CozoError> { - let mut unevaluated = vec![]; - let mut has_null = false; - for expr in exprs { - match expr.eval()? { - Const(Bool(true)) => return Ok(Const(Bool(true))), - Const(Bool(false)) => {} - Const(Null) => { has_null = true}, - Const(_) => return Err(TypeError), - Apply(Op::Or, vs) => { - for el in vs { - match el { - Const(Null) => has_null = true, - Const(_) => unreachable!(), - v => unevaluated.push(v) - } - } - } - v => unevaluated.push(v) - } - } - match (has_null, unevaluated.len()) { - (true, 0) => Ok(Const(Null)), - (false, 0) => Ok(Const(Bool(false))), - (false, _) => Ok(Apply(Op::Or, unevaluated)), - (true, _) => { - unevaluated.push(Const(Null)); - Ok(Apply(Op::Or, unevaluated)) - } - } -} - - - -pub fn and_expr<'a>(exprs: &[Expr<'a>]) -> Result, CozoError> { - let mut unevaluated = vec![]; - let mut no_null = true; - for expr in exprs { - match expr.eval()? { - Const(Bool(false)) => return Ok(Const(Bool(false))), - Const(Bool(true)) => {}, - Const(Null) => no_null = false, - Const(_) => return Err(TypeError), - Apply(Op::Or, vs) => { - for el in vs { - match el { - Const(Null) => no_null = false, - Const(_) => unreachable!(), - v => unevaluated.push(v) - } - } - } - v => unevaluated.push(v) - } - } - match (no_null, unevaluated.len()) { - (true, 0) => Ok(Const(Bool(true))), - (false, 0) => Ok(Const(Null)), - (true, _) => Ok(Apply(Op::Add, unevaluated)), - (false, _) => { - unevaluated.push(Const(Null)); - Ok(Apply(Op::And, unevaluated)) - } - } -} \ No newline at end of file diff --git a/src/ast/op.rs b/src/ast/op.rs deleted file mode 100644 index 1de2cbc3..00000000 --- a/src/ast/op.rs +++ /dev/null @@ -1,24 +0,0 @@ - -#[derive(PartialEq, Debug)] -pub enum Op { - Add, - Sub, - Mul, - Div, - Eq, - Neq, - Gt, - Lt, - Ge, - Le, - Neg, - Minus, - Mod, - Or, - And, - Coalesce, - Pow, - Call, - IsNull, - NotNull -} diff --git a/src/eval.rs b/src/eval.rs index e69de29b..e357d227 100644 --- a/src/eval.rs +++ b/src/eval.rs @@ -0,0 +1,503 @@ +use crate::ast::Expr; +use crate::ast::Expr::*; +use crate::error::CozoError; +use crate::error::CozoError::*; +use crate::value::Value::*; +use crate::ast::*; + +struct Evaluator; + +impl<'a> ExprVisitor<'a, Result, CozoError>> for Evaluator { + fn visit_expr(&mut self, ex: &Expr<'a>) -> Result, CozoError> { + match ex { + Apply(op, args) => { + match op { + Op::Add => self.add_exprs(args), + Op::Sub => self.sub_exprs(args), + Op::Mul => self.mul_exprs(args), + Op::Div => self.div_exprs(args), + Op::Eq => self.eq_exprs(args), + Op::Neq => self.ne_exprs(args), + Op::Gt => self.gt_exprs(args), + Op::Lt => self.lt_exprs(args), + Op::Ge => self.ge_exprs(args), + Op::Le => self.le_exprs(args), + Op::Neg => self.negate_expr(args), + Op::Minus => self.minus_expr(args), + Op::Mod => self.mod_exprs(args), + Op::Or => self.or_expr(args), + Op::And => self.and_expr(args), + Op::Coalesce => self.coalesce_exprs(args), + Op::Pow => self.pow_exprs(args), + Op::IsNull => self.test_null_expr(args), + Op::NotNull => self.not_null_expr(args), + Op::Call => unimplemented!(), + } + } + Const(v) => Ok(Const(v.clone())) + } + } +} + +impl Evaluator { + fn add_exprs<'a>(&mut self, exprs: &[Expr<'a>]) -> Result, CozoError> { + match exprs { + [a, b] => { + let a = self.visit_expr(a)?; + let b = self.visit_expr(b)?; + if a == Const(Null) || b == Const(Null) { + return Ok(Const(Null)); + } + Ok(Const(match (a, b) { + (Const(a), Const(b)) => { + match (a, b) { + (Int(va), Int(vb)) => Int(va + vb), + (Float(va), Int(vb)) => Float(va + vb as f64), + (Int(va), Float(vb)) => Float(va as f64 + vb), + (Float(va), Float(vb)) => Float(va + vb), + (OwnString(va), OwnString(vb)) => OwnString(Box::new(*va + &*vb)), + (OwnString(va), RefString(vb)) => OwnString(Box::new(*va + &*vb)), + (RefString(va), OwnString(vb)) => OwnString(Box::new(va.to_string() + &*vb)), + (RefString(va), RefString(vb)) => OwnString(Box::new(va.to_string() + &*vb)), + (_, _) => return Err(CozoError::TypeError) + } + } + (a, b) => return Ok(Apply(Op::Add, vec![a, b])) + })) + } + _ => unreachable!() + } + } + + fn sub_exprs<'a>(&mut self, exprs: &[Expr<'a>]) -> Result, CozoError> { + match exprs { + [a, b] => { + let a = self.visit_expr(a)?; + let b = self.visit_expr(b)?; + if a == Const(Null) || b == Const(Null) { + return Ok(Const(Null)); + } + Ok(Const(match (a, b) { + (Const(a), Const(b)) => { + match (a, b) { + (Int(va), Int(vb)) => Int(va - vb), + (Float(va), Int(vb)) => Float(va - vb as f64), + (Int(va), Float(vb)) => Float(va as f64 - vb), + (Float(va), Float(vb)) => Float(va - vb), + (_, _) => return Err(CozoError::TypeError) + } + } + (a, b) => return Ok(Apply(Op::Sub, vec![a, b])) + })) + } + _ => unreachable!() + } + } + + fn mul_exprs<'a>(&mut self, exprs: &[Expr<'a>]) -> Result, CozoError> { + match exprs { + [a, b] => { + let a = self.visit_expr(a)?; + let b = self.visit_expr(b)?; + if a == Const(Null) || b == Const(Null) { + return Ok(Const(Null)); + } + Ok(Const(match (a, b) { + (Const(a), Const(b)) => { + match (a, b) { + (Int(va), Int(vb)) => Int(va * vb), + (Float(va), Int(vb)) => Float(va * vb as f64), + (Int(va), Float(vb)) => Float(va as f64 * vb), + (Float(va), Float(vb)) => Float(va * vb), + (_, _) => return Err(CozoError::TypeError) + } + } + (a, b) => return Ok(Apply(Op::Mul, vec![a, b])) + })) + } + _ => unreachable!() + } + } + + + fn div_exprs<'a>(&mut self, exprs: &[Expr<'a>]) -> Result, CozoError> { + match exprs { + [a, b] => { + let a = self.visit_expr(a)?; + let b = self.visit_expr(b)?; + if a == Const(Null) || b == Const(Null) { + return Ok(Const(Null)); + } + Ok(Const(match (a, b) { + (Const(a), Const(b)) => { + match (a, b) { + (Int(va), Int(vb)) => Float(va as f64 / vb as f64), + (Float(va), Int(vb)) => Float(va / vb as f64), + (Int(va), Float(vb)) => Float(va as f64 / vb), + (Float(va), Float(vb)) => Float(va / vb), + (_, _) => return Err(CozoError::TypeError) + } + } + (a, b) => return Ok(Apply(Op::Div, vec![a, b])) + })) + } + _ => unreachable!() + } + } + + fn mod_exprs<'a>(&mut self, exprs: &[Expr<'a>]) -> Result, CozoError> { + match exprs { + [a, b] => { + let a = self.visit_expr(a)?; + let b = self.visit_expr(b)?; + if a == Const(Null) || b == Const(Null) { + return Ok(Const(Null)); + } + Ok(Const(match (a, b) { + (Const(a), Const(b)) => { + match (a, b) { + (Int(a), Int(b)) => Int(a % b), + (_, _) => return Err(CozoError::TypeError) + } + } + (a, b) => return Ok(Apply(Op::Mod, vec![a, b])) + })) + } + _ => unreachable!() + } + } + + + fn eq_exprs<'a>(&mut self, exprs: &[Expr<'a>]) -> Result, CozoError> { + match exprs { + [a, b] => { + let a = self.visit_expr(a)?; + let b = self.visit_expr(b)?; + if a == Const(Null) || b == Const(Null) { + return Ok(Const(Null)); + } + match (a, b) { + (Const(a), Const(b)) => Ok(Const(Bool(a == b))), + (a, b) => Ok(Apply(Op::Eq, vec![a, b])) + } + } + _ => unreachable!() + } + } + + + fn ne_exprs<'a>(&mut self, exprs: &[Expr<'a>]) -> Result, CozoError> { + match exprs { + [a, b] => { + let a = self.visit_expr(a)?; + let b = self.visit_expr(b)?; + if a == Const(Null) || b == Const(Null) { + return Ok(Const(Null)); + } + match (a, b) { + (Const(a), Const(b)) => Ok(Const(Bool(a != b))), + (a, b) => Ok(Apply(Op::Neq, vec![a, b])) + } + } + _ => unreachable!() + } + } + + + fn gt_exprs<'a>(&mut self, exprs: &[Expr<'a>]) -> Result, CozoError> { + match exprs { + [a, b] => { + let a = self.visit_expr(a)?; + let b = self.visit_expr(b)?; + if a == Const(Null) || b == Const(Null) { + return Ok(Const(Null)); + } + match (a, b) { + (Const(a), Const(b)) => { + match (a, b) { + (Int(a), Int(b)) => Ok(Const(Bool(a > b))), + (Float(a), Int(b)) => Ok(Const(Bool(a > b as f64))), + (Int(a), Float(b)) => Ok(Const(Bool(a as f64 > b))), + (Float(a), Float(b)) => Ok(Const(Bool(a > b))), + (_, _) => Err(CozoError::TypeError) + } + } + (a, b) => Ok(Apply(Op::Gt, vec![a, b])) + } + } + _ => unreachable!() + } + } + + + fn ge_exprs<'a>(&mut self, exprs: &[Expr<'a>]) -> Result, CozoError> { + match exprs { + [a, b] => { + let a = self.visit_expr(a)?; + let b = self.visit_expr(b)?; + if a == Const(Null) || b == Const(Null) { + return Ok(Const(Null)); + } + match (a, b) { + (Const(a), Const(b)) => { + match (a, b) { + (Int(a), Int(b)) => Ok(Const(Bool(a >= b))), + (Float(a), Int(b)) => Ok(Const(Bool(a >= b as f64))), + (Int(a), Float(b)) => Ok(Const(Bool(a as f64 >= b))), + (Float(a), Float(b)) => Ok(Const(Bool(a >= b))), + (_, _) => Err(CozoError::TypeError) + } + } + (a, b) => Ok(Apply(Op::Ge, vec![a, b])) + } + } + _ => unreachable!() + } + } + + + fn lt_exprs<'a>(&mut self, exprs: &[Expr<'a>]) -> Result, CozoError> { + match exprs { + [a, b] => { + let a = self.visit_expr(a)?; + let b = self.visit_expr(b)?; + if a == Const(Null) || b == Const(Null) { + return Ok(Const(Null)); + } + match (a, b) { + (Const(a), Const(b)) => { + match (a, b) { + (Int(a), Int(b)) => Ok(Const(Bool(a < b))), + (Float(a), Int(b)) => Ok(Const(Bool(a < b as f64))), + (Int(a), Float(b)) => Ok(Const(Bool((a as f64) < b))), + (Float(a), Float(b)) => Ok(Const(Bool(a < b))), + (_, _) => Err(CozoError::TypeError) + } + } + (a, b) => Ok(Apply(Op::Lt, vec![a, b])) + } + } + _ => unreachable!() + } + } + + fn le_exprs<'a>(&mut self, exprs: &[Expr<'a>]) -> Result, CozoError> { + match exprs { + [a, b] => { + let a = self.visit_expr(a)?; + let b = self.visit_expr(b)?; + if a == Const(Null) || b == Const(Null) { + return Ok(Const(Null)); + } + match (a, b) { + (Const(a), Const(b)) => { + match (a, b) { + (Int(a), Int(b)) => Ok(Const(Bool(a <= b))), + (Float(a), Int(b)) => Ok(Const(Bool(a <= b as f64))), + (Int(a), Float(b)) => Ok(Const(Bool((a as f64) <= b))), + (Float(a), Float(b)) => Ok(Const(Bool(a <= b))), + (_, _) => Err(CozoError::TypeError) + } + } + (a, b) => Ok(Apply(Op::Le, vec![a, b])) + } + } + _ => unreachable!() + } + } + + + fn pow_exprs<'a>(&mut self, exprs: &[Expr<'a>]) -> Result, CozoError> { + match exprs { + [a, b] => { + let a = self.visit_expr(a)?; + let b = self.visit_expr(b)?; + if a == Const(Null) || b == Const(Null) { + return Ok(Const(Null)); + } + match (a, b) { + (Const(a), Const(b)) => { + match (a, b) { + (Int(a), Int(b)) => Ok(Const(Float((a as f64).powf(b as f64)))), + (Float(a), Int(b)) => Ok(Const(Float(a.powi(b as i32)))), + (Int(a), Float(b)) => Ok(Const(Float((a as f64).powf(b)))), + (Float(a), Float(b)) => Ok(Const(Float(a.powf(b)))), + (_, _) => Err(CozoError::TypeError) + } + } + (a, b) => Ok(Apply(Op::Pow, vec![a, b])) + } + } + _ => unreachable!() + } + } + + fn coalesce_exprs<'a>(&mut self, exprs: &[Expr<'a>]) -> Result, CozoError> { + match exprs { + [a, b] => { + let a = self.visit_expr(a)?; + let b = self.visit_expr(b)?; + if a == Const(Null) { + return Ok(b); + } + if b == Const(Null) { + return Ok(a); + } + if let a @ Const(_) = a { + return Ok(a); + } + return Ok(Apply(Op::Coalesce, vec![a, b])); + } + _ => unreachable!() + } + } + + fn negate_expr<'a>(&mut self, exprs: &[Expr<'a>]) -> Result, CozoError> { + Ok(match exprs { + [a] => { + match self.visit_expr(a)? { + Const(Null) => Const(Null), + Const(Bool(b)) => Const(Bool(!b)), + Const(_) => return Err(TypeError), + Apply(Op::Neg, v) => v.into_iter().next().unwrap(), + Apply(Op::IsNull, v) => Apply(Op::NotNull, v), + Apply(Op::NotNull, v) => Apply(Op::IsNull, v), + Apply(Op::Eq, v) => Apply(Op::Neq, v), + Apply(Op::Neq, v) => Apply(Op::Eq, v), + Apply(Op::Gt, v) => Apply(Op::Le, v), + Apply(Op::Ge, v) => Apply(Op::Lt, v), + Apply(Op::Le, v) => Apply(Op::Gt, v), + Apply(Op::Lt, v) => Apply(Op::Ge, v), + v => Apply(Op::Neg, vec![v]) + } + } + _ => unreachable!() + }) + } + + + fn minus_expr<'a>(&mut self, exprs: &[Expr<'a>]) -> Result, CozoError> { + Ok(match exprs { + [a] => { + match self.visit_expr(a)? { + Const(Null) => Const(Null), + Const(Int(i)) => Const(Int(-i)), + Const(Float(f)) => Const(Float(-f)), + Const(_) => return Err(TypeError), + Apply(Op::Minus, v) => v.into_iter().next().unwrap(), + v => Apply(Op::Minus, vec![v]) + } + } + _ => unreachable!() + }) + } + + + fn test_null_expr<'a>(&mut self, exprs: &[Expr<'a>]) -> Result, CozoError> { + Ok(match exprs { + [a] => { + match self.visit_expr(a)? { + Const(Null) => Const(Bool(true)), + Const(_) => Const(Bool(false)), + v => Apply(Op::IsNull, vec![v]) + } + } + _ => unreachable!() + }) + } + + fn not_null_expr<'a>(&mut self, exprs: &[Expr<'a>]) -> Result, CozoError> { + Ok(match exprs { + [a] => { + match self.visit_expr(a)? { + Const(Null) => Const(Bool(false)), + Const(_) => Const(Bool(true)), + v => Apply(Op::IsNull, vec![v]) + } + } + _ => unreachable!() + }) + } + + fn or_expr<'a>(&mut self, exprs: &[Expr<'a>]) -> Result, CozoError> { + let mut unevaluated = vec![]; + let mut has_null = false; + for expr in exprs { + match self.visit_expr(expr)? { + Const(Bool(true)) => return Ok(Const(Bool(true))), + Const(Bool(false)) => {} + Const(Null) => { has_null = true } + Const(_) => return Err(TypeError), + Apply(Op::Or, vs) => { + for el in vs { + match el { + Const(Null) => has_null = true, + Const(_) => unreachable!(), + v => unevaluated.push(v) + } + } + } + v => unevaluated.push(v) + } + } + match (has_null, unevaluated.len()) { + (true, 0) => Ok(Const(Null)), + (false, 0) => Ok(Const(Bool(false))), + (false, _) => Ok(Apply(Op::Or, unevaluated)), + (true, _) => { + unevaluated.push(Const(Null)); + Ok(Apply(Op::Or, unevaluated)) + } + } + } + + + fn and_expr<'a>(&mut self, exprs: &[Expr<'a>]) -> Result, CozoError> { + let mut unevaluated = vec![]; + let mut no_null = true; + for expr in exprs { + match self.visit_expr(expr)? { + Const(Bool(false)) => return Ok(Const(Bool(false))), + Const(Bool(true)) => {} + Const(Null) => no_null = false, + Const(_) => return Err(TypeError), + Apply(Op::Or, vs) => { + for el in vs { + match el { + Const(Null) => no_null = false, + Const(_) => unreachable!(), + v => unevaluated.push(v) + } + } + } + v => unevaluated.push(v) + } + } + match (no_null, unevaluated.len()) { + (true, 0) => Ok(Const(Bool(true))), + (false, 0) => Ok(Const(Null)), + (true, _) => Ok(Apply(Op::Add, unevaluated)), + (false, _) => { + unevaluated.push(Const(Null)); + Ok(Apply(Op::And, unevaluated)) + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn operators() { + let mut ev = Evaluator {}; + + println!("{:#?}", ev.visit_expr(&parse_expr_from_str("1/10+(-2+3)*4^5").unwrap()).unwrap()); + println!("{:#?}", ev.visit_expr(&parse_expr_from_str("true && false").unwrap()).unwrap()); + println!("{:#?}", ev.visit_expr(&parse_expr_from_str("true || false").unwrap()).unwrap()); + println!("{:#?}", ev.visit_expr(&parse_expr_from_str("true || null").unwrap()).unwrap()); + println!("{:#?}", ev.visit_expr(&parse_expr_from_str("null || true").unwrap()).unwrap()); + println!("{:#?}", ev.visit_expr(&parse_expr_from_str("true && null").unwrap()).unwrap()); + } +} \ No newline at end of file diff --git a/src/function.rs b/src/function.rs index fcd2b0f8..035dee7a 100644 --- a/src/function.rs +++ b/src/function.rs @@ -1,6 +1,6 @@ use std::collections::BTreeMap; use std::fmt::{Debug, Formatter}; -use crate::typing::{PrimitiveType, Typing}; +use crate::typing::{BaseType, Typing}; use crate::value::Value; use lazy_static::lazy_static; @@ -55,16 +55,16 @@ lazy_static! { ret.insert("_add_int", Function { args: vec![], - var_arg: Some(Typing::Primitive(PrimitiveType::Int)), - ret_type: Typing::Primitive(PrimitiveType::Int), + var_arg: Some(Typing::Base(BaseType::Int)), + ret_type: Typing::Base(BaseType::Int), fn_impl: FunctionImpl::Native("_add_int", add_int) }); ret.insert("_add_float", Function { args: vec![], - var_arg: Some(Typing::Primitive(PrimitiveType::Float)), - ret_type: Typing::Primitive(PrimitiveType::Float), + var_arg: Some(Typing::Base(BaseType::Float)), + ret_type: Typing::Base(BaseType::Float), fn_impl: FunctionImpl::Native("_add_float", add_float) }); diff --git a/src/typing.rs b/src/typing.rs index e8db1dc8..26673ea2 100644 --- a/src/typing.rs +++ b/src/typing.rs @@ -2,7 +2,7 @@ use std::collections::BTreeMap; use crate::env::Env; #[derive(Debug, Eq, PartialEq)] -pub enum PrimitiveType { +pub enum BaseType { Bool, Int, UInt, @@ -39,7 +39,7 @@ pub enum PrimitiveType { #[derive(Debug, Eq, PartialEq)] pub enum Typing { Any, - Primitive(PrimitiveType), + Base(BaseType), HList(Box), Nullable(Box), Tuple(Vec), @@ -49,22 +49,22 @@ pub enum Typing { pub fn define_types>(env: &mut T) { env.define("Any", Typing::Any); - env.define("Bool", Typing::Primitive(PrimitiveType::Bool)); - env.define("Int", Typing::Primitive(PrimitiveType::Int)); - env.define("UInt", Typing::Primitive(PrimitiveType::UInt)); - env.define("Float", Typing::Primitive(PrimitiveType::Float)); - env.define("String", Typing::Primitive(PrimitiveType::String)); - env.define("Bytes", Typing::Primitive(PrimitiveType::U8Arr)); - env.define("U8Arr", Typing::Primitive(PrimitiveType::U8Arr)); - env.define("Uuid", Typing::Primitive(PrimitiveType::Uuid)); - env.define("Timestamp", Typing::Primitive(PrimitiveType::Timestamp)); - env.define("Datetime", Typing::Primitive(PrimitiveType::Datetime)); - env.define("Timezone", Typing::Primitive(PrimitiveType::Timezone)); - env.define("Date", Typing::Primitive(PrimitiveType::Date)); - env.define("Time", Typing::Primitive(PrimitiveType::Time)); - env.define("Duration", Typing::Primitive(PrimitiveType::Duration)); - env.define("BigInt", Typing::Primitive(PrimitiveType::BigInt)); - env.define("BigDecimal", Typing::Primitive(PrimitiveType::BigDecimal)); - env.define("Int", Typing::Primitive(PrimitiveType::Int)); - env.define("Crs", Typing::Primitive(PrimitiveType::Crs)); + env.define("Bool", Typing::Base(BaseType::Bool)); + env.define("Int", Typing::Base(BaseType::Int)); + env.define("UInt", Typing::Base(BaseType::UInt)); + env.define("Float", Typing::Base(BaseType::Float)); + env.define("String", Typing::Base(BaseType::String)); + env.define("Bytes", Typing::Base(BaseType::U8Arr)); + env.define("U8Arr", Typing::Base(BaseType::U8Arr)); + env.define("Uuid", Typing::Base(BaseType::Uuid)); + env.define("Timestamp", Typing::Base(BaseType::Timestamp)); + env.define("Datetime", Typing::Base(BaseType::Datetime)); + env.define("Timezone", Typing::Base(BaseType::Timezone)); + env.define("Date", Typing::Base(BaseType::Date)); + env.define("Time", Typing::Base(BaseType::Time)); + env.define("Duration", Typing::Base(BaseType::Duration)); + env.define("BigInt", Typing::Base(BaseType::BigInt)); + env.define("BigDecimal", Typing::Base(BaseType::BigDecimal)); + env.define("Int", Typing::Base(BaseType::Int)); + env.define("Crs", Typing::Base(BaseType::Crs)); } \ No newline at end of file