Merge pull request #199 from jessie-murray/ws-improvements

master
Nicolas Favre-Felix 3 years ago committed by GitHub
commit 0528287aa6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -286,15 +286,15 @@ http_client_read(struct http_client *c) {
if(ret <= 0) {
/* broken link, free buffer and client object */
/* disconnect pub/sub client if there is one. */
if(c->pub_sub && c->pub_sub->ac) {
struct cmd *cmd = c->pub_sub;
/* disconnect pub/sub or WS client if there is one. */
if(c->reused_cmd && c->reused_cmd->ac) {
struct cmd *cmd = c->reused_cmd;
/* disconnect from all channels */
redisAsyncDisconnect(c->pub_sub->ac);
// c->pub_sub might be already cleared by an event handler in redisAsyncDisconnect
redisAsyncDisconnect(c->reused_cmd->ac);
// c->reused_cmd might be already cleared by an event handler in redisAsyncDisconnect
cmd->ac = NULL;
c->pub_sub = NULL;
c->reused_cmd = NULL;
/* delete command object */
cmd_free(cmd);

@ -61,9 +61,9 @@ struct http_client {
char *separator; /* list separator for raw lists */
char *filename; /* content-disposition */
struct cmd *pub_sub;
struct cmd *reused_cmd;
struct ws_msg *frame; /* websocket frame */
struct ws_client *ws; /* websocket client */
};
struct http_client *

@ -6,7 +6,6 @@
#include "worker.h"
#include "http.h"
#include "server.h"
#include "slog.h"
#include "formats/json.h"
#include "formats/raw.h"
@ -22,11 +21,12 @@
#include <ctype.h>
struct cmd *
cmd_new(int count) {
cmd_new(struct http_client *client, int count) {
struct cmd *c = calloc(1, sizeof(struct cmd));
c->count = count;
c->http_client = client;
c->argv = calloc(count, sizeof(char*));
c->argv_len = calloc(count, sizeof(size_t));
@ -34,11 +34,21 @@ cmd_new(int count) {
return c;
}
void
cmd_free_argv(struct cmd *c) {
int i;
for(i = 0; i < c->count; ++i) {
free((char*)c->argv[i]);
}
free(c->argv);
free(c->argv_len);
}
void
cmd_free(struct cmd *c) {
int i;
if(!c) return;
free(c->jsonp);
@ -52,12 +62,7 @@ cmd_free(struct cmd *c) {
pool_free_context(c->ac);
}
for(i = 0; i < c->count; ++i) {
free((char*)c->argv[i]);
}
free(c->argv);
free(c->argv_len);
cmd_free_argv(c);
free(c);
}
@ -164,7 +169,7 @@ cmd_run(struct worker *w, struct http_client *client,
return CMD_PARAM_ERROR;
}
cmd = cmd_new(param_count);
cmd = cmd_new(client, param_count);
cmd->fd = client->fd;
cmd->database = w->s->cfg->database;
@ -224,7 +229,7 @@ cmd_run(struct worker *w, struct http_client *client,
cmd->ac = (redisAsyncContext*)pool_connect(w->pool, cmd->database, 0);
/* register with the client, used upon disconnection */
client->pub_sub = cmd;
client->reused_cmd = cmd;
cmd->pub_sub_client = client;
} else if(cmd->database != w->s->cfg->database) {
/* create a new connection to Redis for custom DBs */
@ -276,7 +281,7 @@ cmd_run(struct worker *w, struct http_client *client,
}
/* failed to find a suitable connection to Redis. */
cmd_free(cmd);
client->pub_sub = NULL;
client->reused_cmd = NULL;
return CMD_REDIS_UNAVAIL;
}
@ -370,10 +375,31 @@ cmd_select_format(struct http_client *client, struct cmd *cmd,
int
cmd_is_subscribe(struct cmd *cmd) {
if(cmd->count >= 1 && cmd->argv[0] &&
(strncasecmp(cmd->argv[0], "SUBSCRIBE", cmd->argv_len[0]) == 0 ||
strncasecmp(cmd->argv[0], "PSUBSCRIBE", cmd->argv_len[0]) == 0)) {
if(cmd->pub_sub_client || /* persistent command */
cmd_is_subscribe_args(cmd)) { /* checked with args */
return 1;
}
return 0;
}
int
cmd_is_subscribe_args(struct cmd *cmd) {
if(cmd->count >= 2 &&
((cmd->argv_len[0] == 9 && strncasecmp(cmd->argv[0], "subscribe", 9) == 0) ||
(cmd->argv_len[0] == 10 && strncasecmp(cmd->argv[0], "psubscribe", 10) == 0))) {
return 1;
}
return 0;
}
int
cmd_is_unsubscribe_args(struct cmd *cmd) {
if(cmd->count >= 2 &&
((cmd->argv_len[0] == 11 && strncasecmp(cmd->argv[0], "unsubscribe", 11) == 0) ||
(cmd->argv_len[0] == 12 && strncasecmp(cmd->argv[0], "punsubscribe", 12) == 0))) {
return 1;
}
return 0;
}

@ -43,6 +43,7 @@ struct cmd {
int http_version;
int database;
struct http_client *http_client;
struct http_client *pub_sub_client;
redisAsyncContext *ac;
struct worker *w;
@ -54,7 +55,10 @@ struct subscription {
};
struct cmd *
cmd_new(int count);
cmd_new(struct http_client *c, int count);
void
cmd_free_argv(struct cmd *c);
void
cmd_free(struct cmd *c);
@ -68,6 +72,12 @@ int
cmd_select_format(struct http_client *client, struct cmd *cmd,
const char *uri, size_t uri_len, formatting_fun *f_format);
int
cmd_is_subscribe_args(struct cmd *cmd);
int
cmd_is_unsubscribe_args(struct cmd *cmd);
int
cmd_is_subscribe(struct cmd *cmd);

@ -7,6 +7,8 @@
#include "md5/md5.h"
#include <string.h>
#include <unistd.h>
#include <pthread.h>
#include <ctype.h>
/* TODO: replace this with a faster hash function? */
char *etag_new(const char *p, size_t sz) {
@ -44,13 +46,16 @@ format_send_error(struct cmd *cmd, short code, const char *msg) {
resp->http_version = cmd->http_version;
http_response_set_keep_alive(resp, cmd->keep_alive);
http_response_write(resp, cmd->fd);
} else if(cmd->is_websocket && !cmd->http_client->ws->close_after_events) {
ws_frame_and_send_response(cmd->http_client->ws, WS_BINARY_FRAME, msg, strlen(msg));
}
/* for pub/sub, remove command from client */
if(cmd->pub_sub_client) {
cmd->pub_sub_client->pub_sub = NULL;
} else {
cmd_free(cmd);
if (!cmd->is_websocket) { /* don't free or detach persistent cmd */
if (cmd->pub_sub_client) { /* for pub/sub, remove command from client */
cmd->pub_sub_client->reused_cmd = NULL;
} else {
cmd_free(cmd);
}
}
}
@ -62,7 +67,8 @@ format_send_reply(struct cmd *cmd, const char *p, size_t sz, const char *content
struct http_response *resp;
if(cmd->is_websocket) {
ws_reply(cmd, p, sz);
ws_frame_and_send_response(cmd->http_client->ws, WS_BINARY_FRAME, p, sz);
/* If it's a subscribe command, there'll be more responses */
if(!cmd_is_subscribe(cmd))

@ -522,7 +522,7 @@ json_ws_extract(struct http_client *c, const char *p, size_t sz) {
}
/* create command and add args */
cmd = cmd_new(argc);
cmd = cmd_new(c, argc);
for(i = 0, cur = 0; i < json_array_size(j); ++i) {
json_t *jelem = json_array_get(j, i);
char *tmp;

@ -62,7 +62,7 @@ raw_ws_extract(struct http_client *c, const char *p, size_t sz) {
}
/* create cmd object */
cmd = cmd_new(reply->elements);
cmd = cmd_new(c, reply->elements);
for(i = 0; i < reply->elements; ++i) {
redisReply *ri = reply->element[i];

@ -16,6 +16,10 @@ http_response_init(struct worker *w, int code, const char *msg) {
/* create object */
struct http_response *r = calloc(1, sizeof(struct http_response));
if(!r) {
if(w && w->s) slog(w->s, WEBDIS_ERROR, "Failed to allocate http_response", 0);
return NULL;
}
r->code = code;
r->msg = msg;
@ -43,6 +47,24 @@ http_response_init(struct worker *w, int code, const char *msg) {
return r;
}
struct http_response *
http_response_init_with_buffer(struct worker *w, char *data, size_t data_sz, int keep_alive) {
struct http_response *r = calloc(1, sizeof(struct http_response));
if(!r) {
if(w && w->s) slog(w->s, WEBDIS_ERROR, "Failed to allocate http_response with buffer", 0);
return NULL;
}
r->w = w;
/* provide buffer directly */
r->out = data;
r->out_sz = data_sz;
r->sent = 0;
r->keep_alive = keep_alive;
return r;
}
void
http_response_set_header(struct http_response *r, const char *k, const char *v) {

@ -45,6 +45,9 @@ struct http_response {
struct http_response *
http_response_init(struct worker *w, int code, const char *msg);
struct http_response *
http_response_init_with_buffer(struct worker *w, char *data, size_t data_sz, int keep_alive);
void
http_response_set_header(struct http_response *r, const char *k, const char *v);

@ -30,7 +30,6 @@ pool_free_context(redisAsyncContext *ac) {
if (ac) {
redisAsyncDisconnect(ac);
redisAsyncFree(ac);
}
}
@ -96,6 +95,11 @@ pool_on_disconnect(const redisAsyncContext *ac, int status) {
struct pool *p = ac->data;
int i = 0;
if(p == NULL) { /* no need to clean anything here. */
return;
}
if (status != REDIS_OK) {
char format[] = "Error disconnecting: %s";
size_t msg_sz = sizeof(format) - 2 + ((ac && ac->errstr) ? strlen(ac->errstr) : 6);
@ -107,10 +111,6 @@ pool_on_disconnect(const redisAsyncContext *ac, int status) {
}
}
if(p == NULL) { /* no need to clean anything here. */
return;
}
/* remove from the pool */
for(i = 0; i < p->count; ++i) {
if(p->ac[i] == ac) {

@ -13,6 +13,10 @@
#include "server.h"
#include "conf.h"
#if SLOG_MSG_MAX_LEN < 64
#error "SLOG_MSG_MAX_LEN must be at least 64"
#endif
/**
* Initialize log writer.
*/
@ -95,8 +99,8 @@ slog_internal(struct server *s, log_level level,
time_t now;
struct tm now_tm, *lt_ret;
char time_buf[64];
char msg[124];
char line[256]; /* bounds are checked. */
char msg[1 + SLOG_MSG_MAX_LEN];
char line[2 * SLOG_MSG_MAX_LEN]; /* bounds are checked. */
int line_sz, ret;
if(!s->log.fd) return;

@ -1,6 +1,8 @@
#ifndef SLOG_H
#define SLOG_H
#define SLOG_MSG_MAX_LEN 124
typedef enum {
WEBDIS_ERROR = 0,
WEBDIS_WARNING,

@ -7,6 +7,8 @@
#include "pool.h"
#include "http.h"
#include "slog.h"
#include "server.h"
#include "conf.h"
/* message parsers */
#include "formats/json.h"
@ -19,23 +21,21 @@
#include <errno.h>
#include <sys/param.h>
static int
ws_schedule_write(struct ws_client *ws);
/**
* This code uses the WebSocket specification from RFC 6455.
* A copy is available at http://www.rfc-editor.org/rfc/rfc6455.txt
*/
#if __BIG_ENDIAN__
# define webdis_htonll(x) (x)
# define webdis_ntohll(x) (x)
#else
# define webdis_htonll(x) (((uint64_t)htonl((x) & 0xFFFFFFFF) << 32) | htonl((x) >> 32))
# define webdis_ntohll(x) (((uint64_t)ntohl((x) & 0xFFFFFFFF) << 32) | ntohl((x) >> 32))
#endif
/* custom 64-bit encoding functions to avoid portability issues */
#define webdis_ntohl64(p) \
((((uint64_t)((p)[0])) << 0) + (((uint64_t)((p)[1])) << 8) +\
(((uint64_t)((p)[2])) << 16) + (((uint64_t)((p)[3])) << 24) +\
(((uint64_t)((p)[4])) << 32) + (((uint64_t)((p)[5])) << 40) +\
(((uint64_t)((p)[6])) << 48) + (((uint64_t)((p)[7])) << 56))
#define webdis_htonl64(p) {\
(char)(((p & ((uint64_t)0xff << 0)) >> 0) & 0xff), (char)(((p & ((uint64_t)0xff << 8)) >> 8) & 0xff), \
(char)(((p & ((uint64_t)0xff << 16)) >> 16) & 0xff), (char)(((p & ((uint64_t)0xff << 24)) >> 24) & 0xff), \
(char)(((p & ((uint64_t)0xff << 32)) >> 32) & 0xff), (char)(((p & ((uint64_t)0xff << 40)) >> 40) & 0xff), \
(char)(((p & ((uint64_t)0xff << 48)) >> 48) & 0xff), (char)(((p & ((uint64_t)0xff << 56)) >> 56) & 0xff) }
static int
ws_compute_handshake(struct http_client *c, char *out, size_t *out_sz) {
@ -86,25 +86,75 @@ ws_compute_handshake(struct http_client *c, char *out, size_t *out_sz) {
return 0;
}
struct ws_client *
ws_client_new(struct http_client *http_client) {
int db_num = http_client->w->s->cfg->database;
struct ws_client *ws = calloc(1, sizeof(struct ws_client));
struct evbuffer *rbuf = evbuffer_new();
struct evbuffer *wbuf = evbuffer_new();
redisAsyncContext *ac = pool_connect(http_client->w->pool, db_num, 0);
if(!ws || !rbuf || !wbuf) {
slog(http_client->s, WEBDIS_ERROR, "Failed to allocate memory for WS client", 0);
if(ws) free(ws);
if(rbuf) evbuffer_free(rbuf);
if(wbuf) evbuffer_free(wbuf);
if(ac) redisAsyncFree(ac);
return NULL;
}
http_client->ws = ws;
ws->http_client = http_client;
ws->rbuf = rbuf;
ws->wbuf = wbuf;
ws->ac = ac;
return ws;
}
void
ws_client_free(struct ws_client *ws) {
/* mark WS client as closing to skip the Redis callback */
ws->close_after_events = 1;
pool_free_context(ws->ac); /* could trigger a cb via format_send_error */
struct http_client *c = ws->http_client;
if(c) {
close(c->fd);
c->ws = NULL; /* detach if needed */
}
evbuffer_free(ws->rbuf);
evbuffer_free(ws->wbuf);
if(ws->cmd) {
ws->cmd->ac = NULL; /* we've just free'd it */
cmd_free(ws->cmd);
}
free(ws);
if(c) http_client_free(c);
}
int
ws_handshake_reply(struct http_client *c) {
ws_handshake_reply(struct ws_client *ws) {
struct http_client *c = ws->http_client;
char sha1_handshake[40];
char *buffer = NULL, *p;
const char *origin = NULL, *host = NULL;
size_t origin_sz = 0, host_sz = 0, handshake_sz = 0, sz;
char template0[] = "HTTP/1.1 101 Switching Protocols\r\n"
char template_start[] = "HTTP/1.1 101 Switching Protocols\r\n"
"Upgrade: websocket\r\n"
"Connection: Upgrade\r\n"
"Sec-WebSocket-Origin: "; /* %s */
char template1[] = "\r\n"
"Sec-WebSocket-Location: ws://"; /* %s%s */
char template2[] = "\r\n"
"Origin: http://"; /* %s */
char template3[] = "\r\n"
"Connection: Upgrade";
char template_accept[] = "\r\n" /* just after the start */
"Sec-WebSocket-Accept: "; /* %s */
char template4[] = "\r\n\r\n";
char template_sec_origin[] = "\r\n"
"Sec-WebSocket-Origin: "; /* %s (optional header) */
char template_loc[] = "\r\n"
"Sec-WebSocket-Location: ws://"; /* %s%s */
char template_end[] = "\r\n\r\n";
if((origin = client_get_header(c, "Origin"))) {
origin_sz = strlen(origin);
@ -116,7 +166,7 @@ ws_handshake_reply(struct http_client *c) {
}
/* need those headers */
if(!origin || !origin_sz || !host || !host_sz || !c->path || !c->path_sz) {
if(!host || !host_sz || !c->path || !c->path_sz) {
slog(c->s, WEBDIS_WARNING, "Missing headers for WS handshake", 0);
return -1;
}
@ -128,11 +178,11 @@ ws_handshake_reply(struct http_client *c) {
return -1;
}
sz = sizeof(template0)-1 + origin_sz
+ sizeof(template1)-1 + host_sz + c->path_sz
+ sizeof(template2)-1 + host_sz
+ sizeof(template3)-1 + handshake_sz
+ sizeof(template4)-1;
sz = sizeof(template_start)-1
+ sizeof(template_accept)-1 + handshake_sz
+ (origin && origin_sz ? (sizeof(template_sec_origin)-1 + origin_sz) : 0) /* optional origin */
+ sizeof(template_loc)-1 + host_sz + c->path_sz
+ sizeof(template_end)-1;
p = buffer = malloc(sz);
if(!p) {
@ -142,57 +192,74 @@ ws_handshake_reply(struct http_client *c) {
/* Concat all */
/* template0 */
memcpy(p, template0, sizeof(template0)-1);
p += sizeof(template0)-1;
memcpy(p, origin, origin_sz);
p += origin_sz;
/* template_start */
memcpy(p, template_start, sizeof(template_start)-1);
p += sizeof(template_start)-1;
/* template1 */
memcpy(p, template1, sizeof(template1)-1);
p += sizeof(template1)-1;
/* template_accept */
memcpy(p, template_accept, sizeof(template_accept)-1);
p += sizeof(template_accept)-1;
memcpy(p, &sha1_handshake[0], handshake_sz);
p += handshake_sz;
/* template_sec_origin */
if(origin && origin_sz) {
memcpy(p, template_sec_origin, sizeof(template_sec_origin)-1);
p += sizeof(template_sec_origin)-1;
memcpy(p, origin, origin_sz);
p += origin_sz;
}
/* template_loc */
memcpy(p, template_loc, sizeof(template_loc)-1);
p += sizeof(template_loc)-1;
memcpy(p, host, host_sz);
p += host_sz;
memcpy(p, c->path, c->path_sz);
p += c->path_sz;
/* template2 */
memcpy(p, template2, sizeof(template2)-1);
p += sizeof(template2)-1;
memcpy(p, host, host_sz);
p += host_sz;
/* template_end */
memcpy(p, template_end, sizeof(template_end)-1);
p += sizeof(template_end)-1;
/* template3 */
memcpy(p, template3, sizeof(template3)-1);
p += sizeof(template3)-1;
memcpy(p, &sha1_handshake[0], handshake_sz);
p += handshake_sz;
int add_ret = evbuffer_add(ws->wbuf, buffer, sz);
free(buffer);
if(add_ret < 0) {
slog(c->s, WEBDIS_ERROR, "Failed to add response for WS handshake", 0);
return -1;
}
/* template4 */
memcpy(p, template4, sizeof(template4)-1);
p += sizeof(template4)-1;
return ws_schedule_write(ws); /* will free buffer and response once sent */
}
/* build HTTP response object by hand, since we have the full response already */
struct http_response *r = calloc(1, sizeof(struct http_response));
if(!r) {
slog(c->s, WEBDIS_ERROR, "Failed to allocate response for WS handshake", 0);
free(buffer);
return -1;
static void
ws_log_cmd(struct ws_client *ws, struct cmd *cmd) {
char log_msg[SLOG_MSG_MAX_LEN];
char *p = log_msg, *eom = log_msg + sizeof(log_msg) - 1;
if(!slog_enabled(ws->http_client->s, WEBDIS_DEBUG)) {
return;
}
r->w = c->w;
r->keep_alive = 1;
r->out = buffer;
r->out_sz = sz;
r->sent = 0;
http_schedule_write(c->fd, r); /* will free buffer and response once sent */
return 0;
memset(log_msg, 0, sizeof(log_msg));
memcpy(p, "WS: ", 4); /* WS prefix */
p += 4;
for(int i = 0; p < eom && i < cmd->count; i++) {
*p++ = '/';
char *arg = cmd->argv[i];
size_t arg_sz = cmd->argv_len[i];
size_t copy_sz = arg_sz < (size_t)(eom - p) ? arg_sz : (size_t)(eom - p);
memcpy(p, arg, copy_sz);
p += copy_sz;
}
slog(ws->http_client->s, WEBDIS_DEBUG, log_msg, p - log_msg);
}
static int
ws_execute(struct http_client *c, const char *frame, size_t frame_len) {
ws_execute(struct ws_client *ws, struct ws_msg *msg) {
struct http_client *c = ws->http_client;
struct cmd*(*fun_extract)(struct http_client *, const char *, size_t) = NULL;
formatting_fun fun_reply = NULL;
@ -208,31 +275,50 @@ ws_execute(struct http_client *c, const char *frame, size_t frame_len) {
if(fun_extract) {
/* Parse websocket frame into a cmd object. */
struct cmd *cmd = fun_extract(c, frame, frame_len);
struct cmd *cmd = fun_extract(c, msg->payload, msg->payload_sz);
if(cmd) {
/* copy client info into cmd. */
cmd_setup(cmd, c);
cmd->is_websocket = 1;
if (c->pub_sub != NULL) {
/* This client already has its own connection
* to Redis due to a subscription; use it from
* now on. */
cmd->ac = c->pub_sub->ac;
} else if (cmd_is_subscribe(cmd)) {
/* New subscribe command; make new Redis context
* for this client */
cmd->ac = pool_connect(c->w->pool, cmd->database, 0);
c->pub_sub = cmd;
cmd->pub_sub_client = c;
if(ws->cmd != NULL) {
/* This client already has its own connection to Redis
from a previous command; use it from now on. */
/* free args for the previous cmd */
cmd_free_argv(ws->cmd);
/* copy args from what we just parsed to the persistent command */
ws->cmd->count = cmd->count;
ws->cmd->argv = cmd->argv;
ws->cmd->argv_len = cmd->argv_len;
ws->cmd->pub_sub_client = c; /* mark as persistent, otherwise the Redis context will be freed */
cmd->argv = NULL;
cmd->argv_len = NULL;
cmd->count = 0;
cmd_free(cmd);
cmd = ws->cmd; /* replace pointer since we're about to pass it to cmd_send */
} else {
/* get Redis connection from pool */
cmd->ac = (redisAsyncContext*)pool_get_context(c->w->pool);
/* copy client info into cmd. */
cmd_setup(cmd, c);
/* First WS command; use Redis context from WS client. */
cmd->ac = ws->ac;
ws->cmd = cmd;
cmd->pub_sub_client = c;
}
/* send it off */
cmd_send(cmd, fun_reply);
int is_subscribe = cmd_is_subscribe_args(cmd);
int is_unsubscribe = cmd_is_unsubscribe_args(cmd);
if(ws->ran_subscribe && !is_subscribe && !is_unsubscribe) { /* disallow non-subscribe commands after a subscribe */
char error_msg[] = "Command not allowed after subscribe";
ws_frame_and_send_response(ws, WS_BINARY_FRAME, error_msg, sizeof(error_msg)-1);
} else { /* log and execute */
ws_log_cmd(ws, cmd);
cmd_send(cmd, fun_reply);
ws->ran_subscribe = is_subscribe;
}
return 0;
}
@ -242,16 +328,24 @@ ws_execute(struct http_client *c, const char *frame, size_t frame_len) {
}
static struct ws_msg *
ws_msg_new() {
return calloc(1, sizeof(struct ws_msg));
ws_msg_new(enum ws_frame_type frame_type) {
struct ws_msg *msg = calloc(1, sizeof(struct ws_msg));
if(!msg) {
return NULL;
}
msg->type = frame_type;
return msg;
}
static void
static int
ws_msg_add(struct ws_msg *m, const char *p, size_t psz, const unsigned char *mask) {
/* add data to frame */
size_t i;
m->payload = realloc(m->payload, m->payload_sz + psz);
if(!m->payload) {
return -1;
}
memcpy(m->payload + m->payload_sz, p, psz);
/* apply mask */
@ -261,29 +355,46 @@ ws_msg_add(struct ws_msg *m, const char *p, size_t psz, const unsigned char *mas
/* save new size */
m->payload_sz += psz;
return 0;
}
static void
ws_msg_free(struct ws_msg **m) {
ws_msg_free(struct ws_msg *m) {
free((*m)->payload);
free(*m);
*m = NULL;
free(m->payload);
free(m);
}
/* checks to see if we have a complete message */
static enum ws_state
ws_parse_data(const char *frame, size_t sz, struct ws_msg **msg) {
ws_peek_data(struct ws_client *ws, struct ws_msg **out_msg) {
int has_mask;
uint64_t len;
const char *p;
char *frame;
unsigned char mask[4];
char fin_bit_set;
enum ws_frame_type frame_type;
/* parse frame and extract contents */
size_t sz = evbuffer_get_length(ws->rbuf);
if(sz < 8) {
return WS_READING;
return WS_READING; /* need more data */
}
/* copy into "frame" to process it */
frame = malloc(sz);
if(!frame) {
return WS_ERROR;
}
int rem_ret = evbuffer_remove(ws->rbuf, frame, sz);
if(rem_ret < 0) {
free(frame);
return WS_ERROR;
}
fin_bit_set = frame[0] & 0x80 ? 1 : 0;
frame_type = frame[0] & 0x0F; /* lower 4 bits of first byte */
has_mask = frame[1] & 0x80 ? 1:0;
/* get payload length */
@ -298,67 +409,113 @@ ws_parse_data(const char *frame, size_t sz, struct ws_msg **msg) {
p = frame + 4 + (has_mask ? 4 : 0);
if(has_mask) memcpy(&mask, frame + 4, sizeof(mask));
} else if(len == 127) {
len = webdis_ntohl64(frame+2);
uint64_t sz64 = *((uint64_t*)(frame+2));
len = webdis_ntohll(sz64);
p = frame + 10 + (has_mask ? 4 : 0);
if(has_mask) memcpy(&mask, frame + 10, sizeof(mask));
} else {
free(frame);
return WS_ERROR;
}
/* we now have the (possibly masked) data starting in p, and its length. */
if(len > sz - (p - frame)) { /* not enough data */
return WS_READING;
int add_ret = evbuffer_prepend(ws->rbuf, frame, sz); /* put the whole frame back */
free(frame);
return add_ret < 0 ? WS_ERROR : WS_READING;
}
if(!*msg)
*msg = ws_msg_new();
ws_msg_add(*msg, p, len, has_mask ? mask : NULL);
(*msg)->total_sz += len + (p - frame);
int ev_copy = 0;
if(out_msg) { /* we're extracting the message */
struct ws_msg *msg = ws_msg_new(frame_type);
if(!msg) {
free(frame);
return WS_ERROR;
}
*out_msg = msg; /* attach for it to be freed by caller */
/* create new ws_msg object holding what we read */
int add_ret = ws_msg_add(msg, p, len, has_mask ? mask : NULL);
if(!add_ret) {
free(frame);
return WS_ERROR;
}
if(frame[0] & 0x80) { /* FIN bit set */
size_t processed_sz = len + (p - frame); /* length of data + header bytes between frame start and payload */
msg->total_sz += processed_sz;
ev_copy = evbuffer_prepend(ws->rbuf, frame + len, sz - processed_sz); /* remove processed data */
} else { /* we're just peeking */
ev_copy = evbuffer_prepend(ws->rbuf, frame, sz); /* put the whole frame back */
}
free(frame);
if(ev_copy < 0) {
return WS_ERROR;
} else if(fin_bit_set) {
return WS_MSG_COMPLETE;
} else {
return WS_READING; /* need more data */
}
}
/**
* Process some data just received on the socket.
* Returns the number of messages processed in out_processed, if non-NULL.
*/
enum ws_state
ws_add_data(struct http_client *c) {
ws_process_read_data(struct ws_client *ws, unsigned int *out_processed) {
enum ws_state state;
if(out_processed) *out_processed = 0;
state = ws_parse_data(c->buffer, c->sz, &c->frame);
state = ws_peek_data(ws, NULL); /* check for complete message */
while(state == WS_MSG_COMPLETE) {
int ret = ws_execute(c, c->frame->payload, c->frame->payload_sz);
/* remove frame from client buffer */
http_client_remove_data(c, c->frame->total_sz);
int ret = 0;
struct ws_msg *msg = NULL;
ws_peek_data(ws, &msg); /* extract message */
if(msg && (msg->type == WS_TEXT_FRAME || msg->type == WS_BINARY_FRAME)) {
ret = ws_execute(ws, msg);
if(out_processed) (*out_processed)++;
} else if(msg && msg->type == WS_PING) { /* respond to ping */
ws_frame_and_send_response(ws, WS_PONG, msg->payload, msg->payload_sz);
} else if(msg && msg->type == WS_CONNECTION_CLOSE) { /* respond to close frame */
ws->close_after_events = 1;
ws_frame_and_send_response(ws, WS_CONNECTION_CLOSE, msg->payload, msg->payload_sz);
} else if(msg) {
char format[] = "Received unexpected WS frame type: 0x%x";
char error[(sizeof format)];
int error_len = snprintf(error, sizeof(error), format, msg->type);
slog(ws->http_client->s, WEBDIS_WARNING, error, error_len);
}
/* free frame and set back to NULL */
ws_msg_free(&c->frame);
/* free frame */
if(msg) ws_msg_free(msg);
if(ret != 0) {
/* can't process frame. */
slog(c->s, WEBDIS_WARNING, "ws_add_data: ws_execute failed", 0);
slog(ws->http_client->s, WEBDIS_DEBUG, "ws_process_read_data: ws_execute failed", 0);
return WS_ERROR;
}
state = ws_parse_data(c->buffer, c->sz, &c->frame);
state = ws_peek_data(ws, NULL);
}
return state;
}
int
ws_reply(struct cmd *cmd, const char *p, size_t sz) {
char *frame = malloc(sz + 8); /* create frame by prepending header */
ws_frame_and_send_response(struct ws_client *ws, enum ws_frame_type frame_type, const char *p, size_t sz) {
/* we can have as much as 14 bytes in the header:
* 1 byte for 4 flag bits + 4 frame type bits
* 1 byte for the payload length indicator
* 8 bytes for the size of the payload (at most)
* 4 bytes for the masking key (if present)
*/
char *frame = malloc(sz + 14); /* create frame by prepending header */
size_t frame_sz = 0;
struct http_response *r;
if (frame == NULL)
if(frame == NULL)
return -1;
/*
@ -368,40 +525,119 @@ ws_reply(struct cmd *cmd, const char *p, size_t sz) {
following 8 bytes interpreted as a 64-bit unsigned integer (the
most significant bit MUST be 0) are the payload length.
*/
frame[0] = '\x81';
frame[0] = 0x80 | frame_type; /* frame type + EOM bit */
if(sz <= 125) {
frame[1] = sz;
memcpy(frame + 2, p, sz);
frame_sz = sz + 2;
} else if (sz <= 65536) {
} else if(sz <= 65536) {
uint16_t sz16 = htons(sz);
frame[1] = 126;
memcpy(frame + 2, &sz16, 2);
memcpy(frame + 4, p, sz);
frame_sz = sz + 4;
} else { /* sz > 65536 */
char sz64[8] = webdis_htonl64(sz);
uint64_t sz_be = webdis_htonll(sz); /* big endian */
char sz64[8];
memcpy(sz64, &sz_be, 8);
frame[1] = 127;
memcpy(frame + 2, sz64, 8);
memcpy(frame + 10, p, sz);
frame_sz = sz + 10;
}
/* send WS frame */
r = http_response_init(cmd->w, 0, NULL);
if (r == NULL) {
free(frame);
slog(cmd->w->s, WEBDIS_ERROR, "Failed response allocation in ws_reply", 0);
/* mark as keep alive, otherwise we'll close the connection after the first reply */
int add_ret = evbuffer_add(ws->wbuf, frame, frame_sz);
free(frame); /* no longer needed once added to buffer */
if(add_ret < 0) {
slog(ws->http_client->w->s, WEBDIS_ERROR, "Failed response allocation in ws_frame_and_send_response", 0);
return -1;
}
/* mark as keep alive, otherwise we'll close the connection after the first reply */
r->keep_alive = 1;
/* send WS frame */
return ws_schedule_write(ws);
}
static void
ws_close_if_able(struct ws_client *ws) {
ws->close_after_events = 1; /* note that we're closing */
if(ws->scheduled_read || ws->scheduled_write) {
return; /* still waiting for these events to trigger */
}
ws_client_free(ws); /* will close the socket */
}
r->out = frame;
r->out_sz = frame_sz;
r->sent = 0;
http_schedule_write(cmd->fd, r);
static void
ws_can_read(int fd, short event, void *p) {
int ret;
struct ws_client *ws = p;
(void)event;
/* read pending data */
ws->scheduled_read = 0;
ret = evbuffer_read(ws->rbuf, fd, 4096);
if(ret <= 0) {
ws_client_free(ws); /* will close the socket */
} else if(ws->close_after_events) {
ws_close_if_able(ws);
} else {
enum ws_state state = ws_process_read_data(ws, NULL);
if(state == WS_READING) { /* need more data, schedule new read */
ws_monitor_input(ws);
} else if(state == WS_ERROR) {
ws_close_if_able(ws);
}
}
}
static void
ws_can_write(int fd, short event, void *p) {
int ret;
struct ws_client *ws = p;
(void)event;
ws->scheduled_write = 0;
/* send pending data */
ret = evbuffer_write_atmost(ws->wbuf, fd, 4096);
if(ret <= 0) {
ws_client_free(ws); /* will close the socket */
} else {
if(evbuffer_get_length(ws->wbuf) > 0) { /* more data to send */
ws_schedule_write(ws);
} else if(ws->close_after_events) { /* we're done! */
ws_close_if_able(ws);
} else {
/* check if we can read more data */
unsigned int processed = 0;
ws_process_read_data(ws, &processed); /* process any pending data we've already read */
ws_monitor_input(ws); /* let's read more from the client */
}
}
}
static int
ws_schedule_write(struct ws_client *ws) {
struct http_client *c = ws->http_client;
if(!ws->scheduled_write) {
ws->scheduled_write = 1;
return event_base_once(c->w->base, c->fd, EV_WRITE, ws_can_write, ws, NULL);
}
return 0;
}
int
ws_monitor_input(struct ws_client *ws) {
struct http_client *c = ws->http_client;
if(!ws->scheduled_read) {
ws->scheduled_read = 1;
return event_base_once(c->w->base, c->fd, EV_READ, ws_can_read, ws, NULL);
}
return 0;
}

@ -3,6 +3,8 @@
#include <stdlib.h>
#include <stdint.h>
#include <event.h>
#include <hiredis/async.h>
struct http_client;
struct cmd;
@ -12,19 +14,51 @@ enum ws_state {
WS_READING,
WS_MSG_COMPLETE};
enum ws_frame_type {
WS_TEXT_FRAME = 0,
WS_BINARY_FRAME = 1,
WS_CONNECTION_CLOSE = 8,
WS_PING = 9,
WS_PONG = 0xA,
WS_UNKNOWN_FRAME = -1};
struct ws_msg {
enum ws_frame_type type;
char *payload;
size_t payload_sz;
size_t total_sz;
};
struct ws_client {
struct http_client *http_client; /* parent */
int scheduled_read; /* set if we are scheduled to read WS data */
int scheduled_write; /* set if we are scheduled to send out WS data */
struct evbuffer *rbuf; /* read buffer for incoming data */
struct evbuffer *wbuf; /* write buffer for outgoing data */
redisAsyncContext *ac; /* dedicated connection to redis */
struct cmd *cmd; /* current command */
/* indicates that we'll close once we've flushed all
buffered data and read what we planned to read */
int close_after_events;
int ran_subscribe; /* set if we've run a (p)subscribe command */
};
struct ws_client *
ws_client_new(struct http_client *http_client);
void
ws_client_free(struct ws_client *ws);
int
ws_handshake_reply(struct ws_client *ws);
int
ws_handshake_reply(struct http_client *c);
ws_monitor_input(struct ws_client *ws);
enum ws_state
ws_add_data(struct http_client *c);
ws_process_read_data(struct ws_client *ws, unsigned int *out_processed);
int
ws_reply(struct cmd *cmd, const char *p, size_t sz);
ws_frame_and_send_response(struct ws_client *ws, enum ws_frame_type type, const char *p, size_t sz);
#endif

@ -13,7 +13,8 @@
#include <unistd.h>
#include <event.h>
#include <string.h>
#include <netinet/tcp.h>
#include "formats/common.h"
struct worker *
worker_new(struct server *s) {
@ -53,10 +54,7 @@ worker_can_read(int fd, short event, void *p) {
}
}
if(c->is_websocket) {
/* Got websocket data */
ws_add_data(c);
} else {
if(!c->is_websocket) {
/* run parser */
nparsed = http_client_execute(c);
@ -67,13 +65,32 @@ worker_can_read(int fd, short event, void *p) {
/* only close if requested *and* we've already read the request in full */
c->broken = 1;
} else if(c->is_websocket) {
/* we need to use the remaining (unparsed) data as the body. */
if(nparsed < ret) {
http_client_add_to_body(c, c->buffer + nparsed + 1, c->sz - nparsed - 1);
ws_handshake_reply(c);
} else {
/* Got websocket data */
c->ws = ws_client_new(c);
if(!c->ws) {
c->broken = 1;
} else {
free(c->buffer);
c->buffer = NULL;
c->sz = 0;
/* send response, and start managing fd from websocket.c */
int reply_ret = ws_handshake_reply(c->ws);
if(reply_ret < 0) {
c->ws->http_client = NULL; /* detach to prevent double free */
ws_client_free(c->ws);
c->broken = 1;
} else {
unsigned int processed = 0;
int process_ret = ws_process_read_data(c->ws, &processed);
if(process_ret == WS_ERROR) {
c->broken = 1; /* likely connection was closed */
}
}
}
/* clean up what remains in HTTP client */
free(c->buffer);
c->buffer = NULL;
c->sz = 0;
@ -87,10 +104,14 @@ worker_can_read(int fd, short event, void *p) {
}
if(c->broken) { /* terminate client */
if(c->is_websocket) { /* only close for WS since HTTP might use keep-alive */
close(c->fd);
}
http_client_free(c);
} else {
/* start monitoring input again */
worker_monitor_input(c);
} else { /* start monitoring input again */
if(!c->is_websocket) { /* all communication handled by WS code from now on */
worker_monitor_input(c);
}
}
}
@ -99,7 +120,6 @@ worker_can_read(int fd, short event, void *p) {
*/
void
worker_monitor_input(struct http_client *c) {
event_set(&c->ev, c->fd, EV_READ, worker_can_read, c);
event_base_set(c->w->base, &c->ev);
event_add(&c->ev, NULL);

@ -12,6 +12,9 @@ class BlockingSocket:
self.s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.s.setblocking(True)
self.s.connect((HOST, PORT))
def __del__(self):
self.s.close()
def recv(self):
out = b""

@ -0,0 +1 @@
websocket-client>=1.1.0

@ -263,10 +263,13 @@ websocket_can_read(int fd, short event, void *ptr) {
/* http parser will return the offset at which the upgraded protocol begins,
which in our case is 1 under the total response size. */
if (wt->state == WS_SENT_HANDSHAKE || /* haven't encountered end of response yet */
(wt->parser.upgrade && nparsed != (int)avail_sz -1)) {
wt->debug("UPGRADE *and* we have some data left (nparsed=%d, avail_sz=%lu)\n", nparsed, avail_sz);
continue;
if (wt->state == WS_SENT_HANDSHAKE) { /* haven't encountered end of response yet */
if (wt->parser.upgrade && nparsed != (int)avail_sz) {
wt->debug("UPGRADE *and* we have some data left (state=%d, nparsed=%d, avail_sz=%lu)\n", wt->state, nparsed, avail_sz);
continue;
} else { /* we just haven't read the entire response yet */
wait_for_possible_read(wt);
}
} else if (wt->state == WS_RECEIVED_HANDSHAKE) { /* we have the full response */
evbuffer_drain(wt->rbuffer, evbuffer_get_length(wt->rbuffer));
}
@ -474,6 +477,8 @@ progress_thread_main(void *ptr) {
struct timespec ts_start;
clock_gettime(CLOCK_MONOTONIC, &ts_start);
long start_nanos = ts_start.tv_sec * 1e9 + ts_start.tv_nsec; /* time of monitoring start */
long last_print_nanos = start_nanos;
while(1) {
int sem_received = 0;
@ -501,7 +506,6 @@ progress_thread_main(void *ptr) {
sem_received = 1;
}
#endif
// nanosleep(&ts, NULL);
num_sleeps++;
int total_sent = 0, total_received = 0, any_broken = 0, num_complete = 0;
for (int i = 0; i < pt->worker_count; i++) {
@ -515,11 +519,15 @@ progress_thread_main(void *ptr) {
}
struct timespec ts_after_sleep;
clock_gettime(CLOCK_MONOTONIC, &ts_after_sleep);
long after_sleep_nanos = ts_after_sleep.tv_sec * 1e9 + ts_after_sleep.tv_nsec;
long total_nanos = after_sleep_nanos - start_nanos; /* total time spent so far */
fprintf(stderr, "After %0.2f sec: %'d messages sent, %'d received (%.02f%%). Instant rate: %'ld/sec, overall rate: %'ld/sec\n",
((float)((ts_after_sleep.tv_sec * 1e9 + ts_after_sleep.tv_nsec) - (ts_start.tv_sec * 1e9 + ts_start.tv_nsec))) / (float)1e9,
total_sent, total_received, 100.0f * (float)total_received / (float)(pt->worker_count * pt->msg_target),
lroundf((float)(total_received - last_received) / pt->interval_sec),
lroundf((float)total_received / ((float)num_sleeps) * pt->interval_sec));
lroundf((float)(total_received - last_received) / (((float)(after_sleep_nanos - last_print_nanos)) / 1e9f)),
lroundf((float)total_received / (((float)total_nanos) / 1e9f)));
last_print_nanos = after_sleep_nanos; /* time of last print */
if (sem_received || total_received == pt->msg_target * pt->worker_count || any_broken || num_complete == pt->worker_count) {
break;

@ -101,6 +101,26 @@ function installBlock(title, type) {
</fieldset>
</form>
<form class="pure-form">
<fieldset>
<div class="pure-g">
<div class="pure-u-1-3"><input disabled class="pure-u-23-24" type="text" placeholder="channel name" id="$type-pub-channel" value="channel-0" /></div>
<div class="pure-u-1-3"><input disabled class="pure-u-23-24" type="text" placeholder="message value" id="$type-pub-message" value="message-0" /></div>
<div class="pure-u-1-3"><button disabled type="submit" class="pure-u-23-24 pure-button pure-button-primary" id="$type-btn-pub">PUBLISH</button></div>
</div>
</fieldset>
</form>
<form class="pure-form">
<fieldset>
<div class="pure-g">
<div class="pure-u-1-3">&nbsp;</div>
<div class="pure-u-1-3"><input disabled class="pure-u-23-24" type="text" placeholder="channel" id="$type-sub-channel" value="channel-0" /></div>
<div class="pure-u-1-3"><button disabled type="submit" class="pure-u-23-24 pure-button pure-button-primary" id="$type-btn-sub">SUBSCRIBE</button></div>
</div>
</fieldset>
</form>
<div class="pure-g">
<div class="pure-u-2-3">&nbsp;</div>
<div class="pure-u-1-3"><button disabled type="submit" class="pure-u-23-24 pure-button pure-button-primary" id="$type-btn-clear">Clear logs</button></div>
@ -113,17 +133,19 @@ function installBlock(title, type) {
class Client {
constructor(type, pingSerializer, getSerializer, setSerializer) {
constructor(type, serializer) {
this.type = type;
this.pingSerializer = pingSerializer;
this.getSerializer = getSerializer;
this.setSerializer = setSerializer;
this.serializer = serializer;
this.ws = null;
this.connected = false;
this.subscribed = false;
this.logCount = 0;
$(`${this.type}-btn-connect`).addEventListener('click', event => {
event.preventDefault();
console.log('Connecting...');
this.ws = new WebSocket(`ws://${ host }:${ port }/.${ this.type }`);
this.ws = new WebSocket(`ws://${host}:${port}/.${this.type}`);
window.ws = this.ws;
this.ws.onopen = event => {
console.log('Connected');
this.setConnectedState(true);
@ -135,50 +157,83 @@ class Client {
};
this.ws.onclose = event => {
$(`${this.type}-btn-connect`).disabled = false;
$(`${this.type}-btn-connect`).innerText = 'Connect';
this.setConnectedState(false);
this.subscribed = false;
};
});
$(`${this.type}-btn-ping`).addEventListener('click', event => {
event.preventDefault();
const serialized = this.pingSerializer();
this.log("sent", serialized);
this.ws.send(serialized);
const serialized = this.serializer(['PING']);
this.send(serialized);
});
$(`${this.type}-btn-set`).addEventListener('click', event => {
event.preventDefault();
const serialized = this.setSerializer($(`${this.type}-set-key`).value, $(`${this.type}-set-value`).value);
this.log("sent", serialized);
this.ws.send(serialized);
const serialized = this.serializer(['SET', $(`${this.type}-set-key`).value, $(`${this.type}-set-value`).value]);
this.send(serialized);
});
$(`${this.type}-btn-get`).addEventListener('click', event => {
event.preventDefault();
const serialized = this.getSerializer($(`${this.type}-set-key`).value);
this.log("sent", serialized);
this.ws.send(serialized);
const serialized = this.serializer(['GET', $(`${this.type}-get-key`).value]);
this.send(serialized);
});
$(`${this.type}-btn-pub`).addEventListener('click', event => {
event.preventDefault();
const serialized = this.serializer(['PUBLISH', $(`${this.type}-pub-channel`).value, $(`${this.type}-pub-message`).value]);
this.send(serialized);
});
$(`${this.type}-btn-sub`).addEventListener('click', event => {
event.preventDefault();
const serialized = this.serializer(['SUBSCRIBE', $(`${this.type}-sub-channel`).value]);
try {
this.send(serialized);
this.subscribed = true;
this.setConnectedState(true);
} catch (e) {
console.log('Error sending: ', serialized, e);
}
});
$(`${this.type}-btn-clear`).addEventListener('click', event => {
event.preventDefault();
$(`${this.type}-log`).innerText = "";
this.logCount = 0;
this.updateLogButton();
});
}
send(serialized) {
this.log("sent", serialized);
this.ws.send(serialized);
}
setConnectedState(connected) {
this.connected = connected;
$(`${this.type}-btn-connect`).disabled = connected;
$(`${this.type}-btn-ping`).disabled = !connected;
$(`${this.type}-set-key`).disabled = !connected;
$(`${this.type}-set-value`).disabled = !connected;
$(`${this.type}-btn-set`).disabled = !connected;
$(`${this.type}-get-key`).disabled = !connected;
$(`${this.type}-btn-get`).disabled = !connected;
$(`${this.type}-btn-clear`).disabled = !connected;
$(`${this.type}-btn-ping`).disabled = !connected || this.subscribed;
$(`${this.type}-set-key`).disabled = !connected || this.subscribed;
$(`${this.type}-set-value`).disabled = !connected || this.subscribed;
$(`${this.type}-btn-set`).disabled = !connected || this.subscribed;
$(`${this.type}-get-key`).disabled = !connected || this.subscribed;
$(`${this.type}-btn-get`).disabled = !connected || this.subscribed;
$(`${this.type}-pub-channel`).disabled = !connected || this.subscribed;
$(`${this.type}-pub-message`).disabled = !connected || this.subscribed;
$(`${this.type}-btn-pub`).disabled = !connected || this.subscribed;
$(`${this.type}-sub-channel`).disabled = !connected || this.subscribed;
$(`${this.type}-btn-sub`).disabled = !connected || this.subscribed;
$(`${this.type}-state`).innerText = `State: ${connected ? 'Connected' : 'Disconnected'}`;
}
updateLogButton() {
$(`${this.type}-btn-clear`).disabled = this.logCount === 0;
}
log(dir, msg) {
const id = `${this.type}-log`;
@ -190,22 +245,26 @@ class Client {
contents.setAttribute("class", dir);
contents.innerHTML = msg;
$(id).appendChild(contents);
this.logCount++;
this.updateLogButton();
}
}
function serializeRaw(args) {
let raw = `*${args.length}\r\n`;
for (let i = 0; i < args.length; i++) {
raw += `$${args[i].length}\r\n`;
raw += `${args[i]}\r\n`;
}
return raw;
}
addEventListener("DOMContentLoaded", () => {
installBlock('JSON', 'json');
installBlock('Raw', 'raw');
const jsonClient = new Client('json',
() => JSON.stringify(['PING']),
(key) => JSON.stringify(['GET', key]),
(key, value) => JSON.stringify(['SET', key, value]));
const rawClient = new Client('raw',
() => '*1\r\n$4\r\nPING\r\n',
(key) => `*2\r\n$3\r\nGET\r\n$${key.length}\r\n${key}\r\n`,
(key, value) => `*3\r\n$3\r\nSET\r\n$${key.length}\r\n${key}\r\n$${value.length}\r\n${value}\r\n`);
const jsonClient = new Client('json', JSON.stringify);
const rawClient = new Client('raw', serializeRaw);
});
</script>

@ -0,0 +1,194 @@
#!/usr/bin/env python3
import abc
import json
import os
import unittest
import uuid
from websocket import create_connection
host = os.getenv('WEBDIS_HOST', '127.0.0.1')
port = int(os.getenv('WEBDIS_PORT', 7379))
def connect(format):
return create_connection(f'ws://{host}:{port}/.{format}')
class TestWebdis(unittest.TestCase):
def setUp(self) -> None:
self.ws = connect(self.format())
def tearDown(self) -> None:
self.ws.close()
def exec(self, cmd, *args):
self.ws.send(self.serialize(cmd, *args))
return self.deserialize(self.ws.recv())
def clean_key(self):
"""Returns a key that was just deleted"""
key = str(uuid.uuid4())
self.exec('DEL', key)
return key
@abc.abstractmethod
def format(self):
"""Returns the format to use (added after a dot to the WS URI)"""
return
@abc.abstractmethod
def serialize(self, cmd):
"""Serializes a command according to the format being tested"""
return
@abc.abstractmethod
def deserialize(self, response):
"""Deserializes a response according to the format being tested"""
return
class TestJson(TestWebdis):
def format(self):
return 'json'
def serialize(self, cmd, *args):
return json.dumps([cmd] + list(args))
def deserialize(self, response):
return json.loads(response)
def test_ping(self):
self.assertEqual(self.exec('PING'), {'PING': [True, 'PONG']})
def test_multiple_messages(self):
key = self.clean_key()
n = 100
for i in range(n):
lpush_response = self.exec('LPUSH', key, f'value-{i}')
self.assertEqual(lpush_response, {'LPUSH': i + 1})
self.assertEqual(self.exec('LLEN', key), {'LLEN': n})
class TestRaw(TestWebdis):
def format(self):
return 'raw'
def serialize(self, cmd, *args):
buffer = f"*{1 + len(args)}\r\n${len(cmd)}\r\n{cmd}\r\n"
for arg in args:
buffer += f"${len(arg)}\r\n{arg}\r\n"
return buffer
def deserialize(self, response):
return response # we'll just assert using the raw protocol
def test_ping(self):
self.assertEqual(self.exec('PING'), "+PONG\r\n")
def test_get_set(self):
key = self.clean_key()
value = str(uuid.uuid4())
not_found_response = self.exec('GET', key)
self.assertEqual(not_found_response, "$-1\r\n") # Redis protocol response for "not found"
set_response = self.exec('SET', key, value)
self.assertEqual(set_response, "+OK\r\n")
get_response = self.exec('GET', key)
self.assertEqual(get_response, f"${len(value)}\r\n{value}\r\n")
class TestPubSub(unittest.TestCase):
def setUp(self):
self.publisher = connect('json')
self.subscriber = connect('json')
def tearDown(self):
self.publisher.close()
self.subscriber.close()
def serialize(self, cmd, *args):
return json.dumps([cmd] + list(args))
def deserialize(self, response):
return json.loads(response)
def test_publish_subscribe(self):
channel_count = 2
message_count_per_channel = 8
channels = list(str(uuid.uuid4()) for i in range(channel_count))
# subscribe to all channels
sub_count = 0
for channel in channels:
self.subscriber.send(self.serialize('SUBSCRIBE', channel))
sub_response = self.deserialize(self.subscriber.recv())
sub_count += 1
self.assertEqual(sub_response, {'SUBSCRIBE': ['subscribe', channel, sub_count]})
# send messages to all channels
prefix = 'message-'
for i in range(message_count_per_channel):
for channel in channels:
message = f'{prefix}{i}'
self.publisher.send(self.serialize('PUBLISH', channel, message))
self.deserialize(self.publisher.recv())
received_per_channel = dict((channel, []) for channel in channels)
for j in range(channel_count * message_count_per_channel):
received = self.deserialize(self.subscriber.recv())
# expected: {'SUBSCRIBE': ['message', $channel, $message]}
self.assertTrue(received, 'SUBSCRIBE' in received)
sub_contents = received['SUBSCRIBE']
self.assertEqual(len(sub_contents), 3)
self.assertEqual(sub_contents[0], 'message') # first element is the message type, here a push
channel = sub_contents[1]
self.assertTrue(channel in channels) # second is the channel
received_per_channel[channel].append(
sub_contents[2]) # third, add to list of messages received for this channel
# unsubscribe from all channels
subs_remaining = channel_count
for channel in channels:
self.subscriber.send(self.serialize('UNSUBSCRIBE', channel))
subs_remaining -= 1
unsub_response = self.deserialize(self.subscriber.recv())
self.assertEqual(unsub_response, {'UNSUBSCRIBE': ['unsubscribe', channel, subs_remaining]})
# check that we received all messages
for channel in channels:
self.assertEqual(len(received_per_channel[channel]), message_count_per_channel)
# check that we received them *in order*
for i in range(message_count_per_channel):
for channel in channels:
expected = f'{prefix}{i}'
self.assertEqual(received_per_channel[channel][i], expected,
f'In {channel}: expected at offset {i} was "{expected}", actual was: "{received_per_channel[channel][i]}"')
class TestFrameSizes(TestWebdis):
def format(self):
return 'json'
def serialize(self, cmd, *args):
return json.dumps([cmd] + list(args))
def deserialize(self, response):
return json.loads(response)
def test_length_126(self):
self.validate_set_get('A' * 1024) # this will require 2 bytes to encode the length
def test_length_127(self):
self.validate_set_get('A' * (2 ** 18)) # this will require more than 2 bytes to encode the length (actually using 8)
def validate_set_get(self, value):
key = str(uuid.uuid4())
self.assertEqual(self.exec('SET', key, value), {'SET': [True, 'OK']})
self.assertEqual(self.exec('GET', key), {'GET': value})
self.exec('DEL', key)
if __name__ == '__main__':
unittest.main()
Loading…
Cancel
Save