diff --git a/src/websocket.c b/src/websocket.c index edcb9ff..57f47e2 100644 --- a/src/websocket.c +++ b/src/websocket.c @@ -29,18 +29,6 @@ ws_schedule_write(struct ws_client *ws); * A copy is available at http://www.rfc-editor.org/rfc/rfc6455.txt */ -/* custom 64-bit encoding functions to avoid portability issues */ -#define webdis_ntohl64(p) \ - ((((uint64_t)((p)[0])) << 0) + (((uint64_t)((p)[1])) << 8) +\ - (((uint64_t)((p)[2])) << 16) + (((uint64_t)((p)[3])) << 24) +\ - (((uint64_t)((p)[4])) << 32) + (((uint64_t)((p)[5])) << 40) +\ - (((uint64_t)((p)[6])) << 48) + (((uint64_t)((p)[7])) << 56)) - -#define webdis_htonl64(p) {\ - (char)(((p & ((uint64_t)0xff << 0)) >> 0) & 0xff), (char)(((p & ((uint64_t)0xff << 8)) >> 8) & 0xff), \ - (char)(((p & ((uint64_t)0xff << 16)) >> 16) & 0xff), (char)(((p & ((uint64_t)0xff << 24)) >> 24) & 0xff), \ - (char)(((p & ((uint64_t)0xff << 32)) >> 32) & 0xff), (char)(((p & ((uint64_t)0xff << 40)) >> 40) & 0xff), \ - (char)(((p & ((uint64_t)0xff << 48)) >> 48) & 0xff), (char)(((p & ((uint64_t)0xff << 56)) >> 56) & 0xff) } static int ws_compute_handshake(struct http_client *c, char *out, size_t *out_sz) { @@ -411,7 +399,7 @@ ws_peek_data(struct ws_client *ws, struct ws_msg **out_msg) { p = frame + 4 + (has_mask ? 4 : 0); if(has_mask) memcpy(&mask, frame + 4, sizeof(mask)); } else if(len == 127) { - len = webdis_ntohl64(frame+2); + len = ntohll(*(uint64_t*)(frame+2)); p = frame + 10 + (has_mask ? 4 : 0); if(has_mask) memcpy(&mask, frame + 10, sizeof(mask)); } else { @@ -532,7 +520,9 @@ ws_frame_and_send_response(struct ws_client *ws, enum ws_frame_type frame_type, memcpy(frame + 4, p, sz); frame_sz = sz + 4; } else { /* sz > 65536 */ - char sz64[8] = webdis_htonl64(sz); + uint64_t sz_be = htonll(sz); + char sz64[8]; + memcpy(sz64, &sz_be, 8); frame[1] = 127; memcpy(frame + 2, sz64, 8); memcpy(frame + 10, p, sz); @@ -577,7 +567,12 @@ ws_can_read(int fd, short event, void *p) { } else if(ws->close_after_events) { ws_close_if_able(ws); } else { - ws_process_read_data(ws, NULL); + enum ws_state state = ws_process_read_data(ws, NULL); + if(state == WS_READING) { /* need more data, schedule new read */ + ws_monitor_input(ws); + } else if(state == WS_ERROR) { + ws_close_if_able(ws); + } } } diff --git a/tests/ws-tests.py b/tests/ws-tests.py index 459ab7a..f463190 100755 --- a/tests/ws-tests.py +++ b/tests/ws-tests.py @@ -167,5 +167,29 @@ class TestPubSub(unittest.TestCase): f'In {channel}: expected at offset {i} was "{expected}", actual was: "{received_per_channel[channel][i]}"') +class TestFrameSizes(TestWebdis): + def format(self): + return 'json' + + def serialize(self, cmd, *args): + return json.dumps([cmd] + list(args)) + + def deserialize(self, response): + return json.loads(response) + + def test_length_126(self): + key = str(uuid.uuid4()) + value = 'A' * 1024 # this will require 2 bytes to encode the length + self.assertEqual(self.exec('SET', key, value), {'SET': [True, 'OK']}) + self.exec('DEL', key) + + def test_length_127(self): + key = str(uuid.uuid4()) + value = 'A' * (2 ** 18) # this will require more than 2 bytes to encode the length (actually using 8) + self.assertEqual(self.exec('SET', key, value), {'SET': [True, 'OK']}) + self.exec('DEL', key) + + + if __name__ == '__main__': unittest.main()