Fix typo in CONTRIBUTING.md
[openssl.git] / ssl / statem / statem_dtls.c
index 6068c833c657495a12269ab48c56ce8c07578f99..b37ac80a6065fa3fad0e26a5bfbccf188c580a27 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright 2005-2022 The OpenSSL Project Authors. All Rights Reserved.
+ * Copyright 2005-2024 The OpenSSL Project Authors. All Rights Reserved.
  *
  * Licensed under the Apache License 2.0 (the "License").  You may not use
  * this file except in compliance with the License.  You can obtain a copy
@@ -7,6 +7,7 @@
  * https://www.openssl.org/source/license.html
  */
 
+#include <assert.h>
 #include <limits.h>
 #include <string.h>
 #include <stdio.h>
@@ -38,9 +39,9 @@
                         if (is_complete) for (ii = (((msg_len) - 1) >> 3) - 1; ii >= 0 ; ii--) \
                                 if (bitmask[ii] != 0xff) { is_complete = 0; break; } }
 
-static unsigned char bitmask_start_values[] =
+static const unsigned char bitmask_start_values[] =
     { 0xff, 0xfe, 0xfc, 0xf8, 0xf0, 0xe0, 0xc0, 0x80 };
-static unsigned char bitmask_end_values[] =
+static const unsigned char bitmask_end_values[] =
     { 0xff, 0x01, 0x03, 0x07, 0x0f, 0x1f, 0x3f, 0x7f };
 
 static void dtls1_fix_message_header(SSL_CONNECTION *s, size_t frag_off,
@@ -61,14 +62,11 @@ static hm_fragment *dtls1_hm_fragment_new(size_t frag_len, int reassembly)
     unsigned char *buf = NULL;
     unsigned char *bitmask = NULL;
 
-    if ((frag = OPENSSL_malloc(sizeof(*frag))) == NULL) {
-        ERR_raise(ERR_LIB_SSL, ERR_R_MALLOC_FAILURE);
+    if ((frag = OPENSSL_zalloc(sizeof(*frag))) == NULL)
         return NULL;
-    }
 
     if (frag_len) {
         if ((buf = OPENSSL_malloc(frag_len)) == NULL) {
-            ERR_raise(ERR_LIB_SSL, ERR_R_MALLOC_FAILURE);
             OPENSSL_free(frag);
             return NULL;
         }
@@ -81,7 +79,6 @@ static hm_fragment *dtls1_hm_fragment_new(size_t frag_len, int reassembly)
     if (reassembly) {
         bitmask = OPENSSL_zalloc(RSMBLY_BITMASK_SIZE(frag_len));
         if (bitmask == NULL) {
-            ERR_raise(ERR_LIB_SSL, ERR_R_MALLOC_FAILURE);
             OPENSSL_free(buf);
             OPENSSL_free(frag);
             return NULL;
@@ -97,11 +94,7 @@ void dtls1_hm_fragment_free(hm_fragment *frag)
 {
     if (!frag)
         return;
-    if (frag->msg_header.is_ccs) {
-        EVP_CIPHER_CTX_free(frag->msg_header.
-                            saved_retransmit_state.enc_write_ctx);
-        EVP_MD_CTX_free(frag->msg_header.saved_retransmit_state.write_hash);
-    }
+
     OPENSSL_free(frag->fragment);
     OPENSSL_free(frag->reassembly);
     OPENSSL_free(frag);
@@ -111,13 +104,13 @@ void dtls1_hm_fragment_free(hm_fragment *frag)
  * send s->init_buf in records of type 'type' (SSL3_RT_HANDSHAKE or
  * SSL3_RT_CHANGE_CIPHER_SPEC)
  */
-int dtls1_do_write(SSL_CONNECTION *s, int type)
+int dtls1_do_write(SSL_CONNECTION *s, uint8_t type)
 {
     int ret;
     size_t written;
     size_t curr_mtu;
     int retry = 1;
-    size_t len, frag_off, mac_size, blocksize, used_len;
+    size_t len, frag_off, overhead, used_len;
     SSL *ssl = SSL_CONNECTION_GET_SSL(s);
 
     if (!dtls1_query_mtu(s))
@@ -133,21 +126,7 @@ int dtls1_do_write(SSL_CONNECTION *s, int type)
             return -1;
     }
 
-    if (s->write_hash) {
-        if (s->enc_write_ctx
-            && (EVP_CIPHER_get_flags(EVP_CIPHER_CTX_get0_cipher(s->enc_write_ctx)) &
-                EVP_CIPH_FLAG_AEAD_CIPHER) != 0)
-            mac_size = 0;
-        else
-            mac_size = EVP_MD_CTX_get_size(s->write_hash);
-    } else
-        mac_size = 0;
-
-    if (s->enc_write_ctx &&
-        (EVP_CIPHER_CTX_get_mode(s->enc_write_ctx) == EVP_CIPH_CBC_MODE))
-        blocksize = 2 * EVP_CIPHER_CTX_get_block_size(s->enc_write_ctx);
-    else
-        blocksize = 0;
+    overhead = s->rlayer.wrlmethod->get_max_record_overhead(s->rlayer.wrl);
 
     frag_off = 0;
     s->rwstate = SSL_NOTHING;
@@ -188,8 +167,7 @@ int dtls1_do_write(SSL_CONNECTION *s, int type)
             }
         }
 
-        used_len = BIO_wpending(s->wbio) + DTLS1_RT_HEADER_LENGTH
-            + mac_size + blocksize;
+        used_len = BIO_wpending(s->wbio) + overhead;
         if (s->d1->mtu > used_len)
             curr_mtu = s->d1->mtu - used_len;
         else
@@ -204,9 +182,8 @@ int dtls1_do_write(SSL_CONNECTION *s, int type)
                 s->rwstate = SSL_WRITING;
                 return ret;
             }
-            used_len = DTLS1_RT_HEADER_LENGTH + mac_size + blocksize;
-            if (s->d1->mtu > used_len + DTLS1_HM_HEADER_LENGTH) {
-                curr_mtu = s->d1->mtu - used_len;
+            if (s->d1->mtu > overhead + DTLS1_HM_HEADER_LENGTH) {
+                curr_mtu = s->d1->mtu - overhead;
             } else {
                 /* Shouldn't happen */
                 return -1;
@@ -272,6 +249,16 @@ int dtls1_do_write(SSL_CONNECTION *s, int type)
             if (!ossl_assert(len == written))
                 return -1;
 
+            /*
+             * We should not exceed the MTU size. If compression is in use
+             * then the max record overhead calculation is unreliable so we do
+             * not check in that case. We use assert rather than ossl_assert
+             * because in a production build, if this assert were ever to fail,
+             * then the best thing to do is probably carry on regardless.
+             */
+            assert(s->s3.tmp.new_compression != NULL
+                   || BIO_wpending(s->wbio) <= (int)s->d1->mtu);
+
             if (type == SSL3_RT_HANDSHAKE && !s->d1->retransmitting) {
                 /*
                  * should not be done for 'Hello Request's, but in that case
@@ -496,23 +483,64 @@ static int dtls1_retrieve_buffered_fragment(SSL_CONNECTION *s, size_t *len)
      * (2) update s->init_num
      */
     pitem *item;
+    piterator iter;
     hm_fragment *frag;
     int ret;
+    int chretran = 0;
 
+    iter = pqueue_iterator(s->d1->buffered_messages);
     do {
-        item = pqueue_peek(s->d1->buffered_messages);
+        item = pqueue_next(&iter);
         if (item == NULL)
             return 0;
 
         frag = (hm_fragment *)item->data;
 
         if (frag->msg_header.seq < s->d1->handshake_read_seq) {
-            /* This is a stale message that has been buffered so clear it */
-            pqueue_pop(s->d1->buffered_messages);
-            dtls1_hm_fragment_free(frag);
-            pitem_free(item);
-            item = NULL;
-            frag = NULL;
+            pitem *next;
+            hm_fragment *nextfrag;
+
+            if (!s->server
+                    || frag->msg_header.seq != 0
+                    || s->d1->handshake_read_seq != 1
+                    || s->statem.hand_state != DTLS_ST_SW_HELLO_VERIFY_REQUEST) {
+                /*
+                 * This is a stale message that has been buffered so clear it.
+                 * It is safe to pop this message from the queue even though
+                 * we have an active iterator
+                 */
+                pqueue_pop(s->d1->buffered_messages);
+                dtls1_hm_fragment_free(frag);
+                pitem_free(item);
+                item = NULL;
+                frag = NULL;
+            } else {
+                /*
+                 * We have fragments for a ClientHello without a cookie,
+                 * even though we have sent a HelloVerifyRequest. It is possible
+                 * that the HelloVerifyRequest got lost and this is a
+                 * retransmission of the original ClientHello
+                 */
+                next = pqueue_next(&iter);
+                if (next != NULL) {
+                    nextfrag = (hm_fragment *)next->data;
+                    if (nextfrag->msg_header.seq == s->d1->handshake_read_seq) {
+                        /*
+                        * We have fragments for both a ClientHello without
+                        * cookie and one with. Ditch the one without.
+                        */
+                        pqueue_pop(s->d1->buffered_messages);
+                        dtls1_hm_fragment_free(frag);
+                        pitem_free(item);
+                        item = next;
+                        frag = nextfrag;
+                    } else {
+                        chretran = 1;
+                    }
+                } else {
+                    chretran = 1;
+                }
+            }
         }
     } while (item == NULL);
 
@@ -520,7 +548,7 @@ static int dtls1_retrieve_buffered_fragment(SSL_CONNECTION *s, size_t *len)
     if (frag->reassembly != NULL)
         return 0;
 
-    if (s->d1->handshake_read_seq == frag->msg_header.seq) {
+    if (s->d1->handshake_read_seq == frag->msg_header.seq || chretran) {
         size_t frag_len = frag->msg_header.frag_len;
         pqueue_pop(s->d1->buffered_messages);
 
@@ -538,6 +566,16 @@ static int dtls1_retrieve_buffered_fragment(SSL_CONNECTION *s, size_t *len)
         pitem_free(item);
 
         if (ret) {
+            if (chretran) {
+                /*
+                 * We got a new ClientHello with a message sequence of 0.
+                 * Reset the read/write sequences back to the beginning.
+                 * We process it like this is the first time we've seen a
+                 * ClientHello from the client.
+                 */
+                s->d1->handshake_read_seq = 0;
+                s->d1->next_handshake_write_seq = 0;
+            }
             *len = frag_len;
             return 1;
         }
@@ -762,15 +800,19 @@ static int dtls1_process_out_of_seq_message(SSL_CONNECTION *s,
 static int dtls_get_reassembled_message(SSL_CONNECTION *s, int *errtype,
                                         size_t *len)
 {
-    unsigned char wire[DTLS1_HM_HEADER_LENGTH];
     size_t mlen, frag_off, frag_len;
-    int i, ret, recvd_type;
+    int i, ret;
+    uint8_t recvd_type;
     struct hm_header_st msg_hdr;
     size_t readbytes;
     SSL *ssl = SSL_CONNECTION_GET_SSL(s);
+    int chretran = 0;
+    unsigned char *p;
 
     *errtype = 0;
 
+    p = (unsigned char *)s->init_buf->data;
+
  redo:
     /* see if we have the required fragment already */
     ret = dtls1_retrieve_buffered_fragment(s, &frag_len);
@@ -785,7 +827,7 @@ static int dtls_get_reassembled_message(SSL_CONNECTION *s, int *errtype,
     }
 
     /* read handshake message header */
-    i = ssl->method->ssl_read_bytes(ssl, SSL3_RT_HANDSHAKE, &recvd_type, wire,
+    i = ssl->method->ssl_read_bytes(ssl, SSL3_RT_HANDSHAKE, &recvd_type, p,
                                     DTLS1_HM_HEADER_LENGTH, 0, &readbytes);
     if (i <= 0) {               /* nbio, or an error */
         s->rwstate = SSL_READING;
@@ -793,13 +835,12 @@ static int dtls_get_reassembled_message(SSL_CONNECTION *s, int *errtype,
         return 0;
     }
     if (recvd_type == SSL3_RT_CHANGE_CIPHER_SPEC) {
-        if (wire[0] != SSL3_MT_CCS) {
+        if (p[0] != SSL3_MT_CCS) {
             SSLfatal(s, SSL_AD_UNEXPECTED_MESSAGE,
                      SSL_R_BAD_CHANGE_CIPHER_SPEC);
             goto f_err;
         }
 
-        memcpy(s->init_buf->data, wire, readbytes);
         s->init_num = readbytes - 1;
         s->init_msg = s->init_buf->data + 1;
         s->s3.tmp.message_type = SSL3_MT_CHANGE_CIPHER_SPEC;
@@ -815,7 +856,7 @@ static int dtls_get_reassembled_message(SSL_CONNECTION *s, int *errtype,
     }
 
     /* parse the message fragment header */
-    dtls1_get_message_header(wire, &msg_hdr);
+    dtls1_get_message_header(p, &msg_hdr);
 
     mlen = msg_hdr.msg_len;
     frag_off = msg_hdr.frag_off;
@@ -825,7 +866,7 @@ static int dtls_get_reassembled_message(SSL_CONNECTION *s, int *errtype,
      * We must have at least frag_len bytes left in the record to be read.
      * Fragments must not span records.
      */
-    if (frag_len > RECORD_LAYER_get_rrec_length(&s->rlayer)) {
+    if (frag_len > s->rlayer.tlsrecs[s->rlayer.curr_rec].length) {
         SSLfatal(s, SSL_AD_ILLEGAL_PARAMETER, SSL_R_BAD_LENGTH);
         goto f_err;
     }
@@ -837,8 +878,20 @@ static int dtls_get_reassembled_message(SSL_CONNECTION *s, int *errtype,
      * although we're still expecting seq 0 (ClientHello)
      */
     if (msg_hdr.seq != s->d1->handshake_read_seq) {
-        *errtype = dtls1_process_out_of_seq_message(s, &msg_hdr);
-        return 0;
+        if (!s->server
+                || msg_hdr.seq != 0
+                || s->d1->handshake_read_seq != 1
+                || p[0] != SSL3_MT_CLIENT_HELLO
+                || s->statem.hand_state != DTLS_ST_SW_HELLO_VERIFY_REQUEST) {
+            *errtype = dtls1_process_out_of_seq_message(s, &msg_hdr);
+            return 0;
+        }
+        /*
+         * We received a ClientHello and sent back a HelloVerifyRequest. We
+         * now seem to have received a retransmitted initial ClientHello. That
+         * is allowed (possibly our HelloVerifyRequest got lost).
+         */
+        chretran = 1;
     }
 
     if (frag_len && frag_len < mlen) {
@@ -849,16 +902,16 @@ static int dtls_get_reassembled_message(SSL_CONNECTION *s, int *errtype,
     if (!s->server
             && s->d1->r_msg_hdr.frag_off == 0
             && s->statem.hand_state != TLS_ST_OK
-            && wire[0] == SSL3_MT_HELLO_REQUEST) {
+            && p[0] == SSL3_MT_HELLO_REQUEST) {
         /*
          * The server may always send 'Hello Request' messages -- we are
          * doing a handshake anyway now, so ignore them if their format is
          * correct. Does not count for 'Finished' MAC.
          */
-        if (wire[1] == 0 && wire[2] == 0 && wire[3] == 0) {
+        if (p[1] == 0 && p[2] == 0 && p[3] == 0) {
             if (s->msg_callback)
                 s->msg_callback(0, s->version, SSL3_RT_HANDSHAKE,
-                                wire, DTLS1_HM_HEADER_LENGTH, ssl,
+                                p, DTLS1_HM_HEADER_LENGTH, ssl,
                                 s->msg_callback_arg);
 
             s->init_num = 0;
@@ -876,8 +929,7 @@ static int dtls_get_reassembled_message(SSL_CONNECTION *s, int *errtype,
     }
 
     if (frag_len > 0) {
-        unsigned char *p =
-            (unsigned char *)s->init_buf->data + DTLS1_HM_HEADER_LENGTH;
+        p += DTLS1_HM_HEADER_LENGTH;
 
         i = ssl->method->ssl_read_bytes(ssl, SSL3_RT_HANDSHAKE, NULL,
                                         &p[frag_off], frag_len, 0, &readbytes);
@@ -904,6 +956,17 @@ static int dtls_get_reassembled_message(SSL_CONNECTION *s, int *errtype,
         goto f_err;
     }
 
+    if (chretran) {
+        /*
+         * We got a new ClientHello with a message sequence of 0.
+         * Reset the read/write sequences back to the beginning.
+         * We process it like this is the first time we've seen a ClientHello
+         * from the client.
+         */
+        s->d1->handshake_read_seq = 0;
+        s->d1->next_handshake_write_seq = 0;
+    }
+
     /*
      * Note that s->init_num is *not* used as current offset in
      * s->init_buf->data, but as a counter summing up fragments' lengths: as
@@ -921,25 +984,23 @@ static int dtls_get_reassembled_message(SSL_CONNECTION *s, int *errtype,
 
 /*-
  * for these 2 messages, we need to
- * ssl->enc_read_ctx                    re-init
- * ssl->rlayer.read_sequence            zero
- * ssl->s3.read_mac_secret             re-init
  * ssl->session->read_sym_enc           assign
  * ssl->session->read_compression       assign
  * ssl->session->read_hash              assign
  */
-int dtls_construct_change_cipher_spec(SSL_CONNECTION *s, WPACKET *pkt)
+CON_FUNC_RETURN dtls_construct_change_cipher_spec(SSL_CONNECTION *s,
+                                                  WPACKET *pkt)
 {
     if (s->version == DTLS1_BAD_VER) {
         s->d1->next_handshake_write_seq++;
 
         if (!WPACKET_put_bytes_u16(pkt, s->d1->handshake_write_seq)) {
             SSLfatal(s, SSL_AD_INTERNAL_ERROR, ERR_R_INTERNAL_ERROR);
-            return 0;
+            return CON_FUNC_ERROR;
         }
     }
 
-    return 1;
+    return CON_FUNC_SUCCESS;
 }
 
 #ifndef OPENSSL_NO_SCTP
@@ -1090,12 +1151,9 @@ int dtls1_buffer_message(SSL_CONNECTION *s, int is_ccs)
     frag->msg_header.is_ccs = is_ccs;
 
     /* save current state */
-    frag->msg_header.saved_retransmit_state.enc_write_ctx = s->enc_write_ctx;
-    frag->msg_header.saved_retransmit_state.write_hash = s->write_hash;
-    frag->msg_header.saved_retransmit_state.compress = s->compress;
-    frag->msg_header.saved_retransmit_state.session = s->session;
-    frag->msg_header.saved_retransmit_state.epoch =
-        DTLS_RECORD_LAYER_get_w_epoch(&s->rlayer);
+    frag->msg_header.saved_retransmit_state.wrlmethod = s->rlayer.wrlmethod;
+    frag->msg_header.saved_retransmit_state.wrl = s->rlayer.wrl;
+
 
     memset(seq64be, 0, sizeof(seq64be));
     seq64be[6] =
@@ -1157,32 +1215,27 @@ int dtls1_retransmit_message(SSL_CONNECTION *s, unsigned short seq, int *found)
                                  frag->msg_header.frag_len);
 
     /* save current state */
-    saved_state.enc_write_ctx = s->enc_write_ctx;
-    saved_state.write_hash = s->write_hash;
-    saved_state.compress = s->compress;
-    saved_state.session = s->session;
-    saved_state.epoch = DTLS_RECORD_LAYER_get_w_epoch(&s->rlayer);
+    saved_state.wrlmethod = s->rlayer.wrlmethod;
+    saved_state.wrl = s->rlayer.wrl;
 
     s->d1->retransmitting = 1;
 
     /* restore state in which the message was originally sent */
-    s->enc_write_ctx = frag->msg_header.saved_retransmit_state.enc_write_ctx;
-    s->write_hash = frag->msg_header.saved_retransmit_state.write_hash;
-    s->compress = frag->msg_header.saved_retransmit_state.compress;
-    s->session = frag->msg_header.saved_retransmit_state.session;
-    DTLS_RECORD_LAYER_set_saved_w_epoch(&s->rlayer,
-                                        frag->msg_header.
-                                        saved_retransmit_state.epoch);
+    s->rlayer.wrlmethod = frag->msg_header.saved_retransmit_state.wrlmethod;
+    s->rlayer.wrl = frag->msg_header.saved_retransmit_state.wrl;
+
+    /*
+     * The old wrl may be still pointing at an old BIO. Update it to what we're
+     * using now.
+     */
+    s->rlayer.wrlmethod->set1_bio(s->rlayer.wrl, s->wbio);
 
     ret = dtls1_do_write(s, frag->msg_header.is_ccs ?
                          SSL3_RT_CHANGE_CIPHER_SPEC : SSL3_RT_HANDSHAKE);
 
     /* restore current state */
-    s->enc_write_ctx = saved_state.enc_write_ctx;
-    s->write_hash = saved_state.write_hash;
-    s->compress = saved_state.compress;
-    s->session = saved_state.session;
-    DTLS_RECORD_LAYER_set_saved_w_epoch(&s->rlayer, saved_state.epoch);
+    s->rlayer.wrlmethod = saved_state.wrlmethod;
+    s->rlayer.wrl = saved_state.wrl;
 
     s->d1->retransmitting = 0;
 
@@ -1242,7 +1295,8 @@ static unsigned char *dtls1_write_message_header(SSL_CONNECTION *s,
     return p;
 }
 
-void dtls1_get_message_header(unsigned char *data, struct hm_header_st *msg_hdr)
+void dtls1_get_message_header(const unsigned char *data, struct
+                              hm_header_st *msg_hdr)
 {
     memset(msg_hdr, 0, sizeof(*msg_hdr));
     msg_hdr->type = *(data++);