From 974556defb639ac85d244dd78ab780eca9a9afa0 Mon Sep 17 00:00:00 2001 From: Jessie Murray Date: Sat, 24 Jul 2021 14:02:20 -0700 Subject: [PATCH 01/25] WS headers: change order, make origin optional 1. Origin and Sec-WebSocket-Origin are now optional: only return a matching Sec-WebSocket-Origin if one of the two headers was provided. 2. Change order of headers: return Sec-WebSocket-Accept immediately after Upgrade and Connection since some clients expect it there. --- src/websocket.c | 75 ++++++++++++++++++++++++------------------------- 1 file changed, 37 insertions(+), 38 deletions(-) diff --git a/src/websocket.c b/src/websocket.c index 53d48c5..1b51aff 100644 --- a/src/websocket.c +++ b/src/websocket.c @@ -94,17 +94,16 @@ ws_handshake_reply(struct http_client *c) { const char *origin = NULL, *host = NULL; size_t origin_sz = 0, host_sz = 0, handshake_sz = 0, sz; - char template0[] = "HTTP/1.1 101 Switching Protocols\r\n" + char template_start[] = "HTTP/1.1 101 Switching Protocols\r\n" "Upgrade: websocket\r\n" - "Connection: Upgrade\r\n" - "Sec-WebSocket-Origin: "; /* %s */ - char template1[] = "\r\n" - "Sec-WebSocket-Location: ws://"; /* %s%s */ - char template2[] = "\r\n" - "Origin: http://"; /* %s */ - char template3[] = "\r\n" + "Connection: Upgrade"; + char template_accept[] = "\r\n" /* just after the start */ "Sec-WebSocket-Accept: "; /* %s */ - char template4[] = "\r\n\r\n"; + char template_sec_origin[] = "\r\n" + "Sec-WebSocket-Origin: "; /* %s (optional header) */ + char template_loc[] = "\r\n" + "Sec-WebSocket-Location: ws://"; /* %s%s */ + char template_end[] = "\r\n\r\n"; if((origin = client_get_header(c, "Origin"))) { origin_sz = strlen(origin); @@ -116,7 +115,7 @@ ws_handshake_reply(struct http_client *c) { } /* need those headers */ - if(!origin || !origin_sz || !host || !host_sz || !c->path || !c->path_sz) { + if(!host || !host_sz || !c->path || !c->path_sz) { slog(c->s, WEBDIS_WARNING, "Missing headers for WS handshake", 0); return -1; } @@ -128,11 +127,11 @@ ws_handshake_reply(struct http_client *c) { return -1; } - sz = sizeof(template0)-1 + origin_sz - + sizeof(template1)-1 + host_sz + c->path_sz - + sizeof(template2)-1 + host_sz - + sizeof(template3)-1 + handshake_sz - + sizeof(template4)-1; + sz = sizeof(template_start)-1 + + sizeof(template_accept)-1 + handshake_sz + + (origin && origin_sz ? (sizeof(template_sec_origin)-1 + origin_sz) : 0) /* optional origin */ + + sizeof(template_loc)-1 + host_sz + c->path_sz + + sizeof(template_end)-1; p = buffer = malloc(sz); if(!p) { @@ -142,35 +141,35 @@ ws_handshake_reply(struct http_client *c) { /* Concat all */ - /* template0 */ - memcpy(p, template0, sizeof(template0)-1); - p += sizeof(template0)-1; - memcpy(p, origin, origin_sz); - p += origin_sz; + /* template_start */ + memcpy(p, template_start, sizeof(template_start)-1); + p += sizeof(template_start)-1; + + /* template_accept */ + memcpy(p, template_accept, sizeof(template_accept)-1); + p += sizeof(template_accept)-1; + memcpy(p, &sha1_handshake[0], handshake_sz); + p += handshake_sz; + + /* template_sec_origin */ + if(origin && origin_sz) { + memcpy(p, template_sec_origin, sizeof(template_sec_origin)-1); + p += sizeof(template_sec_origin)-1; + memcpy(p, origin, origin_sz); + p += origin_sz; + } - /* template1 */ - memcpy(p, template1, sizeof(template1)-1); - p += sizeof(template1)-1; + /* template_loc */ + memcpy(p, template_loc, sizeof(template_loc)-1); + p += sizeof(template_loc)-1; memcpy(p, host, host_sz); p += host_sz; memcpy(p, c->path, c->path_sz); p += c->path_sz; - /* template2 */ - memcpy(p, template2, sizeof(template2)-1); - p += sizeof(template2)-1; - memcpy(p, host, host_sz); - p += host_sz; - - /* template3 */ - memcpy(p, template3, sizeof(template3)-1); - p += sizeof(template3)-1; - memcpy(p, &sha1_handshake[0], handshake_sz); - p += handshake_sz; - - /* template4 */ - memcpy(p, template4, sizeof(template4)-1); - p += sizeof(template4)-1; + /* template_end */ + memcpy(p, template_end, sizeof(template_end)-1); + p += sizeof(template_end)-1; /* build HTTP response object by hand, since we have the full response already */ struct http_response *r = calloc(1, sizeof(struct http_response)); From d0acdf030eaceeff78b23c1d718c7799e06cf4ca Mon Sep 17 00:00:00 2001 From: Jessie Murray Date: Sat, 24 Jul 2021 14:05:14 -0700 Subject: [PATCH 02/25] Report WS disconnection at DEBUG to avoid log spam --- src/websocket.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/websocket.c b/src/websocket.c index 1b51aff..3d6a1d1 100644 --- a/src/websocket.c +++ b/src/websocket.c @@ -343,7 +343,7 @@ ws_add_data(struct http_client *c) { if(ret != 0) { /* can't process frame. */ - slog(c->s, WEBDIS_WARNING, "ws_add_data: ws_execute failed", 0); + slog(c->s, WEBDIS_DEBUG, "ws_add_data: ws_execute failed", 0); return WS_ERROR; } state = ws_parse_data(c->buffer, c->sz, &c->frame); From 27ad2413d4942503f6c472a1eef5a14b26ec176a Mon Sep 17 00:00:00 2001 From: Jessie Murray Date: Sat, 24 Jul 2021 14:05:28 -0700 Subject: [PATCH 03/25] Immediately close WS connection on error The implementation was waiting for the client, which leaves some hanging even after they called close(). This mirrors the behavior for HTTP connections in client.c where close() is called right before http_client_free. --- src/worker.c | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/worker.c b/src/worker.c index d5decef..876dd16 100644 --- a/src/worker.c +++ b/src/worker.c @@ -55,7 +55,10 @@ worker_can_read(int fd, short event, void *p) { if(c->is_websocket) { /* Got websocket data */ - ws_add_data(c); + int add_ret = ws_add_data(c); + if(add_ret == WS_ERROR) { + c->broken = 1; /* likely connection was closed */ + } } else { /* run parser */ nparsed = http_client_execute(c); @@ -87,6 +90,9 @@ worker_can_read(int fd, short event, void *p) { } if(c->broken) { /* terminate client */ + if(c->is_websocket) { /* only close for WS since HTTP might use keep-alive */ + close(c->fd); + } http_client_free(c); } else { /* start monitoring input again */ From 8d1b6c40f8184d7e3c3b9e0c3bc5647460b82cbb Mon Sep 17 00:00:00 2001 From: Jessie Murray Date: Sat, 24 Jul 2021 14:19:07 -0700 Subject: [PATCH 04/25] Fix instant vs overall rate in websocket C test The overall rate was incorrect when using non-default intervals. Also simplify the code using plain nanosecond values, and fix the rate for the last event which was based on the number of sleeps instead of the actual time elapsed. --- tests/websocket.c | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/tests/websocket.c b/tests/websocket.c index aa501ca..361f6bd 100644 --- a/tests/websocket.c +++ b/tests/websocket.c @@ -474,6 +474,8 @@ progress_thread_main(void *ptr) { struct timespec ts_start; clock_gettime(CLOCK_MONOTONIC, &ts_start); + long start_nanos = ts_start.tv_sec * 1e9 + ts_start.tv_nsec; /* time of monitoring start */ + long last_print_nanos = start_nanos; while(1) { int sem_received = 0; @@ -501,7 +503,6 @@ progress_thread_main(void *ptr) { sem_received = 1; } #endif - // nanosleep(&ts, NULL); num_sleeps++; int total_sent = 0, total_received = 0, any_broken = 0, num_complete = 0; for (int i = 0; i < pt->worker_count; i++) { @@ -515,11 +516,15 @@ progress_thread_main(void *ptr) { } struct timespec ts_after_sleep; clock_gettime(CLOCK_MONOTONIC, &ts_after_sleep); + long after_sleep_nanos = ts_after_sleep.tv_sec * 1e9 + ts_after_sleep.tv_nsec; + long total_nanos = after_sleep_nanos - start_nanos; /* total time spent so far */ + fprintf(stderr, "After %0.2f sec: %'d messages sent, %'d received (%.02f%%). Instant rate: %'ld/sec, overall rate: %'ld/sec\n", ((float)((ts_after_sleep.tv_sec * 1e9 + ts_after_sleep.tv_nsec) - (ts_start.tv_sec * 1e9 + ts_start.tv_nsec))) / (float)1e9, total_sent, total_received, 100.0f * (float)total_received / (float)(pt->worker_count * pt->msg_target), - lroundf((float)(total_received - last_received) / pt->interval_sec), - lroundf((float)total_received / ((float)num_sleeps) * pt->interval_sec)); + lroundf((float)(total_received - last_received) / (((float)(after_sleep_nanos - last_print_nanos)) / 1e9f)), + lroundf((float)total_received / (((float)total_nanos) / 1e9f))); + last_print_nanos = after_sleep_nanos; /* time of last print */ if (sem_received || total_received == pt->msg_target * pt->worker_count || any_broken || num_complete == pt->worker_count) { break; From f86bad3bc8050ecbcfae4ec6fc62b7661a15362d Mon Sep 17 00:00:00 2001 From: Jessie Murray Date: Sat, 24 Jul 2021 14:56:59 -0700 Subject: [PATCH 05/25] Add Python-based WebSockets tests 1. Create WS tests with the same structure as basic.py 2. Add JSON tests 3. Add "raw" tests --- tests/requirements.txt | 1 + tests/ws-tests.py | 96 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 97 insertions(+) create mode 100644 tests/requirements.txt create mode 100755 tests/ws-tests.py diff --git a/tests/requirements.txt b/tests/requirements.txt new file mode 100644 index 0000000..66a496d --- /dev/null +++ b/tests/requirements.txt @@ -0,0 +1 @@ +websocket-client>=1.1.0 diff --git a/tests/ws-tests.py b/tests/ws-tests.py new file mode 100755 index 0000000..1c49b2c --- /dev/null +++ b/tests/ws-tests.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python3 + +import abc +import json +import os +import unittest +import uuid +from websocket import create_connection + +host = os.getenv('WEBDIS_HOST', '127.0.0.1') +port = int(os.getenv('WEBDIS_PORT', 7379)) + +class TestWebdis(unittest.TestCase): + def setUp(self) -> None: + self.ws = create_connection(f'ws://{host}:{port}/.{self.format()}') + + def tearDown(self) -> None: + self.ws.close() + + def exec(self, cmd, *args): + self.ws.send(self.serialize(cmd, *args)) + return self.deserialize(self.ws.recv()) + + def clean_key(self): + """Returns a key that was just deleted""" + key = str(uuid.uuid4()) + self.exec('DEL', key) + return key + + @abc.abstractmethod + def format(self): + """Returns the format to use (added after a dot to the WS URI)""" + return + + @abc.abstractmethod + def serialize(self, cmd): + """Serializes a command according to the format being tested""" + return + + @abc.abstractmethod + def deserialize(self, response): + """Deserializes a response according to the format being tested""" + return + + +class TestJson(TestWebdis): + def format(self): + return 'json' + + def serialize(self, cmd, *args): + return json.dumps([cmd] + list(args)) + + def deserialize(self, response): + return json.loads(response) + + def test_ping(self): + self.assertEqual(self.exec('PING'), {'PING': [True, 'PONG']}) + + def test_multiple_messages(self): + key = self.clean_key() + n = 100 + for i in range(n): + lpush_response = self.exec('LPUSH', key, f'value-{i}') + self.assertEqual(lpush_response, {'LPUSH': i + 1}) + self.assertEqual(self.exec('LLEN', key), {'LLEN': n}) + + +class TestRaw(TestWebdis): + def format(self): + return 'raw' + + def serialize(self, cmd, *args): + buffer = f"*{1 + len(args)}\r\n${len(cmd)}\r\n{cmd}\r\n" + for arg in args: + buffer += f"${len(arg)}\r\n{arg}\r\n" + return buffer + + def deserialize(self, response): + return response # we'll just assert using the raw protocol + + def test_ping(self): + self.assertEqual(self.exec('PING'), "+PONG\r\n") + + def test_get_set(self): + key = self.clean_key() + value = str(uuid.uuid4()) + not_found_response = self.exec('GET', key) + self.assertEqual(not_found_response, "$-1\r\n") # Redis protocol response for "not found" + set_response = self.exec('SET', key, value) + self.assertEqual(set_response, "+OK\r\n") + get_response = self.exec('GET', key) + self.assertEqual(get_response, f"${len(value)}\r\n{value}\r\n") + + +if __name__ == '__main__': + unittest.main() From 6383cd48dda75598a2727d2aa5d198313bf387c3 Mon Sep 17 00:00:00 2001 From: Jessie Murray Date: Sat, 24 Jul 2021 17:00:32 -0700 Subject: [PATCH 06/25] Add pub-sub test using WebSockets (disabled) Add new pub-sub test using WebSockets, disabled by default due to message ordering not matching what is expected. Enable the test with `PUBSUB=1 ./ws-tests.py` --- tests/ws-tests.py | 70 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) diff --git a/tests/ws-tests.py b/tests/ws-tests.py index 1c49b2c..6a267b5 100755 --- a/tests/ws-tests.py +++ b/tests/ws-tests.py @@ -92,5 +92,75 @@ class TestRaw(TestWebdis): self.assertEqual(get_response, f"${len(value)}\r\n{value}\r\n") +@unittest.skipIf(os.getenv('PUBSUB') != '1', "pub-sub test fail due to invalid ordering") +class TestPubSub(unittest.TestCase): + def setUp(self): + self.publisher = create_connection(f'ws://{host}:{port}/.json') + self.subscriber = create_connection(f'ws://{host}:{port}/.json') + + def tearDown(self): + self.publisher.close() + self.subscriber.close() + + def serialize(self, cmd, *args): + return json.dumps([cmd] + list(args)) + + def deserialize(self, response): + return json.loads(response) + + def test_publish_subscribe(self): + channel_count = 2 + message_count_per_channel = 8 + channels = list(str(uuid.uuid4()) for i in range(channel_count)) + + # subscribe to all channels + sub_count = 0 + for channel in channels: + self.subscriber.send(self.serialize('SUBSCRIBE', channel)) + unsub_response = self.deserialize(self.subscriber.recv()) + sub_count += 1 + self.assertEqual(unsub_response, {'SUBSCRIBE': ['subscribe', channel, sub_count]}) + + # send messages to all channels + prefix = 'message-' + for i in range(message_count_per_channel): + for channel in channels: + message = f'{prefix}{i}' + self.publisher.send(self.serialize('PUBLISH', channel, message)) + + received_per_channel = dict((channel, []) for channel in channels) + for j in range(channel_count * message_count_per_channel): + received = self.deserialize(self.subscriber.recv()) + print('received:', received) + # expected: {'SUBSCRIBE': ['message', $channel, $message]} + self.assertTrue(received, 'SUBSCRIBE' in received) + sub_contents = received['SUBSCRIBE'] + self.assertEqual(len(sub_contents), 3) + + self.assertEqual(sub_contents[0], 'message') # first element is the message type, here a push + channel = sub_contents[1] + self.assertTrue(channel in channels) # second is the channel + received_per_channel[channel].append(sub_contents[2]) # third, add to list of messages received for this channel + + # unsubscribe from all channels + subs_remaining = channel_count + for channel in channels: + self.subscriber.send(self.serialize('UNSUBSCRIBE', channel)) + subs_remaining -= 1 + unsub_response = self.deserialize(self.subscriber.recv()) + self.assertEqual(unsub_response, {'SUBSCRIBE': ['unsubscribe', channel, subs_remaining]}) + + # check that we received all messages + for channel in channels: + self.assertEqual(len(received_per_channel[channel]), message_count_per_channel) + + # check that we received them *in order* + for i in range(message_count_per_channel): + for channel in channels: + expected = f'{prefix}{i}' + self.assertEqual(received_per_channel[channel][i], expected, + f'In {channel}: expected at offset {i} was "{expected}", actual was: "{received_per_channel[channel][i]}"') + + if __name__ == '__main__': unittest.main() From 052458e876463fb29007562a3922b8fc20a727d1 Mon Sep 17 00:00:00 2001 From: Jessie Murray Date: Sat, 24 Jul 2021 20:06:17 -0700 Subject: [PATCH 07/25] Refactor WS code building raw http_response objects Both ws_handshake_reply and ws_reply build http_response objects without using the status code or headers, this code can be refactored to use a single method. --- src/http.c | 22 ++++++++++++++++++++++ src/http.h | 3 +++ src/websocket.c | 19 +++++-------------- 3 files changed, 30 insertions(+), 14 deletions(-) diff --git a/src/http.c b/src/http.c index 66f8135..cd31ea0 100644 --- a/src/http.c +++ b/src/http.c @@ -16,6 +16,10 @@ http_response_init(struct worker *w, int code, const char *msg) { /* create object */ struct http_response *r = calloc(1, sizeof(struct http_response)); + if(!r) { + if(w && w->s) slog(w->s, WEBDIS_ERROR, "Failed to allocate http_response", 0); + return NULL; + } r->code = code; r->msg = msg; @@ -43,6 +47,24 @@ http_response_init(struct worker *w, int code, const char *msg) { return r; } +struct http_response * +http_response_init_with_buffer(struct worker *w, char *data, size_t data_sz, int keep_alive) { + + struct http_response *r = calloc(1, sizeof(struct http_response)); + if(!r) { + if(w && w->s) slog(w->s, WEBDIS_ERROR, "Failed to allocate http_response with buffer", 0); + return NULL; + } + r->w = w; + + /* provide buffer directly */ + r->out = data; + r->out_sz = data_sz; + r->sent = 0; + r->keep_alive = keep_alive; + return r; +} + void http_response_set_header(struct http_response *r, const char *k, const char *v) { diff --git a/src/http.h b/src/http.h index eac1ffe..8deb542 100644 --- a/src/http.h +++ b/src/http.h @@ -45,6 +45,9 @@ struct http_response { struct http_response * http_response_init(struct worker *w, int code, const char *msg); +struct http_response * +http_response_init_with_buffer(struct worker *w, char *data, size_t data_sz, int keep_alive); + void http_response_set_header(struct http_response *r, const char *k, const char *v); diff --git a/src/websocket.c b/src/websocket.c index 3d6a1d1..e0891fd 100644 --- a/src/websocket.c +++ b/src/websocket.c @@ -172,17 +172,13 @@ ws_handshake_reply(struct http_client *c) { p += sizeof(template_end)-1; /* build HTTP response object by hand, since we have the full response already */ - struct http_response *r = calloc(1, sizeof(struct http_response)); + struct http_response *r = http_response_init_with_buffer(c->w, buffer, sz, 1); if(!r) { slog(c->s, WEBDIS_ERROR, "Failed to allocate response for WS handshake", 0); free(buffer); return -1; } - r->w = c->w; - r->keep_alive = 1; - r->out = buffer; - r->out_sz = sz; - r->sent = 0; + http_schedule_write(c->fd, r); /* will free buffer and response once sent */ return 0; @@ -386,20 +382,15 @@ ws_reply(struct cmd *cmd, const char *p, size_t sz) { frame_sz = sz + 10; } - /* send WS frame */ - r = http_response_init(cmd->w, 0, NULL); + /* mark as keep alive, otherwise we'll close the connection after the first reply */ + r = http_response_init_with_buffer(cmd->w, frame, frame_sz, 1); if (r == NULL) { free(frame); slog(cmd->w->s, WEBDIS_ERROR, "Failed response allocation in ws_reply", 0); return -1; } - /* mark as keep alive, otherwise we'll close the connection after the first reply */ - r->keep_alive = 1; - - r->out = frame; - r->out_sz = frame_sz; - r->sent = 0; + /* send WS frame */ http_schedule_write(cmd->fd, r); return 0; From b98116abc84fde9ae069c4aeed4075acf53c372c Mon Sep 17 00:00:00 2001 From: Jessie Murray Date: Sat, 24 Jul 2021 22:34:51 -0700 Subject: [PATCH 08/25] Persistent cmd for WS, write buffer for responses 1. Only HTTP-based pub-sub clients were re-using a cmd object, but WS clients were not. This led to the commands sent by a WS client to be processed out of order, just queued to Redis but with no guarantee that they would be de-queued from the event loop in the same order. This change attaches a permanent cmd object (with its associated Redis context) to WS clients just like pub-sub clients do. 2. WS responses are also no longer sent out of order, but added to a write buffer that is scheduled for writing as long as there is still some data left to send. This replaces the use of http_response which contained extra fields (headers, HTTP response) that were duplicated without ever being sent out. --- src/client.c | 12 +++---- src/client.h | 7 ++-- src/cmd.c | 12 ++++--- src/cmd.h | 3 +- src/formats/common.c | 4 +-- src/formats/json.c | 2 +- src/formats/raw.c | 2 +- src/websocket.c | 79 +++++++++++++++++++++++++++++++++----------- src/websocket.h | 2 +- 9 files changed, 86 insertions(+), 37 deletions(-) diff --git a/src/client.c b/src/client.c index 3e8eda4..c991895 100644 --- a/src/client.c +++ b/src/client.c @@ -286,15 +286,15 @@ http_client_read(struct http_client *c) { if(ret <= 0) { /* broken link, free buffer and client object */ - /* disconnect pub/sub client if there is one. */ - if(c->pub_sub && c->pub_sub->ac) { - struct cmd *cmd = c->pub_sub; + /* disconnect pub/sub or WS client if there is one. */ + if(c->self_cmd && c->self_cmd->ac) { + struct cmd *cmd = c->self_cmd; /* disconnect from all channels */ - redisAsyncDisconnect(c->pub_sub->ac); - // c->pub_sub might be already cleared by an event handler in redisAsyncDisconnect + redisAsyncDisconnect(c->self_cmd->ac); + // c->self_cmd might be already cleared by an event handler in redisAsyncDisconnect cmd->ac = NULL; - c->pub_sub = NULL; + c->self_cmd = NULL; /* delete command object */ cmd_free(cmd); diff --git a/src/client.h b/src/client.h index 6b38992..ae571fe 100644 --- a/src/client.h +++ b/src/client.h @@ -61,9 +61,12 @@ struct http_client { char *separator; /* list separator for raw lists */ char *filename; /* content-disposition */ - struct cmd *pub_sub; + struct cmd *self_cmd; - struct ws_msg *frame; /* websocket frame */ + struct ws_msg *frame; /* websocket frame (containing *received* data) */ + struct event ws_wev; /* websocket write event */ + struct evbuffer *ws_wbuf; /* write buffer for websocket responses */ + int ws_scheduled_write; /* whether we are already scheduled to send out WS data */ }; struct http_client * diff --git a/src/cmd.c b/src/cmd.c index 290df87..03a2821 100644 --- a/src/cmd.c +++ b/src/cmd.c @@ -22,11 +22,12 @@ #include struct cmd * -cmd_new(int count) { +cmd_new(struct http_client *client, int count) { struct cmd *c = calloc(1, sizeof(struct cmd)); c->count = count; + c->http_client = client; c->argv = calloc(count, sizeof(char*)); c->argv_len = calloc(count, sizeof(size_t)); @@ -164,7 +165,7 @@ cmd_run(struct worker *w, struct http_client *client, return CMD_PARAM_ERROR; } - cmd = cmd_new(param_count); + cmd = cmd_new(client, param_count); cmd->fd = client->fd; cmd->database = w->s->cfg->database; @@ -224,7 +225,7 @@ cmd_run(struct worker *w, struct http_client *client, cmd->ac = (redisAsyncContext*)pool_connect(w->pool, cmd->database, 0); /* register with the client, used upon disconnection */ - client->pub_sub = cmd; + client->self_cmd = cmd; cmd->pub_sub_client = client; } else if(cmd->database != w->s->cfg->database) { /* create a new connection to Redis for custom DBs */ @@ -276,7 +277,7 @@ cmd_run(struct worker *w, struct http_client *client, } /* failed to find a suitable connection to Redis. */ cmd_free(cmd); - client->pub_sub = NULL; + client->self_cmd = NULL; return CMD_REDIS_UNAVAIL; } @@ -370,6 +371,9 @@ cmd_select_format(struct http_client *client, struct cmd *cmd, int cmd_is_subscribe(struct cmd *cmd) { + if(cmd->pub_sub_client) { /* persistent command */ + return 1; + } if(cmd->count >= 1 && cmd->argv[0] && (strncasecmp(cmd->argv[0], "SUBSCRIBE", cmd->argv_len[0]) == 0 || strncasecmp(cmd->argv[0], "PSUBSCRIBE", cmd->argv_len[0]) == 0)) { diff --git a/src/cmd.h b/src/cmd.h index e216bde..196a732 100644 --- a/src/cmd.h +++ b/src/cmd.h @@ -43,6 +43,7 @@ struct cmd { int http_version; int database; + struct http_client *http_client; struct http_client *pub_sub_client; redisAsyncContext *ac; struct worker *w; @@ -54,7 +55,7 @@ struct subscription { }; struct cmd * -cmd_new(int count); +cmd_new(struct http_client *c, int count); void cmd_free(struct cmd *c); diff --git a/src/formats/common.c b/src/formats/common.c index 296b822..7842935 100644 --- a/src/formats/common.c +++ b/src/formats/common.c @@ -48,7 +48,7 @@ format_send_error(struct cmd *cmd, short code, const char *msg) { /* for pub/sub, remove command from client */ if(cmd->pub_sub_client) { - cmd->pub_sub_client->pub_sub = NULL; + cmd->pub_sub_client->self_cmd = NULL; } else { cmd_free(cmd); } @@ -62,7 +62,7 @@ format_send_reply(struct cmd *cmd, const char *p, size_t sz, const char *content struct http_response *resp; if(cmd->is_websocket) { - ws_reply(cmd, p, sz); + ws_frame_and_send_response(cmd, p, sz); /* If it's a subscribe command, there'll be more responses */ if(!cmd_is_subscribe(cmd)) diff --git a/src/formats/json.c b/src/formats/json.c index 006bb1f..0cbcb52 100644 --- a/src/formats/json.c +++ b/src/formats/json.c @@ -522,7 +522,7 @@ json_ws_extract(struct http_client *c, const char *p, size_t sz) { } /* create command and add args */ - cmd = cmd_new(argc); + cmd = cmd_new(c, argc); for(i = 0, cur = 0; i < json_array_size(j); ++i) { json_t *jelem = json_array_get(j, i); char *tmp; diff --git a/src/formats/raw.c b/src/formats/raw.c index 20c3927..7caff97 100644 --- a/src/formats/raw.c +++ b/src/formats/raw.c @@ -62,7 +62,7 @@ raw_ws_extract(struct http_client *c, const char *p, size_t sz) { } /* create cmd object */ - cmd = cmd_new(reply->elements); + cmd = cmd_new(c, reply->elements); for(i = 0; i < reply->elements; ++i) { redisReply *ri = reply->element[i]; diff --git a/src/websocket.c b/src/websocket.c index e0891fd..0bd337b 100644 --- a/src/websocket.c +++ b/src/websocket.c @@ -19,6 +19,9 @@ #include #include +static void +ws_schedule_write(struct http_client *c); + /** * This code uses the WebSocket specification from RFC 6455. * A copy is available at http://www.rfc-editor.org/rfc/rfc6455.txt @@ -171,15 +174,22 @@ ws_handshake_reply(struct http_client *c) { memcpy(p, template_end, sizeof(template_end)-1); p += sizeof(template_end)-1; - /* build HTTP response object by hand, since we have the full response already */ - struct http_response *r = http_response_init_with_buffer(c->w, buffer, sz, 1); - if(!r) { + /* create buffer that will hold data to send out */ + c->ws_wbuf = evbuffer_new(); + if(!c->ws_wbuf) { slog(c->s, WEBDIS_ERROR, "Failed to allocate response for WS handshake", 0); free(buffer); return -1; } - http_schedule_write(c->fd, r); /* will free buffer and response once sent */ + int add_ret = evbuffer_add(c->ws_wbuf, buffer, sz); + if(add_ret < 0) { + slog(c->s, WEBDIS_ERROR, "Failed to add response for WS handshake", 0); + free(buffer); + return -1; + } + + ws_schedule_write(c); /* will free buffer and response once sent */ return 0; } @@ -210,20 +220,17 @@ ws_execute(struct http_client *c, const char *frame, size_t frame_len) { cmd_setup(cmd, c); cmd->is_websocket = 1; - if (c->pub_sub != NULL) { + if (c->self_cmd != NULL) { /* This client already has its own connection - * to Redis due to a subscription; use it from + * to Redis from a previous command; use it from * now on. */ - cmd->ac = c->pub_sub->ac; - } else if (cmd_is_subscribe(cmd)) { - /* New subscribe command; make new Redis context + cmd->ac = c->self_cmd->ac; + } else { + /* First WS command; make new Redis context * for this client */ cmd->ac = pool_connect(c->w->pool, cmd->database, 0); - c->pub_sub = cmd; + c->self_cmd = cmd; cmd->pub_sub_client = c; - } else { - /* get Redis connection from pool */ - cmd->ac = (redisAsyncContext*)pool_get_context(c->w->pool); } /* send it off */ @@ -348,11 +355,10 @@ ws_add_data(struct http_client *c) { } int -ws_reply(struct cmd *cmd, const char *p, size_t sz) { +ws_frame_and_send_response(struct cmd *cmd, const char *p, size_t sz) { char *frame = malloc(sz + 8); /* create frame by prepending header */ size_t frame_sz = 0; - struct http_response *r; if (frame == NULL) return -1; @@ -383,15 +389,50 @@ ws_reply(struct cmd *cmd, const char *p, size_t sz) { } /* mark as keep alive, otherwise we'll close the connection after the first reply */ - r = http_response_init_with_buffer(cmd->w, frame, frame_sz, 1); - if (r == NULL) { + int add_ret = evbuffer_add(cmd->http_client->ws_wbuf, frame, frame_sz); + if (add_ret < 0) { free(frame); - slog(cmd->w->s, WEBDIS_ERROR, "Failed response allocation in ws_reply", 0); + slog(cmd->w->s, WEBDIS_ERROR, "Failed response allocation in ws_frame_and_send_response", 0); return -1; } /* send WS frame */ - http_schedule_write(cmd->fd, r); + ws_schedule_write(cmd->http_client); return 0; } + +static void +ws_can_write(int fd, short event, void *p) { + + int ret; + struct http_client *c = p; + (void)event; + + c->ws_scheduled_write = 0; + + /* send pending data */ + ret = evbuffer_write(c->ws_wbuf, fd); + + if(ret < 0) { + close(fd); + } else if(ret > 0 && evbuffer_get_length(c->ws_wbuf) > 0) { /* more data to send */ + ws_schedule_write(c); + } +} + +static void +ws_schedule_write(struct http_client *c) { + + if(c->ws_scheduled_write) { + return; + } + event_set(&c->ws_wev, c->fd, EV_WRITE, ws_can_write, c); + event_base_set(c->w->base, &c->ws_wev); + int ret = event_add(&c->ws_wev, NULL); + if(ret == 0) { + c->ws_scheduled_write = 1; + } else { /* could not schedule write */ + slog(c->w->s, WEBDIS_ERROR, "Could not schedule WS write", 0); + } +} diff --git a/src/websocket.h b/src/websocket.h index dc9d764..09886ca 100644 --- a/src/websocket.h +++ b/src/websocket.h @@ -25,6 +25,6 @@ enum ws_state ws_add_data(struct http_client *c); int -ws_reply(struct cmd *cmd, const char *p, size_t sz); +ws_frame_and_send_response(struct cmd *cmd, const char *p, size_t sz); #endif From e26d6358e7a216427a88952b52a0dc4302a739fa Mon Sep 17 00:00:00 2001 From: Jessie Murray Date: Tue, 27 Jul 2021 18:23:05 -0700 Subject: [PATCH 09/25] WS: Better reuse of the cmd struct for WS clients For WS clients, reuse a persistent cmd struct attached to the http_client object: take the cmd built from the WS frame, and copy it to the persistent cmd. --- src/cmd.c | 20 +++++++++++++------- src/cmd.h | 3 +++ src/websocket.c | 20 ++++++++++++++++---- 3 files changed, 32 insertions(+), 11 deletions(-) diff --git a/src/cmd.c b/src/cmd.c index 03a2821..beaa646 100644 --- a/src/cmd.c +++ b/src/cmd.c @@ -35,11 +35,22 @@ cmd_new(struct http_client *client, int count) { return c; } +void +cmd_free_argv(struct cmd *c) { + + int i; + fprintf(stderr, "%s: %p\n", __func__, c); + for(i = 0; i < c->count; ++i) { + free((char*)c->argv[i]); + } + + free(c->argv); + free(c->argv_len); +} void cmd_free(struct cmd *c) { - int i; if(!c) return; free(c->jsonp); @@ -53,12 +64,7 @@ cmd_free(struct cmd *c) { pool_free_context(c->ac); } - for(i = 0; i < c->count; ++i) { - free((char*)c->argv[i]); - } - - free(c->argv); - free(c->argv_len); + cmd_free_argv(c); free(c); } diff --git a/src/cmd.h b/src/cmd.h index 196a732..78b5bce 100644 --- a/src/cmd.h +++ b/src/cmd.h @@ -57,6 +57,9 @@ struct subscription { struct cmd * cmd_new(struct http_client *c, int count); +void +cmd_free_argv(struct cmd *c); + void cmd_free(struct cmd *c); diff --git a/src/websocket.c b/src/websocket.c index 0bd337b..740e7b0 100644 --- a/src/websocket.c +++ b/src/websocket.c @@ -216,16 +216,28 @@ ws_execute(struct http_client *c, const char *frame, size_t frame_len) { struct cmd *cmd = fun_extract(c, frame, frame_len); if(cmd) { - /* copy client info into cmd. */ - cmd_setup(cmd, c); cmd->is_websocket = 1; if (c->self_cmd != NULL) { /* This client already has its own connection * to Redis from a previous command; use it from * now on. */ - cmd->ac = c->self_cmd->ac; + /* free args for the previous cmd */ + cmd_free_argv(c->self_cmd); + /* copy args from what we just parsed to the persistent command */ + c->self_cmd->count = cmd->count; + c->self_cmd->argv = cmd->argv; + c->self_cmd->argv_len = cmd->argv_len; + cmd->argv = NULL; + cmd->argv_len = NULL; + cmd->count = 0; + cmd_free(cmd); + + cmd = c->self_cmd; /* replace pointer since we're about to pass it to cmd_send */ } else { + /* copy client info into cmd. */ + cmd_setup(cmd, c); + /* First WS command; make new Redis context * for this client */ cmd->ac = pool_connect(c->w->pool, cmd->database, 0); @@ -390,8 +402,8 @@ ws_frame_and_send_response(struct cmd *cmd, const char *p, size_t sz) { /* mark as keep alive, otherwise we'll close the connection after the first reply */ int add_ret = evbuffer_add(cmd->http_client->ws_wbuf, frame, frame_sz); + free(frame); /* no longer needed once added to buffer */ if (add_ret < 0) { - free(frame); slog(cmd->w->s, WEBDIS_ERROR, "Failed response allocation in ws_frame_and_send_response", 0); return -1; } From 6b090b4edeade1cd4d8a2315bd4b3056e6944026 Mon Sep 17 00:00:00 2001 From: Jessie Murray Date: Fri, 30 Jul 2021 12:33:42 -0700 Subject: [PATCH 10/25] Large refactoring of WS code 1. Introduce ws_client struct 2. Handle all communications from websocket.c for WS clients 3. Always use a dedicated Redis connection for WS clients 4. Add rbuf & wbuf evbuffers for incoming & outgoing WS data 5. Use event_base_once to control R/W events 6. WS test: make sure to read complete HTTP response --- src/client.h | 5 +- src/cmd.c | 1 - src/formats/common.c | 5 +- src/pool.c | 2 +- src/websocket.c | 284 +++++++++++++++++++++++++++++++------------ src/websocket.h | 36 +++++- src/worker.c | 40 ++++-- tests/websocket.c | 11 +- tests/ws-tests.py | 8 +- 9 files changed, 282 insertions(+), 110 deletions(-) diff --git a/src/client.h b/src/client.h index ae571fe..32e4c74 100644 --- a/src/client.h +++ b/src/client.h @@ -63,10 +63,7 @@ struct http_client { struct cmd *self_cmd; - struct ws_msg *frame; /* websocket frame (containing *received* data) */ - struct event ws_wev; /* websocket write event */ - struct evbuffer *ws_wbuf; /* write buffer for websocket responses */ - int ws_scheduled_write; /* whether we are already scheduled to send out WS data */ + struct ws_client *ws; /* websocket client */ }; struct http_client * diff --git a/src/cmd.c b/src/cmd.c index beaa646..a206eac 100644 --- a/src/cmd.c +++ b/src/cmd.c @@ -39,7 +39,6 @@ void cmd_free_argv(struct cmd *c) { int i; - fprintf(stderr, "%s: %p\n", __func__, c); for(i = 0; i < c->count; ++i) { free((char*)c->argv[i]); } diff --git a/src/formats/common.c b/src/formats/common.c index 7842935..94f88e6 100644 --- a/src/formats/common.c +++ b/src/formats/common.c @@ -7,6 +7,8 @@ #include "md5/md5.h" #include #include +#include +#include /* TODO: replace this with a faster hash function? */ char *etag_new(const char *p, size_t sz) { @@ -62,7 +64,8 @@ format_send_reply(struct cmd *cmd, const char *p, size_t sz, const char *content struct http_response *resp; if(cmd->is_websocket) { - ws_frame_and_send_response(cmd, p, sz); + + ws_frame_and_send_response(cmd->http_client->ws, WS_BINARY_FRAME, p, sz); /* If it's a subscribe command, there'll be more responses */ if(!cmd_is_subscribe(cmd)) diff --git a/src/pool.c b/src/pool.c index 706006b..4836d27 100644 --- a/src/pool.c +++ b/src/pool.c @@ -2,6 +2,7 @@ #include "worker.h" #include "conf.h" #include "server.h" +#include "formats/common.h" #include #include @@ -30,7 +31,6 @@ pool_free_context(redisAsyncContext *ac) { if (ac) { redisAsyncDisconnect(ac); - redisAsyncFree(ac); } } diff --git a/src/websocket.c b/src/websocket.c index 740e7b0..c8ac76c 100644 --- a/src/websocket.c +++ b/src/websocket.c @@ -7,6 +7,9 @@ #include "pool.h" #include "http.h" #include "slog.h" +#include "server.h" +#include "conf.h" +#include "formats/common.h" /* message parsers */ #include "formats/json.h" @@ -18,9 +21,10 @@ #include #include #include +#include static void -ws_schedule_write(struct http_client *c); +ws_schedule_write(struct ws_client *ws); /** * This code uses the WebSocket specification from RFC 6455. @@ -89,9 +93,54 @@ ws_compute_handshake(struct http_client *c, char *out, size_t *out_sz) { return 0; } +struct ws_client * +ws_client_new(struct http_client *http_client) { + + int db_num = http_client->w->s->cfg->database; + struct ws_client *ws = calloc(1, sizeof(struct ws_client)); + struct evbuffer *rbuf = evbuffer_new(); + struct evbuffer *wbuf = evbuffer_new(); + redisAsyncContext *ac = pool_connect(http_client->w->pool, db_num, 0); + + if(!ws || !rbuf || !wbuf) { + slog(http_client->s, WEBDIS_ERROR, "Failed to allocate memory for WS client", 0); + if(ws) free(ws); + if(rbuf) evbuffer_free(rbuf); + if(wbuf) evbuffer_free(wbuf); + if(ac) redisAsyncFree(ac); + return NULL; + } + + http_client->ws = ws; + ws->http_client = http_client; + ws->rbuf = rbuf; + ws->wbuf = wbuf; + ws->ac = ac; + + return ws; +} + +static void +ws_client_free(struct ws_client *ws) { + + struct http_client *c = ws->http_client; + c->ws = NULL; /* detach */ + evbuffer_free(ws->rbuf); + evbuffer_free(ws->wbuf); + pool_free_context(ws->ac); + if(ws->cmd) { + ws->cmd->ac = NULL; /* we've just free'd it */ + cmd_free(ws->cmd); + } + free(ws); + http_client_free(c); +} + + int -ws_handshake_reply(struct http_client *c) { +ws_handshake_reply(struct ws_client *ws) { + struct http_client *c = ws->http_client; char sha1_handshake[40]; char *buffer = NULL, *p; const char *origin = NULL, *host = NULL; @@ -174,30 +223,23 @@ ws_handshake_reply(struct http_client *c) { memcpy(p, template_end, sizeof(template_end)-1); p += sizeof(template_end)-1; - /* create buffer that will hold data to send out */ - c->ws_wbuf = evbuffer_new(); - if(!c->ws_wbuf) { - slog(c->s, WEBDIS_ERROR, "Failed to allocate response for WS handshake", 0); - free(buffer); - return -1; - } - - int add_ret = evbuffer_add(c->ws_wbuf, buffer, sz); + int add_ret = evbuffer_add(ws->wbuf, buffer, sz); + free(buffer); if(add_ret < 0) { slog(c->s, WEBDIS_ERROR, "Failed to add response for WS handshake", 0); - free(buffer); return -1; } - ws_schedule_write(c); /* will free buffer and response once sent */ + ws_schedule_write(ws); /* will free buffer and response once sent */ return 0; } static int -ws_execute(struct http_client *c, const char *frame, size_t frame_len) { +ws_execute(struct ws_client *ws, struct ws_msg *msg) { + struct http_client *c = ws->http_client; struct cmd*(*fun_extract)(struct http_client *, const char *, size_t) = NULL; formatting_fun fun_reply = NULL; @@ -213,35 +255,36 @@ ws_execute(struct http_client *c, const char *frame, size_t frame_len) { if(fun_extract) { /* Parse websocket frame into a cmd object. */ - struct cmd *cmd = fun_extract(c, frame, frame_len); + struct cmd *cmd = fun_extract(c, msg->payload, msg->payload_sz); if(cmd) { cmd->is_websocket = 1; - if (c->self_cmd != NULL) { - /* This client already has its own connection - * to Redis from a previous command; use it from - * now on. */ + if(ws->cmd != NULL) { + /* This client already has its own connection to Redis + from a previous command; use it from now on. */ + /* free args for the previous cmd */ - cmd_free_argv(c->self_cmd); + cmd_free_argv(ws->cmd); /* copy args from what we just parsed to the persistent command */ - c->self_cmd->count = cmd->count; - c->self_cmd->argv = cmd->argv; - c->self_cmd->argv_len = cmd->argv_len; + ws->cmd->count = cmd->count; + ws->cmd->argv = cmd->argv; + ws->cmd->argv_len = cmd->argv_len; + ws->cmd->pub_sub_client = c; /* mark as persistent, otherwise the Redis context will be freed */ + cmd->argv = NULL; cmd->argv_len = NULL; cmd->count = 0; cmd_free(cmd); - cmd = c->self_cmd; /* replace pointer since we're about to pass it to cmd_send */ + cmd = ws->cmd; /* replace pointer since we're about to pass it to cmd_send */ } else { /* copy client info into cmd. */ cmd_setup(cmd, c); - /* First WS command; make new Redis context - * for this client */ - cmd->ac = pool_connect(c->w->pool, cmd->database, 0); - c->self_cmd = cmd; + /* First WS command; use Redis context from WS client. */ + cmd->ac = ws->ac; + ws->cmd = cmd; cmd->pub_sub_client = c; } @@ -256,8 +299,13 @@ ws_execute(struct http_client *c, const char *frame, size_t frame_len) { } static struct ws_msg * -ws_msg_new() { - return calloc(1, sizeof(struct ws_msg)); +ws_msg_new(enum ws_frame_type frame_type) { + struct ws_msg *msg = calloc(1, sizeof(struct ws_msg)); + if(!msg) { + return NULL; + } + msg->type = frame_type; + return msg; } static void @@ -278,26 +326,38 @@ ws_msg_add(struct ws_msg *m, const char *p, size_t psz, const unsigned char *mas } static void -ws_msg_free(struct ws_msg **m) { +ws_msg_free(struct ws_msg *m) { - free((*m)->payload); - free(*m); - *m = NULL; + free(m->payload); + free(m); } +/* checks to see if we have a complete message */ static enum ws_state -ws_parse_data(const char *frame, size_t sz, struct ws_msg **msg) { +ws_peek_data(struct ws_client *ws, struct ws_msg **out_msg) { int has_mask; uint64_t len; const char *p; + char *frame; unsigned char mask[4]; + char fin_bit_set; + enum ws_frame_type frame_type; /* parse frame and extract contents */ + size_t sz = evbuffer_get_length(ws->rbuf); if(sz < 8) { - return WS_READING; + return WS_READING; /* need more data */ } + /* copy into "frame" to process it */ + frame = malloc(sz); + if(!frame) { + return WS_ERROR; + } + evbuffer_remove(ws->rbuf, frame, sz); + fin_bit_set = frame[0] & 0x80 ? 1 : 0; + frame_type = frame[0] & 0x0F; /* lower 4 bits of first byte */ has_mask = frame[1] & 0x80 ? 1:0; /* get payload length */ @@ -316,62 +376,88 @@ ws_parse_data(const char *frame, size_t sz, struct ws_msg **msg) { p = frame + 10 + (has_mask ? 4 : 0); if(has_mask) memcpy(&mask, frame + 10, sizeof(mask)); } else { + free(frame); return WS_ERROR; } /* we now have the (possibly masked) data starting in p, and its length. */ if(len > sz - (p - frame)) { /* not enough data */ + evbuffer_prepend(ws->rbuf, frame, sz); /* put the whole frame back */ + free(frame); return WS_READING; } - if(!*msg) - *msg = ws_msg_new(); - ws_msg_add(*msg, p, len, has_mask ? mask : NULL); - (*msg)->total_sz += len + (p - frame); + if(out_msg) { /* we're extracting the message */ + struct ws_msg *msg = ws_msg_new(frame_type); + ws_msg_add(msg, p, len, has_mask ? mask : NULL); + size_t processed_sz = len + (p - frame); /* length of data + header bytes between frame start and payload */ + msg->total_sz += processed_sz; + *out_msg = msg; - if(frame[0] & 0x80) { /* FIN bit set */ + evbuffer_prepend(ws->rbuf, frame + len, sz - processed_sz); /* remove processed data */ + } else { /* we're just peeking */ + evbuffer_prepend(ws->rbuf, frame, sz); /* put the whole frame back */ + } + free(frame); + + if(fin_bit_set) { /* FIN bit set */ return WS_MSG_COMPLETE; } else { return WS_READING; /* need more data */ } } - /** * Process some data just received on the socket. + * Returns the number of messages processed in out_processed, if non-NULL. */ enum ws_state -ws_add_data(struct http_client *c) { +ws_process_read_data(struct ws_client *ws, unsigned int *out_processed) { enum ws_state state; + if(out_processed) *out_processed = 0; - state = ws_parse_data(c->buffer, c->sz, &c->frame); + state = ws_peek_data(ws, NULL); /* check for complete message */ while(state == WS_MSG_COMPLETE) { - int ret = ws_execute(c, c->frame->payload, c->frame->payload_sz); - - /* remove frame from client buffer */ - http_client_remove_data(c, c->frame->total_sz); + int ret = 0; + struct ws_msg *msg; + ws_peek_data(ws, &msg); /* extract message */ + + if(msg->type == WS_TEXT_FRAME || msg->type == WS_BINARY_FRAME) { + ret = ws_execute(ws, msg); + if(out_processed) (*out_processed)++; + } else if(msg->type == WS_PING) { /* respond to ping */ + ws_frame_and_send_response(ws, WS_PONG, msg->payload, msg->payload_sz); + } else if(msg->type == WS_CONNECTION_CLOSE) { /* respond to close frame */ + ws->close_after_events = 1; + ws_frame_and_send_response(ws, WS_CONNECTION_CLOSE, msg->payload, msg->payload_sz); + } else { + char format[] = "Received unexpected WS frame type: 0x%x"; + char error[(sizeof format)]; + snprintf(error, sizeof(error), format, msg->type); + slog(ws->http_client->s, WEBDIS_WARNING, error, 0); + } - /* free frame and set back to NULL */ - ws_msg_free(&c->frame); + /* free frame */ + ws_msg_free(msg); if(ret != 0) { /* can't process frame. */ - slog(c->s, WEBDIS_DEBUG, "ws_add_data: ws_execute failed", 0); + slog(ws->http_client->s, WEBDIS_DEBUG, "ws_process_read_data: ws_execute failed", 0); return WS_ERROR; } - state = ws_parse_data(c->buffer, c->sz, &c->frame); + state = ws_peek_data(ws, NULL); } return state; } int -ws_frame_and_send_response(struct cmd *cmd, const char *p, size_t sz) { +ws_frame_and_send_response(struct ws_client *ws, enum ws_frame_type frame_type, const char *p, size_t sz) { char *frame = malloc(sz + 8); /* create frame by prepending header */ size_t frame_sz = 0; - if (frame == NULL) + if(frame == NULL) return -1; /* @@ -381,12 +467,12 @@ ws_frame_and_send_response(struct cmd *cmd, const char *p, size_t sz) { following 8 bytes interpreted as a 64-bit unsigned integer (the most significant bit MUST be 0) are the payload length. */ - frame[0] = '\x81'; + frame[0] = 0x80 | frame_type; /* frame type + EOM bit */ if(sz <= 125) { frame[1] = sz; memcpy(frame + 2, p, sz); frame_sz = sz + 2; - } else if (sz <= 65536) { + } else if(sz <= 65536) { uint16_t sz16 = htons(sz); frame[1] = 126; memcpy(frame + 2, &sz16, 2); @@ -401,50 +487,90 @@ ws_frame_and_send_response(struct cmd *cmd, const char *p, size_t sz) { } /* mark as keep alive, otherwise we'll close the connection after the first reply */ - int add_ret = evbuffer_add(cmd->http_client->ws_wbuf, frame, frame_sz); + int add_ret = evbuffer_add(ws->wbuf, frame, frame_sz); free(frame); /* no longer needed once added to buffer */ - if (add_ret < 0) { - slog(cmd->w->s, WEBDIS_ERROR, "Failed response allocation in ws_frame_and_send_response", 0); + if(add_ret < 0) { + slog(ws->http_client->w->s, WEBDIS_ERROR, "Failed response allocation in ws_frame_and_send_response", 0); return -1; } /* send WS frame */ - ws_schedule_write(cmd->http_client); + ws_schedule_write(ws); return 0; } +static void +ws_close_if_able(struct ws_client *ws) { + if(ws->scheduled_read || ws->scheduled_write) { + return; /* still waiting for these events to trigger */ + } + ws_client_free(ws); /* will close the socket */ +} + +static void +ws_can_read(int fd, short event, void *p) { + + int ret; + struct ws_client *ws = p; + (void)event; + + /* read pending data */ + ws->scheduled_read = 0; + ret = evbuffer_read(ws->rbuf, fd, 4096); + + if(ret <= 0) { + ws_client_free(ws); /* will close the socket */ + } else if(ws->close_after_events) { + ws_close_if_able(ws); + } else { + ws_process_read_data(ws, NULL); + } +} + + static void ws_can_write(int fd, short event, void *p) { int ret; - struct http_client *c = p; + struct ws_client *ws = p; (void)event; - c->ws_scheduled_write = 0; + ws->scheduled_write = 0; /* send pending data */ - ret = evbuffer_write(c->ws_wbuf, fd); - - if(ret < 0) { - close(fd); - } else if(ret > 0 && evbuffer_get_length(c->ws_wbuf) > 0) { /* more data to send */ - ws_schedule_write(c); + ret = evbuffer_write_atmost(ws->wbuf, fd, 4096); + + if(ret <= 0) { + ws_client_free(ws); /* will close the socket */ + } else if(ret > 0) { + if(evbuffer_get_length(ws->wbuf) > 0) { /* more data to send */ + ws_schedule_write(ws); + } else if(ws->close_after_events) { /* we're done! */ + ws_close_if_able(ws); + } else { + /* check if we can read more data */ + unsigned int processed = 0; + ws_process_read_data(ws, &processed); /* process any pending data we've already read */ + ws_monitor_input(ws); /* let's read more from the client */ + } } } static void -ws_schedule_write(struct http_client *c) { - - if(c->ws_scheduled_write) { - return; +ws_schedule_write(struct ws_client *ws) { + struct http_client *c = ws->http_client; + if(!ws->scheduled_write) { + ws->scheduled_write = 1; + event_base_once(c->w->base, c->fd, EV_WRITE, ws_can_write, ws, NULL); } - event_set(&c->ws_wev, c->fd, EV_WRITE, ws_can_write, c); - event_base_set(c->w->base, &c->ws_wev); - int ret = event_add(&c->ws_wev, NULL); - if(ret == 0) { - c->ws_scheduled_write = 1; - } else { /* could not schedule write */ - slog(c->w->s, WEBDIS_ERROR, "Could not schedule WS write", 0); +} + +void +ws_monitor_input(struct ws_client *ws) { + struct http_client *c = ws->http_client; + if(!ws->scheduled_read) { + ws->scheduled_read = 1; + event_base_once(c->w->base, c->fd, EV_READ, ws_can_read, ws, NULL); } } diff --git a/src/websocket.h b/src/websocket.h index 09886ca..7376546 100644 --- a/src/websocket.h +++ b/src/websocket.h @@ -3,6 +3,8 @@ #include #include +#include +#include struct http_client; struct cmd; @@ -12,19 +14,47 @@ enum ws_state { WS_READING, WS_MSG_COMPLETE}; +enum ws_frame_type { + WS_TEXT_FRAME = 0, + WS_BINARY_FRAME = 1, + WS_CONNECTION_CLOSE = 8, + WS_PING = 9, + WS_PONG = 0xA, + WS_UNKNOWN_FRAME = -1}; + struct ws_msg { + enum ws_frame_type type; char *payload; size_t payload_sz; size_t total_sz; }; +struct ws_client { + struct http_client *http_client; /* parent */ + int scheduled_read; /* set if we are scheduled to read WS data */ + int scheduled_write; /* set if we are scheduled to send out WS data */ + struct evbuffer *rbuf; /* read buffer for incoming data */ + struct evbuffer *wbuf; /* write buffer for outgoing data */ + redisAsyncContext *ac; /* dedicated connection to redis */ + struct cmd *cmd; /* current command */ + /* indicates that we'll close once we've flushed all + buffered data and read what we planned to read */ + int close_after_events; +}; + +struct ws_client * +ws_client_new(struct http_client *http_client); + int -ws_handshake_reply(struct http_client *c); +ws_handshake_reply(struct ws_client *ws); + +void +ws_monitor_input(struct ws_client *ws); enum ws_state -ws_add_data(struct http_client *c); +ws_process_read_data(struct ws_client *ws, unsigned int *out_processed); int -ws_frame_and_send_response(struct cmd *cmd, const char *p, size_t sz); +ws_frame_and_send_response(struct ws_client *ws, enum ws_frame_type type, const char *p, size_t sz); #endif diff --git a/src/worker.c b/src/worker.c index 876dd16..eba6b07 100644 --- a/src/worker.c +++ b/src/worker.c @@ -13,7 +13,8 @@ #include #include #include - +#include +#include "formats/common.h" struct worker * worker_new(struct server *s) { @@ -54,11 +55,6 @@ worker_can_read(int fd, short event, void *p) { } if(c->is_websocket) { - /* Got websocket data */ - int add_ret = ws_add_data(c); - if(add_ret == WS_ERROR) { - c->broken = 1; /* likely connection was closed */ - } } else { /* run parser */ nparsed = http_client_execute(c); @@ -70,13 +66,28 @@ worker_can_read(int fd, short event, void *p) { /* only close if requested *and* we've already read the request in full */ c->broken = 1; } else if(c->is_websocket) { - /* we need to use the remaining (unparsed) data as the body. */ - if(nparsed < ret) { - http_client_add_to_body(c, c->buffer + nparsed + 1, c->sz - nparsed - 1); - ws_handshake_reply(c); - } else { + event_del(&c->ev); + + /* Got websocket data */ + c->ws = ws_client_new(c); + if(!c->ws) { c->broken = 1; + } else { + free(c->buffer); + c->buffer = NULL; + c->sz = 0; + + unsigned int processed = 0; + int process_ret = ws_process_read_data(c->ws, &processed); + if(process_ret == WS_ERROR) { + c->broken = 1; /* likely connection was closed */ + } + + /* send response, and start managing fd from websocket.c */ + ws_handshake_reply(c->ws); } + + /* clean up what remains in HTTP client */ free(c->buffer); c->buffer = NULL; c->sz = 0; @@ -96,7 +107,11 @@ worker_can_read(int fd, short event, void *p) { http_client_free(c); } else { /* start monitoring input again */ - worker_monitor_input(c); + if(c->is_websocket) { /* all communication handled by WS code from now on */ + // ws_monitor_input(c->ws); + } else { + worker_monitor_input(c); + } } } @@ -105,7 +120,6 @@ worker_can_read(int fd, short event, void *p) { */ void worker_monitor_input(struct http_client *c) { - event_set(&c->ev, c->fd, EV_READ, worker_can_read, c); event_base_set(c->w->base, &c->ev); event_add(&c->ev, NULL); diff --git a/tests/websocket.c b/tests/websocket.c index 361f6bd..e20f273 100644 --- a/tests/websocket.c +++ b/tests/websocket.c @@ -263,10 +263,13 @@ websocket_can_read(int fd, short event, void *ptr) { /* http parser will return the offset at which the upgraded protocol begins, which in our case is 1 under the total response size. */ - if (wt->state == WS_SENT_HANDSHAKE || /* haven't encountered end of response yet */ - (wt->parser.upgrade && nparsed != (int)avail_sz -1)) { - wt->debug("UPGRADE *and* we have some data left (nparsed=%d, avail_sz=%lu)\n", nparsed, avail_sz); - continue; + if (wt->state == WS_SENT_HANDSHAKE) { /* haven't encountered end of response yet */ + if (wt->parser.upgrade && nparsed != (int)avail_sz) { + wt->debug("UPGRADE *and* we have some data left (state=%d, nparsed=%d, avail_sz=%lu)\n", wt->state, nparsed, avail_sz); + continue; + } else { /* we just haven't read the entire response yet */ + wait_for_possible_read(wt); + } } else if (wt->state == WS_RECEIVED_HANDSHAKE) { /* we have the full response */ evbuffer_drain(wt->rbuffer, evbuffer_get_length(wt->rbuffer)); } diff --git a/tests/ws-tests.py b/tests/ws-tests.py index 6a267b5..dadd0f7 100755 --- a/tests/ws-tests.py +++ b/tests/ws-tests.py @@ -117,9 +117,9 @@ class TestPubSub(unittest.TestCase): sub_count = 0 for channel in channels: self.subscriber.send(self.serialize('SUBSCRIBE', channel)) - unsub_response = self.deserialize(self.subscriber.recv()) + sub_response = self.deserialize(self.subscriber.recv()) sub_count += 1 - self.assertEqual(unsub_response, {'SUBSCRIBE': ['subscribe', channel, sub_count]}) + self.assertEqual(sub_response, {'SUBSCRIBE': ['subscribe', channel, sub_count]}) # send messages to all channels prefix = 'message-' @@ -127,11 +127,11 @@ class TestPubSub(unittest.TestCase): for channel in channels: message = f'{prefix}{i}' self.publisher.send(self.serialize('PUBLISH', channel, message)) + self.deserialize(self.publisher.recv()) received_per_channel = dict((channel, []) for channel in channels) for j in range(channel_count * message_count_per_channel): received = self.deserialize(self.subscriber.recv()) - print('received:', received) # expected: {'SUBSCRIBE': ['message', $channel, $message]} self.assertTrue(received, 'SUBSCRIBE' in received) sub_contents = received['SUBSCRIBE'] @@ -148,7 +148,7 @@ class TestPubSub(unittest.TestCase): self.subscriber.send(self.serialize('UNSUBSCRIBE', channel)) subs_remaining -= 1 unsub_response = self.deserialize(self.subscriber.recv()) - self.assertEqual(unsub_response, {'SUBSCRIBE': ['unsubscribe', channel, subs_remaining]}) + self.assertEqual(unsub_response, {'UNSUBSCRIBE': ['unsubscribe', channel, subs_remaining]}) # check that we received all messages for channel in channels: From dedfc42c676421338e72bba83a6b4857fb9715cb Mon Sep 17 00:00:00 2001 From: Jessie Murray Date: Sun, 1 Aug 2021 13:19:40 -0700 Subject: [PATCH 11/25] WS: Log commands WS client commands were not being logged, they are now with a "WS: " prefix. This is done at debug level like for HTTP commands. --- src/cmd.c | 1 - src/slog.c | 4 ++-- src/slog.h | 2 ++ src/websocket.c | 26 +++++++++++++++++++++++++- 4 files changed, 29 insertions(+), 4 deletions(-) diff --git a/src/cmd.c b/src/cmd.c index a206eac..df268a3 100644 --- a/src/cmd.c +++ b/src/cmd.c @@ -6,7 +6,6 @@ #include "worker.h" #include "http.h" #include "server.h" -#include "slog.h" #include "formats/json.h" #include "formats/raw.h" diff --git a/src/slog.c b/src/slog.c index 9bae415..c48489d 100644 --- a/src/slog.c +++ b/src/slog.c @@ -95,8 +95,8 @@ slog_internal(struct server *s, log_level level, time_t now; struct tm now_tm, *lt_ret; char time_buf[64]; - char msg[124]; - char line[256]; /* bounds are checked. */ + char msg[1 + SLOG_MSG_MAX_LEN]; + char line[2 * SLOG_MSG_MAX_LEN]; /* bounds are checked. */ int line_sz, ret; if(!s->log.fd) return; diff --git a/src/slog.h b/src/slog.h index 14d9c66..d4b47e8 100644 --- a/src/slog.h +++ b/src/slog.h @@ -1,6 +1,8 @@ #ifndef SLOG_H #define SLOG_H +#define SLOG_MSG_MAX_LEN 124 + typedef enum { WEBDIS_ERROR = 0, WEBDIS_WARNING, diff --git a/src/websocket.c b/src/websocket.c index c8ac76c..af31577 100644 --- a/src/websocket.c +++ b/src/websocket.c @@ -235,6 +235,29 @@ ws_handshake_reply(struct ws_client *ws) { return 0; } +static void +ws_log_cmd(struct ws_client *ws, struct cmd *cmd) { + char log_msg[SLOG_MSG_MAX_LEN]; + char *p = log_msg, *eom = log_msg + sizeof(log_msg) - 1; + if(!slog_enabled(ws->http_client->s, WEBDIS_DEBUG)) { + return; + } + + memset(log_msg, 0, sizeof(log_msg)); + memcpy(p, "WS: ", 4); /* WS prefix */ + p += 4; + + for(int i = 0; p < eom && i < cmd->count; i++) { + *p++ = '/'; + char *arg = cmd->argv[i]; + size_t arg_sz = cmd->argv_len[i]; + size_t copy_sz = arg_sz < (size_t)(eom - p) ? arg_sz : (size_t)(eom - p); + memcpy(p, arg, copy_sz); + p += copy_sz; + } + slog(ws->http_client->s, WEBDIS_DEBUG, log_msg, p - log_msg); +} + static int ws_execute(struct ws_client *ws, struct ws_msg *msg) { @@ -288,7 +311,8 @@ ws_execute(struct ws_client *ws, struct ws_msg *msg) { cmd->pub_sub_client = c; } - /* send it off */ + /* log and execute */ + ws_log_cmd(ws, cmd); cmd_send(cmd, fun_reply); return 0; From 583f6747b3fbea1e35cf973825f89221d48667b2 Mon Sep 17 00:00:00 2001 From: Jessie Murray Date: Sun, 1 Aug 2021 13:55:18 -0700 Subject: [PATCH 12/25] Avoid dereferencing NULL in pool_on_disconnect pool_on_disconnect was assuming a pool object was attached and logging using its server object. It also checked for NULL, but too late. --- src/pool.c | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/pool.c b/src/pool.c index 4836d27..741c7da 100644 --- a/src/pool.c +++ b/src/pool.c @@ -96,6 +96,11 @@ pool_on_disconnect(const redisAsyncContext *ac, int status) { struct pool *p = ac->data; int i = 0; + + if(p == NULL) { /* no need to clean anything here. */ + return; + } + if (status != REDIS_OK) { char format[] = "Error disconnecting: %s"; size_t msg_sz = sizeof(format) - 2 + ((ac && ac->errstr) ? strlen(ac->errstr) : 6); @@ -107,10 +112,6 @@ pool_on_disconnect(const redisAsyncContext *ac, int status) { } } - if(p == NULL) { /* no need to clean anything here. */ - return; - } - /* remove from the pool */ for(i = 0; i < p->count; ++i) { if(p->ac[i] == ac) { From 545d18d84da18d08a7da89bdd785a90c18c723fc Mon Sep 17 00:00:00 2001 From: Jessie Murray Date: Sun, 1 Aug 2021 14:11:07 -0700 Subject: [PATCH 13/25] Send error messages to WS clients if triggered by Redis Also mark the WS client as closing before we close the Redis connection, to avoid its last error callback (if sent) trying to send out data while we're in the middle of freeing the client. --- src/formats/common.c | 4 +++- src/websocket.c | 6 +++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/formats/common.c b/src/formats/common.c index 94f88e6..50e4780 100644 --- a/src/formats/common.c +++ b/src/formats/common.c @@ -46,12 +46,14 @@ format_send_error(struct cmd *cmd, short code, const char *msg) { resp->http_version = cmd->http_version; http_response_set_keep_alive(resp, cmd->keep_alive); http_response_write(resp, cmd->fd); + } else if(cmd->is_websocket && !cmd->http_client->ws->close_after_events) { + ws_frame_and_send_response(cmd->http_client->ws, WS_BINARY_FRAME, msg, strlen(msg)); } /* for pub/sub, remove command from client */ if(cmd->pub_sub_client) { cmd->pub_sub_client->self_cmd = NULL; - } else { + } else if (!cmd->is_websocket) { /* don't free persistent cmd */ cmd_free(cmd); } } diff --git a/src/websocket.c b/src/websocket.c index af31577..27358ea 100644 --- a/src/websocket.c +++ b/src/websocket.c @@ -123,11 +123,14 @@ ws_client_new(struct http_client *http_client) { static void ws_client_free(struct ws_client *ws) { + /* mark WS client as closing to skip the Redis callback */ + ws->close_after_events = 1; + pool_free_context(ws->ac); /* could trigger a cb via format_send_error */ + struct http_client *c = ws->http_client; c->ws = NULL; /* detach */ evbuffer_free(ws->rbuf); evbuffer_free(ws->wbuf); - pool_free_context(ws->ac); if(ws->cmd) { ws->cmd->ac = NULL; /* we've just free'd it */ cmd_free(ws->cmd); @@ -526,6 +529,7 @@ ws_frame_and_send_response(struct ws_client *ws, enum ws_frame_type frame_type, static void ws_close_if_able(struct ws_client *ws) { + ws->close_after_events = 1; /* note that we're closing */ if(ws->scheduled_read || ws->scheduled_write) { return; /* still waiting for these events to trigger */ } From d7703b97b300bc32d31232a53e13ee968c081e51 Mon Sep 17 00:00:00 2001 From: Jessie Murray Date: Sun, 1 Aug 2021 14:58:21 -0700 Subject: [PATCH 14/25] Add pub/sub test in HTML demo 1. Add publish button with channel and message 2. Add subscribe button with channel 3. Change "Clear logs" button to appear when logs are visible --- tests/websocket.html | 103 +++++++++++++++++++++++++++++++++++-------- 1 file changed, 84 insertions(+), 19 deletions(-) diff --git a/tests/websocket.html b/tests/websocket.html index 0f06fc8..276bf84 100644 --- a/tests/websocket.html +++ b/tests/websocket.html @@ -101,6 +101,26 @@ function installBlock(title, type) { +
+
+
+
+
+
+
+
+
+ +
+
+
+
 
+
+
+
+
+
+
 
@@ -113,17 +133,23 @@ function installBlock(title, type) { class Client { - constructor(type, pingSerializer, getSerializer, setSerializer) { + constructor(type, pingSerializer, getSerializer, setSerializer, pubSerializer, subSerializer) { this.type = type; this.pingSerializer = pingSerializer; this.getSerializer = getSerializer; this.setSerializer = setSerializer; + this.pubSerializer = pubSerializer; + this.subSerializer = subSerializer; this.ws = null; + this.connected = false; + this.subscribed = false; + this.logCount = 0; $(`${this.type}-btn-connect`).addEventListener('click', event => { event.preventDefault(); console.log('Connecting...'); - this.ws = new WebSocket(`ws://${ host }:${ port }/.${ this.type }`); + this.ws = new WebSocket(`ws://${host}:${port}/.${this.type}`); + window.ws = this.ws; this.ws.onopen = event => { console.log('Connected'); this.setConnectedState(true); @@ -135,50 +161,83 @@ class Client { }; this.ws.onclose = event => { - $(`${this.type}-btn-connect`).disabled = false; + console.log('ON CLOSE') + $(`${this.type}-btn-connect`).innerText = 'Connect'; this.setConnectedState(false); + this.subscribed = false; }; }); $(`${this.type}-btn-ping`).addEventListener('click', event => { event.preventDefault(); const serialized = this.pingSerializer(); - this.log("sent", serialized); - this.ws.send(serialized); + this.send(serialized); }); $(`${this.type}-btn-set`).addEventListener('click', event => { event.preventDefault(); const serialized = this.setSerializer($(`${this.type}-set-key`).value, $(`${this.type}-set-value`).value); - this.log("sent", serialized); - this.ws.send(serialized); + this.send(serialized); }); $(`${this.type}-btn-get`).addEventListener('click', event => { event.preventDefault(); - const serialized = this.getSerializer($(`${this.type}-set-key`).value); - this.log("sent", serialized); - this.ws.send(serialized); + const serialized = this.getSerializer($(`${this.type}-get-key`).value); + this.send(serialized); + }); + + $(`${this.type}-btn-pub`).addEventListener('click', event => { + event.preventDefault(); + const serialized = this.pubSerializer($(`${this.type}-pub-channel`).value, $(`${this.type}-pub-message`).value); + this.send(serialized); + }); + + $(`${this.type}-btn-sub`).addEventListener('click', event => { + event.preventDefault(); + const serialized = this.subSerializer($(`${this.type}-sub-channel`).value); + try { + this.send(serialized); + this.subscribed = true; + this.setConnectedState(true); + } catch (e) { + } }); $(`${this.type}-btn-clear`).addEventListener('click', event => { event.preventDefault(); $(`${this.type}-log`).innerText = ""; + this.logCount = 0; + this.updateLogButton(); }); } + send(serialized) { + this.log("sent", serialized); + this.ws.send(serialized); + } + setConnectedState(connected) { + this.connected = connected; $(`${this.type}-btn-connect`).disabled = connected; - $(`${this.type}-btn-ping`).disabled = !connected; - $(`${this.type}-set-key`).disabled = !connected; - $(`${this.type}-set-value`).disabled = !connected; - $(`${this.type}-btn-set`).disabled = !connected; - $(`${this.type}-get-key`).disabled = !connected; - $(`${this.type}-btn-get`).disabled = !connected; - $(`${this.type}-btn-clear`).disabled = !connected; + $(`${this.type}-btn-ping`).disabled = !connected || this.subscribed; + $(`${this.type}-set-key`).disabled = !connected || this.subscribed; + $(`${this.type}-set-value`).disabled = !connected || this.subscribed; + $(`${this.type}-btn-set`).disabled = !connected || this.subscribed; + $(`${this.type}-get-key`).disabled = !connected || this.subscribed; + $(`${this.type}-btn-get`).disabled = !connected || this.subscribed; + $(`${this.type}-pub-channel`).disabled = !connected || this.subscribed; + $(`${this.type}-pub-message`).disabled = !connected || this.subscribed; + $(`${this.type}-btn-pub`).disabled = !connected || this.subscribed; + $(`${this.type}-sub-channel`).disabled = !connected || this.subscribed; + $(`${this.type}-btn-sub`).disabled = !connected || this.subscribed; + $(`${this.type}-state`).innerText = `State: ${connected ? 'Connected' : 'Disconnected'}`; } + updateLogButton() { + $(`${this.type}-btn-clear`).disabled = this.logCount === 0; + } + log(dir, msg) { const id = `${this.type}-log`; @@ -190,6 +249,8 @@ class Client { contents.setAttribute("class", dir); contents.innerHTML = msg; $(id).appendChild(contents); + this.logCount++; + this.updateLogButton(); } } @@ -200,12 +261,16 @@ addEventListener("DOMContentLoaded", () => { const jsonClient = new Client('json', () => JSON.stringify(['PING']), (key) => JSON.stringify(['GET', key]), - (key, value) => JSON.stringify(['SET', key, value])); + (key, value) => JSON.stringify(['SET', key, value]), + (channel, message) => JSON.stringify(['PUBLISH', channel, message]), + (channel) => JSON.stringify(['SUBSCRIBE', channel])); const rawClient = new Client('raw', () => '*1\r\n$4\r\nPING\r\n', (key) => `*2\r\n$3\r\nGET\r\n$${key.length}\r\n${key}\r\n`, - (key, value) => `*3\r\n$3\r\nSET\r\n$${key.length}\r\n${key}\r\n$${value.length}\r\n${value}\r\n`); + (key, value) => `*3\r\n$3\r\nSET\r\n$${key.length}\r\n${key}\r\n$${value.length}\r\n${value}\r\n`, + (channel, message) => `*3\r\n$7\r\nPUBLISH\r\n$${channel.length}\r\n${channel}\r\n$${message.length}\r\n${message}\r\n`, + (channel) => `*2\r\n$9\r\nSUBSCRIBE\r\n$${channel.length}\r\n${channel}\r\n`); }); From 67490fb825745dde37a81b340c9701eb6c9e50c7 Mon Sep 17 00:00:00 2001 From: Jessie Murray Date: Sun, 1 Aug 2021 22:18:53 -0700 Subject: [PATCH 15/25] Address review comments --- src/formats/common.c | 11 +++--- src/pool.c | 1 - src/slog.c | 4 ++ src/websocket.c | 87 +++++++++++++++++++++++++++----------------- src/websocket.h | 5 ++- src/worker.c | 23 +++++++----- 6 files changed, 80 insertions(+), 51 deletions(-) diff --git a/src/formats/common.c b/src/formats/common.c index 50e4780..5304b09 100644 --- a/src/formats/common.c +++ b/src/formats/common.c @@ -50,11 +50,12 @@ format_send_error(struct cmd *cmd, short code, const char *msg) { ws_frame_and_send_response(cmd->http_client->ws, WS_BINARY_FRAME, msg, strlen(msg)); } - /* for pub/sub, remove command from client */ - if(cmd->pub_sub_client) { - cmd->pub_sub_client->self_cmd = NULL; - } else if (!cmd->is_websocket) { /* don't free persistent cmd */ - cmd_free(cmd); + if (!cmd->is_websocket) { /* don't free or detach persistent cmd */ + if (cmd->pub_sub_client) { /* for pub/sub, remove command from client */ + cmd->pub_sub_client->self_cmd = NULL; + } else { + cmd_free(cmd); + } } } diff --git a/src/pool.c b/src/pool.c index 741c7da..f3e7b42 100644 --- a/src/pool.c +++ b/src/pool.c @@ -2,7 +2,6 @@ #include "worker.h" #include "conf.h" #include "server.h" -#include "formats/common.h" #include #include diff --git a/src/slog.c b/src/slog.c index c48489d..bdb600d 100644 --- a/src/slog.c +++ b/src/slog.c @@ -13,6 +13,10 @@ #include "server.h" #include "conf.h" +#if SLOG_MSG_MAX_LEN < 64 +#error "SLOG_MSG_MAX_LEN must be at least 64" +#endif + /** * Initialize log writer. */ diff --git a/src/websocket.c b/src/websocket.c index 27358ea..f8633ec 100644 --- a/src/websocket.c +++ b/src/websocket.c @@ -9,7 +9,6 @@ #include "slog.h" #include "server.h" #include "conf.h" -#include "formats/common.h" /* message parsers */ #include "formats/json.h" @@ -21,9 +20,8 @@ #include #include #include -#include -static void +static int ws_schedule_write(struct ws_client *ws); /** @@ -120,7 +118,7 @@ ws_client_new(struct http_client *http_client) { return ws; } -static void +void ws_client_free(struct ws_client *ws) { /* mark WS client as closing to skip the Redis callback */ @@ -128,7 +126,7 @@ ws_client_free(struct ws_client *ws) { pool_free_context(ws->ac); /* could trigger a cb via format_send_error */ struct http_client *c = ws->http_client; - c->ws = NULL; /* detach */ + if(c) c->ws = NULL; /* detach if needed */ evbuffer_free(ws->rbuf); evbuffer_free(ws->wbuf); if(ws->cmd) { @@ -136,7 +134,7 @@ ws_client_free(struct ws_client *ws) { cmd_free(ws->cmd); } free(ws); - http_client_free(c); + if(c) http_client_free(c); } @@ -233,9 +231,7 @@ ws_handshake_reply(struct ws_client *ws) { return -1; } - ws_schedule_write(ws); /* will free buffer and response once sent */ - - return 0; + return ws_schedule_write(ws); /* will free buffer and response once sent */ } static void @@ -335,12 +331,15 @@ ws_msg_new(enum ws_frame_type frame_type) { return msg; } -static void +static int ws_msg_add(struct ws_msg *m, const char *p, size_t psz, const unsigned char *mask) { /* add data to frame */ size_t i; m->payload = realloc(m->payload, m->payload_sz + psz); + if(!m->payload) { + return -1; + } memcpy(m->payload + m->payload_sz, p, psz); /* apply mask */ @@ -350,6 +349,7 @@ ws_msg_add(struct ws_msg *m, const char *p, size_t psz, const unsigned char *mas /* save new size */ m->payload_sz += psz; + return 0; } static void @@ -381,7 +381,11 @@ ws_peek_data(struct ws_client *ws, struct ws_msg **out_msg) { if(!frame) { return WS_ERROR; } - evbuffer_remove(ws->rbuf, frame, sz); + int rem_ret = evbuffer_remove(ws->rbuf, frame, sz); + if(rem_ret < 0) { + free(frame); + return WS_ERROR; + } fin_bit_set = frame[0] & 0x80 ? 1 : 0; frame_type = frame[0] & 0x0F; /* lower 4 bits of first byte */ @@ -409,25 +413,39 @@ ws_peek_data(struct ws_client *ws, struct ws_msg **out_msg) { /* we now have the (possibly masked) data starting in p, and its length. */ if(len > sz - (p - frame)) { /* not enough data */ - evbuffer_prepend(ws->rbuf, frame, sz); /* put the whole frame back */ + int add_ret = evbuffer_prepend(ws->rbuf, frame, sz); /* put the whole frame back */ free(frame); - return WS_READING; + return add_ret < 0 ? WS_ERROR : WS_READING; } + int ev_copy = 0; if(out_msg) { /* we're extracting the message */ struct ws_msg *msg = ws_msg_new(frame_type); - ws_msg_add(msg, p, len, has_mask ? mask : NULL); + if(!msg) { + free(frame); + return WS_ERROR; + } + *out_msg = msg; /* attach for it to be freed by caller */ + + /* create new ws_msg object holding what we read */ + int add_ret = ws_msg_add(msg, p, len, has_mask ? mask : NULL); + if(!add_ret) { + free(frame); + return WS_ERROR; + } + size_t processed_sz = len + (p - frame); /* length of data + header bytes between frame start and payload */ msg->total_sz += processed_sz; - *out_msg = msg; - evbuffer_prepend(ws->rbuf, frame + len, sz - processed_sz); /* remove processed data */ + ev_copy = evbuffer_prepend(ws->rbuf, frame + len, sz - processed_sz); /* remove processed data */ } else { /* we're just peeking */ - evbuffer_prepend(ws->rbuf, frame, sz); /* put the whole frame back */ + ev_copy = evbuffer_prepend(ws->rbuf, frame, sz); /* put the whole frame back */ } free(frame); - if(fin_bit_set) { /* FIN bit set */ + if(ev_copy < 0) { + return WS_ERROR; + } else if(fin_bit_set) { return WS_MSG_COMPLETE; } else { return WS_READING; /* need more data */ @@ -448,26 +466,26 @@ ws_process_read_data(struct ws_client *ws, unsigned int *out_processed) { while(state == WS_MSG_COMPLETE) { int ret = 0; - struct ws_msg *msg; + struct ws_msg *msg = NULL; ws_peek_data(ws, &msg); /* extract message */ - if(msg->type == WS_TEXT_FRAME || msg->type == WS_BINARY_FRAME) { + if(msg && (msg->type == WS_TEXT_FRAME || msg->type == WS_BINARY_FRAME)) { ret = ws_execute(ws, msg); if(out_processed) (*out_processed)++; - } else if(msg->type == WS_PING) { /* respond to ping */ + } else if(msg && msg->type == WS_PING) { /* respond to ping */ ws_frame_and_send_response(ws, WS_PONG, msg->payload, msg->payload_sz); - } else if(msg->type == WS_CONNECTION_CLOSE) { /* respond to close frame */ + } else if(msg && msg->type == WS_CONNECTION_CLOSE) { /* respond to close frame */ ws->close_after_events = 1; ws_frame_and_send_response(ws, WS_CONNECTION_CLOSE, msg->payload, msg->payload_sz); - } else { + } else if(msg) { char format[] = "Received unexpected WS frame type: 0x%x"; char error[(sizeof format)]; - snprintf(error, sizeof(error), format, msg->type); - slog(ws->http_client->s, WEBDIS_WARNING, error, 0); + int error_len = snprintf(error, sizeof(error), format, msg->type); + slog(ws->http_client->s, WEBDIS_WARNING, error, error_len); } /* free frame */ - ws_msg_free(msg); + if(msg) ws_msg_free(msg); if(ret != 0) { /* can't process frame. */ @@ -522,13 +540,12 @@ ws_frame_and_send_response(struct ws_client *ws, enum ws_frame_type frame_type, } /* send WS frame */ - ws_schedule_write(ws); - - return 0; + return ws_schedule_write(ws); } static void ws_close_if_able(struct ws_client *ws) { + ws->close_after_events = 1; /* note that we're closing */ if(ws->scheduled_read || ws->scheduled_write) { return; /* still waiting for these events to trigger */ @@ -571,7 +588,7 @@ ws_can_write(int fd, short event, void *p) { if(ret <= 0) { ws_client_free(ws); /* will close the socket */ - } else if(ret > 0) { + } else { if(evbuffer_get_length(ws->wbuf) > 0) { /* more data to send */ ws_schedule_write(ws); } else if(ws->close_after_events) { /* we're done! */ @@ -585,20 +602,22 @@ ws_can_write(int fd, short event, void *p) { } } -static void +static int ws_schedule_write(struct ws_client *ws) { struct http_client *c = ws->http_client; if(!ws->scheduled_write) { ws->scheduled_write = 1; - event_base_once(c->w->base, c->fd, EV_WRITE, ws_can_write, ws, NULL); + return event_base_once(c->w->base, c->fd, EV_WRITE, ws_can_write, ws, NULL); } + return 0; } -void +int ws_monitor_input(struct ws_client *ws) { struct http_client *c = ws->http_client; if(!ws->scheduled_read) { ws->scheduled_read = 1; - event_base_once(c->w->base, c->fd, EV_READ, ws_can_read, ws, NULL); + return event_base_once(c->w->base, c->fd, EV_READ, ws_can_read, ws, NULL); } + return 0; } diff --git a/src/websocket.h b/src/websocket.h index 7376546..f15af74 100644 --- a/src/websocket.h +++ b/src/websocket.h @@ -45,10 +45,13 @@ struct ws_client { struct ws_client * ws_client_new(struct http_client *http_client); +void +ws_client_free(struct ws_client *ws); + int ws_handshake_reply(struct ws_client *ws); -void +int ws_monitor_input(struct ws_client *ws); enum ws_state diff --git a/src/worker.c b/src/worker.c index eba6b07..d440546 100644 --- a/src/worker.c +++ b/src/worker.c @@ -54,8 +54,7 @@ worker_can_read(int fd, short event, void *p) { } } - if(c->is_websocket) { - } else { + if(!c->is_websocket) { /* run parser */ nparsed = http_client_execute(c); @@ -66,7 +65,6 @@ worker_can_read(int fd, short event, void *p) { /* only close if requested *and* we've already read the request in full */ c->broken = 1; } else if(c->is_websocket) { - event_del(&c->ev); /* Got websocket data */ c->ws = ws_client_new(c); @@ -77,14 +75,19 @@ worker_can_read(int fd, short event, void *p) { c->buffer = NULL; c->sz = 0; - unsigned int processed = 0; - int process_ret = ws_process_read_data(c->ws, &processed); - if(process_ret == WS_ERROR) { - c->broken = 1; /* likely connection was closed */ - } - /* send response, and start managing fd from websocket.c */ - ws_handshake_reply(c->ws); + int reply_ret = ws_handshake_reply(c->ws); + if(reply_ret < 0) { + c->ws->http_client = NULL; /* detach to prevent double free */ + ws_client_free(c->ws); + c->broken = 1; + } else { + unsigned int processed = 0; + int process_ret = ws_process_read_data(c->ws, &processed); + if(process_ret == WS_ERROR) { + c->broken = 1; /* likely connection was closed */ + } + } } /* clean up what remains in HTTP client */ From b65c05a985c9116d87e62a1244911da1226892c8 Mon Sep 17 00:00:00 2001 From: Jessie Murray Date: Sun, 1 Aug 2021 22:45:38 -0700 Subject: [PATCH 16/25] Reject unauthorized commands after SUBSCRIBE Redis docs mention that a subscribed client cannot send a non-subscription-related command. --- src/cmd.c | 30 ++++++++++++++++++++++++------ src/cmd.h | 6 ++++++ src/websocket.c | 14 +++++++++++--- src/websocket.h | 1 + src/worker.c | 7 ++----- 5 files changed, 44 insertions(+), 14 deletions(-) diff --git a/src/cmd.c b/src/cmd.c index df268a3..f064518 100644 --- a/src/cmd.c +++ b/src/cmd.c @@ -375,13 +375,31 @@ cmd_select_format(struct http_client *client, struct cmd *cmd, int cmd_is_subscribe(struct cmd *cmd) { - if(cmd->pub_sub_client) { /* persistent command */ - return 1; - } - if(cmd->count >= 1 && cmd->argv[0] && - (strncasecmp(cmd->argv[0], "SUBSCRIBE", cmd->argv_len[0]) == 0 || - strncasecmp(cmd->argv[0], "PSUBSCRIBE", cmd->argv_len[0]) == 0)) { + if(cmd->pub_sub_client || /* persistent command */ + cmd_is_subscribe_args(cmd)) { /* checked with args */ return 1; } return 0; } + +int +cmd_is_subscribe_args(struct cmd *cmd) { + + if(cmd->count >= 2 && + ((cmd->argv_len[0] == 9 && strncasecmp(cmd->argv[0], "subscribe", 9) == 0) || + (cmd->argv_len[0] == 10 && strncasecmp(cmd->argv[0], "psubscribe", 10) == 0))) { + return 1; + } + return 0; +} + +int +cmd_is_unsubscribe_args(struct cmd *cmd) { + + if(cmd->count >= 2 && + ((cmd->argv_len[0] == 11 && strncasecmp(cmd->argv[0], "unsubscribe", 11) == 0) || + (cmd->argv_len[0] == 12 && strncasecmp(cmd->argv[0], "punsubscribe", 12) == 0))) { + return 1; + } + return 0; +} diff --git a/src/cmd.h b/src/cmd.h index 78b5bce..823c3e3 100644 --- a/src/cmd.h +++ b/src/cmd.h @@ -72,6 +72,12 @@ int cmd_select_format(struct http_client *client, struct cmd *cmd, const char *uri, size_t uri_len, formatting_fun *f_format); +int +cmd_is_subscribe_args(struct cmd *cmd); + +int +cmd_is_unsubscribe_args(struct cmd *cmd); + int cmd_is_subscribe(struct cmd *cmd); diff --git a/src/websocket.c b/src/websocket.c index f8633ec..edcb9ff 100644 --- a/src/websocket.c +++ b/src/websocket.c @@ -310,9 +310,17 @@ ws_execute(struct ws_client *ws, struct ws_msg *msg) { cmd->pub_sub_client = c; } - /* log and execute */ - ws_log_cmd(ws, cmd); - cmd_send(cmd, fun_reply); + int is_subscribe = cmd_is_subscribe_args(cmd); + int is_unsubscribe = cmd_is_unsubscribe_args(cmd); + + if(ws->ran_subscribe && !is_subscribe && !is_unsubscribe) { /* disallow non-subscribe commands after a subscribe */ + char error_msg[] = "Command not allowed after subscribe"; + ws_frame_and_send_response(ws, WS_BINARY_FRAME, error_msg, sizeof(error_msg)-1); + } else { /* log and execute */ + ws_log_cmd(ws, cmd); + cmd_send(cmd, fun_reply); + ws->ran_subscribe = is_subscribe; + } return 0; } diff --git a/src/websocket.h b/src/websocket.h index f15af74..8b176a4 100644 --- a/src/websocket.h +++ b/src/websocket.h @@ -40,6 +40,7 @@ struct ws_client { /* indicates that we'll close once we've flushed all buffered data and read what we planned to read */ int close_after_events; + int ran_subscribe; /* set if we've run a (p)subscribe command */ }; struct ws_client * diff --git a/src/worker.c b/src/worker.c index d440546..5d06cf6 100644 --- a/src/worker.c +++ b/src/worker.c @@ -108,11 +108,8 @@ worker_can_read(int fd, short event, void *p) { close(c->fd); } http_client_free(c); - } else { - /* start monitoring input again */ - if(c->is_websocket) { /* all communication handled by WS code from now on */ - // ws_monitor_input(c->ws); - } else { + } else { /* start monitoring input again */ + if(!c->is_websocket) { /* all communication handled by WS code from now on */ worker_monitor_input(c); } } From d48353cec3c7a2ca5acbb681460c5591d9875262 Mon Sep 17 00:00:00 2001 From: Jessie Murray Date: Sun, 1 Aug 2021 22:56:04 -0700 Subject: [PATCH 17/25] Rename self_cmd to reused_cmd --- src/client.c | 10 +++++----- src/client.h | 2 +- src/cmd.c | 4 ++-- src/formats/common.c | 2 +- tests/websocket.html | 2 +- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/client.c b/src/client.c index c991895..ed6f3f3 100644 --- a/src/client.c +++ b/src/client.c @@ -287,14 +287,14 @@ http_client_read(struct http_client *c) { /* broken link, free buffer and client object */ /* disconnect pub/sub or WS client if there is one. */ - if(c->self_cmd && c->self_cmd->ac) { - struct cmd *cmd = c->self_cmd; + if(c->reused_cmd && c->reused_cmd->ac) { + struct cmd *cmd = c->reused_cmd; /* disconnect from all channels */ - redisAsyncDisconnect(c->self_cmd->ac); - // c->self_cmd might be already cleared by an event handler in redisAsyncDisconnect + redisAsyncDisconnect(c->reused_cmd->ac); + // c->reused_cmd might be already cleared by an event handler in redisAsyncDisconnect cmd->ac = NULL; - c->self_cmd = NULL; + c->reused_cmd = NULL; /* delete command object */ cmd_free(cmd); diff --git a/src/client.h b/src/client.h index 32e4c74..c355bac 100644 --- a/src/client.h +++ b/src/client.h @@ -61,7 +61,7 @@ struct http_client { char *separator; /* list separator for raw lists */ char *filename; /* content-disposition */ - struct cmd *self_cmd; + struct cmd *reused_cmd; struct ws_client *ws; /* websocket client */ }; diff --git a/src/cmd.c b/src/cmd.c index f064518..6fcc528 100644 --- a/src/cmd.c +++ b/src/cmd.c @@ -229,7 +229,7 @@ cmd_run(struct worker *w, struct http_client *client, cmd->ac = (redisAsyncContext*)pool_connect(w->pool, cmd->database, 0); /* register with the client, used upon disconnection */ - client->self_cmd = cmd; + client->reused_cmd = cmd; cmd->pub_sub_client = client; } else if(cmd->database != w->s->cfg->database) { /* create a new connection to Redis for custom DBs */ @@ -281,7 +281,7 @@ cmd_run(struct worker *w, struct http_client *client, } /* failed to find a suitable connection to Redis. */ cmd_free(cmd); - client->self_cmd = NULL; + client->reused_cmd = NULL; return CMD_REDIS_UNAVAIL; } diff --git a/src/formats/common.c b/src/formats/common.c index 5304b09..1932143 100644 --- a/src/formats/common.c +++ b/src/formats/common.c @@ -52,7 +52,7 @@ format_send_error(struct cmd *cmd, short code, const char *msg) { if (!cmd->is_websocket) { /* don't free or detach persistent cmd */ if (cmd->pub_sub_client) { /* for pub/sub, remove command from client */ - cmd->pub_sub_client->self_cmd = NULL; + cmd->pub_sub_client->reused_cmd = NULL; } else { cmd_free(cmd); } diff --git a/tests/websocket.html b/tests/websocket.html index 276bf84..f63bc3b 100644 --- a/tests/websocket.html +++ b/tests/websocket.html @@ -161,7 +161,6 @@ class Client { }; this.ws.onclose = event => { - console.log('ON CLOSE') $(`${this.type}-btn-connect`).innerText = 'Connect'; this.setConnectedState(false); this.subscribed = false; @@ -200,6 +199,7 @@ class Client { this.subscribed = true; this.setConnectedState(true); } catch (e) { + console.log('Error sending: ', serialized, e); } }); From 71223ae005e83e09f1da114fd345f30b883724c5 Mon Sep 17 00:00:00 2001 From: Jessie Murray Date: Mon, 2 Aug 2021 06:45:19 -0700 Subject: [PATCH 18/25] Address review comments (tests) --- tests/websocket.html | 42 ++++++++++++++++++------------------------ tests/ws-tests.py | 9 ++++++--- 2 files changed, 24 insertions(+), 27 deletions(-) diff --git a/tests/websocket.html b/tests/websocket.html index f63bc3b..00cca12 100644 --- a/tests/websocket.html +++ b/tests/websocket.html @@ -133,13 +133,9 @@ function installBlock(title, type) { class Client { - constructor(type, pingSerializer, getSerializer, setSerializer, pubSerializer, subSerializer) { + constructor(type, serializer) { this.type = type; - this.pingSerializer = pingSerializer; - this.getSerializer = getSerializer; - this.setSerializer = setSerializer; - this.pubSerializer = pubSerializer; - this.subSerializer = subSerializer; + this.serializer = serializer; this.ws = null; this.connected = false; this.subscribed = false; @@ -169,31 +165,31 @@ class Client { $(`${this.type}-btn-ping`).addEventListener('click', event => { event.preventDefault(); - const serialized = this.pingSerializer(); + const serialized = this.serializer(['PING']); this.send(serialized); }); $(`${this.type}-btn-set`).addEventListener('click', event => { event.preventDefault(); - const serialized = this.setSerializer($(`${this.type}-set-key`).value, $(`${this.type}-set-value`).value); + const serialized = this.serializer(['SET', $(`${this.type}-set-key`).value, $(`${this.type}-set-value`).value]); this.send(serialized); }); $(`${this.type}-btn-get`).addEventListener('click', event => { event.preventDefault(); - const serialized = this.getSerializer($(`${this.type}-get-key`).value); + const serialized = this.serializer(['GET', $(`${this.type}-get-key`).value]); this.send(serialized); }); $(`${this.type}-btn-pub`).addEventListener('click', event => { event.preventDefault(); - const serialized = this.pubSerializer($(`${this.type}-pub-channel`).value, $(`${this.type}-pub-message`).value); + const serialized = this.serializer(['PUBLISH', $(`${this.type}-pub-channel`).value, $(`${this.type}-pub-message`).value]); this.send(serialized); }); $(`${this.type}-btn-sub`).addEventListener('click', event => { event.preventDefault(); - const serialized = this.subSerializer($(`${this.type}-sub-channel`).value); + const serialized = this.serializer(['SUBSCRIBE', $(`${this.type}-sub-channel`).value]); try { this.send(serialized); this.subscribed = true; @@ -254,23 +250,21 @@ class Client { } } +function serializeRaw(args) { + let raw = `*${args.length}\r\n`; + for (let i = 0; i < args.length; i++) { + raw += `$${args[i].length}\r\n`; + raw += `${args[i]}\r\n`; + } + return raw; +} + addEventListener("DOMContentLoaded", () => { installBlock('JSON', 'json'); installBlock('Raw', 'raw'); - const jsonClient = new Client('json', - () => JSON.stringify(['PING']), - (key) => JSON.stringify(['GET', key]), - (key, value) => JSON.stringify(['SET', key, value]), - (channel, message) => JSON.stringify(['PUBLISH', channel, message]), - (channel) => JSON.stringify(['SUBSCRIBE', channel])); - - const rawClient = new Client('raw', - () => '*1\r\n$4\r\nPING\r\n', - (key) => `*2\r\n$3\r\nGET\r\n$${key.length}\r\n${key}\r\n`, - (key, value) => `*3\r\n$3\r\nSET\r\n$${key.length}\r\n${key}\r\n$${value.length}\r\n${value}\r\n`, - (channel, message) => `*3\r\n$7\r\nPUBLISH\r\n$${channel.length}\r\n${channel}\r\n$${message.length}\r\n${message}\r\n`, - (channel) => `*2\r\n$9\r\nSUBSCRIBE\r\n$${channel.length}\r\n${channel}\r\n`); + const jsonClient = new Client('json', JSON.stringify); + const rawClient = new Client('raw', serializeRaw); }); diff --git a/tests/ws-tests.py b/tests/ws-tests.py index dadd0f7..5e6c62c 100755 --- a/tests/ws-tests.py +++ b/tests/ws-tests.py @@ -10,9 +10,12 @@ from websocket import create_connection host = os.getenv('WEBDIS_HOST', '127.0.0.1') port = int(os.getenv('WEBDIS_PORT', 7379)) +def connect(format): + return create_connection(f'ws://{host}:{port}/.{format}') + class TestWebdis(unittest.TestCase): def setUp(self) -> None: - self.ws = create_connection(f'ws://{host}:{port}/.{self.format()}') + self.ws = connect(self.format()) def tearDown(self) -> None: self.ws.close() @@ -95,8 +98,8 @@ class TestRaw(TestWebdis): @unittest.skipIf(os.getenv('PUBSUB') != '1', "pub-sub test fail due to invalid ordering") class TestPubSub(unittest.TestCase): def setUp(self): - self.publisher = create_connection(f'ws://{host}:{port}/.json') - self.subscriber = create_connection(f'ws://{host}:{port}/.json') + self.publisher = connect('json') + self.subscriber = connect('json') def tearDown(self): self.publisher.close() From bb02c1dd04004f48576d5eec589523682685e10c Mon Sep 17 00:00:00 2001 From: Jessie Murray Date: Mon, 2 Aug 2021 06:48:06 -0700 Subject: [PATCH 19/25] Formatting only: make ws-tests.py PEP8 compliant. --- tests/ws-tests.py | 269 +++++++++++++++++++++++----------------------- 1 file changed, 136 insertions(+), 133 deletions(-) diff --git a/tests/ws-tests.py b/tests/ws-tests.py index 5e6c62c..85ef909 100755 --- a/tests/ws-tests.py +++ b/tests/ws-tests.py @@ -10,160 +10,163 @@ from websocket import create_connection host = os.getenv('WEBDIS_HOST', '127.0.0.1') port = int(os.getenv('WEBDIS_PORT', 7379)) + def connect(format): - return create_connection(f'ws://{host}:{port}/.{format}') + return create_connection(f'ws://{host}:{port}/.{format}') + class TestWebdis(unittest.TestCase): - def setUp(self) -> None: - self.ws = connect(self.format()) - - def tearDown(self) -> None: - self.ws.close() - - def exec(self, cmd, *args): - self.ws.send(self.serialize(cmd, *args)) - return self.deserialize(self.ws.recv()) - - def clean_key(self): - """Returns a key that was just deleted""" - key = str(uuid.uuid4()) - self.exec('DEL', key) - return key - - @abc.abstractmethod - def format(self): - """Returns the format to use (added after a dot to the WS URI)""" - return - - @abc.abstractmethod - def serialize(self, cmd): - """Serializes a command according to the format being tested""" - return - - @abc.abstractmethod - def deserialize(self, response): - """Deserializes a response according to the format being tested""" - return + def setUp(self) -> None: + self.ws = connect(self.format()) + + def tearDown(self) -> None: + self.ws.close() + + def exec(self, cmd, *args): + self.ws.send(self.serialize(cmd, *args)) + return self.deserialize(self.ws.recv()) + + def clean_key(self): + """Returns a key that was just deleted""" + key = str(uuid.uuid4()) + self.exec('DEL', key) + return key + + @abc.abstractmethod + def format(self): + """Returns the format to use (added after a dot to the WS URI)""" + return + + @abc.abstractmethod + def serialize(self, cmd): + """Serializes a command according to the format being tested""" + return + + @abc.abstractmethod + def deserialize(self, response): + """Deserializes a response according to the format being tested""" + return class TestJson(TestWebdis): - def format(self): - return 'json' + def format(self): + return 'json' - def serialize(self, cmd, *args): - return json.dumps([cmd] + list(args)) + def serialize(self, cmd, *args): + return json.dumps([cmd] + list(args)) - def deserialize(self, response): - return json.loads(response) + def deserialize(self, response): + return json.loads(response) - def test_ping(self): - self.assertEqual(self.exec('PING'), {'PING': [True, 'PONG']}) + def test_ping(self): + self.assertEqual(self.exec('PING'), {'PING': [True, 'PONG']}) - def test_multiple_messages(self): - key = self.clean_key() - n = 100 - for i in range(n): - lpush_response = self.exec('LPUSH', key, f'value-{i}') - self.assertEqual(lpush_response, {'LPUSH': i + 1}) - self.assertEqual(self.exec('LLEN', key), {'LLEN': n}) + def test_multiple_messages(self): + key = self.clean_key() + n = 100 + for i in range(n): + lpush_response = self.exec('LPUSH', key, f'value-{i}') + self.assertEqual(lpush_response, {'LPUSH': i + 1}) + self.assertEqual(self.exec('LLEN', key), {'LLEN': n}) class TestRaw(TestWebdis): - def format(self): - return 'raw' + def format(self): + return 'raw' - def serialize(self, cmd, *args): - buffer = f"*{1 + len(args)}\r\n${len(cmd)}\r\n{cmd}\r\n" - for arg in args: - buffer += f"${len(arg)}\r\n{arg}\r\n" - return buffer + def serialize(self, cmd, *args): + buffer = f"*{1 + len(args)}\r\n${len(cmd)}\r\n{cmd}\r\n" + for arg in args: + buffer += f"${len(arg)}\r\n{arg}\r\n" + return buffer - def deserialize(self, response): - return response # we'll just assert using the raw protocol + def deserialize(self, response): + return response # we'll just assert using the raw protocol - def test_ping(self): - self.assertEqual(self.exec('PING'), "+PONG\r\n") + def test_ping(self): + self.assertEqual(self.exec('PING'), "+PONG\r\n") - def test_get_set(self): - key = self.clean_key() - value = str(uuid.uuid4()) - not_found_response = self.exec('GET', key) - self.assertEqual(not_found_response, "$-1\r\n") # Redis protocol response for "not found" - set_response = self.exec('SET', key, value) - self.assertEqual(set_response, "+OK\r\n") - get_response = self.exec('GET', key) - self.assertEqual(get_response, f"${len(value)}\r\n{value}\r\n") + def test_get_set(self): + key = self.clean_key() + value = str(uuid.uuid4()) + not_found_response = self.exec('GET', key) + self.assertEqual(not_found_response, "$-1\r\n") # Redis protocol response for "not found" + set_response = self.exec('SET', key, value) + self.assertEqual(set_response, "+OK\r\n") + get_response = self.exec('GET', key) + self.assertEqual(get_response, f"${len(value)}\r\n{value}\r\n") @unittest.skipIf(os.getenv('PUBSUB') != '1', "pub-sub test fail due to invalid ordering") class TestPubSub(unittest.TestCase): - def setUp(self): - self.publisher = connect('json') - self.subscriber = connect('json') - - def tearDown(self): - self.publisher.close() - self.subscriber.close() - - def serialize(self, cmd, *args): - return json.dumps([cmd] + list(args)) - - def deserialize(self, response): - return json.loads(response) - - def test_publish_subscribe(self): - channel_count = 2 - message_count_per_channel = 8 - channels = list(str(uuid.uuid4()) for i in range(channel_count)) - - # subscribe to all channels - sub_count = 0 - for channel in channels: - self.subscriber.send(self.serialize('SUBSCRIBE', channel)) - sub_response = self.deserialize(self.subscriber.recv()) - sub_count += 1 - self.assertEqual(sub_response, {'SUBSCRIBE': ['subscribe', channel, sub_count]}) - - # send messages to all channels - prefix = 'message-' - for i in range(message_count_per_channel): - for channel in channels: - message = f'{prefix}{i}' - self.publisher.send(self.serialize('PUBLISH', channel, message)) - self.deserialize(self.publisher.recv()) - - received_per_channel = dict((channel, []) for channel in channels) - for j in range(channel_count * message_count_per_channel): - received = self.deserialize(self.subscriber.recv()) - # expected: {'SUBSCRIBE': ['message', $channel, $message]} - self.assertTrue(received, 'SUBSCRIBE' in received) - sub_contents = received['SUBSCRIBE'] - self.assertEqual(len(sub_contents), 3) - - self.assertEqual(sub_contents[0], 'message') # first element is the message type, here a push - channel = sub_contents[1] - self.assertTrue(channel in channels) # second is the channel - received_per_channel[channel].append(sub_contents[2]) # third, add to list of messages received for this channel - - # unsubscribe from all channels - subs_remaining = channel_count - for channel in channels: - self.subscriber.send(self.serialize('UNSUBSCRIBE', channel)) - subs_remaining -= 1 - unsub_response = self.deserialize(self.subscriber.recv()) - self.assertEqual(unsub_response, {'UNSUBSCRIBE': ['unsubscribe', channel, subs_remaining]}) - - # check that we received all messages - for channel in channels: - self.assertEqual(len(received_per_channel[channel]), message_count_per_channel) - - # check that we received them *in order* - for i in range(message_count_per_channel): - for channel in channels: - expected = f'{prefix}{i}' - self.assertEqual(received_per_channel[channel][i], expected, - f'In {channel}: expected at offset {i} was "{expected}", actual was: "{received_per_channel[channel][i]}"') + def setUp(self): + self.publisher = connect('json') + self.subscriber = connect('json') + + def tearDown(self): + self.publisher.close() + self.subscriber.close() + + def serialize(self, cmd, *args): + return json.dumps([cmd] + list(args)) + + def deserialize(self, response): + return json.loads(response) + + def test_publish_subscribe(self): + channel_count = 2 + message_count_per_channel = 8 + channels = list(str(uuid.uuid4()) for i in range(channel_count)) + + # subscribe to all channels + sub_count = 0 + for channel in channels: + self.subscriber.send(self.serialize('SUBSCRIBE', channel)) + sub_response = self.deserialize(self.subscriber.recv()) + sub_count += 1 + self.assertEqual(sub_response, {'SUBSCRIBE': ['subscribe', channel, sub_count]}) + + # send messages to all channels + prefix = 'message-' + for i in range(message_count_per_channel): + for channel in channels: + message = f'{prefix}{i}' + self.publisher.send(self.serialize('PUBLISH', channel, message)) + self.deserialize(self.publisher.recv()) + + received_per_channel = dict((channel, []) for channel in channels) + for j in range(channel_count * message_count_per_channel): + received = self.deserialize(self.subscriber.recv()) + # expected: {'SUBSCRIBE': ['message', $channel, $message]} + self.assertTrue(received, 'SUBSCRIBE' in received) + sub_contents = received['SUBSCRIBE'] + self.assertEqual(len(sub_contents), 3) + + self.assertEqual(sub_contents[0], 'message') # first element is the message type, here a push + channel = sub_contents[1] + self.assertTrue(channel in channels) # second is the channel + received_per_channel[channel].append( + sub_contents[2]) # third, add to list of messages received for this channel + + # unsubscribe from all channels + subs_remaining = channel_count + for channel in channels: + self.subscriber.send(self.serialize('UNSUBSCRIBE', channel)) + subs_remaining -= 1 + unsub_response = self.deserialize(self.subscriber.recv()) + self.assertEqual(unsub_response, {'UNSUBSCRIBE': ['unsubscribe', channel, subs_remaining]}) + + # check that we received all messages + for channel in channels: + self.assertEqual(len(received_per_channel[channel]), message_count_per_channel) + + # check that we received them *in order* + for i in range(message_count_per_channel): + for channel in channels: + expected = f'{prefix}{i}' + self.assertEqual(received_per_channel[channel][i], expected, + f'In {channel}: expected at offset {i} was "{expected}", actual was: "{received_per_channel[channel][i]}"') if __name__ == '__main__': - unittest.main() + unittest.main() From 1cbffb63c988d34eb24e9009c4b7c7c0f32fb447 Mon Sep 17 00:00:00 2001 From: Jessie Murray Date: Mon, 2 Aug 2021 06:53:14 -0700 Subject: [PATCH 20/25] Re-enable pub/sub test in ws-tests --- tests/ws-tests.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/ws-tests.py b/tests/ws-tests.py index 85ef909..459ab7a 100755 --- a/tests/ws-tests.py +++ b/tests/ws-tests.py @@ -97,7 +97,6 @@ class TestRaw(TestWebdis): self.assertEqual(get_response, f"${len(value)}\r\n{value}\r\n") -@unittest.skipIf(os.getenv('PUBSUB') != '1', "pub-sub test fail due to invalid ordering") class TestPubSub(unittest.TestCase): def setUp(self): self.publisher = connect('json') From 7f09680c862d874d276f3641e45f6f12f935d6b6 Mon Sep 17 00:00:00 2001 From: Jessie Murray Date: Mon, 2 Aug 2021 06:58:43 -0700 Subject: [PATCH 21/25] Fix ResourceWarning in limits.py --- tests/limits.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/limits.py b/tests/limits.py index 0e2c2d5..82b09c9 100755 --- a/tests/limits.py +++ b/tests/limits.py @@ -12,6 +12,9 @@ class BlockingSocket: self.s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.s.setblocking(True) self.s.connect((HOST, PORT)) + + def __del__(self): + self.s.close() def recv(self): out = b"" From e213af322606b751a41699aba0296a6e18109ec4 Mon Sep 17 00:00:00 2001 From: Jessie Murray Date: Mon, 2 Aug 2021 07:54:32 -0700 Subject: [PATCH 22/25] Fix for WebSocket payload length using 8 bytes The 8-byte conversion macros were incorrect, and could be replaced with standard methods instead. This also adds a test to cover this case. --- src/websocket.c | 25 ++++++++++--------------- tests/ws-tests.py | 24 ++++++++++++++++++++++++ 2 files changed, 34 insertions(+), 15 deletions(-) diff --git a/src/websocket.c b/src/websocket.c index edcb9ff..57f47e2 100644 --- a/src/websocket.c +++ b/src/websocket.c @@ -29,18 +29,6 @@ ws_schedule_write(struct ws_client *ws); * A copy is available at http://www.rfc-editor.org/rfc/rfc6455.txt */ -/* custom 64-bit encoding functions to avoid portability issues */ -#define webdis_ntohl64(p) \ - ((((uint64_t)((p)[0])) << 0) + (((uint64_t)((p)[1])) << 8) +\ - (((uint64_t)((p)[2])) << 16) + (((uint64_t)((p)[3])) << 24) +\ - (((uint64_t)((p)[4])) << 32) + (((uint64_t)((p)[5])) << 40) +\ - (((uint64_t)((p)[6])) << 48) + (((uint64_t)((p)[7])) << 56)) - -#define webdis_htonl64(p) {\ - (char)(((p & ((uint64_t)0xff << 0)) >> 0) & 0xff), (char)(((p & ((uint64_t)0xff << 8)) >> 8) & 0xff), \ - (char)(((p & ((uint64_t)0xff << 16)) >> 16) & 0xff), (char)(((p & ((uint64_t)0xff << 24)) >> 24) & 0xff), \ - (char)(((p & ((uint64_t)0xff << 32)) >> 32) & 0xff), (char)(((p & ((uint64_t)0xff << 40)) >> 40) & 0xff), \ - (char)(((p & ((uint64_t)0xff << 48)) >> 48) & 0xff), (char)(((p & ((uint64_t)0xff << 56)) >> 56) & 0xff) } static int ws_compute_handshake(struct http_client *c, char *out, size_t *out_sz) { @@ -411,7 +399,7 @@ ws_peek_data(struct ws_client *ws, struct ws_msg **out_msg) { p = frame + 4 + (has_mask ? 4 : 0); if(has_mask) memcpy(&mask, frame + 4, sizeof(mask)); } else if(len == 127) { - len = webdis_ntohl64(frame+2); + len = ntohll(*(uint64_t*)(frame+2)); p = frame + 10 + (has_mask ? 4 : 0); if(has_mask) memcpy(&mask, frame + 10, sizeof(mask)); } else { @@ -532,7 +520,9 @@ ws_frame_and_send_response(struct ws_client *ws, enum ws_frame_type frame_type, memcpy(frame + 4, p, sz); frame_sz = sz + 4; } else { /* sz > 65536 */ - char sz64[8] = webdis_htonl64(sz); + uint64_t sz_be = htonll(sz); + char sz64[8]; + memcpy(sz64, &sz_be, 8); frame[1] = 127; memcpy(frame + 2, sz64, 8); memcpy(frame + 10, p, sz); @@ -577,7 +567,12 @@ ws_can_read(int fd, short event, void *p) { } else if(ws->close_after_events) { ws_close_if_able(ws); } else { - ws_process_read_data(ws, NULL); + enum ws_state state = ws_process_read_data(ws, NULL); + if(state == WS_READING) { /* need more data, schedule new read */ + ws_monitor_input(ws); + } else if(state == WS_ERROR) { + ws_close_if_able(ws); + } } } diff --git a/tests/ws-tests.py b/tests/ws-tests.py index 459ab7a..f463190 100755 --- a/tests/ws-tests.py +++ b/tests/ws-tests.py @@ -167,5 +167,29 @@ class TestPubSub(unittest.TestCase): f'In {channel}: expected at offset {i} was "{expected}", actual was: "{received_per_channel[channel][i]}"') +class TestFrameSizes(TestWebdis): + def format(self): + return 'json' + + def serialize(self, cmd, *args): + return json.dumps([cmd] + list(args)) + + def deserialize(self, response): + return json.loads(response) + + def test_length_126(self): + key = str(uuid.uuid4()) + value = 'A' * 1024 # this will require 2 bytes to encode the length + self.assertEqual(self.exec('SET', key, value), {'SET': [True, 'OK']}) + self.exec('DEL', key) + + def test_length_127(self): + key = str(uuid.uuid4()) + value = 'A' * (2 ** 18) # this will require more than 2 bytes to encode the length (actually using 8) + self.assertEqual(self.exec('SET', key, value), {'SET': [True, 'OK']}) + self.exec('DEL', key) + + + if __name__ == '__main__': unittest.main() From 3be189b527f23bdf2d499d7ea50dfa1a75caf452 Mon Sep 17 00:00:00 2001 From: Jessie Murray Date: Mon, 2 Aug 2021 09:47:46 -0700 Subject: [PATCH 23/25] Use macro for htonll/ntohll for portability I realized that these functions are not standard, so this is bringing macros back for the 64-bit transforms. --- src/websocket.c | 12 ++++++++++-- tests/ws-tests.py | 11 +++++------ 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/src/websocket.c b/src/websocket.c index 57f47e2..2d85ce8 100644 --- a/src/websocket.c +++ b/src/websocket.c @@ -28,6 +28,13 @@ ws_schedule_write(struct ws_client *ws); * This code uses the WebSocket specification from RFC 6455. * A copy is available at http://www.rfc-editor.org/rfc/rfc6455.txt */ +#if __BIG_ENDIAN__ +# define webdis_htonll(x) (x) +# define webdis_ntohll(x) (x) +#else +# define webdis_htonll(x) (((uint64_t)htonl((x) & 0xFFFFFFFF) << 32) | htonl((x) >> 32)) +# define webdis_ntohll(x) (((uint64_t)ntohl((x) & 0xFFFFFFFF) << 32) | ntohl((x) >> 32)) +#endif static int ws_compute_handshake(struct http_client *c, char *out, size_t *out_sz) { @@ -399,7 +406,8 @@ ws_peek_data(struct ws_client *ws, struct ws_msg **out_msg) { p = frame + 4 + (has_mask ? 4 : 0); if(has_mask) memcpy(&mask, frame + 4, sizeof(mask)); } else if(len == 127) { - len = ntohll(*(uint64_t*)(frame+2)); + uint64_t sz64 = *((uint64_t*)(frame+2)); + len = webdis_ntohll(sz64); p = frame + 10 + (has_mask ? 4 : 0); if(has_mask) memcpy(&mask, frame + 10, sizeof(mask)); } else { @@ -520,7 +528,7 @@ ws_frame_and_send_response(struct ws_client *ws, enum ws_frame_type frame_type, memcpy(frame + 4, p, sz); frame_sz = sz + 4; } else { /* sz > 65536 */ - uint64_t sz_be = htonll(sz); + uint64_t sz_be = webdis_htonll(sz); /* big endian */ char sz64[8]; memcpy(sz64, &sz_be, 8); frame[1] = 127; diff --git a/tests/ws-tests.py b/tests/ws-tests.py index f463190..cb038b2 100755 --- a/tests/ws-tests.py +++ b/tests/ws-tests.py @@ -178,18 +178,17 @@ class TestFrameSizes(TestWebdis): return json.loads(response) def test_length_126(self): - key = str(uuid.uuid4()) - value = 'A' * 1024 # this will require 2 bytes to encode the length - self.assertEqual(self.exec('SET', key, value), {'SET': [True, 'OK']}) - self.exec('DEL', key) + self.validate_set_get('A' * 1024) # this will require 2 bytes to encode the length def test_length_127(self): + self.validate_set_get('A' * (2 ** 18)) # this will require more than 2 bytes to encode the length (actually using 8) + + def validate_set_get(self, value): key = str(uuid.uuid4()) - value = 'A' * (2 ** 18) # this will require more than 2 bytes to encode the length (actually using 8) self.assertEqual(self.exec('SET', key, value), {'SET': [True, 'OK']}) + self.assertEqual(self.exec('GET', key), {'GET': value}) self.exec('DEL', key) - if __name__ == '__main__': unittest.main() From 33b2923b3a230fa5aa1c69a254416794dded5f6d Mon Sep 17 00:00:00 2001 From: Jessie Murray Date: Mon, 2 Aug 2021 10:44:17 -0700 Subject: [PATCH 24/25] Make sure to reserve enough space for large frames Add description of header and increase header size from 8 to 14 bytes. --- src/websocket.c | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/websocket.c b/src/websocket.c index 2d85ce8..ad8d489 100644 --- a/src/websocket.c +++ b/src/websocket.c @@ -504,7 +504,13 @@ ws_process_read_data(struct ws_client *ws, unsigned int *out_processed) { int ws_frame_and_send_response(struct ws_client *ws, enum ws_frame_type frame_type, const char *p, size_t sz) { - char *frame = malloc(sz + 8); /* create frame by prepending header */ + /* we can have as much as 14 bytes in the header: + * 1 byte for 4 flag bits + 4 frame type bits + * 1 byte for the payload length indicator + * 8 bytes for the size of the payload (at most) + * 4 bytes for the masking key (if present) + */ + char *frame = malloc(sz + 14); /* create frame by prepending header */ size_t frame_sz = 0; if(frame == NULL) return -1; From a8612e846e3034d242ac2cc009029caa3f82a4d7 Mon Sep 17 00:00:00 2001 From: Jessie Murray Date: Mon, 2 Aug 2021 20:50:12 -0700 Subject: [PATCH 25/25] Close fd if needed in ws_client_free --- src/websocket.c | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/websocket.c b/src/websocket.c index ad8d489..e307de2 100644 --- a/src/websocket.c +++ b/src/websocket.c @@ -121,7 +121,10 @@ ws_client_free(struct ws_client *ws) { pool_free_context(ws->ac); /* could trigger a cb via format_send_error */ struct http_client *c = ws->http_client; - if(c) c->ws = NULL; /* detach if needed */ + if(c) { + close(c->fd); + c->ws = NULL; /* detach if needed */ + } evbuffer_free(ws->rbuf); evbuffer_free(ws->wbuf); if(ws->cmd) {