|
|
|
@ -14,8 +14,9 @@ use std::sync::atomic::{AtomicU32, Ordering};
|
|
|
|
|
use std::sync::{Arc, Mutex};
|
|
|
|
|
use std::thread;
|
|
|
|
|
|
|
|
|
|
use axum::body::{Body, BoxBody};
|
|
|
|
|
use axum::extract::{Path, Query, State};
|
|
|
|
|
use axum::http::StatusCode;
|
|
|
|
|
use axum::http::{Method, Request, Response, StatusCode};
|
|
|
|
|
use axum::response::sse::{Event, KeepAlive};
|
|
|
|
|
use axum::response::{Html, Sse};
|
|
|
|
|
use axum::routing::{get, post, put};
|
|
|
|
@ -23,10 +24,14 @@ use axum::{Json, Router};
|
|
|
|
|
use clap::Args;
|
|
|
|
|
use futures::stream::Stream;
|
|
|
|
|
use itertools::Itertools;
|
|
|
|
|
use log::{info, warn};
|
|
|
|
|
use log::{error, info, warn};
|
|
|
|
|
use miette::miette;
|
|
|
|
|
use rand::Rng;
|
|
|
|
|
use serde_json::json;
|
|
|
|
|
use tokio::task::spawn_blocking;
|
|
|
|
|
use tower_http::auth::RequireAuthorizationLayer;
|
|
|
|
|
use tower_http::compression::CompressionLayer;
|
|
|
|
|
use tower_http::cors::{Any, CorsLayer};
|
|
|
|
|
|
|
|
|
|
use cozo::{
|
|
|
|
|
format_error_as_json, DataValue, DbInstance, MultiTransaction, NamedRows, SimpleFixedRule,
|
|
|
|
@ -42,9 +47,10 @@ pub(crate) struct ServerArgs {
|
|
|
|
|
#[clap(short, long, default_value_t = String::from("cozo.db"))]
|
|
|
|
|
path: String,
|
|
|
|
|
|
|
|
|
|
// Restore from the specified backup before starting the server
|
|
|
|
|
// #[clap(long)]
|
|
|
|
|
// restore: Option<String>,
|
|
|
|
|
/// Restore from the specified backup before starting the server
|
|
|
|
|
#[clap(long)]
|
|
|
|
|
restore: Option<String>,
|
|
|
|
|
|
|
|
|
|
/// Extra config in JSON format
|
|
|
|
|
#[clap(short, long, default_value_t = String::from("{}"))]
|
|
|
|
|
config: String,
|
|
|
|
@ -71,7 +77,35 @@ struct DbState {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub(crate) async fn server_main(args: ServerArgs) {
|
|
|
|
|
let db = DbInstance::new(&args.engine, args.path, &args.config).unwrap();
|
|
|
|
|
let db = DbInstance::new(&args.engine, &args.path, &args.config).unwrap();
|
|
|
|
|
if let Some(p) = &args.restore {
|
|
|
|
|
if let Err(err) = db.restore_backup(p) {
|
|
|
|
|
error!("{}", err);
|
|
|
|
|
error!("Restore from backup failed, terminate");
|
|
|
|
|
panic!()
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
let skip_auth = args.bind == "127.0.0.1";
|
|
|
|
|
|
|
|
|
|
let conf_path = if skip_auth {"".to_string()} else { format!("{}.{}.cozo_auth", args.path, args.engine)};
|
|
|
|
|
let auth_guard = if skip_auth {
|
|
|
|
|
"".to_string()
|
|
|
|
|
} else {
|
|
|
|
|
match tokio::fs::read_to_string(&conf_path).await {
|
|
|
|
|
Ok(s) => s.trim().to_string(),
|
|
|
|
|
Err(_) => {
|
|
|
|
|
let s = rand::thread_rng()
|
|
|
|
|
.sample_iter(&rand::distributions::Alphanumeric)
|
|
|
|
|
.take(64)
|
|
|
|
|
.map(char::from)
|
|
|
|
|
.collect();
|
|
|
|
|
tokio::fs::write(&conf_path, &s).await.unwrap();
|
|
|
|
|
s
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
let state = DbState {
|
|
|
|
|
db,
|
|
|
|
|
rule_senders: Default::default(),
|
|
|
|
@ -79,9 +113,11 @@ pub(crate) async fn server_main(args: ServerArgs) {
|
|
|
|
|
tx_counter: Default::default(),
|
|
|
|
|
txs: Default::default(),
|
|
|
|
|
};
|
|
|
|
|
let cors = CorsLayer::new()
|
|
|
|
|
.allow_methods([Method::GET, Method::POST, Method::PUT, Method::DELETE])
|
|
|
|
|
.allow_origin(Any);
|
|
|
|
|
|
|
|
|
|
let app = Router::new()
|
|
|
|
|
.fallback(not_found)
|
|
|
|
|
.route("/", get(root))
|
|
|
|
|
.route("/text-query", post(text_query))
|
|
|
|
|
.route("/export/:relations", get(export_relations))
|
|
|
|
|
.route("/import", put(import_relations))
|
|
|
|
@ -94,8 +130,54 @@ pub(crate) async fn server_main(args: ServerArgs) {
|
|
|
|
|
post(post_rule_result).delete(post_rule_err),
|
|
|
|
|
) // +keep alive
|
|
|
|
|
.route("/transact", post(start_transact))
|
|
|
|
|
.route("/transact/:id", post(transact_query).put(finish_query)) // +keep alive
|
|
|
|
|
.with_state(state);
|
|
|
|
|
.route("/transact/:id", post(transact_query).put(finish_query))
|
|
|
|
|
.with_state(state)
|
|
|
|
|
.layer(RequireAuthorizationLayer::custom(
|
|
|
|
|
move |request: &mut Request<Body>| {
|
|
|
|
|
if skip_auth {
|
|
|
|
|
return Ok(());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
let ok = match request.headers().get("x-cozo-auth") {
|
|
|
|
|
None => match request.uri().query() {
|
|
|
|
|
None => false,
|
|
|
|
|
Some(q_str) => {
|
|
|
|
|
let mut bingo = false;
|
|
|
|
|
for pair in q_str.split('&') {
|
|
|
|
|
if let Some((k, v)) = pair.split_once('=') {
|
|
|
|
|
if k == "auth" {
|
|
|
|
|
if v == &auth_guard {
|
|
|
|
|
bingo = true
|
|
|
|
|
}
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
bingo
|
|
|
|
|
}
|
|
|
|
|
},
|
|
|
|
|
Some(data) => match data.to_str() {
|
|
|
|
|
Ok(s) => s == &auth_guard,
|
|
|
|
|
Err(_) => false,
|
|
|
|
|
},
|
|
|
|
|
};
|
|
|
|
|
if ok {
|
|
|
|
|
Ok(())
|
|
|
|
|
} else {
|
|
|
|
|
let unauthorized_response = Response::builder()
|
|
|
|
|
.status(StatusCode::UNAUTHORIZED)
|
|
|
|
|
.body(BoxBody::default())
|
|
|
|
|
.unwrap();
|
|
|
|
|
|
|
|
|
|
Err(unauthorized_response.into())
|
|
|
|
|
}
|
|
|
|
|
},
|
|
|
|
|
))
|
|
|
|
|
.fallback(not_found)
|
|
|
|
|
.route("/", get(root))
|
|
|
|
|
.layer(cors)
|
|
|
|
|
.layer(CompressionLayer::new());
|
|
|
|
|
|
|
|
|
|
let addr = if Ipv6Addr::from_str(&args.bind).is_ok() {
|
|
|
|
|
SocketAddr::from_str(&format!("[{}]:{}", args.bind, args.port)).unwrap()
|
|
|
|
|
} else {
|
|
|
|
@ -104,6 +186,7 @@ pub(crate) async fn server_main(args: ServerArgs) {
|
|
|
|
|
|
|
|
|
|
if args.bind != "127.0.0.1" {
|
|
|
|
|
warn!("{}", include_str!("./security.txt"));
|
|
|
|
|
info!("The auth token is in the file: {conf_path}");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
info!(
|
|
|
|
|