Convert dtls_write_records to use standard record layer functions
[openssl.git] / ssl / record / rec_layer_s3.c
index 0c70995312a09ea6a40f48507a6247fcde894f00..04f130bc2eba140185c0c5205945f124bda295bd 100644 (file)
 #include "record_local.h"
 #include "internal/packet.h"
 
-#if     defined(OPENSSL_SMALL_FOOTPRINT) || \
-        !(      defined(AES_ASM) &&     ( \
-                defined(__x86_64)       || defined(__x86_64__)  || \
-                defined(_M_AMD64)       || defined(_M_X64)      ) \
-        )
-# undef EVP_CIPH_FLAG_TLS1_1_MULTIBLOCK
-# define EVP_CIPH_FLAG_TLS1_1_MULTIBLOCK 0
-#endif
-
 void RECORD_LAYER_init(RECORD_LAYER *rl, SSL_CONNECTION *s)
 {
     rl->s = s;
@@ -197,10 +188,6 @@ int ssl3_write_bytes(SSL *ssl, int type, const void *buf_, size_t len,
     const unsigned char *buf = buf_;
     size_t tot;
     size_t n, max_send_fragment, split_send_fragment, maxpipes;
-    /* TODO(RECLAYER): Re-enable multiblock code */
-#if 0 && !defined(OPENSSL_NO_MULTIBLOCK) && EVP_CIPH_FLAG_TLS1_1_MULTIBLOCK
-    size_t nw;
-#endif
     int i;
     SSL_CONNECTION *s = SSL_CONNECTION_FROM_SSL_ONLY(ssl);
     OSSL_RECORD_TEMPLATE tmpls[SSL_MAX_PIPELINES];
@@ -285,141 +272,6 @@ int ssl3_write_bytes(SSL *ssl, int type, const void *buf_, size_t len,
         s->rlayer.wpend_ret = len;
     }
 
-/* TODO(RECLAYER): Re-enable multiblock code */
-#if 0 && !defined(OPENSSL_NO_MULTIBLOCK) && EVP_CIPH_FLAG_TLS1_1_MULTIBLOCK
-    /*
-     * Depending on platform multi-block can deliver several *times*
-     * better performance. Downside is that it has to allocate
-     * jumbo buffer to accommodate up to 8 records, but the
-     * compromise is considered worthy.
-     */
-    if (type == SSL3_RT_APPLICATION_DATA
-            && len >= 4 * (max_send_fragment = ssl_get_max_send_fragment(s))
-            && s->compress == NULL
-            && s->msg_callback == NULL
-            && !SSL_WRITE_ETM(s)
-            && SSL_USE_EXPLICIT_IV(s)
-            && !BIO_get_ktls_send(s->wbio)
-            && (EVP_CIPHER_get_flags(EVP_CIPHER_CTX_get0_cipher(s->enc_write_ctx))
-                & EVP_CIPH_FLAG_TLS1_1_MULTIBLOCK) != 0) {
-        unsigned char aad[13];
-        EVP_CTRL_TLS1_1_MULTIBLOCK_PARAM mb_param;
-        size_t packlen;
-        int packleni;
-
-        /* minimize address aliasing conflicts */
-        if ((max_send_fragment & 0xfff) == 0)
-            max_send_fragment -= 512;
-
-        if (tot == 0 || wb->buf == NULL) { /* allocate jumbo buffer */
-            ssl3_release_write_buffer(s);
-
-            packlen = EVP_CIPHER_CTX_ctrl(s->enc_write_ctx,
-                                          EVP_CTRL_TLS1_1_MULTIBLOCK_MAX_BUFSIZE,
-                                          (int)max_send_fragment, NULL);
-
-            if (len >= 8 * max_send_fragment)
-                packlen *= 8;
-            else
-                packlen *= 4;
-
-            if (!ssl3_setup_write_buffer(s, 1, packlen)) {
-                /* SSLfatal() already called */
-                return -1;
-            }
-        } else if (tot == len) { /* done? */
-            /* free jumbo buffer */
-            ssl3_release_write_buffer(s);
-            *written = tot;
-            return 1;
-        }
-
-        n = (len - tot);
-        for (;;) {
-            if (n < 4 * max_send_fragment) {
-                /* free jumbo buffer */
-                ssl3_release_write_buffer(s);
-                break;
-            }
-
-            if (s->s3.alert_dispatch) {
-                i = ssl->method->ssl_dispatch_alert(ssl);
-                if (i <= 0) {
-                    /* SSLfatal() already called if appropriate */
-                    s->rlayer.wnum = tot;
-                    return i;
-                }
-            }
-
-            if (n >= 8 * max_send_fragment)
-                nw = max_send_fragment * (mb_param.interleave = 8);
-            else
-                nw = max_send_fragment * (mb_param.interleave = 4);
-
-            memcpy(aad, s->rlayer.write_sequence, 8);
-            aad[8] = type;
-            aad[9] = (unsigned char)(s->version >> 8);
-            aad[10] = (unsigned char)(s->version);
-            aad[11] = 0;
-            aad[12] = 0;
-            mb_param.out = NULL;
-            mb_param.inp = aad;
-            mb_param.len = nw;
-
-            packleni = EVP_CIPHER_CTX_ctrl(s->enc_write_ctx,
-                                          EVP_CTRL_TLS1_1_MULTIBLOCK_AAD,
-                                          sizeof(mb_param), &mb_param);
-            packlen = (size_t)packleni;
-            if (packleni <= 0 || packlen > wb->len) { /* never happens */
-                /* free jumbo buffer */
-                ssl3_release_write_buffer(s);
-                break;
-            }
-
-            mb_param.out = wb->buf;
-            mb_param.inp = &buf[tot];
-            mb_param.len = nw;
-
-            if (EVP_CIPHER_CTX_ctrl(s->enc_write_ctx,
-                                    EVP_CTRL_TLS1_1_MULTIBLOCK_ENCRYPT,
-                                    sizeof(mb_param), &mb_param) <= 0)
-                return -1;
-
-            s->rlayer.write_sequence[7] += mb_param.interleave;
-            if (s->rlayer.write_sequence[7] < mb_param.interleave) {
-                int j = 6;
-                while (j >= 0 && (++s->rlayer.write_sequence[j--]) == 0) ;
-            }
-
-            wb->offset = 0;
-            wb->left = packlen;
-
-            s->rlayer.wpend_tot = nw;
-            s->rlayer.wpend_buf = &buf[tot];
-            s->rlayer.wpend_type = type;
-            s->rlayer.wpend_ret = nw;
-
-            i = ssl3_write_pending(s, type, &buf[tot], nw, &tmpwrit);
-            if (i <= 0) {
-                /* SSLfatal() already called if appropriate */
-                if (i < 0 && (!s->wbio || !BIO_should_retry(s->wbio))) {
-                    /* free jumbo buffer */
-                    ssl3_release_write_buffer(s);
-                }
-                s->rlayer.wnum = tot;
-                return i;
-            }
-            if (tmpwrit == n) {
-                /* free jumbo buffer */
-                ssl3_release_write_buffer(s);
-                *written = tot + tmpwrit;
-                return 1;
-            }
-            n -= tmpwrit;
-            tot += tmpwrit;
-        }
-    } else
-#endif  /* !defined(OPENSSL_NO_MULTIBLOCK) && EVP_CIPH_FLAG_TLS1_1_MULTIBLOCK */
     if (tot == len) {           /* done? */
         *written = tot;
         return 1;
@@ -439,37 +291,7 @@ int ssl3_write_bytes(SSL *ssl, int type, const void *buf_, size_t len,
 
     max_send_fragment = ssl_get_max_send_fragment(s);
     split_send_fragment = ssl_get_split_send_fragment(s);
-    /*
-     * TODO(RECLAYER): This comment is now out-of-date and probably needs to
-     * move somewhere else
-     *
-     * If max_pipelines is 0 then this means "undefined" and we default to
-     * 1 pipeline. Similarly if the cipher does not support pipelined
-     * processing then we also only use 1 pipeline, or if we're not using
-     * explicit IVs
-     */
-    maxpipes = s->max_pipelines;
-    if (maxpipes > SSL_MAX_PIPELINES) {
-        /*
-         * We should have prevented this when we set max_pipelines so we
-         * shouldn't get here
-         */
-        SSLfatal(s, SSL_AD_INTERNAL_ERROR, ERR_R_INTERNAL_ERROR);
-        return -1;
-    }
-    /* If no explicit maxpipes configuration - default to 1 */
-    /* TODO(RECLAYER): Should we ask the record layer how many pipes it supports? */
-    if (maxpipes <= 0)
-        maxpipes = 1;
-#if 0
-    /* TODO(RECLAYER): FIX ME */
-    if (maxpipes == 0
-        || s->enc_write_ctx == NULL
-        || (EVP_CIPHER_get_flags(EVP_CIPHER_CTX_get0_cipher(s->enc_write_ctx))
-            & EVP_CIPH_FLAG_PIPELINE) == 0
-        || !SSL_USE_EXPLICIT_IV(s))
-        maxpipes = 1;
-#endif
+
     if (max_send_fragment == 0
             || split_send_fragment == 0
             || split_send_fragment > max_send_fragment) {
@@ -495,39 +317,56 @@ int ssl3_write_bytes(SSL *ssl, int type, const void *buf_, size_t len,
 
     for (;;) {
         size_t tmppipelen, remain;
-        size_t numpipes, j, lensofar = 0;
+        size_t j, lensofar = 0;
 
-        if (n == 0)
-            numpipes = 1;
-        else
-            numpipes = ((n - 1) / split_send_fragment) + 1;
-        if (numpipes > maxpipes)
-            numpipes = maxpipes;
+        /*
+        * Ask the record layer how it would like to split the amount of data
+        * that we have, and how many of those records it would like in one go.
+        */
+        maxpipes = s->rlayer.wrlmethod->get_max_records(s->rlayer.wrl, type, n,
+                                                        max_send_fragment,
+                                                        &split_send_fragment);
+        /*
+        * If max_pipelines is 0 then this means "undefined" and we default to
+        * whatever the record layer wants to do. Otherwise we use the smallest
+        * value from the number requested by the record layer, and max number
+        * configured by the user.
+        */
+        if (s->max_pipelines > 0 && maxpipes > s->max_pipelines)
+            maxpipes = s->max_pipelines;
+
+        if (maxpipes > SSL_MAX_PIPELINES)
+            maxpipes = SSL_MAX_PIPELINES;
+
+        if (split_send_fragment > max_send_fragment) {
+            SSLfatal(s, SSL_AD_INTERNAL_ERROR, ERR_R_INTERNAL_ERROR);
+            return -1;
+        }
 
-        if (n / numpipes >= max_send_fragment) {
+        if (n / maxpipes >= split_send_fragment) {
             /*
              * We have enough data to completely fill all available
              * pipelines
              */
-            for (j = 0; j < numpipes; j++) {
+            for (j = 0; j < maxpipes; j++) {
                 tmpls[j].type = type;
                 tmpls[j].version = recversion;
-                tmpls[j].buf = &(buf[tot]) + (j * max_send_fragment);
-                tmpls[j].buflen = max_send_fragment;
+                tmpls[j].buf = &(buf[tot]) + (j * split_send_fragment);
+                tmpls[j].buflen = split_send_fragment;
             }
             /* Remember how much data we are going to be sending */
-            s->rlayer.wpend_tot = numpipes * max_send_fragment;
+            s->rlayer.wpend_tot = maxpipes * split_send_fragment;
         } else {
             /* We can partially fill all available pipelines */
-            tmppipelen = n / numpipes;
-            remain = n % numpipes;
+            tmppipelen = n / maxpipes;
+            remain = n % maxpipes;
             /*
              * If there is a remainder we add an extra byte to the first few
              * pipelines
              */
             if (remain > 0)
                 tmppipelen++;
-            for (j = 0; j < numpipes; j++) {
+            for (j = 0; j < maxpipes; j++) {
                 tmpls[j].type = type;
                 tmpls[j].version = recversion;
                 tmpls[j].buf = &(buf[tot]) + lensofar;
@@ -541,7 +380,7 @@ int ssl3_write_bytes(SSL *ssl, int type, const void *buf_, size_t len,
         }
 
         i = HANDLE_RLAYER_WRITE_RETURN(s,
-            s->rlayer.wrlmethod->write_records(s->rlayer.wrl, tmpls, numpipes));
+            s->rlayer.wrlmethod->write_records(s->rlayer.wrl, tmpls, maxpipes));
         if (i <= 0) {
             /* SSLfatal() already called if appropriate */
             s->rlayer.wnum = tot;
@@ -1223,6 +1062,7 @@ static const OSSL_DISPATCH rlayer_dispatch[] = {
 };
 
 static const OSSL_RECORD_METHOD *ssl_select_next_record_layer(SSL_CONNECTION *s,
+                                                              int direction,
                                                               int level)
 {
 
@@ -1242,7 +1082,8 @@ static const OSSL_RECORD_METHOD *ssl_select_next_record_layer(SSL_CONNECTION *s,
 #endif
 
     /* Default to the current OSSL_RECORD_METHOD */
-    return s->rlayer.rrlmethod;
+    return direction == OSSL_RECORD_DIRECTION_READ ? s->rlayer.rrlmethod
+                                                   : s->rlayer.wrlmethod;
 }
 
 static int ssl_post_record_layer_select(SSL_CONNECTION *s, int direction)
@@ -1294,11 +1135,14 @@ int ssl_set_new_record_layer(SSL_CONNECTION *s, int version,
     SSL_CTX *sctx = SSL_CONNECTION_GET_CTX(s);
     const OSSL_RECORD_METHOD *meth;
     int use_etm, stream_mac = 0, tlstree = 0;
-    unsigned int maxfrag = SSL3_RT_MAX_PLAIN_LENGTH;
+    unsigned int maxfrag = (direction == OSSL_RECORD_DIRECTION_WRITE)
+                           ? ssl_get_max_send_fragment(s)
+                           : SSL3_RT_MAX_PLAIN_LENGTH;
     int use_early_data = 0;
     uint32_t max_early_data;
+    COMP_METHOD *compm = (comp == NULL) ? NULL : comp->method;
 
-    meth = ssl_select_next_record_layer(s, level);
+    meth = ssl_select_next_record_layer(s, direction, level);
 
     if (direction == OSSL_RECORD_DIRECTION_READ) {
         thismethod = &s->rlayer.rrlmethod;
@@ -1363,9 +1207,16 @@ int ssl_set_new_record_layer(SSL_CONNECTION *s, int version,
         *set++ = OSSL_PARAM_construct_int(OSSL_LIBSSL_RECORD_LAYER_PARAM_TLSTREE,
                                           &tlstree);
 
-    if (s->session != NULL && USE_MAX_FRAGMENT_LENGTH_EXT(s->session))
+    /*
+     * We only need to do this for the read side. The write side should already
+     * have the correct value due to the ssl_get_max_send_fragment() call above
+     */
+    if (direction == OSSL_RECORD_DIRECTION_READ
+            && s->session != NULL
+            && USE_MAX_FRAGMENT_LENGTH_EXT(s->session))
         maxfrag = GET_MAX_FRAGMENT_LENGTH(s->session);
 
+
     if (maxfrag != SSL3_RT_MAX_PLAIN_LENGTH)
         *set++ = OSSL_PARAM_construct_uint(OSSL_LIBSSL_RECORD_LAYER_PARAM_MAX_FRAG_LEN,
                                            &maxfrag);
@@ -1375,7 +1226,6 @@ int ssl_set_new_record_layer(SSL_CONNECTION *s, int version,
      * using the early keys. A server also needs to worry about rejected early
      * data that might arrive when the handshake keys are in force.
      */
-    /* TODO(RECLAYER): Check this when doing the "write" record layer */
     if (s->server && direction == OSSL_RECORD_DIRECTION_READ) {
         use_early_data = (level == OSSL_RECORD_PROTECTION_LEVEL_EARLY
                           || level == OSSL_RECORD_PROTECTION_LEVEL_HANDSHAKE);
@@ -1417,6 +1267,10 @@ int ssl_set_new_record_layer(SSL_CONNECTION *s, int version,
                 return 0;
             }
             s->rlayer.rrlnext = next;
+        } else {
+            if (SSL_CONNECTION_IS_DTLS(s)
+                    && level != OSSL_RECORD_PROTECTION_LEVEL_NONE)
+                epoch =  DTLS_RECORD_LAYER_get_w_epoch(&s->rlayer) + 1; /* new epoch */
         }
 
         /*
@@ -1443,7 +1297,7 @@ int ssl_set_new_record_layer(SSL_CONNECTION *s, int version,
                                        s->server, direction, level, epoch,
                                        key, keylen, iv, ivlen, mackey,
                                        mackeylen, ciph, taglen, mactype, md,
-                                       comp, prev, thisbio, next, NULL, NULL,
+                                       compm, prev, thisbio, next, NULL, NULL,
                                        settings, options, rlayer_dispatch_tmp,
                                        s, &newrl);
         BIO_free(prev);
@@ -1475,9 +1329,17 @@ int ssl_set_new_record_layer(SSL_CONNECTION *s, int version,
         break;
     }
 
-    if (*thismethod != NULL && !(*thismethod)->free(*thisrl)) {
-        SSLfatal(s, SSL_AD_INTERNAL_ERROR, ERR_R_INTERNAL_ERROR);
-        return 0;
+    /*
+     * Free the old record layer if we have one except in the case of DTLS when
+     * writing. In that case the record layer is still referenced by buffered
+     * messages for potential retransmit. Only when those buffered messages get
+     * freed do we free the record layer object (see dtls1_hm_fragment_free)
+     */
+    if (!SSL_CONNECTION_IS_DTLS(s) || direction == OSSL_RECORD_DIRECTION_READ) {
+        if (*thismethod != NULL && !(*thismethod)->free(*thisrl)) {
+            SSLfatal(s, SSL_AD_INTERNAL_ERROR, ERR_R_INTERNAL_ERROR);
+            return 0;
+        }
     }
 
     *thisrl = newrl;