From 76fa6b70fc81db2af92dca5796b40dcda3086ec5 Mon Sep 17 00:00:00 2001 From: Ziyang Hu Date: Wed, 14 Jun 2023 05:15:13 +0800 Subject: [PATCH] restore server features --- cozo-bin/src/server.rs | 425 +++++++++++++++++++++-------------------- 1 file changed, 215 insertions(+), 210 deletions(-) diff --git a/cozo-bin/src/server.rs b/cozo-bin/src/server.rs index 7cf67998..24859040 100644 --- a/cozo-bin/src/server.rs +++ b/cozo-bin/src/server.rs @@ -11,11 +11,13 @@ use std::convert::Infallible; use std::net::{Ipv6Addr, SocketAddr}; use std::str::FromStr; // use std::sync::atomic::{AtomicU32, Ordering}; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; +use std::sync::atomic::{AtomicU32, Ordering}; +use std::thread; // use std::thread; use axum::body::{boxed, Body, BoxBody}; -use axum::extract::{DefaultBodyLimit, Path, State}; +use axum::extract::{DefaultBodyLimit, Path, Query, State}; use axum::http::{header, HeaderName, Method, Request, Response, StatusCode}; use axum::response::sse::{Event, KeepAlive}; use axum::response::{Html, Sse}; @@ -26,6 +28,7 @@ use futures::future::BoxFuture; use futures::stream::Stream; use itertools::Itertools; use log::{error, info, warn}; +use miette::miette; // use miette::miette; use rand::Rng; use serde_json::json; @@ -34,7 +37,7 @@ use tower_http::auth::{AsyncAuthorizeRequest, AsyncRequireAuthorizationLayer}; use tower_http::compression::CompressionLayer; use tower_http::cors::{Any, CorsLayer}; -use cozo::{DataValue, DbInstance, NamedRows, ScriptMutability}; +use cozo::{DataValue, DbInstance, format_error_as_json, MultiTransaction, NamedRows, ScriptMutability, SimpleFixedRule}; #[derive(Args, Debug)] pub(crate) struct ServerArgs { @@ -70,10 +73,10 @@ pub(crate) struct ServerArgs { #[derive(Clone)] struct DbState { db: DbInstance, - // rule_senders: Arc>>>>, - // rule_counter: Arc, - // tx_counter: Arc, - // txs: Arc>>>, + rule_senders: Arc>>>>, + rule_counter: Arc, + tx_counter: Arc, + txs: Arc>>>, } #[derive(Clone)] @@ -84,14 +87,14 @@ struct MyAuth { } impl AsyncAuthorizeRequest for MyAuth -where - B: Send + Sync + 'static, + where + B: Send + Sync + 'static, { type RequestBody = B; type ResponseBody = BoxBody; type Future = BoxFuture<'static, Result, Response>>; - fn authorize<'a>(&'a mut self, mut request: Request) -> Self::Future { + fn authorize(&mut self, mut request: Request) -> Self::Future { let skip_auth = self.skip_auth; let auth_guard = self.auth_guard.clone(); let token_table = self.token_table.clone(); @@ -127,8 +130,7 @@ where let (name, db) = tt.as_ref(); if let Some(auth_header) = request.headers().get("Authorization") { if let Ok(auth_str) = auth_header.to_str() { - if auth_str.starts_with("Bearer ") { - let token = &auth_str["Bearer ".len()..]; + if let Some(token) = auth_str.strip_prefix("Bearer ") { match db.run_script( &format!("?[mutable] := *{name} {{ token: $token, mutable }}"), BTreeMap::from([(String::from("token"), DataValue::from(token))]), @@ -147,7 +149,7 @@ where Err(err) => { eprintln!("Error: {}", err); None - }, + } } } else { None @@ -232,10 +234,10 @@ pub(crate) async fn server_main(args: ServerArgs) { let state = DbState { db, - // rule_senders: Default::default(), - // rule_counter: Default::default(), - // tx_counter: Default::default(), - // txs: Default::default(), + rule_senders: Default::default(), + rule_counter: Default::default(), + tx_counter: Default::default(), + txs: Default::default(), }; let cors = CorsLayer::new() .allow_methods([Method::GET, Method::POST, Method::PUT, Method::DELETE]) @@ -249,13 +251,13 @@ pub(crate) async fn server_main(args: ServerArgs) { .route("/backup", post(backup)) .route("/import-from-backup", post(import_from_backup)) .route("/changes/:relation", get(observe_changes)) - // .route("/rules/:name", get(register_rule)) - // .route( - // "/rule-result/:id", - // post(post_rule_result).delete(post_rule_err), - // ) // +keep alive - // .route("/transact", post(start_transact)) - // .route("/transact/:id", post(transact_query).put(finish_query)) + .route("/rules/:name", get(register_rule)) + .route( + "/rule-result/:id", + post(post_rule_result).delete(post_rule_err), + ) // +keep alive + .route("/transact", post(start_transact)) + .route("/transact/:id", post(transact_query).put(finish_query)) .with_state(state) .layer(AsyncRequireAuthorizationLayer::new(auth_obj)) .fallback(not_found) @@ -288,76 +290,76 @@ pub(crate) async fn server_main(args: ServerArgs) { #[derive(serde_derive::Deserialize)] struct StartTransactPayload { - // write: bool, + write: bool, } -// async fn start_transact( -// State(st): State, -// Query(payload): Query, -// ) -> (StatusCode, Json) { -// let tx = st.db.multi_transaction(payload.write); -// let id = st.tx_counter.fetch_add(1, Ordering::SeqCst); -// st.txs.lock().unwrap().insert(id, Arc::new(tx)); -// (StatusCode::OK, json!({"ok": true, "id": id}).into()) -// } - -// async fn transact_query( -// State(st): State, -// Path(id): Path, -// Json(payload): Json, -// ) -> (StatusCode, Json) { -// let tx = match st.txs.lock().unwrap().get(&id) { -// None => return (StatusCode::NOT_FOUND, json!({"ok": false}).into()), -// Some(tx) => tx.clone(), -// }; -// let src = payload.script.clone(); -// let result = spawn_blocking(move || { -// let params = payload -// .params -// .into_iter() -// .map(|(k, v)| (k, DataValue::from(v))) -// .collect(); -// let query = payload.script; -// tx.run_script(&query, params) -// }) -// .await; -// match result { -// Ok(Ok(res)) => (StatusCode::OK, res.into_json().into()), -// Ok(Err(err)) => ( -// StatusCode::BAD_REQUEST, -// format_error_as_json(err, Some(&src)).into(), -// ), -// Err(err) => internal_error(err), -// } -// } - -// #[derive(serde_derive::Deserialize)] -// struct FinishTransactPayload { -// abort: bool, -// } - -// async fn finish_query( -// State(st): State, -// Path(id): Path, -// Json(payload): Json, -// ) -> (StatusCode, Json) { -// let tx = match st.txs.lock().unwrap().remove(&id) { -// None => return (StatusCode::NOT_FOUND, json!({"ok": false}).into()), -// Some(tx) => tx, -// }; -// let res = if payload.abort { -// tx.abort() -// } else { -// tx.commit() -// }; -// match res { -// Ok(_) => (StatusCode::OK, json!({"ok": true}).into()), -// Err(err) => ( -// StatusCode::BAD_REQUEST, -// json!({"ok": false, "message": err.to_string()}).into(), -// ), -// } -// } +async fn start_transact( + State(st): State, + Query(payload): Query, +) -> (StatusCode, Json) { + let tx = st.db.multi_transaction(payload.write); + let id = st.tx_counter.fetch_add(1, Ordering::SeqCst); + st.txs.lock().unwrap().insert(id, Arc::new(tx)); + (StatusCode::OK, json!({"ok": true, "id": id}).into()) +} + +async fn transact_query( + State(st): State, + Path(id): Path, + Json(payload): Json, +) -> (StatusCode, Json) { + let tx = match st.txs.lock().unwrap().get(&id) { + None => return (StatusCode::NOT_FOUND, json!({"ok": false}).into()), + Some(tx) => tx.clone(), + }; + let src = payload.script.clone(); + let result = spawn_blocking(move || { + let params = payload + .params + .into_iter() + .map(|(k, v)| (k, DataValue::from(v))) + .collect(); + let query = payload.script; + tx.run_script(&query, params) + }) + .await; + match result { + Ok(Ok(res)) => (StatusCode::OK, res.into_json().into()), + Ok(Err(err)) => ( + StatusCode::BAD_REQUEST, + format_error_as_json(err, Some(&src)).into(), + ), + Err(err) => internal_error(err), + } +} + +#[derive(serde_derive::Deserialize)] +struct FinishTransactPayload { + abort: bool, +} + +async fn finish_query( + State(st): State, + Path(id): Path, + Json(payload): Json, +) -> (StatusCode, Json) { + let tx = match st.txs.lock().unwrap().remove(&id) { + None => return (StatusCode::NOT_FOUND, json!({"ok": false}).into()), + Some(tx) => tx, + }; + let res = if payload.abort { + tx.abort() + } else { + tx.commit() + }; + match res { + Ok(_) => (StatusCode::OK, json!({"ok": true}).into()), + Err(err) => ( + StatusCode::BAD_REQUEST, + json!({"ok": false, "message": err.to_string()}).into(), + ), + } +} #[derive(serde_derive::Deserialize)] struct QueryPayload { @@ -391,7 +393,7 @@ async fn text_query( }, ) }) - .await; + .await; match result { Ok(res) => wrap_json(res), Err(err) => internal_error(err), @@ -436,7 +438,7 @@ async fn import_relations( return ( StatusCode::BAD_REQUEST, json!({"ok": false, "message": "payload must be a JSON object"}).into(), - ) + ); } Some(pl) => { let mut ret = BTreeMap::new(); @@ -447,7 +449,7 @@ async fn import_relations( return ( StatusCode::BAD_REQUEST, json!({"ok": false, "message": err.to_string()}).into(), - ) + ); } }; ret.insert(k.to_string(), nr); @@ -466,6 +468,7 @@ async fn import_relations( Err(err) => internal_error(err), } } + #[derive(serde_derive::Deserialize)] struct BackupPayload { path: String, @@ -489,11 +492,13 @@ async fn backup( Err(err) => internal_error(err), } } + #[derive(serde_derive::Deserialize)] struct BackupImportPayload { path: String, relations: Vec, } + async fn import_from_backup( State(st): State, Json(payload): Json, @@ -514,123 +519,123 @@ async fn import_from_backup( } } -// #[derive(serde_derive::Deserialize)] -// struct RuleRegisterOptions { -// arity: usize, -// } - -// async fn post_rule_result( -// State(st): State, -// Path(id): Path, -// Json(res): Json, -// ) -> (StatusCode, Json) { -// let res = match NamedRows::from_json(&res) { -// Ok(res) => res, -// Err(err) => { -// if let Some(ch) = st.rule_senders.lock().unwrap().remove(&id) { -// let _ = ch.send(Err(miette!("downstream posted malformed result"))); -// } -// return ( -// StatusCode::BAD_REQUEST, -// json!({"ok": false, "message": err.to_string()}).into(), -// ); -// } -// }; -// if let Some(ch) = st.rule_senders.lock().unwrap().remove(&id) { -// match ch.send(Ok(res)) { -// Ok(_) => (StatusCode::OK, json!({"ok": true}).into()), -// Err(err) => ( -// StatusCode::INTERNAL_SERVER_ERROR, -// json!({"ok": false, "message": err.to_string()}).into(), -// ), -// } -// } else { -// (StatusCode::NOT_FOUND, json!({"ok": false}).into()) -// } -// } - -// async fn post_rule_err( -// State(st): State, -// Path(id): Path, -// ) -> (StatusCode, Json) { -// if let Some(ch) = st.rule_senders.lock().unwrap().remove(&id) { -// match ch.send(Err(miette!("downstream cancelled computation"))) { -// Ok(_) => (StatusCode::OK, json!({"ok": true}).into()), -// Err(err) => ( -// StatusCode::INTERNAL_SERVER_ERROR, -// json!({"ok": false, "message": err.to_string()}).into(), -// ), -// } -// } else { -// (StatusCode::NOT_FOUND, json!({"ok": false}).into()) -// } -// } - -// async fn register_rule( -// State(st): State, -// Path(name): Path, -// Query(rule_opts): Query, -// ) -> Sse>> { -// let (rule, task_receiver) = SimpleFixedRule::rule_with_channel(rule_opts.arity); -// let (down_sender, mut down_receiver) = tokio::sync::mpsc::channel(1); -// let mut errored = None; -// -// if let Err(err) = st.db.register_fixed_rule(name.clone(), rule) { -// errored = Some(err); -// } else { -// let rule_senders = st.rule_senders.clone(); -// let rule_counter = st.rule_counter.clone(); -// -// thread::spawn(move || { -// for (inputs, options, sender) in task_receiver { -// let id = rule_counter.fetch_add(1, Ordering::AcqRel); -// let inputs: serde_json::Value = -// inputs.into_iter().map(|ip| ip.into_json()).collect(); -// let options: serde_json::Value = options -// .into_iter() -// .map(|(k, v)| (k, serde_json::Value::from(v))) -// .collect(); -// if down_sender.blocking_send((id, inputs, options)).is_err() { -// let _ = sender.send(Err(miette!("cannot send request to downstream"))); -// } else { -// rule_senders.lock().unwrap().insert(id, sender); -// } -// } -// }); -// } -// -// struct Guard { -// name: String, -// db: DbInstance, -// } -// -// impl Drop for Guard { -// fn drop(&mut self) { -// info!("dropping rules SSE {}", self.name); -// let _ = self.db.unregister_fixed_rule(&self.name); -// } -// } -// -// let stream = async_stream::stream! { -// if let Some(err) = errored { -// let item = json!({"type": "register-error", "error": err.to_string()}); -// yield Ok(Event::default().json_data(item).unwrap()); -// } else { -// info!("starting rule SSE {}", name); -// let _guard = Guard {db: st.db, name}; -// while let Some((id, inputs, options)) = down_receiver.recv().await { -// let item = json!({"type": "request", "id": id, "inputs": inputs, "options": options}); -// yield Ok(Event::default().json_data(item).unwrap()); -// } -// } -// }; -// Sse::new(stream).keep_alive(KeepAlive::default()) -// } +#[derive(serde_derive::Deserialize)] +struct RuleRegisterOptions { + arity: usize, +} + +async fn post_rule_result( + State(st): State, + Path(id): Path, + Json(res): Json, +) -> (StatusCode, Json) { + let res = match NamedRows::from_json(&res) { + Ok(res) => res, + Err(err) => { + if let Some(ch) = st.rule_senders.lock().unwrap().remove(&id) { + let _ = ch.send(Err(miette!("downstream posted malformed result"))); + } + return ( + StatusCode::BAD_REQUEST, + json!({"ok": false, "message": err.to_string()}).into(), + ); + } + }; + if let Some(ch) = st.rule_senders.lock().unwrap().remove(&id) { + match ch.send(Ok(res)) { + Ok(_) => (StatusCode::OK, json!({"ok": true}).into()), + Err(err) => ( + StatusCode::INTERNAL_SERVER_ERROR, + json!({"ok": false, "message": err.to_string()}).into(), + ), + } + } else { + (StatusCode::NOT_FOUND, json!({"ok": false}).into()) + } +} + +async fn post_rule_err( + State(st): State, + Path(id): Path, +) -> (StatusCode, Json) { + if let Some(ch) = st.rule_senders.lock().unwrap().remove(&id) { + match ch.send(Err(miette!("downstream cancelled computation"))) { + Ok(_) => (StatusCode::OK, json!({"ok": true}).into()), + Err(err) => ( + StatusCode::INTERNAL_SERVER_ERROR, + json!({"ok": false, "message": err.to_string()}).into(), + ), + } + } else { + (StatusCode::NOT_FOUND, json!({"ok": false}).into()) + } +} + +async fn register_rule( + State(st): State, + Path(name): Path, + Query(rule_opts): Query, +) -> Sse>> { + let (rule, task_receiver) = SimpleFixedRule::rule_with_channel(rule_opts.arity); + let (down_sender, mut down_receiver) = tokio::sync::mpsc::channel(1); + let mut errored = None; + + if let Err(err) = st.db.register_fixed_rule(name.clone(), rule) { + errored = Some(err); + } else { + let rule_senders = st.rule_senders.clone(); + let rule_counter = st.rule_counter.clone(); + + thread::spawn(move || { + for (inputs, options, sender) in task_receiver { + let id = rule_counter.fetch_add(1, Ordering::AcqRel); + let inputs: serde_json::Value = + inputs.into_iter().map(|ip| ip.into_json()).collect(); + let options: serde_json::Value = options + .into_iter() + .map(|(k, v)| (k, serde_json::Value::from(v))) + .collect(); + if down_sender.blocking_send((id, inputs, options)).is_err() { + let _ = sender.send(Err(miette!("cannot send request to downstream"))); + } else { + rule_senders.lock().unwrap().insert(id, sender); + } + } + }); + } + + struct Guard { + name: String, + db: DbInstance, + } + + impl Drop for Guard { + fn drop(&mut self) { + info!("dropping rules SSE {}", self.name); + let _ = self.db.unregister_fixed_rule(&self.name); + } + } + + let stream = async_stream::stream! { + if let Some(err) = errored { + let item = json!({"type": "register-error", "error": err.to_string()}); + yield Ok(Event::default().json_data(item).unwrap()); + } else { + info!("starting rule SSE {}", name); + let _guard = Guard {db: st.db, name}; + while let Some((id, inputs, options)) = down_receiver.recv().await { + let item = json!({"type": "request", "id": id, "inputs": inputs, "options": options}); + yield Ok(Event::default().json_data(item).unwrap()); + } + } + }; + Sse::new(stream).keep_alive(KeepAlive::default()) +} async fn observe_changes( State(st): State, Path(relation): Path, -) -> Sse>> { +) -> Sse>> { let (id, recv) = st.db.register_callback(&relation, None); let (sender, mut receiver) = tokio::sync::mpsc::channel(1); struct Guard { @@ -667,8 +672,8 @@ async fn root() -> Html<&'static str> { } fn internal_error(err: E) -> (StatusCode, Json) -where - E: std::error::Error, + where + E: std::error::Error, { ( StatusCode::INTERNAL_SERVER_ERROR,