vector functions

main
Ziyang Hu 1 year ago
parent ba950d4a19
commit f579b971da

@ -704,6 +704,7 @@ pub(crate) fn get_op(name: &str) -> Option<&'static Op> {
"max" => &OP_MAX,
"min" => &OP_MIN,
"pow" => &OP_POW,
"sqrt" => &OP_SQRT,
"exp" => &OP_EXP,
"exp2" => &OP_EXP2,
"ln" => &OP_LN,
@ -788,6 +789,10 @@ pub(crate) fn get_op(name: &str) -> Option<&'static Op> {
"to_int" => &OP_TO_INT,
"to_float" => &OP_TO_FLOAT,
"to_string" => &OP_TO_STRING,
"l2_dist" => &OP_L2_DIST,
"l2_normalize" => &OP_L2_NORMALIZE,
"ip_dist" => &OP_IP_DIST,
"cos_dist" => &OP_COS_DIST,
"rand_float" => &OP_RAND_FLOAT,
"rand_bernoulli" => &OP_RAND_BERNOULLI,
"rand_int" => &OP_RAND_INT,
@ -805,6 +810,8 @@ pub(crate) fn get_op(name: &str) -> Option<&'static Op> {
"now" => &OP_NOW,
"format_timestamp" => &OP_FORMAT_TIMESTAMP,
"parse_timestamp" => &OP_PARSE_TIMESTAMP,
"vec" => &OP_VEC,
"rand_vec" => &OP_RAND_VEC,
_ => return None,
})
}

@ -27,7 +27,8 @@ use uuid::v1::Timestamp;
use crate::data::expr::Op;
use crate::data::json::JsonValue;
use crate::data::value::{DataValue, Num, RegexWrapper, UuidWrapper, Validity, ValidityTs};
use crate::data::relation::VecElementType;
use crate::data::value::{DataValue, Num, RegexWrapper, UuidWrapper, Validity, ValidityTs, Vector};
macro_rules! define_op {
($name:ident, $min_arity:expr, $vararg:expr) => {
@ -158,6 +159,7 @@ pub(crate) fn op_add(args: &[DataValue]) -> Result<DataValue> {
match arg {
DataValue::Num(Num::Int(i)) => i_accum += i,
DataValue::Num(Num::Float(f)) => f_accum += f,
DataValue::Vec(_) => return add_vecs(args),
_ => bail!("addition requires numbers"),
}
}
@ -168,6 +170,58 @@ pub(crate) fn op_add(args: &[DataValue]) -> Result<DataValue> {
}
}
fn add_vecs(args: &[DataValue]) -> Result<DataValue> {
if args.len() == 1 {
return Ok(args[0].clone());
}
let (last, first) = args.split_last().unwrap();
let first = add_vecs(first)?;
match (first, last) {
(DataValue::Vec(a), DataValue::Vec(b)) => {
if a.len() != b.len() {
bail!("can only add vectors of the same length");
}
match (a, b) {
(Vector::F32(a), Vector::F32(b)) => Ok(DataValue::Vec(Vector::F32(a + b))),
(Vector::F64(a), Vector::F64(b)) => Ok(DataValue::Vec(Vector::F64(a + b))),
(Vector::F32(a), Vector::F64(b)) => {
let a = a.mapv(|x| x as f64);
Ok(DataValue::Vec(Vector::F64(a + b)))
}
(Vector::F64(a), Vector::F32(b)) => {
let b = b.mapv(|x| x as f64);
Ok(DataValue::Vec(Vector::F64(a + b)))
}
}
}
(DataValue::Vec(a), b) => {
let f = b
.get_float()
.ok_or_else(|| miette!("can only add numbers to vectors"))?;
match a.clone() {
Vector::F32(mut v) => {
v += f as f32;
Ok(DataValue::Vec(Vector::F32(v)))
}
Vector::F64(mut v) => {
v += f;
Ok(DataValue::Vec(Vector::F64(v)))
}
}
}
(a, DataValue::Vec(b)) => {
let f = a
.get_float()
.ok_or_else(|| miette!("can only add numbers to vectors"))?;
match b {
Vector::F32(v) => Ok(DataValue::Vec(Vector::F32(v + f as f32))),
Vector::F64(v) => Ok(DataValue::Vec(Vector::F64(v + f))),
}
}
_ => bail!("addition requires numbers"),
}
}
define_op!(OP_MAX, 1, true);
pub(crate) fn op_max(args: &[DataValue]) -> Result<DataValue> {
let res = args
@ -213,6 +267,48 @@ pub(crate) fn op_sub(args: &[DataValue]) -> Result<DataValue> {
(DataValue::Num(Num::Float(a)), DataValue::Num(Num::Int(b))) => {
DataValue::Num(Num::Float(a - (*b as f64)))
}
(DataValue::Vec(a), DataValue::Vec(b)) => match (a, b) {
(Vector::F32(a), Vector::F32(b)) => DataValue::Vec(Vector::F32(a - b)),
(Vector::F64(a), Vector::F64(b)) => DataValue::Vec(Vector::F64(a - b)),
(Vector::F32(a), Vector::F64(b)) => {
let a = a.mapv(|x| x as f64);
DataValue::Vec(Vector::F64(a - b))
}
(Vector::F64(a), Vector::F32(b)) => {
let b = b.mapv(|x| x as f64);
DataValue::Vec(Vector::F64(a - b))
}
},
(DataValue::Vec(a), b) => {
let b = b
.get_float()
.ok_or_else(|| miette!("can only subtract numbers from vectors"))?;
match a.clone() {
Vector::F32(mut v) => {
v -= b as f32;
DataValue::Vec(Vector::F32(v))
}
Vector::F64(mut v) => {
v -= b;
DataValue::Vec(Vector::F64(v))
}
}
}
(a, DataValue::Vec(b)) => {
let a = a
.get_float()
.ok_or_else(|| miette!("can only subtract vectors from numbers"))?;
match b.clone() {
Vector::F32(mut v) => {
v -= a as f32;
DataValue::Vec(Vector::F32(-v))
}
Vector::F64(mut v) => {
v -= a;
DataValue::Vec(Vector::F64(-v))
}
}
}
_ => bail!("subtraction requires numbers"),
})
}
@ -225,6 +321,7 @@ pub(crate) fn op_mul(args: &[DataValue]) -> Result<DataValue> {
match arg {
DataValue::Num(Num::Int(i)) => i_accum *= i,
DataValue::Num(Num::Float(f)) => f_accum *= f,
DataValue::Vec(_) => return mul_vecs(args),
_ => bail!("multiplication requires numbers"),
}
}
@ -235,6 +332,58 @@ pub(crate) fn op_mul(args: &[DataValue]) -> Result<DataValue> {
}
}
fn mul_vecs(args: &[DataValue]) -> Result<DataValue> {
if args.len() == 1 {
return Ok(args[0].clone());
}
let (last, first) = args.split_last().unwrap();
let first = add_vecs(first)?;
match (first, last) {
(DataValue::Vec(a), DataValue::Vec(b)) => {
if a.len() != b.len() {
bail!("can only add vectors of the same length");
}
match (a, b) {
(Vector::F32(a), Vector::F32(b)) => Ok(DataValue::Vec(Vector::F32(a * b))),
(Vector::F64(a), Vector::F64(b)) => Ok(DataValue::Vec(Vector::F64(a * b))),
(Vector::F32(a), Vector::F64(b)) => {
let a = a.mapv(|x| x as f64);
Ok(DataValue::Vec(Vector::F64(a * b)))
}
(Vector::F64(a), Vector::F32(b)) => {
let b = b.mapv(|x| x as f64);
Ok(DataValue::Vec(Vector::F64(a * b)))
}
}
}
(DataValue::Vec(a), b) => {
let f = b
.get_float()
.ok_or_else(|| miette!("can only add numbers to vectors"))?;
match a.clone() {
Vector::F32(mut v) => {
v *= f as f32;
Ok(DataValue::Vec(Vector::F32(v)))
}
Vector::F64(mut v) => {
v *= f;
Ok(DataValue::Vec(Vector::F64(v)))
}
}
}
(a, DataValue::Vec(b)) => {
let f = a
.get_float()
.ok_or_else(|| miette!("can only add numbers to vectors"))?;
match b {
Vector::F32(v) => Ok(DataValue::Vec(Vector::F32(v * f as f32))),
Vector::F64(v) => Ok(DataValue::Vec(Vector::F64(v * f))),
}
}
_ => bail!("addition requires numbers"),
}
}
define_op!(OP_DIV, 2, false);
pub(crate) fn op_div(args: &[DataValue]) -> Result<DataValue> {
Ok(match (&args[0], &args[1]) {
@ -250,6 +399,42 @@ pub(crate) fn op_div(args: &[DataValue]) -> Result<DataValue> {
(DataValue::Num(Num::Float(a)), DataValue::Num(Num::Int(b))) => {
DataValue::Num(Num::Float(a / (*b as f64)))
}
(DataValue::Vec(a), DataValue::Vec(b)) => match (a, b) {
(Vector::F32(a), Vector::F32(b)) => DataValue::Vec(Vector::F32(a / b)),
(Vector::F64(a), Vector::F64(b)) => DataValue::Vec(Vector::F64(a / b)),
(Vector::F32(a), Vector::F64(b)) => {
let a = a.mapv(|x| x as f64);
DataValue::Vec(Vector::F64(a / b))
}
(Vector::F64(a), Vector::F32(b)) => {
let b = b.mapv(|x| x as f64);
DataValue::Vec(Vector::F64(a / b))
}
},
(DataValue::Vec(a), b) => {
let b = b
.get_float()
.ok_or_else(|| miette!("can only subtract numbers from vectors"))?;
match a.clone() {
Vector::F32(mut v) => {
v /= b as f32;
DataValue::Vec(Vector::F32(v))
}
Vector::F64(mut v) => {
v /= b;
DataValue::Vec(Vector::F64(v))
}
}
}
(a, DataValue::Vec(b)) => {
let a = a
.get_float()
.ok_or_else(|| miette!("can only subtract vectors from numbers"))?;
match b {
Vector::F32(v) => DataValue::Vec(Vector::F32(a as f32 / v)),
Vector::F64(v) => DataValue::Vec(Vector::F64(a / v)),
}
}
_ => bail!("division requires numbers"),
})
}
@ -259,6 +444,8 @@ pub(crate) fn op_minus(args: &[DataValue]) -> Result<DataValue> {
Ok(match &args[0] {
DataValue::Num(Num::Int(i)) => DataValue::Num(Num::Int(-(*i))),
DataValue::Num(Num::Float(f)) => DataValue::Num(Num::Float(-(*f))),
DataValue::Vec(Vector::F64(v)) => DataValue::Vec(Vector::F64(0. - v)),
DataValue::Vec(Vector::F32(v)) => DataValue::Vec(Vector::F32(0. - v)),
_ => bail!("minus can only be applied to numbers"),
})
}
@ -268,6 +455,8 @@ pub(crate) fn op_abs(args: &[DataValue]) -> Result<DataValue> {
Ok(match &args[0] {
DataValue::Num(Num::Int(i)) => DataValue::Num(Num::Int(i.abs())),
DataValue::Num(Num::Float(f)) => DataValue::Num(Num::Float(f.abs())),
DataValue::Vec(Vector::F64(v)) => DataValue::Vec(Vector::F64(v.mapv(|x| x.abs()))),
DataValue::Vec(Vector::F32(v)) => DataValue::Vec(Vector::F32(v.mapv(|x| x.abs()))),
_ => bail!("'abs' requires numbers"),
})
}
@ -323,6 +512,12 @@ pub(crate) fn op_exp(args: &[DataValue]) -> Result<DataValue> {
let a = match &args[0] {
DataValue::Num(Num::Int(i)) => *i as f64,
DataValue::Num(Num::Float(f)) => *f,
DataValue::Vec(Vector::F32(v)) => {
return Ok(DataValue::Vec(Vector::F32(v.mapv(|x| x.exp()))))
}
DataValue::Vec(Vector::F64(v)) => {
return Ok(DataValue::Vec(Vector::F64(v.mapv(|x| x.exp()))))
}
_ => bail!("'exp' requires numbers"),
};
Ok(DataValue::Num(Num::Float(a.exp())))
@ -333,6 +528,12 @@ pub(crate) fn op_exp2(args: &[DataValue]) -> Result<DataValue> {
let a = match &args[0] {
DataValue::Num(Num::Int(i)) => *i as f64,
DataValue::Num(Num::Float(f)) => *f,
DataValue::Vec(Vector::F32(v)) => {
return Ok(DataValue::Vec(Vector::F32(v.mapv(|x| x.exp2()))))
}
DataValue::Vec(Vector::F64(v)) => {
return Ok(DataValue::Vec(Vector::F64(v.mapv(|x| x.exp2()))))
}
_ => bail!("'exp2' requires numbers"),
};
Ok(DataValue::Num(Num::Float(a.exp2())))
@ -343,6 +544,12 @@ pub(crate) fn op_ln(args: &[DataValue]) -> Result<DataValue> {
let a = match &args[0] {
DataValue::Num(Num::Int(i)) => *i as f64,
DataValue::Num(Num::Float(f)) => *f,
DataValue::Vec(Vector::F32(v)) => {
return Ok(DataValue::Vec(Vector::F32(v.mapv(|x| x.ln()))))
}
DataValue::Vec(Vector::F64(v)) => {
return Ok(DataValue::Vec(Vector::F64(v.mapv(|x| x.ln()))))
}
_ => bail!("'ln' requires numbers"),
};
Ok(DataValue::Num(Num::Float(a.ln())))
@ -353,6 +560,12 @@ pub(crate) fn op_log2(args: &[DataValue]) -> Result<DataValue> {
let a = match &args[0] {
DataValue::Num(Num::Int(i)) => *i as f64,
DataValue::Num(Num::Float(f)) => *f,
DataValue::Vec(Vector::F32(v)) => {
return Ok(DataValue::Vec(Vector::F32(v.mapv(|x| x.log2()))))
}
DataValue::Vec(Vector::F64(v)) => {
return Ok(DataValue::Vec(Vector::F64(v.mapv(|x| x.log2()))))
}
_ => bail!("'log2' requires numbers"),
};
Ok(DataValue::Num(Num::Float(a.log2())))
@ -363,6 +576,12 @@ pub(crate) fn op_log10(args: &[DataValue]) -> Result<DataValue> {
let a = match &args[0] {
DataValue::Num(Num::Int(i)) => *i as f64,
DataValue::Num(Num::Float(f)) => *f,
DataValue::Vec(Vector::F32(v)) => {
return Ok(DataValue::Vec(Vector::F32(v.mapv(|x| x.log10()))))
}
DataValue::Vec(Vector::F64(v)) => {
return Ok(DataValue::Vec(Vector::F64(v.mapv(|x| x.log10()))))
}
_ => bail!("'log10' requires numbers"),
};
Ok(DataValue::Num(Num::Float(a.log10())))
@ -373,6 +592,12 @@ pub(crate) fn op_sin(args: &[DataValue]) -> Result<DataValue> {
let a = match &args[0] {
DataValue::Num(Num::Int(i)) => *i as f64,
DataValue::Num(Num::Float(f)) => *f,
DataValue::Vec(Vector::F32(v)) => {
return Ok(DataValue::Vec(Vector::F32(v.mapv(|x| x.sin()))))
}
DataValue::Vec(Vector::F64(v)) => {
return Ok(DataValue::Vec(Vector::F64(v.mapv(|x| x.sin()))))
}
_ => bail!("'sin' requires numbers"),
};
Ok(DataValue::Num(Num::Float(a.sin())))
@ -383,6 +608,12 @@ pub(crate) fn op_cos(args: &[DataValue]) -> Result<DataValue> {
let a = match &args[0] {
DataValue::Num(Num::Int(i)) => *i as f64,
DataValue::Num(Num::Float(f)) => *f,
DataValue::Vec(Vector::F32(v)) => {
return Ok(DataValue::Vec(Vector::F32(v.mapv(|x| x.cos()))))
}
DataValue::Vec(Vector::F64(v)) => {
return Ok(DataValue::Vec(Vector::F64(v.mapv(|x| x.cos()))))
}
_ => bail!("'cos' requires numbers"),
};
Ok(DataValue::Num(Num::Float(a.cos())))
@ -393,6 +624,12 @@ pub(crate) fn op_tan(args: &[DataValue]) -> Result<DataValue> {
let a = match &args[0] {
DataValue::Num(Num::Int(i)) => *i as f64,
DataValue::Num(Num::Float(f)) => *f,
DataValue::Vec(Vector::F32(v)) => {
return Ok(DataValue::Vec(Vector::F32(v.mapv(|x| x.tan()))))
}
DataValue::Vec(Vector::F64(v)) => {
return Ok(DataValue::Vec(Vector::F64(v.mapv(|x| x.tan()))))
}
_ => bail!("'tan' requires numbers"),
};
Ok(DataValue::Num(Num::Float(a.tan())))
@ -403,6 +640,12 @@ pub(crate) fn op_asin(args: &[DataValue]) -> Result<DataValue> {
let a = match &args[0] {
DataValue::Num(Num::Int(i)) => *i as f64,
DataValue::Num(Num::Float(f)) => *f,
DataValue::Vec(Vector::F32(v)) => {
return Ok(DataValue::Vec(Vector::F32(v.mapv(|x| x.asin()))))
}
DataValue::Vec(Vector::F64(v)) => {
return Ok(DataValue::Vec(Vector::F64(v.mapv(|x| x.asin()))))
}
_ => bail!("'asin' requires numbers"),
};
Ok(DataValue::Num(Num::Float(a.asin())))
@ -413,6 +656,12 @@ pub(crate) fn op_acos(args: &[DataValue]) -> Result<DataValue> {
let a = match &args[0] {
DataValue::Num(Num::Int(i)) => *i as f64,
DataValue::Num(Num::Float(f)) => *f,
DataValue::Vec(Vector::F32(v)) => {
return Ok(DataValue::Vec(Vector::F32(v.mapv(|x| x.acos()))))
}
DataValue::Vec(Vector::F64(v)) => {
return Ok(DataValue::Vec(Vector::F64(v.mapv(|x| x.acos()))))
}
_ => bail!("'acos' requires numbers"),
};
Ok(DataValue::Num(Num::Float(a.acos())))
@ -423,6 +672,12 @@ pub(crate) fn op_atan(args: &[DataValue]) -> Result<DataValue> {
let a = match &args[0] {
DataValue::Num(Num::Int(i)) => *i as f64,
DataValue::Num(Num::Float(f)) => *f,
DataValue::Vec(Vector::F32(v)) => {
return Ok(DataValue::Vec(Vector::F32(v.mapv(|x| x.atan()))))
}
DataValue::Vec(Vector::F64(v)) => {
return Ok(DataValue::Vec(Vector::F64(v.mapv(|x| x.atan()))))
}
_ => bail!("'atan' requires numbers"),
};
Ok(DataValue::Num(Num::Float(a.atan())))
@ -449,6 +704,12 @@ pub(crate) fn op_sinh(args: &[DataValue]) -> Result<DataValue> {
let a = match &args[0] {
DataValue::Num(Num::Int(i)) => *i as f64,
DataValue::Num(Num::Float(f)) => *f,
DataValue::Vec(Vector::F32(v)) => {
return Ok(DataValue::Vec(Vector::F32(v.mapv(|x| x.sinh()))))
}
DataValue::Vec(Vector::F64(v)) => {
return Ok(DataValue::Vec(Vector::F64(v.mapv(|x| x.sinh()))))
}
_ => bail!("'sinh' requires numbers"),
};
Ok(DataValue::Num(Num::Float(a.sinh())))
@ -459,6 +720,12 @@ pub(crate) fn op_cosh(args: &[DataValue]) -> Result<DataValue> {
let a = match &args[0] {
DataValue::Num(Num::Int(i)) => *i as f64,
DataValue::Num(Num::Float(f)) => *f,
DataValue::Vec(Vector::F32(v)) => {
return Ok(DataValue::Vec(Vector::F32(v.mapv(|x| x.cosh()))))
}
DataValue::Vec(Vector::F64(v)) => {
return Ok(DataValue::Vec(Vector::F64(v.mapv(|x| x.cosh()))))
}
_ => bail!("'cosh' requires numbers"),
};
Ok(DataValue::Num(Num::Float(a.cosh())))
@ -469,6 +736,12 @@ pub(crate) fn op_tanh(args: &[DataValue]) -> Result<DataValue> {
let a = match &args[0] {
DataValue::Num(Num::Int(i)) => *i as f64,
DataValue::Num(Num::Float(f)) => *f,
DataValue::Vec(Vector::F32(v)) => {
return Ok(DataValue::Vec(Vector::F32(v.mapv(|x| x.tanh()))))
}
DataValue::Vec(Vector::F64(v)) => {
return Ok(DataValue::Vec(Vector::F64(v.mapv(|x| x.tanh()))))
}
_ => bail!("'tanh' requires numbers"),
};
Ok(DataValue::Num(Num::Float(a.tanh())))
@ -479,6 +752,12 @@ pub(crate) fn op_asinh(args: &[DataValue]) -> Result<DataValue> {
let a = match &args[0] {
DataValue::Num(Num::Int(i)) => *i as f64,
DataValue::Num(Num::Float(f)) => *f,
DataValue::Vec(Vector::F32(v)) => {
return Ok(DataValue::Vec(Vector::F32(v.mapv(|x| x.asinh()))))
}
DataValue::Vec(Vector::F64(v)) => {
return Ok(DataValue::Vec(Vector::F64(v.mapv(|x| x.asinh()))))
}
_ => bail!("'asinh' requires numbers"),
};
Ok(DataValue::Num(Num::Float(a.asinh())))
@ -489,6 +768,12 @@ pub(crate) fn op_acosh(args: &[DataValue]) -> Result<DataValue> {
let a = match &args[0] {
DataValue::Num(Num::Int(i)) => *i as f64,
DataValue::Num(Num::Float(f)) => *f,
DataValue::Vec(Vector::F32(v)) => {
return Ok(DataValue::Vec(Vector::F32(v.mapv(|x| x.acosh()))))
}
DataValue::Vec(Vector::F64(v)) => {
return Ok(DataValue::Vec(Vector::F64(v.mapv(|x| x.acosh()))))
}
_ => bail!("'acosh' requires numbers"),
};
Ok(DataValue::Num(Num::Float(a.acosh())))
@ -499,16 +784,50 @@ pub(crate) fn op_atanh(args: &[DataValue]) -> Result<DataValue> {
let a = match &args[0] {
DataValue::Num(Num::Int(i)) => *i as f64,
DataValue::Num(Num::Float(f)) => *f,
DataValue::Vec(Vector::F32(v)) => {
return Ok(DataValue::Vec(Vector::F32(v.mapv(|x| x.atanh()))))
}
DataValue::Vec(Vector::F64(v)) => {
return Ok(DataValue::Vec(Vector::F64(v.mapv(|x| x.atanh()))))
}
_ => bail!("'atanh' requires numbers"),
};
Ok(DataValue::Num(Num::Float(a.atanh())))
}
define_op!(OP_SQRT, 1, false);
pub(crate) fn op_sqrt(args: &[DataValue]) -> Result<DataValue> {
let a = match &args[0] {
DataValue::Num(Num::Int(i)) => *i as f64,
DataValue::Num(Num::Float(f)) => *f,
DataValue::Vec(Vector::F32(v)) => {
return Ok(DataValue::Vec(Vector::F32(v.mapv(|x| x.sqrt()))))
}
DataValue::Vec(Vector::F64(v)) => {
return Ok(DataValue::Vec(Vector::F64(v.mapv(|x| x.sqrt()))))
}
_ => bail!("'sqrt' requires numbers"),
};
Ok(DataValue::Num(Num::Float(a.sqrt())))
}
define_op!(OP_POW, 2, false);
pub(crate) fn op_pow(args: &[DataValue]) -> Result<DataValue> {
let a = match &args[0] {
DataValue::Num(Num::Int(i)) => *i as f64,
DataValue::Num(Num::Float(f)) => *f,
DataValue::Vec(Vector::F32(v)) => {
let b = args[1]
.get_float()
.ok_or_else(|| miette!("'pow' requires numbers"))?;
return Ok(DataValue::Vec(Vector::F32(v.mapv(|x| x.powf(b as f32)))));
}
DataValue::Vec(Vector::F64(v)) => {
let b = args[1]
.get_float()
.ok_or_else(|| miette!("'pow' requires numbers"))?;
return Ok(DataValue::Vec(Vector::F64(v.mapv(|x| x.powf(b)))));
}
_ => bail!("'pow' requires numbers"),
};
let b = match &args[1] {
@ -938,6 +1257,11 @@ pub(crate) fn op_is_list(args: &[DataValue]) -> Result<DataValue> {
)))
}
define_op!(OP_IS_VEC, 1, false);
pub(crate) fn op_is_vec(args: &[DataValue]) -> Result<DataValue> {
Ok(DataValue::from(matches!(args[0], DataValue::Vec(_))))
}
define_op!(OP_APPEND, 2, false);
pub(crate) fn op_append(args: &[DataValue]) -> Result<DataValue> {
match &args[0] {
@ -984,6 +1308,7 @@ pub(crate) fn op_length(args: &[DataValue]) -> Result<DataValue> {
DataValue::List(l) => l.len() as i64,
DataValue::Str(s) => s.chars().count() as i64,
DataValue::Bytes(b) => b.len() as i64,
DataValue::Vec(v) => v.len() as i64,
_ => bail!("'length' requires lists"),
}))
}
@ -1353,6 +1678,178 @@ pub(crate) fn op_to_string(args: &[DataValue]) -> Result<DataValue> {
})
}
define_op!(OP_VEC, 1, true);
pub(crate) fn op_vec(args: &[DataValue]) -> Result<DataValue> {
let t = match args.get(1) {
Some(DataValue::Str(s)) => match s as &str {
"F32" | "Float" => VecElementType::F32,
"F64" | "Double" => VecElementType::F64,
_ => bail!("'vec' does not recognize type {}", s),
},
None => VecElementType::F32,
_ => bail!("'vec' requires a string as second argument"),
};
match &args[0] {
DataValue::List(l) => match t {
VecElementType::F32 => {
let mut res_arr = ndarray::Array1::zeros(l.len());
for (mut row, el) in res_arr.axis_iter_mut(ndarray::Axis(0)).zip(l.iter()) {
let f = el
.get_float()
.ok_or_else(|| miette!("'vec' requires a list of numbers"))?;
row.fill(f as f32);
}
Ok(DataValue::Vec(Vector::F32(res_arr)))
}
VecElementType::F64 => {
let mut res_arr = ndarray::Array1::zeros(l.len());
for (mut row, el) in res_arr.axis_iter_mut(ndarray::Axis(0)).zip(l.iter()) {
let f = el
.get_float()
.ok_or_else(|| miette!("'vec' requires a list of numbers"))?;
row.fill(f as f64);
}
Ok(DataValue::Vec(Vector::F64(res_arr)))
}
},
DataValue::Vec(v) => match (t, v) {
(VecElementType::F32, Vector::F32(v)) => Ok(DataValue::Vec(Vector::F32(v.clone()))),
(VecElementType::F64, Vector::F64(v)) => Ok(DataValue::Vec(Vector::F64(v.clone()))),
(VecElementType::F32, Vector::F64(v)) => {
Ok(DataValue::Vec(Vector::F32(v.mapv(|x| x as f32))))
}
(VecElementType::F64, Vector::F32(v)) => {
Ok(DataValue::Vec(Vector::F64(v.mapv(|x| x as f64))))
}
},
_ => bail!("'vec' requires a list or a vector"),
}
}
define_op!(OP_RAND_VEC, 1, true);
pub(crate) fn op_rand_vec(args: &[DataValue]) -> Result<DataValue> {
let len = args[0]
.get_int()
.ok_or_else(|| miette!("'rand_vec' requires an integer"))? as usize;
let t = match args.get(1) {
Some(DataValue::Str(s)) => match s as &str {
"F32" | "Float" => VecElementType::F32,
"F64" | "Double" => VecElementType::F64,
_ => bail!("'vec' does not recognize type {}", s),
},
None => VecElementType::F32,
_ => bail!("'vec' requires a string as second argument"),
};
let mut rng = thread_rng();
match t {
VecElementType::F32 => {
let mut res_arr = ndarray::Array1::zeros(len);
for mut row in res_arr.axis_iter_mut(ndarray::Axis(0)) {
row.fill(rng.gen::<f64>() as f32);
}
Ok(DataValue::Vec(Vector::F32(res_arr)))
}
VecElementType::F64 => {
let mut res_arr = ndarray::Array1::zeros(len);
for mut row in res_arr.axis_iter_mut(ndarray::Axis(0)) {
row.fill(rng.gen::<f64>());
}
Ok(DataValue::Vec(Vector::F64(res_arr)))
}
}
}
define_op!(OP_L2_NORMALIZE, 1, false);
pub(crate) fn op_l2_normalize(args: &[DataValue]) -> Result<DataValue> {
let a = &args[0];
match a {
DataValue::Vec(Vector::F32(a)) => {
let norm = a.dot(a).sqrt();
Ok(DataValue::Vec(Vector::F32(a / norm)))
}
DataValue::Vec(Vector::F64(a)) => {
let norm = a.dot(a).sqrt();
Ok(DataValue::Vec(Vector::F64(a / norm)))
}
_ => bail!("'l2_normalize' requires a vector"),
}
}
define_op!(OP_L2_DIST, 2, false);
pub(crate) fn op_l2_dist(args: &[DataValue]) -> Result<DataValue> {
let a = &args[0];
let b = &args[1];
match (a, b) {
(DataValue::Vec(Vector::F32(a)), DataValue::Vec(Vector::F32(b))) => {
if a.len() != b.len() {
bail!("'l2_dist' requires two vectors of the same length");
}
let diff = a - b;
Ok(DataValue::from(diff.dot(&diff) as f64))
}
(DataValue::Vec(Vector::F64(a)), DataValue::Vec(Vector::F64(b))) => {
if a.len() != b.len() {
bail!("'l2_dist' requires two vectors of the same length");
}
let diff = a - b;
Ok(DataValue::from(diff.dot(&diff) as f64))
}
_ => bail!("'l2_dist' requires two vectors of the same type"),
}
}
define_op!(OP_IP_DIST, 2, false);
pub(crate) fn op_ip_dist(args: &[DataValue]) -> Result<DataValue> {
let a = &args[0];
let b = &args[1];
match (a, b) {
(DataValue::Vec(Vector::F32(a)), DataValue::Vec(Vector::F32(b))) => {
if a.len() != b.len() {
bail!("'ip_dist' requires two vectors of the same length");
}
let dot = a.dot(b);
Ok(DataValue::from(1. - dot as f64))
}
(DataValue::Vec(Vector::F64(a)), DataValue::Vec(Vector::F64(b))) => {
if a.len() != b.len() {
bail!("'ip_dist' requires two vectors of the same length");
}
let dot = a.dot(b);
Ok(DataValue::from(1. - dot as f64))
}
_ => bail!("'ip_dist' requires two vectors of the same type"),
}
}
define_op!(OP_COS_DIST, 2, false);
pub(crate) fn op_cos_dist(args: &[DataValue]) -> Result<DataValue> {
let a = &args[0];
let b = &args[1];
match (a, b) {
(DataValue::Vec(Vector::F32(a)), DataValue::Vec(Vector::F32(b))) => {
if a.len() != b.len() {
bail!("'cos_dist' requires two vectors of the same length");
}
let a_norm = a.dot(a) as f64;
let b_norm = b.dot(b) as f64;
let dot = a.dot(b) as f64;
Ok(DataValue::from(1. - dot / (a_norm * b_norm).sqrt()))
}
(DataValue::Vec(Vector::F64(a)), DataValue::Vec(Vector::F64(b))) => {
if a.len() != b.len() {
bail!("'cos_dist' requires two vectors of the same length");
}
let a_norm = a.dot(a);
let b_norm = b.dot(b);
let dot = a.dot(b);
Ok(DataValue::from(1. - dot / (a_norm * b_norm).sqrt()))
}
_ => bail!("'cos_dist' requires two vectors of the same type"),
}
}
define_op!(OP_RAND_FLOAT, 0, false);
pub(crate) fn op_rand_float(_args: &[DataValue]) -> Result<DataValue> {
Ok(thread_rng().gen::<f64>().into())

@ -756,6 +756,27 @@ fn test_vec_types() {
json!([1., 2., 3., 4., 5., 6., 7., 8.]),
res.into_json()["rows"][0][1]
);
let res = db
.run_script("?[v] <- [[vec([1,2,3,4,5,6,7,8])]]", Default::default())
.unwrap();
assert_eq!(
json!([1., 2., 3., 4., 5., 6., 7., 8.]),
res.into_json()["rows"][0][0]
);
let res = db
.run_script("?[v] <- [[rand_vec(5)]]", Default::default())
.unwrap();
assert_eq!(
5,
res.into_json()["rows"][0][0].as_array().unwrap().len()
);
let res = db
.run_script(r#"
val[v] <- [[vec([1,2,3,4,5,6,7,8])]]
?[x,y,z] := val[v], x=l2_dist(v, v), y=cos_dist(v, v), nv = l2_normalize(v), z=ip_dist(nv, nv)
"#, Default::default())
.unwrap();
println!("{}", res.into_json());
}
#[test]

Loading…
Cancel
Save