Delay flush until after CCS with early_data
[openssl.git] / ssl / statem / statem_clnt.c
index af9e1dcd7d261494e7b6e8f557afd1251dbea599..6313b31a0809b401678a41b62f8baee5ed9cb218 100644 (file)
@@ -387,7 +387,7 @@ static WRITE_TRAN ossl_statem_client13_write_transition(SSL *s)
                 || s->early_data_state == SSL_EARLY_DATA_FINISHED_WRITING)
             st->hand_state = TLS_ST_PENDING_EARLY_DATA_END;
         else if ((s->options & SSL_OP_ENABLE_MIDDLEBOX_COMPAT) != 0
-                    && !s->hello_retry_request)
+                 && s->hello_retry_request == SSL_HRR_NONE)
             st->hand_state = TLS_ST_CW_CHANGE;
         else
             st->hand_state = (s->s3->tmp.cert_req != 0) ? TLS_ST_CW_CERT
@@ -679,27 +679,30 @@ WORK_STATE ossl_statem_client_post_work(SSL *s, WORK_STATE wst)
         break;
 
     case TLS_ST_CW_CLNT_HELLO:
-        if (wst == WORK_MORE_A && statem_flush(s) != 1)
-            return WORK_MORE_A;
-
-        if (SSL_IS_DTLS(s)) {
-            /* Treat the next message as the first packet */
-            s->first_packet = 1;
-        }
-
         if (s->early_data_state == SSL_EARLY_DATA_CONNECTING
-                && s->max_early_data > 0
-                && (s->options & SSL_OP_ENABLE_MIDDLEBOX_COMPAT) == 0) {
+                && s->max_early_data > 0) {
             /*
              * We haven't selected TLSv1.3 yet so we don't call the change
              * cipher state function associated with the SSL_METHOD. Instead
              * we call tls13_change_cipher_state() directly.
              */
-            if (!tls13_change_cipher_state(s,
-                        SSL3_CC_EARLY | SSL3_CHANGE_CIPHER_CLIENT_WRITE)) {
-                /* SSLfatal() already called */
-                return WORK_ERROR;
+            if ((s->options & SSL_OP_ENABLE_MIDDLEBOX_COMPAT) == 0) {
+                if (!statem_flush(s))
+                    return WORK_MORE_A;
+                if (!tls13_change_cipher_state(s,
+                            SSL3_CC_EARLY | SSL3_CHANGE_CIPHER_CLIENT_WRITE)) {
+                    /* SSLfatal() already called */
+                    return WORK_ERROR;
+                }
             }
+            /* else we're in compat mode so we delay flushing until after CCS */
+        } else if (!statem_flush(s)) {
+            return WORK_MORE_A;
+        }
+
+        if (SSL_IS_DTLS(s)) {
+            /* Treat the next message as the first packet */
+            s->first_packet = 1;
         }
         break;
 
@@ -724,6 +727,8 @@ WORK_STATE ossl_statem_client_post_work(SSL *s, WORK_STATE wst)
             break;
         if (s->early_data_state == SSL_EARLY_DATA_CONNECTING
                     && s->max_early_data > 0) {
+            if (statem_flush(s) != 1)
+                return WORK_MORE_A;
             /*
              * We haven't selected TLSv1.3 yet so we don't call the change
              * cipher state function associated with the SSL_METHOD. Instead
@@ -1055,7 +1060,8 @@ int tls_construct_client_hello(SSL *s, WPACKET *pkt)
     if (sess == NULL
             || !ssl_version_supported(s, sess->ssl_version)
             || !SSL_SESSION_is_resumable(sess)) {
-        if (!s->hello_retry_request && !ssl_get_new_session(s, 0)) {
+        if (s->hello_retry_request == SSL_HRR_NONE
+                && !ssl_get_new_session(s, 0)) {
             /* SSLfatal() already called */
             return 0;
         }
@@ -1078,7 +1084,7 @@ int tls_construct_client_hello(SSL *s, WPACKET *pkt)
             }
         }
     } else {
-        i = s->hello_retry_request == 0;
+        i = (s->hello_retry_request == SSL_HRR_NONE);
     }
 
     if (i && ssl_fill_hello_random(s, 0, p, sizeof(s->s3->client_random),
@@ -1136,7 +1142,7 @@ int tls_construct_client_hello(SSL *s, WPACKET *pkt)
             sess_id_len = sizeof(s->tmp_session_id);
             s->tmp_session_id_len = sess_id_len;
             session_id = s->tmp_session_id;
-            if (!s->hello_retry_request
+            if (s->hello_retry_request == SSL_HRR_NONE
                     && ssl_randbytes(s, s->tmp_session_id,
                                      sess_id_len) <= 0) {
                 SSLfatal(s, SSL_AD_INTERNAL_ERROR,
@@ -1360,7 +1366,8 @@ MSG_PROCESS_RETURN tls_process_server_hello(SSL *s, PACKET *pkt)
             && sversion == TLS1_2_VERSION
             && PACKET_remaining(pkt) >= SSL3_RANDOM_SIZE
             && memcmp(hrrrandom, PACKET_data(pkt), SSL3_RANDOM_SIZE) == 0) {
-        s->hello_retry_request = hrr = 1;
+        s->hello_retry_request = SSL_HRR_PENDING;
+        hrr = 1;
         if (!PACKET_forward(pkt, SSL3_RANDOM_SIZE)) {
             SSLfatal(s, SSL_AD_DECODE_ERROR, SSL_F_TLS_PROCESS_SERVER_HELLO,
                      SSL_R_LENGTH_MISMATCH);