|
|
@ -11,11 +11,13 @@ use std::convert::Infallible;
|
|
|
|
use std::net::{Ipv6Addr, SocketAddr};
|
|
|
|
use std::net::{Ipv6Addr, SocketAddr};
|
|
|
|
use std::str::FromStr;
|
|
|
|
use std::str::FromStr;
|
|
|
|
// use std::sync::atomic::{AtomicU32, Ordering};
|
|
|
|
// use std::sync::atomic::{AtomicU32, Ordering};
|
|
|
|
use std::sync::Arc;
|
|
|
|
use std::sync::{Arc, Mutex};
|
|
|
|
|
|
|
|
use std::sync::atomic::{AtomicU32, Ordering};
|
|
|
|
|
|
|
|
use std::thread;
|
|
|
|
// use std::thread;
|
|
|
|
// use std::thread;
|
|
|
|
|
|
|
|
|
|
|
|
use axum::body::{boxed, Body, BoxBody};
|
|
|
|
use axum::body::{boxed, Body, BoxBody};
|
|
|
|
use axum::extract::{DefaultBodyLimit, Path, State};
|
|
|
|
use axum::extract::{DefaultBodyLimit, Path, Query, State};
|
|
|
|
use axum::http::{header, HeaderName, Method, Request, Response, StatusCode};
|
|
|
|
use axum::http::{header, HeaderName, Method, Request, Response, StatusCode};
|
|
|
|
use axum::response::sse::{Event, KeepAlive};
|
|
|
|
use axum::response::sse::{Event, KeepAlive};
|
|
|
|
use axum::response::{Html, Sse};
|
|
|
|
use axum::response::{Html, Sse};
|
|
|
@ -26,6 +28,7 @@ use futures::future::BoxFuture;
|
|
|
|
use futures::stream::Stream;
|
|
|
|
use futures::stream::Stream;
|
|
|
|
use itertools::Itertools;
|
|
|
|
use itertools::Itertools;
|
|
|
|
use log::{error, info, warn};
|
|
|
|
use log::{error, info, warn};
|
|
|
|
|
|
|
|
use miette::miette;
|
|
|
|
// use miette::miette;
|
|
|
|
// use miette::miette;
|
|
|
|
use rand::Rng;
|
|
|
|
use rand::Rng;
|
|
|
|
use serde_json::json;
|
|
|
|
use serde_json::json;
|
|
|
@ -34,7 +37,7 @@ use tower_http::auth::{AsyncAuthorizeRequest, AsyncRequireAuthorizationLayer};
|
|
|
|
use tower_http::compression::CompressionLayer;
|
|
|
|
use tower_http::compression::CompressionLayer;
|
|
|
|
use tower_http::cors::{Any, CorsLayer};
|
|
|
|
use tower_http::cors::{Any, CorsLayer};
|
|
|
|
|
|
|
|
|
|
|
|
use cozo::{DataValue, DbInstance, NamedRows, ScriptMutability};
|
|
|
|
use cozo::{DataValue, DbInstance, format_error_as_json, MultiTransaction, NamedRows, ScriptMutability, SimpleFixedRule};
|
|
|
|
|
|
|
|
|
|
|
|
#[derive(Args, Debug)]
|
|
|
|
#[derive(Args, Debug)]
|
|
|
|
pub(crate) struct ServerArgs {
|
|
|
|
pub(crate) struct ServerArgs {
|
|
|
@ -70,10 +73,10 @@ pub(crate) struct ServerArgs {
|
|
|
|
#[derive(Clone)]
|
|
|
|
#[derive(Clone)]
|
|
|
|
struct DbState {
|
|
|
|
struct DbState {
|
|
|
|
db: DbInstance,
|
|
|
|
db: DbInstance,
|
|
|
|
// rule_senders: Arc<Mutex<BTreeMap<u32, crossbeam::channel::Sender<miette::Result<NamedRows>>>>>,
|
|
|
|
rule_senders: Arc<Mutex<BTreeMap<u32, crossbeam::channel::Sender<miette::Result<NamedRows>>>>>,
|
|
|
|
// rule_counter: Arc<AtomicU32>,
|
|
|
|
rule_counter: Arc<AtomicU32>,
|
|
|
|
// tx_counter: Arc<AtomicU32>,
|
|
|
|
tx_counter: Arc<AtomicU32>,
|
|
|
|
// txs: Arc<Mutex<BTreeMap<u32, Arc<MultiTransaction>>>>,
|
|
|
|
txs: Arc<Mutex<BTreeMap<u32, Arc<MultiTransaction>>>>,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
#[derive(Clone)]
|
|
|
|
#[derive(Clone)]
|
|
|
@ -91,7 +94,7 @@ where
|
|
|
|
type ResponseBody = BoxBody;
|
|
|
|
type ResponseBody = BoxBody;
|
|
|
|
type Future = BoxFuture<'static, Result<Request<B>, Response<Self::ResponseBody>>>;
|
|
|
|
type Future = BoxFuture<'static, Result<Request<B>, Response<Self::ResponseBody>>>;
|
|
|
|
|
|
|
|
|
|
|
|
fn authorize<'a>(&'a mut self, mut request: Request<B>) -> Self::Future {
|
|
|
|
fn authorize(&mut self, mut request: Request<B>) -> Self::Future {
|
|
|
|
let skip_auth = self.skip_auth;
|
|
|
|
let skip_auth = self.skip_auth;
|
|
|
|
let auth_guard = self.auth_guard.clone();
|
|
|
|
let auth_guard = self.auth_guard.clone();
|
|
|
|
let token_table = self.token_table.clone();
|
|
|
|
let token_table = self.token_table.clone();
|
|
|
@ -127,8 +130,7 @@ where
|
|
|
|
let (name, db) = tt.as_ref();
|
|
|
|
let (name, db) = tt.as_ref();
|
|
|
|
if let Some(auth_header) = request.headers().get("Authorization") {
|
|
|
|
if let Some(auth_header) = request.headers().get("Authorization") {
|
|
|
|
if let Ok(auth_str) = auth_header.to_str() {
|
|
|
|
if let Ok(auth_str) = auth_header.to_str() {
|
|
|
|
if auth_str.starts_with("Bearer ") {
|
|
|
|
if let Some(token) = auth_str.strip_prefix("Bearer ") {
|
|
|
|
let token = &auth_str["Bearer ".len()..];
|
|
|
|
|
|
|
|
match db.run_script(
|
|
|
|
match db.run_script(
|
|
|
|
&format!("?[mutable] := *{name} {{ token: $token, mutable }}"),
|
|
|
|
&format!("?[mutable] := *{name} {{ token: $token, mutable }}"),
|
|
|
|
BTreeMap::from([(String::from("token"), DataValue::from(token))]),
|
|
|
|
BTreeMap::from([(String::from("token"), DataValue::from(token))]),
|
|
|
@ -147,7 +149,7 @@ where
|
|
|
|
Err(err) => {
|
|
|
|
Err(err) => {
|
|
|
|
eprintln!("Error: {}", err);
|
|
|
|
eprintln!("Error: {}", err);
|
|
|
|
None
|
|
|
|
None
|
|
|
|
},
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
None
|
|
|
|
None
|
|
|
@ -232,10 +234,10 @@ pub(crate) async fn server_main(args: ServerArgs) {
|
|
|
|
|
|
|
|
|
|
|
|
let state = DbState {
|
|
|
|
let state = DbState {
|
|
|
|
db,
|
|
|
|
db,
|
|
|
|
// rule_senders: Default::default(),
|
|
|
|
rule_senders: Default::default(),
|
|
|
|
// rule_counter: Default::default(),
|
|
|
|
rule_counter: Default::default(),
|
|
|
|
// tx_counter: Default::default(),
|
|
|
|
tx_counter: Default::default(),
|
|
|
|
// txs: Default::default(),
|
|
|
|
txs: Default::default(),
|
|
|
|
};
|
|
|
|
};
|
|
|
|
let cors = CorsLayer::new()
|
|
|
|
let cors = CorsLayer::new()
|
|
|
|
.allow_methods([Method::GET, Method::POST, Method::PUT, Method::DELETE])
|
|
|
|
.allow_methods([Method::GET, Method::POST, Method::PUT, Method::DELETE])
|
|
|
@ -249,13 +251,13 @@ pub(crate) async fn server_main(args: ServerArgs) {
|
|
|
|
.route("/backup", post(backup))
|
|
|
|
.route("/backup", post(backup))
|
|
|
|
.route("/import-from-backup", post(import_from_backup))
|
|
|
|
.route("/import-from-backup", post(import_from_backup))
|
|
|
|
.route("/changes/:relation", get(observe_changes))
|
|
|
|
.route("/changes/:relation", get(observe_changes))
|
|
|
|
// .route("/rules/:name", get(register_rule))
|
|
|
|
.route("/rules/:name", get(register_rule))
|
|
|
|
// .route(
|
|
|
|
.route(
|
|
|
|
// "/rule-result/:id",
|
|
|
|
"/rule-result/:id",
|
|
|
|
// post(post_rule_result).delete(post_rule_err),
|
|
|
|
post(post_rule_result).delete(post_rule_err),
|
|
|
|
// ) // +keep alive
|
|
|
|
) // +keep alive
|
|
|
|
// .route("/transact", post(start_transact))
|
|
|
|
.route("/transact", post(start_transact))
|
|
|
|
// .route("/transact/:id", post(transact_query).put(finish_query))
|
|
|
|
.route("/transact/:id", post(transact_query).put(finish_query))
|
|
|
|
.with_state(state)
|
|
|
|
.with_state(state)
|
|
|
|
.layer(AsyncRequireAuthorizationLayer::new(auth_obj))
|
|
|
|
.layer(AsyncRequireAuthorizationLayer::new(auth_obj))
|
|
|
|
.fallback(not_found)
|
|
|
|
.fallback(not_found)
|
|
|
@ -288,76 +290,76 @@ pub(crate) async fn server_main(args: ServerArgs) {
|
|
|
|
|
|
|
|
|
|
|
|
#[derive(serde_derive::Deserialize)]
|
|
|
|
#[derive(serde_derive::Deserialize)]
|
|
|
|
struct StartTransactPayload {
|
|
|
|
struct StartTransactPayload {
|
|
|
|
// write: bool,
|
|
|
|
write: bool,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// async fn start_transact(
|
|
|
|
async fn start_transact(
|
|
|
|
// State(st): State<DbState>,
|
|
|
|
State(st): State<DbState>,
|
|
|
|
// Query(payload): Query<StartTransactPayload>,
|
|
|
|
Query(payload): Query<StartTransactPayload>,
|
|
|
|
// ) -> (StatusCode, Json<serde_json::Value>) {
|
|
|
|
) -> (StatusCode, Json<serde_json::Value>) {
|
|
|
|
// let tx = st.db.multi_transaction(payload.write);
|
|
|
|
let tx = st.db.multi_transaction(payload.write);
|
|
|
|
// let id = st.tx_counter.fetch_add(1, Ordering::SeqCst);
|
|
|
|
let id = st.tx_counter.fetch_add(1, Ordering::SeqCst);
|
|
|
|
// st.txs.lock().unwrap().insert(id, Arc::new(tx));
|
|
|
|
st.txs.lock().unwrap().insert(id, Arc::new(tx));
|
|
|
|
// (StatusCode::OK, json!({"ok": true, "id": id}).into())
|
|
|
|
(StatusCode::OK, json!({"ok": true, "id": id}).into())
|
|
|
|
// }
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// async fn transact_query(
|
|
|
|
async fn transact_query(
|
|
|
|
// State(st): State<DbState>,
|
|
|
|
State(st): State<DbState>,
|
|
|
|
// Path(id): Path<u32>,
|
|
|
|
Path(id): Path<u32>,
|
|
|
|
// Json(payload): Json<QueryPayload>,
|
|
|
|
Json(payload): Json<QueryPayload>,
|
|
|
|
// ) -> (StatusCode, Json<serde_json::Value>) {
|
|
|
|
) -> (StatusCode, Json<serde_json::Value>) {
|
|
|
|
// let tx = match st.txs.lock().unwrap().get(&id) {
|
|
|
|
let tx = match st.txs.lock().unwrap().get(&id) {
|
|
|
|
// None => return (StatusCode::NOT_FOUND, json!({"ok": false}).into()),
|
|
|
|
None => return (StatusCode::NOT_FOUND, json!({"ok": false}).into()),
|
|
|
|
// Some(tx) => tx.clone(),
|
|
|
|
Some(tx) => tx.clone(),
|
|
|
|
// };
|
|
|
|
};
|
|
|
|
// let src = payload.script.clone();
|
|
|
|
let src = payload.script.clone();
|
|
|
|
// let result = spawn_blocking(move || {
|
|
|
|
let result = spawn_blocking(move || {
|
|
|
|
// let params = payload
|
|
|
|
let params = payload
|
|
|
|
// .params
|
|
|
|
.params
|
|
|
|
// .into_iter()
|
|
|
|
.into_iter()
|
|
|
|
// .map(|(k, v)| (k, DataValue::from(v)))
|
|
|
|
.map(|(k, v)| (k, DataValue::from(v)))
|
|
|
|
// .collect();
|
|
|
|
.collect();
|
|
|
|
// let query = payload.script;
|
|
|
|
let query = payload.script;
|
|
|
|
// tx.run_script(&query, params)
|
|
|
|
tx.run_script(&query, params)
|
|
|
|
// })
|
|
|
|
})
|
|
|
|
// .await;
|
|
|
|
.await;
|
|
|
|
// match result {
|
|
|
|
match result {
|
|
|
|
// Ok(Ok(res)) => (StatusCode::OK, res.into_json().into()),
|
|
|
|
Ok(Ok(res)) => (StatusCode::OK, res.into_json().into()),
|
|
|
|
// Ok(Err(err)) => (
|
|
|
|
Ok(Err(err)) => (
|
|
|
|
// StatusCode::BAD_REQUEST,
|
|
|
|
StatusCode::BAD_REQUEST,
|
|
|
|
// format_error_as_json(err, Some(&src)).into(),
|
|
|
|
format_error_as_json(err, Some(&src)).into(),
|
|
|
|
// ),
|
|
|
|
),
|
|
|
|
// Err(err) => internal_error(err),
|
|
|
|
Err(err) => internal_error(err),
|
|
|
|
// }
|
|
|
|
}
|
|
|
|
// }
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// #[derive(serde_derive::Deserialize)]
|
|
|
|
#[derive(serde_derive::Deserialize)]
|
|
|
|
// struct FinishTransactPayload {
|
|
|
|
struct FinishTransactPayload {
|
|
|
|
// abort: bool,
|
|
|
|
abort: bool,
|
|
|
|
// }
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// async fn finish_query(
|
|
|
|
async fn finish_query(
|
|
|
|
// State(st): State<DbState>,
|
|
|
|
State(st): State<DbState>,
|
|
|
|
// Path(id): Path<u32>,
|
|
|
|
Path(id): Path<u32>,
|
|
|
|
// Json(payload): Json<FinishTransactPayload>,
|
|
|
|
Json(payload): Json<FinishTransactPayload>,
|
|
|
|
// ) -> (StatusCode, Json<serde_json::Value>) {
|
|
|
|
) -> (StatusCode, Json<serde_json::Value>) {
|
|
|
|
// let tx = match st.txs.lock().unwrap().remove(&id) {
|
|
|
|
let tx = match st.txs.lock().unwrap().remove(&id) {
|
|
|
|
// None => return (StatusCode::NOT_FOUND, json!({"ok": false}).into()),
|
|
|
|
None => return (StatusCode::NOT_FOUND, json!({"ok": false}).into()),
|
|
|
|
// Some(tx) => tx,
|
|
|
|
Some(tx) => tx,
|
|
|
|
// };
|
|
|
|
};
|
|
|
|
// let res = if payload.abort {
|
|
|
|
let res = if payload.abort {
|
|
|
|
// tx.abort()
|
|
|
|
tx.abort()
|
|
|
|
// } else {
|
|
|
|
} else {
|
|
|
|
// tx.commit()
|
|
|
|
tx.commit()
|
|
|
|
// };
|
|
|
|
};
|
|
|
|
// match res {
|
|
|
|
match res {
|
|
|
|
// Ok(_) => (StatusCode::OK, json!({"ok": true}).into()),
|
|
|
|
Ok(_) => (StatusCode::OK, json!({"ok": true}).into()),
|
|
|
|
// Err(err) => (
|
|
|
|
Err(err) => (
|
|
|
|
// StatusCode::BAD_REQUEST,
|
|
|
|
StatusCode::BAD_REQUEST,
|
|
|
|
// json!({"ok": false, "message": err.to_string()}).into(),
|
|
|
|
json!({"ok": false, "message": err.to_string()}).into(),
|
|
|
|
// ),
|
|
|
|
),
|
|
|
|
// }
|
|
|
|
}
|
|
|
|
// }
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
#[derive(serde_derive::Deserialize)]
|
|
|
|
#[derive(serde_derive::Deserialize)]
|
|
|
|
struct QueryPayload {
|
|
|
|
struct QueryPayload {
|
|
|
@ -436,7 +438,7 @@ async fn import_relations(
|
|
|
|
return (
|
|
|
|
return (
|
|
|
|
StatusCode::BAD_REQUEST,
|
|
|
|
StatusCode::BAD_REQUEST,
|
|
|
|
json!({"ok": false, "message": "payload must be a JSON object"}).into(),
|
|
|
|
json!({"ok": false, "message": "payload must be a JSON object"}).into(),
|
|
|
|
)
|
|
|
|
);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
Some(pl) => {
|
|
|
|
Some(pl) => {
|
|
|
|
let mut ret = BTreeMap::new();
|
|
|
|
let mut ret = BTreeMap::new();
|
|
|
@ -447,7 +449,7 @@ async fn import_relations(
|
|
|
|
return (
|
|
|
|
return (
|
|
|
|
StatusCode::BAD_REQUEST,
|
|
|
|
StatusCode::BAD_REQUEST,
|
|
|
|
json!({"ok": false, "message": err.to_string()}).into(),
|
|
|
|
json!({"ok": false, "message": err.to_string()}).into(),
|
|
|
|
)
|
|
|
|
);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
ret.insert(k.to_string(), nr);
|
|
|
|
ret.insert(k.to_string(), nr);
|
|
|
@ -466,6 +468,7 @@ async fn import_relations(
|
|
|
|
Err(err) => internal_error(err),
|
|
|
|
Err(err) => internal_error(err),
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
#[derive(serde_derive::Deserialize)]
|
|
|
|
#[derive(serde_derive::Deserialize)]
|
|
|
|
struct BackupPayload {
|
|
|
|
struct BackupPayload {
|
|
|
|
path: String,
|
|
|
|
path: String,
|
|
|
@ -489,11 +492,13 @@ async fn backup(
|
|
|
|
Err(err) => internal_error(err),
|
|
|
|
Err(err) => internal_error(err),
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
#[derive(serde_derive::Deserialize)]
|
|
|
|
#[derive(serde_derive::Deserialize)]
|
|
|
|
struct BackupImportPayload {
|
|
|
|
struct BackupImportPayload {
|
|
|
|
path: String,
|
|
|
|
path: String,
|
|
|
|
relations: Vec<String>,
|
|
|
|
relations: Vec<String>,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
async fn import_from_backup(
|
|
|
|
async fn import_from_backup(
|
|
|
|
State(st): State<DbState>,
|
|
|
|
State(st): State<DbState>,
|
|
|
|
Json(payload): Json<BackupImportPayload>,
|
|
|
|
Json(payload): Json<BackupImportPayload>,
|
|
|
@ -514,118 +519,118 @@ async fn import_from_backup(
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// #[derive(serde_derive::Deserialize)]
|
|
|
|
#[derive(serde_derive::Deserialize)]
|
|
|
|
// struct RuleRegisterOptions {
|
|
|
|
struct RuleRegisterOptions {
|
|
|
|
// arity: usize,
|
|
|
|
arity: usize,
|
|
|
|
// }
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// async fn post_rule_result(
|
|
|
|
async fn post_rule_result(
|
|
|
|
// State(st): State<DbState>,
|
|
|
|
State(st): State<DbState>,
|
|
|
|
// Path(id): Path<u32>,
|
|
|
|
Path(id): Path<u32>,
|
|
|
|
// Json(res): Json<serde_json::Value>,
|
|
|
|
Json(res): Json<serde_json::Value>,
|
|
|
|
// ) -> (StatusCode, Json<serde_json::Value>) {
|
|
|
|
) -> (StatusCode, Json<serde_json::Value>) {
|
|
|
|
// let res = match NamedRows::from_json(&res) {
|
|
|
|
let res = match NamedRows::from_json(&res) {
|
|
|
|
// Ok(res) => res,
|
|
|
|
Ok(res) => res,
|
|
|
|
// Err(err) => {
|
|
|
|
Err(err) => {
|
|
|
|
// if let Some(ch) = st.rule_senders.lock().unwrap().remove(&id) {
|
|
|
|
if let Some(ch) = st.rule_senders.lock().unwrap().remove(&id) {
|
|
|
|
// let _ = ch.send(Err(miette!("downstream posted malformed result")));
|
|
|
|
let _ = ch.send(Err(miette!("downstream posted malformed result")));
|
|
|
|
// }
|
|
|
|
}
|
|
|
|
// return (
|
|
|
|
return (
|
|
|
|
// StatusCode::BAD_REQUEST,
|
|
|
|
StatusCode::BAD_REQUEST,
|
|
|
|
// json!({"ok": false, "message": err.to_string()}).into(),
|
|
|
|
json!({"ok": false, "message": err.to_string()}).into(),
|
|
|
|
// );
|
|
|
|
);
|
|
|
|
// }
|
|
|
|
}
|
|
|
|
// };
|
|
|
|
};
|
|
|
|
// if let Some(ch) = st.rule_senders.lock().unwrap().remove(&id) {
|
|
|
|
if let Some(ch) = st.rule_senders.lock().unwrap().remove(&id) {
|
|
|
|
// match ch.send(Ok(res)) {
|
|
|
|
match ch.send(Ok(res)) {
|
|
|
|
// Ok(_) => (StatusCode::OK, json!({"ok": true}).into()),
|
|
|
|
Ok(_) => (StatusCode::OK, json!({"ok": true}).into()),
|
|
|
|
// Err(err) => (
|
|
|
|
Err(err) => (
|
|
|
|
// StatusCode::INTERNAL_SERVER_ERROR,
|
|
|
|
StatusCode::INTERNAL_SERVER_ERROR,
|
|
|
|
// json!({"ok": false, "message": err.to_string()}).into(),
|
|
|
|
json!({"ok": false, "message": err.to_string()}).into(),
|
|
|
|
// ),
|
|
|
|
),
|
|
|
|
// }
|
|
|
|
}
|
|
|
|
// } else {
|
|
|
|
} else {
|
|
|
|
// (StatusCode::NOT_FOUND, json!({"ok": false}).into())
|
|
|
|
(StatusCode::NOT_FOUND, json!({"ok": false}).into())
|
|
|
|
// }
|
|
|
|
}
|
|
|
|
// }
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// async fn post_rule_err(
|
|
|
|
async fn post_rule_err(
|
|
|
|
// State(st): State<DbState>,
|
|
|
|
State(st): State<DbState>,
|
|
|
|
// Path(id): Path<u32>,
|
|
|
|
Path(id): Path<u32>,
|
|
|
|
// ) -> (StatusCode, Json<serde_json::Value>) {
|
|
|
|
) -> (StatusCode, Json<serde_json::Value>) {
|
|
|
|
// if let Some(ch) = st.rule_senders.lock().unwrap().remove(&id) {
|
|
|
|
if let Some(ch) = st.rule_senders.lock().unwrap().remove(&id) {
|
|
|
|
// match ch.send(Err(miette!("downstream cancelled computation"))) {
|
|
|
|
match ch.send(Err(miette!("downstream cancelled computation"))) {
|
|
|
|
// Ok(_) => (StatusCode::OK, json!({"ok": true}).into()),
|
|
|
|
Ok(_) => (StatusCode::OK, json!({"ok": true}).into()),
|
|
|
|
// Err(err) => (
|
|
|
|
Err(err) => (
|
|
|
|
// StatusCode::INTERNAL_SERVER_ERROR,
|
|
|
|
StatusCode::INTERNAL_SERVER_ERROR,
|
|
|
|
// json!({"ok": false, "message": err.to_string()}).into(),
|
|
|
|
json!({"ok": false, "message": err.to_string()}).into(),
|
|
|
|
// ),
|
|
|
|
),
|
|
|
|
// }
|
|
|
|
}
|
|
|
|
// } else {
|
|
|
|
} else {
|
|
|
|
// (StatusCode::NOT_FOUND, json!({"ok": false}).into())
|
|
|
|
(StatusCode::NOT_FOUND, json!({"ok": false}).into())
|
|
|
|
// }
|
|
|
|
}
|
|
|
|
// }
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// async fn register_rule(
|
|
|
|
async fn register_rule(
|
|
|
|
// State(st): State<DbState>,
|
|
|
|
State(st): State<DbState>,
|
|
|
|
// Path(name): Path<String>,
|
|
|
|
Path(name): Path<String>,
|
|
|
|
// Query(rule_opts): Query<RuleRegisterOptions>,
|
|
|
|
Query(rule_opts): Query<RuleRegisterOptions>,
|
|
|
|
// ) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
|
|
|
|
) -> Sse<impl Stream<Item=Result<Event, Infallible>>> {
|
|
|
|
// let (rule, task_receiver) = SimpleFixedRule::rule_with_channel(rule_opts.arity);
|
|
|
|
let (rule, task_receiver) = SimpleFixedRule::rule_with_channel(rule_opts.arity);
|
|
|
|
// let (down_sender, mut down_receiver) = tokio::sync::mpsc::channel(1);
|
|
|
|
let (down_sender, mut down_receiver) = tokio::sync::mpsc::channel(1);
|
|
|
|
// let mut errored = None;
|
|
|
|
let mut errored = None;
|
|
|
|
//
|
|
|
|
|
|
|
|
// if let Err(err) = st.db.register_fixed_rule(name.clone(), rule) {
|
|
|
|
if let Err(err) = st.db.register_fixed_rule(name.clone(), rule) {
|
|
|
|
// errored = Some(err);
|
|
|
|
errored = Some(err);
|
|
|
|
// } else {
|
|
|
|
} else {
|
|
|
|
// let rule_senders = st.rule_senders.clone();
|
|
|
|
let rule_senders = st.rule_senders.clone();
|
|
|
|
// let rule_counter = st.rule_counter.clone();
|
|
|
|
let rule_counter = st.rule_counter.clone();
|
|
|
|
//
|
|
|
|
|
|
|
|
// thread::spawn(move || {
|
|
|
|
thread::spawn(move || {
|
|
|
|
// for (inputs, options, sender) in task_receiver {
|
|
|
|
for (inputs, options, sender) in task_receiver {
|
|
|
|
// let id = rule_counter.fetch_add(1, Ordering::AcqRel);
|
|
|
|
let id = rule_counter.fetch_add(1, Ordering::AcqRel);
|
|
|
|
// let inputs: serde_json::Value =
|
|
|
|
let inputs: serde_json::Value =
|
|
|
|
// inputs.into_iter().map(|ip| ip.into_json()).collect();
|
|
|
|
inputs.into_iter().map(|ip| ip.into_json()).collect();
|
|
|
|
// let options: serde_json::Value = options
|
|
|
|
let options: serde_json::Value = options
|
|
|
|
// .into_iter()
|
|
|
|
.into_iter()
|
|
|
|
// .map(|(k, v)| (k, serde_json::Value::from(v)))
|
|
|
|
.map(|(k, v)| (k, serde_json::Value::from(v)))
|
|
|
|
// .collect();
|
|
|
|
.collect();
|
|
|
|
// if down_sender.blocking_send((id, inputs, options)).is_err() {
|
|
|
|
if down_sender.blocking_send((id, inputs, options)).is_err() {
|
|
|
|
// let _ = sender.send(Err(miette!("cannot send request to downstream")));
|
|
|
|
let _ = sender.send(Err(miette!("cannot send request to downstream")));
|
|
|
|
// } else {
|
|
|
|
} else {
|
|
|
|
// rule_senders.lock().unwrap().insert(id, sender);
|
|
|
|
rule_senders.lock().unwrap().insert(id, sender);
|
|
|
|
// }
|
|
|
|
}
|
|
|
|
// }
|
|
|
|
}
|
|
|
|
// });
|
|
|
|
});
|
|
|
|
// }
|
|
|
|
}
|
|
|
|
//
|
|
|
|
|
|
|
|
// struct Guard {
|
|
|
|
struct Guard {
|
|
|
|
// name: String,
|
|
|
|
name: String,
|
|
|
|
// db: DbInstance,
|
|
|
|
db: DbInstance,
|
|
|
|
// }
|
|
|
|
}
|
|
|
|
//
|
|
|
|
|
|
|
|
// impl Drop for Guard {
|
|
|
|
impl Drop for Guard {
|
|
|
|
// fn drop(&mut self) {
|
|
|
|
fn drop(&mut self) {
|
|
|
|
// info!("dropping rules SSE {}", self.name);
|
|
|
|
info!("dropping rules SSE {}", self.name);
|
|
|
|
// let _ = self.db.unregister_fixed_rule(&self.name);
|
|
|
|
let _ = self.db.unregister_fixed_rule(&self.name);
|
|
|
|
// }
|
|
|
|
}
|
|
|
|
// }
|
|
|
|
}
|
|
|
|
//
|
|
|
|
|
|
|
|
// let stream = async_stream::stream! {
|
|
|
|
let stream = async_stream::stream! {
|
|
|
|
// if let Some(err) = errored {
|
|
|
|
if let Some(err) = errored {
|
|
|
|
// let item = json!({"type": "register-error", "error": err.to_string()});
|
|
|
|
let item = json!({"type": "register-error", "error": err.to_string()});
|
|
|
|
// yield Ok(Event::default().json_data(item).unwrap());
|
|
|
|
yield Ok(Event::default().json_data(item).unwrap());
|
|
|
|
// } else {
|
|
|
|
} else {
|
|
|
|
// info!("starting rule SSE {}", name);
|
|
|
|
info!("starting rule SSE {}", name);
|
|
|
|
// let _guard = Guard {db: st.db, name};
|
|
|
|
let _guard = Guard {db: st.db, name};
|
|
|
|
// while let Some((id, inputs, options)) = down_receiver.recv().await {
|
|
|
|
while let Some((id, inputs, options)) = down_receiver.recv().await {
|
|
|
|
// let item = json!({"type": "request", "id": id, "inputs": inputs, "options": options});
|
|
|
|
let item = json!({"type": "request", "id": id, "inputs": inputs, "options": options});
|
|
|
|
// yield Ok(Event::default().json_data(item).unwrap());
|
|
|
|
yield Ok(Event::default().json_data(item).unwrap());
|
|
|
|
// }
|
|
|
|
}
|
|
|
|
// }
|
|
|
|
}
|
|
|
|
// };
|
|
|
|
};
|
|
|
|
// Sse::new(stream).keep_alive(KeepAlive::default())
|
|
|
|
Sse::new(stream).keep_alive(KeepAlive::default())
|
|
|
|
// }
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
async fn observe_changes(
|
|
|
|
async fn observe_changes(
|
|
|
|
State(st): State<DbState>,
|
|
|
|
State(st): State<DbState>,
|
|
|
|