restore server features

main
Ziyang Hu 1 year ago
parent 919510b0f4
commit 76fa6b70fc

@ -11,11 +11,13 @@ use std::convert::Infallible;
use std::net::{Ipv6Addr, SocketAddr}; use std::net::{Ipv6Addr, SocketAddr};
use std::str::FromStr; use std::str::FromStr;
// use std::sync::atomic::{AtomicU32, Ordering}; // use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc; use std::sync::{Arc, Mutex};
use std::sync::atomic::{AtomicU32, Ordering};
use std::thread;
// use std::thread; // use std::thread;
use axum::body::{boxed, Body, BoxBody}; use axum::body::{boxed, Body, BoxBody};
use axum::extract::{DefaultBodyLimit, Path, State}; use axum::extract::{DefaultBodyLimit, Path, Query, State};
use axum::http::{header, HeaderName, Method, Request, Response, StatusCode}; use axum::http::{header, HeaderName, Method, Request, Response, StatusCode};
use axum::response::sse::{Event, KeepAlive}; use axum::response::sse::{Event, KeepAlive};
use axum::response::{Html, Sse}; use axum::response::{Html, Sse};
@ -26,6 +28,7 @@ use futures::future::BoxFuture;
use futures::stream::Stream; use futures::stream::Stream;
use itertools::Itertools; use itertools::Itertools;
use log::{error, info, warn}; use log::{error, info, warn};
use miette::miette;
// use miette::miette; // use miette::miette;
use rand::Rng; use rand::Rng;
use serde_json::json; use serde_json::json;
@ -34,7 +37,7 @@ use tower_http::auth::{AsyncAuthorizeRequest, AsyncRequireAuthorizationLayer};
use tower_http::compression::CompressionLayer; use tower_http::compression::CompressionLayer;
use tower_http::cors::{Any, CorsLayer}; use tower_http::cors::{Any, CorsLayer};
use cozo::{DataValue, DbInstance, NamedRows, ScriptMutability}; use cozo::{DataValue, DbInstance, format_error_as_json, MultiTransaction, NamedRows, ScriptMutability, SimpleFixedRule};
#[derive(Args, Debug)] #[derive(Args, Debug)]
pub(crate) struct ServerArgs { pub(crate) struct ServerArgs {
@ -70,10 +73,10 @@ pub(crate) struct ServerArgs {
#[derive(Clone)] #[derive(Clone)]
struct DbState { struct DbState {
db: DbInstance, db: DbInstance,
// rule_senders: Arc<Mutex<BTreeMap<u32, crossbeam::channel::Sender<miette::Result<NamedRows>>>>>, rule_senders: Arc<Mutex<BTreeMap<u32, crossbeam::channel::Sender<miette::Result<NamedRows>>>>>,
// rule_counter: Arc<AtomicU32>, rule_counter: Arc<AtomicU32>,
// tx_counter: Arc<AtomicU32>, tx_counter: Arc<AtomicU32>,
// txs: Arc<Mutex<BTreeMap<u32, Arc<MultiTransaction>>>>, txs: Arc<Mutex<BTreeMap<u32, Arc<MultiTransaction>>>>,
} }
#[derive(Clone)] #[derive(Clone)]
@ -91,7 +94,7 @@ where
type ResponseBody = BoxBody; type ResponseBody = BoxBody;
type Future = BoxFuture<'static, Result<Request<B>, Response<Self::ResponseBody>>>; type Future = BoxFuture<'static, Result<Request<B>, Response<Self::ResponseBody>>>;
fn authorize<'a>(&'a mut self, mut request: Request<B>) -> Self::Future { fn authorize(&mut self, mut request: Request<B>) -> Self::Future {
let skip_auth = self.skip_auth; let skip_auth = self.skip_auth;
let auth_guard = self.auth_guard.clone(); let auth_guard = self.auth_guard.clone();
let token_table = self.token_table.clone(); let token_table = self.token_table.clone();
@ -127,8 +130,7 @@ where
let (name, db) = tt.as_ref(); let (name, db) = tt.as_ref();
if let Some(auth_header) = request.headers().get("Authorization") { if let Some(auth_header) = request.headers().get("Authorization") {
if let Ok(auth_str) = auth_header.to_str() { if let Ok(auth_str) = auth_header.to_str() {
if auth_str.starts_with("Bearer ") { if let Some(token) = auth_str.strip_prefix("Bearer ") {
let token = &auth_str["Bearer ".len()..];
match db.run_script( match db.run_script(
&format!("?[mutable] := *{name} {{ token: $token, mutable }}"), &format!("?[mutable] := *{name} {{ token: $token, mutable }}"),
BTreeMap::from([(String::from("token"), DataValue::from(token))]), BTreeMap::from([(String::from("token"), DataValue::from(token))]),
@ -147,7 +149,7 @@ where
Err(err) => { Err(err) => {
eprintln!("Error: {}", err); eprintln!("Error: {}", err);
None None
}, }
} }
} else { } else {
None None
@ -232,10 +234,10 @@ pub(crate) async fn server_main(args: ServerArgs) {
let state = DbState { let state = DbState {
db, db,
// rule_senders: Default::default(), rule_senders: Default::default(),
// rule_counter: Default::default(), rule_counter: Default::default(),
// tx_counter: Default::default(), tx_counter: Default::default(),
// txs: Default::default(), txs: Default::default(),
}; };
let cors = CorsLayer::new() let cors = CorsLayer::new()
.allow_methods([Method::GET, Method::POST, Method::PUT, Method::DELETE]) .allow_methods([Method::GET, Method::POST, Method::PUT, Method::DELETE])
@ -249,13 +251,13 @@ pub(crate) async fn server_main(args: ServerArgs) {
.route("/backup", post(backup)) .route("/backup", post(backup))
.route("/import-from-backup", post(import_from_backup)) .route("/import-from-backup", post(import_from_backup))
.route("/changes/:relation", get(observe_changes)) .route("/changes/:relation", get(observe_changes))
// .route("/rules/:name", get(register_rule)) .route("/rules/:name", get(register_rule))
// .route( .route(
// "/rule-result/:id", "/rule-result/:id",
// post(post_rule_result).delete(post_rule_err), post(post_rule_result).delete(post_rule_err),
// ) // +keep alive ) // +keep alive
// .route("/transact", post(start_transact)) .route("/transact", post(start_transact))
// .route("/transact/:id", post(transact_query).put(finish_query)) .route("/transact/:id", post(transact_query).put(finish_query))
.with_state(state) .with_state(state)
.layer(AsyncRequireAuthorizationLayer::new(auth_obj)) .layer(AsyncRequireAuthorizationLayer::new(auth_obj))
.fallback(not_found) .fallback(not_found)
@ -288,76 +290,76 @@ pub(crate) async fn server_main(args: ServerArgs) {
#[derive(serde_derive::Deserialize)] #[derive(serde_derive::Deserialize)]
struct StartTransactPayload { struct StartTransactPayload {
// write: bool, write: bool,
} }
// async fn start_transact( async fn start_transact(
// State(st): State<DbState>, State(st): State<DbState>,
// Query(payload): Query<StartTransactPayload>, Query(payload): Query<StartTransactPayload>,
// ) -> (StatusCode, Json<serde_json::Value>) { ) -> (StatusCode, Json<serde_json::Value>) {
// let tx = st.db.multi_transaction(payload.write); let tx = st.db.multi_transaction(payload.write);
// let id = st.tx_counter.fetch_add(1, Ordering::SeqCst); let id = st.tx_counter.fetch_add(1, Ordering::SeqCst);
// st.txs.lock().unwrap().insert(id, Arc::new(tx)); st.txs.lock().unwrap().insert(id, Arc::new(tx));
// (StatusCode::OK, json!({"ok": true, "id": id}).into()) (StatusCode::OK, json!({"ok": true, "id": id}).into())
// } }
// async fn transact_query( async fn transact_query(
// State(st): State<DbState>, State(st): State<DbState>,
// Path(id): Path<u32>, Path(id): Path<u32>,
// Json(payload): Json<QueryPayload>, Json(payload): Json<QueryPayload>,
// ) -> (StatusCode, Json<serde_json::Value>) { ) -> (StatusCode, Json<serde_json::Value>) {
// let tx = match st.txs.lock().unwrap().get(&id) { let tx = match st.txs.lock().unwrap().get(&id) {
// None => return (StatusCode::NOT_FOUND, json!({"ok": false}).into()), None => return (StatusCode::NOT_FOUND, json!({"ok": false}).into()),
// Some(tx) => tx.clone(), Some(tx) => tx.clone(),
// }; };
// let src = payload.script.clone(); let src = payload.script.clone();
// let result = spawn_blocking(move || { let result = spawn_blocking(move || {
// let params = payload let params = payload
// .params .params
// .into_iter() .into_iter()
// .map(|(k, v)| (k, DataValue::from(v))) .map(|(k, v)| (k, DataValue::from(v)))
// .collect(); .collect();
// let query = payload.script; let query = payload.script;
// tx.run_script(&query, params) tx.run_script(&query, params)
// }) })
// .await; .await;
// match result { match result {
// Ok(Ok(res)) => (StatusCode::OK, res.into_json().into()), Ok(Ok(res)) => (StatusCode::OK, res.into_json().into()),
// Ok(Err(err)) => ( Ok(Err(err)) => (
// StatusCode::BAD_REQUEST, StatusCode::BAD_REQUEST,
// format_error_as_json(err, Some(&src)).into(), format_error_as_json(err, Some(&src)).into(),
// ), ),
// Err(err) => internal_error(err), Err(err) => internal_error(err),
// } }
// } }
// #[derive(serde_derive::Deserialize)] #[derive(serde_derive::Deserialize)]
// struct FinishTransactPayload { struct FinishTransactPayload {
// abort: bool, abort: bool,
// } }
// async fn finish_query( async fn finish_query(
// State(st): State<DbState>, State(st): State<DbState>,
// Path(id): Path<u32>, Path(id): Path<u32>,
// Json(payload): Json<FinishTransactPayload>, Json(payload): Json<FinishTransactPayload>,
// ) -> (StatusCode, Json<serde_json::Value>) { ) -> (StatusCode, Json<serde_json::Value>) {
// let tx = match st.txs.lock().unwrap().remove(&id) { let tx = match st.txs.lock().unwrap().remove(&id) {
// None => return (StatusCode::NOT_FOUND, json!({"ok": false}).into()), None => return (StatusCode::NOT_FOUND, json!({"ok": false}).into()),
// Some(tx) => tx, Some(tx) => tx,
// }; };
// let res = if payload.abort { let res = if payload.abort {
// tx.abort() tx.abort()
// } else { } else {
// tx.commit() tx.commit()
// }; };
// match res { match res {
// Ok(_) => (StatusCode::OK, json!({"ok": true}).into()), Ok(_) => (StatusCode::OK, json!({"ok": true}).into()),
// Err(err) => ( Err(err) => (
// StatusCode::BAD_REQUEST, StatusCode::BAD_REQUEST,
// json!({"ok": false, "message": err.to_string()}).into(), json!({"ok": false, "message": err.to_string()}).into(),
// ), ),
// } }
// } }
#[derive(serde_derive::Deserialize)] #[derive(serde_derive::Deserialize)]
struct QueryPayload { struct QueryPayload {
@ -436,7 +438,7 @@ async fn import_relations(
return ( return (
StatusCode::BAD_REQUEST, StatusCode::BAD_REQUEST,
json!({"ok": false, "message": "payload must be a JSON object"}).into(), json!({"ok": false, "message": "payload must be a JSON object"}).into(),
) );
} }
Some(pl) => { Some(pl) => {
let mut ret = BTreeMap::new(); let mut ret = BTreeMap::new();
@ -447,7 +449,7 @@ async fn import_relations(
return ( return (
StatusCode::BAD_REQUEST, StatusCode::BAD_REQUEST,
json!({"ok": false, "message": err.to_string()}).into(), json!({"ok": false, "message": err.to_string()}).into(),
) );
} }
}; };
ret.insert(k.to_string(), nr); ret.insert(k.to_string(), nr);
@ -466,6 +468,7 @@ async fn import_relations(
Err(err) => internal_error(err), Err(err) => internal_error(err),
} }
} }
#[derive(serde_derive::Deserialize)] #[derive(serde_derive::Deserialize)]
struct BackupPayload { struct BackupPayload {
path: String, path: String,
@ -489,11 +492,13 @@ async fn backup(
Err(err) => internal_error(err), Err(err) => internal_error(err),
} }
} }
#[derive(serde_derive::Deserialize)] #[derive(serde_derive::Deserialize)]
struct BackupImportPayload { struct BackupImportPayload {
path: String, path: String,
relations: Vec<String>, relations: Vec<String>,
} }
async fn import_from_backup( async fn import_from_backup(
State(st): State<DbState>, State(st): State<DbState>,
Json(payload): Json<BackupImportPayload>, Json(payload): Json<BackupImportPayload>,
@ -514,118 +519,118 @@ async fn import_from_backup(
} }
} }
// #[derive(serde_derive::Deserialize)] #[derive(serde_derive::Deserialize)]
// struct RuleRegisterOptions { struct RuleRegisterOptions {
// arity: usize, arity: usize,
// } }
// async fn post_rule_result( async fn post_rule_result(
// State(st): State<DbState>, State(st): State<DbState>,
// Path(id): Path<u32>, Path(id): Path<u32>,
// Json(res): Json<serde_json::Value>, Json(res): Json<serde_json::Value>,
// ) -> (StatusCode, Json<serde_json::Value>) { ) -> (StatusCode, Json<serde_json::Value>) {
// let res = match NamedRows::from_json(&res) { let res = match NamedRows::from_json(&res) {
// Ok(res) => res, Ok(res) => res,
// Err(err) => { Err(err) => {
// if let Some(ch) = st.rule_senders.lock().unwrap().remove(&id) { if let Some(ch) = st.rule_senders.lock().unwrap().remove(&id) {
// let _ = ch.send(Err(miette!("downstream posted malformed result"))); let _ = ch.send(Err(miette!("downstream posted malformed result")));
// } }
// return ( return (
// StatusCode::BAD_REQUEST, StatusCode::BAD_REQUEST,
// json!({"ok": false, "message": err.to_string()}).into(), json!({"ok": false, "message": err.to_string()}).into(),
// ); );
// } }
// }; };
// if let Some(ch) = st.rule_senders.lock().unwrap().remove(&id) { if let Some(ch) = st.rule_senders.lock().unwrap().remove(&id) {
// match ch.send(Ok(res)) { match ch.send(Ok(res)) {
// Ok(_) => (StatusCode::OK, json!({"ok": true}).into()), Ok(_) => (StatusCode::OK, json!({"ok": true}).into()),
// Err(err) => ( Err(err) => (
// StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
// json!({"ok": false, "message": err.to_string()}).into(), json!({"ok": false, "message": err.to_string()}).into(),
// ), ),
// } }
// } else { } else {
// (StatusCode::NOT_FOUND, json!({"ok": false}).into()) (StatusCode::NOT_FOUND, json!({"ok": false}).into())
// } }
// } }
// async fn post_rule_err( async fn post_rule_err(
// State(st): State<DbState>, State(st): State<DbState>,
// Path(id): Path<u32>, Path(id): Path<u32>,
// ) -> (StatusCode, Json<serde_json::Value>) { ) -> (StatusCode, Json<serde_json::Value>) {
// if let Some(ch) = st.rule_senders.lock().unwrap().remove(&id) { if let Some(ch) = st.rule_senders.lock().unwrap().remove(&id) {
// match ch.send(Err(miette!("downstream cancelled computation"))) { match ch.send(Err(miette!("downstream cancelled computation"))) {
// Ok(_) => (StatusCode::OK, json!({"ok": true}).into()), Ok(_) => (StatusCode::OK, json!({"ok": true}).into()),
// Err(err) => ( Err(err) => (
// StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
// json!({"ok": false, "message": err.to_string()}).into(), json!({"ok": false, "message": err.to_string()}).into(),
// ), ),
// } }
// } else { } else {
// (StatusCode::NOT_FOUND, json!({"ok": false}).into()) (StatusCode::NOT_FOUND, json!({"ok": false}).into())
// } }
// } }
// async fn register_rule( async fn register_rule(
// State(st): State<DbState>, State(st): State<DbState>,
// Path(name): Path<String>, Path(name): Path<String>,
// Query(rule_opts): Query<RuleRegisterOptions>, Query(rule_opts): Query<RuleRegisterOptions>,
// ) -> Sse<impl Stream<Item = Result<Event, Infallible>>> { ) -> Sse<impl Stream<Item=Result<Event, Infallible>>> {
// let (rule, task_receiver) = SimpleFixedRule::rule_with_channel(rule_opts.arity); let (rule, task_receiver) = SimpleFixedRule::rule_with_channel(rule_opts.arity);
// let (down_sender, mut down_receiver) = tokio::sync::mpsc::channel(1); let (down_sender, mut down_receiver) = tokio::sync::mpsc::channel(1);
// let mut errored = None; let mut errored = None;
//
// if let Err(err) = st.db.register_fixed_rule(name.clone(), rule) { if let Err(err) = st.db.register_fixed_rule(name.clone(), rule) {
// errored = Some(err); errored = Some(err);
// } else { } else {
// let rule_senders = st.rule_senders.clone(); let rule_senders = st.rule_senders.clone();
// let rule_counter = st.rule_counter.clone(); let rule_counter = st.rule_counter.clone();
//
// thread::spawn(move || { thread::spawn(move || {
// for (inputs, options, sender) in task_receiver { for (inputs, options, sender) in task_receiver {
// let id = rule_counter.fetch_add(1, Ordering::AcqRel); let id = rule_counter.fetch_add(1, Ordering::AcqRel);
// let inputs: serde_json::Value = let inputs: serde_json::Value =
// inputs.into_iter().map(|ip| ip.into_json()).collect(); inputs.into_iter().map(|ip| ip.into_json()).collect();
// let options: serde_json::Value = options let options: serde_json::Value = options
// .into_iter() .into_iter()
// .map(|(k, v)| (k, serde_json::Value::from(v))) .map(|(k, v)| (k, serde_json::Value::from(v)))
// .collect(); .collect();
// if down_sender.blocking_send((id, inputs, options)).is_err() { if down_sender.blocking_send((id, inputs, options)).is_err() {
// let _ = sender.send(Err(miette!("cannot send request to downstream"))); let _ = sender.send(Err(miette!("cannot send request to downstream")));
// } else { } else {
// rule_senders.lock().unwrap().insert(id, sender); rule_senders.lock().unwrap().insert(id, sender);
// } }
// } }
// }); });
// } }
//
// struct Guard { struct Guard {
// name: String, name: String,
// db: DbInstance, db: DbInstance,
// } }
//
// impl Drop for Guard { impl Drop for Guard {
// fn drop(&mut self) { fn drop(&mut self) {
// info!("dropping rules SSE {}", self.name); info!("dropping rules SSE {}", self.name);
// let _ = self.db.unregister_fixed_rule(&self.name); let _ = self.db.unregister_fixed_rule(&self.name);
// } }
// } }
//
// let stream = async_stream::stream! { let stream = async_stream::stream! {
// if let Some(err) = errored { if let Some(err) = errored {
// let item = json!({"type": "register-error", "error": err.to_string()}); let item = json!({"type": "register-error", "error": err.to_string()});
// yield Ok(Event::default().json_data(item).unwrap()); yield Ok(Event::default().json_data(item).unwrap());
// } else { } else {
// info!("starting rule SSE {}", name); info!("starting rule SSE {}", name);
// let _guard = Guard {db: st.db, name}; let _guard = Guard {db: st.db, name};
// while let Some((id, inputs, options)) = down_receiver.recv().await { while let Some((id, inputs, options)) = down_receiver.recv().await {
// let item = json!({"type": "request", "id": id, "inputs": inputs, "options": options}); let item = json!({"type": "request", "id": id, "inputs": inputs, "options": options});
// yield Ok(Event::default().json_data(item).unwrap()); yield Ok(Event::default().json_data(item).unwrap());
// } }
// } }
// }; };
// Sse::new(stream).keep_alive(KeepAlive::default()) Sse::new(stream).keep_alive(KeepAlive::default())
// } }
async fn observe_changes( async fn observe_changes(
State(st): State<DbState>, State(st): State<DbState>,

Loading…
Cancel
Save