From 11c07470bb57ebc786af19d306c7867961f94798 Mon Sep 17 00:00:00 2001
From: Sergey Lyubka <valenok@gmail.com>
Date: Sat, 7 Aug 2021 22:55:13 +0100
Subject: [PATCH] Refactor UDP handling

---
 mongoose.c       | 38 +++++++++++++-------------------------
 src/sntp.c       |  5 ++---
 src/sock.c       | 32 +++++++++++---------------------
 test/unit_test.c | 28 ++++++++++++++++++++++++++++
 4 files changed, 54 insertions(+), 49 deletions(-)

diff --git a/mongoose.c b/mongoose.c
index b7e0bcc7b..8103e51bb 100644
--- a/mongoose.c
+++ b/mongoose.c
@@ -2775,7 +2775,7 @@ int mg_sntp_parse(const unsigned char *buf, size_t len, struct timeval *tv) {
     LOG(LL_ERROR, ("%s", "server sent a kiss of death"));
   } else {
     uint32_t *data = (uint32_t *) &buf[40];
-    tv->tv_sec = (time_t)(mg_ntohl(data[0]) - SNTP_TIME_OFFSET);
+    tv->tv_sec = (time_t) (mg_ntohl(data[0]) - SNTP_TIME_OFFSET);
     tv->tv_usec = (suseconds_t) mg_ntohl(data[1]);
     s_sntmp_next = (unsigned long) (tv->tv_sec + SNTP_INTERVAL_SEC);
     res = 0;
@@ -2792,10 +2792,9 @@ static void sntp_cb(struct mg_connection *c, int ev, void *evd, void *fnd) {
                      (unsigned) tv.tv_usec, s_sntmp_next));
     }
     c->recv.len = 0;  // Clear receive buffer
-  } else if (ev == MG_EV_RESOLVE) {
+  } else if (ev == MG_EV_CONNECT) {
     mg_sntp_send(c, (unsigned long) time(NULL));
   } else if (ev == MG_EV_CLOSE) {
-    // mg_fn_del(c, sntp_cb);
   }
   (void) fnd;
   (void) evd;
@@ -2928,14 +2927,7 @@ static struct mg_connection *alloc_conn(struct mg_mgr *mgr, bool is_client,
 }
 
 static long mg_sock_send(struct mg_connection *c, const void *buf, size_t len) {
-  long n = 0;
-  if (c->is_udp) {
-    union usa usa;
-    socklen_t slen = tousa(&c->peer, &usa);
-    n = sendto(FD(c), (char *) buf, len, 0, &usa.sa, slen);
-  } else {
-    n = send(FD(c), (char *) buf, len, MSG_NONBLOCKING);
-  }
+  long n = send(FD(c), (char *) buf, len, MSG_NONBLOCKING);
   return n == 0 ? -1 : n < 0 && mg_sock_would_block() ? 0 : n;
 }
 
@@ -3142,25 +3134,21 @@ static void setsockopts(struct mg_connection *c) {
 void mg_connect_resolved(struct mg_connection *c) {
   char buf[40];
   int type = c->is_udp ? SOCK_DGRAM : SOCK_STREAM;
-  int rc, af = AF_INET;
-#if MG_ENABLE_IPV6
-  if (c->peer.is_ip6) af = AF_INET6;
-#endif
+  int rc, af = c->peer.is_ip6 ? AF_INET6 : AF_INET;
   mg_straddr(c, buf, sizeof(buf));
   c->fd = S2PTR(socket(af, type, 0));
   if (FD(c) == INVALID_SOCKET) {
     mg_error(c, "socket(): %d", MG_SOCK_ERRNO);
-    return;
-  }
-
-  mg_set_non_blocking_mode(FD(c));
-  mg_call(c, MG_EV_RESOLVE, NULL);
-  if (type == SOCK_STREAM) {
+  } else {
     union usa usa;
     socklen_t slen = tousa(&c->peer, &usa);
-    if ((rc = connect(FD(c), &usa.sa, slen)) == 0 || mg_sock_would_block()) {
-      setsockopts(c);
-      if (rc != 0) c->is_connecting = 1;
+    if (c->is_udp == 0) mg_set_non_blocking_mode(FD(c));
+    if (c->is_udp == 0) setsockopts(c);
+    mg_call(c, MG_EV_RESOLVE, NULL);
+    if ((rc = connect(FD(c), &usa.sa, slen)) == 0) {
+      mg_call(c, MG_EV_CONNECT, NULL);
+    } else if (mg_sock_would_block()) {
+      c->is_connecting = 1;
     } else {
       mg_error(c, "connect: %d", MG_SOCK_ERRNO);
     }
@@ -3224,7 +3212,6 @@ static bool mg_socketpair(SOCKET sp[2], union usa usa[2]) {
 
   (void) memset(&usa[0], 0, sizeof(usa[0]));
   usa[0].sin.sin_family = AF_INET;
-  // usa[0].sin.sin_addr.s_addr = mg_htonl(0x7f000001);  // 127.0.0.1
   *(uint32_t *) &usa->sin.sin_addr = mg_htonl(0x7f000001);  // 127.0.0.1
   usa[1] = usa[0];
 
@@ -3300,6 +3287,7 @@ struct mg_connection *mg_listen(struct mg_mgr *mgr, const char *url,
   struct mg_addr addr;
   SOCKET fd = mg_open_listener(url, &addr);
   if (fd == INVALID_SOCKET) {
+    LOG(LL_ERROR, ("Failed: %s, errno %d", url, MG_SOCK_ERRNO));
   } else if ((c = alloc_conn(mgr, 0, fd)) == NULL) {
     LOG(LL_ERROR, ("OOM %s", url));
     closesocket(fd);
diff --git a/src/sntp.c b/src/sntp.c
index 677928ee1..dc932fbf6 100644
--- a/src/sntp.c
+++ b/src/sntp.c
@@ -21,7 +21,7 @@ int mg_sntp_parse(const unsigned char *buf, size_t len, struct timeval *tv) {
     LOG(LL_ERROR, ("%s", "server sent a kiss of death"));
   } else {
     uint32_t *data = (uint32_t *) &buf[40];
-    tv->tv_sec = (time_t)(mg_ntohl(data[0]) - SNTP_TIME_OFFSET);
+    tv->tv_sec = (time_t) (mg_ntohl(data[0]) - SNTP_TIME_OFFSET);
     tv->tv_usec = (suseconds_t) mg_ntohl(data[1]);
     s_sntmp_next = (unsigned long) (tv->tv_sec + SNTP_INTERVAL_SEC);
     res = 0;
@@ -38,10 +38,9 @@ static void sntp_cb(struct mg_connection *c, int ev, void *evd, void *fnd) {
                      (unsigned) tv.tv_usec, s_sntmp_next));
     }
     c->recv.len = 0;  // Clear receive buffer
-  } else if (ev == MG_EV_RESOLVE) {
+  } else if (ev == MG_EV_CONNECT) {
     mg_sntp_send(c, (unsigned long) time(NULL));
   } else if (ev == MG_EV_CLOSE) {
-    // mg_fn_del(c, sntp_cb);
   }
   (void) fnd;
   (void) evd;
diff --git a/src/sock.c b/src/sock.c
index b820ec4fb..545d053c0 100644
--- a/src/sock.c
+++ b/src/sock.c
@@ -101,14 +101,7 @@ static struct mg_connection *alloc_conn(struct mg_mgr *mgr, bool is_client,
 }
 
 static long mg_sock_send(struct mg_connection *c, const void *buf, size_t len) {
-  long n = 0;
-  if (c->is_udp) {
-    union usa usa;
-    socklen_t slen = tousa(&c->peer, &usa);
-    n = sendto(FD(c), (char *) buf, len, 0, &usa.sa, slen);
-  } else {
-    n = send(FD(c), (char *) buf, len, MSG_NONBLOCKING);
-  }
+  long n = send(FD(c), (char *) buf, len, MSG_NONBLOCKING);
   return n == 0 ? -1 : n < 0 && mg_sock_would_block() ? 0 : n;
 }
 
@@ -315,25 +308,21 @@ static void setsockopts(struct mg_connection *c) {
 void mg_connect_resolved(struct mg_connection *c) {
   char buf[40];
   int type = c->is_udp ? SOCK_DGRAM : SOCK_STREAM;
-  int rc, af = AF_INET;
-#if MG_ENABLE_IPV6
-  if (c->peer.is_ip6) af = AF_INET6;
-#endif
+  int rc, af = c->peer.is_ip6 ? AF_INET6 : AF_INET;
   mg_straddr(c, buf, sizeof(buf));
   c->fd = S2PTR(socket(af, type, 0));
   if (FD(c) == INVALID_SOCKET) {
     mg_error(c, "socket(): %d", MG_SOCK_ERRNO);
-    return;
-  }
-
-  mg_set_non_blocking_mode(FD(c));
-  mg_call(c, MG_EV_RESOLVE, NULL);
-  if (type == SOCK_STREAM) {
+  } else {
     union usa usa;
     socklen_t slen = tousa(&c->peer, &usa);
-    if ((rc = connect(FD(c), &usa.sa, slen)) == 0 || mg_sock_would_block()) {
-      setsockopts(c);
-      if (rc != 0) c->is_connecting = 1;
+    if (c->is_udp == 0) mg_set_non_blocking_mode(FD(c));
+    if (c->is_udp == 0) setsockopts(c);
+    mg_call(c, MG_EV_RESOLVE, NULL);
+    if ((rc = connect(FD(c), &usa.sa, slen)) == 0) {
+      mg_call(c, MG_EV_CONNECT, NULL);
+    } else if (mg_sock_would_block()) {
+      c->is_connecting = 1;
     } else {
       mg_error(c, "connect: %d", MG_SOCK_ERRNO);
     }
@@ -472,6 +461,7 @@ struct mg_connection *mg_listen(struct mg_mgr *mgr, const char *url,
   struct mg_addr addr;
   SOCKET fd = mg_open_listener(url, &addr);
   if (fd == INVALID_SOCKET) {
+    LOG(LL_ERROR, ("Failed: %s, errno %d", url, MG_SOCK_ERRNO));
   } else if ((c = alloc_conn(mgr, 0, fd)) == NULL) {
     LOG(LL_ERROR, ("OOM %s", url));
     closesocket(fd);
diff --git a/test/unit_test.c b/test/unit_test.c
index 040c6232a..8262a05c8 100644
--- a/test/unit_test.c
+++ b/test/unit_test.c
@@ -1133,6 +1133,8 @@ static void test_dns(void) {
     mg_mgr_init(&mgr);
     mgr.dns4.url = "udp://127.0.0.1:12345";
     mgr.dnstimeout = 10;
+    LOG(LL_DEBUG, ("opening dummy DNS listener..."));
+    mg_listen(&mgr, mgr.dns4.url, NULL, NULL);  // Just discard our queries
     mg_http_connect(&mgr, "http://google.com", fn1, buf);
     for (i = 0; i < 50 && buf[0] == '\0'; i++) mg_mgr_poll(&mgr, 1);
     mg_mgr_free(&mgr);
@@ -1436,8 +1438,34 @@ static void test_pipe(void) {
   ASSERT(mgr.conns == NULL);
 }
 
+static void u1(struct mg_connection *c, int ev, void *ev_data, void *fn_data) {
+  if (ev == MG_EV_CONNECT) {
+    ((int *) fn_data)[0] += 1;
+    mg_send(c, "hi", 2);
+  } else if (ev == MG_EV_READ) {
+    ((int *) fn_data)[0] += 10;
+    mg_iobuf_free(&c->recv);
+  }
+  (void) ev_data;
+}
+
+static void test_udp(void) {
+  struct mg_mgr mgr;
+  const char *url = "udp://127.0.0.1:12353";
+  int i, done = 0;
+  mg_mgr_init(&mgr);
+  mg_listen(&mgr, url, u1, (void *) &done);
+  mg_connect(&mgr, url, u1, (void *) &done);
+  for (i = 0; i < 5; i++) mg_mgr_poll(&mgr, 1);
+  // LOG(LL_INFO, ("%d", done));
+  ASSERT(done == 11);
+  mg_mgr_free(&mgr);
+  ASSERT(mgr.conns == NULL);
+}
+
 int main(void) {
   mg_log_set("3");
+  test_udp();
   test_pipe();
   test_packed();
   test_crc32();
-- 
GitLab