use axum for server; implement callback SSE

main
Ziyang Hu 2 years ago
parent b30ebf7b77
commit 0b0830d1ba

436
Cargo.lock generated

@ -17,12 +17,6 @@ version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe"
[[package]]
name = "adler32"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aae1277d39aeec15cb388266ecc24b11c80469deae6067e17a1a7aa9e5c1f234"
[[package]]
name = "ahash"
version = "0.7.6"
@ -55,21 +49,6 @@ dependencies = [
"memchr",
]
[[package]]
name = "alloc-no-stdlib"
version = "2.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cc7bb162ec39d46ab1ca8c77bf72e890535becd1751bb45f64c597edb4c8c6b3"
[[package]]
name = "alloc-stdlib"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94fb8275041c72129eb51b7d0322c29b8387a0386127718b096429201a5d6ece"
dependencies = [
"alloc-no-stdlib",
]
[[package]]
name = "android_system_properties"
version = "0.1.5"
@ -101,10 +80,25 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8da52d66c7071e2e3fa2a1e5c6d088fec47b593032b254f5e980de8ea54454d6"
[[package]]
name = "ascii"
version = "1.1.0"
name = "async-stream"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d92bec98840b8f03a5ff5413de5293bfcd8bf96467cf5452609f939ec6f5de16"
checksum = "dad5c83079eae9969be7fadefe640a1c566901f05ff91ab221de4b6f68d9507e"
dependencies = [
"async-stream-impl",
"futures-core",
]
[[package]]
name = "async-stream-impl"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "10f203db73a71dfa2fb6dd22763990fa26f3d2625a6da2da900d23b87d26be27"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "async-trait"
@ -158,6 +152,68 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
[[package]]
name = "axum"
version = "0.6.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1304eab461cf02bd70b083ed8273388f9724c549b316ba3d1e213ce0e9e7fb7e"
dependencies = [
"async-trait",
"axum-core",
"bitflags",
"bytes",
"futures-util",
"http",
"http-body",
"hyper",
"itoa 1.0.5",
"matchit",
"memchr",
"mime",
"percent-encoding",
"pin-project-lite",
"rustversion",
"serde",
"serde_json",
"serde_path_to_error",
"serde_urlencoded",
"sync_wrapper",
"tokio",
"tower",
"tower-http",
"tower-layer",
"tower-service",
]
[[package]]
name = "axum-core"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f487e40dc9daee24d8a1779df88522f159a54a980f99cfbe43db0be0bd3444a8"
dependencies = [
"async-trait",
"bytes",
"futures-util",
"http",
"http-body",
"mime",
"rustversion",
"tower-layer",
"tower-service",
]
[[package]]
name = "axum-macros"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cc7d7c3e69f305217e317a28172aab29f275667f2e1c15b87451e134fe27c7b1"
dependencies = [
"heck 0.4.0",
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "backtrace"
version = "0.3.67"
@ -228,27 +284,6 @@ dependencies = [
"cmake",
]
[[package]]
name = "brotli"
version = "3.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a1a0b1dbcc8ae29329621f8d4f0d835787c1c38bb1401979b49d13b0b305ff68"
dependencies = [
"alloc-no-stdlib",
"alloc-stdlib",
"brotli-decompressor",
]
[[package]]
name = "brotli-decompressor"
version = "2.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "59ad2d4653bf5ca36ae797b1f4bb4dbddb60ce49ca4aed8a2ce4829f60425b80"
dependencies = [
"alloc-no-stdlib",
"alloc-stdlib",
]
[[package]]
name = "bstr"
version = "0.2.17"
@ -261,16 +296,6 @@ dependencies = [
"serde",
]
[[package]]
name = "buf_redux"
version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b953a6887648bb07a535631f2bc00fbdb2a2216f135552cb3f534ed136b9c07f"
dependencies = [
"memchr",
"safemem",
]
[[package]]
name = "bumpalo"
version = "3.11.1"
@ -369,7 +394,7 @@ dependencies = [
"js-sys",
"num-integer",
"num-traits",
"time 0.1.45",
"time",
"wasm-bindgen",
"winapi",
]
@ -396,12 +421,6 @@ dependencies = [
"phf_codegen",
]
[[package]]
name = "chunked_transfer"
version = "1.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cca491388666e04d7248af3f60f0c40cfb0991c72205595d7c396e3510207d1a"
[[package]]
name = "clang-sys"
version = "1.4.0"
@ -595,21 +614,27 @@ dependencies = [
name = "cozo-bin"
version = "0.5.0"
dependencies = [
"async-stream",
"axum",
"axum-macros",
"chrono",
"clap 4.0.32",
"cozo",
"crossbeam",
"ctrlc",
"env_logger",
"futures 0.3.25",
"itertools 0.10.5",
"log",
"miette",
"minreq",
"prettytable",
"rand 0.8.5",
"rouille",
"rustyline",
"serde",
"serde_derive",
"serde_json",
"tokio",
]
[[package]]
@ -856,16 +881,6 @@ dependencies = [
"syn",
]
[[package]]
name = "deflate"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c86f7e25f518f4b81808a2cf1c50996a61f5c2eb394b2393bd87f2a4780a432f"
dependencies = [
"adler32",
"gzip-header",
]
[[package]]
name = "delegate"
version = "0.8.0"
@ -1036,18 +1051,6 @@ dependencies = [
"windows-sys",
]
[[package]]
name = "filetime"
version = "0.2.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4e884668cd0c7480504233e951174ddc3b382f7c2666e3b7310b5c4e7b0c37f9"
dependencies = [
"cfg-if 1.0.0",
"libc",
"redox_syscall",
"windows-sys",
]
[[package]]
name = "fixedbitset"
version = "0.2.0"
@ -1358,15 +1361,6 @@ dependencies = [
"walkdir",
]
[[package]]
name = "gzip-header"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "95cc527b92e6029a62960ad99aa8a6660faa4555fe5f731aab13aa6a921795a2"
dependencies = [
"crc32fast",
]
[[package]]
name = "h2"
version = "0.3.15"
@ -1456,6 +1450,12 @@ dependencies = [
"pin-project-lite",
]
[[package]]
name = "http-range-header"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0bfe8eed0a9285ef776bb792479ea3834e8b94e13d615c2f66d03dd50a435a29"
[[package]]
name = "httparse"
version = "1.8.0"
@ -1781,6 +1781,12 @@ dependencies = [
"libc",
]
[[package]]
name = "matchit"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b87248edafb776e59e6ee64a79086f65890d3510f2c656c000bf2a7e8a0aea40"
[[package]]
name = "memchr"
version = "2.5.0"
@ -1857,16 +1863,6 @@ version = "0.3.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2a60c7ce501c71e03a9c9c0d35b861413ae925bd979cc7a4e30d060069aaac8d"
[[package]]
name = "mime_guess"
version = "2.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4192263c238a5f0d0c6bfd21f336a313a4ce1c450542449ca191bb657b4642ef"
dependencies = [
"mime",
"unicase",
]
[[package]]
name = "miniz_oxide"
version = "0.6.2"
@ -1907,24 +1903,6 @@ version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5ce46fe64a9d73be07dcbe690a38ce1b293be448fd8ce1e6c1b8062c9f72c6a"
[[package]]
name = "multipart"
version = "0.18.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "00dec633863867f29cb39df64a397cdf4a6354708ddd7759f70c7fb51c5f9182"
dependencies = [
"buf_redux",
"httparse",
"log",
"mime",
"mime_guess",
"quick-error",
"rand 0.8.5",
"safemem",
"tempfile",
"twoway",
]
[[package]]
name = "nanorand"
version = "0.7.0"
@ -2128,15 +2106,6 @@ dependencies = [
"libc",
]
[[package]]
name = "num_threads"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2819ce041d2ee131036f4fc9d6ae7ae125a3a40e97ba64d04fe799ad9dabbb44"
dependencies = [
"libc",
]
[[package]]
name = "object"
version = "0.30.2"
@ -2415,6 +2384,26 @@ dependencies = [
"siphasher",
]
[[package]]
name = "pin-project"
version = "1.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ad29a609b6bcd67fee905812e544992d216af9d755757c05ed2d0e15a74c6ecc"
dependencies = [
"pin-project-internal",
]
[[package]]
name = "pin-project-internal"
version = "1.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "069bdb1e05adc7a8990dce9cc75370895fbe4e3d58b9b73bf1aee56359344a55"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "pin-project-lite"
version = "0.2.9"
@ -2724,12 +2713,6 @@ dependencies = [
"syn",
]
[[package]]
name = "quick-error"
version = "1.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0"
[[package]]
name = "quote"
version = "1.0.23"
@ -2978,31 +2961,6 @@ dependencies = [
"rmp",
]
[[package]]
name = "rouille"
version = "3.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4f86e4c51a773f953f02bbab5fd049f004bfd384341d62da2a079aff812ab176"
dependencies = [
"base64 0.13.1",
"brotli",
"chrono",
"deflate",
"filetime",
"multipart",
"num_cpus",
"percent-encoding",
"rand 0.8.5",
"serde",
"serde_derive",
"serde_json",
"sha1",
"threadpool",
"time 0.3.17",
"tiny_http",
"url",
]
[[package]]
name = "rustc-demangle"
version = "0.1.21"
@ -3076,12 +3034,6 @@ version = "1.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b4b9743ed687d4b4bcedf9ff5eaa7398495ae14e61cba0a295704edbc7decde"
[[package]]
name = "safemem"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ef703b7cb59335eae2eb93ceb664c0eb7ea6bf567079d843e09420219668e072"
[[package]]
name = "same-file"
version = "1.0.6"
@ -3206,6 +3158,15 @@ dependencies = [
"serde",
]
[[package]]
name = "serde_path_to_error"
version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "26b04f22b563c91331a10074bda3dd5492e3cc39d56bd557e91c0af42b6c7341"
dependencies = [
"serde",
]
[[package]]
name = "serde_urlencoded"
version = "0.7.1"
@ -3218,17 +3179,6 @@ dependencies = [
"serde",
]
[[package]]
name = "sha1"
version = "0.10.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f04293dc80c3993519f2d7f6f511707ee7094fe0c6d3406feb330cdb3540eba3"
dependencies = [
"cfg-if 1.0.0",
"cpufeatures",
"digest",
]
[[package]]
name = "sha2"
version = "0.10.6"
@ -3246,6 +3196,15 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7fdf1b9db47230893d76faad238fd6097fd6d6a9245cd7a4d90dbd639536bbd2"
[[package]]
name = "signal-hook-registry"
version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e51e73328dc4ac0c7ccbda3a494dfa03df1de2f46018127f60c693f2648455b0"
dependencies = [
"libc",
]
[[package]]
name = "siphasher"
version = "0.3.10"
@ -3463,6 +3422,12 @@ dependencies = [
"syn",
]
[[package]]
name = "sync_wrapper"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "20518fe4a4c9acf048008599e464deb21beeae3d3578418951a189c235a7a9a8"
[[package]]
name = "target-lexicon"
version = "0.12.5"
@ -3550,15 +3515,6 @@ dependencies = [
"syn",
]
[[package]]
name = "threadpool"
version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d050e60b33d41c19108b32cea32164033a9013fe3b46cbd4457559bfbf77afaa"
dependencies = [
"num_cpus",
]
[[package]]
name = "tikv-client"
version = "0.1.0"
@ -3687,36 +3643,6 @@ dependencies = [
"winapi",
]
[[package]]
name = "time"
version = "0.3.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a561bf4617eebd33bca6434b988f39ed798e527f51a1e797d0ee4f61c0a38376"
dependencies = [
"libc",
"num_threads",
"serde",
"time-core",
]
[[package]]
name = "time-core"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2e153e1f1acaef8acc537e68b44906d2db6436e2b35ac2c6b42640fff91f00fd"
[[package]]
name = "tiny_http"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "389915df6413a2e74fb181895f933386023c71110878cd0825588928e64cdc82"
dependencies = [
"ascii",
"chunked_transfer",
"httpdate",
"log",
]
[[package]]
name = "tinyvec"
version = "1.6.0"
@ -3744,11 +3670,25 @@ dependencies = [
"memchr",
"mio",
"num_cpus",
"parking_lot 0.12.1",
"pin-project-lite",
"signal-hook-registry",
"socket2",
"tokio-macros",
"windows-sys",
]
[[package]]
name = "tokio-macros"
version = "1.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d266c00fde287f55d3f1c3e96c500c362a2b8c695076ec180f27918820bc6df8"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "tokio-native-tls"
version = "0.3.0"
@ -3782,6 +3722,47 @@ dependencies = [
"serde",
]
[[package]]
name = "tower"
version = "0.4.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c"
dependencies = [
"futures-core",
"futures-util",
"pin-project",
"pin-project-lite",
"tokio",
"tower-layer",
"tower-service",
"tracing",
]
[[package]]
name = "tower-http"
version = "0.3.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f873044bf02dd1e8239e9c1293ea39dad76dc594ec16185d0a1bf31d8dc8d858"
dependencies = [
"bitflags",
"bytes",
"futures-core",
"futures-util",
"http",
"http-body",
"http-range-header",
"pin-project-lite",
"tower",
"tower-layer",
"tower-service",
]
[[package]]
name = "tower-layer"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c20c8dbed6283a09604c3e69b4b7eeb54e298b8a600d4d5ecb5ad39de609f1d0"
[[package]]
name = "tower-service"
version = "0.3.2"
@ -3795,6 +3776,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8ce8c33a8d48bd45d624a6e523445fd21ec13d3653cd51f681abf67418f54eb8"
dependencies = [
"cfg-if 1.0.0",
"log",
"pin-project-lite",
"tracing-core",
]
@ -3814,15 +3796,6 @@ version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3528ecfd12c466c6f163363caf2d02a71161dd5e1cc6ae7b34207ea2d42d81ed"
[[package]]
name = "twoway"
version = "0.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "59b11b2b5241ba34be09c3cc85a36e56e48f9888862e19cedf23336d35316ed1"
dependencies = [
"memchr",
]
[[package]]
name = "typenum"
version = "1.16.0"
@ -3835,15 +3808,6 @@ version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9e79c4d996edb816c91e4308506774452e55e95c3c9de07b6729e17e15a5ef81"
[[package]]
name = "unicase"
version = "2.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "50f37be617794602aabbeee0be4f259dc1778fabe05e2d67ee8f79326d5cb4f6"
dependencies = [
"version_check",
]
[[package]]
name = "unicode-bidi"
version = "0.3.8"

@ -44,7 +44,6 @@ storage-tikv = ["cozo/storage-tikv"]
[dependencies]
cozo = { version = "0.5.0", path = "../cozo-core", default-features = false }
clap = { version = "4.0.26", features = ["derive"] }
rouille = "3.5.0"
env_logger = "0.10.0"
log = "0.4.17"
rand = "0.8.5"
@ -57,3 +56,10 @@ rustyline = "10.0.0"
minreq = { version = "2.6.0", features = ["https-rustls"] }
miette = { version = "5.5.0", features = ["fancy"] }
ctrlc = "3.2.4"
axum = "0.6.2"
axum-macros = "0.3.1"
itertools = "0.10.5"
tokio = { version = "1.24.1", features = ["full"] }
async-stream = "0.3.3"
futures = "0.3.25"
crossbeam = "0.8.2"

@ -0,0 +1,8 @@
/*
* Copyright 2023, The Cozo Project Authors.
*
* This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0.
* If a copy of the MPL was not distributed with this file,
* You can obtain one at https://mozilla.org/MPL/2.0/.
*/

@ -0,0 +1,52 @@
<!DOCTYPE html>
<html lang="en">
<head>
<link rel="icon" href="data:;base64,iVBORw0KGgo=">
<title>Cozo database</title>
</head>
<body>
<p>Cozo API is running.</p>
<script>
let COZO_AUTH = '';
let LAST_RESP = null;
async function run(script, params) {
const resp = await fetch('/text-query', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'x-cozo-auth': COZO_AUTH
},
body: JSON.stringify({
script,
params: params || {}
})
});
if (resp.ok) {
const json_resp = await resp.json();
LAST_RESP = json_resp;
if (json_resp) {
json_resp.headers ||= [];
console.table(json_resp.rows.map(row => {
let ret = {};
for (let i = 0; i < row.length; ++i) {
ret[json_resp.headers[i] || `(${i})`] = row[i];
}
return ret
}))
}
} else {
console.error((await resp.json()).display)
}
}
console.log(
`Welcome to the Cozo Makeshift Javascript Console!
You can run your query like this:
await run("YOUR QUERY HERE", {param: value})
The global variables 'COZO_AUTH' and 'LAST_RESP' are available.`);
</script>
</body>
</html>

@ -6,25 +6,25 @@
* You can obtain one at https://mozilla.org/MPL/2.0/.
*/
use std::collections::BTreeMap;
use std::fmt::Debug;
use std::fs;
use std::net::Ipv6Addr;
use std::process::exit;
use std::str::FromStr;
use clap::{Args, Parser, Subcommand};
use env_logger::Env;
use log::{error, info};
use rand::Rng;
use rouille::{router, try_or_400, Request, Response};
use serde_json::json;
use cozo::*;
use crate::repl::repl_main;
use crate::server::{server_main, ServerArgs};
// use rand::Rng;
// use serde_json::json;
// use cozo::*;
mod client;
mod repl;
mod server;
#[derive(Parser)]
#[command(author, version, about, long_about = None)]
@ -37,14 +37,15 @@ struct AppArgs {
#[derive(Subcommand)]
enum Commands {
///
Server(Server),
Client(Client),
Repl(Repl),
Restore(Restore),
Server(ServerArgs),
Client(ClientArgs),
Repl(ReplArgs),
Restore(RestoreArgs),
Stream(StreamArgs),
}
#[derive(Args, Debug)]
struct Repl {
struct ReplArgs {
/// Database engine, can be `mem`, `sqlite`, `rocksdb` and others.
#[clap(short, long, default_value_t = String::from("mem"))]
engine: String,
@ -59,7 +60,7 @@ struct Repl {
}
#[derive(Args, Debug)]
struct Client {
struct ClientArgs {
#[clap(default_value_t = String::from("http://127.0.0.1:9070"))]
address: String,
#[clap(short, long, default_value_t = String::from(""))]
@ -67,7 +68,10 @@ struct Client {
}
#[derive(Args, Debug)]
struct Restore {
struct StreamArgs {}
#[derive(Args, Debug)]
struct RestoreArgs {
/// Path of the backup file to restore from, must be a SQLite-backed backup file.
from: String,
/// Path of the database to restore into
@ -77,47 +81,18 @@ struct Restore {
engine: String,
}
#[derive(Args, Debug)]
struct Server {
/// Database engine, can be `mem`, `sqlite`, `rocksdb` and others.
#[clap(short, long, default_value_t = String::from("mem"))]
engine: String,
/// Path to the directory to store the database
#[clap(short, long, default_value_t = String::from("cozo.db"))]
path: String,
// Restore from the specified backup before starting the server
// #[clap(long)]
// restore: Option<String>,
/// Extra config in JSON format
#[clap(short, long, default_value_t = String::from("{}"))]
config: String,
// When on, start REPL instead of starting a webserver
// #[clap(short, long)]
// repl: bool,
/// Address to bind the service to
#[clap(short, long, default_value_t = String::from("127.0.0.1"))]
bind: String,
/// Port to use
#[clap(short = 'P', long, default_value_t = 9070)]
port: u16,
}
macro_rules! check_auth {
($request:expr, $auth_guard:expr) => {
match $request.header("x-cozo-auth") {
None => return Response::text("Unauthorized").with_status_code(401),
Some(code) => {
if $auth_guard != code {
return Response::text("Unauthorized").with_status_code(401);
}
}
}
};
}
// macro_rules! check_auth {
// ($request:expr, $auth_guard:expr) => {
// match $request.header("x-cozo-auth") {
// None => return Response::text("Unauthorized").with_status_code(401),
// Some(code) => {
// if $auth_guard != code {
// return Response::text("Unauthorized").with_status_code(401);
// }
// }
// }
// };
// }
fn main() {
let args = match AppArgs::parse().command {
@ -131,17 +106,10 @@ fn main() {
Commands::Restore(_) => {
todo!()
}
};
if args.bind != "127.0.0.1" {
eprintln!("{SECURITY_WARNING}");
Commands::Stream(_) => {
todo!()
}
let db = DbInstance::new(
args.engine.as_str(),
args.path.as_str(),
&args.config.clone(),
)
.unwrap();
};
// if let Some(restore_path) = &args.restore {
// db.restore_backup(restore_path).unwrap();
@ -169,233 +137,54 @@ fn main() {
// }
// } else {
env_logger::Builder::from_env(Env::default().default_filter_or("info")).init();
server_main(args, db)
// }
}
fn server_main(args: Server, db: DbInstance) {
let conf_path = format!("{}.{}.cozo_auth", args.path, args.engine);
let auth_guard = match fs::read_to_string(&conf_path) {
Ok(s) => s.trim().to_string(),
Err(_) => {
let s = rand::thread_rng()
.sample_iter(&rand::distributions::Alphanumeric)
.take(64)
.map(char::from)
.collect();
fs::write(&conf_path, &s).unwrap();
s
}
};
let addr = if Ipv6Addr::from_str(&args.bind).is_ok() {
format!("[{}]:{}", args.bind, args.port)
} else {
format!("{}:{}", args.bind, args.port)
};
println!(
"Database ({} backend) web API running at http://{}",
args.engine, addr
);
println!("The auth file is at {conf_path}");
rouille::start_server(addr, move |request| {
let now = chrono::Utc::now().format("%Y-%m-%d %H:%M:%S%.6f");
let log_ok = |req: &Request, _resp: &Response, elap: std::time::Duration| {
info!("{} {} {} {:?}", now, req.method(), req.raw_url(), elap);
};
let log_err = |req: &Request, elap: std::time::Duration| {
error!(
"{} Handler panicked: {} {} {:?}",
now,
req.method(),
req.raw_url(),
elap
);
};
rouille::log_custom(request, log_ok, log_err, || {
router!(request,
(POST) (/text-query) => {
if !request.remote_addr().ip().is_loopback() {
check_auth!(request, auth_guard);
}
#[derive(serde_derive::Serialize, serde_derive::Deserialize)]
struct QueryPayload {
script: String,
params: BTreeMap<String, serde_json::Value>,
}
let payload: QueryPayload = try_or_400!(rouille::input::json_input(request));
let params = payload.params.into_iter().map(|(k, v)|
(k, DataValue::from(v))).collect();
let result = db.run_script_fold_err(&payload.script, params);
let response = Response::json(&result);
if let Some(serde_json::Value::Bool(true)) = result.get("ok") {
response
} else {
response.with_status_code(400)
}
},
(GET) (/export/{relations: String}) => {
if !request.remote_addr().ip().is_loopback() {
check_auth!(request, auth_guard);
}
let relations = relations.split(',').filter(|t| !t.is_empty());
let result = db.export_relations(relations);
match result {
Ok(s) => {
let ret = json!({"ok": true, "data": s});
Response::json(&ret)
}
Err(err) => {
let ret = json!({"ok": false, "message": err.to_string()});
Response::json(&ret).with_status_code(400)
}
}
},
(PUT) (/import) => {
if !request.remote_addr().ip().is_loopback() {
check_auth!(request, auth_guard);
}
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap()
.block_on(server_main(args))
let payload: BTreeMap<String, NamedRows> = try_or_400!(rouille::input::json_input(request));
let result = db.import_relations(payload);
match result {
Ok(()) => {
let ret = json!({"ok": true});
Response::json(&ret)
}
Err(err) => {
let ret = json!({"ok": false, "message": err.to_string()});
Response::json(&ret).with_status_code(400)
}
}
},
(POST) (/backup) => {
if !request.remote_addr().ip().is_loopback() {
check_auth!(request, auth_guard);
}
#[derive(serde_derive::Serialize, serde_derive::Deserialize)]
struct BackupPayload {
path: String,
}
let payload: BackupPayload = try_or_400!(rouille::input::json_input(request));
let result = db.backup_db(payload.path);
match result {
Ok(()) => {
let ret = json!({"ok": true});
Response::json(&ret)
}
Err(err) => {
let ret = json!({"ok": false, "message": err.to_string()});
Response::json(&ret).with_status_code(400)
}
}
},
(POST) (/import-from-backup) => {
if !request.remote_addr().ip().is_loopback() {
check_auth!(request, auth_guard);
}
#[derive(serde_derive::Serialize, serde_derive::Deserialize)]
struct BackupImportPayload {
path: String,
relations: Vec<String>
}
let payload: BackupImportPayload = try_or_400!(rouille::input::json_input(request));
let result = db.import_from_backup(&payload.path, &payload.relations);
match result {
Ok(()) => {
let ret = json!({"ok": true});
Response::json(&ret)
}
Err(err) => {
let ret = json!({"ok": false, "message": err.to_string()});
Response::json(&ret).with_status_code(400)
}
}
},
(GET) (/) => {
Response::html(HTML_CONTENT)
},
_ => Response::empty_404()
)
})
});
}
const HTML_CONTENT: &str = r##"
<!DOCTYPE html>
<html lang="en">
<head>
<link rel="icon" href="data:;base64,iVBORw0KGgo=">
<title>Cozo database</title>
</head>
<body>
<p>Cozo HTTP server is running.</p>
<script>
let COZO_AUTH = '';
let LAST_RESP = null;
async function run(script, params) {
const resp = await fetch('/text-query', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'x-cozo-auth': COZO_AUTH
},
body: JSON.stringify({
script,
params: params || {}
})
});
if (resp.ok) {
const json_resp = await resp.json();
LAST_RESP = json_resp;
if (json_resp) {
json_resp.headers ||= [];
console.table(json_resp.rows.map(row => {
let ret = {};
for (let i = 0; i < row.length; ++i) {
ret[json_resp.headers[i] || `(${i})`] = row[i];
}
return ret
}))
}
} else {
console.error((await resp.json()).display)
}
// server_main(args, db)
// }
}
console.log(
`Welcome to the Cozo Makeshift Javascript Console!
You can run your query like this:
await run("YOUR QUERY HERE", {param: value})
The global variables 'COZO_AUTH' and 'LAST_RESP' are available.`);
</script>
</body>
</html>
"##;
const SECURITY_WARNING: &str = r#"
====================================================================================
!! SECURITY NOTICE, PLEASE READ !!
====================================================================================
You instructed Cozo to bind to a non-default address.
Cozo is designed to be accessed by trusted clients in a trusted network.
As a last defense against unauthorized access when everything else fails,
any requests from non-loopback addresses require the HTTP request header
`x-cozo-auth` to be set to the content of auth.txt in your database directory.
This is not a sufficient protection against attacks, and you must set up
proper authentication schemes, encryptions, etc. by firewalls and/or proxies.
====================================================================================
"#;
// fn server_main(args: Server, db: DbInstance) {
// let conf_path = format!("{}.{}.cozo_auth", args.path, args.engine);
// let auth_guard = match fs::read_to_string(&conf_path) {
// Ok(s) => s.trim().to_string(),
// Err(_) => {
// let s = rand::thread_rng()
// .sample_iter(&rand::distributions::Alphanumeric)
// .take(64)
// .map(char::from)
// .collect();
// fs::write(&conf_path, &s).unwrap();
// s
// }
// };
//
// let addr = if Ipv6Addr::from_str(&args.bind).is_ok() {
// format!("[{}]:{}", args.bind, args.port)
// } else {
// format!("{}:{}", args.bind, args.port)
// };
// println!(
// "Database ({} backend) web API running at http://{}",
// args.engine, addr
// );
// println!("The auth file is at {conf_path}");
// rouille::start_server(addr, move |request| {
// let now = chrono::Utc::now().format("%Y-%m-%d %H:%M:%S%.6f");
// let log_ok = |req: &Request, _resp: &Response, elap: std::time::Duration| {
// info!("{} {} {} {:?}", now, req.method(), req.raw_url(), elap);
// };
// let log_err = |req: &Request, elap: std::time::Duration| {
// error!(
// "{} Handler panicked: {} {} {:?}",
// now,
// req.method(),
// req.raw_url(),
// elap
// );
// };
// });
// }

@ -0,0 +1,12 @@
====================================================================================
!! SECURITY NOTICE, PLEASE READ !!
====================================================================================
You instructed Cozo to bind to a non-default address.
Cozo is designed to be accessed by trusted clients in a trusted network.
As a last defense against unauthorized access when everything else fails,
any requests from non-loopback addresses require the HTTP request header
`x-cozo-auth` to be set to the content of auth.txt in your database directory.
This is not a sufficient protection against attacks, and you must set up
proper authentication schemes, encryptions, etc. by firewalls and/or proxies.
====================================================================================

@ -0,0 +1,348 @@
/*
* Copyright 2023, The Cozo Project Authors.
*
* This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0.
* If a copy of the MPL was not distributed with this file,
* You can obtain one at https://mozilla.org/MPL/2.0/.
*/
use std::collections::BTreeMap;
use std::convert::Infallible;
use std::net::{Ipv6Addr, SocketAddr};
use std::str::FromStr;
use std::sync::{mpsc, Arc, Mutex};
use axum::extract::{Path, Query, State};
use axum::http::StatusCode;
use axum::response::sse::{Event, KeepAlive};
use axum::response::{Html, Sse};
use axum::routing::{get, post, put};
use axum::{Json, Router};
use clap::Args;
use futures::stream::{self, Stream};
use itertools::Itertools;
use log::{info, warn};
use serde_json::json;
use tokio::task::spawn_blocking;
use cozo::{DataValue, DbInstance, NamedRows};
#[derive(Args, Debug)]
pub(crate) struct ServerArgs {
/// Database engine, can be `mem`, `sqlite`, `rocksdb` and others.
#[clap(short, long, default_value_t = String::from("mem"))]
engine: String,
/// Path to the directory to store the database
#[clap(short, long, default_value_t = String::from("cozo.db"))]
path: String,
// Restore from the specified backup before starting the server
// #[clap(long)]
// restore: Option<String>,
/// Extra config in JSON format
#[clap(short, long, default_value_t = String::from("{}"))]
config: String,
// When on, start REPL instead of starting a webserver
// #[clap(short, long)]
// repl: bool,
/// Address to bind the service to
#[clap(short, long, default_value_t = String::from("127.0.0.1"))]
bind: String,
/// Port to use
#[clap(short = 'P', long, default_value_t = 9070)]
port: u16,
}
type RuleCallbackStore = BTreeMap<usize, crossbeam::channel::Sender<miette::Result<NamedRows>>>;
type DbState = (DbInstance, Arc<Mutex<RuleCallbackStore>>);
pub(crate) async fn server_main(args: ServerArgs) {
let db = DbInstance::new(&args.engine, args.path, &args.config).unwrap();
let rule_channels: Arc<Mutex<RuleCallbackStore>> = Default::default();
let state = (db, rule_channels);
let app = Router::new()
.fallback(not_found)
.route("/", get(root))
.route("/text-query", post(text_query))
.route("/export/:relations", get(export_relations))
.route("/import", put(import_relations))
.route("/backup", post(backup))
.route("/import-from-backup", post(import_from_backup))
.route("/changes/:relations", get(observe_changes))
.route("/rules/:name", get(register_rule)) // sse + post
.route("/rules/:name/:id", post(rule_result))
.with_state(state);
let addr = if Ipv6Addr::from_str(&args.bind).is_ok() {
SocketAddr::from_str(&format!("[{}]:{}", args.bind, args.port)).unwrap()
} else {
SocketAddr::from_str(&format!("{}:{}", args.bind, args.port)).unwrap()
};
if args.bind != "127.0.0.1" {
warn!("{}", include_str!("./security.txt"));
}
info!(
"Starting Cozo ({}-backed) API at http://{}",
args.engine, addr
);
axum::Server::bind(&addr)
.serve(app.into_make_service())
.await
.unwrap();
}
#[derive(serde_derive::Deserialize)]
struct QueryPayload {
script: String,
params: BTreeMap<String, serde_json::Value>,
}
async fn text_query(
State((db, _)): State<DbState>,
Json(payload): Json<QueryPayload>,
) -> (StatusCode, Json<serde_json::Value>) {
let params = payload
.params
.into_iter()
.map(|(k, v)| (k, DataValue::from(v)))
.collect();
let result = spawn_blocking(move || db.run_script_fold_err(&payload.script, params)).await;
match result {
Ok(res) => wrap_json(res),
Err(err) => internal_error(err),
}
}
async fn export_relations(
State((db, _)): State<DbState>,
Path(relations): Path<String>,
) -> (StatusCode, Json<serde_json::Value>) {
let relations = relations
.split(',')
.filter_map(|t| {
if t.is_empty() {
None
} else {
Some(t.to_string())
}
})
.collect_vec();
let result = spawn_blocking(move || db.export_relations(relations.iter())).await;
match result {
Ok(Ok(s)) => {
let ret = json!({"ok": true, "data": s});
(StatusCode::OK, ret.into())
}
Ok(Err(err)) => {
let ret = json!({"ok": false, "message": err.to_string()});
(StatusCode::BAD_REQUEST, ret.into())
}
Err(err) => internal_error(err),
}
}
async fn import_relations(
State((db, _)): State<DbState>,
Json(payload): Json<serde_json::Value>,
) -> (StatusCode, Json<serde_json::Value>) {
let payload = match payload.as_object() {
None => {
return (
StatusCode::BAD_REQUEST,
json!({"ok": false, "message": "payload must be a JSON object"}).into(),
)
}
Some(pl) => {
let mut ret = BTreeMap::new();
for (k, v) in pl {
let nr = match NamedRows::from_json(v) {
Ok(p) => p,
Err(err) => {
return (
StatusCode::BAD_REQUEST,
json!({"ok": false, "message": err.to_string()}).into(),
)
}
};
ret.insert(k.to_string(), nr);
}
ret
}
};
let result = spawn_blocking(move || db.import_relations(payload)).await;
match result {
Ok(Ok(_)) => (StatusCode::OK, json!({"ok": true}).into()),
Ok(Err(err)) => {
let ret = json!({"ok": false, "message": err.to_string()});
(StatusCode::BAD_REQUEST, ret.into())
}
Err(err) => internal_error(err),
}
}
#[derive(serde_derive::Deserialize)]
struct BackupPayload {
path: String,
}
async fn backup(
State((db, _)): State<DbState>,
Json(payload): Json<BackupPayload>,
) -> (StatusCode, Json<serde_json::Value>) {
let result = spawn_blocking(move || db.backup_db(payload.path)).await;
match result {
Ok(Ok(())) => {
let ret = json!({"ok": true});
(StatusCode::OK, ret.into())
}
Ok(Err(err)) => {
let ret = json!({"ok": false, "message": err.to_string()});
(StatusCode::BAD_REQUEST, ret.into())
}
Err(err) => internal_error(err),
}
}
#[derive(serde_derive::Deserialize)]
struct BackupImportPayload {
path: String,
relations: Vec<String>,
}
async fn import_from_backup(
State((db, _)): State<DbState>,
Json(payload): Json<BackupImportPayload>,
) -> (StatusCode, Json<serde_json::Value>) {
let result =
spawn_blocking(move || db.import_from_backup(&payload.path, &payload.relations)).await;
match result {
Ok(Ok(())) => {
let ret = json!({"ok": true});
(StatusCode::OK, ret.into())
}
Ok(Err(err)) => {
let ret = json!({"ok": false, "message": err.to_string()});
(StatusCode::BAD_REQUEST, ret.into())
}
Err(err) => internal_error(err),
}
}
#[derive(serde_derive::Deserialize)]
struct RuleRegisterOptions {
arity: usize,
}
async fn rule_result(
State((store, _)): State<DbState>,
Path(name): Path<String>,
Path(id): Path<usize>,
) -> (StatusCode, Json<serde_json::Value>) {
todo!()
}
async fn register_rule(
State((db, cbs)): State<DbState>,
Path(name): Path<String>,
Query(rule_opts): Query<RuleRegisterOptions>,
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
let (id, recv) = db.register_callback(&name, None);
let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
struct Guard {
id: u32,
db: DbInstance,
relation: String,
}
impl Drop for Guard {
fn drop(&mut self) {
info!("dropping changes SSE {}: {}", self.relation, self.id);
self.db.unregister_callback(self.id);
}
}
spawn_blocking(move || {
for data in recv {
sender.blocking_send(data).unwrap();
}
});
let stream = async_stream::stream! {
info!("starting callback SSE {}: {}", name, id);
let _guard = Guard {id, db, relation: name};
while let Some((op, new, old)) = receiver.recv().await {
let item = json!({"op": op.to_string(), "new_rows": new.into_json(), "old_rows": old.into_json()});
yield Ok(Event::default().json_data(item).unwrap());
}
};
Sse::new(stream).keep_alive(KeepAlive::default())
}
async fn observe_changes(
State((db, _)): State<DbState>,
Path(relation): Path<String>,
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
let (id, recv) = db.register_callback(&relation, None);
let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
struct Guard {
id: u32,
db: DbInstance,
relation: String,
}
impl Drop for Guard {
fn drop(&mut self) {
info!("dropping changes SSE {}: {}", self.relation, self.id);
self.db.unregister_callback(self.id);
}
}
spawn_blocking(move || {
for data in recv {
sender.blocking_send(data).unwrap();
}
});
let stream = async_stream::stream! {
info!("starting changes SSE {}: {}", relation, id);
let _guard = Guard {id, db, relation};
while let Some((op, new, old)) = receiver.recv().await {
let item = json!({"op": op.to_string(), "new_rows": new.into_json(), "old_rows": old.into_json()});
yield Ok(Event::default().json_data(item).unwrap());
}
};
Sse::new(stream).keep_alive(KeepAlive::default())
}
async fn root() -> Html<&'static str> {
Html(include_str!("./index.html"))
}
fn internal_error<E>(err: E) -> (StatusCode, Json<serde_json::Value>)
where
E: std::error::Error,
{
(
StatusCode::INTERNAL_SERVER_ERROR,
json!({"ok": false, "message": err.to_string()}).into(),
)
}
fn wrap_json(json: serde_json::Value) -> (StatusCode, Json<serde_json::Value>) {
let code = if let Some(serde_json::Value::Bool(true)) = json.get("ok") {
StatusCode::OK
} else {
StatusCode::BAD_REQUEST
};
(code, json.into())
}
pub async fn not_found(uri: axum::http::Uri) -> (StatusCode, Json<serde_json::Value>) {
(
StatusCode::NOT_FOUND,
json!({"ok": false, "message": format!("No route {}", uri)}).into(),
)
}

@ -173,6 +173,39 @@ impl NamedRows {
"next": nxt,
})
}
/// Make named rows from JSON
pub fn from_json(value: &JsonValue) -> Result<Self> {
let headers = value
.get("headers")
.ok_or_else(|| miette!("NamedRows requires 'headers' field"))?;
let headers = headers
.as_array()
.ok_or_else(|| miette!("'headers' field must be an array"))?;
let headers = headers.iter().map(|h| -> Result<String> {
let h = h.as_str().ok_or_else(|| miette!("'headers' field must be an array of strings"))?;
Ok(h.to_string())
}).try_collect()?;
let rows = value
.get("rows")
.ok_or_else(|| miette!("NamedRows requires 'rows' field"))?;
let rows = rows
.as_array()
.ok_or_else(|| miette!("'rows' field must be an array"))?;
let rows = rows
.iter()
.map(|row| -> Result<Vec<DataValue>> {
let row = row
.as_array()
.ok_or_else(|| miette!("'rows' field must be an array of arrays"))?;
Ok(row.iter().map(|el| DataValue::from(el)).collect_vec())
})
.try_collect()?;
Ok(Self {
headers,
rows,
next: None,
})
}
}
const STATUS_STR: &str = "status";

Loading…
Cancel
Save