token table; disable multi-table transactions and custom rules for remote

main
Ziyang Hu 1 year ago
parent 6b8438e823
commit cd840e687c

@ -2,11 +2,14 @@
[![server](https://img.shields.io/github/v/release/cozodb/cozo)](https://github.com/cozodb/cozo/releases) [![server](https://img.shields.io/github/v/release/cozodb/cozo)](https://github.com/cozodb/cozo/releases)
本文叙述的是如何安装设置 Cozo 的独立程序本身。有关如何使用 CozoDBCozoScript的信息见 [文档](https://docs.cozodb.org/zh_CN/latest/index.html) 。 本文叙述的是如何安装设置 Cozo 的独立程序本身。有关如何使用
CozoDBCozoScript的信息见 [文档](https://docs.cozodb.org/zh_CN/latest/index.html) 。
## 下载 ## 下载
独立服务的程序可以从 [GitHub 发布页](https://github.com/cozodb/cozo/releases) 或 [Gitee 发布页](https://gitee.com/cozodb/cozo/releases) 下载,其中名为 `cozo-*` 的是独立服务程序,名为 `cozo_all-*` 的独立程序同时支持更多地存储引擎,比如 [TiKV](https://tikv.org/)。 独立服务的程序可以从 [GitHub 发布页](https://github.com/cozodb/cozo/releases)
或 [Gitee 发布页](https://gitee.com/cozodb/cozo/releases) 下载,其中名为 `cozo-*` 的是独立服务程序,名为 `cozo_all-*`
的独立程序同时支持更多地存储引擎,比如 [TiKV](https://tikv.org/)。
## 启动服务程序 ## 启动服务程序
@ -39,29 +42,38 @@
## 查询 API ## 查询 API
查询通过向 API 发送 POST 请求来完成。默认的请求地址是 `http://127.0.0.1:9070/text-query` 。请求必须包含 JSON 格式的正文,具体内容如下: 查询通过向 API 发送 POST 请求来完成。默认的请求地址是 `http://127.0.0.1:9070/text-query` 。请求必须包含 JSON 格式的正文,具体内容如下:
```json ```json
{ {
"script": "<COZOSCRIPT QUERY STRING>", "script": "<COZOSCRIPT QUERY STRING>",
"params": {} "params": {}
} }
``` ```
`params` 给出了查询文本中可用的变量。例如,当 `params``{"num": 1}` 时,查询文本中可以以 `$num` 来代替常量 `1`。请善用此功能,而不是手动拼接查询字符串。
HTTP API 返回的结果永远是 JSON 格式的。如果请求成功,则返回结果的 `"ok"` 字段将为 `true`,且 `"rows"` 字段将含有查询结果的行,而 `"headers"` 将含有表头。如果查询报错,则 `"ok"` 字段将为 `false`,而错误信息会在 `"message"` 字段中,同时 `"display"` 字段会包含格式化好的友好的错误提示。 `params` 给出了查询文本中可用的变量。例如,当 `params``{"num": 1}` 时,查询文本中可以以 `$num` 来代替常量 `1`
。请善用此功能,而不是手动拼接查询字符串。
HTTP API 返回的结果永远是 JSON 格式的。如果请求成功,则返回结果的 `"ok"` 字段将为 `true`,且 `"rows"`
字段将含有查询结果的行,而 `"headers"` 将含有表头。如果查询报错,则 `"ok"` 字段将为 `false`,而错误信息会在 `"message"`
字段中,同时 `"display"` 字段会包含格式化好的友好的错误提示。
> Cozo 的设计,基于其在一个受信任的环境中运行,且其所有用户也是由受信任的这种假设。因此 Cozo 没有内置认证与复杂的安全机制。如果你需要远程访问 Cozo 服务,你必须自己设置防火墙、加密和代理等,用来保护服务器上资源的安全。 > Cozo 的设计,基于其在一个受信任的环境中运行,且其所有用户也是由受信任的这种假设。因此 Cozo 没有内置认证与复杂的安全机制。如果你需要远程访问
> > Cozo 服务,你必须自己设置防火墙、加密和代理等,用来保护服务器上资源的安全。
> 由于总是会有用户不小心将服务接口暴露于外网Cozo 有一个补救措施:如果允许了非回传地址访问 Cozo则必须在所有请求中以 HTTP 文件头 `x-cozo-auth` 的形式附上访问令牌。访问令牌的内容在启动服务的终端中有提示。注意这仅仅是一个补救措施,并不是特别可靠的安全机制,是为了尽量防止一些不由于小心而造成严重后果的悲剧。 >
> > 由于总是会有用户不小心将服务接口暴露于外网Cozo 有一个补救措施:如果允许了非回传地址访问 Cozo则必须在所有请求中以 HTTP
> 文件头 `x-cozo-auth` 的形式附上访问令牌。访问令牌的内容在启动服务的终端中有提示。注意这仅仅是一个补救措施,并不是特别可靠的安全机制,是为了尽量防止一些不由于小心而造成严重后果的悲剧。
>
> 在有些客户端里,部分 API 加入 HTTP 文件头可能会比较困难或者根本不可能。这时也可以在查询参数 `auth` 中传入令牌。 > 在有些客户端里,部分 API 加入 HTTP 文件头可能会比较困难或者根本不可能。这时也可以在查询参数 `auth` 中传入令牌。
## 所有 API ## 所有 API
* `POST /text-query`,见上。 * `POST /text-query`,见上。
* `GET /export/{relations: String}`,导出指定表中的数据,其中 `relations` 是以逗号分割的表名。 * `GET /export/{relations: String}`,导出指定表中的数据,其中 `relations` 是以逗号分割的表名。
* `PUT /import`,向数据库导入数据。所导入的数据应以在正文中以 `application/json` MIME 类型传入,具体格式与 `/export` 返回值中的 `data` 字段相同。 * `PUT /import`,向数据库导入数据。所导入的数据应以在正文中以 `application/json` MIME 类型传入,具体格式与 `/export`
返回值中的 `data` 字段相同。
* `POST /backup`,备份数据库,需要传入 JSON 正文 `{"path": <路径>}` * `POST /backup`,备份数据库,需要传入 JSON 正文 `{"path": <路径>}`
* `POST /import-from-backup`,将备份中指定存储表中的数据插入当前数据库中同名存储表。需要传入 JSON 正文 `{"path": <路径>, "relations": <表名数组>}`. * `POST /import-from-backup`,将备份中指定存储表中的数据插入当前数据库中同名存储表。需要传入 JSON
正文 `{"path": <路径>, "relations": <表名数组>}`.
* `GET /`,用浏览器打开这个地址,然后打开浏览器的调试工具,就可以使用一个简陋的 JS 客户端。 * `GET /`,用浏览器打开这个地址,然后打开浏览器的调试工具,就可以使用一个简陋的 JS 客户端。
> 注意 `import``import-from-backup` 接口不会激活任何触发器。 > 注意 `import``import-from-backup` 接口不会激活任何触发器。
@ -69,12 +81,8 @@ HTTP API 返回的结果永远是 JSON 格式的。如果请求成功,则返
以下为试验性的 API 以下为试验性的 API
* `GET(SSE) /changes/{relation: String}` 获取某个存储表的更新,基于 [SSE](https://developer.mozilla.org/zh-CN/docs/Web/API/Server-sent_events/Using_server-sent_events). * `GET(SSE) /changes/{relation: String}`
* `GET(SSE) /rules/{name: String}` 注册一个自定义的固定规则。查询参数 `arity` 是必须的。 获取某个存储表的更新,基于 [SSE](https://developer.mozilla.org/zh-CN/docs/Web/API/Server-sent_events/Using_server-sent_events).
* `POST /rule-result/{id}` 将固定规则的计算结果回传给服务器,配合上一个 API 使用。
* `POST /transact` 开始一个多语句的事务。返回的 ID 在下面几个 API 中使用。如果要进行写操作,则需要传入 `write=true` 查询参数。
* `POST /transact/{id}` 在多语句事务中进行查询。要求的正文与 `/text-query` 所要求的相同。
* `PUT /transact/{id}` 提交或放弃多语句事务。要求的正文是 JSON `{"abort": <bool>}`,传入真值则放弃,否则提交。如果不执行此查询则服务器会浪费系统资源。
## 编译 ## 编译

@ -38,24 +38,27 @@ You can use the following meta ops in the REPL:
* `%clear`: unset all parameters. * `%clear`: unset all parameters.
* `%params`: print all set parameters. * `%params`: print all set parameters.
* `%run <FILE>`: run the script contained in `<FILE>`. * `%run <FILE>`: run the script contained in `<FILE>`.
* `%import <FILE OR URL>`: import data in JSON format from the file or URL. * `%import <FILE OR URL>`: import data in JSON format from the file or URL.
* `%save <FILE>`: the result of the next successful query will be saved in JSON format in a file instead of printed on screen. If `<FILE>` is omitted, then the effect of any previous `%save` command is nullified. * `%save <FILE>`: the result of the next successful query will be saved in JSON format in a file instead of printed on
screen. If `<FILE>` is omitted, then the effect of any previous `%save` command is nullified.
* `%backup <FILE>`: the current database will be backed up into the file. * `%backup <FILE>`: the current database will be backed up into the file.
* `%restore <FILE>`: restore the data in the backup to the current database. The current database must be empty. * `%restore <FILE>`: restore the data in the backup to the current database. The current database must be empty.
## The query API ## The query API
Queries are run by sending HTTP POST requests to the server. Queries are run by sending HTTP POST requests to the server.
By default, the API endpoint is `http://127.0.0.1:9070/text-query`. By default, the API endpoint is `http://127.0.0.1:9070/text-query`.
A JSON body of the following form is expected: A JSON body of the following form is expected:
```json ```json
{ {
"script": "<COZOSCRIPT QUERY STRING>", "script": "<COZOSCRIPT QUERY STRING>",
"params": {} "params": {}
} }
``` ```
params should be an object of named parameters. For example, if params is `{"num": 1}`,
then `$num` can be used anywhere in your query string where an expression is expected. params should be an object of named parameters. For example, if params is `{"num": 1}`,
then `$num` can be used anywhere in your query string where an expression is expected.
Always use params instead of concatenating strings when you need parametrized queries. Always use params instead of concatenating strings when you need parametrized queries.
The HTTP API always responds in JSON. If a request is successful, then its `"ok"` field will be `true`, The HTTP API always responds in JSON. If a request is successful, then its `"ok"` field will be `true`,
@ -63,19 +66,19 @@ and the `"rows"` field will contain the data for the resulting relation, and `"h
the headers. If an error occurs, then `"ok"` will contain `false`, the error message will be in `"message"` the headers. If an error occurs, then `"ok"` will contain `false`, the error message will be in `"message"`
and a nicely-formatted diagnostic will be in `"display"` if available. and a nicely-formatted diagnostic will be in `"display"` if available.
> Cozo is designed to run in a trusted environment and be used by trusted clients. > Cozo is designed to run in a trusted environment and be used by trusted clients.
> It does not come with elaborate authentication and security features. > It does not come with elaborate authentication and security features.
> If you must access Cozo remotely, you are responsible for setting up firewalls, encryptions and proxies yourself. > If you must access Cozo remotely, you are responsible for setting up firewalls, encryptions and proxies yourself.
> >
> As a guard against users accidentally exposing sensitive data, > As a guard against users accidentally exposing sensitive data,
> If you bind Cozo to non-loopback addresses, > If you bind Cozo to non-loopback addresses,
> Cozo will generate a token string and require all queries > Cozo will generate a token string and require all queries
> to provide the token string in the HTTP header field `x-cozo-auth`. > to provide the token string in the HTTP header field `x-cozo-auth`.
> The warning printed when you start Cozo with a > The warning printed when you start Cozo with a
> non-default binding will tell you where to find the token string. > non-default binding will tell you where to find the token string.
> This “security measure” is not considered sufficient for any purpose > This “security measure” is not considered sufficient for any purpose
> and is only intended as a last defence against carelessness. > and is only intended as a last defence against carelessness.
> >
> In some environments, setting the header may be difficult or impossible > In some environments, setting the header may be difficult or impossible
> for some of the APIs. In this case you can pass the token in the query parameter `auth`. > for some of the APIs. In this case you can pass the token in the query parameter `auth`.
@ -84,27 +87,20 @@ and a nicely-formatted diagnostic will be in `"display"` if available.
* `POST /text-query`, described above. * `POST /text-query`, described above.
* `GET /export/{relations: String}`, where `relations` is a comma-separated list of relations to export. * `GET /export/{relations: String}`, where `relations` is a comma-separated list of relations to export.
* `PUT /import`, import data into the database. Data should be in `application/json` MIME type in the body, * `PUT /import`, import data into the database. Data should be in `application/json` MIME type in the body,
in the same format as returned in the `data` field in the `/export` API. in the same format as returned in the `data` field in the `/export` API.
* `POST /backup`, backup database, should supply a JSON body of the form `{"path": <PATH>}` * `POST /backup`, backup database, should supply a JSON body of the form `{"path": <PATH>}`
* `POST /import-from-backup`, import data into the database from a backup. Should supply a JSON body * `POST /import-from-backup`, import data into the database from a backup. Should supply a JSON body
of the form `{"path": <PATH>, "relations": <ARRAY OF RELATION NAMES>}`. of the form `{"path": <PATH>, "relations": <ARRAY OF RELATION NAMES>}`.
* `GET /`, if you open this in your browser and open your developer tools, you will be able to use * `GET /`, if you open this in your browser and open your developer tools, you will be able to use
a very simple client to query this database. a very simple client to query this database.
> For `import` and `import-from-backup`, triggers are _not_ run for the relations, if any exists. > For `import` and `import-from-backup`, triggers are _not_ run for the relations, if any exists.
If you need to activate triggers, use queries with parameters. > If you need to activate triggers, use queries with parameters.
The following are experimental: The following are experimental:
* `GET(SSE) /changes/{relation: String}` get changes when mutations are made against a relation, relies on [SSE](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events). * `GET(SSE) /changes/{relation: String}` get changes when mutations are made against a relation, relies
* `GET(SSE) /rules/{name: String}` register a custom fixed rule and receive requests for computation. on [SSE](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events).
Query parameter `arity` must also be present.
* `POST /rule-result/{id}` post results of custom fixed rule computation back to the server, used together with the last API.
* `POST /transact` start a multi-statement transaction, the ID returned is used in the following two APIs.
Need to set the `write=true` query parameter if mutations are present.
* `POST /transact/{id}` do queries inside a multi-statement transaction, JSON payload expected is the same as for `/text-query`.
* `PUT /transact/{id}` commit or abort a multi-statement transaction. JSON payload is of the form `{"abort": <bool>}`, pass `false` for commit and `true` for abort. If you forget to do this, a resource leak results, even for read-only transactions.
## Building ## Building

@ -10,23 +10,23 @@ use std::collections::BTreeMap;
use std::convert::Infallible; use std::convert::Infallible;
use std::net::{Ipv6Addr, SocketAddr}; use std::net::{Ipv6Addr, SocketAddr};
use std::str::FromStr; use std::str::FromStr;
use std::sync::atomic::{AtomicU32, Ordering}; // use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::{Arc, Mutex}; use std::sync::Arc;
use std::thread; // use std::thread;
use axum::body::{boxed, Body, BoxBody}; use axum::body::{boxed, Body, BoxBody};
use axum::extract::{DefaultBodyLimit, Path, Query, State}; use axum::extract::{DefaultBodyLimit, Path, State};
use axum::http::{header, HeaderName, Method, Request, Response, StatusCode}; use axum::http::{header, HeaderName, Method, Request, Response, StatusCode};
use axum::response::sse::{Event, KeepAlive}; use axum::response::sse::{Event, KeepAlive};
use axum::response::{Html, Sse}; use axum::response::{Html, Sse};
use axum::routing::{get, post, put}; use axum::routing::{get, post, put};
use axum::{Json, Router}; use axum::{Extension, Json, Router};
use clap::Args; use clap::Args;
use futures::future::BoxFuture; use futures::future::BoxFuture;
use futures::stream::Stream; use futures::stream::Stream;
use itertools::Itertools; use itertools::Itertools;
use log::{error, info, warn}; use log::{error, info, warn};
use miette::miette; // use miette::miette;
use rand::Rng; use rand::Rng;
use serde_json::json; use serde_json::json;
use tokio::task::spawn_blocking; use tokio::task::spawn_blocking;
@ -34,10 +34,7 @@ use tower_http::auth::{AsyncAuthorizeRequest, AsyncRequireAuthorizationLayer};
use tower_http::compression::CompressionLayer; use tower_http::compression::CompressionLayer;
use tower_http::cors::{Any, CorsLayer}; use tower_http::cors::{Any, CorsLayer};
use cozo::{ use cozo::{DataValue, DbInstance, NamedRows, ScriptMutability};
format_error_as_json, DataValue, DbInstance, MultiTransaction, NamedRows, ScriptMutability,
SimpleFixedRule,
};
#[derive(Args, Debug)] #[derive(Args, Debug)]
pub(crate) struct ServerArgs { pub(crate) struct ServerArgs {
@ -57,9 +54,6 @@ pub(crate) struct ServerArgs {
#[clap(short, long, default_value_t = String::from("{}"))] #[clap(short, long, default_value_t = String::from("{}"))]
config: String, config: String,
// When on, start REPL instead of starting a webserver
// #[clap(short, long)]
// repl: bool,
/// Address to bind the service to /// Address to bind the service to
#[clap(short, long, default_value_t = String::from("127.0.0.1"))] #[clap(short, long, default_value_t = String::from("127.0.0.1"))]
bind: String, bind: String,
@ -67,21 +61,26 @@ pub(crate) struct ServerArgs {
/// Port to use /// Port to use
#[clap(short = 'P', long, default_value_t = 9070)] #[clap(short = 'P', long, default_value_t = 9070)]
port: u16, port: u16,
/// When set, the content of the named table will be used as a token table
#[clap(long)]
token_table: Option<String>,
} }
#[derive(Clone)] #[derive(Clone)]
struct DbState { struct DbState {
db: DbInstance, db: DbInstance,
rule_senders: Arc<Mutex<BTreeMap<u32, crossbeam::channel::Sender<miette::Result<NamedRows>>>>>, // rule_senders: Arc<Mutex<BTreeMap<u32, crossbeam::channel::Sender<miette::Result<NamedRows>>>>>,
rule_counter: Arc<AtomicU32>, // rule_counter: Arc<AtomicU32>,
tx_counter: Arc<AtomicU32>, // tx_counter: Arc<AtomicU32>,
txs: Arc<Mutex<BTreeMap<u32, Arc<MultiTransaction>>>>, // txs: Arc<Mutex<BTreeMap<u32, Arc<MultiTransaction>>>>,
} }
#[derive(Clone)] #[derive(Clone)]
struct MyAuth { struct MyAuth {
skip_auth: bool, skip_auth: bool,
auth_guard: String, auth_guard: String,
token_table: Option<Arc<(String, DbInstance)>>,
} }
impl<B> AsyncAuthorizeRequest<B> for MyAuth impl<B> AsyncAuthorizeRequest<B> for MyAuth
@ -92,17 +91,18 @@ where
type ResponseBody = BoxBody; type ResponseBody = BoxBody;
type Future = BoxFuture<'static, Result<Request<B>, Response<Self::ResponseBody>>>; type Future = BoxFuture<'static, Result<Request<B>, Response<Self::ResponseBody>>>;
fn authorize(&mut self, request: Request<B>) -> Self::Future { fn authorize<'a>(&'a mut self, mut request: Request<B>) -> Self::Future {
let skip_auth = self.skip_auth; let skip_auth = self.skip_auth;
let auth_guard = self.auth_guard.clone(); let auth_guard = self.auth_guard.clone();
let token_table = self.token_table.clone();
Box::pin(async move { Box::pin(async move {
if skip_auth { if skip_auth {
request.extensions_mut().insert(ScriptMutability::Mutable);
return Ok(request); return Ok(request);
} }
let ok = match request.headers().get("x-cozo-auth") { let mutability = match request.headers().get("x-cozo-auth") {
None => match request.uri().query() { None => match request.uri().query() {
None => false,
Some(q_str) => { Some(q_str) => {
let mut bingo = false; let mut bingo = false;
for pair in q_str.split('&') { for pair in q_str.split('&') {
@ -115,15 +115,65 @@ where
} }
} }
} }
bingo if bingo {
Some(ScriptMutability::Mutable)
} else {
None
}
} }
None => match token_table {
None => None,
Some(tt) => {
let (name, db) = tt.as_ref();
if let Some(auth_header) = request.headers().get("Authorization") {
if let Ok(auth_str) = auth_header.to_str() {
if auth_str.starts_with("Bearer ") {
let token = &auth_str["Bearer ".len()..];
match db.run_script(
&format!("?[mutable] := *{name} {{ token: $token, mutable }}"),
BTreeMap::from([(String::from("token"), DataValue::from(token))]),
ScriptMutability::Immutable,
) {
Ok(rows) => match rows.rows.first() {
None => None,
Some(val) => {
if val[0].get_bool() == Some(true) {
Some(ScriptMutability::Mutable)
} else {
Some(ScriptMutability::Immutable)
}
}
},
Err(err) => {
eprintln!("Error: {}", err);
None
},
}
} else {
None
}
} else {
None
}
} else {
None
}
}
},
}, },
Some(data) => match data.to_str() { Some(data) => match data.to_str() {
Ok(s) => s == auth_guard.as_str(), Ok(s) => {
Err(_) => false, if s == auth_guard.as_str() {
Some(ScriptMutability::Mutable)
} else {
None
}
}
Err(_) => None,
}, },
}; };
if ok { if let Some(mutability) = mutability {
request.extensions_mut().insert(mutability);
Ok(request) Ok(request)
} else { } else {
let unauthorized_response = Response::builder() let unauthorized_response = Response::builder()
@ -174,12 +224,18 @@ pub(crate) async fn server_main(args: ServerArgs) {
} }
}; };
let auth_obj = MyAuth {
skip_auth,
auth_guard,
token_table: args.token_table.map(|t| Arc::new((t, db.clone()))),
};
let state = DbState { let state = DbState {
db, db,
rule_senders: Default::default(), // rule_senders: Default::default(),
rule_counter: Default::default(), // rule_counter: Default::default(),
tx_counter: Default::default(), // tx_counter: Default::default(),
txs: Default::default(), // txs: Default::default(),
}; };
let cors = CorsLayer::new() let cors = CorsLayer::new()
.allow_methods([Method::GET, Method::POST, Method::PUT, Method::DELETE]) .allow_methods([Method::GET, Method::POST, Method::PUT, Method::DELETE])
@ -193,18 +249,15 @@ pub(crate) async fn server_main(args: ServerArgs) {
.route("/backup", post(backup)) .route("/backup", post(backup))
.route("/import-from-backup", post(import_from_backup)) .route("/import-from-backup", post(import_from_backup))
.route("/changes/:relation", get(observe_changes)) .route("/changes/:relation", get(observe_changes))
.route("/rules/:name", get(register_rule)) // .route("/rules/:name", get(register_rule))
.route( // .route(
"/rule-result/:id", // "/rule-result/:id",
post(post_rule_result).delete(post_rule_err), // post(post_rule_result).delete(post_rule_err),
) // +keep alive // ) // +keep alive
.route("/transact", post(start_transact)) // .route("/transact", post(start_transact))
.route("/transact/:id", post(transact_query).put(finish_query)) // .route("/transact/:id", post(transact_query).put(finish_query))
.with_state(state) .with_state(state)
.layer(AsyncRequireAuthorizationLayer::new(MyAuth { .layer(AsyncRequireAuthorizationLayer::new(auth_obj))
skip_auth,
auth_guard,
}))
.fallback(not_found) .fallback(not_found)
.route("/", get(root)) .route("/", get(root))
.layer(cors) .layer(cors)
@ -235,76 +288,76 @@ pub(crate) async fn server_main(args: ServerArgs) {
#[derive(serde_derive::Deserialize)] #[derive(serde_derive::Deserialize)]
struct StartTransactPayload { struct StartTransactPayload {
write: bool, // write: bool,
}
async fn start_transact(
State(st): State<DbState>,
Query(payload): Query<StartTransactPayload>,
) -> (StatusCode, Json<serde_json::Value>) {
let tx = st.db.multi_transaction(payload.write);
let id = st.tx_counter.fetch_add(1, Ordering::SeqCst);
st.txs.lock().unwrap().insert(id, Arc::new(tx));
(StatusCode::OK, json!({"ok": true, "id": id}).into())
}
async fn transact_query(
State(st): State<DbState>,
Path(id): Path<u32>,
Json(payload): Json<QueryPayload>,
) -> (StatusCode, Json<serde_json::Value>) {
let tx = match st.txs.lock().unwrap().get(&id) {
None => return (StatusCode::NOT_FOUND, json!({"ok": false}).into()),
Some(tx) => tx.clone(),
};
let src = payload.script.clone();
let result = spawn_blocking(move || {
let params = payload
.params
.into_iter()
.map(|(k, v)| (k, DataValue::from(v)))
.collect();
let query = payload.script;
tx.run_script(&query, params)
})
.await;
match result {
Ok(Ok(res)) => (StatusCode::OK, res.into_json().into()),
Ok(Err(err)) => (
StatusCode::BAD_REQUEST,
format_error_as_json(err, Some(&src)).into(),
),
Err(err) => internal_error(err),
}
} }
#[derive(serde_derive::Deserialize)] // async fn start_transact(
struct FinishTransactPayload { // State(st): State<DbState>,
abort: bool, // Query(payload): Query<StartTransactPayload>,
} // ) -> (StatusCode, Json<serde_json::Value>) {
// let tx = st.db.multi_transaction(payload.write);
async fn finish_query( // let id = st.tx_counter.fetch_add(1, Ordering::SeqCst);
State(st): State<DbState>, // st.txs.lock().unwrap().insert(id, Arc::new(tx));
Path(id): Path<u32>, // (StatusCode::OK, json!({"ok": true, "id": id}).into())
Json(payload): Json<FinishTransactPayload>, // }
) -> (StatusCode, Json<serde_json::Value>) {
let tx = match st.txs.lock().unwrap().remove(&id) { // async fn transact_query(
None => return (StatusCode::NOT_FOUND, json!({"ok": false}).into()), // State(st): State<DbState>,
Some(tx) => tx, // Path(id): Path<u32>,
}; // Json(payload): Json<QueryPayload>,
let res = if payload.abort { // ) -> (StatusCode, Json<serde_json::Value>) {
tx.abort() // let tx = match st.txs.lock().unwrap().get(&id) {
} else { // None => return (StatusCode::NOT_FOUND, json!({"ok": false}).into()),
tx.commit() // Some(tx) => tx.clone(),
}; // };
match res { // let src = payload.script.clone();
Ok(_) => (StatusCode::OK, json!({"ok": true}).into()), // let result = spawn_blocking(move || {
Err(err) => ( // let params = payload
StatusCode::BAD_REQUEST, // .params
json!({"ok": false, "message": err.to_string()}).into(), // .into_iter()
), // .map(|(k, v)| (k, DataValue::from(v)))
} // .collect();
} // let query = payload.script;
// tx.run_script(&query, params)
// })
// .await;
// match result {
// Ok(Ok(res)) => (StatusCode::OK, res.into_json().into()),
// Ok(Err(err)) => (
// StatusCode::BAD_REQUEST,
// format_error_as_json(err, Some(&src)).into(),
// ),
// Err(err) => internal_error(err),
// }
// }
// #[derive(serde_derive::Deserialize)]
// struct FinishTransactPayload {
// abort: bool,
// }
// async fn finish_query(
// State(st): State<DbState>,
// Path(id): Path<u32>,
// Json(payload): Json<FinishTransactPayload>,
// ) -> (StatusCode, Json<serde_json::Value>) {
// let tx = match st.txs.lock().unwrap().remove(&id) {
// None => return (StatusCode::NOT_FOUND, json!({"ok": false}).into()),
// Some(tx) => tx,
// };
// let res = if payload.abort {
// tx.abort()
// } else {
// tx.commit()
// };
// match res {
// Ok(_) => (StatusCode::OK, json!({"ok": true}).into()),
// Err(err) => (
// StatusCode::BAD_REQUEST,
// json!({"ok": false, "message": err.to_string()}).into(),
// ),
// }
// }
#[derive(serde_derive::Deserialize)] #[derive(serde_derive::Deserialize)]
struct QueryPayload { struct QueryPayload {
@ -314,6 +367,7 @@ struct QueryPayload {
} }
async fn text_query( async fn text_query(
Extension(mutability): Extension<ScriptMutability>,
State(st): State<DbState>, State(st): State<DbState>,
Json(payload): Json<QueryPayload>, Json(payload): Json<QueryPayload>,
) -> (StatusCode, Json<serde_json::Value>) { ) -> (StatusCode, Json<serde_json::Value>) {
@ -322,7 +376,10 @@ async fn text_query(
.into_iter() .into_iter()
.map(|(k, v)| (k, DataValue::from(v))) .map(|(k, v)| (k, DataValue::from(v)))
.collect(); .collect();
let immutable = payload.immutable.unwrap_or(false); let immutable = match mutability {
ScriptMutability::Mutable => payload.immutable.unwrap_or(false),
ScriptMutability::Immutable => true,
};
let result = spawn_blocking(move || { let result = spawn_blocking(move || {
st.db.run_script_fold_err( st.db.run_script_fold_err(
&payload.script, &payload.script,
@ -457,118 +514,118 @@ async fn import_from_backup(
} }
} }
#[derive(serde_derive::Deserialize)] // #[derive(serde_derive::Deserialize)]
struct RuleRegisterOptions { // struct RuleRegisterOptions {
arity: usize, // arity: usize,
} // }
async fn post_rule_result( // async fn post_rule_result(
State(st): State<DbState>, // State(st): State<DbState>,
Path(id): Path<u32>, // Path(id): Path<u32>,
Json(res): Json<serde_json::Value>, // Json(res): Json<serde_json::Value>,
) -> (StatusCode, Json<serde_json::Value>) { // ) -> (StatusCode, Json<serde_json::Value>) {
let res = match NamedRows::from_json(&res) { // let res = match NamedRows::from_json(&res) {
Ok(res) => res, // Ok(res) => res,
Err(err) => { // Err(err) => {
if let Some(ch) = st.rule_senders.lock().unwrap().remove(&id) { // if let Some(ch) = st.rule_senders.lock().unwrap().remove(&id) {
let _ = ch.send(Err(miette!("downstream posted malformed result"))); // let _ = ch.send(Err(miette!("downstream posted malformed result")));
} // }
return ( // return (
StatusCode::BAD_REQUEST, // StatusCode::BAD_REQUEST,
json!({"ok": false, "message": err.to_string()}).into(), // json!({"ok": false, "message": err.to_string()}).into(),
); // );
} // }
}; // };
if let Some(ch) = st.rule_senders.lock().unwrap().remove(&id) { // if let Some(ch) = st.rule_senders.lock().unwrap().remove(&id) {
match ch.send(Ok(res)) { // match ch.send(Ok(res)) {
Ok(_) => (StatusCode::OK, json!({"ok": true}).into()), // Ok(_) => (StatusCode::OK, json!({"ok": true}).into()),
Err(err) => ( // Err(err) => (
StatusCode::INTERNAL_SERVER_ERROR, // StatusCode::INTERNAL_SERVER_ERROR,
json!({"ok": false, "message": err.to_string()}).into(), // json!({"ok": false, "message": err.to_string()}).into(),
), // ),
} // }
} else { // } else {
(StatusCode::NOT_FOUND, json!({"ok": false}).into()) // (StatusCode::NOT_FOUND, json!({"ok": false}).into())
} // }
} // }
async fn post_rule_err( // async fn post_rule_err(
State(st): State<DbState>, // State(st): State<DbState>,
Path(id): Path<u32>, // Path(id): Path<u32>,
) -> (StatusCode, Json<serde_json::Value>) { // ) -> (StatusCode, Json<serde_json::Value>) {
if let Some(ch) = st.rule_senders.lock().unwrap().remove(&id) { // if let Some(ch) = st.rule_senders.lock().unwrap().remove(&id) {
match ch.send(Err(miette!("downstream cancelled computation"))) { // match ch.send(Err(miette!("downstream cancelled computation"))) {
Ok(_) => (StatusCode::OK, json!({"ok": true}).into()), // Ok(_) => (StatusCode::OK, json!({"ok": true}).into()),
Err(err) => ( // Err(err) => (
StatusCode::INTERNAL_SERVER_ERROR, // StatusCode::INTERNAL_SERVER_ERROR,
json!({"ok": false, "message": err.to_string()}).into(), // json!({"ok": false, "message": err.to_string()}).into(),
), // ),
} // }
} else { // } else {
(StatusCode::NOT_FOUND, json!({"ok": false}).into()) // (StatusCode::NOT_FOUND, json!({"ok": false}).into())
} // }
} // }
async fn register_rule( // async fn register_rule(
State(st): State<DbState>, // State(st): State<DbState>,
Path(name): Path<String>, // Path(name): Path<String>,
Query(rule_opts): Query<RuleRegisterOptions>, // Query(rule_opts): Query<RuleRegisterOptions>,
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> { // ) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
let (rule, task_receiver) = SimpleFixedRule::rule_with_channel(rule_opts.arity); // let (rule, task_receiver) = SimpleFixedRule::rule_with_channel(rule_opts.arity);
let (down_sender, mut down_receiver) = tokio::sync::mpsc::channel(1); // let (down_sender, mut down_receiver) = tokio::sync::mpsc::channel(1);
let mut errored = None; // let mut errored = None;
//
if let Err(err) = st.db.register_fixed_rule(name.clone(), rule) { // if let Err(err) = st.db.register_fixed_rule(name.clone(), rule) {
errored = Some(err); // errored = Some(err);
} else { // } else {
let rule_senders = st.rule_senders.clone(); // let rule_senders = st.rule_senders.clone();
let rule_counter = st.rule_counter.clone(); // let rule_counter = st.rule_counter.clone();
//
thread::spawn(move || { // thread::spawn(move || {
for (inputs, options, sender) in task_receiver { // for (inputs, options, sender) in task_receiver {
let id = rule_counter.fetch_add(1, Ordering::AcqRel); // let id = rule_counter.fetch_add(1, Ordering::AcqRel);
let inputs: serde_json::Value = // let inputs: serde_json::Value =
inputs.into_iter().map(|ip| ip.into_json()).collect(); // inputs.into_iter().map(|ip| ip.into_json()).collect();
let options: serde_json::Value = options // let options: serde_json::Value = options
.into_iter() // .into_iter()
.map(|(k, v)| (k, serde_json::Value::from(v))) // .map(|(k, v)| (k, serde_json::Value::from(v)))
.collect(); // .collect();
if down_sender.blocking_send((id, inputs, options)).is_err() { // if down_sender.blocking_send((id, inputs, options)).is_err() {
let _ = sender.send(Err(miette!("cannot send request to downstream"))); // let _ = sender.send(Err(miette!("cannot send request to downstream")));
} else { // } else {
rule_senders.lock().unwrap().insert(id, sender); // rule_senders.lock().unwrap().insert(id, sender);
} // }
} // }
}); // });
} // }
//
struct Guard { // struct Guard {
name: String, // name: String,
db: DbInstance, // db: DbInstance,
} // }
//
impl Drop for Guard { // impl Drop for Guard {
fn drop(&mut self) { // fn drop(&mut self) {
info!("dropping rules SSE {}", self.name); // info!("dropping rules SSE {}", self.name);
let _ = self.db.unregister_fixed_rule(&self.name); // let _ = self.db.unregister_fixed_rule(&self.name);
} // }
} // }
//
let stream = async_stream::stream! { // let stream = async_stream::stream! {
if let Some(err) = errored { // if let Some(err) = errored {
let item = json!({"type": "register-error", "error": err.to_string()}); // let item = json!({"type": "register-error", "error": err.to_string()});
yield Ok(Event::default().json_data(item).unwrap()); // yield Ok(Event::default().json_data(item).unwrap());
} else { // } else {
info!("starting rule SSE {}", name); // info!("starting rule SSE {}", name);
let _guard = Guard {db: st.db, name}; // let _guard = Guard {db: st.db, name};
while let Some((id, inputs, options)) = down_receiver.recv().await { // while let Some((id, inputs, options)) = down_receiver.recv().await {
let item = json!({"type": "request", "id": id, "inputs": inputs, "options": options}); // let item = json!({"type": "request", "id": id, "inputs": inputs, "options": options});
yield Ok(Event::default().json_data(item).unwrap()); // yield Ok(Event::default().json_data(item).unwrap());
} // }
} // }
}; // };
Sse::new(stream).keep_alive(KeepAlive::default()) // Sse::new(stream).keep_alive(KeepAlive::default())
} // }
async fn observe_changes( async fn observe_changes(
State(st): State<DbState>, State(st): State<DbState>,

Loading…
Cancel
Save