Add some missing extensions to SSL_extension_supported()
[openssl.git] / ssl / statem / statem_extensions.c
1 /*
2  * Copyright 2016 The OpenSSL Project Authors. All Rights Reserved.
3  *
4  * Licensed under the OpenSSL license (the "License").  You may not use
5  * this file except in compliance with the License.  You can obtain a copy
6  * in the file LICENSE in the source distribution or at
7  * https://www.openssl.org/source/license.html
8  */
9
10 #include <stdlib.h>
11 #include "../ssl_locl.h"
12 #include "statem_locl.h"
13
14 /*
15  * Comparison function used in a call to qsort (see tls_collect_extensions()
16  * below.)
17  * The two arguments |p1| and |p2| are expected to be pointers to RAW_EXTENSIONs
18  *
19  * Returns:
20  *  1 if the type for p1 is greater than p2
21  *  0 if the type for p1 and p2 are the same
22  * -1 if the type for p1 is less than p2
23  */
24 static int compare_extensions(const void *p1, const void *p2)
25 {
26     const RAW_EXTENSION *e1 = (const RAW_EXTENSION *)p1;
27     const RAW_EXTENSION *e2 = (const RAW_EXTENSION *)p2;
28
29     if (e1->type < e2->type)
30         return -1;
31     else if (e1->type > e2->type)
32         return 1;
33
34     return 0;
35 }
36
37 /*
38  * Gather a list of all the extensions. We don't actually process the content
39  * of the extensions yet, except to check their types.
40  *
41  * Per http://tools.ietf.org/html/rfc5246#section-7.4.1.4, there may not be
42  * more than one extension of the same type in a ClientHello or ServerHello.
43  * This function returns 1 if all extensions are unique and we have parsed their
44  * types, and 0 if the extensions contain duplicates, could not be successfully
45  * parsed, or an internal error occurred.
46  */
47 /*
48  * TODO(TLS1.3): Refactor ServerHello extension parsing to use this and then
49  * remove tls1_check_duplicate_extensions()
50  */
51 int tls_collect_extensions(PACKET *packet, RAW_EXTENSION **res,
52                              size_t *numfound, int *ad)
53 {
54     PACKET extensions = *packet;
55     size_t num_extensions = 0, i = 0;
56     RAW_EXTENSION *raw_extensions = NULL;
57
58     /* First pass: count the extensions. */
59     while (PACKET_remaining(&extensions) > 0) {
60         unsigned int type;
61         PACKET extension;
62
63         if (!PACKET_get_net_2(&extensions, &type) ||
64             !PACKET_get_length_prefixed_2(&extensions, &extension)) {
65             *ad = SSL_AD_DECODE_ERROR;
66             goto err;
67         }
68         num_extensions++;
69     }
70
71     if (num_extensions > 0) {
72         raw_extensions = OPENSSL_malloc(sizeof(*raw_extensions)
73                                         * num_extensions);
74         if (raw_extensions == NULL) {
75             *ad = SSL_AD_INTERNAL_ERROR;
76             SSLerr(SSL_F_TLS_COLLECT_EXTENSIONS, ERR_R_MALLOC_FAILURE);
77             goto err;
78         }
79
80         /* Second pass: collect the extensions. */
81         for (i = 0; i < num_extensions; i++) {
82             if (!PACKET_get_net_2(packet, &raw_extensions[i].type) ||
83                 !PACKET_get_length_prefixed_2(packet,
84                                               &raw_extensions[i].data)) {
85                 /* This should not happen. */
86                 *ad = SSL_AD_INTERNAL_ERROR;
87                 SSLerr(SSL_F_TLS_COLLECT_EXTENSIONS, ERR_R_INTERNAL_ERROR);
88                 goto err;
89             }
90         }
91
92         if (PACKET_remaining(packet) != 0) {
93             *ad = SSL_AD_DECODE_ERROR;
94             SSLerr(SSL_F_TLS_COLLECT_EXTENSIONS, SSL_R_LENGTH_MISMATCH);
95             goto err;
96         }
97         /* Sort the extensions and make sure there are no duplicates. */
98         qsort(raw_extensions, num_extensions, sizeof(*raw_extensions),
99               compare_extensions);
100         for (i = 1; i < num_extensions; i++) {
101             if (raw_extensions[i - 1].type == raw_extensions[i].type) {
102                 *ad = SSL_AD_DECODE_ERROR;
103                 goto err;
104             }
105         }
106     }
107
108     *res = raw_extensions;
109     *numfound = num_extensions;
110     return 1;
111
112  err:
113     OPENSSL_free(raw_extensions);
114     return 0;
115 }