Add server side support for creating the Hello Retry Request message
[openssl.git] / ssl / statem / extensions.c
index c9e7d30eaae150454dbe15a77938e91b20551ba9..fd050f0d2ebfc12ec4ef7c8d0e15d9a0ef64f1d3 100644 (file)
@@ -57,15 +57,17 @@ typedef struct extensions_definition_st {
      */
     int (*init)(SSL *s, unsigned int context);
     /* Parse extension sent from client to server */
-    int (*parse_ctos)(SSL *s, PACKET *pkt, X509 *x, size_t chainidx, int *al);
+    int (*parse_ctos)(SSL *s, PACKET *pkt, unsigned int context, X509 *x,
+                      size_t chainidx, int *al);
     /* Parse extension send from server to client */
-    int (*parse_stoc)(SSL *s, PACKET *pkt, X509 *x, size_t chainidx, int *al);
+    int (*parse_stoc)(SSL *s, PACKET *pkt, unsigned int context, X509 *x,
+                      size_t chainidx, int *al);
     /* Construct extension sent from server to client */
-    int (*construct_stoc)(SSL *s, WPACKET *pkt, X509 *x, size_t chainidx,
-                          int *al);
+    int (*construct_stoc)(SSL *s, WPACKET *pkt, unsigned int context, X509 *x,
+                          size_t chainidx, int *al);
     /* Construct extension sent from client to server */
-    int (*construct_ctos)(SSL *s, WPACKET *pkt, X509 *x, size_t chainidx,
-                          int *al);
+    int (*construct_ctos)(SSL *s, WPACKET *pkt, unsigned int context, X509 *x,
+                          size_t chainidx, int *al);
     /*
      * Finalise extension after parsing. Always called where an extensions was
      * initialised even if the extension was not present. |sent| is set to 1 if
@@ -237,7 +239,6 @@ static const EXTENSION_DEFINITION ext_defs[] = {
         NULL, NULL, NULL, tls_construct_ctos_supported_versions, NULL
     },
     {
-        /* Must be before key_share */
         TLSEXT_TYPE_psk_kex_modes,
         EXT_CLIENT_HELLO | EXT_TLS_IMPLEMENTATION_ONLY | EXT_TLS1_3_ONLY,
         init_psk_kex_modes, tls_parse_ctos_psk_kex_modes, NULL, NULL,
@@ -471,7 +472,8 @@ int tls_parse_extension(SSL *s, TLSEXT_INDEX idx, int context,
                         RAW_EXTENSION *exts, X509 *x, size_t chainidx, int *al)
 {
     RAW_EXTENSION *currext = &exts[idx];
-    int (*parser)(SSL *s, PACKET *pkt, X509 *x, size_t chainidx, int *al) = NULL;
+    int (*parser)(SSL *s, PACKET *pkt, unsigned int context, X509 *x,
+                  size_t chainidx, int *al) = NULL;
 
     /* Skip if the extension is not present */
     if (!currext->present)
@@ -500,7 +502,7 @@ int tls_parse_extension(SSL *s, TLSEXT_INDEX idx, int context,
         parser = s->server ? extdef->parse_ctos : extdef->parse_stoc;
 
         if (parser != NULL)
-            return parser(s, &currext->data, x, chainidx, al);
+            return parser(s, &currext->data, context, x, chainidx, al);
 
         /*
          * If the parser is NULL we fall through to the custom extension
@@ -634,8 +636,8 @@ int tls_construct_extensions(SSL *s, WPACKET *pkt, unsigned int context,
     }
 
     for (i = 0, thisexd = ext_defs; i < OSSL_NELEM(ext_defs); i++, thisexd++) {
-        int (*construct)(SSL *s, WPACKET *pkt, X509 *x, size_t chainidx,
-                         int *al);
+        int (*construct)(SSL *s, WPACKET *pkt, unsigned int context, X509 *x,
+                         size_t chainidx, int *al);
 
         /* Skip if not relevant for our context */
         if ((thisexd->context & context) == 0)
@@ -662,7 +664,7 @@ int tls_construct_extensions(SSL *s, WPACKET *pkt, unsigned int context,
                 || construct == NULL)
             continue;
 
-        if (!construct(s, pkt, x, chainidx, &tmpal))
+        if (!construct(s, pkt, context, x, chainidx, &tmpal))
             goto err;
     }
 
@@ -737,10 +739,10 @@ static int final_server_name(SSL *s, unsigned int context, int sent,
     if (s->ctx != NULL && s->ctx->ext.servername_cb != 0)
         ret = s->ctx->ext.servername_cb(s, &altmp,
                                         s->ctx->ext.servername_arg);
-    else if (s->initial_ctx != NULL
-             && s->initial_ctx->ext.servername_cb != 0)
-        ret = s->initial_ctx->ext.servername_cb(s, &altmp,
-                                       s->initial_ctx->ext.servername_arg);
+    else if (s->session_ctx != NULL
+             && s->session_ctx->ext.servername_cb != 0)
+        ret = s->session_ctx->ext.servername_cb(s, &altmp,
+                                       s->session_ctx->ext.servername_arg);
 
     switch (ret) {
     case SSL_TLSEXT_ERR_ALERT_FATAL:
@@ -977,12 +979,18 @@ static int final_key_share(SSL *s, unsigned int context, int sent, int *al)
             && (!s->hit
                 || (s->ext.psk_kex_mode & TLSEXT_KEX_MODE_FLAG_KE) == 0)) {
         /* No suitable share */
-        /* TODO(TLS1.3): Send a HelloRetryRequest */
+        if (s->server && s->hello_retry_request == 0 && sent) {
+            s->hello_retry_request = 1;
+            return 1;
+        }
+
+        /* Nothing left we can do - just fail */
         *al = SSL_AD_HANDSHAKE_FAILURE;
         SSLerr(SSL_F_FINAL_KEY_SHARE, SSL_R_NO_SUITABLE_KEY_SHARE);
         return 0;
     }
 
+    s->hello_retry_request = 0;
     /*
      * For a client side resumption with no key_share we need to generate
      * the handshake secret (otherwise this is done during key_share
@@ -1000,7 +1008,6 @@ static int final_key_share(SSL *s, unsigned int context, int sent, int *al)
 static int init_psk_kex_modes(SSL *s, unsigned int context)
 {
     s->ext.psk_kex_mode = TLSEXT_KEX_MODE_FLAG_NONE;
-
     return 1;
 }
 
@@ -1014,7 +1021,7 @@ int tls_psk_do_binder(SSL *s, const EVP_MD *md, const unsigned char *msgstart,
     unsigned char hash[EVP_MAX_MD_SIZE], binderkey[EVP_MAX_MD_SIZE];
     unsigned char finishedkey[EVP_MAX_MD_SIZE], tmpbinder[EVP_MAX_MD_SIZE];
     const char resumption_label[] = "resumption psk binder key";
-    size_t hashsize = EVP_MD_size(md), bindersize;
+    size_t bindersize, hashsize = EVP_MD_size(md);
     int ret = -1;
 
     /* Generate the early_secret */