diff --git a/Cargo.lock b/Cargo.lock index 87edfd8e..26568895 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -49,6 +49,21 @@ dependencies = [ "memchr", ] +[[package]] +name = "alloc-no-stdlib" +version = "2.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc7bb162ec39d46ab1ca8c77bf72e890535becd1751bb45f64c597edb4c8c6b3" + +[[package]] +name = "alloc-stdlib" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94fb8275041c72129eb51b7d0322c29b8387a0386127718b096429201a5d6ece" +dependencies = [ + "alloc-no-stdlib", +] + [[package]] name = "android_system_properties" version = "0.1.5" @@ -79,6 +94,20 @@ version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8da52d66c7071e2e3fa2a1e5c6d088fec47b593032b254f5e980de8ea54454d6" +[[package]] +name = "async-compression" +version = "0.3.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "942c7cd7ae39e91bde4820d74132e9862e62c2f386c3aa90ccf55949f5bad63a" +dependencies = [ + "brotli", + "flate2", + "futures-core", + "memchr", + "pin-project-lite", + "tokio", +] + [[package]] name = "async-stream" version = "0.3.3" @@ -284,6 +313,27 @@ dependencies = [ "cmake", ] +[[package]] +name = "brotli" +version = "3.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1a0b1dbcc8ae29329621f8d4f0d835787c1c38bb1401979b49d13b0b305ff68" +dependencies = [ + "alloc-no-stdlib", + "alloc-stdlib", + "brotli-decompressor", +] + +[[package]] +name = "brotli-decompressor" +version = "2.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b6561fd3f895a11e8f72af2cb7d22e08366bebc2b6b57f7744c4bda27034744" +dependencies = [ + "alloc-no-stdlib", + "alloc-stdlib", +] + [[package]] name = "bstr" version = "0.2.17" @@ -369,7 +419,7 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f4aedb84272dbe89af497cf81375129abda4fc0a9e7c5d317498c15cc30c0d27" dependencies = [ - "nom", + "nom 5.1.2", ] [[package]] @@ -636,6 +686,7 @@ dependencies = [ "serde_derive", "serde_json", "tokio", + "tower-http", ] [[package]] @@ -1641,6 +1692,15 @@ version = "2.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "30e22bd8629359895450b59ea7a776c850561b96a3b1d31321c1949d9e6c9146" +[[package]] +name = "iri-string" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f0f7638c1e223529f1bfdc48c8b133b9e0b434094d1d28473161ee48b235f78" +dependencies = [ + "nom 7.1.3", +] + [[package]] name = "is-terminal" version = "0.4.2" @@ -1918,6 +1978,22 @@ version = "0.3.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2a60c7ce501c71e03a9c9c0d35b861413ae925bd979cc7a4e30d060069aaac8d" +[[package]] +name = "mime_guess" +version = "2.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4192263c238a5f0d0c6bfd21f336a313a4ce1c450542449ca191bb657b4642ef" +dependencies = [ + "mime", + "unicase", +] + +[[package]] +name = "minimal-lexical" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" + [[package]] name = "miniz_oxide" version = "0.6.2" @@ -2065,6 +2141,16 @@ dependencies = [ "version_check", ] +[[package]] +name = "nom" +version = "7.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" +dependencies = [ + "memchr", + "minimal-lexical", +] + [[package]] name = "num" version = "0.4.0" @@ -3855,6 +3941,8 @@ version = "0.3.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f873044bf02dd1e8239e9c1293ea39dad76dc594ec16185d0a1bf31d8dc8d858" dependencies = [ + "async-compression", + "base64 0.13.1", "bitflags", "bytes", "futures-core", @@ -3862,10 +3950,19 @@ dependencies = [ "http", "http-body", "http-range-header", + "httpdate", + "iri-string", + "mime", + "mime_guess", + "percent-encoding", "pin-project-lite", + "tokio", + "tokio-util", "tower", "tower-layer", "tower-service", + "tracing", + "uuid", ] [[package]] @@ -3919,6 +4016,15 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9e79c4d996edb816c91e4308506774452e55e95c3c9de07b6729e17e15a5ef81" +[[package]] +name = "unicase" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50f37be617794602aabbeee0be4f259dc1778fabe05e2d67ee8f79326d5cb4f6" +dependencies = [ + "version_check", +] + [[package]] name = "unicode-bidi" version = "0.3.8" diff --git a/cozo-bin/Cargo.toml b/cozo-bin/Cargo.toml index 4310cc6a..d544023a 100644 --- a/cozo-bin/Cargo.toml +++ b/cozo-bin/Cargo.toml @@ -63,4 +63,5 @@ tokio = { version = "1.24.1", features = ["full"] } async-stream = "0.3.3" futures = "0.3.25" crossbeam = "0.8.2" -eventsource-client = "0.11.0" \ No newline at end of file +eventsource-client = "0.11.0" +tower-http = { version = "0.3.5", features = ["full"] } \ No newline at end of file diff --git a/cozo-bin/README.md b/cozo-bin/README.md index 75d28eea..4a0195aa 100644 --- a/cozo-bin/README.md +++ b/cozo-bin/README.md @@ -70,7 +70,7 @@ and a nicely-formatted diagnostic will be in `"display"` if available. > As a guard against users accidentally exposing sensitive data, > If you bind Cozo to non-loopback addresses, > Cozo will generate a token string and require all queries from non-loopback addresses -> to provide the token string in the HTTP header field x-cozo-auth. +> to provide the token string in the HTTP header field `x-cozo-auth`. > The warning printed when you start Cozo with a > non-default binding will tell you where to find the token string. > This “security measure” is not considered sufficient for any purpose diff --git a/cozo-bin/src/main.rs b/cozo-bin/src/main.rs index babf597f..ba1901b6 100644 --- a/cozo-bin/src/main.rs +++ b/cozo-bin/src/main.rs @@ -6,22 +6,16 @@ * You can obtain one at https://mozilla.org/MPL/2.0/. */ -use std::fmt::Debug; -use std::fs; +extern crate core; + use std::process::exit; -use clap::{Args, Parser, Subcommand}; +use clap::{Parser, Subcommand}; use env_logger::Env; -use log::{error, info}; -use crate::repl::repl_main; +use crate::repl::{repl_main, ReplArgs}; use crate::server::{server_main, ServerArgs}; -// use rand::Rng; -// use serde_json::json; - -// use cozo::*; - mod client; mod repl; mod server; @@ -36,131 +30,37 @@ struct AppArgs { #[derive(Subcommand)] enum Commands { - /// Server(ServerArgs), - Client(ClientArgs), Repl(ReplArgs), - Restore(RestoreArgs), - Stream(StreamArgs), } -#[derive(Args, Debug)] -struct ReplArgs { - /// Database engine, can be `mem`, `sqlite`, `rocksdb` and others. - #[clap(short, long, default_value_t = String::from("mem"))] - engine: String, - - /// Path to the directory to store the database - #[clap(short, long, default_value_t = String::from("cozo.db"))] - path: String, - - /// Extra config in JSON format - #[clap(short, long, default_value_t = String::from("{}"))] - config: String, -} - -#[derive(Args, Debug)] -struct ClientArgs { - #[clap(default_value_t = String::from("http://127.0.0.1:9070"))] - address: String, - #[clap(short, long, default_value_t = String::from(""))] - auth: String, -} - -#[derive(Args, Debug)] -struct StreamArgs {} - -#[derive(Args, Debug)] -struct RestoreArgs { - /// Path of the backup file to restore from, must be a SQLite-backed backup file. - from: String, - /// Path of the database to restore into - to: String, - /// Database engine for the database to restore into, can be `mem`, `sqlite`, `rocksdb` and others. - #[clap(short, long)] - engine: String, -} - -// macro_rules! check_auth { -// ($request:expr, $auth_guard:expr) => { -// match $request.header("x-cozo-auth") { -// None => return Response::text("Unauthorized").with_status_code(401), -// Some(code) => { -// if $auth_guard != code { -// return Response::text("Unauthorized").with_status_code(401); -// } -// } -// } -// }; -// } - fn main() { - let args = match AppArgs::parse().command { - Commands::Server(s) => s, - Commands::Client(_) => { - todo!() - } - Commands::Repl(_) => { - todo!() + match AppArgs::parse().command { + Commands::Server(args) => { + env_logger::Builder::from_env(Env::default().default_filter_or("info")).init(); + tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap() + .block_on(server_main(args)) } - Commands::Restore(_) => { - todo!() - } - Commands::Stream(_) => { - todo!() + Commands::Repl(args) => { + if let Err(e) = repl_main(args) { + eprintln!("{e}"); + exit(-1); + } } }; - // if let Some(restore_path) = &args.restore { - // db.restore_backup(restore_path).unwrap(); - // } - // if args.repl { - // let db_copy = db.clone(); - // ctrlc::set_handler(move || { - // let running = db_copy - // .run_script("::running", Default::default()) - // .expect("Cannot determine running queries"); - // for row in running.rows { - // let id = row.into_iter().next().unwrap(); - // eprintln!("Killing running query {id}"); - // db_copy - // .run_script("::kill $id", BTreeMap::from([("id".to_string(), id)])) - // .expect("Cannot kill process"); - // } - // }) - // .expect("Error setting Ctrl-C handler"); - // - // if let Err(e) = repl_main(db) { - // eprintln!("{e}"); - // exit(-1); - // } + // } else { - env_logger::Builder::from_env(Env::default().default_filter_or("info")).init(); - tokio::runtime::Builder::new_multi_thread() - .enable_all() - .build() - .unwrap() - .block_on(server_main(args)) // server_main(args, db) // } } // fn server_main(args: Server, db: DbInstance) { -// let conf_path = format!("{}.{}.cozo_auth", args.path, args.engine); -// let auth_guard = match fs::read_to_string(&conf_path) { -// Ok(s) => s.trim().to_string(), -// Err(_) => { -// let s = rand::thread_rng() -// .sample_iter(&rand::distributions::Alphanumeric) -// .take(64) -// .map(char::from) -// .collect(); -// fs::write(&conf_path, &s).unwrap(); -// s -// } -// }; // // let addr = if Ipv6Addr::from_str(&args.bind).is_ok() { // format!("[{}]:{}", args.bind, args.port) diff --git a/cozo-bin/src/repl.rs b/cozo-bin/src/repl.rs index 6b0a4e5d..88203e04 100644 --- a/cozo-bin/src/repl.rs +++ b/cozo-bin/src/repl.rs @@ -13,6 +13,7 @@ use std::error::Error; use std::fs::File; use std::io::{Read, Write}; +use clap::Args; use miette::{bail, miette, IntoDiagnostic}; use serde_json::{json, Value}; @@ -57,7 +58,39 @@ impl rustyline::validate::Validator for Indented { } } -pub(crate) fn repl_main(db: DbInstance) -> Result<(), Box> { +#[derive(Args, Debug)] +pub(crate) struct ReplArgs { + /// Database engine, can be `mem`, `sqlite`, `rocksdb` and others. + #[clap(short, long, default_value_t = String::from("mem"))] + engine: String, + + /// Path to the directory to store the database + #[clap(short, long, default_value_t = String::from("cozo.db"))] + path: String, + + /// Extra config in JSON format + #[clap(short, long, default_value_t = String::from("{}"))] + config: String, +} + +pub(crate) fn repl_main(args: ReplArgs) -> Result<(), Box> { + let db = DbInstance::new(&args.engine, args.path, &args.config).unwrap(); + + let db_copy = db.clone(); + ctrlc::set_handler(move || { + let running = db_copy + .run_script("::running", Default::default()) + .expect("Cannot determine running queries"); + for row in running.rows { + let id = row.into_iter().next().unwrap(); + eprintln!("Killing running query {id}"); + db_copy + .run_script("::kill $id", BTreeMap::from([("id".to_string(), id)])) + .expect("Cannot kill process"); + } + }) + .expect("Error setting Ctrl-C handler"); + println!("Welcome to the Cozo REPL."); println!("Type a space followed by newline to enter multiline mode."); diff --git a/cozo-bin/src/security.txt b/cozo-bin/src/security.txt index a7c8860c..8ba79998 100644 --- a/cozo-bin/src/security.txt +++ b/cozo-bin/src/security.txt @@ -3,10 +3,15 @@ !! SECURITY NOTICE, PLEASE READ !! ==================================================================================== You instructed Cozo to bind to a non-default address. + Cozo is designed to be accessed by trusted clients in a trusted network. As a last defense against unauthorized access when everything else fails, -any requests from non-loopback addresses require the HTTP request header -`x-cozo-auth` to be set to the content of auth.txt in your database directory. +any requests now require the HTTP request header `x-cozo-auth` to be set, +or the query parameter `auth=` be set to the auth token. +The auth token is found in a file indicated below. + +This is required even if the request comes from localhost. + This is not a sufficient protection against attacks, and you must set up proper authentication schemes, encryptions, etc. by firewalls and/or proxies. ==================================================================================== diff --git a/cozo-bin/src/server.rs b/cozo-bin/src/server.rs index 21800066..02283e53 100644 --- a/cozo-bin/src/server.rs +++ b/cozo-bin/src/server.rs @@ -14,8 +14,9 @@ use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::{Arc, Mutex}; use std::thread; +use axum::body::{Body, BoxBody}; use axum::extract::{Path, Query, State}; -use axum::http::StatusCode; +use axum::http::{Method, Request, Response, StatusCode}; use axum::response::sse::{Event, KeepAlive}; use axum::response::{Html, Sse}; use axum::routing::{get, post, put}; @@ -23,10 +24,14 @@ use axum::{Json, Router}; use clap::Args; use futures::stream::Stream; use itertools::Itertools; -use log::{info, warn}; +use log::{error, info, warn}; use miette::miette; +use rand::Rng; use serde_json::json; use tokio::task::spawn_blocking; +use tower_http::auth::RequireAuthorizationLayer; +use tower_http::compression::CompressionLayer; +use tower_http::cors::{Any, CorsLayer}; use cozo::{ format_error_as_json, DataValue, DbInstance, MultiTransaction, NamedRows, SimpleFixedRule, @@ -42,9 +47,10 @@ pub(crate) struct ServerArgs { #[clap(short, long, default_value_t = String::from("cozo.db"))] path: String, - // Restore from the specified backup before starting the server - // #[clap(long)] - // restore: Option, + /// Restore from the specified backup before starting the server + #[clap(long)] + restore: Option, + /// Extra config in JSON format #[clap(short, long, default_value_t = String::from("{}"))] config: String, @@ -71,7 +77,35 @@ struct DbState { } pub(crate) async fn server_main(args: ServerArgs) { - let db = DbInstance::new(&args.engine, args.path, &args.config).unwrap(); + let db = DbInstance::new(&args.engine, &args.path, &args.config).unwrap(); + if let Some(p) = &args.restore { + if let Err(err) = db.restore_backup(p) { + error!("{}", err); + error!("Restore from backup failed, terminate"); + panic!() + } + } + + let skip_auth = args.bind == "127.0.0.1"; + + let conf_path = if skip_auth {"".to_string()} else { format!("{}.{}.cozo_auth", args.path, args.engine)}; + let auth_guard = if skip_auth { + "".to_string() + } else { + match tokio::fs::read_to_string(&conf_path).await { + Ok(s) => s.trim().to_string(), + Err(_) => { + let s = rand::thread_rng() + .sample_iter(&rand::distributions::Alphanumeric) + .take(64) + .map(char::from) + .collect(); + tokio::fs::write(&conf_path, &s).await.unwrap(); + s + } + } + }; + let state = DbState { db, rule_senders: Default::default(), @@ -79,9 +113,11 @@ pub(crate) async fn server_main(args: ServerArgs) { tx_counter: Default::default(), txs: Default::default(), }; + let cors = CorsLayer::new() + .allow_methods([Method::GET, Method::POST, Method::PUT, Method::DELETE]) + .allow_origin(Any); + let app = Router::new() - .fallback(not_found) - .route("/", get(root)) .route("/text-query", post(text_query)) .route("/export/:relations", get(export_relations)) .route("/import", put(import_relations)) @@ -94,8 +130,54 @@ pub(crate) async fn server_main(args: ServerArgs) { post(post_rule_result).delete(post_rule_err), ) // +keep alive .route("/transact", post(start_transact)) - .route("/transact/:id", post(transact_query).put(finish_query)) // +keep alive - .with_state(state); + .route("/transact/:id", post(transact_query).put(finish_query)) + .with_state(state) + .layer(RequireAuthorizationLayer::custom( + move |request: &mut Request| { + if skip_auth { + return Ok(()); + } + + let ok = match request.headers().get("x-cozo-auth") { + None => match request.uri().query() { + None => false, + Some(q_str) => { + let mut bingo = false; + for pair in q_str.split('&') { + if let Some((k, v)) = pair.split_once('=') { + if k == "auth" { + if v == &auth_guard { + bingo = true + } + break; + } + } + } + bingo + } + }, + Some(data) => match data.to_str() { + Ok(s) => s == &auth_guard, + Err(_) => false, + }, + }; + if ok { + Ok(()) + } else { + let unauthorized_response = Response::builder() + .status(StatusCode::UNAUTHORIZED) + .body(BoxBody::default()) + .unwrap(); + + Err(unauthorized_response.into()) + } + }, + )) + .fallback(not_found) + .route("/", get(root)) + .layer(cors) + .layer(CompressionLayer::new()); + let addr = if Ipv6Addr::from_str(&args.bind).is_ok() { SocketAddr::from_str(&format!("[{}]:{}", args.bind, args.port)).unwrap() } else { @@ -104,6 +186,7 @@ pub(crate) async fn server_main(args: ServerArgs) { if args.bind != "127.0.0.1" { warn!("{}", include_str!("./security.txt")); + info!("The auth token is in the file: {conf_path}"); } info!(