Make the anti-replay feature optional
[openssl.git] / ssl / d1_lib.c
index d839e1ab72187076c25fa4c2a6b0cc120a1845fa..f80851251fe2c40544459644dec8c87149bc8b44 100644 (file)
@@ -161,6 +161,8 @@ int dtls1_clear(SSL *s)
     DTLS_RECORD_LAYER_clear(&s->rlayer);
 
     if (s->d1) {
+        DTLS_timer_cb timer_cb = s->d1->timer_cb;
+
         buffered_messages = s->d1->buffered_messages;
         sent_messages = s->d1->sent_messages;
         mtu = s->d1->mtu;
@@ -170,6 +172,9 @@ int dtls1_clear(SSL *s)
 
         memset(s->d1, 0, sizeof(*s->d1));
 
+        /* Restore the timer callback from previous state */
+        s->d1->timer_cb = timer_cb;
+
         if (s->server) {
             s->d1->cookie_len = sizeof(s->d1->cookie);
         }
@@ -231,11 +236,13 @@ long dtls1_ctrl(SSL *s, int cmd, long larg, void *parg)
         ret = ssl3_ctrl(s, cmd, larg, parg);
         break;
     }
-    return (ret);
+    return ret;
 }
 
 void dtls1_start_timer(SSL *s)
 {
+    unsigned int sec, usec;
+
 #ifndef OPENSSL_NO_SCTP
     /* Disable timer for SCTP */
     if (BIO_dgram_is_sctp(SSL_get_wbio(s))) {
@@ -244,16 +251,34 @@ void dtls1_start_timer(SSL *s)
     }
 #endif
 
-    /* If timer is not set, initialize duration with 1 second */
+    /*
+     * If timer is not set, initialize duration with 1 second or
+     * a user-specified value if the timer callback is installed.
+     */
     if (s->d1->next_timeout.tv_sec == 0 && s->d1->next_timeout.tv_usec == 0) {
-        s->d1->timeout_duration = 1;
+
+        if (s->d1->timer_cb != NULL)
+            s->d1->timeout_duration_us = s->d1->timer_cb(s, 0);
+        else
+            s->d1->timeout_duration_us = 1000000;
     }
 
     /* Set timeout to current time */
     get_current_time(&(s->d1->next_timeout));
 
     /* Add duration to current time */
-    s->d1->next_timeout.tv_sec += s->d1->timeout_duration;
+
+    sec  = s->d1->timeout_duration_us / 1000000;
+    usec = s->d1->timeout_duration_us - (sec * 1000000);
+
+    s->d1->next_timeout.tv_sec  += sec;
+    s->d1->next_timeout.tv_usec += usec;
+
+    if (s->d1->next_timeout.tv_usec >= 1000000) {
+        s->d1->next_timeout.tv_sec++;
+        s->d1->next_timeout.tv_usec -= 1000000;
+    }
+
     BIO_ctrl(SSL_get_rbio(s), BIO_CTRL_DGRAM_SET_NEXT_TIMEOUT, 0,
              &(s->d1->next_timeout));
 }
@@ -318,9 +343,9 @@ int dtls1_is_timer_expired(SSL *s)
 
 void dtls1_double_timeout(SSL *s)
 {
-    s->d1->timeout_duration *= 2;
-    if (s->d1->timeout_duration > 60)
-        s->d1->timeout_duration = 60;
+    s->d1->timeout_duration_us *= 2;
+    if (s->d1->timeout_duration_us > 60000000)
+        s->d1->timeout_duration_us = 60000000;
     dtls1_start_timer(s);
 }
 
@@ -329,7 +354,7 @@ void dtls1_stop_timer(SSL *s)
     /* Reset everything */
     memset(&s->d1->timeout, 0, sizeof(s->d1->timeout));
     memset(&s->d1->next_timeout, 0, sizeof(s->d1->next_timeout));
-    s->d1->timeout_duration = 1;
+    s->d1->timeout_duration_us = 1000000;
     BIO_ctrl(SSL_get_rbio(s), BIO_CTRL_DGRAM_SET_NEXT_TIMEOUT, 0,
              &(s->d1->next_timeout));
     /* Clear retransmission buffer */
@@ -353,7 +378,8 @@ int dtls1_check_timeout_num(SSL *s)
 
     if (s->d1->timeout.num_alerts > DTLS1_TMO_ALERT_COUNT) {
         /* fail the connection, enough alerts have been sent */
-        SSLerr(SSL_F_DTLS1_CHECK_TIMEOUT_NUM, SSL_R_READ_TIMEOUT_EXPIRED);
+        SSLfatal(s, SSL_AD_NO_ALERT, SSL_F_DTLS1_CHECK_TIMEOUT_NUM,
+                 SSL_R_READ_TIMEOUT_EXPIRED);
         return -1;
     }
 
@@ -367,10 +393,15 @@ int dtls1_handle_timeout(SSL *s)
         return 0;
     }
 
-    dtls1_double_timeout(s);
+    if (s->d1->timer_cb != NULL)
+        s->d1->timeout_duration_us = s->d1->timer_cb(s, s->d1->timeout_duration_us);
+    else
+        dtls1_double_timeout(s);
 
-    if (dtls1_check_timeout_num(s) < 0)
+    if (dtls1_check_timeout_num(s) < 0) {
+        /* SSLfatal() already called */
         return -1;
+    }
 
     s->d1->timeout.read_timeouts++;
     if (s->d1->timeout.read_timeouts > DTLS1_TMO_READ_COUNT) {
@@ -378,6 +409,7 @@ int dtls1_handle_timeout(SSL *s)
     }
 
     dtls1_start_timer(s);
+    /* Calls SSLfatal() if required */
     return dtls1_retransmit_buffered_messages(s);
 }
 
@@ -952,3 +984,8 @@ size_t DTLS_get_data_mtu(const SSL *s)
 
     return mtu;
 }
+
+void DTLS_set_timer_cb(SSL *s, DTLS_timer_cb cb)
+{
+    s->d1->timer_cb = cb;
+}