QUIC DEMUX: (Server support) Add support for default handler
[openssl.git] / ssl / ssl_cert_comp.c
1 /*
2  * Copyright 2022 The OpenSSL Project Authors. All Rights Reserved.
3  *
4  * Licensed under the Apache License 2.0 (the "License").  You may not use
5  * this file except in compliance with the License.  You can obtain a copy
6  * in the file LICENSE in the source distribution or at
7  * https://www.openssl.org/source/license.html
8  */
9
10 #include <stdio.h>
11 #include "ssl_local.h"
12 #include "internal/e_os.h"
13 #include "internal/refcount.h"
14
15 size_t ossl_calculate_comp_expansion(int alg, size_t length)
16 {
17     size_t ret;
18     /*
19      * Uncompressibility expansion:
20      * ZLIB: N + 11 + 5 * (N >> 14)
21      * Brotli: per RFC7932: N + 5 + 3 * (N >> 16)
22      * ZSTD: N + 4 + 14 + 3 * (N >> 17) + 4
23      */
24     
25     switch (alg) {
26     case TLSEXT_comp_cert_zlib:
27         ret = length + 11 + 5 * (length >> 14);
28         break;
29     case TLSEXT_comp_cert_brotli:
30         ret = length + 5 + 3 * (length >> 16);
31         break;
32     case TLSEXT_comp_cert_zstd:
33         ret = length + 22 + 3 * (length >> 17);
34         break;
35     default:
36         return 0;
37     }
38     /* Check for overflow */
39     if (ret < length)
40         return 0;
41     return ret;
42 }
43
44 int ossl_comp_has_alg(int a)
45 {
46 #ifndef OPENSSL_NO_COMP_ALG
47     /* 0 means "any" algorithm */
48     if ((a == 0 || a == TLSEXT_comp_cert_brotli) && BIO_f_brotli() != NULL)
49         return 1;
50     if ((a == 0 || a == TLSEXT_comp_cert_zstd) && BIO_f_zstd() != NULL)
51         return 1;
52     if ((a == 0 || a == TLSEXT_comp_cert_zlib) && BIO_f_zlib() != NULL)
53         return 1;
54 #endif
55     return 0;
56 }
57
58 /* New operation Helper routine */
59 #ifndef OPENSSL_NO_COMP_ALG
60 static OSSL_COMP_CERT *OSSL_COMP_CERT_new(unsigned char *data, size_t len, size_t orig_len, int alg)
61 {
62     OSSL_COMP_CERT *ret = NULL;
63
64     if (!ossl_comp_has_alg(alg)
65             || data == NULL
66             || (ret = OPENSSL_zalloc(sizeof(*ret))) == NULL
67             || (ret->lock = CRYPTO_THREAD_lock_new()) == NULL)
68         goto err;
69
70     ret->references = 1;
71     ret->data = data;
72     ret->len = len;
73     ret->orig_len = orig_len;
74     ret->alg = alg;
75     return ret;
76  err:
77     ERR_raise(ERR_LIB_SSL, ERR_R_MALLOC_FAILURE);
78     OPENSSL_free(data);
79     OPENSSL_free(ret);
80     return NULL;
81 }
82
83 __owur static OSSL_COMP_CERT *OSSL_COMP_CERT_from_compressed_data(unsigned char *data, size_t len,
84                                                                   size_t orig_len, int alg)
85 {
86     return OSSL_COMP_CERT_new(OPENSSL_memdup(data, len), len, orig_len, alg);
87 }
88
89 __owur static OSSL_COMP_CERT *OSSL_COMP_CERT_from_uncompressed_data(unsigned char *data, size_t len,
90                                                                     int alg)
91 {
92     OSSL_COMP_CERT *ret = NULL;
93     size_t max_length;
94     int comp_length;
95     COMP_METHOD *method;
96     unsigned char *comp_data = NULL;
97     COMP_CTX *comp_ctx = NULL;
98
99     switch (alg) {
100     case TLSEXT_comp_cert_brotli:
101         method = COMP_brotli_oneshot();
102         break;
103     case TLSEXT_comp_cert_zlib:
104         method = COMP_zlib_oneshot();
105         break;
106     case TLSEXT_comp_cert_zstd:
107         method = COMP_zstd_oneshot();
108         break;
109     default:
110         goto err;
111     }
112
113     if ((max_length = ossl_calculate_comp_expansion(alg, len)) == 0
114           || method == NULL
115           || (comp_ctx = COMP_CTX_new(method)) == NULL
116           || (comp_data = OPENSSL_zalloc(max_length)) == NULL)
117         goto err;
118
119     comp_length = COMP_compress_block(comp_ctx, comp_data, max_length, data, len);
120     if (comp_length <= 0)
121         goto err;
122
123     ret = OSSL_COMP_CERT_new(comp_data, comp_length, len, alg);
124     comp_data = NULL;
125
126  err:
127     OPENSSL_free(comp_data);
128     COMP_CTX_free(comp_ctx);
129     return ret;
130 }
131
132 void OSSL_COMP_CERT_free(OSSL_COMP_CERT *cc)
133 {
134     int i;
135
136     if (cc == NULL)
137         return;
138
139     CRYPTO_DOWN_REF(&cc->references, &i, cc->lock);
140     REF_PRINT_COUNT("OSSL_COMP_CERT", cc);
141     if (i > 0)
142         return;
143     REF_ASSERT_ISNT(i < 0);
144
145     OPENSSL_free(cc->data);
146     CRYPTO_THREAD_lock_free(cc->lock);
147     OPENSSL_free(cc);
148 }
149 int OSSL_COMP_CERT_up_ref(OSSL_COMP_CERT *cc)
150 {
151     int i;
152
153     if (CRYPTO_UP_REF(&cc->references, &i, cc->lock) <= 0)
154         return 0;
155
156     REF_PRINT_COUNT("OSSL_COMP_CERT", cc);
157     REF_ASSERT_ISNT(i < 2);
158     return ((i > 1) ? 1 : 0);
159 }
160
161 static int ssl_set_cert_comp_pref(int *prefs, int *algs, size_t len)
162 {
163     size_t j = 0;
164     size_t i;
165     int found = 0;
166     int already_set[TLSEXT_comp_cert_limit];
167     int tmp_prefs[TLSEXT_comp_cert_limit];
168
169     /* Note that |len| is the number of |algs| elements */
170     /* clear all algorithms */
171     if (len == 0 || algs == NULL) {
172         memset(prefs, 0, sizeof(tmp_prefs));
173         return 1;
174     }
175
176     /* This will 0-terminate the array */
177     memset(tmp_prefs, 0, sizeof(tmp_prefs));
178     memset(already_set, 0, sizeof(already_set));
179     /* Include only those algorithms we support, ignoring duplicates and unknowns */
180     for (i = 0; i < len; i++) {
181         if (algs[i] != 0 && ossl_comp_has_alg(algs[i])) {
182             /* Check for duplicate */
183             if (already_set[algs[i]])
184                 return 0;
185             tmp_prefs[j++] = algs[i];
186             already_set[algs[i]] = 1;
187             found = 1;
188         }
189     }
190     if (found)
191         memcpy(prefs, tmp_prefs, sizeof(tmp_prefs));
192     return found;
193 }
194
195 static size_t ssl_get_cert_to_compress(SSL *ssl, CERT_PKEY *cpk, unsigned char **data)
196 {
197     SSL_CONNECTION *sc = SSL_CONNECTION_FROM_SSL(ssl);
198     WPACKET tmppkt;
199     BUF_MEM buf = { 0 };
200     size_t ret = 0;
201
202     if (sc == NULL
203             || cpk == NULL
204             || !sc->server
205             || !SSL_in_before(ssl))
206         return 0;
207
208     /* Use the |tmppkt| for the to-be-compressed data */
209     if (!WPACKET_init(&tmppkt, &buf))
210         goto out;
211
212     /* no context present, add 0-length context */
213     if (!WPACKET_put_bytes_u8(&tmppkt, 0))
214         goto out;
215
216     /*
217      * ssl3_output_cert_chain() may generate an SSLfatal() error,
218      * for this case, we want to ignore it, argument for_comp = 1
219      */
220     if (!ssl3_output_cert_chain(sc, &tmppkt, cpk, 1))
221         goto out;
222     WPACKET_get_total_written(&tmppkt, &ret);
223
224  out:
225     WPACKET_cleanup(&tmppkt);
226     if (ret != 0 && data != NULL)
227         *data = (unsigned char *)buf.data;
228     else
229         OPENSSL_free(buf.data);
230     return ret;
231 }
232
233 static int ssl_compress_one_cert(SSL *ssl, CERT_PKEY *cpk, int alg)
234 {
235     unsigned char *cert_data = NULL;
236     OSSL_COMP_CERT *comp_cert = NULL;
237     size_t length;
238
239     if (cpk == NULL
240             || alg == TLSEXT_comp_cert_none
241             || !ossl_comp_has_alg(alg))
242         return 0;
243
244     if ((length = ssl_get_cert_to_compress(ssl, cpk, &cert_data)) == 0)
245         return 0;
246     comp_cert = OSSL_COMP_CERT_from_uncompressed_data(cert_data, length, alg);
247     OPENSSL_free(cert_data);
248     if (comp_cert == NULL)
249         return 0;
250
251     OSSL_COMP_CERT_free(cpk->comp_cert[alg]);
252     cpk->comp_cert[alg] = comp_cert;
253     return 1;
254 }
255
256 /* alg_in can be 0, meaning any/all algorithms */
257 static int ssl_compress_certs(SSL *ssl, CERT_PKEY *cpks, int alg_in)
258 {
259     SSL_CONNECTION *sc = SSL_CONNECTION_FROM_SSL(ssl);
260     int i;
261     int j;
262     int alg;
263     int count = 0;
264
265     if (sc == NULL
266             || cpks == NULL
267             || !ossl_comp_has_alg(alg_in))
268         return 0;
269
270     /* Look through the preferences to see what we have */
271     for (i = 0; i < TLSEXT_comp_cert_limit; i++) {
272         /*
273          * alg = 0 means compress for everything, but only for algorithms enabled
274          * alg != 0 means compress for that algorithm if enabled
275          */
276         alg = sc->cert_comp_prefs[i];
277         if ((alg_in == 0 && alg != TLSEXT_comp_cert_none)
278                 || (alg_in != 0 && alg == alg_in)) {
279
280             for (j = 0; j < SSL_PKEY_NUM; j++) {
281                 /* No cert, move on */
282                 if (cpks[j].x509 == NULL)
283                     continue;
284
285                 if (!ssl_compress_one_cert(ssl, &cpks[j], alg))
286                     return 0;
287
288                 /* if the cert expanded, set the value in the CERT_PKEY to NULL */
289                 if (cpks[j].comp_cert[alg]->len >= cpks[j].comp_cert[alg]->orig_len) {
290                     OSSL_COMP_CERT_free(cpks[j].comp_cert[alg]);
291                     cpks[j].comp_cert[alg] = NULL;
292                 } else {
293                     count++;
294                 }
295             }
296         }
297     }
298     return (count > 0);
299 }
300
301 static size_t ssl_get_compressed_cert(SSL *ssl, CERT_PKEY *cpk, int alg, unsigned char **data,
302                                       size_t *orig_len)
303 {
304     SSL_CONNECTION *sc = SSL_CONNECTION_FROM_SSL(ssl);
305     size_t cert_len = 0;
306     size_t comp_len = 0;
307     unsigned char *cert_data = NULL;
308     OSSL_COMP_CERT *comp_cert = NULL;
309
310     if (sc == NULL
311             || cpk == NULL
312             || data == NULL
313             || orig_len == NULL
314             || !sc->server
315             || !SSL_in_before(ssl)
316             || !ossl_comp_has_alg(alg))
317         return 0;
318
319     if ((cert_len = ssl_get_cert_to_compress(ssl, cpk, &cert_data)) == 0)
320         goto err;
321
322     comp_cert = OSSL_COMP_CERT_from_uncompressed_data(cert_data, cert_len, alg);
323     OPENSSL_free(cert_data);
324     if (comp_cert == NULL)
325         goto err;
326
327     comp_len = comp_cert->len;
328     *orig_len = comp_cert->orig_len;
329     *data = comp_cert->data;
330     comp_cert->data = NULL;
331  err:
332     OSSL_COMP_CERT_free(comp_cert);
333     return comp_len;
334 }
335
336 static int ossl_set1_compressed_cert(CERT *cert, int algorithm,
337                                      unsigned char *comp_data, size_t comp_length,
338                                      size_t orig_length)
339 {
340     OSSL_COMP_CERT *comp_cert;
341
342     /* No explicit cert set */
343     if (cert == NULL || cert->key == NULL)
344         return 0;
345
346     comp_cert = OSSL_COMP_CERT_from_compressed_data(comp_data, comp_length,
347                                                     orig_length, algorithm);
348     if (comp_cert == NULL)
349         return 0;
350
351     OSSL_COMP_CERT_free(cert->key->comp_cert[algorithm]);
352     cert->key->comp_cert[algorithm] = comp_cert;
353
354     return 1;
355 }
356 #endif
357
358 /*-
359  * Public API
360  */
361 int SSL_CTX_set1_cert_comp_preference(SSL_CTX *ctx, int *algs, size_t len)
362 {
363 #ifndef OPENSSL_NO_COMP_ALG
364     return ssl_set_cert_comp_pref(ctx->cert_comp_prefs, algs, len);
365 #else
366     return 0;
367 #endif
368 }
369
370 int SSL_set1_cert_comp_preference(SSL *ssl, int *algs, size_t len)
371 {
372 #ifndef OPENSSL_NO_COMP_ALG
373     SSL_CONNECTION *sc = SSL_CONNECTION_FROM_SSL(ssl);
374
375     if (sc == NULL)
376         return 0;
377     return ssl_set_cert_comp_pref(sc->cert_comp_prefs, algs, len);
378 #else
379     return 0;
380 #endif
381 }
382
383 int SSL_compress_certs(SSL *ssl, int alg)
384 {
385 #ifndef OPENSSL_NO_COMP_ALG
386     SSL_CONNECTION *sc = SSL_CONNECTION_FROM_SSL(ssl);
387
388     if (sc == NULL || sc->cert == NULL)
389         return 0;
390
391     return ssl_compress_certs(ssl, sc->cert->pkeys, alg);
392 #endif
393     return 0;
394 }
395
396 int SSL_CTX_compress_certs(SSL_CTX *ctx, int alg)
397 {
398     int ret = 0;
399 #ifndef OPENSSL_NO_COMP_ALG
400     SSL *new = SSL_new(ctx);
401
402     if (new == NULL)
403         return 0;
404
405     ret = ssl_compress_certs(new, ctx->cert->pkeys, alg);
406     SSL_free(new);
407 #endif
408     return ret;
409 }
410
411 size_t SSL_get1_compressed_cert(SSL *ssl, int alg, unsigned char **data, size_t *orig_len)
412 {
413 #ifndef OPENSSL_NO_COMP_ALG
414     SSL_CONNECTION *sc = SSL_CONNECTION_FROM_SSL(ssl);
415     CERT_PKEY *cpk = NULL;
416
417     if (sc->cert != NULL)
418         cpk = sc->cert->key;
419     else
420         cpk = ssl->ctx->cert->key;
421
422     return ssl_get_compressed_cert(ssl, cpk, alg, data, orig_len);
423 #else
424     return 0;
425 #endif
426 }
427
428 size_t SSL_CTX_get1_compressed_cert(SSL_CTX *ctx, int alg, unsigned char **data, size_t *orig_len)
429 {
430 #ifndef OPENSSL_NO_COMP_ALG
431     size_t ret;
432     SSL *new = SSL_new(ctx);
433
434     ret = ssl_get_compressed_cert(new, ctx->cert->key, alg, data, orig_len);
435     SSL_free(new);
436     return ret;
437 #else
438         return 0;
439 #endif
440 }
441
442 int SSL_CTX_set1_compressed_cert(SSL_CTX *ctx, int algorithm, unsigned char *comp_data,
443                                  size_t comp_length, size_t orig_length)
444 {
445 #ifndef OPENSSL_NO_COMP_ALG
446     return ossl_set1_compressed_cert(ctx->cert, algorithm, comp_data, comp_length, orig_length);
447 #else
448     return 0;
449 #endif
450 }
451
452 int SSL_set1_compressed_cert(SSL *ssl, int algorithm, unsigned char *comp_data,
453                              size_t comp_length, size_t orig_length)
454 {
455 #ifndef OPENSSL_NO_COMP_ALG
456     SSL_CONNECTION *sc = SSL_CONNECTION_FROM_SSL(ssl);
457
458     /* Cannot set a pre-compressed certificate on a client */
459     if (sc == NULL || !sc->server)
460         return 0;
461
462     return ossl_set1_compressed_cert(sc->cert, algorithm, comp_data, comp_length, orig_length);
463 #else
464     return 0;
465 #endif
466 }