@ -1,8 +1,10 @@
#!/usr/bin/env python3
import abc
from dataclasses import dataclass
import json
import os
import struct
import unittest
import uuid
from websocket import create_connection
@ -194,37 +196,59 @@ class TestFrameSizes(TestWebdis):
self . exec ( ' DEL ' , key )
class TestConnectDisconnect ( TestWebdis ) :
@dataclass
class WebsocketHandshake :
b64_key : bytes
b64_hash : bytes
class TestRawConnection ( TestWebdis ) :
""" Base class for WS tests using sockets """
def create_random_handshake ( self ) - > WebsocketHandshake :
# Build WS handshake request
raw_key = os . urandom ( 16 )
b64_key = base64 . b64encode ( raw_key )
magic = b ' 258EAFA5-E914-47DA-95CA-C5AB0DC85B11 '
expected_raw = b64_key + magic
hash = hashlib . sha1 ( expected_raw ) . digest ( )
b64_hash = base64 . b64encode ( hash )
return WebsocketHandshake ( b64_key , b64_hash )
def format_handshake_request ( self , host : str , port : int , handshake : WebsocketHandshake ) - > str :
return ( " GET /.json HTTP/1.1 \r \n " + \
" Host: %s : %d \r \n " + \
" Connection: Upgrade \r \n " + \
" Upgrade: WebSocket \r \n " + \
" Origin: http:// %s : %d \r \n " + \
" Sec-WebSocket-Key: %s \r \n " + \
" \r \n " ) % ( host , port , host , port , handshake . b64_key . decode ( ' utf-8 ' ) )
def send_handshake_request ( self , sock : socket . socket , request : str , handshake : WebsocketHandshake ) - > None :
""" Send handshake and validate response """
sock . send ( request . encode ( ' utf-8 ' ) )
response = sock . recv ( 1024 )
lines = response . decode ( ' utf-8 ' ) . split ( ' \r \n ' )
self . assertEqual ( lines [ 0 ] , ' HTTP/1.1 101 Switching Protocols ' )
self . assertTrue ( ' Sec-WebSocket-Accept: %s ' % handshake . b64_hash . decode ( ' utf-8 ' ) in lines )
def connect_and_handshake ( self , sock : socket . socket ) - > None :
global host , port
sock . connect ( ( host , port ) )
# establish WS connection with handshake
handshake = self . create_random_handshake ( )
request = self . format_handshake_request ( host , port , handshake )
self . send_handshake_request ( sock , request , handshake )
class TestConnectDisconnect ( TestRawConnection ) :
""" Test for issue #209. A client connects, receives their handshake, disconnects """
def test_connect_handshake_disconnect ( self ) :
# Connect to Webdis
with socket . socket ( socket . AF_INET , socket . SOCK_STREAM ) as sock :
global host , port
sock . connect ( ( host , port ) )
# Build WS handshake request
raw_key = os . urandom ( 16 )
b64_key = base64 . b64encode ( raw_key )
magic = b ' 258EAFA5-E914-47DA-95CA-C5AB0DC85B11 '
expected_raw = b64_key + magic
hash = hashlib . sha1 ( expected_raw ) . digest ( )
b64_hash = base64 . b64encode ( hash )
request = ( " GET /.json HTTP/1.1 \r \n " + \
" Host: %s : %d \r \n " + \
" Connection: Upgrade \r \n " + \
" Upgrade: WebSocket \r \n " + \
" Origin: http:// %s : %d \r \n " + \
" Sec-WebSocket-Key: %s \r \n " + \
" \r \n " ) % ( host , port , host , port , b64_key . decode ( ' utf-8 ' ) )
# Send handshake and validate response
sock . send ( request . encode ( ' utf-8 ' ) )
response = sock . recv ( 1024 )
lines = response . decode ( ' utf-8 ' ) . split ( ' \r \n ' )
self . assertEqual ( lines [ 0 ] , ' HTTP/1.1 101 Switching Protocols ' )
self . assertTrue ( ' Sec-WebSocket-Accept: %s ' % b64_hash . decode ( ' utf-8 ' ) in lines )
self . connect_and_handshake ( sock )
# send FIN frame. Format:
# 4 bits flags (0x8 = Fin: true)
@ -242,6 +266,42 @@ class TestConnectDisconnect(TestWebdis):
with socket . socket ( socket . AF_INET , socket . SOCK_STREAM ) as sock :
sock . connect ( ( host , port ) )
class TestMinimalPingPong ( TestRawConnection ) :
""" Test that we receive a PONG response to a 6-byte PING, the smallest possible frame """
def test_minimal_ping_pong ( self ) :
with socket . socket ( socket . AF_INET , socket . SOCK_STREAM ) as sock :
self . connect_and_handshake ( sock )
# send PING frame. Format:
ping_frame = bytes ( [ 0b1000_1001 , 0b1000_0000 , 0x12 , 0x34 , 0x56 , 0x78 ] ) # 0x09 0x80: FIN=true, opcode=ping(9), mask=true, payload_length=0
sock . send ( ping_frame )
# receive PONG frame.
out = sock . recv ( 4 )
self . assertEqual ( len ( out ) , 2 ) # should be just 2 bytes, the FIN and pong opcode
self . assertEqual ( out [ 0 ] , 0b1000_1010 ) # FIN=true, opcode=pong(0xA)
self . assertEqual ( out [ 1 ] , 0x00 ) # mask=false, payload_length=0
class TestUnmaskedDataFrame ( TestRawConnection ) :
""" Test that we reject unmasked frames """
def test_unmasked_redis_ping ( self ) :
with socket . socket ( socket . AF_INET , socket . SOCK_STREAM ) as sock :
self . connect_and_handshake ( sock )
# send unmasked frame querying Redis with a "PING" command. Format:
payload = b ' [ " PING " ] '
unmasked_frame = bytes ( [ 0b10000001 , len ( payload ) ] ) + payload # 0x81 0x08: FIN=true, opcode=text_frame, mask=false, payload_length=8
sock . send ( unmasked_frame )
# receive error frame.
out = sock . recv ( 200 )
self . assertEqual ( out [ 0 : 1 ] , b ' \x88 ' ) # 0x88: FIN=true, opcode=close, mask=false
self . assertEqual ( out [ 1 : 2 ] , bytes ( [ len ( out ) - 2 ] ) ) # payload length, minus the 2 bytes of the header
self . assertEqual ( out [ 2 : 4 ] , struct . pack ( ' >h ' , 1002 ) ) # error code 1002: protocol error
if __name__ == ' __main__ ' :
unittest . main ( )