From 9489f46e39ef2259ab6659f0ac243f3816ba4860 Mon Sep 17 00:00:00 2001 From: Ziyang Hu Date: Sun, 25 Sep 2022 18:54:22 +0800 Subject: [PATCH] complete CSV reader --- Cargo.toml | 1 + src/algo/csv.rs | 165 ++++++++++++++++++++++++++++++++++++++++++++ src/algo/mod.rs | 40 ++++++++++- src/cozoscript.pest | 3 +- src/data/program.rs | 6 +- src/data/value.rs | 8 ++- src/parse/mod.rs | 9 ++- src/parse/query.rs | 7 +- src/parse/schema.rs | 8 +-- 9 files changed, 234 insertions(+), 13 deletions(-) create mode 100644 src/algo/csv.rs diff --git a/Cargo.toml b/Cargo.toml index f503a058..0d7a3717 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,6 +43,7 @@ unicode-normalization = "0.1.21" thiserror = "1.0.34" uuid = { version = "1.1.2", features = ["v1", "v4", "serde"] } sqlx = { version = "0.6.2", features = ["runtime-tokio-native-tls", "postgres", "mysql", "mssql", "sqlite"] } +csv = "1.1.6" cozorocks = { path = "cozorocks" } #[target.'cfg(not(target_env = "msvc"))'.dependencies] diff --git a/src/algo/csv.rs b/src/algo/csv.rs new file mode 100644 index 00000000..a489493e --- /dev/null +++ b/src/algo/csv.rs @@ -0,0 +1,165 @@ +use std::collections::BTreeMap; + +use csv::StringRecord; +use miette::{bail, ensure, IntoDiagnostic, Result}; +use smartstring::SmartString; + +use crate::algo::AlgoImpl; +use crate::data::functions::{op_to_float, op_to_uuid}; +use crate::data::program::{MagicAlgoApply, MagicSymbol, WrongAlgoOptionError}; +use crate::data::relation::{ColType, NullableColType}; +use crate::data::tuple::Tuple; +use crate::data::value::DataValue; +use crate::parse::parse_type; +use crate::runtime::db::Poison; +use crate::runtime::derived::DerivedRelStore; +use crate::runtime::transact::SessionTx; + +pub(crate) struct CsvReader; + +impl AlgoImpl for CsvReader { + fn run( + &mut self, + _tx: &SessionTx, + algo: &MagicAlgoApply, + _stores: &BTreeMap, + out: &DerivedRelStore, + _poison: Poison, + ) -> Result<()> { + let delimiter = algo.string_option("delimiter", Some(","))?; + let delimiter = delimiter.as_bytes(); + ensure!( + delimiter.len() == 1, + WrongAlgoOptionError { + name: "delimiter".to_string(), + span: algo.span, + algo_name: "CsvReader".to_string(), + help: "'delimiter' must be a single-byte string".to_string() + } + ); + let delimiter = delimiter[0]; + let prepend_index = algo.bool_option("prepend_index", Some(false))?; + let has_headers = algo.bool_option("has_headers", Some(true))?; + let types_opts = algo.expr_option("types", None)?.eval_to_const()?; + let typing = NullableColType { + coltype: ColType::List { + eltype: Box::new(NullableColType { + coltype: ColType::String, + nullable: false, + }), + len: None, + }, + nullable: false, + }; + let types_opts = typing.coerce(types_opts)?; + let mut types = vec![]; + for type_str in types_opts.get_list().unwrap() { + let type_str = type_str.get_string().unwrap(); + let typ = parse_type(type_str).map_err(|e| WrongAlgoOptionError { + name: "types".to_string(), + span: algo.span, + algo_name: "CsvReader".to_string(), + help: e.to_string(), + })?; + types.push(typ); + } + + let mut rdr_builder = csv::ReaderBuilder::new(); + rdr_builder + .delimiter(delimiter) + .has_headers(has_headers) + .flexible(true); + + let mut counter = -1i64; + let out_tuple_size = if prepend_index { + types.len() + 1 + } else { + types.len() + }; + let mut process_row = |row: StringRecord| -> Result<()> { + let mut out_tuple = Tuple(Vec::with_capacity(out_tuple_size)); + if prepend_index { + counter += 1; + out_tuple.0.push(DataValue::from(counter)); + } + for (i, typ) in types.iter().enumerate() { + match row.get(i) { + None => { + if typ.nullable { + out_tuple.0.push(DataValue::Null) + } else { + bail!( + "encountered null value when processing CSV when non-null required" + ) + } + } + Some(s) => { + let dv = DataValue::Str(SmartString::from(s)); + match &typ.coltype { + ColType::Any | ColType::String => out_tuple.0.push(dv), + ColType::Uuid => out_tuple.0.push(match op_to_uuid(&[dv]) { + Ok(uuid) => uuid, + Err(err) => { + if typ.nullable { + DataValue::Null + } else { + bail!(err) + } + } + }), + ColType::Float => out_tuple.0.push(match op_to_float(&[dv]) { + Ok(data) => data, + Err(err) => { + if typ.nullable { + DataValue::Null + } else { + bail!(err) + } + } + }), + ColType::Int => { + let f = op_to_float(&[dv]).unwrap_or(DataValue::Null); + match f.get_int() { + None => {if typ.nullable { + out_tuple.0.push(DataValue::Null) + } else { + bail!("cannot convert {} to type {}", s, typ) + }} + Some(i) => { + out_tuple.0.push(DataValue::from(i)) + } + }; + } + _ => bail!("cannot convert {} to type {}", s, typ), + } + } + } + } + out.put(out_tuple, 0); + Ok(()) + }; + + let url = algo.string_option("url", None)?; + match url.strip_prefix("file://") { + Some(file_path) => { + let mut rdr = rdr_builder.from_path(file_path).into_diagnostic()?; + for record in rdr.records() { + let record = record.into_diagnostic()?; + process_row(record)?; + } + } + None => { + let content = reqwest::blocking::get(&url as &str) + .into_diagnostic()? + .text() + .into_diagnostic()?; + let mut rdr = rdr_builder.from_reader(content.as_bytes()); + for record in rdr.records() { + let record = record.into_diagnostic()?; + process_row(record)?; + } + } + } + Ok(()) + } +} diff --git a/src/algo/mod.rs b/src/algo/mod.rs index 634cd52b..a02fc7b6 100644 --- a/src/algo/mod.rs +++ b/src/algo/mod.rs @@ -8,6 +8,7 @@ use thiserror::Error; use crate::algo::all_pairs_shortest_path::{BetweennessCentrality, ClosenessCentrality}; use crate::algo::astar::ShortestPathAStar; use crate::algo::bfs::Bfs; +use crate::algo::csv::CsvReader; use crate::algo::degree_centrality::DegreeCentrality; use crate::algo::dfs::Dfs; use crate::algo::jlines::JsonReader; @@ -25,7 +26,9 @@ use crate::algo::triangles::ClusteringCoefficients; use crate::algo::yen::KShortestPathYen; use crate::data::expr::Expr; use crate::data::functions::OP_LIST; -use crate::data::program::{AlgoRuleArg, MagicAlgoApply, MagicAlgoRuleArg, MagicSymbol}; +use crate::data::program::{ + AlgoOptionNotFoundError, AlgoRuleArg, MagicAlgoApply, MagicAlgoRuleArg, MagicSymbol, +}; use crate::data::symb::Symbol; use crate::data::tuple::{Tuple, TupleIter}; use crate::data::value::DataValue; @@ -37,6 +40,7 @@ use crate::runtime::transact::SessionTx; pub(crate) mod all_pairs_shortest_path; pub(crate) mod astar; pub(crate) mod bfs; +pub(crate) mod csv; pub(crate) mod degree_centrality; pub(crate) mod dfs; pub(crate) mod jlines; @@ -125,6 +129,38 @@ impl AlgoHandle { )), } } + n @ "CsvReader" => { + let with_row_num = match opts.get("prepend_index") { + None => 0, + Some(Expr::Const { + val: DataValue::Bool(true), + .. + }) => 1, + Some(Expr::Const { + val: DataValue::Bool(false), + .. + }) => 0, + _ => bail!(CannotDetermineArity( + n.to_string(), + "invalid option 'prepend_index' given, expect a boolean".to_string(), + self.name.span + )), + }; + let columns = opts.get("types").ok_or_else(|| AlgoOptionNotFoundError { + name: "types".to_string(), + span: self.name.span, + algo_name: n.to_string(), + })?; + let columns = columns.clone().eval_to_const()?; + if let Some(l) = columns.get_list() { + return Ok(l.len() + with_row_num); + } + bail!(CannotDetermineArity( + n.to_string(), + "invalid option 'types' given, expect positive number or list".to_string(), + self.name.span + )) + } n @ "JsonReader" => { let with_row_num = match opts.get("prepend_index") { None => 0, @@ -186,7 +222,7 @@ impl AlgoHandle { "RandomWalk" => Box::new(RandomWalk), "ReorderSort" => Box::new(ReorderSort), "JsonReader" => Box::new(JsonReader), - "CsvReader" => todo!(), + "CsvReader" => Box::new(CsvReader), "RemoteCozo" => todo!(), "SqlDb" => todo!(), name => bail!(AlgoNotFoundError(name.to_string(), self.name.span)), diff --git a/src/cozoscript.pest b/src/cozoscript.pest index 0336cb73..513ce59b 100644 --- a/src/cozoscript.pest +++ b/src/cozoscript.pest @@ -56,7 +56,7 @@ negation = {"not" ~ atom} apply = {ident ~ "(" ~ apply_args ~ ")"} apply_args = {(expr ~ ",")* ~ expr?} named_apply_args = {(named_apply_pair ~ ",")* ~ named_apply_pair?} -named_apply_pair = {ident ~ ":" ~ expr} +named_apply_pair = {ident ~ (":" ~ expr)?} grouped = _{"(" ~ rule_body ~ ")"} expr = {unary ~ (operation ~ unary)*} @@ -163,6 +163,7 @@ table_schema = {"{" ~ table_cols ~ ("=>" ~ table_cols)? ~ "}"} table_cols = {(table_col ~ ",")* ~ table_col?} table_col = {ident ~ (":" ~ col_type)? ~ (("default" ~ expr) | ("=" ~ out_arg))?} col_type = {(any_type | int_type | float_type | string_type | bytes_type | uuid_type | list_type | tuple_type) ~ "?"?} +col_type_with_term = {SOI ~ col_type ~ EOI} any_type = {"Any"} int_type = {"Int"} float_type = {"Float"} diff --git a/src/data/program.rs b/src/data/program.rs index af43a325..4f705334 100644 --- a/src/data/program.rs +++ b/src/data/program.rs @@ -145,10 +145,10 @@ pub(crate) struct MagicAlgoApply { #[error("Cannot find a required named option '{name}' for '{algo_name}'")] #[diagnostic(code(algo::arg_not_found))] pub(crate) struct AlgoOptionNotFoundError { - name: String, + pub(crate) name: String, #[label] - span: SourceSpan, - algo_name: String, + pub(crate) span: SourceSpan, + pub(crate) algo_name: String, } #[derive(Error, Diagnostic, Debug)] diff --git a/src/data/value.rs b/src/data/value.rs index d81ebb9f..bdaa2ea2 100644 --- a/src/data/value.rs +++ b/src/data/value.rs @@ -139,7 +139,13 @@ impl Num { pub(crate) fn get_int(&self) -> Option { match self { Num::I(i) => Some(*i), - _ => None, + Num::F(f) => { + if f.round() == *f { + Some(*f as i64) + } else { + None + } + } } } pub(crate) fn get_float(&self) -> f64 { diff --git a/src/parse/mod.rs b/src/parse/mod.rs index 631539e2..4d4945ce 100644 --- a/src/parse/mod.rs +++ b/src/parse/mod.rs @@ -1,13 +1,15 @@ use std::cmp::{max, min}; use std::collections::BTreeMap; -use miette::{Diagnostic, Result}; +use miette::{Diagnostic, Result, IntoDiagnostic}; use pest::error::InputLocation; use pest::Parser; use crate::data::program::InputProgram; +use crate::data::relation::NullableColType; use crate::data::value::DataValue; use crate::parse::query::parse_query; +use crate::parse::schema::parse_nullable_type; use crate::parse::sys::{parse_sys, SysOp}; pub(crate) mod expr; @@ -62,6 +64,11 @@ pub(crate) struct ParseError { pub(crate) span: SourceSpan, } +pub(crate) fn parse_type(src: &str) -> Result { + let parsed = CozoScriptParser::parse(Rule::col_type_with_term, src).into_diagnostic()?.next().unwrap(); + parse_nullable_type(parsed.into_inner().next().unwrap()) +} + pub(crate) fn parse_script( src: &str, param_pool: &BTreeMap, diff --git a/src/parse/query.rs b/src/parse/query.rs index aee96d39..8ca79fe1 100644 --- a/src/parse/query.rs +++ b/src/parse/query.rs @@ -653,7 +653,12 @@ fn parse_atom(src: Pair<'_>, param_pool: &BTreeMap) -> Result let mut inner = pair.into_inner(); let name_p = inner.next().unwrap(); let name = SmartString::from(name_p.as_str()); - let arg = parse_rule_arg(inner.next().unwrap(), param_pool)?; + let arg = match inner.next() { + Some(a) => parse_rule_arg(a, param_pool)?, + None => InputTerm::Var { + name: Symbol::new(name.clone(), name_p.extract_span()), + }, + }; Ok((name, arg)) }, ) diff --git a/src/parse/schema.rs b/src/parse/schema.rs index 8d216ffc..8b9f9c47 100644 --- a/src/parse/schema.rs +++ b/src/parse/schema.rs @@ -73,7 +73,7 @@ fn parse_col(pair: Pair<'_>) -> Result<(ColumnDef, Symbol)> { let mut binding_candidate = None; for nxt in src { match nxt.as_rule() { - Rule::col_type => typing = parse_type(nxt)?, + Rule::col_type => typing = parse_nullable_type(nxt)?, Rule::expr => default_gen = Some(build_expr(nxt, &Default::default())?), Rule::out_arg => binding_candidate = Some(Symbol::new(nxt.as_str(), nxt.extract_span())), r => unreachable!("{:?}", r) @@ -88,7 +88,7 @@ fn parse_col(pair: Pair<'_>) -> Result<(ColumnDef, Symbol)> { }, binding)) } -fn parse_type(pair: Pair<'_>) -> Result { +pub(crate) fn parse_nullable_type(pair: Pair<'_>) -> Result { let nullable = pair.as_str().ends_with('?'); let coltype = parse_type_inner(pair.into_inner().next().unwrap())?; Ok(NullableColType { @@ -107,7 +107,7 @@ fn parse_type_inner(pair: Pair<'_>) -> Result { Rule::uuid_type => ColType::Uuid, Rule::list_type => { let mut inner = pair.into_inner(); - let eltype = parse_type(inner.next().unwrap())?; + let eltype = parse_nullable_type(inner.next().unwrap())?; let len = match inner.next() { None => None, Some(len_p) => { @@ -129,7 +129,7 @@ fn parse_type_inner(pair: Pair<'_>) -> Result { ColType::List { eltype: eltype.into(), len } } Rule::tuple_type => { - ColType::Tuple(pair.into_inner().map(|p| parse_type(p)).try_collect()?) + ColType::Tuple(pair.into_inner().map(|p| parse_nullable_type(p)).try_collect()?) } _ => unreachable!() })