@ -92,5 +92,75 @@ class TestRaw(TestWebdis):
self . assertEqual ( get_response , f " $ { len ( value ) } \r \n { value } \r \n " )
@unittest.skipIf ( os . getenv ( ' PUBSUB ' ) != ' 1 ' , " pub-sub test fail due to invalid ordering " )
class TestPubSub ( unittest . TestCase ) :
def setUp ( self ) :
self . publisher = create_connection ( f ' ws:// { host } : { port } /.json ' )
self . subscriber = create_connection ( f ' ws:// { host } : { port } /.json ' )
def tearDown ( self ) :
self . publisher . close ( )
self . subscriber . close ( )
def serialize ( self , cmd , * args ) :
return json . dumps ( [ cmd ] + list ( args ) )
def deserialize ( self , response ) :
return json . loads ( response )
def test_publish_subscribe ( self ) :
channel_count = 2
message_count_per_channel = 8
channels = list ( str ( uuid . uuid4 ( ) ) for i in range ( channel_count ) )
# subscribe to all channels
sub_count = 0
for channel in channels :
self . subscriber . send ( self . serialize ( ' SUBSCRIBE ' , channel ) )
unsub_response = self . deserialize ( self . subscriber . recv ( ) )
sub_count + = 1
self . assertEqual ( unsub_response , { ' SUBSCRIBE ' : [ ' subscribe ' , channel , sub_count ] } )
# send messages to all channels
prefix = ' message- '
for i in range ( message_count_per_channel ) :
for channel in channels :
message = f ' { prefix } { i } '
self . publisher . send ( self . serialize ( ' PUBLISH ' , channel , message ) )
received_per_channel = dict ( ( channel , [ ] ) for channel in channels )
for j in range ( channel_count * message_count_per_channel ) :
received = self . deserialize ( self . subscriber . recv ( ) )
print ( ' received: ' , received )
# expected: {'SUBSCRIBE': ['message', $channel, $message]}
self . assertTrue ( received , ' SUBSCRIBE ' in received )
sub_contents = received [ ' SUBSCRIBE ' ]
self . assertEqual ( len ( sub_contents ) , 3 )
self . assertEqual ( sub_contents [ 0 ] , ' message ' ) # first element is the message type, here a push
channel = sub_contents [ 1 ]
self . assertTrue ( channel in channels ) # second is the channel
received_per_channel [ channel ] . append ( sub_contents [ 2 ] ) # third, add to list of messages received for this channel
# unsubscribe from all channels
subs_remaining = channel_count
for channel in channels :
self . subscriber . send ( self . serialize ( ' UNSUBSCRIBE ' , channel ) )
subs_remaining - = 1
unsub_response = self . deserialize ( self . subscriber . recv ( ) )
self . assertEqual ( unsub_response , { ' SUBSCRIBE ' : [ ' unsubscribe ' , channel , subs_remaining ] } )
# check that we received all messages
for channel in channels :
self . assertEqual ( len ( received_per_channel [ channel ] ) , message_count_per_channel )
# check that we received them *in order*
for i in range ( message_count_per_channel ) :
for channel in channels :
expected = f ' { prefix } { i } '
self . assertEqual ( received_per_channel [ channel ] [ i ] , expected ,
f ' In { channel } : expected at offset { i } was " { expected } " , actual was: " { received_per_channel [ channel ] [ i ] } " ' )
if __name__ == ' __main__ ' :
unittest . main ( )