The first call to query the mtu in dtls1_do_write correctly checks that the
[openssl.git] / ssl / d1_both.c
index fb524dafa09cf80df538cce6260a4ec08bc84767..9a981e82ae3d7b9a0dc27e3c03c6c5492e4fd302 100644 (file)
@@ -156,7 +156,7 @@ static unsigned char bitmask_start_values[] = {0xff, 0xfe, 0xfc, 0xf8, 0xf0, 0xe
 static unsigned char bitmask_end_values[]   = {0xff, 0x01, 0x03, 0x07, 0x0f, 0x1f, 0x3f, 0x7f};
 
 /* XDTLS:  figure out the right values */
-static unsigned int g_probable_mtu[] = {1500 - 28, 512 - 28, 256 - 28};
+static const unsigned int g_probable_mtu[] = {1500 - 28, 512 - 28, 256 - 28};
 
 static unsigned int dtls1_guess_mtu(unsigned int curr_mtu);
 static void dtls1_fix_message_header(SSL *s, unsigned long frag_off, 
@@ -211,8 +211,7 @@ dtls1_hm_fragment_new(unsigned long frag_len, int reassembly)
        return frag;
        }
 
-static void
-dtls1_hm_fragment_free(hm_fragment *frag)
+void dtls1_hm_fragment_free(hm_fragment *frag)
        {
 
        if (frag->msg_header.is_ccs)
@@ -225,13 +224,8 @@ dtls1_hm_fragment_free(hm_fragment *frag)
        OPENSSL_free(frag);
        }
 
-/* send s->init_buf in records of type 'type' (SSL3_RT_HANDSHAKE or SSL3_RT_CHANGE_CIPHER_SPEC) */
-int dtls1_do_write(SSL *s, int type)
-       {
-       int ret;
-       int curr_mtu;
-       unsigned int len, frag_off, mac_size, blocksize;
-
+static void dtls1_query_mtu(SSL *s)
+{
        /* AHA!  Figure out the MTU, and stick to the right size */
        if (s->d1->mtu < dtls1_min_mtu() && !(SSL_get_options(s) & SSL_OP_NO_QUERY_MTU))
                {
@@ -248,6 +242,16 @@ int dtls1_do_write(SSL *s, int type)
                                s->d1->mtu, NULL);
                        }
                }
+}
+
+/* send s->init_buf in records of type 'type' (SSL3_RT_HANDSHAKE or SSL3_RT_CHANGE_CIPHER_SPEC) */
+int dtls1_do_write(SSL *s, int type)
+       {
+       int ret;
+       int curr_mtu;
+       unsigned int len, frag_off, mac_size, blocksize;
+
+       dtls1_query_mtu(s);
 #if 0 
        mtu = s->d1->mtu;
 
@@ -330,12 +334,18 @@ int dtls1_do_write(SSL *s, int type)
                                        len = s->init_num;
                                }
 
+                       if ( len < DTLS1_HM_HEADER_LENGTH )
+                               {
+                               /*
+                                * len is so small that we really can't do anything sensible
+                                * so fail
+                                */
+                               return -1;
+                               }
                        dtls1_fix_message_header(s, frag_off, 
                                len - DTLS1_HM_HEADER_LENGTH);
 
                        dtls1_write_message_header(s, (unsigned char *)&s->init_buf->data[s->init_off]);
-
-                       OPENSSL_assert(len >= DTLS1_HM_HEADER_LENGTH);
                        }
 
                ret=dtls1_write_bytes(s,type,&s->init_buf->data[s->init_off],
@@ -350,10 +360,16 @@ int dtls1_do_write(SSL *s, int type)
                         */
                        if ( BIO_ctrl(SSL_get_wbio(s),
                                BIO_CTRL_DGRAM_MTU_EXCEEDED, 0, NULL) > 0 )
-                               s->d1->mtu = BIO_ctrl(SSL_get_wbio(s),
-                                       BIO_CTRL_DGRAM_QUERY_MTU, 0, NULL);
+                               {
+                               if(!(SSL_get_options(s) & SSL_OP_NO_QUERY_MTU))
+                                       dtls1_query_mtu(s);
+                               else
+                                       return -1;
+                               }
                        else
+                               {
                                return(-1);
+                               }
                        }
                else
                        {
@@ -1363,6 +1379,9 @@ dtls1_process_heartbeat(SSL *s)
        /* Read type and payload length first */
        if (1 + 2 + 16 > s->s3->rrec.length)
                return 0; /* silently discard */
+       if (s->s3->rrec.length > SSL3_RT_MAX_PLAIN_LENGTH)
+               return 0; /* silently discard per RFC 6520 sec. 4 */
+
        hbtype = *p++;
        n2s(p, payload);
        if (1 + 2 + payload + 16 > s->s3->rrec.length)