python wrapper no longer goes through strings

main
Ziyang Hu 2 years ago
parent c3c674950f
commit 0c7f2a3d27

@ -17,8 +17,9 @@ use serde::{Deserialize, Deserializer, Serialize};
use smartstring::{LazyCompact, SmartString}; use smartstring::{LazyCompact, SmartString};
use uuid::Uuid; use uuid::Uuid;
/// UUID value in the database
#[derive(Clone, Hash, Eq, PartialEq, serde_derive::Deserialize, serde_derive::Serialize)] #[derive(Clone, Hash, Eq, PartialEq, serde_derive::Deserialize, serde_derive::Serialize)]
pub struct UuidWrapper(pub(crate) Uuid); pub struct UuidWrapper(pub Uuid);
impl PartialOrd<Self> for UuidWrapper { impl PartialOrd<Self> for UuidWrapper {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> { fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
@ -37,8 +38,9 @@ impl Ord for UuidWrapper {
} }
} }
/// A Regex in the database. Used internally in functions.
#[derive(Clone)] #[derive(Clone)]
pub struct RegexWrapper(pub(crate) Regex); pub struct RegexWrapper(pub Regex);
impl Hash for RegexWrapper { impl Hash for RegexWrapper {
fn hash<H: Hasher>(&self, state: &mut H) { fn hash<H: Hasher>(&self, state: &mut H) {
@ -84,6 +86,7 @@ impl PartialOrd for RegexWrapper {
} }
} }
/// Timestamp part of validity
#[derive( #[derive(
Copy, Copy,
Clone, Clone,
@ -98,6 +101,7 @@ impl PartialOrd for RegexWrapper {
)] )]
pub struct ValidityTs(pub Reverse<i64>); pub struct ValidityTs(pub Reverse<i64>);
/// Validity for time travel
#[derive( #[derive(
Copy, Copy,
Clone, Clone,
@ -110,25 +114,39 @@ pub struct ValidityTs(pub Reverse<i64>);
Hash, Hash,
)] )]
pub struct Validity { pub struct Validity {
pub(crate) timestamp: ValidityTs, /// Timestamp, sorted descendingly
pub(crate) is_assert: Reverse<bool>, pub timestamp: ValidityTs,
/// Whether this validity is an assertion, sorted descendingly
pub is_assert: Reverse<bool>,
} }
/// A Value in the database
#[derive( #[derive(
Clone, PartialEq, Eq, PartialOrd, Ord, serde_derive::Deserialize, serde_derive::Serialize, Hash, Clone, PartialEq, Eq, PartialOrd, Ord, serde_derive::Deserialize, serde_derive::Serialize, Hash,
)] )]
pub enum DataValue { pub enum DataValue {
/// null
Null, Null,
/// boolean
Bool(bool), Bool(bool),
/// number, may be int or float
Num(Num), Num(Num),
/// string
Str(SmartString<LazyCompact>), Str(SmartString<LazyCompact>),
/// bytes
#[serde(with = "serde_bytes")] #[serde(with = "serde_bytes")]
Bytes(Vec<u8>), Bytes(Vec<u8>),
/// UUID
Uuid(UuidWrapper), Uuid(UuidWrapper),
/// Regex, used internally only
Regex(RegexWrapper), Regex(RegexWrapper),
/// list
List(Vec<DataValue>), List(Vec<DataValue>),
/// set, used internally only
Set(BTreeSet<DataValue>), Set(BTreeSet<DataValue>),
/// validity
Validity(Validity), Validity(Validity),
/// bottom type, used internally only
Bot, Bot,
} }
@ -162,9 +180,12 @@ impl From<bool> for DataValue {
} }
} }
/// Representing a number
#[derive(Copy, Clone, serde_derive::Deserialize, serde_derive::Serialize)] #[derive(Copy, Clone, serde_derive::Deserialize, serde_derive::Serialize)]
pub enum Num { pub enum Num {
/// intger number
Int(i64), Int(i64),
/// float number
Float(f64), Float(f64),
} }
@ -340,6 +361,7 @@ impl DataValue {
_ => None, _ => None,
} }
} }
/// Returns bool if this one is.
pub fn get_bool(&self) -> Option<bool> { pub fn get_bool(&self) -> Option<bool> {
match self { match self {
DataValue::Bool(b) => Some(*b), DataValue::Bool(b) => Some(*b),

@ -45,6 +45,7 @@ use miette::{
}; };
use serde_json::json; use serde_json::json;
pub use data::value::{DataValue, Num, RegexWrapper, UuidWrapper, Validity, ValidityTs};
pub use fixed_rule::FixedRule; pub use fixed_rule::FixedRule;
pub use runtime::db::Db; pub use runtime::db::Db;
pub use runtime::db::NamedRows; pub use runtime::db::NamedRows;
@ -153,7 +154,7 @@ impl DbInstance {
pub fn run_script( pub fn run_script(
&self, &self,
payload: &str, payload: &str,
params: BTreeMap<String, JsonValue>, params: BTreeMap<String, DataValue>,
) -> Result<NamedRows> { ) -> Result<NamedRows> {
match self { match self {
DbInstance::Mem(db) => db.run_script(payload, params), DbInstance::Mem(db) => db.run_script(payload, params),
@ -172,7 +173,7 @@ impl DbInstance {
pub fn run_script_fold_err( pub fn run_script_fold_err(
&self, &self,
payload: &str, payload: &str,
params: BTreeMap<String, JsonValue>, params: BTreeMap<String, DataValue>,
) -> JsonValue { ) -> JsonValue {
#[cfg(not(target_arch = "wasm32"))] #[cfg(not(target_arch = "wasm32"))]
let start = Instant::now(); let start = Instant::now();
@ -198,7 +199,10 @@ impl DbInstance {
BTreeMap::default() BTreeMap::default()
} else { } else {
match serde_json::from_str::<BTreeMap<String, JsonValue>>(params) { match serde_json::from_str::<BTreeMap<String, JsonValue>>(params) {
Ok(map) => map, Ok(map) => map
.into_iter()
.map(|(k, v)| (k, DataValue::from(v)))
.collect(),
Err(_) => { Err(_) => {
return json!({"ok": false, "message": "params argument is not a JSON map"}) return json!({"ok": false, "message": "params argument is not a JSON map"})
.to_string() .to_string()

@ -163,13 +163,9 @@ impl<'s, S: Storage<'s>> Db<S> {
pub fn run_script( pub fn run_script(
&'s self, &'s self,
payload: &str, payload: &str,
params: BTreeMap<String, JsonValue>, params: BTreeMap<String, DataValue>,
) -> Result<NamedRows> { ) -> Result<NamedRows> {
let cur_vld = current_validity(); let cur_vld = current_validity();
let params = params
.into_iter()
.map(|(k, v)| (k, DataValue::from(v)))
.collect();
self.do_run_script(payload, &params, cur_vld) self.do_run_script(payload, &params, cur_vld)
} }
/// Export relations to JSON data. /// Export relations to JSON data.

@ -8,9 +8,95 @@
use pyo3::exceptions::PyException; use pyo3::exceptions::PyException;
use pyo3::prelude::*; use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList};
use std::collections::BTreeMap;
use cozo::*; use cozo::*;
fn py_to_value(ob: &PyAny) -> PyResult<DataValue> {
Ok(if ob.is_none() {
DataValue::Null
} else if let Ok(i) = ob.extract::<i64>() {
DataValue::from(i)
} else if let Ok(f) = ob.extract::<f64>() {
DataValue::from(f)
} else if let Ok(s) = ob.extract::<String>() {
DataValue::from(s)
} else if let Ok(b) = ob.extract::<Vec<u8>>() {
DataValue::Bytes(b)
} else if let Ok(l) = ob.downcast::<PyList>() {
let mut coll = Vec::with_capacity(l.len());
for el in l {
let el = py_to_value(el)?;
coll.push(el)
}
DataValue::List(coll)
} else if let Ok(d) = ob.downcast::<PyDict>() {
let mut coll = Vec::with_capacity(d.len());
for (k, v) in d {
let k = py_to_value(k)?;
let v = py_to_value(v)?;
coll.push(DataValue::List(vec![k, v]))
}
DataValue::List(coll)
} else {
return Err(PyException::new_err(format!("Cannot convert {} into Cozo value", ob)).into());
})
}
fn convert_params(ob: &PyDict) -> PyResult<BTreeMap<String, DataValue>> {
let mut ret = BTreeMap::new();
for (k, v) in ob {
let k: String = k.extract()?;
let v = py_to_value(v)?;
ret.insert(k, v);
}
Ok(ret)
}
fn value_to_py(val: DataValue, py: Python<'_>) -> PyObject {
match val {
DataValue::Null => py.None(),
DataValue::Bool(b) => b.into_py(py),
DataValue::Num(num) => match num {
Num::Int(i) => i.into_py(py),
Num::Float(f) => f.into_py(py),
},
DataValue::Str(s) => s.as_str().into_py(py),
DataValue::Bytes(b) => b.into_py(py),
DataValue::Uuid(uuid) => uuid.0.to_string().into_py(py),
DataValue::Regex(rx) => rx.0.as_str().into_py(py),
DataValue::List(l) => {
let vs: Vec<_> = l.into_iter().map(|v| value_to_py(v, py)).collect();
vs.into_py(py)
}
DataValue::Set(l) => {
let vs: Vec<_> = l.into_iter().map(|v| value_to_py(v, py)).collect();
vs.into_py(py)
}
DataValue::Validity(vld) => {
[vld.timestamp.0 .0.into_py(py), vld.is_assert.0.into_py(py)].into_py(py)
}
DataValue::Bot => py.None(),
}
}
fn named_rows_to_py(named_rows: NamedRows, py: Python<'_>) -> PyObject {
let rows = named_rows
.rows
.into_iter()
.map(|row| {
row.into_iter()
.map(|val| value_to_py(val, py))
.collect::<Vec<_>>()
.into_py(py)
})
.collect::<Vec<_>>()
.into_py(py);
let headers = named_rows.headers.into_py(py);
BTreeMap::from([("rows", rows), ("headers", headers)]).into_py(py)
}
#[pyclass] #[pyclass]
struct CozoDbPy { struct CozoDbPy {
db: Option<DbInstance>, db: Option<DbInstance>,
@ -27,6 +113,20 @@ impl CozoDbPy {
Err(err) => Err(PyException::new_err(format!("{err:?}"))), Err(err) => Err(PyException::new_err(format!("{err:?}"))),
} }
} }
pub fn run_script(&self, py: Python<'_>, query: &str, params: &PyDict) -> PyResult<PyObject> {
if let Some(db) = &self.db {
let params = convert_params(params)?;
match py.allow_threads(|| db.run_script(query, params)) {
Ok(rows) => Ok(named_rows_to_py(rows, py)),
Err(err) => {
let reports = format_error_as_json(err, Some(query)).to_string();
Err(PyException::new_err(reports))
}
}
} else {
Err(PyException::new_err(DB_CLOSED_MSG))
}
}
pub fn run_query(&self, py: Python<'_>, query: &str, params: &str) -> String { pub fn run_query(&self, py: Python<'_>, query: &str, params: &str) -> String {
if let Some(db) = &self.db { if let Some(db) = &self.db {
py.allow_threads(|| db.run_script_str(query, params)) py.allow_threads(|| db.run_script_str(query, params))

Loading…
Cancel
Save