diff --git a/mongoose.c b/mongoose.c index 74a2540887940e7d4ad39ce7b4feb845dd37444f..9561409e1116a761949a93f2d0e9c457612c8867 100644 --- a/mongoose.c +++ b/mongoose.c @@ -4661,13 +4661,16 @@ void mg_ws_upgrade(struct mg_connection *c, struct mg_http_message *hm, const char *fmt, ...) { struct mg_str *wskey = mg_http_get_header(hm, "Sec-WebSocket-Key"); c->pfn = mg_ws_cb; - if (wskey != NULL) { + if (wskey == NULL) { + mg_http_reply(c, 426, "", "WS upgrade expected\n"); + c->is_draining = 1; + } else { va_list ap; va_start(ap, fmt); ws_handshake(c, wskey->ptr, wskey->len, fmt, ap); va_end(ap); + c->is_websocket = 1; } - c->is_websocket = 1; } size_t mg_ws_wrap(struct mg_connection *c, size_t len, int op) { diff --git a/src/ws.c b/src/ws.c index a1a9941499af0e0e28a74c4dc9f4da15d6d1d7db..28448339f557b9a32333077329bf1554f7532406 100644 --- a/src/ws.c +++ b/src/ws.c @@ -217,13 +217,16 @@ void mg_ws_upgrade(struct mg_connection *c, struct mg_http_message *hm, const char *fmt, ...) { struct mg_str *wskey = mg_http_get_header(hm, "Sec-WebSocket-Key"); c->pfn = mg_ws_cb; - if (wskey != NULL) { + if (wskey == NULL) { + mg_http_reply(c, 426, "", "WS upgrade expected\n"); + c->is_draining = 1; + } else { va_list ap; va_start(ap, fmt); ws_handshake(c, wskey->ptr, wskey->len, fmt, ap); va_end(ap); + c->is_websocket = 1; } - c->is_websocket = 1; } size_t mg_ws_wrap(struct mg_connection *c, size_t len, int op) { diff --git a/test/unit_test.c b/test/unit_test.c index 9f92db992c7ad10bfcdbef4595d721a08153f9ee..0e9cb4e985ecf306657d8e3789f6f945588ad6c1 100644 --- a/test/unit_test.c +++ b/test/unit_test.c @@ -383,34 +383,6 @@ static void eh1(struct mg_connection *c, int ev, void *ev_data, void *fn_data) { } } -static void wcb(struct mg_connection *c, int ev, void *ev_data, void *fn_data) { - if (ev == MG_EV_WS_OPEN) { - mg_ws_send(c, "boo", 3, WEBSOCKET_OP_BINARY); - mg_ws_send(c, "", 0, WEBSOCKET_OP_PING); - } else if (ev == MG_EV_WS_MSG) { - struct mg_ws_message *wm = (struct mg_ws_message *) ev_data; - ASSERT(mg_strstr(wm->data, mg_str("boo"))); - mg_ws_send(c, "", 0, WEBSOCKET_OP_CLOSE); // Ask server to close - *(int *) fn_data = 1; - } else if (ev == MG_EV_CLOSE) { - *(int *) fn_data = 2; - } -} - -static void test_ws(void) { - struct mg_mgr mgr; - int i, done = 0; - - mg_mgr_init(&mgr); - ASSERT(mg_http_listen(&mgr, "ws://LOCALHOST:12345", eh1, NULL) != NULL); - mg_ws_connect(&mgr, "ws://localhost:12345/ws", wcb, &done, "%s", ""); - for (i = 0; i < 20; i++) mg_mgr_poll(&mgr, 1); - ASSERT(done == 2); - - mg_mgr_free(&mgr); - ASSERT(mgr.conns == NULL); -} - struct fetch_data { char *buf; int code, closed; @@ -467,6 +439,39 @@ static int cmpbody(const char *buf, const char *str) { return strncmp(hm.body.ptr, str, hm.body.len); } +static void wcb(struct mg_connection *c, int ev, void *ev_data, void *fn_data) { + if (ev == MG_EV_WS_OPEN) { + mg_ws_send(c, "boo", 3, WEBSOCKET_OP_BINARY); + mg_ws_send(c, "", 0, WEBSOCKET_OP_PING); + } else if (ev == MG_EV_WS_MSG) { + struct mg_ws_message *wm = (struct mg_ws_message *) ev_data; + ASSERT(mg_strstr(wm->data, mg_str("boo"))); + mg_ws_send(c, "", 0, WEBSOCKET_OP_CLOSE); // Ask server to close + *(int *) fn_data = 1; + } else if (ev == MG_EV_CLOSE) { + *(int *) fn_data = 2; + } +} + +static void test_ws(void) { + char buf[FETCH_BUF_SIZE]; + const char *url = "ws://LOCALHOST:12343"; + struct mg_mgr mgr; + int i, done = 0; + + mg_mgr_init(&mgr); + ASSERT(mg_http_listen(&mgr, url, eh1, NULL) != NULL); + mg_ws_connect(&mgr, url, wcb, &done, "%s", ""); + for (i = 0; i < 20; i++) mg_mgr_poll(&mgr, 1); + ASSERT(done == 2); + + // Test that non-WS requests fail + ASSERT(fetch(&mgr, buf, url, "GET /ws HTTP/1.0\r\n\n") == 426); + + mg_mgr_free(&mgr); + ASSERT(mgr.conns == NULL); +} + static void eh9(struct mg_connection *c, int ev, void *ev_data, void *fn_data) { if (ev == MG_EV_ERROR) { ASSERT(!strcmp((char *) ev_data, "error connecting to 127.0.0.1:55117"));