Added an explicit yield (OP_SLEEP) to QUIC testing for cooperative threading.
[openssl.git] / ssl / tls13_enc.c
index 6d2f46441af6c64cc0841a7b851596da7b677e01..f6b4b9f4c21af9bf4ae576035be7a56f1d2af0d7 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright 2016-2022 The OpenSSL Project Authors. All Rights Reserved.
+ * Copyright 2016-2023 The OpenSSL Project Authors. All Rights Reserved.
  *
  * Licensed under the Apache License 2.0 (the "License").  You may not use
  * this file except in compliance with the License.  You can obtain a copy
@@ -247,8 +247,14 @@ int tls13_generate_master_secret(SSL_CONNECTION *s, unsigned char *out,
                                  size_t *secret_size)
 {
     const EVP_MD *md = ssl_handshake_md(s);
+    int md_size;
 
-    *secret_size = EVP_MD_get_size(md);
+    md_size = EVP_MD_get_size(md);
+    if (md_size <= 0) {
+        SSLfatal(s, SSL_AD_INTERNAL_ERROR, ERR_R_INTERNAL_ERROR);
+        return 0;
+    }
+    *secret_size = (size_t)md_size;
     /* Calls SSLfatal() if required */
     return tls13_generate_secret(s, md, prev, NULL, 0, out);
 }
@@ -317,10 +323,12 @@ int tls13_setup_key_block(SSL_CONNECTION *s)
 {
     const EVP_CIPHER *c;
     const EVP_MD *hash;
+    int mac_type = NID_undef;
+    size_t mac_secret_size = 0;
 
     s->session->cipher = s->s3.tmp.new_cipher;
     if (!ssl_cipher_get_evp(SSL_CONNECTION_GET_CTX(s), s->session, &c, &hash,
-                            NULL, NULL, NULL, 0)) {
+                            &mac_type, &mac_secret_size, NULL, 0)) {
         /* Error is already recorded */
         SSLfatal_alert(s, SSL_AD_INTERNAL_ERROR);
         return 0;
@@ -330,24 +338,27 @@ int tls13_setup_key_block(SSL_CONNECTION *s)
     s->s3.tmp.new_sym_enc = c;
     ssl_evp_md_free(s->s3.tmp.new_hash);
     s->s3.tmp.new_hash = hash;
+    s->s3.tmp.new_mac_pkey_type = mac_type;
+    s->s3.tmp.new_mac_secret_size = mac_secret_size;
 
     return 1;
 }
 
-static int derive_secret_key_and_iv(SSL_CONNECTION *s, int sending,
-                                    const EVP_MD *md,
+static int derive_secret_key_and_iv(SSL_CONNECTION *s, const EVP_MD *md,
                                     const EVP_CIPHER *ciph,
+                                    int mac_type,
+                                    const EVP_MD *mac_md,
                                     const unsigned char *insecret,
                                     const unsigned char *hash,
                                     const unsigned char *label,
                                     size_t labellen, unsigned char *secret,
                                     unsigned char *key, size_t *keylen,
-                                    unsigned char *iv, size_t *ivlen,
+                                    unsigned char **iv, size_t *ivlen,
                                     size_t *taglen)
 {
     int hashleni = EVP_MD_get_size(md);
     size_t hashlen;
-    int mode;
+    int mode, mac_mdleni;
 
     /* Ensure cast to size_t is safe */
     if (!ossl_assert(hashleni >= 0)) {
@@ -362,48 +373,71 @@ static int derive_secret_key_and_iv(SSL_CONNECTION *s, int sending,
         return 0;
     }
 
-    *keylen = EVP_CIPHER_get_key_length(ciph);
-
-    mode = EVP_CIPHER_get_mode(ciph);
-    if (mode == EVP_CIPH_CCM_MODE) {
-        uint32_t algenc;
-
-        *ivlen = EVP_CCM_TLS_IV_LEN;
-        if (s->s3.tmp.new_cipher != NULL) {
-            algenc = s->s3.tmp.new_cipher->algorithm_enc;
-        } else if (s->session->cipher != NULL) {
-            /* We've not selected a cipher yet - we must be doing early data */
-            algenc = s->session->cipher->algorithm_enc;
-        } else if (s->psksession != NULL && s->psksession->cipher != NULL) {
-            /* We must be doing early data with out-of-band PSK */
-            algenc = s->psksession->cipher->algorithm_enc;
-        } else {
-            SSLfatal(s, SSL_AD_INTERNAL_ERROR, ERR_R_EVP_LIB);
+    /* if ciph is NULL cipher, then use new_hash to calculate keylen */
+    if (EVP_CIPHER_is_a(ciph, "NULL")
+        && mac_md != NULL
+        && mac_type == NID_hmac) {
+        mac_mdleni = EVP_MD_get_size(mac_md);
+
+        if (mac_mdleni < 0) {
+            SSLfatal(s, SSL_AD_INTERNAL_ERROR, ERR_R_INTERNAL_ERROR);
             return 0;
         }
-        if (algenc & (SSL_AES128CCM8 | SSL_AES256CCM8))
-            *taglen = EVP_CCM8_TLS_TAG_LEN;
-         else
-            *taglen = EVP_CCM_TLS_TAG_LEN;
+        *ivlen = *taglen = (size_t)mac_mdleni;
+        *keylen = s->s3.tmp.new_mac_secret_size;
     } else {
-        int iivlen;
 
-        if (mode == EVP_CIPH_GCM_MODE) {
-            *taglen = EVP_GCM_TLS_TAG_LEN;
+        *keylen = EVP_CIPHER_get_key_length(ciph);
+
+        mode = EVP_CIPHER_get_mode(ciph);
+        if (mode == EVP_CIPH_CCM_MODE) {
+            uint32_t algenc;
+
+            *ivlen = EVP_CCM_TLS_IV_LEN;
+            if (s->s3.tmp.new_cipher != NULL) {
+                algenc = s->s3.tmp.new_cipher->algorithm_enc;
+            } else if (s->session->cipher != NULL) {
+                /* We've not selected a cipher yet - we must be doing early data */
+                algenc = s->session->cipher->algorithm_enc;
+            } else if (s->psksession != NULL && s->psksession->cipher != NULL) {
+                /* We must be doing early data with out-of-band PSK */
+                algenc = s->psksession->cipher->algorithm_enc;
+            } else {
+                SSLfatal(s, SSL_AD_INTERNAL_ERROR, ERR_R_EVP_LIB);
+                return 0;
+            }
+            if (algenc & (SSL_AES128CCM8 | SSL_AES256CCM8))
+                *taglen = EVP_CCM8_TLS_TAG_LEN;
+            else
+                *taglen = EVP_CCM_TLS_TAG_LEN;
         } else {
-            /* CHACHA20P-POLY1305 */
-            *taglen = EVP_CHACHAPOLY_TLS_TAG_LEN;
+            int iivlen;
+
+            if (mode == EVP_CIPH_GCM_MODE) {
+                *taglen = EVP_GCM_TLS_TAG_LEN;
+            } else {
+                /* CHACHA20P-POLY1305 */
+                *taglen = EVP_CHACHAPOLY_TLS_TAG_LEN;
+            }
+            iivlen = EVP_CIPHER_get_iv_length(ciph);
+            if (iivlen < 0) {
+                SSLfatal(s, SSL_AD_INTERNAL_ERROR, ERR_R_EVP_LIB);
+                return 0;
+            }
+            *ivlen = iivlen;
         }
-        iivlen = EVP_CIPHER_get_iv_length(ciph);
-        if (iivlen < 0) {
-            SSLfatal(s, SSL_AD_INTERNAL_ERROR, ERR_R_EVP_LIB);
+    }
+
+    if (*ivlen > EVP_MAX_IV_LENGTH) {
+        *iv = OPENSSL_malloc(*ivlen);
+        if (*iv == NULL) {
+            SSLfatal(s, SSL_AD_INTERNAL_ERROR, ERR_R_MALLOC_FAILURE);
             return 0;
         }
-        *ivlen = iivlen;
     }
 
     if (!tls13_derive_key(s, md, secret, key, *keylen)
-            || !tls13_derive_iv(s, md, secret, iv, *ivlen)) {
+            || !tls13_derive_iv(s, md, secret, *iv, *ivlen)) {
         /* SSLfatal() already called */
         return 0;
     }
@@ -429,7 +463,8 @@ int tls13_change_cipher_state(SSL_CONNECTION *s, int which)
     static const unsigned char resumption_master_secret[] = "\x72\x65\x73\x20\x6D\x61\x73\x74\x65\x72";
     /* ASCII: "e exp master", in hex for EBCDIC compatibility */
     static const unsigned char early_exporter_master_secret[] = "\x65\x20\x65\x78\x70\x20\x6D\x61\x73\x74\x65\x72";
-    unsigned char iv[EVP_MAX_IV_LENGTH];
+    unsigned char iv_intern[EVP_MAX_IV_LENGTH];
+    unsigned char *iv = iv_intern;
     unsigned char key[EVP_MAX_KEY_LENGTH];
     unsigned char secret[EVP_MAX_MD_SIZE];
     unsigned char hashval[EVP_MAX_MD_SIZE];
@@ -437,21 +472,22 @@ int tls13_change_cipher_state(SSL_CONNECTION *s, int which)
     unsigned char *insecret;
     unsigned char *finsecret = NULL;
     const char *log_label = NULL;
-    size_t finsecretlen = 0;
+    int finsecretlen = 0;
     const unsigned char *label;
     size_t labellen, hashlen = 0;
     int ret = 0;
-    const EVP_MD *md = NULL;
+    const EVP_MD *md = NULL, *mac_md = NULL;
     const EVP_CIPHER *cipher = NULL;
+    int mac_pkey_type = NID_undef;
     SSL_CTX *sctx = SSL_CONNECTION_GET_CTX(s);
-    size_t keylen, ivlen, taglen;
+    size_t keylen, ivlen = EVP_MAX_IV_LENGTH, taglen;
     int level;
     int direction = (which & SSL3_CC_READ) != 0 ? OSSL_RECORD_DIRECTION_READ
                                                 : OSSL_RECORD_DIRECTION_WRITE;
 
     if (((which & SSL3_CC_CLIENT) && (which & SSL3_CC_WRITE))
             || ((which & SSL3_CC_SERVER) && (which & SSL3_CC_READ))) {
-        if (which & SSL3_CC_EARLY) {
+        if ((which & SSL3_CC_EARLY) != 0) {
             EVP_MD_CTX *mdctx = NULL;
             long handlen;
             void *hdata;
@@ -490,6 +526,23 @@ int tls13_change_cipher_state(SSL_CONNECTION *s, int which)
                 goto err;
             }
 
+            /*
+             * This ups the ref count on cipher so we better make sure we free
+             * it again
+             */
+            if (!ssl_cipher_get_evp_cipher(sctx, sslcipher, &cipher)) {
+                /* Error is already recorded */
+                SSLfatal_alert(s, SSL_AD_INTERNAL_ERROR);
+                goto err;
+            }
+
+            if (((EVP_CIPHER_flags(cipher) & EVP_CIPH_FLAG_AEAD_CIPHER) == 0)
+                && (!ssl_cipher_get_evp_md_mac(sctx, sslcipher, &mac_md,
+                                               &mac_pkey_type, NULL))) {
+                SSLfatal_alert(s, SSL_AD_INTERNAL_ERROR);
+                goto err;
+            }
+
             /*
              * We need to calculate the handshake digest using the digest from
              * the session. We haven't yet selected our ciphersuite so we can't
@@ -501,17 +554,6 @@ int tls13_change_cipher_state(SSL_CONNECTION *s, int which)
                 goto err;
             }
 
-            /*
-             * This ups the ref count on cipher so we better make sure we free
-             * it again
-             */
-            if (!ssl_cipher_get_evp_cipher(sctx, sslcipher, &cipher)) {
-                /* Error is already recorded */
-                SSLfatal_alert(s, SSL_AD_INTERNAL_ERROR);
-                EVP_MD_CTX_free(mdctx);
-                goto err;
-            }
-
             md = ssl_md(sctx, sslcipher->algorithm2);
             if (md == NULL || !EVP_DigestInit_ex(mdctx, md, NULL)
                     || !EVP_DigestUpdate(mdctx, hdata, handlen)
@@ -542,6 +584,10 @@ int tls13_change_cipher_state(SSL_CONNECTION *s, int which)
             insecret = s->handshake_secret;
             finsecret = s->client_finished_secret;
             finsecretlen = EVP_MD_get_size(ssl_handshake_md(s));
+            if (finsecretlen <= 0) {
+                SSLfatal(s, SSL_AD_INTERNAL_ERROR, ERR_R_INTERNAL_ERROR);
+                goto err;
+            }
             label = client_handshake_traffic;
             labellen = sizeof(client_handshake_traffic) - 1;
             log_label = CLIENT_HANDSHAKE_LABEL;
@@ -574,6 +620,10 @@ int tls13_change_cipher_state(SSL_CONNECTION *s, int which)
             insecret = s->handshake_secret;
             finsecret = s->server_finished_secret;
             finsecretlen = EVP_MD_get_size(ssl_handshake_md(s));
+            if (finsecretlen <= 0) {
+                SSLfatal(s, SSL_AD_INTERNAL_ERROR, ERR_R_INTERNAL_ERROR);
+                goto err;
+            }
             label = server_handshake_traffic;
             labellen = sizeof(server_handshake_traffic) - 1;
             log_label = SERVER_HANDSHAKE_LABEL;
@@ -585,9 +635,11 @@ int tls13_change_cipher_state(SSL_CONNECTION *s, int which)
         }
     }
 
-    if (!(which & SSL3_CC_EARLY)) {
+    if ((which & SSL3_CC_EARLY) == 0) {
         md = ssl_handshake_md(s);
         cipher = s->s3.tmp.new_sym_enc;
+        mac_md = s->s3.tmp.new_hash;
+        mac_pkey_type = s->s3.tmp.new_mac_pkey_type;
         if (!ssl3_digest_cached_records(s, 1)
                 || !ssl_handshake_hash(s, hashval, sizeof(hashval), &hashlen)) {
             /* SSLfatal() already called */;
@@ -624,9 +676,9 @@ int tls13_change_cipher_state(SSL_CONNECTION *s, int which)
     if (!ossl_assert(cipher != NULL))
         goto err;
 
-    if (!derive_secret_key_and_iv(s, which & SSL3_CC_WRITE, md, cipher,
+    if (!derive_secret_key_and_iv(s, md, cipher, mac_pkey_type, mac_md,
                                   insecret, hash, label, labellen, secret, key,
-                                  &keylen, iv, &ivlen, &taglen)) {
+                                  &keylen, &iv, &ivlen, &taglen)) {
         /* SSLfatal() already called */
         goto err;
     }
@@ -658,7 +710,7 @@ int tls13_change_cipher_state(SSL_CONNECTION *s, int which)
 
     if (finsecret != NULL
             && !tls13_derive_finishedkey(s, ssl_handshake_md(s), secret,
-                                         finsecret, finsecretlen)) {
+                                         finsecret, (size_t)finsecretlen)) {
         /* SSLfatal() already called */
         goto err;
     }
@@ -678,8 +730,9 @@ int tls13_change_cipher_state(SSL_CONNECTION *s, int which)
 
     if (!ssl_set_new_record_layer(s, s->version,
                                   direction,
-                                  level, key, keylen, iv, ivlen, NULL, 0,
-                                  cipher, taglen, NID_undef, NULL, NULL)) {
+                                  level, secret, hashlen, key, keylen, iv,
+                                  ivlen, NULL, 0, cipher, taglen,
+                                  mac_pkey_type, mac_md, NULL, md)) {
         /* SSLfatal already called */
         goto err;
     }
@@ -688,10 +741,14 @@ int tls13_change_cipher_state(SSL_CONNECTION *s, int which)
  err:
     if ((which & SSL3_CC_EARLY) != 0) {
         /* We up-refed this so now we need to down ref */
+        if ((EVP_CIPHER_flags(cipher) & EVP_CIPH_FLAG_AEAD_CIPHER) == 0)
+            ssl_evp_md_free(mac_md);
         ssl_evp_cipher_free(cipher);
     }
     OPENSSL_cleanse(key, sizeof(key));
     OPENSSL_cleanse(secret, sizeof(secret));
+    if (iv != iv_intern)
+        OPENSSL_free(iv);
     return ret;
 }
 
@@ -709,7 +766,8 @@ int tls13_update_key(SSL_CONNECTION *s, int sending)
     int ret = 0, l;
     int direction = sending ? OSSL_RECORD_DIRECTION_WRITE
                             : OSSL_RECORD_DIRECTION_READ;
-    unsigned char iv[EVP_MAX_IV_LENGTH];
+    unsigned char iv_intern[EVP_MAX_IV_LENGTH];
+    unsigned char *iv = iv_intern;
 
     if ((l = EVP_MD_get_size(md)) <= 0) {
         SSLfatal(s, SSL_AD_INTERNAL_ERROR, ERR_R_INTERNAL_ERROR);
@@ -722,11 +780,13 @@ int tls13_update_key(SSL_CONNECTION *s, int sending)
     else
         insecret = s->client_app_traffic_secret;
 
-    if (!derive_secret_key_and_iv(s, sending, md,
-                                  s->s3.tmp.new_sym_enc, insecret, NULL,
+    if (!derive_secret_key_and_iv(s, md,
+                                  s->s3.tmp.new_sym_enc,
+                                  s->s3.tmp.new_mac_pkey_type, s->s3.tmp.new_hash,
+                                  insecret, NULL,
                                   application_traffic,
                                   sizeof(application_traffic) - 1, secret, key,
-                                  &keylen, iv, &ivlen, &taglen)) {
+                                  &keylen, &iv, &ivlen, &taglen)) {
         /* SSLfatal() already called */
         goto err;
     }
@@ -736,9 +796,9 @@ int tls13_update_key(SSL_CONNECTION *s, int sending)
     if (!ssl_set_new_record_layer(s, s->version,
                             direction,
                             OSSL_RECORD_PROTECTION_LEVEL_APPLICATION,
-                            key, keylen, iv, ivlen, NULL, 0,
+                            insecret, hashlen, key, keylen, iv, ivlen, NULL, 0,
                             s->s3.tmp.new_sym_enc, taglen, NID_undef, NULL,
-                            NULL)) {
+                            NULL, md)) {
         /* SSLfatal already called */
         goto err;
     }
@@ -753,6 +813,8 @@ int tls13_update_key(SSL_CONNECTION *s, int sending)
  err:
     OPENSSL_cleanse(key, sizeof(key));
     OPENSSL_cleanse(secret, sizeof(secret));
+    if (iv != iv_intern)
+        OPENSSL_free(iv);
     return ret;
 }