diff --git a/src/websocket.c b/src/websocket.c index 57f47e2..2d85ce8 100644 --- a/src/websocket.c +++ b/src/websocket.c @@ -28,6 +28,13 @@ ws_schedule_write(struct ws_client *ws); * This code uses the WebSocket specification from RFC 6455. * A copy is available at http://www.rfc-editor.org/rfc/rfc6455.txt */ +#if __BIG_ENDIAN__ +# define webdis_htonll(x) (x) +# define webdis_ntohll(x) (x) +#else +# define webdis_htonll(x) (((uint64_t)htonl((x) & 0xFFFFFFFF) << 32) | htonl((x) >> 32)) +# define webdis_ntohll(x) (((uint64_t)ntohl((x) & 0xFFFFFFFF) << 32) | ntohl((x) >> 32)) +#endif static int ws_compute_handshake(struct http_client *c, char *out, size_t *out_sz) { @@ -399,7 +406,8 @@ ws_peek_data(struct ws_client *ws, struct ws_msg **out_msg) { p = frame + 4 + (has_mask ? 4 : 0); if(has_mask) memcpy(&mask, frame + 4, sizeof(mask)); } else if(len == 127) { - len = ntohll(*(uint64_t*)(frame+2)); + uint64_t sz64 = *((uint64_t*)(frame+2)); + len = webdis_ntohll(sz64); p = frame + 10 + (has_mask ? 4 : 0); if(has_mask) memcpy(&mask, frame + 10, sizeof(mask)); } else { @@ -520,7 +528,7 @@ ws_frame_and_send_response(struct ws_client *ws, enum ws_frame_type frame_type, memcpy(frame + 4, p, sz); frame_sz = sz + 4; } else { /* sz > 65536 */ - uint64_t sz_be = htonll(sz); + uint64_t sz_be = webdis_htonll(sz); /* big endian */ char sz64[8]; memcpy(sz64, &sz_be, 8); frame[1] = 127; diff --git a/tests/ws-tests.py b/tests/ws-tests.py index f463190..cb038b2 100755 --- a/tests/ws-tests.py +++ b/tests/ws-tests.py @@ -178,18 +178,17 @@ class TestFrameSizes(TestWebdis): return json.loads(response) def test_length_126(self): - key = str(uuid.uuid4()) - value = 'A' * 1024 # this will require 2 bytes to encode the length - self.assertEqual(self.exec('SET', key, value), {'SET': [True, 'OK']}) - self.exec('DEL', key) + self.validate_set_get('A' * 1024) # this will require 2 bytes to encode the length def test_length_127(self): + self.validate_set_get('A' * (2 ** 18)) # this will require more than 2 bytes to encode the length (actually using 8) + + def validate_set_get(self, value): key = str(uuid.uuid4()) - value = 'A' * (2 ** 18) # this will require more than 2 bytes to encode the length (actually using 8) self.assertEqual(self.exec('SET', key, value), {'SET': [True, 'OK']}) + self.assertEqual(self.exec('GET', key), {'GET': value}) self.exec('DEL', key) - if __name__ == '__main__': unittest.main()