diff --git a/src/cmd.c b/src/cmd.c index df268a3..f064518 100644 --- a/src/cmd.c +++ b/src/cmd.c @@ -375,13 +375,31 @@ cmd_select_format(struct http_client *client, struct cmd *cmd, int cmd_is_subscribe(struct cmd *cmd) { - if(cmd->pub_sub_client) { /* persistent command */ - return 1; - } - if(cmd->count >= 1 && cmd->argv[0] && - (strncasecmp(cmd->argv[0], "SUBSCRIBE", cmd->argv_len[0]) == 0 || - strncasecmp(cmd->argv[0], "PSUBSCRIBE", cmd->argv_len[0]) == 0)) { + if(cmd->pub_sub_client || /* persistent command */ + cmd_is_subscribe_args(cmd)) { /* checked with args */ return 1; } return 0; } + +int +cmd_is_subscribe_args(struct cmd *cmd) { + + if(cmd->count >= 2 && + ((cmd->argv_len[0] == 9 && strncasecmp(cmd->argv[0], "subscribe", 9) == 0) || + (cmd->argv_len[0] == 10 && strncasecmp(cmd->argv[0], "psubscribe", 10) == 0))) { + return 1; + } + return 0; +} + +int +cmd_is_unsubscribe_args(struct cmd *cmd) { + + if(cmd->count >= 2 && + ((cmd->argv_len[0] == 11 && strncasecmp(cmd->argv[0], "unsubscribe", 11) == 0) || + (cmd->argv_len[0] == 12 && strncasecmp(cmd->argv[0], "punsubscribe", 12) == 0))) { + return 1; + } + return 0; +} diff --git a/src/cmd.h b/src/cmd.h index 78b5bce..823c3e3 100644 --- a/src/cmd.h +++ b/src/cmd.h @@ -72,6 +72,12 @@ int cmd_select_format(struct http_client *client, struct cmd *cmd, const char *uri, size_t uri_len, formatting_fun *f_format); +int +cmd_is_subscribe_args(struct cmd *cmd); + +int +cmd_is_unsubscribe_args(struct cmd *cmd); + int cmd_is_subscribe(struct cmd *cmd); diff --git a/src/websocket.c b/src/websocket.c index f8633ec..edcb9ff 100644 --- a/src/websocket.c +++ b/src/websocket.c @@ -310,9 +310,17 @@ ws_execute(struct ws_client *ws, struct ws_msg *msg) { cmd->pub_sub_client = c; } - /* log and execute */ - ws_log_cmd(ws, cmd); - cmd_send(cmd, fun_reply); + int is_subscribe = cmd_is_subscribe_args(cmd); + int is_unsubscribe = cmd_is_unsubscribe_args(cmd); + + if(ws->ran_subscribe && !is_subscribe && !is_unsubscribe) { /* disallow non-subscribe commands after a subscribe */ + char error_msg[] = "Command not allowed after subscribe"; + ws_frame_and_send_response(ws, WS_BINARY_FRAME, error_msg, sizeof(error_msg)-1); + } else { /* log and execute */ + ws_log_cmd(ws, cmd); + cmd_send(cmd, fun_reply); + ws->ran_subscribe = is_subscribe; + } return 0; } diff --git a/src/websocket.h b/src/websocket.h index f15af74..8b176a4 100644 --- a/src/websocket.h +++ b/src/websocket.h @@ -40,6 +40,7 @@ struct ws_client { /* indicates that we'll close once we've flushed all buffered data and read what we planned to read */ int close_after_events; + int ran_subscribe; /* set if we've run a (p)subscribe command */ }; struct ws_client * diff --git a/src/worker.c b/src/worker.c index d440546..5d06cf6 100644 --- a/src/worker.c +++ b/src/worker.c @@ -108,11 +108,8 @@ worker_can_read(int fd, short event, void *p) { close(c->fd); } http_client_free(c); - } else { - /* start monitoring input again */ - if(c->is_websocket) { /* all communication handled by WS code from now on */ - // ws_monitor_input(c->ws); - } else { + } else { /* start monitoring input again */ + if(!c->is_websocket) { /* all communication handled by WS code from now on */ worker_monitor_input(c); } }