Convert dtls_write_records to use standard record layer functions
authorMatt Caswell <matt@openssl.org>
Thu, 13 Oct 2022 15:44:22 +0000 (16:44 +0100)
committerMatt Caswell <matt@openssl.org>
Thu, 20 Oct 2022 13:39:33 +0000 (14:39 +0100)
We have standard functions for most of the work that dtls_write_records
does - so we convert it to use those functions instead.

Reviewed-by: Richard Levitte <levitte@openssl.org>
Reviewed-by: Tomas Mraz <tomas@openssl.org>
Reviewed-by: Hugo Landau <hlandau@openssl.org>
(Merged from https://github.com/openssl/openssl/pull/19424)

ssl/record/methods/dtls_meth.c
ssl/record/methods/recmethod_local.h
ssl/record/methods/tls1_meth.c
ssl/record/methods/tls_common.c
ssl/record/methods/tlsany_meth.c
ssl/record/rec_layer_s3.c
ssl/ssl_local.h
ssl/statem/statem_dtls.c

index 1b51c84893678a8e9a7008b42be81a263d35a98a..e6c71ed1e7b9b5c05d8b30144951254161d45f7e 100644 (file)
@@ -689,36 +689,52 @@ dtls_new_record_layer(OSSL_LIB_CTX *libctx, const char *propq, int vers,
     return ret;
 }
 
+int dtls_prepare_record_header(OSSL_RECORD_LAYER *rl,
+                               WPACKET *thispkt,
+                               OSSL_RECORD_TEMPLATE *templ,
+                               unsigned int rectype,
+                               unsigned char **recdata)
+{
+    size_t maxcomplen;
+
+    *recdata = NULL;
+
+    maxcomplen = templ->buflen;
+    if (rl->compctx != NULL)
+        maxcomplen += SSL3_RT_MAX_COMPRESSED_OVERHEAD;
+
+    if (!WPACKET_put_bytes_u8(thispkt, rectype)
+            || !WPACKET_put_bytes_u16(thispkt, templ->version)
+            || !WPACKET_put_bytes_u16(thispkt, rl->epoch)
+            || !WPACKET_memcpy(thispkt, &(rl->sequence[2]), 6)
+            || !WPACKET_start_sub_packet_u16(thispkt)
+            || (rl->eivlen > 0
+                && !WPACKET_allocate_bytes(thispkt, rl->eivlen, NULL))
+            || (maxcomplen > 0
+                && !WPACKET_reserve_bytes(thispkt, maxcomplen,
+                                          recdata))) {
+        RLAYERfatal(rl, SSL_AD_INTERNAL_ERROR, ERR_R_INTERNAL_ERROR);
+        return 0;
+    }
+
+    return 1;
+}
+
 int dtls_write_records(OSSL_RECORD_LAYER *rl, OSSL_RECORD_TEMPLATE *templates,
                        size_t numtempl)
 {
-    /* TODO(RECLAYER): Remove me */
-    SSL_CONNECTION *sc = (SSL_CONNECTION *)rl->cbarg;
-    unsigned char *p, *pseq;
-    int mac_size, clear = 0;
-    int eivlen;
+    int mac_size = 0;
     SSL3_RECORD wr;
     SSL3_BUFFER *wb;
-    SSL_SESSION *sess;
-    SSL *s = SSL_CONNECTION_GET_SSL(sc);
     WPACKET pkt, *thispkt = &pkt;
     size_t wpinited = 0;
     int ret = 0;
+    unsigned char *compressdata = NULL;
 
-    sess = sc->session;
-
-    if ((sess == NULL)
-            || (sc->enc_write_ctx == NULL)
-            || (EVP_MD_CTX_get0_md(sc->write_hash) == NULL))
-        clear = 1;
-
-    if (clear)
-        mac_size = 0;
-    else {
-        mac_size = EVP_MD_CTX_get_size(sc->write_hash);
+    if (rl->md_ctx != NULL && EVP_MD_CTX_get0_md(rl->md_ctx) != NULL) {
+        mac_size = EVP_MD_CTX_get_size(rl->md_ctx);
         if (mac_size < 0) {
-            RLAYERfatal(rl, SSL_AD_INTERNAL_ERROR,
-                        SSL_R_EXCEEDS_MAX_FRAGMENT_SIZE);
+            RLAYERfatal(rl, SSL_AD_INTERNAL_ERROR, ERR_R_INTERNAL_ERROR);
             return 0;
         }
     }
@@ -741,45 +757,19 @@ int dtls_write_records(OSSL_RECORD_LAYER *rl, OSSL_RECORD_TEMPLATE *templates,
         return 0;
     }
 
-
     wb = rl->wbuf;
-    p = SSL3_BUFFER_get_buf(wb);
 
-    /* write the header */
-
-    *(p++) = templates->type & 0xff;
     SSL3_RECORD_set_type(&wr, templates->type);
-    *(p++) = templates->version >> 8;
-    *(p++) = templates->version & 0xff;
-
-    /* field where we are to write out packet epoch, seq num and len */
-    pseq = p;
-    p += 10;
-
-    /* Explicit IV length, block ciphers appropriate version flag */
-    if (sc->enc_write_ctx) {
-        int mode = EVP_CIPHER_CTX_get_mode(sc->enc_write_ctx);
-        if (mode == EVP_CIPH_CBC_MODE) {
-            eivlen = EVP_CIPHER_CTX_get_iv_length(sc->enc_write_ctx);
-            if (eivlen < 0) {
-                RLAYERfatal(rl, SSL_AD_INTERNAL_ERROR, SSL_R_LIBRARY_BUG);
-                goto err;
-            }
-            if (eivlen <= 1)
-                eivlen = 0;
-        }
-        /* Need explicit part of IV for GCM mode */
-        else if (mode == EVP_CIPH_GCM_MODE)
-            eivlen = EVP_GCM_TLS_EXPLICIT_IV_LEN;
-        else if (mode == EVP_CIPH_CCM_MODE)
-            eivlen = EVP_CCM_TLS_EXPLICIT_IV_LEN;
-        else
-            eivlen = 0;
-    } else
-        eivlen = 0;
+    SSL3_RECORD_set_rec_version(&wr, templates->version);
+
+    if (!rl->funcs->prepare_record_header(rl, thispkt, templates,
+                                          templates->type, &compressdata)) {
+        /* RLAYERfatal() already called */
+        goto err;
+    }
 
     /* lets setup the record stuff. */
-    SSL3_RECORD_set_data(&wr, p + eivlen); /* make room for IV in case of CBC */
+    SSL3_RECORD_set_data(&wr, compressdata);
     SSL3_RECORD_set_length(&wr, templates->buflen);
     SSL3_RECORD_set_input(&wr, (unsigned char *)templates->buf);
 
@@ -788,91 +778,43 @@ int dtls_write_records(OSSL_RECORD_LAYER *rl, OSSL_RECORD_TEMPLATE *templates,
      */
 
     /* first we compress */
-    if (sc->compress != NULL) {
-        if (!ssl3_do_compress(sc, &wr)) {
+    if (rl->compctx != NULL) {
+        if (!tls_do_compress(rl, &wr)
+                || !WPACKET_allocate_bytes(thispkt, wr.length, NULL)) {
             RLAYERfatal(rl, SSL_AD_INTERNAL_ERROR, SSL_R_COMPRESSION_FAILURE);
             goto err;
         }
-    } else {
-        memcpy(SSL3_RECORD_get_data(&wr), SSL3_RECORD_get_input(&wr),
-               SSL3_RECORD_get_length(&wr));
-        SSL3_RECORD_reset_input(&wr);
-    }
-
-    /*
-     * we should still have the output to wr.data and the input from
-     * wr.input.  Length should be wr.length. wr.data still points in the
-     * wb->buf
-     */
-
-    if (!SSL_WRITE_ETM(sc) && mac_size != 0) {
-        if (!s->method->ssl3_enc->mac(sc, &wr,
-                                      &(p[SSL3_RECORD_get_length(&wr) + eivlen]),
-                                      1)) {
+    } else if (compressdata != NULL) {
+        if (!WPACKET_memcpy(thispkt, wr.input, wr.length)) {
             RLAYERfatal(rl, SSL_AD_INTERNAL_ERROR, ERR_R_INTERNAL_ERROR);
             goto err;
         }
-        SSL3_RECORD_add_length(&wr, mac_size);
+        SSL3_RECORD_reset_input(&wr);
     }
 
-    /* this is true regardless of mac size */
-    SSL3_RECORD_set_data(&wr, p);
-    SSL3_RECORD_reset_input(&wr);
-
-    if (eivlen)
-        SSL3_RECORD_add_length(&wr, eivlen);
-
-    if (s->method->ssl3_enc->enc(sc, &wr, 1, 1, NULL, mac_size) < 1) {
-        if (!ossl_statem_in_error(sc)) {
-            RLAYERfatal(rl, SSL_AD_INTERNAL_ERROR, ERR_R_INTERNAL_ERROR);
-        }
+    if (!rl->funcs->prepare_for_encryption(rl, mac_size, thispkt, &wr)) {
+        /* RLAYERfatal() already called */
         goto err;
     }
 
-    if (SSL_WRITE_ETM(sc) && mac_size != 0) {
-        if (!s->method->ssl3_enc->mac(sc, &wr,
-                                      &(p[SSL3_RECORD_get_length(&wr)]), 1)) {
+    if (rl->funcs->cipher(rl, &wr, 1, 1, NULL, mac_size) < 1) {
+        if (rl->alert == SSL_AD_NO_ALERT) {
             RLAYERfatal(rl, SSL_AD_INTERNAL_ERROR, ERR_R_INTERNAL_ERROR);
-            goto err;
         }
-        SSL3_RECORD_add_length(&wr, mac_size);
+        goto err;
     }
 
-    /* record length after mac and block padding */
-
-    /* there's only one epoch between handshake and app data */
-
-    s2n(sc->rlayer.d->w_epoch, pseq);
-
-    memcpy(pseq, &(sc->rlayer.write_sequence[2]), 6);
-    pseq += 6;
-    s2n(SSL3_RECORD_get_length(&wr), pseq);
-
-    if (sc->msg_callback)
-        sc->msg_callback(1, 0, SSL3_RT_HEADER, pseq - DTLS1_RT_HEADER_LENGTH,
-                         DTLS1_RT_HEADER_LENGTH, s, sc->msg_callback_arg);
-
-    /*
-     * we should now have wr.data pointing to the encrypted data, which is
-     * wr->length long
-     */
-    SSL3_RECORD_set_type(&wr, templates->type); /* not needed but helps for debugging */
-    SSL3_RECORD_add_length(&wr, DTLS1_RT_HEADER_LENGTH);
+    if (!rl->funcs->post_encryption_processing(rl, mac_size, templates,
+                                               thispkt, &wr)) {
+        /* RLAYERfatal() already called */
+        goto err;
+    }
 
-    ssl3_record_sequence_update(&(sc->rlayer.write_sequence[0]));
+    /* TODO(RECLAYER): FIXME */
+    ssl3_record_sequence_update(rl->sequence);
 
     /* now let's set up wb */
     SSL3_BUFFER_set_left(wb, SSL3_RECORD_get_length(&wr));
-    SSL3_BUFFER_set_offset(wb, 0);
-
-    /*
-     * memorize arguments so that ssl3_write_pending can detect bad write
-     * retries later
-     */
-    sc->rlayer.wpend_tot = templates->buflen;
-    sc->rlayer.wpend_buf = templates->buf;
-    sc->rlayer.wpend_type = templates->type;
-    sc->rlayer.wpend_ret = templates->buflen;
 
     ret = 1;
  err:
index b9ce61e4efaaa55eb906f2b5b4e106b635c24ab7..2ee6c2e7531f9af2e7883e90ea204e7ec125ed90 100644 (file)
@@ -349,11 +349,17 @@ int tls_default_read_n(OSSL_RECORD_LAYER *rl, size_t n, size_t max, int extend,
 int tls_get_more_records(OSSL_RECORD_LAYER *rl);
 int dtls_get_more_records(OSSL_RECORD_LAYER *rl);
 
+int dtls_prepare_record_header(OSSL_RECORD_LAYER *rl,
+                               WPACKET *thispkt,
+                               OSSL_RECORD_TEMPLATE *templ,
+                               unsigned int rectype,
+                               unsigned char **recdata);
 int dtls_write_records(OSSL_RECORD_LAYER *rl, OSSL_RECORD_TEMPLATE *templates,
                        size_t numtempl);
 
 int tls_default_set_protocol_version(OSSL_RECORD_LAYER *rl, int version);
 int tls_default_validate_record_header(OSSL_RECORD_LAYER *rl, SSL3_RECORD *re);
+int tls_do_compress(OSSL_RECORD_LAYER *rl, SSL3_RECORD *wr);
 int tls_do_uncompress(OSSL_RECORD_LAYER *rl, SSL3_RECORD *rec);
 int tls_default_post_process_record(OSSL_RECORD_LAYER *rl, SSL3_RECORD *rec);
 int tls13_common_post_process_record(OSSL_RECORD_LAYER *rl, SSL3_RECORD *rec);
index 5f6ff3f806b1b54268d121ee79d4bf49dcf79015..166ee548eb910c7e344cf686ab8584bafc734058 100644 (file)
@@ -683,8 +683,9 @@ struct record_functions_st dtls_1_funcs = {
     /* Don't use tls1_initialise_write_packets for same reason as above */
     tls_initialise_write_packets_default,
     NULL,
+    dtls_prepare_record_header,
     NULL,
-    NULL,
-    NULL,
+    tls_prepare_for_encryption_default,
+    tls_post_encryption_processing_default,
     NULL
 };
index dd497ed1de98b8a9b76a8b474336509a685ba48c..238684a77b25f4063eb9483620b7a7e216286675 100644 (file)
@@ -962,7 +962,7 @@ int tls_default_validate_record_header(OSSL_RECORD_LAYER *rl, SSL3_RECORD *rec)
     return 1;
 }
 
-static int tls_do_compress(OSSL_RECORD_LAYER *rl, SSL3_RECORD *wr)
+int tls_do_compress(OSSL_RECORD_LAYER *rl, SSL3_RECORD *wr)
 {
 #ifndef OPENSSL_NO_COMP
     int i;
@@ -1514,7 +1514,8 @@ int tls_initialise_write_packets_default(OSSL_RECORD_LAYER *rl,
         wb->type = templates[j].type;
 
 #if defined(SSL3_ALIGN_PAYLOAD) && SSL3_ALIGN_PAYLOAD != 0
-        align = (size_t)SSL3_BUFFER_get_buf(wb) + SSL3_RT_HEADER_LENGTH;
+        align = (size_t)SSL3_BUFFER_get_buf(wb);
+        align += rl->isdtls ? DTLS1_RT_HEADER_LENGTH : SSL3_RT_HEADER_LENGTH;
         align = SSL3_ALIGN_PAYLOAD - 1
                 - ((align - 1) % SSL3_ALIGN_PAYLOAD);
 #endif
@@ -1621,6 +1622,8 @@ int tls_post_encryption_processing_default(OSSL_RECORD_LAYER *rl,
                                            SSL3_RECORD *thiswr)
 {
     size_t origlen, len;
+    size_t headerlen = rl->isdtls ? DTLS1_RT_HEADER_LENGTH
+                                  : SSL3_RT_HEADER_LENGTH;
 
     /* Allocate bytes for the encryption overhead */
     if (!WPACKET_get_length(thispkt, &origlen)
@@ -1654,9 +1657,9 @@ int tls_post_encryption_processing_default(OSSL_RECORD_LAYER *rl,
     if (rl->msg_callback != NULL) {
         unsigned char *recordstart;
 
-        recordstart = WPACKET_get_curr(thispkt) - len - SSL3_RT_HEADER_LENGTH;
+        recordstart = WPACKET_get_curr(thispkt) - len - headerlen;
         rl->msg_callback(1, thiswr->rec_version, SSL3_RT_HEADER, recordstart,
-                         SSL3_RT_HEADER_LENGTH, rl->cbarg);
+                         headerlen, rl->cbarg);
 
         if (rl->version == TLS1_3_VERSION && rl->enc_ctx != NULL) {
             unsigned char ctype = thistempl->type;
@@ -1671,7 +1674,7 @@ int tls_post_encryption_processing_default(OSSL_RECORD_LAYER *rl,
         return 0;
     }
 
-    SSL3_RECORD_add_length(thiswr, SSL3_RT_HEADER_LENGTH);
+    SSL3_RECORD_add_length(thiswr, headerlen);
 
     return 1;
 }
index 4cdb0e8ca67afa556653467544f72e900ead8b9f..ff08c11d0dd8eacf9ade7f481697cda44bc18795 100644 (file)
@@ -187,9 +187,9 @@ struct record_functions_st dtls_any_funcs = {
     tls_allocate_write_buffers_default,
     tls_initialise_write_packets_default,
     NULL,
+    dtls_prepare_record_header,
     NULL,
-    NULL,
-    NULL,
-    NULL,
+    tls_prepare_for_encryption_default,
+    tls_post_encryption_processing_default,
     NULL
 };
index aa81d589b552be5cb3f41f78da5f6b3b01ddb35c..04f130bc2eba140185c0c5205945f124bda295bd 100644 (file)
@@ -1267,6 +1267,10 @@ int ssl_set_new_record_layer(SSL_CONNECTION *s, int version,
                 return 0;
             }
             s->rlayer.rrlnext = next;
+        } else {
+            if (SSL_CONNECTION_IS_DTLS(s)
+                    && level != OSSL_RECORD_PROTECTION_LEVEL_NONE)
+                epoch =  DTLS_RECORD_LAYER_get_w_epoch(&s->rlayer) + 1; /* new epoch */
         }
 
         /*
@@ -1325,9 +1329,17 @@ int ssl_set_new_record_layer(SSL_CONNECTION *s, int version,
         break;
     }
 
-    if (*thismethod != NULL && !(*thismethod)->free(*thisrl)) {
-        SSLfatal(s, SSL_AD_INTERNAL_ERROR, ERR_R_INTERNAL_ERROR);
-        return 0;
+    /*
+     * Free the old record layer if we have one except in the case of DTLS when
+     * writing. In that case the record layer is still referenced by buffered
+     * messages for potential retransmit. Only when those buffered messages get
+     * freed do we free the record layer object (see dtls1_hm_fragment_free)
+     */
+    if (!SSL_CONNECTION_IS_DTLS(s) || direction == OSSL_RECORD_DIRECTION_READ) {
+        if (*thismethod != NULL && !(*thismethod)->free(*thisrl)) {
+            SSLfatal(s, SSL_AD_INTERNAL_ERROR, ERR_R_INTERNAL_ERROR);
+            return 0;
+        }
     }
 
     *thisrl = newrl;
index fd21d0be82ec3d62ed221a3db7d09bafaa02305f..a1f15a712fa3611847c24dcf34ef934e9e87bef9 100644 (file)
@@ -1915,11 +1915,8 @@ typedef struct {
 # define DTLS1_SKIP_RECORD_HEADER                 2
 
 struct dtls1_retransmit_state {
-    EVP_CIPHER_CTX *enc_write_ctx; /* cryptographic state */
-    EVP_MD_CTX *write_hash;     /* used for mac generation */
-    COMP_CTX *compress;         /* compression */
-    SSL_SESSION *session;
-    uint16_t epoch;
+    const OSSL_RECORD_METHOD *wrlmethod;
+    OSSL_RECORD_LAYER *wrl;
 };
 
 struct hm_header_st {
index 93c49011a2ca9b09b8ba837f98ea9dfa9568697e..b673c860abc2c5c3525d5d1db680f958109f7812 100644 (file)
@@ -94,9 +94,12 @@ void dtls1_hm_fragment_free(hm_fragment *frag)
     if (!frag)
         return;
     if (frag->msg_header.is_ccs) {
-        EVP_CIPHER_CTX_free(frag->msg_header.
-                            saved_retransmit_state.enc_write_ctx);
-        EVP_MD_CTX_free(frag->msg_header.saved_retransmit_state.write_hash);
+        /*
+         * If we're freeing the CCS then we're done with the old wrl and it
+         * can bee freed
+         */
+        if (frag->msg_header.saved_retransmit_state.wrlmethod != NULL)
+            frag->msg_header.saved_retransmit_state.wrlmethod->free(frag->msg_header.saved_retransmit_state.wrl);
     }
     OPENSSL_free(frag->fragment);
     OPENSSL_free(frag->reassembly);
@@ -1161,12 +1164,9 @@ int dtls1_buffer_message(SSL_CONNECTION *s, int is_ccs)
     frag->msg_header.is_ccs = is_ccs;
 
     /* save current state */
-    frag->msg_header.saved_retransmit_state.enc_write_ctx = s->enc_write_ctx;
-    frag->msg_header.saved_retransmit_state.write_hash = s->write_hash;
-    frag->msg_header.saved_retransmit_state.compress = s->compress;
-    frag->msg_header.saved_retransmit_state.session = s->session;
-    frag->msg_header.saved_retransmit_state.epoch =
-        DTLS_RECORD_LAYER_get_w_epoch(&s->rlayer);
+    frag->msg_header.saved_retransmit_state.wrlmethod = s->rlayer.wrlmethod;
+    frag->msg_header.saved_retransmit_state.wrl = s->rlayer.wrl;
+
 
     memset(seq64be, 0, sizeof(seq64be));
     seq64be[6] =
@@ -1228,32 +1228,27 @@ int dtls1_retransmit_message(SSL_CONNECTION *s, unsigned short seq, int *found)
                                  frag->msg_header.frag_len);
 
     /* save current state */
-    saved_state.enc_write_ctx = s->enc_write_ctx;
-    saved_state.write_hash = s->write_hash;
-    saved_state.compress = s->compress;
-    saved_state.session = s->session;
-    saved_state.epoch = DTLS_RECORD_LAYER_get_w_epoch(&s->rlayer);
+    saved_state.wrlmethod = s->rlayer.wrlmethod;
+    saved_state.wrl = s->rlayer.wrl;
 
     s->d1->retransmitting = 1;
 
     /* restore state in which the message was originally sent */
-    s->enc_write_ctx = frag->msg_header.saved_retransmit_state.enc_write_ctx;
-    s->write_hash = frag->msg_header.saved_retransmit_state.write_hash;
-    s->compress = frag->msg_header.saved_retransmit_state.compress;
-    s->session = frag->msg_header.saved_retransmit_state.session;
-    DTLS_RECORD_LAYER_set_saved_w_epoch(&s->rlayer,
-                                        frag->msg_header.
-                                        saved_retransmit_state.epoch);
+    s->rlayer.wrlmethod = frag->msg_header.saved_retransmit_state.wrlmethod;
+    s->rlayer.wrl = frag->msg_header.saved_retransmit_state.wrl;
+
+    /*
+     * The old wrl may be still pointing at an old BIO. Update it to what we're
+     * using now.
+     */
+    s->rlayer.wrlmethod->set1_bio(s->rlayer.wrl, s->wbio);
 
     ret = dtls1_do_write(s, frag->msg_header.is_ccs ?
                          SSL3_RT_CHANGE_CIPHER_SPEC : SSL3_RT_HANDSHAKE);
 
     /* restore current state */
-    s->enc_write_ctx = saved_state.enc_write_ctx;
-    s->write_hash = saved_state.write_hash;
-    s->compress = saved_state.compress;
-    s->session = saved_state.session;
-    DTLS_RECORD_LAYER_set_saved_w_epoch(&s->rlayer, saved_state.epoch);
+    s->rlayer.wrlmethod = saved_state.wrlmethod;
+    s->rlayer.wrl = saved_state.wrl;
 
     s->d1->retransmitting = 0;