Fix SSL_pending() for DTLS
[openssl.git] / ssl / d1_lib.c
index 6c594a268606d7fa6194ee866767f22e2ef51c37..f80851251fe2c40544459644dec8c87149bc8b44 100644 (file)
 #include <openssl/rand.h>
 #include "ssl_locl.h"
 
-#if defined(OPENSSL_SYS_VXWORKS)
-# include <sys/times.h>
-#elif !defined(OPENSSL_SYS_WIN32)
-# include <sys/time.h>
-#endif
-
 static void get_current_time(struct timeval *t);
 static int dtls1_handshake_write(SSL *s);
 static size_t dtls1_link_min_mtu(void);
@@ -167,6 +161,8 @@ int dtls1_clear(SSL *s)
     DTLS_RECORD_LAYER_clear(&s->rlayer);
 
     if (s->d1) {
+        DTLS_timer_cb timer_cb = s->d1->timer_cb;
+
         buffered_messages = s->d1->buffered_messages;
         sent_messages = s->d1->sent_messages;
         mtu = s->d1->mtu;
@@ -176,6 +172,9 @@ int dtls1_clear(SSL *s)
 
         memset(s->d1, 0, sizeof(*s->d1));
 
+        /* Restore the timer callback from previous state */
+        s->d1->timer_cb = timer_cb;
+
         if (s->server) {
             s->d1->cookie_len = sizeof(s->d1->cookie);
         }
@@ -237,11 +236,13 @@ long dtls1_ctrl(SSL *s, int cmd, long larg, void *parg)
         ret = ssl3_ctrl(s, cmd, larg, parg);
         break;
     }
-    return (ret);
+    return ret;
 }
 
 void dtls1_start_timer(SSL *s)
 {
+    unsigned int sec, usec;
+
 #ifndef OPENSSL_NO_SCTP
     /* Disable timer for SCTP */
     if (BIO_dgram_is_sctp(SSL_get_wbio(s))) {
@@ -250,16 +251,34 @@ void dtls1_start_timer(SSL *s)
     }
 #endif
 
-    /* If timer is not set, initialize duration with 1 second */
+    /*
+     * If timer is not set, initialize duration with 1 second or
+     * a user-specified value if the timer callback is installed.
+     */
     if (s->d1->next_timeout.tv_sec == 0 && s->d1->next_timeout.tv_usec == 0) {
-        s->d1->timeout_duration = 1;
+
+        if (s->d1->timer_cb != NULL)
+            s->d1->timeout_duration_us = s->d1->timer_cb(s, 0);
+        else
+            s->d1->timeout_duration_us = 1000000;
     }
 
     /* Set timeout to current time */
     get_current_time(&(s->d1->next_timeout));
 
     /* Add duration to current time */
-    s->d1->next_timeout.tv_sec += s->d1->timeout_duration;
+
+    sec  = s->d1->timeout_duration_us / 1000000;
+    usec = s->d1->timeout_duration_us - (sec * 1000000);
+
+    s->d1->next_timeout.tv_sec  += sec;
+    s->d1->next_timeout.tv_usec += usec;
+
+    if (s->d1->next_timeout.tv_usec >= 1000000) {
+        s->d1->next_timeout.tv_sec++;
+        s->d1->next_timeout.tv_usec -= 1000000;
+    }
+
     BIO_ctrl(SSL_get_rbio(s), BIO_CTRL_DGRAM_SET_NEXT_TIMEOUT, 0,
              &(s->d1->next_timeout));
 }
@@ -324,9 +343,9 @@ int dtls1_is_timer_expired(SSL *s)
 
 void dtls1_double_timeout(SSL *s)
 {
-    s->d1->timeout_duration *= 2;
-    if (s->d1->timeout_duration > 60)
-        s->d1->timeout_duration = 60;
+    s->d1->timeout_duration_us *= 2;
+    if (s->d1->timeout_duration_us > 60000000)
+        s->d1->timeout_duration_us = 60000000;
     dtls1_start_timer(s);
 }
 
@@ -335,7 +354,7 @@ void dtls1_stop_timer(SSL *s)
     /* Reset everything */
     memset(&s->d1->timeout, 0, sizeof(s->d1->timeout));
     memset(&s->d1->next_timeout, 0, sizeof(s->d1->next_timeout));
-    s->d1->timeout_duration = 1;
+    s->d1->timeout_duration_us = 1000000;
     BIO_ctrl(SSL_get_rbio(s), BIO_CTRL_DGRAM_SET_NEXT_TIMEOUT, 0,
              &(s->d1->next_timeout));
     /* Clear retransmission buffer */
@@ -359,7 +378,8 @@ int dtls1_check_timeout_num(SSL *s)
 
     if (s->d1->timeout.num_alerts > DTLS1_TMO_ALERT_COUNT) {
         /* fail the connection, enough alerts have been sent */
-        SSLerr(SSL_F_DTLS1_CHECK_TIMEOUT_NUM, SSL_R_READ_TIMEOUT_EXPIRED);
+        SSLfatal(s, SSL_AD_NO_ALERT, SSL_F_DTLS1_CHECK_TIMEOUT_NUM,
+                 SSL_R_READ_TIMEOUT_EXPIRED);
         return -1;
     }
 
@@ -373,10 +393,15 @@ int dtls1_handle_timeout(SSL *s)
         return 0;
     }
 
-    dtls1_double_timeout(s);
+    if (s->d1->timer_cb != NULL)
+        s->d1->timeout_duration_us = s->d1->timer_cb(s, s->d1->timeout_duration_us);
+    else
+        dtls1_double_timeout(s);
 
-    if (dtls1_check_timeout_num(s) < 0)
+    if (dtls1_check_timeout_num(s) < 0) {
+        /* SSLfatal() already called */
         return -1;
+    }
 
     s->d1->timeout.read_timeouts++;
     if (s->d1->timeout.read_timeouts > DTLS1_TMO_READ_COUNT) {
@@ -384,6 +409,7 @@ int dtls1_handle_timeout(SSL *s)
     }
 
     dtls1_start_timer(s);
+    /* Calls SSLfatal() if required */
     return dtls1_retransmit_buffered_messages(s);
 }
 
@@ -958,3 +984,8 @@ size_t DTLS_get_data_mtu(const SSL *s)
 
     return mtu;
 }
+
+void DTLS_set_timer_cb(SSL *s, DTLS_timer_cb cb)
+{
+    s->d1->timer_cb = cb;
+}