towards minimal server

main
Ziyang Hu 2 years ago
parent 1a4e9fda13
commit 1a072ab1ce

1192
Cargo.lock generated

File diff suppressed because it is too large Load Diff

@ -37,7 +37,7 @@ pest = "2.2.1"
pest_derive = "2.2.1"
rayon = "1.5.3"
nalgebra = "0.31.1"
reqwest = { version = "0.11.11", features = ["blocking"] }
minreq = { version = "2.6.0", features = ["https-native"] }
approx = "0.5.1"
unicode-normalization = "0.1.21"
thiserror = "1.0.34"

@ -7,21 +7,15 @@ license = "AGPL-3.0-or-later"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
actix-web = "4.1.0"
clap = { version = "3.2.8", features = ["derive"] }
actix-cors = "0.6.1"
log = "0.4.16"
miette = { version = "5.3.0", features = ["fancy"] }
env_logger = "0.9.0"
serde = "1.0.144"
serde_json = "1.0.81"
serde_derive = "1.0.144"
rust-argon2 = "1.0.0"
sha3 = "0.10.4"
rand = "0.8.5"
rpassword = "7.0.0"
webbrowser = "0.8.0"
futures = "0.3.24"
rouille = "3.5.0"
cozo = { path = ".." }
[build-dependencies]

@ -1,58 +1,14 @@
use std::collections::{BTreeMap, HashMap};
use std::env;
use std::fmt::{Debug, Display, Formatter};
use std::io::{stdin, stdout, Write};
use std::str::from_utf8;
use std::sync::{Arc, RwLock};
use std::collections::BTreeMap;
use std::fmt::Debug;
use std::time::Instant;
use actix_cors::Cors;
use actix_web::{App, HttpRequest, HttpResponse, HttpServer, post, Responder, ResponseError, web};
use actix_web::body::BoxBody;
use actix_web::http::header::HeaderName;
use actix_web::middleware::Logger;
use actix_web::rt::task::spawn_blocking;
use clap::Parser;
use env_logger::Env;
use log::debug;
use miette::{bail, IntoDiagnostic, miette};
use rand::Rng;
use rouille::{router, try_or_400, Response};
use serde_json::json;
use sha3::Digest;
use cozo::{Db, DbBuilder};
type Result<T> = std::result::Result<T, RespError>;
struct RespError {
err: miette::Error,
}
impl Debug for RespError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self.err)
}
}
impl Display for RespError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.err)
}
}
impl From<cozo::Error> for RespError {
fn from(err: cozo::Error) -> RespError {
RespError { err }
}
}
impl ResponseError for RespError {
fn error_response(&self) -> HttpResponse<BoxBody> {
let formatted = format!("{:?}", self.err);
HttpResponse::BadRequest().body(formatted)
}
}
#[derive(Parser, Debug)]
#[clap(version, about, long_about = None)]
struct Args {
@ -69,180 +25,7 @@ struct Args {
port: u16,
}
struct AppStateWithDb {
db: Db,
pass_cache: Arc<RwLock<HashMap<String, Box<[u8]>>>>,
seed: Box<[u8]>,
}
const PASSWORD_KEY: &str = "WEB_USER_PASSWORD";
impl AppStateWithDb {
async fn verify_password(&self, req: &HttpRequest) -> miette::Result<()> {
let username = req
.headers()
.get(&HeaderName::from_static("x-cozo-username"))
.ok_or_else(|| miette!("not authenticated"))?
.to_str()
.into_diagnostic()?;
let password = req
.headers()
.get(&HeaderName::from_static("x-cozo-password"))
.ok_or_else(|| miette!("not authenticated"))?
.to_str()
.into_diagnostic()?;
let existing = self.pass_cache.try_read().map_err(|e| miette!(e.to_string()))?.get(username).cloned();
if let Some(stored) = existing {
let mut seed = self.seed.to_vec();
seed.extend_from_slice(password.as_bytes());
let digest: &[u8] = &sha3::Sha3_256::digest(&seed);
if *stored == *digest {
return Ok(());
} else {
self.pass_cache.try_write().map_err(|e| miette!(e.to_string()))?.remove(username);
bail!("invalid password")
}
}
let pass_cache = self.pass_cache.clone();
let mut seed = self.seed.to_vec();
let db = self.db.new_session()?;
let password = password.to_string();
let username = username.to_string();
spawn_blocking(move || -> miette::Result<()> {
if let Some(hashed) = db.get_meta_kv(&[PASSWORD_KEY, &username])? {
let hashed = from_utf8(&hashed).into_diagnostic()?;
if argon2::verify_encoded(&hashed, password.as_bytes()).into_diagnostic()? {
seed.extend_from_slice(password.as_bytes());
let easy_digest: &[u8] = &sha3::Sha3_256::digest(&seed);
pass_cache
.try_write().map_err(|e| miette!(e.to_string()))?
.insert(username, easy_digest.into());
return Ok(());
}
}
bail!("invalid password")
})
.await
.into_diagnostic()?
}
async fn reset_password(&self, user: &str, new_pass: &str) -> miette::Result<()> {
let pass_cache = self.pass_cache.clone();
let db = self.db.new_session()?;
let username = user.to_string();
let new_pass = new_pass.to_string();
spawn_blocking(move || -> miette::Result<()> {
pass_cache.try_write().map_err(|e| miette!(e.to_string()))?.remove(&username);
let salt = rand::thread_rng().gen::<[u8; 32]>();
let config = argon2config();
let hash =
argon2::hash_encoded(new_pass.as_bytes(), &salt, &config).into_diagnostic()?;
db.put_meta_kv(&[PASSWORD_KEY, &username], hash.as_bytes())?;
Ok(())
})
.await
.into_diagnostic()?
}
async fn remove_user(&self, user: &str) -> miette::Result<()> {
self.pass_cache.try_write().map_err(|e| miette!(e.to_string()))?.remove(user);
self.db.remove_meta_kv(&[PASSWORD_KEY, &user])?;
Ok(())
}
}
fn argon2config() -> argon2::Config<'static> {
argon2::Config {
variant: argon2::Variant::Argon2id,
mem_cost: 65536,
time_cost: 10,
..argon2::Config::default()
}
}
#[derive(serde_derive::Serialize, serde_derive::Deserialize)]
struct QueryPayload {
script: String,
params: BTreeMap<String, serde_json::Value>,
}
#[post("/text-query")]
async fn query(
body: web::Json<QueryPayload>,
data: web::Data<AppStateWithDb>,
req: HttpRequest,
) -> Result<impl Responder> {
data.verify_password(&req).await?;
let db = data.db.new_session()?;
let start = Instant::now();
let task = spawn_blocking(move || db.run_script(&body.script, &body.params));
let mut result = task.await.map_err(|e| miette!(e))??;
if let Some(obj) = result.as_object_mut() {
obj.insert(
"time_taken".to_string(),
json!(start.elapsed().as_millis() as u64),
);
}
Ok(HttpResponse::Ok().json(result))
}
#[derive(serde_derive::Serialize, serde_derive::Deserialize)]
struct ChangePassPayload {
new_pass: String,
}
#[post("/change-password")]
async fn change_password(
body: web::Json<ChangePassPayload>,
data: web::Data<AppStateWithDb>,
req: HttpRequest,
) -> Result<impl Responder> {
data.verify_password(&req).await?;
let username = req
.headers()
.get(&HeaderName::from_static("x-cozo-username"))
.ok_or_else(|| miette!("not authenticated"))?
.to_str()
.map_err(|e| miette!(e))?;
data.reset_password(username, &body.new_pass).await?;
Ok(HttpResponse::Ok().json(json!({"status": "OK"})))
}
#[derive(serde_derive::Serialize, serde_derive::Deserialize)]
struct AssertUserPayload {
username: String,
new_pass: String,
}
#[post("/assert-user")]
async fn assert_user(
body: web::Json<AssertUserPayload>,
data: web::Data<AppStateWithDb>,
req: HttpRequest,
) -> Result<impl Responder> {
data.verify_password(&req).await?;
data.reset_password(&body.username, &body.new_pass).await?;
Ok(HttpResponse::Ok().json(json!({"status": "OK"})))
}
#[derive(serde_derive::Serialize, serde_derive::Deserialize)]
struct RemoveUserPayload {
username: String,
}
#[post("/remove-user")]
async fn remove_user(
body: web::Json<RemoveUserPayload>,
data: web::Data<AppStateWithDb>,
req: HttpRequest,
) -> Result<impl Responder> {
data.verify_password(&req).await?;
data.remove_user(&body.username).await?;
Ok(HttpResponse::Ok().json(json!({"status": "OK"})))
}
#[actix_web::main]
async fn main() -> std::io::Result<()> {
fn main() {
env_logger::Builder::from_env(Env::default().default_filter_or("info")).init();
let args = Args::parse();
@ -251,78 +34,34 @@ async fn main() -> std::io::Result<()> {
.create_if_missing(true);
let db = Db::build(builder).unwrap();
match db.meta_range_scan(&[PASSWORD_KEY]).next() {
None => {
let (username, password) = match (
env::var("COZO_INITIAL_WEB_USERNAME"),
env::var("COZO_INITIAL_WEB_PASSWORD"),
) {
(Ok(username), Ok(password)) => (username, password),
_ => {
println!("Welcome to Cozo!");
println!();
println!(
"This is the first time you are running this database at {},",
args.path
);
println!("so let's create a username and password.");
loop {
println!();
print!("Enter a username: ");
let _ = stdout().flush();
let mut username = String::new();
stdin().read_line(&mut username).unwrap();
let username = username.trim().to_string();
if username.is_empty() {
continue;
let addr = format!("{}:{}", args.bind, args.port);
println!("Service running at http://{}", addr);
rouille::start_server(addr, move |request| {
router!(request,
(POST) (/text-query) => {
#[derive(serde_derive::Serialize, serde_derive::Deserialize)]
struct QueryPayload {
script: String,
params: BTreeMap<String, serde_json::Value>,
}
let password = rpassword::prompt_password("Enter your password: ").unwrap();
let confpass = rpassword::prompt_password("Again to confirm it: ").unwrap();
if password.trim() != confpass.trim() {
println!("Password mismatch. Try again.");
continue;
}
break (username, password.trim().to_string());
}
}
};
let payload: QueryPayload = try_or_400!(rouille::input::json_input(request));
let start = Instant::now();
let salt = rand::thread_rng().gen::<[u8; 32]>();
let config = argon2config();
let hash = argon2::hash_encoded(password.trim().as_bytes(), &salt, &config).unwrap();
db.put_meta_kv(&[PASSWORD_KEY, &username], hash.as_bytes())
.unwrap();
match db.run_script(&payload.script, &payload.params) {
Ok(mut result) => {
if let Some(obj) = result.as_object_mut() {
obj.insert(
"time_taken".to_string(),
json!(start.elapsed().as_millis() as u64),
);
}
Some(Err(err)) => panic!("{}", err),
Some(Ok((user, _))) => {
debug!("User {:?}", user[1]);
Response::json(&result)
}
Err(e) => Response::text(format!("{:?}", e)).with_status_code(400),
}
let app_state = web::Data::new(AppStateWithDb {
db,
pass_cache: Arc::new(Default::default()),
seed: Box::new(rand::thread_rng().gen::<[u8; 32]>()),
},
_ => Response::empty_404()
)
});
let addr = (&args.bind as &str, args.port);
println!("Service running at http://{}:{}", addr.0, addr.1);
HttpServer::new(move || {
App::new()
.app_data(app_state.clone())
.wrap(Cors::permissive())
.wrap(Logger::default())
.service(query)
.service(change_password)
.service(assert_user)
.service(remove_user)
})
.bind(addr)?
.run()
.await
}

@ -120,14 +120,14 @@ impl AlgoImpl for CsvReader {
ColType::Int => {
let f = op_to_float(&[dv]).unwrap_or(DataValue::Null);
match f.get_int() {
None => {if typ.nullable {
None => {
if typ.nullable {
out_tuple.0.push(DataValue::Null)
} else {
bail!("cannot convert {} to type {}", s, typ)
}}
Some(i) => {
out_tuple.0.push(DataValue::from(i))
}
}
Some(i) => out_tuple.0.push(DataValue::from(i)),
};
}
_ => bail!("cannot convert {} to type {}", s, typ),
@ -149,9 +149,8 @@ impl AlgoImpl for CsvReader {
}
}
None => {
let content = reqwest::blocking::get(&url as &str)
.into_diagnostic()?
.text()
let content = minreq::get(&url as &str)
.send()
.into_diagnostic()?;
let mut rdr = rdr_builder.from_reader(content.as_bytes());
for record in rdr.records() {

@ -98,10 +98,8 @@ impl AlgoImpl for JsonReader {
}
}
None => {
let content = reqwest::blocking::get(&url as &str)
.into_diagnostic()?
.text()
.into_diagnostic()?;
let content = minreq::get(&url as &str).send().into_diagnostic()?;
let content = content.as_str().into_diagnostic()?;
if json_lines {
for line in content.lines() {
let line = line.trim();

@ -14,13 +14,13 @@ use serde_json::json;
use smartstring::SmartString;
use thiserror::Error;
use cozorocks::{DbBuilder, DbIter, RocksDb};
use cozorocks::{DbBuilder, RocksDb};
use crate::data::json::JsonValue;
use crate::data::program::{InputProgram, QueryAssertion, RelationOp};
use crate::data::symb::Symbol;
use crate::data::tuple::{
compare_tuple_keys, rusty_scratch_cmp, EncodedTuple, Tuple, SCRATCH_DB_KEY_PREFIX_LEN,
compare_tuple_keys, rusty_scratch_cmp, Tuple, SCRATCH_DB_KEY_PREFIX_LEN,
};
use crate::data::value::{DataValue, LARGEST_UTF_CHAR};
use crate::parse::sys::SysOp;
@ -32,7 +32,6 @@ use crate::query::relation::{
};
use crate::runtime::relation::{RelationHandle, RelationId};
use crate::runtime::transact::SessionTx;
use crate::utils::swap_option_result;
struct RunningQueryHandle {
started_at: f64,
@ -668,113 +667,6 @@ impl Db {
.collect_vec();
Ok(json!({"rows": res, "headers": ["id", "started_at"]}))
}
pub fn put_meta_kv(&self, k: &[&str], v: &[u8]) -> Result<()> {
let mut ks = vec![DataValue::Guard];
for el in k {
ks.push(DataValue::Str(SmartString::from(*el)));
}
let key = Tuple(ks).encode_as_key(RelationId::SYSTEM);
let mut vtx = self.db.transact().start();
vtx.put(&key, v)?;
vtx.commit()?;
Ok(())
}
pub fn remove_meta_kv(&self, k: &[&str]) -> Result<()> {
let mut ks = vec![DataValue::Guard];
for el in k {
ks.push(DataValue::Str(SmartString::from(*el)));
}
let key = Tuple(ks).encode_as_key(RelationId::SYSTEM);
let mut vtx = self.db.transact().start();
vtx.del(&key)?;
vtx.commit()?;
Ok(())
}
pub fn get_meta_kv(&self, k: &[&str]) -> Result<Option<Vec<u8>>> {
let mut ks = vec![DataValue::Guard];
for el in k {
ks.push(DataValue::Str(SmartString::from(*el)));
}
let key = Tuple(ks).encode_as_key(RelationId::SYSTEM);
let vtx = self.db.transact().start();
Ok(vtx.get(&key, false)?.map(|slice| slice.to_vec()))
}
pub fn meta_range_scan(
&self,
prefix: &[&str],
) -> impl Iterator<Item = Result<(Vec<String>, Vec<u8>)>> {
let mut lower_bound = Tuple(vec![DataValue::Guard]);
for p in prefix {
lower_bound.0.push(DataValue::Str(SmartString::from(*p)));
}
let upper_data_bound = Tuple(vec![DataValue::Bot]);
let upper_bound = upper_data_bound.encode_as_key(RelationId::SYSTEM);
let mut it = self
.db
.transact()
.start()
.iterator()
.upper_bound(&upper_bound)
.start();
it.seek(&lower_bound.encode_as_key(RelationId::SYSTEM));
struct CustomIter {
it: DbIter,
started: bool,
upper_bound: Vec<u8>,
}
impl CustomIter {
fn next_inner(&mut self) -> Result<Option<(Vec<String>, Vec<u8>)>> {
if self.started {
self.it.next()
} else {
self.started = true;
}
match self.it.pair()? {
None => Ok(None),
Some((k_slice, v_slice)) => {
if compare_tuple_keys(&self.upper_bound, k_slice) != Greater {
return Ok(None);
}
#[derive(Debug, Error, Diagnostic)]
#[error("Encountered corrupt key in meta store")]
#[diagnostic(code(db::corrupt_meta_key))]
#[diagnostic(help("This is an internal error. Please file a bug."))]
struct CorruptKeyInMetaStoreError;
let encoded = EncodedTuple(k_slice).decode();
let ks: Vec<_> = encoded
.0
.into_iter()
.skip(1)
.map(|v| {
v.get_string()
.map(|s| s.to_string())
.ok_or(CorruptKeyInMetaStoreError)
})
.try_collect()?;
Ok(Some((ks, v_slice.to_vec())))
}
}
}
}
impl Iterator for CustomIter {
type Item = Result<(Vec<String>, Vec<u8>)>;
fn next(&mut self) -> Option<Self::Item> {
swap_option_result(self.next_inner())
}
}
CustomIter {
it,
started: false,
upper_bound,
}
}
pub fn list_relation(&self, name: &str) -> Result<JsonValue> {
let tx = self.transact()?;
let handle = tx.get_relation(name, false)?;

@ -106,8 +106,7 @@
"id": "81436147-dd67-4f0e-a9c0-d6c18c37495c",
"metadata": {},
"source": [
"%reload_ext pycozo.ipyext_direct\n",
"%cozo_auth YOUR_USERNAME"
"%reload_ext pycozo.ipyext_direct"
]
},
{

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save