Implement the QUIC Fault injector support for plaintext packets
[openssl.git] / test / helpers / quictestlib.c
1 /*
2  * Copyright 2022 The OpenSSL Project Authors. All Rights Reserved.
3  *
4  * Licensed under the Apache License 2.0 (the "License").  You may not use
5  * this file except in compliance with the License.  You can obtain a copy
6  * in the file LICENSE in the source distribution or at
7  * https://www.openssl.org/source/license.html
8  */
9
10 #include <assert.h>
11 #include "quictestlib.h"
12 #include "../testutil.h"
13 #include "internal/quic_wire_pkt.h"
14 #include "internal/quic_record_tx.h"
15
16 #define GROWTH_ALLOWANCE 1024
17
18 struct ossl_quic_fault {
19     QUIC_TSERVER *qtserv;
20
21     /* Plain packet mutations */
22     /* Header for the plaintext packet */
23     QUIC_PKT_HDR pplainhdr;
24     /* iovec for the plaintext packet data buffer */
25     OSSL_QTX_IOVEC pplainio;
26     /* Allocted size of the plaintext packet data buffer */
27     size_t pplainbuf_alloc;
28     ossl_quic_fault_on_packet_plain_cb pplaincb;
29     void *pplaincbarg;
30 };
31
32 int qtest_create_quic_objects(SSL_CTX *clientctx, char *certfile, char *keyfile,
33                               QUIC_TSERVER **qtserv, SSL **cssl,
34                               OSSL_QUIC_FAULT **fault)
35 {
36     /* ALPN value as recognised by QUIC_TSERVER */
37     unsigned char alpn[] = { 8, 'o', 's', 's', 'l', 't', 'e', 's', 't' };
38     QUIC_TSERVER_ARGS tserver_args = {0};
39     BIO *bio1 = NULL, *bio2 = NULL;
40     BIO_ADDR *peeraddr = NULL;
41     struct in_addr ina = {0};
42
43     *qtserv = NULL;
44     if (fault != NULL)
45         *fault = NULL;
46     *cssl = SSL_new(clientctx);
47     if (!TEST_ptr(*cssl))
48         return 0;
49
50     if (!TEST_true(SSL_set_blocking_mode(*cssl, 0)))
51         goto err;
52
53     /* SSL_set_alpn_protos returns 0 for success! */
54     if (!TEST_false(SSL_set_alpn_protos(*cssl, alpn, sizeof(alpn))))
55         goto err;
56
57     if (!TEST_true(BIO_new_bio_dgram_pair(&bio1, 0, &bio2, 0)))
58         goto err;
59
60     if (!TEST_true(BIO_dgram_set_caps(bio1, BIO_DGRAM_CAP_HANDLES_DST_ADDR))
61             || !TEST_true(BIO_dgram_set_caps(bio2, BIO_DGRAM_CAP_HANDLES_DST_ADDR)))
62         goto err;
63
64     SSL_set_bio(*cssl, bio1, bio1);
65
66     if (!TEST_ptr(peeraddr = BIO_ADDR_new()))
67         goto err;
68
69     /* Dummy server address */
70     if (!TEST_true(BIO_ADDR_rawmake(peeraddr, AF_INET, &ina, sizeof(ina),
71                                     htons(0))))
72         goto err;
73
74     if (!TEST_true(SSL_set_initial_peer_addr(*cssl, peeraddr)))
75         goto err;
76
77     /* 2 refs are passed for bio2 */
78     if (!BIO_up_ref(bio2))
79         goto err;
80     tserver_args.net_rbio = bio2;
81     tserver_args.net_wbio = bio2;
82
83     if (!TEST_ptr(*qtserv = ossl_quic_tserver_new(&tserver_args, certfile,
84                                                   keyfile))) {
85         /* We hold 2 refs to bio2 at the moment */
86         BIO_free(bio2);
87         goto err;
88     }
89     /* Ownership of bio2 is now held by *qtserv */
90     bio2 = NULL;
91
92     if (fault != NULL) {
93         *fault = OPENSSL_zalloc(sizeof(**fault));
94         if (*fault == NULL)
95             goto err;
96
97         (*fault)->qtserv = *qtserv;
98     }
99
100     BIO_ADDR_free(peeraddr);
101
102     return 1;
103  err:
104     BIO_ADDR_free(peeraddr);
105     BIO_free(bio1);
106     BIO_free(bio2);
107     SSL_free(*cssl);
108     ossl_quic_tserver_free(*qtserv);
109     if (fault != NULL)
110         OPENSSL_free(*fault);
111
112     return 0;
113 }
114
115 #define MAXLOOPS    1000
116
117 int qtest_create_quic_connection(QUIC_TSERVER *qtserv, SSL *clientssl)
118 {
119     int retc = -1, rets = 0, err, abortctr = 0, ret = 0;
120     int clienterr = 0, servererr = 0;
121
122     do {
123         err = SSL_ERROR_WANT_WRITE;
124         while (!clienterr && retc <= 0 && err == SSL_ERROR_WANT_WRITE) {
125             retc = SSL_connect(clientssl);
126             if (retc <= 0)
127                 err = SSL_get_error(clientssl, retc);
128         }
129
130         if (!clienterr && retc <= 0 && err != SSL_ERROR_WANT_READ) {
131             TEST_info("SSL_connect() failed %d, %d", retc, err);
132             TEST_openssl_errors();
133             clienterr = 1;
134         }
135
136         /*
137          * We're cheating. We don't take any notice of SSL_get_tick_timeout()
138          * and tick everytime around the loop anyway. This is inefficient. We
139          * can get away with it in test code because we control both ends of
140          * the communications and don't expect network delays. This shouldn't
141          * be done in a real application.
142          */
143         if (!clienterr)
144             SSL_tick(clientssl);
145         if (!servererr) {
146             ossl_quic_tserver_tick(qtserv);
147             servererr = ossl_quic_tserver_is_term_any(qtserv);
148             if (!servererr && !rets)
149                 rets = ossl_quic_tserver_is_connected(qtserv);
150         }
151
152         if (clienterr && servererr)
153             goto err;
154
155         if (++abortctr == MAXLOOPS) {
156             TEST_info("No progress made");
157             goto err;
158         }
159     } while (retc <=0 || rets <= 0);
160
161     ret = 1;
162  err:
163     return ret;
164 }
165
166 void ossl_quic_fault_free(OSSL_QUIC_FAULT *fault)
167 {
168     if (fault == NULL)
169         return;
170
171     OPENSSL_free(fault);
172 }
173
174 static int packet_plain_mutate(const QUIC_PKT_HDR *hdrin,
175                                const OSSL_QTX_IOVEC *iovecin, size_t numin,
176                                QUIC_PKT_HDR **hdrout,
177                                const OSSL_QTX_IOVEC **iovecout,
178                                size_t *numout,
179                                void *arg)
180 {
181     OSSL_QUIC_FAULT *fault = arg;
182     size_t i, bufsz = 0;
183     unsigned char *cur;
184
185     /* Coalesce our data into a single buffer */
186
187     /* First calculate required buffer size */
188     for (i = 0; i < numin; i++)
189         bufsz += iovecin[i].buf_len;
190
191     fault->pplainio.buf_len = bufsz;
192
193     /* Add an allowance for possible growth */
194     bufsz += GROWTH_ALLOWANCE;
195
196     fault->pplainio.buf = cur = OPENSSL_malloc(bufsz);
197     if (cur == NULL) {
198         fault->pplainio.buf_len = 0;
199         return 0;
200     }
201
202     fault->pplainbuf_alloc = bufsz;
203
204     /* Copy in the data from the input buffers */
205     for (i = 0; i < numin; i++) {
206         memcpy(cur, iovecin[i].buf, iovecin[i].buf_len);
207         cur += iovecin[i].buf_len;
208     }
209
210     fault->pplainhdr = *hdrin;
211
212     /* Cast below is safe because we allocated the buffer */
213     if (fault->pplaincb != NULL
214             && !fault->pplaincb(fault, &fault->pplainhdr,
215                                 (unsigned char *)fault->pplainio.buf,
216                                 fault->pplainio.buf_len, fault->pplaincbarg))
217         return 0;
218
219     *hdrout = &fault->pplainhdr;
220     *iovecout = &fault->pplainio;
221     *numout = 1;
222
223     return 1;
224 }
225
226 static void packet_plain_finish(void *arg)
227 {
228     OSSL_QUIC_FAULT *fault = arg;
229
230     /* Cast below is safe because we allocated the buffer */
231     OPENSSL_free((unsigned char *)fault->pplainio.buf);
232     fault->pplainio.buf_len = 0;
233     fault->pplainbuf_alloc = 0;
234 }
235
236 int ossl_quic_fault_set_packet_plain_listener(OSSL_QUIC_FAULT *fault,
237                                               ossl_quic_fault_on_packet_plain_cb pplaincb,
238                                               void *pplaincbarg)
239 {
240     fault->pplaincb = pplaincb;
241     fault->pplaincbarg = pplaincbarg;
242
243     return ossl_quic_tserver_set_mutator(fault->qtserv, packet_plain_mutate,
244                                          packet_plain_finish, fault);
245 }
246
247 /* To be called from a packet_plain_listener callback */
248 int ossl_quic_fault_resize_plain_packet(OSSL_QUIC_FAULT *fault, size_t newlen)
249 {
250     unsigned char *buf;
251     size_t oldlen = fault->pplainio.buf_len;
252
253     /*
254      * Alloc'd size should always be non-zero, so if this fails we've been
255      * incorrectly called
256      */
257     if (fault->pplainbuf_alloc == 0)
258         return 0;
259
260     if (newlen > fault->pplainbuf_alloc) {
261         /* This exceeds our growth allowance. Fail */
262         return 0;
263     }
264
265     /* Cast below is safe because we allocated the buffer */
266     buf = (unsigned char *)fault->pplainio.buf;
267
268     if (newlen > oldlen) {
269         /* Extend packet with 0 bytes */
270         memset(buf + oldlen, 0, newlen - oldlen);
271     } /* else we're truncating or staying the same */
272
273     fault->pplainio.buf_len = newlen;
274     fault->pplainhdr.len = newlen;
275
276     return 1;
277 }