add locking around fake_now
[openssl.git] / test / helpers / quictestlib.c
1 /*
2  * Copyright 2022-2023 The OpenSSL Project Authors. All Rights Reserved.
3  *
4  * Licensed under the Apache License 2.0 (the "License").  You may not use
5  * this file except in compliance with the License.  You can obtain a copy
6  * in the file LICENSE in the source distribution or at
7  * https://www.openssl.org/source/license.html
8  */
9
10 #include <assert.h>
11 #include <openssl/configuration.h>
12 #include <openssl/bio.h>
13 #include "internal/e_os.h" /* For struct timeval */
14 #include "quictestlib.h"
15 #include "ssltestlib.h"
16 #include "../testutil.h"
17 #if defined(OPENSSL_THREADS) && !defined(CRYPTO_TDEBUG)
18 # include "../threadstest.h"
19 #endif
20 #include "internal/quic_ssl.h"
21 #include "internal/quic_wire_pkt.h"
22 #include "internal/quic_record_tx.h"
23 #include "internal/quic_error.h"
24 #include "internal/packet.h"
25 #include "internal/tsan_assist.h"
26
27 #define GROWTH_ALLOWANCE 1024
28
29 struct noise_args_data_st {
30     BIO *cbio;
31     BIO *sbio;
32     BIO *tracebio;
33     int flags;
34 };
35
36 struct qtest_fault {
37     QUIC_TSERVER *qtserv;
38
39     /* Plain packet mutations */
40     /* Header for the plaintext packet */
41     QUIC_PKT_HDR pplainhdr;
42     /* iovec for the plaintext packet data buffer */
43     OSSL_QTX_IOVEC pplainio;
44     /* Allocated size of the plaintext packet data buffer */
45     size_t pplainbuf_alloc;
46     qtest_fault_on_packet_plain_cb pplaincb;
47     void *pplaincbarg;
48
49     /* Handshake message mutations */
50     /* Handshake message buffer */
51     unsigned char *handbuf;
52     /* Allocated size of the handshake message buffer */
53     size_t handbufalloc;
54     /* Actual length of the handshake message */
55     size_t handbuflen;
56     qtest_fault_on_handshake_cb handshakecb;
57     void *handshakecbarg;
58     qtest_fault_on_enc_ext_cb encextcb;
59     void *encextcbarg;
60
61     /* Cipher packet mutations */
62     qtest_fault_on_packet_cipher_cb pciphercb;
63     void *pciphercbarg;
64
65     /* Datagram mutations */
66     qtest_fault_on_datagram_cb datagramcb;
67     void *datagramcbarg;
68     /* The currently processed message */
69     BIO_MSG msg;
70     /* Allocated size of msg data buffer */
71     size_t msgalloc;
72     struct noise_args_data_st noiseargs;
73 };
74
75 static void packet_plain_finish(void *arg);
76 static void handshake_finish(void *arg);
77 static OSSL_TIME qtest_get_time(void);
78 static void qtest_reset_time(void);
79
80 static int using_fake_time = 0;
81 static OSSL_TIME fake_now;
82 static CRYPTO_RWLOCK *fake_now_lock = NULL;
83
84 static OSSL_TIME fake_now_cb(void *arg)
85 {
86     return qtest_get_time();
87 }
88
89 static void noise_msg_callback(int write_p, int version, int content_type,
90                                const void *buf, size_t len, SSL *ssl,
91                                void *arg)
92 {
93     struct noise_args_data_st *noiseargs = (struct noise_args_data_st *)arg;
94
95     if (content_type == SSL3_RT_QUIC_FRAME_FULL) {
96         PACKET pkt;
97         uint64_t frame_type;
98
99         if (!PACKET_buf_init(&pkt, buf, len))
100             return;
101
102         if (!ossl_quic_wire_peek_frame_header(&pkt, &frame_type, NULL))
103             return;
104
105         if (frame_type == OSSL_QUIC_FRAME_TYPE_PING) {
106             /*
107              * If either endpoint issues a ping frame then we are in danger
108              * of our noise being too much such that the connection itself
109              * fails. We back off on the noise for a bit to avoid that.
110              */
111             (void)BIO_ctrl(noiseargs->cbio, BIO_CTRL_NOISE_BACK_OFF, 0, NULL);
112             (void)BIO_ctrl(noiseargs->sbio, BIO_CTRL_NOISE_BACK_OFF, 0, NULL);
113         }
114     }
115
116 #ifndef OPENSSL_NO_SSL_TRACE
117     if ((noiseargs->flags & QTEST_FLAG_CLIENT_TRACE) != 0
118             && !SSL_is_server(ssl))
119         SSL_trace(write_p, version, content_type, buf, len, ssl,
120                   noiseargs->tracebio);
121 #endif
122 }
123
124 int qtest_create_quic_objects(OSSL_LIB_CTX *libctx, SSL_CTX *clientctx,
125                               SSL_CTX *serverctx, char *certfile, char *keyfile,
126                               int flags, QUIC_TSERVER **qtserv, SSL **cssl,
127                               QTEST_FAULT **fault, BIO **tracebio)
128 {
129     /* ALPN value as recognised by QUIC_TSERVER */
130     unsigned char alpn[] = { 8, 'o', 's', 's', 'l', 't', 'e', 's', 't' };
131     QUIC_TSERVER_ARGS tserver_args = {0};
132     BIO *cbio = NULL, *sbio = NULL, *fisbio = NULL;
133     BIO_ADDR *peeraddr = NULL;
134     struct in_addr ina = {0};
135     BIO *tmpbio = NULL;
136
137     *qtserv = NULL;
138     if (*cssl == NULL) {
139         *cssl = SSL_new(clientctx);
140         if (!TEST_ptr(*cssl))
141             return 0;
142     }
143
144     if (fault != NULL) {
145         *fault = OPENSSL_zalloc(sizeof(**fault));
146         if (*fault == NULL)
147             goto err;
148     }
149
150 #ifndef OPENSSL_NO_SSL_TRACE
151     if ((flags & QTEST_FLAG_CLIENT_TRACE) != 0) {
152         tmpbio = BIO_new_fp(stdout, BIO_NOCLOSE);
153         if (!TEST_ptr(tmpbio))
154             goto err;
155
156         SSL_set_msg_callback(*cssl, SSL_trace);
157         SSL_set_msg_callback_arg(*cssl, tmpbio);
158     }
159 #endif
160     if (tracebio != NULL)
161         *tracebio = tmpbio;
162
163     /* SSL_set_alpn_protos returns 0 for success! */
164     if (!TEST_false(SSL_set_alpn_protos(*cssl, alpn, sizeof(alpn))))
165         goto err;
166
167     if (!TEST_ptr(peeraddr = BIO_ADDR_new()))
168         goto err;
169
170     if ((flags & QTEST_FLAG_BLOCK) != 0) {
171 #if !defined(OPENSSL_NO_POSIX_IO)
172         int cfd, sfd;
173
174         /*
175          * For blocking mode we need to create actual sockets rather than doing
176          * everything in memory
177          */
178         if (!TEST_true(create_test_sockets(&cfd, &sfd, SOCK_DGRAM, peeraddr)))
179             goto err;
180         cbio = BIO_new_dgram(cfd, 1);
181         if (!TEST_ptr(cbio)) {
182             close(cfd);
183             close(sfd);
184             goto err;
185         }
186         sbio = BIO_new_dgram(sfd, 1);
187         if (!TEST_ptr(sbio)) {
188             close(sfd);
189             goto err;
190         }
191 #else
192         goto err;
193 #endif
194     } else {
195         if (!TEST_true(BIO_new_bio_dgram_pair(&cbio, 0, &sbio, 0)))
196             goto err;
197
198         if (!TEST_true(BIO_dgram_set_caps(cbio, BIO_DGRAM_CAP_HANDLES_DST_ADDR))
199                 || !TEST_true(BIO_dgram_set_caps(sbio, BIO_DGRAM_CAP_HANDLES_DST_ADDR)))
200             goto err;
201
202         /* Dummy server address */
203         if (!TEST_true(BIO_ADDR_rawmake(peeraddr, AF_INET, &ina, sizeof(ina),
204                                         htons(0))))
205             goto err;
206     }
207
208     if ((flags & QTEST_FLAG_PACKET_SPLIT) != 0) {
209         BIO *pktsplitbio = BIO_new(bio_f_pkt_split_dgram_filter());
210
211         if (!TEST_ptr(pktsplitbio))
212             goto err;
213         cbio = BIO_push(pktsplitbio, cbio);
214
215         pktsplitbio = BIO_new(bio_f_pkt_split_dgram_filter());
216         if (!TEST_ptr(pktsplitbio))
217             goto err;
218         sbio = BIO_push(pktsplitbio, sbio);
219     }
220
221     if ((flags & QTEST_FLAG_NOISE) != 0) {
222         BIO *noisebio;
223
224         /*
225          * It is an error to not have a QTEST_FAULT object when introducing noise
226          */
227         if (!TEST_ptr(fault))
228             goto err;
229
230         noisebio = BIO_new(bio_f_noisy_dgram_filter());
231
232         if (!TEST_ptr(noisebio))
233             goto err;
234         cbio = BIO_push(noisebio, cbio);
235
236         noisebio = BIO_new(bio_f_noisy_dgram_filter());
237
238         if (!TEST_ptr(noisebio))
239             goto err;
240         sbio = BIO_push(noisebio, sbio);
241         /*
242          * TODO(QUIC SERVER):
243          *    Currently the simplistic handler of the quic tserver cannot cope
244          *    with noise introduced in the first packet received from the
245          *    client. This needs to be removed once we have proper server side
246          *    handling.
247          */
248         (void)BIO_ctrl(sbio, BIO_CTRL_NOISE_BACK_OFF, 0, NULL);
249
250         (*fault)->noiseargs.cbio = cbio;
251         (*fault)->noiseargs.sbio = sbio;
252         (*fault)->noiseargs.tracebio = tmpbio;
253         (*fault)->noiseargs.flags = flags;
254
255         SSL_set_msg_callback(*cssl, noise_msg_callback);
256         SSL_set_msg_callback_arg(*cssl, &(*fault)->noiseargs);
257     }
258
259     SSL_set_bio(*cssl, cbio, cbio);
260
261     if (!TEST_true(SSL_set_blocking_mode(*cssl,
262                                          (flags & QTEST_FLAG_BLOCK) != 0 ? 1 : 0)))
263         goto err;
264
265     if (!TEST_true(SSL_set1_initial_peer_addr(*cssl, peeraddr)))
266         goto err;
267
268     fisbio = BIO_new(qtest_get_bio_method());
269     if (!TEST_ptr(fisbio))
270         goto err;
271
272     BIO_set_data(fisbio, fault == NULL ? NULL : *fault);
273
274     if (!BIO_up_ref(sbio))
275         goto err;
276     if (!TEST_ptr(BIO_push(fisbio, sbio))) {
277         BIO_free(sbio);
278         goto err;
279     }
280
281     tserver_args.libctx = libctx;
282     tserver_args.net_rbio = sbio;
283     tserver_args.net_wbio = fisbio;
284     tserver_args.alpn = NULL;
285     if (serverctx != NULL && !TEST_true(SSL_CTX_up_ref(serverctx)))
286         goto err;
287     tserver_args.ctx = serverctx;
288     if (fake_now_lock == NULL) {
289         fake_now_lock = CRYPTO_THREAD_lock_new();
290         if (fake_now_lock == NULL)
291             goto err;
292     }
293     if ((flags & QTEST_FLAG_FAKE_TIME) != 0) {
294         using_fake_time = 1;
295         qtest_reset_time();
296         tserver_args.now_cb = fake_now_cb;
297         (void)ossl_quic_conn_set_override_now_cb(*cssl, fake_now_cb, NULL);
298     } else {
299         using_fake_time = 0;
300     }
301
302     if (!TEST_ptr(*qtserv = ossl_quic_tserver_new(&tserver_args, certfile,
303                                                   keyfile)))
304         goto err;
305
306     /* Ownership of fisbio and sbio is now held by *qtserv */
307     sbio = NULL;
308     fisbio = NULL;
309
310     if ((flags & QTEST_FLAG_NOISE) != 0)
311         ossl_quic_tserver_set_msg_callback(*qtserv, noise_msg_callback,
312                                            &(*fault)->noiseargs);
313
314     if (fault != NULL)
315         (*fault)->qtserv = *qtserv;
316
317     BIO_ADDR_free(peeraddr);
318
319     return 1;
320  err:
321     SSL_CTX_free(tserver_args.ctx);
322     BIO_ADDR_free(peeraddr);
323     BIO_free_all(cbio);
324     BIO_free_all(fisbio);
325     BIO_free_all(sbio);
326     SSL_free(*cssl);
327     *cssl = NULL;
328     ossl_quic_tserver_free(*qtserv);
329     if (fault != NULL)
330         OPENSSL_free(*fault);
331     BIO_free(tmpbio);
332     if (tracebio != NULL)
333         *tracebio = NULL;
334
335     return 0;
336 }
337
338 void qtest_add_time(uint64_t millis)
339 {
340     if (!CRYPTO_THREAD_write_lock(fake_now_lock))
341         return;
342     fake_now = ossl_time_add(fake_now, ossl_ms2time(millis));
343     CRYPTO_THREAD_unlock(fake_now_lock);
344 }
345
346 static OSSL_TIME qtest_get_time(void)
347 {
348     OSSL_TIME ret;
349
350     if (!CRYPTO_THREAD_read_lock(fake_now_lock))
351         return ossl_time_zero();
352     ret = fake_now;
353     CRYPTO_THREAD_unlock(fake_now_lock);
354     return ret;
355 }
356
357 static void qtest_reset_time(void)
358 {
359     if (!CRYPTO_THREAD_write_lock(fake_now_lock))
360         return;
361     fake_now = ossl_time_zero();
362     CRYPTO_THREAD_unlock(fake_now_lock);
363     /* zero time can have a special meaning, bump it */
364     qtest_add_time(1);
365 }
366
367 QTEST_FAULT *qtest_create_injector(QUIC_TSERVER *ts)
368 {
369     QTEST_FAULT *f;
370
371     f = OPENSSL_zalloc(sizeof(*f));
372     if (f == NULL)
373         return NULL;
374
375     f->qtserv = ts;
376     return f;
377
378 }
379
380 int qtest_supports_blocking(void)
381 {
382 #if !defined(OPENSSL_NO_POSIX_IO) && defined(OPENSSL_THREADS) && !defined(CRYPTO_TDEBUG)
383     return 1;
384 #else
385     return 0;
386 #endif
387 }
388
389 #define MAXLOOPS    1000
390
391 #if defined(OPENSSL_THREADS) && !defined(CRYPTO_TDEBUG)
392 static int globserverret = 0;
393 static TSAN_QUALIFIER int abortserverthread = 0;
394 static QUIC_TSERVER *globtserv;
395 static const thread_t thread_zero;
396
397 static void run_server_thread(void)
398 {
399     /*
400      * This will operate in a busy loop because the server does not block,
401      * but should be acceptable because it is local and we expect this to be
402      * fast
403      */
404     globserverret = qtest_create_quic_connection(globtserv, NULL);
405 }
406 #endif
407
408 int qtest_wait_for_timeout(SSL *s, QUIC_TSERVER *qtserv)
409 {
410     struct timeval tv;
411     OSSL_TIME ctimeout, stimeout, mintimeout, now;
412     int cinf;
413
414     /* We don't need to wait in blocking mode */
415     if (s == NULL || SSL_get_blocking_mode(s))
416         return 1;
417
418     /* Don't wait if either BIO has data waiting */
419     if (BIO_pending(SSL_get_rbio(s)) > 0
420             || BIO_pending(ossl_quic_tserver_get0_rbio(qtserv)) > 0)
421         return 1;
422
423     /*
424      * Neither endpoint has data waiting to be read. We assume data transmission
425      * is instantaneous due to using mem based BIOs, so there is no data "in
426      * flight" and no more data will be sent by either endpoint until some time
427      * based event has occurred. Therefore, wait for a timeout to occur. This
428      * might happen if we are using the noisy BIO and datagrams have been lost.
429      */
430     if (!SSL_get_event_timeout(s, &tv, &cinf))
431         return 0;
432
433     if (using_fake_time)
434         now = qtest_get_time();
435     else
436         now = ossl_time_now();
437
438     ctimeout = cinf ? ossl_time_infinite() : ossl_time_from_timeval(tv);
439     stimeout = ossl_time_subtract(ossl_quic_tserver_get_deadline(qtserv), now);
440     mintimeout = ossl_time_min(ctimeout, stimeout);
441     if (ossl_time_is_infinite(mintimeout))
442         return 0;
443
444     if (using_fake_time)
445         qtest_add_time(ossl_time2ms(mintimeout));
446     else
447         OSSL_sleep(ossl_time2ms(mintimeout));
448
449     return 1;
450 }
451
452 int qtest_create_quic_connection_ex(QUIC_TSERVER *qtserv, SSL *clientssl,
453                                     int wanterr)
454 {
455     int retc = -1, rets = 0, abortctr = 0, ret = 0;
456     int clienterr = 0, servererr = 0;
457 #if defined(OPENSSL_THREADS) && !defined(CRYPTO_TDEBUG)
458     /*
459      * Pointless initialisation to avoid bogus compiler warnings about using
460      * t uninitialised
461      */
462     thread_t t = thread_zero;
463
464     if (clientssl != NULL)
465         abortserverthread = 0;
466 #endif
467
468     if (!TEST_ptr(qtserv)) {
469         goto err;
470     } else if (clientssl == NULL) {
471         retc = 1;
472     } else if (SSL_get_blocking_mode(clientssl) > 0) {
473 #if defined(OPENSSL_THREADS) && !defined(CRYPTO_TDEBUG)
474         /*
475          * clientssl is blocking. We will need a thread to complete the
476          * connection
477          */
478         globtserv = qtserv;
479         if (!TEST_true(run_thread(&t, run_server_thread)))
480             goto err;
481
482         qtserv = NULL;
483         rets = 1;
484 #else
485         TEST_error("No thread support in this build");
486         goto err;
487 #endif
488     }
489
490     do {
491         if (!clienterr && retc <= 0) {
492             int err;
493
494             retc = SSL_connect(clientssl);
495             if (retc <= 0) {
496                 err = SSL_get_error(clientssl, retc);
497
498                 if (err == wanterr) {
499                     retc = 1;
500 #if defined(OPENSSL_THREADS) && !defined(CRYPTO_TDEBUG)
501                     if (qtserv == NULL && rets > 0)
502                         tsan_store(&abortserverthread, 1);
503                     else
504 #endif
505                         rets = 1;
506                 } else {
507                     if (err != SSL_ERROR_WANT_READ
508                             && err != SSL_ERROR_WANT_WRITE) {
509                         TEST_info("SSL_connect() failed %d, %d", retc, err);
510                         TEST_openssl_errors();
511                         clienterr = 1;
512                     }
513                 }
514             }
515         }
516
517         qtest_add_time(1);
518         if (clientssl != NULL)
519             SSL_handle_events(clientssl);
520         if (qtserv != NULL)
521             ossl_quic_tserver_tick(qtserv);
522
523         if (!servererr && rets <= 0) {
524             servererr = ossl_quic_tserver_is_term_any(qtserv);
525             if (!servererr)
526                 rets = ossl_quic_tserver_is_handshake_confirmed(qtserv);
527         }
528
529         if (clienterr && servererr)
530             goto err;
531
532         if (clientssl != NULL && ++abortctr == MAXLOOPS) {
533             TEST_info("No progress made");
534             goto err;
535         }
536
537         if ((retc <= 0 && !clienterr) || (rets <= 0 && !servererr)) {
538             if (!qtest_wait_for_timeout(clientssl, qtserv))
539                 goto err;
540         }
541     } while ((retc <= 0 && !clienterr)
542              || (rets <= 0 && !servererr
543 #if defined(OPENSSL_THREADS) && !defined(CRYPTO_TDEBUG)
544                  && !tsan_load(&abortserverthread)
545 #endif
546                 ));
547
548     if (qtserv == NULL && rets > 0) {
549 #if defined(OPENSSL_THREADS) && !defined(CRYPTO_TDEBUG)
550         if (!TEST_true(wait_for_thread(t)) || !TEST_true(globserverret))
551             goto err;
552 #else
553         TEST_error("Should not happen");
554         goto err;
555 #endif
556     }
557
558     if (!clienterr && !servererr)
559         ret = 1;
560  err:
561     return ret;
562 }
563
564 int qtest_create_quic_connection(QUIC_TSERVER *qtserv, SSL *clientssl)
565 {
566     return qtest_create_quic_connection_ex(qtserv, clientssl, SSL_ERROR_NONE);
567 }
568
569 #if defined(OPENSSL_THREADS) && !defined(CRYPTO_TDEBUG)
570 static TSAN_QUALIFIER int shutdowndone;
571
572 static void run_server_shutdown_thread(void)
573 {
574     /*
575      * This will operate in a busy loop because the server does not block,
576      * but should be acceptable because it is local and we expect this to be
577      * fast
578      */
579     do {
580         ossl_quic_tserver_tick(globtserv);
581     } while(!tsan_load(&shutdowndone));
582 }
583 #endif
584
585 int qtest_shutdown(QUIC_TSERVER *qtserv, SSL *clientssl)
586 {
587     int tickserver = 1;
588     int ret = 0;
589 #if defined(OPENSSL_THREADS) && !defined(CRYPTO_TDEBUG)
590     /*
591      * Pointless initialisation to avoid bogus compiler warnings about using
592      * t uninitialised
593      */
594     thread_t t = thread_zero;
595 #endif
596
597     if (SSL_get_blocking_mode(clientssl) > 0) {
598 #if defined(OPENSSL_THREADS) && !defined(CRYPTO_TDEBUG)
599         /*
600          * clientssl is blocking. We will need a thread to complete the
601          * connection
602          */
603         globtserv = qtserv;
604         shutdowndone = 0;
605         if (!TEST_true(run_thread(&t, run_server_shutdown_thread)))
606             return 0;
607
608         tickserver = 0;
609 #else
610         TEST_error("No thread support in this build");
611         return 0;
612 #endif
613     }
614
615     /* Busy loop in non-blocking mode. It should be quick because its local */
616     for (;;) {
617         int rc = SSL_shutdown(clientssl);
618
619         if (rc == 1) {
620             ret = 1;
621             break;
622         }
623
624         if (rc < 0)
625             break;
626
627         if (tickserver)
628             ossl_quic_tserver_tick(qtserv);
629     }
630
631 #if defined(OPENSSL_THREADS) && !defined(CRYPTO_TDEBUG)
632     tsan_store(&shutdowndone, 1);
633     if (!tickserver) {
634         if (!TEST_true(wait_for_thread(t)))
635             ret = 0;
636     }
637 #endif
638
639     return ret;
640 }
641
642 int qtest_check_server_transport_err(QUIC_TSERVER *qtserv, uint64_t code)
643 {
644     const QUIC_TERMINATE_CAUSE *cause;
645
646     ossl_quic_tserver_tick(qtserv);
647
648     /*
649      * Check that the server has closed with the specified code from the client
650      */
651     if (!TEST_true(ossl_quic_tserver_is_term_any(qtserv)))
652         return 0;
653
654     cause = ossl_quic_tserver_get_terminate_cause(qtserv);
655     if  (!TEST_ptr(cause)
656             || !TEST_true(cause->remote)
657             || !TEST_false(cause->app)
658             || !TEST_uint64_t_eq(cause->error_code, code))
659         return 0;
660
661     return 1;
662 }
663
664 int qtest_check_server_protocol_err(QUIC_TSERVER *qtserv)
665 {
666     return qtest_check_server_transport_err(qtserv, QUIC_ERR_PROTOCOL_VIOLATION);
667 }
668
669 int qtest_check_server_frame_encoding_err(QUIC_TSERVER *qtserv)
670 {
671     return qtest_check_server_transport_err(qtserv, QUIC_ERR_FRAME_ENCODING_ERROR);
672 }
673
674 void qtest_fault_free(QTEST_FAULT *fault)
675 {
676     if (fault == NULL)
677         return;
678
679     packet_plain_finish(fault);
680     handshake_finish(fault);
681
682     OPENSSL_free(fault);
683 }
684
685 static int packet_plain_mutate(const QUIC_PKT_HDR *hdrin,
686                                const OSSL_QTX_IOVEC *iovecin, size_t numin,
687                                QUIC_PKT_HDR **hdrout,
688                                const OSSL_QTX_IOVEC **iovecout,
689                                size_t *numout,
690                                void *arg)
691 {
692     QTEST_FAULT *fault = arg;
693     size_t i, bufsz = 0;
694     unsigned char *cur;
695
696     /* Coalesce our data into a single buffer */
697
698     /* First calculate required buffer size */
699     for (i = 0; i < numin; i++)
700         bufsz += iovecin[i].buf_len;
701
702     fault->pplainio.buf_len = bufsz;
703
704     /* Add an allowance for possible growth */
705     bufsz += GROWTH_ALLOWANCE;
706
707     fault->pplainio.buf = cur = OPENSSL_malloc(bufsz);
708     if (cur == NULL) {
709         fault->pplainio.buf_len = 0;
710         return 0;
711     }
712
713     fault->pplainbuf_alloc = bufsz;
714
715     /* Copy in the data from the input buffers */
716     for (i = 0; i < numin; i++) {
717         memcpy(cur, iovecin[i].buf, iovecin[i].buf_len);
718         cur += iovecin[i].buf_len;
719     }
720
721     fault->pplainhdr = *hdrin;
722
723     /* Cast below is safe because we allocated the buffer */
724     if (fault->pplaincb != NULL
725             && !fault->pplaincb(fault, &fault->pplainhdr,
726                                 (unsigned char *)fault->pplainio.buf,
727                                 fault->pplainio.buf_len, fault->pplaincbarg))
728         return 0;
729
730     *hdrout = &fault->pplainhdr;
731     *iovecout = &fault->pplainio;
732     *numout = 1;
733
734     return 1;
735 }
736
737 static void packet_plain_finish(void *arg)
738 {
739     QTEST_FAULT *fault = arg;
740
741     /* Cast below is safe because we allocated the buffer */
742     OPENSSL_free((unsigned char *)fault->pplainio.buf);
743     fault->pplainio.buf_len = 0;
744     fault->pplainbuf_alloc = 0;
745     fault->pplainio.buf = NULL;
746 }
747
748 int qtest_fault_set_packet_plain_listener(QTEST_FAULT *fault,
749                                           qtest_fault_on_packet_plain_cb pplaincb,
750                                           void *pplaincbarg)
751 {
752     fault->pplaincb = pplaincb;
753     fault->pplaincbarg = pplaincbarg;
754
755     return ossl_quic_tserver_set_plain_packet_mutator(fault->qtserv,
756                                                       packet_plain_mutate,
757                                                       packet_plain_finish,
758                                                       fault);
759 }
760
761 /* To be called from a packet_plain_listener callback */
762 int qtest_fault_resize_plain_packet(QTEST_FAULT *fault, size_t newlen)
763 {
764     unsigned char *buf;
765     size_t oldlen = fault->pplainio.buf_len;
766
767     /*
768      * Alloc'd size should always be non-zero, so if this fails we've been
769      * incorrectly called
770      */
771     if (fault->pplainbuf_alloc == 0)
772         return 0;
773
774     if (newlen > fault->pplainbuf_alloc) {
775         /* This exceeds our growth allowance. Fail */
776         return 0;
777     }
778
779     /* Cast below is safe because we allocated the buffer */
780     buf = (unsigned char *)fault->pplainio.buf;
781
782     if (newlen > oldlen) {
783         /* Extend packet with 0 bytes */
784         memset(buf + oldlen, 0, newlen - oldlen);
785     } /* else we're truncating or staying the same */
786
787     fault->pplainio.buf_len = newlen;
788     fault->pplainhdr.len = newlen;
789
790     return 1;
791 }
792
793 /*
794  * Prepend frame data into a packet. To be called from a packet_plain_listener
795  * callback
796  */
797 int qtest_fault_prepend_frame(QTEST_FAULT *fault, const unsigned char *frame,
798                               size_t frame_len)
799 {
800     unsigned char *buf;
801     size_t old_len;
802
803     /*
804      * Alloc'd size should always be non-zero, so if this fails we've been
805      * incorrectly called
806      */
807     if (fault->pplainbuf_alloc == 0)
808         return 0;
809
810     /* Cast below is safe because we allocated the buffer */
811     buf = (unsigned char *)fault->pplainio.buf;
812     old_len = fault->pplainio.buf_len;
813
814     /* Extend the size of the packet by the size of the new frame */
815     if (!TEST_true(qtest_fault_resize_plain_packet(fault,
816                                                    old_len + frame_len)))
817         return 0;
818
819     memmove(buf + frame_len, buf, old_len);
820     memcpy(buf, frame, frame_len);
821
822     return 1;
823 }
824
825 static int handshake_mutate(const unsigned char *msgin, size_t msginlen,
826                             unsigned char **msgout, size_t *msgoutlen,
827                             void *arg)
828 {
829     QTEST_FAULT *fault = arg;
830     unsigned char *buf;
831     unsigned long payloadlen;
832     unsigned int msgtype;
833     PACKET pkt;
834
835     buf = OPENSSL_malloc(msginlen + GROWTH_ALLOWANCE);
836     if (buf == NULL)
837         return 0;
838
839     fault->handbuf = buf;
840     fault->handbuflen = msginlen;
841     fault->handbufalloc = msginlen + GROWTH_ALLOWANCE;
842     memcpy(buf, msgin, msginlen);
843
844     if (!PACKET_buf_init(&pkt, buf, msginlen)
845             || !PACKET_get_1(&pkt, &msgtype)
846             || !PACKET_get_net_3(&pkt, &payloadlen)
847             || PACKET_remaining(&pkt) != payloadlen)
848         return 0;
849
850     /* Parse specific message types */
851     switch (msgtype) {
852     case SSL3_MT_ENCRYPTED_EXTENSIONS:
853     {
854         QTEST_ENCRYPTED_EXTENSIONS ee;
855
856         if (fault->encextcb == NULL)
857             break;
858
859         /*
860          * The EncryptedExtensions message is very simple. It just has an
861          * extensions block in it and nothing else.
862          */
863         ee.extensions = (unsigned char *)PACKET_data(&pkt);
864         ee.extensionslen = payloadlen;
865         if (!fault->encextcb(fault, &ee, payloadlen, fault->encextcbarg))
866             return 0;
867     }
868
869     default:
870         /* No specific handlers for these message types yet */
871         break;
872     }
873
874     if (fault->handshakecb != NULL
875             && !fault->handshakecb(fault, buf, fault->handbuflen,
876                                    fault->handshakecbarg))
877         return 0;
878
879     *msgout = buf;
880     *msgoutlen = fault->handbuflen;
881
882     return 1;
883 }
884
885 static void handshake_finish(void *arg)
886 {
887     QTEST_FAULT *fault = arg;
888
889     OPENSSL_free(fault->handbuf);
890     fault->handbuf = NULL;
891 }
892
893 int qtest_fault_set_handshake_listener(QTEST_FAULT *fault,
894                                        qtest_fault_on_handshake_cb handshakecb,
895                                        void *handshakecbarg)
896 {
897     fault->handshakecb = handshakecb;
898     fault->handshakecbarg = handshakecbarg;
899
900     return ossl_quic_tserver_set_handshake_mutator(fault->qtserv,
901                                                    handshake_mutate,
902                                                    handshake_finish,
903                                                    fault);
904 }
905
906 int qtest_fault_set_hand_enc_ext_listener(QTEST_FAULT *fault,
907                                           qtest_fault_on_enc_ext_cb encextcb,
908                                           void *encextcbarg)
909 {
910     fault->encextcb = encextcb;
911     fault->encextcbarg = encextcbarg;
912
913     return ossl_quic_tserver_set_handshake_mutator(fault->qtserv,
914                                                    handshake_mutate,
915                                                    handshake_finish,
916                                                    fault);
917 }
918
919 /* To be called from a handshake_listener callback */
920 int qtest_fault_resize_handshake(QTEST_FAULT *fault, size_t newlen)
921 {
922     unsigned char *buf;
923     size_t oldlen = fault->handbuflen;
924
925     /*
926      * Alloc'd size should always be non-zero, so if this fails we've been
927      * incorrectly called
928      */
929     if (fault->handbufalloc == 0)
930         return 0;
931
932     if (newlen > fault->handbufalloc) {
933         /* This exceeds our growth allowance. Fail */
934         return 0;
935     }
936
937     buf = (unsigned char *)fault->handbuf;
938
939     if (newlen > oldlen) {
940         /* Extend packet with 0 bytes */
941         memset(buf + oldlen, 0, newlen - oldlen);
942     } /* else we're truncating or staying the same */
943
944     fault->handbuflen = newlen;
945     return 1;
946 }
947
948 /* To be called from message specific listener callbacks */
949 int qtest_fault_resize_message(QTEST_FAULT *fault, size_t newlen)
950 {
951     /* First resize the underlying message */
952     if (!qtest_fault_resize_handshake(fault, newlen + SSL3_HM_HEADER_LENGTH))
953         return 0;
954
955     /* Fixup the handshake message header */
956     fault->handbuf[1] = (unsigned char)((newlen >> 16) & 0xff);
957     fault->handbuf[2] = (unsigned char)((newlen >>  8) & 0xff);
958     fault->handbuf[3] = (unsigned char)((newlen      ) & 0xff);
959
960     return 1;
961 }
962
963 int qtest_fault_delete_extension(QTEST_FAULT *fault,
964                                  unsigned int exttype, unsigned char *ext,
965                                  size_t *extlen,
966                                  BUF_MEM *old_ext)
967 {
968     PACKET pkt, sub, subext;
969     WPACKET old_ext_wpkt;
970     unsigned int type;
971     const unsigned char *start, *end;
972     size_t newlen, w;
973     size_t msglen = fault->handbuflen;
974
975     if (!PACKET_buf_init(&pkt, ext, *extlen))
976         return 0;
977
978     /* Extension block starts with 2 bytes for extension block length */
979     if (!PACKET_as_length_prefixed_2(&pkt, &sub))
980         return 0;
981
982     do {
983         start = PACKET_data(&sub);
984         if (!PACKET_get_net_2(&sub, &type)
985                 || !PACKET_get_length_prefixed_2(&sub, &subext))
986             return 0;
987     } while (type != exttype);
988
989     /* Found it */
990     end = PACKET_data(&sub);
991
992     if (old_ext != NULL) {
993         if (!WPACKET_init(&old_ext_wpkt, old_ext))
994             return 0;
995
996         if (!WPACKET_memcpy(&old_ext_wpkt, PACKET_data(&subext),
997                             PACKET_remaining(&subext))
998             || !WPACKET_get_total_written(&old_ext_wpkt, &w)) {
999             WPACKET_cleanup(&old_ext_wpkt);
1000             return 0;
1001         }
1002
1003         WPACKET_finish(&old_ext_wpkt);
1004         old_ext->length = w;
1005     }
1006
1007     /*
1008      * If we're not the last extension we need to move the rest earlier. The
1009      * cast below is safe because we own the underlying buffer and we're no
1010      * longer making PACKET calls.
1011      */
1012     if (end < ext + *extlen)
1013         memmove((unsigned char *)start, end, end - start);
1014
1015     /*
1016      * Calculate new extensions payload length =
1017      * Original length
1018      * - 2 extension block length bytes
1019      * - length of removed extension
1020      */
1021     newlen = *extlen - 2 - (end - start);
1022
1023     /* Fixup the length bytes for the extension block */
1024     ext[0] = (unsigned char)((newlen >> 8) & 0xff);
1025     ext[1] = (unsigned char)((newlen     ) & 0xff);
1026
1027     /*
1028      * Length of the whole extension block is the new payload length plus the
1029      * 2 bytes for the length
1030      */
1031     *extlen = newlen + 2;
1032
1033     /* We can now resize the message */
1034     if ((size_t)(end - start) + SSL3_HM_HEADER_LENGTH > msglen)
1035         return 0; /* Should not happen */
1036     msglen -= (end - start) + SSL3_HM_HEADER_LENGTH;
1037     if (!qtest_fault_resize_message(fault, msglen))
1038         return 0;
1039
1040     return 1;
1041 }
1042
1043 #define BIO_TYPE_CIPHER_PACKET_FILTER  (0x80 | BIO_TYPE_FILTER)
1044
1045 static BIO_METHOD *pcipherbiometh = NULL;
1046
1047 # define BIO_MSG_N(array, stride, n) (*(BIO_MSG *)((char *)(array) + (n)*(stride)))
1048
1049 static int pcipher_sendmmsg(BIO *b, BIO_MSG *msg, size_t stride,
1050                             size_t num_msg, uint64_t flags,
1051                             size_t *num_processed)
1052 {
1053     QTEST_FAULT *fault;
1054     BIO *next = BIO_next(b);
1055     ossl_ssize_t ret = 0;
1056     size_t i = 0, tmpnump;
1057     QUIC_PKT_HDR hdr;
1058     PACKET pkt;
1059     unsigned char *tmpdata;
1060
1061     if (next == NULL)
1062         return 0;
1063
1064     fault = BIO_get_data(b);
1065     if (fault == NULL
1066             || (fault->pciphercb == NULL && fault->datagramcb == NULL))
1067         return BIO_sendmmsg(next, msg, stride, num_msg, flags, num_processed);
1068
1069     if (num_msg == 0) {
1070         *num_processed = 0;
1071         return 1;
1072     }
1073
1074     for (i = 0; i < num_msg; ++i) {
1075         fault->msg = BIO_MSG_N(msg, stride, i);
1076
1077         /* Take a copy of the data so that callbacks can modify it */
1078         tmpdata = OPENSSL_malloc(fault->msg.data_len + GROWTH_ALLOWANCE);
1079         if (tmpdata == NULL)
1080             return 0;
1081         memcpy(tmpdata, fault->msg.data, fault->msg.data_len);
1082         fault->msg.data = tmpdata;
1083         fault->msgalloc = fault->msg.data_len + GROWTH_ALLOWANCE;
1084
1085         if (fault->pciphercb != NULL) {
1086             if (!PACKET_buf_init(&pkt, fault->msg.data, fault->msg.data_len))
1087                 return 0;
1088
1089             do {
1090                 if (!ossl_quic_wire_decode_pkt_hdr(&pkt,
1091                         /*
1092                          * TODO(QUIC SERVER):
1093                          * Needs to be set to the actual short header CID length
1094                          * when testing the server implementation.
1095                          */
1096                         0,
1097                         1,
1098                         0, &hdr, NULL))
1099                     goto out;
1100
1101                 /*
1102                  * hdr.data is const - but its our buffer so casting away the
1103                  * const is safe
1104                  */
1105                 if (!fault->pciphercb(fault, &hdr, (unsigned char *)hdr.data,
1106                                     hdr.len, fault->pciphercbarg))
1107                     goto out;
1108
1109                 /*
1110                  * At the moment modifications to hdr by the callback
1111                  * are ignored. We might need to rewrite the QUIC header to
1112                  * enable tests to change this. We also don't yet have a
1113                  * mechanism for the callback to change the encrypted data
1114                  * length. It's not clear if that's needed or not.
1115                  */
1116             } while (PACKET_remaining(&pkt) > 0);
1117         }
1118
1119         if (fault->datagramcb != NULL
1120                 && !fault->datagramcb(fault, &fault->msg, stride,
1121                                       fault->datagramcbarg))
1122             goto out;
1123
1124         if (!BIO_sendmmsg(next, &fault->msg, stride, 1, flags, &tmpnump)) {
1125             *num_processed = i;
1126             goto out;
1127         }
1128
1129         OPENSSL_free(fault->msg.data);
1130         fault->msg.data = NULL;
1131         fault->msgalloc = 0;
1132     }
1133
1134     *num_processed = i;
1135 out:
1136     ret = i > 0;
1137     OPENSSL_free(fault->msg.data);
1138     fault->msg.data = NULL;
1139     return ret;
1140 }
1141
1142 static long pcipher_ctrl(BIO *b, int cmd, long larg, void *parg)
1143 {
1144     BIO *next = BIO_next(b);
1145
1146     if (next == NULL)
1147         return -1;
1148
1149     return BIO_ctrl(next, cmd, larg, parg);
1150 }
1151
1152 BIO_METHOD *qtest_get_bio_method(void)
1153 {
1154     BIO_METHOD *tmp;
1155
1156     if (pcipherbiometh != NULL)
1157         return pcipherbiometh;
1158
1159     tmp = BIO_meth_new(BIO_TYPE_CIPHER_PACKET_FILTER, "Cipher Packet Filter");
1160
1161     if (!TEST_ptr(tmp))
1162         return NULL;
1163
1164     if (!TEST_true(BIO_meth_set_sendmmsg(tmp, pcipher_sendmmsg))
1165             || !TEST_true(BIO_meth_set_ctrl(tmp, pcipher_ctrl)))
1166         goto err;
1167
1168     pcipherbiometh = tmp;
1169     tmp = NULL;
1170  err:
1171     BIO_meth_free(tmp);
1172     return pcipherbiometh;
1173 }
1174
1175 int qtest_fault_set_packet_cipher_listener(QTEST_FAULT *fault,
1176                                            qtest_fault_on_packet_cipher_cb pciphercb,
1177                                            void *pciphercbarg)
1178 {
1179     fault->pciphercb = pciphercb;
1180     fault->pciphercbarg = pciphercbarg;
1181
1182     return 1;
1183 }
1184
1185 int qtest_fault_set_datagram_listener(QTEST_FAULT *fault,
1186                                       qtest_fault_on_datagram_cb datagramcb,
1187                                       void *datagramcbarg)
1188 {
1189     fault->datagramcb = datagramcb;
1190     fault->datagramcbarg = datagramcbarg;
1191
1192     return 1;
1193 }
1194
1195 /* To be called from a datagram_listener callback */
1196 int qtest_fault_resize_datagram(QTEST_FAULT *fault, size_t newlen)
1197 {
1198     if (newlen > fault->msgalloc)
1199             return 0;
1200
1201     if (newlen > fault->msg.data_len)
1202         memset((unsigned char *)fault->msg.data + fault->msg.data_len, 0,
1203                 newlen - fault->msg.data_len);
1204
1205     fault->msg.data_len = newlen;
1206
1207     return 1;
1208 }
1209
1210 int bio_msg_copy(BIO_MSG *dst, BIO_MSG *src)
1211 {
1212     /*
1213      * Note it is assumed that the originally allocated data sizes for dst and
1214      * src are the same
1215      */
1216     memcpy(dst->data, src->data, src->data_len);
1217     dst->data_len = src->data_len;
1218     dst->flags = src->flags;
1219     if (dst->local != NULL) {
1220         if (src->local != NULL) {
1221             if (!TEST_true(BIO_ADDR_copy(dst->local, src->local)))
1222                 return 0;
1223         } else {
1224             BIO_ADDR_clear(dst->local);
1225         }
1226     }
1227     if (!TEST_true(BIO_ADDR_copy(dst->peer, src->peer)))
1228         return 0;
1229
1230     return 1;
1231 }