I'm reversing this change, as it seems the error is somewhere else.
[openssl.git] / ssl / s3_clnt.c
index b27a1deaa77bb9e44b65ea02b8b81184326ecb98..4f67cda042b2b7dff034e5d673a73bc881b91d38 100644 (file)
 #include <openssl/objects.h>
 #include <openssl/evp.h>
 #include <openssl/md5.h>
-#include "cryptlib.h"
+#ifndef OPENSSL_NO_DH
+#include <openssl/dh.h>
+#endif
+#include <openssl/bn.h>
 
 static SSL_METHOD *ssl3_get_client_method(int ver);
-static int ssl3_client_hello(SSL *s);
-static int ssl3_get_server_hello(SSL *s);
-static int ssl3_get_certificate_request(SSL *s);
 static int ca_dn_cmp(const X509_NAME * const *a,const X509_NAME * const *b);
-static int ssl3_get_server_done(SSL *s);
-static int ssl3_send_client_verify(SSL *s);
-static int ssl3_send_client_certificate(SSL *s);
-static int ssl3_send_client_key_exchange(SSL *s);
-static int ssl3_get_key_exchange(SSL *s);
-static int ssl3_get_server_certificate(SSL *s);
-static int ssl3_check_cert_and_algorithm(SSL *s);
 
 #ifndef OPENSSL_NO_ECDH
 static int curve_id2nid(int curve_id);
@@ -538,7 +531,7 @@ end:
        }
 
 
-static int ssl3_client_hello(SSL *s)
+int ssl3_client_hello(SSL *s)
        {
        unsigned char *buf;
        unsigned char *p,*d;
@@ -561,7 +554,8 @@ static int ssl3_client_hello(SSL *s)
                p=s->s3->client_random;
                Time=time(NULL);                        /* Time */
                l2n(Time,p);
-               RAND_pseudo_bytes(p,SSL3_RANDOM_SIZE-sizeof(Time));
+               if (RAND_pseudo_bytes(p,SSL3_RANDOM_SIZE-4) <= 0)
+                       goto err;
 
                /* Do the message type and length last */
                d=p= &(buf[4]);
@@ -582,7 +576,7 @@ static int ssl3_client_hello(SSL *s)
                *(p++)=i;
                if (i != 0)
                        {
-                       if (i > sizeof s->session->session_id)
+                       if (i > (int)sizeof(s->session->session_id))
                                {
                                SSLerr(SSL_F_SSL3_CLIENT_HELLO, ERR_R_INTERNAL_ERROR);
                                goto err;
@@ -592,7 +586,7 @@ static int ssl3_client_hello(SSL *s)
                        }
                
                /* Ciphers supported */
-               i=ssl_cipher_list_to_bytes(s,SSL_get_ciphers(s),&(p[2]));
+               i=ssl_cipher_list_to_bytes(s,SSL_get_ciphers(s),&(p[2]),0);
                if (i == 0)
                        {
                        SSLerr(SSL_F_SSL3_CLIENT_HELLO,SSL_R_NO_CIPHERS_AVAILABLE);
@@ -631,7 +625,7 @@ err:
        return(-1);
        }
 
-static int ssl3_get_server_hello(SSL *s)
+int ssl3_get_server_hello(SSL *s)
        {
        STACK_OF(SSL_CIPHER) *sk;
        SSL_CIPHER *c;
@@ -641,14 +635,40 @@ static int ssl3_get_server_hello(SSL *s)
        long n;
        SSL_COMP *comp;
 
-       n=ssl3_get_message(s,
+       n=s->method->ssl_get_message(s,
                SSL3_ST_CR_SRVR_HELLO_A,
                SSL3_ST_CR_SRVR_HELLO_B,
-               SSL3_MT_SERVER_HELLO,
+               -1,
                300, /* ?? */
                &ok);
 
        if (!ok) return((int)n);
+
+       if ( SSL_version(s) == DTLS1_VERSION)
+               {
+               if ( s->s3->tmp.message_type == DTLS1_MT_HELLO_VERIFY_REQUEST)
+                       {
+                       if ( s->d1->send_cookie == 0)
+                               {
+                               s->s3->tmp.reuse_message = 1;
+                               return 1;
+                               }
+                       else /* already sent a cookie */
+                               {
+                               al=SSL_AD_UNEXPECTED_MESSAGE;
+                               SSLerr(SSL_F_SSL3_GET_SERVER_HELLO,SSL_R_BAD_MESSAGE_TYPE);
+                               goto f_err;
+                               }
+                       }
+               }
+       
+       if ( s->s3->tmp.message_type != SSL3_MT_SERVER_HELLO)
+               {
+               al=SSL_AD_UNEXPECTED_MESSAGE;
+               SSLerr(SSL_F_SSL3_GET_SERVER_HELLO,SSL_R_BAD_MESSAGE_TYPE);
+               goto f_err;
+               }
+
        d=p=(unsigned char *)s->init_msg;
 
        if ((p[0] != (s->version>>8)) || (p[1] != (s->version&0xff)))
@@ -776,18 +796,19 @@ err:
        return(-1);
        }
 
-static int ssl3_get_server_certificate(SSL *s)
+int ssl3_get_server_certificate(SSL *s)
        {
        int al,i,ok,ret= -1;
        unsigned long n,nc,llen,l;
        X509 *x=NULL;
-       unsigned char *p,*d,*q;
+       const unsigned char *q,*p;
+       unsigned char *d;
        STACK_OF(X509) *sk=NULL;
        SESS_CERT *sc;
        EVP_PKEY *pkey=NULL;
        int need_cert = 1; /* VRS: 0=> will allow null cert if auth == KRB5 */
 
-       n=ssl3_get_message(s,
+       n=s->method->ssl_get_message(s,
                SSL3_ST_CR_CERT_A,
                SSL3_ST_CR_CERT_B,
                -1,
@@ -808,7 +829,7 @@ static int ssl3_get_server_certificate(SSL *s)
                SSLerr(SSL_F_SSL3_GET_SERVER_CERTIFICATE,SSL_R_BAD_MESSAGE_TYPE);
                goto f_err;
                }
-       d=p=(unsigned char *)s->init_msg;
+       p=d=(unsigned char *)s->init_msg;
 
        if ((sk=sk_X509_new_null()) == NULL)
                {
@@ -959,7 +980,7 @@ err:
        return(ret);
        }
 
-static int ssl3_get_key_exchange(SSL *s)
+int ssl3_get_key_exchange(SSL *s)
        {
 #ifndef OPENSSL_NO_RSA
        unsigned char *q,md_buf[EVP_MAX_MD_SIZE*2];
@@ -985,7 +1006,7 @@ static int ssl3_get_key_exchange(SSL *s)
 
        /* use same message size as in ssl3_get_certificate_request()
         * as ServerKeyExchange message may be skipped */
-       n=ssl3_get_message(s,
+       n=s->method->ssl_get_message(s,
                SSL3_ST_CR_KEY_EXCH_A,
                SSL3_ST_CR_KEY_EXCH_B,
                -1,
@@ -1170,6 +1191,9 @@ static int ssl3_get_key_exchange(SSL *s)
 #ifndef OPENSSL_NO_ECDH
        else if (alg & SSL_kECDHE)
                {
+               EC_GROUP *ngroup;
+               const EC_GROUP *group;
+
                if ((ecdh=EC_KEY_new()) == NULL)
                        {
                        SSLerr(SSL_F_SSL3_GET_KEY_EXCHANGE,ERR_R_MALLOC_FAILURE);
@@ -1195,14 +1219,23 @@ static int ssl3_get_key_exchange(SSL *s)
                        goto f_err;
                        }
 
-               if (!(ecdh->group=EC_GROUP_new_by_nid(curve_nid)))
+               ngroup = EC_GROUP_new_by_curve_name(curve_nid);
+               if (ngroup == NULL)
+                       {
+                       SSLerr(SSL_F_SSL3_GET_KEY_EXCHANGE,ERR_R_EC_LIB);
+                       goto err;
+                       }
+               if (EC_KEY_set_group(ecdh, ngroup) == 0)
                        {
                        SSLerr(SSL_F_SSL3_GET_KEY_EXCHANGE,ERR_R_EC_LIB);
                        goto err;
                        }
+               EC_GROUP_free(ngroup);
+
+               group = EC_KEY_get0_group(ecdh);
 
                if (SSL_C_IS_EXPORT(s->s3->tmp.new_cipher) &&
-                   (EC_GROUP_get_degree(ecdh->group) > 163))
+                   (EC_GROUP_get_degree(group) > 163))
                        {
                        al=SSL_AD_EXPORT_RESTRICTION;
                        SSLerr(SSL_F_SSL3_GET_KEY_EXCHANGE,SSL_R_ECGROUP_TOO_LARGE_FOR_CIPHER);
@@ -1212,7 +1245,7 @@ static int ssl3_get_key_exchange(SSL *s)
                p+=2;
 
                /* Next, get the encoded ECPoint */
-               if (((srvr_ecpoint = EC_POINT_new(ecdh->group)) == NULL) ||
+               if (((srvr_ecpoint = EC_POINT_new(group)) == NULL) ||
                    ((bn_ctx = BN_CTX_new()) == NULL))
                        {
                        SSLerr(SSL_F_SSL3_GET_KEY_EXCHANGE,ERR_R_MALLOC_FAILURE);
@@ -1223,7 +1256,7 @@ static int ssl3_get_key_exchange(SSL *s)
                p+=1;
                param_len += (1 + encoded_pt_len);
                if ((param_len > n) ||
-                   (EC_POINT_oct2point(ecdh->group, srvr_ecpoint, 
+                   (EC_POINT_oct2point(group, srvr_ecpoint, 
                        p, encoded_pt_len, bn_ctx) == 0))
                        {
                        al=SSL_AD_DECODE_ERROR;
@@ -1248,10 +1281,11 @@ static int ssl3_get_key_exchange(SSL *s)
                        pkey=X509_get_pubkey(s->session->sess_cert->peer_pkeys[SSL_PKEY_ECC].x509);
 #endif
                /* else anonymous ECDH, so no certificate or pkey. */
-               ecdh->pub_key = srvr_ecpoint;
+               EC_KEY_set_public_key(ecdh, srvr_ecpoint);
                s->session->sess_cert->peer_ecdh_tmp=ecdh;
                ecdh=NULL;
                BN_CTX_free(bn_ctx);
+               EC_POINT_free(srvr_ecpoint);
                srvr_ecpoint = NULL;
                }
        else if (alg & SSL_kECDH)
@@ -1403,16 +1437,17 @@ err:
        return(-1);
        }
 
-static int ssl3_get_certificate_request(SSL *s)
+int ssl3_get_certificate_request(SSL *s)
        {
        int ok,ret=0;
        unsigned long n,nc,l;
        unsigned int llen,ctype_num,i;
        X509_NAME *xn=NULL;
-       unsigned char *p,*d,*q;
+       const unsigned char *p,*q;
+       unsigned char *d;
        STACK_OF(X509_NAME) *ca_sk=NULL;
 
-       n=ssl3_get_message(s,
+       n=s->method->ssl_get_message(s,
                SSL3_ST_CR_CERT_REQ_A,
                SSL3_ST_CR_CERT_REQ_B,
                -1,
@@ -1448,7 +1483,7 @@ static int ssl3_get_certificate_request(SSL *s)
                        }
                }
 
-       d=p=(unsigned char *)s->init_msg;
+       p=d=(unsigned char *)s->init_msg;
 
        if ((ca_sk=sk_X509_NAME_new(ca_dn_cmp)) == NULL)
                {
@@ -1550,12 +1585,12 @@ static int ca_dn_cmp(const X509_NAME * const *a, const X509_NAME * const *b)
        return(X509_NAME_cmp(*a,*b));
        }
 
-static int ssl3_get_server_done(SSL *s)
+int ssl3_get_server_done(SSL *s)
        {
        int ok,ret=0;
        long n;
 
-       n=ssl3_get_message(s,
+       n=s->method->ssl_get_message(s,
                SSL3_ST_CR_SRVR_DONE_A,
                SSL3_ST_CR_SRVR_DONE_B,
                SSL3_MT_SERVER_DONE,
@@ -1575,19 +1610,23 @@ static int ssl3_get_server_done(SSL *s)
        }
 
 
+#ifndef OPENSSL_NO_ECDH
 static const int KDF1_SHA1_len = 20;
-static void *KDF1_SHA1(void *in, size_t inlen, void *out, size_t outlen)
+static void *KDF1_SHA1(const void *in, size_t inlen, void *out, size_t *outlen)
        {
 #ifndef OPENSSL_NO_SHA
-       if (outlen != SHA_DIGEST_LENGTH)
+       if (*outlen < SHA_DIGEST_LENGTH)
                return NULL;
+       else
+               *outlen = SHA_DIGEST_LENGTH;
        return SHA1(in, inlen, out);
 #else
        return NULL;
-#endif
+#endif /* OPENSSL_NO_SHA */
        }
+#endif /* OPENSSL_NO_ECDH */
 
-static int ssl3_send_client_key_exchange(SSL *s)
+int ssl3_send_client_key_exchange(SSL *s)
        {
        unsigned char *p,*d;
        int n;
@@ -1601,7 +1640,7 @@ static int ssl3_send_client_key_exchange(SSL *s)
 #endif /* OPENSSL_NO_KRB5 */
 #ifndef OPENSSL_NO_ECDH
        EC_KEY *clnt_ecdh = NULL;
-       EC_POINT *srvr_ecpoint = NULL;
+       const EC_POINT *srvr_ecpoint = NULL;
        EVP_PKEY *srvr_pub_pkey = NULL;
        unsigned char *encodedPoint = NULL;
        int encoded_pt_len = 0;
@@ -1868,7 +1907,8 @@ static int ssl3_send_client_key_exchange(SSL *s)
 #ifndef OPENSSL_NO_ECDH 
                else if ((l & SSL_kECDH) || (l & SSL_kECDHE))
                        {
-                       EC_GROUP *srvr_group = NULL;
+                       const EC_GROUP *srvr_group = NULL;
+                       EC_KEY *tkey;
                        int ecdh_clnt_cert = 0;
                        int field_size = 0;
 
@@ -1902,10 +1942,7 @@ static int ssl3_send_client_key_exchange(SSL *s)
 
                        if (s->session->sess_cert->peer_ecdh_tmp != NULL)
                                {
-                               srvr_group = s->session->sess_cert-> \
-                                   peer_ecdh_tmp->group;
-                               srvr_ecpoint = s->session->sess_cert-> \
-                                   peer_ecdh_tmp->pub_key;
+                               tkey = s->session->sess_cert->peer_ecdh_tmp;
                                }
                        else
                                {
@@ -1914,18 +1951,19 @@ static int ssl3_send_client_key_exchange(SSL *s)
                                    sess_cert->peer_pkeys[SSL_PKEY_ECC].x509);
                                if ((srvr_pub_pkey == NULL) ||
                                    (srvr_pub_pkey->type != EVP_PKEY_EC) ||
-                                   (srvr_pub_pkey->pkey.eckey == NULL))
+                                   (srvr_pub_pkey->pkey.ec == NULL))
                                        {
                                        SSLerr(SSL_F_SSL3_SEND_CLIENT_KEY_EXCHANGE,
                                            ERR_R_INTERNAL_ERROR);
                                        goto err;
                                        }
 
-                               srvr_group = srvr_pub_pkey->pkey.eckey->group;
-                               srvr_ecpoint = 
-                                   srvr_pub_pkey->pkey.eckey->pub_key;
+                               tkey = srvr_pub_pkey->pkey.ec;
                                }
 
+                       srvr_group   = EC_KEY_get0_group(tkey);
+                       srvr_ecpoint = EC_KEY_get0_public_key(tkey);
+
                        if ((srvr_group == NULL) || (srvr_ecpoint == NULL))
                                {
                                SSLerr(SSL_F_SSL3_SEND_CLIENT_KEY_EXCHANGE,
@@ -1939,15 +1977,30 @@ static int ssl3_send_client_key_exchange(SSL *s)
                                goto err;
                                }
 
-                       clnt_ecdh->group = srvr_group;
+                       if (!EC_KEY_set_group(clnt_ecdh, srvr_group))
+                               {
+                               SSLerr(SSL_F_SSL3_SEND_CLIENT_KEY_EXCHANGE,ERR_R_EC_LIB);
+                               goto err;
+                               }
                        if (ecdh_clnt_cert) 
                                { 
                                /* Reuse key info from our certificate
                                 * We only need our private key to perform
                                 * the ECDH computation.
                                 */
-                               clnt_ecdh->priv_key = BN_dup(s->cert->key-> \
-                                   privatekey->pkey.eckey->priv_key);
+                               const BIGNUM *priv_key;
+                               tkey = s->cert->key->privatekey->pkey.ec;
+                               priv_key = EC_KEY_get0_private_key(tkey);
+                               if (priv_key == NULL)
+                                       {
+                                       SSLerr(SSL_F_SSL3_SEND_CLIENT_KEY_EXCHANGE,ERR_R_MALLOC_FAILURE);
+                                       goto err;
+                                       }
+                               if (!EC_KEY_set_private_key(clnt_ecdh, priv_key))
+                                       {
+                                       SSLerr(SSL_F_SSL3_SEND_CLIENT_KEY_EXCHANGE,ERR_R_EC_LIB);
+                                       goto err;
+                                       }
                                }
                        else 
                                {
@@ -1963,7 +2016,7 @@ static int ssl3_send_client_key_exchange(SSL *s)
                         * make sure to clear it out afterwards
                         */
 
-                       field_size = EC_GROUP_get_degree(clnt_ecdh->group);
+                       field_size = EC_GROUP_get_degree(srvr_group);
                        if (field_size <= 0)
                                {
                                SSLerr(SSL_F_SSL3_SEND_CLIENT_KEY_EXCHANGE, 
@@ -2004,8 +2057,8 @@ static int ssl3_send_client_key_exchange(SSL *s)
                                 * allocate memory accordingly.
                                 */
                                encoded_pt_len = 
-                                   EC_POINT_point2oct(clnt_ecdh->group, 
-                                       clnt_ecdh->pub_key
+                                   EC_POINT_point2oct(srvr_group, 
+                                       EC_KEY_get0_public_key(clnt_ecdh)
                                        POINT_CONVERSION_UNCOMPRESSED, 
                                        NULL, 0, NULL);
 
@@ -2021,8 +2074,8 @@ static int ssl3_send_client_key_exchange(SSL *s)
                                        }
 
                                /* Encode the public key */
-                               n = EC_POINT_point2oct(clnt_ecdh->group, 
-                                   clnt_ecdh->pub_key
+                               n = EC_POINT_point2oct(srvr_group, 
+                                   EC_KEY_get0_public_key(clnt_ecdh)
                                    POINT_CONVERSION_UNCOMPRESSED, 
                                    encodedPoint, encoded_pt_len, bn_ctx);
 
@@ -2039,11 +2092,7 @@ static int ssl3_send_client_key_exchange(SSL *s)
                        BN_CTX_free(bn_ctx);
                        if (encodedPoint != NULL) OPENSSL_free(encodedPoint);
                        if (clnt_ecdh != NULL) 
-                               {
-                                /* group is shared */
-                                clnt_ecdh->group = NULL; 
                                 EC_KEY_free(clnt_ecdh);
-                               }
                        EVP_PKEY_free(srvr_pub_pkey);
                        }
 #endif /* !OPENSSL_NO_ECDH */
@@ -2072,17 +2121,13 @@ err:
        BN_CTX_free(bn_ctx);
        if (encodedPoint != NULL) OPENSSL_free(encodedPoint);
        if (clnt_ecdh != NULL) 
-               {
-               /* group is shared */
-               clnt_ecdh->group = NULL; 
                EC_KEY_free(clnt_ecdh);
-               }
        EVP_PKEY_free(srvr_pub_pkey);
 #endif
        return(-1);
        }
 
-static int ssl3_send_client_verify(SSL *s)
+int ssl3_send_client_verify(SSL *s)
        {
        unsigned char *p,*d;
        unsigned char data[MD5_DIGEST_LENGTH+SHA_DIGEST_LENGTH];
@@ -2091,7 +2136,7 @@ static int ssl3_send_client_verify(SSL *s)
        unsigned u=0;
 #endif
        unsigned long n;
-#ifndef OPENSSL_NO_DSA
+#if !defined(OPENSSL_NO_DSA) || !defined(OPENSSL_NO_ECDSA)
        int j;
 #endif
 
@@ -2143,7 +2188,7 @@ static int ssl3_send_client_verify(SSL *s)
                        if (!ECDSA_sign(pkey->save_type,
                                &(data[MD5_DIGEST_LENGTH]),
                                SHA_DIGEST_LENGTH,&(p[2]),
-                               (unsigned int *)&j,pkey->pkey.eckey))
+                               (unsigned int *)&j,pkey->pkey.ec))
                                {
                                SSLerr(SSL_F_SSL3_SEND_CLIENT_VERIFY,
                                    ERR_R_ECDSA_LIB);
@@ -2170,7 +2215,7 @@ err:
        return(-1);
        }
 
-static int ssl3_send_client_certificate(SSL *s)
+int ssl3_send_client_certificate(SSL *s)
        {
        X509 *x509=NULL;
        EVP_PKEY *pkey=NULL;
@@ -2249,7 +2294,7 @@ static int ssl3_send_client_certificate(SSL *s)
 
 #define has_bits(i,m)  (((i)&(m)) == (m))
 
-static int ssl3_check_cert_and_algorithm(SSL *s)
+int ssl3_check_cert_and_algorithm(SSL *s)
        {
        int i,idx;
        long algs;
@@ -2354,7 +2399,7 @@ static int ssl3_check_cert_and_algorithm(SSL *s)
                if (algs & SSL_kRSA)
                        {
                        if (rsa == NULL
-                           || RSA_size(rsa) > SSL_C_EXPORT_PKEYLENGTH(s->s3->tmp.new_cipher))
+                           || RSA_size(rsa)*8 > SSL_C_EXPORT_PKEYLENGTH(s->s3->tmp.new_cipher))
                                {
                                SSLerr(SSL_F_SSL3_CHECK_CERT_AND_ALGORITHM,SSL_R_MISSING_EXPORT_TMP_RSA_KEY);
                                goto f_err;
@@ -2366,7 +2411,7 @@ static int ssl3_check_cert_and_algorithm(SSL *s)
                        if (algs & (SSL_kEDH|SSL_kDHr|SSL_kDHd))
                            {
                            if (dh == NULL
-                               || DH_size(dh) > SSL_C_EXPORT_PKEYLENGTH(s->s3->tmp.new_cipher))
+                               || DH_size(dh)*8 > SSL_C_EXPORT_PKEYLENGTH(s->s3->tmp.new_cipher))
                                {
                                SSLerr(SSL_F_SSL3_CHECK_CERT_AND_ALGORITHM,SSL_R_MISSING_EXPORT_TMP_DH_KEY);
                                goto f_err;