diff --git a/cozo-core/src/data/expr.rs b/cozo-core/src/data/expr.rs index 100a9f4e..3e0113ef 100644 --- a/cozo-core/src/data/expr.rs +++ b/cozo-core/src/data/expr.rs @@ -704,6 +704,7 @@ pub(crate) fn get_op(name: &str) -> Option<&'static Op> { "max" => &OP_MAX, "min" => &OP_MIN, "pow" => &OP_POW, + "sqrt" => &OP_SQRT, "exp" => &OP_EXP, "exp2" => &OP_EXP2, "ln" => &OP_LN, @@ -788,6 +789,10 @@ pub(crate) fn get_op(name: &str) -> Option<&'static Op> { "to_int" => &OP_TO_INT, "to_float" => &OP_TO_FLOAT, "to_string" => &OP_TO_STRING, + "l2_dist" => &OP_L2_DIST, + "l2_normalize" => &OP_L2_NORMALIZE, + "ip_dist" => &OP_IP_DIST, + "cos_dist" => &OP_COS_DIST, "rand_float" => &OP_RAND_FLOAT, "rand_bernoulli" => &OP_RAND_BERNOULLI, "rand_int" => &OP_RAND_INT, @@ -805,6 +810,8 @@ pub(crate) fn get_op(name: &str) -> Option<&'static Op> { "now" => &OP_NOW, "format_timestamp" => &OP_FORMAT_TIMESTAMP, "parse_timestamp" => &OP_PARSE_TIMESTAMP, + "vec" => &OP_VEC, + "rand_vec" => &OP_RAND_VEC, _ => return None, }) } diff --git a/cozo-core/src/data/functions.rs b/cozo-core/src/data/functions.rs index bd766438..8250a79a 100644 --- a/cozo-core/src/data/functions.rs +++ b/cozo-core/src/data/functions.rs @@ -27,7 +27,8 @@ use uuid::v1::Timestamp; use crate::data::expr::Op; use crate::data::json::JsonValue; -use crate::data::value::{DataValue, Num, RegexWrapper, UuidWrapper, Validity, ValidityTs}; +use crate::data::relation::VecElementType; +use crate::data::value::{DataValue, Num, RegexWrapper, UuidWrapper, Validity, ValidityTs, Vector}; macro_rules! define_op { ($name:ident, $min_arity:expr, $vararg:expr) => { @@ -158,6 +159,7 @@ pub(crate) fn op_add(args: &[DataValue]) -> Result { match arg { DataValue::Num(Num::Int(i)) => i_accum += i, DataValue::Num(Num::Float(f)) => f_accum += f, + DataValue::Vec(_) => return add_vecs(args), _ => bail!("addition requires numbers"), } } @@ -168,6 +170,58 @@ pub(crate) fn op_add(args: &[DataValue]) -> Result { } } +fn add_vecs(args: &[DataValue]) -> Result { + if args.len() == 1 { + return Ok(args[0].clone()); + } + let (last, first) = args.split_last().unwrap(); + let first = add_vecs(first)?; + match (first, last) { + (DataValue::Vec(a), DataValue::Vec(b)) => { + if a.len() != b.len() { + bail!("can only add vectors of the same length"); + } + match (a, b) { + (Vector::F32(a), Vector::F32(b)) => Ok(DataValue::Vec(Vector::F32(a + b))), + (Vector::F64(a), Vector::F64(b)) => Ok(DataValue::Vec(Vector::F64(a + b))), + (Vector::F32(a), Vector::F64(b)) => { + let a = a.mapv(|x| x as f64); + Ok(DataValue::Vec(Vector::F64(a + b))) + } + (Vector::F64(a), Vector::F32(b)) => { + let b = b.mapv(|x| x as f64); + Ok(DataValue::Vec(Vector::F64(a + b))) + } + } + } + (DataValue::Vec(a), b) => { + let f = b + .get_float() + .ok_or_else(|| miette!("can only add numbers to vectors"))?; + match a.clone() { + Vector::F32(mut v) => { + v += f as f32; + Ok(DataValue::Vec(Vector::F32(v))) + } + Vector::F64(mut v) => { + v += f; + Ok(DataValue::Vec(Vector::F64(v))) + } + } + } + (a, DataValue::Vec(b)) => { + let f = a + .get_float() + .ok_or_else(|| miette!("can only add numbers to vectors"))?; + match b { + Vector::F32(v) => Ok(DataValue::Vec(Vector::F32(v + f as f32))), + Vector::F64(v) => Ok(DataValue::Vec(Vector::F64(v + f))), + } + } + _ => bail!("addition requires numbers"), + } +} + define_op!(OP_MAX, 1, true); pub(crate) fn op_max(args: &[DataValue]) -> Result { let res = args @@ -213,6 +267,48 @@ pub(crate) fn op_sub(args: &[DataValue]) -> Result { (DataValue::Num(Num::Float(a)), DataValue::Num(Num::Int(b))) => { DataValue::Num(Num::Float(a - (*b as f64))) } + (DataValue::Vec(a), DataValue::Vec(b)) => match (a, b) { + (Vector::F32(a), Vector::F32(b)) => DataValue::Vec(Vector::F32(a - b)), + (Vector::F64(a), Vector::F64(b)) => DataValue::Vec(Vector::F64(a - b)), + (Vector::F32(a), Vector::F64(b)) => { + let a = a.mapv(|x| x as f64); + DataValue::Vec(Vector::F64(a - b)) + } + (Vector::F64(a), Vector::F32(b)) => { + let b = b.mapv(|x| x as f64); + DataValue::Vec(Vector::F64(a - b)) + } + }, + (DataValue::Vec(a), b) => { + let b = b + .get_float() + .ok_or_else(|| miette!("can only subtract numbers from vectors"))?; + match a.clone() { + Vector::F32(mut v) => { + v -= b as f32; + DataValue::Vec(Vector::F32(v)) + } + Vector::F64(mut v) => { + v -= b; + DataValue::Vec(Vector::F64(v)) + } + } + } + (a, DataValue::Vec(b)) => { + let a = a + .get_float() + .ok_or_else(|| miette!("can only subtract vectors from numbers"))?; + match b.clone() { + Vector::F32(mut v) => { + v -= a as f32; + DataValue::Vec(Vector::F32(-v)) + } + Vector::F64(mut v) => { + v -= a; + DataValue::Vec(Vector::F64(-v)) + } + } + } _ => bail!("subtraction requires numbers"), }) } @@ -225,6 +321,7 @@ pub(crate) fn op_mul(args: &[DataValue]) -> Result { match arg { DataValue::Num(Num::Int(i)) => i_accum *= i, DataValue::Num(Num::Float(f)) => f_accum *= f, + DataValue::Vec(_) => return mul_vecs(args), _ => bail!("multiplication requires numbers"), } } @@ -235,6 +332,58 @@ pub(crate) fn op_mul(args: &[DataValue]) -> Result { } } +fn mul_vecs(args: &[DataValue]) -> Result { + if args.len() == 1 { + return Ok(args[0].clone()); + } + let (last, first) = args.split_last().unwrap(); + let first = add_vecs(first)?; + match (first, last) { + (DataValue::Vec(a), DataValue::Vec(b)) => { + if a.len() != b.len() { + bail!("can only add vectors of the same length"); + } + match (a, b) { + (Vector::F32(a), Vector::F32(b)) => Ok(DataValue::Vec(Vector::F32(a * b))), + (Vector::F64(a), Vector::F64(b)) => Ok(DataValue::Vec(Vector::F64(a * b))), + (Vector::F32(a), Vector::F64(b)) => { + let a = a.mapv(|x| x as f64); + Ok(DataValue::Vec(Vector::F64(a * b))) + } + (Vector::F64(a), Vector::F32(b)) => { + let b = b.mapv(|x| x as f64); + Ok(DataValue::Vec(Vector::F64(a * b))) + } + } + } + (DataValue::Vec(a), b) => { + let f = b + .get_float() + .ok_or_else(|| miette!("can only add numbers to vectors"))?; + match a.clone() { + Vector::F32(mut v) => { + v *= f as f32; + Ok(DataValue::Vec(Vector::F32(v))) + } + Vector::F64(mut v) => { + v *= f; + Ok(DataValue::Vec(Vector::F64(v))) + } + } + } + (a, DataValue::Vec(b)) => { + let f = a + .get_float() + .ok_or_else(|| miette!("can only add numbers to vectors"))?; + match b { + Vector::F32(v) => Ok(DataValue::Vec(Vector::F32(v * f as f32))), + Vector::F64(v) => Ok(DataValue::Vec(Vector::F64(v * f))), + } + } + _ => bail!("addition requires numbers"), + } +} + define_op!(OP_DIV, 2, false); pub(crate) fn op_div(args: &[DataValue]) -> Result { Ok(match (&args[0], &args[1]) { @@ -250,6 +399,42 @@ pub(crate) fn op_div(args: &[DataValue]) -> Result { (DataValue::Num(Num::Float(a)), DataValue::Num(Num::Int(b))) => { DataValue::Num(Num::Float(a / (*b as f64))) } + (DataValue::Vec(a), DataValue::Vec(b)) => match (a, b) { + (Vector::F32(a), Vector::F32(b)) => DataValue::Vec(Vector::F32(a / b)), + (Vector::F64(a), Vector::F64(b)) => DataValue::Vec(Vector::F64(a / b)), + (Vector::F32(a), Vector::F64(b)) => { + let a = a.mapv(|x| x as f64); + DataValue::Vec(Vector::F64(a / b)) + } + (Vector::F64(a), Vector::F32(b)) => { + let b = b.mapv(|x| x as f64); + DataValue::Vec(Vector::F64(a / b)) + } + }, + (DataValue::Vec(a), b) => { + let b = b + .get_float() + .ok_or_else(|| miette!("can only subtract numbers from vectors"))?; + match a.clone() { + Vector::F32(mut v) => { + v /= b as f32; + DataValue::Vec(Vector::F32(v)) + } + Vector::F64(mut v) => { + v /= b; + DataValue::Vec(Vector::F64(v)) + } + } + } + (a, DataValue::Vec(b)) => { + let a = a + .get_float() + .ok_or_else(|| miette!("can only subtract vectors from numbers"))?; + match b { + Vector::F32(v) => DataValue::Vec(Vector::F32(a as f32 / v)), + Vector::F64(v) => DataValue::Vec(Vector::F64(a / v)), + } + } _ => bail!("division requires numbers"), }) } @@ -259,6 +444,8 @@ pub(crate) fn op_minus(args: &[DataValue]) -> Result { Ok(match &args[0] { DataValue::Num(Num::Int(i)) => DataValue::Num(Num::Int(-(*i))), DataValue::Num(Num::Float(f)) => DataValue::Num(Num::Float(-(*f))), + DataValue::Vec(Vector::F64(v)) => DataValue::Vec(Vector::F64(0. - v)), + DataValue::Vec(Vector::F32(v)) => DataValue::Vec(Vector::F32(0. - v)), _ => bail!("minus can only be applied to numbers"), }) } @@ -268,6 +455,8 @@ pub(crate) fn op_abs(args: &[DataValue]) -> Result { Ok(match &args[0] { DataValue::Num(Num::Int(i)) => DataValue::Num(Num::Int(i.abs())), DataValue::Num(Num::Float(f)) => DataValue::Num(Num::Float(f.abs())), + DataValue::Vec(Vector::F64(v)) => DataValue::Vec(Vector::F64(v.mapv(|x| x.abs()))), + DataValue::Vec(Vector::F32(v)) => DataValue::Vec(Vector::F32(v.mapv(|x| x.abs()))), _ => bail!("'abs' requires numbers"), }) } @@ -323,6 +512,12 @@ pub(crate) fn op_exp(args: &[DataValue]) -> Result { let a = match &args[0] { DataValue::Num(Num::Int(i)) => *i as f64, DataValue::Num(Num::Float(f)) => *f, + DataValue::Vec(Vector::F32(v)) => { + return Ok(DataValue::Vec(Vector::F32(v.mapv(|x| x.exp())))) + } + DataValue::Vec(Vector::F64(v)) => { + return Ok(DataValue::Vec(Vector::F64(v.mapv(|x| x.exp())))) + } _ => bail!("'exp' requires numbers"), }; Ok(DataValue::Num(Num::Float(a.exp()))) @@ -333,6 +528,12 @@ pub(crate) fn op_exp2(args: &[DataValue]) -> Result { let a = match &args[0] { DataValue::Num(Num::Int(i)) => *i as f64, DataValue::Num(Num::Float(f)) => *f, + DataValue::Vec(Vector::F32(v)) => { + return Ok(DataValue::Vec(Vector::F32(v.mapv(|x| x.exp2())))) + } + DataValue::Vec(Vector::F64(v)) => { + return Ok(DataValue::Vec(Vector::F64(v.mapv(|x| x.exp2())))) + } _ => bail!("'exp2' requires numbers"), }; Ok(DataValue::Num(Num::Float(a.exp2()))) @@ -343,6 +544,12 @@ pub(crate) fn op_ln(args: &[DataValue]) -> Result { let a = match &args[0] { DataValue::Num(Num::Int(i)) => *i as f64, DataValue::Num(Num::Float(f)) => *f, + DataValue::Vec(Vector::F32(v)) => { + return Ok(DataValue::Vec(Vector::F32(v.mapv(|x| x.ln())))) + } + DataValue::Vec(Vector::F64(v)) => { + return Ok(DataValue::Vec(Vector::F64(v.mapv(|x| x.ln())))) + } _ => bail!("'ln' requires numbers"), }; Ok(DataValue::Num(Num::Float(a.ln()))) @@ -353,6 +560,12 @@ pub(crate) fn op_log2(args: &[DataValue]) -> Result { let a = match &args[0] { DataValue::Num(Num::Int(i)) => *i as f64, DataValue::Num(Num::Float(f)) => *f, + DataValue::Vec(Vector::F32(v)) => { + return Ok(DataValue::Vec(Vector::F32(v.mapv(|x| x.log2())))) + } + DataValue::Vec(Vector::F64(v)) => { + return Ok(DataValue::Vec(Vector::F64(v.mapv(|x| x.log2())))) + } _ => bail!("'log2' requires numbers"), }; Ok(DataValue::Num(Num::Float(a.log2()))) @@ -363,6 +576,12 @@ pub(crate) fn op_log10(args: &[DataValue]) -> Result { let a = match &args[0] { DataValue::Num(Num::Int(i)) => *i as f64, DataValue::Num(Num::Float(f)) => *f, + DataValue::Vec(Vector::F32(v)) => { + return Ok(DataValue::Vec(Vector::F32(v.mapv(|x| x.log10())))) + } + DataValue::Vec(Vector::F64(v)) => { + return Ok(DataValue::Vec(Vector::F64(v.mapv(|x| x.log10())))) + } _ => bail!("'log10' requires numbers"), }; Ok(DataValue::Num(Num::Float(a.log10()))) @@ -373,6 +592,12 @@ pub(crate) fn op_sin(args: &[DataValue]) -> Result { let a = match &args[0] { DataValue::Num(Num::Int(i)) => *i as f64, DataValue::Num(Num::Float(f)) => *f, + DataValue::Vec(Vector::F32(v)) => { + return Ok(DataValue::Vec(Vector::F32(v.mapv(|x| x.sin())))) + } + DataValue::Vec(Vector::F64(v)) => { + return Ok(DataValue::Vec(Vector::F64(v.mapv(|x| x.sin())))) + } _ => bail!("'sin' requires numbers"), }; Ok(DataValue::Num(Num::Float(a.sin()))) @@ -383,6 +608,12 @@ pub(crate) fn op_cos(args: &[DataValue]) -> Result { let a = match &args[0] { DataValue::Num(Num::Int(i)) => *i as f64, DataValue::Num(Num::Float(f)) => *f, + DataValue::Vec(Vector::F32(v)) => { + return Ok(DataValue::Vec(Vector::F32(v.mapv(|x| x.cos())))) + } + DataValue::Vec(Vector::F64(v)) => { + return Ok(DataValue::Vec(Vector::F64(v.mapv(|x| x.cos())))) + } _ => bail!("'cos' requires numbers"), }; Ok(DataValue::Num(Num::Float(a.cos()))) @@ -393,6 +624,12 @@ pub(crate) fn op_tan(args: &[DataValue]) -> Result { let a = match &args[0] { DataValue::Num(Num::Int(i)) => *i as f64, DataValue::Num(Num::Float(f)) => *f, + DataValue::Vec(Vector::F32(v)) => { + return Ok(DataValue::Vec(Vector::F32(v.mapv(|x| x.tan())))) + } + DataValue::Vec(Vector::F64(v)) => { + return Ok(DataValue::Vec(Vector::F64(v.mapv(|x| x.tan())))) + } _ => bail!("'tan' requires numbers"), }; Ok(DataValue::Num(Num::Float(a.tan()))) @@ -403,6 +640,12 @@ pub(crate) fn op_asin(args: &[DataValue]) -> Result { let a = match &args[0] { DataValue::Num(Num::Int(i)) => *i as f64, DataValue::Num(Num::Float(f)) => *f, + DataValue::Vec(Vector::F32(v)) => { + return Ok(DataValue::Vec(Vector::F32(v.mapv(|x| x.asin())))) + } + DataValue::Vec(Vector::F64(v)) => { + return Ok(DataValue::Vec(Vector::F64(v.mapv(|x| x.asin())))) + } _ => bail!("'asin' requires numbers"), }; Ok(DataValue::Num(Num::Float(a.asin()))) @@ -413,6 +656,12 @@ pub(crate) fn op_acos(args: &[DataValue]) -> Result { let a = match &args[0] { DataValue::Num(Num::Int(i)) => *i as f64, DataValue::Num(Num::Float(f)) => *f, + DataValue::Vec(Vector::F32(v)) => { + return Ok(DataValue::Vec(Vector::F32(v.mapv(|x| x.acos())))) + } + DataValue::Vec(Vector::F64(v)) => { + return Ok(DataValue::Vec(Vector::F64(v.mapv(|x| x.acos())))) + } _ => bail!("'acos' requires numbers"), }; Ok(DataValue::Num(Num::Float(a.acos()))) @@ -423,6 +672,12 @@ pub(crate) fn op_atan(args: &[DataValue]) -> Result { let a = match &args[0] { DataValue::Num(Num::Int(i)) => *i as f64, DataValue::Num(Num::Float(f)) => *f, + DataValue::Vec(Vector::F32(v)) => { + return Ok(DataValue::Vec(Vector::F32(v.mapv(|x| x.atan())))) + } + DataValue::Vec(Vector::F64(v)) => { + return Ok(DataValue::Vec(Vector::F64(v.mapv(|x| x.atan())))) + } _ => bail!("'atan' requires numbers"), }; Ok(DataValue::Num(Num::Float(a.atan()))) @@ -449,6 +704,12 @@ pub(crate) fn op_sinh(args: &[DataValue]) -> Result { let a = match &args[0] { DataValue::Num(Num::Int(i)) => *i as f64, DataValue::Num(Num::Float(f)) => *f, + DataValue::Vec(Vector::F32(v)) => { + return Ok(DataValue::Vec(Vector::F32(v.mapv(|x| x.sinh())))) + } + DataValue::Vec(Vector::F64(v)) => { + return Ok(DataValue::Vec(Vector::F64(v.mapv(|x| x.sinh())))) + } _ => bail!("'sinh' requires numbers"), }; Ok(DataValue::Num(Num::Float(a.sinh()))) @@ -459,6 +720,12 @@ pub(crate) fn op_cosh(args: &[DataValue]) -> Result { let a = match &args[0] { DataValue::Num(Num::Int(i)) => *i as f64, DataValue::Num(Num::Float(f)) => *f, + DataValue::Vec(Vector::F32(v)) => { + return Ok(DataValue::Vec(Vector::F32(v.mapv(|x| x.cosh())))) + } + DataValue::Vec(Vector::F64(v)) => { + return Ok(DataValue::Vec(Vector::F64(v.mapv(|x| x.cosh())))) + } _ => bail!("'cosh' requires numbers"), }; Ok(DataValue::Num(Num::Float(a.cosh()))) @@ -469,6 +736,12 @@ pub(crate) fn op_tanh(args: &[DataValue]) -> Result { let a = match &args[0] { DataValue::Num(Num::Int(i)) => *i as f64, DataValue::Num(Num::Float(f)) => *f, + DataValue::Vec(Vector::F32(v)) => { + return Ok(DataValue::Vec(Vector::F32(v.mapv(|x| x.tanh())))) + } + DataValue::Vec(Vector::F64(v)) => { + return Ok(DataValue::Vec(Vector::F64(v.mapv(|x| x.tanh())))) + } _ => bail!("'tanh' requires numbers"), }; Ok(DataValue::Num(Num::Float(a.tanh()))) @@ -479,6 +752,12 @@ pub(crate) fn op_asinh(args: &[DataValue]) -> Result { let a = match &args[0] { DataValue::Num(Num::Int(i)) => *i as f64, DataValue::Num(Num::Float(f)) => *f, + DataValue::Vec(Vector::F32(v)) => { + return Ok(DataValue::Vec(Vector::F32(v.mapv(|x| x.asinh())))) + } + DataValue::Vec(Vector::F64(v)) => { + return Ok(DataValue::Vec(Vector::F64(v.mapv(|x| x.asinh())))) + } _ => bail!("'asinh' requires numbers"), }; Ok(DataValue::Num(Num::Float(a.asinh()))) @@ -489,6 +768,12 @@ pub(crate) fn op_acosh(args: &[DataValue]) -> Result { let a = match &args[0] { DataValue::Num(Num::Int(i)) => *i as f64, DataValue::Num(Num::Float(f)) => *f, + DataValue::Vec(Vector::F32(v)) => { + return Ok(DataValue::Vec(Vector::F32(v.mapv(|x| x.acosh())))) + } + DataValue::Vec(Vector::F64(v)) => { + return Ok(DataValue::Vec(Vector::F64(v.mapv(|x| x.acosh())))) + } _ => bail!("'acosh' requires numbers"), }; Ok(DataValue::Num(Num::Float(a.acosh()))) @@ -499,16 +784,50 @@ pub(crate) fn op_atanh(args: &[DataValue]) -> Result { let a = match &args[0] { DataValue::Num(Num::Int(i)) => *i as f64, DataValue::Num(Num::Float(f)) => *f, + DataValue::Vec(Vector::F32(v)) => { + return Ok(DataValue::Vec(Vector::F32(v.mapv(|x| x.atanh())))) + } + DataValue::Vec(Vector::F64(v)) => { + return Ok(DataValue::Vec(Vector::F64(v.mapv(|x| x.atanh())))) + } _ => bail!("'atanh' requires numbers"), }; Ok(DataValue::Num(Num::Float(a.atanh()))) } +define_op!(OP_SQRT, 1, false); +pub(crate) fn op_sqrt(args: &[DataValue]) -> Result { + let a = match &args[0] { + DataValue::Num(Num::Int(i)) => *i as f64, + DataValue::Num(Num::Float(f)) => *f, + DataValue::Vec(Vector::F32(v)) => { + return Ok(DataValue::Vec(Vector::F32(v.mapv(|x| x.sqrt())))) + } + DataValue::Vec(Vector::F64(v)) => { + return Ok(DataValue::Vec(Vector::F64(v.mapv(|x| x.sqrt())))) + } + _ => bail!("'sqrt' requires numbers"), + }; + Ok(DataValue::Num(Num::Float(a.sqrt()))) +} + define_op!(OP_POW, 2, false); pub(crate) fn op_pow(args: &[DataValue]) -> Result { let a = match &args[0] { DataValue::Num(Num::Int(i)) => *i as f64, DataValue::Num(Num::Float(f)) => *f, + DataValue::Vec(Vector::F32(v)) => { + let b = args[1] + .get_float() + .ok_or_else(|| miette!("'pow' requires numbers"))?; + return Ok(DataValue::Vec(Vector::F32(v.mapv(|x| x.powf(b as f32))))); + } + DataValue::Vec(Vector::F64(v)) => { + let b = args[1] + .get_float() + .ok_or_else(|| miette!("'pow' requires numbers"))?; + return Ok(DataValue::Vec(Vector::F64(v.mapv(|x| x.powf(b))))); + } _ => bail!("'pow' requires numbers"), }; let b = match &args[1] { @@ -938,6 +1257,11 @@ pub(crate) fn op_is_list(args: &[DataValue]) -> Result { ))) } +define_op!(OP_IS_VEC, 1, false); +pub(crate) fn op_is_vec(args: &[DataValue]) -> Result { + Ok(DataValue::from(matches!(args[0], DataValue::Vec(_)))) +} + define_op!(OP_APPEND, 2, false); pub(crate) fn op_append(args: &[DataValue]) -> Result { match &args[0] { @@ -984,6 +1308,7 @@ pub(crate) fn op_length(args: &[DataValue]) -> Result { DataValue::List(l) => l.len() as i64, DataValue::Str(s) => s.chars().count() as i64, DataValue::Bytes(b) => b.len() as i64, + DataValue::Vec(v) => v.len() as i64, _ => bail!("'length' requires lists"), })) } @@ -1353,6 +1678,178 @@ pub(crate) fn op_to_string(args: &[DataValue]) -> Result { }) } +define_op!(OP_VEC, 1, true); +pub(crate) fn op_vec(args: &[DataValue]) -> Result { + let t = match args.get(1) { + Some(DataValue::Str(s)) => match s as &str { + "F32" | "Float" => VecElementType::F32, + "F64" | "Double" => VecElementType::F64, + _ => bail!("'vec' does not recognize type {}", s), + }, + None => VecElementType::F32, + _ => bail!("'vec' requires a string as second argument"), + }; + + match &args[0] { + DataValue::List(l) => match t { + VecElementType::F32 => { + let mut res_arr = ndarray::Array1::zeros(l.len()); + for (mut row, el) in res_arr.axis_iter_mut(ndarray::Axis(0)).zip(l.iter()) { + let f = el + .get_float() + .ok_or_else(|| miette!("'vec' requires a list of numbers"))?; + row.fill(f as f32); + } + Ok(DataValue::Vec(Vector::F32(res_arr))) + } + VecElementType::F64 => { + let mut res_arr = ndarray::Array1::zeros(l.len()); + for (mut row, el) in res_arr.axis_iter_mut(ndarray::Axis(0)).zip(l.iter()) { + let f = el + .get_float() + .ok_or_else(|| miette!("'vec' requires a list of numbers"))?; + row.fill(f as f64); + } + Ok(DataValue::Vec(Vector::F64(res_arr))) + } + }, + DataValue::Vec(v) => match (t, v) { + (VecElementType::F32, Vector::F32(v)) => Ok(DataValue::Vec(Vector::F32(v.clone()))), + (VecElementType::F64, Vector::F64(v)) => Ok(DataValue::Vec(Vector::F64(v.clone()))), + (VecElementType::F32, Vector::F64(v)) => { + Ok(DataValue::Vec(Vector::F32(v.mapv(|x| x as f32)))) + } + (VecElementType::F64, Vector::F32(v)) => { + Ok(DataValue::Vec(Vector::F64(v.mapv(|x| x as f64)))) + } + }, + _ => bail!("'vec' requires a list or a vector"), + } +} + +define_op!(OP_RAND_VEC, 1, true); +pub(crate) fn op_rand_vec(args: &[DataValue]) -> Result { + let len = args[0] + .get_int() + .ok_or_else(|| miette!("'rand_vec' requires an integer"))? as usize; + let t = match args.get(1) { + Some(DataValue::Str(s)) => match s as &str { + "F32" | "Float" => VecElementType::F32, + "F64" | "Double" => VecElementType::F64, + _ => bail!("'vec' does not recognize type {}", s), + }, + None => VecElementType::F32, + _ => bail!("'vec' requires a string as second argument"), + }; + + let mut rng = thread_rng(); + match t { + VecElementType::F32 => { + let mut res_arr = ndarray::Array1::zeros(len); + for mut row in res_arr.axis_iter_mut(ndarray::Axis(0)) { + row.fill(rng.gen::() as f32); + } + Ok(DataValue::Vec(Vector::F32(res_arr))) + } + VecElementType::F64 => { + let mut res_arr = ndarray::Array1::zeros(len); + for mut row in res_arr.axis_iter_mut(ndarray::Axis(0)) { + row.fill(rng.gen::()); + } + Ok(DataValue::Vec(Vector::F64(res_arr))) + } + } +} + +define_op!(OP_L2_NORMALIZE, 1, false); +pub(crate) fn op_l2_normalize(args: &[DataValue]) -> Result { + let a = &args[0]; + match a { + DataValue::Vec(Vector::F32(a)) => { + let norm = a.dot(a).sqrt(); + Ok(DataValue::Vec(Vector::F32(a / norm))) + } + DataValue::Vec(Vector::F64(a)) => { + let norm = a.dot(a).sqrt(); + Ok(DataValue::Vec(Vector::F64(a / norm))) + } + _ => bail!("'l2_normalize' requires a vector"), + } +} + +define_op!(OP_L2_DIST, 2, false); +pub(crate) fn op_l2_dist(args: &[DataValue]) -> Result { + let a = &args[0]; + let b = &args[1]; + match (a, b) { + (DataValue::Vec(Vector::F32(a)), DataValue::Vec(Vector::F32(b))) => { + if a.len() != b.len() { + bail!("'l2_dist' requires two vectors of the same length"); + } + let diff = a - b; + Ok(DataValue::from(diff.dot(&diff) as f64)) + } + (DataValue::Vec(Vector::F64(a)), DataValue::Vec(Vector::F64(b))) => { + if a.len() != b.len() { + bail!("'l2_dist' requires two vectors of the same length"); + } + let diff = a - b; + Ok(DataValue::from(diff.dot(&diff) as f64)) + } + _ => bail!("'l2_dist' requires two vectors of the same type"), + } +} + +define_op!(OP_IP_DIST, 2, false); +pub(crate) fn op_ip_dist(args: &[DataValue]) -> Result { + let a = &args[0]; + let b = &args[1]; + match (a, b) { + (DataValue::Vec(Vector::F32(a)), DataValue::Vec(Vector::F32(b))) => { + if a.len() != b.len() { + bail!("'ip_dist' requires two vectors of the same length"); + } + let dot = a.dot(b); + Ok(DataValue::from(1. - dot as f64)) + } + (DataValue::Vec(Vector::F64(a)), DataValue::Vec(Vector::F64(b))) => { + if a.len() != b.len() { + bail!("'ip_dist' requires two vectors of the same length"); + } + let dot = a.dot(b); + Ok(DataValue::from(1. - dot as f64)) + } + _ => bail!("'ip_dist' requires two vectors of the same type"), + } +} + +define_op!(OP_COS_DIST, 2, false); +pub(crate) fn op_cos_dist(args: &[DataValue]) -> Result { + let a = &args[0]; + let b = &args[1]; + match (a, b) { + (DataValue::Vec(Vector::F32(a)), DataValue::Vec(Vector::F32(b))) => { + if a.len() != b.len() { + bail!("'cos_dist' requires two vectors of the same length"); + } + let a_norm = a.dot(a) as f64; + let b_norm = b.dot(b) as f64; + let dot = a.dot(b) as f64; + Ok(DataValue::from(1. - dot / (a_norm * b_norm).sqrt())) + } + (DataValue::Vec(Vector::F64(a)), DataValue::Vec(Vector::F64(b))) => { + if a.len() != b.len() { + bail!("'cos_dist' requires two vectors of the same length"); + } + let a_norm = a.dot(a); + let b_norm = b.dot(b); + let dot = a.dot(b); + Ok(DataValue::from(1. - dot / (a_norm * b_norm).sqrt())) + } + _ => bail!("'cos_dist' requires two vectors of the same type"), + } +} + define_op!(OP_RAND_FLOAT, 0, false); pub(crate) fn op_rand_float(_args: &[DataValue]) -> Result { Ok(thread_rng().gen::().into()) diff --git a/cozo-core/src/runtime/tests.rs b/cozo-core/src/runtime/tests.rs index 264ef891..e2000dec 100644 --- a/cozo-core/src/runtime/tests.rs +++ b/cozo-core/src/runtime/tests.rs @@ -756,6 +756,27 @@ fn test_vec_types() { json!([1., 2., 3., 4., 5., 6., 7., 8.]), res.into_json()["rows"][0][1] ); + let res = db + .run_script("?[v] <- [[vec([1,2,3,4,5,6,7,8])]]", Default::default()) + .unwrap(); + assert_eq!( + json!([1., 2., 3., 4., 5., 6., 7., 8.]), + res.into_json()["rows"][0][0] + ); + let res = db + .run_script("?[v] <- [[rand_vec(5)]]", Default::default()) + .unwrap(); + assert_eq!( + 5, + res.into_json()["rows"][0][0].as_array().unwrap().len() + ); + let res = db + .run_script(r#" + val[v] <- [[vec([1,2,3,4,5,6,7,8])]] + ?[x,y,z] := val[v], x=l2_dist(v, v), y=cos_dist(v, v), nv = l2_normalize(v), z=ip_dist(nv, nv) + "#, Default::default()) + .unwrap(); + println!("{}", res.into_json()); } #[test]