Formatting only: make ws-tests.py PEP8 compliant.

master
Jessie Murray 3 years ago
parent 71223ae005
commit bb02c1dd04
No known key found for this signature in database
GPG Key ID: E7E4D57EDDA744C5

@ -10,160 +10,163 @@ from websocket import create_connection
host = os.getenv('WEBDIS_HOST', '127.0.0.1') host = os.getenv('WEBDIS_HOST', '127.0.0.1')
port = int(os.getenv('WEBDIS_PORT', 7379)) port = int(os.getenv('WEBDIS_PORT', 7379))
def connect(format): def connect(format):
return create_connection(f'ws://{host}:{port}/.{format}') return create_connection(f'ws://{host}:{port}/.{format}')
class TestWebdis(unittest.TestCase): class TestWebdis(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.ws = connect(self.format()) self.ws = connect(self.format())
def tearDown(self) -> None: def tearDown(self) -> None:
self.ws.close() self.ws.close()
def exec(self, cmd, *args): def exec(self, cmd, *args):
self.ws.send(self.serialize(cmd, *args)) self.ws.send(self.serialize(cmd, *args))
return self.deserialize(self.ws.recv()) return self.deserialize(self.ws.recv())
def clean_key(self): def clean_key(self):
"""Returns a key that was just deleted""" """Returns a key that was just deleted"""
key = str(uuid.uuid4()) key = str(uuid.uuid4())
self.exec('DEL', key) self.exec('DEL', key)
return key return key
@abc.abstractmethod @abc.abstractmethod
def format(self): def format(self):
"""Returns the format to use (added after a dot to the WS URI)""" """Returns the format to use (added after a dot to the WS URI)"""
return return
@abc.abstractmethod @abc.abstractmethod
def serialize(self, cmd): def serialize(self, cmd):
"""Serializes a command according to the format being tested""" """Serializes a command according to the format being tested"""
return return
@abc.abstractmethod @abc.abstractmethod
def deserialize(self, response): def deserialize(self, response):
"""Deserializes a response according to the format being tested""" """Deserializes a response according to the format being tested"""
return return
class TestJson(TestWebdis): class TestJson(TestWebdis):
def format(self): def format(self):
return 'json' return 'json'
def serialize(self, cmd, *args): def serialize(self, cmd, *args):
return json.dumps([cmd] + list(args)) return json.dumps([cmd] + list(args))
def deserialize(self, response): def deserialize(self, response):
return json.loads(response) return json.loads(response)
def test_ping(self): def test_ping(self):
self.assertEqual(self.exec('PING'), {'PING': [True, 'PONG']}) self.assertEqual(self.exec('PING'), {'PING': [True, 'PONG']})
def test_multiple_messages(self): def test_multiple_messages(self):
key = self.clean_key() key = self.clean_key()
n = 100 n = 100
for i in range(n): for i in range(n):
lpush_response = self.exec('LPUSH', key, f'value-{i}') lpush_response = self.exec('LPUSH', key, f'value-{i}')
self.assertEqual(lpush_response, {'LPUSH': i + 1}) self.assertEqual(lpush_response, {'LPUSH': i + 1})
self.assertEqual(self.exec('LLEN', key), {'LLEN': n}) self.assertEqual(self.exec('LLEN', key), {'LLEN': n})
class TestRaw(TestWebdis): class TestRaw(TestWebdis):
def format(self): def format(self):
return 'raw' return 'raw'
def serialize(self, cmd, *args): def serialize(self, cmd, *args):
buffer = f"*{1 + len(args)}\r\n${len(cmd)}\r\n{cmd}\r\n" buffer = f"*{1 + len(args)}\r\n${len(cmd)}\r\n{cmd}\r\n"
for arg in args: for arg in args:
buffer += f"${len(arg)}\r\n{arg}\r\n" buffer += f"${len(arg)}\r\n{arg}\r\n"
return buffer return buffer
def deserialize(self, response): def deserialize(self, response):
return response # we'll just assert using the raw protocol return response # we'll just assert using the raw protocol
def test_ping(self): def test_ping(self):
self.assertEqual(self.exec('PING'), "+PONG\r\n") self.assertEqual(self.exec('PING'), "+PONG\r\n")
def test_get_set(self): def test_get_set(self):
key = self.clean_key() key = self.clean_key()
value = str(uuid.uuid4()) value = str(uuid.uuid4())
not_found_response = self.exec('GET', key) not_found_response = self.exec('GET', key)
self.assertEqual(not_found_response, "$-1\r\n") # Redis protocol response for "not found" self.assertEqual(not_found_response, "$-1\r\n") # Redis protocol response for "not found"
set_response = self.exec('SET', key, value) set_response = self.exec('SET', key, value)
self.assertEqual(set_response, "+OK\r\n") self.assertEqual(set_response, "+OK\r\n")
get_response = self.exec('GET', key) get_response = self.exec('GET', key)
self.assertEqual(get_response, f"${len(value)}\r\n{value}\r\n") self.assertEqual(get_response, f"${len(value)}\r\n{value}\r\n")
@unittest.skipIf(os.getenv('PUBSUB') != '1', "pub-sub test fail due to invalid ordering") @unittest.skipIf(os.getenv('PUBSUB') != '1', "pub-sub test fail due to invalid ordering")
class TestPubSub(unittest.TestCase): class TestPubSub(unittest.TestCase):
def setUp(self): def setUp(self):
self.publisher = connect('json') self.publisher = connect('json')
self.subscriber = connect('json') self.subscriber = connect('json')
def tearDown(self): def tearDown(self):
self.publisher.close() self.publisher.close()
self.subscriber.close() self.subscriber.close()
def serialize(self, cmd, *args): def serialize(self, cmd, *args):
return json.dumps([cmd] + list(args)) return json.dumps([cmd] + list(args))
def deserialize(self, response): def deserialize(self, response):
return json.loads(response) return json.loads(response)
def test_publish_subscribe(self): def test_publish_subscribe(self):
channel_count = 2 channel_count = 2
message_count_per_channel = 8 message_count_per_channel = 8
channels = list(str(uuid.uuid4()) for i in range(channel_count)) channels = list(str(uuid.uuid4()) for i in range(channel_count))
# subscribe to all channels # subscribe to all channels
sub_count = 0 sub_count = 0
for channel in channels: for channel in channels:
self.subscriber.send(self.serialize('SUBSCRIBE', channel)) self.subscriber.send(self.serialize('SUBSCRIBE', channel))
sub_response = self.deserialize(self.subscriber.recv()) sub_response = self.deserialize(self.subscriber.recv())
sub_count += 1 sub_count += 1
self.assertEqual(sub_response, {'SUBSCRIBE': ['subscribe', channel, sub_count]}) self.assertEqual(sub_response, {'SUBSCRIBE': ['subscribe', channel, sub_count]})
# send messages to all channels # send messages to all channels
prefix = 'message-' prefix = 'message-'
for i in range(message_count_per_channel): for i in range(message_count_per_channel):
for channel in channels: for channel in channels:
message = f'{prefix}{i}' message = f'{prefix}{i}'
self.publisher.send(self.serialize('PUBLISH', channel, message)) self.publisher.send(self.serialize('PUBLISH', channel, message))
self.deserialize(self.publisher.recv()) self.deserialize(self.publisher.recv())
received_per_channel = dict((channel, []) for channel in channels) received_per_channel = dict((channel, []) for channel in channels)
for j in range(channel_count * message_count_per_channel): for j in range(channel_count * message_count_per_channel):
received = self.deserialize(self.subscriber.recv()) received = self.deserialize(self.subscriber.recv())
# expected: {'SUBSCRIBE': ['message', $channel, $message]} # expected: {'SUBSCRIBE': ['message', $channel, $message]}
self.assertTrue(received, 'SUBSCRIBE' in received) self.assertTrue(received, 'SUBSCRIBE' in received)
sub_contents = received['SUBSCRIBE'] sub_contents = received['SUBSCRIBE']
self.assertEqual(len(sub_contents), 3) self.assertEqual(len(sub_contents), 3)
self.assertEqual(sub_contents[0], 'message') # first element is the message type, here a push self.assertEqual(sub_contents[0], 'message') # first element is the message type, here a push
channel = sub_contents[1] channel = sub_contents[1]
self.assertTrue(channel in channels) # second is the channel self.assertTrue(channel in channels) # second is the channel
received_per_channel[channel].append(sub_contents[2]) # third, add to list of messages received for this channel received_per_channel[channel].append(
sub_contents[2]) # third, add to list of messages received for this channel
# unsubscribe from all channels
subs_remaining = channel_count # unsubscribe from all channels
for channel in channels: subs_remaining = channel_count
self.subscriber.send(self.serialize('UNSUBSCRIBE', channel)) for channel in channels:
subs_remaining -= 1 self.subscriber.send(self.serialize('UNSUBSCRIBE', channel))
unsub_response = self.deserialize(self.subscriber.recv()) subs_remaining -= 1
self.assertEqual(unsub_response, {'UNSUBSCRIBE': ['unsubscribe', channel, subs_remaining]}) unsub_response = self.deserialize(self.subscriber.recv())
self.assertEqual(unsub_response, {'UNSUBSCRIBE': ['unsubscribe', channel, subs_remaining]})
# check that we received all messages
for channel in channels: # check that we received all messages
self.assertEqual(len(received_per_channel[channel]), message_count_per_channel) for channel in channels:
self.assertEqual(len(received_per_channel[channel]), message_count_per_channel)
# check that we received them *in order*
for i in range(message_count_per_channel): # check that we received them *in order*
for channel in channels: for i in range(message_count_per_channel):
expected = f'{prefix}{i}' for channel in channels:
self.assertEqual(received_per_channel[channel][i], expected, expected = f'{prefix}{i}'
f'In {channel}: expected at offset {i} was "{expected}", actual was: "{received_per_channel[channel][i]}"') self.assertEqual(received_per_channel[channel][i], expected,
f'In {channel}: expected at offset {i} was "{expected}", actual was: "{received_per_channel[channel][i]}"')
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

Loading…
Cancel
Save