Improvements to ws_peek_data (by @majklik)

Better handling of WS client frames, contributed in the comments of #212:
* Reject unmasked frames as per RFC 6455
* Avoid unnecessary data copy from/to evbuffer
* Remove conditions on has_mask

2 new tests cover this change:
* minimal ping-pong with masked client frame, unmasked response
* rejected unmasked client frame
master
Nicolas Favre-Felix 3 years ago
parent d28dd3ec80
commit 73f29055c1
No known key found for this signature in database
GPG Key ID: C04E7AA8B6F73372

@ -379,7 +379,7 @@ ws_peek_data(struct ws_client *ws, struct ws_msg **out_msg) {
/* parse frame and extract contents */ /* parse frame and extract contents */
size_t sz = evbuffer_get_length(ws->rbuf); size_t sz = evbuffer_get_length(ws->rbuf);
if(sz < 8) { if(sz < 2) {
return WS_READING; /* need more data */ return WS_READING; /* need more data */
} }
/* copy into "frame" to process it */ /* copy into "frame" to process it */
@ -387,7 +387,7 @@ ws_peek_data(struct ws_client *ws, struct ws_msg **out_msg) {
if(!frame) { if(!frame) {
return WS_ERROR; return WS_ERROR;
} }
int rem_ret = evbuffer_remove(ws->rbuf, frame, sz); int rem_ret = evbuffer_copyout(ws->rbuf, frame, sz); /* copy into frame but keep in rbuf */
if(rem_ret < 0) { if(rem_ret < 0) {
free(frame); free(frame);
return WS_ERROR; return WS_ERROR;
@ -397,32 +397,54 @@ ws_peek_data(struct ws_client *ws, struct ws_msg **out_msg) {
frame_type = frame[0] & 0x0F; /* lower 4 bits of first byte */ frame_type = frame[0] & 0x0F; /* lower 4 bits of first byte */
has_mask = frame[1] & 0x80 ? 1:0; has_mask = frame[1] & 0x80 ? 1:0;
if(!has_mask) {
/* a client MUST mask all frames that it sends to the server (RFC6455, 5.1. Overview) */
ws->close_after_events = 1;
const char close_code_reason[] = "\x03\xeaReceived a frame without a mask from the client (violates RFC6455, 5.1. Overview)."; /* 0x03,0xEA = 1002 - protocol error */
ws_frame_and_send_response(ws, WS_CONNECTION_CLOSE, close_code_reason, sizeof(close_code_reason)-1);
free(frame);
return WS_ERROR;
}
/* get payload length */ /* get payload length */
len = frame[1] & 0x7f; /* remove leftmost bit */ len = frame[1] & 0x7f; /* remove leftmost bit */
/* checking that the copyout frame contains the minimum data needed to determine the true length and mask in next step */
size_t min_sz = 6; /* 2 bytes for flags and opcode + 4 bytes for mask */
if(len == 126) {
min_sz += sizeof(uint16_t);
} else if(len == 127) {
min_sz += sizeof(uint64_t);
}
if(sz < min_sz) { /* not enough data */
free(frame);
return WS_READING;
}
/* determine payload size (RFC 6455, section 5.2) */
if(len <= 125) { /* data starts right after the mask */ if(len <= 125) { /* data starts right after the mask */
p = frame + 2 + (has_mask ? 4 : 0); p = frame + 6;
if(has_mask) memcpy(&mask, frame + 2, sizeof(mask)); memcpy(&mask, frame + 2, sizeof(mask));
} else if(len == 126) { } else if(len == 126) { /* size is stored in 16 bits after mask */
uint16_t sz16; uint16_t sz16;
memcpy(&sz16, frame + 2, sizeof(uint16_t)); memcpy(&sz16, frame + 2, sizeof(uint16_t));
len = ntohs(sz16); len = ntohs(sz16);
p = frame + 4 + (has_mask ? 4 : 0); p = frame + 6 + sizeof(uint16_t);
if(has_mask) memcpy(&mask, frame + 4, sizeof(mask)); memcpy(&mask, frame + 4, sizeof(mask));
} else if(len == 127) { } else if(len == 127) { /* size is stored in 64 bits after mask */
uint64_t sz64 = *((uint64_t*)(frame+2)); uint64_t sz64 = *((uint64_t*)(frame+2));
len = webdis_ntohll(sz64); len = webdis_ntohll(sz64);
p = frame + 10 + (has_mask ? 4 : 0); p = frame + 6 + sizeof(uint64_t);
if(has_mask) memcpy(&mask, frame + 10, sizeof(mask)); memcpy(&mask, frame + 10, sizeof(mask));
} else { } else {
free(frame); free(frame);
return WS_ERROR; return WS_ERROR;
} }
/* we now have the (possibly masked) data starting in p, and its length. */ /* we now have the masked data starting in p, and its length. */
if(len > sz - (p - frame)) { /* not enough data */ if(len > sz - (p - frame)) { /* not enough data */
int add_ret = evbuffer_prepend(ws->rbuf, frame, sz); /* put the whole frame back */
free(frame); free(frame);
return add_ret < 0 ? WS_ERROR : WS_READING; return WS_READING;
} }
int ev_copy = 0; int ev_copy = 0;
@ -435,18 +457,15 @@ ws_peek_data(struct ws_client *ws, struct ws_msg **out_msg) {
*out_msg = msg; /* attach for it to be freed by caller */ *out_msg = msg; /* attach for it to be freed by caller */
/* create new ws_msg object holding what we read */ /* create new ws_msg object holding what we read */
int add_ret = ws_msg_add(msg, p, len, has_mask ? mask : NULL); int add_ret = ws_msg_add(msg, p, len, mask);
if(!add_ret) { if(add_ret < 0) {
free(frame); free(frame);
return WS_ERROR; return WS_ERROR;
} }
size_t processed_sz = len + (p - frame); /* length of data + header bytes between frame start and payload */ size_t processed_sz = len + (p - frame); /* length of data + header bytes between frame start and payload */
msg->total_sz += processed_sz; msg->total_sz += processed_sz;
ev_copy = evbuffer_drain(ws->rbuf, processed_sz); /* remove processed data from evbuffer */
ev_copy = evbuffer_prepend(ws->rbuf, frame + len, sz - processed_sz); /* remove processed data */
} else { /* we're just peeking */
ev_copy = evbuffer_prepend(ws->rbuf, frame, sz); /* put the whole frame back */
} }
free(frame); free(frame);

@ -1,8 +1,10 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import abc import abc
from dataclasses import dataclass
import json import json
import os import os
import struct
import unittest import unittest
import uuid import uuid
from websocket import create_connection from websocket import create_connection
@ -194,37 +196,59 @@ class TestFrameSizes(TestWebdis):
self.exec('DEL', key) self.exec('DEL', key)
class TestConnectDisconnect(TestWebdis): @dataclass
class WebsocketHandshake:
b64_key: bytes
b64_hash: bytes
class TestRawConnection(TestWebdis):
"""Base class for WS tests using sockets"""
def create_random_handshake(self) -> WebsocketHandshake:
# Build WS handshake request
raw_key = os.urandom(16)
b64_key = base64.b64encode(raw_key)
magic = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
expected_raw = b64_key + magic
hash = hashlib.sha1(expected_raw).digest()
b64_hash = base64.b64encode(hash)
return WebsocketHandshake(b64_key, b64_hash)
def format_handshake_request(self, host: str, port: int, handshake: WebsocketHandshake) -> str:
return ("GET /.json HTTP/1.1\r\n" + \
"Host: %s:%d\r\n" + \
"Connection: Upgrade\r\n" + \
"Upgrade: WebSocket\r\n" + \
"Origin: http://%s:%d\r\n" + \
"Sec-WebSocket-Key: %s\r\n" + \
"\r\n") % (host, port, host, port, handshake.b64_key.decode('utf-8'))
def send_handshake_request(self, sock: socket.socket, request: str, handshake: WebsocketHandshake) -> None:
"""Send handshake and validate response"""
sock.send(request.encode('utf-8'))
response = sock.recv(1024)
lines = response.decode('utf-8').split('\r\n')
self.assertEqual(lines[0], 'HTTP/1.1 101 Switching Protocols')
self.assertTrue('Sec-WebSocket-Accept: %s' % handshake.b64_hash.decode('utf-8') in lines)
def connect_and_handshake(self, sock: socket.socket) -> None:
global host, port
sock.connect((host, port))
# establish WS connection with handshake
handshake = self.create_random_handshake()
request = self.format_handshake_request(host, port, handshake)
self.send_handshake_request(sock, request, handshake)
class TestConnectDisconnect(TestRawConnection):
"""Test for issue #209. A client connects, receives their handshake, disconnects""" """Test for issue #209. A client connects, receives their handshake, disconnects"""
def test_connect_handshake_disconnect(self): def test_connect_handshake_disconnect(self):
# Connect to Webdis # Connect to Webdis
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
global host, port self.connect_and_handshake(sock)
sock.connect((host, port))
# Build WS handshake request
raw_key = os.urandom(16)
b64_key = base64.b64encode(raw_key)
magic = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
expected_raw = b64_key + magic
hash = hashlib.sha1(expected_raw).digest()
b64_hash = base64.b64encode(hash)
request = ("GET /.json HTTP/1.1\r\n" + \
"Host: %s:%d\r\n" + \
"Connection: Upgrade\r\n" + \
"Upgrade: WebSocket\r\n" + \
"Origin: http://%s:%d\r\n" + \
"Sec-WebSocket-Key: %s\r\n" + \
"\r\n") % (host, port, host, port, b64_key.decode('utf-8'))
# Send handshake and validate response
sock.send(request.encode('utf-8'))
response = sock.recv(1024)
lines = response.decode('utf-8').split('\r\n')
self.assertEqual(lines[0], 'HTTP/1.1 101 Switching Protocols')
self.assertTrue('Sec-WebSocket-Accept: %s' % b64_hash.decode('utf-8') in lines)
# send FIN frame. Format: # send FIN frame. Format:
# 4 bits flags (0x8 = Fin: true) # 4 bits flags (0x8 = Fin: true)
@ -242,6 +266,42 @@ class TestConnectDisconnect(TestWebdis):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.connect((host, port)) sock.connect((host, port))
class TestMinimalPingPong(TestRawConnection):
"""Test that we receive a PONG response to a 6-byte PING, the smallest possible frame"""
def test_minimal_ping_pong(self):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
self.connect_and_handshake(sock)
# send PING frame. Format:
ping_frame = bytes([0b1000_1001, 0b1000_0000, 0x12, 0x34, 0x56, 0x78]) # 0x09 0x80: FIN=true, opcode=ping(9), mask=true, payload_length=0
sock.send(ping_frame)
# receive PONG frame.
out = sock.recv(4)
self.assertEqual(len(out), 2) # should be just 2 bytes, the FIN and pong opcode
self.assertEqual(out[0], 0b1000_1010) # FIN=true, opcode=pong(0xA)
self.assertEqual(out[1], 0x00) # mask=false, payload_length=0
class TestUnmaskedDataFrame(TestRawConnection):
"""Test that we reject unmasked frames"""
def test_unmasked_redis_ping(self):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
self.connect_and_handshake(sock)
# send unmasked frame querying Redis with a "PING" command. Format:
payload = b'["PING"]'
unmasked_frame = bytes([0b10000001, len(payload)]) + payload # 0x81 0x08: FIN=true, opcode=text_frame, mask=false, payload_length=8
sock.send(unmasked_frame)
# receive error frame.
out = sock.recv(200)
self.assertEqual(out[0:1], b'\x88') # 0x88: FIN=true, opcode=close, mask=false
self.assertEqual(out[1:2], bytes([len(out) - 2])) # payload length, minus the 2 bytes of the header
self.assertEqual(out[2:4], struct.pack('>h', 1002)) # error code 1002: protocol error
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

Loading…
Cancel
Save