Improvements to ws_peek_data (by @majklik)

Better handling of WS client frames, contributed in the comments of #212:
* Reject unmasked frames as per RFC 6455
* Avoid unnecessary data copy from/to evbuffer
* Remove conditions on has_mask

2 new tests cover this change:
* minimal ping-pong with masked client frame, unmasked response
* rejected unmasked client frame
master
Nicolas Favre-Felix 3 years ago
parent d28dd3ec80
commit 73f29055c1
No known key found for this signature in database
GPG Key ID: C04E7AA8B6F73372

@ -379,7 +379,7 @@ ws_peek_data(struct ws_client *ws, struct ws_msg **out_msg) {
/* parse frame and extract contents */
size_t sz = evbuffer_get_length(ws->rbuf);
if(sz < 8) {
if(sz < 2) {
return WS_READING; /* need more data */
}
/* copy into "frame" to process it */
@ -387,7 +387,7 @@ ws_peek_data(struct ws_client *ws, struct ws_msg **out_msg) {
if(!frame) {
return WS_ERROR;
}
int rem_ret = evbuffer_remove(ws->rbuf, frame, sz);
int rem_ret = evbuffer_copyout(ws->rbuf, frame, sz); /* copy into frame but keep in rbuf */
if(rem_ret < 0) {
free(frame);
return WS_ERROR;
@ -397,32 +397,54 @@ ws_peek_data(struct ws_client *ws, struct ws_msg **out_msg) {
frame_type = frame[0] & 0x0F; /* lower 4 bits of first byte */
has_mask = frame[1] & 0x80 ? 1:0;
if(!has_mask) {
/* a client MUST mask all frames that it sends to the server (RFC6455, 5.1. Overview) */
ws->close_after_events = 1;
const char close_code_reason[] = "\x03\xeaReceived a frame without a mask from the client (violates RFC6455, 5.1. Overview)."; /* 0x03,0xEA = 1002 - protocol error */
ws_frame_and_send_response(ws, WS_CONNECTION_CLOSE, close_code_reason, sizeof(close_code_reason)-1);
free(frame);
return WS_ERROR;
}
/* get payload length */
len = frame[1] & 0x7f; /* remove leftmost bit */
/* checking that the copyout frame contains the minimum data needed to determine the true length and mask in next step */
size_t min_sz = 6; /* 2 bytes for flags and opcode + 4 bytes for mask */
if(len == 126) {
min_sz += sizeof(uint16_t);
} else if(len == 127) {
min_sz += sizeof(uint64_t);
}
if(sz < min_sz) { /* not enough data */
free(frame);
return WS_READING;
}
/* determine payload size (RFC 6455, section 5.2) */
if(len <= 125) { /* data starts right after the mask */
p = frame + 2 + (has_mask ? 4 : 0);
if(has_mask) memcpy(&mask, frame + 2, sizeof(mask));
} else if(len == 126) {
p = frame + 6;
memcpy(&mask, frame + 2, sizeof(mask));
} else if(len == 126) { /* size is stored in 16 bits after mask */
uint16_t sz16;
memcpy(&sz16, frame + 2, sizeof(uint16_t));
len = ntohs(sz16);
p = frame + 4 + (has_mask ? 4 : 0);
if(has_mask) memcpy(&mask, frame + 4, sizeof(mask));
} else if(len == 127) {
p = frame + 6 + sizeof(uint16_t);
memcpy(&mask, frame + 4, sizeof(mask));
} else if(len == 127) { /* size is stored in 64 bits after mask */
uint64_t sz64 = *((uint64_t*)(frame+2));
len = webdis_ntohll(sz64);
p = frame + 10 + (has_mask ? 4 : 0);
if(has_mask) memcpy(&mask, frame + 10, sizeof(mask));
p = frame + 6 + sizeof(uint64_t);
memcpy(&mask, frame + 10, sizeof(mask));
} else {
free(frame);
return WS_ERROR;
}
/* we now have the (possibly masked) data starting in p, and its length. */
/* we now have the masked data starting in p, and its length. */
if(len > sz - (p - frame)) { /* not enough data */
int add_ret = evbuffer_prepend(ws->rbuf, frame, sz); /* put the whole frame back */
free(frame);
return add_ret < 0 ? WS_ERROR : WS_READING;
return WS_READING;
}
int ev_copy = 0;
@ -435,18 +457,15 @@ ws_peek_data(struct ws_client *ws, struct ws_msg **out_msg) {
*out_msg = msg; /* attach for it to be freed by caller */
/* create new ws_msg object holding what we read */
int add_ret = ws_msg_add(msg, p, len, has_mask ? mask : NULL);
if(!add_ret) {
int add_ret = ws_msg_add(msg, p, len, mask);
if(add_ret < 0) {
free(frame);
return WS_ERROR;
}
size_t processed_sz = len + (p - frame); /* length of data + header bytes between frame start and payload */
msg->total_sz += processed_sz;
ev_copy = evbuffer_prepend(ws->rbuf, frame + len, sz - processed_sz); /* remove processed data */
} else { /* we're just peeking */
ev_copy = evbuffer_prepend(ws->rbuf, frame, sz); /* put the whole frame back */
ev_copy = evbuffer_drain(ws->rbuf, processed_sz); /* remove processed data from evbuffer */
}
free(frame);

@ -1,8 +1,10 @@
#!/usr/bin/env python3
import abc
from dataclasses import dataclass
import json
import os
import struct
import unittest
import uuid
from websocket import create_connection
@ -194,15 +196,16 @@ class TestFrameSizes(TestWebdis):
self.exec('DEL', key)
class TestConnectDisconnect(TestWebdis):
"""Test for issue #209. A client connects, receives their handshake, disconnects"""
@dataclass
class WebsocketHandshake:
b64_key: bytes
b64_hash: bytes
def test_connect_handshake_disconnect(self):
# Connect to Webdis
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
global host, port
sock.connect((host, port))
class TestRawConnection(TestWebdis):
"""Base class for WS tests using sockets"""
def create_random_handshake(self) -> WebsocketHandshake:
# Build WS handshake request
raw_key = os.urandom(16)
b64_key = base64.b64encode(raw_key)
@ -210,21 +213,42 @@ class TestConnectDisconnect(TestWebdis):
expected_raw = b64_key + magic
hash = hashlib.sha1(expected_raw).digest()
b64_hash = base64.b64encode(hash)
return WebsocketHandshake(b64_key, b64_hash)
request = ("GET /.json HTTP/1.1\r\n" + \
def format_handshake_request(self, host: str, port: int, handshake: WebsocketHandshake) -> str:
return ("GET /.json HTTP/1.1\r\n" + \
"Host: %s:%d\r\n" + \
"Connection: Upgrade\r\n" + \
"Upgrade: WebSocket\r\n" + \
"Origin: http://%s:%d\r\n" + \
"Sec-WebSocket-Key: %s\r\n" + \
"\r\n") % (host, port, host, port, b64_key.decode('utf-8'))
"\r\n") % (host, port, host, port, handshake.b64_key.decode('utf-8'))
# Send handshake and validate response
def send_handshake_request(self, sock: socket.socket, request: str, handshake: WebsocketHandshake) -> None:
"""Send handshake and validate response"""
sock.send(request.encode('utf-8'))
response = sock.recv(1024)
lines = response.decode('utf-8').split('\r\n')
self.assertEqual(lines[0], 'HTTP/1.1 101 Switching Protocols')
self.assertTrue('Sec-WebSocket-Accept: %s' % b64_hash.decode('utf-8') in lines)
self.assertTrue('Sec-WebSocket-Accept: %s' % handshake.b64_hash.decode('utf-8') in lines)
def connect_and_handshake(self, sock: socket.socket) -> None:
global host, port
sock.connect((host, port))
# establish WS connection with handshake
handshake = self.create_random_handshake()
request = self.format_handshake_request(host, port, handshake)
self.send_handshake_request(sock, request, handshake)
class TestConnectDisconnect(TestRawConnection):
"""Test for issue #209. A client connects, receives their handshake, disconnects"""
def test_connect_handshake_disconnect(self):
# Connect to Webdis
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
self.connect_and_handshake(sock)
# send FIN frame. Format:
# 4 bits flags (0x8 = Fin: true)
@ -242,6 +266,42 @@ class TestConnectDisconnect(TestWebdis):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.connect((host, port))
class TestMinimalPingPong(TestRawConnection):
"""Test that we receive a PONG response to a 6-byte PING, the smallest possible frame"""
def test_minimal_ping_pong(self):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
self.connect_and_handshake(sock)
# send PING frame. Format:
ping_frame = bytes([0b1000_1001, 0b1000_0000, 0x12, 0x34, 0x56, 0x78]) # 0x09 0x80: FIN=true, opcode=ping(9), mask=true, payload_length=0
sock.send(ping_frame)
# receive PONG frame.
out = sock.recv(4)
self.assertEqual(len(out), 2) # should be just 2 bytes, the FIN and pong opcode
self.assertEqual(out[0], 0b1000_1010) # FIN=true, opcode=pong(0xA)
self.assertEqual(out[1], 0x00) # mask=false, payload_length=0
class TestUnmaskedDataFrame(TestRawConnection):
"""Test that we reject unmasked frames"""
def test_unmasked_redis_ping(self):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
self.connect_and_handshake(sock)
# send unmasked frame querying Redis with a "PING" command. Format:
payload = b'["PING"]'
unmasked_frame = bytes([0b10000001, len(payload)]) + payload # 0x81 0x08: FIN=true, opcode=text_frame, mask=false, payload_length=8
sock.send(unmasked_frame)
# receive error frame.
out = sock.recv(200)
self.assertEqual(out[0:1], b'\x88') # 0x88: FIN=true, opcode=close, mask=false
self.assertEqual(out[1:2], bytes([len(out) - 2])) # payload length, minus the 2 bytes of the header
self.assertEqual(out[2:4], struct.pack('>h', 1002)) # error code 1002: protocol error
if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save