diff --git a/README.md b/README.md index f2ab762..c46b088 100644 --- a/README.md +++ b/README.md @@ -317,6 +317,21 @@ Examples: ``` ACLs are interpreted in order, later authorizations superseding earlier ones if a client matches several. The special value "*" matches all commands. +## ACLs and Websocket clients + +These rules apply to WebSocket connections as well, although without support for HTTP Basic Auth filtering. IP filtering is supported. + +For JSON-based WebSocket clients, a rejected command will return this object (sent as a string in a binary frame): +```json +{"message": "Forbidden", "error": true, "http_status": 403} +``` +The `http_status` code is an indicator of how Webdis would have responded if the client had used HTTP instead of a WebSocket connection, since WebSocket messages do not inherently have a status code. + +For raw Redis protocol WebSocket clients, a rejected command will produce this error (sent as a string in a binary frame): +``` +-ERR Forbidden\r\n +``` + # Environment variables Environment variables can be used in `webdis.json` to read values from the environment instead of using constant values. diff --git a/src/cmd.h b/src/cmd.h index 823c3e3..6e9b4f4 100644 --- a/src/cmd.h +++ b/src/cmd.h @@ -14,6 +14,8 @@ struct worker; struct cmd; typedef void (*formatting_fun)(redisAsyncContext *, void *, void *); +typedef char* (*ws_error_fun)(int http_status, const char *msg, size_t msg_sz, size_t *out_sz); + typedef enum {CMD_SENT, CMD_PARAM_ERROR, CMD_ACL_FAIL, diff --git a/src/formats/json.c b/src/formats/json.c index 833e7b1..ba4a887 100644 --- a/src/formats/json.c +++ b/src/formats/json.c @@ -554,3 +554,23 @@ json_ws_extract(struct http_client *c, const char *p, size_t sz) { json_decref(j); return cmd; } + +/* Formats a WebSocket error message */ +char* json_ws_error(int http_status, const char *msg, size_t msg_sz, size_t *out_sz) { + + (void)msg_sz; /* unused */ + json_t *jroot = json_object(); + char *jstr; + + /* e.g. {"message": "Forbidden", "error": true, "http_status": 403} */ + /* Note: this is only an equivalent HTTP status code, we're sending a WS message not an HTTP response */ + json_object_set_new(jroot, "error", json_true()); + json_object_set_new(jroot, "message", json_string(msg)); + json_object_set_new(jroot, "http_status", json_integer(http_status)); + + jstr = json_string_output(jroot, NULL); + json_decref(jroot); + + *out_sz = strlen(jstr); + return jstr; +} diff --git a/src/formats/json.h b/src/formats/json.h index 911f8f5..8192f43 100644 --- a/src/formats/json.h +++ b/src/formats/json.h @@ -17,4 +17,7 @@ json_string_output(json_t *j, const char *jsonp); struct cmd * json_ws_extract(struct http_client *c, const char *p, size_t sz); +char* +json_ws_error(int http_status, const char *msg, size_t msg_sz, size_t *out_sz); + #endif diff --git a/src/formats/raw.c b/src/formats/raw.c index 77005e6..957e041 100644 --- a/src/formats/raw.c +++ b/src/formats/raw.c @@ -190,3 +190,22 @@ raw_wrap(const redisReply *r, size_t *sz) { } } + +/* Formats a WebSocket error message */ +char* raw_ws_error(int http_status, const char *msg, size_t msg_sz, size_t *out_sz) { + + (void)http_status; /* unused */ + char *ret, *p; + + /* e.g. "-ERR unknown command 'foo'\r\n" */ + *out_sz = 5 + msg_sz + 2; + p = ret = malloc(*out_sz); + + memcpy(p, "-ERR ", 5); + p += 5; + memcpy(p, msg, msg_sz); + p += msg_sz; + memcpy(p, "\r\n", 2); + + return ret; +} diff --git a/src/formats/raw.h b/src/formats/raw.h index 10321e3..03c41c7 100644 --- a/src/formats/raw.h +++ b/src/formats/raw.h @@ -13,4 +13,7 @@ raw_reply(redisAsyncContext *c, void *r, void *privdata); struct cmd * raw_ws_extract(struct http_client *c, const char *p, size_t sz); +char* +raw_ws_error(int http_status, const char *msg, size_t msg_sz, size_t *out_sz); + #endif diff --git a/src/websocket.c b/src/websocket.c index 498f6b3..85b1093 100644 --- a/src/websocket.c +++ b/src/websocket.c @@ -1,5 +1,6 @@ #include "sha1/sha1.h" #include +#include "acl.h" #include "websocket.h" #include "client.h" #include "cmd.h" @@ -255,6 +256,15 @@ ws_log_cmd(struct ws_client *ws, struct cmd *cmd) { slog(ws->http_client->s, WEBDIS_DEBUG, log_msg, p - log_msg); } +static void +ws_log_unauthorized(struct ws_client *ws) { + if(!slog_enabled(ws->http_client->s, WEBDIS_DEBUG)) { + return; + } + const char msg[] = "WS: 403"; + slog(ws->http_client->s, WEBDIS_DEBUG, msg, sizeof(msg)-1); +} + static int ws_execute(struct ws_client *ws, struct ws_msg *msg) { @@ -262,14 +272,17 @@ ws_execute(struct ws_client *ws, struct ws_msg *msg) { struct http_client *c = ws->http_client; struct cmd*(*fun_extract)(struct http_client *, const char *, size_t) = NULL; formatting_fun fun_reply = NULL; + ws_error_fun fun_error = NULL; if((c->path_sz == 1 && strncmp(c->path, "/", 1) == 0) || strncmp(c->path, "/.json", 6) == 0) { fun_extract = json_ws_extract; fun_reply = json_reply; + fun_error = json_ws_error; } else if(strncmp(c->path, "/.raw", 5) == 0) { fun_extract = raw_ws_extract; fun_reply = raw_reply; + fun_error = raw_ws_error; } if(fun_extract) { @@ -311,7 +324,17 @@ ws_execute(struct ws_client *ws, struct ws_msg *msg) { int is_subscribe = cmd_is_subscribe_args(cmd); int is_unsubscribe = cmd_is_unsubscribe_args(cmd); - if(ws->ran_subscribe && !is_subscribe && !is_unsubscribe) { /* disallow non-subscribe commands after a subscribe */ + /* check that the client is able to run this command */ + if(!acl_allow_command(cmd, c->s->cfg, c)) { + const char forbidden[] = "Forbidden"; + size_t error_sz; + char *error = fun_error(403, forbidden, sizeof(forbidden)-1, &error_sz); + ws_frame_and_send_response(ws, WS_BINARY_FRAME, error, error_sz); + free(error); + /* similar to HTTP: log command first and then rejection, both with "WS: " prefix */ + ws_log_cmd(ws, cmd); + ws_log_unauthorized(ws); + } else if(ws->ran_subscribe && !is_subscribe && !is_unsubscribe) { /* disallow non-subscribe commands after a subscribe */ char error_msg[] = "Command not allowed after subscribe"; ws_frame_and_send_response(ws, WS_BINARY_FRAME, error_msg, sizeof(error_msg)-1); } else { /* log and execute */ diff --git a/tests/websocket.html b/tests/websocket.html index 00cca12..6ce1962 100644 --- a/tests/websocket.html +++ b/tests/websocket.html @@ -121,6 +121,15 @@ function installBlock(title, type) { +
+
+
+
 
+
+
+
+
+
 
@@ -199,6 +208,12 @@ class Client { } }); + $(`${this.type}-btn-debug`).addEventListener('click', event => { + event.preventDefault(); + const serialized = this.serializer(['DEBUG', 'OBJECT', 'foo']); + this.send(serialized); + }); + $(`${this.type}-btn-clear`).addEventListener('click', event => { event.preventDefault(); $(`${this.type}-log`).innerText = ""; @@ -226,6 +241,7 @@ class Client { $(`${this.type}-btn-pub`).disabled = !connected || this.subscribed; $(`${this.type}-sub-channel`).disabled = !connected || this.subscribed; $(`${this.type}-btn-sub`).disabled = !connected || this.subscribed; + $(`${this.type}-btn-debug`).disabled = !connected || this.subscribed; $(`${this.type}-state`).innerText = `State: ${connected ? 'Connected' : 'Disconnected'}`; } diff --git a/tests/ws-tests.py b/tests/ws-tests.py index 43549b7..08d0d91 100755 --- a/tests/ws-tests.py +++ b/tests/ws-tests.py @@ -67,6 +67,11 @@ class TestJson(TestWebdis): def test_ping(self): self.assertEqual(self.exec('PING'), {'PING': [True, 'PONG']}) + def test_acl(self): + key, value = self.clean_key(), str(uuid.uuid4()) + self.assertEqual(self.exec('SET', key, value), {'SET': [True, 'OK']}) + self.assertEqual(self.exec('DEBUG', 'OBJECT', key), {'error': True, 'message': 'Forbidden', 'http_status': 403}) + def test_multiple_messages(self): key = self.clean_key() n = 100 @@ -92,6 +97,11 @@ class TestRaw(TestWebdis): def test_ping(self): self.assertEqual(self.exec('PING'), "+PONG\r\n") + def test_acl(self): + key, value = self.clean_key(), str(uuid.uuid4()) + self.assertEqual(self.exec('SET', key, value), "+OK\r\n") + self.assertEqual(self.exec('DEBUG', 'OBJECT', key), "-ERR Forbidden\r\n") + def test_get_set(self): key = self.clean_key() value = str(uuid.uuid4())