diff --git a/src/client.c b/src/client.c index 3e8eda4..ed6f3f3 100644 --- a/src/client.c +++ b/src/client.c @@ -286,15 +286,15 @@ http_client_read(struct http_client *c) { if(ret <= 0) { /* broken link, free buffer and client object */ - /* disconnect pub/sub client if there is one. */ - if(c->pub_sub && c->pub_sub->ac) { - struct cmd *cmd = c->pub_sub; + /* disconnect pub/sub or WS client if there is one. */ + if(c->reused_cmd && c->reused_cmd->ac) { + struct cmd *cmd = c->reused_cmd; /* disconnect from all channels */ - redisAsyncDisconnect(c->pub_sub->ac); - // c->pub_sub might be already cleared by an event handler in redisAsyncDisconnect + redisAsyncDisconnect(c->reused_cmd->ac); + // c->reused_cmd might be already cleared by an event handler in redisAsyncDisconnect cmd->ac = NULL; - c->pub_sub = NULL; + c->reused_cmd = NULL; /* delete command object */ cmd_free(cmd); diff --git a/src/client.h b/src/client.h index 6b38992..c355bac 100644 --- a/src/client.h +++ b/src/client.h @@ -61,9 +61,9 @@ struct http_client { char *separator; /* list separator for raw lists */ char *filename; /* content-disposition */ - struct cmd *pub_sub; + struct cmd *reused_cmd; - struct ws_msg *frame; /* websocket frame */ + struct ws_client *ws; /* websocket client */ }; struct http_client * diff --git a/src/cmd.c b/src/cmd.c index 290df87..6fcc528 100644 --- a/src/cmd.c +++ b/src/cmd.c @@ -6,7 +6,6 @@ #include "worker.h" #include "http.h" #include "server.h" -#include "slog.h" #include "formats/json.h" #include "formats/raw.h" @@ -22,11 +21,12 @@ #include struct cmd * -cmd_new(int count) { +cmd_new(struct http_client *client, int count) { struct cmd *c = calloc(1, sizeof(struct cmd)); c->count = count; + c->http_client = client; c->argv = calloc(count, sizeof(char*)); c->argv_len = calloc(count, sizeof(size_t)); @@ -34,11 +34,21 @@ cmd_new(int count) { return c; } +void +cmd_free_argv(struct cmd *c) { + + int i; + for(i = 0; i < c->count; ++i) { + free((char*)c->argv[i]); + } + + free(c->argv); + free(c->argv_len); +} void cmd_free(struct cmd *c) { - int i; if(!c) return; free(c->jsonp); @@ -52,12 +62,7 @@ cmd_free(struct cmd *c) { pool_free_context(c->ac); } - for(i = 0; i < c->count; ++i) { - free((char*)c->argv[i]); - } - - free(c->argv); - free(c->argv_len); + cmd_free_argv(c); free(c); } @@ -164,7 +169,7 @@ cmd_run(struct worker *w, struct http_client *client, return CMD_PARAM_ERROR; } - cmd = cmd_new(param_count); + cmd = cmd_new(client, param_count); cmd->fd = client->fd; cmd->database = w->s->cfg->database; @@ -224,7 +229,7 @@ cmd_run(struct worker *w, struct http_client *client, cmd->ac = (redisAsyncContext*)pool_connect(w->pool, cmd->database, 0); /* register with the client, used upon disconnection */ - client->pub_sub = cmd; + client->reused_cmd = cmd; cmd->pub_sub_client = client; } else if(cmd->database != w->s->cfg->database) { /* create a new connection to Redis for custom DBs */ @@ -276,7 +281,7 @@ cmd_run(struct worker *w, struct http_client *client, } /* failed to find a suitable connection to Redis. */ cmd_free(cmd); - client->pub_sub = NULL; + client->reused_cmd = NULL; return CMD_REDIS_UNAVAIL; } @@ -370,10 +375,31 @@ cmd_select_format(struct http_client *client, struct cmd *cmd, int cmd_is_subscribe(struct cmd *cmd) { - if(cmd->count >= 1 && cmd->argv[0] && - (strncasecmp(cmd->argv[0], "SUBSCRIBE", cmd->argv_len[0]) == 0 || - strncasecmp(cmd->argv[0], "PSUBSCRIBE", cmd->argv_len[0]) == 0)) { + if(cmd->pub_sub_client || /* persistent command */ + cmd_is_subscribe_args(cmd)) { /* checked with args */ return 1; } return 0; } + +int +cmd_is_subscribe_args(struct cmd *cmd) { + + if(cmd->count >= 2 && + ((cmd->argv_len[0] == 9 && strncasecmp(cmd->argv[0], "subscribe", 9) == 0) || + (cmd->argv_len[0] == 10 && strncasecmp(cmd->argv[0], "psubscribe", 10) == 0))) { + return 1; + } + return 0; +} + +int +cmd_is_unsubscribe_args(struct cmd *cmd) { + + if(cmd->count >= 2 && + ((cmd->argv_len[0] == 11 && strncasecmp(cmd->argv[0], "unsubscribe", 11) == 0) || + (cmd->argv_len[0] == 12 && strncasecmp(cmd->argv[0], "punsubscribe", 12) == 0))) { + return 1; + } + return 0; +} diff --git a/src/cmd.h b/src/cmd.h index e216bde..823c3e3 100644 --- a/src/cmd.h +++ b/src/cmd.h @@ -43,6 +43,7 @@ struct cmd { int http_version; int database; + struct http_client *http_client; struct http_client *pub_sub_client; redisAsyncContext *ac; struct worker *w; @@ -54,7 +55,10 @@ struct subscription { }; struct cmd * -cmd_new(int count); +cmd_new(struct http_client *c, int count); + +void +cmd_free_argv(struct cmd *c); void cmd_free(struct cmd *c); @@ -68,6 +72,12 @@ int cmd_select_format(struct http_client *client, struct cmd *cmd, const char *uri, size_t uri_len, formatting_fun *f_format); +int +cmd_is_subscribe_args(struct cmd *cmd); + +int +cmd_is_unsubscribe_args(struct cmd *cmd); + int cmd_is_subscribe(struct cmd *cmd); diff --git a/src/formats/common.c b/src/formats/common.c index 296b822..1932143 100644 --- a/src/formats/common.c +++ b/src/formats/common.c @@ -7,6 +7,8 @@ #include "md5/md5.h" #include #include +#include +#include /* TODO: replace this with a faster hash function? */ char *etag_new(const char *p, size_t sz) { @@ -44,13 +46,16 @@ format_send_error(struct cmd *cmd, short code, const char *msg) { resp->http_version = cmd->http_version; http_response_set_keep_alive(resp, cmd->keep_alive); http_response_write(resp, cmd->fd); + } else if(cmd->is_websocket && !cmd->http_client->ws->close_after_events) { + ws_frame_and_send_response(cmd->http_client->ws, WS_BINARY_FRAME, msg, strlen(msg)); } - /* for pub/sub, remove command from client */ - if(cmd->pub_sub_client) { - cmd->pub_sub_client->pub_sub = NULL; - } else { - cmd_free(cmd); + if (!cmd->is_websocket) { /* don't free or detach persistent cmd */ + if (cmd->pub_sub_client) { /* for pub/sub, remove command from client */ + cmd->pub_sub_client->reused_cmd = NULL; + } else { + cmd_free(cmd); + } } } @@ -62,7 +67,8 @@ format_send_reply(struct cmd *cmd, const char *p, size_t sz, const char *content struct http_response *resp; if(cmd->is_websocket) { - ws_reply(cmd, p, sz); + + ws_frame_and_send_response(cmd->http_client->ws, WS_BINARY_FRAME, p, sz); /* If it's a subscribe command, there'll be more responses */ if(!cmd_is_subscribe(cmd)) diff --git a/src/formats/json.c b/src/formats/json.c index 006bb1f..0cbcb52 100644 --- a/src/formats/json.c +++ b/src/formats/json.c @@ -522,7 +522,7 @@ json_ws_extract(struct http_client *c, const char *p, size_t sz) { } /* create command and add args */ - cmd = cmd_new(argc); + cmd = cmd_new(c, argc); for(i = 0, cur = 0; i < json_array_size(j); ++i) { json_t *jelem = json_array_get(j, i); char *tmp; diff --git a/src/formats/raw.c b/src/formats/raw.c index 20c3927..7caff97 100644 --- a/src/formats/raw.c +++ b/src/formats/raw.c @@ -62,7 +62,7 @@ raw_ws_extract(struct http_client *c, const char *p, size_t sz) { } /* create cmd object */ - cmd = cmd_new(reply->elements); + cmd = cmd_new(c, reply->elements); for(i = 0; i < reply->elements; ++i) { redisReply *ri = reply->element[i]; diff --git a/src/http.c b/src/http.c index 66f8135..cd31ea0 100644 --- a/src/http.c +++ b/src/http.c @@ -16,6 +16,10 @@ http_response_init(struct worker *w, int code, const char *msg) { /* create object */ struct http_response *r = calloc(1, sizeof(struct http_response)); + if(!r) { + if(w && w->s) slog(w->s, WEBDIS_ERROR, "Failed to allocate http_response", 0); + return NULL; + } r->code = code; r->msg = msg; @@ -43,6 +47,24 @@ http_response_init(struct worker *w, int code, const char *msg) { return r; } +struct http_response * +http_response_init_with_buffer(struct worker *w, char *data, size_t data_sz, int keep_alive) { + + struct http_response *r = calloc(1, sizeof(struct http_response)); + if(!r) { + if(w && w->s) slog(w->s, WEBDIS_ERROR, "Failed to allocate http_response with buffer", 0); + return NULL; + } + r->w = w; + + /* provide buffer directly */ + r->out = data; + r->out_sz = data_sz; + r->sent = 0; + r->keep_alive = keep_alive; + return r; +} + void http_response_set_header(struct http_response *r, const char *k, const char *v) { diff --git a/src/http.h b/src/http.h index eac1ffe..8deb542 100644 --- a/src/http.h +++ b/src/http.h @@ -45,6 +45,9 @@ struct http_response { struct http_response * http_response_init(struct worker *w, int code, const char *msg); +struct http_response * +http_response_init_with_buffer(struct worker *w, char *data, size_t data_sz, int keep_alive); + void http_response_set_header(struct http_response *r, const char *k, const char *v); diff --git a/src/pool.c b/src/pool.c index 706006b..f3e7b42 100644 --- a/src/pool.c +++ b/src/pool.c @@ -30,7 +30,6 @@ pool_free_context(redisAsyncContext *ac) { if (ac) { redisAsyncDisconnect(ac); - redisAsyncFree(ac); } } @@ -96,6 +95,11 @@ pool_on_disconnect(const redisAsyncContext *ac, int status) { struct pool *p = ac->data; int i = 0; + + if(p == NULL) { /* no need to clean anything here. */ + return; + } + if (status != REDIS_OK) { char format[] = "Error disconnecting: %s"; size_t msg_sz = sizeof(format) - 2 + ((ac && ac->errstr) ? strlen(ac->errstr) : 6); @@ -107,10 +111,6 @@ pool_on_disconnect(const redisAsyncContext *ac, int status) { } } - if(p == NULL) { /* no need to clean anything here. */ - return; - } - /* remove from the pool */ for(i = 0; i < p->count; ++i) { if(p->ac[i] == ac) { diff --git a/src/slog.c b/src/slog.c index 9bae415..bdb600d 100644 --- a/src/slog.c +++ b/src/slog.c @@ -13,6 +13,10 @@ #include "server.h" #include "conf.h" +#if SLOG_MSG_MAX_LEN < 64 +#error "SLOG_MSG_MAX_LEN must be at least 64" +#endif + /** * Initialize log writer. */ @@ -95,8 +99,8 @@ slog_internal(struct server *s, log_level level, time_t now; struct tm now_tm, *lt_ret; char time_buf[64]; - char msg[124]; - char line[256]; /* bounds are checked. */ + char msg[1 + SLOG_MSG_MAX_LEN]; + char line[2 * SLOG_MSG_MAX_LEN]; /* bounds are checked. */ int line_sz, ret; if(!s->log.fd) return; diff --git a/src/slog.h b/src/slog.h index 14d9c66..d4b47e8 100644 --- a/src/slog.h +++ b/src/slog.h @@ -1,6 +1,8 @@ #ifndef SLOG_H #define SLOG_H +#define SLOG_MSG_MAX_LEN 124 + typedef enum { WEBDIS_ERROR = 0, WEBDIS_WARNING, diff --git a/src/websocket.c b/src/websocket.c index 53d48c5..e307de2 100644 --- a/src/websocket.c +++ b/src/websocket.c @@ -7,6 +7,8 @@ #include "pool.h" #include "http.h" #include "slog.h" +#include "server.h" +#include "conf.h" /* message parsers */ #include "formats/json.h" @@ -19,23 +21,21 @@ #include #include +static int +ws_schedule_write(struct ws_client *ws); + /** * This code uses the WebSocket specification from RFC 6455. * A copy is available at http://www.rfc-editor.org/rfc/rfc6455.txt */ +#if __BIG_ENDIAN__ +# define webdis_htonll(x) (x) +# define webdis_ntohll(x) (x) +#else +# define webdis_htonll(x) (((uint64_t)htonl((x) & 0xFFFFFFFF) << 32) | htonl((x) >> 32)) +# define webdis_ntohll(x) (((uint64_t)ntohl((x) & 0xFFFFFFFF) << 32) | ntohl((x) >> 32)) +#endif -/* custom 64-bit encoding functions to avoid portability issues */ -#define webdis_ntohl64(p) \ - ((((uint64_t)((p)[0])) << 0) + (((uint64_t)((p)[1])) << 8) +\ - (((uint64_t)((p)[2])) << 16) + (((uint64_t)((p)[3])) << 24) +\ - (((uint64_t)((p)[4])) << 32) + (((uint64_t)((p)[5])) << 40) +\ - (((uint64_t)((p)[6])) << 48) + (((uint64_t)((p)[7])) << 56)) - -#define webdis_htonl64(p) {\ - (char)(((p & ((uint64_t)0xff << 0)) >> 0) & 0xff), (char)(((p & ((uint64_t)0xff << 8)) >> 8) & 0xff), \ - (char)(((p & ((uint64_t)0xff << 16)) >> 16) & 0xff), (char)(((p & ((uint64_t)0xff << 24)) >> 24) & 0xff), \ - (char)(((p & ((uint64_t)0xff << 32)) >> 32) & 0xff), (char)(((p & ((uint64_t)0xff << 40)) >> 40) & 0xff), \ - (char)(((p & ((uint64_t)0xff << 48)) >> 48) & 0xff), (char)(((p & ((uint64_t)0xff << 56)) >> 56) & 0xff) } static int ws_compute_handshake(struct http_client *c, char *out, size_t *out_sz) { @@ -86,25 +86,75 @@ ws_compute_handshake(struct http_client *c, char *out, size_t *out_sz) { return 0; } +struct ws_client * +ws_client_new(struct http_client *http_client) { + + int db_num = http_client->w->s->cfg->database; + struct ws_client *ws = calloc(1, sizeof(struct ws_client)); + struct evbuffer *rbuf = evbuffer_new(); + struct evbuffer *wbuf = evbuffer_new(); + redisAsyncContext *ac = pool_connect(http_client->w->pool, db_num, 0); + + if(!ws || !rbuf || !wbuf) { + slog(http_client->s, WEBDIS_ERROR, "Failed to allocate memory for WS client", 0); + if(ws) free(ws); + if(rbuf) evbuffer_free(rbuf); + if(wbuf) evbuffer_free(wbuf); + if(ac) redisAsyncFree(ac); + return NULL; + } + + http_client->ws = ws; + ws->http_client = http_client; + ws->rbuf = rbuf; + ws->wbuf = wbuf; + ws->ac = ac; + + return ws; +} + +void +ws_client_free(struct ws_client *ws) { + + /* mark WS client as closing to skip the Redis callback */ + ws->close_after_events = 1; + pool_free_context(ws->ac); /* could trigger a cb via format_send_error */ + + struct http_client *c = ws->http_client; + if(c) { + close(c->fd); + c->ws = NULL; /* detach if needed */ + } + evbuffer_free(ws->rbuf); + evbuffer_free(ws->wbuf); + if(ws->cmd) { + ws->cmd->ac = NULL; /* we've just free'd it */ + cmd_free(ws->cmd); + } + free(ws); + if(c) http_client_free(c); +} + + int -ws_handshake_reply(struct http_client *c) { +ws_handshake_reply(struct ws_client *ws) { + struct http_client *c = ws->http_client; char sha1_handshake[40]; char *buffer = NULL, *p; const char *origin = NULL, *host = NULL; size_t origin_sz = 0, host_sz = 0, handshake_sz = 0, sz; - char template0[] = "HTTP/1.1 101 Switching Protocols\r\n" + char template_start[] = "HTTP/1.1 101 Switching Protocols\r\n" "Upgrade: websocket\r\n" - "Connection: Upgrade\r\n" - "Sec-WebSocket-Origin: "; /* %s */ - char template1[] = "\r\n" - "Sec-WebSocket-Location: ws://"; /* %s%s */ - char template2[] = "\r\n" - "Origin: http://"; /* %s */ - char template3[] = "\r\n" + "Connection: Upgrade"; + char template_accept[] = "\r\n" /* just after the start */ "Sec-WebSocket-Accept: "; /* %s */ - char template4[] = "\r\n\r\n"; + char template_sec_origin[] = "\r\n" + "Sec-WebSocket-Origin: "; /* %s (optional header) */ + char template_loc[] = "\r\n" + "Sec-WebSocket-Location: ws://"; /* %s%s */ + char template_end[] = "\r\n\r\n"; if((origin = client_get_header(c, "Origin"))) { origin_sz = strlen(origin); @@ -116,7 +166,7 @@ ws_handshake_reply(struct http_client *c) { } /* need those headers */ - if(!origin || !origin_sz || !host || !host_sz || !c->path || !c->path_sz) { + if(!host || !host_sz || !c->path || !c->path_sz) { slog(c->s, WEBDIS_WARNING, "Missing headers for WS handshake", 0); return -1; } @@ -128,11 +178,11 @@ ws_handshake_reply(struct http_client *c) { return -1; } - sz = sizeof(template0)-1 + origin_sz - + sizeof(template1)-1 + host_sz + c->path_sz - + sizeof(template2)-1 + host_sz - + sizeof(template3)-1 + handshake_sz - + sizeof(template4)-1; + sz = sizeof(template_start)-1 + + sizeof(template_accept)-1 + handshake_sz + + (origin && origin_sz ? (sizeof(template_sec_origin)-1 + origin_sz) : 0) /* optional origin */ + + sizeof(template_loc)-1 + host_sz + c->path_sz + + sizeof(template_end)-1; p = buffer = malloc(sz); if(!p) { @@ -142,57 +192,74 @@ ws_handshake_reply(struct http_client *c) { /* Concat all */ - /* template0 */ - memcpy(p, template0, sizeof(template0)-1); - p += sizeof(template0)-1; - memcpy(p, origin, origin_sz); - p += origin_sz; + /* template_start */ + memcpy(p, template_start, sizeof(template_start)-1); + p += sizeof(template_start)-1; - /* template1 */ - memcpy(p, template1, sizeof(template1)-1); - p += sizeof(template1)-1; + /* template_accept */ + memcpy(p, template_accept, sizeof(template_accept)-1); + p += sizeof(template_accept)-1; + memcpy(p, &sha1_handshake[0], handshake_sz); + p += handshake_sz; + + /* template_sec_origin */ + if(origin && origin_sz) { + memcpy(p, template_sec_origin, sizeof(template_sec_origin)-1); + p += sizeof(template_sec_origin)-1; + memcpy(p, origin, origin_sz); + p += origin_sz; + } + + /* template_loc */ + memcpy(p, template_loc, sizeof(template_loc)-1); + p += sizeof(template_loc)-1; memcpy(p, host, host_sz); p += host_sz; memcpy(p, c->path, c->path_sz); p += c->path_sz; - /* template2 */ - memcpy(p, template2, sizeof(template2)-1); - p += sizeof(template2)-1; - memcpy(p, host, host_sz); - p += host_sz; + /* template_end */ + memcpy(p, template_end, sizeof(template_end)-1); + p += sizeof(template_end)-1; - /* template3 */ - memcpy(p, template3, sizeof(template3)-1); - p += sizeof(template3)-1; - memcpy(p, &sha1_handshake[0], handshake_sz); - p += handshake_sz; + int add_ret = evbuffer_add(ws->wbuf, buffer, sz); + free(buffer); + if(add_ret < 0) { + slog(c->s, WEBDIS_ERROR, "Failed to add response for WS handshake", 0); + return -1; + } - /* template4 */ - memcpy(p, template4, sizeof(template4)-1); - p += sizeof(template4)-1; + return ws_schedule_write(ws); /* will free buffer and response once sent */ +} - /* build HTTP response object by hand, since we have the full response already */ - struct http_response *r = calloc(1, sizeof(struct http_response)); - if(!r) { - slog(c->s, WEBDIS_ERROR, "Failed to allocate response for WS handshake", 0); - free(buffer); - return -1; +static void +ws_log_cmd(struct ws_client *ws, struct cmd *cmd) { + char log_msg[SLOG_MSG_MAX_LEN]; + char *p = log_msg, *eom = log_msg + sizeof(log_msg) - 1; + if(!slog_enabled(ws->http_client->s, WEBDIS_DEBUG)) { + return; } - r->w = c->w; - r->keep_alive = 1; - r->out = buffer; - r->out_sz = sz; - r->sent = 0; - http_schedule_write(c->fd, r); /* will free buffer and response once sent */ - return 0; + memset(log_msg, 0, sizeof(log_msg)); + memcpy(p, "WS: ", 4); /* WS prefix */ + p += 4; + + for(int i = 0; p < eom && i < cmd->count; i++) { + *p++ = '/'; + char *arg = cmd->argv[i]; + size_t arg_sz = cmd->argv_len[i]; + size_t copy_sz = arg_sz < (size_t)(eom - p) ? arg_sz : (size_t)(eom - p); + memcpy(p, arg, copy_sz); + p += copy_sz; + } + slog(ws->http_client->s, WEBDIS_DEBUG, log_msg, p - log_msg); } static int -ws_execute(struct http_client *c, const char *frame, size_t frame_len) { +ws_execute(struct ws_client *ws, struct ws_msg *msg) { + struct http_client *c = ws->http_client; struct cmd*(*fun_extract)(struct http_client *, const char *, size_t) = NULL; formatting_fun fun_reply = NULL; @@ -208,31 +275,50 @@ ws_execute(struct http_client *c, const char *frame, size_t frame_len) { if(fun_extract) { /* Parse websocket frame into a cmd object. */ - struct cmd *cmd = fun_extract(c, frame, frame_len); + struct cmd *cmd = fun_extract(c, msg->payload, msg->payload_sz); if(cmd) { - /* copy client info into cmd. */ - cmd_setup(cmd, c); cmd->is_websocket = 1; - if (c->pub_sub != NULL) { - /* This client already has its own connection - * to Redis due to a subscription; use it from - * now on. */ - cmd->ac = c->pub_sub->ac; - } else if (cmd_is_subscribe(cmd)) { - /* New subscribe command; make new Redis context - * for this client */ - cmd->ac = pool_connect(c->w->pool, cmd->database, 0); - c->pub_sub = cmd; - cmd->pub_sub_client = c; + if(ws->cmd != NULL) { + /* This client already has its own connection to Redis + from a previous command; use it from now on. */ + + /* free args for the previous cmd */ + cmd_free_argv(ws->cmd); + /* copy args from what we just parsed to the persistent command */ + ws->cmd->count = cmd->count; + ws->cmd->argv = cmd->argv; + ws->cmd->argv_len = cmd->argv_len; + ws->cmd->pub_sub_client = c; /* mark as persistent, otherwise the Redis context will be freed */ + + cmd->argv = NULL; + cmd->argv_len = NULL; + cmd->count = 0; + cmd_free(cmd); + + cmd = ws->cmd; /* replace pointer since we're about to pass it to cmd_send */ } else { - /* get Redis connection from pool */ - cmd->ac = (redisAsyncContext*)pool_get_context(c->w->pool); + /* copy client info into cmd. */ + cmd_setup(cmd, c); + + /* First WS command; use Redis context from WS client. */ + cmd->ac = ws->ac; + ws->cmd = cmd; + cmd->pub_sub_client = c; } - /* send it off */ - cmd_send(cmd, fun_reply); + int is_subscribe = cmd_is_subscribe_args(cmd); + int is_unsubscribe = cmd_is_unsubscribe_args(cmd); + + if(ws->ran_subscribe && !is_subscribe && !is_unsubscribe) { /* disallow non-subscribe commands after a subscribe */ + char error_msg[] = "Command not allowed after subscribe"; + ws_frame_and_send_response(ws, WS_BINARY_FRAME, error_msg, sizeof(error_msg)-1); + } else { /* log and execute */ + ws_log_cmd(ws, cmd); + cmd_send(cmd, fun_reply); + ws->ran_subscribe = is_subscribe; + } return 0; } @@ -242,16 +328,24 @@ ws_execute(struct http_client *c, const char *frame, size_t frame_len) { } static struct ws_msg * -ws_msg_new() { - return calloc(1, sizeof(struct ws_msg)); +ws_msg_new(enum ws_frame_type frame_type) { + struct ws_msg *msg = calloc(1, sizeof(struct ws_msg)); + if(!msg) { + return NULL; + } + msg->type = frame_type; + return msg; } -static void +static int ws_msg_add(struct ws_msg *m, const char *p, size_t psz, const unsigned char *mask) { /* add data to frame */ size_t i; m->payload = realloc(m->payload, m->payload_sz + psz); + if(!m->payload) { + return -1; + } memcpy(m->payload + m->payload_sz, p, psz); /* apply mask */ @@ -261,29 +355,46 @@ ws_msg_add(struct ws_msg *m, const char *p, size_t psz, const unsigned char *mas /* save new size */ m->payload_sz += psz; + return 0; } static void -ws_msg_free(struct ws_msg **m) { +ws_msg_free(struct ws_msg *m) { - free((*m)->payload); - free(*m); - *m = NULL; + free(m->payload); + free(m); } +/* checks to see if we have a complete message */ static enum ws_state -ws_parse_data(const char *frame, size_t sz, struct ws_msg **msg) { +ws_peek_data(struct ws_client *ws, struct ws_msg **out_msg) { int has_mask; uint64_t len; const char *p; + char *frame; unsigned char mask[4]; + char fin_bit_set; + enum ws_frame_type frame_type; /* parse frame and extract contents */ + size_t sz = evbuffer_get_length(ws->rbuf); if(sz < 8) { - return WS_READING; + return WS_READING; /* need more data */ + } + /* copy into "frame" to process it */ + frame = malloc(sz); + if(!frame) { + return WS_ERROR; + } + int rem_ret = evbuffer_remove(ws->rbuf, frame, sz); + if(rem_ret < 0) { + free(frame); + return WS_ERROR; } + fin_bit_set = frame[0] & 0x80 ? 1 : 0; + frame_type = frame[0] & 0x0F; /* lower 4 bits of first byte */ has_mask = frame[1] & 0x80 ? 1:0; /* get payload length */ @@ -298,67 +409,113 @@ ws_parse_data(const char *frame, size_t sz, struct ws_msg **msg) { p = frame + 4 + (has_mask ? 4 : 0); if(has_mask) memcpy(&mask, frame + 4, sizeof(mask)); } else if(len == 127) { - len = webdis_ntohl64(frame+2); + uint64_t sz64 = *((uint64_t*)(frame+2)); + len = webdis_ntohll(sz64); p = frame + 10 + (has_mask ? 4 : 0); if(has_mask) memcpy(&mask, frame + 10, sizeof(mask)); } else { + free(frame); return WS_ERROR; } /* we now have the (possibly masked) data starting in p, and its length. */ if(len > sz - (p - frame)) { /* not enough data */ - return WS_READING; + int add_ret = evbuffer_prepend(ws->rbuf, frame, sz); /* put the whole frame back */ + free(frame); + return add_ret < 0 ? WS_ERROR : WS_READING; } - if(!*msg) - *msg = ws_msg_new(); - ws_msg_add(*msg, p, len, has_mask ? mask : NULL); - (*msg)->total_sz += len + (p - frame); + int ev_copy = 0; + if(out_msg) { /* we're extracting the message */ + struct ws_msg *msg = ws_msg_new(frame_type); + if(!msg) { + free(frame); + return WS_ERROR; + } + *out_msg = msg; /* attach for it to be freed by caller */ + + /* create new ws_msg object holding what we read */ + int add_ret = ws_msg_add(msg, p, len, has_mask ? mask : NULL); + if(!add_ret) { + free(frame); + return WS_ERROR; + } - if(frame[0] & 0x80) { /* FIN bit set */ + size_t processed_sz = len + (p - frame); /* length of data + header bytes between frame start and payload */ + msg->total_sz += processed_sz; + + ev_copy = evbuffer_prepend(ws->rbuf, frame + len, sz - processed_sz); /* remove processed data */ + } else { /* we're just peeking */ + ev_copy = evbuffer_prepend(ws->rbuf, frame, sz); /* put the whole frame back */ + } + free(frame); + + if(ev_copy < 0) { + return WS_ERROR; + } else if(fin_bit_set) { return WS_MSG_COMPLETE; } else { return WS_READING; /* need more data */ } } - /** * Process some data just received on the socket. + * Returns the number of messages processed in out_processed, if non-NULL. */ enum ws_state -ws_add_data(struct http_client *c) { +ws_process_read_data(struct ws_client *ws, unsigned int *out_processed) { enum ws_state state; + if(out_processed) *out_processed = 0; - state = ws_parse_data(c->buffer, c->sz, &c->frame); + state = ws_peek_data(ws, NULL); /* check for complete message */ while(state == WS_MSG_COMPLETE) { - int ret = ws_execute(c, c->frame->payload, c->frame->payload_sz); - - /* remove frame from client buffer */ - http_client_remove_data(c, c->frame->total_sz); + int ret = 0; + struct ws_msg *msg = NULL; + ws_peek_data(ws, &msg); /* extract message */ + + if(msg && (msg->type == WS_TEXT_FRAME || msg->type == WS_BINARY_FRAME)) { + ret = ws_execute(ws, msg); + if(out_processed) (*out_processed)++; + } else if(msg && msg->type == WS_PING) { /* respond to ping */ + ws_frame_and_send_response(ws, WS_PONG, msg->payload, msg->payload_sz); + } else if(msg && msg->type == WS_CONNECTION_CLOSE) { /* respond to close frame */ + ws->close_after_events = 1; + ws_frame_and_send_response(ws, WS_CONNECTION_CLOSE, msg->payload, msg->payload_sz); + } else if(msg) { + char format[] = "Received unexpected WS frame type: 0x%x"; + char error[(sizeof format)]; + int error_len = snprintf(error, sizeof(error), format, msg->type); + slog(ws->http_client->s, WEBDIS_WARNING, error, error_len); + } - /* free frame and set back to NULL */ - ws_msg_free(&c->frame); + /* free frame */ + if(msg) ws_msg_free(msg); if(ret != 0) { /* can't process frame. */ - slog(c->s, WEBDIS_WARNING, "ws_add_data: ws_execute failed", 0); + slog(ws->http_client->s, WEBDIS_DEBUG, "ws_process_read_data: ws_execute failed", 0); return WS_ERROR; } - state = ws_parse_data(c->buffer, c->sz, &c->frame); + state = ws_peek_data(ws, NULL); } return state; } int -ws_reply(struct cmd *cmd, const char *p, size_t sz) { - - char *frame = malloc(sz + 8); /* create frame by prepending header */ +ws_frame_and_send_response(struct ws_client *ws, enum ws_frame_type frame_type, const char *p, size_t sz) { + + /* we can have as much as 14 bytes in the header: + * 1 byte for 4 flag bits + 4 frame type bits + * 1 byte for the payload length indicator + * 8 bytes for the size of the payload (at most) + * 4 bytes for the masking key (if present) + */ + char *frame = malloc(sz + 14); /* create frame by prepending header */ size_t frame_sz = 0; - struct http_response *r; - if (frame == NULL) + if(frame == NULL) return -1; /* @@ -368,40 +525,119 @@ ws_reply(struct cmd *cmd, const char *p, size_t sz) { following 8 bytes interpreted as a 64-bit unsigned integer (the most significant bit MUST be 0) are the payload length. */ - frame[0] = '\x81'; + frame[0] = 0x80 | frame_type; /* frame type + EOM bit */ if(sz <= 125) { frame[1] = sz; memcpy(frame + 2, p, sz); frame_sz = sz + 2; - } else if (sz <= 65536) { + } else if(sz <= 65536) { uint16_t sz16 = htons(sz); frame[1] = 126; memcpy(frame + 2, &sz16, 2); memcpy(frame + 4, p, sz); frame_sz = sz + 4; } else { /* sz > 65536 */ - char sz64[8] = webdis_htonl64(sz); + uint64_t sz_be = webdis_htonll(sz); /* big endian */ + char sz64[8]; + memcpy(sz64, &sz_be, 8); frame[1] = 127; memcpy(frame + 2, sz64, 8); memcpy(frame + 10, p, sz); frame_sz = sz + 10; } - /* send WS frame */ - r = http_response_init(cmd->w, 0, NULL); - if (r == NULL) { - free(frame); - slog(cmd->w->s, WEBDIS_ERROR, "Failed response allocation in ws_reply", 0); + /* mark as keep alive, otherwise we'll close the connection after the first reply */ + int add_ret = evbuffer_add(ws->wbuf, frame, frame_sz); + free(frame); /* no longer needed once added to buffer */ + if(add_ret < 0) { + slog(ws->http_client->w->s, WEBDIS_ERROR, "Failed response allocation in ws_frame_and_send_response", 0); return -1; } - /* mark as keep alive, otherwise we'll close the connection after the first reply */ - r->keep_alive = 1; + /* send WS frame */ + return ws_schedule_write(ws); +} + +static void +ws_close_if_able(struct ws_client *ws) { + + ws->close_after_events = 1; /* note that we're closing */ + if(ws->scheduled_read || ws->scheduled_write) { + return; /* still waiting for these events to trigger */ + } + ws_client_free(ws); /* will close the socket */ +} - r->out = frame; - r->out_sz = frame_sz; - r->sent = 0; - http_schedule_write(cmd->fd, r); +static void +ws_can_read(int fd, short event, void *p) { + + int ret; + struct ws_client *ws = p; + (void)event; + + /* read pending data */ + ws->scheduled_read = 0; + ret = evbuffer_read(ws->rbuf, fd, 4096); + + if(ret <= 0) { + ws_client_free(ws); /* will close the socket */ + } else if(ws->close_after_events) { + ws_close_if_able(ws); + } else { + enum ws_state state = ws_process_read_data(ws, NULL); + if(state == WS_READING) { /* need more data, schedule new read */ + ws_monitor_input(ws); + } else if(state == WS_ERROR) { + ws_close_if_able(ws); + } + } +} + + +static void +ws_can_write(int fd, short event, void *p) { + + int ret; + struct ws_client *ws = p; + (void)event; + + ws->scheduled_write = 0; + /* send pending data */ + ret = evbuffer_write_atmost(ws->wbuf, fd, 4096); + + if(ret <= 0) { + ws_client_free(ws); /* will close the socket */ + } else { + if(evbuffer_get_length(ws->wbuf) > 0) { /* more data to send */ + ws_schedule_write(ws); + } else if(ws->close_after_events) { /* we're done! */ + ws_close_if_able(ws); + } else { + /* check if we can read more data */ + unsigned int processed = 0; + ws_process_read_data(ws, &processed); /* process any pending data we've already read */ + ws_monitor_input(ws); /* let's read more from the client */ + } + } +} + +static int +ws_schedule_write(struct ws_client *ws) { + struct http_client *c = ws->http_client; + if(!ws->scheduled_write) { + ws->scheduled_write = 1; + return event_base_once(c->w->base, c->fd, EV_WRITE, ws_can_write, ws, NULL); + } + return 0; +} + +int +ws_monitor_input(struct ws_client *ws) { + struct http_client *c = ws->http_client; + if(!ws->scheduled_read) { + ws->scheduled_read = 1; + return event_base_once(c->w->base, c->fd, EV_READ, ws_can_read, ws, NULL); + } return 0; } diff --git a/src/websocket.h b/src/websocket.h index dc9d764..8b176a4 100644 --- a/src/websocket.h +++ b/src/websocket.h @@ -3,6 +3,8 @@ #include #include +#include +#include struct http_client; struct cmd; @@ -12,19 +14,51 @@ enum ws_state { WS_READING, WS_MSG_COMPLETE}; +enum ws_frame_type { + WS_TEXT_FRAME = 0, + WS_BINARY_FRAME = 1, + WS_CONNECTION_CLOSE = 8, + WS_PING = 9, + WS_PONG = 0xA, + WS_UNKNOWN_FRAME = -1}; + struct ws_msg { + enum ws_frame_type type; char *payload; size_t payload_sz; size_t total_sz; }; +struct ws_client { + struct http_client *http_client; /* parent */ + int scheduled_read; /* set if we are scheduled to read WS data */ + int scheduled_write; /* set if we are scheduled to send out WS data */ + struct evbuffer *rbuf; /* read buffer for incoming data */ + struct evbuffer *wbuf; /* write buffer for outgoing data */ + redisAsyncContext *ac; /* dedicated connection to redis */ + struct cmd *cmd; /* current command */ + /* indicates that we'll close once we've flushed all + buffered data and read what we planned to read */ + int close_after_events; + int ran_subscribe; /* set if we've run a (p)subscribe command */ +}; + +struct ws_client * +ws_client_new(struct http_client *http_client); + +void +ws_client_free(struct ws_client *ws); + +int +ws_handshake_reply(struct ws_client *ws); + int -ws_handshake_reply(struct http_client *c); +ws_monitor_input(struct ws_client *ws); enum ws_state -ws_add_data(struct http_client *c); +ws_process_read_data(struct ws_client *ws, unsigned int *out_processed); int -ws_reply(struct cmd *cmd, const char *p, size_t sz); +ws_frame_and_send_response(struct ws_client *ws, enum ws_frame_type type, const char *p, size_t sz); #endif diff --git a/src/worker.c b/src/worker.c index d5decef..5d06cf6 100644 --- a/src/worker.c +++ b/src/worker.c @@ -13,7 +13,8 @@ #include #include #include - +#include +#include "formats/common.h" struct worker * worker_new(struct server *s) { @@ -53,10 +54,7 @@ worker_can_read(int fd, short event, void *p) { } } - if(c->is_websocket) { - /* Got websocket data */ - ws_add_data(c); - } else { + if(!c->is_websocket) { /* run parser */ nparsed = http_client_execute(c); @@ -67,13 +65,32 @@ worker_can_read(int fd, short event, void *p) { /* only close if requested *and* we've already read the request in full */ c->broken = 1; } else if(c->is_websocket) { - /* we need to use the remaining (unparsed) data as the body. */ - if(nparsed < ret) { - http_client_add_to_body(c, c->buffer + nparsed + 1, c->sz - nparsed - 1); - ws_handshake_reply(c); - } else { + + /* Got websocket data */ + c->ws = ws_client_new(c); + if(!c->ws) { c->broken = 1; + } else { + free(c->buffer); + c->buffer = NULL; + c->sz = 0; + + /* send response, and start managing fd from websocket.c */ + int reply_ret = ws_handshake_reply(c->ws); + if(reply_ret < 0) { + c->ws->http_client = NULL; /* detach to prevent double free */ + ws_client_free(c->ws); + c->broken = 1; + } else { + unsigned int processed = 0; + int process_ret = ws_process_read_data(c->ws, &processed); + if(process_ret == WS_ERROR) { + c->broken = 1; /* likely connection was closed */ + } + } } + + /* clean up what remains in HTTP client */ free(c->buffer); c->buffer = NULL; c->sz = 0; @@ -87,10 +104,14 @@ worker_can_read(int fd, short event, void *p) { } if(c->broken) { /* terminate client */ + if(c->is_websocket) { /* only close for WS since HTTP might use keep-alive */ + close(c->fd); + } http_client_free(c); - } else { - /* start monitoring input again */ - worker_monitor_input(c); + } else { /* start monitoring input again */ + if(!c->is_websocket) { /* all communication handled by WS code from now on */ + worker_monitor_input(c); + } } } @@ -99,7 +120,6 @@ worker_can_read(int fd, short event, void *p) { */ void worker_monitor_input(struct http_client *c) { - event_set(&c->ev, c->fd, EV_READ, worker_can_read, c); event_base_set(c->w->base, &c->ev); event_add(&c->ev, NULL); diff --git a/tests/limits.py b/tests/limits.py index 0e2c2d5..82b09c9 100755 --- a/tests/limits.py +++ b/tests/limits.py @@ -12,6 +12,9 @@ class BlockingSocket: self.s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.s.setblocking(True) self.s.connect((HOST, PORT)) + + def __del__(self): + self.s.close() def recv(self): out = b"" diff --git a/tests/requirements.txt b/tests/requirements.txt new file mode 100644 index 0000000..66a496d --- /dev/null +++ b/tests/requirements.txt @@ -0,0 +1 @@ +websocket-client>=1.1.0 diff --git a/tests/websocket.c b/tests/websocket.c index aa501ca..e20f273 100644 --- a/tests/websocket.c +++ b/tests/websocket.c @@ -263,10 +263,13 @@ websocket_can_read(int fd, short event, void *ptr) { /* http parser will return the offset at which the upgraded protocol begins, which in our case is 1 under the total response size. */ - if (wt->state == WS_SENT_HANDSHAKE || /* haven't encountered end of response yet */ - (wt->parser.upgrade && nparsed != (int)avail_sz -1)) { - wt->debug("UPGRADE *and* we have some data left (nparsed=%d, avail_sz=%lu)\n", nparsed, avail_sz); - continue; + if (wt->state == WS_SENT_HANDSHAKE) { /* haven't encountered end of response yet */ + if (wt->parser.upgrade && nparsed != (int)avail_sz) { + wt->debug("UPGRADE *and* we have some data left (state=%d, nparsed=%d, avail_sz=%lu)\n", wt->state, nparsed, avail_sz); + continue; + } else { /* we just haven't read the entire response yet */ + wait_for_possible_read(wt); + } } else if (wt->state == WS_RECEIVED_HANDSHAKE) { /* we have the full response */ evbuffer_drain(wt->rbuffer, evbuffer_get_length(wt->rbuffer)); } @@ -474,6 +477,8 @@ progress_thread_main(void *ptr) { struct timespec ts_start; clock_gettime(CLOCK_MONOTONIC, &ts_start); + long start_nanos = ts_start.tv_sec * 1e9 + ts_start.tv_nsec; /* time of monitoring start */ + long last_print_nanos = start_nanos; while(1) { int sem_received = 0; @@ -501,7 +506,6 @@ progress_thread_main(void *ptr) { sem_received = 1; } #endif - // nanosleep(&ts, NULL); num_sleeps++; int total_sent = 0, total_received = 0, any_broken = 0, num_complete = 0; for (int i = 0; i < pt->worker_count; i++) { @@ -515,11 +519,15 @@ progress_thread_main(void *ptr) { } struct timespec ts_after_sleep; clock_gettime(CLOCK_MONOTONIC, &ts_after_sleep); + long after_sleep_nanos = ts_after_sleep.tv_sec * 1e9 + ts_after_sleep.tv_nsec; + long total_nanos = after_sleep_nanos - start_nanos; /* total time spent so far */ + fprintf(stderr, "After %0.2f sec: %'d messages sent, %'d received (%.02f%%). Instant rate: %'ld/sec, overall rate: %'ld/sec\n", ((float)((ts_after_sleep.tv_sec * 1e9 + ts_after_sleep.tv_nsec) - (ts_start.tv_sec * 1e9 + ts_start.tv_nsec))) / (float)1e9, total_sent, total_received, 100.0f * (float)total_received / (float)(pt->worker_count * pt->msg_target), - lroundf((float)(total_received - last_received) / pt->interval_sec), - lroundf((float)total_received / ((float)num_sleeps) * pt->interval_sec)); + lroundf((float)(total_received - last_received) / (((float)(after_sleep_nanos - last_print_nanos)) / 1e9f)), + lroundf((float)total_received / (((float)total_nanos) / 1e9f))); + last_print_nanos = after_sleep_nanos; /* time of last print */ if (sem_received || total_received == pt->msg_target * pt->worker_count || any_broken || num_complete == pt->worker_count) { break; diff --git a/tests/websocket.html b/tests/websocket.html index 0f06fc8..00cca12 100644 --- a/tests/websocket.html +++ b/tests/websocket.html @@ -101,6 +101,26 @@ function installBlock(title, type) { +
+
+
+
+
+
+
+
+
+ +
+
+
+
 
+
+
+
+
+
+
 
@@ -113,17 +133,19 @@ function installBlock(title, type) { class Client { - constructor(type, pingSerializer, getSerializer, setSerializer) { + constructor(type, serializer) { this.type = type; - this.pingSerializer = pingSerializer; - this.getSerializer = getSerializer; - this.setSerializer = setSerializer; + this.serializer = serializer; this.ws = null; + this.connected = false; + this.subscribed = false; + this.logCount = 0; $(`${this.type}-btn-connect`).addEventListener('click', event => { event.preventDefault(); console.log('Connecting...'); - this.ws = new WebSocket(`ws://${ host }:${ port }/.${ this.type }`); + this.ws = new WebSocket(`ws://${host}:${port}/.${this.type}`); + window.ws = this.ws; this.ws.onopen = event => { console.log('Connected'); this.setConnectedState(true); @@ -135,50 +157,83 @@ class Client { }; this.ws.onclose = event => { - $(`${this.type}-btn-connect`).disabled = false; + $(`${this.type}-btn-connect`).innerText = 'Connect'; this.setConnectedState(false); + this.subscribed = false; }; }); $(`${this.type}-btn-ping`).addEventListener('click', event => { event.preventDefault(); - const serialized = this.pingSerializer(); - this.log("sent", serialized); - this.ws.send(serialized); + const serialized = this.serializer(['PING']); + this.send(serialized); }); $(`${this.type}-btn-set`).addEventListener('click', event => { event.preventDefault(); - const serialized = this.setSerializer($(`${this.type}-set-key`).value, $(`${this.type}-set-value`).value); - this.log("sent", serialized); - this.ws.send(serialized); + const serialized = this.serializer(['SET', $(`${this.type}-set-key`).value, $(`${this.type}-set-value`).value]); + this.send(serialized); }); $(`${this.type}-btn-get`).addEventListener('click', event => { event.preventDefault(); - const serialized = this.getSerializer($(`${this.type}-set-key`).value); - this.log("sent", serialized); - this.ws.send(serialized); + const serialized = this.serializer(['GET', $(`${this.type}-get-key`).value]); + this.send(serialized); + }); + + $(`${this.type}-btn-pub`).addEventListener('click', event => { + event.preventDefault(); + const serialized = this.serializer(['PUBLISH', $(`${this.type}-pub-channel`).value, $(`${this.type}-pub-message`).value]); + this.send(serialized); + }); + + $(`${this.type}-btn-sub`).addEventListener('click', event => { + event.preventDefault(); + const serialized = this.serializer(['SUBSCRIBE', $(`${this.type}-sub-channel`).value]); + try { + this.send(serialized); + this.subscribed = true; + this.setConnectedState(true); + } catch (e) { + console.log('Error sending: ', serialized, e); + } }); $(`${this.type}-btn-clear`).addEventListener('click', event => { event.preventDefault(); $(`${this.type}-log`).innerText = ""; + this.logCount = 0; + this.updateLogButton(); }); } + send(serialized) { + this.log("sent", serialized); + this.ws.send(serialized); + } + setConnectedState(connected) { + this.connected = connected; $(`${this.type}-btn-connect`).disabled = connected; - $(`${this.type}-btn-ping`).disabled = !connected; - $(`${this.type}-set-key`).disabled = !connected; - $(`${this.type}-set-value`).disabled = !connected; - $(`${this.type}-btn-set`).disabled = !connected; - $(`${this.type}-get-key`).disabled = !connected; - $(`${this.type}-btn-get`).disabled = !connected; - $(`${this.type}-btn-clear`).disabled = !connected; + $(`${this.type}-btn-ping`).disabled = !connected || this.subscribed; + $(`${this.type}-set-key`).disabled = !connected || this.subscribed; + $(`${this.type}-set-value`).disabled = !connected || this.subscribed; + $(`${this.type}-btn-set`).disabled = !connected || this.subscribed; + $(`${this.type}-get-key`).disabled = !connected || this.subscribed; + $(`${this.type}-btn-get`).disabled = !connected || this.subscribed; + $(`${this.type}-pub-channel`).disabled = !connected || this.subscribed; + $(`${this.type}-pub-message`).disabled = !connected || this.subscribed; + $(`${this.type}-btn-pub`).disabled = !connected || this.subscribed; + $(`${this.type}-sub-channel`).disabled = !connected || this.subscribed; + $(`${this.type}-btn-sub`).disabled = !connected || this.subscribed; + $(`${this.type}-state`).innerText = `State: ${connected ? 'Connected' : 'Disconnected'}`; } + updateLogButton() { + $(`${this.type}-btn-clear`).disabled = this.logCount === 0; + } + log(dir, msg) { const id = `${this.type}-log`; @@ -190,22 +245,26 @@ class Client { contents.setAttribute("class", dir); contents.innerHTML = msg; $(id).appendChild(contents); + this.logCount++; + this.updateLogButton(); } } +function serializeRaw(args) { + let raw = `*${args.length}\r\n`; + for (let i = 0; i < args.length; i++) { + raw += `$${args[i].length}\r\n`; + raw += `${args[i]}\r\n`; + } + return raw; +} + addEventListener("DOMContentLoaded", () => { installBlock('JSON', 'json'); installBlock('Raw', 'raw'); - const jsonClient = new Client('json', - () => JSON.stringify(['PING']), - (key) => JSON.stringify(['GET', key]), - (key, value) => JSON.stringify(['SET', key, value])); - - const rawClient = new Client('raw', - () => '*1\r\n$4\r\nPING\r\n', - (key) => `*2\r\n$3\r\nGET\r\n$${key.length}\r\n${key}\r\n`, - (key, value) => `*3\r\n$3\r\nSET\r\n$${key.length}\r\n${key}\r\n$${value.length}\r\n${value}\r\n`); + const jsonClient = new Client('json', JSON.stringify); + const rawClient = new Client('raw', serializeRaw); }); diff --git a/tests/ws-tests.py b/tests/ws-tests.py new file mode 100755 index 0000000..cb038b2 --- /dev/null +++ b/tests/ws-tests.py @@ -0,0 +1,194 @@ +#!/usr/bin/env python3 + +import abc +import json +import os +import unittest +import uuid +from websocket import create_connection + +host = os.getenv('WEBDIS_HOST', '127.0.0.1') +port = int(os.getenv('WEBDIS_PORT', 7379)) + + +def connect(format): + return create_connection(f'ws://{host}:{port}/.{format}') + + +class TestWebdis(unittest.TestCase): + def setUp(self) -> None: + self.ws = connect(self.format()) + + def tearDown(self) -> None: + self.ws.close() + + def exec(self, cmd, *args): + self.ws.send(self.serialize(cmd, *args)) + return self.deserialize(self.ws.recv()) + + def clean_key(self): + """Returns a key that was just deleted""" + key = str(uuid.uuid4()) + self.exec('DEL', key) + return key + + @abc.abstractmethod + def format(self): + """Returns the format to use (added after a dot to the WS URI)""" + return + + @abc.abstractmethod + def serialize(self, cmd): + """Serializes a command according to the format being tested""" + return + + @abc.abstractmethod + def deserialize(self, response): + """Deserializes a response according to the format being tested""" + return + + +class TestJson(TestWebdis): + def format(self): + return 'json' + + def serialize(self, cmd, *args): + return json.dumps([cmd] + list(args)) + + def deserialize(self, response): + return json.loads(response) + + def test_ping(self): + self.assertEqual(self.exec('PING'), {'PING': [True, 'PONG']}) + + def test_multiple_messages(self): + key = self.clean_key() + n = 100 + for i in range(n): + lpush_response = self.exec('LPUSH', key, f'value-{i}') + self.assertEqual(lpush_response, {'LPUSH': i + 1}) + self.assertEqual(self.exec('LLEN', key), {'LLEN': n}) + + +class TestRaw(TestWebdis): + def format(self): + return 'raw' + + def serialize(self, cmd, *args): + buffer = f"*{1 + len(args)}\r\n${len(cmd)}\r\n{cmd}\r\n" + for arg in args: + buffer += f"${len(arg)}\r\n{arg}\r\n" + return buffer + + def deserialize(self, response): + return response # we'll just assert using the raw protocol + + def test_ping(self): + self.assertEqual(self.exec('PING'), "+PONG\r\n") + + def test_get_set(self): + key = self.clean_key() + value = str(uuid.uuid4()) + not_found_response = self.exec('GET', key) + self.assertEqual(not_found_response, "$-1\r\n") # Redis protocol response for "not found" + set_response = self.exec('SET', key, value) + self.assertEqual(set_response, "+OK\r\n") + get_response = self.exec('GET', key) + self.assertEqual(get_response, f"${len(value)}\r\n{value}\r\n") + + +class TestPubSub(unittest.TestCase): + def setUp(self): + self.publisher = connect('json') + self.subscriber = connect('json') + + def tearDown(self): + self.publisher.close() + self.subscriber.close() + + def serialize(self, cmd, *args): + return json.dumps([cmd] + list(args)) + + def deserialize(self, response): + return json.loads(response) + + def test_publish_subscribe(self): + channel_count = 2 + message_count_per_channel = 8 + channels = list(str(uuid.uuid4()) for i in range(channel_count)) + + # subscribe to all channels + sub_count = 0 + for channel in channels: + self.subscriber.send(self.serialize('SUBSCRIBE', channel)) + sub_response = self.deserialize(self.subscriber.recv()) + sub_count += 1 + self.assertEqual(sub_response, {'SUBSCRIBE': ['subscribe', channel, sub_count]}) + + # send messages to all channels + prefix = 'message-' + for i in range(message_count_per_channel): + for channel in channels: + message = f'{prefix}{i}' + self.publisher.send(self.serialize('PUBLISH', channel, message)) + self.deserialize(self.publisher.recv()) + + received_per_channel = dict((channel, []) for channel in channels) + for j in range(channel_count * message_count_per_channel): + received = self.deserialize(self.subscriber.recv()) + # expected: {'SUBSCRIBE': ['message', $channel, $message]} + self.assertTrue(received, 'SUBSCRIBE' in received) + sub_contents = received['SUBSCRIBE'] + self.assertEqual(len(sub_contents), 3) + + self.assertEqual(sub_contents[0], 'message') # first element is the message type, here a push + channel = sub_contents[1] + self.assertTrue(channel in channels) # second is the channel + received_per_channel[channel].append( + sub_contents[2]) # third, add to list of messages received for this channel + + # unsubscribe from all channels + subs_remaining = channel_count + for channel in channels: + self.subscriber.send(self.serialize('UNSUBSCRIBE', channel)) + subs_remaining -= 1 + unsub_response = self.deserialize(self.subscriber.recv()) + self.assertEqual(unsub_response, {'UNSUBSCRIBE': ['unsubscribe', channel, subs_remaining]}) + + # check that we received all messages + for channel in channels: + self.assertEqual(len(received_per_channel[channel]), message_count_per_channel) + + # check that we received them *in order* + for i in range(message_count_per_channel): + for channel in channels: + expected = f'{prefix}{i}' + self.assertEqual(received_per_channel[channel][i], expected, + f'In {channel}: expected at offset {i} was "{expected}", actual was: "{received_per_channel[channel][i]}"') + + +class TestFrameSizes(TestWebdis): + def format(self): + return 'json' + + def serialize(self, cmd, *args): + return json.dumps([cmd] + list(args)) + + def deserialize(self, response): + return json.loads(response) + + def test_length_126(self): + self.validate_set_get('A' * 1024) # this will require 2 bytes to encode the length + + def test_length_127(self): + self.validate_set_get('A' * (2 ** 18)) # this will require more than 2 bytes to encode the length (actually using 8) + + def validate_set_get(self, value): + key = str(uuid.uuid4()) + self.assertEqual(self.exec('SET', key, value), {'SET': [True, 'OK']}) + self.assertEqual(self.exec('GET', key), {'GET': value}) + self.exec('DEL', key) + + +if __name__ == '__main__': + unittest.main()