auth token for server

main
Ziyang Hu 2 years ago
parent aa4f73546b
commit 2abdda43f9

108
Cargo.lock generated

@ -49,6 +49,21 @@ dependencies = [
"memchr",
]
[[package]]
name = "alloc-no-stdlib"
version = "2.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cc7bb162ec39d46ab1ca8c77bf72e890535becd1751bb45f64c597edb4c8c6b3"
[[package]]
name = "alloc-stdlib"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94fb8275041c72129eb51b7d0322c29b8387a0386127718b096429201a5d6ece"
dependencies = [
"alloc-no-stdlib",
]
[[package]]
name = "android_system_properties"
version = "0.1.5"
@ -79,6 +94,20 @@ version = "0.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8da52d66c7071e2e3fa2a1e5c6d088fec47b593032b254f5e980de8ea54454d6"
[[package]]
name = "async-compression"
version = "0.3.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "942c7cd7ae39e91bde4820d74132e9862e62c2f386c3aa90ccf55949f5bad63a"
dependencies = [
"brotli",
"flate2",
"futures-core",
"memchr",
"pin-project-lite",
"tokio",
]
[[package]]
name = "async-stream"
version = "0.3.3"
@ -284,6 +313,27 @@ dependencies = [
"cmake",
]
[[package]]
name = "brotli"
version = "3.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a1a0b1dbcc8ae29329621f8d4f0d835787c1c38bb1401979b49d13b0b305ff68"
dependencies = [
"alloc-no-stdlib",
"alloc-stdlib",
"brotli-decompressor",
]
[[package]]
name = "brotli-decompressor"
version = "2.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4b6561fd3f895a11e8f72af2cb7d22e08366bebc2b6b57f7744c4bda27034744"
dependencies = [
"alloc-no-stdlib",
"alloc-stdlib",
]
[[package]]
name = "bstr"
version = "0.2.17"
@ -369,7 +419,7 @@ version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f4aedb84272dbe89af497cf81375129abda4fc0a9e7c5d317498c15cc30c0d27"
dependencies = [
"nom",
"nom 5.1.2",
]
[[package]]
@ -636,6 +686,7 @@ dependencies = [
"serde_derive",
"serde_json",
"tokio",
"tower-http",
]
[[package]]
@ -1641,6 +1692,15 @@ version = "2.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "30e22bd8629359895450b59ea7a776c850561b96a3b1d31321c1949d9e6c9146"
[[package]]
name = "iri-string"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8f0f7638c1e223529f1bfdc48c8b133b9e0b434094d1d28473161ee48b235f78"
dependencies = [
"nom 7.1.3",
]
[[package]]
name = "is-terminal"
version = "0.4.2"
@ -1918,6 +1978,22 @@ version = "0.3.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2a60c7ce501c71e03a9c9c0d35b861413ae925bd979cc7a4e30d060069aaac8d"
[[package]]
name = "mime_guess"
version = "2.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4192263c238a5f0d0c6bfd21f336a313a4ce1c450542449ca191bb657b4642ef"
dependencies = [
"mime",
"unicase",
]
[[package]]
name = "minimal-lexical"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a"
[[package]]
name = "miniz_oxide"
version = "0.6.2"
@ -2065,6 +2141,16 @@ dependencies = [
"version_check",
]
[[package]]
name = "nom"
version = "7.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a"
dependencies = [
"memchr",
"minimal-lexical",
]
[[package]]
name = "num"
version = "0.4.0"
@ -3855,6 +3941,8 @@ version = "0.3.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f873044bf02dd1e8239e9c1293ea39dad76dc594ec16185d0a1bf31d8dc8d858"
dependencies = [
"async-compression",
"base64 0.13.1",
"bitflags",
"bytes",
"futures-core",
@ -3862,10 +3950,19 @@ dependencies = [
"http",
"http-body",
"http-range-header",
"httpdate",
"iri-string",
"mime",
"mime_guess",
"percent-encoding",
"pin-project-lite",
"tokio",
"tokio-util",
"tower",
"tower-layer",
"tower-service",
"tracing",
"uuid",
]
[[package]]
@ -3919,6 +4016,15 @@ version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9e79c4d996edb816c91e4308506774452e55e95c3c9de07b6729e17e15a5ef81"
[[package]]
name = "unicase"
version = "2.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "50f37be617794602aabbeee0be4f259dc1778fabe05e2d67ee8f79326d5cb4f6"
dependencies = [
"version_check",
]
[[package]]
name = "unicode-bidi"
version = "0.3.8"

@ -64,3 +64,4 @@ async-stream = "0.3.3"
futures = "0.3.25"
crossbeam = "0.8.2"
eventsource-client = "0.11.0"
tower-http = { version = "0.3.5", features = ["full"] }

@ -70,7 +70,7 @@ and a nicely-formatted diagnostic will be in `"display"` if available.
> As a guard against users accidentally exposing sensitive data,
> If you bind Cozo to non-loopback addresses,
> Cozo will generate a token string and require all queries from non-loopback addresses
> to provide the token string in the HTTP header field x-cozo-auth.
> to provide the token string in the HTTP header field `x-cozo-auth`.
> The warning printed when you start Cozo with a
> non-default binding will tell you where to find the token string.
> This “security measure” is not considered sufficient for any purpose

@ -6,22 +6,16 @@
* You can obtain one at https://mozilla.org/MPL/2.0/.
*/
use std::fmt::Debug;
use std::fs;
extern crate core;
use std::process::exit;
use clap::{Args, Parser, Subcommand};
use clap::{Parser, Subcommand};
use env_logger::Env;
use log::{error, info};
use crate::repl::repl_main;
use crate::repl::{repl_main, ReplArgs};
use crate::server::{server_main, ServerArgs};
// use rand::Rng;
// use serde_json::json;
// use cozo::*;
mod client;
mod repl;
mod server;
@ -36,131 +30,37 @@ struct AppArgs {
#[derive(Subcommand)]
enum Commands {
///
Server(ServerArgs),
Client(ClientArgs),
Repl(ReplArgs),
Restore(RestoreArgs),
Stream(StreamArgs),
}
#[derive(Args, Debug)]
struct ReplArgs {
/// Database engine, can be `mem`, `sqlite`, `rocksdb` and others.
#[clap(short, long, default_value_t = String::from("mem"))]
engine: String,
/// Path to the directory to store the database
#[clap(short, long, default_value_t = String::from("cozo.db"))]
path: String,
/// Extra config in JSON format
#[clap(short, long, default_value_t = String::from("{}"))]
config: String,
}
#[derive(Args, Debug)]
struct ClientArgs {
#[clap(default_value_t = String::from("http://127.0.0.1:9070"))]
address: String,
#[clap(short, long, default_value_t = String::from(""))]
auth: String,
}
#[derive(Args, Debug)]
struct StreamArgs {}
#[derive(Args, Debug)]
struct RestoreArgs {
/// Path of the backup file to restore from, must be a SQLite-backed backup file.
from: String,
/// Path of the database to restore into
to: String,
/// Database engine for the database to restore into, can be `mem`, `sqlite`, `rocksdb` and others.
#[clap(short, long)]
engine: String,
}
// macro_rules! check_auth {
// ($request:expr, $auth_guard:expr) => {
// match $request.header("x-cozo-auth") {
// None => return Response::text("Unauthorized").with_status_code(401),
// Some(code) => {
// if $auth_guard != code {
// return Response::text("Unauthorized").with_status_code(401);
// }
// }
// }
// };
// }
fn main() {
let args = match AppArgs::parse().command {
Commands::Server(s) => s,
Commands::Client(_) => {
todo!()
}
Commands::Repl(_) => {
todo!()
match AppArgs::parse().command {
Commands::Server(args) => {
env_logger::Builder::from_env(Env::default().default_filter_or("info")).init();
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap()
.block_on(server_main(args))
}
Commands::Restore(_) => {
todo!()
Commands::Repl(args) => {
if let Err(e) = repl_main(args) {
eprintln!("{e}");
exit(-1);
}
Commands::Stream(_) => {
todo!()
}
};
// if let Some(restore_path) = &args.restore {
// db.restore_backup(restore_path).unwrap();
// }
// if args.repl {
// let db_copy = db.clone();
// ctrlc::set_handler(move || {
// let running = db_copy
// .run_script("::running", Default::default())
// .expect("Cannot determine running queries");
// for row in running.rows {
// let id = row.into_iter().next().unwrap();
// eprintln!("Killing running query {id}");
// db_copy
// .run_script("::kill $id", BTreeMap::from([("id".to_string(), id)]))
// .expect("Cannot kill process");
// }
// })
// .expect("Error setting Ctrl-C handler");
//
// if let Err(e) = repl_main(db) {
// eprintln!("{e}");
// exit(-1);
// }
// } else {
env_logger::Builder::from_env(Env::default().default_filter_or("info")).init();
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap()
.block_on(server_main(args))
// server_main(args, db)
// }
}
// fn server_main(args: Server, db: DbInstance) {
// let conf_path = format!("{}.{}.cozo_auth", args.path, args.engine);
// let auth_guard = match fs::read_to_string(&conf_path) {
// Ok(s) => s.trim().to_string(),
// Err(_) => {
// let s = rand::thread_rng()
// .sample_iter(&rand::distributions::Alphanumeric)
// .take(64)
// .map(char::from)
// .collect();
// fs::write(&conf_path, &s).unwrap();
// s
// }
// };
//
// let addr = if Ipv6Addr::from_str(&args.bind).is_ok() {
// format!("[{}]:{}", args.bind, args.port)

@ -13,6 +13,7 @@ use std::error::Error;
use std::fs::File;
use std::io::{Read, Write};
use clap::Args;
use miette::{bail, miette, IntoDiagnostic};
use serde_json::{json, Value};
@ -57,7 +58,39 @@ impl rustyline::validate::Validator for Indented {
}
}
pub(crate) fn repl_main(db: DbInstance) -> Result<(), Box<dyn Error>> {
#[derive(Args, Debug)]
pub(crate) struct ReplArgs {
/// Database engine, can be `mem`, `sqlite`, `rocksdb` and others.
#[clap(short, long, default_value_t = String::from("mem"))]
engine: String,
/// Path to the directory to store the database
#[clap(short, long, default_value_t = String::from("cozo.db"))]
path: String,
/// Extra config in JSON format
#[clap(short, long, default_value_t = String::from("{}"))]
config: String,
}
pub(crate) fn repl_main(args: ReplArgs) -> Result<(), Box<dyn Error>> {
let db = DbInstance::new(&args.engine, args.path, &args.config).unwrap();
let db_copy = db.clone();
ctrlc::set_handler(move || {
let running = db_copy
.run_script("::running", Default::default())
.expect("Cannot determine running queries");
for row in running.rows {
let id = row.into_iter().next().unwrap();
eprintln!("Killing running query {id}");
db_copy
.run_script("::kill $id", BTreeMap::from([("id".to_string(), id)]))
.expect("Cannot kill process");
}
})
.expect("Error setting Ctrl-C handler");
println!("Welcome to the Cozo REPL.");
println!("Type a space followed by newline to enter multiline mode.");

@ -3,10 +3,15 @@
!! SECURITY NOTICE, PLEASE READ !!
====================================================================================
You instructed Cozo to bind to a non-default address.
Cozo is designed to be accessed by trusted clients in a trusted network.
As a last defense against unauthorized access when everything else fails,
any requests from non-loopback addresses require the HTTP request header
`x-cozo-auth` to be set to the content of auth.txt in your database directory.
any requests now require the HTTP request header `x-cozo-auth` to be set,
or the query parameter `auth=<AUTH_STR>` be set to the auth token.
The auth token is found in a file indicated below.
This is required even if the request comes from localhost.
This is not a sufficient protection against attacks, and you must set up
proper authentication schemes, encryptions, etc. by firewalls and/or proxies.
====================================================================================

@ -14,8 +14,9 @@ use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::{Arc, Mutex};
use std::thread;
use axum::body::{Body, BoxBody};
use axum::extract::{Path, Query, State};
use axum::http::StatusCode;
use axum::http::{Method, Request, Response, StatusCode};
use axum::response::sse::{Event, KeepAlive};
use axum::response::{Html, Sse};
use axum::routing::{get, post, put};
@ -23,10 +24,14 @@ use axum::{Json, Router};
use clap::Args;
use futures::stream::Stream;
use itertools::Itertools;
use log::{info, warn};
use log::{error, info, warn};
use miette::miette;
use rand::Rng;
use serde_json::json;
use tokio::task::spawn_blocking;
use tower_http::auth::RequireAuthorizationLayer;
use tower_http::compression::CompressionLayer;
use tower_http::cors::{Any, CorsLayer};
use cozo::{
format_error_as_json, DataValue, DbInstance, MultiTransaction, NamedRows, SimpleFixedRule,
@ -42,9 +47,10 @@ pub(crate) struct ServerArgs {
#[clap(short, long, default_value_t = String::from("cozo.db"))]
path: String,
// Restore from the specified backup before starting the server
// #[clap(long)]
// restore: Option<String>,
/// Restore from the specified backup before starting the server
#[clap(long)]
restore: Option<String>,
/// Extra config in JSON format
#[clap(short, long, default_value_t = String::from("{}"))]
config: String,
@ -71,7 +77,35 @@ struct DbState {
}
pub(crate) async fn server_main(args: ServerArgs) {
let db = DbInstance::new(&args.engine, args.path, &args.config).unwrap();
let db = DbInstance::new(&args.engine, &args.path, &args.config).unwrap();
if let Some(p) = &args.restore {
if let Err(err) = db.restore_backup(p) {
error!("{}", err);
error!("Restore from backup failed, terminate");
panic!()
}
}
let skip_auth = args.bind == "127.0.0.1";
let conf_path = if skip_auth {"".to_string()} else { format!("{}.{}.cozo_auth", args.path, args.engine)};
let auth_guard = if skip_auth {
"".to_string()
} else {
match tokio::fs::read_to_string(&conf_path).await {
Ok(s) => s.trim().to_string(),
Err(_) => {
let s = rand::thread_rng()
.sample_iter(&rand::distributions::Alphanumeric)
.take(64)
.map(char::from)
.collect();
tokio::fs::write(&conf_path, &s).await.unwrap();
s
}
}
};
let state = DbState {
db,
rule_senders: Default::default(),
@ -79,9 +113,11 @@ pub(crate) async fn server_main(args: ServerArgs) {
tx_counter: Default::default(),
txs: Default::default(),
};
let cors = CorsLayer::new()
.allow_methods([Method::GET, Method::POST, Method::PUT, Method::DELETE])
.allow_origin(Any);
let app = Router::new()
.fallback(not_found)
.route("/", get(root))
.route("/text-query", post(text_query))
.route("/export/:relations", get(export_relations))
.route("/import", put(import_relations))
@ -94,8 +130,54 @@ pub(crate) async fn server_main(args: ServerArgs) {
post(post_rule_result).delete(post_rule_err),
) // +keep alive
.route("/transact", post(start_transact))
.route("/transact/:id", post(transact_query).put(finish_query)) // +keep alive
.with_state(state);
.route("/transact/:id", post(transact_query).put(finish_query))
.with_state(state)
.layer(RequireAuthorizationLayer::custom(
move |request: &mut Request<Body>| {
if skip_auth {
return Ok(());
}
let ok = match request.headers().get("x-cozo-auth") {
None => match request.uri().query() {
None => false,
Some(q_str) => {
let mut bingo = false;
for pair in q_str.split('&') {
if let Some((k, v)) = pair.split_once('=') {
if k == "auth" {
if v == &auth_guard {
bingo = true
}
break;
}
}
}
bingo
}
},
Some(data) => match data.to_str() {
Ok(s) => s == &auth_guard,
Err(_) => false,
},
};
if ok {
Ok(())
} else {
let unauthorized_response = Response::builder()
.status(StatusCode::UNAUTHORIZED)
.body(BoxBody::default())
.unwrap();
Err(unauthorized_response.into())
}
},
))
.fallback(not_found)
.route("/", get(root))
.layer(cors)
.layer(CompressionLayer::new());
let addr = if Ipv6Addr::from_str(&args.bind).is_ok() {
SocketAddr::from_str(&format!("[{}]:{}", args.bind, args.port)).unwrap()
} else {
@ -104,6 +186,7 @@ pub(crate) async fn server_main(args: ServerArgs) {
if args.bind != "127.0.0.1" {
warn!("{}", include_str!("./security.txt"));
info!("The auth token is in the file: {conf_path}");
}
info!(

Loading…
Cancel
Save