Defer Finished MAC handling until after state transition
authorMatt Caswell <matt@openssl.org>
Mon, 19 Apr 2021 14:21:54 +0000 (15:21 +0100)
committerMatt Caswell <matt@openssl.org>
Wed, 28 Apr 2021 15:23:08 +0000 (16:23 +0100)
In TLS we process received messages like this:

1) Read Message Header
2) Validate and transition state based on received message type
3) Read Message Body
4) Process Message

In DTLS we read messages like this:

1) Read Message Header and Body
2) Validate and transition state based on received message type
3) Process Message

The difference is because of the stream vs datagram semantics of the
underlying transport.

In both TLS and DTLS we were doing finished MAC processing as part of
reading the message body. This means that in DTLS this was occurring
*before* the state transition has been validated. A crash was occurring
in DTLS if a Finished message was sent in an invalid state due to
assumptions in the code that certain variables would have been setup by
the time a Finished message arrives.

To avoid this problem we shift the finished MAC processing to be after
the state transition in DTLS.

Thanks to github user @bathooman for reporting this issue.

Fixes #14906

Reviewed-by: Paul Dale <pauli@openssl.org>
Reviewed-by: Tim Hudson <tjh@openssl.org>
(Merged from https://github.com/openssl/openssl/pull/14930)

ssl/statem/statem.c
ssl/statem/statem_dtls.c
ssl/statem/statem_local.h

index 3b6e78e3f84a84d2d8359a9e54a63ede09d3365c..4c463974eaa142247bbf4f5a6576e9122d8964f3 100644 (file)
@@ -582,7 +582,7 @@ static SUB_STATE_RETURN read_state_machine(SSL *s)
                 /*
                  * In DTLS we get the whole message in one go - header and body
                  */
-                ret = dtls_get_message(s, &mt, &len);
+                ret = dtls_get_message(s, &mt);
             } else {
                 ret = tls_get_message_header(s, &mt);
             }
@@ -625,13 +625,18 @@ static SUB_STATE_RETURN read_state_machine(SSL *s)
             /* Fall through */
 
         case READ_STATE_BODY:
-            if (!SSL_IS_DTLS(s)) {
-                /* We already got this above for DTLS */
+            if (SSL_IS_DTLS(s)) {
+                /*
+                 * Actually we already have the body, but we give DTLS the
+                 * opportunity to do any further processing.
+                 */
+                ret = dtls_get_message_body(s, &len);
+            } else {
                 ret = tls_get_message_body(s, &len);
-                if (ret == 0) {
-                    /* Could be non-blocking IO */
-                    return SUB_STATE_ERROR;
-                }
+            }
+            if (ret == 0) {
+                /* Could be non-blocking IO */
+                return SUB_STATE_ERROR;
             }
 
             s->first_packet = 0;
index c4bed3d3eeb0d21cf1341b674250d83c4d2b2d9a..1fcd064ea6d4b001544dd1832d2584df061348be 100644 (file)
@@ -328,7 +328,7 @@ int dtls1_do_write(SSL *s, int type)
     return 0;
 }
 
-int dtls_get_message(SSL *s, int *mt, size_t *len)
+int dtls_get_message(SSL *s, int *mt)
 {
     struct hm_header_st *msg_hdr;
     unsigned char *p;
@@ -352,7 +352,6 @@ int dtls_get_message(SSL *s, int *mt, size_t *len)
     *mt = s->s3.tmp.message_type;
 
     p = (unsigned char *)s->init_buf->data;
-    *len = s->init_num;
 
     if (*mt == SSL3_MT_CHANGE_CIPHER_SPEC) {
         if (s->msg_callback) {
@@ -373,32 +372,54 @@ int dtls_get_message(SSL *s, int *mt, size_t *len)
     s2n(msg_hdr->seq, p);
     l2n3(0, p);
     l2n3(msg_len, p);
-    if (s->version != DTLS1_BAD_VER) {
-        p -= DTLS1_HM_HEADER_LENGTH;
-        msg_len += DTLS1_HM_HEADER_LENGTH;
-    }
 
+    memset(msg_hdr, 0, sizeof(*msg_hdr));
+
+    s->d1->handshake_read_seq++;
+
+    s->init_msg = s->init_buf->data + DTLS1_HM_HEADER_LENGTH;
+
+    return 1;
+}
+
+/*
+ * Actually we already have the message body - but this is an opportunity for
+ * DTLS to do any further processing it wants at the same point that TLS would
+ * be asked for the message body.
+ */
+int dtls_get_message_body(SSL *s, size_t *len)
+{
+    unsigned char *msg = (unsigned char *)s->init_buf->data;
+    size_t msg_len = s->init_num + DTLS1_HM_HEADER_LENGTH;
+
+    if (s->s3.tmp.message_type == SSL3_MT_CHANGE_CIPHER_SPEC) {
+        /* Nothing to be done */
+        goto end;
+    }
     /*
      * If receiving Finished, record MAC of prior handshake messages for
      * Finished verification.
      */
-    if (*mt == SSL3_MT_FINISHED && !ssl3_take_mac(s)) {
+    if (*(s->init_buf->data) == SSL3_MT_FINISHED && !ssl3_take_mac(s)) {
         /* SSLfatal() already called */
         return 0;
     }
 
-    if (!ssl3_finish_mac(s, p, msg_len))
+    if (s->version == DTLS1_BAD_VER) {
+        msg += DTLS1_HM_HEADER_LENGTH;
+        msg_len -= DTLS1_HM_HEADER_LENGTH;
+    }
+
+    if (!ssl3_finish_mac(s, msg, msg_len))
         return 0;
+
     if (s->msg_callback)
         s->msg_callback(0, s->version, SSL3_RT_HANDSHAKE,
-                        p, msg_len, s, s->msg_callback_arg);
-
-    memset(msg_hdr, 0, sizeof(*msg_hdr));
-
-    s->d1->handshake_read_seq++;
-
-    s->init_msg = s->init_buf->data + DTLS1_HM_HEADER_LENGTH;
+                        s->init_buf->data, s->init_num + DTLS1_HM_HEADER_LENGTH,
+                        s, s->msg_callback_arg);
 
+ end:
+    *len = s->init_num;
     return 1;
 }
 
index 61de225584d55f2d8fdbd79638a404bfcc5ab84b..25bfdffc6c0768df559bfe707cc700c02a388c69 100644 (file)
@@ -95,7 +95,8 @@ WORK_STATE ossl_statem_server_post_process_message(SSL *s, WORK_STATE wst);
 /* Functions for getting new message data */
 __owur int tls_get_message_header(SSL *s, int *mt);
 __owur int tls_get_message_body(SSL *s, size_t *len);
-__owur int dtls_get_message(SSL *s, int *mt, size_t *len);
+__owur int dtls_get_message(SSL *s, int *mt);
+__owur int dtls_get_message_body(SSL *s, size_t *len);
 
 /* Message construction and processing functions */
 __owur int tls_process_initial_server_flight(SSL *s);