diff --git a/src/formats/common.c b/src/formats/common.c index 50e4780..5304b09 100644 --- a/src/formats/common.c +++ b/src/formats/common.c @@ -50,11 +50,12 @@ format_send_error(struct cmd *cmd, short code, const char *msg) { ws_frame_and_send_response(cmd->http_client->ws, WS_BINARY_FRAME, msg, strlen(msg)); } - /* for pub/sub, remove command from client */ - if(cmd->pub_sub_client) { - cmd->pub_sub_client->self_cmd = NULL; - } else if (!cmd->is_websocket) { /* don't free persistent cmd */ - cmd_free(cmd); + if (!cmd->is_websocket) { /* don't free or detach persistent cmd */ + if (cmd->pub_sub_client) { /* for pub/sub, remove command from client */ + cmd->pub_sub_client->self_cmd = NULL; + } else { + cmd_free(cmd); + } } } diff --git a/src/pool.c b/src/pool.c index 741c7da..f3e7b42 100644 --- a/src/pool.c +++ b/src/pool.c @@ -2,7 +2,6 @@ #include "worker.h" #include "conf.h" #include "server.h" -#include "formats/common.h" #include #include diff --git a/src/slog.c b/src/slog.c index c48489d..bdb600d 100644 --- a/src/slog.c +++ b/src/slog.c @@ -13,6 +13,10 @@ #include "server.h" #include "conf.h" +#if SLOG_MSG_MAX_LEN < 64 +#error "SLOG_MSG_MAX_LEN must be at least 64" +#endif + /** * Initialize log writer. */ diff --git a/src/websocket.c b/src/websocket.c index 27358ea..f8633ec 100644 --- a/src/websocket.c +++ b/src/websocket.c @@ -9,7 +9,6 @@ #include "slog.h" #include "server.h" #include "conf.h" -#include "formats/common.h" /* message parsers */ #include "formats/json.h" @@ -21,9 +20,8 @@ #include #include #include -#include -static void +static int ws_schedule_write(struct ws_client *ws); /** @@ -120,7 +118,7 @@ ws_client_new(struct http_client *http_client) { return ws; } -static void +void ws_client_free(struct ws_client *ws) { /* mark WS client as closing to skip the Redis callback */ @@ -128,7 +126,7 @@ ws_client_free(struct ws_client *ws) { pool_free_context(ws->ac); /* could trigger a cb via format_send_error */ struct http_client *c = ws->http_client; - c->ws = NULL; /* detach */ + if(c) c->ws = NULL; /* detach if needed */ evbuffer_free(ws->rbuf); evbuffer_free(ws->wbuf); if(ws->cmd) { @@ -136,7 +134,7 @@ ws_client_free(struct ws_client *ws) { cmd_free(ws->cmd); } free(ws); - http_client_free(c); + if(c) http_client_free(c); } @@ -233,9 +231,7 @@ ws_handshake_reply(struct ws_client *ws) { return -1; } - ws_schedule_write(ws); /* will free buffer and response once sent */ - - return 0; + return ws_schedule_write(ws); /* will free buffer and response once sent */ } static void @@ -335,12 +331,15 @@ ws_msg_new(enum ws_frame_type frame_type) { return msg; } -static void +static int ws_msg_add(struct ws_msg *m, const char *p, size_t psz, const unsigned char *mask) { /* add data to frame */ size_t i; m->payload = realloc(m->payload, m->payload_sz + psz); + if(!m->payload) { + return -1; + } memcpy(m->payload + m->payload_sz, p, psz); /* apply mask */ @@ -350,6 +349,7 @@ ws_msg_add(struct ws_msg *m, const char *p, size_t psz, const unsigned char *mas /* save new size */ m->payload_sz += psz; + return 0; } static void @@ -381,7 +381,11 @@ ws_peek_data(struct ws_client *ws, struct ws_msg **out_msg) { if(!frame) { return WS_ERROR; } - evbuffer_remove(ws->rbuf, frame, sz); + int rem_ret = evbuffer_remove(ws->rbuf, frame, sz); + if(rem_ret < 0) { + free(frame); + return WS_ERROR; + } fin_bit_set = frame[0] & 0x80 ? 1 : 0; frame_type = frame[0] & 0x0F; /* lower 4 bits of first byte */ @@ -409,25 +413,39 @@ ws_peek_data(struct ws_client *ws, struct ws_msg **out_msg) { /* we now have the (possibly masked) data starting in p, and its length. */ if(len > sz - (p - frame)) { /* not enough data */ - evbuffer_prepend(ws->rbuf, frame, sz); /* put the whole frame back */ + int add_ret = evbuffer_prepend(ws->rbuf, frame, sz); /* put the whole frame back */ free(frame); - return WS_READING; + return add_ret < 0 ? WS_ERROR : WS_READING; } + int ev_copy = 0; if(out_msg) { /* we're extracting the message */ struct ws_msg *msg = ws_msg_new(frame_type); - ws_msg_add(msg, p, len, has_mask ? mask : NULL); + if(!msg) { + free(frame); + return WS_ERROR; + } + *out_msg = msg; /* attach for it to be freed by caller */ + + /* create new ws_msg object holding what we read */ + int add_ret = ws_msg_add(msg, p, len, has_mask ? mask : NULL); + if(!add_ret) { + free(frame); + return WS_ERROR; + } + size_t processed_sz = len + (p - frame); /* length of data + header bytes between frame start and payload */ msg->total_sz += processed_sz; - *out_msg = msg; - evbuffer_prepend(ws->rbuf, frame + len, sz - processed_sz); /* remove processed data */ + ev_copy = evbuffer_prepend(ws->rbuf, frame + len, sz - processed_sz); /* remove processed data */ } else { /* we're just peeking */ - evbuffer_prepend(ws->rbuf, frame, sz); /* put the whole frame back */ + ev_copy = evbuffer_prepend(ws->rbuf, frame, sz); /* put the whole frame back */ } free(frame); - if(fin_bit_set) { /* FIN bit set */ + if(ev_copy < 0) { + return WS_ERROR; + } else if(fin_bit_set) { return WS_MSG_COMPLETE; } else { return WS_READING; /* need more data */ @@ -448,26 +466,26 @@ ws_process_read_data(struct ws_client *ws, unsigned int *out_processed) { while(state == WS_MSG_COMPLETE) { int ret = 0; - struct ws_msg *msg; + struct ws_msg *msg = NULL; ws_peek_data(ws, &msg); /* extract message */ - if(msg->type == WS_TEXT_FRAME || msg->type == WS_BINARY_FRAME) { + if(msg && (msg->type == WS_TEXT_FRAME || msg->type == WS_BINARY_FRAME)) { ret = ws_execute(ws, msg); if(out_processed) (*out_processed)++; - } else if(msg->type == WS_PING) { /* respond to ping */ + } else if(msg && msg->type == WS_PING) { /* respond to ping */ ws_frame_and_send_response(ws, WS_PONG, msg->payload, msg->payload_sz); - } else if(msg->type == WS_CONNECTION_CLOSE) { /* respond to close frame */ + } else if(msg && msg->type == WS_CONNECTION_CLOSE) { /* respond to close frame */ ws->close_after_events = 1; ws_frame_and_send_response(ws, WS_CONNECTION_CLOSE, msg->payload, msg->payload_sz); - } else { + } else if(msg) { char format[] = "Received unexpected WS frame type: 0x%x"; char error[(sizeof format)]; - snprintf(error, sizeof(error), format, msg->type); - slog(ws->http_client->s, WEBDIS_WARNING, error, 0); + int error_len = snprintf(error, sizeof(error), format, msg->type); + slog(ws->http_client->s, WEBDIS_WARNING, error, error_len); } /* free frame */ - ws_msg_free(msg); + if(msg) ws_msg_free(msg); if(ret != 0) { /* can't process frame. */ @@ -522,13 +540,12 @@ ws_frame_and_send_response(struct ws_client *ws, enum ws_frame_type frame_type, } /* send WS frame */ - ws_schedule_write(ws); - - return 0; + return ws_schedule_write(ws); } static void ws_close_if_able(struct ws_client *ws) { + ws->close_after_events = 1; /* note that we're closing */ if(ws->scheduled_read || ws->scheduled_write) { return; /* still waiting for these events to trigger */ @@ -571,7 +588,7 @@ ws_can_write(int fd, short event, void *p) { if(ret <= 0) { ws_client_free(ws); /* will close the socket */ - } else if(ret > 0) { + } else { if(evbuffer_get_length(ws->wbuf) > 0) { /* more data to send */ ws_schedule_write(ws); } else if(ws->close_after_events) { /* we're done! */ @@ -585,20 +602,22 @@ ws_can_write(int fd, short event, void *p) { } } -static void +static int ws_schedule_write(struct ws_client *ws) { struct http_client *c = ws->http_client; if(!ws->scheduled_write) { ws->scheduled_write = 1; - event_base_once(c->w->base, c->fd, EV_WRITE, ws_can_write, ws, NULL); + return event_base_once(c->w->base, c->fd, EV_WRITE, ws_can_write, ws, NULL); } + return 0; } -void +int ws_monitor_input(struct ws_client *ws) { struct http_client *c = ws->http_client; if(!ws->scheduled_read) { ws->scheduled_read = 1; - event_base_once(c->w->base, c->fd, EV_READ, ws_can_read, ws, NULL); + return event_base_once(c->w->base, c->fd, EV_READ, ws_can_read, ws, NULL); } + return 0; } diff --git a/src/websocket.h b/src/websocket.h index 7376546..f15af74 100644 --- a/src/websocket.h +++ b/src/websocket.h @@ -45,10 +45,13 @@ struct ws_client { struct ws_client * ws_client_new(struct http_client *http_client); +void +ws_client_free(struct ws_client *ws); + int ws_handshake_reply(struct ws_client *ws); -void +int ws_monitor_input(struct ws_client *ws); enum ws_state diff --git a/src/worker.c b/src/worker.c index eba6b07..d440546 100644 --- a/src/worker.c +++ b/src/worker.c @@ -54,8 +54,7 @@ worker_can_read(int fd, short event, void *p) { } } - if(c->is_websocket) { - } else { + if(!c->is_websocket) { /* run parser */ nparsed = http_client_execute(c); @@ -66,7 +65,6 @@ worker_can_read(int fd, short event, void *p) { /* only close if requested *and* we've already read the request in full */ c->broken = 1; } else if(c->is_websocket) { - event_del(&c->ev); /* Got websocket data */ c->ws = ws_client_new(c); @@ -77,14 +75,19 @@ worker_can_read(int fd, short event, void *p) { c->buffer = NULL; c->sz = 0; - unsigned int processed = 0; - int process_ret = ws_process_read_data(c->ws, &processed); - if(process_ret == WS_ERROR) { - c->broken = 1; /* likely connection was closed */ - } - /* send response, and start managing fd from websocket.c */ - ws_handshake_reply(c->ws); + int reply_ret = ws_handshake_reply(c->ws); + if(reply_ret < 0) { + c->ws->http_client = NULL; /* detach to prevent double free */ + ws_client_free(c->ws); + c->broken = 1; + } else { + unsigned int processed = 0; + int process_ret = ws_process_read_data(c->ws, &processed); + if(process_ret == WS_ERROR) { + c->broken = 1; /* likely connection was closed */ + } + } } /* clean up what remains in HTTP client */