From 0b0830d1bad51987038583b45ddd4d05701460cd Mon Sep 17 00:00:00 2001 From: Ziyang Hu Date: Tue, 17 Jan 2023 21:05:14 +0800 Subject: [PATCH] use axum for server; implement callback SSE --- Cargo.lock | 436 +++++++++++++++++------------------- cozo-bin/Cargo.toml | 10 +- cozo-bin/src/client.rs | 8 + cozo-bin/src/index.html | 52 +++++ cozo-bin/src/main.rs | 375 +++++++------------------------ cozo-bin/src/security.txt | 12 + cozo-bin/src/server.rs | 348 ++++++++++++++++++++++++++++ cozo-core/src/runtime/db.rs | 33 +++ 8 files changed, 743 insertions(+), 531 deletions(-) create mode 100644 cozo-bin/src/client.rs create mode 100644 cozo-bin/src/index.html create mode 100644 cozo-bin/src/security.txt create mode 100644 cozo-bin/src/server.rs diff --git a/Cargo.lock b/Cargo.lock index d5ee992c..3c61b9d5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,12 +17,6 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" -[[package]] -name = "adler32" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aae1277d39aeec15cb388266ecc24b11c80469deae6067e17a1a7aa9e5c1f234" - [[package]] name = "ahash" version = "0.7.6" @@ -55,21 +49,6 @@ dependencies = [ "memchr", ] -[[package]] -name = "alloc-no-stdlib" -version = "2.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc7bb162ec39d46ab1ca8c77bf72e890535becd1751bb45f64c597edb4c8c6b3" - -[[package]] -name = "alloc-stdlib" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94fb8275041c72129eb51b7d0322c29b8387a0386127718b096429201a5d6ece" -dependencies = [ - "alloc-no-stdlib", -] - [[package]] name = "android_system_properties" version = "0.1.5" @@ -101,10 +80,25 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8da52d66c7071e2e3fa2a1e5c6d088fec47b593032b254f5e980de8ea54454d6" [[package]] -name = "ascii" -version = "1.1.0" +name = "async-stream" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d92bec98840b8f03a5ff5413de5293bfcd8bf96467cf5452609f939ec6f5de16" +checksum = "dad5c83079eae9969be7fadefe640a1c566901f05ff91ab221de4b6f68d9507e" +dependencies = [ + "async-stream-impl", + "futures-core", +] + +[[package]] +name = "async-stream-impl" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10f203db73a71dfa2fb6dd22763990fa26f3d2625a6da2da900d23b87d26be27" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] [[package]] name = "async-trait" @@ -158,6 +152,68 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" +[[package]] +name = "axum" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1304eab461cf02bd70b083ed8273388f9724c549b316ba3d1e213ce0e9e7fb7e" +dependencies = [ + "async-trait", + "axum-core", + "bitflags", + "bytes", + "futures-util", + "http", + "http-body", + "hyper", + "itoa 1.0.5", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "rustversion", + "serde", + "serde_json", + "serde_path_to_error", + "serde_urlencoded", + "sync_wrapper", + "tokio", + "tower", + "tower-http", + "tower-layer", + "tower-service", +] + +[[package]] +name = "axum-core" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f487e40dc9daee24d8a1779df88522f159a54a980f99cfbe43db0be0bd3444a8" +dependencies = [ + "async-trait", + "bytes", + "futures-util", + "http", + "http-body", + "mime", + "rustversion", + "tower-layer", + "tower-service", +] + +[[package]] +name = "axum-macros" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc7d7c3e69f305217e317a28172aab29f275667f2e1c15b87451e134fe27c7b1" +dependencies = [ + "heck 0.4.0", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "backtrace" version = "0.3.67" @@ -228,27 +284,6 @@ dependencies = [ "cmake", ] -[[package]] -name = "brotli" -version = "3.3.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1a0b1dbcc8ae29329621f8d4f0d835787c1c38bb1401979b49d13b0b305ff68" -dependencies = [ - "alloc-no-stdlib", - "alloc-stdlib", - "brotli-decompressor", -] - -[[package]] -name = "brotli-decompressor" -version = "2.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59ad2d4653bf5ca36ae797b1f4bb4dbddb60ce49ca4aed8a2ce4829f60425b80" -dependencies = [ - "alloc-no-stdlib", - "alloc-stdlib", -] - [[package]] name = "bstr" version = "0.2.17" @@ -261,16 +296,6 @@ dependencies = [ "serde", ] -[[package]] -name = "buf_redux" -version = "0.8.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b953a6887648bb07a535631f2bc00fbdb2a2216f135552cb3f534ed136b9c07f" -dependencies = [ - "memchr", - "safemem", -] - [[package]] name = "bumpalo" version = "3.11.1" @@ -369,7 +394,7 @@ dependencies = [ "js-sys", "num-integer", "num-traits", - "time 0.1.45", + "time", "wasm-bindgen", "winapi", ] @@ -396,12 +421,6 @@ dependencies = [ "phf_codegen", ] -[[package]] -name = "chunked_transfer" -version = "1.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cca491388666e04d7248af3f60f0c40cfb0991c72205595d7c396e3510207d1a" - [[package]] name = "clang-sys" version = "1.4.0" @@ -595,21 +614,27 @@ dependencies = [ name = "cozo-bin" version = "0.5.0" dependencies = [ + "async-stream", + "axum", + "axum-macros", "chrono", "clap 4.0.32", "cozo", + "crossbeam", "ctrlc", "env_logger", + "futures 0.3.25", + "itertools 0.10.5", "log", "miette", "minreq", "prettytable", "rand 0.8.5", - "rouille", "rustyline", "serde", "serde_derive", "serde_json", + "tokio", ] [[package]] @@ -856,16 +881,6 @@ dependencies = [ "syn", ] -[[package]] -name = "deflate" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c86f7e25f518f4b81808a2cf1c50996a61f5c2eb394b2393bd87f2a4780a432f" -dependencies = [ - "adler32", - "gzip-header", -] - [[package]] name = "delegate" version = "0.8.0" @@ -1036,18 +1051,6 @@ dependencies = [ "windows-sys", ] -[[package]] -name = "filetime" -version = "0.2.19" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e884668cd0c7480504233e951174ddc3b382f7c2666e3b7310b5c4e7b0c37f9" -dependencies = [ - "cfg-if 1.0.0", - "libc", - "redox_syscall", - "windows-sys", -] - [[package]] name = "fixedbitset" version = "0.2.0" @@ -1358,15 +1361,6 @@ dependencies = [ "walkdir", ] -[[package]] -name = "gzip-header" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95cc527b92e6029a62960ad99aa8a6660faa4555fe5f731aab13aa6a921795a2" -dependencies = [ - "crc32fast", -] - [[package]] name = "h2" version = "0.3.15" @@ -1456,6 +1450,12 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "http-range-header" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bfe8eed0a9285ef776bb792479ea3834e8b94e13d615c2f66d03dd50a435a29" + [[package]] name = "httparse" version = "1.8.0" @@ -1781,6 +1781,12 @@ dependencies = [ "libc", ] +[[package]] +name = "matchit" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b87248edafb776e59e6ee64a79086f65890d3510f2c656c000bf2a7e8a0aea40" + [[package]] name = "memchr" version = "2.5.0" @@ -1857,16 +1863,6 @@ version = "0.3.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2a60c7ce501c71e03a9c9c0d35b861413ae925bd979cc7a4e30d060069aaac8d" -[[package]] -name = "mime_guess" -version = "2.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4192263c238a5f0d0c6bfd21f336a313a4ce1c450542449ca191bb657b4642ef" -dependencies = [ - "mime", - "unicase", -] - [[package]] name = "miniz_oxide" version = "0.6.2" @@ -1907,24 +1903,6 @@ version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e5ce46fe64a9d73be07dcbe690a38ce1b293be448fd8ce1e6c1b8062c9f72c6a" -[[package]] -name = "multipart" -version = "0.18.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "00dec633863867f29cb39df64a397cdf4a6354708ddd7759f70c7fb51c5f9182" -dependencies = [ - "buf_redux", - "httparse", - "log", - "mime", - "mime_guess", - "quick-error", - "rand 0.8.5", - "safemem", - "tempfile", - "twoway", -] - [[package]] name = "nanorand" version = "0.7.0" @@ -2128,15 +2106,6 @@ dependencies = [ "libc", ] -[[package]] -name = "num_threads" -version = "0.1.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2819ce041d2ee131036f4fc9d6ae7ae125a3a40e97ba64d04fe799ad9dabbb44" -dependencies = [ - "libc", -] - [[package]] name = "object" version = "0.30.2" @@ -2415,6 +2384,26 @@ dependencies = [ "siphasher", ] +[[package]] +name = "pin-project" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad29a609b6bcd67fee905812e544992d216af9d755757c05ed2d0e15a74c6ecc" +dependencies = [ + "pin-project-internal", +] + +[[package]] +name = "pin-project-internal" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "069bdb1e05adc7a8990dce9cc75370895fbe4e3d58b9b73bf1aee56359344a55" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "pin-project-lite" version = "0.2.9" @@ -2724,12 +2713,6 @@ dependencies = [ "syn", ] -[[package]] -name = "quick-error" -version = "1.2.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0" - [[package]] name = "quote" version = "1.0.23" @@ -2978,31 +2961,6 @@ dependencies = [ "rmp", ] -[[package]] -name = "rouille" -version = "3.6.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f86e4c51a773f953f02bbab5fd049f004bfd384341d62da2a079aff812ab176" -dependencies = [ - "base64 0.13.1", - "brotli", - "chrono", - "deflate", - "filetime", - "multipart", - "num_cpus", - "percent-encoding", - "rand 0.8.5", - "serde", - "serde_derive", - "serde_json", - "sha1", - "threadpool", - "time 0.3.17", - "tiny_http", - "url", -] - [[package]] name = "rustc-demangle" version = "0.1.21" @@ -3076,12 +3034,6 @@ version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7b4b9743ed687d4b4bcedf9ff5eaa7398495ae14e61cba0a295704edbc7decde" -[[package]] -name = "safemem" -version = "0.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef703b7cb59335eae2eb93ceb664c0eb7ea6bf567079d843e09420219668e072" - [[package]] name = "same-file" version = "1.0.6" @@ -3206,6 +3158,15 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_path_to_error" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26b04f22b563c91331a10074bda3dd5492e3cc39d56bd557e91c0af42b6c7341" +dependencies = [ + "serde", +] + [[package]] name = "serde_urlencoded" version = "0.7.1" @@ -3218,17 +3179,6 @@ dependencies = [ "serde", ] -[[package]] -name = "sha1" -version = "0.10.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f04293dc80c3993519f2d7f6f511707ee7094fe0c6d3406feb330cdb3540eba3" -dependencies = [ - "cfg-if 1.0.0", - "cpufeatures", - "digest", -] - [[package]] name = "sha2" version = "0.10.6" @@ -3246,6 +3196,15 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7fdf1b9db47230893d76faad238fd6097fd6d6a9245cd7a4d90dbd639536bbd2" +[[package]] +name = "signal-hook-registry" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e51e73328dc4ac0c7ccbda3a494dfa03df1de2f46018127f60c693f2648455b0" +dependencies = [ + "libc", +] + [[package]] name = "siphasher" version = "0.3.10" @@ -3463,6 +3422,12 @@ dependencies = [ "syn", ] +[[package]] +name = "sync_wrapper" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "20518fe4a4c9acf048008599e464deb21beeae3d3578418951a189c235a7a9a8" + [[package]] name = "target-lexicon" version = "0.12.5" @@ -3550,15 +3515,6 @@ dependencies = [ "syn", ] -[[package]] -name = "threadpool" -version = "1.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d050e60b33d41c19108b32cea32164033a9013fe3b46cbd4457559bfbf77afaa" -dependencies = [ - "num_cpus", -] - [[package]] name = "tikv-client" version = "0.1.0" @@ -3687,36 +3643,6 @@ dependencies = [ "winapi", ] -[[package]] -name = "time" -version = "0.3.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a561bf4617eebd33bca6434b988f39ed798e527f51a1e797d0ee4f61c0a38376" -dependencies = [ - "libc", - "num_threads", - "serde", - "time-core", -] - -[[package]] -name = "time-core" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e153e1f1acaef8acc537e68b44906d2db6436e2b35ac2c6b42640fff91f00fd" - -[[package]] -name = "tiny_http" -version = "0.12.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "389915df6413a2e74fb181895f933386023c71110878cd0825588928e64cdc82" -dependencies = [ - "ascii", - "chunked_transfer", - "httpdate", - "log", -] - [[package]] name = "tinyvec" version = "1.6.0" @@ -3744,11 +3670,25 @@ dependencies = [ "memchr", "mio", "num_cpus", + "parking_lot 0.12.1", "pin-project-lite", + "signal-hook-registry", "socket2", + "tokio-macros", "windows-sys", ] +[[package]] +name = "tokio-macros" +version = "1.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d266c00fde287f55d3f1c3e96c500c362a2b8c695076ec180f27918820bc6df8" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "tokio-native-tls" version = "0.3.0" @@ -3782,6 +3722,47 @@ dependencies = [ "serde", ] +[[package]] +name = "tower" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c" +dependencies = [ + "futures-core", + "futures-util", + "pin-project", + "pin-project-lite", + "tokio", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "tower-http" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f873044bf02dd1e8239e9c1293ea39dad76dc594ec16185d0a1bf31d8dc8d858" +dependencies = [ + "bitflags", + "bytes", + "futures-core", + "futures-util", + "http", + "http-body", + "http-range-header", + "pin-project-lite", + "tower", + "tower-layer", + "tower-service", +] + +[[package]] +name = "tower-layer" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c20c8dbed6283a09604c3e69b4b7eeb54e298b8a600d4d5ecb5ad39de609f1d0" + [[package]] name = "tower-service" version = "0.3.2" @@ -3795,6 +3776,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ce8c33a8d48bd45d624a6e523445fd21ec13d3653cd51f681abf67418f54eb8" dependencies = [ "cfg-if 1.0.0", + "log", "pin-project-lite", "tracing-core", ] @@ -3814,15 +3796,6 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3528ecfd12c466c6f163363caf2d02a71161dd5e1cc6ae7b34207ea2d42d81ed" -[[package]] -name = "twoway" -version = "0.1.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59b11b2b5241ba34be09c3cc85a36e56e48f9888862e19cedf23336d35316ed1" -dependencies = [ - "memchr", -] - [[package]] name = "typenum" version = "1.16.0" @@ -3835,15 +3808,6 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9e79c4d996edb816c91e4308506774452e55e95c3c9de07b6729e17e15a5ef81" -[[package]] -name = "unicase" -version = "2.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50f37be617794602aabbeee0be4f259dc1778fabe05e2d67ee8f79326d5cb4f6" -dependencies = [ - "version_check", -] - [[package]] name = "unicode-bidi" version = "0.3.8" diff --git a/cozo-bin/Cargo.toml b/cozo-bin/Cargo.toml index 0bfa8d8c..4c528c3f 100644 --- a/cozo-bin/Cargo.toml +++ b/cozo-bin/Cargo.toml @@ -44,7 +44,6 @@ storage-tikv = ["cozo/storage-tikv"] [dependencies] cozo = { version = "0.5.0", path = "../cozo-core", default-features = false } clap = { version = "4.0.26", features = ["derive"] } -rouille = "3.5.0" env_logger = "0.10.0" log = "0.4.17" rand = "0.8.5" @@ -56,4 +55,11 @@ prettytable = "0.10.0" rustyline = "10.0.0" minreq = { version = "2.6.0", features = ["https-rustls"] } miette = { version = "5.5.0", features = ["fancy"] } -ctrlc = "3.2.4" \ No newline at end of file +ctrlc = "3.2.4" +axum = "0.6.2" +axum-macros = "0.3.1" +itertools = "0.10.5" +tokio = { version = "1.24.1", features = ["full"] } +async-stream = "0.3.3" +futures = "0.3.25" +crossbeam = "0.8.2" \ No newline at end of file diff --git a/cozo-bin/src/client.rs b/cozo-bin/src/client.rs new file mode 100644 index 00000000..6501c605 --- /dev/null +++ b/cozo-bin/src/client.rs @@ -0,0 +1,8 @@ +/* + * Copyright 2023, The Cozo Project Authors. + * + * This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. + * If a copy of the MPL was not distributed with this file, + * You can obtain one at https://mozilla.org/MPL/2.0/. + */ + diff --git a/cozo-bin/src/index.html b/cozo-bin/src/index.html new file mode 100644 index 00000000..ee88a6d4 --- /dev/null +++ b/cozo-bin/src/index.html @@ -0,0 +1,52 @@ + + + + + Cozo database + + +

Cozo API is running.

+ + + \ No newline at end of file diff --git a/cozo-bin/src/main.rs b/cozo-bin/src/main.rs index 9b610846..babf597f 100644 --- a/cozo-bin/src/main.rs +++ b/cozo-bin/src/main.rs @@ -6,25 +6,25 @@ * You can obtain one at https://mozilla.org/MPL/2.0/. */ -use std::collections::BTreeMap; use std::fmt::Debug; use std::fs; -use std::net::Ipv6Addr; use std::process::exit; -use std::str::FromStr; use clap::{Args, Parser, Subcommand}; use env_logger::Env; use log::{error, info}; -use rand::Rng; -use rouille::{router, try_or_400, Request, Response}; -use serde_json::json; - -use cozo::*; use crate::repl::repl_main; +use crate::server::{server_main, ServerArgs}; + +// use rand::Rng; +// use serde_json::json; +// use cozo::*; + +mod client; mod repl; +mod server; #[derive(Parser)] #[command(author, version, about, long_about = None)] @@ -37,14 +37,15 @@ struct AppArgs { #[derive(Subcommand)] enum Commands { /// - Server(Server), - Client(Client), - Repl(Repl), - Restore(Restore), + Server(ServerArgs), + Client(ClientArgs), + Repl(ReplArgs), + Restore(RestoreArgs), + Stream(StreamArgs), } #[derive(Args, Debug)] -struct Repl { +struct ReplArgs { /// Database engine, can be `mem`, `sqlite`, `rocksdb` and others. #[clap(short, long, default_value_t = String::from("mem"))] engine: String, @@ -59,7 +60,7 @@ struct Repl { } #[derive(Args, Debug)] -struct Client { +struct ClientArgs { #[clap(default_value_t = String::from("http://127.0.0.1:9070"))] address: String, #[clap(short, long, default_value_t = String::from(""))] @@ -67,7 +68,10 @@ struct Client { } #[derive(Args, Debug)] -struct Restore { +struct StreamArgs {} + +#[derive(Args, Debug)] +struct RestoreArgs { /// Path of the backup file to restore from, must be a SQLite-backed backup file. from: String, /// Path of the database to restore into @@ -77,47 +81,18 @@ struct Restore { engine: String, } -#[derive(Args, Debug)] -struct Server { - /// Database engine, can be `mem`, `sqlite`, `rocksdb` and others. - #[clap(short, long, default_value_t = String::from("mem"))] - engine: String, - - /// Path to the directory to store the database - #[clap(short, long, default_value_t = String::from("cozo.db"))] - path: String, - - // Restore from the specified backup before starting the server - // #[clap(long)] - // restore: Option, - /// Extra config in JSON format - #[clap(short, long, default_value_t = String::from("{}"))] - config: String, - - // When on, start REPL instead of starting a webserver - // #[clap(short, long)] - // repl: bool, - /// Address to bind the service to - #[clap(short, long, default_value_t = String::from("127.0.0.1"))] - bind: String, - - /// Port to use - #[clap(short = 'P', long, default_value_t = 9070)] - port: u16, -} - -macro_rules! check_auth { - ($request:expr, $auth_guard:expr) => { - match $request.header("x-cozo-auth") { - None => return Response::text("Unauthorized").with_status_code(401), - Some(code) => { - if $auth_guard != code { - return Response::text("Unauthorized").with_status_code(401); - } - } - } - }; -} +// macro_rules! check_auth { +// ($request:expr, $auth_guard:expr) => { +// match $request.header("x-cozo-auth") { +// None => return Response::text("Unauthorized").with_status_code(401), +// Some(code) => { +// if $auth_guard != code { +// return Response::text("Unauthorized").with_status_code(401); +// } +// } +// } +// }; +// } fn main() { let args = match AppArgs::parse().command { @@ -131,17 +106,10 @@ fn main() { Commands::Restore(_) => { todo!() } + Commands::Stream(_) => { + todo!() + } }; - if args.bind != "127.0.0.1" { - eprintln!("{SECURITY_WARNING}"); - } - - let db = DbInstance::new( - args.engine.as_str(), - args.path.as_str(), - &args.config.clone(), - ) - .unwrap(); // if let Some(restore_path) = &args.restore { // db.restore_backup(restore_path).unwrap(); @@ -169,233 +137,54 @@ fn main() { // } // } else { env_logger::Builder::from_env(Env::default().default_filter_or("info")).init(); - server_main(args, db) - // } -} - -fn server_main(args: Server, db: DbInstance) { - let conf_path = format!("{}.{}.cozo_auth", args.path, args.engine); - let auth_guard = match fs::read_to_string(&conf_path) { - Ok(s) => s.trim().to_string(), - Err(_) => { - let s = rand::thread_rng() - .sample_iter(&rand::distributions::Alphanumeric) - .take(64) - .map(char::from) - .collect(); - fs::write(&conf_path, &s).unwrap(); - s - } - }; - - let addr = if Ipv6Addr::from_str(&args.bind).is_ok() { - format!("[{}]:{}", args.bind, args.port) - } else { - format!("{}:{}", args.bind, args.port) - }; - println!( - "Database ({} backend) web API running at http://{}", - args.engine, addr - ); - println!("The auth file is at {conf_path}"); - rouille::start_server(addr, move |request| { - let now = chrono::Utc::now().format("%Y-%m-%d %H:%M:%S%.6f"); - let log_ok = |req: &Request, _resp: &Response, elap: std::time::Duration| { - info!("{} {} {} {:?}", now, req.method(), req.raw_url(), elap); - }; - let log_err = |req: &Request, elap: std::time::Duration| { - error!( - "{} Handler panicked: {} {} {:?}", - now, - req.method(), - req.raw_url(), - elap - ); - }; - rouille::log_custom(request, log_ok, log_err, || { - router!(request, - (POST) (/text-query) => { - if !request.remote_addr().ip().is_loopback() { - check_auth!(request, auth_guard); - } - - #[derive(serde_derive::Serialize, serde_derive::Deserialize)] - struct QueryPayload { - script: String, - params: BTreeMap, - } - - let payload: QueryPayload = try_or_400!(rouille::input::json_input(request)); - let params = payload.params.into_iter().map(|(k, v)| - (k, DataValue::from(v))).collect(); - let result = db.run_script_fold_err(&payload.script, params); - let response = Response::json(&result); - if let Some(serde_json::Value::Bool(true)) = result.get("ok") { - response - } else { - response.with_status_code(400) - } - }, - (GET) (/export/{relations: String}) => { - if !request.remote_addr().ip().is_loopback() { - check_auth!(request, auth_guard); - } - - let relations = relations.split(',').filter(|t| !t.is_empty()); - let result = db.export_relations(relations); - match result { - Ok(s) => { - let ret = json!({"ok": true, "data": s}); - Response::json(&ret) - } - Err(err) => { - let ret = json!({"ok": false, "message": err.to_string()}); - Response::json(&ret).with_status_code(400) - } - } - }, - (PUT) (/import) => { - if !request.remote_addr().ip().is_loopback() { - check_auth!(request, auth_guard); - } - - let payload: BTreeMap = try_or_400!(rouille::input::json_input(request)); - let result = db.import_relations(payload); - - match result { - Ok(()) => { - let ret = json!({"ok": true}); - Response::json(&ret) - } - Err(err) => { - let ret = json!({"ok": false, "message": err.to_string()}); - Response::json(&ret).with_status_code(400) - } - } - }, - (POST) (/backup) => { - if !request.remote_addr().ip().is_loopback() { - check_auth!(request, auth_guard); - } + tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap() + .block_on(server_main(args)) - #[derive(serde_derive::Serialize, serde_derive::Deserialize)] - struct BackupPayload { - path: String, - } - - let payload: BackupPayload = try_or_400!(rouille::input::json_input(request)); - - let result = db.backup_db(payload.path); - - match result { - Ok(()) => { - let ret = json!({"ok": true}); - Response::json(&ret) - } - Err(err) => { - let ret = json!({"ok": false, "message": err.to_string()}); - Response::json(&ret).with_status_code(400) - } - } - }, - (POST) (/import-from-backup) => { - if !request.remote_addr().ip().is_loopback() { - check_auth!(request, auth_guard); - } - - #[derive(serde_derive::Serialize, serde_derive::Deserialize)] - struct BackupImportPayload { - path: String, - relations: Vec - } - - let payload: BackupImportPayload = try_or_400!(rouille::input::json_input(request)); - let result = db.import_from_backup(&payload.path, &payload.relations); - - match result { - Ok(()) => { - let ret = json!({"ok": true}); - Response::json(&ret) - } - Err(err) => { - let ret = json!({"ok": false, "message": err.to_string()}); - Response::json(&ret).with_status_code(400) - } - } - }, - (GET) (/) => { - Response::html(HTML_CONTENT) - }, - _ => Response::empty_404() - ) - }) - }); + // server_main(args, db) + // } } -const HTML_CONTENT: &str = r##" - - - - -Cozo database - - -

Cozo HTTP server is running.

- - - -"##; - -const SECURITY_WARNING: &str = r#" -==================================================================================== - !! SECURITY NOTICE, PLEASE READ !! -==================================================================================== -You instructed Cozo to bind to a non-default address. -Cozo is designed to be accessed by trusted clients in a trusted network. -As a last defense against unauthorized access when everything else fails, -any requests from non-loopback addresses require the HTTP request header -`x-cozo-auth` to be set to the content of auth.txt in your database directory. -This is not a sufficient protection against attacks, and you must set up -proper authentication schemes, encryptions, etc. by firewalls and/or proxies. -==================================================================================== -"#; +// fn server_main(args: Server, db: DbInstance) { +// let conf_path = format!("{}.{}.cozo_auth", args.path, args.engine); +// let auth_guard = match fs::read_to_string(&conf_path) { +// Ok(s) => s.trim().to_string(), +// Err(_) => { +// let s = rand::thread_rng() +// .sample_iter(&rand::distributions::Alphanumeric) +// .take(64) +// .map(char::from) +// .collect(); +// fs::write(&conf_path, &s).unwrap(); +// s +// } +// }; +// +// let addr = if Ipv6Addr::from_str(&args.bind).is_ok() { +// format!("[{}]:{}", args.bind, args.port) +// } else { +// format!("{}:{}", args.bind, args.port) +// }; +// println!( +// "Database ({} backend) web API running at http://{}", +// args.engine, addr +// ); +// println!("The auth file is at {conf_path}"); +// rouille::start_server(addr, move |request| { +// let now = chrono::Utc::now().format("%Y-%m-%d %H:%M:%S%.6f"); +// let log_ok = |req: &Request, _resp: &Response, elap: std::time::Duration| { +// info!("{} {} {} {:?}", now, req.method(), req.raw_url(), elap); +// }; +// let log_err = |req: &Request, elap: std::time::Duration| { +// error!( +// "{} Handler panicked: {} {} {:?}", +// now, +// req.method(), +// req.raw_url(), +// elap +// ); +// }; +// }); +// } diff --git a/cozo-bin/src/security.txt b/cozo-bin/src/security.txt new file mode 100644 index 00000000..a7c8860c --- /dev/null +++ b/cozo-bin/src/security.txt @@ -0,0 +1,12 @@ + +==================================================================================== + !! SECURITY NOTICE, PLEASE READ !! +==================================================================================== +You instructed Cozo to bind to a non-default address. +Cozo is designed to be accessed by trusted clients in a trusted network. +As a last defense against unauthorized access when everything else fails, +any requests from non-loopback addresses require the HTTP request header +`x-cozo-auth` to be set to the content of auth.txt in your database directory. +This is not a sufficient protection against attacks, and you must set up +proper authentication schemes, encryptions, etc. by firewalls and/or proxies. +==================================================================================== diff --git a/cozo-bin/src/server.rs b/cozo-bin/src/server.rs new file mode 100644 index 00000000..a534f16f --- /dev/null +++ b/cozo-bin/src/server.rs @@ -0,0 +1,348 @@ +/* + * Copyright 2023, The Cozo Project Authors. + * + * This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. + * If a copy of the MPL was not distributed with this file, + * You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +use std::collections::BTreeMap; +use std::convert::Infallible; +use std::net::{Ipv6Addr, SocketAddr}; +use std::str::FromStr; +use std::sync::{mpsc, Arc, Mutex}; + +use axum::extract::{Path, Query, State}; +use axum::http::StatusCode; +use axum::response::sse::{Event, KeepAlive}; +use axum::response::{Html, Sse}; +use axum::routing::{get, post, put}; +use axum::{Json, Router}; +use clap::Args; +use futures::stream::{self, Stream}; +use itertools::Itertools; +use log::{info, warn}; +use serde_json::json; +use tokio::task::spawn_blocking; + +use cozo::{DataValue, DbInstance, NamedRows}; + +#[derive(Args, Debug)] +pub(crate) struct ServerArgs { + /// Database engine, can be `mem`, `sqlite`, `rocksdb` and others. + #[clap(short, long, default_value_t = String::from("mem"))] + engine: String, + + /// Path to the directory to store the database + #[clap(short, long, default_value_t = String::from("cozo.db"))] + path: String, + + // Restore from the specified backup before starting the server + // #[clap(long)] + // restore: Option, + /// Extra config in JSON format + #[clap(short, long, default_value_t = String::from("{}"))] + config: String, + + // When on, start REPL instead of starting a webserver + // #[clap(short, long)] + // repl: bool, + /// Address to bind the service to + #[clap(short, long, default_value_t = String::from("127.0.0.1"))] + bind: String, + + /// Port to use + #[clap(short = 'P', long, default_value_t = 9070)] + port: u16, +} + +type RuleCallbackStore = BTreeMap>>; +type DbState = (DbInstance, Arc>); + +pub(crate) async fn server_main(args: ServerArgs) { + let db = DbInstance::new(&args.engine, args.path, &args.config).unwrap(); + let rule_channels: Arc> = Default::default(); + let state = (db, rule_channels); + let app = Router::new() + .fallback(not_found) + .route("/", get(root)) + .route("/text-query", post(text_query)) + .route("/export/:relations", get(export_relations)) + .route("/import", put(import_relations)) + .route("/backup", post(backup)) + .route("/import-from-backup", post(import_from_backup)) + .route("/changes/:relations", get(observe_changes)) + .route("/rules/:name", get(register_rule)) // sse + post + .route("/rules/:name/:id", post(rule_result)) + .with_state(state); + let addr = if Ipv6Addr::from_str(&args.bind).is_ok() { + SocketAddr::from_str(&format!("[{}]:{}", args.bind, args.port)).unwrap() + } else { + SocketAddr::from_str(&format!("{}:{}", args.bind, args.port)).unwrap() + }; + + if args.bind != "127.0.0.1" { + warn!("{}", include_str!("./security.txt")); + } + + info!( + "Starting Cozo ({}-backed) API at http://{}", + args.engine, addr + ); + + axum::Server::bind(&addr) + .serve(app.into_make_service()) + .await + .unwrap(); +} + +#[derive(serde_derive::Deserialize)] +struct QueryPayload { + script: String, + params: BTreeMap, +} + +async fn text_query( + State((db, _)): State, + Json(payload): Json, +) -> (StatusCode, Json) { + let params = payload + .params + .into_iter() + .map(|(k, v)| (k, DataValue::from(v))) + .collect(); + let result = spawn_blocking(move || db.run_script_fold_err(&payload.script, params)).await; + match result { + Ok(res) => wrap_json(res), + Err(err) => internal_error(err), + } +} + +async fn export_relations( + State((db, _)): State, + Path(relations): Path, +) -> (StatusCode, Json) { + let relations = relations + .split(',') + .filter_map(|t| { + if t.is_empty() { + None + } else { + Some(t.to_string()) + } + }) + .collect_vec(); + let result = spawn_blocking(move || db.export_relations(relations.iter())).await; + match result { + Ok(Ok(s)) => { + let ret = json!({"ok": true, "data": s}); + (StatusCode::OK, ret.into()) + } + Ok(Err(err)) => { + let ret = json!({"ok": false, "message": err.to_string()}); + (StatusCode::BAD_REQUEST, ret.into()) + } + Err(err) => internal_error(err), + } +} + +async fn import_relations( + State((db, _)): State, + Json(payload): Json, +) -> (StatusCode, Json) { + let payload = match payload.as_object() { + None => { + return ( + StatusCode::BAD_REQUEST, + json!({"ok": false, "message": "payload must be a JSON object"}).into(), + ) + } + Some(pl) => { + let mut ret = BTreeMap::new(); + for (k, v) in pl { + let nr = match NamedRows::from_json(v) { + Ok(p) => p, + Err(err) => { + return ( + StatusCode::BAD_REQUEST, + json!({"ok": false, "message": err.to_string()}).into(), + ) + } + }; + ret.insert(k.to_string(), nr); + } + ret + } + }; + + let result = spawn_blocking(move || db.import_relations(payload)).await; + match result { + Ok(Ok(_)) => (StatusCode::OK, json!({"ok": true}).into()), + Ok(Err(err)) => { + let ret = json!({"ok": false, "message": err.to_string()}); + (StatusCode::BAD_REQUEST, ret.into()) + } + Err(err) => internal_error(err), + } +} +#[derive(serde_derive::Deserialize)] +struct BackupPayload { + path: String, +} + +async fn backup( + State((db, _)): State, + Json(payload): Json, +) -> (StatusCode, Json) { + let result = spawn_blocking(move || db.backup_db(payload.path)).await; + + match result { + Ok(Ok(())) => { + let ret = json!({"ok": true}); + (StatusCode::OK, ret.into()) + } + Ok(Err(err)) => { + let ret = json!({"ok": false, "message": err.to_string()}); + (StatusCode::BAD_REQUEST, ret.into()) + } + Err(err) => internal_error(err), + } +} +#[derive(serde_derive::Deserialize)] +struct BackupImportPayload { + path: String, + relations: Vec, +} +async fn import_from_backup( + State((db, _)): State, + Json(payload): Json, +) -> (StatusCode, Json) { + let result = + spawn_blocking(move || db.import_from_backup(&payload.path, &payload.relations)).await; + + match result { + Ok(Ok(())) => { + let ret = json!({"ok": true}); + (StatusCode::OK, ret.into()) + } + Ok(Err(err)) => { + let ret = json!({"ok": false, "message": err.to_string()}); + (StatusCode::BAD_REQUEST, ret.into()) + } + Err(err) => internal_error(err), + } +} + +#[derive(serde_derive::Deserialize)] +struct RuleRegisterOptions { + arity: usize, +} + +async fn rule_result( + State((store, _)): State, + Path(name): Path, + Path(id): Path, +) -> (StatusCode, Json) { + todo!() +} + +async fn register_rule( + State((db, cbs)): State, + Path(name): Path, + Query(rule_opts): Query, +) -> Sse>> { + let (id, recv) = db.register_callback(&name, None); + let (sender, mut receiver) = tokio::sync::mpsc::channel(1); + struct Guard { + id: u32, + db: DbInstance, + relation: String, + } + + impl Drop for Guard { + fn drop(&mut self) { + info!("dropping changes SSE {}: {}", self.relation, self.id); + self.db.unregister_callback(self.id); + } + } + + spawn_blocking(move || { + for data in recv { + sender.blocking_send(data).unwrap(); + } + }); + let stream = async_stream::stream! { + info!("starting callback SSE {}: {}", name, id); + let _guard = Guard {id, db, relation: name}; + while let Some((op, new, old)) = receiver.recv().await { + let item = json!({"op": op.to_string(), "new_rows": new.into_json(), "old_rows": old.into_json()}); + yield Ok(Event::default().json_data(item).unwrap()); + } + }; + Sse::new(stream).keep_alive(KeepAlive::default()) +} + +async fn observe_changes( + State((db, _)): State, + Path(relation): Path, +) -> Sse>> { + let (id, recv) = db.register_callback(&relation, None); + let (sender, mut receiver) = tokio::sync::mpsc::channel(1); + struct Guard { + id: u32, + db: DbInstance, + relation: String, + } + + impl Drop for Guard { + fn drop(&mut self) { + info!("dropping changes SSE {}: {}", self.relation, self.id); + self.db.unregister_callback(self.id); + } + } + + spawn_blocking(move || { + for data in recv { + sender.blocking_send(data).unwrap(); + } + }); + let stream = async_stream::stream! { + info!("starting changes SSE {}: {}", relation, id); + let _guard = Guard {id, db, relation}; + while let Some((op, new, old)) = receiver.recv().await { + let item = json!({"op": op.to_string(), "new_rows": new.into_json(), "old_rows": old.into_json()}); + yield Ok(Event::default().json_data(item).unwrap()); + } + }; + Sse::new(stream).keep_alive(KeepAlive::default()) +} + +async fn root() -> Html<&'static str> { + Html(include_str!("./index.html")) +} + +fn internal_error(err: E) -> (StatusCode, Json) +where + E: std::error::Error, +{ + ( + StatusCode::INTERNAL_SERVER_ERROR, + json!({"ok": false, "message": err.to_string()}).into(), + ) +} + +fn wrap_json(json: serde_json::Value) -> (StatusCode, Json) { + let code = if let Some(serde_json::Value::Bool(true)) = json.get("ok") { + StatusCode::OK + } else { + StatusCode::BAD_REQUEST + }; + (code, json.into()) +} + +pub async fn not_found(uri: axum::http::Uri) -> (StatusCode, Json) { + ( + StatusCode::NOT_FOUND, + json!({"ok": false, "message": format!("No route {}", uri)}).into(), + ) +} diff --git a/cozo-core/src/runtime/db.rs b/cozo-core/src/runtime/db.rs index 81fdfa98..0cda3eb7 100644 --- a/cozo-core/src/runtime/db.rs +++ b/cozo-core/src/runtime/db.rs @@ -173,6 +173,39 @@ impl NamedRows { "next": nxt, }) } + /// Make named rows from JSON + pub fn from_json(value: &JsonValue) -> Result { + let headers = value + .get("headers") + .ok_or_else(|| miette!("NamedRows requires 'headers' field"))?; + let headers = headers + .as_array() + .ok_or_else(|| miette!("'headers' field must be an array"))?; + let headers = headers.iter().map(|h| -> Result { + let h = h.as_str().ok_or_else(|| miette!("'headers' field must be an array of strings"))?; + Ok(h.to_string()) + }).try_collect()?; + let rows = value + .get("rows") + .ok_or_else(|| miette!("NamedRows requires 'rows' field"))?; + let rows = rows + .as_array() + .ok_or_else(|| miette!("'rows' field must be an array"))?; + let rows = rows + .iter() + .map(|row| -> Result> { + let row = row + .as_array() + .ok_or_else(|| miette!("'rows' field must be an array of arrays"))?; + Ok(row.iter().map(|el| DataValue::from(el)).collect_vec()) + }) + .try_collect()?; + Ok(Self { + headers, + rows, + next: None, + }) + } } const STATUS_STR: &str = "status";