Make sure we use the correct cipher when using the early_secret
[openssl.git] / ssl / tls13_enc.c
index 44d8ba9eb1bc2c4b2252f693d2a0baad2d644c4e..98a1d1ea17e96a84ea48f2e2eaea170fcb14b464 100644 (file)
@@ -9,6 +9,7 @@
 
 #include <stdlib.h>
 #include "ssl_locl.h"
+#include "internal/cryptlib.h"
 #include <openssl/evp.h>
 #include <openssl/kdf.h>
 
@@ -404,8 +405,26 @@ int tls13_change_cipher_state(SSL *s, int which)
                        SSL_R_BAD_HANDSHAKE_LENGTH);
                 goto err;
             }
+
+            if (s->early_data_state == SSL_EARLY_DATA_CONNECTING
+                    && s->max_early_data > 0
+                    && s->session->ext.max_early_data == 0) {
+                /*
+                 * If we are attempting to send early data, and we've decided to
+                 * actually do it but max_early_data in s->session is 0 then we
+                 * must be using an external PSK.
+                 */
+                if (!ossl_assert(s->psksession != NULL
+                        && s->max_early_data ==
+                           s->psksession->ext.max_early_data)) {
+                    SSLerr(SSL_F_TLS13_CHANGE_CIPHER_STATE,
+                           ERR_R_INTERNAL_ERROR);
+                    goto err;
+                }
+                sslcipher = SSL_SESSION_get0_cipher(s->psksession);
+            }
             if (sslcipher == NULL) {
-                SSLerr(SSL_F_TLS13_CHANGE_CIPHER_STATE, ERR_R_INTERNAL_ERROR);
+                SSLerr(SSL_F_TLS13_CHANGE_CIPHER_STATE, SSL_R_BAD_PSK);
                 goto err;
             }
 
@@ -607,10 +626,10 @@ int tls13_export_keying_material(SSL *s, unsigned char *out, size_t olen,
 {
     unsigned char exportsecret[EVP_MAX_MD_SIZE];
     static const unsigned char exporterlabel[] = "exporter";
-    unsigned char hash[EVP_MAX_MD_SIZE];
+    unsigned char hash[EVP_MAX_MD_SIZE], data[EVP_MAX_MD_SIZE];
     const EVP_MD *md = ssl_handshake_md(s);
     EVP_MD_CTX *ctx = EVP_MD_CTX_new();
-    unsigned int hashsize;
+    unsigned int hashsize, datalen;
     int ret = 0;
 
     if (ctx == NULL || !SSL_is_init_finished(s))
@@ -622,9 +641,11 @@ int tls13_export_keying_material(SSL *s, unsigned char *out, size_t olen,
     if (EVP_DigestInit_ex(ctx, md, NULL) <= 0
             || EVP_DigestUpdate(ctx, context, contextlen) <= 0
             || EVP_DigestFinal_ex(ctx, hash, &hashsize) <= 0
+            || EVP_DigestInit_ex(ctx, md, NULL) <= 0
+            || EVP_DigestFinal_ex(ctx, data, &datalen) <= 0
             || !tls13_hkdf_expand(s, md, s->exporter_master_secret,
-                                  (const unsigned char *)label, llen, NULL, 0,
-                                  exportsecret, hashsize)
+                                  (const unsigned char *)label, llen,
+                                  data, datalen, exportsecret, hashsize)
             || !tls13_hkdf_expand(s, md, exportsecret, exporterlabel,
                                   sizeof(exporterlabel) - 1, hash, hashsize,
                                   out, olen))