Merge early_data_info extension into early_data
[openssl.git] / ssl / statem / statem_clnt.c
index 4f4409300e879713db280ec887f93d733131bf92..32b7300e459426df6a0a7e67d2197cef38115b47 100644 (file)
@@ -123,11 +123,6 @@ static int ossl_statem_client13_read_transition(SSL *s, int mt)
 {
     OSSL_STATEM *st = &s->statem;
 
-    /*
-     * TODO(TLS1.3): This is still based on the TLSv1.2 state machine. Over time
-     * we will update this to look more like real TLSv1.3
-     */
-
     /*
      * Note: There is no case for TLS_ST_CW_CLNT_HELLO, because we haven't
      * yet negotiated TLSv1.3 at that point so that is handled by
@@ -196,11 +191,6 @@ static int ossl_statem_client13_read_transition(SSL *s, int mt)
         break;
 
     case TLS_ST_OK:
-        if (s->early_data_state == SSL_EARLY_DATA_FINISHED_WRITING
-                && mt == SSL3_MT_SERVER_HELLO) {
-            st->hand_state = TLS_ST_CR_SRVR_HELLO;
-            return 1;
-        }
         if (mt == SSL3_MT_NEWSESSION_TICKET) {
             st->hand_state = TLS_ST_CR_SESSION_TICKET;
             return 1;
@@ -263,6 +253,22 @@ int ossl_statem_client_read_transition(SSL *s, int mt)
         }
         break;
 
+    case TLS_ST_EARLY_DATA:
+        /*
+         * We've not actually selected TLSv1.3 yet, but we have sent early
+         * data. The only thing allowed now is a ServerHello or a
+         * HelloRetryRequest.
+         */
+        if (mt == SSL3_MT_SERVER_HELLO) {
+            st->hand_state = TLS_ST_CR_SRVR_HELLO;
+            return 1;
+        }
+        if (mt == SSL3_MT_HELLO_RETRY_REQUEST) {
+            st->hand_state = TLS_ST_CR_HELLO_RETRY_REQUEST;
+            return 1;
+        }
+        break;
+
     case TLS_ST_CR_SRVR_HELLO:
         if (s->hit) {
             if (s->ext.ticket_expected) {
@@ -387,21 +393,7 @@ int ossl_statem_client_read_transition(SSL *s, int mt)
         break;
 
     case TLS_ST_OK:
-        if (s->early_data_state == SSL_EARLY_DATA_FINISHED_WRITING) {
-            /*
-             * We've not actually selected TLSv1.3 yet, but we have sent early
-             * data. The only thing allowed now is a ServerHello or a
-             * HelloRetryRequest.
-             */
-            if (mt == SSL3_MT_SERVER_HELLO) {
-                st->hand_state = TLS_ST_CR_SRVR_HELLO;
-                return 1;
-            }
-            if (mt == SSL3_MT_HELLO_RETRY_REQUEST) {
-                st->hand_state = TLS_ST_CR_HELLO_RETRY_REQUEST;
-                return 1;
-            }
-        } else if (mt == SSL3_MT_HELLO_REQUEST) {
+        if (mt == SSL3_MT_HELLO_REQUEST) {
             st->hand_state = TLS_ST_CR_HELLO_REQ;
             return 1;
         }
@@ -443,6 +435,22 @@ static WRITE_TRAN ossl_statem_client13_write_transition(SSL *s)
         return WRITE_TRAN_CONTINUE;
 
     case TLS_ST_CR_FINISHED:
+        if (s->early_data_state == SSL_EARLY_DATA_WRITE_RETRY
+                || s->early_data_state == SSL_EARLY_DATA_FINISHED_WRITING)
+            st->hand_state = TLS_ST_PENDING_EARLY_DATA_END;
+        else
+            st->hand_state = (s->s3->tmp.cert_req != 0) ? TLS_ST_CW_CERT
+                                                        : TLS_ST_CW_FINISHED;
+        return WRITE_TRAN_CONTINUE;
+
+    case TLS_ST_PENDING_EARLY_DATA_END:
+        if (s->ext.early_data == SSL_EARLY_DATA_ACCEPTED) {
+            st->hand_state = TLS_ST_CW_END_OF_EARLY_DATA;
+            return WRITE_TRAN_CONTINUE;
+        }
+        /* Fall through */
+
+    case TLS_ST_CW_END_OF_EARLY_DATA:
         st->hand_state = (s->s3->tmp.cert_req != 0) ? TLS_ST_CW_CERT
                                                     : TLS_ST_CW_FINISHED;
         return WRITE_TRAN_CONTINUE;
@@ -468,7 +476,6 @@ static WRITE_TRAN ossl_statem_client13_write_transition(SSL *s)
     case TLS_ST_CR_SESSION_TICKET:
     case TLS_ST_CW_FINISHED:
         st->hand_state = TLS_ST_OK;
-        ossl_statem_set_in_init(s, 0);
         return WRITE_TRAN_CONTINUE;
 
     case TLS_ST_OK:
@@ -504,13 +511,6 @@ WRITE_TRAN ossl_statem_client_write_transition(SSL *s)
         return WRITE_TRAN_ERROR;
 
     case TLS_ST_OK:
-        if (s->early_data_state == SSL_EARLY_DATA_FINISHED_WRITING) {
-            /*
-             * We are assuming this is a TLSv1.3 connection, although we haven't
-             * actually selected a version yet.
-             */
-            return WRITE_TRAN_FINISHED;
-        }
         if (!s->renegotiate) {
             /*
              * We haven't requested a renegotiation ourselves so we must have
@@ -529,8 +529,7 @@ WRITE_TRAN ossl_statem_client_write_transition(SSL *s)
              * We are assuming this is a TLSv1.3 connection, although we haven't
              * actually selected a version yet.
              */
-            st->hand_state = TLS_ST_OK;
-            ossl_statem_set_in_init(s, 0);
+            st->hand_state = TLS_ST_EARLY_DATA;
             return WRITE_TRAN_CONTINUE;
         }
         /*
@@ -539,6 +538,9 @@ WRITE_TRAN ossl_statem_client_write_transition(SSL *s)
          */
         return WRITE_TRAN_FINISHED;
 
+    case TLS_ST_EARLY_DATA:
+        return WRITE_TRAN_FINISHED;
+
     case DTLS_ST_CR_HELLO_VERIFY_REQUEST:
         st->hand_state = TLS_ST_CW_CLNT_HELLO;
         return WRITE_TRAN_CONTINUE;
@@ -581,7 +583,8 @@ WRITE_TRAN ossl_statem_client_write_transition(SSL *s)
 
     case TLS_ST_CW_CHANGE:
 #if defined(OPENSSL_NO_NEXTPROTONEG)
-        st->hand_state = TLS_ST_CW_FINISHED;
+        st->
+        hand_state = TLS_ST_CW_FINISHED;
 #else
         if (!SSL_IS_DTLS(s) && s->s3->npn_seen)
             st->hand_state = TLS_ST_CW_NEXT_PROTO;
@@ -599,7 +602,6 @@ WRITE_TRAN ossl_statem_client_write_transition(SSL *s)
     case TLS_ST_CW_FINISHED:
         if (s->hit) {
             st->hand_state = TLS_ST_OK;
-            ossl_statem_set_in_init(s, 0);
             return WRITE_TRAN_CONTINUE;
         } else {
             return WRITE_TRAN_FINISHED;
@@ -611,7 +613,6 @@ WRITE_TRAN ossl_statem_client_write_transition(SSL *s)
             return WRITE_TRAN_CONTINUE;
         } else {
             st->hand_state = TLS_ST_OK;
-            ossl_statem_set_in_init(s, 0);
             return WRITE_TRAN_CONTINUE;
         }
 
@@ -629,7 +630,6 @@ WRITE_TRAN ossl_statem_client_write_transition(SSL *s)
             return WRITE_TRAN_CONTINUE;
         }
         st->hand_state = TLS_ST_OK;
-        ossl_statem_set_in_init(s, 0);
         return WRITE_TRAN_CONTINUE;
     }
 }
@@ -674,6 +674,18 @@ WORK_STATE ossl_statem_client_pre_work(SSL *s, WORK_STATE wst)
         }
         break;
 
+    case TLS_ST_PENDING_EARLY_DATA_END:
+        /*
+         * If we've been called by SSL_do_handshake()/SSL_write(), or we did not
+         * attempt to write early data before calling SSL_read() then we press
+         * on with the handshake. Otherwise we pause here.
+         */
+        if (s->early_data_state == SSL_EARLY_DATA_FINISHED_WRITING
+                || s->early_data_state == SSL_EARLY_DATA_NONE)
+            return WORK_FINISHED_CONTINUE;
+        /* Fall through */
+
+    case TLS_ST_EARLY_DATA:
     case TLS_ST_OK:
         return tls_finish_handshake(s, wst, 1);
     }
@@ -718,6 +730,15 @@ WORK_STATE ossl_statem_client_post_work(SSL *s, WORK_STATE wst)
         }
         break;
 
+    case TLS_ST_CW_END_OF_EARLY_DATA:
+        /*
+         * We set the enc_write_ctx back to NULL because we may end up writing
+         * in cleartext again if we get a HelloRetryRequest from the server.
+         */
+        EVP_CIPHER_CTX_free(s->enc_write_ctx);
+        s->enc_write_ctx = NULL;
+        break;
+
     case TLS_ST_CW_KEY_EXCH:
         if (tls_client_key_exchange_post_work(s) == 0)
             return WORK_ERROR;
@@ -819,6 +840,16 @@ int ossl_statem_client_construct_message(SSL *s, WPACKET *pkt,
         *mt = SSL3_MT_CLIENT_HELLO;
         break;
 
+    case TLS_ST_CW_END_OF_EARLY_DATA:
+        *confunc = tls_construct_end_of_early_data;
+        *mt = SSL3_MT_END_OF_EARLY_DATA;
+        break;
+
+    case TLS_ST_PENDING_EARLY_DATA_END:
+        *confunc = NULL;
+        *mt = SSL3_MT_DUMMY;
+        break;
+
     case TLS_ST_CW_CERT:
         *confunc = tls_construct_client_certificate;
         *mt = SSL3_MT_CERTIFICATE;
@@ -1243,6 +1274,16 @@ MSG_PROCESS_RETURN tls_process_server_hello(SSL *s, PACKET *pkt)
         goto f_err;
     }
 
+    /*
+     * In TLSv1.3 a ServerHello message signals a key change so the end of the
+     * message must be on a record boundary.
+     */
+    if (SSL_IS_TLS13(s) && RECORD_LAYER_processed_read_pending(&s->rlayer)) {
+        al = SSL_AD_UNEXPECTED_MESSAGE;
+        SSLerr(SSL_F_TLS_PROCESS_SERVER_HELLO, SSL_R_NOT_ON_RECORD_BOUNDARY);
+        goto f_err;
+    }
+
     /* load the server hello data */
     /* load the server random */
     if (!PACKET_copy_bytes(pkt, s->s3->server_random, SSL3_RANDOM_SIZE)) {
@@ -1519,8 +1560,6 @@ MSG_PROCESS_RETURN tls_process_server_hello(SSL *s, PACKET *pkt)
      */
     if (SSL_IS_TLS13(s)
             && (!s->method->ssl3_enc->setup_key_block(s)
-                || !s->method->ssl3_enc->change_cipher_state(s,
-                    SSL3_CC_HANDSHAKE | SSL3_CHANGE_CIPHER_CLIENT_WRITE)
                 || !s->method->ssl3_enc->change_cipher_state(s,
                     SSL3_CC_HANDSHAKE | SSL3_CHANGE_CIPHER_CLIENT_READ))) {
         al = SSL_AD_INTERNAL_ERROR;
@@ -3278,11 +3317,22 @@ int tls_construct_client_certificate(SSL *s, WPACKET *pkt)
                                                           : s->cert->key,
                                 &al)) {
         SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_CERTIFICATE, ERR_R_INTERNAL_ERROR);
-        ssl3_send_alert(s, SSL3_AL_FATAL, al);
-        return 0;
+        goto err;
+    }
+
+    if (SSL_IS_TLS13(s)
+            && SSL_IS_FIRST_HANDSHAKE(s)
+            && (!s->method->ssl3_enc->change_cipher_state(s,
+                    SSL3_CC_HANDSHAKE | SSL3_CHANGE_CIPHER_CLIENT_WRITE))) {
+        SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_CERTIFICATE,
+               SSL_R_CANNOT_CHANGE_CIPHER);
+        goto err;
     }
 
     return 1;
+ err:
+    ssl3_send_alert(s, SSL3_AL_FATAL, al);
+    return 0;
 }
 
 #define has_bits(i,m)   (((i)&(m)) == (m))
@@ -3530,3 +3580,16 @@ int ssl_cipher_list_to_bytes(SSL *s, STACK_OF(SSL_CIPHER) *sk, WPACKET *pkt)
 
     return 1;
 }
+
+int tls_construct_end_of_early_data(SSL *s, WPACKET *pkt)
+{
+    if (s->early_data_state != SSL_EARLY_DATA_WRITE_RETRY
+            && s->early_data_state != SSL_EARLY_DATA_FINISHED_WRITING) {
+        SSLerr(SSL_F_TLS_CONSTRUCT_END_OF_EARLY_DATA,
+               ERR_R_SHOULD_NOT_HAVE_BEEN_CALLED);
+        return 0;
+    }
+
+    s->early_data_state = SSL_EARLY_DATA_FINISHED_WRITING;
+    return 1;
+}