TLS: reject duplicate extensions
[openssl.git] / ssl / t1_lib.c
index 7a2047dcca1e5c7d676fc7b9238f2ff5fda79d57..db5f0f6b442a86b535a79276fccf491a8df8e410 100644 (file)
  */
 
 #include <stdio.h>
+#include <stdlib.h>
 #include <openssl/objects.h>
 #include <openssl/evp.h>
 #include <openssl/hmac.h>
@@ -124,7 +125,7 @@ static int tls_decrypt_ticket(SSL *s, const unsigned char *tick, int ticklen,
                               const unsigned char *sess_id, int sesslen,
                               SSL_SESSION **psess);
 static int ssl_check_clienthello_tlsext_early(SSL *s);
-int ssl_check_serverhello_tlsext(SSL *s);
+static int ssl_check_serverhello_tlsext(SSL *s);
 
 SSL3_ENC_METHOD const TLSv1_enc_data = {
     tls1_enc,
@@ -1037,6 +1038,79 @@ static int tls_use_ticket(SSL *s)
     return ssl_security(s, SSL_SECOP_TICKET, 0, 0, NULL);
 }
 
+static int compare_uint(const void *p1, const void *p2) {
+    unsigned int u1 = *((const unsigned int *)p1);
+    unsigned int u2 = *((const unsigned int *)p2);
+    if (u1 < u2)
+        return -1;
+    else if (u1 > u2)
+        return 1;
+    else
+        return 0;
+}
+
+/*
+ * Per http://tools.ietf.org/html/rfc5246#section-7.4.1.4, there may not be
+ * more than one extension of the same type in a ClientHello or ServerHello.
+ * This function does an initial scan over the extensions block to filter those
+ * out. It returns 1 if all extensions are unique, and 0 if the extensions
+ * contain duplicates, could not be successfully parsed, or an internal error
+ * occurred.
+ */
+static int tls1_check_duplicate_extensions(const PACKET *packet) {
+    PACKET extensions = *packet;
+    size_t num_extensions = 0, i = 0;
+    unsigned int *extension_types = NULL;
+    int ret = 0;
+
+    /* First pass: count the extensions. */
+    while (PACKET_remaining(&extensions) > 0) {
+        unsigned int type;
+        PACKET extension;
+        if (!PACKET_get_net_2(&extensions, &type) ||
+            !PACKET_get_length_prefixed_2(&extensions, &extension)) {
+            goto done;
+        }
+        num_extensions++;
+    }
+
+    if (num_extensions <= 1)
+        return 1;
+
+    extension_types = OPENSSL_malloc(sizeof(unsigned int) * num_extensions);
+    if (extension_types == NULL) {
+        SSLerr(SSL_F_TLS1_CHECK_DUPLICATE_EXTENSIONS, ERR_R_MALLOC_FAILURE);
+        goto done;
+    }
+
+    /* Second pass: gather the extension types. */
+    extensions = *packet;
+    for (i = 0; i < num_extensions; i++) {
+        PACKET extension;
+        if (!PACKET_get_net_2(&extensions, &extension_types[i]) ||
+            !PACKET_get_length_prefixed_2(&extensions, &extension)) {
+            /* This should not happen. */
+            SSLerr(SSL_F_TLS1_CHECK_DUPLICATE_EXTENSIONS, ERR_R_INTERNAL_ERROR);
+            goto done;
+        }
+    }
+
+    if (PACKET_remaining(&extensions) != 0) {
+        SSLerr(SSL_F_TLS1_CHECK_DUPLICATE_EXTENSIONS, ERR_R_INTERNAL_ERROR);
+        goto done;
+    }
+    /* Sort the extensions and make sure there are no duplicates. */
+    qsort(extension_types, num_extensions, sizeof(unsigned int), compare_uint);
+    for (i = 1; i < num_extensions; i++) {
+        if (extension_types[i - 1] == extension_types[i])
+            goto done;
+    }
+    ret = 1;
+ done:
+    OPENSSL_free(extension_types);
+    return ret;
+}
+
 unsigned char *ssl_add_clienthello_tlsext(SSL *s, unsigned char *buf,
                                           unsigned char *limit, int *al)
 {
@@ -1837,6 +1911,9 @@ static int ssl_scan_clienthello_tlsext(SSL *s, PACKET *pkt, int *al)
     if (PACKET_remaining(pkt) != len)
         goto err;
 
+    if (!tls1_check_duplicate_extensions(pkt))
+        goto err;
+
     while (PACKET_get_net_2(pkt, &type) && PACKET_get_net_2(pkt, &size)) {
         PACKET subpkt;
 
@@ -2301,6 +2378,11 @@ static int ssl_scan_serverhello_tlsext(SSL *s, PACKET *pkt, int *al)
         return 0;
     }
 
+    if (!tls1_check_duplicate_extensions(pkt)) {
+        *al = SSL_AD_DECODE_ERROR;
+        return 0;
+    }
+
     while (PACKET_get_net_2(pkt, &type) && PACKET_get_net_2(pkt, &size)) {
         const unsigned char *data;
         PACKET spkt;