Fixed error introduced in commit f2be92b94dad3c6cbdf79d99a324804094cf1617
[openssl.git] / ssl / s3_clnt.c
index 4e3cc2ef7a5ac4c1a835e02667800e984214f6b7..3e89e5204d7bab86c74724a4e1b0167b468d1007 100644 (file)
@@ -327,9 +327,9 @@ int ssl3_connect(SSL *s)
                                break;
                                }
 #endif
-                       /* Check if it is anon DH/ECDH */
+                       /* Check if it is anon DH/ECDH, SRP auth */
                        /* or PSK */
-                       if (!(s->s3->tmp.new_cipher->algorithm_auth & SSL_aNULL) &&
+                       if (!(s->s3->tmp.new_cipher->algorithm_auth & (SSL_aNULL|SSL_aSRP)) &&
                            !(s->s3->tmp.new_cipher->algorithm_mkey & SSL_kPSK))
                                {
                                ret=ssl3_get_server_certificate(s);
@@ -1363,8 +1363,8 @@ int ssl3_get_key_exchange(SSL *s)
 #endif
        EVP_MD_CTX md_ctx;
        unsigned char *param,*p;
-       int al,i,j,param_len,ok;
-       long n,alg_k,alg_a;
+       int al,j,ok;
+       long i,param_len,n,alg_k,alg_a;
        EVP_PKEY *pkey=NULL;
        const EVP_MD *md = NULL;
 #ifndef OPENSSL_NO_RSA
@@ -1440,36 +1440,48 @@ int ssl3_get_key_exchange(SSL *s)
                s->session->sess_cert=ssl_sess_cert_new();
                }
 
+       /* Total length of the parameters including the length prefix */
        param_len=0;
+
        alg_k=s->s3->tmp.new_cipher->algorithm_mkey;
        alg_a=s->s3->tmp.new_cipher->algorithm_auth;
        EVP_MD_CTX_init(&md_ctx);
 
+       al=SSL_AD_DECODE_ERROR;
+
 #ifndef OPENSSL_NO_PSK
        if (alg_k & SSL_kPSK)
                {
                char tmp_id_hint[PSK_MAX_IDENTITY_LEN+1];
 
-               al=SSL_AD_HANDSHAKE_FAILURE;
+               param_len = 2;
+               if (param_len > n)
+                       {
+                       SSLerr(SSL_F_SSL3_GET_KEY_EXCHANGE,
+                               SSL_R_LENGTH_TOO_SHORT);
+                       goto f_err;
+                       }
                n2s(p,i);
-               param_len=i+2;
+
                /* Store PSK identity hint for later use, hint is used
                 * in ssl3_send_client_key_exchange.  Assume that the
                 * maximum length of a PSK identity hint can be as
                 * long as the maximum length of a PSK identity. */
                if (i > PSK_MAX_IDENTITY_LEN)
                        {
+                       al=SSL_AD_HANDSHAKE_FAILURE;
                        SSLerr(SSL_F_SSL3_GET_KEY_EXCHANGE,
                                SSL_R_DATA_LENGTH_TOO_LONG);
                        goto f_err;
                        }
-               if (param_len > n)
+               if (i > n - param_len)
                        {
-                       al=SSL_AD_DECODE_ERROR;
                        SSLerr(SSL_F_SSL3_GET_KEY_EXCHANGE,
                                SSL_R_BAD_PSK_IDENTITY_HINT_LENGTH);
                        goto f_err;
                        }
+               param_len += i;
+
                /* If received PSK identity hint contains NULL
                 * characters, the hint is truncated from the first
                 * NULL. p may not be ending with NULL, so create a
@@ -1481,6 +1493,7 @@ int ssl3_get_key_exchange(SSL *s)
                s->ctx->psk_identity_hint = BUF_strdup(tmp_id_hint);
                if (s->ctx->psk_identity_hint == NULL)
                        {
+                       al=SSL_AD_HANDSHAKE_FAILURE;
                        SSLerr(SSL_F_SSL3_GET_KEY_EXCHANGE, ERR_R_MALLOC_FAILURE);
                        goto f_err;
                        }          
@@ -1493,14 +1506,22 @@ int ssl3_get_key_exchange(SSL *s)
 #ifndef OPENSSL_NO_SRP
        if (alg_k & SSL_kSRP)
                {
-               n2s(p,i);
-               param_len=i+2;
+               param_len = 2;
                if (param_len > n)
                        {
-                       al=SSL_AD_DECODE_ERROR;
+                       SSLerr(SSL_F_SSL3_GET_KEY_EXCHANGE,
+                               SSL_R_LENGTH_TOO_SHORT);
+                       goto f_err;
+                       }
+               n2s(p,i);
+
+               if (i > n - param_len)
+                       {
                        SSLerr(SSL_F_SSL3_GET_KEY_EXCHANGE,SSL_R_BAD_SRP_N_LENGTH);
                        goto f_err;
                        }
+               param_len += i;
+
                if (!(s->srp_ctx.N=BN_bin2bn(p,i,NULL)))
                        {
                        SSLerr(SSL_F_SSL3_GET_KEY_EXCHANGE,ERR_R_BN_LIB);
@@ -1508,14 +1529,24 @@ int ssl3_get_key_exchange(SSL *s)
                        }
                p+=i;
 
+
+               if (2 > n - param_len)
+                       {
+                       SSLerr(SSL_F_SSL3_GET_KEY_EXCHANGE,
+                               SSL_R_LENGTH_TOO_SHORT);
+                       goto f_err;
+                       }
+               param_len += 2;
+
                n2s(p,i);
-               param_len+=i+2;
-               if (param_len > n)
+
+               if (i > n - param_len)
                        {
-                       al=SSL_AD_DECODE_ERROR;
                        SSLerr(SSL_F_SSL3_GET_KEY_EXCHANGE,SSL_R_BAD_SRP_G_LENGTH);
                        goto f_err;
                        }
+               param_len += i;
+
                if (!(s->srp_ctx.g=BN_bin2bn(p,i,NULL)))
                        {
                        SSLerr(SSL_F_SSL3_GET_KEY_EXCHANGE,ERR_R_BN_LIB);
@@ -1523,15 +1554,25 @@ int ssl3_get_key_exchange(SSL *s)
                        }
                p+=i;
 
+
+               if (1 > n - param_len)
+                       {
+                       SSLerr(SSL_F_SSL3_GET_KEY_EXCHANGE,
+                               SSL_R_LENGTH_TOO_SHORT);
+                       goto f_err;
+                       }
+               param_len += 1;
+
                i = (unsigned int)(p[0]);
                p++;
-               param_len+=i+1;
-               if (param_len > n)
+
+               if (i > n - param_len)
                        {
-                       al=SSL_AD_DECODE_ERROR;
                        SSLerr(SSL_F_SSL3_GET_KEY_EXCHANGE,SSL_R_BAD_SRP_S_LENGTH);
                        goto f_err;
                        }
+               param_len += i;
+
                if (!(s->srp_ctx.s=BN_bin2bn(p,i,NULL)))
                        {
                        SSLerr(SSL_F_SSL3_GET_KEY_EXCHANGE,ERR_R_BN_LIB);
@@ -1539,14 +1580,23 @@ int ssl3_get_key_exchange(SSL *s)
                        }
                p+=i;
 
+               if (2 > n - param_len)
+                       {
+                       SSLerr(SSL_F_SSL3_GET_KEY_EXCHANGE,
+                               SSL_R_LENGTH_TOO_SHORT);
+                       goto f_err;
+                       }
+               param_len += 2;
+
                n2s(p,i);
-               param_len+=i+2;
-               if (param_len > n)
+
+               if (i > n - param_len)
                        {
-                       al=SSL_AD_DECODE_ERROR;
                        SSLerr(SSL_F_SSL3_GET_KEY_EXCHANGE,SSL_R_BAD_SRP_B_LENGTH);
                        goto f_err;
                        }
+               param_len += i;
+
                if (!(s->srp_ctx.B=BN_bin2bn(p,i,NULL)))
                        {
                        SSLerr(SSL_F_SSL3_GET_KEY_EXCHANGE,ERR_R_BN_LIB);
@@ -1555,6 +1605,12 @@ int ssl3_get_key_exchange(SSL *s)
                p+=i;
                n-=param_len;
 
+               if (!srp_verify_server_param(s, &al))
+                       {
+                       SSLerr(SSL_F_SSL3_GET_KEY_EXCHANGE,SSL_R_BAD_SRP_PARAMETERS);
+                       goto f_err;
+                       }
+
 /* We must check if there is a certificate */
 #ifndef OPENSSL_NO_RSA
                if (alg_a & SSL_aRSA)
@@ -1578,14 +1634,23 @@ int ssl3_get_key_exchange(SSL *s)
                        SSLerr(SSL_F_SSL3_GET_KEY_EXCHANGE,ERR_R_MALLOC_FAILURE);
                        goto err;
                        }
-               n2s(p,i);
-               param_len=i+2;
+
+               param_len = 2;
                if (param_len > n)
                        {
-                       al=SSL_AD_DECODE_ERROR;
+                       SSLerr(SSL_F_SSL3_GET_KEY_EXCHANGE,
+                               SSL_R_LENGTH_TOO_SHORT);
+                       goto f_err;
+                       }
+               n2s(p,i);
+
+               if (i > n - param_len)
+                       {
                        SSLerr(SSL_F_SSL3_GET_KEY_EXCHANGE,SSL_R_BAD_RSA_MODULUS_LENGTH);
                        goto f_err;
                        }
+               param_len += i;
+
                if (!(rsa->n=BN_bin2bn(p,i,rsa->n)))
                        {
                        SSLerr(SSL_F_SSL3_GET_KEY_EXCHANGE,ERR_R_BN_LIB);
@@ -1593,14 +1658,23 @@ int ssl3_get_key_exchange(SSL *s)
                        }
                p+=i;
 
+               if (2 > n - param_len)
+                       {
+                       SSLerr(SSL_F_SSL3_GET_KEY_EXCHANGE,
+                               SSL_R_LENGTH_TOO_SHORT);
+                       goto f_err;
+                       }
+               param_len += 2;
+
                n2s(p,i);
-               param_len+=i+2;
-               if (param_len > n)
+
+               if (i > n - param_len)
                        {
-                       al=SSL_AD_DECODE_ERROR;
                        SSLerr(SSL_F_SSL3_GET_KEY_EXCHANGE,SSL_R_BAD_RSA_E_LENGTH);
                        goto f_err;
                        }
+               param_len += i;
+
                if (!(rsa->e=BN_bin2bn(p,i,rsa->e)))
                        {
                        SSLerr(SSL_F_SSL3_GET_KEY_EXCHANGE,ERR_R_BN_LIB);
@@ -1632,14 +1706,23 @@ int ssl3_get_key_exchange(SSL *s)
                        SSLerr(SSL_F_SSL3_GET_KEY_EXCHANGE,ERR_R_DH_LIB);
                        goto err;
                        }
-               n2s(p,i);
-               param_len=i+2;
+
+               param_len = 2;
                if (param_len > n)
                        {
-                       al=SSL_AD_DECODE_ERROR;
+                       SSLerr(SSL_F_SSL3_GET_KEY_EXCHANGE,
+                               SSL_R_LENGTH_TOO_SHORT);
+                       goto f_err;
+                       }
+               n2s(p,i);
+
+               if (i > n - param_len)
+                       {
                        SSLerr(SSL_F_SSL3_GET_KEY_EXCHANGE,SSL_R_BAD_DH_P_LENGTH);
                        goto f_err;
                        }
+               param_len += i;
+
                if (!(dh->p=BN_bin2bn(p,i,NULL)))
                        {
                        SSLerr(SSL_F_SSL3_GET_KEY_EXCHANGE,ERR_R_BN_LIB);
@@ -1647,14 +1730,23 @@ int ssl3_get_key_exchange(SSL *s)
                        }
                p+=i;
 
+               if (2 > n - param_len)
+                       {
+                       SSLerr(SSL_F_SSL3_GET_KEY_EXCHANGE,
+                               SSL_R_LENGTH_TOO_SHORT);
+                       goto f_err;
+                       }
+               param_len += 2;
+
                n2s(p,i);
-               param_len+=i+2;
-               if (param_len > n)
+
+               if (i > n - param_len)
                        {
-                       al=SSL_AD_DECODE_ERROR;
                        SSLerr(SSL_F_SSL3_GET_KEY_EXCHANGE,SSL_R_BAD_DH_G_LENGTH);
                        goto f_err;
                        }
+               param_len += i;
+
                if (!(dh->g=BN_bin2bn(p,i,NULL)))
                        {
                        SSLerr(SSL_F_SSL3_GET_KEY_EXCHANGE,ERR_R_BN_LIB);
@@ -1662,14 +1754,23 @@ int ssl3_get_key_exchange(SSL *s)
                        }
                p+=i;
 
+               if (2 > n - param_len)
+                       {
+                       SSLerr(SSL_F_SSL3_GET_KEY_EXCHANGE,
+                               SSL_R_LENGTH_TOO_SHORT);
+                       goto f_err;
+                       }
+               param_len += 2;
+
                n2s(p,i);
-               param_len+=i+2;
-               if (param_len > n)
+
+               if (i > n - param_len)
                        {
-                       al=SSL_AD_DECODE_ERROR;
                        SSLerr(SSL_F_SSL3_GET_KEY_EXCHANGE,SSL_R_BAD_DH_PUB_KEY_LENGTH);
                        goto f_err;
                        }
+               param_len += i;
+
                if (!(dh->pub_key=BN_bin2bn(p,i,NULL)))
                        {
                        SSLerr(SSL_F_SSL3_GET_KEY_EXCHANGE,ERR_R_BN_LIB);
@@ -1721,15 +1822,21 @@ int ssl3_get_key_exchange(SSL *s)
                 */
 
                /* XXX: For now we only support named (not generic) curves
-                * and the ECParameters in this case is just three bytes.
+                * and the ECParameters in this case is just three bytes. We
+                * also need one byte for the length of the encoded point
                 */
-               param_len=3;
-               /* Check curve is one of our prefrences, if not server has
-                * sent an invalid curve.
+               param_len=4;
+               if (param_len > n)
+                       {
+                       SSLerr(SSL_F_SSL3_GET_KEY_EXCHANGE,
+                               SSL_R_LENGTH_TOO_SHORT);
+                       goto f_err;
+                       }
+               /* Check curve is one of our preferences, if not server has
+                * sent an invalid curve. ECParameters is 3 bytes.
                 */
-               if (!tls1_check_curve(s, p, param_len))
+               if (!tls1_check_curve(s, p, 3))
                        {
-                       al=SSL_AD_DECODE_ERROR;
                        SSLerr(SSL_F_SSL3_GET_KEY_EXCHANGE,SSL_R_WRONG_CURVE);
                        goto f_err;
                        }
@@ -1776,15 +1883,15 @@ int ssl3_get_key_exchange(SSL *s)
 
                encoded_pt_len = *p;  /* length of encoded point */
                p+=1;
-               param_len += (1 + encoded_pt_len);
-               if ((param_len > n) ||
+
+               if ((encoded_pt_len > n - param_len) ||
                    (EC_POINT_oct2point(group, srvr_ecpoint, 
                        p, encoded_pt_len, bn_ctx) == 0))
                        {
-                       al=SSL_AD_DECODE_ERROR;
                        SSLerr(SSL_F_SSL3_GET_KEY_EXCHANGE,SSL_R_BAD_ECPOINT);
                        goto f_err;
                        }
+               param_len += encoded_pt_len;
 
                n-=param_len;
                p+=encoded_pt_len;
@@ -1827,12 +1934,18 @@ int ssl3_get_key_exchange(SSL *s)
                {
                if (SSL_USE_SIGALGS(s))
                        {
-                       int rv = tls12_check_peer_sigalg(&md, s, p, pkey);
+                       int rv;
+                       if (2 > n)
+                               {
+                               SSLerr(SSL_F_SSL3_GET_KEY_EXCHANGE,
+                                       SSL_R_LENGTH_TOO_SHORT);
+                               goto f_err;
+                               }
+                       rv = tls12_check_peer_sigalg(&md, s, p, pkey);
                        if (rv == -1)
                                goto err;
                        else if (rv == 0)
                                {
-                               al = SSL_AD_DECODE_ERROR;
                                goto f_err;
                                }
 #ifdef SSL_DEBUG
@@ -1843,15 +1956,21 @@ fprintf(stderr, "USING TLSv1.2 HASH %s\n", EVP_MD_name(md));
                        }
                else
                        md = EVP_sha1();
-                       
+
+               if (2 > n)
+                       {
+                       SSLerr(SSL_F_SSL3_GET_KEY_EXCHANGE,
+                               SSL_R_LENGTH_TOO_SHORT);
+                       goto f_err;
+                       }
                n2s(p,i);
                n-=2;
                j=EVP_PKEY_size(pkey);
 
+               /* Check signature length. If n is 0 then signature is empty */
                if ((i != n) || (n > j) || (n <= 0))
                        {
                        /* wrong packet length */
-                       al=SSL_AD_DECODE_ERROR;
                        SSLerr(SSL_F_SSL3_GET_KEY_EXCHANGE,SSL_R_WRONG_SIGNATURE_LENGTH);
                        goto f_err;
                        }
@@ -1860,6 +1979,7 @@ fprintf(stderr, "USING TLSv1.2 HASH %s\n", EVP_MD_name(md));
                if (pkey->type == EVP_PKEY_RSA && !SSL_USE_SIGALGS(s))
                        {
                        int num;
+                       unsigned int size;
 
                        j=0;
                        q=md_buf;
@@ -1872,9 +1992,9 @@ fprintf(stderr, "USING TLSv1.2 HASH %s\n", EVP_MD_name(md));
                                EVP_DigestUpdate(&md_ctx,&(s->s3->client_random[0]),SSL3_RANDOM_SIZE);
                                EVP_DigestUpdate(&md_ctx,&(s->s3->server_random[0]),SSL3_RANDOM_SIZE);
                                EVP_DigestUpdate(&md_ctx,param,param_len);
-                               EVP_DigestFinal_ex(&md_ctx,q,(unsigned int *)&i);
-                               q+=i;
-                               j+=i;
+                               EVP_DigestFinal_ex(&md_ctx,q,&size);
+                               q+=size;
+                               j+=size;
                                }
                        i=RSA_verify(NID_md5_sha1, md_buf, j, p, n,
                                                                pkey->pkey.rsa);
@@ -1910,8 +2030,8 @@ fprintf(stderr, "USING TLSv1.2 HASH %s\n", EVP_MD_name(md));
                }
        else
                {
-               /* aNULL or kPSK do not need public keys */
-               if (!(alg_a & SSL_aNULL) && !(alg_k & SSL_kPSK))
+               /* aNULL, aSRP or kPSK do not need public keys */
+               if (!(alg_a & (SSL_aNULL|SSL_aSRP)) && !(alg_k & SSL_kPSK))
                        {
                        /* Might be wrong key type, check it */
                        if (ssl3_check_cert_and_algorithm(s))
@@ -1922,7 +2042,6 @@ fprintf(stderr, "USING TLSv1.2 HASH %s\n", EVP_MD_name(md));
                /* still data left over */
                if (n != 0)
                        {
-                       al=SSL_AD_DECODE_ERROR;
                        SSLerr(SSL_F_SSL3_GET_KEY_EXCHANGE,SSL_R_EXTRA_DATA_IN_MESSAGE);
                        goto f_err;
                        }
@@ -2362,6 +2481,13 @@ int ssl3_send_client_key_exchange(SSL *s)
                        RSA *rsa;
                        unsigned char tmp_buf[SSL_MAX_MASTER_KEY_LENGTH];
 
+                       if (s->session->sess_cert == NULL)
+                               {
+                               /* We should always have a server certificate with SSL_kRSA. */
+                               SSLerr(SSL_F_SSL3_SEND_CLIENT_KEY_EXCHANGE,ERR_R_INTERNAL_ERROR);
+                               goto err;
+                               }
+
                        if (s->session->sess_cert->peer_rsa_tmp != NULL)
                                rsa=s->session->sess_cert->peer_rsa_tmp;
                        else
@@ -2971,7 +3097,11 @@ int ssl3_send_client_key_exchange(SSL *s)
 #ifndef OPENSSL_NO_PSK
                else if (alg_k & SSL_kPSK)
                        {
-                       char identity[PSK_MAX_IDENTITY_LEN];
+                       /* The callback needs PSK_MAX_IDENTITY_LEN + 1 bytes
+                        * to return a \0-terminated identity. The last byte
+                        * is for us for simulating strnlen. */
+                       char identity[PSK_MAX_IDENTITY_LEN + 2];
+                       size_t identity_len;
                        unsigned char *t = NULL;
                        unsigned char psk_or_pre_ms[PSK_MAX_PSK_LEN*2+4];
                        unsigned int pre_ms_len = 0, psk_len = 0;
@@ -2985,8 +3115,9 @@ int ssl3_send_client_key_exchange(SSL *s)
                                goto err;
                                }
 
+                       memset(identity, 0, sizeof(identity));
                        psk_len = s->psk_client_callback(s, s->ctx->psk_identity_hint,
-                               identity, PSK_MAX_IDENTITY_LEN,
+                               identity, sizeof(identity) - 1,
                                psk_or_pre_ms, sizeof(psk_or_pre_ms));
                        if (psk_len > PSK_MAX_PSK_LEN)
                                {
@@ -3000,7 +3131,14 @@ int ssl3_send_client_key_exchange(SSL *s)
                                        SSL_R_PSK_IDENTITY_NOT_FOUND);
                                goto psk_err;
                                }
-
+                       identity[PSK_MAX_IDENTITY_LEN + 1] = '\0';
+                       identity_len = strlen(identity);
+                       if (identity_len > PSK_MAX_IDENTITY_LEN)
+                               {
+                               SSLerr(SSL_F_SSL3_SEND_CLIENT_KEY_EXCHANGE,
+                                       ERR_R_INTERNAL_ERROR);
+                               goto psk_err;
+                               }
                        /* create PSK pre_master_secret */
                        pre_ms_len = 2+psk_len+2+psk_len;
                        t = psk_or_pre_ms;
@@ -3034,14 +3172,13 @@ int ssl3_send_client_key_exchange(SSL *s)
                        s->session->master_key_length =
                                s->method->ssl3_enc->generate_master_secret(s,
                                        s->session->master_key,
-                                       psk_or_pre_ms, pre_ms_len); 
-                       n = strlen(identity);
-                       s2n(n, p);
-                       memcpy(p, identity, n);
-                       n+=2;
+                                       psk_or_pre_ms, pre_ms_len);
+                       s2n(identity_len, p);
+                       memcpy(p, identity, identity_len);
+                       n = 2 + identity_len;
                        psk_err = 0;
                psk_err:
-                       OPENSSL_cleanse(identity, PSK_MAX_IDENTITY_LEN);
+                       OPENSSL_cleanse(identity, sizeof(identity));
                        OPENSSL_cleanse(psk_or_pre_ms, sizeof(psk_or_pre_ms));
                        if (psk_err != 0)
                                {