diff --git a/tests/basic.py b/tests/basic.py index b833809..aeb4d10 100755 --- a/tests/basic.py +++ b/tests/basic.py @@ -1,5 +1,5 @@ #!/usr/bin/python3 -import urllib.request, urllib.error, urllib.parse, unittest, json, hashlib +import urllib.request, urllib.error, urllib.parse, unittest, json, hashlib, threading, uuid, time from functools import wraps try: import msgpack @@ -223,5 +223,62 @@ class TestDbSwitch(TestWebdis): f = self.query('GET/key.txt') self.assertTrue(f.read() == b"val0") + +class TestPubSub(TestWebdis): + + def test_pubsub_basic(self): + self.validate_pubsub(1) + + def test_pubsub_many_messages(self): + self.validate_pubsub(1000) + + def validate_pubsub(self, num_messages): + channel_name = str(uuid.uuid4()) + expected_messages = [str(uuid.uuid4()) for i in range(num_messages)] + + self.subscribed = False + subscriber = threading.Thread(target=self.subscriber_main, args=(channel_name,expected_messages)) + subscriber.start() + + # wait for the subscription to be confirmed + while not self.subscribed: + time.sleep(0.1) + + for msg in expected_messages: + pub_response = self.query('PUBLISH/' + channel_name + '/' + msg) + self.assertEqual('{"PUBLISH":1}', pub_response.read().decode('utf-8')) + subscriber.join() + + def subscriber_main(self, channel_name, expected_messages): + sub_response = self.query('SUBSCRIBE/' + channel_name) + + msg_index = 0 + buffer = '' + open_braces = 0 + while True: + cur = sub_response.read(1).decode('utf-8') + buffer += cur + if cur == '{': + open_braces += 1 + elif cur == '}': + open_braces -= 1 + + if open_braces == 0: # we have a complete JSON message + message = json.loads(buffer) + buffer = '' + if 'SUBSCRIBE' in message: + if message['SUBSCRIBE'] == ['subscribe', channel_name, 1]: # notify of successful subscription + self.subscribed = True + continue + elif message['SUBSCRIBE'] == ['message', channel_name, expected_messages[msg_index]]: # confirm current message + msg_index += 1 + if msg_index == len(expected_messages): + sub_response.close() + return + else: + continue + self.fail('Received an unexpected message: ' + buffer) + + if __name__ == '__main__': unittest.main()