Introduce a new early_data state in the state machine
[openssl.git] / ssl / statem / statem_clnt.c
index bc35a3ea25bb96da176b70c81b4aad3b6c549586..23a4d7663bfe11adffd8c5013e8a1279b70d15e2 100644 (file)
@@ -123,11 +123,6 @@ static int ossl_statem_client13_read_transition(SSL *s, int mt)
 {
     OSSL_STATEM *st = &s->statem;
 
-    /*
-     * TODO(TLS1.3): This is still based on the TLSv1.2 state machine. Over time
-     * we will update this to look more like real TLSv1.3
-     */
-
     /*
      * Note: There is no case for TLS_ST_CW_CLNT_HELLO, because we haven't
      * yet negotiated TLSv1.3 at that point so that is handled by
@@ -258,6 +253,22 @@ int ossl_statem_client_read_transition(SSL *s, int mt)
         }
         break;
 
+    case TLS_ST_CW_EARLY_DATA:
+        /*
+         * We've not actually selected TLSv1.3 yet, but we have sent early
+         * data. The only thing allowed now is a ServerHello or a
+         * HelloRetryRequest.
+         */
+        if (mt == SSL3_MT_SERVER_HELLO) {
+            st->hand_state = TLS_ST_CR_SRVR_HELLO;
+            return 1;
+        }
+        if (mt == SSL3_MT_HELLO_RETRY_REQUEST) {
+            st->hand_state = TLS_ST_CR_HELLO_RETRY_REQUEST;
+            return 1;
+        }
+        break;
+
     case TLS_ST_CR_SRVR_HELLO:
         if (s->hit) {
             if (s->ext.ticket_expected) {
@@ -449,7 +460,6 @@ static WRITE_TRAN ossl_statem_client13_write_transition(SSL *s)
     case TLS_ST_CR_SESSION_TICKET:
     case TLS_ST_CW_FINISHED:
         st->hand_state = TLS_ST_OK;
-        ossl_statem_set_in_init(s, 0);
         return WRITE_TRAN_CONTINUE;
 
     case TLS_ST_OK:
@@ -498,12 +508,23 @@ WRITE_TRAN ossl_statem_client_write_transition(SSL *s)
         return WRITE_TRAN_CONTINUE;
 
     case TLS_ST_CW_CLNT_HELLO:
+        if (s->early_data_state == SSL_EARLY_DATA_CONNECTING) {
+            /*
+             * We are assuming this is a TLSv1.3 connection, although we haven't
+             * actually selected a version yet.
+             */
+            st->hand_state = TLS_ST_CW_EARLY_DATA;
+            return WRITE_TRAN_CONTINUE;
+        }
         /*
          * No transition at the end of writing because we don't know what
          * we will be sent
          */
         return WRITE_TRAN_FINISHED;
 
+    case TLS_ST_CW_EARLY_DATA:
+        return WRITE_TRAN_FINISHED;
+
     case DTLS_ST_CR_HELLO_VERIFY_REQUEST:
         st->hand_state = TLS_ST_CW_CLNT_HELLO;
         return WRITE_TRAN_CONTINUE;
@@ -546,7 +567,8 @@ WRITE_TRAN ossl_statem_client_write_transition(SSL *s)
 
     case TLS_ST_CW_CHANGE:
 #if defined(OPENSSL_NO_NEXTPROTONEG)
-        st->hand_state = TLS_ST_CW_FINISHED;
+        st->
+        hand_state = TLS_ST_CW_FINISHED;
 #else
         if (!SSL_IS_DTLS(s) && s->s3->npn_seen)
             st->hand_state = TLS_ST_CW_NEXT_PROTO;
@@ -564,7 +586,6 @@ WRITE_TRAN ossl_statem_client_write_transition(SSL *s)
     case TLS_ST_CW_FINISHED:
         if (s->hit) {
             st->hand_state = TLS_ST_OK;
-            ossl_statem_set_in_init(s, 0);
             return WRITE_TRAN_CONTINUE;
         } else {
             return WRITE_TRAN_FINISHED;
@@ -576,7 +597,6 @@ WRITE_TRAN ossl_statem_client_write_transition(SSL *s)
             return WRITE_TRAN_CONTINUE;
         } else {
             st->hand_state = TLS_ST_OK;
-            ossl_statem_set_in_init(s, 0);
             return WRITE_TRAN_CONTINUE;
         }
 
@@ -594,7 +614,6 @@ WRITE_TRAN ossl_statem_client_write_transition(SSL *s)
             return WRITE_TRAN_CONTINUE;
         }
         st->hand_state = TLS_ST_OK;
-        ossl_statem_set_in_init(s, 0);
         return WRITE_TRAN_CONTINUE;
     }
 }
@@ -639,6 +658,7 @@ WORK_STATE ossl_statem_client_pre_work(SSL *s, WORK_STATE wst)
         }
         break;
 
+    case TLS_ST_CW_EARLY_DATA:
     case TLS_ST_OK:
         return tls_finish_handshake(s, wst, 1);
     }
@@ -649,8 +669,6 @@ WORK_STATE ossl_statem_client_pre_work(SSL *s, WORK_STATE wst)
 /*
  * Perform any work that needs to be done after sending a message from the
  * client to the server.
-    case TLS_ST_SR_CERT_VRFY:
-        return SSL3_RT_MAX_PLAIN_LENGTH;
  */
 WORK_STATE ossl_statem_client_post_work(SSL *s, WORK_STATE wst)
 {
@@ -671,6 +689,18 @@ WORK_STATE ossl_statem_client_post_work(SSL *s, WORK_STATE wst)
             /* Treat the next message as the first packet */
             s->first_packet = 1;
         }
+
+        if (s->early_data_state == SSL_EARLY_DATA_CONNECTING
+                && s->max_early_data > 0) {
+            /*
+             * We haven't selected TLSv1.3 yet so we don't call the change
+             * cipher state function associated with the SSL_METHOD. Instead
+             * we call tls13_change_cipher_state() directly.
+             */
+            if (!tls13_change_cipher_state(s,
+                        SSL3_CC_EARLY | SSL3_CHANGE_CIPHER_CLIENT_WRITE))
+                return WORK_ERROR;
+        }
         break;
 
     case TLS_ST_CW_KEY_EXCH:
@@ -1001,9 +1031,6 @@ int tls_construct_client_hello(SSL *s, WPACKET *pkt)
     }
     /* else use the pre-loaded session */
 
-    /* This is a real handshake so make sure we clean it up at the end */
-    s->statem.cleanuphand = 1;
-
     p = s->s3->client_random;
 
     /*
@@ -1107,7 +1134,9 @@ int tls_construct_client_hello(SSL *s, WPACKET *pkt)
         return 0;
     }
 #ifndef OPENSSL_NO_COMP
-    if (ssl_allow_compression(s) && s->ctx->comp_methods) {
+    if (ssl_allow_compression(s)
+            && s->ctx->comp_methods
+            && (SSL_IS_DTLS(s) || s->s3->tmp.max_ver < TLS1_3_VERSION)) {
         int compnum = sk_SSL_COMP_num(s->ctx->comp_methods);
         for (i = 0; i < compnum; i++) {
             comp = sk_SSL_COMP_value(s->ctx->comp_methods, i);
@@ -2196,39 +2225,46 @@ MSG_PROCESS_RETURN tls_process_key_exchange(SSL *s, PACKET *pkt)
 MSG_PROCESS_RETURN tls_process_certificate_request(SSL *s, PACKET *pkt)
 {
     int ret = MSG_PROCESS_ERROR;
-    unsigned int list_len, ctype_num, i, name_len;
+    unsigned int i, name_len;
     X509_NAME *xn = NULL;
-    const unsigned char *data;
     const unsigned char *namestart, *namebytes;
     STACK_OF(X509_NAME) *ca_sk = NULL;
+    PACKET cadns;
 
     if ((ca_sk = sk_X509_NAME_new(ca_dn_cmp)) == NULL) {
         SSLerr(SSL_F_TLS_PROCESS_CERTIFICATE_REQUEST, ERR_R_MALLOC_FAILURE);
         goto err;
     }
 
-    /* get the certificate types */
-    if (!PACKET_get_1(pkt, &ctype_num)
-        || !PACKET_get_bytes(pkt, &data, ctype_num)) {
-        ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_DECODE_ERROR);
-        SSLerr(SSL_F_TLS_PROCESS_CERTIFICATE_REQUEST, SSL_R_LENGTH_MISMATCH);
-        goto err;
-    }
-    OPENSSL_free(s->cert->ctypes);
-    s->cert->ctypes = NULL;
-    if (ctype_num > SSL3_CT_NUMBER) {
-        /* If we exceed static buffer copy all to cert structure */
-        s->cert->ctypes = OPENSSL_malloc(ctype_num);
-        if (s->cert->ctypes == NULL) {
-            SSLerr(SSL_F_TLS_PROCESS_CERTIFICATE_REQUEST, ERR_R_MALLOC_FAILURE);
+    if (SSL_IS_TLS13(s)) {
+        PACKET reqctx;
+
+        /* Free and zero certificate types: it is not present in TLS 1.3 */
+        OPENSSL_free(s->s3->tmp.ctype);
+        s->s3->tmp.ctype = NULL;
+        s->s3->tmp.ctype_len = 0;
+        /* TODO(TLS1.3) need to process request context, for now ignore */
+        if (!PACKET_get_length_prefixed_1(pkt, &reqctx)) {
+            SSLerr(SSL_F_TLS_PROCESS_CERTIFICATE_REQUEST,
+                   SSL_R_LENGTH_MISMATCH);
+            goto err;
+        }
+    } else {
+        PACKET ctypes;
+
+        /* get the certificate types */
+        if (!PACKET_get_length_prefixed_1(pkt, &ctypes)) {
+            ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_DECODE_ERROR);
+            SSLerr(SSL_F_TLS_PROCESS_CERTIFICATE_REQUEST,
+                   SSL_R_LENGTH_MISMATCH);
+            goto err;
+        }
+
+        if (!PACKET_memdup(&ctypes, &s->s3->tmp.ctype, &s->s3->tmp.ctype_len)) {
+            SSLerr(SSL_F_TLS_PROCESS_CERTIFICATE_REQUEST, ERR_R_INTERNAL_ERROR);
             goto err;
         }
-        memcpy(s->cert->ctypes, data, ctype_num);
-        s->cert->ctype_num = ctype_num;
-        ctype_num = SSL3_CT_NUMBER;
     }
-    for (i = 0; i < ctype_num; i++)
-        s->s3->tmp.ctype[i] = data[i];
 
     if (SSL_USE_SIGALGS(s)) {
         PACKET sigalgs;
@@ -2257,16 +2293,15 @@ MSG_PROCESS_RETURN tls_process_certificate_request(SSL *s, PACKET *pkt)
     }
 
     /* get the CA RDNs */
-    if (!PACKET_get_net_2(pkt, &list_len)
-        || PACKET_remaining(pkt) != list_len) {
+    if (!PACKET_get_length_prefixed_2(pkt, &cadns)) {
         ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_DECODE_ERROR);
         SSLerr(SSL_F_TLS_PROCESS_CERTIFICATE_REQUEST, SSL_R_LENGTH_MISMATCH);
         goto err;
     }
 
-    while (PACKET_remaining(pkt)) {
-        if (!PACKET_get_net_2(pkt, &name_len)
-            || !PACKET_get_bytes(pkt, &namebytes, name_len)) {
+    while (PACKET_remaining(&cadns)) {
+        if (!PACKET_get_net_2(&cadns, &name_len)
+            || !PACKET_get_bytes(&cadns, &namebytes, name_len)) {
             ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_DECODE_ERROR);
             SSLerr(SSL_F_TLS_PROCESS_CERTIFICATE_REQUEST,
                    SSL_R_LENGTH_MISMATCH);
@@ -2294,10 +2329,26 @@ MSG_PROCESS_RETURN tls_process_certificate_request(SSL *s, PACKET *pkt)
         }
         xn = NULL;
     }
+    /* TODO(TLS1.3) need to parse and process extensions, for now ignore */
+    if (SSL_IS_TLS13(s)) {
+        PACKET reqexts;
+
+        if (!PACKET_get_length_prefixed_2(pkt, &reqexts)) {
+            ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_DECODE_ERROR);
+            SSLerr(SSL_F_TLS_PROCESS_CERTIFICATE_REQUEST,
+                   SSL_R_EXT_LENGTH_MISMATCH);
+            goto err;
+        }
+    }
+
+    if (PACKET_remaining(pkt) != 0) {
+        ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_DECODE_ERROR);
+        SSLerr(SSL_F_TLS_PROCESS_CERTIFICATE_REQUEST, SSL_R_LENGTH_MISMATCH);
+        goto err;
+    }
 
     /* we should setup a certificate to return.... */
     s->s3->tmp.cert_req = 1;
-    s->s3->tmp.ctype_num = ctype_num;
     sk_X509_NAME_pop_free(s->s3->tmp.ca_names, X509_NAME_free);
     s->s3->tmp.ca_names = ca_sk;
     ca_sk = NULL;
@@ -2732,12 +2783,6 @@ static int tls_construct_cke_rsa(SSL *s, WPACKET *pkt, int *al)
     }
     EVP_PKEY_CTX_free(pctx);
     pctx = NULL;
-# ifdef PKCS1_CHECK
-    if (s->options & SSL_OP_PKCS1_CHECK_1)
-        (*p)[1]++;
-    if (s->options & SSL_OP_PKCS1_CHECK_2)
-        tmp_buf[0] = 0x70;
-# endif
 
     /* Fix buf for TLS and beyond */
     if (s->version > SSL3_VERSION && !WPACKET_close(pkt)) {