Don't calculate the Finished MAC twice
[openssl.git] / ssl / statem / statem_dtls.c
index b2ba35763a09b7b2d360f4ca4ad97fa5c1f0216b..0ac60cbf995cc8850b3987f5e56a3b169672527d 100644 (file)
@@ -12,6 +12,7 @@
 #include <stdio.h>
 #include "../ssl_locl.h"
 #include "statem_locl.h"
+#include "internal/cryptlib.h"
 #include <openssl/buffer.h>
 #include <openssl/objects.h>
 #include <openssl/evp.h>
@@ -32,7 +33,6 @@
 
 #define RSMBLY_BITMASK_IS_COMPLETE(bitmask, msg_len, is_complete) { \
                         long ii; \
-                        OPENSSL_assert((msg_len) > 0); \
                         is_complete = 1; \
                         if (bitmask[(((msg_len) - 1) >> 3)] != bitmask_end_values[((msg_len) & 7)]) is_complete = 0; \
                         if (is_complete) for (ii = (((msg_len) - 1) >> 3) - 1; ii >= 0 ; ii--) \
@@ -122,9 +122,11 @@ int dtls1_do_write(SSL *s, int type)
         /* should have something reasonable now */
         return -1;
 
-    if (s->init_off == 0 && type == SSL3_RT_HANDSHAKE)
-        OPENSSL_assert(s->init_num ==
-                       s->d1->w_msg_hdr.msg_len + DTLS1_HM_HEADER_LENGTH);
+    if (s->init_off == 0 && type == SSL3_RT_HANDSHAKE) {
+        if (!ossl_assert(s->init_num ==
+                         s->d1->w_msg_hdr.msg_len + DTLS1_HM_HEADER_LENGTH))
+            return -1;
+    }
 
     if (s->write_hash) {
         if (s->enc_write_ctx
@@ -254,7 +256,7 @@ int dtls1_do_write(SSL *s, int type)
                 } else
                     return -1;
             } else {
-                return (-1);
+                return -1;
             }
         } else {
 
@@ -262,7 +264,8 @@ int dtls1_do_write(SSL *s, int type)
              * bad if this assert fails, only part of the handshake message
              * got sent.  but why would this happen?
              */
-            OPENSSL_assert(len == written);
+            if (!ossl_assert(len == written))
+                return -1;
 
             if (type == SSL3_RT_HANDSHAKE && !s->d1->retransmitting) {
                 /*
@@ -373,6 +376,15 @@ int dtls_get_message(SSL *s, int *mt, size_t *len)
         msg_len += DTLS1_HM_HEADER_LENGTH;
     }
 
+    /*
+     * If receiving Finished, record MAC of prior handshake messages for
+     * Finished verification.
+     */
+    if (*mt == SSL3_MT_FINISHED && !ssl3_take_mac(s)) {
+        /* SSLfatal() already called */
+        return 0;
+    }
+
     if (!ssl3_finish_mac(s, p, msg_len))
         return 0;
     if (s->msg_callback)
@@ -412,8 +424,9 @@ static int dtls1_preprocess_fragment(SSL *s, struct hm_header_st *msg_hdr)
     /* sanity checking */
     if ((frag_off + frag_len) > msg_len
             || msg_len > dtls1_max_handshake_message_len(s)) {
-        SSLerr(SSL_F_DTLS1_PREPROCESS_FRAGMENT, SSL_R_EXCESSIVE_MESSAGE_SIZE);
-        return SSL_AD_ILLEGAL_PARAMETER;
+        SSLfatal(s, SSL_AD_ILLEGAL_PARAMETER, SSL_F_DTLS1_PREPROCESS_FRAGMENT,
+                 SSL_R_EXCESSIVE_MESSAGE_SIZE);
+        return 0;
     }
 
     if (s->d1->r_msg_hdr.frag_off == 0) { /* first fragment */
@@ -422,8 +435,9 @@ static int dtls1_preprocess_fragment(SSL *s, struct hm_header_st *msg_hdr)
          * dtls_max_handshake_message_len(s) above
          */
         if (!BUF_MEM_grow_clean(s->init_buf, msg_len + DTLS1_HM_HEADER_LENGTH)) {
-            SSLerr(SSL_F_DTLS1_PREPROCESS_FRAGMENT, ERR_R_BUF_LIB);
-            return SSL_AD_INTERNAL_ERROR;
+            SSLfatal(s, SSL_AD_INTERNAL_ERROR, SSL_F_DTLS1_PREPROCESS_FRAGMENT,
+                     ERR_R_BUF_LIB);
+            return 0;
         }
 
         s->s3->tmp.message_size = msg_len;
@@ -436,13 +450,18 @@ static int dtls1_preprocess_fragment(SSL *s, struct hm_header_st *msg_hdr)
          * They must be playing with us! BTW, failure to enforce upper limit
          * would open possibility for buffer overrun.
          */
-        SSLerr(SSL_F_DTLS1_PREPROCESS_FRAGMENT, SSL_R_EXCESSIVE_MESSAGE_SIZE);
-        return SSL_AD_ILLEGAL_PARAMETER;
+        SSLfatal(s, SSL_AD_ILLEGAL_PARAMETER, SSL_F_DTLS1_PREPROCESS_FRAGMENT,
+                 SSL_R_EXCESSIVE_MESSAGE_SIZE);
+        return 0;
     }
 
-    return 0;                   /* no error */
+    return 1;
 }
 
+/*
+ * Returns 1 if there is a buffered fragment available, 0 if not, or -1 on a
+ * fatal error.
+ */
 static int dtls1_retrieve_buffered_fragment(SSL *s, size_t *len)
 {
     /*-
@@ -453,7 +472,7 @@ static int dtls1_retrieve_buffered_fragment(SSL *s, size_t *len)
      */
     pitem *item;
     hm_fragment *frag;
-    int al;
+    int ret;
 
     do {
         item = pqueue_peek(s->d1->buffered_messages);
@@ -480,9 +499,10 @@ static int dtls1_retrieve_buffered_fragment(SSL *s, size_t *len)
         size_t frag_len = frag->msg_header.frag_len;
         pqueue_pop(s->d1->buffered_messages);
 
-        al = dtls1_preprocess_fragment(s, &frag->msg_header);
+        /* Calls SSLfatal() as required */
+        ret = dtls1_preprocess_fragment(s, &frag->msg_header);
 
-        if (al == 0) {          /* no alert */
+        if (ret) {
             unsigned char *p =
                 (unsigned char *)s->init_buf->data + DTLS1_HM_HEADER_LENGTH;
             memcpy(&p[frag->msg_header.frag_off], frag->fragment,
@@ -492,14 +512,14 @@ static int dtls1_retrieve_buffered_fragment(SSL *s, size_t *len)
         dtls1_hm_fragment_free(frag);
         pitem_free(item);
 
-        if (al == 0) {
+        if (ret) {
             *len = frag_len;
             return 1;
         }
 
-        ssl3_send_alert(s, SSL3_AL_FATAL, al);
+        /* Fatal error */
         s->init_num = 0;
-        return 0;
+        return -1;
     } else {
         return 0;
     }
@@ -578,6 +598,8 @@ dtls1_reassemble_fragment(SSL *s, const struct hm_header_st *msg_hdr)
     RSMBLY_BITMASK_MARK(frag->reassembly, (long)msg_hdr->frag_off,
                         (long)(msg_hdr->frag_off + frag_len));
 
+    if (!ossl_assert(msg_hdr->msg_len > 0))
+        goto err;
     RSMBLY_BITMASK_IS_COMPLETE(frag->reassembly, (long)msg_hdr->msg_len,
                                is_complete);
 
@@ -600,7 +622,8 @@ dtls1_reassemble_fragment(SSL *s, const struct hm_header_st *msg_hdr)
          * would have returned it and control would never have reached this
          * branch.
          */
-        OPENSSL_assert(item != NULL);
+        if (!ossl_assert(item != NULL))
+            goto err;
     }
 
     return DTLS1_HM_FRAGMENT_RETRY;
@@ -697,7 +720,8 @@ dtls1_process_out_of_seq_message(SSL *s, const struct hm_header_st *msg_hdr)
          * have been processed with |dtls1_reassemble_fragment|, above, or
          * the record will have been discarded.
          */
-        OPENSSL_assert(item != NULL);
+        if (!ossl_assert(item != NULL))
+            goto err;
     }
 
     return DTLS1_HM_FRAGMENT_RETRY;
@@ -712,7 +736,7 @@ static int dtls_get_reassembled_message(SSL *s, int *errtype, size_t *len)
 {
     unsigned char wire[DTLS1_HM_HEADER_LENGTH];
     size_t mlen, frag_off, frag_len;
-    int i, al, recvd_type;
+    int i, ret, recvd_type;
     struct hm_header_st msg_hdr;
     size_t readbytes;
 
@@ -720,7 +744,12 @@ static int dtls_get_reassembled_message(SSL *s, int *errtype, size_t *len)
 
  redo:
     /* see if we have the required fragment already */
-    if (dtls1_retrieve_buffered_fragment(s, &frag_len)) {
+    ret = dtls1_retrieve_buffered_fragment(s, &frag_len);
+    if (ret < 0) {
+        /* SSLfatal() already called */
+        return 0;
+    }
+    if (ret > 0) {
         s->init_num = frag_len;
         *len = frag_len;
         return 1;
@@ -736,9 +765,9 @@ static int dtls_get_reassembled_message(SSL *s, int *errtype, size_t *len)
     }
     if (recvd_type == SSL3_RT_CHANGE_CIPHER_SPEC) {
         if (wire[0] != SSL3_MT_CCS) {
-            al = SSL_AD_UNEXPECTED_MESSAGE;
-            SSLerr(SSL_F_DTLS_GET_REASSEMBLED_MESSAGE,
-                   SSL_R_BAD_CHANGE_CIPHER_SPEC);
+            SSLfatal(s, SSL_AD_UNEXPECTED_MESSAGE,
+                     SSL_F_DTLS_GET_REASSEMBLED_MESSAGE,
+                     SSL_R_BAD_CHANGE_CIPHER_SPEC);
             goto f_err;
         }
 
@@ -753,8 +782,8 @@ static int dtls_get_reassembled_message(SSL *s, int *errtype, size_t *len)
 
     /* Handshake fails if message header is incomplete */
     if (readbytes != DTLS1_HM_HEADER_LENGTH) {
-        al = SSL_AD_UNEXPECTED_MESSAGE;
-        SSLerr(SSL_F_DTLS_GET_REASSEMBLED_MESSAGE, SSL_R_UNEXPECTED_MESSAGE);
+        SSLfatal(s, SSL_AD_UNEXPECTED_MESSAGE,
+                 SSL_F_DTLS_GET_REASSEMBLED_MESSAGE, SSL_R_UNEXPECTED_MESSAGE);
         goto f_err;
     }
 
@@ -770,8 +799,8 @@ static int dtls_get_reassembled_message(SSL *s, int *errtype, size_t *len)
      * Fragments must not span records.
      */
     if (frag_len > RECORD_LAYER_get_rrec_length(&s->rlayer)) {
-        al = SSL3_AD_ILLEGAL_PARAMETER;
-        SSLerr(SSL_F_DTLS_GET_REASSEMBLED_MESSAGE, SSL_R_BAD_LENGTH);
+        SSLfatal(s, SSL_AD_ILLEGAL_PARAMETER,
+                 SSL_F_DTLS_GET_REASSEMBLED_MESSAGE, SSL_R_BAD_LENGTH);
         goto f_err;
     }
 
@@ -810,15 +839,17 @@ static int dtls_get_reassembled_message(SSL *s, int *errtype, size_t *len)
             goto redo;
         } else {                /* Incorrectly formatted Hello request */
 
-            al = SSL_AD_UNEXPECTED_MESSAGE;
-            SSLerr(SSL_F_DTLS_GET_REASSEMBLED_MESSAGE,
-                   SSL_R_UNEXPECTED_MESSAGE);
+            SSLfatal(s, SSL_AD_UNEXPECTED_MESSAGE,
+                     SSL_F_DTLS_GET_REASSEMBLED_MESSAGE,
+                     SSL_R_UNEXPECTED_MESSAGE);
             goto f_err;
         }
     }
 
-    if ((al = dtls1_preprocess_fragment(s, &msg_hdr)))
+    if (!dtls1_preprocess_fragment(s, &msg_hdr)) {
+        /* SSLfatal() already called */
         goto f_err;
+    }
 
     if (frag_len > 0) {
         unsigned char *p =
@@ -845,8 +876,8 @@ static int dtls_get_reassembled_message(SSL *s, int *errtype, size_t *len)
      * to fail
      */
     if (readbytes != frag_len) {
-        al = SSL3_AD_ILLEGAL_PARAMETER;
-        SSLerr(SSL_F_DTLS_GET_REASSEMBLED_MESSAGE, SSL3_AD_ILLEGAL_PARAMETER);
+        SSLfatal(s, SSL_AD_ILLEGAL_PARAMETER,
+                 SSL_F_DTLS_GET_REASSEMBLED_MESSAGE, SSL_R_BAD_LENGTH);
         goto f_err;
     }
 
@@ -860,7 +891,6 @@ static int dtls_get_reassembled_message(SSL *s, int *errtype, size_t *len)
     return 1;
 
  f_err:
-    ssl3_send_alert(s, SSL3_AL_FATAL, al);
     s->init_num = 0;
     *len = 0;
     return 0;
@@ -881,8 +911,10 @@ int dtls_construct_change_cipher_spec(SSL *s, WPACKET *pkt)
         s->d1->next_handshake_write_seq++;
 
         if (!WPACKET_put_bytes_u16(pkt, s->d1->handshake_write_seq)) {
-            SSLerr(SSL_F_DTLS_CONSTRUCT_CHANGE_CIPHER_SPEC, ERR_R_INTERNAL_ERROR);
-            ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_INTERNAL_ERROR);
+            SSLfatal(s, SSL_AD_INTERNAL_ERROR,
+                     SSL_F_DTLS_CONSTRUCT_CHANGE_CIPHER_SPEC,
+                     ERR_R_INTERNAL_ERROR);
+            return 0;
         }
     }
 
@@ -896,8 +928,11 @@ WORK_STATE dtls_wait_for_dry(SSL *s)
 
     /* read app data until dry event */
     ret = BIO_dgram_sctp_wait_for_dry(SSL_get_wbio(s));
-    if (ret < 0)
+    if (ret < 0) {
+        SSLfatal(s, SSL_AD_INTERNAL_ERROR, SSL_F_DTLS_WAIT_FOR_DRY,
+                 ERR_R_INTERNAL_ERROR);
         return WORK_ERROR;
+    }
 
     if (ret == 0) {
         s->s3->in_read_app_data = 2;
@@ -913,11 +948,12 @@ WORK_STATE dtls_wait_for_dry(SSL *s)
 int dtls1_read_failed(SSL *s, int code)
 {
     if (code > 0) {
-        SSLerr(SSL_F_DTLS1_READ_FAILED, ERR_R_INTERNAL_ERROR);
-        return 1;
+        SSLfatal(s, SSL_AD_INTERNAL_ERROR,
+                 SSL_F_DTLS1_READ_FAILED, ERR_R_INTERNAL_ERROR);
+        return 0;
     }
 
-    if (!dtls1_is_timer_expired(s)) {
+    if (!dtls1_is_timer_expired(s) || ossl_statem_in_error(s)) {
         /*
          * not a timeout, none of our business, let higher layers handle
          * this.  in fact it's probably an error
@@ -981,7 +1017,8 @@ int dtls1_buffer_message(SSL *s, int is_ccs)
      * this function is called immediately after a message has been
      * serialized
      */
-    OPENSSL_assert(s->init_off == 0);
+    if (!ossl_assert(s->init_off == 0))
+        return 0;
 
     frag = dtls1_hm_fragment_new(s->init_num, 0);
     if (frag == NULL)
@@ -991,13 +1028,15 @@ int dtls1_buffer_message(SSL *s, int is_ccs)
 
     if (is_ccs) {
         /* For DTLS1_BAD_VER the header length is non-standard */
-        OPENSSL_assert(s->d1->w_msg_hdr.msg_len +
-                       ((s->version ==
-                         DTLS1_BAD_VER) ? 3 : DTLS1_CCS_HEADER_LENGTH)
-                       == (unsigned int)s->init_num);
+        if (!ossl_assert(s->d1->w_msg_hdr.msg_len +
+                         ((s->version ==
+                           DTLS1_BAD_VER) ? 3 : DTLS1_CCS_HEADER_LENGTH)
+                         == (unsigned int)s->init_num))
+            return 0;
     } else {
-        OPENSSL_assert(s->d1->w_msg_hdr.msg_len +
-                       DTLS1_HM_HEADER_LENGTH == (unsigned int)s->init_num);
+        if (!ossl_assert(s->d1->w_msg_hdr.msg_len +
+                         DTLS1_HM_HEADER_LENGTH == (unsigned int)s->init_num))
+            return 0;
     }
 
     frag->msg_header.msg_len = s->d1->w_msg_hdr.msg_len;
@@ -1045,11 +1084,6 @@ int dtls1_retransmit_message(SSL *s, unsigned short seq, int *found)
     unsigned char seq64be[8];
     struct dtls1_retransmit_state saved_state;
 
-    /*-
-      OPENSSL_assert(s->init_num == 0);
-      OPENSSL_assert(s->init_off == 0);
-     */
-
     /* XDTLS:  the requested message ought to be found, otherwise error */
     memset(seq64be, 0, sizeof(seq64be));
     seq64be[6] = (unsigned char)(seq >> 8);
@@ -1057,7 +1091,8 @@ int dtls1_retransmit_message(SSL *s, unsigned short seq, int *found)
 
     item = pqueue_find(s->d1->sent_messages, seq64be);
     if (item == NULL) {
-        SSLerr(SSL_F_DTLS1_RETRANSMIT_MESSAGE, ERR_R_INTERNAL_ERROR);
+        SSLfatal(s, SSL_AD_INTERNAL_ERROR, SSL_F_DTLS1_RETRANSMIT_MESSAGE,
+                 ERR_R_INTERNAL_ERROR);
         *found = 0;
         return 0;
     }