Suite B support for DTLS 1.2
[openssl.git] / ssl / s3_clnt.c
index 81e45a758ef1326933e2337f399114e66879fb4b..018a9f590ce87eab3c24c722f0fb5f83609948ef 100644 (file)
@@ -694,16 +694,73 @@ int ssl3_client_hello(SSL *s)
                        if (!ssl_get_new_session(s,0))
                                goto err;
                        }
+               if (s->method->version == DTLS_ANY_VERSION)
+                       {
+                       /* Determine which DTLS version to use */
+                       int options = s->options;
+                       /* If DTLS 1.2 disabled correct the version number */
+                       if (options & SSL_OP_NO_DTLSv1_2)
+                               {
+                               if (tls1_suiteb(s))
+                                       {
+                                       SSLerr(SSL_F_SSL3_CLIENT_HELLO, SSL_R_ONLY_DTLS_1_2_ALLOWED_IN_SUITEB_MODE);
+                                       goto err;
+                                       }
+                               /* Disabling all versions is silly: return an
+                                * error.
+                                */
+                               if (options & SSL_OP_NO_DTLSv1)
+                                       {
+                                       SSLerr(SSL_F_SSL3_CLIENT_HELLO,SSL_R_WRONG_SSL_VERSION);
+                                       goto err;
+                                       }
+                               /* Update method so we don't use any DTLS 1.2
+                                * features.
+                                */
+                               s->method = DTLSv1_client_method();
+                               s->version = DTLS1_VERSION;
+                               }
+                       else
+                               {
+                               /* We only support one version: update method */
+                               if (options & SSL_OP_NO_DTLSv1)
+                                       s->method = DTLSv1_2_client_method();
+                               s->version = DTLS1_2_VERSION;
+                               }
+                       s->client_version = s->version;
+                       }
                /* else use the pre-loaded session */
 
                p=s->s3->client_random;
-               Time=(unsigned long)time(NULL);                 /* Time */
-               l2n(Time,p);
-               if (RAND_pseudo_bytes(p,SSL3_RANDOM_SIZE-4) <= 0)
-                       goto err;
+
+               /* for DTLS if client_random is initialized, reuse it, we are
+                * required to use same upon reply to HelloVerify */
+               if (SSL_IS_DTLS(s))
+                       {
+                       size_t idx;
+                       i = 1;
+                       for (idx=0; idx < sizeof(s->s3->client_random); idx++)
+                               {
+                               if (p[idx])
+                                       {
+                                       i = 0;
+                                       break;
+                                       }
+                               }
+                       }
+               else 
+                       i = 1;
+
+               if (i)
+                       {
+                       Time=(unsigned long)time(NULL); /* Time */
+                       l2n(Time,p);
+                       RAND_pseudo_bytes(p,sizeof(s->s3->client_random)-4);
+                                       
+                       }
 
                /* Do the message type and length last */
-               d=p= &(buf[4]);
+               d=p= ssl_handshake_start(s);
 
                /* version indicates the negotiated version: for example from
                 * an SSLv2/v3 compatible client hello). The client_version
@@ -764,6 +821,19 @@ int ssl3_client_hello(SSL *s)
                        p+=i;
                        }
                
+               /* cookie stuff for DTLS */
+               if (SSL_IS_DTLS(s))
+                       {
+                       if ( s->d1->cookie_len > sizeof(s->d1->cookie))
+                               {
+                               SSLerr(SSL_F_SSL3_CLIENT_HELLO, ERR_R_INTERNAL_ERROR);
+                               goto err;
+                               }
+                       *(p++) = s->d1->cookie_len;
+                       memcpy(p, s->d1->cookie, s->d1->cookie_len);
+                       p += s->d1->cookie_len;
+                       }
+               
                /* Ciphers supported */
                i=ssl_cipher_list_to_bytes(s,SSL_get_ciphers(s),&(p[2]),0);
                if (i == 0)
@@ -816,19 +886,13 @@ int ssl3_client_hello(SSL *s)
                        }
 #endif
                
-               l=(p-d);
-               d=buf;
-               *(d++)=SSL3_MT_CLIENT_HELLO;
-               l2n3(l,d);
-
+               l= p-d;
+               ssl_set_handshake_header(s, SSL3_MT_CLIENT_HELLO, l);
                s->state=SSL3_ST_CW_CLNT_HELLO_B;
-               /* number of bytes to write */
-               s->init_num=p-buf;
-               s->init_off=0;
                }
 
        /* SSL3_ST_CW_CLNT_HELLO_B */
-       return(ssl3_do_write(s,SSL3_RT_HANDSHAKE));
+       return ssl_do_write(s);
 err:
        return(-1);
        }
@@ -845,6 +909,11 @@ int ssl3_get_server_hello(SSL *s)
 #ifndef OPENSSL_NO_COMP
        SSL_COMP *comp;
 #endif
+       /* Hello verify request and/or server hello version may not
+        * match so set first packet if we're negotiating version.
+        */
+       if (s->method->version == DTLS_ANY_VERSION)
+               s->first_packet = 1;
 
        n=s->method->ssl_get_message(s,
                SSL3_ST_CR_SRVR_HELLO_A,
@@ -855,8 +924,9 @@ int ssl3_get_server_hello(SSL *s)
 
        if (!ok) return((int)n);
 
-       if ( SSL_version(s) == DTLS1_VERSION || SSL_version(s) == DTLS1_BAD_VER)
+       if (SSL_IS_DTLS(s))
                {
+               s->first_packet = 0;
                if ( s->s3->tmp.message_type == DTLS1_MT_HELLO_VERIFY_REQUEST)
                        {
                        if ( s->d1->send_cookie == 0)
@@ -881,6 +951,33 @@ int ssl3_get_server_hello(SSL *s)
                }
 
        d=p=(unsigned char *)s->init_msg;
+       if (s->method->version == DTLS_ANY_VERSION)
+               {
+               /* Work out correct protocol version to use */
+               int hversion = (p[0] << 8)|p[1];
+               int options = s->options;
+               if (hversion == DTLS1_2_VERSION
+                       && !(options & SSL_OP_NO_DTLSv1_2))
+                       s->method = DTLSv1_2_client_method();
+               else if (tls1_suiteb(s))
+                       {
+                       SSLerr(SSL_F_SSL3_GET_SERVER_HELLO, SSL_R_ONLY_DTLS_1_2_ALLOWED_IN_SUITEB_MODE);
+                       s->version = hversion;
+                       al = SSL_AD_PROTOCOL_VERSION;
+                       goto f_err;
+                       }
+               else if (hversion == DTLS1_VERSION
+                       && !(options & SSL_OP_NO_DTLSv1))
+                       s->method = DTLSv1_client_method();
+               else
+                       {
+                       SSLerr(SSL_F_SSL3_GET_SERVER_HELLO,SSL_R_WRONG_SSL_VERSION);
+                       s->version = hversion;
+                       al = SSL_AD_PROTOCOL_VERSION;
+                       goto f_err;
+                       }
+               s->version = s->client_version = s->method->version;
+               }
 
        if ((p[0] != (s->version>>8)) || (p[1] != (s->version&0xff)))
                {
@@ -1002,10 +1099,10 @@ int ssl3_get_server_hello(SSL *s)
                        }
                }
        s->s3->tmp.new_cipher=c;
-       /* Don't digest cached records if TLS v1.2: we may need them for
+       /* Don't digest cached records if no sigalgs: we may need them for
         * client authentication.
         */
-       if (TLS1_get_version(s) < TLS1_2_VERSION && !ssl3_digest_cached_records(s))
+       if (!SSL_USE_SIGALGS(s) && !ssl3_digest_cached_records(s))
                goto f_err;
        /* lets get the compression algorithm */
        /* COMPRESSION */
@@ -1225,6 +1322,15 @@ int ssl3_get_server_certificate(SSL *s)
 
        if (need_cert)
                {
+               int exp_idx = ssl_cipher_get_cert_index(s->s3->tmp.new_cipher);
+               if (exp_idx >= 0 && i != exp_idx)
+                       {
+                       x=NULL;
+                       al=SSL_AD_ILLEGAL_PARAMETER;
+                       SSLerr(SSL_F_SSL3_GET_SERVER_CERTIFICATE,
+                               SSL_R_WRONG_CERTIFICATE_TYPE);
+                       goto f_err;
+                       }
                sc->peer_cert_type=i;
                CRYPTO_add(&x->references,1,CRYPTO_LOCK_X509);
                /* Why would the following ever happen?
@@ -1748,7 +1854,7 @@ int ssl3_get_key_exchange(SSL *s)
        /* if it was signed, check the signature */
        if (pkey != NULL)
                {
-               if (TLS1_get_version(s) >= TLS1_2_VERSION)
+               if (SSL_USE_SIGALGS(s))
                        {
                        int rv = tls12_check_peer_sigalg(&md, s, p, pkey);
                        if (rv == -1)
@@ -1780,7 +1886,7 @@ fprintf(stderr, "USING TLSv1.2 HASH %s\n", EVP_MD_name(md));
                        }
 
 #ifndef OPENSSL_NO_RSA
-               if (pkey->type == EVP_PKEY_RSA && TLS1_get_version(s) < TLS1_2_VERSION)
+               if (pkey->type == EVP_PKEY_RSA && !SSL_USE_SIGALGS(s))
                        {
                        int num;
 
@@ -1954,7 +2060,7 @@ int ssl3_get_certificate_request(SSL *s)
        for (i=0; i<ctype_num; i++)
                s->s3->tmp.ctype[i]= p[i];
        p+=p[-1];
-       if (TLS1_get_version(s) >= TLS1_2_VERSION)
+       if (SSL_USE_SIGALGS(s))
                {
                n2s(p, llen);
                /* Check we have enough room for signature algorithms and
@@ -2252,7 +2358,7 @@ int ssl3_get_server_done(SSL *s)
 
 int ssl3_send_client_key_exchange(SSL *s)
        {
-       unsigned char *p,*d;
+       unsigned char *p;
        int n;
        unsigned long alg_k;
 #ifndef OPENSSL_NO_RSA
@@ -2273,8 +2379,7 @@ int ssl3_send_client_key_exchange(SSL *s)
 
        if (s->state == SSL3_ST_CW_KEY_EXCH_A)
                {
-               d=(unsigned char *)s->init_buf->data;
-               p= &(d[4]);
+               p = ssl_handshake_start(s);
 
                alg_k=s->s3->tmp.new_cipher->algorithm_mkey;
 
@@ -2975,18 +3080,13 @@ int ssl3_send_client_key_exchange(SSL *s)
                            ERR_R_INTERNAL_ERROR);
                        goto err;
                        }
-               
-               *(d++)=SSL3_MT_CLIENT_KEY_EXCHANGE;
-               l2n3(n,d);
 
+               ssl_set_handshake_header(s, SSL3_MT_CLIENT_KEY_EXCHANGE, n);
                s->state=SSL3_ST_CW_KEY_EXCH_B;
-               /* number of bytes to write */
-               s->init_num=n+4;
-               s->init_off=0;
                }
 
        /* SSL3_ST_CW_KEY_EXCH_B */
-       return(ssl3_do_write(s,SSL3_RT_HANDSHAKE));
+       return ssl_do_write(s);
 err:
 #ifndef OPENSSL_NO_ECDH
        BN_CTX_free(bn_ctx);
@@ -3000,7 +3100,7 @@ err:
 
 int ssl3_send_client_verify(SSL *s)
        {
-       unsigned char *p,*d;
+       unsigned char *p;
        unsigned char data[MD5_DIGEST_LENGTH+SHA_DIGEST_LENGTH];
        EVP_PKEY *pkey;
        EVP_PKEY_CTX *pctx=NULL;
@@ -3013,15 +3113,14 @@ int ssl3_send_client_verify(SSL *s)
 
        if (s->state == SSL3_ST_CW_CERT_VRFY_A)
                {
-               d=(unsigned char *)s->init_buf->data;
-               p= &(d[4]);
+               p= ssl_handshake_start(s);
                pkey=s->cert->key->privatekey;
 /* Create context from key and test if sha1 is allowed as digest */
                pctx = EVP_PKEY_CTX_new(pkey,NULL);
                EVP_PKEY_sign_init(pctx);
                if (EVP_PKEY_CTX_set_signature_md(pctx, EVP_sha1())>0)
                        {
-                       if (TLS1_get_version(s) < TLS1_2_VERSION)
+                       if (!SSL_USE_SIGALGS(s))
                                s->method->ssl3_enc->cert_verify_mac(s,
                                                NID_sha1,
                                                &(data[MD5_DIGEST_LENGTH]));
@@ -3033,7 +3132,7 @@ int ssl3_send_client_verify(SSL *s)
                /* For TLS v1.2 send signature algorithm and signature
                 * using agreed digest and cached handshake records.
                 */
-               if (TLS1_get_version(s) >= TLS1_2_VERSION)
+               if (SSL_USE_SIGALGS(s))
                        {
                        long hdatalen = 0;
                        void *hdata;
@@ -3140,16 +3239,12 @@ int ssl3_send_client_verify(SSL *s)
                        SSLerr(SSL_F_SSL3_SEND_CLIENT_VERIFY,ERR_R_INTERNAL_ERROR);
                        goto err;
                }
-               *(d++)=SSL3_MT_CERTIFICATE_VERIFY;
-               l2n3(n,d);
-
+               ssl_set_handshake_header(s, SSL3_MT_CERTIFICATE_VERIFY, n);
                s->state=SSL3_ST_CW_CERT_VRFY_B;
-               s->init_num=(int)n+4;
-               s->init_off=0;
                }
        EVP_MD_CTX_cleanup(&mctx);
        EVP_PKEY_CTX_free(pctx);
-       return(ssl3_do_write(s,SSL3_RT_HANDSHAKE));
+       return ssl_do_write(s);
 err:
        EVP_MD_CTX_cleanup(&mctx);
        EVP_PKEY_CTX_free(pctx);
@@ -3167,7 +3262,7 @@ static int ssl3_check_client_certificate(SSL *s)
        if (!s->cert || !s->cert->key->x509 || !s->cert->key->privatekey)
                return 0;
        /* If no suitable signature algorithm can't use certificate */
-       if (TLS1_get_version(s) >= TLS1_2_VERSION && !s->cert->key->digest)
+       if (SSL_USE_SIGALGS(s) && !s->cert->key->digest)
                return 0;
        /* If strict mode check suitability of chain before using it.
         * This also adjusts suite B digest if necessary.
@@ -3206,7 +3301,6 @@ int ssl3_send_client_certificate(SSL *s)
        X509 *x509=NULL;
        EVP_PKEY *pkey=NULL;
        int i;
-       unsigned long l;
 
        if (s->state == SSL3_ST_CW_CERT_A)
                {
@@ -3275,13 +3369,11 @@ int ssl3_send_client_certificate(SSL *s)
        if (s->state == SSL3_ST_CW_CERT_C)
                {
                s->state=SSL3_ST_CW_CERT_D;
-               l=ssl3_output_cert_chain(s,
+               ssl3_output_cert_chain(s,
                        (s->s3->tmp.cert_req == 2)?NULL:s->cert->key);
-               s->init_num=(int)l;
-               s->init_off=0;
                }
        /* SSL3_ST_CW_CERT_D */
-       return(ssl3_do_write(s,SSL3_RT_HANDSHAKE));
+       return ssl_do_write(s);
        }
 
 #define has_bits(i,m)  (((i)&(m)) == (m))
@@ -3381,14 +3473,14 @@ int ssl3_check_cert_and_algorithm(SSL *s)
                SSLerr(SSL_F_SSL3_CHECK_CERT_AND_ALGORITHM,SSL_R_MISSING_DH_KEY);
                goto f_err;
                }
-       else if ((alg_k & SSL_kDHr) && (TLS1_get_version(s) < TLS1_2_VERSION) &&
+       else if ((alg_k & SSL_kDHr) && !SSL_USE_SIGALGS(s) &&
                !has_bits(i,EVP_PK_DH|EVP_PKS_RSA))
                {
                SSLerr(SSL_F_SSL3_CHECK_CERT_AND_ALGORITHM,SSL_R_MISSING_DH_RSA_CERT);
                goto f_err;
                }
 #ifndef OPENSSL_NO_DSA
-       else if ((alg_k & SSL_kDHd) && (TLS1_get_version(s) < TLS1_2_VERSION) &&
+       else if ((alg_k & SSL_kDHd) && !SSL_USE_SIGALGS(s) &&
                !has_bits(i,EVP_PK_DH|EVP_PKS_DSA))
                {
                SSLerr(SSL_F_SSL3_CHECK_CERT_AND_ALGORITHM,SSL_R_MISSING_DH_DSA_CERT);