Split out CKE construction PSK pre-amble and RSA into a separate function
authorMatt Caswell <matt@openssl.org>
Thu, 7 Jul 2016 13:42:27 +0000 (14:42 +0100)
committerMatt Caswell <matt@openssl.org>
Mon, 18 Jul 2016 22:05:14 +0000 (23:05 +0100)
The tls_construct_client_key_exchange() function is too long. This splits
out the construction of the PSK pre-amble into a separate function as well
as the RSA construction.

Reviewed-by: Richard Levitte <levitte@openssl.org>
ssl/statem/statem_clnt.c

index e52bbe3b1bce553bb0969e9506bda9fff63cc9a2..c2ecd68e0a8a2fe55b33ecbb8000578fe79f0f4a 100644 (file)
@@ -2012,177 +2012,205 @@ MSG_PROCESS_RETURN tls_process_server_done(SSL *s, PACKET *pkt)
         return MSG_PROCESS_FINISHED_READING;
 }
 
-int tls_construct_client_key_exchange(SSL *s)
+static int tls_construct_cke_psk_preamble(SSL *s, unsigned char **p,
+                                          size_t *pskhdrlen, int *al)
 {
-    unsigned char *p;
-    int n;
 #ifndef OPENSSL_NO_PSK
-    size_t pskhdrlen = 0;
-#endif
-    unsigned long alg_k;
-
-    alg_k = s->s3->tmp.new_cipher->algorithm_mkey;
+    int ret = 0;
+    /*
+     * The callback needs PSK_MAX_IDENTITY_LEN + 1 bytes to return a
+     * \0-terminated identity. The last byte is for us for simulating
+     * strnlen.
+     */
+    char identity[PSK_MAX_IDENTITY_LEN + 1];
+    size_t identitylen = 0;
+    unsigned char psk[PSK_MAX_PSK_LEN];
+    unsigned char *tmppsk = NULL;
+    char *tmpidentity = NULL;
+    size_t psklen = 0;
+
+    if (s->psk_client_callback == NULL) {
+        SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE,
+               SSL_R_PSK_NO_CLIENT_CB);
+        *al = SSL_AD_INTERNAL_ERROR;
+        goto err;
+    }
 
-    p = ssl_handshake_start(s);
+    memset(identity, 0, sizeof(identity));
 
+    psklen = s->psk_client_callback(s, s->session->psk_identity_hint,
+                                    identity, sizeof(identity) - 1,
+                                    psk, sizeof(psk));
 
-#ifndef OPENSSL_NO_PSK
-    if (alg_k & SSL_PSK) {
-        int psk_err = 1;
-        /*
-         * The callback needs PSK_MAX_IDENTITY_LEN + 1 bytes to return a
-         * \0-terminated identity. The last byte is for us for simulating
-         * strnlen.
-         */
-        char identity[PSK_MAX_IDENTITY_LEN + 1];
-        size_t identitylen;
-        unsigned char psk[PSK_MAX_PSK_LEN];
-        unsigned char *tmppsk;
-        char *tmpidentity;
-        size_t psklen;
-
-        if (s->psk_client_callback == NULL) {
-            SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE,
-                   SSL_R_PSK_NO_CLIENT_CB);
-            goto err;
-        }
+    if (psklen > PSK_MAX_PSK_LEN) {
+        SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE,
+               ERR_R_INTERNAL_ERROR);
+        *al = SSL_AD_HANDSHAKE_FAILURE;
+        goto err;
+    } else if (psklen == 0) {
+        SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE,
+               SSL_R_PSK_IDENTITY_NOT_FOUND);
+        *al = SSL_AD_HANDSHAKE_FAILURE;
+        goto err;
+    }
 
-        memset(identity, 0, sizeof(identity));
+    identitylen = strlen(identity);
+    if (identitylen > PSK_MAX_IDENTITY_LEN) {
+        SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE,
+               ERR_R_INTERNAL_ERROR);
+        *al = SSL_AD_HANDSHAKE_FAILURE;
+        goto err;
+    }
 
-        psklen = s->psk_client_callback(s, s->session->psk_identity_hint,
-                                        identity, sizeof(identity) - 1,
-                                        psk, sizeof(psk));
+    tmppsk = OPENSSL_memdup(psk, psklen);
+    tmpidentity = OPENSSL_strdup(identity);
+    if (tmppsk == NULL || tmpidentity == NULL) {
+        SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE, ERR_R_MALLOC_FAILURE);
+        *al = SSL_AD_INTERNAL_ERROR;
+        goto err;
+    }
 
-        if (psklen > PSK_MAX_PSK_LEN) {
-            SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE,
-                   ERR_R_INTERNAL_ERROR);
-            goto psk_err;
-        } else if (psklen == 0) {
-            SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE,
-                   SSL_R_PSK_IDENTITY_NOT_FOUND);
-            goto psk_err;
-        }
+    OPENSSL_free(s->s3->tmp.psk);
+    s->s3->tmp.psk = tmppsk;
+    s->s3->tmp.psklen = psklen;
+    tmppsk = NULL;
+    OPENSSL_free(s->session->psk_identity);
+    s->session->psk_identity = tmpidentity;
+    tmpidentity = NULL;
+    s2n(identitylen, *p);
+    memcpy(*p, identity, identitylen);
+    *pskhdrlen = 2 + identitylen;
+    *p += identitylen;
 
-        identitylen = strlen(identity);
-        if (identitylen > PSK_MAX_IDENTITY_LEN) {
-            SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE,
-                   ERR_R_INTERNAL_ERROR);
-            goto psk_err;
-        }
+    ret = 1;
 
-        tmppsk = OPENSSL_memdup(psk, psklen);
-        tmpidentity = OPENSSL_strdup(identity);
-        if (tmppsk == NULL || tmpidentity == NULL) {
-            OPENSSL_cleanse(identity, sizeof(identity));
-            OPENSSL_cleanse(psk, psklen);
-            OPENSSL_clear_free(tmppsk, psklen);
-            OPENSSL_clear_free(tmpidentity, identitylen);
-            goto memerr;
-        }
+ err:
+    OPENSSL_cleanse(psk, psklen);
+    OPENSSL_cleanse(identity, sizeof(identity));
+    OPENSSL_clear_free(tmppsk, psklen);
+    OPENSSL_clear_free(tmpidentity, identitylen);
 
-        OPENSSL_free(s->s3->tmp.psk);
-        s->s3->tmp.psk = tmppsk;
-        s->s3->tmp.psklen = psklen;
-        OPENSSL_free(s->session->psk_identity);
-        s->session->psk_identity = tmpidentity;
-        s2n(identitylen, p);
-        memcpy(p, identity, identitylen);
-        pskhdrlen = 2 + identitylen;
-        p += identitylen;
-        psk_err = 0;
-psk_err:
-        OPENSSL_cleanse(psk, psklen);
-        OPENSSL_cleanse(identity, sizeof(identity));
-        if (psk_err != 0) {
-            ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_HANDSHAKE_FAILURE);
-            goto err;
-        }
-    }
-    if (alg_k & SSL_kPSK) {
-        n = 0;
-    } else
+    return ret;
+#else
+    SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE, ERR_R_INTERNAL_ERROR);
+    *al = SSL_AD_INTERNAL_ERROR;
+    return 0;
 #endif
+}
 
-    /* Fool emacs indentation */
-    if (0) {
-    }
+static int tls_construct_cke_rsa(SSL *s, unsigned char **p, int *len, int *al)
+{
 #ifndef OPENSSL_NO_RSA
-    else if (alg_k & (SSL_kRSA | SSL_kRSAPSK)) {
-        unsigned char *q;
-        EVP_PKEY *pkey = NULL;
-        EVP_PKEY_CTX *pctx = NULL;
-        size_t enclen;
-        unsigned char *pms = NULL;
-        size_t pmslen = 0;
+    unsigned char *q;
+    EVP_PKEY *pkey = NULL;
+    EVP_PKEY_CTX *pctx = NULL;
+    size_t enclen;
+    unsigned char *pms = NULL;
+    size_t pmslen = 0;
 
-        if (s->session->peer == NULL) {
-            /*
-             * We should always have a server certificate with SSL_kRSA.
-             */
-            SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE,
-                   ERR_R_INTERNAL_ERROR);
-            goto err;
-        }
+    if (s->session->peer == NULL) {
+        /*
+         * We should always have a server certificate with SSL_kRSA.
+         */
+        SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE,
+               ERR_R_INTERNAL_ERROR);
+        return 0;
+    }
 
-        pkey = X509_get0_pubkey(s->session->peer);
-        if (EVP_PKEY_get0_RSA(pkey) == NULL) {
-            SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE,
-                   ERR_R_INTERNAL_ERROR);
-            goto err;
-        }
+    pkey = X509_get0_pubkey(s->session->peer);
+    if (EVP_PKEY_get0_RSA(pkey) == NULL) {
+        SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE,
+               ERR_R_INTERNAL_ERROR);
+        return 0;
+    }
 
-        pmslen = SSL_MAX_MASTER_KEY_LENGTH;
-        pms = OPENSSL_malloc(pmslen);
-        if (pms == NULL)
-            goto memerr;
+    pmslen = SSL_MAX_MASTER_KEY_LENGTH;
+    pms = OPENSSL_malloc(pmslen);
+    if (pms == NULL) {
+        SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE,
+               ERR_R_MALLOC_FAILURE);
+        *al = SSL_AD_INTERNAL_ERROR;
+        return 0;
+    }
 
-        pms[0] = s->client_version >> 8;
-        pms[1] = s->client_version & 0xff;
-        if (RAND_bytes(pms + 2, pmslen - 2) <= 0) {
-            OPENSSL_clear_free(pms, pmslen);
-            goto err;
-        }
+    pms[0] = s->client_version >> 8;
+    pms[1] = s->client_version & 0xff;
+    if (RAND_bytes(pms + 2, pmslen - 2) <= 0) {
+        goto err;
+    }
 
-        q = p;
-        /* Fix buf for TLS and beyond */
-        if (s->version > SSL3_VERSION)
-            p += 2;
-        pctx = EVP_PKEY_CTX_new(pkey, NULL);
-        if (pctx == NULL || EVP_PKEY_encrypt_init(pctx) <= 0
-            || EVP_PKEY_encrypt(pctx, NULL, &enclen, pms, pmslen) <= 0) {
-            OPENSSL_clear_free(pms, pmslen);
-            EVP_PKEY_CTX_free(pctx);
-            SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE,
-                   ERR_R_EVP_LIB);
-            goto err;
-        }
-        if (EVP_PKEY_encrypt(pctx, p, &enclen, pms, pmslen) <= 0) {
-            OPENSSL_clear_free(pms, pmslen);
-            EVP_PKEY_CTX_free(pctx);
-            SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE,
-                   SSL_R_BAD_RSA_ENCRYPT);
-            goto err;
-        }
-        n = enclen;
-        EVP_PKEY_CTX_free(pctx);
-        pctx = NULL;
+    q = *p;
+    /* Fix buf for TLS and beyond */
+    if (s->version > SSL3_VERSION)
+        *p += 2;
+    pctx = EVP_PKEY_CTX_new(pkey, NULL);
+    if (pctx == NULL || EVP_PKEY_encrypt_init(pctx) <= 0
+        || EVP_PKEY_encrypt(pctx, NULL, &enclen, pms, pmslen) <= 0) {
+        SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE,
+               ERR_R_EVP_LIB);
+        goto err;
+    }
+    if (EVP_PKEY_encrypt(pctx, *p, &enclen, pms, pmslen) <= 0) {
+        SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE,
+               SSL_R_BAD_RSA_ENCRYPT);
+        goto err;
+    }
+    *len = enclen;
+    EVP_PKEY_CTX_free(pctx);
+    pctx = NULL;
 # ifdef PKCS1_CHECK
-        if (s->options & SSL_OP_PKCS1_CHECK_1)
-            p[1]++;
-        if (s->options & SSL_OP_PKCS1_CHECK_2)
-            tmp_buf[0] = 0x70;
+    if (s->options & SSL_OP_PKCS1_CHECK_1)
+        (*p)[1]++;
+    if (s->options & SSL_OP_PKCS1_CHECK_2)
+        tmp_buf[0] = 0x70;
 # endif
 
-        /* Fix buf for TLS and beyond */
-        if (s->version > SSL3_VERSION) {
-            s2n(n, q);
-            n += 2;
-        }
-
-        s->s3->tmp.pms = pms;
-        s->s3->tmp.pmslen = pmslen;
+    /* Fix buf for TLS and beyond */
+    if (s->version > SSL3_VERSION) {
+        s2n(*len, q);
+        *len += 2;
     }
+
+    s->s3->tmp.pms = pms;
+    s->s3->tmp.pmslen = pmslen;
+
+    return 1;
+ err:
+    OPENSSL_clear_free(pms, pmslen);
+    EVP_PKEY_CTX_free(pctx);
+
+    return 0;
+#else
+    SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE, ERR_R_INTERNAL_ERROR);
+    *al = SSL_AD_INTERNAL_ERROR;
+    return 0;
 #endif
+}
+
+int tls_construct_client_key_exchange(SSL *s)
+{
+    unsigned char *p;
+    int n;
+    size_t pskhdrlen = 0;
+    unsigned long alg_k;
+    int al = -1;
+
+    alg_k = s->s3->tmp.new_cipher->algorithm_mkey;
+
+    p = ssl_handshake_start(s);
+
+
+
+    if ((alg_k & SSL_PSK)
+            && !tls_construct_cke_psk_preamble(s, &p, &pskhdrlen, &al))
+        goto err;
+
+    if (alg_k & SSL_kPSK) {
+        n = 0;
+    } else if (alg_k & (SSL_kRSA | SSL_kRSAPSK)) {
+        if (!tls_construct_cke_rsa(s, &p, &n, &al))
+            goto err;
+    }
 #ifndef OPENSSL_NO_DH
     else if (alg_k & (SSL_kDHE | SSL_kDHEPSK)) {
         DH *dh_clnt = NULL;
@@ -2421,9 +2449,7 @@ psk_err:
         goto err;
     }
 
-#ifndef OPENSSL_NO_PSK
     n += pskhdrlen;
-#endif
 
     if (!ssl_set_handshake_header(s, SSL3_MT_CLIENT_KEY_EXCHANGE, n)) {
         ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_HANDSHAKE_FAILURE);
@@ -2433,9 +2459,11 @@ psk_err:
 
     return 1;
  memerr:
-    ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_INTERNAL_ERROR);
     SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE, ERR_R_MALLOC_FAILURE);
+    al = SSL_AD_INTERNAL_ERROR;
  err:
+    if (al != -1)
+        ssl3_send_alert(s, SSL3_AL_FATAL, al);
     OPENSSL_clear_free(s->s3->tmp.pms, s->s3->tmp.pmslen);
     s->s3->tmp.pms = NULL;
 #ifndef OPENSSL_NO_PSK