Enable the ability to use an external PSK for sending early_data
authorMatt Caswell <matt@openssl.org>
Wed, 5 Jul 2017 19:53:03 +0000 (20:53 +0100)
committerMatt Caswell <matt@openssl.org>
Thu, 31 Aug 2017 14:02:22 +0000 (15:02 +0100)
Reviewed-by: Ben Kaduk <kaduk@mit.edu>
(Merged from https://github.com/openssl/openssl/pull/3926)

apps/s_client.c
ssl/record/ssl3_record.c
ssl/record/ssl3_record_tls13.c
ssl/ssl_lib.c
ssl/ssl_locl.h
ssl/statem/extensions.c
ssl/statem/extensions_clnt.c
ssl/statem/extensions_srvr.c
ssl/tls13_enc.c

index 5a4a2f6..36da3b6 100644 (file)
@@ -2600,8 +2600,10 @@ int s_client_main(int argc, char **argv)
     }
 
     if (early_data_file != NULL
-            && SSL_get0_session(con) != NULL
-            && SSL_SESSION_get_max_early_data(SSL_get0_session(con)) > 0) {
+            && ((SSL_get0_session(con) != NULL
+                 && SSL_SESSION_get_max_early_data(SSL_get0_session(con)) > 0)
+                || (psksess != NULL
+                    && SSL_SESSION_get_max_early_data(psksess) > 0))) {
         BIO *edfile = BIO_new_file(early_data_file, "r");
         size_t readbytes, writtenbytes;
         int finish = 0;
index ae48504..fa7f5d9 100644 (file)
@@ -104,15 +104,24 @@ static int ssl3_record_app_data_waiting(SSL *s)
 int early_data_count_ok(SSL *s, size_t length, size_t overhead, int *al)
 {
     uint32_t max_early_data = s->max_early_data;
+    SSL_SESSION *sess = s->session;
 
     /*
      * If we are a client then we always use the max_early_data from the
-     * session. Otherwise we go with the lowest out of the max early data set in
-     * the session and the configured max_early_data.
+     * session/psksession. Otherwise we go with the lowest out of the max early
+     * data set in the session and the configured max_early_data.
      */
-    if (!s->server || (s->hit
-                       && s->session->ext.max_early_data < s->max_early_data))
-        max_early_data = s->session->ext.max_early_data;
+    if (!s->server && sess->ext.max_early_data == 0) {
+        if (!ossl_assert(s->psksession != NULL
+                         && s->psksession->ext.max_early_data > 0)) {
+            SSLerr(SSL_F_EARLY_DATA_COUNT_OK, ERR_R_INTERNAL_ERROR);
+            return 0;
+        }
+        sess = s->psksession;
+    }
+    if (!s->server
+            || (s->hit && sess->ext.max_early_data < s->max_early_data))
+        max_early_data = sess->ext.max_early_data;
 
     if (max_early_data == 0) {
         if (al != NULL)
index ec8f9f9..0c3fc6b 100644 (file)
@@ -58,7 +58,10 @@ int tls13_enc(SSL *s, SSL3_RECORD *recs, size_t n_recs, int sending)
 
     if (s->early_data_state == SSL_EARLY_DATA_WRITING
             || s->early_data_state == SSL_EARLY_DATA_WRITE_RETRY) {
-        alg_enc = s->session->cipher->algorithm_enc;
+        if (s->session != NULL && s->session->ext.max_early_data > 0)
+            alg_enc = s->session->cipher->algorithm_enc;
+        else
+            alg_enc = s->psksession->cipher->algorithm_enc;
     } else {
         /*
          * To get here we must have selected a ciphersuite - otherwise ctx would
index cac8820..70f4acf 100644 (file)
@@ -534,6 +534,9 @@ int SSL_clear(SSL *s)
     }
     SSL_SESSION_free(s->psksession);
     s->psksession = NULL;
+    OPENSSL_free(s->psksession_id);
+    s->psksession_id = NULL;
+    s->psksession_id_len = 0;
 
     s->error = 0;
     s->hit = 0;
@@ -1097,6 +1100,7 @@ void SSL_free(SSL *s)
         SSL_SESSION_free(s->session);
     }
     SSL_SESSION_free(s->psksession);
+    OPENSSL_free(s->psksession_id);
 
     clear_ciphers(s);
 
@@ -1910,8 +1914,8 @@ int SSL_write_early_data(SSL *s, const void *buf, size_t num, size_t *written)
     case SSL_EARLY_DATA_NONE:
         if (s->server
                 || !SSL_in_before(s)
-                || s->session == NULL
-                || s->session->ext.max_early_data == 0) {
+                || ((s->session == NULL || s->session->ext.max_early_data == 0)
+                     && (s->psk_use_session_cb == NULL))) {
             SSLerr(SSL_F_SSL_WRITE_EARLY_DATA,
                    ERR_R_SHOULD_NOT_HAVE_BEEN_CALLED);
             return 0;
index 4896c35..7caec67 100644 (file)
@@ -1119,6 +1119,8 @@ struct ssl_st {
     SSL_SESSION *session;
     /* TLSv1.3 PSK session */
     SSL_SESSION *psksession;
+    unsigned char *psksession_id;
+    size_t psksession_id_len;
     /* Default generate session ID callback. */
     GEN_SESSION_CB generate_session_id;
     /* Used in SSL3 */
index c435405..3d830a7 100644 (file)
@@ -1206,6 +1206,13 @@ int tls_psk_do_binder(SSL *s, const EVP_MD *md, const unsigned char *msgstart,
     const char *label;
     size_t bindersize, labelsize, hashsize = EVP_MD_size(md);
     int ret = -1;
+    int usepskfored = 0;
+
+    if (external
+            && s->early_data_state == SSL_EARLY_DATA_CONNECTING
+            && s->session->ext.max_early_data == 0
+            && sess->ext.max_early_data > 0)
+        usepskfored = 1;
 
     if (external) {
         label = external_label;
@@ -1236,11 +1243,12 @@ int tls_psk_do_binder(SSL *s, const EVP_MD *md, const unsigned char *msgstart,
     /*
      * Generate the early_secret. On the server side we've selected a PSK to
      * resume with (internal or external) so we always do this. On the client
-     * side we do this for a non-external (i.e. resumption) PSK so that it
-     * is in place for sending early data. For client side external PSK we
+     * side we do this for a non-external (i.e. resumption) PSK or external PSK
+     * that will be used for early_data so that it is in place for sending early
+     * data. For client side external PSK not being used for early_data we
      * generate it but store it away for later use.
      */
-    if (s->server || !external)
+    if (s->server || !external || usepskfored)
         early_secret = (unsigned char *)s->early_secret;
     else
         early_secret = (unsigned char *)sess->early_secret;
index b1c2eb0..86a1cab 100644 (file)
@@ -679,12 +679,41 @@ EXT_RETURN tls_construct_ctos_early_data(SSL *s, WPACKET *pkt,
                                          unsigned int context, X509 *x,
                                          size_t chainidx, int *al)
 {
+    const unsigned char *id;
+    size_t idlen;
+    SSL_SESSION *psksess = NULL;
+    const EVP_MD *handmd = NULL;
+
+    if (s->hello_retry_request)
+        handmd = ssl_handshake_md(s);
+
+    if (s->psk_use_session_cb != NULL
+            && !s->psk_use_session_cb(s, handmd, &id, &idlen, &psksess)) {
+        SSLerr(SSL_F_TLS_CONSTRUCT_CTOS_EARLY_DATA, SSL_R_BAD_PSK);
+        return EXT_RETURN_FAIL;
+    }
+
+    SSL_SESSION_free(s->psksession);
+    s->psksession = psksess;
+    if (psksess != NULL) {
+        OPENSSL_free(s->psksession_id);
+        s->psksession_id = OPENSSL_memdup(id, idlen);
+        if (s->psksession_id == NULL) {
+            SSLerr(SSL_F_TLS_CONSTRUCT_CTOS_EARLY_DATA, ERR_R_INTERNAL_ERROR);
+            return EXT_RETURN_FAIL;
+        }
+        s->psksession_id_len = idlen;
+    }
+
     if (s->early_data_state != SSL_EARLY_DATA_CONNECTING
-            || s->session->ext.max_early_data == 0) {
+            || (s->session->ext.max_early_data == 0
+                && (psksess == NULL || psksess->ext.max_early_data == 0))) {
         s->max_early_data = 0;
         return EXT_RETURN_NOT_SENT;
     }
-    s->max_early_data = s->session->ext.max_early_data;
+    s->max_early_data = s->session->ext.max_early_data != 0 ?
+                        s->session->ext.max_early_data
+                        : psksess->ext.max_early_data;
 
     if (!WPACKET_put_bytes_u16(pkt, TLSEXT_TYPE_early_data)
             || !WPACKET_start_sub_packet_u16(pkt)
@@ -793,12 +822,10 @@ EXT_RETURN tls_construct_ctos_psk(SSL *s, WPACKET *pkt, unsigned int context,
 {
 #ifndef OPENSSL_NO_TLS1_3
     uint32_t now, agesec, agems = 0;
-    size_t reshashsize = 0, pskhashsize = 0, binderoffset, msglen, idlen = 0;
+    size_t reshashsize = 0, pskhashsize = 0, binderoffset, msglen;
     unsigned char *resbinder = NULL, *pskbinder = NULL, *msgstart = NULL;
-    const unsigned char *id = 0;
     const EVP_MD *handmd = NULL, *mdres = NULL, *mdpsk = NULL;
     EXT_RETURN ret = EXT_RETURN_FAIL;
-    SSL_SESSION *psksess = NULL;
     int dores = 0;
 
     s->session->ext.tick_identity = TLSEXT_PSK_BAD_IDENTITY;
@@ -814,18 +841,12 @@ EXT_RETURN tls_construct_ctos_psk(SSL *s, WPACKET *pkt, unsigned int context,
      * so don't add this extension.
      */
     if (s->session->ssl_version != TLS1_3_VERSION
-            || (s->session->ext.ticklen == 0 && s->psk_use_session_cb == NULL))
+            || (s->session->ext.ticklen == 0 && s->psksession == NULL))
         return EXT_RETURN_NOT_SENT;
 
     if (s->hello_retry_request)
         handmd = ssl_handshake_md(s);
 
-    if (s->psk_use_session_cb != NULL
-            && !s->psk_use_session_cb(s, handmd, &id, &idlen, &psksess)) {
-        SSLerr(SSL_F_TLS_CONSTRUCT_CTOS_PSK, SSL_R_BAD_PSK);
-        goto err;
-    }
-
     if (s->session->ext.ticklen != 0) {
         /* Get the digest associated with the ciphersuite in the session */
         if (s->session->cipher == NULL) {
@@ -890,11 +911,11 @@ EXT_RETURN tls_construct_ctos_psk(SSL *s, WPACKET *pkt, unsigned int context,
     }
 
  dopsksess:
-    if (!dores && psksess == NULL)
+    if (!dores && s->psksession == NULL)
         return EXT_RETURN_NOT_SENT;
 
-    if (psksess != NULL) {
-        mdpsk = ssl_md(psksess->cipher->algorithm2);
+    if (s->psksession != NULL) {
+        mdpsk = ssl_md(s->psksession->cipher->algorithm2);
         if (mdpsk == NULL) {
             /*
              * Don't recognize this cipher so we can't use the session.
@@ -933,8 +954,9 @@ EXT_RETURN tls_construct_ctos_psk(SSL *s, WPACKET *pkt, unsigned int context,
         }
     }
 
-    if (psksess != NULL) {
-        if (!WPACKET_sub_memcpy_u16(pkt, id, idlen)
+    if (s->psksession != NULL) {
+        if (!WPACKET_sub_memcpy_u16(pkt, s->psksession_id,
+                                    s->psksession_id_len)
                 || !WPACKET_put_bytes_u32(pkt, 0)) {
             SSLerr(SSL_F_TLS_CONSTRUCT_CTOS_PSK, ERR_R_INTERNAL_ERROR);
             goto err;
@@ -946,7 +968,7 @@ EXT_RETURN tls_construct_ctos_psk(SSL *s, WPACKET *pkt, unsigned int context,
             || !WPACKET_start_sub_packet_u16(pkt)
             || (dores
                 && !WPACKET_sub_allocate_bytes_u8(pkt, reshashsize, &resbinder))
-            || (psksess != NULL
+            || (s->psksession != NULL
                 && !WPACKET_sub_allocate_bytes_u8(pkt, pskhashsize, &pskbinder))
             || !WPACKET_close(pkt)
             || !WPACKET_close(pkt)
@@ -969,24 +991,20 @@ EXT_RETURN tls_construct_ctos_psk(SSL *s, WPACKET *pkt, unsigned int context,
         goto err;
     }
 
-    if (psksess != NULL
+    if (s->psksession != NULL
             && tls_psk_do_binder(s, mdpsk, msgstart, binderoffset, NULL,
-                                 pskbinder, psksess, 1, 1) != 1) {
+                                 pskbinder, s->psksession, 1, 1) != 1) {
         SSLerr(SSL_F_TLS_CONSTRUCT_CTOS_PSK, ERR_R_INTERNAL_ERROR);
         goto err;
     }
 
     if (dores)
         s->session->ext.tick_identity = 0;
-    SSL_SESSION_free(s->psksession);
-    s->psksession = psksess;
-    if (psksess != NULL)
+    if (s->psksession != NULL)
         s->psksession->ext.tick_identity = (dores ? 1 : 0);
-    psksess = NULL;
 
     ret = EXT_RETURN_SENT;
  err:
-    SSL_SESSION_free(psksess);
     return ret;
 #else
     return 1;
@@ -1606,10 +1624,20 @@ int tls_parse_stoc_psk(SSL *s, PACKET *pkt, unsigned int context, X509 *x,
         return 0;
     }
 
+    /*
+     * If we used the external PSK for sending early_data then s->early_secret
+     * is already set up, so don't overwrite it. Otherwise we copy the
+     * early_secret across that we generated earlier.
+     */
+    if ((s->early_data_state != SSL_EARLY_DATA_WRITE_RETRY
+                && s->early_data_state != SSL_EARLY_DATA_FINISHED_WRITING)
+            || s->session->ext.max_early_data > 0
+            || s->psksession->ext.max_early_data == 0)
+        memcpy(s->early_secret, s->psksession->early_secret, EVP_MAX_MD_SIZE);
+
     SSL_SESSION_free(s->session);
     s->session = s->psksession;
     s->psksession = NULL;
-    memcpy(s->early_secret, s->session->early_secret, EVP_MAX_MD_SIZE);
     s->hit = 1;
 #endif
 
index a70f53b..2363c42 100644 (file)
@@ -745,6 +745,7 @@ int tls_parse_ctos_psk(SSL *s, PACKET *pkt, unsigned int context, X509 *x,
             memcpy(sess->sid_ctx, s->sid_ctx, s->sid_ctx_length);
             sess->sid_ctx_length = s->sid_ctx_length;
             ext = 1;
+            s->ext.early_data_ok = 1;
         } else {
             uint32_t ticket_age = 0, now, agesec, agems;
             int ret = tls_decrypt_ticket(s, PACKET_data(&identity),
index ac5d06c..1a6ed98 100644 (file)
@@ -404,6 +404,9 @@ int tls13_change_cipher_state(SSL *s, int which)
                        SSL_R_BAD_HANDSHAKE_LENGTH);
                 goto err;
             }
+
+            if (sslcipher == NULL && s->psksession != NULL)
+                sslcipher = SSL_SESSION_get0_cipher(s->psksession);
             if (sslcipher == NULL) {
                 SSLerr(SSL_F_TLS13_CHANGE_CIPHER_STATE, ERR_R_INTERNAL_ERROR);
                 goto err;