From 73f29055c1c1224c500885109a55cbfe77f65aae Mon Sep 17 00:00:00 2001 From: Nicolas Favre-Felix Date: Sat, 22 Jan 2022 13:54:03 -0800 Subject: [PATCH] Improvements to ws_peek_data (by @majklik) Better handling of WS client frames, contributed in the comments of #212: * Reject unmasked frames as per RFC 6455 * Avoid unnecessary data copy from/to evbuffer * Remove conditions on has_mask 2 new tests cover this change: * minimal ping-pong with masked client frame, unmasked response * rejected unmasked client frame --- src/websocket.c | 57 +++++++++++++++-------- tests/ws-tests.py | 112 +++++++++++++++++++++++++++++++++++----------- 2 files changed, 124 insertions(+), 45 deletions(-) diff --git a/src/websocket.c b/src/websocket.c index 5d3efe0..498f6b3 100644 --- a/src/websocket.c +++ b/src/websocket.c @@ -379,7 +379,7 @@ ws_peek_data(struct ws_client *ws, struct ws_msg **out_msg) { /* parse frame and extract contents */ size_t sz = evbuffer_get_length(ws->rbuf); - if(sz < 8) { + if(sz < 2) { return WS_READING; /* need more data */ } /* copy into "frame" to process it */ @@ -387,7 +387,7 @@ ws_peek_data(struct ws_client *ws, struct ws_msg **out_msg) { if(!frame) { return WS_ERROR; } - int rem_ret = evbuffer_remove(ws->rbuf, frame, sz); + int rem_ret = evbuffer_copyout(ws->rbuf, frame, sz); /* copy into frame but keep in rbuf */ if(rem_ret < 0) { free(frame); return WS_ERROR; @@ -397,32 +397,54 @@ ws_peek_data(struct ws_client *ws, struct ws_msg **out_msg) { frame_type = frame[0] & 0x0F; /* lower 4 bits of first byte */ has_mask = frame[1] & 0x80 ? 1:0; + if(!has_mask) { + /* a client MUST mask all frames that it sends to the server (RFC6455, 5.1. Overview) */ + ws->close_after_events = 1; + const char close_code_reason[] = "\x03\xeaReceived a frame without a mask from the client (violates RFC6455, 5.1. Overview)."; /* 0x03,0xEA = 1002 - protocol error */ + ws_frame_and_send_response(ws, WS_CONNECTION_CLOSE, close_code_reason, sizeof(close_code_reason)-1); + free(frame); + return WS_ERROR; + } + /* get payload length */ len = frame[1] & 0x7f; /* remove leftmost bit */ + + /* checking that the copyout frame contains the minimum data needed to determine the true length and mask in next step */ + size_t min_sz = 6; /* 2 bytes for flags and opcode + 4 bytes for mask */ + if(len == 126) { + min_sz += sizeof(uint16_t); + } else if(len == 127) { + min_sz += sizeof(uint64_t); + } + if(sz < min_sz) { /* not enough data */ + free(frame); + return WS_READING; + } + + /* determine payload size (RFC 6455, section 5.2) */ if(len <= 125) { /* data starts right after the mask */ - p = frame + 2 + (has_mask ? 4 : 0); - if(has_mask) memcpy(&mask, frame + 2, sizeof(mask)); - } else if(len == 126) { + p = frame + 6; + memcpy(&mask, frame + 2, sizeof(mask)); + } else if(len == 126) { /* size is stored in 16 bits after mask */ uint16_t sz16; memcpy(&sz16, frame + 2, sizeof(uint16_t)); len = ntohs(sz16); - p = frame + 4 + (has_mask ? 4 : 0); - if(has_mask) memcpy(&mask, frame + 4, sizeof(mask)); - } else if(len == 127) { + p = frame + 6 + sizeof(uint16_t); + memcpy(&mask, frame + 4, sizeof(mask)); + } else if(len == 127) { /* size is stored in 64 bits after mask */ uint64_t sz64 = *((uint64_t*)(frame+2)); len = webdis_ntohll(sz64); - p = frame + 10 + (has_mask ? 4 : 0); - if(has_mask) memcpy(&mask, frame + 10, sizeof(mask)); + p = frame + 6 + sizeof(uint64_t); + memcpy(&mask, frame + 10, sizeof(mask)); } else { free(frame); return WS_ERROR; } - /* we now have the (possibly masked) data starting in p, and its length. */ + /* we now have the masked data starting in p, and its length. */ if(len > sz - (p - frame)) { /* not enough data */ - int add_ret = evbuffer_prepend(ws->rbuf, frame, sz); /* put the whole frame back */ free(frame); - return add_ret < 0 ? WS_ERROR : WS_READING; + return WS_READING; } int ev_copy = 0; @@ -435,18 +457,15 @@ ws_peek_data(struct ws_client *ws, struct ws_msg **out_msg) { *out_msg = msg; /* attach for it to be freed by caller */ /* create new ws_msg object holding what we read */ - int add_ret = ws_msg_add(msg, p, len, has_mask ? mask : NULL); - if(!add_ret) { + int add_ret = ws_msg_add(msg, p, len, mask); + if(add_ret < 0) { free(frame); return WS_ERROR; } size_t processed_sz = len + (p - frame); /* length of data + header bytes between frame start and payload */ msg->total_sz += processed_sz; - - ev_copy = evbuffer_prepend(ws->rbuf, frame + len, sz - processed_sz); /* remove processed data */ - } else { /* we're just peeking */ - ev_copy = evbuffer_prepend(ws->rbuf, frame, sz); /* put the whole frame back */ + ev_copy = evbuffer_drain(ws->rbuf, processed_sz); /* remove processed data from evbuffer */ } free(frame); diff --git a/tests/ws-tests.py b/tests/ws-tests.py index d5fc929..43549b7 100755 --- a/tests/ws-tests.py +++ b/tests/ws-tests.py @@ -1,8 +1,10 @@ #!/usr/bin/env python3 import abc +from dataclasses import dataclass import json import os +import struct import unittest import uuid from websocket import create_connection @@ -194,37 +196,59 @@ class TestFrameSizes(TestWebdis): self.exec('DEL', key) -class TestConnectDisconnect(TestWebdis): +@dataclass +class WebsocketHandshake: + b64_key: bytes + b64_hash: bytes + + +class TestRawConnection(TestWebdis): + """Base class for WS tests using sockets""" + + def create_random_handshake(self) -> WebsocketHandshake: + # Build WS handshake request + raw_key = os.urandom(16) + b64_key = base64.b64encode(raw_key) + magic = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11' + expected_raw = b64_key + magic + hash = hashlib.sha1(expected_raw).digest() + b64_hash = base64.b64encode(hash) + return WebsocketHandshake(b64_key, b64_hash) + + def format_handshake_request(self, host: str, port: int, handshake: WebsocketHandshake) -> str: + return ("GET /.json HTTP/1.1\r\n" + \ + "Host: %s:%d\r\n" + \ + "Connection: Upgrade\r\n" + \ + "Upgrade: WebSocket\r\n" + \ + "Origin: http://%s:%d\r\n" + \ + "Sec-WebSocket-Key: %s\r\n" + \ + "\r\n") % (host, port, host, port, handshake.b64_key.decode('utf-8')) + + def send_handshake_request(self, sock: socket.socket, request: str, handshake: WebsocketHandshake) -> None: + """Send handshake and validate response""" + sock.send(request.encode('utf-8')) + response = sock.recv(1024) + lines = response.decode('utf-8').split('\r\n') + self.assertEqual(lines[0], 'HTTP/1.1 101 Switching Protocols') + self.assertTrue('Sec-WebSocket-Accept: %s' % handshake.b64_hash.decode('utf-8') in lines) + + def connect_and_handshake(self, sock: socket.socket) -> None: + global host, port + sock.connect((host, port)) + + # establish WS connection with handshake + handshake = self.create_random_handshake() + request = self.format_handshake_request(host, port, handshake) + self.send_handshake_request(sock, request, handshake) + + +class TestConnectDisconnect(TestRawConnection): """Test for issue #209. A client connects, receives their handshake, disconnects""" def test_connect_handshake_disconnect(self): # Connect to Webdis with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: - global host, port - sock.connect((host, port)) - - # Build WS handshake request - raw_key = os.urandom(16) - b64_key = base64.b64encode(raw_key) - magic = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11' - expected_raw = b64_key + magic - hash = hashlib.sha1(expected_raw).digest() - b64_hash = base64.b64encode(hash) - - request = ("GET /.json HTTP/1.1\r\n" + \ - "Host: %s:%d\r\n" + \ - "Connection: Upgrade\r\n" + \ - "Upgrade: WebSocket\r\n" + \ - "Origin: http://%s:%d\r\n" + \ - "Sec-WebSocket-Key: %s\r\n" + \ - "\r\n") % (host, port, host, port, b64_key.decode('utf-8')) - - # Send handshake and validate response - sock.send(request.encode('utf-8')) - response = sock.recv(1024) - lines = response.decode('utf-8').split('\r\n') - self.assertEqual(lines[0], 'HTTP/1.1 101 Switching Protocols') - self.assertTrue('Sec-WebSocket-Accept: %s' % b64_hash.decode('utf-8') in lines) + self.connect_and_handshake(sock) # send FIN frame. Format: # 4 bits flags (0x8 = Fin: true) @@ -242,6 +266,42 @@ class TestConnectDisconnect(TestWebdis): with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: sock.connect((host, port)) +class TestMinimalPingPong(TestRawConnection): + """Test that we receive a PONG response to a 6-byte PING, the smallest possible frame""" + + def test_minimal_ping_pong(self): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + self.connect_and_handshake(sock) + + # send PING frame. Format: + ping_frame = bytes([0b1000_1001, 0b1000_0000, 0x12, 0x34, 0x56, 0x78]) # 0x09 0x80: FIN=true, opcode=ping(9), mask=true, payload_length=0 + sock.send(ping_frame) + + # receive PONG frame. + out = sock.recv(4) + self.assertEqual(len(out), 2) # should be just 2 bytes, the FIN and pong opcode + self.assertEqual(out[0], 0b1000_1010) # FIN=true, opcode=pong(0xA) + self.assertEqual(out[1], 0x00) # mask=false, payload_length=0 + +class TestUnmaskedDataFrame(TestRawConnection): + """Test that we reject unmasked frames""" + + def test_unmasked_redis_ping(self): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + self.connect_and_handshake(sock) + + # send unmasked frame querying Redis with a "PING" command. Format: + payload = b'["PING"]' + unmasked_frame = bytes([0b10000001, len(payload)]) + payload # 0x81 0x08: FIN=true, opcode=text_frame, mask=false, payload_length=8 + sock.send(unmasked_frame) + + # receive error frame. + out = sock.recv(200) + self.assertEqual(out[0:1], b'\x88') # 0x88: FIN=true, opcode=close, mask=false + self.assertEqual(out[1:2], bytes([len(out) - 2])) # payload length, minus the 2 bytes of the header + self.assertEqual(out[2:4], struct.pack('>h', 1002)) # error code 1002: protocol error + + if __name__ == '__main__': unittest.main()