diff --git a/cozo-bin/src/server.rs b/cozo-bin/src/server.rs index af1ff5cb..21800066 100644 --- a/cozo-bin/src/server.rs +++ b/cozo-bin/src/server.rs @@ -10,7 +10,9 @@ use std::collections::BTreeMap; use std::convert::Infallible; use std::net::{Ipv6Addr, SocketAddr}; use std::str::FromStr; -use std::sync::{mpsc, Arc, Mutex}; +use std::sync::atomic::{AtomicU32, Ordering}; +use std::sync::{Arc, Mutex}; +use std::thread; use axum::extract::{Path, Query, State}; use axum::http::StatusCode; @@ -19,13 +21,16 @@ use axum::response::{Html, Sse}; use axum::routing::{get, post, put}; use axum::{Json, Router}; use clap::Args; -use futures::stream::{self, Stream}; +use futures::stream::Stream; use itertools::Itertools; use log::{info, warn}; +use miette::miette; use serde_json::json; use tokio::task::spawn_blocking; -use cozo::{DataValue, DbInstance, NamedRows}; +use cozo::{ + format_error_as_json, DataValue, DbInstance, MultiTransaction, NamedRows, SimpleFixedRule, +}; #[derive(Args, Debug)] pub(crate) struct ServerArgs { @@ -56,13 +61,24 @@ pub(crate) struct ServerArgs { port: u16, } -type RuleCallbackStore = BTreeMap>>; -type DbState = (DbInstance, Arc>); +#[derive(Clone)] +struct DbState { + db: DbInstance, + rule_senders: Arc>>>>, + rule_counter: Arc, + tx_counter: Arc, + txs: Arc>>>, +} pub(crate) async fn server_main(args: ServerArgs) { let db = DbInstance::new(&args.engine, args.path, &args.config).unwrap(); - let rule_channels: Arc> = Default::default(); - let state = (db, rule_channels); + let state = DbState { + db, + rule_senders: Default::default(), + rule_counter: Default::default(), + tx_counter: Default::default(), + txs: Default::default(), + }; let app = Router::new() .fallback(not_found) .route("/", get(root)) @@ -72,8 +88,13 @@ pub(crate) async fn server_main(args: ServerArgs) { .route("/backup", post(backup)) .route("/import-from-backup", post(import_from_backup)) .route("/changes/:relation", get(observe_changes)) - .route("/rules/:name", get(register_rule)) // sse + post - .route("/rules/:name/:id", post(rule_result)) + .route("/rules/:name", get(register_rule)) + .route( + "/rule-result/:id", + post(post_rule_result).delete(post_rule_err), + ) // +keep alive + .route("/transact", post(start_transact)) + .route("/transact/:id", post(transact_query).put(finish_query)) // +keep alive .with_state(state); let addr = if Ipv6Addr::from_str(&args.bind).is_ok() { SocketAddr::from_str(&format!("[{}]:{}", args.bind, args.port)).unwrap() @@ -96,6 +117,79 @@ pub(crate) async fn server_main(args: ServerArgs) { .unwrap(); } +#[derive(serde_derive::Deserialize)] +struct StartTransactPayload { + write: bool, +} + +async fn start_transact( + State(st): State, + Query(payload): Query, +) -> (StatusCode, Json) { + let tx = st.db.multi_transaction(payload.write); + let id = st.tx_counter.fetch_add(1, Ordering::SeqCst); + st.txs.lock().unwrap().insert(id, Arc::new(tx)); + (StatusCode::OK, json!({"ok": true, "id": id}).into()) +} + +async fn transact_query( + State(st): State, + Path(id): Path, + Json(payload): Json, +) -> (StatusCode, Json) { + let tx = match st.txs.lock().unwrap().get(&id) { + None => return (StatusCode::NOT_FOUND, json!({"ok": false}).into()), + Some(tx) => tx.clone(), + }; + let src = payload.script.clone(); + let result = spawn_blocking(move || { + let params = payload + .params + .into_iter() + .map(|(k, v)| (k, DataValue::from(v))) + .collect(); + let query = payload.script; + tx.run_script(&query, params) + }) + .await; + match result { + Ok(Ok(res)) => (StatusCode::OK, res.into_json().into()), + Ok(Err(err)) => ( + StatusCode::BAD_REQUEST, + format_error_as_json(err, Some(&src)).into(), + ), + Err(err) => internal_error(err), + } +} + +#[derive(serde_derive::Deserialize)] +struct FinishTransactPayload { + abort: bool, +} + +async fn finish_query( + State(st): State, + Path(id): Path, + Json(payload): Json, +) -> (StatusCode, Json) { + let tx = match st.txs.lock().unwrap().remove(&id) { + None => return (StatusCode::NOT_FOUND, json!({"ok": false}).into()), + Some(tx) => tx, + }; + let res = if payload.abort { + tx.abort() + } else { + tx.commit() + }; + match res { + Ok(_) => (StatusCode::OK, json!({"ok": true}).into()), + Err(err) => ( + StatusCode::BAD_REQUEST, + json!({"ok": false, "message": err.to_string()}).into(), + ), + } +} + #[derive(serde_derive::Deserialize)] struct QueryPayload { script: String, @@ -103,7 +197,7 @@ struct QueryPayload { } async fn text_query( - State((db, _)): State, + State(st): State, Json(payload): Json, ) -> (StatusCode, Json) { let params = payload @@ -111,7 +205,7 @@ async fn text_query( .into_iter() .map(|(k, v)| (k, DataValue::from(v))) .collect(); - let result = spawn_blocking(move || db.run_script_fold_err(&payload.script, params)).await; + let result = spawn_blocking(move || st.db.run_script_fold_err(&payload.script, params)).await; match result { Ok(res) => wrap_json(res), Err(err) => internal_error(err), @@ -119,7 +213,7 @@ async fn text_query( } async fn export_relations( - State((db, _)): State, + State(st): State, Path(relations): Path, ) -> (StatusCode, Json) { let relations = relations @@ -132,7 +226,7 @@ async fn export_relations( } }) .collect_vec(); - let result = spawn_blocking(move || db.export_relations(relations.iter())).await; + let result = spawn_blocking(move || st.db.export_relations(relations.iter())).await; match result { Ok(Ok(s)) => { let ret = json!({"ok": true, "data": s}); @@ -147,7 +241,7 @@ async fn export_relations( } async fn import_relations( - State((db, _)): State, + State(st): State, Json(payload): Json, ) -> (StatusCode, Json) { let payload = match payload.as_object() { @@ -175,7 +269,7 @@ async fn import_relations( } }; - let result = spawn_blocking(move || db.import_relations(payload)).await; + let result = spawn_blocking(move || st.db.import_relations(payload)).await; match result { Ok(Ok(_)) => (StatusCode::OK, json!({"ok": true}).into()), Ok(Err(err)) => { @@ -191,10 +285,10 @@ struct BackupPayload { } async fn backup( - State((db, _)): State, + State(st): State, Json(payload): Json, ) -> (StatusCode, Json) { - let result = spawn_blocking(move || db.backup_db(payload.path)).await; + let result = spawn_blocking(move || st.db.backup_db(payload.path)).await; match result { Ok(Ok(())) => { @@ -214,11 +308,11 @@ struct BackupImportPayload { relations: Vec, } async fn import_from_backup( - State((db, _)): State, + State(st): State, Json(payload): Json, ) -> (StatusCode, Json) { let result = - spawn_blocking(move || db.import_from_backup(&payload.path, &payload.relations)).await; + spawn_blocking(move || st.db.import_from_backup(&payload.path, &payload.relations)).await; match result { Ok(Ok(())) => { @@ -238,55 +332,119 @@ struct RuleRegisterOptions { arity: usize, } -async fn rule_result( - State((store, _)): State, - Path(name): Path, - Path(id): Path, +async fn post_rule_result( + State(st): State, + Path(id): Path, + Json(res): Json, ) -> (StatusCode, Json) { - todo!() + let res = match NamedRows::from_json(&res) { + Ok(res) => res, + Err(err) => { + if let Some(ch) = st.rule_senders.lock().unwrap().remove(&id) { + let _ = ch.send(Err(miette!("downstream posted malformed result"))); + } + return ( + StatusCode::BAD_REQUEST, + json!({"ok": false, "message": err.to_string()}).into(), + ); + } + }; + if let Some(ch) = st.rule_senders.lock().unwrap().remove(&id) { + match ch.send(Ok(res)) { + Ok(_) => (StatusCode::OK, json!({"ok": true}).into()), + Err(err) => ( + StatusCode::INTERNAL_SERVER_ERROR, + json!({"ok": false, "message": err.to_string()}).into(), + ), + } + } else { + (StatusCode::NOT_FOUND, json!({"ok": false}).into()) + } +} + +async fn post_rule_err( + State(st): State, + Path(id): Path, +) -> (StatusCode, Json) { + if let Some(ch) = st.rule_senders.lock().unwrap().remove(&id) { + match ch.send(Err(miette!("downstream cancelled computation"))) { + Ok(_) => (StatusCode::OK, json!({"ok": true}).into()), + Err(err) => ( + StatusCode::INTERNAL_SERVER_ERROR, + json!({"ok": false, "message": err.to_string()}).into(), + ), + } + } else { + (StatusCode::NOT_FOUND, json!({"ok": false}).into()) + } } async fn register_rule( - State((db, cbs)): State, + State(st): State, Path(name): Path, Query(rule_opts): Query, ) -> Sse>> { - let (id, recv) = db.register_callback(&name, None); - let (sender, mut receiver) = tokio::sync::mpsc::channel(1); + let (rule, task_receiver) = SimpleFixedRule::rule_with_channel(rule_opts.arity); + let (down_sender, mut down_receiver) = tokio::sync::mpsc::channel(1); + let mut errored = None; + + if let Err(err) = st.db.register_fixed_rule(name.clone(), rule) { + errored = Some(err); + } else { + let rule_senders = st.rule_senders.clone(); + let rule_counter = st.rule_counter.clone(); + + thread::spawn(move || { + for (inputs, options, sender) in task_receiver { + let id = rule_counter.fetch_add(1, Ordering::AcqRel); + let inputs: serde_json::Value = + inputs.into_iter().map(|ip| ip.into_json()).collect(); + let options: serde_json::Value = options + .into_iter() + .map(|(k, v)| (k, serde_json::Value::from(v))) + .collect(); + if down_sender.blocking_send((id, inputs, options)).is_err() { + let _ = sender.send(Err(miette!("cannot send request to downstream"))); + } else { + rule_senders.lock().unwrap().insert(id, sender); + } + } + }); + } + struct Guard { - id: u32, + name: String, db: DbInstance, - relation: String, } impl Drop for Guard { fn drop(&mut self) { - info!("dropping changes SSE {}: {}", self.relation, self.id); - self.db.unregister_callback(self.id); + info!("dropping rules SSE {}", self.name); + let _ = self.db.unregister_fixed_rule(&self.name); } } - spawn_blocking(move || { - for data in recv { - sender.blocking_send(data).unwrap(); - } - }); let stream = async_stream::stream! { - info!("starting callback SSE {}: {}", name, id); - let _guard = Guard {id, db, relation: name}; - while let Some((op, new, old)) = receiver.recv().await { - let item = json!({"op": op.to_string(), "new_rows": new.into_json(), "old_rows": old.into_json()}); + if let Some(err) = errored { + let item = json!({"type": "register-error", "error": err.to_string()}); yield Ok(Event::default().json_data(item).unwrap()); + } else { + info!("starting rule SSE {}", name); + let _guard = Guard {db: st.db, name}; + while let Some((id, inputs, options)) = down_receiver.recv().await { + let item = json!({"type": "request", "id": id, "inputs": inputs, "options": options}); + yield Ok(Event::default().json_data(item).unwrap()); + } } }; Sse::new(stream).keep_alive(KeepAlive::default()) } async fn observe_changes( - State((db, _)): State, + State(st): State, Path(relation): Path, ) -> Sse>> { - let (id, recv) = db.register_callback(&relation, None); + let (id, recv) = st.db.register_callback(&relation, None); let (sender, mut receiver) = tokio::sync::mpsc::channel(1); struct Guard { id: u32, @@ -308,7 +466,7 @@ async fn observe_changes( }); let stream = async_stream::stream! { info!("starting changes SSE {}: {}", relation, id); - let _guard = Guard {id, db, relation}; + let _guard = Guard {id, db: st.db, relation}; while let Some((op, new, old)) = receiver.recv().await { let item = json!({"op": op.to_string(), "new_rows": new.into_json(), "old_rows": old.into_json()}); yield Ok(Event::default().json_data(item).unwrap());