@ -1,8 +1,10 @@
#!/usr/bin/env python3
#!/usr/bin/env python3
import abc
import abc
from dataclasses import dataclass
import json
import json
import os
import os
import struct
import unittest
import unittest
import uuid
import uuid
from websocket import create_connection
from websocket import create_connection
@ -194,15 +196,16 @@ class TestFrameSizes(TestWebdis):
self . exec ( ' DEL ' , key )
self . exec ( ' DEL ' , key )
class TestConnectDisconnect ( TestWebdis ) :
@dataclass
""" Test for issue #209. A client connects, receives their handshake, disconnects """
class WebsocketHandshake :
b64_key : bytes
b64_hash : bytes
def test_connect_handshake_disconnect ( self ) :
# Connect to Webdis
with socket . socket ( socket . AF_INET , socket . SOCK_STREAM ) as sock :
global host , port
sock . connect ( ( host , port ) )
class TestRawConnection ( TestWebdis ) :
""" Base class for WS tests using sockets """
def create_random_handshake ( self ) - > WebsocketHandshake :
# Build WS handshake request
# Build WS handshake request
raw_key = os . urandom ( 16 )
raw_key = os . urandom ( 16 )
b64_key = base64 . b64encode ( raw_key )
b64_key = base64 . b64encode ( raw_key )
@ -210,21 +213,42 @@ class TestConnectDisconnect(TestWebdis):
expected_raw = b64_key + magic
expected_raw = b64_key + magic
hash = hashlib . sha1 ( expected_raw ) . digest ( )
hash = hashlib . sha1 ( expected_raw ) . digest ( )
b64_hash = base64 . b64encode ( hash )
b64_hash = base64 . b64encode ( hash )
return WebsocketHandshake ( b64_key , b64_hash )
request = ( " GET /.json HTTP/1.1 \r \n " + \
def format_handshake_request ( self , host : str , port : int , handshake : WebsocketHandshake ) - > str :
return ( " GET /.json HTTP/1.1 \r \n " + \
" Host: %s : %d \r \n " + \
" Host: %s : %d \r \n " + \
" Connection: Upgrade \r \n " + \
" Connection: Upgrade \r \n " + \
" Upgrade: WebSocket \r \n " + \
" Upgrade: WebSocket \r \n " + \
" Origin: http:// %s : %d \r \n " + \
" Origin: http:// %s : %d \r \n " + \
" Sec-WebSocket-Key: %s \r \n " + \
" Sec-WebSocket-Key: %s \r \n " + \
" \r \n " ) % ( host , port , host , port , b64_key . decode ( ' utf-8 ' ) )
" \r \n " ) % ( host , port , host , port , handshake . b64_key . decode ( ' utf-8 ' ) )
# Send handshake and validate response
def send_handshake_request ( self , sock : socket . socket , request : str , handshake : WebsocketHandshake ) - > None :
""" Send handshake and validate response """
sock . send ( request . encode ( ' utf-8 ' ) )
sock . send ( request . encode ( ' utf-8 ' ) )
response = sock . recv ( 1024 )
response = sock . recv ( 1024 )
lines = response . decode ( ' utf-8 ' ) . split ( ' \r \n ' )
lines = response . decode ( ' utf-8 ' ) . split ( ' \r \n ' )
self . assertEqual ( lines [ 0 ] , ' HTTP/1.1 101 Switching Protocols ' )
self . assertEqual ( lines [ 0 ] , ' HTTP/1.1 101 Switching Protocols ' )
self . assertTrue ( ' Sec-WebSocket-Accept: %s ' % b64_hash . decode ( ' utf-8 ' ) in lines )
self . assertTrue ( ' Sec-WebSocket-Accept: %s ' % handshake . b64_hash . decode ( ' utf-8 ' ) in lines )
def connect_and_handshake ( self , sock : socket . socket ) - > None :
global host , port
sock . connect ( ( host , port ) )
# establish WS connection with handshake
handshake = self . create_random_handshake ( )
request = self . format_handshake_request ( host , port , handshake )
self . send_handshake_request ( sock , request , handshake )
class TestConnectDisconnect ( TestRawConnection ) :
""" Test for issue #209. A client connects, receives their handshake, disconnects """
def test_connect_handshake_disconnect ( self ) :
# Connect to Webdis
with socket . socket ( socket . AF_INET , socket . SOCK_STREAM ) as sock :
self . connect_and_handshake ( sock )
# send FIN frame. Format:
# send FIN frame. Format:
# 4 bits flags (0x8 = Fin: true)
# 4 bits flags (0x8 = Fin: true)
@ -242,6 +266,42 @@ class TestConnectDisconnect(TestWebdis):
with socket . socket ( socket . AF_INET , socket . SOCK_STREAM ) as sock :
with socket . socket ( socket . AF_INET , socket . SOCK_STREAM ) as sock :
sock . connect ( ( host , port ) )
sock . connect ( ( host , port ) )
class TestMinimalPingPong ( TestRawConnection ) :
""" Test that we receive a PONG response to a 6-byte PING, the smallest possible frame """
def test_minimal_ping_pong ( self ) :
with socket . socket ( socket . AF_INET , socket . SOCK_STREAM ) as sock :
self . connect_and_handshake ( sock )
# send PING frame. Format:
ping_frame = bytes ( [ 0b1000_1001 , 0b1000_0000 , 0x12 , 0x34 , 0x56 , 0x78 ] ) # 0x09 0x80: FIN=true, opcode=ping(9), mask=true, payload_length=0
sock . send ( ping_frame )
# receive PONG frame.
out = sock . recv ( 4 )
self . assertEqual ( len ( out ) , 2 ) # should be just 2 bytes, the FIN and pong opcode
self . assertEqual ( out [ 0 ] , 0b1000_1010 ) # FIN=true, opcode=pong(0xA)
self . assertEqual ( out [ 1 ] , 0x00 ) # mask=false, payload_length=0
class TestUnmaskedDataFrame ( TestRawConnection ) :
""" Test that we reject unmasked frames """
def test_unmasked_redis_ping ( self ) :
with socket . socket ( socket . AF_INET , socket . SOCK_STREAM ) as sock :
self . connect_and_handshake ( sock )
# send unmasked frame querying Redis with a "PING" command. Format:
payload = b ' [ " PING " ] '
unmasked_frame = bytes ( [ 0b10000001 , len ( payload ) ] ) + payload # 0x81 0x08: FIN=true, opcode=text_frame, mask=false, payload_length=8
sock . send ( unmasked_frame )
# receive error frame.
out = sock . recv ( 200 )
self . assertEqual ( out [ 0 : 1 ] , b ' \x88 ' ) # 0x88: FIN=true, opcode=close, mask=false
self . assertEqual ( out [ 1 : 2 ] , bytes ( [ len ( out ) - 2 ] ) ) # payload length, minus the 2 bytes of the header
self . assertEqual ( out [ 2 : 4 ] , struct . pack ( ' >h ' , 1002 ) ) # error code 1002: protocol error
if __name__ == ' __main__ ' :
if __name__ == ' __main__ ' :
unittest . main ( )
unittest . main ( )