Various fixes required to allow SSL_write/SSL_read during early data
authorMatt Caswell <matt@openssl.org>
Mon, 27 Feb 2017 11:19:57 +0000 (11:19 +0000)
committerMatt Caswell <matt@openssl.org>
Thu, 2 Mar 2017 17:44:16 +0000 (17:44 +0000)
Reviewed-by: Rich Salz <rsalz@openssl.org>
(Merged from https://github.com/openssl/openssl/pull/2737)

ssl/ssl_lib.c
ssl/statem/statem.c
ssl/statem/statem_clnt.c
ssl/statem/statem_lib.c
ssl/statem/statem_locl.h
ssl/statem/statem_srvr.c
ssl/tls13_enc.c

index c244c3c..baeb3bb 100644 (file)
@@ -1650,7 +1650,6 @@ int SSL_read_early(SSL *s, void *buf, size_t num, size_t *readbytes)
             s->early_data_state = SSL_EARLY_DATA_FINISHED_READING;
         }
         *readbytes = 0;
             s->early_data_state = SSL_EARLY_DATA_FINISHED_READING;
         }
         *readbytes = 0;
-        ossl_statem_set_in_init(s, 1);
         return SSL_READ_EARLY_FINISH;
 
     default:
         return SSL_READ_EARLY_FINISH;
 
     default:
@@ -1661,7 +1660,8 @@ int SSL_read_early(SSL *s, void *buf, size_t num, size_t *readbytes)
 
 int ssl_end_of_early_data_seen(SSL *s)
 {
 
 int ssl_end_of_early_data_seen(SSL *s)
 {
-    if (s->early_data_state == SSL_EARLY_DATA_READING) {
+    if (s->early_data_state == SSL_EARLY_DATA_READING
+            || s->early_data_state == SSL_EARLY_DATA_READ_RETRY) {
         s->early_data_state = SSL_EARLY_DATA_FINISHED_READING;
         ossl_statem_finish_early_data(s);
         return 1;
         s->early_data_state = SSL_EARLY_DATA_FINISHED_READING;
         ossl_statem_finish_early_data(s);
         return 1;
@@ -3242,15 +3242,21 @@ int SSL_do_handshake(SSL *s)
         return -1;
     }
 
         return -1;
     }
 
-    if (s->early_data_state != SSL_EARLY_DATA_NONE
-            && s->early_data_state != SSL_EARLY_DATA_FINISHED_WRITING
-            && s->early_data_state != SSL_EARLY_DATA_FINISHED_READING
-            && s->early_data_state != SSL_EARLY_DATA_ACCEPTING
-            && s->early_data_state != SSL_EARLY_DATA_CONNECTING) {
-        SSLerr(SSL_F_SSL_WRITE_INTERNAL, ERR_R_SHOULD_NOT_HAVE_BEEN_CALLED);
-        return 0;
-    }
+    if (s->early_data_state == SSL_EARLY_DATA_WRITE_RETRY
+            || s->early_data_state == SSL_EARLY_DATA_READ_RETRY) {
+        /*
+         * We skip this if we were called via SSL_read_early() or
+         * SSL_write_early()
+         */
+        if (s->early_data_state == SSL_EARLY_DATA_WRITE_RETRY) {
+            int edfin;
 
 
+            edfin = SSL_write_early_finish(s);
+            if (edfin <= 0)
+                return edfin;
+        }
+        ossl_statem_set_in_init(s, 1);
+    }
 
     s->method->ssl_renegotiate_check(s, 0);
 
 
     s->method->ssl_renegotiate_check(s, 0);
 
index 50c4345..8a251ea 100644 (file)
@@ -161,7 +161,7 @@ int ossl_statem_skip_early_data(SSL *s)
         if (s->statem.hand_state != TLS_ST_SW_HELLO_RETRY_REQUEST)
             return 0;
     } else {
         if (s->statem.hand_state != TLS_ST_SW_HELLO_RETRY_REQUEST)
             return 0;
     } else {
-        if (s->statem.hand_state != TLS_ST_SW_FINISHED)
+        if (!s->server || s->statem.hand_state != TLS_ST_EARLY_DATA)
             return 0;
     }
 
             return 0;
     }
 
@@ -171,9 +171,14 @@ int ossl_statem_skip_early_data(SSL *s)
 void ossl_statem_check_finish_init(SSL *s, int send)
 {
     if (!s->server) {
 void ossl_statem_check_finish_init(SSL *s, int send)
 {
     if (!s->server) {
-        if ((send && s->statem.hand_state == TLS_ST_PENDING_EARLY_DATA_END)
+        if ((send && s->statem.hand_state == TLS_ST_PENDING_EARLY_DATA_END
+                  && s->early_data_state != SSL_EARLY_DATA_WRITING)
                 || (!send && s->statem.hand_state == TLS_ST_EARLY_DATA))
             ossl_statem_set_in_init(s, 1);
                 || (!send && s->statem.hand_state == TLS_ST_EARLY_DATA))
             ossl_statem_set_in_init(s, 1);
+    } else {
+        if (s->early_data_state == SSL_EARLY_DATA_FINISHED_READING
+                && s->statem.hand_state == TLS_ST_EARLY_DATA)
+            ossl_statem_set_in_init(s, 1);
     }
 }
 
     }
 }
 
@@ -339,9 +344,7 @@ static int state_machine(SSL *s, int server)
                 goto end;
             }
 
                 goto end;
             }
 
-        if ((SSL_IS_FIRST_HANDSHAKE(s)
-                    && s->early_data_state != SSL_EARLY_DATA_FINISHED_WRITING
-                    && s->early_data_state != SSL_EARLY_DATA_FINISHED_READING)
+        if ((SSL_in_before(s))
                 || s->renegotiate) {
             if (!tls_setup_handshake(s)) {
                 ossl_statem_set_error(s);
                 || s->renegotiate) {
             if (!tls_setup_handshake(s)) {
                 ossl_statem_set_error(s);
@@ -746,8 +749,17 @@ static SUB_STATE_RETURN write_state_machine(SSL *s)
             case WORK_FINISHED_STOP:
                 return SUB_STATE_END_HANDSHAKE;
             }
             case WORK_FINISHED_STOP:
                 return SUB_STATE_END_HANDSHAKE;
             }
+            if (!get_construct_message_f(s, &pkt, &confunc, &mt)) {
+                ossl_statem_set_error(s);
+                return SUB_STATE_ERROR;
+            }
+            if (mt == SSL3_MT_DUMMY) {
+                /* Skip construction and sending. This isn't a "real" state */
+                st->write_state = WRITE_STATE_POST_WORK;
+                st->write_state_work = WORK_MORE_A;
+                break;
+            }
             if (!WPACKET_init(&pkt, s->init_buf)
             if (!WPACKET_init(&pkt, s->init_buf)
-                    || !get_construct_message_f(s, &pkt, &confunc, &mt)
                     || !ssl_set_handshake_header(s, &pkt, mt)
                     || (confunc != NULL && !confunc(s, &pkt))
                     || !ssl_close_construct_packet(s, &pkt, mt)
                     || !ssl_set_handshake_header(s, &pkt, mt)
                     || (confunc != NULL && !confunc(s, &pkt))
                     || !ssl_close_construct_packet(s, &pkt, mt)
index 9a29ab5..b11cd19 100644 (file)
@@ -1513,8 +1513,6 @@ MSG_PROCESS_RETURN tls_process_server_hello(SSL *s, PACKET *pkt)
      */
     if (SSL_IS_TLS13(s)
             && (!s->method->ssl3_enc->setup_key_block(s)
      */
     if (SSL_IS_TLS13(s)
             && (!s->method->ssl3_enc->setup_key_block(s)
-                || !s->method->ssl3_enc->change_cipher_state(s,
-                    SSL3_CC_HANDSHAKE | SSL3_CHANGE_CIPHER_CLIENT_WRITE)
                 || !s->method->ssl3_enc->change_cipher_state(s,
                     SSL3_CC_HANDSHAKE | SSL3_CHANGE_CIPHER_CLIENT_READ))) {
         al = SSL_AD_INTERNAL_ERROR;
                 || !s->method->ssl3_enc->change_cipher_state(s,
                     SSL3_CC_HANDSHAKE | SSL3_CHANGE_CIPHER_CLIENT_READ))) {
         al = SSL_AD_INTERNAL_ERROR;
@@ -3272,11 +3270,22 @@ int tls_construct_client_certificate(SSL *s, WPACKET *pkt)
                                                           : s->cert->key,
                                 &al)) {
         SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_CERTIFICATE, ERR_R_INTERNAL_ERROR);
                                                           : s->cert->key,
                                 &al)) {
         SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_CERTIFICATE, ERR_R_INTERNAL_ERROR);
-        ssl3_send_alert(s, SSL3_AL_FATAL, al);
-        return 0;
+        goto err;
+    }
+
+    if (SSL_IS_TLS13(s)
+            && SSL_IS_FIRST_HANDSHAKE(s)
+            && (!s->method->ssl3_enc->change_cipher_state(s,
+                    SSL3_CC_HANDSHAKE | SSL3_CHANGE_CIPHER_CLIENT_WRITE))) {
+        SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_CERTIFICATE,
+               SSL_R_CANNOT_CHANGE_CIPHER);
+        goto err;
     }
 
     return 1;
     }
 
     return 1;
+ err:
+    ssl3_send_alert(s, SSL3_AL_FATAL, al);
+    return 0;
 }
 
 #define has_bits(i,m)   (((i)&(m)) == (m))
 }
 
 #define has_bits(i,m)   (((i)&(m)) == (m))
index 595d7c1..32bcad4 100644 (file)
@@ -442,6 +442,23 @@ int tls_construct_finished(SSL *s, WPACKET *pkt)
     const char *sender;
     size_t slen;
 
     const char *sender;
     size_t slen;
 
+    /* This is a real handshake so make sure we clean it up at the end */
+    if (!s->server)
+        s->statem.cleanuphand = 1;
+
+    /*
+     * We only change the keys if we didn't already do this when we sent the
+     * client certificate
+     */
+    if (SSL_IS_TLS13(s)
+            && !s->server
+            && s->s3->tmp.cert_req == 0
+            && (!s->method->ssl3_enc->change_cipher_state(s,
+                    SSL3_CC_HANDSHAKE | SSL3_CHANGE_CIPHER_CLIENT_WRITE))) {
+        SSLerr(SSL_F_TLS_CONSTRUCT_FINISHED, SSL_R_CANNOT_CHANGE_CIPHER);
+        goto err;
+    }
+
     if (s->server) {
         sender = s->method->ssl3_enc->server_finished_label;
         slen = s->method->ssl3_enc->server_finished_label_len;
     if (s->server) {
         sender = s->method->ssl3_enc->server_finished_label;
         slen = s->method->ssl3_enc->server_finished_label_len;
@@ -656,7 +673,8 @@ MSG_PROCESS_RETURN tls_process_finished(SSL *s, PACKET *pkt)
 
 
     /* This is a real handshake so make sure we clean it up at the end */
 
 
     /* This is a real handshake so make sure we clean it up at the end */
-    s->statem.cleanuphand = 1;
+    if (s->server)
+        s->statem.cleanuphand = 1;
 
     /* If this occurs, we have missed a message */
     if (!SSL_IS_TLS13(s) && !s->s3->change_cipher_spec) {
 
     /* If this occurs, we have missed a message */
     if (!SSL_IS_TLS13(s) && !s->s3->change_cipher_spec) {
index eb80b71..c52ce2b 100644 (file)
@@ -53,6 +53,9 @@
 #define EXT_TLS1_3_CERTIFICATE              0x0800
 #define EXT_TLS1_3_NEW_SESSION_TICKET       0x1000
 
 #define EXT_TLS1_3_CERTIFICATE              0x0800
 #define EXT_TLS1_3_NEW_SESSION_TICKET       0x1000
 
+/* Dummy message type */
+#define SSL3_MT_DUMMY   -1
+
 /* Message processing return codes */
 typedef enum {
     /* Something bad happened */
 /* Message processing return codes */
 typedef enum {
     /* Something bad happened */
index 9d15252..7414c19 100644 (file)
@@ -413,10 +413,6 @@ static WRITE_TRAN ossl_statem_server13_write_transition(SSL *s)
         return WRITE_TRAN_ERROR;
 
     case TLS_ST_OK:
         return WRITE_TRAN_ERROR;
 
     case TLS_ST_OK:
-        if (s->early_data_state == SSL_EARLY_DATA_FINISHED_READING) {
-            st->hand_state = TLS_ST_SW_FINISHED;
-            return WRITE_TRAN_FINISHED;
-        }
         if (s->key_update != SSL_KEY_UPDATE_NONE) {
             st->hand_state = TLS_ST_SW_KEY_UPDATE;
             return WRITE_TRAN_CONTINUE;
         if (s->key_update != SSL_KEY_UPDATE_NONE) {
             st->hand_state = TLS_ST_SW_KEY_UPDATE;
             return WRITE_TRAN_CONTINUE;
@@ -461,11 +457,8 @@ static WRITE_TRAN ossl_statem_server13_write_transition(SSL *s)
         return WRITE_TRAN_CONTINUE;
 
     case TLS_ST_SW_FINISHED:
         return WRITE_TRAN_CONTINUE;
 
     case TLS_ST_SW_FINISHED:
-        if (s->early_data_state == SSL_EARLY_DATA_ACCEPTING) {
-            st->hand_state = TLS_ST_EARLY_DATA;
-            return WRITE_TRAN_CONTINUE;
-        }
-        return WRITE_TRAN_FINISHED;
+        st->hand_state = TLS_ST_EARLY_DATA;
+        return WRITE_TRAN_CONTINUE;
 
     case TLS_ST_EARLY_DATA:
         return WRITE_TRAN_FINISHED;
 
     case TLS_ST_EARLY_DATA:
         return WRITE_TRAN_FINISHED;
@@ -708,6 +701,10 @@ WORK_STATE ossl_statem_server_pre_work(SSL *s, WORK_STATE wst)
         return WORK_FINISHED_CONTINUE;
 
     case TLS_ST_EARLY_DATA:
         return WORK_FINISHED_CONTINUE;
 
     case TLS_ST_EARLY_DATA:
+        if (s->early_data_state != SSL_EARLY_DATA_ACCEPTING)
+            return WORK_FINISHED_CONTINUE;
+        /* Fall through */
+
     case TLS_ST_OK:
         return tls_finish_handshake(s, wst, 1);
     }
     case TLS_ST_OK:
         return tls_finish_handshake(s, wst, 1);
     }
@@ -952,6 +949,11 @@ int ossl_statem_server_construct_message(SSL *s, WPACKET *pkt,
         *mt = SSL3_MT_FINISHED;
         break;
 
         *mt = SSL3_MT_FINISHED;
         break;
 
+    case TLS_ST_EARLY_DATA:
+        *confunc = NULL;
+        *mt = SSL3_MT_DUMMY;
+        break;
+
     case TLS_ST_SW_ENCRYPTED_EXTENSIONS:
         *confunc = tls_construct_encrypted_extensions;
         *mt = SSL3_MT_ENCRYPTED_EXTENSIONS;
     case TLS_ST_SW_ENCRYPTED_EXTENSIONS:
         *confunc = tls_construct_encrypted_extensions;
         *mt = SSL3_MT_ENCRYPTED_EXTENSIONS;
index db8de1d..47d23bd 100644 (file)
@@ -430,15 +430,15 @@ int tls13_change_cipher_state(SSL *s, int which)
             labellen = sizeof(client_handshake_traffic) - 1;
             log_label = CLIENT_HANDSHAKE_LABEL;
             /*
             labellen = sizeof(client_handshake_traffic) - 1;
             log_label = CLIENT_HANDSHAKE_LABEL;
             /*
-             * The hanshake hash used for the server read handshake traffic
-             * secret is the same as the hash for the server write handshake
-             * traffic secret. However, if we processed early data then we delay
-             * changing the server read cipher state until later, and the
-             * handshake hashes have moved on. Therefore we use the value saved
-             * earlier when we did the server write change cipher state.
+             * The hanshake hash used for the server read/client write handshake
+             * traffic secret is the same as the hash for the server
+             * write/client read handshake traffic secret. However, if we
+             * processed early data then we delay changing the server
+             * read/client write cipher state until later, and the handshake
+             * hashes have moved on. Therefore we use the value saved earlier
+             * when we did the server write/client read change cipher state.
              */
              */
-            if (s->server)
-                hash = s->handshake_traffic_hash;
+            hash = s->handshake_traffic_hash;
         } else {
             insecret = s->master_secret;
             label = client_application_traffic;
         } else {
             insecret = s->master_secret;
             label = client_application_traffic;
@@ -486,7 +486,7 @@ int tls13_change_cipher_state(SSL *s, int which)
     if (label == server_application_traffic)
         memcpy(s->server_finished_hash, hashval, hashlen);
 
     if (label == server_application_traffic)
         memcpy(s->server_finished_hash, hashval, hashlen);
 
-    if (s->server && label == server_handshake_traffic)
+    if (label == server_handshake_traffic)
         memcpy(s->handshake_traffic_hash, hashval, hashlen);
 
     if (label == client_application_traffic) {
         memcpy(s->handshake_traffic_hash, hashval, hashlen);
 
     if (label == client_application_traffic) {