From 8fb5e8ef191c96da7d86ed30e15e24ece8261aa4 Mon Sep 17 00:00:00 2001
From: Deomid Ryabkov <rojer@cesanta.com>
Date: Tue, 22 Nov 2016 13:07:55 +0000
Subject: [PATCH] mbedTLS support for LWIP net_if

PUBLISHED_FROM=a733ba6e06887a448f96f92679f6f8adbe9c61f7
---
 mongoose.c | 148 +++++++++++++++++++++++++++++++++++++----------------
 1 file changed, 103 insertions(+), 45 deletions(-)

diff --git a/mongoose.c b/mongoose.c
index a74a66d78..7d581b251 100644
--- a/mongoose.c
+++ b/mongoose.c
@@ -4256,7 +4256,8 @@ const char *mg_set_ssl(struct mg_connection *nc, const char *cert,
 #include <mbedtls/ssl.h>
 #include <mbedtls/x509_crt.h>
 
-static void mg_ssl_mbed_log(void *ctx, int level, const char *file, int line, const char *str) {
+static void mg_ssl_mbed_log(void *ctx, int level, const char *file, int line,
+                            const char *str) {
   enum cs_log_level cs_level;
   switch (level) {
     case 1:
@@ -4269,7 +4270,8 @@ static void mg_ssl_mbed_log(void *ctx, int level, const char *file, int line, co
     default:
       cs_level = LL_VERBOSE_DEBUG;
   }
-  LOG(cs_level, ("%p %s", ctx, str));
+  /* mbedTLS passes strings with \n at the end, strip it. */
+  LOG(cs_level, ("%p %.*s", ctx, (int) (strlen(str) - 1), str));
   (void) file;
   (void) line;
 }
@@ -4302,10 +4304,13 @@ enum mg_ssl_if_result mg_ssl_if_conn_accept(struct mg_connection *nc,
   return MG_SSL_OK;
 }
 
-static enum mg_ssl_if_result mg_use_cert(struct mg_ssl_if_ctx *ctx, const char *cert,
-                                         const char *key, const char **err_msg);
-static enum mg_ssl_if_result mg_use_ca_cert(struct mg_ssl_if_ctx *ctx, const char *cert);
-static enum mg_ssl_if_result mg_set_cipher_list(struct mg_ssl_if_ctx *ctx, const char *ciphers);
+static enum mg_ssl_if_result mg_use_cert(struct mg_ssl_if_ctx *ctx,
+                                         const char *cert, const char *key,
+                                         const char **err_msg);
+static enum mg_ssl_if_result mg_use_ca_cert(struct mg_ssl_if_ctx *ctx,
+                                            const char *cert);
+static enum mg_ssl_if_result mg_set_cipher_list(struct mg_ssl_if_ctx *ctx,
+                                                const char *ciphers);
 
 enum mg_ssl_if_result mg_ssl_if_conn_init(
     struct mg_connection *nc, const struct mg_ssl_if_conn_params *params,
@@ -4315,6 +4320,7 @@ enum mg_ssl_if_result mg_ssl_if_conn_init(
   DBG(("%p %s,%s,%s", nc, (params->cert ? params->cert : ""),
        (params->key ? params->key : ""),
        (params->ca_cert ? params->ca_cert : "")));
+
   if (ctx == NULL) {
     MG_SET_PTRPTR(err_msg, "Out of memory");
     return MG_SSL_ERROR;
@@ -4324,20 +4330,19 @@ enum mg_ssl_if_result mg_ssl_if_conn_init(
   mbedtls_ssl_config_init(ctx->conf);
   mbedtls_ssl_conf_dbg(ctx->conf, mg_ssl_mbed_log, nc);
   if (mbedtls_ssl_config_defaults(
-        ctx->conf,
-        (nc->flags & MG_F_LISTENING ? MBEDTLS_SSL_IS_SERVER : MBEDTLS_SSL_IS_CLIENT),
-        MBEDTLS_SSL_TRANSPORT_STREAM,
-        MBEDTLS_SSL_PRESET_DEFAULT) != 0) {
+          ctx->conf, (nc->flags & MG_F_LISTENING ? MBEDTLS_SSL_IS_SERVER
+                                                 : MBEDTLS_SSL_IS_CLIENT),
+          MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT) != 0) {
     MG_SET_PTRPTR(err_msg, "Failed to init SSL config");
     return MG_SSL_ERROR;
   }
   /* TLS 1.2 and up */
-  mbedtls_ssl_conf_min_version(ctx->conf, MBEDTLS_SSL_MAJOR_VERSION_3, MBEDTLS_SSL_MINOR_VERSION_3);
+  mbedtls_ssl_conf_min_version(ctx->conf, MBEDTLS_SSL_MAJOR_VERSION_3,
+                               MBEDTLS_SSL_MINOR_VERSION_3);
   mbedtls_ssl_conf_rng(ctx->conf, mg_ssl_if_mbed_random, nc);
 
   if (params->cert != NULL &&
-      mg_use_cert(ctx, params->cert, params->key, err_msg) !=
-          MG_SSL_OK) {
+      mg_use_cert(ctx, params->cert, params->key, err_msg) != MG_SSL_OK) {
     return MG_SSL_ERROR;
   }
 
@@ -4353,8 +4358,6 @@ enum mg_ssl_if_result mg_ssl_if_conn_init(
 
   mg_set_cipher_list(ctx, NULL);
 
-  mbedtls_debug_set_threshold(4);
-
   if (!(nc->flags & MG_F_LISTENING)) {
     ctx->ssl = MG_CALLOC(1, sizeof(*ctx->ssl));
     mbedtls_ssl_init(ctx->ssl);
@@ -4369,7 +4372,11 @@ enum mg_ssl_if_result mg_ssl_if_conn_init(
   return MG_SSL_OK;
 }
 
-int ssl_socket_send(void *ctx, const unsigned char *buf, size_t len) {
+#if MG_NET_IF == MG_NET_IF_LWIP_LOW_LEVEL
+int ssl_socket_send(void *ctx, const unsigned char *buf, size_t len);
+int ssl_socket_recv(void *ctx, unsigned char *buf, size_t len);
+#else
+static int ssl_socket_send(void *ctx, const unsigned char *buf, size_t len) {
   struct mg_connection *nc = (struct mg_connection *) ctx;
   int n = (int) MG_SEND_FUNC(nc->sock, buf, len, 0);
   DBG(("%p %d -> %d", nc, (int) len, n));
@@ -4386,12 +4393,18 @@ static int ssl_socket_recv(void *ctx, unsigned char *buf, size_t len) {
   n = mg_get_errno();
   return ((n == EAGAIN || n == EINPROGRESS) ? MBEDTLS_ERR_SSL_WANT_READ : -1);
 }
+#endif
 
-static enum mg_ssl_if_result mg_ssl_if_mbed_err(struct mg_connection *nc, int ret) {
+static enum mg_ssl_if_result mg_ssl_if_mbed_err(struct mg_connection *nc,
+                                                int ret) {
   if (ret == MBEDTLS_ERR_SSL_WANT_READ) return MG_SSL_WANT_READ;
   if (ret == MBEDTLS_ERR_SSL_WANT_WRITE) return MG_SSL_WANT_WRITE;
-  DBG(("%p SSL error: %d", nc, ret));
+  if (ret !=
+      MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) { /* CLOSE_NOTIFY = Normal shutdown */
+    LOG(LL_ERROR, ("%p SSL error: %d", nc, ret));
+  }
   nc->err = ret;
+  nc->flags |= MG_F_CLOSE_IMMEDIATELY;
   return MG_SSL_ERROR;
 }
 
@@ -4452,7 +4465,8 @@ void mg_ssl_if_conn_free(struct mg_connection *nc) {
   MG_FREE(ctx);
 }
 
-static enum mg_ssl_if_result mg_use_ca_cert(struct mg_ssl_if_ctx *ctx, const char *ca_cert) {
+static enum mg_ssl_if_result mg_use_ca_cert(struct mg_ssl_if_ctx *ctx,
+                                            const char *ca_cert) {
   if (ca_cert == NULL || strcmp(ca_cert, "*") == 0) {
     return MG_SSL_OK;
   }
@@ -4466,8 +4480,8 @@ static enum mg_ssl_if_result mg_use_ca_cert(struct mg_ssl_if_ctx *ctx, const cha
   return MG_SSL_OK;
 }
 
-static enum mg_ssl_if_result mg_use_cert(struct mg_ssl_if_ctx *ctx, const char *cert,
-                                         const char *key,
+static enum mg_ssl_if_result mg_use_cert(struct mg_ssl_if_ctx *ctx,
+                                         const char *cert, const char *key,
                                          const char **err_msg) {
   if (key == NULL) key = cert;
   if (cert == NULL || cert[0] == '\0' || key == NULL || key[0] == '\0') {
@@ -4493,28 +4507,28 @@ static enum mg_ssl_if_result mg_use_cert(struct mg_ssl_if_ctx *ctx, const char *
 }
 
 static const int mg_s_cipher_list[] = {
-  MBEDTLS_TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
-  MBEDTLS_TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
-  MBEDTLS_TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
-  MBEDTLS_TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
-  MBEDTLS_TLS_DHE_RSA_WITH_AES_256_GCM_SHA384,
-  MBEDTLS_TLS_DHE_RSA_WITH_AES_128_GCM_SHA256,
-  MBEDTLS_TLS_DHE_RSA_WITH_AES_128_CBC_SHA256,
-  MBEDTLS_TLS_RSA_WITH_AES_256_GCM_SHA384,
-  MBEDTLS_TLS_RSA_WITH_AES_256_CBC_SHA256,
-  MBEDTLS_TLS_RSA_WITH_AES_128_GCM_SHA256,
-  MBEDTLS_TLS_RSA_WITH_AES_128_CBC_SHA256,
-  MBEDTLS_TLS_RSA_WITH_AES_256_CBC_SHA,
-  MBEDTLS_TLS_RSA_WITH_AES_128_CBC_SHA,
-  0
-};
+    MBEDTLS_TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
+    MBEDTLS_TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
+    MBEDTLS_TLS_DHE_RSA_WITH_AES_128_GCM_SHA256,
+    MBEDTLS_TLS_DHE_RSA_WITH_AES_128_CBC_SHA256,
+    MBEDTLS_TLS_ECDH_ECDSA_WITH_AES_128_GCM_SHA256,
+    MBEDTLS_TLS_ECDH_ECDSA_WITH_AES_128_CBC_SHA256,
+    MBEDTLS_TLS_ECDH_ECDSA_WITH_AES_128_CBC_SHA,
+    MBEDTLS_TLS_ECDH_RSA_WITH_AES_128_GCM_SHA256,
+    MBEDTLS_TLS_ECDH_RSA_WITH_AES_128_CBC_SHA256,
+    MBEDTLS_TLS_ECDH_RSA_WITH_AES_128_CBC_SHA,
+    MBEDTLS_TLS_RSA_WITH_AES_128_GCM_SHA256,
+    MBEDTLS_TLS_RSA_WITH_AES_128_CBC_SHA256,
+    MBEDTLS_TLS_RSA_WITH_AES_128_CBC_SHA, 0};
 
 /*
  * Ciphers can be specified as a colon-separated list of cipher suite names.
- * These can be found in https://github.com/ARMmbed/mbedtls/blob/development/library/ssl_ciphersuites.c#L267
+ * These can be found in
+ * https://github.com/ARMmbed/mbedtls/blob/development/library/ssl_ciphersuites.c#L267
  * E.g.: TLS-ECDHE-ECDSA-WITH-AES-128-GCM-SHA256:TLS-DHE-RSA-WITH-AES-256-CCM
  */
-static enum mg_ssl_if_result mg_set_cipher_list(struct mg_ssl_if_ctx *ctx, const char *ciphers) {
+static enum mg_ssl_if_result mg_set_cipher_list(struct mg_ssl_if_ctx *ctx,
+                                                const char *ciphers) {
   if (ciphers != NULL) {
     int ids[50], n = 0, l, id;
     const char *s = ciphers;
@@ -13692,7 +13706,8 @@ time_t mg_lwip_if_poll(struct mg_iface *iface, int timeout_ms) {
 #if MG_ENABLE_SSL
     if ((nc->flags & MG_F_SSL) && cs != NULL && cs->pcb.tcp != NULL &&
         cs->pcb.tcp->state == ESTABLISHED) {
-      if (((nc->flags & MG_F_WANT_WRITE) || nc->send_mbuf.len > 0) &&
+      if (((nc->flags & MG_F_WANT_WRITE) ||
+           (nc->send_mbuf.len > 0) && (nc->flags & MG_F_SSL_HANDSHAKE_DONE)) &&
           cs->pcb.tcp->snd_buf > 0) {
         /* Can write more. */
         if (nc->flags & MG_F_SSL_HANDSHAKE_DONE) {
@@ -13757,14 +13772,14 @@ uint32_t mg_lwip_get_poll_delay_ms(struct mg_mgr *mgr) {
 
 #endif /* MG_NET_IF == MG_NET_IF_LWIP_LOW_LEVEL */
 #ifdef MG_MODULE_LINES
-#line 1 "common/platforms/lwip/mg_lwip_ssl_krypton.c"
+#line 1 "common/platforms/lwip/mg_lwip_ssl_if.c"
 #endif
 /*
  * Copyright (c) 2014-2016 Cesanta Software Limited
  * All rights reserved
  */
 
-#if MG_NET_IF == MG_NET_IF_LWIP_LOW_LEVEL && MG_ENABLE_SSL
+#if MG_ENABLE_SSL && MG_NET_IF == MG_NET_IF_LWIP_LOW_LEVEL
 
 /* Amalgamated: #include "common/cs_dbg.h" */
 
@@ -13783,19 +13798,27 @@ uint32_t mg_lwip_get_poll_delay_ms(struct mg_mgr *mgr) {
 #define MG_LWIP_SSL_RECV_MBUF_LIMIT 3072
 #endif
 
+#ifndef MIN
 #define MIN(a, b) ((a) < (b) ? (a) : (b))
+#endif
 
 void mg_lwip_ssl_do_hs(struct mg_connection *nc) {
   struct mg_lwip_conn_state *cs = (struct mg_lwip_conn_state *) nc->sock;
   int server_side = (nc->listener != NULL);
-  enum mg_ssl_if_result res = mg_ssl_if_handshake(nc);
-  DBG(("%d %d %d", server_side, res));
+  enum mg_ssl_if_result res;
+  if (nc->flags & MG_F_CLOSE_IMMEDIATELY) return;
+  res = mg_ssl_if_handshake(nc);
+  DBG(("%p %d %d %d", nc, nc->flags, server_side, res));
   if (res != MG_SSL_OK) {
     if (res == MG_SSL_WANT_WRITE) {
       nc->flags |= MG_F_WANT_WRITE;
       cs->err = 0;
     } else if (res == MG_SSL_WANT_READ) {
-      /* Nothing, we are callback-driven. */
+      /*
+       * Nothing to do in particular, we are callback-driven.
+       * What we definitely do not need anymore is SSL reading (nothing left).
+       */
+      nc->flags &= ~MG_F_WANT_READ;
       cs->err = 0;
     } else {
       cs->err = res;
@@ -13915,9 +13938,44 @@ ssize_t kr_recv(int fd, void *buf, size_t len) {
   return len;
 }
 
+#elif MG_SSL_IF == MG_SSL_IF_MBEDTLS
+
+int ssl_socket_send(void *ctx, const unsigned char *buf, size_t len) {
+  struct mg_connection *nc = (struct mg_connection *) ctx;
+  struct mg_lwip_conn_state *cs = (struct mg_lwip_conn_state *) nc->sock;
+  int ret = mg_lwip_tcp_write(cs->nc, buf, len);
+  DBG(("%p mg_lwip_tcp_write %u = %d", cs->nc, len, ret));
+  if (ret == 0) ret = MBEDTLS_ERR_SSL_WANT_WRITE;
+  return ret;
+}
+
+int ssl_socket_recv(void *ctx, unsigned char *buf, size_t len) {
+  struct mg_connection *nc = (struct mg_connection *) ctx;
+  struct mg_lwip_conn_state *cs = (struct mg_lwip_conn_state *) nc->sock;
+  struct pbuf *seg = cs->rx_chain;
+  if (seg == NULL) {
+    DBG(("%u - nothing to read", len));
+    return MBEDTLS_ERR_SSL_WANT_READ;
+  }
+  size_t seg_len = (seg->len - cs->rx_offset);
+  DBG(("%u %u %u %u", len, cs->rx_chain->len, seg_len, cs->rx_chain->tot_len));
+  len = MIN(len, seg_len);
+  pbuf_copy_partial(seg, buf, len, cs->rx_offset);
+  cs->rx_offset += len;
+  /* TCP PCB may be NULL if connection has already been closed
+   * but we still have data to deliver to SSL. */
+  if (cs->pcb.tcp != NULL) tcp_recved(cs->pcb.tcp, len);
+  if (cs->rx_offset == cs->rx_chain->len) {
+    cs->rx_chain = pbuf_dechain(cs->rx_chain);
+    pbuf_free(seg);
+    cs->rx_offset = 0;
+  }
+  return len;
+}
+
 #endif
 
-#endif /* MG_NET_IF == MG_NET_IF_LWIP_LOW_LEVEL && MG_ENABLE_SSL */
+#endif /* MG_ENABLE_SSL && MG_NET_IF == MG_NET_IF_LWIP_LOW_LEVEL */
 #ifdef MG_MODULE_LINES
 #line 1 "common/platforms/wince/wince_libc.c"
 #endif
-- 
GitLab