towards minimal server

main
Ziyang Hu 2 years ago
parent 1a4e9fda13
commit 1a072ab1ce

1192
Cargo.lock generated

File diff suppressed because it is too large Load Diff

@ -37,7 +37,7 @@ pest = "2.2.1"
pest_derive = "2.2.1" pest_derive = "2.2.1"
rayon = "1.5.3" rayon = "1.5.3"
nalgebra = "0.31.1" nalgebra = "0.31.1"
reqwest = { version = "0.11.11", features = ["blocking"] } minreq = { version = "2.6.0", features = ["https-native"] }
approx = "0.5.1" approx = "0.5.1"
unicode-normalization = "0.1.21" unicode-normalization = "0.1.21"
thiserror = "1.0.34" thiserror = "1.0.34"

@ -7,21 +7,15 @@ license = "AGPL-3.0-or-later"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies] [dependencies]
actix-web = "4.1.0"
clap = { version = "3.2.8", features = ["derive"] } clap = { version = "3.2.8", features = ["derive"] }
actix-cors = "0.6.1"
log = "0.4.16" log = "0.4.16"
miette = { version = "5.3.0", features = ["fancy"] } miette = { version = "5.3.0", features = ["fancy"] }
env_logger = "0.9.0" env_logger = "0.9.0"
serde = "1.0.144" serde = "1.0.144"
serde_json = "1.0.81" serde_json = "1.0.81"
serde_derive = "1.0.144" serde_derive = "1.0.144"
rust-argon2 = "1.0.0"
sha3 = "0.10.4"
rand = "0.8.5"
rpassword = "7.0.0" rpassword = "7.0.0"
webbrowser = "0.8.0" rouille = "3.5.0"
futures = "0.3.24"
cozo = { path = ".." } cozo = { path = ".." }
[build-dependencies] [build-dependencies]

@ -1,58 +1,14 @@
use std::collections::{BTreeMap, HashMap}; use std::collections::BTreeMap;
use std::env; use std::fmt::Debug;
use std::fmt::{Debug, Display, Formatter};
use std::io::{stdin, stdout, Write};
use std::str::from_utf8;
use std::sync::{Arc, RwLock};
use std::time::Instant; use std::time::Instant;
use actix_cors::Cors;
use actix_web::{App, HttpRequest, HttpResponse, HttpServer, post, Responder, ResponseError, web};
use actix_web::body::BoxBody;
use actix_web::http::header::HeaderName;
use actix_web::middleware::Logger;
use actix_web::rt::task::spawn_blocking;
use clap::Parser; use clap::Parser;
use env_logger::Env; use env_logger::Env;
use log::debug; use rouille::{router, try_or_400, Response};
use miette::{bail, IntoDiagnostic, miette};
use rand::Rng;
use serde_json::json; use serde_json::json;
use sha3::Digest;
use cozo::{Db, DbBuilder}; use cozo::{Db, DbBuilder};
type Result<T> = std::result::Result<T, RespError>;
struct RespError {
err: miette::Error,
}
impl Debug for RespError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self.err)
}
}
impl Display for RespError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.err)
}
}
impl From<cozo::Error> for RespError {
fn from(err: cozo::Error) -> RespError {
RespError { err }
}
}
impl ResponseError for RespError {
fn error_response(&self) -> HttpResponse<BoxBody> {
let formatted = format!("{:?}", self.err);
HttpResponse::BadRequest().body(formatted)
}
}
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
#[clap(version, about, long_about = None)] #[clap(version, about, long_about = None)]
struct Args { struct Args {
@ -69,180 +25,7 @@ struct Args {
port: u16, port: u16,
} }
struct AppStateWithDb { fn main() {
db: Db,
pass_cache: Arc<RwLock<HashMap<String, Box<[u8]>>>>,
seed: Box<[u8]>,
}
const PASSWORD_KEY: &str = "WEB_USER_PASSWORD";
impl AppStateWithDb {
async fn verify_password(&self, req: &HttpRequest) -> miette::Result<()> {
let username = req
.headers()
.get(&HeaderName::from_static("x-cozo-username"))
.ok_or_else(|| miette!("not authenticated"))?
.to_str()
.into_diagnostic()?;
let password = req
.headers()
.get(&HeaderName::from_static("x-cozo-password"))
.ok_or_else(|| miette!("not authenticated"))?
.to_str()
.into_diagnostic()?;
let existing = self.pass_cache.try_read().map_err(|e| miette!(e.to_string()))?.get(username).cloned();
if let Some(stored) = existing {
let mut seed = self.seed.to_vec();
seed.extend_from_slice(password.as_bytes());
let digest: &[u8] = &sha3::Sha3_256::digest(&seed);
if *stored == *digest {
return Ok(());
} else {
self.pass_cache.try_write().map_err(|e| miette!(e.to_string()))?.remove(username);
bail!("invalid password")
}
}
let pass_cache = self.pass_cache.clone();
let mut seed = self.seed.to_vec();
let db = self.db.new_session()?;
let password = password.to_string();
let username = username.to_string();
spawn_blocking(move || -> miette::Result<()> {
if let Some(hashed) = db.get_meta_kv(&[PASSWORD_KEY, &username])? {
let hashed = from_utf8(&hashed).into_diagnostic()?;
if argon2::verify_encoded(&hashed, password.as_bytes()).into_diagnostic()? {
seed.extend_from_slice(password.as_bytes());
let easy_digest: &[u8] = &sha3::Sha3_256::digest(&seed);
pass_cache
.try_write().map_err(|e| miette!(e.to_string()))?
.insert(username, easy_digest.into());
return Ok(());
}
}
bail!("invalid password")
})
.await
.into_diagnostic()?
}
async fn reset_password(&self, user: &str, new_pass: &str) -> miette::Result<()> {
let pass_cache = self.pass_cache.clone();
let db = self.db.new_session()?;
let username = user.to_string();
let new_pass = new_pass.to_string();
spawn_blocking(move || -> miette::Result<()> {
pass_cache.try_write().map_err(|e| miette!(e.to_string()))?.remove(&username);
let salt = rand::thread_rng().gen::<[u8; 32]>();
let config = argon2config();
let hash =
argon2::hash_encoded(new_pass.as_bytes(), &salt, &config).into_diagnostic()?;
db.put_meta_kv(&[PASSWORD_KEY, &username], hash.as_bytes())?;
Ok(())
})
.await
.into_diagnostic()?
}
async fn remove_user(&self, user: &str) -> miette::Result<()> {
self.pass_cache.try_write().map_err(|e| miette!(e.to_string()))?.remove(user);
self.db.remove_meta_kv(&[PASSWORD_KEY, &user])?;
Ok(())
}
}
fn argon2config() -> argon2::Config<'static> {
argon2::Config {
variant: argon2::Variant::Argon2id,
mem_cost: 65536,
time_cost: 10,
..argon2::Config::default()
}
}
#[derive(serde_derive::Serialize, serde_derive::Deserialize)]
struct QueryPayload {
script: String,
params: BTreeMap<String, serde_json::Value>,
}
#[post("/text-query")]
async fn query(
body: web::Json<QueryPayload>,
data: web::Data<AppStateWithDb>,
req: HttpRequest,
) -> Result<impl Responder> {
data.verify_password(&req).await?;
let db = data.db.new_session()?;
let start = Instant::now();
let task = spawn_blocking(move || db.run_script(&body.script, &body.params));
let mut result = task.await.map_err(|e| miette!(e))??;
if let Some(obj) = result.as_object_mut() {
obj.insert(
"time_taken".to_string(),
json!(start.elapsed().as_millis() as u64),
);
}
Ok(HttpResponse::Ok().json(result))
}
#[derive(serde_derive::Serialize, serde_derive::Deserialize)]
struct ChangePassPayload {
new_pass: String,
}
#[post("/change-password")]
async fn change_password(
body: web::Json<ChangePassPayload>,
data: web::Data<AppStateWithDb>,
req: HttpRequest,
) -> Result<impl Responder> {
data.verify_password(&req).await?;
let username = req
.headers()
.get(&HeaderName::from_static("x-cozo-username"))
.ok_or_else(|| miette!("not authenticated"))?
.to_str()
.map_err(|e| miette!(e))?;
data.reset_password(username, &body.new_pass).await?;
Ok(HttpResponse::Ok().json(json!({"status": "OK"})))
}
#[derive(serde_derive::Serialize, serde_derive::Deserialize)]
struct AssertUserPayload {
username: String,
new_pass: String,
}
#[post("/assert-user")]
async fn assert_user(
body: web::Json<AssertUserPayload>,
data: web::Data<AppStateWithDb>,
req: HttpRequest,
) -> Result<impl Responder> {
data.verify_password(&req).await?;
data.reset_password(&body.username, &body.new_pass).await?;
Ok(HttpResponse::Ok().json(json!({"status": "OK"})))
}
#[derive(serde_derive::Serialize, serde_derive::Deserialize)]
struct RemoveUserPayload {
username: String,
}
#[post("/remove-user")]
async fn remove_user(
body: web::Json<RemoveUserPayload>,
data: web::Data<AppStateWithDb>,
req: HttpRequest,
) -> Result<impl Responder> {
data.verify_password(&req).await?;
data.remove_user(&body.username).await?;
Ok(HttpResponse::Ok().json(json!({"status": "OK"})))
}
#[actix_web::main]
async fn main() -> std::io::Result<()> {
env_logger::Builder::from_env(Env::default().default_filter_or("info")).init(); env_logger::Builder::from_env(Env::default().default_filter_or("info")).init();
let args = Args::parse(); let args = Args::parse();
@ -251,78 +34,34 @@ async fn main() -> std::io::Result<()> {
.create_if_missing(true); .create_if_missing(true);
let db = Db::build(builder).unwrap(); let db = Db::build(builder).unwrap();
let addr = format!("{}:{}", args.bind, args.port);
println!("Service running at http://{}", addr);
rouille::start_server(addr, move |request| {
router!(request,
(POST) (/text-query) => {
#[derive(serde_derive::Serialize, serde_derive::Deserialize)]
struct QueryPayload {
script: String,
params: BTreeMap<String, serde_json::Value>,
}
match db.meta_range_scan(&[PASSWORD_KEY]).next() { let payload: QueryPayload = try_or_400!(rouille::input::json_input(request));
None => { let start = Instant::now();
let (username, password) = match (
env::var("COZO_INITIAL_WEB_USERNAME"),
env::var("COZO_INITIAL_WEB_PASSWORD"),
) {
(Ok(username), Ok(password)) => (username, password),
_ => {
println!("Welcome to Cozo!");
println!();
println!(
"This is the first time you are running this database at {},",
args.path
);
println!("so let's create a username and password.");
loop {
println!();
print!("Enter a username: ");
let _ = stdout().flush();
let mut username = String::new();
stdin().read_line(&mut username).unwrap();
let username = username.trim().to_string();
if username.is_empty() {
continue;
}
let password = rpassword::prompt_password("Enter your password: ").unwrap();
let confpass = rpassword::prompt_password("Again to confirm it: ").unwrap();
if password.trim() != confpass.trim() { match db.run_script(&payload.script, &payload.params) {
println!("Password mismatch. Try again."); Ok(mut result) => {
continue; if let Some(obj) = result.as_object_mut() {
obj.insert(
"time_taken".to_string(),
json!(start.elapsed().as_millis() as u64),
);
} }
break (username, password.trim().to_string()); Response::json(&result)
} }
Err(e) => Response::text(format!("{:?}", e)).with_status_code(400),
} }
}; },
_ => Response::empty_404()
let salt = rand::thread_rng().gen::<[u8; 32]>(); )
let config = argon2config();
let hash = argon2::hash_encoded(password.trim().as_bytes(), &salt, &config).unwrap();
db.put_meta_kv(&[PASSWORD_KEY, &username], hash.as_bytes())
.unwrap();
}
Some(Err(err)) => panic!("{}", err),
Some(Ok((user, _))) => {
debug!("User {:?}", user[1]);
}
}
let app_state = web::Data::new(AppStateWithDb {
db,
pass_cache: Arc::new(Default::default()),
seed: Box::new(rand::thread_rng().gen::<[u8; 32]>()),
}); });
}
let addr = (&args.bind as &str, args.port);
println!("Service running at http://{}:{}", addr.0, addr.1);
HttpServer::new(move || {
App::new()
.app_data(app_state.clone())
.wrap(Cors::permissive())
.wrap(Logger::default())
.service(query)
.service(change_password)
.service(assert_user)
.service(remove_user)
})
.bind(addr)?
.run()
.await
}

@ -118,16 +118,16 @@ impl AlgoImpl for CsvReader {
} }
}), }),
ColType::Int => { ColType::Int => {
let f = op_to_float(&[dv]).unwrap_or(DataValue::Null); let f = op_to_float(&[dv]).unwrap_or(DataValue::Null);
match f.get_int() { match f.get_int() {
None => {if typ.nullable { None => {
out_tuple.0.push(DataValue::Null) if typ.nullable {
} else { out_tuple.0.push(DataValue::Null)
bail!("cannot convert {} to type {}", s, typ) } else {
}} bail!("cannot convert {} to type {}", s, typ)
Some(i) => { }
out_tuple.0.push(DataValue::from(i))
} }
Some(i) => out_tuple.0.push(DataValue::from(i)),
}; };
} }
_ => bail!("cannot convert {} to type {}", s, typ), _ => bail!("cannot convert {} to type {}", s, typ),
@ -149,9 +149,8 @@ impl AlgoImpl for CsvReader {
} }
} }
None => { None => {
let content = reqwest::blocking::get(&url as &str) let content = minreq::get(&url as &str)
.into_diagnostic()? .send()
.text()
.into_diagnostic()?; .into_diagnostic()?;
let mut rdr = rdr_builder.from_reader(content.as_bytes()); let mut rdr = rdr_builder.from_reader(content.as_bytes());
for record in rdr.records() { for record in rdr.records() {

@ -98,10 +98,8 @@ impl AlgoImpl for JsonReader {
} }
} }
None => { None => {
let content = reqwest::blocking::get(&url as &str) let content = minreq::get(&url as &str).send().into_diagnostic()?;
.into_diagnostic()? let content = content.as_str().into_diagnostic()?;
.text()
.into_diagnostic()?;
if json_lines { if json_lines {
for line in content.lines() { for line in content.lines() {
let line = line.trim(); let line = line.trim();

@ -14,13 +14,13 @@ use serde_json::json;
use smartstring::SmartString; use smartstring::SmartString;
use thiserror::Error; use thiserror::Error;
use cozorocks::{DbBuilder, DbIter, RocksDb}; use cozorocks::{DbBuilder, RocksDb};
use crate::data::json::JsonValue; use crate::data::json::JsonValue;
use crate::data::program::{InputProgram, QueryAssertion, RelationOp}; use crate::data::program::{InputProgram, QueryAssertion, RelationOp};
use crate::data::symb::Symbol; use crate::data::symb::Symbol;
use crate::data::tuple::{ use crate::data::tuple::{
compare_tuple_keys, rusty_scratch_cmp, EncodedTuple, Tuple, SCRATCH_DB_KEY_PREFIX_LEN, compare_tuple_keys, rusty_scratch_cmp, Tuple, SCRATCH_DB_KEY_PREFIX_LEN,
}; };
use crate::data::value::{DataValue, LARGEST_UTF_CHAR}; use crate::data::value::{DataValue, LARGEST_UTF_CHAR};
use crate::parse::sys::SysOp; use crate::parse::sys::SysOp;
@ -32,7 +32,6 @@ use crate::query::relation::{
}; };
use crate::runtime::relation::{RelationHandle, RelationId}; use crate::runtime::relation::{RelationHandle, RelationId};
use crate::runtime::transact::SessionTx; use crate::runtime::transact::SessionTx;
use crate::utils::swap_option_result;
struct RunningQueryHandle { struct RunningQueryHandle {
started_at: f64, started_at: f64,
@ -668,113 +667,6 @@ impl Db {
.collect_vec(); .collect_vec();
Ok(json!({"rows": res, "headers": ["id", "started_at"]})) Ok(json!({"rows": res, "headers": ["id", "started_at"]}))
} }
pub fn put_meta_kv(&self, k: &[&str], v: &[u8]) -> Result<()> {
let mut ks = vec![DataValue::Guard];
for el in k {
ks.push(DataValue::Str(SmartString::from(*el)));
}
let key = Tuple(ks).encode_as_key(RelationId::SYSTEM);
let mut vtx = self.db.transact().start();
vtx.put(&key, v)?;
vtx.commit()?;
Ok(())
}
pub fn remove_meta_kv(&self, k: &[&str]) -> Result<()> {
let mut ks = vec![DataValue::Guard];
for el in k {
ks.push(DataValue::Str(SmartString::from(*el)));
}
let key = Tuple(ks).encode_as_key(RelationId::SYSTEM);
let mut vtx = self.db.transact().start();
vtx.del(&key)?;
vtx.commit()?;
Ok(())
}
pub fn get_meta_kv(&self, k: &[&str]) -> Result<Option<Vec<u8>>> {
let mut ks = vec![DataValue::Guard];
for el in k {
ks.push(DataValue::Str(SmartString::from(*el)));
}
let key = Tuple(ks).encode_as_key(RelationId::SYSTEM);
let vtx = self.db.transact().start();
Ok(vtx.get(&key, false)?.map(|slice| slice.to_vec()))
}
pub fn meta_range_scan(
&self,
prefix: &[&str],
) -> impl Iterator<Item = Result<(Vec<String>, Vec<u8>)>> {
let mut lower_bound = Tuple(vec![DataValue::Guard]);
for p in prefix {
lower_bound.0.push(DataValue::Str(SmartString::from(*p)));
}
let upper_data_bound = Tuple(vec![DataValue::Bot]);
let upper_bound = upper_data_bound.encode_as_key(RelationId::SYSTEM);
let mut it = self
.db
.transact()
.start()
.iterator()
.upper_bound(&upper_bound)
.start();
it.seek(&lower_bound.encode_as_key(RelationId::SYSTEM));
struct CustomIter {
it: DbIter,
started: bool,
upper_bound: Vec<u8>,
}
impl CustomIter {
fn next_inner(&mut self) -> Result<Option<(Vec<String>, Vec<u8>)>> {
if self.started {
self.it.next()
} else {
self.started = true;
}
match self.it.pair()? {
None => Ok(None),
Some((k_slice, v_slice)) => {
if compare_tuple_keys(&self.upper_bound, k_slice) != Greater {
return Ok(None);
}
#[derive(Debug, Error, Diagnostic)]
#[error("Encountered corrupt key in meta store")]
#[diagnostic(code(db::corrupt_meta_key))]
#[diagnostic(help("This is an internal error. Please file a bug."))]
struct CorruptKeyInMetaStoreError;
let encoded = EncodedTuple(k_slice).decode();
let ks: Vec<_> = encoded
.0
.into_iter()
.skip(1)
.map(|v| {
v.get_string()
.map(|s| s.to_string())
.ok_or(CorruptKeyInMetaStoreError)
})
.try_collect()?;
Ok(Some((ks, v_slice.to_vec())))
}
}
}
}
impl Iterator for CustomIter {
type Item = Result<(Vec<String>, Vec<u8>)>;
fn next(&mut self) -> Option<Self::Item> {
swap_option_result(self.next_inner())
}
}
CustomIter {
it,
started: false,
upper_bound,
}
}
pub fn list_relation(&self, name: &str) -> Result<JsonValue> { pub fn list_relation(&self, name: &str) -> Result<JsonValue> {
let tx = self.transact()?; let tx = self.transact()?;
let handle = tx.get_relation(name, false)?; let handle = tx.get_relation(name, false)?;

@ -106,8 +106,7 @@
"id": "81436147-dd67-4f0e-a9c0-d6c18c37495c", "id": "81436147-dd67-4f0e-a9c0-d6c18c37495c",
"metadata": {}, "metadata": {},
"source": [ "source": [
"%reload_ext pycozo.ipyext_direct\n", "%reload_ext pycozo.ipyext_direct"
"%cozo_auth YOUR_USERNAME"
] ]
}, },
{ {

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save