Revert "Avoid duplication."
[openssl.git] / ssl / d1_both.c
index 3af3ba15cc0ea8f585b553f575fc83817ee28004..155b8bffe0820792cff2fed5b9b2318b6cfa32ab 100644 (file)
@@ -170,7 +170,7 @@ static hm_fragment *dtls1_hm_fragment_new(unsigned long frag_len,
     unsigned char *buf = NULL;
     unsigned char *bitmask = NULL;
 
     unsigned char *buf = NULL;
     unsigned char *bitmask = NULL;
 
-    frag = OPENSSL_malloc(sizeof(hm_fragment));
+    frag = OPENSSL_malloc(sizeof(*frag));
     if (frag == NULL)
         return NULL;
 
     if (frag == NULL)
         return NULL;
 
@@ -467,7 +467,7 @@ long dtls1_get_message(SSL *s, int st1, int stn, int mt, long max, int *ok)
     }
 
     msg_hdr = &s->d1->r_msg_hdr;
     }
 
     msg_hdr = &s->d1->r_msg_hdr;
-    memset(msg_hdr, 0x00, sizeof(struct hm_header_st));
+    memset(msg_hdr, 0, sizeof(*msg_hdr));
 
  again:
     i = dtls1_get_message_fragment(s, st1, stn, max, ok);
 
  again:
     i = dtls1_get_message_fragment(s, st1, stn, max, ok);
@@ -478,6 +478,12 @@ long dtls1_get_message(SSL *s, int st1, int stn, int mt, long max, int *ok)
         return i;
     }
 
         return i;
     }
 
+    if (mt >= 0 && s->s3->tmp.message_type != mt) {
+        al = SSL_AD_UNEXPECTED_MESSAGE;
+        SSLerr(SSL_F_DTLS1_GET_MESSAGE, SSL_R_UNEXPECTED_MESSAGE);
+        goto f_err;
+    }
+
     p = (unsigned char *)s->init_buf->data;
     msg_len = msg_hdr->msg_len;
 
     p = (unsigned char *)s->init_buf->data;
     msg_len = msg_hdr->msg_len;
 
@@ -497,7 +503,7 @@ long dtls1_get_message(SSL *s, int st1, int stn, int mt, long max, int *ok)
         s->msg_callback(0, s->version, SSL3_RT_HANDSHAKE,
                         p, msg_len, s, s->msg_callback_arg);
 
         s->msg_callback(0, s->version, SSL3_RT_HANDSHAKE,
                         p, msg_len, s, s->msg_callback_arg);
 
-    memset(msg_hdr, 0x00, sizeof(struct hm_header_st));
+    memset(msg_hdr, 0, sizeof(*msg_hdr));
 
     /* Don't change sequence numbers while listening */
     if (!s->d1->listen)
 
     /* Don't change sequence numbers while listening */
     if (!s->d1->listen)
@@ -862,6 +868,20 @@ dtls1_get_message_fragment(SSL *s, int st1, int stn, long max, int *ok)
     /* parse the message fragment header */
     dtls1_get_message_header(wire, &msg_hdr);
 
     /* parse the message fragment header */
     dtls1_get_message_header(wire, &msg_hdr);
 
+    len = msg_hdr.msg_len;
+    frag_off = msg_hdr.frag_off;
+    frag_len = msg_hdr.frag_len;
+
+    /*
+     * We must have at least frag_len bytes left in the record to be read.
+     * Fragments must not span records.
+     */
+    if (frag_len > RECORD_LAYER_get_rrec_length(&s->rlayer)) {
+        al = SSL3_AD_ILLEGAL_PARAMETER;
+        SSLerr(SSL_F_DTLS1_GET_MESSAGE_FRAGMENT, SSL_R_BAD_LENGTH);
+        goto f_err;
+    }
+
     /*
      * if this is a future (or stale) message it gets buffered
      * (or dropped)--no further processing at this time
     /*
      * if this is a future (or stale) message it gets buffered
      * (or dropped)--no further processing at this time
@@ -872,10 +892,6 @@ dtls1_get_message_fragment(SSL *s, int st1, int stn, long max, int *ok)
         && !(s->d1->listen && msg_hdr.seq == 1))
         return dtls1_process_out_of_seq_message(s, &msg_hdr, ok);
 
         && !(s->d1->listen && msg_hdr.seq == 1))
         return dtls1_process_out_of_seq_message(s, &msg_hdr, ok);
 
-    len = msg_hdr.msg_len;
-    frag_off = msg_hdr.frag_off;
-    frag_len = msg_hdr.frag_len;
-
     if (frag_len && frag_len < len)
         return dtls1_reassemble_fragment(s, &msg_hdr, ok);
 
     if (frag_len && frag_len < len)
         return dtls1_reassemble_fragment(s, &msg_hdr, ok);
 
@@ -906,17 +922,16 @@ dtls1_get_message_fragment(SSL *s, int st1, int stn, long max, int *ok)
     if ((al = dtls1_preprocess_fragment(s, &msg_hdr, max)))
         goto f_err;
 
     if ((al = dtls1_preprocess_fragment(s, &msg_hdr, max)))
         goto f_err;
 
-    /* XDTLS:  ressurect this when restart is in place */
-    s->state = stn;
-
     if (frag_len > 0) {
         unsigned char *p =
             (unsigned char *)s->init_buf->data + DTLS1_HM_HEADER_LENGTH;
 
         i = s->method->ssl_read_bytes(s, SSL3_RT_HANDSHAKE,
                                       &p[frag_off], frag_len, 0);
     if (frag_len > 0) {
         unsigned char *p =
             (unsigned char *)s->init_buf->data + DTLS1_HM_HEADER_LENGTH;
 
         i = s->method->ssl_read_bytes(s, SSL3_RT_HANDSHAKE,
                                       &p[frag_off], frag_len, 0);
+
         /*
         /*
-         * XDTLS: fix this--message fragments cannot span multiple packets
+         * This shouldn't ever fail due to NBIO because we already checked
+         * that we have enough data in the record
          */
         if (i <= 0) {
             s->rwstate = SSL_READING;
          */
         if (i <= 0) {
             s->rwstate = SSL_READING;
@@ -937,6 +952,7 @@ dtls1_get_message_fragment(SSL *s, int st1, int stn, long max, int *ok)
     }
 
     *ok = 1;
     }
 
     *ok = 1;
+    s->state = stn;
 
     /*
      * Note that s->init_num is *not* used as current offset in
 
     /*
      * Note that s->init_num is *not* used as current offset in
@@ -1289,7 +1305,7 @@ unsigned int dtls1_min_mtu(SSL *s)
 void
 dtls1_get_message_header(unsigned char *data, struct hm_header_st *msg_hdr)
 {
 void
 dtls1_get_message_header(unsigned char *data, struct hm_header_st *msg_hdr)
 {
-    memset(msg_hdr, 0x00, sizeof(struct hm_header_st));
+    memset(msg_hdr, 0, sizeof(*msg_hdr));
     msg_hdr->type = *(data++);
     n2l3(data, msg_hdr->msg_len);
 
     msg_hdr->type = *(data++);
     n2l3(data, msg_hdr->msg_len);
 
@@ -1298,13 +1314,6 @@ dtls1_get_message_header(unsigned char *data, struct hm_header_st *msg_hdr)
     n2l3(data, msg_hdr->frag_len);
 }
 
     n2l3(data, msg_hdr->frag_len);
 }
 
-void dtls1_get_ccs_header(unsigned char *data, struct ccs_header_st *ccs_hdr)
-{
-    memset(ccs_hdr, 0x00, sizeof(struct ccs_header_st));
-
-    ccs_hdr->type = *(data++);
-}
-
 int dtls1_shutdown(SSL *s)
 {
     int ret;
 int dtls1_shutdown(SSL *s)
 {
     int ret;