Convert dtls_write_records to use standard record layer functions
[openssl.git] / ssl / record / methods / dtls_meth.c
index bf8244ce31d949a7336eb94652f395ea264fc36c..e6c71ed1e7b9b5c05d8b30144951254161d45f7e 100644 (file)
@@ -253,7 +253,7 @@ static int dtls_process_record(OSSL_RECORD_LAYER *rl, DTLS_BITMAP *bitmap)
      * Check if the received packet overflows the current Max Fragment
      * Length setting.
      */
-    if (rl->max_frag_len > 0 && rr->length > rl->max_frag_len) {
+    if (rr->length > rl->max_frag_len) {
         RLAYERfatal(rl, SSL_AD_RECORD_OVERFLOW, SSL_R_DATA_LENGTH_TOO_LONG);
         goto end;
     }
@@ -482,8 +482,7 @@ int dtls_get_more_records(OSSL_RECORD_LAYER *rl)
          * If received packet overflows maximum possible fragment length then
          * silently discard it
          */
-        if (rl->max_frag_len > 0
-                && rr->length > rl->max_frag_len + SSL3_RT_MAX_ENCRYPTED_OVERHEAD) {
+        if (rr->length > rl->max_frag_len + SSL3_RT_MAX_ENCRYPTED_OVERHEAD) {
             /* record too long, silently discard it */
             rr->length = 0;
             rl->packet_length = 0;
@@ -628,7 +627,7 @@ dtls_new_record_layer(OSSL_LIB_CTX *libctx, const char *propq, int vers,
                       size_t ivlen, unsigned char *mackey, size_t mackeylen,
                       const EVP_CIPHER *ciph, size_t taglen,
                       int mactype,
-                      const EVP_MD *md, const SSL_COMP *comp, BIO *prev,
+                      const EVP_MD *md, COMP_METHOD *comp, BIO *prev,
                       BIO *transport, BIO *next, BIO_ADDR *local, BIO_ADDR *peer,
                       const OSSL_PARAM *settings, const OSSL_PARAM *options,
                       const OSSL_DISPATCH *fns, void *cbarg,
@@ -690,6 +689,140 @@ dtls_new_record_layer(OSSL_LIB_CTX *libctx, const char *propq, int vers,
     return ret;
 }
 
+int dtls_prepare_record_header(OSSL_RECORD_LAYER *rl,
+                               WPACKET *thispkt,
+                               OSSL_RECORD_TEMPLATE *templ,
+                               unsigned int rectype,
+                               unsigned char **recdata)
+{
+    size_t maxcomplen;
+
+    *recdata = NULL;
+
+    maxcomplen = templ->buflen;
+    if (rl->compctx != NULL)
+        maxcomplen += SSL3_RT_MAX_COMPRESSED_OVERHEAD;
+
+    if (!WPACKET_put_bytes_u8(thispkt, rectype)
+            || !WPACKET_put_bytes_u16(thispkt, templ->version)
+            || !WPACKET_put_bytes_u16(thispkt, rl->epoch)
+            || !WPACKET_memcpy(thispkt, &(rl->sequence[2]), 6)
+            || !WPACKET_start_sub_packet_u16(thispkt)
+            || (rl->eivlen > 0
+                && !WPACKET_allocate_bytes(thispkt, rl->eivlen, NULL))
+            || (maxcomplen > 0
+                && !WPACKET_reserve_bytes(thispkt, maxcomplen,
+                                          recdata))) {
+        RLAYERfatal(rl, SSL_AD_INTERNAL_ERROR, ERR_R_INTERNAL_ERROR);
+        return 0;
+    }
+
+    return 1;
+}
+
+int dtls_write_records(OSSL_RECORD_LAYER *rl, OSSL_RECORD_TEMPLATE *templates,
+                       size_t numtempl)
+{
+    int mac_size = 0;
+    SSL3_RECORD wr;
+    SSL3_BUFFER *wb;
+    WPACKET pkt, *thispkt = &pkt;
+    size_t wpinited = 0;
+    int ret = 0;
+    unsigned char *compressdata = NULL;
+
+    if (rl->md_ctx != NULL && EVP_MD_CTX_get0_md(rl->md_ctx) != NULL) {
+        mac_size = EVP_MD_CTX_get_size(rl->md_ctx);
+        if (mac_size < 0) {
+            RLAYERfatal(rl, SSL_AD_INTERNAL_ERROR, ERR_R_INTERNAL_ERROR);
+            return 0;
+        }
+    }
+
+    if (numtempl != 1) {
+        /* Should not happen */
+        RLAYERfatal(rl, SSL_AD_INTERNAL_ERROR, ERR_R_INTERNAL_ERROR);
+        return 0;
+    }
+
+    if (!rl->funcs->allocate_write_buffers(rl, templates, numtempl, NULL)) {
+        /* RLAYERfatal() already called */
+        return 0;
+    }
+
+    if (!rl->funcs->initialise_write_packets(rl, templates, numtempl,
+                                             NULL, thispkt, rl->wbuf,
+                                             &wpinited)) {
+        /* RLAYERfatal() already called */
+        return 0;
+    }
+
+    wb = rl->wbuf;
+
+    SSL3_RECORD_set_type(&wr, templates->type);
+    SSL3_RECORD_set_rec_version(&wr, templates->version);
+
+    if (!rl->funcs->prepare_record_header(rl, thispkt, templates,
+                                          templates->type, &compressdata)) {
+        /* RLAYERfatal() already called */
+        goto err;
+    }
+
+    /* lets setup the record stuff. */
+    SSL3_RECORD_set_data(&wr, compressdata);
+    SSL3_RECORD_set_length(&wr, templates->buflen);
+    SSL3_RECORD_set_input(&wr, (unsigned char *)templates->buf);
+
+    /*
+     * we now 'read' from wr.input, wr.length bytes into wr.data
+     */
+
+    /* first we compress */
+    if (rl->compctx != NULL) {
+        if (!tls_do_compress(rl, &wr)
+                || !WPACKET_allocate_bytes(thispkt, wr.length, NULL)) {
+            RLAYERfatal(rl, SSL_AD_INTERNAL_ERROR, SSL_R_COMPRESSION_FAILURE);
+            goto err;
+        }
+    } else if (compressdata != NULL) {
+        if (!WPACKET_memcpy(thispkt, wr.input, wr.length)) {
+            RLAYERfatal(rl, SSL_AD_INTERNAL_ERROR, ERR_R_INTERNAL_ERROR);
+            goto err;
+        }
+        SSL3_RECORD_reset_input(&wr);
+    }
+
+    if (!rl->funcs->prepare_for_encryption(rl, mac_size, thispkt, &wr)) {
+        /* RLAYERfatal() already called */
+        goto err;
+    }
+
+    if (rl->funcs->cipher(rl, &wr, 1, 1, NULL, mac_size) < 1) {
+        if (rl->alert == SSL_AD_NO_ALERT) {
+            RLAYERfatal(rl, SSL_AD_INTERNAL_ERROR, ERR_R_INTERNAL_ERROR);
+        }
+        goto err;
+    }
+
+    if (!rl->funcs->post_encryption_processing(rl, mac_size, templates,
+                                               thispkt, &wr)) {
+        /* RLAYERfatal() already called */
+        goto err;
+    }
+
+    /* TODO(RECLAYER): FIXME */
+    ssl3_record_sequence_update(rl->sequence);
+
+    /* now let's set up wb */
+    SSL3_BUFFER_set_left(wb, SSL3_RECORD_get_length(&wr));
+
+    ret = 1;
+ err:
+    if (wpinited > 0)
+        WPACKET_cleanup(thispkt);
+    return ret;
+}
+
 const OSSL_RECORD_METHOD ossl_dtls_record_method = {
     dtls_new_record_layer,
     dtls_free,
@@ -712,5 +845,7 @@ const OSSL_RECORD_METHOD ossl_dtls_record_method = {
     tls_set_max_pipelines,
     dtls_set_in_init,
     tls_get_state,
-    tls_set_options
+    tls_set_options,
+    tls_get_compression,
+    tls_set_max_frag_len
 };