check EC tmp key matches preferences
[openssl.git] / ssl / s3_clnt.c
index c51f3d0b0f48f9fad52f0eb71a248deeb0c68073..11ffabb460c335791c00483defad9326c4479717 100644 (file)
@@ -837,6 +837,7 @@ int ssl3_get_server_hello(SSL *s)
        {
        STACK_OF(SSL_CIPHER) *sk;
        const SSL_CIPHER *c;
+       CERT *ct = s->cert;
        unsigned char *p,*d;
        int i,al=SSL_AD_INTERNAL_ERROR,ok;
        unsigned int j;
@@ -959,9 +960,12 @@ int ssl3_get_server_hello(SSL *s)
                SSLerr(SSL_F_SSL3_GET_SERVER_HELLO,SSL_R_UNKNOWN_CIPHER_RETURNED);
                goto f_err;
                }
-       /* TLS v1.2 only ciphersuites require v1.2 or later */
-       if ((c->algorithm_ssl & SSL_TLSV1_2) && 
-               (TLS1_get_version(s) < TLS1_2_VERSION))
+       /* If it is a disabled cipher we didn't send it in client hello,
+        * so return an error.
+        */
+       if (c->algorithm_ssl & ct->mask_ssl ||
+               c->algorithm_mkey & ct->mask_k ||
+               c->algorithm_auth & ct->mask_a)
                {
                al=SSL_AD_ILLEGAL_PARAMETER;
                SSLerr(SSL_F_SSL3_GET_SERVER_HELLO,SSL_R_WRONG_CIPHER_RETURNED);
@@ -1643,9 +1647,17 @@ int ssl3_get_key_exchange(SSL *s)
                 * and the ECParameters in this case is just three bytes.
                 */
                param_len=3;
-               if ((param_len > n) ||
-                   (*p != NAMED_CURVE_TYPE) || 
-                   ((curve_nid = tls1_ec_curve_id2nid(*(p + 2))) == 0)) 
+               /* Check curve is one of our prefrences, if not server has
+                * sent an invalid curve.
+                */
+               if (!tls1_check_curve(s, p, param_len))
+                       {
+                       al=SSL_AD_DECODE_ERROR;
+                       SSLerr(SSL_F_SSL3_GET_KEY_EXCHANGE,SSL_R_WRONG_CURVE);
+                       goto f_err;
+                       }
+
+               if ((curve_nid = tls1_ec_curve_id2nid(*(p + 2))) == 0) 
                        {
                        al=SSL_AD_INTERNAL_ERROR;
                        SSLerr(SSL_F_SSL3_GET_KEY_EXCHANGE,SSL_R_UNABLE_TO_FIND_ECDH_PARAMETERS);
@@ -1936,11 +1948,22 @@ int ssl3_get_certificate_request(SSL *s)
 
        /* get the certificate types */
        ctype_num= *(p++);
+       if (s->cert->ctypes)
+               {
+               OPENSSL_free(s->cert->ctypes);
+               s->cert->ctypes = NULL;
+               }
        if (ctype_num > SSL3_CT_NUMBER)
+               {
+               /* If we exceed static buffer copy all to cert structure */
+               s->cert->ctypes = OPENSSL_malloc(ctype_num);
+               memcpy(s->cert->ctypes, p, ctype_num);
+               s->cert->ctype_num = (size_t)ctype_num;
                ctype_num=SSL3_CT_NUMBER;
+               }
        for (i=0; i<ctype_num; i++)
                s->s3->tmp.ctype[i]= p[i];
-       p+=ctype_num;
+       p+=p[-1];
        if (TLS1_get_version(s) >= TLS1_2_VERSION)
                {
                n2s(p, llen);
@@ -3180,6 +3203,13 @@ int ssl3_send_client_certificate(SSL *s)
 
        if (s->state == SSL3_ST_CW_CERT_A)
                {
+               /* Let cert callback update client certificates if required */
+               if (s->cert->cert_cb
+                       && s->cert->cert_cb(s, s->cert->cert_cb_arg) <= 0)
+                       {
+                       ssl3_send_alert(s,SSL3_AL_FATAL,SSL_AD_INTERNAL_ERROR);
+                       return 0;
+                       }
                if (ssl3_check_client_certificate(s))
                        s->state=SSL3_ST_CW_CERT_C;
                else