PR: 1731 and maybe 2197
[openssl.git] / ssl / d1_both.c
index d11d6d5888281165e4b5b12ed1223a9dab891347..0242f1e4da00f510dae5871e2e55724e62b2aad2 100644 (file)
@@ -177,7 +177,7 @@ int dtls1_do_write(SSL *s, int type)
        {
        int ret;
        int curr_mtu;
-       unsigned int len, frag_off;
+       unsigned int len, frag_off, mac_size, blocksize;
 
        /* AHA!  Figure out the MTU, and stick to the right size */
        if ( ! (SSL_get_options(s) & SSL_OP_NO_QUERY_MTU))
@@ -225,11 +225,22 @@ int dtls1_do_write(SSL *s, int type)
                OPENSSL_assert(s->init_num == 
                        (int)s->d1->w_msg_hdr.msg_len + DTLS1_HM_HEADER_LENGTH);
 
+       if (s->write_hash)
+               mac_size = EVP_MD_CTX_size(s->write_hash);
+       else
+               mac_size = 0;
+
+       if (s->enc_write_ctx && 
+               (EVP_CIPHER_mode( s->enc_write_ctx->cipher) & EVP_CIPH_CBC_MODE))
+               blocksize = 2 * EVP_CIPHER_block_size(s->enc_write_ctx->cipher);
+       else
+               blocksize = 0;
+
        frag_off = 0;
        while( s->init_num)
                {
                curr_mtu = s->d1->mtu - BIO_wpending(SSL_get_wbio(s)) - 
-                       DTLS1_RT_HEADER_LENGTH;
+                       DTLS1_RT_HEADER_LENGTH - mac_size - blocksize;
 
                if ( curr_mtu <= DTLS1_HM_HEADER_LENGTH)
                        {
@@ -237,7 +248,8 @@ int dtls1_do_write(SSL *s, int type)
                        ret = BIO_flush(SSL_get_wbio(s));
                        if ( ret <= 0)
                                return ret;
-                       curr_mtu = s->d1->mtu - DTLS1_RT_HEADER_LENGTH;
+                       curr_mtu = s->d1->mtu - DTLS1_RT_HEADER_LENGTH -
+                               mac_size - blocksize;
                        }
 
                if ( s->init_num > curr_mtu)
@@ -279,7 +291,7 @@ int dtls1_do_write(SSL *s, int type)
                         * retransmit 
                         */
                        if ( BIO_ctrl(SSL_get_wbio(s),
-                               BIO_CTRL_DGRAM_MTU_EXCEEDED, 0, NULL))
+                               BIO_CTRL_DGRAM_MTU_EXCEEDED, 0, NULL) > 0 )
                                s->d1->mtu = BIO_ctrl(SSL_get_wbio(s),
                                        BIO_CTRL_DGRAM_QUERY_MTU, 0, NULL);
                        else
@@ -752,6 +764,24 @@ int dtls1_send_finished(SSL *s, int a, int b, const char *sender, int slen)
                p+=i;
                l=i;
 
+       /* Copy the finished so we can use it for
+        * renegotiation checks
+        */
+       if(s->type == SSL_ST_CONNECT)
+               {
+               OPENSSL_assert(i <= EVP_MAX_MD_SIZE);
+               memcpy(s->s3->previous_client_finished, 
+                      s->s3->tmp.finish_md, i);
+               s->s3->previous_client_finished_len=i;
+               }
+       else
+               {
+               OPENSSL_assert(i <= EVP_MAX_MD_SIZE);
+               memcpy(s->s3->previous_server_finished, 
+                      s->s3->tmp.finish_md, i);
+               s->s3->previous_server_finished_len=i;
+               }
+
 #ifdef OPENSSL_SYS_WIN16
                /* MSVC 1.5 does not clear the top bytes of the word unless
                 * I do this.
@@ -856,6 +886,8 @@ unsigned long dtls1_output_cert_chain(SSL *s, X509 *x)
                        }
   
                X509_verify_cert(&xs_ctx);
+               /* Don't leave errors in the queue */
+               ERR_clear_error();
                for (i=0; i < sk_X509_num(xs_ctx.chain); i++)
                        {
                        x = sk_X509_value(xs_ctx.chain, i);
@@ -890,9 +922,6 @@ unsigned long dtls1_output_cert_chain(SSL *s, X509 *x)
 
 int dtls1_read_failed(SSL *s, int code)
        {
-       DTLS1_STATE *state;
-       int send_alert = 0;
-
        if ( code > 0)
                {
                fprintf( stderr, "invalid state reached %s:%d", __FILE__, __LINE__);
@@ -912,24 +941,6 @@ int dtls1_read_failed(SSL *s, int code)
                return code;
                }
 
-       dtls1_double_timeout(s);
-       state = s->d1;
-       state->timeout.num_alerts++;
-       if ( state->timeout.num_alerts > DTLS1_TMO_ALERT_COUNT)
-               {
-               /* fail the connection, enough alerts have been sent */
-               SSLerr(SSL_F_DTLS1_READ_FAILED,SSL_R_READ_TIMEOUT_EXPIRED);
-               return 0;
-               }
-
-       state->timeout.read_timeouts++;
-       if ( state->timeout.read_timeouts > DTLS1_TMO_READ_COUNT)
-               {
-               send_alert = 1;
-               state->timeout.read_timeouts = 1;
-               }
-
-
 #if 0 /* for now, each alert contains only one record number */
        item = pqueue_peek(state->rcvd_records);
        if ( item )
@@ -940,12 +951,12 @@ int dtls1_read_failed(SSL *s, int code)
 #endif
 
 #if 0  /* no more alert sending, just retransmit the last set of messages */
-               if ( send_alert)
-                       ssl3_send_alert(s,SSL3_AL_WARNING,
-                               DTLS1_AD_MISSING_HANDSHAKE_MESSAGE);
+       if ( state->timeout.read_timeouts >= DTLS1_TMO_READ_COUNT)
+               ssl3_send_alert(s,SSL3_AL_WARNING,
+                       DTLS1_AD_MISSING_HANDSHAKE_MESSAGE);
 #endif
 
-       return dtls1_retransmit_buffered_messages(s) ;
+       return dtls1_handle_timeout(s);
        }
 
 int