diff --git a/tests/ws-tests.py b/tests/ws-tests.py index cb038b2..d5fc929 100755 --- a/tests/ws-tests.py +++ b/tests/ws-tests.py @@ -6,6 +6,10 @@ import os import unittest import uuid from websocket import create_connection +import socket +import base64 +import hashlib +import time host = os.getenv('WEBDIS_HOST', '127.0.0.1') port = int(os.getenv('WEBDIS_PORT', 7379)) @@ -190,5 +194,54 @@ class TestFrameSizes(TestWebdis): self.exec('DEL', key) +class TestConnectDisconnect(TestWebdis): + """Test for issue #209. A client connects, receives their handshake, disconnects""" + + def test_connect_handshake_disconnect(self): + # Connect to Webdis + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + global host, port + sock.connect((host, port)) + + # Build WS handshake request + raw_key = os.urandom(16) + b64_key = base64.b64encode(raw_key) + magic = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11' + expected_raw = b64_key + magic + hash = hashlib.sha1(expected_raw).digest() + b64_hash = base64.b64encode(hash) + + request = ("GET /.json HTTP/1.1\r\n" + \ + "Host: %s:%d\r\n" + \ + "Connection: Upgrade\r\n" + \ + "Upgrade: WebSocket\r\n" + \ + "Origin: http://%s:%d\r\n" + \ + "Sec-WebSocket-Key: %s\r\n" + \ + "\r\n") % (host, port, host, port, b64_key.decode('utf-8')) + + # Send handshake and validate response + sock.send(request.encode('utf-8')) + response = sock.recv(1024) + lines = response.decode('utf-8').split('\r\n') + self.assertEqual(lines[0], 'HTTP/1.1 101 Switching Protocols') + self.assertTrue('Sec-WebSocket-Accept: %s' % b64_hash.decode('utf-8') in lines) + + # send FIN frame. Format: + # 4 bits flags (0x8 = Fin: true) + # 4 bits opcode (0x8 = close) + # 1 bit mask, 7 bits payload length -- here, 2 + fin_header = bytes([0b10001000, 0b10000010]) # 0x88, 0x82 + fin_payload = os.urandom(2) + mask = os.urandom(4) + fin_frame = fin_header + mask + bytes(fin_payload[0] ^ mask[0]) + bytes(fin_payload[1] ^ mask[1]) + sock.send(fin_frame) + + time.sleep(0.5) + + # now that we've disconnected, make sure Webdis is still alive + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.connect((host, port)) + + if __name__ == '__main__': unittest.main()