@ -1,8 +1,10 @@
#!/usr/bin/env python3
#!/usr/bin/env python3
import abc
import abc
from dataclasses import dataclass
import json
import json
import os
import os
import struct
import unittest
import unittest
import uuid
import uuid
from websocket import create_connection
from websocket import create_connection
@ -194,37 +196,59 @@ class TestFrameSizes(TestWebdis):
self . exec ( ' DEL ' , key )
self . exec ( ' DEL ' , key )
class TestConnectDisconnect ( TestWebdis ) :
@dataclass
class WebsocketHandshake :
b64_key : bytes
b64_hash : bytes
class TestRawConnection ( TestWebdis ) :
""" Base class for WS tests using sockets """
def create_random_handshake ( self ) - > WebsocketHandshake :
# Build WS handshake request
raw_key = os . urandom ( 16 )
b64_key = base64 . b64encode ( raw_key )
magic = b ' 258EAFA5-E914-47DA-95CA-C5AB0DC85B11 '
expected_raw = b64_key + magic
hash = hashlib . sha1 ( expected_raw ) . digest ( )
b64_hash = base64 . b64encode ( hash )
return WebsocketHandshake ( b64_key , b64_hash )
def format_handshake_request ( self , host : str , port : int , handshake : WebsocketHandshake ) - > str :
return ( " GET /.json HTTP/1.1 \r \n " + \
" Host: %s : %d \r \n " + \
" Connection: Upgrade \r \n " + \
" Upgrade: WebSocket \r \n " + \
" Origin: http:// %s : %d \r \n " + \
" Sec-WebSocket-Key: %s \r \n " + \
" \r \n " ) % ( host , port , host , port , handshake . b64_key . decode ( ' utf-8 ' ) )
def send_handshake_request ( self , sock : socket . socket , request : str , handshake : WebsocketHandshake ) - > None :
""" Send handshake and validate response """
sock . send ( request . encode ( ' utf-8 ' ) )
response = sock . recv ( 1024 )
lines = response . decode ( ' utf-8 ' ) . split ( ' \r \n ' )
self . assertEqual ( lines [ 0 ] , ' HTTP/1.1 101 Switching Protocols ' )
self . assertTrue ( ' Sec-WebSocket-Accept: %s ' % handshake . b64_hash . decode ( ' utf-8 ' ) in lines )
def connect_and_handshake ( self , sock : socket . socket ) - > None :
global host , port
sock . connect ( ( host , port ) )
# establish WS connection with handshake
handshake = self . create_random_handshake ( )
request = self . format_handshake_request ( host , port , handshake )
self . send_handshake_request ( sock , request , handshake )
class TestConnectDisconnect ( TestRawConnection ) :
""" Test for issue #209. A client connects, receives their handshake, disconnects """
""" Test for issue #209. A client connects, receives their handshake, disconnects """
def test_connect_handshake_disconnect ( self ) :
def test_connect_handshake_disconnect ( self ) :
# Connect to Webdis
# Connect to Webdis
with socket . socket ( socket . AF_INET , socket . SOCK_STREAM ) as sock :
with socket . socket ( socket . AF_INET , socket . SOCK_STREAM ) as sock :
global host , port
self . connect_and_handshake ( sock )
sock . connect ( ( host , port ) )
# Build WS handshake request
raw_key = os . urandom ( 16 )
b64_key = base64 . b64encode ( raw_key )
magic = b ' 258EAFA5-E914-47DA-95CA-C5AB0DC85B11 '
expected_raw = b64_key + magic
hash = hashlib . sha1 ( expected_raw ) . digest ( )
b64_hash = base64 . b64encode ( hash )
request = ( " GET /.json HTTP/1.1 \r \n " + \
" Host: %s : %d \r \n " + \
" Connection: Upgrade \r \n " + \
" Upgrade: WebSocket \r \n " + \
" Origin: http:// %s : %d \r \n " + \
" Sec-WebSocket-Key: %s \r \n " + \
" \r \n " ) % ( host , port , host , port , b64_key . decode ( ' utf-8 ' ) )
# Send handshake and validate response
sock . send ( request . encode ( ' utf-8 ' ) )
response = sock . recv ( 1024 )
lines = response . decode ( ' utf-8 ' ) . split ( ' \r \n ' )
self . assertEqual ( lines [ 0 ] , ' HTTP/1.1 101 Switching Protocols ' )
self . assertTrue ( ' Sec-WebSocket-Accept: %s ' % b64_hash . decode ( ' utf-8 ' ) in lines )
# send FIN frame. Format:
# send FIN frame. Format:
# 4 bits flags (0x8 = Fin: true)
# 4 bits flags (0x8 = Fin: true)
@ -242,6 +266,42 @@ class TestConnectDisconnect(TestWebdis):
with socket . socket ( socket . AF_INET , socket . SOCK_STREAM ) as sock :
with socket . socket ( socket . AF_INET , socket . SOCK_STREAM ) as sock :
sock . connect ( ( host , port ) )
sock . connect ( ( host , port ) )
class TestMinimalPingPong ( TestRawConnection ) :
""" Test that we receive a PONG response to a 6-byte PING, the smallest possible frame """
def test_minimal_ping_pong ( self ) :
with socket . socket ( socket . AF_INET , socket . SOCK_STREAM ) as sock :
self . connect_and_handshake ( sock )
# send PING frame. Format:
ping_frame = bytes ( [ 0b1000_1001 , 0b1000_0000 , 0x12 , 0x34 , 0x56 , 0x78 ] ) # 0x09 0x80: FIN=true, opcode=ping(9), mask=true, payload_length=0
sock . send ( ping_frame )
# receive PONG frame.
out = sock . recv ( 4 )
self . assertEqual ( len ( out ) , 2 ) # should be just 2 bytes, the FIN and pong opcode
self . assertEqual ( out [ 0 ] , 0b1000_1010 ) # FIN=true, opcode=pong(0xA)
self . assertEqual ( out [ 1 ] , 0x00 ) # mask=false, payload_length=0
class TestUnmaskedDataFrame ( TestRawConnection ) :
""" Test that we reject unmasked frames """
def test_unmasked_redis_ping ( self ) :
with socket . socket ( socket . AF_INET , socket . SOCK_STREAM ) as sock :
self . connect_and_handshake ( sock )
# send unmasked frame querying Redis with a "PING" command. Format:
payload = b ' [ " PING " ] '
unmasked_frame = bytes ( [ 0b10000001 , len ( payload ) ] ) + payload # 0x81 0x08: FIN=true, opcode=text_frame, mask=false, payload_length=8
sock . send ( unmasked_frame )
# receive error frame.
out = sock . recv ( 200 )
self . assertEqual ( out [ 0 : 1 ] , b ' \x88 ' ) # 0x88: FIN=true, opcode=close, mask=false
self . assertEqual ( out [ 1 : 2 ] , bytes ( [ len ( out ) - 2 ] ) ) # payload length, minus the 2 bytes of the header
self . assertEqual ( out [ 2 : 4 ] , struct . pack ( ' >h ' , 1002 ) ) # error code 1002: protocol error
if __name__ == ' __main__ ' :
if __name__ == ' __main__ ' :
unittest . main ( )
unittest . main ( )