implement full API for server

main
Ziyang Hu 2 years ago
parent 110da3828b
commit aa4f73546b

@ -10,7 +10,9 @@ use std::collections::BTreeMap;
use std::convert::Infallible;
use std::net::{Ipv6Addr, SocketAddr};
use std::str::FromStr;
use std::sync::{mpsc, Arc, Mutex};
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::{Arc, Mutex};
use std::thread;
use axum::extract::{Path, Query, State};
use axum::http::StatusCode;
@ -19,13 +21,16 @@ use axum::response::{Html, Sse};
use axum::routing::{get, post, put};
use axum::{Json, Router};
use clap::Args;
use futures::stream::{self, Stream};
use futures::stream::Stream;
use itertools::Itertools;
use log::{info, warn};
use miette::miette;
use serde_json::json;
use tokio::task::spawn_blocking;
use cozo::{DataValue, DbInstance, NamedRows};
use cozo::{
format_error_as_json, DataValue, DbInstance, MultiTransaction, NamedRows, SimpleFixedRule,
};
#[derive(Args, Debug)]
pub(crate) struct ServerArgs {
@ -56,13 +61,24 @@ pub(crate) struct ServerArgs {
port: u16,
}
type RuleCallbackStore = BTreeMap<usize, crossbeam::channel::Sender<miette::Result<NamedRows>>>;
type DbState = (DbInstance, Arc<Mutex<RuleCallbackStore>>);
#[derive(Clone)]
struct DbState {
db: DbInstance,
rule_senders: Arc<Mutex<BTreeMap<u32, crossbeam::channel::Sender<miette::Result<NamedRows>>>>>,
rule_counter: Arc<AtomicU32>,
tx_counter: Arc<AtomicU32>,
txs: Arc<Mutex<BTreeMap<u32, Arc<MultiTransaction>>>>,
}
pub(crate) async fn server_main(args: ServerArgs) {
let db = DbInstance::new(&args.engine, args.path, &args.config).unwrap();
let rule_channels: Arc<Mutex<RuleCallbackStore>> = Default::default();
let state = (db, rule_channels);
let state = DbState {
db,
rule_senders: Default::default(),
rule_counter: Default::default(),
tx_counter: Default::default(),
txs: Default::default(),
};
let app = Router::new()
.fallback(not_found)
.route("/", get(root))
@ -72,8 +88,13 @@ pub(crate) async fn server_main(args: ServerArgs) {
.route("/backup", post(backup))
.route("/import-from-backup", post(import_from_backup))
.route("/changes/:relation", get(observe_changes))
.route("/rules/:name", get(register_rule)) // sse + post
.route("/rules/:name/:id", post(rule_result))
.route("/rules/:name", get(register_rule))
.route(
"/rule-result/:id",
post(post_rule_result).delete(post_rule_err),
) // +keep alive
.route("/transact", post(start_transact))
.route("/transact/:id", post(transact_query).put(finish_query)) // +keep alive
.with_state(state);
let addr = if Ipv6Addr::from_str(&args.bind).is_ok() {
SocketAddr::from_str(&format!("[{}]:{}", args.bind, args.port)).unwrap()
@ -96,6 +117,79 @@ pub(crate) async fn server_main(args: ServerArgs) {
.unwrap();
}
#[derive(serde_derive::Deserialize)]
struct StartTransactPayload {
write: bool,
}
async fn start_transact(
State(st): State<DbState>,
Query(payload): Query<StartTransactPayload>,
) -> (StatusCode, Json<serde_json::Value>) {
let tx = st.db.multi_transaction(payload.write);
let id = st.tx_counter.fetch_add(1, Ordering::SeqCst);
st.txs.lock().unwrap().insert(id, Arc::new(tx));
(StatusCode::OK, json!({"ok": true, "id": id}).into())
}
async fn transact_query(
State(st): State<DbState>,
Path(id): Path<u32>,
Json(payload): Json<QueryPayload>,
) -> (StatusCode, Json<serde_json::Value>) {
let tx = match st.txs.lock().unwrap().get(&id) {
None => return (StatusCode::NOT_FOUND, json!({"ok": false}).into()),
Some(tx) => tx.clone(),
};
let src = payload.script.clone();
let result = spawn_blocking(move || {
let params = payload
.params
.into_iter()
.map(|(k, v)| (k, DataValue::from(v)))
.collect();
let query = payload.script;
tx.run_script(&query, params)
})
.await;
match result {
Ok(Ok(res)) => (StatusCode::OK, res.into_json().into()),
Ok(Err(err)) => (
StatusCode::BAD_REQUEST,
format_error_as_json(err, Some(&src)).into(),
),
Err(err) => internal_error(err),
}
}
#[derive(serde_derive::Deserialize)]
struct FinishTransactPayload {
abort: bool,
}
async fn finish_query(
State(st): State<DbState>,
Path(id): Path<u32>,
Json(payload): Json<FinishTransactPayload>,
) -> (StatusCode, Json<serde_json::Value>) {
let tx = match st.txs.lock().unwrap().remove(&id) {
None => return (StatusCode::NOT_FOUND, json!({"ok": false}).into()),
Some(tx) => tx,
};
let res = if payload.abort {
tx.abort()
} else {
tx.commit()
};
match res {
Ok(_) => (StatusCode::OK, json!({"ok": true}).into()),
Err(err) => (
StatusCode::BAD_REQUEST,
json!({"ok": false, "message": err.to_string()}).into(),
),
}
}
#[derive(serde_derive::Deserialize)]
struct QueryPayload {
script: String,
@ -103,7 +197,7 @@ struct QueryPayload {
}
async fn text_query(
State((db, _)): State<DbState>,
State(st): State<DbState>,
Json(payload): Json<QueryPayload>,
) -> (StatusCode, Json<serde_json::Value>) {
let params = payload
@ -111,7 +205,7 @@ async fn text_query(
.into_iter()
.map(|(k, v)| (k, DataValue::from(v)))
.collect();
let result = spawn_blocking(move || db.run_script_fold_err(&payload.script, params)).await;
let result = spawn_blocking(move || st.db.run_script_fold_err(&payload.script, params)).await;
match result {
Ok(res) => wrap_json(res),
Err(err) => internal_error(err),
@ -119,7 +213,7 @@ async fn text_query(
}
async fn export_relations(
State((db, _)): State<DbState>,
State(st): State<DbState>,
Path(relations): Path<String>,
) -> (StatusCode, Json<serde_json::Value>) {
let relations = relations
@ -132,7 +226,7 @@ async fn export_relations(
}
})
.collect_vec();
let result = spawn_blocking(move || db.export_relations(relations.iter())).await;
let result = spawn_blocking(move || st.db.export_relations(relations.iter())).await;
match result {
Ok(Ok(s)) => {
let ret = json!({"ok": true, "data": s});
@ -147,7 +241,7 @@ async fn export_relations(
}
async fn import_relations(
State((db, _)): State<DbState>,
State(st): State<DbState>,
Json(payload): Json<serde_json::Value>,
) -> (StatusCode, Json<serde_json::Value>) {
let payload = match payload.as_object() {
@ -175,7 +269,7 @@ async fn import_relations(
}
};
let result = spawn_blocking(move || db.import_relations(payload)).await;
let result = spawn_blocking(move || st.db.import_relations(payload)).await;
match result {
Ok(Ok(_)) => (StatusCode::OK, json!({"ok": true}).into()),
Ok(Err(err)) => {
@ -191,10 +285,10 @@ struct BackupPayload {
}
async fn backup(
State((db, _)): State<DbState>,
State(st): State<DbState>,
Json(payload): Json<BackupPayload>,
) -> (StatusCode, Json<serde_json::Value>) {
let result = spawn_blocking(move || db.backup_db(payload.path)).await;
let result = spawn_blocking(move || st.db.backup_db(payload.path)).await;
match result {
Ok(Ok(())) => {
@ -214,11 +308,11 @@ struct BackupImportPayload {
relations: Vec<String>,
}
async fn import_from_backup(
State((db, _)): State<DbState>,
State(st): State<DbState>,
Json(payload): Json<BackupImportPayload>,
) -> (StatusCode, Json<serde_json::Value>) {
let result =
spawn_blocking(move || db.import_from_backup(&payload.path, &payload.relations)).await;
spawn_blocking(move || st.db.import_from_backup(&payload.path, &payload.relations)).await;
match result {
Ok(Ok(())) => {
@ -238,55 +332,119 @@ struct RuleRegisterOptions {
arity: usize,
}
async fn rule_result(
State((store, _)): State<DbState>,
Path(name): Path<String>,
Path(id): Path<usize>,
async fn post_rule_result(
State(st): State<DbState>,
Path(id): Path<u32>,
Json(res): Json<serde_json::Value>,
) -> (StatusCode, Json<serde_json::Value>) {
let res = match NamedRows::from_json(&res) {
Ok(res) => res,
Err(err) => {
if let Some(ch) = st.rule_senders.lock().unwrap().remove(&id) {
let _ = ch.send(Err(miette!("downstream posted malformed result")));
}
return (
StatusCode::BAD_REQUEST,
json!({"ok": false, "message": err.to_string()}).into(),
);
}
};
if let Some(ch) = st.rule_senders.lock().unwrap().remove(&id) {
match ch.send(Ok(res)) {
Ok(_) => (StatusCode::OK, json!({"ok": true}).into()),
Err(err) => (
StatusCode::INTERNAL_SERVER_ERROR,
json!({"ok": false, "message": err.to_string()}).into(),
),
}
} else {
(StatusCode::NOT_FOUND, json!({"ok": false}).into())
}
}
async fn post_rule_err(
State(st): State<DbState>,
Path(id): Path<u32>,
) -> (StatusCode, Json<serde_json::Value>) {
todo!()
if let Some(ch) = st.rule_senders.lock().unwrap().remove(&id) {
match ch.send(Err(miette!("downstream cancelled computation"))) {
Ok(_) => (StatusCode::OK, json!({"ok": true}).into()),
Err(err) => (
StatusCode::INTERNAL_SERVER_ERROR,
json!({"ok": false, "message": err.to_string()}).into(),
),
}
} else {
(StatusCode::NOT_FOUND, json!({"ok": false}).into())
}
}
async fn register_rule(
State((db, cbs)): State<DbState>,
State(st): State<DbState>,
Path(name): Path<String>,
Query(rule_opts): Query<RuleRegisterOptions>,
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
let (id, recv) = db.register_callback(&name, None);
let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
let (rule, task_receiver) = SimpleFixedRule::rule_with_channel(rule_opts.arity);
let (down_sender, mut down_receiver) = tokio::sync::mpsc::channel(1);
let mut errored = None;
if let Err(err) = st.db.register_fixed_rule(name.clone(), rule) {
errored = Some(err);
} else {
let rule_senders = st.rule_senders.clone();
let rule_counter = st.rule_counter.clone();
thread::spawn(move || {
for (inputs, options, sender) in task_receiver {
let id = rule_counter.fetch_add(1, Ordering::AcqRel);
let inputs: serde_json::Value =
inputs.into_iter().map(|ip| ip.into_json()).collect();
let options: serde_json::Value = options
.into_iter()
.map(|(k, v)| (k, serde_json::Value::from(v)))
.collect();
if down_sender.blocking_send((id, inputs, options)).is_err() {
let _ = sender.send(Err(miette!("cannot send request to downstream")));
} else {
rule_senders.lock().unwrap().insert(id, sender);
}
}
});
}
struct Guard {
id: u32,
name: String,
db: DbInstance,
relation: String,
}
impl Drop for Guard {
fn drop(&mut self) {
info!("dropping changes SSE {}: {}", self.relation, self.id);
self.db.unregister_callback(self.id);
info!("dropping rules SSE {}", self.name);
let _ = self.db.unregister_fixed_rule(&self.name);
}
}
spawn_blocking(move || {
for data in recv {
sender.blocking_send(data).unwrap();
}
});
let stream = async_stream::stream! {
info!("starting callback SSE {}: {}", name, id);
let _guard = Guard {id, db, relation: name};
while let Some((op, new, old)) = receiver.recv().await {
let item = json!({"op": op.to_string(), "new_rows": new.into_json(), "old_rows": old.into_json()});
if let Some(err) = errored {
let item = json!({"type": "register-error", "error": err.to_string()});
yield Ok(Event::default().json_data(item).unwrap());
} else {
info!("starting rule SSE {}", name);
let _guard = Guard {db: st.db, name};
while let Some((id, inputs, options)) = down_receiver.recv().await {
let item = json!({"type": "request", "id": id, "inputs": inputs, "options": options});
yield Ok(Event::default().json_data(item).unwrap());
}
}
};
Sse::new(stream).keep_alive(KeepAlive::default())
}
async fn observe_changes(
State((db, _)): State<DbState>,
State(st): State<DbState>,
Path(relation): Path<String>,
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
let (id, recv) = db.register_callback(&relation, None);
let (id, recv) = st.db.register_callback(&relation, None);
let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
struct Guard {
id: u32,
@ -308,7 +466,7 @@ async fn observe_changes(
});
let stream = async_stream::stream! {
info!("starting changes SSE {}: {}", relation, id);
let _guard = Guard {id, db, relation};
let _guard = Guard {id, db: st.db, relation};
while let Some((op, new, old)) = receiver.recv().await {
let item = json!({"op": op.to_string(), "new_rows": new.into_json(), "old_rows": old.into_json()});
yield Ok(Event::default().json_data(item).unwrap());

Loading…
Cancel
Save