random functions

main
Ziyang Hu 2 years ago
parent bf112ea399
commit 315917aa67

@ -23,4 +23,4 @@
* [x] louvain modularity
* [x] direct loading of data
* [ ] serial agg function
* [ ] random function
* [x] random function

@ -8,6 +8,7 @@ use std::str::FromStr;
use anyhow::{anyhow, bail, ensure, Result};
use itertools::Itertools;
use num_traits::FloatConst;
use rand::prelude::*;
use smartstring::SmartString;
use crate::data::symb::Symbol;
@ -1425,6 +1426,22 @@ fn op_slice(args: &[DataValue]) -> Result<DataValue> {
Ok(DataValue::List(l[m..n].to_vec()))
}
define_op!(OP_CHARS, 1, false, false);
fn op_chars(args: &[DataValue]) -> Result<DataValue> {
Ok(DataValue::List(
args[0]
.get_string()
.ok_or_else(|| anyhow!("'chars' can only be applied to string, got {:?}", args))?
.chars()
.map(|c| {
let mut s = SmartString::new();
s.push(c);
DataValue::String(s)
})
.collect_vec(),
))
}
define_op!(OP_NTH_CHAR, 2, false, false);
fn op_nth_char(args: &[DataValue]) -> Result<DataValue> {
let l = args[0].get_string().ok_or_else(|| {
@ -1533,6 +1550,60 @@ fn op_to_float(args: &[DataValue]) -> Result<DataValue> {
})
}
define_op!(OP_RAND_FLOAT, 0, false, false);
fn op_rand_float(_args: &[DataValue]) -> Result<DataValue> {
Ok(thread_rng().gen::<f64>().into())
}
define_op!(OP_RAND_BERNOULLI, 0, true, false);
fn op_rand_bernoulli(args: &[DataValue]) -> Result<DataValue> {
let prob = match args.get(0) {
None => 0.5,
Some(DataValue::Number(n)) => {
let f = n.get_float();
ensure!(
f >= 0. && f <= 1.,
"'rand_bernoulli' requires number between 0. and 1., got {}",
f
);
f
}
Some(v) => bail!(
"'rand_bernoulli' requires number between 0. and 1., got {:?}",
v
),
};
Ok(DataValue::Bool(thread_rng().gen_bool(prob)))
}
define_op!(OP_RAND_INT, 2, false, false);
fn op_rand_int(args: &[DataValue]) -> Result<DataValue> {
let lower = &args[0].get_int().ok_or_else(|| {
anyhow!(
"first argument to 'rand_int' must be an integer, got args {:?}",
args
)
})?;
let upper = &args[1].get_int().ok_or_else(|| {
anyhow!(
"second argument to 'rand_int' must be an integer, got args {:?}",
args
)
})?;
Ok(thread_rng().gen_range(*lower..=*upper).into())
}
define_op!(OP_RAND_CHOOSE, 1, false, true);
fn op_rand_choose(args: &[DataValue]) -> Result<DataValue> {
match &args[0] {
DataValue::List(l) => Ok(l
.choose(&mut thread_rng())
.cloned()
.unwrap_or(DataValue::Null)),
v => bail!("'rand_choice' can only be applied to list, got {:?}", v),
}
}
pub(crate) fn get_op(name: &str) -> Option<&'static Op> {
Some(match name {
"list" => &OP_LIST,
@ -1612,6 +1683,7 @@ pub(crate) fn get_op(name: &str) -> Option<&'static Op> {
"nth" => &OP_NTH,
"maybe_nth" => &OP_MAYBE_NTH,
"nth_char" => &OP_NTH_CHAR,
"chars" => &OP_CHARS,
"maybe_nth_char" => &OP_MAYBE_NTH_CHAR,
"slice" => &OP_SLICE,
"str_slice" => &OP_STR_SLICE,
@ -1628,6 +1700,10 @@ pub(crate) fn get_op(name: &str) -> Option<&'static Op> {
"chunks_exact" => &OP_CHUNKS_EXACT,
"windows" => &OP_WINDOWS,
"to_float" => &OP_TO_FLOAT,
"rand_float" => &OP_RAND_FLOAT,
"rand_bernoulli" => &OP_RAND_BERNOULLI,
"rand_int" => &OP_RAND_INT,
"rand_choose" => &OP_RAND_CHOOSE,
_ => return None,
})
}

Loading…
Cancel
Save