Make sure we trigger retransmits in DTLS testing
[openssl.git] / test / ssltestlib.c
1 /*
2  * Copyright 2016-2019 The OpenSSL Project Authors. All Rights Reserved.
3  *
4  * Licensed under the Apache License 2.0 (the "License").  You may not use
5  * this file except in compliance with the License.  You can obtain a copy
6  * in the file LICENSE in the source distribution or at
7  * https://www.openssl.org/source/license.html
8  */
9
10 #include <string.h>
11
12 #include "internal/nelem.h"
13 #include "ssltestlib.h"
14 #include "testutil.h"
15 #include "e_os.h"
16
17 #ifdef OPENSSL_SYS_UNIX
18 # include <unistd.h>
19 #ifndef OPENSSL_NO_KTLS
20 # include <netinet/in.h>
21 # include <netinet/in.h>
22 # include <arpa/inet.h>
23 # include <sys/socket.h>
24 # include <unistd.h>
25 # include <fcntl.h>
26 #endif
27
28 static ossl_inline void ossl_sleep(unsigned int millis) {
29     usleep(millis * 1000);
30 }
31 #elif defined(_WIN32)
32 # include <windows.h>
33
34 static ossl_inline void ossl_sleep(unsigned int millis) {
35     Sleep(millis);
36 }
37 #else
38 /* Fallback to a busy wait */
39 static ossl_inline void ossl_sleep(unsigned int millis) {
40     struct timeval start, now;
41     unsigned int elapsedms;
42
43     gettimeofday(&start, NULL);
44     do {
45         gettimeofday(&now, NULL);
46         elapsedms = (((now.tv_sec - start.tv_sec) * 1000000)
47                      + now.tv_usec - start.tv_usec) / 1000;
48     } while (elapsedms < millis);
49 }
50 #endif
51
52 static int tls_dump_new(BIO *bi);
53 static int tls_dump_free(BIO *a);
54 static int tls_dump_read(BIO *b, char *out, int outl);
55 static int tls_dump_write(BIO *b, const char *in, int inl);
56 static long tls_dump_ctrl(BIO *b, int cmd, long num, void *ptr);
57 static int tls_dump_gets(BIO *bp, char *buf, int size);
58 static int tls_dump_puts(BIO *bp, const char *str);
59
60 /* Choose a sufficiently large type likely to be unused for this custom BIO */
61 #define BIO_TYPE_TLS_DUMP_FILTER  (0x80 | BIO_TYPE_FILTER)
62 #define BIO_TYPE_MEMPACKET_TEST    0x81
63
64 static BIO_METHOD *method_tls_dump = NULL;
65 static BIO_METHOD *meth_mem = NULL;
66
67 /* Note: Not thread safe! */
68 const BIO_METHOD *bio_f_tls_dump_filter(void)
69 {
70     if (method_tls_dump == NULL) {
71         method_tls_dump = BIO_meth_new(BIO_TYPE_TLS_DUMP_FILTER,
72                                         "TLS dump filter");
73         if (   method_tls_dump == NULL
74             || !BIO_meth_set_write(method_tls_dump, tls_dump_write)
75             || !BIO_meth_set_read(method_tls_dump, tls_dump_read)
76             || !BIO_meth_set_puts(method_tls_dump, tls_dump_puts)
77             || !BIO_meth_set_gets(method_tls_dump, tls_dump_gets)
78             || !BIO_meth_set_ctrl(method_tls_dump, tls_dump_ctrl)
79             || !BIO_meth_set_create(method_tls_dump, tls_dump_new)
80             || !BIO_meth_set_destroy(method_tls_dump, tls_dump_free))
81             return NULL;
82     }
83     return method_tls_dump;
84 }
85
86 void bio_f_tls_dump_filter_free(void)
87 {
88     BIO_meth_free(method_tls_dump);
89 }
90
91 static int tls_dump_new(BIO *bio)
92 {
93     BIO_set_init(bio, 1);
94     return 1;
95 }
96
97 static int tls_dump_free(BIO *bio)
98 {
99     BIO_set_init(bio, 0);
100
101     return 1;
102 }
103
104 static void copy_flags(BIO *bio)
105 {
106     int flags;
107     BIO *next = BIO_next(bio);
108
109     flags = BIO_test_flags(next, BIO_FLAGS_SHOULD_RETRY | BIO_FLAGS_RWS);
110     BIO_clear_flags(bio, BIO_FLAGS_SHOULD_RETRY | BIO_FLAGS_RWS);
111     BIO_set_flags(bio, flags);
112 }
113
114 #define RECORD_CONTENT_TYPE     0
115 #define RECORD_VERSION_HI       1
116 #define RECORD_VERSION_LO       2
117 #define RECORD_EPOCH_HI         3
118 #define RECORD_EPOCH_LO         4
119 #define RECORD_SEQUENCE_START   5
120 #define RECORD_SEQUENCE_END     10
121 #define RECORD_LEN_HI           11
122 #define RECORD_LEN_LO           12
123
124 #define MSG_TYPE                0
125 #define MSG_LEN_HI              1
126 #define MSG_LEN_MID             2
127 #define MSG_LEN_LO              3
128 #define MSG_SEQ_HI              4
129 #define MSG_SEQ_LO              5
130 #define MSG_FRAG_OFF_HI         6
131 #define MSG_FRAG_OFF_MID        7
132 #define MSG_FRAG_OFF_LO         8
133 #define MSG_FRAG_LEN_HI         9
134 #define MSG_FRAG_LEN_MID        10
135 #define MSG_FRAG_LEN_LO         11
136
137
138 static void dump_data(const char *data, int len)
139 {
140     int rem, i, content, reclen, msglen, fragoff, fraglen, epoch;
141     unsigned char *rec;
142
143     printf("---- START OF PACKET ----\n");
144
145     rem = len;
146     rec = (unsigned char *)data;
147
148     while (rem > 0) {
149         if (rem != len)
150             printf("*\n");
151         printf("*---- START OF RECORD ----\n");
152         if (rem < DTLS1_RT_HEADER_LENGTH) {
153             printf("*---- RECORD TRUNCATED ----\n");
154             break;
155         }
156         content = rec[RECORD_CONTENT_TYPE];
157         printf("** Record Content-type: %d\n", content);
158         printf("** Record Version: %02x%02x\n",
159                rec[RECORD_VERSION_HI], rec[RECORD_VERSION_LO]);
160         epoch = (rec[RECORD_EPOCH_HI] << 8) | rec[RECORD_EPOCH_LO];
161         printf("** Record Epoch: %d\n", epoch);
162         printf("** Record Sequence: ");
163         for (i = RECORD_SEQUENCE_START; i <= RECORD_SEQUENCE_END; i++)
164             printf("%02x", rec[i]);
165         reclen = (rec[RECORD_LEN_HI] << 8) | rec[RECORD_LEN_LO];
166         printf("\n** Record Length: %d\n", reclen);
167
168         /* Now look at message */
169         rec += DTLS1_RT_HEADER_LENGTH;
170         rem -= DTLS1_RT_HEADER_LENGTH;
171         if (content == SSL3_RT_HANDSHAKE) {
172             printf("**---- START OF HANDSHAKE MESSAGE FRAGMENT ----\n");
173             if (epoch > 0) {
174                 printf("**---- HANDSHAKE MESSAGE FRAGMENT ENCRYPTED ----\n");
175             } else if (rem < DTLS1_HM_HEADER_LENGTH
176                     || reclen < DTLS1_HM_HEADER_LENGTH) {
177                 printf("**---- HANDSHAKE MESSAGE FRAGMENT TRUNCATED ----\n");
178             } else {
179                 printf("*** Message Type: %d\n", rec[MSG_TYPE]);
180                 msglen = (rec[MSG_LEN_HI] << 16) | (rec[MSG_LEN_MID] << 8)
181                          | rec[MSG_LEN_LO];
182                 printf("*** Message Length: %d\n", msglen);
183                 printf("*** Message sequence: %d\n",
184                        (rec[MSG_SEQ_HI] << 8) | rec[MSG_SEQ_LO]);
185                 fragoff = (rec[MSG_FRAG_OFF_HI] << 16)
186                           | (rec[MSG_FRAG_OFF_MID] << 8)
187                           | rec[MSG_FRAG_OFF_LO];
188                 printf("*** Message Fragment offset: %d\n", fragoff);
189                 fraglen = (rec[MSG_FRAG_LEN_HI] << 16)
190                           | (rec[MSG_FRAG_LEN_MID] << 8)
191                           | rec[MSG_FRAG_LEN_LO];
192                 printf("*** Message Fragment len: %d\n", fraglen);
193                 if (fragoff + fraglen > msglen)
194                     printf("***---- HANDSHAKE MESSAGE FRAGMENT INVALID ----\n");
195                 else if (reclen < fraglen)
196                     printf("**---- HANDSHAKE MESSAGE FRAGMENT TRUNCATED ----\n");
197                 else
198                     printf("**---- END OF HANDSHAKE MESSAGE FRAGMENT ----\n");
199             }
200         }
201         if (rem < reclen) {
202             printf("*---- RECORD TRUNCATED ----\n");
203             rem = 0;
204         } else {
205             rec += reclen;
206             rem -= reclen;
207             printf("*---- END OF RECORD ----\n");
208         }
209     }
210     printf("---- END OF PACKET ----\n\n");
211     fflush(stdout);
212 }
213
214 static int tls_dump_read(BIO *bio, char *out, int outl)
215 {
216     int ret;
217     BIO *next = BIO_next(bio);
218
219     ret = BIO_read(next, out, outl);
220     copy_flags(bio);
221
222     if (ret > 0) {
223         dump_data(out, ret);
224     }
225
226     return ret;
227 }
228
229 static int tls_dump_write(BIO *bio, const char *in, int inl)
230 {
231     int ret;
232     BIO *next = BIO_next(bio);
233
234     ret = BIO_write(next, in, inl);
235     copy_flags(bio);
236
237     return ret;
238 }
239
240 static long tls_dump_ctrl(BIO *bio, int cmd, long num, void *ptr)
241 {
242     long ret;
243     BIO *next = BIO_next(bio);
244
245     if (next == NULL)
246         return 0;
247
248     switch (cmd) {
249     case BIO_CTRL_DUP:
250         ret = 0L;
251         break;
252     default:
253         ret = BIO_ctrl(next, cmd, num, ptr);
254         break;
255     }
256     return ret;
257 }
258
259 static int tls_dump_gets(BIO *bio, char *buf, int size)
260 {
261     /* We don't support this - not needed anyway */
262     return -1;
263 }
264
265 static int tls_dump_puts(BIO *bio, const char *str)
266 {
267     return tls_dump_write(bio, str, strlen(str));
268 }
269
270
271 struct mempacket_st {
272     unsigned char *data;
273     int len;
274     unsigned int num;
275     unsigned int type;
276 };
277
278 static void mempacket_free(MEMPACKET *pkt)
279 {
280     if (pkt->data != NULL)
281         OPENSSL_free(pkt->data);
282     OPENSSL_free(pkt);
283 }
284
285 typedef struct mempacket_test_ctx_st {
286     STACK_OF(MEMPACKET) *pkts;
287     unsigned int epoch;
288     unsigned int currrec;
289     unsigned int currpkt;
290     unsigned int lastpkt;
291     unsigned int injected;
292     unsigned int noinject;
293     unsigned int dropepoch;
294     int droprec;
295     int duprec;
296 } MEMPACKET_TEST_CTX;
297
298 static int mempacket_test_new(BIO *bi);
299 static int mempacket_test_free(BIO *a);
300 static int mempacket_test_read(BIO *b, char *out, int outl);
301 static int mempacket_test_write(BIO *b, const char *in, int inl);
302 static long mempacket_test_ctrl(BIO *b, int cmd, long num, void *ptr);
303 static int mempacket_test_gets(BIO *bp, char *buf, int size);
304 static int mempacket_test_puts(BIO *bp, const char *str);
305
306 const BIO_METHOD *bio_s_mempacket_test(void)
307 {
308     if (meth_mem == NULL) {
309         if (!TEST_ptr(meth_mem = BIO_meth_new(BIO_TYPE_MEMPACKET_TEST,
310                                               "Mem Packet Test"))
311             || !TEST_true(BIO_meth_set_write(meth_mem, mempacket_test_write))
312             || !TEST_true(BIO_meth_set_read(meth_mem, mempacket_test_read))
313             || !TEST_true(BIO_meth_set_puts(meth_mem, mempacket_test_puts))
314             || !TEST_true(BIO_meth_set_gets(meth_mem, mempacket_test_gets))
315             || !TEST_true(BIO_meth_set_ctrl(meth_mem, mempacket_test_ctrl))
316             || !TEST_true(BIO_meth_set_create(meth_mem, mempacket_test_new))
317             || !TEST_true(BIO_meth_set_destroy(meth_mem, mempacket_test_free)))
318             return NULL;
319     }
320     return meth_mem;
321 }
322
323 void bio_s_mempacket_test_free(void)
324 {
325     BIO_meth_free(meth_mem);
326 }
327
328 static int mempacket_test_new(BIO *bio)
329 {
330     MEMPACKET_TEST_CTX *ctx;
331
332     if (!TEST_ptr(ctx = OPENSSL_zalloc(sizeof(*ctx))))
333         return 0;
334     if (!TEST_ptr(ctx->pkts = sk_MEMPACKET_new_null())) {
335         OPENSSL_free(ctx);
336         return 0;
337     }
338     ctx->dropepoch = 0;
339     ctx->droprec = -1;
340     BIO_set_init(bio, 1);
341     BIO_set_data(bio, ctx);
342     return 1;
343 }
344
345 static int mempacket_test_free(BIO *bio)
346 {
347     MEMPACKET_TEST_CTX *ctx = BIO_get_data(bio);
348
349     sk_MEMPACKET_pop_free(ctx->pkts, mempacket_free);
350     OPENSSL_free(ctx);
351     BIO_set_data(bio, NULL);
352     BIO_set_init(bio, 0);
353     return 1;
354 }
355
356 /* Record Header values */
357 #define EPOCH_HI        3
358 #define EPOCH_LO        4
359 #define RECORD_SEQUENCE 10
360 #define RECORD_LEN_HI   11
361 #define RECORD_LEN_LO   12
362
363 #define STANDARD_PACKET                 0
364
365 static int mempacket_test_read(BIO *bio, char *out, int outl)
366 {
367     MEMPACKET_TEST_CTX *ctx = BIO_get_data(bio);
368     MEMPACKET *thispkt;
369     unsigned char *rec;
370     int rem;
371     unsigned int seq, offset, len, epoch;
372
373     BIO_clear_retry_flags(bio);
374     thispkt = sk_MEMPACKET_value(ctx->pkts, 0);
375     if (thispkt == NULL || thispkt->num != ctx->currpkt) {
376         /* Probably run out of data */
377         BIO_set_retry_read(bio);
378         return -1;
379     }
380     (void)sk_MEMPACKET_shift(ctx->pkts);
381     ctx->currpkt++;
382
383     if (outl > thispkt->len)
384         outl = thispkt->len;
385
386     if (thispkt->type != INJECT_PACKET_IGNORE_REC_SEQ
387             && (ctx->injected || ctx->droprec >= 0)) {
388         /*
389          * Overwrite the record sequence number. We strictly number them in
390          * the order received. Since we are actually a reliable transport
391          * we know that there won't be any re-ordering. We overwrite to deal
392          * with any packets that have been injected
393          */
394         for (rem = thispkt->len, rec = thispkt->data; rem > 0; rem -= len) {
395             if (rem < DTLS1_RT_HEADER_LENGTH)
396                 return -1;
397             epoch = (rec[EPOCH_HI] << 8) | rec[EPOCH_LO];
398             if (epoch != ctx->epoch) {
399                 ctx->epoch = epoch;
400                 ctx->currrec = 0;
401             }
402             seq = ctx->currrec;
403             offset = 0;
404             do {
405                 rec[RECORD_SEQUENCE - offset] = seq & 0xFF;
406                 seq >>= 8;
407                 offset++;
408             } while (seq > 0);
409
410             len = ((rec[RECORD_LEN_HI] << 8) | rec[RECORD_LEN_LO])
411                   + DTLS1_RT_HEADER_LENGTH;
412             if (rem < (int)len)
413                 return -1;
414             if (ctx->droprec == (int)ctx->currrec && ctx->dropepoch == epoch) {
415                 if (rem > (int)len)
416                     memmove(rec, rec + len, rem - len);
417                 outl -= len;
418                 ctx->droprec = -1;
419                 if (outl == 0)
420                     BIO_set_retry_read(bio);
421             } else {
422                 rec += len;
423             }
424
425             ctx->currrec++;
426         }
427     }
428
429     memcpy(out, thispkt->data, outl);
430     mempacket_free(thispkt);
431     return outl;
432 }
433
434 int mempacket_test_inject(BIO *bio, const char *in, int inl, int pktnum,
435                           int type)
436 {
437     MEMPACKET_TEST_CTX *ctx = BIO_get_data(bio);
438     MEMPACKET *thispkt = NULL, *looppkt, *nextpkt, *allpkts[3];
439     int i, duprec;
440     const unsigned char *inu = (const unsigned char *)in;
441     size_t len = ((inu[RECORD_LEN_HI] << 8) | inu[RECORD_LEN_LO])
442                  + DTLS1_RT_HEADER_LENGTH;
443
444     if (ctx == NULL)
445         return -1;
446
447     if ((size_t)inl < len)
448         return -1;
449
450     if ((size_t)inl == len)
451         duprec = 0;
452     else
453         duprec = ctx->duprec > 0;
454
455     /* We don't support arbitrary injection when duplicating records */
456     if (duprec && pktnum != -1)
457         return -1;
458
459     /* We only allow injection before we've started writing any data */
460     if (pktnum >= 0) {
461         if (ctx->noinject)
462             return -1;
463         ctx->injected  = 1;
464     } else {
465         ctx->noinject = 1;
466     }
467
468     for (i = 0; i < (duprec ? 3 : 1); i++) {
469         if (!TEST_ptr(allpkts[i] = OPENSSL_malloc(sizeof(*thispkt))))
470             goto err;
471         thispkt = allpkts[i];
472
473         if (!TEST_ptr(thispkt->data = OPENSSL_malloc(inl)))
474             goto err;
475         /*
476          * If we are duplicating the packet, we duplicate it three times. The
477          * first two times we drop the first record if there are more than one.
478          * In this way we know that libssl will not be able to make progress
479          * until it receives the last packet, and hence will be forced to
480          * buffer these records.
481          */
482         if (duprec && i != 2) {
483             memcpy(thispkt->data, in + len, inl - len);
484             thispkt->len = inl - len;
485         } else {
486             memcpy(thispkt->data, in, inl);
487             thispkt->len = inl;
488         }
489         thispkt->num = (pktnum >= 0) ? (unsigned int)pktnum : ctx->lastpkt + i;
490         thispkt->type = type;
491     }
492
493     for(i = 0; (looppkt = sk_MEMPACKET_value(ctx->pkts, i)) != NULL; i++) {
494         /* Check if we found the right place to insert this packet */
495         if (looppkt->num > thispkt->num) {
496             if (sk_MEMPACKET_insert(ctx->pkts, thispkt, i) == 0)
497                 goto err;
498             /* If we're doing up front injection then we're done */
499             if (pktnum >= 0)
500                 return inl;
501             /*
502              * We need to do some accounting on lastpkt. We increment it first,
503              * but it might now equal the value of injected packets, so we need
504              * to skip over those
505              */
506             ctx->lastpkt++;
507             do {
508                 i++;
509                 nextpkt = sk_MEMPACKET_value(ctx->pkts, i);
510                 if (nextpkt != NULL && nextpkt->num == ctx->lastpkt)
511                     ctx->lastpkt++;
512                 else
513                     return inl;
514             } while(1);
515         } else if (looppkt->num == thispkt->num) {
516             if (!ctx->noinject) {
517                 /* We injected two packets with the same packet number! */
518                 goto err;
519             }
520             ctx->lastpkt++;
521             thispkt->num++;
522         }
523     }
524     /*
525      * We didn't find any packets with a packet number equal to or greater than
526      * this one, so we just add it onto the end
527      */
528     for (i = 0; i < (duprec ? 3 : 1); i++) {
529         thispkt = allpkts[i];
530         if (!sk_MEMPACKET_push(ctx->pkts, thispkt))
531             goto err;
532
533         if (pktnum < 0)
534             ctx->lastpkt++;
535     }
536
537     return inl;
538
539  err:
540     for (i = 0; i < (ctx->duprec > 0 ? 3 : 1); i++)
541         mempacket_free(allpkts[i]);
542     return -1;
543 }
544
545 static int mempacket_test_write(BIO *bio, const char *in, int inl)
546 {
547     return mempacket_test_inject(bio, in, inl, -1, STANDARD_PACKET);
548 }
549
550 static long mempacket_test_ctrl(BIO *bio, int cmd, long num, void *ptr)
551 {
552     long ret = 1;
553     MEMPACKET_TEST_CTX *ctx = BIO_get_data(bio);
554     MEMPACKET *thispkt;
555
556     switch (cmd) {
557     case BIO_CTRL_EOF:
558         ret = (long)(sk_MEMPACKET_num(ctx->pkts) == 0);
559         break;
560     case BIO_CTRL_GET_CLOSE:
561         ret = BIO_get_shutdown(bio);
562         break;
563     case BIO_CTRL_SET_CLOSE:
564         BIO_set_shutdown(bio, (int)num);
565         break;
566     case BIO_CTRL_WPENDING:
567         ret = 0L;
568         break;
569     case BIO_CTRL_PENDING:
570         thispkt = sk_MEMPACKET_value(ctx->pkts, 0);
571         if (thispkt == NULL)
572             ret = 0;
573         else
574             ret = thispkt->len;
575         break;
576     case BIO_CTRL_FLUSH:
577         ret = 1;
578         break;
579     case MEMPACKET_CTRL_SET_DROP_EPOCH:
580         ctx->dropepoch = (unsigned int)num;
581         break;
582     case MEMPACKET_CTRL_SET_DROP_REC:
583         ctx->droprec = (int)num;
584         break;
585     case MEMPACKET_CTRL_GET_DROP_REC:
586         ret = ctx->droprec;
587         break;
588     case MEMPACKET_CTRL_SET_DUPLICATE_REC:
589         ctx->duprec = (int)num;
590         break;
591     case BIO_CTRL_RESET:
592     case BIO_CTRL_DUP:
593     case BIO_CTRL_PUSH:
594     case BIO_CTRL_POP:
595     default:
596         ret = 0;
597         break;
598     }
599     return ret;
600 }
601
602 static int mempacket_test_gets(BIO *bio, char *buf, int size)
603 {
604     /* We don't support this - not needed anyway */
605     return -1;
606 }
607
608 static int mempacket_test_puts(BIO *bio, const char *str)
609 {
610     return mempacket_test_write(bio, str, strlen(str));
611 }
612
613 int create_ssl_ctx_pair(const SSL_METHOD *sm, const SSL_METHOD *cm,
614                         int min_proto_version, int max_proto_version,
615                         SSL_CTX **sctx, SSL_CTX **cctx, char *certfile,
616                         char *privkeyfile)
617 {
618     SSL_CTX *serverctx = NULL;
619     SSL_CTX *clientctx = NULL;
620
621     if (!TEST_ptr(serverctx = SSL_CTX_new(sm))
622             || (cctx != NULL && !TEST_ptr(clientctx = SSL_CTX_new(cm))))
623         goto err;
624
625     if ((min_proto_version > 0
626          && !TEST_true(SSL_CTX_set_min_proto_version(serverctx,
627                                                      min_proto_version)))
628         || (max_proto_version > 0
629             && !TEST_true(SSL_CTX_set_max_proto_version(serverctx,
630                                                         max_proto_version))))
631         goto err;
632     if (clientctx != NULL
633         && ((min_proto_version > 0
634              && !TEST_true(SSL_CTX_set_min_proto_version(clientctx,
635                                                          min_proto_version)))
636             || (max_proto_version > 0
637                 && !TEST_true(SSL_CTX_set_max_proto_version(clientctx,
638                                                             max_proto_version)))))
639         goto err;
640
641     if (certfile != NULL && privkeyfile != NULL) {
642         if (!TEST_int_eq(SSL_CTX_use_certificate_file(serverctx, certfile,
643                                                       SSL_FILETYPE_PEM), 1)
644                 || !TEST_int_eq(SSL_CTX_use_PrivateKey_file(serverctx,
645                                                             privkeyfile,
646                                                             SSL_FILETYPE_PEM), 1)
647                 || !TEST_int_eq(SSL_CTX_check_private_key(serverctx), 1))
648             goto err;
649     }
650
651 #ifndef OPENSSL_NO_DH
652     SSL_CTX_set_dh_auto(serverctx, 1);
653 #endif
654
655     *sctx = serverctx;
656     if (cctx != NULL)
657         *cctx = clientctx;
658     return 1;
659
660  err:
661     SSL_CTX_free(serverctx);
662     SSL_CTX_free(clientctx);
663     return 0;
664 }
665
666 #define MAXLOOPS    1000000
667
668 #if !defined(OPENSSL_NO_KTLS) && !defined(OPENSSL_NO_SOCK)
669 static int set_nb(int fd)
670 {
671     int flags;
672
673     flags = fcntl(fd,F_GETFL,0);
674     if (flags == -1)
675         return flags;
676     flags = fcntl(fd, F_SETFL, flags | O_NONBLOCK);
677     return flags;
678 }
679
680 int create_test_sockets(int *cfd, int *sfd)
681 {
682     struct sockaddr_in sin;
683     const char *host = "127.0.0.1";
684     int cfd_connected = 0, ret = 0;
685     socklen_t slen = sizeof(sin);
686     int afd = -1;
687
688     *cfd = -1;
689     *sfd = -1;
690
691     memset ((char *) &sin, 0, sizeof(sin));
692     sin.sin_family = AF_INET;
693     sin.sin_addr.s_addr = inet_addr(host);
694
695     afd = socket(AF_INET, SOCK_STREAM, 0);
696     if (afd < 0)
697         return 0;
698
699     if (bind(afd, (struct sockaddr*)&sin, sizeof(sin)) < 0)
700         goto out;
701
702     if (getsockname(afd, (struct sockaddr*)&sin, &slen) < 0)
703         goto out;
704
705     if (listen(afd, 1) < 0)
706         goto out;
707
708     *cfd = socket(AF_INET, SOCK_STREAM, 0);
709     if (*cfd < 0)
710         goto out;
711
712     if (set_nb(afd) == -1)
713         goto out;
714
715     while (*sfd == -1 || !cfd_connected ) {
716         *sfd = accept(afd, NULL, 0);
717         if (*sfd == -1 && errno != EAGAIN)
718             goto out;
719
720         if (!cfd_connected && connect(*cfd, (struct sockaddr*)&sin, sizeof(sin)) < 0)
721             goto out;
722         else
723             cfd_connected = 1;
724     }
725
726     if (set_nb(*cfd) == -1 || set_nb(*sfd) == -1)
727         goto out;
728     ret = 1;
729     goto success;
730
731 out:
732         if (*cfd != -1)
733             close(*cfd);
734         if (*sfd != -1)
735             close(*sfd);
736 success:
737         if (afd != -1)
738             close(afd);
739     return ret;
740 }
741
742 int create_ssl_objects2(SSL_CTX *serverctx, SSL_CTX *clientctx, SSL **sssl,
743                           SSL **cssl, int sfd, int cfd)
744 {
745     SSL *serverssl = NULL, *clientssl = NULL;
746     BIO *s_to_c_bio = NULL, *c_to_s_bio = NULL;
747
748     if (*sssl != NULL)
749         serverssl = *sssl;
750     else if (!TEST_ptr(serverssl = SSL_new(serverctx)))
751         goto error;
752     if (*cssl != NULL)
753         clientssl = *cssl;
754     else if (!TEST_ptr(clientssl = SSL_new(clientctx)))
755         goto error;
756
757     if (!TEST_ptr(s_to_c_bio = BIO_new_socket(sfd, BIO_NOCLOSE))
758             || !TEST_ptr(c_to_s_bio = BIO_new_socket(cfd, BIO_NOCLOSE)))
759         goto error;
760
761     SSL_set_bio(clientssl, c_to_s_bio, c_to_s_bio);
762     SSL_set_bio(serverssl, s_to_c_bio, s_to_c_bio);
763     *sssl = serverssl;
764     *cssl = clientssl;
765     return 1;
766
767  error:
768     SSL_free(serverssl);
769     SSL_free(clientssl);
770     BIO_free(s_to_c_bio);
771     BIO_free(c_to_s_bio);
772     return 0;
773 }
774 #endif
775
776 /*
777  * NOTE: Transfers control of the BIOs - this function will free them on error
778  */
779 int create_ssl_objects(SSL_CTX *serverctx, SSL_CTX *clientctx, SSL **sssl,
780                           SSL **cssl, BIO *s_to_c_fbio, BIO *c_to_s_fbio)
781 {
782     SSL *serverssl = NULL, *clientssl = NULL;
783     BIO *s_to_c_bio = NULL, *c_to_s_bio = NULL;
784
785     if (*sssl != NULL)
786         serverssl = *sssl;
787     else if (!TEST_ptr(serverssl = SSL_new(serverctx)))
788         goto error;
789     if (*cssl != NULL)
790         clientssl = *cssl;
791     else if (!TEST_ptr(clientssl = SSL_new(clientctx)))
792         goto error;
793
794     if (SSL_is_dtls(clientssl)) {
795         if (!TEST_ptr(s_to_c_bio = BIO_new(bio_s_mempacket_test()))
796                 || !TEST_ptr(c_to_s_bio = BIO_new(bio_s_mempacket_test())))
797             goto error;
798     } else {
799         if (!TEST_ptr(s_to_c_bio = BIO_new(BIO_s_mem()))
800                 || !TEST_ptr(c_to_s_bio = BIO_new(BIO_s_mem())))
801             goto error;
802     }
803
804     if (s_to_c_fbio != NULL
805             && !TEST_ptr(s_to_c_bio = BIO_push(s_to_c_fbio, s_to_c_bio)))
806         goto error;
807     if (c_to_s_fbio != NULL
808             && !TEST_ptr(c_to_s_bio = BIO_push(c_to_s_fbio, c_to_s_bio)))
809         goto error;
810
811     /* Set Non-blocking IO behaviour */
812     BIO_set_mem_eof_return(s_to_c_bio, -1);
813     BIO_set_mem_eof_return(c_to_s_bio, -1);
814
815     /* Up ref these as we are passing them to two SSL objects */
816     SSL_set_bio(serverssl, c_to_s_bio, s_to_c_bio);
817     BIO_up_ref(s_to_c_bio);
818     BIO_up_ref(c_to_s_bio);
819     SSL_set_bio(clientssl, s_to_c_bio, c_to_s_bio);
820     *sssl = serverssl;
821     *cssl = clientssl;
822     return 1;
823
824  error:
825     SSL_free(serverssl);
826     SSL_free(clientssl);
827     BIO_free(s_to_c_bio);
828     BIO_free(c_to_s_bio);
829     BIO_free(s_to_c_fbio);
830     BIO_free(c_to_s_fbio);
831
832     return 0;
833 }
834
835 /*
836  * Create an SSL connection, but does not ready any post-handshake
837  * NewSessionTicket messages.
838  * If |read| is set and we're using DTLS then we will attempt to SSL_read on
839  * the connection once we've completed one half of it, to ensure any retransmits
840  * get triggered.
841  */
842 int create_bare_ssl_connection(SSL *serverssl, SSL *clientssl, int want,
843                                int read)
844 {
845     int retc = -1, rets = -1, err, abortctr = 0;
846     int clienterr = 0, servererr = 0;
847     int isdtls = SSL_is_dtls(serverssl);
848
849     do {
850         err = SSL_ERROR_WANT_WRITE;
851         while (!clienterr && retc <= 0 && err == SSL_ERROR_WANT_WRITE) {
852             retc = SSL_connect(clientssl);
853             if (retc <= 0)
854                 err = SSL_get_error(clientssl, retc);
855         }
856
857         if (!clienterr && retc <= 0 && err != SSL_ERROR_WANT_READ) {
858             TEST_info("SSL_connect() failed %d, %d", retc, err);
859             clienterr = 1;
860         }
861         if (want != SSL_ERROR_NONE && err == want)
862             return 0;
863
864         err = SSL_ERROR_WANT_WRITE;
865         while (!servererr && rets <= 0 && err == SSL_ERROR_WANT_WRITE) {
866             rets = SSL_accept(serverssl);
867             if (rets <= 0)
868                 err = SSL_get_error(serverssl, rets);
869         }
870
871         if (!servererr && rets <= 0
872                 && err != SSL_ERROR_WANT_READ
873                 && err != SSL_ERROR_WANT_X509_LOOKUP) {
874             TEST_info("SSL_accept() failed %d, %d", rets, err);
875             servererr = 1;
876         }
877         if (want != SSL_ERROR_NONE && err == want)
878             return 0;
879         if (clienterr && servererr)
880             return 0;
881         if (isdtls && read) {
882             unsigned char buf[20];
883
884             /* Trigger any retransmits that may be appropriate */
885             if (rets > 0 && retc <= 0) {
886                 if (SSL_read(serverssl, buf, sizeof(buf)) > 0) {
887                     /* We don't expect this to succeed! */
888                     TEST_info("Unexpected SSL_read() success!");
889                     return 0;
890                 }
891             }
892             if (retc > 0 && rets <= 0) {
893                 if (SSL_read(clientssl, buf, sizeof(buf)) > 0) {
894                     /* We don't expect this to succeed! */
895                     TEST_info("Unexpected SSL_read() success!");
896                     return 0;
897                 }
898             }
899         }
900         if (++abortctr == MAXLOOPS) {
901             TEST_info("No progress made");
902             return 0;
903         }
904         if (isdtls && abortctr <= 50 && (abortctr % 10) == 0) {
905             /*
906              * It looks like we're just spinning. Pause for a short period to
907              * give the DTLS timer a chance to do something. We only do this for
908              * the first few times to prevent hangs.
909              */
910             ossl_sleep(50);
911         }
912     } while (retc <=0 || rets <= 0);
913
914     return 1;
915 }
916
917 /*
918  * Create an SSL connection including any post handshake NewSessionTicket
919  * messages.
920  */
921 int create_ssl_connection(SSL *serverssl, SSL *clientssl, int want)
922 {
923     int i;
924     unsigned char buf;
925     size_t readbytes;
926
927     if (!create_bare_ssl_connection(serverssl, clientssl, want, 1))
928         return 0;
929
930     /*
931      * We attempt to read some data on the client side which we expect to fail.
932      * This will ensure we have received the NewSessionTicket in TLSv1.3 where
933      * appropriate. We do this twice because there are 2 NewSesionTickets.
934      */
935     for (i = 0; i < 2; i++) {
936         if (SSL_read_ex(clientssl, &buf, sizeof(buf), &readbytes) > 0) {
937             if (!TEST_ulong_eq(readbytes, 0))
938                 return 0;
939         } else if (!TEST_int_eq(SSL_get_error(clientssl, 0),
940                                 SSL_ERROR_WANT_READ)) {
941             return 0;
942         }
943     }
944
945     return 1;
946 }
947
948 void shutdown_ssl_connection(SSL *serverssl, SSL *clientssl)
949 {
950     SSL_shutdown(clientssl);
951     SSL_shutdown(serverssl);
952     SSL_free(serverssl);
953     SSL_free(clientssl);
954 }