Prevent DTLS Finished message injection
[openssl.git] / ssl / record / rec_layer_d1.c
index 6699d2a3f1c44937efd5b35cf9e889ee19464287..cd582f32225b15be617c3ce6a484dfcf8dde7d64 100644 (file)
 int DTLS_RECORD_LAYER_new(RECORD_LAYER *rl)
 {
     DTLS_RECORD_LAYER *d;
-    
+
     if ((d = OPENSSL_malloc(sizeof(*d))) == NULL)
         return (0);
 
-
     rl->d = d;
 
     d->unprocessed_rcds.q = pqueue_new();
@@ -62,7 +61,7 @@ void DTLS_RECORD_LAYER_clear(RECORD_LAYER *rl)
     pqueue *buffered_app_data;
 
     d = rl->d;
-    
+
     while ((item = pqueue_pop(d->unprocessed_rcds.q)) != NULL) {
         rdata = (DTLS1_RECORD_DATA *)item->data;
         OPENSSL_free(rdata->rbuf.buf);
@@ -97,18 +96,14 @@ void DTLS_RECORD_LAYER_set_saved_w_epoch(RECORD_LAYER *rl, unsigned short e)
 {
     if (e == rl->d->w_epoch - 1) {
         memcpy(rl->d->curr_write_sequence,
-               rl->write_sequence,
-               sizeof(rl->write_sequence));
+               rl->write_sequence, sizeof(rl->write_sequence));
         memcpy(rl->write_sequence,
-               rl->d->last_write_sequence,
-               sizeof(rl->write_sequence));
+               rl->d->last_write_sequence, sizeof(rl->write_sequence));
     } else if (e == rl->d->w_epoch + 1) {
         memcpy(rl->d->last_write_sequence,
-               rl->write_sequence,
-               sizeof(unsigned char[8]));
+               rl->write_sequence, sizeof(unsigned char[8]));
         memcpy(rl->write_sequence,
-               rl->d->curr_write_sequence,
-               sizeof(rl->write_sequence));
+               rl->d->curr_write_sequence, sizeof(rl->write_sequence));
     }
     rl->d->w_epoch = e;
 }
@@ -118,7 +113,6 @@ void DTLS_RECORD_LAYER_resync_write(RECORD_LAYER *rl)
     memcpy(rl->write_sequence, rl->read_sequence, sizeof(rl->write_sequence));
 }
 
-
 void DTLS_RECORD_LAYER_set_write_sequence(RECORD_LAYER *rl, unsigned char *seq)
 {
     memcpy(rl->write_sequence, seq, SEQ_NUM_SIZE);
@@ -232,25 +226,73 @@ int dtls1_retrieve_buffered_record(SSL *s, record_pqueue *queue)
                    dtls1_retrieve_buffered_record((s), \
                    &((s)->rlayer.d->unprocessed_rcds))
 
-
 int dtls1_process_buffered_records(SSL *s)
 {
     pitem *item;
+    SSL3_BUFFER *rb;
+    SSL3_RECORD *rr;
+    DTLS1_BITMAP *bitmap;
+    unsigned int is_next_epoch;
+    int replayok = 1;
 
     item = pqueue_peek(s->rlayer.d->unprocessed_rcds.q);
     if (item) {
         /* Check if epoch is current. */
         if (s->rlayer.d->unprocessed_rcds.epoch != s->rlayer.d->r_epoch)
-            return (1);         /* Nothing to do. */
+            return 1;         /* Nothing to do. */
+
+        rr = RECORD_LAYER_get_rrec(&s->rlayer);
+
+        rb = RECORD_LAYER_get_rbuf(&s->rlayer);
+
+        if (SSL3_BUFFER_get_left(rb) > 0) {
+            /*
+             * We've still got data from the current packet to read. There could
+             * be a record from the new epoch in it - so don't overwrite it
+             * with the unprocessed records yet (we'll do it when we've
+             * finished reading the current packet).
+             */
+            return 1;
+        }
 
         /* Process all the records. */
         while (pqueue_peek(s->rlayer.d->unprocessed_rcds.q)) {
             dtls1_get_unprocessed_record(s);
-            if (!dtls1_process_record(s))
-                return (0);
+            bitmap = dtls1_get_bitmap(s, rr, &is_next_epoch);
+            if (bitmap == NULL) {
+                /*
+                 * Should not happen. This will only ever be NULL when the
+                 * current record is from a different epoch. But that cannot
+                 * be the case because we already checked the epoch above
+                 */
+                 SSLerr(SSL_F_DTLS1_PROCESS_BUFFERED_RECORDS,
+                        ERR_R_INTERNAL_ERROR);
+                 return 0;
+            }
+#ifndef OPENSSL_NO_SCTP
+            /* Only do replay check if no SCTP bio */
+            if (!BIO_dgram_is_sctp(SSL_get_rbio(s)))
+#endif
+            {
+                /*
+                 * Check whether this is a repeat, or aged record. We did this
+                 * check once already when we first received the record - but
+                 * we might have updated the window since then due to
+                 * records we subsequently processed.
+                 */
+                replayok = dtls1_record_replay_check(s, bitmap);
+            }
+
+            if (!replayok || !dtls1_process_record(s, bitmap)) {
+                /* dump this record */
+                rr->length = 0;
+                RECORD_LAYER_reset_packet_length(&s->rlayer);
+                continue;
+            }
+
             if (dtls1_buffer_record(s, &(s->rlayer.d->processed_rcds),
-                SSL3_RECORD_get_seq_num(s->rlayer.rrec)) < 0)
-                return -1;
+                    SSL3_RECORD_get_seq_num(s->rlayer.rrec)) < 0)
+                return 0;
         }
     }
 
@@ -261,10 +303,9 @@ int dtls1_process_buffered_records(SSL *s)
     s->rlayer.d->processed_rcds.epoch = s->rlayer.d->r_epoch;
     s->rlayer.d->unprocessed_rcds.epoch = s->rlayer.d->r_epoch + 1;
 
-    return (1);
+    return 1;
 }
 
-
 /*-
  * Return up to 'len' payload bytes received in 'type' records.
  * 'type' is one of the following:
@@ -390,7 +431,7 @@ int dtls1_read_bytes(SSL *s, int type, int *recvd_type, unsigned char *buf,
 
     /* get new packet if necessary */
     if ((SSL3_RECORD_get_length(rr) == 0)
-            || (s->rlayer.rstate == SSL_ST_READ_BODY)) {
+        || (s->rlayer.rstate == SSL_ST_READ_BODY)) {
         ret = dtls1_get_record(s);
         if (ret <= 0) {
             ret = dtls1_read_failed(s, ret);
@@ -413,7 +454,7 @@ int dtls1_read_bytes(SSL *s, int type, int *recvd_type, unsigned char *buf,
          * data for later processing rather than dropping the connection.
          */
         if (dtls1_buffer_record(s, &(s->rlayer.d->buffered_app_data),
-            SSL3_RECORD_get_seq_num(rr)) < 0) {
+                                SSL3_RECORD_get_seq_num(rr)) < 0) {
             SSLerr(SSL_F_DTLS1_READ_BYTES, ERR_R_INTERNAL_ERROR);
             return -1;
         }
@@ -432,8 +473,8 @@ int dtls1_read_bytes(SSL *s, int type, int *recvd_type, unsigned char *buf,
     }
 
     if (type == SSL3_RECORD_get_type(rr)
-            || (SSL3_RECORD_get_type(rr) == SSL3_RT_CHANGE_CIPHER_SPEC
-                && type == SSL3_RT_HANDSHAKE && recvd_type != NULL)) {
+        || (SSL3_RECORD_get_type(rr) == SSL3_RT_CHANGE_CIPHER_SPEC
+            && type == SSL3_RT_HANDSHAKE && recvd_type != NULL)) {
         /*
          * SSL3_RT_APPLICATION_DATA or
          * SSL3_RT_HANDSHAKE or
@@ -525,7 +566,7 @@ int dtls1_read_bytes(SSL *s, int type, int *recvd_type, unsigned char *buf,
         else if (SSL3_RECORD_get_type(rr) == DTLS1_RT_HEARTBEAT) {
             /* We allow a 0 return */
             if (dtls1_process_heartbeat(s, SSL3_RECORD_get_data(rr),
-                    SSL3_RECORD_get_length(rr)) < 0) {
+                                        SSL3_RECORD_get_length(rr)) < 0) {
                 return -1;
             }
             /* Exit and notify application to read again */
@@ -542,7 +583,7 @@ int dtls1_read_bytes(SSL *s, int type, int *recvd_type, unsigned char *buf,
              * Application data while renegotiating is allowed. Try again
              * reading.
              */
-            if (SSL3_RECORD_get_type(rr)  == SSL3_RT_APPLICATION_DATA) {
+            if (SSL3_RECORD_get_type(rr) == SSL3_RT_APPLICATION_DATA) {
                 BIO *bio;
                 s->s3->in_read_app_data = 2;
                 bio = SSL_get_rbio(s);
@@ -563,14 +604,14 @@ int dtls1_read_bytes(SSL *s, int type, int *recvd_type, unsigned char *buf,
              * XDTLS: In a pathological case, the Client Hello may be
              * fragmented--don't always expect dest_maxlen bytes
              */
-            if (SSL3_RECORD_get_length(rr)  < dest_maxlen) {
+            if (SSL3_RECORD_get_length(rr) < dest_maxlen) {
 #ifdef DTLS1_AD_MISSING_HANDSHAKE_MESSAGE
                 /*
                  * for normal alerts rr->length is 2, while
                  * dest_maxlen is 7 if we were to handle this
                  * non-existing alert...
                  */
-                FIX ME
+                FIX ME;
 #endif
                 s->rlayer.rstate = SSL_ST_READ_HEADER;
                 SSL3_RECORD_set_length(rr, 0);
@@ -628,8 +669,7 @@ int dtls1_read_bytes(SSL *s, int type, int *recvd_type, unsigned char *buf,
                 if (i < 0)
                     return (i);
                 if (i == 0) {
-                    SSLerr(SSL_F_DTLS1_READ_BYTES,
-                           SSL_R_SSL_HANDSHAKE_FAILURE);
+                    SSLerr(SSL_F_DTLS1_READ_BYTES, SSL_R_SSL_HANDSHAKE_FAILURE);
                     return (-1);
                 }
 
@@ -734,8 +774,7 @@ int dtls1_read_bytes(SSL *s, int type, int *recvd_type, unsigned char *buf,
 
             s->rwstate = SSL_NOTHING;
             s->s3->fatal_alert = alert_descr;
-            SSLerr(SSL_F_DTLS1_READ_BYTES,
-                   SSL_AD_REASON_OFFSET + alert_descr);
+            SSLerr(SSL_F_DTLS1_READ_BYTES, SSL_AD_REASON_OFFSET + alert_descr);
             BIO_snprintf(tmp, sizeof tmp, "%d", alert_descr);
             ERR_add_error_data(2, "SSL alert number ", tmp);
             s->shutdown |= SSL_RECEIVED_SHUTDOWN;
@@ -874,7 +913,6 @@ int dtls1_read_bytes(SSL *s, int type, int *recvd_type, unsigned char *buf,
     return (-1);
 }
 
-
         /*
          * this only happens when a client hello is received and a handshake
          * is started.
@@ -884,7 +922,7 @@ static int have_handshake_fragment(SSL *s, int type, unsigned char *buf,
 {
 
     if ((type == SSL3_RT_HANDSHAKE)
-            && (s->rlayer.d->handshake_fragment_len > 0))
+        && (s->rlayer.d->handshake_fragment_len > 0))
         /* (partially) satisfy request from storage */
     {
         unsigned char *src = s->rlayer.d->handshake_fragment;
@@ -980,7 +1018,8 @@ int do_dtls1_write(SSL *s, int type, const unsigned char *buf,
      * haven't decided which version to use yet send back using version 1.0
      * header: otherwise some clients will ignore it.
      */
-    if (s->method->version == DTLS_ANY_VERSION) {
+    if (s->method->version == DTLS_ANY_VERSION &&
+        s->max_proto_version != DTLS1_BAD_VER) {
         *(p++) = DTLS1_VERSION >> 8;
         *(p++) = DTLS1_VERSION & 0xff;
     } else {
@@ -1039,7 +1078,8 @@ int do_dtls1_write(SSL *s, int type, const unsigned char *buf,
 
     if (mac_size != 0) {
         if (s->method->ssl3_enc->mac(s, &wr,
-                &(p[SSL3_RECORD_get_length(&wr) + eivlen]), 1) < 0)
+                                     &(p[SSL3_RECORD_get_length(&wr) + eivlen]),
+                                     1) < 0)
             goto err;
         SSL3_RECORD_add_length(&wr, mac_size);
     }
@@ -1114,7 +1154,7 @@ int do_dtls1_write(SSL *s, int type, const unsigned char *buf,
 }
 
 DTLS1_BITMAP *dtls1_get_bitmap(SSL *s, SSL3_RECORD *rr,
-                                      unsigned int *is_next_epoch)
+                               unsigned int *is_next_epoch)
 {
 
     *is_next_epoch = 0;
@@ -1123,9 +1163,14 @@ DTLS1_BITMAP *dtls1_get_bitmap(SSL *s, SSL3_RECORD *rr,
     if (rr->epoch == s->rlayer.d->r_epoch)
         return &s->rlayer.d->bitmap;
 
-    /* Only HM and ALERT messages can be from the next epoch */
+    /*
+     * Only HM and ALERT messages can be from the next epoch and only if we
+     * have already processed all of the unprocessed records from the last
+     * epoch
+     */
     else if (rr->epoch == (unsigned long)(s->rlayer.d->r_epoch + 1) &&
-            (rr->type == SSL3_RT_HANDSHAKE || rr->type == SSL3_RT_ALERT)) {
+             s->rlayer.d->unprocessed_rcds.epoch != s->rlayer.d->r_epoch &&
+             (rr->type == SSL3_RT_HANDSHAKE || rr->type == SSL3_RT_ALERT)) {
         *is_next_epoch = 1;
         return &s->rlayer.d->next_bitmap;
     }
@@ -1143,8 +1188,13 @@ void dtls1_reset_seq_numbers(SSL *s, int rw)
         s->rlayer.d->r_epoch++;
         memcpy(&s->rlayer.d->bitmap, &s->rlayer.d->next_bitmap,
                sizeof(s->rlayer.d->bitmap));
-        memset(&s->rlayer.d->next_bitmap, 0,
-               sizeof(s->rlayer.d->next_bitmap));
+        memset(&s->rlayer.d->next_bitmap, 0, sizeof(s->rlayer.d->next_bitmap));
+
+        /*
+         * We must not use any buffered messages received from the previous
+         * epoch
+         */
+        dtls1_clear_received_buffer(s);
     } else {
         seq = s->rlayer.write_sequence;
         memcpy(s->rlayer.d->last_write_sequence, seq,