token table; disable multi-table transactions and custom rules for remote

main
Ziyang Hu 1 year ago
parent 6b8438e823
commit cd840e687c

@ -2,11 +2,14 @@
[![server](https://img.shields.io/github/v/release/cozodb/cozo)](https://github.com/cozodb/cozo/releases)
本文叙述的是如何安装设置 Cozo 的独立程序本身。有关如何使用 CozoDBCozoScript的信息见 [文档](https://docs.cozodb.org/zh_CN/latest/index.html) 。
本文叙述的是如何安装设置 Cozo 的独立程序本身。有关如何使用
CozoDBCozoScript的信息见 [文档](https://docs.cozodb.org/zh_CN/latest/index.html) 。
## 下载
独立服务的程序可以从 [GitHub 发布页](https://github.com/cozodb/cozo/releases) 或 [Gitee 发布页](https://gitee.com/cozodb/cozo/releases) 下载,其中名为 `cozo-*` 的是独立服务程序,名为 `cozo_all-*` 的独立程序同时支持更多地存储引擎,比如 [TiKV](https://tikv.org/)。
独立服务的程序可以从 [GitHub 发布页](https://github.com/cozodb/cozo/releases)
或 [Gitee 发布页](https://gitee.com/cozodb/cozo/releases) 下载,其中名为 `cozo-*` 的是独立服务程序,名为 `cozo_all-*`
的独立程序同时支持更多地存储引擎,比如 [TiKV](https://tikv.org/)。
## 启动服务程序
@ -39,19 +42,26 @@
## 查询 API
查询通过向 API 发送 POST 请求来完成。默认的请求地址是 `http://127.0.0.1:9070/text-query` 。请求必须包含 JSON 格式的正文,具体内容如下:
```json
{
"script": "<COZOSCRIPT QUERY STRING>",
"params": {}
"script": "<COZOSCRIPT QUERY STRING>",
"params": {}
}
```
`params` 给出了查询文本中可用的变量。例如,当 `params``{"num": 1}` 时,查询文本中可以以 `$num` 来代替常量 `1`。请善用此功能,而不是手动拼接查询字符串。
HTTP API 返回的结果永远是 JSON 格式的。如果请求成功,则返回结果的 `"ok"` 字段将为 `true`,且 `"rows"` 字段将含有查询结果的行,而 `"headers"` 将含有表头。如果查询报错,则 `"ok"` 字段将为 `false`,而错误信息会在 `"message"` 字段中,同时 `"display"` 字段会包含格式化好的友好的错误提示。
`params` 给出了查询文本中可用的变量。例如,当 `params``{"num": 1}` 时,查询文本中可以以 `$num` 来代替常量 `1`
。请善用此功能,而不是手动拼接查询字符串。
HTTP API 返回的结果永远是 JSON 格式的。如果请求成功,则返回结果的 `"ok"` 字段将为 `true`,且 `"rows"`
字段将含有查询结果的行,而 `"headers"` 将含有表头。如果查询报错,则 `"ok"` 字段将为 `false`,而错误信息会在 `"message"`
字段中,同时 `"display"` 字段会包含格式化好的友好的错误提示。
> Cozo 的设计,基于其在一个受信任的环境中运行,且其所有用户也是由受信任的这种假设。因此 Cozo 没有内置认证与复杂的安全机制。如果你需要远程访问 Cozo 服务,你必须自己设置防火墙、加密和代理等,用来保护服务器上资源的安全。
> Cozo 的设计,基于其在一个受信任的环境中运行,且其所有用户也是由受信任的这种假设。因此 Cozo 没有内置认证与复杂的安全机制。如果你需要远程访问
> Cozo 服务,你必须自己设置防火墙、加密和代理等,用来保护服务器上资源的安全。
>
> 由于总是会有用户不小心将服务接口暴露于外网Cozo 有一个补救措施:如果允许了非回传地址访问 Cozo则必须在所有请求中以 HTTP 文件头 `x-cozo-auth` 的形式附上访问令牌。访问令牌的内容在启动服务的终端中有提示。注意这仅仅是一个补救措施,并不是特别可靠的安全机制,是为了尽量防止一些不由于小心而造成严重后果的悲剧。
> 由于总是会有用户不小心将服务接口暴露于外网Cozo 有一个补救措施:如果允许了非回传地址访问 Cozo则必须在所有请求中以 HTTP
> 文件头 `x-cozo-auth` 的形式附上访问令牌。访问令牌的内容在启动服务的终端中有提示。注意这仅仅是一个补救措施,并不是特别可靠的安全机制,是为了尽量防止一些不由于小心而造成严重后果的悲剧。
>
> 在有些客户端里,部分 API 加入 HTTP 文件头可能会比较困难或者根本不可能。这时也可以在查询参数 `auth` 中传入令牌。
@ -59,9 +69,11 @@ HTTP API 返回的结果永远是 JSON 格式的。如果请求成功,则返
* `POST /text-query`,见上。
* `GET /export/{relations: String}`,导出指定表中的数据,其中 `relations` 是以逗号分割的表名。
* `PUT /import`,向数据库导入数据。所导入的数据应以在正文中以 `application/json` MIME 类型传入,具体格式与 `/export` 返回值中的 `data` 字段相同。
* `PUT /import`,向数据库导入数据。所导入的数据应以在正文中以 `application/json` MIME 类型传入,具体格式与 `/export`
返回值中的 `data` 字段相同。
* `POST /backup`,备份数据库,需要传入 JSON 正文 `{"path": <路径>}`
* `POST /import-from-backup`,将备份中指定存储表中的数据插入当前数据库中同名存储表。需要传入 JSON 正文 `{"path": <路径>, "relations": <表名数组>}`.
* `POST /import-from-backup`,将备份中指定存储表中的数据插入当前数据库中同名存储表。需要传入 JSON
正文 `{"path": <路径>, "relations": <表名数组>}`.
* `GET /`,用浏览器打开这个地址,然后打开浏览器的调试工具,就可以使用一个简陋的 JS 客户端。
> 注意 `import``import-from-backup` 接口不会激活任何触发器。
@ -69,12 +81,8 @@ HTTP API 返回的结果永远是 JSON 格式的。如果请求成功,则返
以下为试验性的 API
* `GET(SSE) /changes/{relation: String}` 获取某个存储表的更新,基于 [SSE](https://developer.mozilla.org/zh-CN/docs/Web/API/Server-sent_events/Using_server-sent_events).
* `GET(SSE) /rules/{name: String}` 注册一个自定义的固定规则。查询参数 `arity` 是必须的。
* `POST /rule-result/{id}` 将固定规则的计算结果回传给服务器,配合上一个 API 使用。
* `POST /transact` 开始一个多语句的事务。返回的 ID 在下面几个 API 中使用。如果要进行写操作,则需要传入 `write=true` 查询参数。
* `POST /transact/{id}` 在多语句事务中进行查询。要求的正文与 `/text-query` 所要求的相同。
* `PUT /transact/{id}` 提交或放弃多语句事务。要求的正文是 JSON `{"abort": <bool>}`,传入真值则放弃,否则提交。如果不执行此查询则服务器会浪费系统资源。
* `GET(SSE) /changes/{relation: String}`
获取某个存储表的更新,基于 [SSE](https://developer.mozilla.org/zh-CN/docs/Web/API/Server-sent_events/Using_server-sent_events).
## 编译

@ -39,7 +39,8 @@ You can use the following meta ops in the REPL:
* `%params`: print all set parameters.
* `%run <FILE>`: run the script contained in `<FILE>`.
* `%import <FILE OR URL>`: import data in JSON format from the file or URL.
* `%save <FILE>`: the result of the next successful query will be saved in JSON format in a file instead of printed on screen. If `<FILE>` is omitted, then the effect of any previous `%save` command is nullified.
* `%save <FILE>`: the result of the next successful query will be saved in JSON format in a file instead of printed on
screen. If `<FILE>` is omitted, then the effect of any previous `%save` command is nullified.
* `%backup <FILE>`: the current database will be backed up into the file.
* `%restore <FILE>`: restore the data in the backup to the current database. The current database must be empty.
@ -48,12 +49,14 @@ You can use the following meta ops in the REPL:
Queries are run by sending HTTP POST requests to the server.
By default, the API endpoint is `http://127.0.0.1:9070/text-query`.
A JSON body of the following form is expected:
```json
{
"script": "<COZOSCRIPT QUERY STRING>",
"params": {}
"script": "<COZOSCRIPT QUERY STRING>",
"params": {}
}
```
params should be an object of named parameters. For example, if params is `{"num": 1}`,
then `$num` can be used anywhere in your query string where an expression is expected.
Always use params instead of concatenating strings when you need parametrized queries.
@ -84,27 +87,20 @@ and a nicely-formatted diagnostic will be in `"display"` if available.
* `POST /text-query`, described above.
* `GET /export/{relations: String}`, where `relations` is a comma-separated list of relations to export.
* `PUT /import`, import data into the database. Data should be in `application/json` MIME type in the body,
in the same format as returned in the `data` field in the `/export` API.
in the same format as returned in the `data` field in the `/export` API.
* `POST /backup`, backup database, should supply a JSON body of the form `{"path": <PATH>}`
* `POST /import-from-backup`, import data into the database from a backup. Should supply a JSON body
of the form `{"path": <PATH>, "relations": <ARRAY OF RELATION NAMES>}`.
of the form `{"path": <PATH>, "relations": <ARRAY OF RELATION NAMES>}`.
* `GET /`, if you open this in your browser and open your developer tools, you will be able to use
a very simple client to query this database.
a very simple client to query this database.
> For `import` and `import-from-backup`, triggers are _not_ run for the relations, if any exists.
If you need to activate triggers, use queries with parameters.
> If you need to activate triggers, use queries with parameters.
The following are experimental:
* `GET(SSE) /changes/{relation: String}` get changes when mutations are made against a relation, relies on [SSE](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events).
* `GET(SSE) /rules/{name: String}` register a custom fixed rule and receive requests for computation.
Query parameter `arity` must also be present.
* `POST /rule-result/{id}` post results of custom fixed rule computation back to the server, used together with the last API.
* `POST /transact` start a multi-statement transaction, the ID returned is used in the following two APIs.
Need to set the `write=true` query parameter if mutations are present.
* `POST /transact/{id}` do queries inside a multi-statement transaction, JSON payload expected is the same as for `/text-query`.
* `PUT /transact/{id}` commit or abort a multi-statement transaction. JSON payload is of the form `{"abort": <bool>}`, pass `false` for commit and `true` for abort. If you forget to do this, a resource leak results, even for read-only transactions.
* `GET(SSE) /changes/{relation: String}` get changes when mutations are made against a relation, relies
on [SSE](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events).
## Building

@ -10,23 +10,23 @@ use std::collections::BTreeMap;
use std::convert::Infallible;
use std::net::{Ipv6Addr, SocketAddr};
use std::str::FromStr;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::{Arc, Mutex};
use std::thread;
// use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
// use std::thread;
use axum::body::{boxed, Body, BoxBody};
use axum::extract::{DefaultBodyLimit, Path, Query, State};
use axum::extract::{DefaultBodyLimit, Path, State};
use axum::http::{header, HeaderName, Method, Request, Response, StatusCode};
use axum::response::sse::{Event, KeepAlive};
use axum::response::{Html, Sse};
use axum::routing::{get, post, put};
use axum::{Json, Router};
use axum::{Extension, Json, Router};
use clap::Args;
use futures::future::BoxFuture;
use futures::stream::Stream;
use itertools::Itertools;
use log::{error, info, warn};
use miette::miette;
// use miette::miette;
use rand::Rng;
use serde_json::json;
use tokio::task::spawn_blocking;
@ -34,10 +34,7 @@ use tower_http::auth::{AsyncAuthorizeRequest, AsyncRequireAuthorizationLayer};
use tower_http::compression::CompressionLayer;
use tower_http::cors::{Any, CorsLayer};
use cozo::{
format_error_as_json, DataValue, DbInstance, MultiTransaction, NamedRows, ScriptMutability,
SimpleFixedRule,
};
use cozo::{DataValue, DbInstance, NamedRows, ScriptMutability};
#[derive(Args, Debug)]
pub(crate) struct ServerArgs {
@ -57,9 +54,6 @@ pub(crate) struct ServerArgs {
#[clap(short, long, default_value_t = String::from("{}"))]
config: String,
// When on, start REPL instead of starting a webserver
// #[clap(short, long)]
// repl: bool,
/// Address to bind the service to
#[clap(short, long, default_value_t = String::from("127.0.0.1"))]
bind: String,
@ -67,21 +61,26 @@ pub(crate) struct ServerArgs {
/// Port to use
#[clap(short = 'P', long, default_value_t = 9070)]
port: u16,
/// When set, the content of the named table will be used as a token table
#[clap(long)]
token_table: Option<String>,
}
#[derive(Clone)]
struct DbState {
db: DbInstance,
rule_senders: Arc<Mutex<BTreeMap<u32, crossbeam::channel::Sender<miette::Result<NamedRows>>>>>,
rule_counter: Arc<AtomicU32>,
tx_counter: Arc<AtomicU32>,
txs: Arc<Mutex<BTreeMap<u32, Arc<MultiTransaction>>>>,
// rule_senders: Arc<Mutex<BTreeMap<u32, crossbeam::channel::Sender<miette::Result<NamedRows>>>>>,
// rule_counter: Arc<AtomicU32>,
// tx_counter: Arc<AtomicU32>,
// txs: Arc<Mutex<BTreeMap<u32, Arc<MultiTransaction>>>>,
}
#[derive(Clone)]
struct MyAuth {
skip_auth: bool,
auth_guard: String,
token_table: Option<Arc<(String, DbInstance)>>,
}
impl<B> AsyncAuthorizeRequest<B> for MyAuth
@ -92,17 +91,18 @@ where
type ResponseBody = BoxBody;
type Future = BoxFuture<'static, Result<Request<B>, Response<Self::ResponseBody>>>;
fn authorize(&mut self, request: Request<B>) -> Self::Future {
fn authorize<'a>(&'a mut self, mut request: Request<B>) -> Self::Future {
let skip_auth = self.skip_auth;
let auth_guard = self.auth_guard.clone();
let token_table = self.token_table.clone();
Box::pin(async move {
if skip_auth {
request.extensions_mut().insert(ScriptMutability::Mutable);
return Ok(request);
}
let ok = match request.headers().get("x-cozo-auth") {
let mutability = match request.headers().get("x-cozo-auth") {
None => match request.uri().query() {
None => false,
Some(q_str) => {
let mut bingo = false;
for pair in q_str.split('&') {
@ -115,15 +115,65 @@ where
}
}
}
bingo
if bingo {
Some(ScriptMutability::Mutable)
} else {
None
}
}
None => match token_table {
None => None,
Some(tt) => {
let (name, db) = tt.as_ref();
if let Some(auth_header) = request.headers().get("Authorization") {
if let Ok(auth_str) = auth_header.to_str() {
if auth_str.starts_with("Bearer ") {
let token = &auth_str["Bearer ".len()..];
match db.run_script(
&format!("?[mutable] := *{name} {{ token: $token, mutable }}"),
BTreeMap::from([(String::from("token"), DataValue::from(token))]),
ScriptMutability::Immutable,
) {
Ok(rows) => match rows.rows.first() {
None => None,
Some(val) => {
if val[0].get_bool() == Some(true) {
Some(ScriptMutability::Mutable)
} else {
Some(ScriptMutability::Immutable)
}
}
},
Err(err) => {
eprintln!("Error: {}", err);
None
},
}
} else {
None
}
} else {
None
}
} else {
None
}
}
},
},
Some(data) => match data.to_str() {
Ok(s) => s == auth_guard.as_str(),
Err(_) => false,
Ok(s) => {
if s == auth_guard.as_str() {
Some(ScriptMutability::Mutable)
} else {
None
}
}
Err(_) => None,
},
};
if ok {
if let Some(mutability) = mutability {
request.extensions_mut().insert(mutability);
Ok(request)
} else {
let unauthorized_response = Response::builder()
@ -174,12 +224,18 @@ pub(crate) async fn server_main(args: ServerArgs) {
}
};
let auth_obj = MyAuth {
skip_auth,
auth_guard,
token_table: args.token_table.map(|t| Arc::new((t, db.clone()))),
};
let state = DbState {
db,
rule_senders: Default::default(),
rule_counter: Default::default(),
tx_counter: Default::default(),
txs: Default::default(),
// rule_senders: Default::default(),
// rule_counter: Default::default(),
// tx_counter: Default::default(),
// txs: Default::default(),
};
let cors = CorsLayer::new()
.allow_methods([Method::GET, Method::POST, Method::PUT, Method::DELETE])
@ -193,18 +249,15 @@ pub(crate) async fn server_main(args: ServerArgs) {
.route("/backup", post(backup))
.route("/import-from-backup", post(import_from_backup))
.route("/changes/:relation", get(observe_changes))
.route("/rules/:name", get(register_rule))
.route(
"/rule-result/:id",
post(post_rule_result).delete(post_rule_err),
) // +keep alive
.route("/transact", post(start_transact))
.route("/transact/:id", post(transact_query).put(finish_query))
// .route("/rules/:name", get(register_rule))
// .route(
// "/rule-result/:id",
// post(post_rule_result).delete(post_rule_err),
// ) // +keep alive
// .route("/transact", post(start_transact))
// .route("/transact/:id", post(transact_query).put(finish_query))
.with_state(state)
.layer(AsyncRequireAuthorizationLayer::new(MyAuth {
skip_auth,
auth_guard,
}))
.layer(AsyncRequireAuthorizationLayer::new(auth_obj))
.fallback(not_found)
.route("/", get(root))
.layer(cors)
@ -235,76 +288,76 @@ pub(crate) async fn server_main(args: ServerArgs) {
#[derive(serde_derive::Deserialize)]
struct StartTransactPayload {
write: bool,
}
async fn start_transact(
State(st): State<DbState>,
Query(payload): Query<StartTransactPayload>,
) -> (StatusCode, Json<serde_json::Value>) {
let tx = st.db.multi_transaction(payload.write);
let id = st.tx_counter.fetch_add(1, Ordering::SeqCst);
st.txs.lock().unwrap().insert(id, Arc::new(tx));
(StatusCode::OK, json!({"ok": true, "id": id}).into())
}
async fn transact_query(
State(st): State<DbState>,
Path(id): Path<u32>,
Json(payload): Json<QueryPayload>,
) -> (StatusCode, Json<serde_json::Value>) {
let tx = match st.txs.lock().unwrap().get(&id) {
None => return (StatusCode::NOT_FOUND, json!({"ok": false}).into()),
Some(tx) => tx.clone(),
};
let src = payload.script.clone();
let result = spawn_blocking(move || {
let params = payload
.params
.into_iter()
.map(|(k, v)| (k, DataValue::from(v)))
.collect();
let query = payload.script;
tx.run_script(&query, params)
})
.await;
match result {
Ok(Ok(res)) => (StatusCode::OK, res.into_json().into()),
Ok(Err(err)) => (
StatusCode::BAD_REQUEST,
format_error_as_json(err, Some(&src)).into(),
),
Err(err) => internal_error(err),
}
// write: bool,
}
#[derive(serde_derive::Deserialize)]
struct FinishTransactPayload {
abort: bool,
}
async fn finish_query(
State(st): State<DbState>,
Path(id): Path<u32>,
Json(payload): Json<FinishTransactPayload>,
) -> (StatusCode, Json<serde_json::Value>) {
let tx = match st.txs.lock().unwrap().remove(&id) {
None => return (StatusCode::NOT_FOUND, json!({"ok": false}).into()),
Some(tx) => tx,
};
let res = if payload.abort {
tx.abort()
} else {
tx.commit()
};
match res {
Ok(_) => (StatusCode::OK, json!({"ok": true}).into()),
Err(err) => (
StatusCode::BAD_REQUEST,
json!({"ok": false, "message": err.to_string()}).into(),
),
}
}
// async fn start_transact(
// State(st): State<DbState>,
// Query(payload): Query<StartTransactPayload>,
// ) -> (StatusCode, Json<serde_json::Value>) {
// let tx = st.db.multi_transaction(payload.write);
// let id = st.tx_counter.fetch_add(1, Ordering::SeqCst);
// st.txs.lock().unwrap().insert(id, Arc::new(tx));
// (StatusCode::OK, json!({"ok": true, "id": id}).into())
// }
// async fn transact_query(
// State(st): State<DbState>,
// Path(id): Path<u32>,
// Json(payload): Json<QueryPayload>,
// ) -> (StatusCode, Json<serde_json::Value>) {
// let tx = match st.txs.lock().unwrap().get(&id) {
// None => return (StatusCode::NOT_FOUND, json!({"ok": false}).into()),
// Some(tx) => tx.clone(),
// };
// let src = payload.script.clone();
// let result = spawn_blocking(move || {
// let params = payload
// .params
// .into_iter()
// .map(|(k, v)| (k, DataValue::from(v)))
// .collect();
// let query = payload.script;
// tx.run_script(&query, params)
// })
// .await;
// match result {
// Ok(Ok(res)) => (StatusCode::OK, res.into_json().into()),
// Ok(Err(err)) => (
// StatusCode::BAD_REQUEST,
// format_error_as_json(err, Some(&src)).into(),
// ),
// Err(err) => internal_error(err),
// }
// }
// #[derive(serde_derive::Deserialize)]
// struct FinishTransactPayload {
// abort: bool,
// }
// async fn finish_query(
// State(st): State<DbState>,
// Path(id): Path<u32>,
// Json(payload): Json<FinishTransactPayload>,
// ) -> (StatusCode, Json<serde_json::Value>) {
// let tx = match st.txs.lock().unwrap().remove(&id) {
// None => return (StatusCode::NOT_FOUND, json!({"ok": false}).into()),
// Some(tx) => tx,
// };
// let res = if payload.abort {
// tx.abort()
// } else {
// tx.commit()
// };
// match res {
// Ok(_) => (StatusCode::OK, json!({"ok": true}).into()),
// Err(err) => (
// StatusCode::BAD_REQUEST,
// json!({"ok": false, "message": err.to_string()}).into(),
// ),
// }
// }
#[derive(serde_derive::Deserialize)]
struct QueryPayload {
@ -314,6 +367,7 @@ struct QueryPayload {
}
async fn text_query(
Extension(mutability): Extension<ScriptMutability>,
State(st): State<DbState>,
Json(payload): Json<QueryPayload>,
) -> (StatusCode, Json<serde_json::Value>) {
@ -322,7 +376,10 @@ async fn text_query(
.into_iter()
.map(|(k, v)| (k, DataValue::from(v)))
.collect();
let immutable = payload.immutable.unwrap_or(false);
let immutable = match mutability {
ScriptMutability::Mutable => payload.immutable.unwrap_or(false),
ScriptMutability::Immutable => true,
};
let result = spawn_blocking(move || {
st.db.run_script_fold_err(
&payload.script,
@ -457,118 +514,118 @@ async fn import_from_backup(
}
}
#[derive(serde_derive::Deserialize)]
struct RuleRegisterOptions {
arity: usize,
}
async fn post_rule_result(
State(st): State<DbState>,
Path(id): Path<u32>,
Json(res): Json<serde_json::Value>,
) -> (StatusCode, Json<serde_json::Value>) {
let res = match NamedRows::from_json(&res) {
Ok(res) => res,
Err(err) => {
if let Some(ch) = st.rule_senders.lock().unwrap().remove(&id) {
let _ = ch.send(Err(miette!("downstream posted malformed result")));
}
return (
StatusCode::BAD_REQUEST,
json!({"ok": false, "message": err.to_string()}).into(),
);
}
};
if let Some(ch) = st.rule_senders.lock().unwrap().remove(&id) {
match ch.send(Ok(res)) {
Ok(_) => (StatusCode::OK, json!({"ok": true}).into()),
Err(err) => (
StatusCode::INTERNAL_SERVER_ERROR,
json!({"ok": false, "message": err.to_string()}).into(),
),
}
} else {
(StatusCode::NOT_FOUND, json!({"ok": false}).into())
}
}
async fn post_rule_err(
State(st): State<DbState>,
Path(id): Path<u32>,
) -> (StatusCode, Json<serde_json::Value>) {
if let Some(ch) = st.rule_senders.lock().unwrap().remove(&id) {
match ch.send(Err(miette!("downstream cancelled computation"))) {
Ok(_) => (StatusCode::OK, json!({"ok": true}).into()),
Err(err) => (
StatusCode::INTERNAL_SERVER_ERROR,
json!({"ok": false, "message": err.to_string()}).into(),
),
}
} else {
(StatusCode::NOT_FOUND, json!({"ok": false}).into())
}
}
async fn register_rule(
State(st): State<DbState>,
Path(name): Path<String>,
Query(rule_opts): Query<RuleRegisterOptions>,
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
let (rule, task_receiver) = SimpleFixedRule::rule_with_channel(rule_opts.arity);
let (down_sender, mut down_receiver) = tokio::sync::mpsc::channel(1);
let mut errored = None;
if let Err(err) = st.db.register_fixed_rule(name.clone(), rule) {
errored = Some(err);
} else {
let rule_senders = st.rule_senders.clone();
let rule_counter = st.rule_counter.clone();
thread::spawn(move || {
for (inputs, options, sender) in task_receiver {
let id = rule_counter.fetch_add(1, Ordering::AcqRel);
let inputs: serde_json::Value =
inputs.into_iter().map(|ip| ip.into_json()).collect();
let options: serde_json::Value = options
.into_iter()
.map(|(k, v)| (k, serde_json::Value::from(v)))
.collect();
if down_sender.blocking_send((id, inputs, options)).is_err() {
let _ = sender.send(Err(miette!("cannot send request to downstream")));
} else {
rule_senders.lock().unwrap().insert(id, sender);
}
}
});
}
struct Guard {
name: String,
db: DbInstance,
}
impl Drop for Guard {
fn drop(&mut self) {
info!("dropping rules SSE {}", self.name);
let _ = self.db.unregister_fixed_rule(&self.name);
}
}
let stream = async_stream::stream! {
if let Some(err) = errored {
let item = json!({"type": "register-error", "error": err.to_string()});
yield Ok(Event::default().json_data(item).unwrap());
} else {
info!("starting rule SSE {}", name);
let _guard = Guard {db: st.db, name};
while let Some((id, inputs, options)) = down_receiver.recv().await {
let item = json!({"type": "request", "id": id, "inputs": inputs, "options": options});
yield Ok(Event::default().json_data(item).unwrap());
}
}
};
Sse::new(stream).keep_alive(KeepAlive::default())
}
// #[derive(serde_derive::Deserialize)]
// struct RuleRegisterOptions {
// arity: usize,
// }
// async fn post_rule_result(
// State(st): State<DbState>,
// Path(id): Path<u32>,
// Json(res): Json<serde_json::Value>,
// ) -> (StatusCode, Json<serde_json::Value>) {
// let res = match NamedRows::from_json(&res) {
// Ok(res) => res,
// Err(err) => {
// if let Some(ch) = st.rule_senders.lock().unwrap().remove(&id) {
// let _ = ch.send(Err(miette!("downstream posted malformed result")));
// }
// return (
// StatusCode::BAD_REQUEST,
// json!({"ok": false, "message": err.to_string()}).into(),
// );
// }
// };
// if let Some(ch) = st.rule_senders.lock().unwrap().remove(&id) {
// match ch.send(Ok(res)) {
// Ok(_) => (StatusCode::OK, json!({"ok": true}).into()),
// Err(err) => (
// StatusCode::INTERNAL_SERVER_ERROR,
// json!({"ok": false, "message": err.to_string()}).into(),
// ),
// }
// } else {
// (StatusCode::NOT_FOUND, json!({"ok": false}).into())
// }
// }
// async fn post_rule_err(
// State(st): State<DbState>,
// Path(id): Path<u32>,
// ) -> (StatusCode, Json<serde_json::Value>) {
// if let Some(ch) = st.rule_senders.lock().unwrap().remove(&id) {
// match ch.send(Err(miette!("downstream cancelled computation"))) {
// Ok(_) => (StatusCode::OK, json!({"ok": true}).into()),
// Err(err) => (
// StatusCode::INTERNAL_SERVER_ERROR,
// json!({"ok": false, "message": err.to_string()}).into(),
// ),
// }
// } else {
// (StatusCode::NOT_FOUND, json!({"ok": false}).into())
// }
// }
// async fn register_rule(
// State(st): State<DbState>,
// Path(name): Path<String>,
// Query(rule_opts): Query<RuleRegisterOptions>,
// ) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
// let (rule, task_receiver) = SimpleFixedRule::rule_with_channel(rule_opts.arity);
// let (down_sender, mut down_receiver) = tokio::sync::mpsc::channel(1);
// let mut errored = None;
//
// if let Err(err) = st.db.register_fixed_rule(name.clone(), rule) {
// errored = Some(err);
// } else {
// let rule_senders = st.rule_senders.clone();
// let rule_counter = st.rule_counter.clone();
//
// thread::spawn(move || {
// for (inputs, options, sender) in task_receiver {
// let id = rule_counter.fetch_add(1, Ordering::AcqRel);
// let inputs: serde_json::Value =
// inputs.into_iter().map(|ip| ip.into_json()).collect();
// let options: serde_json::Value = options
// .into_iter()
// .map(|(k, v)| (k, serde_json::Value::from(v)))
// .collect();
// if down_sender.blocking_send((id, inputs, options)).is_err() {
// let _ = sender.send(Err(miette!("cannot send request to downstream")));
// } else {
// rule_senders.lock().unwrap().insert(id, sender);
// }
// }
// });
// }
//
// struct Guard {
// name: String,
// db: DbInstance,
// }
//
// impl Drop for Guard {
// fn drop(&mut self) {
// info!("dropping rules SSE {}", self.name);
// let _ = self.db.unregister_fixed_rule(&self.name);
// }
// }
//
// let stream = async_stream::stream! {
// if let Some(err) = errored {
// let item = json!({"type": "register-error", "error": err.to_string()});
// yield Ok(Event::default().json_data(item).unwrap());
// } else {
// info!("starting rule SSE {}", name);
// let _guard = Guard {db: st.db, name};
// while let Some((id, inputs, options)) = down_receiver.recv().await {
// let item = json!({"type": "request", "id": id, "inputs": inputs, "options": options});
// yield Ok(Event::default().json_data(item).unwrap());
// }
// }
// };
// Sse::new(stream).keep_alive(KeepAlive::default())
// }
async fn observe_changes(
State(st): State<DbState>,

Loading…
Cancel
Save