Convert dtls_write_records to use standard record layer functions
[openssl.git] / ssl / record / methods / dtls_meth.c
index d5dae75c4ff4e56e73da8a1fc34b5c54a86749cd..e6c71ed1e7b9b5c05d8b30144951254161d45f7e 100644 (file)
@@ -93,9 +93,9 @@ static DTLS_BITMAP *dtls_get_bitmap(OSSL_RECORD_LAYER *rl, SSL3_RECORD *rr,
      * have already processed all of the unprocessed records from the last
      * epoch
      */
-    else if (rr->epoch == (unsigned long)(rl->epoch + 1) &&
-             rl->unprocessed_rcds.epoch != rl->epoch &&
-             (rr->type == SSL3_RT_HANDSHAKE || rr->type == SSL3_RT_ALERT)) {
+    else if (rr->epoch == (unsigned long)(rl->epoch + 1)
+             && rl->unprocessed_rcds.epoch != rl->epoch
+             && (rr->type == SSL3_RT_HANDSHAKE || rr->type == SSL3_RT_ALERT)) {
         *is_next_epoch = 1;
         return &rl->next_bitmap;
     }
@@ -122,7 +122,7 @@ static int dtls_process_record(OSSL_RECORD_LAYER *rl, DTLS_BITMAP *bitmap)
     rr = &rl->rrec[0];
 
     /*
-     * At this point, rl->packet_length == SSL3_RT_HEADER_LNGTH + rr->length,
+     * At this point, rl->packet_length == DTLS1_RT_HEADER_LENGTH + rr->length,
      * and we have that many bytes in rl->packet
      */
     rr->input = &(rl->packet[DTLS1_RT_HEADER_LENGTH]);
@@ -155,14 +155,14 @@ static int dtls_process_record(OSSL_RECORD_LAYER *rl, DTLS_BITMAP *bitmap)
         if (tmpmd != NULL) {
             imac_size = EVP_MD_get_size(tmpmd);
             if (!ossl_assert(imac_size >= 0 && imac_size <= EVP_MAX_MD_SIZE)) {
-                    RLAYERfatal(rl, SSL_AD_INTERNAL_ERROR, ERR_R_EVP_LIB);
-                    return 0;
+                RLAYERfatal(rl, SSL_AD_INTERNAL_ERROR, ERR_R_EVP_LIB);
+                return 0;
             }
             mac_size = (size_t)imac_size;
         }
     }
 
-    if (rl->use_etm && rl->md_ctx) {
+    if (rl->use_etm && rl->md_ctx != NULL) {
         unsigned char *mac;
 
         if (rr->orig_len < mac_size) {
@@ -221,7 +221,7 @@ static int dtls_process_record(OSSL_RECORD_LAYER *rl, DTLS_BITMAP *bitmap)
             && (EVP_MD_CTX_get0_md(rl->md_ctx) != NULL)) {
         /* rl->md_ctx != NULL => mac_size != -1 */
 
-        i = rl->funcs->mac(rl, rr, md, 0 /* not send */ );
+        i = rl->funcs->mac(rl, rr, md, 0 /* not send */);
         if (i == 0 || macbuf.mac == NULL
             || CRYPTO_memcmp(md, macbuf.mac, mac_size) != 0)
             enc_err = 0;
@@ -237,7 +237,7 @@ static int dtls_process_record(OSSL_RECORD_LAYER *rl, DTLS_BITMAP *bitmap)
     }
 
     /* r->length is now just compressed */
-    if (rl->expand != NULL) {
+    if (rl->compctx != NULL) {
         if (rr->length > SSL3_RT_MAX_COMPRESSED_LENGTH) {
             RLAYERfatal(rl, SSL_AD_RECORD_OVERFLOW,
                         SSL_R_COMPRESSED_LENGTH_TOO_LONG);
@@ -253,12 +253,11 @@ static int dtls_process_record(OSSL_RECORD_LAYER *rl, DTLS_BITMAP *bitmap)
      * Check if the received packet overflows the current Max Fragment
      * Length setting.
      */
-    if (rl->max_frag_len > 0 && rr->length > rl->max_frag_len) {
+    if (rr->length > rl->max_frag_len) {
         RLAYERfatal(rl, SSL_AD_RECORD_OVERFLOW, SSL_R_DATA_LENGTH_TOO_LONG);
         goto end;
     }
 
-
     rr->off = 0;
     /*-
      * So at this point the following is true
@@ -313,9 +312,8 @@ static int dtls_rlayer_buffer_record(OSSL_RECORD_LAYER *rl, record_pqueue *queue
     memset(&rl->rbuf, 0, sizeof(SSL3_BUFFER));
     memset(&rl->rrec[0], 0, sizeof(rl->rrec[0]));
 
-
-    if (!rlayer_setup_read_buffer(rl)) {
-        /* SSLfatal() already called */
+    if (!tls_setup_read_buffer(rl)) {
+        /* RLAYERfatal() already called */
         OPENSSL_free(rdata->rbuf.buf);
         OPENSSL_free(rdata);
         pitem_free(item);
@@ -397,7 +395,7 @@ int dtls_get_more_records(OSSL_RECORD_LAYER *rl)
     rr = rl->rrec;
 
     if (rl->rbuf.buf == NULL) {
-        if (!rlayer_setup_read_buffer(rl)) {
+        if (!tls_setup_read_buffer(rl)) {
             /* RLAYERfatal() already called */
             return OSSL_RECORD_RETURN_FATAL;
         }
@@ -419,7 +417,7 @@ int dtls_get_more_records(OSSL_RECORD_LAYER *rl)
                                  SSL3_BUFFER_get_len(&rl->rbuf), 0, 1, &n);
         /* read timeout is handled by dtls1_read_bytes */
         if (rret < OSSL_RECORD_RETURN_SUCCESS) {
-            /* SSLfatal() already called if appropriate */
+            /* RLAYERfatal() already called if appropriate */
             return rret;         /* error or non-blocking */
         }
 
@@ -433,8 +431,9 @@ int dtls_get_more_records(OSSL_RECORD_LAYER *rl)
 
         p = rl->packet;
 
-        rl->msg_callback(0, 0, SSL3_RT_HEADER, p, DTLS1_RT_HEADER_LENGTH,
-                         rl->cbarg);
+        if (rl->msg_callback != NULL)
+            rl->msg_callback(0, 0, SSL3_RT_HEADER, p, DTLS1_RT_HEADER_LENGTH,
+                            rl->cbarg);
 
         /* Pull apart the header into the DTLS1_RECORD */
         rr->type = *(p++);
@@ -463,10 +462,9 @@ int dtls_get_more_records(OSSL_RECORD_LAYER *rl)
             }
         }
 
-
         if (ssl_major !=
                 (rl->version == DTLS_ANY_VERSION ? DTLS1_VERSION_MAJOR
-                                                   : rl->version >> 8)) {
+                                                 : rl->version >> 8)) {
             /* wrong version, silently discard record */
             rr->length = 0;
             rl->packet_length = 0;
@@ -480,13 +478,11 @@ int dtls_get_more_records(OSSL_RECORD_LAYER *rl)
             goto again;
         }
 
-
         /*
          * If received packet overflows maximum possible fragment length then
          * silently discard it
          */
-        if (rl->max_frag_len > 0
-                && rr->length > rl->max_frag_len + SSL3_RT_MAX_ENCRYPTED_OVERHEAD) {
+        if (rr->length > rl->max_frag_len + SSL3_RT_MAX_ENCRYPTED_OVERHEAD) {
             /* record too long, silently discard it */
             rr->length = 0;
             rl->packet_length = 0;
@@ -498,8 +494,7 @@ int dtls_get_more_records(OSSL_RECORD_LAYER *rl)
 
     /* rl->rstate == SSL_ST_READ_BODY, get and decode the data */
 
-    if (rr->length >
-        rl->packet_length - DTLS1_RT_HEADER_LENGTH) {
+    if (rr->length > rl->packet_length - DTLS1_RT_HEADER_LENGTH) {
         /* now rl->packet_length == DTLS1_RT_HEADER_LENGTH */
         more = rr->length;
         rret = rl->funcs->read_n(rl, more, more, 1, 1, &n);
@@ -554,10 +549,9 @@ int dtls_get_more_records(OSSL_RECORD_LAYER *rl)
      */
     if (is_next_epoch) {
         if (rl->in_init) {
-            if (dtls_rlayer_buffer_record(rl,
-                    &(rl->unprocessed_rcds),
-                    rr->seq_num) < 0) {
-                /* SSLfatal() already called */
+            if (dtls_rlayer_buffer_record(rl, &(rl->unprocessed_rcds),
+                                          rr->seq_num) < 0) {
+                /* RLAYERfatal() already called */
                 return OSSL_RECORD_RETURN_FATAL;
             }
         }
@@ -578,7 +572,6 @@ int dtls_get_more_records(OSSL_RECORD_LAYER *rl)
 
     rl->num_recs = 1;
     return OSSL_RECORD_RETURN_SUCCESS;
-
 }
 
 static int dtls_free(OSSL_RECORD_LAYER *rl)
@@ -634,7 +627,7 @@ dtls_new_record_layer(OSSL_LIB_CTX *libctx, const char *propq, int vers,
                       size_t ivlen, unsigned char *mackey, size_t mackeylen,
                       const EVP_CIPHER *ciph, size_t taglen,
                       int mactype,
-                      const EVP_MD *md, const SSL_COMP *comp, BIO *prev,
+                      const EVP_MD *md, COMP_METHOD *comp, BIO *prev,
                       BIO *transport, BIO *next, BIO_ADDR *local, BIO_ADDR *peer,
                       const OSSL_PARAM *settings, const OSSL_PARAM *options,
                       const OSSL_DISPATCH *fns, void *cbarg,
@@ -642,7 +635,6 @@ dtls_new_record_layer(OSSL_LIB_CTX *libctx, const char *propq, int vers,
 {
     int ret;
 
-
     ret = tls_int_new_record_layer(libctx, propq, vers, role, direction, level,
                                    key, keylen, iv, ivlen, mackey, mackeylen,
                                    ciph, taglen, mactype, md, comp, prev,
@@ -654,10 +646,11 @@ dtls_new_record_layer(OSSL_LIB_CTX *libctx, const char *propq, int vers,
 
     (*retrl)->unprocessed_rcds.q = pqueue_new();
     (*retrl)->processed_rcds.q = pqueue_new();
-    if ((*retrl)->unprocessed_rcds.q == NULL || (*retrl)->processed_rcds.q == NULL) {
+    if ((*retrl)->unprocessed_rcds.q == NULL
+            || (*retrl)->processed_rcds.q == NULL) {
         dtls_free(*retrl);
         *retrl = NULL;
-        RLAYERfatal(*retrl, SSL_AD_INTERNAL_ERROR, ERR_R_MALLOC_FAILURE);
+        ERR_raise(ERR_LIB_SSL, ERR_R_SSL_LIB);
         return OSSL_RECORD_RETURN_FATAL;
     }
 
@@ -685,8 +678,8 @@ dtls_new_record_layer(OSSL_LIB_CTX *libctx, const char *propq, int vers,
     }
 
     ret = (*retrl)->funcs->set_crypto_state(*retrl, level, key, keylen, iv,
-                                             ivlen, mackey, mackeylen, ciph,
-                                             taglen, mactype, md, comp);
+                                            ivlen, mackey, mackeylen, ciph,
+                                            taglen, mactype, md, comp);
 
  err:
     if (ret != OSSL_RECORD_RETURN_SUCCESS) {
@@ -696,6 +689,140 @@ dtls_new_record_layer(OSSL_LIB_CTX *libctx, const char *propq, int vers,
     return ret;
 }
 
+int dtls_prepare_record_header(OSSL_RECORD_LAYER *rl,
+                               WPACKET *thispkt,
+                               OSSL_RECORD_TEMPLATE *templ,
+                               unsigned int rectype,
+                               unsigned char **recdata)
+{
+    size_t maxcomplen;
+
+    *recdata = NULL;
+
+    maxcomplen = templ->buflen;
+    if (rl->compctx != NULL)
+        maxcomplen += SSL3_RT_MAX_COMPRESSED_OVERHEAD;
+
+    if (!WPACKET_put_bytes_u8(thispkt, rectype)
+            || !WPACKET_put_bytes_u16(thispkt, templ->version)
+            || !WPACKET_put_bytes_u16(thispkt, rl->epoch)
+            || !WPACKET_memcpy(thispkt, &(rl->sequence[2]), 6)
+            || !WPACKET_start_sub_packet_u16(thispkt)
+            || (rl->eivlen > 0
+                && !WPACKET_allocate_bytes(thispkt, rl->eivlen, NULL))
+            || (maxcomplen > 0
+                && !WPACKET_reserve_bytes(thispkt, maxcomplen,
+                                          recdata))) {
+        RLAYERfatal(rl, SSL_AD_INTERNAL_ERROR, ERR_R_INTERNAL_ERROR);
+        return 0;
+    }
+
+    return 1;
+}
+
+int dtls_write_records(OSSL_RECORD_LAYER *rl, OSSL_RECORD_TEMPLATE *templates,
+                       size_t numtempl)
+{
+    int mac_size = 0;
+    SSL3_RECORD wr;
+    SSL3_BUFFER *wb;
+    WPACKET pkt, *thispkt = &pkt;
+    size_t wpinited = 0;
+    int ret = 0;
+    unsigned char *compressdata = NULL;
+
+    if (rl->md_ctx != NULL && EVP_MD_CTX_get0_md(rl->md_ctx) != NULL) {
+        mac_size = EVP_MD_CTX_get_size(rl->md_ctx);
+        if (mac_size < 0) {
+            RLAYERfatal(rl, SSL_AD_INTERNAL_ERROR, ERR_R_INTERNAL_ERROR);
+            return 0;
+        }
+    }
+
+    if (numtempl != 1) {
+        /* Should not happen */
+        RLAYERfatal(rl, SSL_AD_INTERNAL_ERROR, ERR_R_INTERNAL_ERROR);
+        return 0;
+    }
+
+    if (!rl->funcs->allocate_write_buffers(rl, templates, numtempl, NULL)) {
+        /* RLAYERfatal() already called */
+        return 0;
+    }
+
+    if (!rl->funcs->initialise_write_packets(rl, templates, numtempl,
+                                             NULL, thispkt, rl->wbuf,
+                                             &wpinited)) {
+        /* RLAYERfatal() already called */
+        return 0;
+    }
+
+    wb = rl->wbuf;
+
+    SSL3_RECORD_set_type(&wr, templates->type);
+    SSL3_RECORD_set_rec_version(&wr, templates->version);
+
+    if (!rl->funcs->prepare_record_header(rl, thispkt, templates,
+                                          templates->type, &compressdata)) {
+        /* RLAYERfatal() already called */
+        goto err;
+    }
+
+    /* lets setup the record stuff. */
+    SSL3_RECORD_set_data(&wr, compressdata);
+    SSL3_RECORD_set_length(&wr, templates->buflen);
+    SSL3_RECORD_set_input(&wr, (unsigned char *)templates->buf);
+
+    /*
+     * we now 'read' from wr.input, wr.length bytes into wr.data
+     */
+
+    /* first we compress */
+    if (rl->compctx != NULL) {
+        if (!tls_do_compress(rl, &wr)
+                || !WPACKET_allocate_bytes(thispkt, wr.length, NULL)) {
+            RLAYERfatal(rl, SSL_AD_INTERNAL_ERROR, SSL_R_COMPRESSION_FAILURE);
+            goto err;
+        }
+    } else if (compressdata != NULL) {
+        if (!WPACKET_memcpy(thispkt, wr.input, wr.length)) {
+            RLAYERfatal(rl, SSL_AD_INTERNAL_ERROR, ERR_R_INTERNAL_ERROR);
+            goto err;
+        }
+        SSL3_RECORD_reset_input(&wr);
+    }
+
+    if (!rl->funcs->prepare_for_encryption(rl, mac_size, thispkt, &wr)) {
+        /* RLAYERfatal() already called */
+        goto err;
+    }
+
+    if (rl->funcs->cipher(rl, &wr, 1, 1, NULL, mac_size) < 1) {
+        if (rl->alert == SSL_AD_NO_ALERT) {
+            RLAYERfatal(rl, SSL_AD_INTERNAL_ERROR, ERR_R_INTERNAL_ERROR);
+        }
+        goto err;
+    }
+
+    if (!rl->funcs->post_encryption_processing(rl, mac_size, templates,
+                                               thispkt, &wr)) {
+        /* RLAYERfatal() already called */
+        goto err;
+    }
+
+    /* TODO(RECLAYER): FIXME */
+    ssl3_record_sequence_update(rl->sequence);
+
+    /* now let's set up wb */
+    SSL3_BUFFER_set_left(wb, SSL3_RECORD_get_length(&wr));
+
+    ret = 1;
+ err:
+    if (wpinited > 0)
+        WPACKET_cleanup(thispkt);
+    return ret;
+}
+
 const OSSL_RECORD_METHOD ossl_dtls_record_method = {
     dtls_new_record_layer,
     dtls_free,
@@ -718,5 +845,7 @@ const OSSL_RECORD_METHOD ossl_dtls_record_method = {
     tls_set_max_pipelines,
     dtls_set_in_init,
     tls_get_state,
-    tls_set_options
+    tls_set_options,
+    tls_get_compression,
+    tls_set_max_frag_len
 };