Harmonise setting the header and closing construction
authorMatt Caswell <matt@openssl.org>
Fri, 30 Sep 2016 09:38:32 +0000 (10:38 +0100)
committerMatt Caswell <matt@openssl.org>
Mon, 3 Oct 2016 15:25:48 +0000 (16:25 +0100)
Ensure all message types work the same way including CCS so that the state
machine doesn't need to know about special cases. Put all the special logic
into ssl_set_handshake_header() and ssl_close_construct_packet().

Reviewed-by: Rich Salz <rsalz@openssl.org>
ssl/s3_lib.c
ssl/ssl_locl.h
ssl/statem/statem_clnt.c
ssl/statem/statem_dtls.c
ssl/statem/statem_lib.c
ssl/statem/statem_srvr.c

index 630c94d..d19b97a 100644 (file)
@@ -2779,6 +2779,10 @@ const SSL_CIPHER *ssl3_get_cipher(unsigned int u)
 
 int ssl3_set_handshake_header(SSL *s, WPACKET *pkt, int htype)
 {
+    /* No header in the event of a CCS */
+    if (htype == SSL3_MT_CHANGE_CIPHER_SPEC)
+        return 1;
+
     /* Set the content type and 3 bytes for the message len */
     if (!WPACKET_put_bytes_u8(pkt, htype)
             || !WPACKET_start_sub_packet_u24(pkt))
index 06cf6e6..8a7e1a9 100644 (file)
@@ -1586,7 +1586,7 @@ typedef struct ssl3_enc_method {
     /* Set the handshake header */
     int (*set_handshake_header) (SSL *s, WPACKET *pkt, int type);
     /* Close construction of the handshake message */
-    int (*close_construct_packet) (SSL *s, WPACKET *pkt);
+    int (*close_construct_packet) (SSL *s, WPACKET *pkt, int htype);
     /* Write out handshake message */
     int (*do_write) (SSL *s);
 } SSL3_ENC_METHOD;
@@ -1596,8 +1596,8 @@ typedef struct ssl3_enc_method {
         (((unsigned char *)s->init_buf->data) + s->method->ssl3_enc->hhlen)
 # define ssl_set_handshake_header(s, pkt, htype) \
         s->method->ssl3_enc->set_handshake_header((s), (pkt), (htype))
-# define ssl_close_construct_packet(s, pkt) \
-        s->method->ssl3_enc->close_construct_packet((s), (pkt))
+# define ssl_close_construct_packet(s, pkt, htype) \
+        s->method->ssl3_enc->close_construct_packet((s), (pkt), (htype))
 # define ssl_do_write(s)  s->method->ssl3_enc->do_write(s)
 
 /* Values for enc_flags */
@@ -1901,9 +1901,9 @@ __owur int ssl3_do_change_cipher_spec(SSL *ssl);
 __owur long ssl3_default_timeout(void);
 
 __owur int ssl3_set_handshake_header(SSL *s, WPACKET *pkt, int htype);
-__owur int tls_close_construct_packet(SSL *s, WPACKET *pkt);
+__owur int tls_close_construct_packet(SSL *s, WPACKET *pkt, int htype);
 __owur int dtls1_set_handshake_header(SSL *s, WPACKET *pkt, int htype);
-__owur int dtls1_close_construct_packet(SSL *s, WPACKET *pkt);
+__owur int dtls1_close_construct_packet(SSL *s, WPACKET *pkt, int htype);
 __owur int ssl3_handshake_write(SSL *s);
 
 __owur int ssl_allow_compression(SSL *s);
index 18eaf32..52c07ea 100644 (file)
@@ -516,69 +516,69 @@ int ossl_statem_client_construct_message(SSL *s, WPACKET *pkt)
     int (*confunc) (SSL *s, WPACKET *pkt) = NULL;
     int ret = 1, mt;
 
-    if (st->hand_state == TLS_ST_CW_CHANGE) {
-        /* Special case becase it is a different content type */
-        if (SSL_IS_DTLS(s))
-            return dtls_construct_change_cipher_spec(s, pkt);
+    switch (st->hand_state) {
+    default:
+        /* Shouldn't happen */
+        return 0;
 
-        return tls_construct_change_cipher_spec(s, pkt);
-    } else {
-        switch (st->hand_state) {
-        default:
-            /* Shouldn't happen */
-            return 0;
+    case TLS_ST_CW_CHANGE:
+        if (SSL_IS_DTLS(s))
+            confunc = dtls_construct_change_cipher_spec;
+        else
+            confunc = tls_construct_change_cipher_spec;
+        mt = SSL3_MT_CHANGE_CIPHER_SPEC;
+        break;
 
-        case TLS_ST_CW_CLNT_HELLO:
-            confunc = tls_construct_client_hello;
-            mt = SSL3_MT_CLIENT_HELLO;
-            break;
+    case TLS_ST_CW_CLNT_HELLO:
+        confunc = tls_construct_client_hello;
+        mt = SSL3_MT_CLIENT_HELLO;
+        break;
 
-        case TLS_ST_CW_CERT:
-            confunc = tls_construct_client_certificate;
-            mt = SSL3_MT_CERTIFICATE;
-            break;
+    case TLS_ST_CW_CERT:
+        confunc = tls_construct_client_certificate;
+        mt = SSL3_MT_CERTIFICATE;
+        break;
 
-        case TLS_ST_CW_KEY_EXCH:
-            confunc = tls_construct_client_key_exchange;
-            mt = SSL3_MT_CLIENT_KEY_EXCHANGE;
-            break;
+    case TLS_ST_CW_KEY_EXCH:
+        confunc = tls_construct_client_key_exchange;
+        mt = SSL3_MT_CLIENT_KEY_EXCHANGE;
+        break;
 
-        case TLS_ST_CW_CERT_VRFY:
-            confunc = tls_construct_client_verify;
-            mt = SSL3_MT_CERTIFICATE_VERIFY;
-            break;
+    case TLS_ST_CW_CERT_VRFY:
+        confunc = tls_construct_client_verify;
+        mt = SSL3_MT_CERTIFICATE_VERIFY;
+        break;
 
 #if !defined(OPENSSL_NO_NEXTPROTONEG)
-        case TLS_ST_CW_NEXT_PROTO:
-            confunc = tls_construct_next_proto;
-            mt = SSL3_MT_NEXT_PROTO;
-            break;
+    case TLS_ST_CW_NEXT_PROTO:
+        confunc = tls_construct_next_proto;
+        mt = SSL3_MT_NEXT_PROTO;
+        break;
 #endif
-        case TLS_ST_CW_FINISHED:
-            mt = SSL3_MT_FINISHED;
-            break;
-        }
+    case TLS_ST_CW_FINISHED:
+        mt = SSL3_MT_FINISHED;
+        break;
+    }
 
-        if (!ssl_set_handshake_header(s, pkt, mt)) {
-            SSLerr(SSL_F_OSSL_STATEM_CLIENT_CONSTRUCT_MESSAGE,
-                   ERR_R_INTERNAL_ERROR);
-            return 0;
-        }
+    if (!ssl_set_handshake_header(s, pkt, mt)) {
+        SSLerr(SSL_F_OSSL_STATEM_CLIENT_CONSTRUCT_MESSAGE,
+               ERR_R_INTERNAL_ERROR);
+        return 0;
+    }
 
-        if (st->hand_state == TLS_ST_CW_FINISHED)
-            ret = tls_construct_finished(s, pkt,
-                                         s->method->
-                                         ssl3_enc->client_finished_label,
-                                         s->method->
-                                         ssl3_enc->client_finished_label_len);
-        else
-            ret = confunc(s, pkt);
+    if (st->hand_state == TLS_ST_CW_FINISHED)
+        ret = tls_construct_finished(s, pkt,
+                                     s->method->
+                                     ssl3_enc->client_finished_label,
+                                     s->method->
+                                     ssl3_enc->client_finished_label_len);
+    else
+        ret = confunc(s, pkt);
 
-        if (!ret || !ssl_close_construct_packet(s, pkt)) {
-            SSLerr(SSL_F_OSSL_STATEM_CLIENT_CONSTRUCT_MESSAGE,
-                   ERR_R_INTERNAL_ERROR);
-            return 0;
-        }
+    if (!ret || !ssl_close_construct_packet(s, pkt, mt)) {
+        SSLerr(SSL_F_OSSL_STATEM_CLIENT_CONSTRUCT_MESSAGE,
+               ERR_R_INTERNAL_ERROR);
+        return 0;
     }
     return 1;
 }
index cc016da..5b90c56 100644 (file)
@@ -874,41 +874,16 @@ static int dtls_get_reassembled_message(SSL *s, long *len)
  */
 int dtls_construct_change_cipher_spec(SSL *s, WPACKET *pkt)
 {
-    if (!WPACKET_put_bytes_u8(pkt, SSL3_MT_CCS)) {
-        SSLerr(SSL_F_DTLS_CONSTRUCT_CHANGE_CIPHER_SPEC, ERR_R_INTERNAL_ERROR);
-        goto err;
-    }
-
-    s->d1->handshake_write_seq = s->d1->next_handshake_write_seq;
-    s->init_num = DTLS1_CCS_HEADER_LENGTH;
-
     if (s->version == DTLS1_BAD_VER) {
         s->d1->next_handshake_write_seq++;
 
         if (!WPACKET_put_bytes_u16(pkt, s->d1->handshake_write_seq)) {
             SSLerr(SSL_F_DTLS_CONSTRUCT_CHANGE_CIPHER_SPEC, ERR_R_INTERNAL_ERROR);
-            goto err;
+            ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_INTERNAL_ERROR);
         }
-
-        s->init_num += 2;
-    }
-
-    s->init_off = 0;
-
-    dtls1_set_message_header_int(s, SSL3_MT_CCS, 0,
-                                 s->d1->handshake_write_seq, 0, 0);
-
-    /* buffer the message to handle re-xmits */
-    if (!dtls1_buffer_message(s, 1)) {
-        SSLerr(SSL_F_DTLS_CONSTRUCT_CHANGE_CIPHER_SPEC, ERR_R_INTERNAL_ERROR);
-        goto err    ;
     }
 
     return 1;
-
- err:
-    ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_INTERNAL_ERROR);
-    return 0;
 }
 
 #ifndef OPENSSL_NO_SCTP
@@ -1206,35 +1181,48 @@ int dtls1_set_handshake_header(SSL *s, WPACKET *pkt, int htype)
 {
     unsigned char *header;
 
-    dtls1_set_message_header(s, htype, 0, 0, 0);
-
-    /*
-     * We allocate space at the start for the message header. This gets filled
-     * in later
-     */
-    if (!WPACKET_allocate_bytes(pkt, DTLS1_HM_HEADER_LENGTH, &header)
-            || !WPACKET_start_sub_packet(pkt))
-        return 0;
+    if (htype == SSL3_MT_CHANGE_CIPHER_SPEC) {
+        s->d1->handshake_write_seq = s->d1->next_handshake_write_seq;
+        dtls1_set_message_header_int(s, SSL3_MT_CCS, 0,
+                                     s->d1->handshake_write_seq, 0, 0);
+        if (!WPACKET_put_bytes_u8(pkt, SSL3_MT_CCS))
+            return 0;
+    } else {
+        dtls1_set_message_header(s, htype, 0, 0, 0);
+        /*
+         * We allocate space at the start for the message header. This gets
+         * filled in later
+         */
+        if (!WPACKET_allocate_bytes(pkt, DTLS1_HM_HEADER_LENGTH, &header)
+                || !WPACKET_start_sub_packet(pkt))
+            return 0;
+    }
 
     return 1;
 }
 
-int dtls1_close_construct_packet(SSL *s, WPACKET *pkt)
+int dtls1_close_construct_packet(SSL *s, WPACKET *pkt, int htype)
 {
     size_t msglen;
 
-    if (!WPACKET_close(pkt)
+    if ((htype != SSL3_MT_CHANGE_CIPHER_SPEC && !WPACKET_close(pkt))
             || !WPACKET_get_length(pkt, &msglen)
             || msglen > INT_MAX)
         return 0;
-    s->d1->w_msg_hdr.msg_len = msglen - DTLS1_HM_HEADER_LENGTH;
-    s->d1->w_msg_hdr.frag_len = msglen - DTLS1_HM_HEADER_LENGTH;
+
+    if (htype != SSL3_MT_CHANGE_CIPHER_SPEC) {
+        s->d1->w_msg_hdr.msg_len = msglen - DTLS1_HM_HEADER_LENGTH;
+        s->d1->w_msg_hdr.frag_len = msglen - DTLS1_HM_HEADER_LENGTH;
+    }
     s->init_num = (int)msglen;
     s->init_off = 0;
 
-    /* Buffer the message to handle re-xmits */
-    if (!dtls1_buffer_message(s, 0))
-        return 0;
+    if (htype != DTLS1_MT_HELLO_VERIFY_REQUEST) {
+        /* Buffer the message to handle re-xmits */
+        if (!dtls1_buffer_message(s, htype == SSL3_MT_CHANGE_CIPHER_SPEC
+                                     ? 1 : 0))
+            return 0;
+    }
 
     return 1;
 }
index cac18cc..fa0032b 100644 (file)
@@ -57,11 +57,11 @@ int ssl3_do_write(SSL *s, int type)
     return (0);
 }
 
-int tls_close_construct_packet(SSL *s, WPACKET *pkt)
+int tls_close_construct_packet(SSL *s, WPACKET *pkt, int htype)
 {
     size_t msglen;
 
-    if (!WPACKET_close(pkt)
+    if ((htype != SSL3_MT_CHANGE_CIPHER_SPEC && !WPACKET_close(pkt))
             || !WPACKET_get_length(pkt, &msglen)
             || msglen > INT_MAX)
         return 0;
@@ -260,9 +260,6 @@ int tls_construct_change_cipher_spec(SSL *s, WPACKET *pkt)
         return 0;
     }
 
-    s->init_num = 1;
-    s->init_off = 0;
-
     return 1;
 }
 
index 46bd5c7..78850a7 100644 (file)
@@ -625,87 +625,90 @@ int ossl_statem_server_construct_message(SSL *s, WPACKET *pkt)
     int (*confunc) (SSL *s, WPACKET *pkt) = NULL;
     int ret = 1, mt;
 
-    if (st->hand_state == TLS_ST_SW_CHANGE) {
-        /* Special case becase it is a different content type */
+    switch (st->hand_state) {
+    default:
+        /* Shouldn't happen */
+        return 0;
+
+    case TLS_ST_SW_CHANGE:
         if (SSL_IS_DTLS(s))
-            return dtls_construct_change_cipher_spec(s, pkt);
+            confunc = dtls_construct_change_cipher_spec;
+        else
+            confunc = tls_construct_change_cipher_spec;
+        mt = SSL3_MT_CHANGE_CIPHER_SPEC;
+        break;
 
-        return tls_construct_change_cipher_spec(s, pkt);
-    } else if (st->hand_state == DTLS_ST_SW_HELLO_VERIFY_REQUEST) {
-        /* Special case because we don't call ssl_close_construct_packet() */
-        return dtls_construct_hello_verify_request(s, pkt);
-    } else {
-        switch (st->hand_state) {
-        default:
-            /* Shouldn't happen */
-            return 0;
+    case DTLS_ST_SW_HELLO_VERIFY_REQUEST:
+        confunc = dtls_construct_hello_verify_request;
+        mt = DTLS1_MT_HELLO_VERIFY_REQUEST;
+        break;
 
-        case TLS_ST_SW_HELLO_REQ:
-            /* No construction function needed */
-            mt = SSL3_MT_HELLO_REQUEST;
-            break;
+    case TLS_ST_SW_HELLO_REQ:
+        /* No construction function needed */
+        mt = SSL3_MT_HELLO_REQUEST;
+        break;
 
-        case TLS_ST_SW_SRVR_HELLO:
-            confunc = tls_construct_server_hello;
-            mt = SSL3_MT_SERVER_HELLO;
-            break;
+    case TLS_ST_SW_SRVR_HELLO:
+        confunc = tls_construct_server_hello;
+        mt = SSL3_MT_SERVER_HELLO;
+        break;
 
-        case TLS_ST_SW_CERT:
-            confunc = tls_construct_server_certificate;
-            mt = SSL3_MT_CERTIFICATE;
-            break;
+    case TLS_ST_SW_CERT:
+        confunc = tls_construct_server_certificate;
+        mt = SSL3_MT_CERTIFICATE;
+        break;
 
-        case TLS_ST_SW_KEY_EXCH:
-            confunc = tls_construct_server_key_exchange;
-            mt = SSL3_MT_SERVER_KEY_EXCHANGE;
-            break;
+    case TLS_ST_SW_KEY_EXCH:
+        confunc = tls_construct_server_key_exchange;
+        mt = SSL3_MT_SERVER_KEY_EXCHANGE;
+        break;
 
-        case TLS_ST_SW_CERT_REQ:
-            confunc = tls_construct_certificate_request;
-            mt = SSL3_MT_CERTIFICATE_REQUEST;
-            break;
+    case TLS_ST_SW_CERT_REQ:
+        confunc = tls_construct_certificate_request;
+        mt = SSL3_MT_CERTIFICATE_REQUEST;
+        break;
 
-        case TLS_ST_SW_SRVR_DONE:
-            confunc = tls_construct_server_done;
-            mt = SSL3_MT_SERVER_DONE;
-            break;
+    case TLS_ST_SW_SRVR_DONE:
+        confunc = tls_construct_server_done;
+        mt = SSL3_MT_SERVER_DONE;
+        break;
 
-        case TLS_ST_SW_SESSION_TICKET:
-            confunc = tls_construct_new_session_ticket;
-            mt = SSL3_MT_NEWSESSION_TICKET;
-            break;
+    case TLS_ST_SW_SESSION_TICKET:
+        confunc = tls_construct_new_session_ticket;
+        mt = SSL3_MT_NEWSESSION_TICKET;
+        break;
 
-        case TLS_ST_SW_CERT_STATUS:
-            confunc = tls_construct_cert_status;
-            mt = SSL3_MT_CERTIFICATE_STATUS;
-            break;
+    case TLS_ST_SW_CERT_STATUS:
+        confunc = tls_construct_cert_status;
+        mt = SSL3_MT_CERTIFICATE_STATUS;
+        break;
 
-        case TLS_ST_SW_FINISHED:
-            mt = SSL3_MT_FINISHED;
-            break;
-        }
+    case TLS_ST_SW_FINISHED:
+        mt = SSL3_MT_FINISHED;
+        break;
+    }
 
-        if (!ssl_set_handshake_header(s, pkt, mt)) {
-            SSLerr(SSL_F_OSSL_STATEM_SERVER_CONSTRUCT_MESSAGE,
-                   ERR_R_INTERNAL_ERROR);
-            return 0;
-        }
+    if (!ssl_set_handshake_header(s, pkt, mt)) {
+        SSLerr(SSL_F_OSSL_STATEM_SERVER_CONSTRUCT_MESSAGE,
+               ERR_R_INTERNAL_ERROR);
+        return 0;
+    }
 
-        if (st->hand_state == TLS_ST_SW_FINISHED)
-            ret = tls_construct_finished(s, pkt,
-                                         s->method->
-                                         ssl3_enc->server_finished_label,
-                                         s->method->
-                                         ssl3_enc->server_finished_label_len);
-        else if (confunc != NULL)
-            ret = confunc(s, pkt);
+    if (st->hand_state == TLS_ST_SW_FINISHED)
+        ret = tls_construct_finished(s, pkt,
+                                     s->method->
+                                     ssl3_enc->server_finished_label,
+                                     s->method->
+                                     ssl3_enc->server_finished_label_len);
+    else if (confunc != NULL)
+        ret = confunc(s, pkt);
 
-        if (!ret || !ssl_close_construct_packet(s, pkt)) {
-            SSLerr(SSL_F_OSSL_STATEM_SERVER_CONSTRUCT_MESSAGE,
-                   ERR_R_INTERNAL_ERROR);
-            return 0;
-        }
+    if (!ret || !ssl_close_construct_packet(s, pkt, mt)) {
+        SSLerr(SSL_F_OSSL_STATEM_SERVER_CONSTRUCT_MESSAGE,
+               ERR_R_INTERNAL_ERROR);
+        return 0;
     }
+
     return 1;
 }
 
@@ -881,8 +884,6 @@ int dtls_raw_hello_verify_request(WPACKET *pkt, unsigned char *cookie,
 
 int dtls_construct_hello_verify_request(SSL *s, WPACKET *pkt)
 {
-    size_t msglen;
-
     if (s->ctx->app_gen_cookie_cb == NULL ||
         s->ctx->app_gen_cookie_cb(s, s->d1->cookie,
                                   &(s->d1->cookie_len)) == 0 ||
@@ -892,27 +893,12 @@ int dtls_construct_hello_verify_request(SSL *s, WPACKET *pkt)
         return 0;
     }
 
-    if (!ssl_set_handshake_header(s, pkt,
-                                         DTLS1_MT_HELLO_VERIFY_REQUEST)
-            || !dtls_raw_hello_verify_request(pkt, s->d1->cookie,
-                                              s->d1->cookie_len)
-               /*
-                * We don't call close_construct_packet() because we don't want
-                * to buffer this message
-                */
-            || !WPACKET_close(pkt)
-            || !WPACKET_get_length(pkt, &msglen)
-            || !WPACKET_finish(pkt)) {
+    if (!dtls_raw_hello_verify_request(pkt, s->d1->cookie,
+                                              s->d1->cookie_len)) {
         SSLerr(SSL_F_DTLS_CONSTRUCT_HELLO_VERIFY_REQUEST, ERR_R_INTERNAL_ERROR);
         return 0;
     }
 
-    /* number of bytes to write */
-    s->d1->w_msg_hdr.msg_len = msglen - DTLS1_HM_HEADER_LENGTH;
-    s->d1->w_msg_hdr.frag_len = msglen - DTLS1_HM_HEADER_LENGTH;
-    s->init_num = (int)msglen;
-    s->init_off = 0;
-
     return 1;
 }
 
@@ -3002,8 +2988,7 @@ int tls_construct_new_session_ticket(SSL *s, WPACKET *pkt)
 
             /* Put timeout and length */
             if (!WPACKET_put_bytes_u32(pkt, 0)
-                    || !WPACKET_put_bytes_u16(pkt, 0)
-                    || !ssl_close_construct_packet(s, pkt)) {
+                    || !WPACKET_put_bytes_u16(pkt, 0)) {
                 SSLerr(SSL_F_TLS_CONSTRUCT_NEW_SESSION_TICKET,
                        ERR_R_INTERNAL_ERROR);
                 goto err;