Ensure we unpad in constant time for read pipelining
[openssl.git] / ssl / record / ssl3_record.c
index ad240bc52d2a3e6b97759adbb3c60102f1505f23..f1d6f72d837da4e4495b6b743d969211960bc9c6 100644 (file)
@@ -159,19 +159,9 @@ int ssl3_get_record(SSL *s)
             p = RECORD_LAYER_get_packet(&s->rlayer);
 
             /*
-             * Check whether this is a regular record or an SSLv2 style record.
-             * The latter can only be used in the first record of an initial
-             * ClientHello for old clients. Initial ClientHello means
-             * s->first_packet is set and s->server is true. The first record
-             * means we've not received any data so far (s->init_num == 0) and
-             * have had no empty records. We check s->read_hash and
-             * s->enc_read_ctx to ensure this does not apply during
-             * renegotiation.
+             * The first record received by the server may be a V2ClientHello.
              */
-            if (s->first_packet && s->server
-                    && s->init_num == 0
-                    && RECORD_LAYER_get_empty_record_count(&s->rlayer) == 0
-                    && s->read_hash == NULL && s->enc_read_ctx == NULL
+            if (s->server && RECORD_LAYER_is_first_record(&s->rlayer)
                     && (p[0] & 0x80) && (p[2] == SSL2_MT_CLIENT_HELLO)) {
                 /*
                  *  SSLv2 style record
@@ -239,7 +229,7 @@ int ssl3_get_record(SSL *s)
                 }
 
                 if ((version >> 8) != SSL3_VERSION_MAJOR) {
-                    if (s->first_packet) {
+                    if (RECORD_LAYER_is_first_record(&s->rlayer)) {
                         /* Go back to start of packet, look at the five bytes
                          * that we have. */
                         p = RECORD_LAYER_get_packet(&s->rlayer);
@@ -254,9 +244,17 @@ int ssl3_get_record(SSL *s)
                                    SSL_R_HTTPS_PROXY_REQUEST);
                             goto err;
                         }
+
+                        /* Doesn't look like TLS - don't send an alert */
+                        SSLerr(SSL_F_SSL3_GET_RECORD,
+                               SSL_R_WRONG_VERSION_NUMBER);
+                        goto err;
+                    } else {
+                        SSLerr(SSL_F_SSL3_GET_RECORD,
+                               SSL_R_WRONG_VERSION_NUMBER);
+                        al = SSL_AD_PROTOCOL_VERSION;
+                        goto f_err;
                     }
-                    SSLerr(SSL_F_SSL3_GET_RECORD, SSL_R_WRONG_VERSION_NUMBER);
-                    goto err;
                 }
 
                 if (rr[num_recs].length >
@@ -335,6 +333,7 @@ int ssl3_get_record(SSL *s)
 
         /* we have pulled in a full packet so zero things */
         RECORD_LAYER_reset_packet_length(&s->rlayer);
+        RECORD_LAYER_clear_first_record(&s->rlayer);
     } while (num_recs < max_recs
              && rr[num_recs-1].type == SSL3_RT_APPLICATION_DATA
              && SSL_USE_EXPLICIT_IV(s)
@@ -832,9 +831,15 @@ int tls1_enc(SSL *s, SSL3_RECORD *recs, unsigned int n_recs, int send)
             int tmpret;
             for (ctr = 0; ctr < n_recs; ctr++) {
                 tmpret = tls1_cbc_remove_padding(s, &recs[ctr], bs, mac_size);
-                if (tmpret == -1)
-                    return -1;
-                ret &= tmpret;
+                /*
+                 * If tmpret == 0 then this means publicly invalid so we can
+                 * short circuit things here. Otherwise we must respect constant
+                 * time behaviour.
+                 */
+                if (tmpret == 0)
+                    return 0;
+                ret = constant_time_select_int(constant_time_eq_int(tmpret, 1),
+                                               ret, -1);
             }
         }
         if (pad && !send) {
@@ -1149,9 +1154,9 @@ int tls1_cbc_remove_padding(const SSL *s,
      * maximum amount of padding possible. (Again, the length of the record
      * is public information so we can use it.)
      */
-    to_check = 255;             /* maximum amount of padding. */
-    if (to_check > rec->length - 1)
-        to_check = rec->length - 1;
+    to_check = 256;            /* maximum amount of padding, inc length byte. */
+    if (to_check > rec->length)
+        to_check = rec->length;
 
     for (i = 0; i < to_check; i++) {
         unsigned char mask = constant_time_ge_8(padding_length, i);