Ensure the record layer is responsible for calculating record overheads
[openssl.git] / ssl / statem / statem_dtls.c
index b673c860abc2c5c3525d5d1db680f958109f7812..2e71014ef8aced4c4ef7ea014016b993633f75af 100644 (file)
@@ -116,7 +116,7 @@ int dtls1_do_write(SSL_CONNECTION *s, int type)
     size_t written;
     size_t curr_mtu;
     int retry = 1;
-    size_t len, frag_off, mac_size, blocksize, used_len;
+    size_t len, frag_off, overhead, used_len;
     SSL *ssl = SSL_CONNECTION_GET_SSL(s);
 
     if (!dtls1_query_mtu(s))
@@ -132,21 +132,7 @@ int dtls1_do_write(SSL_CONNECTION *s, int type)
             return -1;
     }
 
-    if (s->write_hash) {
-        if (s->enc_write_ctx
-            && (EVP_CIPHER_get_flags(EVP_CIPHER_CTX_get0_cipher(s->enc_write_ctx)) &
-                EVP_CIPH_FLAG_AEAD_CIPHER) != 0)
-            mac_size = 0;
-        else
-            mac_size = EVP_MD_CTX_get_size(s->write_hash);
-    } else
-        mac_size = 0;
-
-    if (s->enc_write_ctx &&
-        (EVP_CIPHER_CTX_get_mode(s->enc_write_ctx) == EVP_CIPH_CBC_MODE))
-        blocksize = 2 * EVP_CIPHER_CTX_get_block_size(s->enc_write_ctx);
-    else
-        blocksize = 0;
+    overhead = s->rlayer.wrlmethod->get_max_record_overhead(s->rlayer.wrl);
 
     frag_off = 0;
     s->rwstate = SSL_NOTHING;
@@ -187,8 +173,7 @@ int dtls1_do_write(SSL_CONNECTION *s, int type)
             }
         }
 
-        used_len = BIO_wpending(s->wbio) + DTLS1_RT_HEADER_LENGTH
-            + mac_size + blocksize;
+        used_len = BIO_wpending(s->wbio) + overhead;
         if (s->d1->mtu > used_len)
             curr_mtu = s->d1->mtu - used_len;
         else
@@ -203,9 +188,8 @@ int dtls1_do_write(SSL_CONNECTION *s, int type)
                 s->rwstate = SSL_WRITING;
                 return ret;
             }
-            used_len = DTLS1_RT_HEADER_LENGTH + mac_size + blocksize;
-            if (s->d1->mtu > used_len + DTLS1_HM_HEADER_LENGTH) {
-                curr_mtu = s->d1->mtu - used_len;
+            if (s->d1->mtu > overhead + DTLS1_HM_HEADER_LENGTH) {
+                curr_mtu = s->d1->mtu - overhead;
             } else {
                 /* Shouldn't happen */
                 return -1;