Add a packet splitting BIO
[openssl.git] / test / helpers / quictestlib.c
1 /*
2  * Copyright 2022-2023 The OpenSSL Project Authors. All Rights Reserved.
3  *
4  * Licensed under the Apache License 2.0 (the "License").  You may not use
5  * this file except in compliance with the License.  You can obtain a copy
6  * in the file LICENSE in the source distribution or at
7  * https://www.openssl.org/source/license.html
8  */
9
10 #include <assert.h>
11 #include <openssl/configuration.h>
12 #include <openssl/bio.h>
13 #include "quictestlib.h"
14 #include "ssltestlib.h"
15 #include "../testutil.h"
16 #if defined(OPENSSL_THREADS) && !defined(CRYPTO_TDEBUG)
17 # include "../threadstest.h"
18 #endif
19 #include "internal/quic_ssl.h"
20 #include "internal/quic_wire_pkt.h"
21 #include "internal/quic_record_tx.h"
22 #include "internal/quic_error.h"
23 #include "internal/packet.h"
24 #include "internal/tsan_assist.h"
25
26 #define GROWTH_ALLOWANCE 1024
27
28 struct qtest_fault {
29     QUIC_TSERVER *qtserv;
30
31     /* Plain packet mutations */
32     /* Header for the plaintext packet */
33     QUIC_PKT_HDR pplainhdr;
34     /* iovec for the plaintext packet data buffer */
35     OSSL_QTX_IOVEC pplainio;
36     /* Allocated size of the plaintext packet data buffer */
37     size_t pplainbuf_alloc;
38     qtest_fault_on_packet_plain_cb pplaincb;
39     void *pplaincbarg;
40
41     /* Handshake message mutations */
42     /* Handshake message buffer */
43     unsigned char *handbuf;
44     /* Allocated size of the handshake message buffer */
45     size_t handbufalloc;
46     /* Actual length of the handshake message */
47     size_t handbuflen;
48     qtest_fault_on_handshake_cb handshakecb;
49     void *handshakecbarg;
50     qtest_fault_on_enc_ext_cb encextcb;
51     void *encextcbarg;
52
53     /* Cipher packet mutations */
54     qtest_fault_on_packet_cipher_cb pciphercb;
55     void *pciphercbarg;
56
57     /* Datagram mutations */
58     qtest_fault_on_datagram_cb datagramcb;
59     void *datagramcbarg;
60     /* The currently processed message */
61     BIO_MSG msg;
62     /* Allocated size of msg data buffer */
63     size_t msgalloc;
64 };
65
66 static void packet_plain_finish(void *arg);
67 static void handshake_finish(void *arg);
68
69 static int using_fake_time = 0;
70 static OSSL_TIME fake_now;
71
72 static OSSL_TIME fake_now_cb(void *arg)
73 {
74     return fake_now;
75 }
76
77 int qtest_create_quic_objects(OSSL_LIB_CTX *libctx, SSL_CTX *clientctx,
78                               SSL_CTX *serverctx, char *certfile, char *keyfile,
79                               int flags, QUIC_TSERVER **qtserv, SSL **cssl,
80                               QTEST_FAULT **fault)
81 {
82     /* ALPN value as recognised by QUIC_TSERVER */
83     unsigned char alpn[] = { 8, 'o', 's', 's', 'l', 't', 'e', 's', 't' };
84     QUIC_TSERVER_ARGS tserver_args = {0};
85     BIO *cbio = NULL, *sbio = NULL, *fisbio = NULL;
86     BIO_ADDR *peeraddr = NULL;
87     struct in_addr ina = {0};
88
89     *qtserv = NULL;
90     if (fault != NULL)
91         *fault = NULL;
92
93     if (*cssl == NULL) {
94         *cssl = SSL_new(clientctx);
95         if (!TEST_ptr(*cssl))
96             return 0;
97     }
98
99     /* SSL_set_alpn_protos returns 0 for success! */
100     if (!TEST_false(SSL_set_alpn_protos(*cssl, alpn, sizeof(alpn))))
101         goto err;
102
103     if (!TEST_ptr(peeraddr = BIO_ADDR_new()))
104         goto err;
105
106     if ((flags & QTEST_FLAG_BLOCK) != 0) {
107 #if !defined(OPENSSL_NO_POSIX_IO)
108         int cfd, sfd;
109
110         /*
111          * For blocking mode we need to create actual sockets rather than doing
112          * everything in memory
113          */
114         if (!TEST_true(create_test_sockets(&cfd, &sfd, SOCK_DGRAM, peeraddr)))
115             goto err;
116         cbio = BIO_new_dgram(cfd, 1);
117         if (!TEST_ptr(cbio)) {
118             close(cfd);
119             close(sfd);
120             goto err;
121         }
122         sbio = BIO_new_dgram(sfd, 1);
123         if (!TEST_ptr(sbio)) {
124             close(sfd);
125             goto err;
126         }
127 #else
128         goto err;
129 #endif
130     } else {
131         if (!TEST_true(BIO_new_bio_dgram_pair(&cbio, 0, &sbio, 0)))
132             goto err;
133
134         if (!TEST_true(BIO_dgram_set_caps(cbio, BIO_DGRAM_CAP_HANDLES_DST_ADDR))
135                 || !TEST_true(BIO_dgram_set_caps(sbio, BIO_DGRAM_CAP_HANDLES_DST_ADDR)))
136             goto err;
137
138         /* Dummy server address */
139         if (!TEST_true(BIO_ADDR_rawmake(peeraddr, AF_INET, &ina, sizeof(ina),
140                                         htons(0))))
141             goto err;
142     }
143
144     if ((flags & QTEST_FLAG_NOISE) != 0) {
145         BIO *noisebio = BIO_new(bio_f_noisy_dgram_filter());
146
147         if (!TEST_ptr(noisebio))
148             goto err;
149         cbio = BIO_push(noisebio, cbio);
150     }
151
152     SSL_set_bio(*cssl, cbio, cbio);
153
154     if (!TEST_true(SSL_set_blocking_mode(*cssl,
155                                          (flags & QTEST_FLAG_BLOCK) != 0 ? 1 : 0)))
156         goto err;
157
158     if (!TEST_true(SSL_set1_initial_peer_addr(*cssl, peeraddr)))
159         goto err;
160
161     if (fault != NULL) {
162         *fault = OPENSSL_zalloc(sizeof(**fault));
163         if (*fault == NULL)
164             goto err;
165     }
166
167     fisbio = BIO_new(qtest_get_bio_method());
168     if (!TEST_ptr(fisbio))
169         goto err;
170
171     BIO_set_data(fisbio, fault == NULL ? NULL : *fault);
172
173     if (!TEST_ptr(BIO_push(fisbio, sbio)))
174         goto err;
175
176     tserver_args.libctx = libctx;
177     tserver_args.net_rbio = sbio;
178     tserver_args.net_wbio = fisbio;
179     tserver_args.alpn = NULL;
180     if (serverctx != NULL && !TEST_true(SSL_CTX_up_ref(serverctx)))
181         goto err;
182     tserver_args.ctx = serverctx;
183     if ((flags & QTEST_FLAG_FAKE_TIME) != 0) {
184         using_fake_time = 1;
185         fake_now = ossl_time_zero();
186         /* zero time can have a special meaning, bump it */
187         qtest_add_time(1);
188         tserver_args.now_cb = fake_now_cb;
189         (void)ossl_quic_conn_set_override_now_cb(*cssl, fake_now_cb, NULL);
190     } else {
191         using_fake_time = 0;
192     }
193
194     if (!TEST_ptr(*qtserv = ossl_quic_tserver_new(&tserver_args, certfile,
195                                                   keyfile)))
196         goto err;
197
198     /* Ownership of fisbio and sbio is now held by *qtserv */
199     sbio = NULL;
200     fisbio = NULL;
201
202     if (fault != NULL)
203         (*fault)->qtserv = *qtserv;
204
205     BIO_ADDR_free(peeraddr);
206
207     return 1;
208  err:
209     SSL_CTX_free(tserver_args.ctx);
210     BIO_ADDR_free(peeraddr);
211     BIO_free(cbio);
212     BIO_free(fisbio);
213     BIO_free(sbio);
214     SSL_free(*cssl);
215     *cssl = NULL;
216     ossl_quic_tserver_free(*qtserv);
217     if (fault != NULL)
218         OPENSSL_free(*fault);
219
220     return 0;
221 }
222
223 void qtest_add_time(uint64_t millis)
224 {
225     fake_now = ossl_time_add(fake_now, ossl_ms2time(millis));
226 }
227
228 QTEST_FAULT *qtest_create_injector(QUIC_TSERVER *ts)
229 {
230     QTEST_FAULT *f;
231
232     f = OPENSSL_zalloc(sizeof(*f));
233     if (f == NULL)
234         return NULL;
235
236     f->qtserv = ts;
237     return f;
238
239 }
240
241 int qtest_supports_blocking(void)
242 {
243 #if !defined(OPENSSL_NO_POSIX_IO) && defined(OPENSSL_THREADS) && !defined(CRYPTO_TDEBUG)
244     return 1;
245 #else
246     return 0;
247 #endif
248 }
249
250 #define MAXLOOPS    1000
251
252 #if defined(OPENSSL_THREADS) && !defined(CRYPTO_TDEBUG)
253 static int globserverret = 0;
254 static TSAN_QUALIFIER int abortserverthread = 0;
255 static QUIC_TSERVER *globtserv;
256 static const thread_t thread_zero;
257
258 static void run_server_thread(void)
259 {
260     /*
261      * This will operate in a busy loop because the server does not block,
262      * but should be acceptable because it is local and we expect this to be
263      * fast
264      */
265     globserverret = qtest_create_quic_connection(globtserv, NULL);
266 }
267 #endif
268
269 static int wait_for_timeout(SSL *s, QUIC_TSERVER *qtserv)
270 {
271     struct timeval tv;
272     OSSL_TIME ctimeout, stimeout, mintimeout, now;
273     int cinf;
274
275     /* We don't need to wait in blocking mode */
276     if (s == NULL || qtserv == NULL)
277         return 1;
278
279     /* Don't wait if either BIO has data waiting */
280     if (BIO_pending(SSL_get_rbio(s)) > 0
281             || BIO_pending(ossl_quic_tserver_get0_rbio(qtserv)) > 0)
282         return 1;
283
284     /*
285      * Neither endpoint has data waiting to be read. We assume data transmission
286      * is instantaneous due to using mem based BIOs, so there is no data "in
287      * flight" and no more data will be sent by either endpoint until some time
288      * based event has occurred. Therefore, wait for a timeout to occur. This
289      * might happen if we are using the noisy BIO and datagrams have been lost.
290      */
291     if (!SSL_get_event_timeout(s, &tv, &cinf))
292         return 0;
293     if (using_fake_time)
294         now = fake_now;
295     else
296         now = ossl_time_now();
297     ctimeout = cinf ? ossl_time_infinite() : ossl_time_from_timeval(tv);
298     stimeout = ossl_time_subtract(ossl_quic_tserver_get_deadline(qtserv), now);
299     mintimeout = ossl_time_min(ctimeout, stimeout);
300     if (ossl_time_is_infinite(mintimeout))
301         return 0;
302     if (using_fake_time)
303         fake_now = ossl_time_add(now, mintimeout);
304     else
305         OSSL_sleep(ossl_time2ms(mintimeout));
306
307     return 1;
308 }
309
310 int qtest_create_quic_connection_ex(QUIC_TSERVER *qtserv, SSL *clientssl,
311                                     int wanterr)
312 {
313     int retc = -1, rets = 0, abortctr = 0, ret = 0;
314     int clienterr = 0, servererr = 0;
315 #if defined(OPENSSL_THREADS) && !defined(CRYPTO_TDEBUG)
316     /*
317      * Pointless initialisation to avoid bogus compiler warnings about using
318      * t uninitialised
319      */
320     thread_t t = thread_zero;
321
322     if (clientssl != NULL)
323         abortserverthread = 0;
324 #endif
325
326     if (!TEST_ptr(qtserv)) {
327         goto err;
328     } else if (clientssl == NULL) {
329         retc = 1;
330     } else if (SSL_get_blocking_mode(clientssl) > 0) {
331 #if defined(OPENSSL_THREADS) && !defined(CRYPTO_TDEBUG)
332         /*
333          * clientssl is blocking. We will need a thread to complete the
334          * connection
335          */
336         globtserv = qtserv;
337         if (!TEST_true(run_thread(&t, run_server_thread)))
338             goto err;
339
340         qtserv = NULL;
341         rets = 1;
342 #else
343         TEST_error("No thread support in this build");
344         goto err;
345 #endif
346     }
347
348     do {
349         if (!clienterr && retc <= 0) {
350             int err;
351
352             retc = SSL_connect(clientssl);
353             if (retc <= 0) {
354                 err = SSL_get_error(clientssl, retc);
355
356                 if (err == wanterr) {
357                     retc = 1;
358 #if defined(OPENSSL_THREADS) && !defined(CRYPTO_TDEBUG)
359                     if (qtserv == NULL && rets > 0)
360                         tsan_store(&abortserverthread, 1);
361                     else
362 #endif
363                         rets = 1;
364                 } else {
365                     if (err != SSL_ERROR_WANT_READ
366                             && err != SSL_ERROR_WANT_WRITE) {
367                         TEST_info("SSL_connect() failed %d, %d", retc, err);
368                         TEST_openssl_errors();
369                         clienterr = 1;
370                     }
371                 }
372             }
373         }
374
375         if (!clienterr && retc <= 0)
376             SSL_handle_events(clientssl);
377
378         if (!servererr && rets <= 0) {
379             qtest_add_time(1);
380             ossl_quic_tserver_tick(qtserv);
381             servererr = ossl_quic_tserver_is_term_any(qtserv);
382             if (!servererr)
383                 rets = ossl_quic_tserver_is_handshake_confirmed(qtserv);
384         }
385
386         if (clienterr && servererr)
387             goto err;
388
389         if (clientssl != NULL && ++abortctr == MAXLOOPS) {
390             TEST_info("No progress made");
391             goto err;
392         }
393
394         if (!wait_for_timeout(clientssl, qtserv))
395             goto err;
396     } while ((retc <= 0 && !clienterr)
397              || (rets <= 0 && !servererr
398 #if defined(OPENSSL_THREADS) && !defined(CRYPTO_TDEBUG)
399                  && !tsan_load(&abortserverthread)
400 #endif
401                 ));
402
403     if (qtserv == NULL && rets > 0) {
404 #if defined(OPENSSL_THREADS) && !defined(CRYPTO_TDEBUG)
405         if (!TEST_true(wait_for_thread(t)) || !TEST_true(globserverret))
406             goto err;
407 #else
408         TEST_error("Should not happen");
409         goto err;
410 #endif
411     }
412
413     if (!clienterr && !servererr)
414         ret = 1;
415  err:
416     return ret;
417 }
418
419 int qtest_create_quic_connection(QUIC_TSERVER *qtserv, SSL *clientssl)
420 {
421     return qtest_create_quic_connection_ex(qtserv, clientssl, SSL_ERROR_NONE);
422 }
423
424 #if defined(OPENSSL_THREADS) && !defined(CRYPTO_TDEBUG)
425 static TSAN_QUALIFIER int shutdowndone;
426
427 static void run_server_shutdown_thread(void)
428 {
429     /*
430      * This will operate in a busy loop because the server does not block,
431      * but should be acceptable because it is local and we expect this to be
432      * fast
433      */
434     do {
435         ossl_quic_tserver_tick(globtserv);
436     } while(!tsan_load(&shutdowndone));
437 }
438 #endif
439
440 int qtest_shutdown(QUIC_TSERVER *qtserv, SSL *clientssl)
441 {
442     int tickserver = 1;
443     int ret = 0;
444 #if defined(OPENSSL_THREADS) && !defined(CRYPTO_TDEBUG)
445     /*
446      * Pointless initialisation to avoid bogus compiler warnings about using
447      * t uninitialised
448      */
449     thread_t t = thread_zero;
450 #endif
451
452     if (SSL_get_blocking_mode(clientssl) > 0) {
453 #if defined(OPENSSL_THREADS) && !defined(CRYPTO_TDEBUG)
454         /*
455          * clientssl is blocking. We will need a thread to complete the
456          * connection
457          */
458         globtserv = qtserv;
459         shutdowndone = 0;
460         if (!TEST_true(run_thread(&t, run_server_shutdown_thread)))
461             return 0;
462
463         tickserver = 0;
464 #else
465         TEST_error("No thread support in this build");
466         return 0;
467 #endif
468     }
469
470     /* Busy loop in non-blocking mode. It should be quick because its local */
471     for (;;) {
472         int rc = SSL_shutdown(clientssl);
473
474         if (rc == 1) {
475             ret = 1;
476             break;
477         }
478
479         if (rc < 0)
480             break;
481
482         if (tickserver)
483             ossl_quic_tserver_tick(qtserv);
484     }
485
486 #if defined(OPENSSL_THREADS) && !defined(CRYPTO_TDEBUG)
487     tsan_store(&shutdowndone, 1);
488     if (!tickserver) {
489         if (!TEST_true(wait_for_thread(t)))
490             ret = 0;
491     }
492 #endif
493
494     return ret;
495 }
496
497 int qtest_check_server_transport_err(QUIC_TSERVER *qtserv, uint64_t code)
498 {
499     const QUIC_TERMINATE_CAUSE *cause;
500
501     ossl_quic_tserver_tick(qtserv);
502
503     /*
504      * Check that the server has closed with the specified code from the client
505      */
506     if (!TEST_true(ossl_quic_tserver_is_term_any(qtserv)))
507         return 0;
508
509     cause = ossl_quic_tserver_get_terminate_cause(qtserv);
510     if  (!TEST_ptr(cause)
511             || !TEST_true(cause->remote)
512             || !TEST_false(cause->app)
513             || !TEST_uint64_t_eq(cause->error_code, code))
514         return 0;
515
516     return 1;
517 }
518
519 int qtest_check_server_protocol_err(QUIC_TSERVER *qtserv)
520 {
521     return qtest_check_server_transport_err(qtserv, QUIC_ERR_PROTOCOL_VIOLATION);
522 }
523
524 int qtest_check_server_frame_encoding_err(QUIC_TSERVER *qtserv)
525 {
526     return qtest_check_server_transport_err(qtserv, QUIC_ERR_FRAME_ENCODING_ERROR);
527 }
528
529 void qtest_fault_free(QTEST_FAULT *fault)
530 {
531     if (fault == NULL)
532         return;
533
534     packet_plain_finish(fault);
535     handshake_finish(fault);
536
537     OPENSSL_free(fault);
538 }
539
540 static int packet_plain_mutate(const QUIC_PKT_HDR *hdrin,
541                                const OSSL_QTX_IOVEC *iovecin, size_t numin,
542                                QUIC_PKT_HDR **hdrout,
543                                const OSSL_QTX_IOVEC **iovecout,
544                                size_t *numout,
545                                void *arg)
546 {
547     QTEST_FAULT *fault = arg;
548     size_t i, bufsz = 0;
549     unsigned char *cur;
550
551     /* Coalesce our data into a single buffer */
552
553     /* First calculate required buffer size */
554     for (i = 0; i < numin; i++)
555         bufsz += iovecin[i].buf_len;
556
557     fault->pplainio.buf_len = bufsz;
558
559     /* Add an allowance for possible growth */
560     bufsz += GROWTH_ALLOWANCE;
561
562     fault->pplainio.buf = cur = OPENSSL_malloc(bufsz);
563     if (cur == NULL) {
564         fault->pplainio.buf_len = 0;
565         return 0;
566     }
567
568     fault->pplainbuf_alloc = bufsz;
569
570     /* Copy in the data from the input buffers */
571     for (i = 0; i < numin; i++) {
572         memcpy(cur, iovecin[i].buf, iovecin[i].buf_len);
573         cur += iovecin[i].buf_len;
574     }
575
576     fault->pplainhdr = *hdrin;
577
578     /* Cast below is safe because we allocated the buffer */
579     if (fault->pplaincb != NULL
580             && !fault->pplaincb(fault, &fault->pplainhdr,
581                                 (unsigned char *)fault->pplainio.buf,
582                                 fault->pplainio.buf_len, fault->pplaincbarg))
583         return 0;
584
585     *hdrout = &fault->pplainhdr;
586     *iovecout = &fault->pplainio;
587     *numout = 1;
588
589     return 1;
590 }
591
592 static void packet_plain_finish(void *arg)
593 {
594     QTEST_FAULT *fault = arg;
595
596     /* Cast below is safe because we allocated the buffer */
597     OPENSSL_free((unsigned char *)fault->pplainio.buf);
598     fault->pplainio.buf_len = 0;
599     fault->pplainbuf_alloc = 0;
600     fault->pplainio.buf = NULL;
601 }
602
603 int qtest_fault_set_packet_plain_listener(QTEST_FAULT *fault,
604                                           qtest_fault_on_packet_plain_cb pplaincb,
605                                           void *pplaincbarg)
606 {
607     fault->pplaincb = pplaincb;
608     fault->pplaincbarg = pplaincbarg;
609
610     return ossl_quic_tserver_set_plain_packet_mutator(fault->qtserv,
611                                                       packet_plain_mutate,
612                                                       packet_plain_finish,
613                                                       fault);
614 }
615
616 /* To be called from a packet_plain_listener callback */
617 int qtest_fault_resize_plain_packet(QTEST_FAULT *fault, size_t newlen)
618 {
619     unsigned char *buf;
620     size_t oldlen = fault->pplainio.buf_len;
621
622     /*
623      * Alloc'd size should always be non-zero, so if this fails we've been
624      * incorrectly called
625      */
626     if (fault->pplainbuf_alloc == 0)
627         return 0;
628
629     if (newlen > fault->pplainbuf_alloc) {
630         /* This exceeds our growth allowance. Fail */
631         return 0;
632     }
633
634     /* Cast below is safe because we allocated the buffer */
635     buf = (unsigned char *)fault->pplainio.buf;
636
637     if (newlen > oldlen) {
638         /* Extend packet with 0 bytes */
639         memset(buf + oldlen, 0, newlen - oldlen);
640     } /* else we're truncating or staying the same */
641
642     fault->pplainio.buf_len = newlen;
643     fault->pplainhdr.len = newlen;
644
645     return 1;
646 }
647
648 /*
649  * Prepend frame data into a packet. To be called from a packet_plain_listener
650  * callback
651  */
652 int qtest_fault_prepend_frame(QTEST_FAULT *fault, const unsigned char *frame,
653                               size_t frame_len)
654 {
655     unsigned char *buf;
656     size_t old_len;
657
658     /*
659      * Alloc'd size should always be non-zero, so if this fails we've been
660      * incorrectly called
661      */
662     if (fault->pplainbuf_alloc == 0)
663         return 0;
664
665     /* Cast below is safe because we allocated the buffer */
666     buf = (unsigned char *)fault->pplainio.buf;
667     old_len = fault->pplainio.buf_len;
668
669     /* Extend the size of the packet by the size of the new frame */
670     if (!TEST_true(qtest_fault_resize_plain_packet(fault,
671                                                    old_len + frame_len)))
672         return 0;
673
674     memmove(buf + frame_len, buf, old_len);
675     memcpy(buf, frame, frame_len);
676
677     return 1;
678 }
679
680 static int handshake_mutate(const unsigned char *msgin, size_t msginlen,
681                             unsigned char **msgout, size_t *msgoutlen,
682                             void *arg)
683 {
684     QTEST_FAULT *fault = arg;
685     unsigned char *buf;
686     unsigned long payloadlen;
687     unsigned int msgtype;
688     PACKET pkt;
689
690     buf = OPENSSL_malloc(msginlen + GROWTH_ALLOWANCE);
691     if (buf == NULL)
692         return 0;
693
694     fault->handbuf = buf;
695     fault->handbuflen = msginlen;
696     fault->handbufalloc = msginlen + GROWTH_ALLOWANCE;
697     memcpy(buf, msgin, msginlen);
698
699     if (!PACKET_buf_init(&pkt, buf, msginlen)
700             || !PACKET_get_1(&pkt, &msgtype)
701             || !PACKET_get_net_3(&pkt, &payloadlen)
702             || PACKET_remaining(&pkt) != payloadlen)
703         return 0;
704
705     /* Parse specific message types */
706     switch (msgtype) {
707     case SSL3_MT_ENCRYPTED_EXTENSIONS:
708     {
709         QTEST_ENCRYPTED_EXTENSIONS ee;
710
711         if (fault->encextcb == NULL)
712             break;
713
714         /*
715          * The EncryptedExtensions message is very simple. It just has an
716          * extensions block in it and nothing else.
717          */
718         ee.extensions = (unsigned char *)PACKET_data(&pkt);
719         ee.extensionslen = payloadlen;
720         if (!fault->encextcb(fault, &ee, payloadlen, fault->encextcbarg))
721             return 0;
722     }
723
724     default:
725         /* No specific handlers for these message types yet */
726         break;
727     }
728
729     if (fault->handshakecb != NULL
730             && !fault->handshakecb(fault, buf, fault->handbuflen,
731                                    fault->handshakecbarg))
732         return 0;
733
734     *msgout = buf;
735     *msgoutlen = fault->handbuflen;
736
737     return 1;
738 }
739
740 static void handshake_finish(void *arg)
741 {
742     QTEST_FAULT *fault = arg;
743
744     OPENSSL_free(fault->handbuf);
745     fault->handbuf = NULL;
746 }
747
748 int qtest_fault_set_handshake_listener(QTEST_FAULT *fault,
749                                        qtest_fault_on_handshake_cb handshakecb,
750                                        void *handshakecbarg)
751 {
752     fault->handshakecb = handshakecb;
753     fault->handshakecbarg = handshakecbarg;
754
755     return ossl_quic_tserver_set_handshake_mutator(fault->qtserv,
756                                                    handshake_mutate,
757                                                    handshake_finish,
758                                                    fault);
759 }
760
761 int qtest_fault_set_hand_enc_ext_listener(QTEST_FAULT *fault,
762                                           qtest_fault_on_enc_ext_cb encextcb,
763                                           void *encextcbarg)
764 {
765     fault->encextcb = encextcb;
766     fault->encextcbarg = encextcbarg;
767
768     return ossl_quic_tserver_set_handshake_mutator(fault->qtserv,
769                                                    handshake_mutate,
770                                                    handshake_finish,
771                                                    fault);
772 }
773
774 /* To be called from a handshake_listener callback */
775 int qtest_fault_resize_handshake(QTEST_FAULT *fault, size_t newlen)
776 {
777     unsigned char *buf;
778     size_t oldlen = fault->handbuflen;
779
780     /*
781      * Alloc'd size should always be non-zero, so if this fails we've been
782      * incorrectly called
783      */
784     if (fault->handbufalloc == 0)
785         return 0;
786
787     if (newlen > fault->handbufalloc) {
788         /* This exceeds our growth allowance. Fail */
789         return 0;
790     }
791
792     buf = (unsigned char *)fault->handbuf;
793
794     if (newlen > oldlen) {
795         /* Extend packet with 0 bytes */
796         memset(buf + oldlen, 0, newlen - oldlen);
797     } /* else we're truncating or staying the same */
798
799     fault->handbuflen = newlen;
800     return 1;
801 }
802
803 /* To be called from message specific listener callbacks */
804 int qtest_fault_resize_message(QTEST_FAULT *fault, size_t newlen)
805 {
806     /* First resize the underlying message */
807     if (!qtest_fault_resize_handshake(fault, newlen + SSL3_HM_HEADER_LENGTH))
808         return 0;
809
810     /* Fixup the handshake message header */
811     fault->handbuf[1] = (unsigned char)((newlen >> 16) & 0xff);
812     fault->handbuf[2] = (unsigned char)((newlen >>  8) & 0xff);
813     fault->handbuf[3] = (unsigned char)((newlen      ) & 0xff);
814
815     return 1;
816 }
817
818 int qtest_fault_delete_extension(QTEST_FAULT *fault,
819                                  unsigned int exttype, unsigned char *ext,
820                                  size_t *extlen)
821 {
822     PACKET pkt, sub, subext;
823     unsigned int type;
824     const unsigned char *start, *end;
825     size_t newlen;
826     size_t msglen = fault->handbuflen;
827
828     if (!PACKET_buf_init(&pkt, ext, *extlen))
829         return 0;
830
831     /* Extension block starts with 2 bytes for extension block length */
832     if (!PACKET_as_length_prefixed_2(&pkt, &sub))
833         return 0;
834
835     do {
836         start = PACKET_data(&sub);
837         if (!PACKET_get_net_2(&sub, &type)
838                 || !PACKET_get_length_prefixed_2(&sub, &subext))
839             return 0;
840     } while (type != exttype);
841
842     /* Found it */
843     end = PACKET_data(&sub);
844
845     /*
846      * If we're not the last extension we need to move the rest earlier. The
847      * cast below is safe because we own the underlying buffer and we're no
848      * longer making PACKET calls.
849      */
850     if (end < ext + *extlen)
851         memmove((unsigned char *)start, end, end - start);
852
853     /*
854      * Calculate new extensions payload length =
855      * Original length
856      * - 2 extension block length bytes
857      * - length of removed extension
858      */
859     newlen = *extlen - 2 - (end - start);
860
861     /* Fixup the length bytes for the extension block */
862     ext[0] = (unsigned char)((newlen >> 8) & 0xff);
863     ext[1] = (unsigned char)((newlen     ) & 0xff);
864
865     /*
866      * Length of the whole extension block is the new payload length plus the
867      * 2 bytes for the length
868      */
869     *extlen = newlen + 2;
870
871     /* We can now resize the message */
872     if ((size_t)(end - start) + SSL3_HM_HEADER_LENGTH > msglen)
873         return 0; /* Should not happen */
874     msglen -= (end - start) + SSL3_HM_HEADER_LENGTH;
875     if (!qtest_fault_resize_message(fault, msglen))
876         return 0;
877
878     return 1;
879 }
880
881 #define BIO_TYPE_CIPHER_PACKET_FILTER  (0x80 | BIO_TYPE_FILTER)
882
883 static BIO_METHOD *pcipherbiometh = NULL;
884
885 # define BIO_MSG_N(array, stride, n) (*(BIO_MSG *)((char *)(array) + (n)*(stride)))
886
887 static int pcipher_sendmmsg(BIO *b, BIO_MSG *msg, size_t stride,
888                             size_t num_msg, uint64_t flags,
889                             size_t *num_processed)
890 {
891     QTEST_FAULT *fault;
892     BIO *next = BIO_next(b);
893     ossl_ssize_t ret = 0;
894     size_t i = 0, tmpnump;
895     QUIC_PKT_HDR hdr;
896     PACKET pkt;
897     unsigned char *tmpdata;
898
899     if (next == NULL)
900         return 0;
901
902     fault = BIO_get_data(b);
903     if (fault == NULL
904             || (fault->pciphercb == NULL && fault->datagramcb == NULL))
905         return BIO_sendmmsg(next, msg, stride, num_msg, flags, num_processed);
906
907     if (num_msg == 0) {
908         *num_processed = 0;
909         return 1;
910     }
911
912     for (i = 0; i < num_msg; ++i) {
913         fault->msg = BIO_MSG_N(msg, stride, i);
914
915         /* Take a copy of the data so that callbacks can modify it */
916         tmpdata = OPENSSL_malloc(fault->msg.data_len + GROWTH_ALLOWANCE);
917         if (tmpdata == NULL)
918             return 0;
919         memcpy(tmpdata, fault->msg.data, fault->msg.data_len);
920         fault->msg.data = tmpdata;
921         fault->msgalloc = fault->msg.data_len + GROWTH_ALLOWANCE;
922
923         if (fault->pciphercb != NULL) {
924             if (!PACKET_buf_init(&pkt, fault->msg.data, fault->msg.data_len))
925                 return 0;
926
927             do {
928                 if (!ossl_quic_wire_decode_pkt_hdr(&pkt,
929                         /*
930                          * TODO(QUIC SERVER):
931                          * Needs to be set to the actual short header CID length
932                          * when testing the server implementation.
933                          */
934                         0,
935                         1,
936                         0, &hdr, NULL))
937                     goto out;
938
939                 /*
940                  * hdr.data is const - but its our buffer so casting away the
941                  * const is safe
942                  */
943                 if (!fault->pciphercb(fault, &hdr, (unsigned char *)hdr.data,
944                                     hdr.len, fault->pciphercbarg))
945                     goto out;
946
947                 /*
948                  * At the moment modifications to hdr by the callback
949                  * are ignored. We might need to rewrite the QUIC header to
950                  * enable tests to change this. We also don't yet have a
951                  * mechanism for the callback to change the encrypted data
952                  * length. It's not clear if that's needed or not.
953                  */
954             } while (PACKET_remaining(&pkt) > 0);
955         }
956
957         if (fault->datagramcb != NULL
958                 && !fault->datagramcb(fault, &fault->msg, stride,
959                                       fault->datagramcbarg))
960             goto out;
961
962         if (!BIO_sendmmsg(next, &fault->msg, stride, 1, flags, &tmpnump)) {
963             *num_processed = i;
964             goto out;
965         }
966
967         OPENSSL_free(fault->msg.data);
968         fault->msg.data = NULL;
969         fault->msgalloc = 0;
970     }
971
972     *num_processed = i;
973 out:
974     ret = i > 0;
975     OPENSSL_free(fault->msg.data);
976     fault->msg.data = NULL;
977     return ret;
978 }
979
980 static long pcipher_ctrl(BIO *b, int cmd, long larg, void *parg)
981 {
982     BIO *next = BIO_next(b);
983
984     if (next == NULL)
985         return -1;
986
987     return BIO_ctrl(next, cmd, larg, parg);
988 }
989
990 BIO_METHOD *qtest_get_bio_method(void)
991 {
992     BIO_METHOD *tmp;
993
994     if (pcipherbiometh != NULL)
995         return pcipherbiometh;
996
997     tmp = BIO_meth_new(BIO_TYPE_CIPHER_PACKET_FILTER, "Cipher Packet Filter");
998
999     if (!TEST_ptr(tmp))
1000         return NULL;
1001
1002     if (!TEST_true(BIO_meth_set_sendmmsg(tmp, pcipher_sendmmsg))
1003             || !TEST_true(BIO_meth_set_ctrl(tmp, pcipher_ctrl)))
1004         goto err;
1005
1006     pcipherbiometh = tmp;
1007     tmp = NULL;
1008  err:
1009     BIO_meth_free(tmp);
1010     return pcipherbiometh;
1011 }
1012
1013 int qtest_fault_set_packet_cipher_listener(QTEST_FAULT *fault,
1014                                            qtest_fault_on_packet_cipher_cb pciphercb,
1015                                            void *pciphercbarg)
1016 {
1017     fault->pciphercb = pciphercb;
1018     fault->pciphercbarg = pciphercbarg;
1019
1020     return 1;
1021 }
1022
1023 int qtest_fault_set_datagram_listener(QTEST_FAULT *fault,
1024                                       qtest_fault_on_datagram_cb datagramcb,
1025                                       void *datagramcbarg)
1026 {
1027     fault->datagramcb = datagramcb;
1028     fault->datagramcbarg = datagramcbarg;
1029
1030     return 1;
1031 }
1032
1033 /* To be called from a datagram_listener callback */
1034 int qtest_fault_resize_datagram(QTEST_FAULT *fault, size_t newlen)
1035 {
1036     if (newlen > fault->msgalloc)
1037             return 0;
1038
1039     if (newlen > fault->msg.data_len)
1040         memset((unsigned char *)fault->msg.data + fault->msg.data_len, 0,
1041                 newlen - fault->msg.data_len);
1042
1043     fault->msg.data_len = newlen;
1044
1045     return 1;
1046 }
1047
1048 /* There isn't a public function to do BIO_ADDR_copy() so we create one */
1049 int bio_addr_copy(BIO_ADDR *dst, BIO_ADDR *src)
1050 {
1051     size_t len;
1052     void *data = NULL;
1053     int res = 0;
1054     int family;
1055
1056     if (src == NULL || dst == NULL)
1057         return 0;
1058
1059     family = BIO_ADDR_family(src);
1060     if (family == AF_UNSPEC) {
1061         BIO_ADDR_clear(dst);
1062         return 1;
1063     }
1064
1065     if (!BIO_ADDR_rawaddress(src, NULL, &len))
1066         return 0;
1067
1068     if (len > 0) {
1069         data = OPENSSL_malloc(len);
1070         if (!TEST_ptr(data))
1071             return 0;
1072     }
1073
1074     if (!BIO_ADDR_rawaddress(src, data, &len))
1075         goto err;
1076
1077     if (!BIO_ADDR_rawmake(src, family, data, len, BIO_ADDR_rawport(src)))
1078         goto err;
1079
1080     res = 1;
1081  err:
1082     OPENSSL_free(data);
1083     return res;
1084 }
1085
1086 int bio_msg_copy(BIO_MSG *dst, BIO_MSG *src)
1087 {
1088     /*
1089      * Note it is assumed that the originally allocated data sizes for dst and
1090      * src are the same
1091      */
1092     memcpy(dst->data, src->data, src->data_len);
1093     dst->data_len = src->data_len;
1094     dst->flags = src->flags;
1095     if (dst->local != NULL) {
1096         if (src->local != NULL) {
1097             if (!TEST_true(bio_addr_copy(dst->local, src->local)))
1098                 return 0;
1099         } else {
1100             BIO_ADDR_clear(dst->local);
1101         }
1102     }
1103     if (!TEST_true(bio_addr_copy(dst->peer, src->peer)))
1104         return 0;
1105
1106     return 1;
1107 }