pave way for custom algorithms

main
Ziyang Hu 2 years ago
parent 8aaae15be8
commit 9cadfc48a7

@ -7,7 +7,9 @@
*/
use std::collections::BTreeMap;
use std::sync::Arc;
use lazy_static::lazy_static;
use miette::{bail, ensure, Diagnostic, Result};
use smartstring::{LazyCompact, SmartString};
use thiserror::Error;
@ -368,7 +370,7 @@ impl<'a, 'b> AlgoPayload<'a, 'b> {
}
/// Trait for an implementation of an algorithm or a utility
pub trait AlgoImpl {
pub trait AlgoImpl: Send + Sync {
/// Called to initialize the options given.
/// Will always be called once, before anything else.
/// You can mutate the options if you need to.
@ -413,59 +415,139 @@ pub(crate) struct AlgoHandle {
pub(crate) name: Symbol,
}
impl AlgoHandle {
pub(crate) fn new(name: &str, span: SourceSpan) -> Self {
AlgoHandle {
name: Symbol::new(name, span),
}
}
pub(crate) fn get_impl(&self) -> Result<Box<dyn AlgoImpl>> {
Ok(match &self.name.name as &str {
lazy_static! {
pub(crate) static ref DEFAULT_ALGOS: Arc<BTreeMap<String, Arc<Box<dyn AlgoImpl>>>> = {
Arc::new(BTreeMap::from([
#[cfg(feature = "graph-algo")]
"ClusteringCoefficients" => Box::new(ClusteringCoefficients),
(
"ClusteringCoefficients".to_string(),
Arc::<Box<dyn AlgoImpl>>::new(Box::new(ClusteringCoefficients)),
),
#[cfg(feature = "graph-algo")]
"DegreeCentrality" => Box::new(DegreeCentrality),
(
"DegreeCentrality".to_string(),
Arc::<Box<dyn AlgoImpl>>::new(Box::new(DegreeCentrality)),
),
#[cfg(feature = "graph-algo")]
"ClosenessCentrality" => Box::new(ClosenessCentrality),
(
"ClosenessCentrality".to_string(),
Arc::<Box<dyn AlgoImpl>>::new(Box::new(ClosenessCentrality)),
),
#[cfg(feature = "graph-algo")]
"BetweennessCentrality" => Box::new(BetweennessCentrality),
(
"BetweennessCentrality".to_string(),
Arc::<Box<dyn AlgoImpl>>::new(Box::new(BetweennessCentrality)),
),
#[cfg(feature = "graph-algo")]
"DepthFirstSearch" | "DFS" => Box::new(Dfs),
(
"DepthFirstSearch".to_string(),
Arc::<Box<dyn AlgoImpl>>::new(Box::new(Dfs)),
),
#[cfg(feature = "graph-algo")]
"BreadthFirstSearch" | "BFS" => Box::new(Bfs),
(
"DFS".to_string(),
Arc::<Box<dyn AlgoImpl>>::new(Box::new(Dfs)),
),
#[cfg(feature = "graph-algo")]
"ShortestPathDijkstra" => Box::new(ShortestPathDijkstra),
(
"BreadthFirstSearch".to_string(),
Arc::<Box<dyn AlgoImpl>>::new(Box::new(Bfs)),
),
#[cfg(feature = "graph-algo")]
"ShortestPathAStar" => Box::new(ShortestPathAStar),
(
"BFS".to_string(),
Arc::<Box<dyn AlgoImpl>>::new(Box::new(Bfs)),
),
#[cfg(feature = "graph-algo")]
"KShortestPathYen" => Box::new(KShortestPathYen),
(
"ShortestPathDijkstra".to_string(),
Arc::<Box<dyn AlgoImpl>>::new(Box::new(ShortestPathDijkstra)),
),
#[cfg(feature = "graph-algo")]
"MinimumSpanningTreePrim" => Box::new(MinimumSpanningTreePrim),
(
"ShortestPathAStar".to_string(),
Arc::<Box<dyn AlgoImpl>>::new(Box::new(ShortestPathAStar)),
),
#[cfg(feature = "graph-algo")]
"MinimumSpanningForestKruskal" => Box::new(MinimumSpanningForestKruskal),
(
"KShortestPathYen".to_string(),
Arc::<Box<dyn AlgoImpl>>::new(Box::new(KShortestPathYen)),
),
#[cfg(feature = "graph-algo")]
"TopSort" => Box::new(TopSort),
(
"MinimumSpanningTreePrim".to_string(),
Arc::<Box<dyn AlgoImpl>>::new(Box::new(MinimumSpanningTreePrim)),
),
#[cfg(feature = "graph-algo")]
"ConnectedComponents" => Box::new(StronglyConnectedComponent::new(false)),
(
"MinimumSpanningForestKruskal".to_string(),
Arc::<Box<dyn AlgoImpl>>::new(Box::new(MinimumSpanningForestKruskal)),
),
#[cfg(feature = "graph-algo")]
"StronglyConnectedComponents" | "SCC" => {
Box::new(StronglyConnectedComponent::new(true))
}
(
"TopSort".to_string(),
Arc::<Box<dyn AlgoImpl>>::new(Box::new(TopSort)),
),
#[cfg(feature = "graph-algo")]
"PageRank" => Box::new(PageRank),
(
"ConnectedComponents".to_string(),
Arc::<Box<dyn AlgoImpl>>::new(Box::new(StronglyConnectedComponent::new(false))),
),
#[cfg(feature = "graph-algo")]
"CommunityDetectionLouvain" => Box::new(CommunityDetectionLouvain),
(
"StronglyConnectedComponents".to_string(),
Arc::<Box<dyn AlgoImpl>>::new(Box::new(StronglyConnectedComponent::new(true))),
),
#[cfg(feature = "graph-algo")]
"LabelPropagation" => Box::new(LabelPropagation),
(
"SCC".to_string(),
Arc::<Box<dyn AlgoImpl>>::new(Box::new(StronglyConnectedComponent::new(true))),
),
#[cfg(feature = "graph-algo")]
"RandomWalk" => Box::new(RandomWalk),
"ReorderSort" => Box::new(ReorderSort),
"JsonReader" => Box::new(JsonReader),
"CsvReader" => Box::new(CsvReader),
"Constant" => Box::new(Constant),
name => bail!(AlgoNotFoundError(name.to_string(), self.name.span)),
})
(
"PageRank".to_string(),
Arc::<Box<dyn AlgoImpl>>::new(Box::new(PageRank)),
),
#[cfg(feature = "graph-algo")]
(
"CommunityDetectionLouvain".to_string(),
Arc::<Box<dyn AlgoImpl>>::new(Box::new(CommunityDetectionLouvain)),
),
#[cfg(feature = "graph-algo")]
(
"LabelPropagation".to_string(),
Arc::<Box<dyn AlgoImpl>>::new(Box::new(LabelPropagation)),
),
#[cfg(feature = "graph-algo")]
(
"RandomWalk".to_string(),
Arc::<Box<dyn AlgoImpl>>::new(Box::new(RandomWalk)),
),
(
"ReorderSort".to_string(),
Arc::<Box<dyn AlgoImpl>>::new(Box::new(ReorderSort)),
),
(
"JsonReader".to_string(),
Arc::<Box<dyn AlgoImpl>>::new(Box::new(JsonReader)),
),
(
"CsvReader".to_string(),
Arc::<Box<dyn AlgoImpl>>::new(Box::new(CsvReader)),
),
(
"Constant".to_string(),
Arc::<Box<dyn AlgoImpl>>::new(Box::new(Constant)),
),
]))
};
}
impl AlgoHandle {
pub(crate) fn new(name: &str, span: SourceSpan) -> Self {
AlgoHandle {
name: Symbol::new(name, span),
}
}
}

@ -10,6 +10,7 @@ use std::collections::btree_map::Entry;
use std::collections::{BTreeMap, BTreeSet};
use std::fmt::{Debug, Display, Formatter};
use std::rc::Rc;
use std::sync::Arc;
use miette::{ensure, Diagnostic, Result};
use smallvec::SmallVec;
@ -207,7 +208,7 @@ pub(crate) struct AlgoApply {
pub(crate) head: Vec<Symbol>,
pub(crate) arity: usize,
pub(crate) span: SourceSpan,
pub(crate) algo_impl: Rc<Box<dyn AlgoImpl>>,
pub(crate) algo_impl: Arc<Box<dyn AlgoImpl>>,
}
impl AlgoApply {
@ -232,7 +233,7 @@ pub(crate) struct MagicAlgoApply {
pub(crate) options: Rc<BTreeMap<SmartString<LazyCompact>, Expr>>,
pub(crate) span: SourceSpan,
pub(crate) arity: usize,
pub(crate) algo_impl: Rc<Box<dyn AlgoImpl>>,
pub(crate) algo_impl: Arc<Box<dyn AlgoImpl>>,
}
#[derive(Error, Diagnostic, Debug)]

@ -8,6 +8,7 @@
use std::cmp::{max, min};
use std::collections::BTreeMap;
use std::sync::Arc;
use miette::{bail, ensure, Diagnostic, IntoDiagnostic, Result};
use pest::error::InputLocation;
@ -20,6 +21,7 @@ use crate::data::value::DataValue;
use crate::parse::query::parse_query;
use crate::parse::schema::parse_nullable_type;
use crate::parse::sys::{parse_sys, SysOp};
use crate::AlgoImpl;
pub(crate) mod expr;
pub(crate) mod query;
@ -104,6 +106,7 @@ pub(crate) fn parse_type(src: &str) -> Result<NullableColType> {
pub(crate) fn parse_script(
src: &str,
param_pool: &BTreeMap<String, DataValue>,
algorithms: &BTreeMap<String, Arc<Box<dyn AlgoImpl>>>,
) -> Result<CozoScript> {
let parsed = CozoScriptParser::parse(Rule::script, src)
.map_err(|err| {
@ -117,19 +120,21 @@ pub(crate) fn parse_script(
.unwrap();
Ok(match parsed.as_rule() {
Rule::query_script => {
let q = parse_query(parsed.into_inner(), param_pool)?;
let q = parse_query(parsed.into_inner(), param_pool, algorithms)?;
CozoScript::Multi(vec![q])
}
Rule::multi_script => {
let mut qs = vec![];
for pair in parsed.into_inner() {
if pair.as_rule() != Rule::EOI {
qs.push(parse_query(pair.into_inner(), param_pool)?);
qs.push(parse_query(pair.into_inner(), param_pool, algorithms)?);
}
}
CozoScript::Multi(qs)
}
Rule::sys_script => CozoScript::Sys(parse_sys(parsed.into_inner(), param_pool)?),
Rule::sys_script => {
CozoScript::Sys(parse_sys(parsed.into_inner(), param_pool, algorithms)?)
}
_ => unreachable!(),
})
}

@ -11,6 +11,7 @@ use std::collections::BTreeMap;
use std::error::Error;
use std::fmt::{Display, Formatter};
use std::rc::Rc;
use std::sync::Arc;
use either::{Left, Right};
use itertools::Itertools;
@ -19,7 +20,7 @@ use smartstring::{LazyCompact, SmartString};
use thiserror::Error;
use crate::algo::constant::Constant;
use crate::algo::AlgoHandle;
use crate::algo::{AlgoHandle, AlgoNotFoundError};
use crate::data::aggr::{parse_aggr, Aggregation};
use crate::data::expr::Expr;
use crate::data::program::{
@ -34,6 +35,7 @@ use crate::parse::expr::build_expr;
use crate::parse::schema::parse_schema;
use crate::parse::{ExtractSpan, Pair, Pairs, Rule, SourceSpan};
use crate::runtime::relation::InputRelationHandle;
use crate::AlgoImpl;
#[derive(Error, Diagnostic, Debug)]
#[error("Query option {0} is not constant")]
@ -92,6 +94,7 @@ fn merge_spans(symbs: &[Symbol]) -> SourceSpan {
pub(crate) fn parse_query(
src: Pairs<'_>,
param_pool: &BTreeMap<String, DataValue>,
algorithms: &BTreeMap<String, Arc<Box<dyn AlgoImpl>>>,
) -> Result<InputProgram> {
let mut progs: BTreeMap<Symbol, InputInlineRulesOrAlgo> = Default::default();
let mut out_opts: QueryOutOptions = Default::default();
@ -144,7 +147,7 @@ pub(crate) fn parse_query(
}
Rule::algo_rule => {
let rule_span = pair.extract_span();
let (name, apply) = parse_algo_rule(pair, param_pool)?;
let (name, apply) = parse_algo_rule(pair, param_pool, algorithms)?;
match progs.entry(name) {
Entry::Vacant(e) => {
@ -199,7 +202,7 @@ pub(crate) fn parse_query(
let handle = AlgoHandle {
name: Symbol::new("Constant", span),
};
let algo_impl = handle.get_impl()?;
let algo_impl = Box::new(Constant);
algo_impl.init_options(&mut options, span)?;
let arity = algo_impl.arity(&options, &head, span)?;
@ -218,7 +221,7 @@ pub(crate) fn parse_query(
head,
arity,
span,
algo_impl: Rc::new(algo_impl),
algo_impl: Arc::new(algo_impl),
},
},
);
@ -237,7 +240,8 @@ pub(crate) fn parse_query(
Rule::sleep_option => {
#[cfg(feature = "wasm")]
bail!(":sleep is not supported under WASM");
#[cfg(not(feature = "wasm"))] {
#[cfg(not(feature = "wasm"))]
{
let pair = pair.into_inner().next().unwrap();
let span = pair.extract_span();
let sleep = build_expr(pair, param_pool)?
@ -651,6 +655,7 @@ fn parse_rule_head_arg(
fn parse_algo_rule(
src: Pair<'_>,
param_pool: &BTreeMap<String, DataValue>,
algorithms: &BTreeMap<String, Arc<Box<dyn AlgoImpl>>>,
) -> Result<(Symbol, AlgoApply)> {
let mut src = src.into_inner();
let (out_symbol, head, aggr) = parse_rule_head(src.next().unwrap(), param_pool)?;
@ -745,7 +750,8 @@ fn parse_algo_rule(
let algo = AlgoHandle::new(algo_name, name_pair.extract_span());
let algo_impl = algo.get_impl()?;
let algo_impl = algorithms.get(&algo.name as &str)
.ok_or_else(|| AlgoNotFoundError(algo.name.to_string(), name_pair.extract_span()))?;
algo_impl.init_options(&mut options, args_list_span)?;
let arity = algo_impl.arity(&options, &head, name_pair.extract_span())?;
@ -763,7 +769,7 @@ fn parse_algo_rule(
head,
arity,
span: args_list_span,
algo_impl: Rc::new(algo_impl),
algo_impl: algo_impl.clone(),
},
))
}
@ -801,7 +807,7 @@ fn make_empty_const_rule(prog: &mut InputProgram, bindings: &[Symbol]) {
head: bindings.to_vec(),
arity: bindings.len(),
span: Default::default(),
algo_impl: Rc::new(Box::new(Constant)),
algo_impl: Arc::new(Box::new(Constant)),
},
},
);

@ -7,6 +7,7 @@
*/
use std::collections::BTreeMap;
use std::sync::Arc;
use itertools::Itertools;
use miette::{Diagnostic, Result};
@ -18,6 +19,7 @@ use crate::data::value::DataValue;
use crate::parse::query::parse_query;
use crate::parse::{ExtractSpan, Pairs, Rule, SourceSpan};
use crate::runtime::relation::AccessLevel;
use crate::AlgoImpl;
pub(crate) enum SysOp {
Compact,
@ -41,6 +43,7 @@ struct ProcessIdError(String, #[label] SourceSpan);
pub(crate) fn parse_sys(
mut src: Pairs<'_>,
param_pool: &BTreeMap<String, DataValue>,
algorithms: &BTreeMap<String, Arc<Box<dyn AlgoImpl>>>,
) -> Result<SysOp> {
let inner = src.next().unwrap();
Ok(match inner.as_rule() {
@ -55,7 +58,11 @@ pub(crate) fn parse_sys(
SysOp::KillRunning(i)
}
Rule::explain_op => {
let prog = parse_query(inner.into_inner().next().unwrap().into_inner(), param_pool)?;
let prog = parse_query(
inner.into_inner().next().unwrap().into_inner(),
param_pool,
algorithms,
)?;
SysOp::Explain(Box::new(prog))
}
Rule::list_relations_op => SysOp::ListRelations,
@ -93,7 +100,7 @@ pub(crate) fn parse_sys(
"protected" => AccessLevel::Protected,
"read_only" => AccessLevel::ReadOnly,
"hidden" => AccessLevel::Hidden,
_ => unreachable!()
_ => unreachable!(),
};
let mut rels = vec![];
for rel_p in ps {
@ -119,7 +126,7 @@ pub(crate) fn parse_sys(
let op = clause_inner.next().unwrap();
let script = clause_inner.next().unwrap();
let script_str = script.as_str();
parse_query(script.into_inner(), &Default::default())?;
parse_query(script.into_inner(), &Default::default(), algorithms)?;
match op.as_rule() {
Rule::trigger_put => puts.push(script_str.to_string()),
Rule::trigger_rm => rms.push(script_str.to_string()),

@ -8,6 +8,7 @@
use std::collections::BTreeMap;
use std::rc::Rc;
use std::sync::Arc;
use itertools::Itertools;
use miette::{bail, Diagnostic, Result, WrapErr};
@ -42,6 +43,8 @@ impl<'a> SessionTx<'a> {
meta: &InputRelationHandle,
headers: &[Symbol],
) -> Result<Vec<(Vec<u8>, Vec<u8>)>> {
// TODO
let algorithms = BTreeMap::new();
let mut to_clear = vec![];
let mut replaced_old_triggers = None;
if op == RelationOp::Replace {
@ -57,8 +60,8 @@ impl<'a> SessionTx<'a> {
replaced_old_triggers = Some((old_handle.put_triggers, old_handle.rm_triggers))
}
for trigger in &old_handle.replace_triggers {
let program =
parse_script(trigger, &Default::default())?.get_single_program()?;
let program = parse_script(trigger, &Default::default(), &algorithms)?
.get_single_program()?;
let (_, cleanups) = db.run_query(self, program).map_err(|err| {
if err.source_code().is_some() {
@ -137,7 +140,8 @@ impl<'a> SessionTx<'a> {
if has_triggers && !new_tuples.is_empty() {
for trigger in &relation_store.rm_triggers {
let mut program =
parse_script(trigger, &Default::default())?.get_single_program()?;
parse_script(trigger, &Default::default(), &db.algorithms)?
.get_single_program()?;
let mut bindings = relation_store
.metadata
@ -312,7 +316,8 @@ impl<'a> SessionTx<'a> {
if has_triggers && !new_tuples.is_empty() {
for trigger in &relation_store.put_triggers {
let mut program =
parse_script(trigger, &Default::default())?.get_single_program()?;
parse_script(trigger, &Default::default(), &db.algorithms)?
.get_single_program()?;
let mut bindings = relation_store
.metadata
@ -442,7 +447,7 @@ fn make_const_rule(
head: bindings,
arity: bindings_arity,
span: Default::default(),
algo_impl: Rc::new(Box::new(Constant)),
algo_impl: Arc::new(Box::new(Constant)),
},
},
);

@ -29,7 +29,6 @@ use crate::data::program::{InputProgram, QueryAssertion, RelationOp};
use crate::data::relation::ColumnDef;
use crate::data::tuple::{Tuple, TupleT};
use crate::data::value::{DataValue, LARGEST_UTF_CHAR};
use crate::decode_tuple_from_kv;
use crate::parse::sys::SysOp;
use crate::parse::{parse_script, CozoScript, SourceSpan};
use crate::query::compile::{CompiledProgram, CompiledRule, CompiledRuleSet};
@ -39,6 +38,8 @@ use crate::query::ra::{
use crate::runtime::relation::{AccessLevel, InsufficientAccessLevel, RelationHandle, RelationId};
use crate::runtime::transact::SessionTx;
use crate::storage::{Storage, StoreTx};
use crate::{decode_tuple_from_kv, AlgoImpl};
use crate::algo::DEFAULT_ALGOS;
struct RunningQueryHandle {
started_at: f64,
@ -71,6 +72,7 @@ pub struct Db<S> {
relation_store_id: Arc<AtomicU64>,
queries_count: Arc<AtomicU64>,
running_queries: Arc<Mutex<BTreeMap<u64, RunningQueryHandle>>>,
pub(crate) algorithms: Arc<BTreeMap<String, Arc<Box<dyn AlgoImpl>>>>,
}
impl<S> Debug for Db<S> {
@ -116,6 +118,7 @@ impl<'s, S: Storage<'s>> Db<S> {
relation_store_id: Arc::new(Default::default()),
queries_count: Arc::new(Default::default()),
running_queries: Arc::new(Mutex::new(Default::default())),
algorithms: DEFAULT_ALGOS.clone(),
};
Ok(ret)
}
@ -426,7 +429,7 @@ impl<'s, S: Storage<'s>> Db<S> {
payload: &str,
param_pool: &BTreeMap<String, DataValue>,
) -> Result<NamedRows> {
match parse_script(payload, param_pool)? {
match parse_script(payload, param_pool, &self.algorithms)? {
CozoScript::Multi(ps) => {
let is_write = ps.iter().any(|p| p.out_opts.store_relation.is_some());
let mut cleanups = vec![];
@ -948,13 +951,7 @@ impl<'s, S: Storage<'s>> Db<S> {
if let Some((meta, relation_op)) = &out_opts.store_relation {
let to_clear = tx
.execute_relation(
self,
scan,
*relation_op,
meta,
&entry_head_or_default,
)
.execute_relation(self, scan, *relation_op, meta, &entry_head_or_default)
.wrap_err_with(|| format!("when executing against relation '{}'", meta.name))?;
clean_ups.extend(to_clear);
Ok((

Loading…
Cancel
Save