ct_locl.h moved, reflect it in crypto/ct/Makefile
[openssl.git] / ssl / d1_both.c
index a1499da3eb9f44827e37bb69dc412727939b4458..02a464e4f0e6445b0b2e87835ad13a2fd32a567b 100644 (file)
@@ -160,8 +160,8 @@ static void dtls1_set_message_header_int(SSL *s, unsigned char mt,
                                          unsigned short seq_num,
                                          unsigned long frag_off,
                                          unsigned long frag_len);
-static long dtls1_get_message_fragment(SSL *s, int st1, int stn, long max,
-                                       int *ok);
+static long dtls1_get_message_fragment(SSL *s, int st1, int stn, int mt,
+                                       long max, int *ok);
 
 static hm_fragment *dtls1_hm_fragment_new(unsigned long frag_len,
                                           int reassembly)
@@ -187,13 +187,12 @@ static hm_fragment *dtls1_hm_fragment_new(unsigned long frag_len,
 
     /* Initialize reassembly bitmask if necessary */
     if (reassembly) {
-        bitmask = OPENSSL_malloc(RSMBLY_BITMASK_SIZE(frag_len));
+        bitmask = OPENSSL_zalloc(RSMBLY_BITMASK_SIZE(frag_len));
         if (bitmask == NULL) {
             OPENSSL_free(buf);
             OPENSSL_free(frag);
             return NULL;
         }
-        memset(bitmask, 0, RSMBLY_BITMASK_SIZE(frag_len));
     }
 
     frag->reassembly = bitmask;
@@ -270,7 +269,8 @@ int dtls1_do_write(SSL *s, int type)
 
     if (s->write_hash) {
         if (s->enc_write_ctx
-            && EVP_CIPHER_CTX_mode(s->enc_write_ctx) == EVP_CIPH_GCM_MODE)
+            && ((EVP_CIPHER_CTX_mode(s->enc_write_ctx) == EVP_CIPH_GCM_MODE) ||
+                (EVP_CIPHER_CTX_mode(s->enc_write_ctx) == EVP_CIPH_CCM_MODE)))
             mac_size = 0;
         else
             mac_size = EVP_MD_CTX_size(s->write_hash);
@@ -454,15 +454,26 @@ long dtls1_get_message(SSL *s, int st1, int stn, int mt, long max, int *ok)
      * absence of an optional handshake message
      */
     if (s->s3->tmp.reuse_message) {
-        s->s3->tmp.reuse_message = 0;
         if ((mt >= 0) && (s->s3->tmp.message_type != mt)) {
             al = SSL_AD_UNEXPECTED_MESSAGE;
             SSLerr(SSL_F_DTLS1_GET_MESSAGE, SSL_R_UNEXPECTED_MESSAGE);
             goto f_err;
         }
         *ok = 1;
-        s->init_msg = s->init_buf->data + DTLS1_HM_HEADER_LENGTH;
+
+
+        /*
+         * Messages reused from dtls1_listen also have the record header in
+         * the buffer which we need to skip over.
+         */
+        if (s->s3->tmp.reuse_message == DTLS1_SKIP_RECORD_HEADER) {
+            s->init_msg = s->init_buf->data + DTLS1_HM_HEADER_LENGTH
+                          + DTLS1_RT_HEADER_LENGTH;
+        } else {
+            s->init_msg = s->init_buf->data + DTLS1_HM_HEADER_LENGTH;
+        }
         s->init_num = (int)s->s3->tmp.message_size;
+        s->s3->tmp.reuse_message = 0;
         return s->init_num;
     }
 
@@ -470,7 +481,7 @@ long dtls1_get_message(SSL *s, int st1, int stn, int mt, long max, int *ok)
     memset(msg_hdr, 0, sizeof(*msg_hdr));
 
  again:
-    i = dtls1_get_message_fragment(s, st1, stn, max, ok);
+    i = dtls1_get_message_fragment(s, st1, stn, mt, max, ok);
     if (i == DTLS1_HM_BAD_FRAGMENT || i == DTLS1_HM_FRAGMENT_RETRY) {
         /* bad fragment received */
         goto again;
@@ -485,6 +496,20 @@ long dtls1_get_message(SSL *s, int st1, int stn, int mt, long max, int *ok)
     }
 
     p = (unsigned char *)s->init_buf->data;
+
+    if (mt == SSL3_MT_CHANGE_CIPHER_SPEC) {
+        if (s->msg_callback) {
+            s->msg_callback(0, s->version, SSL3_RT_CHANGE_CIPHER_SPEC,
+                            p, 1, s, s->msg_callback_arg);
+        }
+        /*
+         * This isn't a real handshake message so skip the processing below.
+         * dtls1_get_message_fragment() will never return a CCS if mt == -1,
+         * so we are ok to continue in that case.
+         */
+        return i;
+    }
+
     msg_len = msg_hdr->msg_len;
 
     /* reconstruct message header */
@@ -505,9 +530,8 @@ long dtls1_get_message(SSL *s, int st1, int stn, int mt, long max, int *ok)
 
     memset(msg_hdr, 0, sizeof(*msg_hdr));
 
-    /* Don't change sequence numbers while listening */
-    if (!s->d1->listen)
-        s->d1->handshake_read_seq++;
+    s->d1->handshake_read_seq++;
+
 
     s->init_msg = s->init_buf->data + DTLS1_HM_HEADER_LENGTH;
     return s->init_num;
@@ -835,11 +859,11 @@ dtls1_process_out_of_seq_message(SSL *s, const struct hm_header_st *msg_hdr,
 }
 
 static long
-dtls1_get_message_fragment(SSL *s, int st1, int stn, long max, int *ok)
+dtls1_get_message_fragment(SSL *s, int st1, int stn, int mt, long max, int *ok)
 {
     unsigned char wire[DTLS1_HM_HEADER_LENGTH];
     unsigned long len, frag_off, frag_len;
-    int i, al;
+    int i, al, recvd_type;
     struct hm_header_st msg_hdr;
 
  redo:
@@ -851,13 +875,46 @@ dtls1_get_message_fragment(SSL *s, int st1, int stn, long max, int *ok)
     }
 
     /* read handshake message header */
-    i = s->method->ssl_read_bytes(s, SSL3_RT_HANDSHAKE, NULL, wire,
+    i = s->method->ssl_read_bytes(s, SSL3_RT_HANDSHAKE, &recvd_type, wire,
                                   DTLS1_HM_HEADER_LENGTH, 0);
     if (i <= 0) {               /* nbio, or an error */
         s->rwstate = SSL_READING;
         *ok = 0;
         return i;
     }
+    if(recvd_type == SSL3_RT_CHANGE_CIPHER_SPEC) {
+        /* This isn't a real handshake message - its a CCS.
+         * There is no message sequence number in a CCS to give us confidence
+         * that this was really intended to be at this point in the handshake
+         * sequence. Therefore we only allow this if we were explicitly looking
+         * for it (i.e. if |mt| is -1 we still don't allow it).
+         */
+        if(mt == SSL3_MT_CHANGE_CIPHER_SPEC) {
+            if (wire[0] != SSL3_MT_CCS) {
+                al = SSL_AD_UNEXPECTED_MESSAGE;
+                SSLerr(SSL_F_DTLS1_GET_MESSAGE_FRAGMENT, SSL_R_BAD_CHANGE_CIPHER_SPEC);
+                goto f_err;
+            }
+
+            memcpy(s->init_buf->data, wire, i);
+            s->init_num = i - 1;
+            s->init_msg = s->init_buf->data + 1;
+            s->s3->tmp.message_type = SSL3_MT_CHANGE_CIPHER_SPEC;
+            s->s3->tmp.message_size = i - 1;
+            s->state = stn;
+            *ok = 1;
+            return i-1;
+        } else {
+            /*
+             * We weren't expecting a CCS yet. Probably something got
+             * re-ordered or this is a retransmit. We should drop this and try
+             * again.
+             */
+            s->init_num = 0;
+            goto redo;
+        }
+    }
+
     /* Handshake fails if message header is incomplete */
     if (i != DTLS1_HM_HEADER_LENGTH) {
         al = SSL_AD_UNEXPECTED_MESSAGE;
@@ -888,8 +945,7 @@ dtls1_get_message_fragment(SSL *s, int st1, int stn, long max, int *ok)
      * While listening, we accept seq 1 (ClientHello with cookie)
      * although we're still expecting seq 0 (ClientHello)
      */
-    if (msg_hdr.seq != s->d1->handshake_read_seq
-        && !(s->d1->listen && msg_hdr.seq == 1))
+    if (msg_hdr.seq != s->d1->handshake_read_seq)
         return dtls1_process_out_of_seq_message(s, &msg_hdr, ok);
 
     if (frag_len && frag_len < len)
@@ -1242,8 +1298,7 @@ void dtls1_set_message_header(SSL *s, unsigned char *p,
                                         unsigned long frag_off,
                                         unsigned long frag_len)
 {
-    /* Don't change sequence numbers while listening */
-    if (frag_off == 0 && !s->d1->listen) {
+    if (frag_off == 0) {
         s->d1->handshake_write_seq = s->d1->next_handshake_write_seq;
         s->d1->next_handshake_write_seq++;
     }
@@ -1318,9 +1373,12 @@ int dtls1_shutdown(SSL *s)
 {
     int ret;
 #ifndef OPENSSL_NO_SCTP
-    if (BIO_dgram_is_sctp(SSL_get_wbio(s)) &&
+    BIO *wbio;
+
+    wbio = SSL_get_wbio(s);
+    if (wbio != NULL && BIO_dgram_is_sctp(wbio) &&
         !(s->shutdown & SSL_SENT_SHUTDOWN)) {
-        ret = BIO_dgram_sctp_wait_for_dry(SSL_get_wbio(s));
+        ret = BIO_dgram_sctp_wait_for_dry(wbio);
         if (ret < 0)
             return -1;