From bd1ba17527548c6cb11337c8ada0fd18b2d497a4 Mon Sep 17 00:00:00 2001 From: Ziyang Hu Date: Thu, 15 Jun 2023 22:22:01 +0800 Subject: [PATCH] make starts_with and ends_with work with bytes https://github.com/cozodb/cozo/discussions/132 --- cozo-core/src/data/functions.rs | 74 ++++++++++++++++++--------------- 1 file changed, 41 insertions(+), 33 deletions(-) diff --git a/cozo-core/src/data/functions.rs b/cozo-core/src/data/functions.rs index 374a9c9d..c2ced5df 100644 --- a/cozo-core/src/data/functions.rs +++ b/cozo-core/src/data/functions.rs @@ -1359,28 +1359,24 @@ pub(crate) fn op_trim_end(args: &[DataValue]) -> Result { define_op!(OP_STARTS_WITH, 2, false); pub(crate) fn op_starts_with(args: &[DataValue]) -> Result { - let a = match &args[0] { - DataValue::Str(s) => s, - _ => bail!("'starts_with' requires strings"), - }; - let b = match &args[1] { - DataValue::Str(s) => s, - _ => bail!("'starts_with' requires strings"), - }; - Ok(DataValue::from(a.starts_with(b as &str))) + match (&args[0], &args[1]) { + (DataValue::Str(l), DataValue::Str(r)) => Ok(DataValue::from(l.starts_with(r as &str))), + (DataValue::Bytes(l), DataValue::Bytes(r)) => { + Ok(DataValue::from(l.starts_with(r as &[u8]))) + } + _ => bail!("'starts_with' requires strings or bytes"), + } } define_op!(OP_ENDS_WITH, 2, false); pub(crate) fn op_ends_with(args: &[DataValue]) -> Result { - let a = match &args[0] { - DataValue::Str(s) => s, - _ => bail!("'ends_with' requires strings"), - }; - let b = match &args[1] { - DataValue::Str(s) => s, - _ => bail!("'ends_with' requires strings"), - }; - Ok(DataValue::from(a.ends_with(b as &str))) + match (&args[0], &args[1]) { + (DataValue::Str(l), DataValue::Str(r)) => Ok(DataValue::from(l.ends_with(r as &str))), + (DataValue::Bytes(l), DataValue::Bytes(r)) => { + Ok(DataValue::from(l.ends_with(r as &[u8]))) + } + _ => bail!("'ends_with' requires strings or bytes"), + } } define_op!(OP_REGEX, 1, false); @@ -1623,9 +1619,9 @@ pub(crate) fn op_haversine(args: &[DataValue]) -> Result { let lon2 = args[3].get_float().ok_or_else(miette)?; let ret = 2. * f64::asin(f64::sqrt( - f64::sin((lat1 - lat2) / 2.).powi(2) - + f64::cos(lat1) * f64::cos(lat2) * f64::sin((lon1 - lon2) / 2.).powi(2), - )); + f64::sin((lat1 - lat2) / 2.).powi(2) + + f64::cos(lat1) * f64::cos(lat2) * f64::sin((lon1 - lon2) / 2.).powi(2), + )); Ok(DataValue::from(ret)) } @@ -1638,9 +1634,9 @@ pub(crate) fn op_haversine_deg_input(args: &[DataValue]) -> Result { let lon2 = args[3].get_float().ok_or_else(miette)? * f64::PI() / 180.; let ret = 2. * f64::asin(f64::sqrt( - f64::sin((lat1 - lat2) / 2.).powi(2) - + f64::cos(lat1) * f64::cos(lat2) * f64::sin((lon1 - lon2) / 2.).powi(2), - )); + f64::sin((lat1 - lat2) / 2.).powi(2) + + f64::cos(lat1) * f64::cos(lat2) * f64::sin((lon1 - lon2) / 2.).powi(2), + )); Ok(DataValue::from(ret)) } @@ -1971,7 +1967,7 @@ pub(crate) fn op_to_int(args: &[DataValue]) -> Result { .map_err(|_| miette!("The string cannot be interpreted as int"))? .into() } - DataValue::Validity(vld) => DataValue::Num(Num::Int(vld.timestamp.0.0)), + DataValue::Validity(vld) => DataValue::Num(Num::Int(vld.timestamp.0 .0)), v => bail!("'to_int' does not recognize {:?}", v), }) } @@ -2086,16 +2082,28 @@ pub(crate) fn op_vec(args: &[DataValue]) -> Result { } }, DataValue::Str(s) => { - let bytes = STANDARD.decode(s).map_err(|_| miette!("Data is not base64 encoded"))?; + let bytes = STANDARD + .decode(s) + .map_err(|_| miette!("Data is not base64 encoded"))?; match t { VecElementType::F32 => { let f32_count = bytes.len() / mem::size_of::(); - let arr = unsafe { ndarray::ArrayView1::from_shape_ptr(ndarray::Dim([f32_count]), bytes.as_ptr() as *const f32) }; + let arr = unsafe { + ndarray::ArrayView1::from_shape_ptr( + ndarray::Dim([f32_count]), + bytes.as_ptr() as *const f32, + ) + }; Ok(DataValue::Vec(Vector::F32(arr.to_owned()))) } VecElementType::F64 => { let f64_count = bytes.len() / mem::size_of::(); - let arr = unsafe { ndarray::ArrayView1::from_shape_ptr(ndarray::Dim([f64_count]), bytes.as_ptr() as *const f64) }; + let arr = unsafe { + ndarray::ArrayView1::from_shape_ptr( + ndarray::Dim([f64_count]), + bytes.as_ptr() as *const f64, + ) + }; Ok(DataValue::Vec(Vector::F64(arr.to_owned()))) } } @@ -2428,12 +2436,12 @@ pub(crate) fn op_now(_args: &[DataValue]) -> Result { pub(crate) fn current_validity() -> ValidityTs { #[cfg(not(target_arch = "wasm32"))] - let ts_micros = { + let ts_micros = { let now = SystemTime::now(); now.duration_since(UNIX_EPOCH).unwrap().as_micros() as i64 }; #[cfg(target_arch = "wasm32")] - let ts_micros = { (Date::now() * 1000.) as i64 }; + let ts_micros = { (Date::now() * 1000.) as i64 }; ValidityTs(Reverse(ts_micros)) } @@ -2448,7 +2456,7 @@ define_op!(OP_FORMAT_TIMESTAMP, 1, true); pub(crate) fn op_format_timestamp(args: &[DataValue]) -> Result { let dt = { let millis = match &args[0] { - DataValue::Validity(vld) => vld.timestamp.0.0 / 1000, + DataValue::Validity(vld) => vld.timestamp.0 .0 / 1000, v => { let f = v .get_float() @@ -2502,14 +2510,14 @@ pub(crate) fn op_rand_uuid_v1(_args: &[DataValue]) -> Result { let mut rng = rand::thread_rng(); let uuid_ctx = uuid::v1::Context::new(rng.gen()); #[cfg(target_arch = "wasm32")] - let ts = { + let ts = { let since_epoch: f64 = Date::now(); let seconds = since_epoch.floor(); let fractional = (since_epoch - seconds) * 1.0e9; Timestamp::from_unix(uuid_ctx, seconds as u64, fractional as u32) }; #[cfg(not(target_arch = "wasm32"))] - let ts = { + let ts = { let now = SystemTime::now(); let since_epoch = now.duration_since(UNIX_EPOCH).unwrap(); Timestamp::from_unix(uuid_ctx, since_epoch.as_secs(), since_epoch.subsec_nanos())