Fix SSL_set_tlsext_debug_callback/-tlsextdebug
[openssl.git] / ssl / statem / extensions.c
index fd76337564024f16f49efb615f2f6bfa8db42af7..f62b1fe65f9e3c56c7e4e4286e8396b103403453 100644 (file)
@@ -91,7 +91,7 @@ typedef struct extensions_definition_st {
 
 /*
  * Definitions of all built-in extensions. NOTE: Changes in the number or order
- * of these extensions should be mirrored with equivalent changes to the 
+ * of these extensions should be mirrored with equivalent changes to the
  * indexes ( TLSEXT_IDX_* ) defined in ssl_locl.h.
  * Each extension has an initialiser, a client and
  * server side parser and a finaliser. The initialiser is called (if the
@@ -462,6 +462,7 @@ int tls_collect_extensions(SSL *s, PACKET *packet, unsigned int context,
         return 0;
     }
 
+    i = 0;
     while (PACKET_remaining(&extensions) > 0) {
         unsigned int type, idx;
         PACKET extension;
@@ -518,6 +519,12 @@ int tls_collect_extensions(SSL *s, PACKET *packet, unsigned int context,
             thisex->data = extension;
             thisex->present = 1;
             thisex->type = type;
+            thisex->received_order = i++;
+            if (s->ext.debug_cb)
+                s->ext.debug_cb(s, !s->server, thisex->type,
+                                PACKET_data(&thisex->data),
+                                PACKET_remaining(&thisex->data),
+                                s->ext.debug_arg);
         }
     }
 
@@ -569,12 +576,6 @@ int tls_parse_extension(SSL *s, TLSEXT_INDEX idx, int context,
     if (!currext->present)
         return 1;
 
-    if (s->ext.debug_cb)
-        s->ext.debug_cb(s, !s->server, currext->type,
-                        PACKET_data(&currext->data),
-                        PACKET_remaining(&currext->data),
-                        s->ext.debug_arg);
-
     /* Skip if we've already parsed this extension */
     if (currext->parsed)
         return 1;
@@ -1081,7 +1082,7 @@ static int init_srtp(SSL *s, unsigned int context)
 
 static int final_sig_algs(SSL *s, unsigned int context, int sent, int *al)
 {
-    if (!sent && SSL_IS_TLS13(s)) {
+    if (!sent && SSL_IS_TLS13(s) && !s->hit) {
         *al = TLS13_AD_MISSING_EXTENSION;
         SSLerr(SSL_F_FINAL_SIG_ALGS, SSL_R_MISSING_SIGALGS_EXTENSION);
         return 0;
@@ -1116,7 +1117,7 @@ static int final_key_share(SSL *s, unsigned int context, int sent, int *al)
             && (!s->hit
                 || (s->ext.psk_kex_mode & TLSEXT_KEX_MODE_FLAG_KE) == 0)) {
         /* Nothing left we can do - just fail */
-        *al = SSL_AD_HANDSHAKE_FAILURE;
+        *al = SSL_AD_MISSING_EXTENSION;
         SSLerr(SSL_F_FINAL_KEY_SHARE, SSL_R_NO_SUITABLE_KEY_SHARE);
         return 0;
     }
@@ -1225,21 +1226,60 @@ static int init_psk_kex_modes(SSL *s, unsigned int context)
 
 int tls_psk_do_binder(SSL *s, const EVP_MD *md, const unsigned char *msgstart,
                       size_t binderoffset, const unsigned char *binderin,
-                      unsigned char *binderout,
-                      SSL_SESSION *sess, int sign)
+                      unsigned char *binderout, SSL_SESSION *sess, int sign,
+                      int external)
 {
     EVP_PKEY *mackey = NULL;
     EVP_MD_CTX *mctx = NULL;
     unsigned char hash[EVP_MAX_MD_SIZE], binderkey[EVP_MAX_MD_SIZE];
     unsigned char finishedkey[EVP_MAX_MD_SIZE], tmpbinder[EVP_MAX_MD_SIZE];
+    unsigned char tmppsk[EVP_MAX_MD_SIZE];
+    unsigned char *early_secret, *psk;
     const char resumption_label[] = "res binder";
-    size_t bindersize, hashsize = EVP_MD_size(md);
+    const char external_label[] = "ext binder";
+    const char nonce_label[] = "resumption";
+    const char *label;
+    size_t bindersize, labelsize, hashsize = EVP_MD_size(md);
     int ret = -1;
 
-    /* Generate the early_secret */
-    if (!tls13_generate_secret(s, md, NULL, sess->master_key,
-                               sess->master_key_length,
-                               (unsigned char *)&s->early_secret)) {
+    if (external) {
+        label = external_label;
+        labelsize = sizeof(external_label) - 1;
+    } else {
+        label = resumption_label;
+        labelsize = sizeof(resumption_label) - 1;
+    }
+
+    if (sess->master_key_length != hashsize) {
+        SSLerr(SSL_F_TLS_PSK_DO_BINDER, SSL_R_BAD_PSK);
+        goto err;
+    }
+
+    if (external) {
+        psk = sess->master_key;
+    } else {
+        psk = tmppsk;
+        if (!tls13_hkdf_expand(s, md, sess->master_key,
+                               (const unsigned char *)nonce_label,
+                               sizeof(nonce_label) - 1, sess->ext.tick_nonce,
+                               sess->ext.tick_nonce_len, psk, hashsize)) {
+            SSLerr(SSL_F_TLS_PSK_DO_BINDER, ERR_R_INTERNAL_ERROR);
+            goto err;
+        }
+    }
+
+    /*
+     * Generate the early_secret. On the server side we've selected a PSK to
+     * resume with (internal or external) so we always do this. On the client
+     * side we do this for a non-external (i.e. resumption) PSK so that it
+     * is in place for sending early data. For client side external PSK we
+     * generate it but store it away for later use.
+     */
+    if (s->server || !external)
+        early_secret = (unsigned char *)s->early_secret;
+    else
+        early_secret = (unsigned char *)sess->early_secret;
+    if (!tls13_generate_secret(s, md, NULL, psk, hashsize, early_secret)) {
         SSLerr(SSL_F_TLS_PSK_DO_BINDER, ERR_R_INTERNAL_ERROR);
         goto err;
     }
@@ -1257,10 +1297,8 @@ int tls_psk_do_binder(SSL *s, const EVP_MD *md, const unsigned char *msgstart,
     }
 
     /* Generate the binder key */
-    if (!tls13_hkdf_expand(s, md, s->early_secret,
-                           (unsigned char *)resumption_label,
-                           sizeof(resumption_label) - 1, hash, binderkey,
-                           hashsize)) {
+    if (!tls13_hkdf_expand(s, md, early_secret, (unsigned char *)label,
+                           labelsize, hash, hashsize, binderkey, hashsize)) {
         SSLerr(SSL_F_TLS_PSK_DO_BINDER, ERR_R_INTERNAL_ERROR);
         goto err;
     }