Don't set the handshake header in every message
[openssl.git] / ssl / statem / statem_clnt.c
index 5614d5afcb5d6755db7e34a21b7b29f394bce55e..18eaf3257fe228f9560a78a85c3e983159ec68e4 100644 (file)
@@ -513,41 +513,74 @@ WORK_STATE ossl_statem_client_post_work(SSL *s, WORK_STATE wst)
 int ossl_statem_client_construct_message(SSL *s, WPACKET *pkt)
 {
     OSSL_STATEM *st = &s->statem;
+    int (*confunc) (SSL *s, WPACKET *pkt) = NULL;
+    int ret = 1, mt;
 
-    switch (st->hand_state) {
-    default:
-        /* Shouldn't happen */
-        return 0;
+    if (st->hand_state == TLS_ST_CW_CHANGE) {
+        /* Special case becase it is a different content type */
+        if (SSL_IS_DTLS(s))
+            return dtls_construct_change_cipher_spec(s, pkt);
 
-    case TLS_ST_CW_CLNT_HELLO:
-        return tls_construct_client_hello(s, pkt);
+        return tls_construct_change_cipher_spec(s, pkt);
+    } else {
+        switch (st->hand_state) {
+        default:
+            /* Shouldn't happen */
+            return 0;
 
-    case TLS_ST_CW_CERT:
-        return tls_construct_client_certificate(s, pkt);
+        case TLS_ST_CW_CLNT_HELLO:
+            confunc = tls_construct_client_hello;
+            mt = SSL3_MT_CLIENT_HELLO;
+            break;
 
-    case TLS_ST_CW_KEY_EXCH:
-        return tls_construct_client_key_exchange(s, pkt);
+        case TLS_ST_CW_CERT:
+            confunc = tls_construct_client_certificate;
+            mt = SSL3_MT_CERTIFICATE;
+            break;
 
-    case TLS_ST_CW_CERT_VRFY:
-        return tls_construct_client_verify(s, pkt);
+        case TLS_ST_CW_KEY_EXCH:
+            confunc = tls_construct_client_key_exchange;
+            mt = SSL3_MT_CLIENT_KEY_EXCHANGE;
+            break;
 
-    case TLS_ST_CW_CHANGE:
-        if (SSL_IS_DTLS(s))
-            return dtls_construct_change_cipher_spec(s, pkt);
-        else
-            return tls_construct_change_cipher_spec(s, pkt);
+        case TLS_ST_CW_CERT_VRFY:
+            confunc = tls_construct_client_verify;
+            mt = SSL3_MT_CERTIFICATE_VERIFY;
+            break;
 
 #if !defined(OPENSSL_NO_NEXTPROTONEG)
-    case TLS_ST_CW_NEXT_PROTO:
-        return tls_construct_next_proto(s, pkt);
+        case TLS_ST_CW_NEXT_PROTO:
+            confunc = tls_construct_next_proto;
+            mt = SSL3_MT_NEXT_PROTO;
+            break;
 #endif
-    case TLS_ST_CW_FINISHED:
-        return tls_construct_finished(s, pkt,
-                                      s->method->
-                                      ssl3_enc->client_finished_label,
-                                      s->method->
-                                      ssl3_enc->client_finished_label_len);
+        case TLS_ST_CW_FINISHED:
+            mt = SSL3_MT_FINISHED;
+            break;
+        }
+
+        if (!ssl_set_handshake_header(s, pkt, mt)) {
+            SSLerr(SSL_F_OSSL_STATEM_CLIENT_CONSTRUCT_MESSAGE,
+                   ERR_R_INTERNAL_ERROR);
+            return 0;
+        }
+
+        if (st->hand_state == TLS_ST_CW_FINISHED)
+            ret = tls_construct_finished(s, pkt,
+                                         s->method->
+                                         ssl3_enc->client_finished_label,
+                                         s->method->
+                                         ssl3_enc->client_finished_label_len);
+        else
+            ret = confunc(s, pkt);
+
+        if (!ret || !ssl_close_construct_packet(s, pkt)) {
+            SSLerr(SSL_F_OSSL_STATEM_CLIENT_CONSTRUCT_MESSAGE,
+                   ERR_R_INTERNAL_ERROR);
+            return 0;
+        }
     }
+    return 1;
 }
 
 /*
@@ -736,12 +769,6 @@ int tls_construct_client_hello(SSL *s, WPACKET *pkt)
     if (i && ssl_fill_hello_random(s, 0, p, sizeof(s->s3->client_random)) <= 0)
         return 0;
 
-    if (!ssl_set_handshake_header(s, pkt, SSL3_MT_CLIENT_HELLO)) {
-        ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_HANDSHAKE_FAILURE);
-        SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_HELLO, ERR_R_INTERNAL_ERROR);
-        return 0;
-    }
-
     /*-
      * version indicates the negotiated version: for example from
      * an SSLv2/v3 compatible client hello). The client_version
@@ -855,11 +882,6 @@ int tls_construct_client_hello(SSL *s, WPACKET *pkt)
         return 0;
     }
 
-    if (!ssl_close_construct_packet(s, pkt)) {
-        SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_HELLO, ERR_R_INTERNAL_ERROR);
-        return 0;
-    }
-
     return 1;
 }
 
@@ -2455,12 +2477,6 @@ int tls_construct_client_key_exchange(SSL *s, WPACKET *pkt)
     unsigned long alg_k;
     int al = -1;
 
-    if (!ssl_set_handshake_header(s, pkt, SSL3_MT_CLIENT_KEY_EXCHANGE)) {
-        ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_HANDSHAKE_FAILURE);
-        SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE, ERR_R_INTERNAL_ERROR);
-        goto err;
-    }
-
     alg_k = s->s3->tmp.new_cipher->algorithm_mkey;
 
     if ((alg_k & SSL_PSK)
@@ -2488,12 +2504,6 @@ int tls_construct_client_key_exchange(SSL *s, WPACKET *pkt)
         goto err;
     }
 
-    if (!ssl_close_construct_packet(s, pkt)) {
-        ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_HANDSHAKE_FAILURE);
-        SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE, ERR_R_INTERNAL_ERROR);
-        goto err;
-    }
-
     return 1;
  err:
     if (al != -1)
@@ -2582,11 +2592,6 @@ int tls_construct_client_verify(SSL *s, WPACKET *pkt)
     void *hdata;
     unsigned char *sig = NULL;
 
-    if (!ssl_set_handshake_header(s, pkt, SSL3_MT_CERTIFICATE_VERIFY)) {
-        SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_VERIFY, ERR_R_INTERNAL_ERROR);
-        goto err;
-    }
-
     mctx = EVP_MD_CTX_new();
     if (mctx == NULL) {
         SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_VERIFY, ERR_R_MALLOC_FAILURE);
@@ -2640,11 +2645,6 @@ int tls_construct_client_verify(SSL *s, WPACKET *pkt)
     if (!ssl3_digest_cached_records(s, 0))
         goto err;
 
-    if (!ssl_close_construct_packet(s, pkt)) {
-        SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_VERIFY, ERR_R_INTERNAL_ERROR);
-        goto err;
-    }
-
     OPENSSL_free(sig);
     EVP_MD_CTX_free(mctx);
     return 1;
@@ -2846,11 +2846,6 @@ int tls_construct_next_proto(SSL *s, WPACKET *pkt)
     size_t len, padding_len;
     unsigned char *padding = NULL;
 
-    if (!ssl_set_handshake_header(s, pkt, SSL3_MT_NEXT_PROTO)) {
-        SSLerr(SSL_F_TLS_CONSTRUCT_NEXT_PROTO, ERR_R_INTERNAL_ERROR);
-        goto err;
-    }
-
     len = s->next_proto_negotiated_len;
     padding_len = 32 - ((len + 2) % 32);
 
@@ -2862,11 +2857,6 @@ int tls_construct_next_proto(SSL *s, WPACKET *pkt)
 
     memset(padding, 0, padding_len);
 
-    if (!ssl_close_construct_packet(s, pkt)) {
-        SSLerr(SSL_F_TLS_CONSTRUCT_NEXT_PROTO, ERR_R_INTERNAL_ERROR);
-        goto err;
-    }
-
     return 1;
  err:
     ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_INTERNAL_ERROR);