From b9e37f8f573de1951655f6d8684f2f65ffc6905b Mon Sep 17 00:00:00 2001 From: Matt Caswell Date: Thu, 13 Oct 2022 16:44:22 +0100 Subject: [PATCH] Convert dtls_write_records to use standard record layer functions We have standard functions for most of the work that dtls_write_records does - so we convert it to use those functions instead. Reviewed-by: Richard Levitte Reviewed-by: Tomas Mraz Reviewed-by: Hugo Landau (Merged from https://github.com/openssl/openssl/pull/19424) --- ssl/record/methods/dtls_meth.c | 182 +++++++++------------------ ssl/record/methods/recmethod_local.h | 6 + ssl/record/methods/tls1_meth.c | 5 +- ssl/record/methods/tls_common.c | 13 +- ssl/record/methods/tlsany_meth.c | 6 +- ssl/record/rec_layer_s3.c | 18 ++- ssl/ssl_local.h | 7 +- ssl/statem/statem_dtls.c | 47 ++++--- 8 files changed, 120 insertions(+), 164 deletions(-) diff --git a/ssl/record/methods/dtls_meth.c b/ssl/record/methods/dtls_meth.c index 1b51c84893..e6c71ed1e7 100644 --- a/ssl/record/methods/dtls_meth.c +++ b/ssl/record/methods/dtls_meth.c @@ -689,36 +689,52 @@ dtls_new_record_layer(OSSL_LIB_CTX *libctx, const char *propq, int vers, return ret; } +int dtls_prepare_record_header(OSSL_RECORD_LAYER *rl, + WPACKET *thispkt, + OSSL_RECORD_TEMPLATE *templ, + unsigned int rectype, + unsigned char **recdata) +{ + size_t maxcomplen; + + *recdata = NULL; + + maxcomplen = templ->buflen; + if (rl->compctx != NULL) + maxcomplen += SSL3_RT_MAX_COMPRESSED_OVERHEAD; + + if (!WPACKET_put_bytes_u8(thispkt, rectype) + || !WPACKET_put_bytes_u16(thispkt, templ->version) + || !WPACKET_put_bytes_u16(thispkt, rl->epoch) + || !WPACKET_memcpy(thispkt, &(rl->sequence[2]), 6) + || !WPACKET_start_sub_packet_u16(thispkt) + || (rl->eivlen > 0 + && !WPACKET_allocate_bytes(thispkt, rl->eivlen, NULL)) + || (maxcomplen > 0 + && !WPACKET_reserve_bytes(thispkt, maxcomplen, + recdata))) { + RLAYERfatal(rl, SSL_AD_INTERNAL_ERROR, ERR_R_INTERNAL_ERROR); + return 0; + } + + return 1; +} + int dtls_write_records(OSSL_RECORD_LAYER *rl, OSSL_RECORD_TEMPLATE *templates, size_t numtempl) { - /* TODO(RECLAYER): Remove me */ - SSL_CONNECTION *sc = (SSL_CONNECTION *)rl->cbarg; - unsigned char *p, *pseq; - int mac_size, clear = 0; - int eivlen; + int mac_size = 0; SSL3_RECORD wr; SSL3_BUFFER *wb; - SSL_SESSION *sess; - SSL *s = SSL_CONNECTION_GET_SSL(sc); WPACKET pkt, *thispkt = &pkt; size_t wpinited = 0; int ret = 0; + unsigned char *compressdata = NULL; - sess = sc->session; - - if ((sess == NULL) - || (sc->enc_write_ctx == NULL) - || (EVP_MD_CTX_get0_md(sc->write_hash) == NULL)) - clear = 1; - - if (clear) - mac_size = 0; - else { - mac_size = EVP_MD_CTX_get_size(sc->write_hash); + if (rl->md_ctx != NULL && EVP_MD_CTX_get0_md(rl->md_ctx) != NULL) { + mac_size = EVP_MD_CTX_get_size(rl->md_ctx); if (mac_size < 0) { - RLAYERfatal(rl, SSL_AD_INTERNAL_ERROR, - SSL_R_EXCEEDS_MAX_FRAGMENT_SIZE); + RLAYERfatal(rl, SSL_AD_INTERNAL_ERROR, ERR_R_INTERNAL_ERROR); return 0; } } @@ -741,45 +757,19 @@ int dtls_write_records(OSSL_RECORD_LAYER *rl, OSSL_RECORD_TEMPLATE *templates, return 0; } - wb = rl->wbuf; - p = SSL3_BUFFER_get_buf(wb); - /* write the header */ - - *(p++) = templates->type & 0xff; SSL3_RECORD_set_type(&wr, templates->type); - *(p++) = templates->version >> 8; - *(p++) = templates->version & 0xff; - - /* field where we are to write out packet epoch, seq num and len */ - pseq = p; - p += 10; - - /* Explicit IV length, block ciphers appropriate version flag */ - if (sc->enc_write_ctx) { - int mode = EVP_CIPHER_CTX_get_mode(sc->enc_write_ctx); - if (mode == EVP_CIPH_CBC_MODE) { - eivlen = EVP_CIPHER_CTX_get_iv_length(sc->enc_write_ctx); - if (eivlen < 0) { - RLAYERfatal(rl, SSL_AD_INTERNAL_ERROR, SSL_R_LIBRARY_BUG); - goto err; - } - if (eivlen <= 1) - eivlen = 0; - } - /* Need explicit part of IV for GCM mode */ - else if (mode == EVP_CIPH_GCM_MODE) - eivlen = EVP_GCM_TLS_EXPLICIT_IV_LEN; - else if (mode == EVP_CIPH_CCM_MODE) - eivlen = EVP_CCM_TLS_EXPLICIT_IV_LEN; - else - eivlen = 0; - } else - eivlen = 0; + SSL3_RECORD_set_rec_version(&wr, templates->version); + + if (!rl->funcs->prepare_record_header(rl, thispkt, templates, + templates->type, &compressdata)) { + /* RLAYERfatal() already called */ + goto err; + } /* lets setup the record stuff. */ - SSL3_RECORD_set_data(&wr, p + eivlen); /* make room for IV in case of CBC */ + SSL3_RECORD_set_data(&wr, compressdata); SSL3_RECORD_set_length(&wr, templates->buflen); SSL3_RECORD_set_input(&wr, (unsigned char *)templates->buf); @@ -788,91 +778,43 @@ int dtls_write_records(OSSL_RECORD_LAYER *rl, OSSL_RECORD_TEMPLATE *templates, */ /* first we compress */ - if (sc->compress != NULL) { - if (!ssl3_do_compress(sc, &wr)) { + if (rl->compctx != NULL) { + if (!tls_do_compress(rl, &wr) + || !WPACKET_allocate_bytes(thispkt, wr.length, NULL)) { RLAYERfatal(rl, SSL_AD_INTERNAL_ERROR, SSL_R_COMPRESSION_FAILURE); goto err; } - } else { - memcpy(SSL3_RECORD_get_data(&wr), SSL3_RECORD_get_input(&wr), - SSL3_RECORD_get_length(&wr)); - SSL3_RECORD_reset_input(&wr); - } - - /* - * we should still have the output to wr.data and the input from - * wr.input. Length should be wr.length. wr.data still points in the - * wb->buf - */ - - if (!SSL_WRITE_ETM(sc) && mac_size != 0) { - if (!s->method->ssl3_enc->mac(sc, &wr, - &(p[SSL3_RECORD_get_length(&wr) + eivlen]), - 1)) { + } else if (compressdata != NULL) { + if (!WPACKET_memcpy(thispkt, wr.input, wr.length)) { RLAYERfatal(rl, SSL_AD_INTERNAL_ERROR, ERR_R_INTERNAL_ERROR); goto err; } - SSL3_RECORD_add_length(&wr, mac_size); + SSL3_RECORD_reset_input(&wr); } - /* this is true regardless of mac size */ - SSL3_RECORD_set_data(&wr, p); - SSL3_RECORD_reset_input(&wr); - - if (eivlen) - SSL3_RECORD_add_length(&wr, eivlen); - - if (s->method->ssl3_enc->enc(sc, &wr, 1, 1, NULL, mac_size) < 1) { - if (!ossl_statem_in_error(sc)) { - RLAYERfatal(rl, SSL_AD_INTERNAL_ERROR, ERR_R_INTERNAL_ERROR); - } + if (!rl->funcs->prepare_for_encryption(rl, mac_size, thispkt, &wr)) { + /* RLAYERfatal() already called */ goto err; } - if (SSL_WRITE_ETM(sc) && mac_size != 0) { - if (!s->method->ssl3_enc->mac(sc, &wr, - &(p[SSL3_RECORD_get_length(&wr)]), 1)) { + if (rl->funcs->cipher(rl, &wr, 1, 1, NULL, mac_size) < 1) { + if (rl->alert == SSL_AD_NO_ALERT) { RLAYERfatal(rl, SSL_AD_INTERNAL_ERROR, ERR_R_INTERNAL_ERROR); - goto err; } - SSL3_RECORD_add_length(&wr, mac_size); + goto err; } - /* record length after mac and block padding */ - - /* there's only one epoch between handshake and app data */ - - s2n(sc->rlayer.d->w_epoch, pseq); - - memcpy(pseq, &(sc->rlayer.write_sequence[2]), 6); - pseq += 6; - s2n(SSL3_RECORD_get_length(&wr), pseq); - - if (sc->msg_callback) - sc->msg_callback(1, 0, SSL3_RT_HEADER, pseq - DTLS1_RT_HEADER_LENGTH, - DTLS1_RT_HEADER_LENGTH, s, sc->msg_callback_arg); - - /* - * we should now have wr.data pointing to the encrypted data, which is - * wr->length long - */ - SSL3_RECORD_set_type(&wr, templates->type); /* not needed but helps for debugging */ - SSL3_RECORD_add_length(&wr, DTLS1_RT_HEADER_LENGTH); + if (!rl->funcs->post_encryption_processing(rl, mac_size, templates, + thispkt, &wr)) { + /* RLAYERfatal() already called */ + goto err; + } - ssl3_record_sequence_update(&(sc->rlayer.write_sequence[0])); + /* TODO(RECLAYER): FIXME */ + ssl3_record_sequence_update(rl->sequence); /* now let's set up wb */ SSL3_BUFFER_set_left(wb, SSL3_RECORD_get_length(&wr)); - SSL3_BUFFER_set_offset(wb, 0); - - /* - * memorize arguments so that ssl3_write_pending can detect bad write - * retries later - */ - sc->rlayer.wpend_tot = templates->buflen; - sc->rlayer.wpend_buf = templates->buf; - sc->rlayer.wpend_type = templates->type; - sc->rlayer.wpend_ret = templates->buflen; ret = 1; err: diff --git a/ssl/record/methods/recmethod_local.h b/ssl/record/methods/recmethod_local.h index b9ce61e4ef..2ee6c2e753 100644 --- a/ssl/record/methods/recmethod_local.h +++ b/ssl/record/methods/recmethod_local.h @@ -349,11 +349,17 @@ int tls_default_read_n(OSSL_RECORD_LAYER *rl, size_t n, size_t max, int extend, int tls_get_more_records(OSSL_RECORD_LAYER *rl); int dtls_get_more_records(OSSL_RECORD_LAYER *rl); +int dtls_prepare_record_header(OSSL_RECORD_LAYER *rl, + WPACKET *thispkt, + OSSL_RECORD_TEMPLATE *templ, + unsigned int rectype, + unsigned char **recdata); int dtls_write_records(OSSL_RECORD_LAYER *rl, OSSL_RECORD_TEMPLATE *templates, size_t numtempl); int tls_default_set_protocol_version(OSSL_RECORD_LAYER *rl, int version); int tls_default_validate_record_header(OSSL_RECORD_LAYER *rl, SSL3_RECORD *re); +int tls_do_compress(OSSL_RECORD_LAYER *rl, SSL3_RECORD *wr); int tls_do_uncompress(OSSL_RECORD_LAYER *rl, SSL3_RECORD *rec); int tls_default_post_process_record(OSSL_RECORD_LAYER *rl, SSL3_RECORD *rec); int tls13_common_post_process_record(OSSL_RECORD_LAYER *rl, SSL3_RECORD *rec); diff --git a/ssl/record/methods/tls1_meth.c b/ssl/record/methods/tls1_meth.c index 5f6ff3f806..166ee548eb 100644 --- a/ssl/record/methods/tls1_meth.c +++ b/ssl/record/methods/tls1_meth.c @@ -683,8 +683,9 @@ struct record_functions_st dtls_1_funcs = { /* Don't use tls1_initialise_write_packets for same reason as above */ tls_initialise_write_packets_default, NULL, + dtls_prepare_record_header, NULL, - NULL, - NULL, + tls_prepare_for_encryption_default, + tls_post_encryption_processing_default, NULL }; diff --git a/ssl/record/methods/tls_common.c b/ssl/record/methods/tls_common.c index dd497ed1de..238684a77b 100644 --- a/ssl/record/methods/tls_common.c +++ b/ssl/record/methods/tls_common.c @@ -962,7 +962,7 @@ int tls_default_validate_record_header(OSSL_RECORD_LAYER *rl, SSL3_RECORD *rec) return 1; } -static int tls_do_compress(OSSL_RECORD_LAYER *rl, SSL3_RECORD *wr) +int tls_do_compress(OSSL_RECORD_LAYER *rl, SSL3_RECORD *wr) { #ifndef OPENSSL_NO_COMP int i; @@ -1514,7 +1514,8 @@ int tls_initialise_write_packets_default(OSSL_RECORD_LAYER *rl, wb->type = templates[j].type; #if defined(SSL3_ALIGN_PAYLOAD) && SSL3_ALIGN_PAYLOAD != 0 - align = (size_t)SSL3_BUFFER_get_buf(wb) + SSL3_RT_HEADER_LENGTH; + align = (size_t)SSL3_BUFFER_get_buf(wb); + align += rl->isdtls ? DTLS1_RT_HEADER_LENGTH : SSL3_RT_HEADER_LENGTH; align = SSL3_ALIGN_PAYLOAD - 1 - ((align - 1) % SSL3_ALIGN_PAYLOAD); #endif @@ -1621,6 +1622,8 @@ int tls_post_encryption_processing_default(OSSL_RECORD_LAYER *rl, SSL3_RECORD *thiswr) { size_t origlen, len; + size_t headerlen = rl->isdtls ? DTLS1_RT_HEADER_LENGTH + : SSL3_RT_HEADER_LENGTH; /* Allocate bytes for the encryption overhead */ if (!WPACKET_get_length(thispkt, &origlen) @@ -1654,9 +1657,9 @@ int tls_post_encryption_processing_default(OSSL_RECORD_LAYER *rl, if (rl->msg_callback != NULL) { unsigned char *recordstart; - recordstart = WPACKET_get_curr(thispkt) - len - SSL3_RT_HEADER_LENGTH; + recordstart = WPACKET_get_curr(thispkt) - len - headerlen; rl->msg_callback(1, thiswr->rec_version, SSL3_RT_HEADER, recordstart, - SSL3_RT_HEADER_LENGTH, rl->cbarg); + headerlen, rl->cbarg); if (rl->version == TLS1_3_VERSION && rl->enc_ctx != NULL) { unsigned char ctype = thistempl->type; @@ -1671,7 +1674,7 @@ int tls_post_encryption_processing_default(OSSL_RECORD_LAYER *rl, return 0; } - SSL3_RECORD_add_length(thiswr, SSL3_RT_HEADER_LENGTH); + SSL3_RECORD_add_length(thiswr, headerlen); return 1; } diff --git a/ssl/record/methods/tlsany_meth.c b/ssl/record/methods/tlsany_meth.c index 4cdb0e8ca6..ff08c11d0d 100644 --- a/ssl/record/methods/tlsany_meth.c +++ b/ssl/record/methods/tlsany_meth.c @@ -187,9 +187,9 @@ struct record_functions_st dtls_any_funcs = { tls_allocate_write_buffers_default, tls_initialise_write_packets_default, NULL, + dtls_prepare_record_header, NULL, - NULL, - NULL, - NULL, + tls_prepare_for_encryption_default, + tls_post_encryption_processing_default, NULL }; diff --git a/ssl/record/rec_layer_s3.c b/ssl/record/rec_layer_s3.c index aa81d589b5..04f130bc2e 100644 --- a/ssl/record/rec_layer_s3.c +++ b/ssl/record/rec_layer_s3.c @@ -1267,6 +1267,10 @@ int ssl_set_new_record_layer(SSL_CONNECTION *s, int version, return 0; } s->rlayer.rrlnext = next; + } else { + if (SSL_CONNECTION_IS_DTLS(s) + && level != OSSL_RECORD_PROTECTION_LEVEL_NONE) + epoch = DTLS_RECORD_LAYER_get_w_epoch(&s->rlayer) + 1; /* new epoch */ } /* @@ -1325,9 +1329,17 @@ int ssl_set_new_record_layer(SSL_CONNECTION *s, int version, break; } - if (*thismethod != NULL && !(*thismethod)->free(*thisrl)) { - SSLfatal(s, SSL_AD_INTERNAL_ERROR, ERR_R_INTERNAL_ERROR); - return 0; + /* + * Free the old record layer if we have one except in the case of DTLS when + * writing. In that case the record layer is still referenced by buffered + * messages for potential retransmit. Only when those buffered messages get + * freed do we free the record layer object (see dtls1_hm_fragment_free) + */ + if (!SSL_CONNECTION_IS_DTLS(s) || direction == OSSL_RECORD_DIRECTION_READ) { + if (*thismethod != NULL && !(*thismethod)->free(*thisrl)) { + SSLfatal(s, SSL_AD_INTERNAL_ERROR, ERR_R_INTERNAL_ERROR); + return 0; + } } *thisrl = newrl; diff --git a/ssl/ssl_local.h b/ssl/ssl_local.h index fd21d0be82..a1f15a712f 100644 --- a/ssl/ssl_local.h +++ b/ssl/ssl_local.h @@ -1915,11 +1915,8 @@ typedef struct { # define DTLS1_SKIP_RECORD_HEADER 2 struct dtls1_retransmit_state { - EVP_CIPHER_CTX *enc_write_ctx; /* cryptographic state */ - EVP_MD_CTX *write_hash; /* used for mac generation */ - COMP_CTX *compress; /* compression */ - SSL_SESSION *session; - uint16_t epoch; + const OSSL_RECORD_METHOD *wrlmethod; + OSSL_RECORD_LAYER *wrl; }; struct hm_header_st { diff --git a/ssl/statem/statem_dtls.c b/ssl/statem/statem_dtls.c index 93c49011a2..b673c860ab 100644 --- a/ssl/statem/statem_dtls.c +++ b/ssl/statem/statem_dtls.c @@ -94,9 +94,12 @@ void dtls1_hm_fragment_free(hm_fragment *frag) if (!frag) return; if (frag->msg_header.is_ccs) { - EVP_CIPHER_CTX_free(frag->msg_header. - saved_retransmit_state.enc_write_ctx); - EVP_MD_CTX_free(frag->msg_header.saved_retransmit_state.write_hash); + /* + * If we're freeing the CCS then we're done with the old wrl and it + * can bee freed + */ + if (frag->msg_header.saved_retransmit_state.wrlmethod != NULL) + frag->msg_header.saved_retransmit_state.wrlmethod->free(frag->msg_header.saved_retransmit_state.wrl); } OPENSSL_free(frag->fragment); OPENSSL_free(frag->reassembly); @@ -1161,12 +1164,9 @@ int dtls1_buffer_message(SSL_CONNECTION *s, int is_ccs) frag->msg_header.is_ccs = is_ccs; /* save current state */ - frag->msg_header.saved_retransmit_state.enc_write_ctx = s->enc_write_ctx; - frag->msg_header.saved_retransmit_state.write_hash = s->write_hash; - frag->msg_header.saved_retransmit_state.compress = s->compress; - frag->msg_header.saved_retransmit_state.session = s->session; - frag->msg_header.saved_retransmit_state.epoch = - DTLS_RECORD_LAYER_get_w_epoch(&s->rlayer); + frag->msg_header.saved_retransmit_state.wrlmethod = s->rlayer.wrlmethod; + frag->msg_header.saved_retransmit_state.wrl = s->rlayer.wrl; + memset(seq64be, 0, sizeof(seq64be)); seq64be[6] = @@ -1228,32 +1228,27 @@ int dtls1_retransmit_message(SSL_CONNECTION *s, unsigned short seq, int *found) frag->msg_header.frag_len); /* save current state */ - saved_state.enc_write_ctx = s->enc_write_ctx; - saved_state.write_hash = s->write_hash; - saved_state.compress = s->compress; - saved_state.session = s->session; - saved_state.epoch = DTLS_RECORD_LAYER_get_w_epoch(&s->rlayer); + saved_state.wrlmethod = s->rlayer.wrlmethod; + saved_state.wrl = s->rlayer.wrl; s->d1->retransmitting = 1; /* restore state in which the message was originally sent */ - s->enc_write_ctx = frag->msg_header.saved_retransmit_state.enc_write_ctx; - s->write_hash = frag->msg_header.saved_retransmit_state.write_hash; - s->compress = frag->msg_header.saved_retransmit_state.compress; - s->session = frag->msg_header.saved_retransmit_state.session; - DTLS_RECORD_LAYER_set_saved_w_epoch(&s->rlayer, - frag->msg_header. - saved_retransmit_state.epoch); + s->rlayer.wrlmethod = frag->msg_header.saved_retransmit_state.wrlmethod; + s->rlayer.wrl = frag->msg_header.saved_retransmit_state.wrl; + + /* + * The old wrl may be still pointing at an old BIO. Update it to what we're + * using now. + */ + s->rlayer.wrlmethod->set1_bio(s->rlayer.wrl, s->wbio); ret = dtls1_do_write(s, frag->msg_header.is_ccs ? SSL3_RT_CHANGE_CIPHER_SPEC : SSL3_RT_HANDSHAKE); /* restore current state */ - s->enc_write_ctx = saved_state.enc_write_ctx; - s->write_hash = saved_state.write_hash; - s->compress = saved_state.compress; - s->session = saved_state.session; - DTLS_RECORD_LAYER_set_saved_w_epoch(&s->rlayer, saved_state.epoch); + s->rlayer.wrlmethod = saved_state.wrlmethod; + s->rlayer.wrl = saved_state.wrl; s->d1->retransmitting = 0; -- 2.34.1