Large refactoring of WS code

1. Introduce ws_client struct
2. Handle all communications from websocket.c for WS clients
3. Always use a dedicated Redis connection for WS clients
4. Add rbuf & wbuf evbuffers for incoming & outgoing WS data
5. Use event_base_once to control R/W events
6. WS test: make sure to read complete HTTP response
master
Jessie Murray 3 years ago
parent e26d6358e7
commit 6b090b4ede
No known key found for this signature in database
GPG Key ID: E7E4D57EDDA744C5

@ -63,10 +63,7 @@ struct http_client {
struct cmd *self_cmd;
struct ws_msg *frame; /* websocket frame (containing *received* data) */
struct event ws_wev; /* websocket write event */
struct evbuffer *ws_wbuf; /* write buffer for websocket responses */
int ws_scheduled_write; /* whether we are already scheduled to send out WS data */
struct ws_client *ws; /* websocket client */
};
struct http_client *

@ -39,7 +39,6 @@ void
cmd_free_argv(struct cmd *c) {
int i;
fprintf(stderr, "%s: %p\n", __func__, c);
for(i = 0; i < c->count; ++i) {
free((char*)c->argv[i]);
}

@ -7,6 +7,8 @@
#include "md5/md5.h"
#include <string.h>
#include <unistd.h>
#include <pthread.h>
#include <ctype.h>
/* TODO: replace this with a faster hash function? */
char *etag_new(const char *p, size_t sz) {
@ -62,7 +64,8 @@ format_send_reply(struct cmd *cmd, const char *p, size_t sz, const char *content
struct http_response *resp;
if(cmd->is_websocket) {
ws_frame_and_send_response(cmd, p, sz);
ws_frame_and_send_response(cmd->http_client->ws, WS_BINARY_FRAME, p, sz);
/* If it's a subscribe command, there'll be more responses */
if(!cmd_is_subscribe(cmd))

@ -2,6 +2,7 @@
#include "worker.h"
#include "conf.h"
#include "server.h"
#include "formats/common.h"
#include <stdlib.h>
#include <string.h>
@ -30,7 +31,6 @@ pool_free_context(redisAsyncContext *ac) {
if (ac) {
redisAsyncDisconnect(ac);
redisAsyncFree(ac);
}
}

@ -7,6 +7,9 @@
#include "pool.h"
#include "http.h"
#include "slog.h"
#include "server.h"
#include "conf.h"
#include "formats/common.h"
/* message parsers */
#include "formats/json.h"
@ -18,9 +21,10 @@
#include <unistd.h>
#include <errno.h>
#include <sys/param.h>
#include <inttypes.h>
static void
ws_schedule_write(struct http_client *c);
ws_schedule_write(struct ws_client *ws);
/**
* This code uses the WebSocket specification from RFC 6455.
@ -89,9 +93,54 @@ ws_compute_handshake(struct http_client *c, char *out, size_t *out_sz) {
return 0;
}
struct ws_client *
ws_client_new(struct http_client *http_client) {
int db_num = http_client->w->s->cfg->database;
struct ws_client *ws = calloc(1, sizeof(struct ws_client));
struct evbuffer *rbuf = evbuffer_new();
struct evbuffer *wbuf = evbuffer_new();
redisAsyncContext *ac = pool_connect(http_client->w->pool, db_num, 0);
if(!ws || !rbuf || !wbuf) {
slog(http_client->s, WEBDIS_ERROR, "Failed to allocate memory for WS client", 0);
if(ws) free(ws);
if(rbuf) evbuffer_free(rbuf);
if(wbuf) evbuffer_free(wbuf);
if(ac) redisAsyncFree(ac);
return NULL;
}
http_client->ws = ws;
ws->http_client = http_client;
ws->rbuf = rbuf;
ws->wbuf = wbuf;
ws->ac = ac;
return ws;
}
static void
ws_client_free(struct ws_client *ws) {
struct http_client *c = ws->http_client;
c->ws = NULL; /* detach */
evbuffer_free(ws->rbuf);
evbuffer_free(ws->wbuf);
pool_free_context(ws->ac);
if(ws->cmd) {
ws->cmd->ac = NULL; /* we've just free'd it */
cmd_free(ws->cmd);
}
free(ws);
http_client_free(c);
}
int
ws_handshake_reply(struct http_client *c) {
ws_handshake_reply(struct ws_client *ws) {
struct http_client *c = ws->http_client;
char sha1_handshake[40];
char *buffer = NULL, *p;
const char *origin = NULL, *host = NULL;
@ -174,30 +223,23 @@ ws_handshake_reply(struct http_client *c) {
memcpy(p, template_end, sizeof(template_end)-1);
p += sizeof(template_end)-1;
/* create buffer that will hold data to send out */
c->ws_wbuf = evbuffer_new();
if(!c->ws_wbuf) {
slog(c->s, WEBDIS_ERROR, "Failed to allocate response for WS handshake", 0);
int add_ret = evbuffer_add(ws->wbuf, buffer, sz);
free(buffer);
return -1;
}
int add_ret = evbuffer_add(c->ws_wbuf, buffer, sz);
if(add_ret < 0) {
slog(c->s, WEBDIS_ERROR, "Failed to add response for WS handshake", 0);
free(buffer);
return -1;
}
ws_schedule_write(c); /* will free buffer and response once sent */
ws_schedule_write(ws); /* will free buffer and response once sent */
return 0;
}
static int
ws_execute(struct http_client *c, const char *frame, size_t frame_len) {
ws_execute(struct ws_client *ws, struct ws_msg *msg) {
struct http_client *c = ws->http_client;
struct cmd*(*fun_extract)(struct http_client *, const char *, size_t) = NULL;
formatting_fun fun_reply = NULL;
@ -213,35 +255,36 @@ ws_execute(struct http_client *c, const char *frame, size_t frame_len) {
if(fun_extract) {
/* Parse websocket frame into a cmd object. */
struct cmd *cmd = fun_extract(c, frame, frame_len);
struct cmd *cmd = fun_extract(c, msg->payload, msg->payload_sz);
if(cmd) {
cmd->is_websocket = 1;
if (c->self_cmd != NULL) {
/* This client already has its own connection
* to Redis from a previous command; use it from
* now on. */
if(ws->cmd != NULL) {
/* This client already has its own connection to Redis
from a previous command; use it from now on. */
/* free args for the previous cmd */
cmd_free_argv(c->self_cmd);
cmd_free_argv(ws->cmd);
/* copy args from what we just parsed to the persistent command */
c->self_cmd->count = cmd->count;
c->self_cmd->argv = cmd->argv;
c->self_cmd->argv_len = cmd->argv_len;
ws->cmd->count = cmd->count;
ws->cmd->argv = cmd->argv;
ws->cmd->argv_len = cmd->argv_len;
ws->cmd->pub_sub_client = c; /* mark as persistent, otherwise the Redis context will be freed */
cmd->argv = NULL;
cmd->argv_len = NULL;
cmd->count = 0;
cmd_free(cmd);
cmd = c->self_cmd; /* replace pointer since we're about to pass it to cmd_send */
cmd = ws->cmd; /* replace pointer since we're about to pass it to cmd_send */
} else {
/* copy client info into cmd. */
cmd_setup(cmd, c);
/* First WS command; make new Redis context
* for this client */
cmd->ac = pool_connect(c->w->pool, cmd->database, 0);
c->self_cmd = cmd;
/* First WS command; use Redis context from WS client. */
cmd->ac = ws->ac;
ws->cmd = cmd;
cmd->pub_sub_client = c;
}
@ -256,8 +299,13 @@ ws_execute(struct http_client *c, const char *frame, size_t frame_len) {
}
static struct ws_msg *
ws_msg_new() {
return calloc(1, sizeof(struct ws_msg));
ws_msg_new(enum ws_frame_type frame_type) {
struct ws_msg *msg = calloc(1, sizeof(struct ws_msg));
if(!msg) {
return NULL;
}
msg->type = frame_type;
return msg;
}
static void
@ -278,26 +326,38 @@ ws_msg_add(struct ws_msg *m, const char *p, size_t psz, const unsigned char *mas
}
static void
ws_msg_free(struct ws_msg **m) {
ws_msg_free(struct ws_msg *m) {
free((*m)->payload);
free(*m);
*m = NULL;
free(m->payload);
free(m);
}
/* checks to see if we have a complete message */
static enum ws_state
ws_parse_data(const char *frame, size_t sz, struct ws_msg **msg) {
ws_peek_data(struct ws_client *ws, struct ws_msg **out_msg) {
int has_mask;
uint64_t len;
const char *p;
char *frame;
unsigned char mask[4];
char fin_bit_set;
enum ws_frame_type frame_type;
/* parse frame and extract contents */
size_t sz = evbuffer_get_length(ws->rbuf);
if(sz < 8) {
return WS_READING;
return WS_READING; /* need more data */
}
/* copy into "frame" to process it */
frame = malloc(sz);
if(!frame) {
return WS_ERROR;
}
evbuffer_remove(ws->rbuf, frame, sz);
fin_bit_set = frame[0] & 0x80 ? 1 : 0;
frame_type = frame[0] & 0x0F; /* lower 4 bits of first byte */
has_mask = frame[1] & 0x80 ? 1:0;
/* get payload length */
@ -316,58 +376,84 @@ ws_parse_data(const char *frame, size_t sz, struct ws_msg **msg) {
p = frame + 10 + (has_mask ? 4 : 0);
if(has_mask) memcpy(&mask, frame + 10, sizeof(mask));
} else {
free(frame);
return WS_ERROR;
}
/* we now have the (possibly masked) data starting in p, and its length. */
if(len > sz - (p - frame)) { /* not enough data */
evbuffer_prepend(ws->rbuf, frame, sz); /* put the whole frame back */
free(frame);
return WS_READING;
}
if(!*msg)
*msg = ws_msg_new();
ws_msg_add(*msg, p, len, has_mask ? mask : NULL);
(*msg)->total_sz += len + (p - frame);
if(out_msg) { /* we're extracting the message */
struct ws_msg *msg = ws_msg_new(frame_type);
ws_msg_add(msg, p, len, has_mask ? mask : NULL);
size_t processed_sz = len + (p - frame); /* length of data + header bytes between frame start and payload */
msg->total_sz += processed_sz;
*out_msg = msg;
evbuffer_prepend(ws->rbuf, frame + len, sz - processed_sz); /* remove processed data */
} else { /* we're just peeking */
evbuffer_prepend(ws->rbuf, frame, sz); /* put the whole frame back */
}
free(frame);
if(frame[0] & 0x80) { /* FIN bit set */
if(fin_bit_set) { /* FIN bit set */
return WS_MSG_COMPLETE;
} else {
return WS_READING; /* need more data */
}
}
/**
* Process some data just received on the socket.
* Returns the number of messages processed in out_processed, if non-NULL.
*/
enum ws_state
ws_add_data(struct http_client *c) {
ws_process_read_data(struct ws_client *ws, unsigned int *out_processed) {
enum ws_state state;
if(out_processed) *out_processed = 0;
state = ws_parse_data(c->buffer, c->sz, &c->frame);
state = ws_peek_data(ws, NULL); /* check for complete message */
while(state == WS_MSG_COMPLETE) {
int ret = ws_execute(c, c->frame->payload, c->frame->payload_sz);
/* remove frame from client buffer */
http_client_remove_data(c, c->frame->total_sz);
int ret = 0;
struct ws_msg *msg;
ws_peek_data(ws, &msg); /* extract message */
if(msg->type == WS_TEXT_FRAME || msg->type == WS_BINARY_FRAME) {
ret = ws_execute(ws, msg);
if(out_processed) (*out_processed)++;
} else if(msg->type == WS_PING) { /* respond to ping */
ws_frame_and_send_response(ws, WS_PONG, msg->payload, msg->payload_sz);
} else if(msg->type == WS_CONNECTION_CLOSE) { /* respond to close frame */
ws->close_after_events = 1;
ws_frame_and_send_response(ws, WS_CONNECTION_CLOSE, msg->payload, msg->payload_sz);
} else {
char format[] = "Received unexpected WS frame type: 0x%x";
char error[(sizeof format)];
snprintf(error, sizeof(error), format, msg->type);
slog(ws->http_client->s, WEBDIS_WARNING, error, 0);
}
/* free frame and set back to NULL */
ws_msg_free(&c->frame);
/* free frame */
ws_msg_free(msg);
if(ret != 0) {
/* can't process frame. */
slog(c->s, WEBDIS_DEBUG, "ws_add_data: ws_execute failed", 0);
slog(ws->http_client->s, WEBDIS_DEBUG, "ws_process_read_data: ws_execute failed", 0);
return WS_ERROR;
}
state = ws_parse_data(c->buffer, c->sz, &c->frame);
state = ws_peek_data(ws, NULL);
}
return state;
}
int
ws_frame_and_send_response(struct cmd *cmd, const char *p, size_t sz) {
ws_frame_and_send_response(struct ws_client *ws, enum ws_frame_type frame_type, const char *p, size_t sz) {
char *frame = malloc(sz + 8); /* create frame by prepending header */
size_t frame_sz = 0;
@ -381,7 +467,7 @@ ws_frame_and_send_response(struct cmd *cmd, const char *p, size_t sz) {
following 8 bytes interpreted as a 64-bit unsigned integer (the
most significant bit MUST be 0) are the payload length.
*/
frame[0] = '\x81';
frame[0] = 0x80 | frame_type; /* frame type + EOM bit */
if(sz <= 125) {
frame[1] = sz;
memcpy(frame + 2, p, sz);
@ -401,50 +487,90 @@ ws_frame_and_send_response(struct cmd *cmd, const char *p, size_t sz) {
}
/* mark as keep alive, otherwise we'll close the connection after the first reply */
int add_ret = evbuffer_add(cmd->http_client->ws_wbuf, frame, frame_sz);
int add_ret = evbuffer_add(ws->wbuf, frame, frame_sz);
free(frame); /* no longer needed once added to buffer */
if(add_ret < 0) {
slog(cmd->w->s, WEBDIS_ERROR, "Failed response allocation in ws_frame_and_send_response", 0);
slog(ws->http_client->w->s, WEBDIS_ERROR, "Failed response allocation in ws_frame_and_send_response", 0);
return -1;
}
/* send WS frame */
ws_schedule_write(cmd->http_client);
ws_schedule_write(ws);
return 0;
}
static void
ws_close_if_able(struct ws_client *ws) {
if(ws->scheduled_read || ws->scheduled_write) {
return; /* still waiting for these events to trigger */
}
ws_client_free(ws); /* will close the socket */
}
static void
ws_can_read(int fd, short event, void *p) {
int ret;
struct ws_client *ws = p;
(void)event;
/* read pending data */
ws->scheduled_read = 0;
ret = evbuffer_read(ws->rbuf, fd, 4096);
if(ret <= 0) {
ws_client_free(ws); /* will close the socket */
} else if(ws->close_after_events) {
ws_close_if_able(ws);
} else {
ws_process_read_data(ws, NULL);
}
}
static void
ws_can_write(int fd, short event, void *p) {
int ret;
struct http_client *c = p;
struct ws_client *ws = p;
(void)event;
c->ws_scheduled_write = 0;
ws->scheduled_write = 0;
/* send pending data */
ret = evbuffer_write(c->ws_wbuf, fd);
if(ret < 0) {
close(fd);
} else if(ret > 0 && evbuffer_get_length(c->ws_wbuf) > 0) { /* more data to send */
ws_schedule_write(c);
ret = evbuffer_write_atmost(ws->wbuf, fd, 4096);
if(ret <= 0) {
ws_client_free(ws); /* will close the socket */
} else if(ret > 0) {
if(evbuffer_get_length(ws->wbuf) > 0) { /* more data to send */
ws_schedule_write(ws);
} else if(ws->close_after_events) { /* we're done! */
ws_close_if_able(ws);
} else {
/* check if we can read more data */
unsigned int processed = 0;
ws_process_read_data(ws, &processed); /* process any pending data we've already read */
ws_monitor_input(ws); /* let's read more from the client */
}
}
}
static void
ws_schedule_write(struct http_client *c) {
if(c->ws_scheduled_write) {
return;
}
event_set(&c->ws_wev, c->fd, EV_WRITE, ws_can_write, c);
event_base_set(c->w->base, &c->ws_wev);
int ret = event_add(&c->ws_wev, NULL);
if(ret == 0) {
c->ws_scheduled_write = 1;
} else { /* could not schedule write */
slog(c->w->s, WEBDIS_ERROR, "Could not schedule WS write", 0);
ws_schedule_write(struct ws_client *ws) {
struct http_client *c = ws->http_client;
if(!ws->scheduled_write) {
ws->scheduled_write = 1;
event_base_once(c->w->base, c->fd, EV_WRITE, ws_can_write, ws, NULL);
}
}
void
ws_monitor_input(struct ws_client *ws) {
struct http_client *c = ws->http_client;
if(!ws->scheduled_read) {
ws->scheduled_read = 1;
event_base_once(c->w->base, c->fd, EV_READ, ws_can_read, ws, NULL);
}
}

@ -3,6 +3,8 @@
#include <stdlib.h>
#include <stdint.h>
#include <event.h>
#include <hiredis/async.h>
struct http_client;
struct cmd;
@ -12,19 +14,47 @@ enum ws_state {
WS_READING,
WS_MSG_COMPLETE};
enum ws_frame_type {
WS_TEXT_FRAME = 0,
WS_BINARY_FRAME = 1,
WS_CONNECTION_CLOSE = 8,
WS_PING = 9,
WS_PONG = 0xA,
WS_UNKNOWN_FRAME = -1};
struct ws_msg {
enum ws_frame_type type;
char *payload;
size_t payload_sz;
size_t total_sz;
};
struct ws_client {
struct http_client *http_client; /* parent */
int scheduled_read; /* set if we are scheduled to read WS data */
int scheduled_write; /* set if we are scheduled to send out WS data */
struct evbuffer *rbuf; /* read buffer for incoming data */
struct evbuffer *wbuf; /* write buffer for outgoing data */
redisAsyncContext *ac; /* dedicated connection to redis */
struct cmd *cmd; /* current command */
/* indicates that we'll close once we've flushed all
buffered data and read what we planned to read */
int close_after_events;
};
struct ws_client *
ws_client_new(struct http_client *http_client);
int
ws_handshake_reply(struct http_client *c);
ws_handshake_reply(struct ws_client *ws);
void
ws_monitor_input(struct ws_client *ws);
enum ws_state
ws_add_data(struct http_client *c);
ws_process_read_data(struct ws_client *ws, unsigned int *out_processed);
int
ws_frame_and_send_response(struct cmd *cmd, const char *p, size_t sz);
ws_frame_and_send_response(struct ws_client *ws, enum ws_frame_type type, const char *p, size_t sz);
#endif

@ -13,7 +13,8 @@
#include <unistd.h>
#include <event.h>
#include <string.h>
#include <netinet/tcp.h>
#include "formats/common.h"
struct worker *
worker_new(struct server *s) {
@ -54,11 +55,6 @@ worker_can_read(int fd, short event, void *p) {
}
if(c->is_websocket) {
/* Got websocket data */
int add_ret = ws_add_data(c);
if(add_ret == WS_ERROR) {
c->broken = 1; /* likely connection was closed */
}
} else {
/* run parser */
nparsed = http_client_execute(c);
@ -70,13 +66,28 @@ worker_can_read(int fd, short event, void *p) {
/* only close if requested *and* we've already read the request in full */
c->broken = 1;
} else if(c->is_websocket) {
/* we need to use the remaining (unparsed) data as the body. */
if(nparsed < ret) {
http_client_add_to_body(c, c->buffer + nparsed + 1, c->sz - nparsed - 1);
ws_handshake_reply(c);
} else {
event_del(&c->ev);
/* Got websocket data */
c->ws = ws_client_new(c);
if(!c->ws) {
c->broken = 1;
} else {
free(c->buffer);
c->buffer = NULL;
c->sz = 0;
unsigned int processed = 0;
int process_ret = ws_process_read_data(c->ws, &processed);
if(process_ret == WS_ERROR) {
c->broken = 1; /* likely connection was closed */
}
/* send response, and start managing fd from websocket.c */
ws_handshake_reply(c->ws);
}
/* clean up what remains in HTTP client */
free(c->buffer);
c->buffer = NULL;
c->sz = 0;
@ -96,16 +107,19 @@ worker_can_read(int fd, short event, void *p) {
http_client_free(c);
} else {
/* start monitoring input again */
if(c->is_websocket) { /* all communication handled by WS code from now on */
// ws_monitor_input(c->ws);
} else {
worker_monitor_input(c);
}
}
}
/**
* Monitor client FD for possible reads.
*/
void
worker_monitor_input(struct http_client *c) {
event_set(&c->ev, c->fd, EV_READ, worker_can_read, c);
event_base_set(c->w->base, &c->ev);
event_add(&c->ev, NULL);

@ -263,10 +263,13 @@ websocket_can_read(int fd, short event, void *ptr) {
/* http parser will return the offset at which the upgraded protocol begins,
which in our case is 1 under the total response size. */
if (wt->state == WS_SENT_HANDSHAKE || /* haven't encountered end of response yet */
(wt->parser.upgrade && nparsed != (int)avail_sz -1)) {
wt->debug("UPGRADE *and* we have some data left (nparsed=%d, avail_sz=%lu)\n", nparsed, avail_sz);
if (wt->state == WS_SENT_HANDSHAKE) { /* haven't encountered end of response yet */
if (wt->parser.upgrade && nparsed != (int)avail_sz) {
wt->debug("UPGRADE *and* we have some data left (state=%d, nparsed=%d, avail_sz=%lu)\n", wt->state, nparsed, avail_sz);
continue;
} else { /* we just haven't read the entire response yet */
wait_for_possible_read(wt);
}
} else if (wt->state == WS_RECEIVED_HANDSHAKE) { /* we have the full response */
evbuffer_drain(wt->rbuffer, evbuffer_get_length(wt->rbuffer));
}

@ -117,9 +117,9 @@ class TestPubSub(unittest.TestCase):
sub_count = 0
for channel in channels:
self.subscriber.send(self.serialize('SUBSCRIBE', channel))
unsub_response = self.deserialize(self.subscriber.recv())
sub_response = self.deserialize(self.subscriber.recv())
sub_count += 1
self.assertEqual(unsub_response, {'SUBSCRIBE': ['subscribe', channel, sub_count]})
self.assertEqual(sub_response, {'SUBSCRIBE': ['subscribe', channel, sub_count]})
# send messages to all channels
prefix = 'message-'
@ -127,11 +127,11 @@ class TestPubSub(unittest.TestCase):
for channel in channels:
message = f'{prefix}{i}'
self.publisher.send(self.serialize('PUBLISH', channel, message))
self.deserialize(self.publisher.recv())
received_per_channel = dict((channel, []) for channel in channels)
for j in range(channel_count * message_count_per_channel):
received = self.deserialize(self.subscriber.recv())
print('received:', received)
# expected: {'SUBSCRIBE': ['message', $channel, $message]}
self.assertTrue(received, 'SUBSCRIBE' in received)
sub_contents = received['SUBSCRIBE']
@ -148,7 +148,7 @@ class TestPubSub(unittest.TestCase):
self.subscriber.send(self.serialize('UNSUBSCRIBE', channel))
subs_remaining -= 1
unsub_response = self.deserialize(self.subscriber.recv())
self.assertEqual(unsub_response, {'SUBSCRIBE': ['unsubscribe', channel, subs_remaining]})
self.assertEqual(unsub_response, {'UNSUBSCRIBE': ['unsubscribe', channel, subs_remaining]})
# check that we received all messages
for channel in channels:

Loading…
Cancel
Save