Fix dtls_query_mtu so that it will always either complete with an mtu that is
[openssl.git] / ssl / d1_both.c
index fb524dafa09cf80df538cce6260a4ec08bc84767..808d4d14eb4aaa10902ed445e3de1694b9f214f8 100644 (file)
@@ -156,9 +156,8 @@ static unsigned char bitmask_start_values[] = {0xff, 0xfe, 0xfc, 0xf8, 0xf0, 0xe
 static unsigned char bitmask_end_values[]   = {0xff, 0x01, 0x03, 0x07, 0x0f, 0x1f, 0x3f, 0x7f};
 
 /* XDTLS:  figure out the right values */
-static unsigned int g_probable_mtu[] = {1500 - 28, 512 - 28, 256 - 28};
+static const unsigned int g_probable_mtu[] = {1500, 512, 256};
 
-static unsigned int dtls1_guess_mtu(unsigned int curr_mtu);
 static void dtls1_fix_message_header(SSL *s, unsigned long frag_off, 
        unsigned long frag_len);
 static unsigned char *dtls1_write_message_header(SSL *s,
@@ -211,8 +210,7 @@ dtls1_hm_fragment_new(unsigned long frag_len, int reassembly)
        return frag;
        }
 
-static void
-dtls1_hm_fragment_free(hm_fragment *frag)
+void dtls1_hm_fragment_free(hm_fragment *frag)
        {
 
        if (frag->msg_header.is_ccs)
@@ -225,6 +223,38 @@ dtls1_hm_fragment_free(hm_fragment *frag)
        OPENSSL_free(frag);
        }
 
+static int dtls1_query_mtu(SSL *s)
+{
+       if(s->d1->link_mtu)
+               {
+               s->d1->mtu = s->d1->link_mtu-BIO_dgram_get_mtu_overhead(SSL_get_wbio(s));
+               s->d1->link_mtu = 0;
+               }
+
+       /* AHA!  Figure out the MTU, and stick to the right size */
+       if (s->d1->mtu < dtls1_min_mtu(s))
+               {
+               if(!(SSL_get_options(s) & SSL_OP_NO_QUERY_MTU))
+                       {
+                       s->d1->mtu =
+                               BIO_ctrl(SSL_get_wbio(s), BIO_CTRL_DGRAM_QUERY_MTU, 0, NULL);
+
+                       /* I've seen the kernel return bogus numbers when it doesn't know
+                        * (initial write), so just make sure we have a reasonable number */
+                       if (s->d1->mtu < dtls1_min_mtu(s))
+                               {
+                               /* Set to min mtu */
+                               s->d1->mtu = dtls1_min_mtu(s);
+                               BIO_ctrl(SSL_get_wbio(s), BIO_CTRL_DGRAM_SET_MTU,
+                                       s->d1->mtu, NULL);
+                               }
+                       }
+               else
+                       return 0;
+               }
+       return 1;
+}
+
 /* send s->init_buf in records of type 'type' (SSL3_RT_HANDSHAKE or SSL3_RT_CHANGE_CIPHER_SPEC) */
 int dtls1_do_write(SSL *s, int type)
        {
@@ -232,22 +262,8 @@ int dtls1_do_write(SSL *s, int type)
        int curr_mtu;
        unsigned int len, frag_off, mac_size, blocksize;
 
-       /* AHA!  Figure out the MTU, and stick to the right size */
-       if (s->d1->mtu < dtls1_min_mtu() && !(SSL_get_options(s) & SSL_OP_NO_QUERY_MTU))
-               {
-               s->d1->mtu = 
-                       BIO_ctrl(SSL_get_wbio(s), BIO_CTRL_DGRAM_QUERY_MTU, 0, NULL);
-
-               /* I've seen the kernel return bogus numbers when it doesn't know
-                * (initial write), so just make sure we have a reasonable number */
-               if (s->d1->mtu < dtls1_min_mtu())
-                       {
-                       s->d1->mtu = 0;
-                       s->d1->mtu = dtls1_guess_mtu(s->d1->mtu);
-                       BIO_ctrl(SSL_get_wbio(s), BIO_CTRL_DGRAM_SET_MTU, 
-                               s->d1->mtu, NULL);
-                       }
-               }
+       if(!dtls1_query_mtu(s))
+               return -1;
 #if 0 
        mtu = s->d1->mtu;
 
@@ -271,7 +287,7 @@ int dtls1_do_write(SSL *s, int type)
                }
 #endif
 
-       OPENSSL_assert(s->d1->mtu >= dtls1_min_mtu());  /* should have something reasonable now */
+       OPENSSL_assert(s->d1->mtu >= dtls1_min_mtu(s));  /* should have something reasonable now */
 
        if ( s->init_off == 0  && type == SSL3_RT_HANDSHAKE)
                OPENSSL_assert(s->init_num == 
@@ -330,12 +346,18 @@ int dtls1_do_write(SSL *s, int type)
                                        len = s->init_num;
                                }
 
+                       if ( len < DTLS1_HM_HEADER_LENGTH )
+                               {
+                               /*
+                                * len is so small that we really can't do anything sensible
+                                * so fail
+                                */
+                               return -1;
+                               }
                        dtls1_fix_message_header(s, frag_off, 
                                len - DTLS1_HM_HEADER_LENGTH);
 
                        dtls1_write_message_header(s, (unsigned char *)&s->init_buf->data[s->init_off]);
-
-                       OPENSSL_assert(len >= DTLS1_HM_HEADER_LENGTH);
                        }
 
                ret=dtls1_write_bytes(s,type,&s->init_buf->data[s->init_off],
@@ -350,10 +372,19 @@ int dtls1_do_write(SSL *s, int type)
                         */
                        if ( BIO_ctrl(SSL_get_wbio(s),
                                BIO_CTRL_DGRAM_MTU_EXCEEDED, 0, NULL) > 0 )
-                               s->d1->mtu = BIO_ctrl(SSL_get_wbio(s),
-                                       BIO_CTRL_DGRAM_QUERY_MTU, 0, NULL);
+                               {
+                               if(!(SSL_get_options(s) & SSL_OP_NO_QUERY_MTU))
+                                       {
+                                       if(!dtls1_query_mtu(s))
+                                               return -1;
+                                       }
+                               else
+                                       return -1;
+                               }
                        else
+                               {
                                return(-1);
+                               }
                        }
                else
                        {
@@ -1283,28 +1314,20 @@ dtls1_write_message_header(SSL *s, unsigned char *p)
        return p;
        }
 
-unsigned int 
-dtls1_min_mtu(void)
+unsigned int
+dtls1_link_min_mtu(void)
        {
        return (g_probable_mtu[(sizeof(g_probable_mtu) / 
                sizeof(g_probable_mtu[0])) - 1]);
        }
 
-static unsigned int 
-dtls1_guess_mtu(unsigned int curr_mtu)
+unsigned int
+dtls1_min_mtu(SSL *s)
        {
-       unsigned int i;
-
-       if ( curr_mtu == 0 )
-               return g_probable_mtu[0] ;
-
-       for ( i = 0; i < sizeof(g_probable_mtu)/sizeof(g_probable_mtu[0]); i++)
-               if ( curr_mtu > g_probable_mtu[i])
-                       return g_probable_mtu[i];
-
-       return curr_mtu;
+       return dtls1_link_min_mtu()-BIO_dgram_get_mtu_overhead(SSL_get_wbio(s));
        }
 
+
 void
 dtls1_get_message_header(unsigned char *data, struct hm_header_st *msg_hdr)
        {
@@ -1363,6 +1386,9 @@ dtls1_process_heartbeat(SSL *s)
        /* Read type and payload length first */
        if (1 + 2 + 16 > s->s3->rrec.length)
                return 0; /* silently discard */
+       if (s->s3->rrec.length > SSL3_RT_MAX_PLAIN_LENGTH)
+               return 0; /* silently discard per RFC 6520 sec. 4 */
+
        hbtype = *p++;
        n2s(p, payload);
        if (1 + 2 + payload + 16 > s->s3->rrec.length)