Use the read and write buffers in DTLSv1_listen()
[openssl.git] / ssl / d1_lib.c
index 213fad5a8ded074730b8be42ca6da5180b83fee0..38adda3355a7a567474785a59b9fa56998c01b0e 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright 2005-2016 The OpenSSL Project Authors. All Rights Reserved.
+ * Copyright 2005-2017 The OpenSSL Project Authors. All Rights Reserved.
  *
  * Licensed under the OpenSSL license (the "License").  You may not use
  * this file except in compliance with the License.  You can obtain a copy
@@ -7,20 +7,12 @@
  * https://www.openssl.org/source/license.html
  */
 
+#include "e_os.h"
 #include <stdio.h>
-#define USE_SOCKETS
 #include <openssl/objects.h>
 #include <openssl/rand.h>
 #include "ssl_locl.h"
 
-#if defined(OPENSSL_SYS_VMS)
-# include <sys/timeb.h>
-#elif defined(OPENSSL_SYS_VXWORKS)
-# include <sys/times.h>
-#elif !defined(OPENSSL_SYS_WIN32)
-# include <sys/time.h>
-#endif
-
 static void get_current_time(struct timeval *t);
 static int dtls1_handshake_write(SSL *s);
 static size_t dtls1_link_min_mtu(void);
@@ -81,10 +73,10 @@ int dtls1_new(SSL *s)
     }
 
     if (!ssl3_new(s))
-        return (0);
+        return 0;
     if ((d1 = OPENSSL_zalloc(sizeof(*d1))) == NULL) {
         ssl3_free(s);
-        return (0);
+        return 0;
     }
 
     d1->buffered_messages = pqueue_new();
@@ -102,12 +94,15 @@ int dtls1_new(SSL *s)
         pqueue_free(d1->sent_messages);
         OPENSSL_free(d1);
         ssl3_free(s);
-        return (0);
+        return 0;
     }
 
     s->d1 = d1;
-    s->method->ssl_clear(s);
-    return (1);
+
+    if (!s->method->ssl_clear(s))
+        return 0;
+
+    return 1;
 }
 
 static void dtls1_clear_queues(SSL *s)
@@ -156,7 +151,7 @@ void dtls1_free(SSL *s)
     s->d1 = NULL;
 }
 
-void dtls1_clear(SSL *s)
+int dtls1_clear(SSL *s)
 {
     pqueue *buffered_messages;
     pqueue *sent_messages;
@@ -166,6 +161,8 @@ void dtls1_clear(SSL *s)
     DTLS_RECORD_LAYER_clear(&s->rlayer);
 
     if (s->d1) {
+        DTLS_timer_cb timer_cb = s->d1->timer_cb;
+
         buffered_messages = s->d1->buffered_messages;
         sent_messages = s->d1->sent_messages;
         mtu = s->d1->mtu;
@@ -175,6 +172,9 @@ void dtls1_clear(SSL *s)
 
         memset(s->d1, 0, sizeof(*s->d1));
 
+        /* Restore the timer callback from previous state */
+        s->d1->timer_cb = timer_cb;
+
         if (s->server) {
             s->d1->cookie_len = sizeof(s->d1->cookie);
         }
@@ -188,7 +188,8 @@ void dtls1_clear(SSL *s)
         s->d1->sent_messages = sent_messages;
     }
 
-    ssl3_clear(s);
+    if (!ssl3_clear(s))
+        return 0;
 
     if (s->method->version == DTLS_ANY_VERSION)
         s->version = DTLS_MAX_VERSION;
@@ -198,6 +199,8 @@ void dtls1_clear(SSL *s)
 #endif
     else
         s->version = s->method->version;
+
+    return 1;
 }
 
 long dtls1_ctrl(SSL *s, int cmd, long larg, void *parg)
@@ -233,11 +236,13 @@ long dtls1_ctrl(SSL *s, int cmd, long larg, void *parg)
         ret = ssl3_ctrl(s, cmd, larg, parg);
         break;
     }
-    return (ret);
+    return ret;
 }
 
 void dtls1_start_timer(SSL *s)
 {
+    unsigned int sec, usec;
+
 #ifndef OPENSSL_NO_SCTP
     /* Disable timer for SCTP */
     if (BIO_dgram_is_sctp(SSL_get_wbio(s))) {
@@ -246,16 +251,34 @@ void dtls1_start_timer(SSL *s)
     }
 #endif
 
-    /* If timer is not set, initialize duration with 1 second */
+    /*
+     * If timer is not set, initialize duration with 1 second or
+     * a user-specified value if the timer callback is installed.
+     */
     if (s->d1->next_timeout.tv_sec == 0 && s->d1->next_timeout.tv_usec == 0) {
-        s->d1->timeout_duration = 1;
+
+        if (s->d1->timer_cb != NULL)
+            s->d1->timeout_duration_us = s->d1->timer_cb(s, 0);
+        else
+            s->d1->timeout_duration_us = 1000000;
     }
 
     /* Set timeout to current time */
     get_current_time(&(s->d1->next_timeout));
 
     /* Add duration to current time */
-    s->d1->next_timeout.tv_sec += s->d1->timeout_duration;
+
+    sec  = s->d1->timeout_duration_us / 1000000;
+    usec = s->d1->timeout_duration_us - (sec * 1000000);
+
+    s->d1->next_timeout.tv_sec  += sec;
+    s->d1->next_timeout.tv_usec += usec;
+
+    if (s->d1->next_timeout.tv_usec >= 1000000) {
+        s->d1->next_timeout.tv_sec++;
+        s->d1->next_timeout.tv_usec -= 1000000;
+    }
+
     BIO_ctrl(SSL_get_rbio(s), BIO_CTRL_DGRAM_SET_NEXT_TIMEOUT, 0,
              &(s->d1->next_timeout));
 }
@@ -320,9 +343,9 @@ int dtls1_is_timer_expired(SSL *s)
 
 void dtls1_double_timeout(SSL *s)
 {
-    s->d1->timeout_duration *= 2;
-    if (s->d1->timeout_duration > 60)
-        s->d1->timeout_duration = 60;
+    s->d1->timeout_duration_us *= 2;
+    if (s->d1->timeout_duration_us > 60000000)
+        s->d1->timeout_duration_us = 60000000;
     dtls1_start_timer(s);
 }
 
@@ -331,7 +354,7 @@ void dtls1_stop_timer(SSL *s)
     /* Reset everything */
     memset(&s->d1->timeout, 0, sizeof(s->d1->timeout));
     memset(&s->d1->next_timeout, 0, sizeof(s->d1->next_timeout));
-    s->d1->timeout_duration = 1;
+    s->d1->timeout_duration_us = 1000000;
     BIO_ctrl(SSL_get_rbio(s), BIO_CTRL_DGRAM_SET_NEXT_TIMEOUT, 0,
              &(s->d1->next_timeout));
     /* Clear retransmission buffer */
@@ -355,7 +378,8 @@ int dtls1_check_timeout_num(SSL *s)
 
     if (s->d1->timeout.num_alerts > DTLS1_TMO_ALERT_COUNT) {
         /* fail the connection, enough alerts have been sent */
-        SSLerr(SSL_F_DTLS1_CHECK_TIMEOUT_NUM, SSL_R_READ_TIMEOUT_EXPIRED);
+        SSLfatal(s, SSL_AD_NO_ALERT, SSL_F_DTLS1_CHECK_TIMEOUT_NUM,
+                 SSL_R_READ_TIMEOUT_EXPIRED);
         return -1;
     }
 
@@ -369,10 +393,15 @@ int dtls1_handle_timeout(SSL *s)
         return 0;
     }
 
-    dtls1_double_timeout(s);
+    if (s->d1->timer_cb != NULL)
+        s->d1->timeout_duration_us = s->d1->timer_cb(s, s->d1->timeout_duration_us);
+    else
+        dtls1_double_timeout(s);
 
-    if (dtls1_check_timeout_num(s) < 0)
+    if (dtls1_check_timeout_num(s) < 0) {
+        /* SSLfatal() already called */
         return -1;
+    }
 
     s->d1->timeout.read_timeouts++;
     if (s->d1->timeout.read_timeouts > DTLS1_TMO_READ_COUNT) {
@@ -380,6 +409,7 @@ int dtls1_handle_timeout(SSL *s)
     }
 
     dtls1_start_timer(s);
+    /* Calls SSLfatal() if required */
     return dtls1_retransmit_buffered_messages(s);
 }
 
@@ -404,11 +434,6 @@ static void get_current_time(struct timeval *t)
 # endif
     t->tv_sec = (long)(now.ul / 10000000);
     t->tv_usec = ((int)(now.ul % 10000000)) / 10;
-#elif defined(OPENSSL_SYS_VMS)
-    struct timeb tb;
-    ftime(&tb);
-    t->tv_sec = (long)tb.time;
-    t->tv_usec = (long)tb.millitm * 1000;
 #else
     gettimeofday(t, NULL);
 #endif
@@ -424,11 +449,10 @@ int DTLSv1_listen(SSL *s, BIO_ADDR *client)
     unsigned char cookie[DTLS1_COOKIE_LENGTH];
     unsigned char seq[SEQ_NUM_SIZE];
     const unsigned char *data;
-    unsigned char *buf;
+    unsigned char *buf, *wbuf;
     size_t fragoff, fraglen, msglen;
     unsigned int rectype, versmajor, msgseq, msgtype, clientvers, cookielen;
     BIO *rbio, *wbio;
-    BUF_MEM *bufm;
     BIO_ADDR *tmpclient = NULL;
     PACKET pkt, msgpkt, msgpayload, session, cookiepkt;
 
@@ -470,34 +494,19 @@ int DTLSv1_listen(SSL *s, BIO_ADDR *client)
         return -1;
     }
 
-    if (s->init_buf == NULL) {
-        if ((bufm = BUF_MEM_new()) == NULL) {
-            SSLerr(SSL_F_DTLSV1_LISTEN, ERR_R_MALLOC_FAILURE);
-            return -1;
-        }
-
-        if (!BUF_MEM_grow(bufm, SSL3_RT_MAX_PLAIN_LENGTH)) {
-            BUF_MEM_free(bufm);
-            SSLerr(SSL_F_DTLSV1_LISTEN, ERR_R_MALLOC_FAILURE);
-            return -1;
-        }
-        s->init_buf = bufm;
+    if (!ssl3_setup_buffers(s)) {
+        /* SSLerr already called */
+        return -1;
     }
-    buf = (unsigned char *)s->init_buf->data;
+    buf = RECORD_LAYER_get_rbuf(&s->rlayer)->buf;
+    wbuf = RECORD_LAYER_get_wbuf(&s->rlayer)[0].buf;
 
     do {
         /* Get a packet */
 
         clear_sys_error();
-        /*
-         * Technically a ClientHello could be SSL3_RT_MAX_PLAIN_LENGTH
-         * + DTLS1_RT_HEADER_LENGTH bytes long. Normally init_buf does not store
-         * the record header as well, but we do here. We've set up init_buf to
-         * be the standard size for simplicity. In practice we shouldn't ever
-         * receive a ClientHello as long as this. If we do it will get dropped
-         * in the record length check below.
-         */
-        n = BIO_read(rbio, buf, SSL3_RT_MAX_PLAIN_LENGTH);
+        n = BIO_read(rbio, buf, SSL3_RT_MAX_PLAIN_LENGTH
+                                + DTLS1_RT_HEADER_LENGTH);
 
         if (n <= 0) {
             if (BIO_should_retry(rbio)) {
@@ -707,7 +716,11 @@ int DTLSv1_listen(SSL *s, BIO_ADDR *client)
                                                                : s->version;
 
             /* Construct the record and message headers */
-            if (!WPACKET_init(&wpkt, s->init_buf)
+            if (!WPACKET_init_static_len(&wpkt,
+                                         wbuf,
+                                         SSL3_RT_MAX_PLAIN_LENGTH
+                                         + DTLS1_RT_HEADER_LENGTH,
+                                         0)
                     || !WPACKET_put_bytes_u8(&wpkt, SSL3_RT_HANDSHAKE)
                     || !WPACKET_put_bytes_u16(&wpkt, version)
                        /*
@@ -765,8 +778,8 @@ int DTLSv1_listen(SSL *s, BIO_ADDR *client)
              * plus one byte for the message content type. The source is the
              * last 3 bytes of the message header
              */
-            memcpy(&buf[DTLS1_RT_HEADER_LENGTH + 1],
-                   &buf[DTLS1_RT_HEADER_LENGTH + DTLS1_HM_HEADER_LENGTH - 3],
+            memcpy(&wbuf[DTLS1_RT_HEADER_LENGTH + 1],
+                   &wbuf[DTLS1_RT_HEADER_LENGTH + DTLS1_HM_HEADER_LENGTH - 3],
                    3);
 
             if (s->msg_callback)
@@ -790,7 +803,7 @@ int DTLSv1_listen(SSL *s, BIO_ADDR *client)
             tmpclient = NULL;
 
             /* TODO(size_t): convert this call */
-            if (BIO_write(wbio, buf, wreclen) < (int)wreclen) {
+            if (BIO_write(wbio, wbuf, wreclen) < (int)wreclen) {
                 if (BIO_should_retry(wbio)) {
                     /*
                      * Non-blocking IO...but we're stateless, so we're just
@@ -840,6 +853,7 @@ int DTLSv1_listen(SSL *s, BIO_ADDR *client)
     if (BIO_dgram_get_peer(rbio, client) <= 0)
         BIO_ADDR_clear(client);
 
+
     ret = 1;
     clearpkt = 0;
  end:
@@ -959,3 +973,8 @@ size_t DTLS_get_data_mtu(const SSL *s)
 
     return mtu;
 }
+
+void DTLS_set_timer_cb(SSL *s, DTLS_timer_cb cb)
+{
+    s->d1->timer_cb = cb;
+}