Add WPACKET_sub_memcpy() function
[openssl.git] / ssl / statem / statem_clnt.c
index 8f250cdc133097943b380d2a80410582a4a177be..b6895f5762dcf067a54c9dd4149ae069a0c37447 100644 (file)
@@ -63,7 +63,7 @@ static ossl_inline int cert_req_allowed(SSL *s);
 static int key_exchange_expected(SSL *s);
 static int ca_dn_cmp(const X509_NAME *const *a, const X509_NAME *const *b);
 static int ssl_cipher_list_to_bytes(SSL *s, STACK_OF(SSL_CIPHER) *sk,
-                                    unsigned char *p);
+                                    WPACKET *pkt);
 
 /*
  * Is a CertificateRequest message allowed at the moment or not?
@@ -689,19 +689,22 @@ WORK_STATE ossl_statem_client_post_process_message(SSL *s, WORK_STATE wst)
 
 int tls_construct_client_hello(SSL *s)
 {
-    unsigned char *buf;
-    unsigned char *p, *d;
+    unsigned char *p;
     int i;
     int protverr;
-    unsigned long l;
-    int al = 0;
+    int al = SSL_AD_HANDSHAKE_FAILURE;
 #ifndef OPENSSL_NO_COMP
-    int j;
     SSL_COMP *comp;
 #endif
     SSL_SESSION *sess = s->session;
+    WPACKET pkt;
 
-    buf = (unsigned char *)s->init_buf->data;
+    if (!WPACKET_init(&pkt, s->init_buf)
+            || !WPACKET_set_max_size(&pkt, SSL3_RT_MAX_PLAIN_LENGTH)) {
+        /* Should not happen */
+        SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_HELLO, ERR_R_INTERNAL_ERROR);
+        goto err;
+    }
 
     /* Work out what SSL/TLS/DTLS version to use */
     protverr = ssl_set_client_hello_version(s);
@@ -743,8 +746,11 @@ int tls_construct_client_hello(SSL *s)
     if (i && ssl_fill_hello_random(s, 0, p, sizeof(s->s3->client_random)) <= 0)
         goto err;
 
-    /* Do the message type and length last */
-    d = p = ssl_handshake_start(s);
+    if (!ssl_set_handshake_header2(s, &pkt, SSL3_MT_CLIENT_HELLO)) {
+        ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_HANDSHAKE_FAILURE);
+        SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_HELLO, ERR_R_INTERNAL_ERROR);
+        goto err;
+    }
 
     /*-
      * version indicates the negotiated version: for example from
@@ -776,90 +782,91 @@ int tls_construct_client_hello(SSL *s)
      * client_version in client hello and not resetting it to
      * the negotiated version.
      */
-    *(p++) = s->client_version >> 8;
-    *(p++) = s->client_version & 0xff;
-
-    /* Random stuff */
-    memcpy(p, s->s3->client_random, SSL3_RANDOM_SIZE);
-    p += SSL3_RANDOM_SIZE;
+    if (!WPACKET_put_bytes(&pkt, s->client_version, 2)
+            || !WPACKET_memcpy(&pkt, s->s3->client_random, SSL3_RANDOM_SIZE)) {
+        SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_HELLO, ERR_R_INTERNAL_ERROR);
+        goto err;
+    }
 
     /* Session ID */
     if (s->new_session)
         i = 0;
     else
         i = s->session->session_id_length;
-    *(p++) = i;
-    if (i != 0) {
-        if (i > (int)sizeof(s->session->session_id)) {
-            SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_HELLO, ERR_R_INTERNAL_ERROR);
-            goto err;
-        }
-        memcpy(p, s->session->session_id, i);
-        p += i;
+    if (i > (int)sizeof(s->session->session_id)
+            || !WPACKET_start_sub_packet_len(&pkt, 1)
+            || (i != 0 && !WPACKET_memcpy(&pkt, s->session->session_id, i))
+            || !WPACKET_close(&pkt)) {
+        SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_HELLO, ERR_R_INTERNAL_ERROR);
+        goto err;
     }
 
     /* cookie stuff for DTLS */
     if (SSL_IS_DTLS(s)) {
-        if (s->d1->cookie_len > sizeof(s->d1->cookie)) {
+        if (s->d1->cookie_len > sizeof(s->d1->cookie)
+                || !WPACKET_sub_memcpy(&pkt, s->d1->cookie, s->d1->cookie_len,
+                                       1)) {
             SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_HELLO, ERR_R_INTERNAL_ERROR);
             goto err;
         }
-        *(p++) = s->d1->cookie_len;
-        memcpy(p, s->d1->cookie, s->d1->cookie_len);
-        p += s->d1->cookie_len;
     }
 
     /* Ciphers supported */
-    i = ssl_cipher_list_to_bytes(s, SSL_get_ciphers(s), &(p[2]));
-    if (i == 0) {
-        SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_HELLO, SSL_R_NO_CIPHERS_AVAILABLE);
+    if (!WPACKET_start_sub_packet_len(&pkt, 2)) {
+        SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_HELLO, ERR_R_INTERNAL_ERROR);
+        goto err;
+    }
+    /* ssl_cipher_list_to_bytes() raises SSLerr if appropriate */
+    if (!ssl_cipher_list_to_bytes(s, SSL_get_ciphers(s), &pkt))
+        goto err;
+    if (!WPACKET_close(&pkt)) {
+        SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_HELLO, ERR_R_INTERNAL_ERROR);
         goto err;
     }
-#ifdef OPENSSL_MAX_TLS1_2_CIPHER_LENGTH
-    /*
-     * Some servers hang if client hello > 256 bytes as hack workaround
-     * chop number of supported ciphers to keep it well below this if we
-     * use TLS v1.2
-     */
-    if (TLS1_get_version(s) >= TLS1_2_VERSION
-        && i > OPENSSL_MAX_TLS1_2_CIPHER_LENGTH)
-        i = OPENSSL_MAX_TLS1_2_CIPHER_LENGTH & ~1;
-#endif
-    s2n(i, p);
-    p += i;
 
     /* COMPRESSION */
-#ifdef OPENSSL_NO_COMP
-    *(p++) = 1;
-#else
-
-    if (!ssl_allow_compression(s) || !s->ctx->comp_methods)
-        j = 0;
-    else
-        j = sk_SSL_COMP_num(s->ctx->comp_methods);
-    *(p++) = 1 + j;
-    for (i = 0; i < j; i++) {
-        comp = sk_SSL_COMP_value(s->ctx->comp_methods, i);
-        *(p++) = comp->id;
+    if (!WPACKET_start_sub_packet_len(&pkt, 1)) {
+        SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_HELLO, ERR_R_INTERNAL_ERROR);
+        goto err;
+    }
+#ifndef OPENSSL_NO_COMP
+    if (ssl_allow_compression(s) && s->ctx->comp_methods) {
+        int compnum = sk_SSL_COMP_num(s->ctx->comp_methods);
+        for (i = 0; i < compnum; i++) {
+            comp = sk_SSL_COMP_value(s->ctx->comp_methods, i);
+            if (!WPACKET_put_bytes(&pkt, comp->id, 1)) {
+                SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_HELLO, ERR_R_INTERNAL_ERROR);
+                goto err;
+            }
+        }
     }
 #endif
-    *(p++) = 0;                 /* Add the NULL method */
+    /* Add the NULL method */
+    if (!WPACKET_put_bytes(&pkt, 0, 1) || !WPACKET_close(&pkt)) {
+        SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_HELLO, ERR_R_INTERNAL_ERROR);
+        goto err;
+    }
 
     /* TLS extensions */
     if (ssl_prepare_clienthello_tlsext(s) <= 0) {
         SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_HELLO, SSL_R_CLIENTHELLO_TLSEXT);
         goto err;
     }
-    if ((p =
-         ssl_add_clienthello_tlsext(s, p, buf + SSL3_RT_MAX_PLAIN_LENGTH,
-                                    &al)) == NULL) {
+    if (!WPACKET_start_sub_packet_len(&pkt, 2)
+               /*
+                * If extensions are of zero length then we don't even add the
+                * extensions length bytes
+                */
+            || !WPACKET_set_flags(&pkt,
+                                  OPENSSL_WPACKET_FLAGS_ABANDON_ON_ZERO_LENGTH)
+            || !ssl_add_clienthello_tlsext(s, &pkt, &al)
+            || !WPACKET_close(&pkt)) {
         ssl3_send_alert(s, SSL3_AL_FATAL, al);
         SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_HELLO, ERR_R_INTERNAL_ERROR);
         goto err;
     }
 
-    l = p - d;
-    if (!ssl_set_handshake_header(s, SSL3_MT_CLIENT_HELLO, l)) {
+    if (!WPACKET_close(&pkt) || !ssl_close_construct_packet(s, &pkt)) {
         ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_HANDSHAKE_FAILURE);
         SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_HELLO, ERR_R_INTERNAL_ERROR);
         goto err;
@@ -2368,7 +2375,7 @@ static int tls_construct_cke_gost(SSL *s, unsigned char **p, int *len, int *al)
     if (pms == NULL) {
         *al = SSL_AD_INTERNAL_ERROR;
         SSLerr(SSL_F_TLS_CONSTRUCT_CKE_GOST, ERR_R_MALLOC_FAILURE);
-        return 0;
+        goto err;
     }
 
     if (EVP_PKEY_encrypt_init(pkey_ctx) <= 0
@@ -2909,47 +2916,79 @@ int ssl_do_client_cert_cb(SSL *s, X509 **px509, EVP_PKEY **ppkey)
     return i;
 }
 
-int ssl_cipher_list_to_bytes(SSL *s, STACK_OF(SSL_CIPHER) *sk, unsigned char *p)
+int ssl_cipher_list_to_bytes(SSL *s, STACK_OF(SSL_CIPHER) *sk, WPACKET *pkt)
 {
-    int i, j = 0;
-    const SSL_CIPHER *c;
-    unsigned char *q;
+    int i;
+    size_t totlen = 0, len, maxlen;
     int empty_reneg_info_scsv = !s->renegotiate;
     /* Set disabled masks for this session */
     ssl_set_client_disabled(s);
 
     if (sk == NULL)
         return (0);
-    q = p;
 
-    for (i = 0; i < sk_SSL_CIPHER_num(sk); i++) {
+#ifdef OPENSSL_MAX_TLS1_2_CIPHER_LENGTH
+# if OPENSSL_MAX_TLS1_2_CIPHER_LENGTH < 6
+#  error Max cipher length too short
+# endif
+    /*
+     * Some servers hang if client hello > 256 bytes as hack workaround
+     * chop number of supported ciphers to keep it well below this if we
+     * use TLS v1.2
+     */
+    if (TLS1_get_version(s) >= TLS1_2_VERSION)
+        maxlen = OPENSSL_MAX_TLS1_2_CIPHER_LENGTH & ~1;
+    else
+#endif
+        /* Maximum length that can be stored in 2 bytes. Length must be even */
+        maxlen = 0xfffe;
+
+    if (empty_reneg_info_scsv)
+        maxlen -= 2;
+    if (s->mode & SSL_MODE_SEND_FALLBACK_SCSV)
+        maxlen -= 2;
+
+    for (i = 0; i < sk_SSL_CIPHER_num(sk) && totlen < maxlen; i++) {
+        const SSL_CIPHER *c;
+
         c = sk_SSL_CIPHER_value(sk, i);
         /* Skip disabled ciphers */
         if (ssl_cipher_disabled(s, c, SSL_SECOP_CIPHER_SUPPORTED))
             continue;
-        j = s->method->put_cipher_by_char(c, p);
-        p += j;
+
+        if (!s->method->put_cipher_by_char(c, pkt, &len)) {
+            SSLerr(SSL_F_SSL_CIPHER_LIST_TO_BYTES, ERR_R_INTERNAL_ERROR);
+            return 0;
+        }
+
+        totlen += len;
     }
-    /*
-     * If p == q, no ciphers; caller indicates an error. Otherwise, add
-     * applicable SCSVs.
-     */
-    if (p != q) {
+
+    if (totlen == 0) {
+        SSLerr(SSL_F_SSL_CIPHER_LIST_TO_BYTES, SSL_R_NO_CIPHERS_AVAILABLE);
+        return 0;
+    }
+
+    if (totlen != 0) {
         if (empty_reneg_info_scsv) {
             static SSL_CIPHER scsv = {
                 0, NULL, SSL3_CK_SCSV, 0, 0, 0, 0, 0, 0, 0, 0, 0
             };
-            j = s->method->put_cipher_by_char(&scsv, p);
-            p += j;
+            if (!s->method->put_cipher_by_char(&scsv, pkt, &len)) {
+                SSLerr(SSL_F_SSL_CIPHER_LIST_TO_BYTES, ERR_R_INTERNAL_ERROR);
+                return 0;
+            }
         }
         if (s->mode & SSL_MODE_SEND_FALLBACK_SCSV) {
             static SSL_CIPHER scsv = {
                 0, NULL, SSL3_CK_FALLBACK_SCSV, 0, 0, 0, 0, 0, 0, 0, 0, 0
             };
-            j = s->method->put_cipher_by_char(&scsv, p);
-            p += j;
+            if (!s->method->put_cipher_by_char(&scsv, pkt, &len)) {
+                SSLerr(SSL_F_SSL_CIPHER_LIST_TO_BYTES, ERR_R_INTERNAL_ERROR);
+                return 0;
+            }
         }
     }
 
-    return (p - q);
+    return 1;
 }