Enable the ability to use an external PSK for sending early_data
[openssl.git] / ssl / statem / extensions_clnt.c
index b1c2eb0..86a1cab 100644 (file)
@@ -679,12 +679,41 @@ EXT_RETURN tls_construct_ctos_early_data(SSL *s, WPACKET *pkt,
                                          unsigned int context, X509 *x,
                                          size_t chainidx, int *al)
 {
+    const unsigned char *id;
+    size_t idlen;
+    SSL_SESSION *psksess = NULL;
+    const EVP_MD *handmd = NULL;
+
+    if (s->hello_retry_request)
+        handmd = ssl_handshake_md(s);
+
+    if (s->psk_use_session_cb != NULL
+            && !s->psk_use_session_cb(s, handmd, &id, &idlen, &psksess)) {
+        SSLerr(SSL_F_TLS_CONSTRUCT_CTOS_EARLY_DATA, SSL_R_BAD_PSK);
+        return EXT_RETURN_FAIL;
+    }
+
+    SSL_SESSION_free(s->psksession);
+    s->psksession = psksess;
+    if (psksess != NULL) {
+        OPENSSL_free(s->psksession_id);
+        s->psksession_id = OPENSSL_memdup(id, idlen);
+        if (s->psksession_id == NULL) {
+            SSLerr(SSL_F_TLS_CONSTRUCT_CTOS_EARLY_DATA, ERR_R_INTERNAL_ERROR);
+            return EXT_RETURN_FAIL;
+        }
+        s->psksession_id_len = idlen;
+    }
+
     if (s->early_data_state != SSL_EARLY_DATA_CONNECTING
-            || s->session->ext.max_early_data == 0) {
+            || (s->session->ext.max_early_data == 0
+                && (psksess == NULL || psksess->ext.max_early_data == 0))) {
         s->max_early_data = 0;
         return EXT_RETURN_NOT_SENT;
     }
-    s->max_early_data = s->session->ext.max_early_data;
+    s->max_early_data = s->session->ext.max_early_data != 0 ?
+                        s->session->ext.max_early_data
+                        : psksess->ext.max_early_data;
 
     if (!WPACKET_put_bytes_u16(pkt, TLSEXT_TYPE_early_data)
             || !WPACKET_start_sub_packet_u16(pkt)
@@ -793,12 +822,10 @@ EXT_RETURN tls_construct_ctos_psk(SSL *s, WPACKET *pkt, unsigned int context,
 {
 #ifndef OPENSSL_NO_TLS1_3
     uint32_t now, agesec, agems = 0;
-    size_t reshashsize = 0, pskhashsize = 0, binderoffset, msglen, idlen = 0;
+    size_t reshashsize = 0, pskhashsize = 0, binderoffset, msglen;
     unsigned char *resbinder = NULL, *pskbinder = NULL, *msgstart = NULL;
-    const unsigned char *id = 0;
     const EVP_MD *handmd = NULL, *mdres = NULL, *mdpsk = NULL;
     EXT_RETURN ret = EXT_RETURN_FAIL;
-    SSL_SESSION *psksess = NULL;
     int dores = 0;
 
     s->session->ext.tick_identity = TLSEXT_PSK_BAD_IDENTITY;
@@ -814,18 +841,12 @@ EXT_RETURN tls_construct_ctos_psk(SSL *s, WPACKET *pkt, unsigned int context,
      * so don't add this extension.
      */
     if (s->session->ssl_version != TLS1_3_VERSION
-            || (s->session->ext.ticklen == 0 && s->psk_use_session_cb == NULL))
+            || (s->session->ext.ticklen == 0 && s->psksession == NULL))
         return EXT_RETURN_NOT_SENT;
 
     if (s->hello_retry_request)
         handmd = ssl_handshake_md(s);
 
-    if (s->psk_use_session_cb != NULL
-            && !s->psk_use_session_cb(s, handmd, &id, &idlen, &psksess)) {
-        SSLerr(SSL_F_TLS_CONSTRUCT_CTOS_PSK, SSL_R_BAD_PSK);
-        goto err;
-    }
-
     if (s->session->ext.ticklen != 0) {
         /* Get the digest associated with the ciphersuite in the session */
         if (s->session->cipher == NULL) {
@@ -890,11 +911,11 @@ EXT_RETURN tls_construct_ctos_psk(SSL *s, WPACKET *pkt, unsigned int context,
     }
 
  dopsksess:
-    if (!dores && psksess == NULL)
+    if (!dores && s->psksession == NULL)
         return EXT_RETURN_NOT_SENT;
 
-    if (psksess != NULL) {
-        mdpsk = ssl_md(psksess->cipher->algorithm2);
+    if (s->psksession != NULL) {
+        mdpsk = ssl_md(s->psksession->cipher->algorithm2);
         if (mdpsk == NULL) {
             /*
              * Don't recognize this cipher so we can't use the session.
@@ -933,8 +954,9 @@ EXT_RETURN tls_construct_ctos_psk(SSL *s, WPACKET *pkt, unsigned int context,
         }
     }
 
-    if (psksess != NULL) {
-        if (!WPACKET_sub_memcpy_u16(pkt, id, idlen)
+    if (s->psksession != NULL) {
+        if (!WPACKET_sub_memcpy_u16(pkt, s->psksession_id,
+                                    s->psksession_id_len)
                 || !WPACKET_put_bytes_u32(pkt, 0)) {
             SSLerr(SSL_F_TLS_CONSTRUCT_CTOS_PSK, ERR_R_INTERNAL_ERROR);
             goto err;
@@ -946,7 +968,7 @@ EXT_RETURN tls_construct_ctos_psk(SSL *s, WPACKET *pkt, unsigned int context,
             || !WPACKET_start_sub_packet_u16(pkt)
             || (dores
                 && !WPACKET_sub_allocate_bytes_u8(pkt, reshashsize, &resbinder))
-            || (psksess != NULL
+            || (s->psksession != NULL
                 && !WPACKET_sub_allocate_bytes_u8(pkt, pskhashsize, &pskbinder))
             || !WPACKET_close(pkt)
             || !WPACKET_close(pkt)
@@ -969,24 +991,20 @@ EXT_RETURN tls_construct_ctos_psk(SSL *s, WPACKET *pkt, unsigned int context,
         goto err;
     }
 
-    if (psksess != NULL
+    if (s->psksession != NULL
             && tls_psk_do_binder(s, mdpsk, msgstart, binderoffset, NULL,
-                                 pskbinder, psksess, 1, 1) != 1) {
+                                 pskbinder, s->psksession, 1, 1) != 1) {
         SSLerr(SSL_F_TLS_CONSTRUCT_CTOS_PSK, ERR_R_INTERNAL_ERROR);
         goto err;
     }
 
     if (dores)
         s->session->ext.tick_identity = 0;
-    SSL_SESSION_free(s->psksession);
-    s->psksession = psksess;
-    if (psksess != NULL)
+    if (s->psksession != NULL)
         s->psksession->ext.tick_identity = (dores ? 1 : 0);
-    psksess = NULL;
 
     ret = EXT_RETURN_SENT;
  err:
-    SSL_SESSION_free(psksess);
     return ret;
 #else
     return 1;
@@ -1606,10 +1624,20 @@ int tls_parse_stoc_psk(SSL *s, PACKET *pkt, unsigned int context, X509 *x,
         return 0;
     }
 
+    /*
+     * If we used the external PSK for sending early_data then s->early_secret
+     * is already set up, so don't overwrite it. Otherwise we copy the
+     * early_secret across that we generated earlier.
+     */
+    if ((s->early_data_state != SSL_EARLY_DATA_WRITE_RETRY
+                && s->early_data_state != SSL_EARLY_DATA_FINISHED_WRITING)
+            || s->session->ext.max_early_data > 0
+            || s->psksession->ext.max_early_data == 0)
+        memcpy(s->early_secret, s->psksession->early_secret, EVP_MAX_MD_SIZE);
+
     SSL_SESSION_free(s->session);
     s->session = s->psksession;
     s->psksession = NULL;
-    memcpy(s->early_secret, s->session->early_secret, EVP_MAX_MD_SIZE);
     s->hit = 1;
 #endif