use axum for server; implement callback SSE
parent
b30ebf7b77
commit
0b0830d1ba
@ -0,0 +1,8 @@
|
||||
/*
|
||||
* Copyright 2023, The Cozo Project Authors.
|
||||
*
|
||||
* This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0.
|
||||
* If a copy of the MPL was not distributed with this file,
|
||||
* You can obtain one at https://mozilla.org/MPL/2.0/.
|
||||
*/
|
||||
|
@ -0,0 +1,52 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<link rel="icon" href="data:;base64,iVBORw0KGgo=">
|
||||
<title>Cozo database</title>
|
||||
</head>
|
||||
<body>
|
||||
<p>Cozo API is running.</p>
|
||||
<script>
|
||||
let COZO_AUTH = '';
|
||||
let LAST_RESP = null;
|
||||
|
||||
async function run(script, params) {
|
||||
const resp = await fetch('/text-query', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'x-cozo-auth': COZO_AUTH
|
||||
},
|
||||
body: JSON.stringify({
|
||||
script,
|
||||
params: params || {}
|
||||
})
|
||||
});
|
||||
if (resp.ok) {
|
||||
const json_resp = await resp.json();
|
||||
LAST_RESP = json_resp;
|
||||
if (json_resp) {
|
||||
json_resp.headers ||= [];
|
||||
console.table(json_resp.rows.map(row => {
|
||||
let ret = {};
|
||||
for (let i = 0; i < row.length; ++i) {
|
||||
ret[json_resp.headers[i] || `(${i})`] = row[i];
|
||||
}
|
||||
return ret
|
||||
}))
|
||||
}
|
||||
} else {
|
||||
console.error((await resp.json()).display)
|
||||
}
|
||||
}
|
||||
|
||||
console.log(
|
||||
`Welcome to the Cozo Makeshift Javascript Console!
|
||||
You can run your query like this:
|
||||
|
||||
await run("YOUR QUERY HERE", {param: value})
|
||||
|
||||
The global variables 'COZO_AUTH' and 'LAST_RESP' are available.`);
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
@ -0,0 +1,12 @@
|
||||
|
||||
====================================================================================
|
||||
!! SECURITY NOTICE, PLEASE READ !!
|
||||
====================================================================================
|
||||
You instructed Cozo to bind to a non-default address.
|
||||
Cozo is designed to be accessed by trusted clients in a trusted network.
|
||||
As a last defense against unauthorized access when everything else fails,
|
||||
any requests from non-loopback addresses require the HTTP request header
|
||||
`x-cozo-auth` to be set to the content of auth.txt in your database directory.
|
||||
This is not a sufficient protection against attacks, and you must set up
|
||||
proper authentication schemes, encryptions, etc. by firewalls and/or proxies.
|
||||
====================================================================================
|
@ -0,0 +1,348 @@
|
||||
/*
|
||||
* Copyright 2023, The Cozo Project Authors.
|
||||
*
|
||||
* This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0.
|
||||
* If a copy of the MPL was not distributed with this file,
|
||||
* You can obtain one at https://mozilla.org/MPL/2.0/.
|
||||
*/
|
||||
|
||||
use std::collections::BTreeMap;
|
||||
use std::convert::Infallible;
|
||||
use std::net::{Ipv6Addr, SocketAddr};
|
||||
use std::str::FromStr;
|
||||
use std::sync::{mpsc, Arc, Mutex};
|
||||
|
||||
use axum::extract::{Path, Query, State};
|
||||
use axum::http::StatusCode;
|
||||
use axum::response::sse::{Event, KeepAlive};
|
||||
use axum::response::{Html, Sse};
|
||||
use axum::routing::{get, post, put};
|
||||
use axum::{Json, Router};
|
||||
use clap::Args;
|
||||
use futures::stream::{self, Stream};
|
||||
use itertools::Itertools;
|
||||
use log::{info, warn};
|
||||
use serde_json::json;
|
||||
use tokio::task::spawn_blocking;
|
||||
|
||||
use cozo::{DataValue, DbInstance, NamedRows};
|
||||
|
||||
#[derive(Args, Debug)]
|
||||
pub(crate) struct ServerArgs {
|
||||
/// Database engine, can be `mem`, `sqlite`, `rocksdb` and others.
|
||||
#[clap(short, long, default_value_t = String::from("mem"))]
|
||||
engine: String,
|
||||
|
||||
/// Path to the directory to store the database
|
||||
#[clap(short, long, default_value_t = String::from("cozo.db"))]
|
||||
path: String,
|
||||
|
||||
// Restore from the specified backup before starting the server
|
||||
// #[clap(long)]
|
||||
// restore: Option<String>,
|
||||
/// Extra config in JSON format
|
||||
#[clap(short, long, default_value_t = String::from("{}"))]
|
||||
config: String,
|
||||
|
||||
// When on, start REPL instead of starting a webserver
|
||||
// #[clap(short, long)]
|
||||
// repl: bool,
|
||||
/// Address to bind the service to
|
||||
#[clap(short, long, default_value_t = String::from("127.0.0.1"))]
|
||||
bind: String,
|
||||
|
||||
/// Port to use
|
||||
#[clap(short = 'P', long, default_value_t = 9070)]
|
||||
port: u16,
|
||||
}
|
||||
|
||||
type RuleCallbackStore = BTreeMap<usize, crossbeam::channel::Sender<miette::Result<NamedRows>>>;
|
||||
type DbState = (DbInstance, Arc<Mutex<RuleCallbackStore>>);
|
||||
|
||||
pub(crate) async fn server_main(args: ServerArgs) {
|
||||
let db = DbInstance::new(&args.engine, args.path, &args.config).unwrap();
|
||||
let rule_channels: Arc<Mutex<RuleCallbackStore>> = Default::default();
|
||||
let state = (db, rule_channels);
|
||||
let app = Router::new()
|
||||
.fallback(not_found)
|
||||
.route("/", get(root))
|
||||
.route("/text-query", post(text_query))
|
||||
.route("/export/:relations", get(export_relations))
|
||||
.route("/import", put(import_relations))
|
||||
.route("/backup", post(backup))
|
||||
.route("/import-from-backup", post(import_from_backup))
|
||||
.route("/changes/:relations", get(observe_changes))
|
||||
.route("/rules/:name", get(register_rule)) // sse + post
|
||||
.route("/rules/:name/:id", post(rule_result))
|
||||
.with_state(state);
|
||||
let addr = if Ipv6Addr::from_str(&args.bind).is_ok() {
|
||||
SocketAddr::from_str(&format!("[{}]:{}", args.bind, args.port)).unwrap()
|
||||
} else {
|
||||
SocketAddr::from_str(&format!("{}:{}", args.bind, args.port)).unwrap()
|
||||
};
|
||||
|
||||
if args.bind != "127.0.0.1" {
|
||||
warn!("{}", include_str!("./security.txt"));
|
||||
}
|
||||
|
||||
info!(
|
||||
"Starting Cozo ({}-backed) API at http://{}",
|
||||
args.engine, addr
|
||||
);
|
||||
|
||||
axum::Server::bind(&addr)
|
||||
.serve(app.into_make_service())
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[derive(serde_derive::Deserialize)]
|
||||
struct QueryPayload {
|
||||
script: String,
|
||||
params: BTreeMap<String, serde_json::Value>,
|
||||
}
|
||||
|
||||
async fn text_query(
|
||||
State((db, _)): State<DbState>,
|
||||
Json(payload): Json<QueryPayload>,
|
||||
) -> (StatusCode, Json<serde_json::Value>) {
|
||||
let params = payload
|
||||
.params
|
||||
.into_iter()
|
||||
.map(|(k, v)| (k, DataValue::from(v)))
|
||||
.collect();
|
||||
let result = spawn_blocking(move || db.run_script_fold_err(&payload.script, params)).await;
|
||||
match result {
|
||||
Ok(res) => wrap_json(res),
|
||||
Err(err) => internal_error(err),
|
||||
}
|
||||
}
|
||||
|
||||
async fn export_relations(
|
||||
State((db, _)): State<DbState>,
|
||||
Path(relations): Path<String>,
|
||||
) -> (StatusCode, Json<serde_json::Value>) {
|
||||
let relations = relations
|
||||
.split(',')
|
||||
.filter_map(|t| {
|
||||
if t.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(t.to_string())
|
||||
}
|
||||
})
|
||||
.collect_vec();
|
||||
let result = spawn_blocking(move || db.export_relations(relations.iter())).await;
|
||||
match result {
|
||||
Ok(Ok(s)) => {
|
||||
let ret = json!({"ok": true, "data": s});
|
||||
(StatusCode::OK, ret.into())
|
||||
}
|
||||
Ok(Err(err)) => {
|
||||
let ret = json!({"ok": false, "message": err.to_string()});
|
||||
(StatusCode::BAD_REQUEST, ret.into())
|
||||
}
|
||||
Err(err) => internal_error(err),
|
||||
}
|
||||
}
|
||||
|
||||
async fn import_relations(
|
||||
State((db, _)): State<DbState>,
|
||||
Json(payload): Json<serde_json::Value>,
|
||||
) -> (StatusCode, Json<serde_json::Value>) {
|
||||
let payload = match payload.as_object() {
|
||||
None => {
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
json!({"ok": false, "message": "payload must be a JSON object"}).into(),
|
||||
)
|
||||
}
|
||||
Some(pl) => {
|
||||
let mut ret = BTreeMap::new();
|
||||
for (k, v) in pl {
|
||||
let nr = match NamedRows::from_json(v) {
|
||||
Ok(p) => p,
|
||||
Err(err) => {
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
json!({"ok": false, "message": err.to_string()}).into(),
|
||||
)
|
||||
}
|
||||
};
|
||||
ret.insert(k.to_string(), nr);
|
||||
}
|
||||
ret
|
||||
}
|
||||
};
|
||||
|
||||
let result = spawn_blocking(move || db.import_relations(payload)).await;
|
||||
match result {
|
||||
Ok(Ok(_)) => (StatusCode::OK, json!({"ok": true}).into()),
|
||||
Ok(Err(err)) => {
|
||||
let ret = json!({"ok": false, "message": err.to_string()});
|
||||
(StatusCode::BAD_REQUEST, ret.into())
|
||||
}
|
||||
Err(err) => internal_error(err),
|
||||
}
|
||||
}
|
||||
#[derive(serde_derive::Deserialize)]
|
||||
struct BackupPayload {
|
||||
path: String,
|
||||
}
|
||||
|
||||
async fn backup(
|
||||
State((db, _)): State<DbState>,
|
||||
Json(payload): Json<BackupPayload>,
|
||||
) -> (StatusCode, Json<serde_json::Value>) {
|
||||
let result = spawn_blocking(move || db.backup_db(payload.path)).await;
|
||||
|
||||
match result {
|
||||
Ok(Ok(())) => {
|
||||
let ret = json!({"ok": true});
|
||||
(StatusCode::OK, ret.into())
|
||||
}
|
||||
Ok(Err(err)) => {
|
||||
let ret = json!({"ok": false, "message": err.to_string()});
|
||||
(StatusCode::BAD_REQUEST, ret.into())
|
||||
}
|
||||
Err(err) => internal_error(err),
|
||||
}
|
||||
}
|
||||
#[derive(serde_derive::Deserialize)]
|
||||
struct BackupImportPayload {
|
||||
path: String,
|
||||
relations: Vec<String>,
|
||||
}
|
||||
async fn import_from_backup(
|
||||
State((db, _)): State<DbState>,
|
||||
Json(payload): Json<BackupImportPayload>,
|
||||
) -> (StatusCode, Json<serde_json::Value>) {
|
||||
let result =
|
||||
spawn_blocking(move || db.import_from_backup(&payload.path, &payload.relations)).await;
|
||||
|
||||
match result {
|
||||
Ok(Ok(())) => {
|
||||
let ret = json!({"ok": true});
|
||||
(StatusCode::OK, ret.into())
|
||||
}
|
||||
Ok(Err(err)) => {
|
||||
let ret = json!({"ok": false, "message": err.to_string()});
|
||||
(StatusCode::BAD_REQUEST, ret.into())
|
||||
}
|
||||
Err(err) => internal_error(err),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(serde_derive::Deserialize)]
|
||||
struct RuleRegisterOptions {
|
||||
arity: usize,
|
||||
}
|
||||
|
||||
async fn rule_result(
|
||||
State((store, _)): State<DbState>,
|
||||
Path(name): Path<String>,
|
||||
Path(id): Path<usize>,
|
||||
) -> (StatusCode, Json<serde_json::Value>) {
|
||||
todo!()
|
||||
}
|
||||
|
||||
async fn register_rule(
|
||||
State((db, cbs)): State<DbState>,
|
||||
Path(name): Path<String>,
|
||||
Query(rule_opts): Query<RuleRegisterOptions>,
|
||||
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
|
||||
let (id, recv) = db.register_callback(&name, None);
|
||||
let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
|
||||
struct Guard {
|
||||
id: u32,
|
||||
db: DbInstance,
|
||||
relation: String,
|
||||
}
|
||||
|
||||
impl Drop for Guard {
|
||||
fn drop(&mut self) {
|
||||
info!("dropping changes SSE {}: {}", self.relation, self.id);
|
||||
self.db.unregister_callback(self.id);
|
||||
}
|
||||
}
|
||||
|
||||
spawn_blocking(move || {
|
||||
for data in recv {
|
||||
sender.blocking_send(data).unwrap();
|
||||
}
|
||||
});
|
||||
let stream = async_stream::stream! {
|
||||
info!("starting callback SSE {}: {}", name, id);
|
||||
let _guard = Guard {id, db, relation: name};
|
||||
while let Some((op, new, old)) = receiver.recv().await {
|
||||
let item = json!({"op": op.to_string(), "new_rows": new.into_json(), "old_rows": old.into_json()});
|
||||
yield Ok(Event::default().json_data(item).unwrap());
|
||||
}
|
||||
};
|
||||
Sse::new(stream).keep_alive(KeepAlive::default())
|
||||
}
|
||||
|
||||
async fn observe_changes(
|
||||
State((db, _)): State<DbState>,
|
||||
Path(relation): Path<String>,
|
||||
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
|
||||
let (id, recv) = db.register_callback(&relation, None);
|
||||
let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
|
||||
struct Guard {
|
||||
id: u32,
|
||||
db: DbInstance,
|
||||
relation: String,
|
||||
}
|
||||
|
||||
impl Drop for Guard {
|
||||
fn drop(&mut self) {
|
||||
info!("dropping changes SSE {}: {}", self.relation, self.id);
|
||||
self.db.unregister_callback(self.id);
|
||||
}
|
||||
}
|
||||
|
||||
spawn_blocking(move || {
|
||||
for data in recv {
|
||||
sender.blocking_send(data).unwrap();
|
||||
}
|
||||
});
|
||||
let stream = async_stream::stream! {
|
||||
info!("starting changes SSE {}: {}", relation, id);
|
||||
let _guard = Guard {id, db, relation};
|
||||
while let Some((op, new, old)) = receiver.recv().await {
|
||||
let item = json!({"op": op.to_string(), "new_rows": new.into_json(), "old_rows": old.into_json()});
|
||||
yield Ok(Event::default().json_data(item).unwrap());
|
||||
}
|
||||
};
|
||||
Sse::new(stream).keep_alive(KeepAlive::default())
|
||||
}
|
||||
|
||||
async fn root() -> Html<&'static str> {
|
||||
Html(include_str!("./index.html"))
|
||||
}
|
||||
|
||||
fn internal_error<E>(err: E) -> (StatusCode, Json<serde_json::Value>)
|
||||
where
|
||||
E: std::error::Error,
|
||||
{
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
json!({"ok": false, "message": err.to_string()}).into(),
|
||||
)
|
||||
}
|
||||
|
||||
fn wrap_json(json: serde_json::Value) -> (StatusCode, Json<serde_json::Value>) {
|
||||
let code = if let Some(serde_json::Value::Bool(true)) = json.get("ok") {
|
||||
StatusCode::OK
|
||||
} else {
|
||||
StatusCode::BAD_REQUEST
|
||||
};
|
||||
(code, json.into())
|
||||
}
|
||||
|
||||
pub async fn not_found(uri: axum::http::Uri) -> (StatusCode, Json<serde_json::Value>) {
|
||||
(
|
||||
StatusCode::NOT_FOUND,
|
||||
json!({"ok": false, "message": format!("No route {}", uri)}).into(),
|
||||
)
|
||||
}
|
Loading…
Reference in New Issue