crypto/bn: fix return value in BN_generate_prime
[openssl.git] / test / ssltestlib.c
1 /*
2  * Copyright 2016-2019 The OpenSSL Project Authors. All Rights Reserved.
3  *
4  * Licensed under the OpenSSL license (the "License").  You may not use
5  * this file except in compliance with the License.  You can obtain a copy
6  * in the file LICENSE in the source distribution or at
7  * https://www.openssl.org/source/license.html
8  */
9
10 #include <string.h>
11
12 #include "internal/nelem.h"
13 #include "ssltestlib.h"
14 #include "testutil.h"
15 #include "e_os.h"
16
17 #ifdef OPENSSL_SYS_UNIX
18 # include <unistd.h>
19
20 static ossl_inline void ossl_sleep(unsigned int millis)
21 {
22 # ifdef OPENSSL_SYS_VXWORKS
23     struct timespec ts;
24     ts.tv_sec = (long int) (millis / 1000);
25     ts.tv_nsec = (long int) (millis % 1000) * 1000000ul;
26     nanosleep(&ts, NULL);
27 # else
28     usleep(millis * 1000);
29 # endif
30 }
31 #elif defined(_WIN32)
32 # include <windows.h>
33
34 static ossl_inline void ossl_sleep(unsigned int millis)
35 {
36     Sleep(millis);
37 }
38 #else
39 /* Fallback to a busy wait */
40 static ossl_inline void ossl_sleep(unsigned int millis)
41 {
42     struct timeval start, now;
43     unsigned int elapsedms;
44
45     gettimeofday(&start, NULL);
46     do {
47         gettimeofday(&now, NULL);
48         elapsedms = (((now.tv_sec - start.tv_sec) * 1000000)
49                      + now.tv_usec - start.tv_usec) / 1000;
50     } while (elapsedms < millis);
51 }
52 #endif
53
54 static int tls_dump_new(BIO *bi);
55 static int tls_dump_free(BIO *a);
56 static int tls_dump_read(BIO *b, char *out, int outl);
57 static int tls_dump_write(BIO *b, const char *in, int inl);
58 static long tls_dump_ctrl(BIO *b, int cmd, long num, void *ptr);
59 static int tls_dump_gets(BIO *bp, char *buf, int size);
60 static int tls_dump_puts(BIO *bp, const char *str);
61
62 /* Choose a sufficiently large type likely to be unused for this custom BIO */
63 #define BIO_TYPE_TLS_DUMP_FILTER  (0x80 | BIO_TYPE_FILTER)
64 #define BIO_TYPE_MEMPACKET_TEST    0x81
65
66 static BIO_METHOD *method_tls_dump = NULL;
67 static BIO_METHOD *meth_mem = NULL;
68
69 /* Note: Not thread safe! */
70 const BIO_METHOD *bio_f_tls_dump_filter(void)
71 {
72     if (method_tls_dump == NULL) {
73         method_tls_dump = BIO_meth_new(BIO_TYPE_TLS_DUMP_FILTER,
74                                         "TLS dump filter");
75         if (   method_tls_dump == NULL
76             || !BIO_meth_set_write(method_tls_dump, tls_dump_write)
77             || !BIO_meth_set_read(method_tls_dump, tls_dump_read)
78             || !BIO_meth_set_puts(method_tls_dump, tls_dump_puts)
79             || !BIO_meth_set_gets(method_tls_dump, tls_dump_gets)
80             || !BIO_meth_set_ctrl(method_tls_dump, tls_dump_ctrl)
81             || !BIO_meth_set_create(method_tls_dump, tls_dump_new)
82             || !BIO_meth_set_destroy(method_tls_dump, tls_dump_free))
83             return NULL;
84     }
85     return method_tls_dump;
86 }
87
88 void bio_f_tls_dump_filter_free(void)
89 {
90     BIO_meth_free(method_tls_dump);
91 }
92
93 static int tls_dump_new(BIO *bio)
94 {
95     BIO_set_init(bio, 1);
96     return 1;
97 }
98
99 static int tls_dump_free(BIO *bio)
100 {
101     BIO_set_init(bio, 0);
102
103     return 1;
104 }
105
106 static void copy_flags(BIO *bio)
107 {
108     int flags;
109     BIO *next = BIO_next(bio);
110
111     flags = BIO_test_flags(next, BIO_FLAGS_SHOULD_RETRY | BIO_FLAGS_RWS);
112     BIO_clear_flags(bio, BIO_FLAGS_SHOULD_RETRY | BIO_FLAGS_RWS);
113     BIO_set_flags(bio, flags);
114 }
115
116 #define RECORD_CONTENT_TYPE     0
117 #define RECORD_VERSION_HI       1
118 #define RECORD_VERSION_LO       2
119 #define RECORD_EPOCH_HI         3
120 #define RECORD_EPOCH_LO         4
121 #define RECORD_SEQUENCE_START   5
122 #define RECORD_SEQUENCE_END     10
123 #define RECORD_LEN_HI           11
124 #define RECORD_LEN_LO           12
125
126 #define MSG_TYPE                0
127 #define MSG_LEN_HI              1
128 #define MSG_LEN_MID             2
129 #define MSG_LEN_LO              3
130 #define MSG_SEQ_HI              4
131 #define MSG_SEQ_LO              5
132 #define MSG_FRAG_OFF_HI         6
133 #define MSG_FRAG_OFF_MID        7
134 #define MSG_FRAG_OFF_LO         8
135 #define MSG_FRAG_LEN_HI         9
136 #define MSG_FRAG_LEN_MID        10
137 #define MSG_FRAG_LEN_LO         11
138
139
140 static void dump_data(const char *data, int len)
141 {
142     int rem, i, content, reclen, msglen, fragoff, fraglen, epoch;
143     unsigned char *rec;
144
145     printf("---- START OF PACKET ----\n");
146
147     rem = len;
148     rec = (unsigned char *)data;
149
150     while (rem > 0) {
151         if (rem != len)
152             printf("*\n");
153         printf("*---- START OF RECORD ----\n");
154         if (rem < DTLS1_RT_HEADER_LENGTH) {
155             printf("*---- RECORD TRUNCATED ----\n");
156             break;
157         }
158         content = rec[RECORD_CONTENT_TYPE];
159         printf("** Record Content-type: %d\n", content);
160         printf("** Record Version: %02x%02x\n",
161                rec[RECORD_VERSION_HI], rec[RECORD_VERSION_LO]);
162         epoch = (rec[RECORD_EPOCH_HI] << 8) | rec[RECORD_EPOCH_LO];
163         printf("** Record Epoch: %d\n", epoch);
164         printf("** Record Sequence: ");
165         for (i = RECORD_SEQUENCE_START; i <= RECORD_SEQUENCE_END; i++)
166             printf("%02x", rec[i]);
167         reclen = (rec[RECORD_LEN_HI] << 8) | rec[RECORD_LEN_LO];
168         printf("\n** Record Length: %d\n", reclen);
169
170         /* Now look at message */
171         rec += DTLS1_RT_HEADER_LENGTH;
172         rem -= DTLS1_RT_HEADER_LENGTH;
173         if (content == SSL3_RT_HANDSHAKE) {
174             printf("**---- START OF HANDSHAKE MESSAGE FRAGMENT ----\n");
175             if (epoch > 0) {
176                 printf("**---- HANDSHAKE MESSAGE FRAGMENT ENCRYPTED ----\n");
177             } else if (rem < DTLS1_HM_HEADER_LENGTH
178                     || reclen < DTLS1_HM_HEADER_LENGTH) {
179                 printf("**---- HANDSHAKE MESSAGE FRAGMENT TRUNCATED ----\n");
180             } else {
181                 printf("*** Message Type: %d\n", rec[MSG_TYPE]);
182                 msglen = (rec[MSG_LEN_HI] << 16) | (rec[MSG_LEN_MID] << 8)
183                          | rec[MSG_LEN_LO];
184                 printf("*** Message Length: %d\n", msglen);
185                 printf("*** Message sequence: %d\n",
186                        (rec[MSG_SEQ_HI] << 8) | rec[MSG_SEQ_LO]);
187                 fragoff = (rec[MSG_FRAG_OFF_HI] << 16)
188                           | (rec[MSG_FRAG_OFF_MID] << 8)
189                           | rec[MSG_FRAG_OFF_LO];
190                 printf("*** Message Fragment offset: %d\n", fragoff);
191                 fraglen = (rec[MSG_FRAG_LEN_HI] << 16)
192                           | (rec[MSG_FRAG_LEN_MID] << 8)
193                           | rec[MSG_FRAG_LEN_LO];
194                 printf("*** Message Fragment len: %d\n", fraglen);
195                 if (fragoff + fraglen > msglen)
196                     printf("***---- HANDSHAKE MESSAGE FRAGMENT INVALID ----\n");
197                 else if (reclen < fraglen)
198                     printf("**---- HANDSHAKE MESSAGE FRAGMENT TRUNCATED ----\n");
199                 else
200                     printf("**---- END OF HANDSHAKE MESSAGE FRAGMENT ----\n");
201             }
202         }
203         if (rem < reclen) {
204             printf("*---- RECORD TRUNCATED ----\n");
205             rem = 0;
206         } else {
207             rec += reclen;
208             rem -= reclen;
209             printf("*---- END OF RECORD ----\n");
210         }
211     }
212     printf("---- END OF PACKET ----\n\n");
213     fflush(stdout);
214 }
215
216 static int tls_dump_read(BIO *bio, char *out, int outl)
217 {
218     int ret;
219     BIO *next = BIO_next(bio);
220
221     ret = BIO_read(next, out, outl);
222     copy_flags(bio);
223
224     if (ret > 0) {
225         dump_data(out, ret);
226     }
227
228     return ret;
229 }
230
231 static int tls_dump_write(BIO *bio, const char *in, int inl)
232 {
233     int ret;
234     BIO *next = BIO_next(bio);
235
236     ret = BIO_write(next, in, inl);
237     copy_flags(bio);
238
239     return ret;
240 }
241
242 static long tls_dump_ctrl(BIO *bio, int cmd, long num, void *ptr)
243 {
244     long ret;
245     BIO *next = BIO_next(bio);
246
247     if (next == NULL)
248         return 0;
249
250     switch (cmd) {
251     case BIO_CTRL_DUP:
252         ret = 0L;
253         break;
254     default:
255         ret = BIO_ctrl(next, cmd, num, ptr);
256         break;
257     }
258     return ret;
259 }
260
261 static int tls_dump_gets(BIO *bio, char *buf, int size)
262 {
263     /* We don't support this - not needed anyway */
264     return -1;
265 }
266
267 static int tls_dump_puts(BIO *bio, const char *str)
268 {
269     return tls_dump_write(bio, str, strlen(str));
270 }
271
272
273 struct mempacket_st {
274     unsigned char *data;
275     int len;
276     unsigned int num;
277     unsigned int type;
278 };
279
280 static void mempacket_free(MEMPACKET *pkt)
281 {
282     if (pkt->data != NULL)
283         OPENSSL_free(pkt->data);
284     OPENSSL_free(pkt);
285 }
286
287 typedef struct mempacket_test_ctx_st {
288     STACK_OF(MEMPACKET) *pkts;
289     unsigned int epoch;
290     unsigned int currrec;
291     unsigned int currpkt;
292     unsigned int lastpkt;
293     unsigned int injected;
294     unsigned int noinject;
295     unsigned int dropepoch;
296     int droprec;
297     int duprec;
298 } MEMPACKET_TEST_CTX;
299
300 static int mempacket_test_new(BIO *bi);
301 static int mempacket_test_free(BIO *a);
302 static int mempacket_test_read(BIO *b, char *out, int outl);
303 static int mempacket_test_write(BIO *b, const char *in, int inl);
304 static long mempacket_test_ctrl(BIO *b, int cmd, long num, void *ptr);
305 static int mempacket_test_gets(BIO *bp, char *buf, int size);
306 static int mempacket_test_puts(BIO *bp, const char *str);
307
308 const BIO_METHOD *bio_s_mempacket_test(void)
309 {
310     if (meth_mem == NULL) {
311         if (!TEST_ptr(meth_mem = BIO_meth_new(BIO_TYPE_MEMPACKET_TEST,
312                                               "Mem Packet Test"))
313             || !TEST_true(BIO_meth_set_write(meth_mem, mempacket_test_write))
314             || !TEST_true(BIO_meth_set_read(meth_mem, mempacket_test_read))
315             || !TEST_true(BIO_meth_set_puts(meth_mem, mempacket_test_puts))
316             || !TEST_true(BIO_meth_set_gets(meth_mem, mempacket_test_gets))
317             || !TEST_true(BIO_meth_set_ctrl(meth_mem, mempacket_test_ctrl))
318             || !TEST_true(BIO_meth_set_create(meth_mem, mempacket_test_new))
319             || !TEST_true(BIO_meth_set_destroy(meth_mem, mempacket_test_free)))
320             return NULL;
321     }
322     return meth_mem;
323 }
324
325 void bio_s_mempacket_test_free(void)
326 {
327     BIO_meth_free(meth_mem);
328 }
329
330 static int mempacket_test_new(BIO *bio)
331 {
332     MEMPACKET_TEST_CTX *ctx;
333
334     if (!TEST_ptr(ctx = OPENSSL_zalloc(sizeof(*ctx))))
335         return 0;
336     if (!TEST_ptr(ctx->pkts = sk_MEMPACKET_new_null())) {
337         OPENSSL_free(ctx);
338         return 0;
339     }
340     ctx->dropepoch = 0;
341     ctx->droprec = -1;
342     BIO_set_init(bio, 1);
343     BIO_set_data(bio, ctx);
344     return 1;
345 }
346
347 static int mempacket_test_free(BIO *bio)
348 {
349     MEMPACKET_TEST_CTX *ctx = BIO_get_data(bio);
350
351     sk_MEMPACKET_pop_free(ctx->pkts, mempacket_free);
352     OPENSSL_free(ctx);
353     BIO_set_data(bio, NULL);
354     BIO_set_init(bio, 0);
355     return 1;
356 }
357
358 /* Record Header values */
359 #define EPOCH_HI        3
360 #define EPOCH_LO        4
361 #define RECORD_SEQUENCE 10
362 #define RECORD_LEN_HI   11
363 #define RECORD_LEN_LO   12
364
365 #define STANDARD_PACKET                 0
366
367 static int mempacket_test_read(BIO *bio, char *out, int outl)
368 {
369     MEMPACKET_TEST_CTX *ctx = BIO_get_data(bio);
370     MEMPACKET *thispkt;
371     unsigned char *rec;
372     int rem;
373     unsigned int seq, offset, len, epoch;
374
375     BIO_clear_retry_flags(bio);
376     thispkt = sk_MEMPACKET_value(ctx->pkts, 0);
377     if (thispkt == NULL || thispkt->num != ctx->currpkt) {
378         /* Probably run out of data */
379         BIO_set_retry_read(bio);
380         return -1;
381     }
382     (void)sk_MEMPACKET_shift(ctx->pkts);
383     ctx->currpkt++;
384
385     if (outl > thispkt->len)
386         outl = thispkt->len;
387
388     if (thispkt->type != INJECT_PACKET_IGNORE_REC_SEQ
389             && (ctx->injected || ctx->droprec >= 0)) {
390         /*
391          * Overwrite the record sequence number. We strictly number them in
392          * the order received. Since we are actually a reliable transport
393          * we know that there won't be any re-ordering. We overwrite to deal
394          * with any packets that have been injected
395          */
396         for (rem = thispkt->len, rec = thispkt->data; rem > 0; rem -= len) {
397             if (rem < DTLS1_RT_HEADER_LENGTH)
398                 return -1;
399             epoch = (rec[EPOCH_HI] << 8) | rec[EPOCH_LO];
400             if (epoch != ctx->epoch) {
401                 ctx->epoch = epoch;
402                 ctx->currrec = 0;
403             }
404             seq = ctx->currrec;
405             offset = 0;
406             do {
407                 rec[RECORD_SEQUENCE - offset] = seq & 0xFF;
408                 seq >>= 8;
409                 offset++;
410             } while (seq > 0);
411
412             len = ((rec[RECORD_LEN_HI] << 8) | rec[RECORD_LEN_LO])
413                   + DTLS1_RT_HEADER_LENGTH;
414             if (rem < (int)len)
415                 return -1;
416             if (ctx->droprec == (int)ctx->currrec && ctx->dropepoch == epoch) {
417                 if (rem > (int)len)
418                     memmove(rec, rec + len, rem - len);
419                 outl -= len;
420                 ctx->droprec = -1;
421                 if (outl == 0)
422                     BIO_set_retry_read(bio);
423             } else {
424                 rec += len;
425             }
426
427             ctx->currrec++;
428         }
429     }
430
431     memcpy(out, thispkt->data, outl);
432     mempacket_free(thispkt);
433     return outl;
434 }
435
436 int mempacket_test_inject(BIO *bio, const char *in, int inl, int pktnum,
437                           int type)
438 {
439     MEMPACKET_TEST_CTX *ctx = BIO_get_data(bio);
440     MEMPACKET *thispkt = NULL, *looppkt, *nextpkt, *allpkts[3];
441     int i, duprec;
442     const unsigned char *inu = (const unsigned char *)in;
443     size_t len = ((inu[RECORD_LEN_HI] << 8) | inu[RECORD_LEN_LO])
444                  + DTLS1_RT_HEADER_LENGTH;
445
446     if (ctx == NULL)
447         return -1;
448
449     if ((size_t)inl < len)
450         return -1;
451
452     if ((size_t)inl == len)
453         duprec = 0;
454     else
455         duprec = ctx->duprec > 0;
456
457     /* We don't support arbitrary injection when duplicating records */
458     if (duprec && pktnum != -1)
459         return -1;
460
461     /* We only allow injection before we've started writing any data */
462     if (pktnum >= 0) {
463         if (ctx->noinject)
464             return -1;
465         ctx->injected  = 1;
466     } else {
467         ctx->noinject = 1;
468     }
469
470     for (i = 0; i < (duprec ? 3 : 1); i++) {
471         if (!TEST_ptr(allpkts[i] = OPENSSL_malloc(sizeof(*thispkt))))
472             goto err;
473         thispkt = allpkts[i];
474
475         if (!TEST_ptr(thispkt->data = OPENSSL_malloc(inl)))
476             goto err;
477         /*
478          * If we are duplicating the packet, we duplicate it three times. The
479          * first two times we drop the first record if there are more than one.
480          * In this way we know that libssl will not be able to make progress
481          * until it receives the last packet, and hence will be forced to
482          * buffer these records.
483          */
484         if (duprec && i != 2) {
485             memcpy(thispkt->data, in + len, inl - len);
486             thispkt->len = inl - len;
487         } else {
488             memcpy(thispkt->data, in, inl);
489             thispkt->len = inl;
490         }
491         thispkt->num = (pktnum >= 0) ? (unsigned int)pktnum : ctx->lastpkt + i;
492         thispkt->type = type;
493     }
494
495     for(i = 0; (looppkt = sk_MEMPACKET_value(ctx->pkts, i)) != NULL; i++) {
496         /* Check if we found the right place to insert this packet */
497         if (looppkt->num > thispkt->num) {
498             if (sk_MEMPACKET_insert(ctx->pkts, thispkt, i) == 0)
499                 goto err;
500             /* If we're doing up front injection then we're done */
501             if (pktnum >= 0)
502                 return inl;
503             /*
504              * We need to do some accounting on lastpkt. We increment it first,
505              * but it might now equal the value of injected packets, so we need
506              * to skip over those
507              */
508             ctx->lastpkt++;
509             do {
510                 i++;
511                 nextpkt = sk_MEMPACKET_value(ctx->pkts, i);
512                 if (nextpkt != NULL && nextpkt->num == ctx->lastpkt)
513                     ctx->lastpkt++;
514                 else
515                     return inl;
516             } while(1);
517         } else if (looppkt->num == thispkt->num) {
518             if (!ctx->noinject) {
519                 /* We injected two packets with the same packet number! */
520                 goto err;
521             }
522             ctx->lastpkt++;
523             thispkt->num++;
524         }
525     }
526     /*
527      * We didn't find any packets with a packet number equal to or greater than
528      * this one, so we just add it onto the end
529      */
530     for (i = 0; i < (duprec ? 3 : 1); i++) {
531         thispkt = allpkts[i];
532         if (!sk_MEMPACKET_push(ctx->pkts, thispkt))
533             goto err;
534
535         if (pktnum < 0)
536             ctx->lastpkt++;
537     }
538
539     return inl;
540
541  err:
542     for (i = 0; i < (ctx->duprec > 0 ? 3 : 1); i++)
543         mempacket_free(allpkts[i]);
544     return -1;
545 }
546
547 static int mempacket_test_write(BIO *bio, const char *in, int inl)
548 {
549     return mempacket_test_inject(bio, in, inl, -1, STANDARD_PACKET);
550 }
551
552 static long mempacket_test_ctrl(BIO *bio, int cmd, long num, void *ptr)
553 {
554     long ret = 1;
555     MEMPACKET_TEST_CTX *ctx = BIO_get_data(bio);
556     MEMPACKET *thispkt;
557
558     switch (cmd) {
559     case BIO_CTRL_EOF:
560         ret = (long)(sk_MEMPACKET_num(ctx->pkts) == 0);
561         break;
562     case BIO_CTRL_GET_CLOSE:
563         ret = BIO_get_shutdown(bio);
564         break;
565     case BIO_CTRL_SET_CLOSE:
566         BIO_set_shutdown(bio, (int)num);
567         break;
568     case BIO_CTRL_WPENDING:
569         ret = 0L;
570         break;
571     case BIO_CTRL_PENDING:
572         thispkt = sk_MEMPACKET_value(ctx->pkts, 0);
573         if (thispkt == NULL)
574             ret = 0;
575         else
576             ret = thispkt->len;
577         break;
578     case BIO_CTRL_FLUSH:
579         ret = 1;
580         break;
581     case MEMPACKET_CTRL_SET_DROP_EPOCH:
582         ctx->dropepoch = (unsigned int)num;
583         break;
584     case MEMPACKET_CTRL_SET_DROP_REC:
585         ctx->droprec = (int)num;
586         break;
587     case MEMPACKET_CTRL_GET_DROP_REC:
588         ret = ctx->droprec;
589         break;
590     case MEMPACKET_CTRL_SET_DUPLICATE_REC:
591         ctx->duprec = (int)num;
592         break;
593     case BIO_CTRL_RESET:
594     case BIO_CTRL_DUP:
595     case BIO_CTRL_PUSH:
596     case BIO_CTRL_POP:
597     default:
598         ret = 0;
599         break;
600     }
601     return ret;
602 }
603
604 static int mempacket_test_gets(BIO *bio, char *buf, int size)
605 {
606     /* We don't support this - not needed anyway */
607     return -1;
608 }
609
610 static int mempacket_test_puts(BIO *bio, const char *str)
611 {
612     return mempacket_test_write(bio, str, strlen(str));
613 }
614
615 int create_ssl_ctx_pair(const SSL_METHOD *sm, const SSL_METHOD *cm,
616                         int min_proto_version, int max_proto_version,
617                         SSL_CTX **sctx, SSL_CTX **cctx, char *certfile,
618                         char *privkeyfile)
619 {
620     SSL_CTX *serverctx = NULL;
621     SSL_CTX *clientctx = NULL;
622
623     if (!TEST_ptr(serverctx = SSL_CTX_new(sm))
624             || (cctx != NULL && !TEST_ptr(clientctx = SSL_CTX_new(cm))))
625         goto err;
626
627     if ((min_proto_version > 0
628          && !TEST_true(SSL_CTX_set_min_proto_version(serverctx,
629                                                      min_proto_version)))
630         || (max_proto_version > 0
631             && !TEST_true(SSL_CTX_set_max_proto_version(serverctx,
632                                                         max_proto_version))))
633         goto err;
634     if (clientctx != NULL
635         && ((min_proto_version > 0
636              && !TEST_true(SSL_CTX_set_min_proto_version(clientctx,
637                                                          min_proto_version)))
638             || (max_proto_version > 0
639                 && !TEST_true(SSL_CTX_set_max_proto_version(clientctx,
640                                                             max_proto_version)))))
641         goto err;
642
643     if (certfile != NULL && privkeyfile != NULL) {
644         if (!TEST_int_eq(SSL_CTX_use_certificate_file(serverctx, certfile,
645                                                       SSL_FILETYPE_PEM), 1)
646                 || !TEST_int_eq(SSL_CTX_use_PrivateKey_file(serverctx,
647                                                             privkeyfile,
648                                                             SSL_FILETYPE_PEM), 1)
649                 || !TEST_int_eq(SSL_CTX_check_private_key(serverctx), 1))
650             goto err;
651     }
652
653 #ifndef OPENSSL_NO_DH
654     SSL_CTX_set_dh_auto(serverctx, 1);
655 #endif
656
657     *sctx = serverctx;
658     if (cctx != NULL)
659         *cctx = clientctx;
660     return 1;
661
662  err:
663     SSL_CTX_free(serverctx);
664     SSL_CTX_free(clientctx);
665     return 0;
666 }
667
668 #define MAXLOOPS    1000000
669
670 /*
671  * NOTE: Transfers control of the BIOs - this function will free them on error
672  */
673 int create_ssl_objects(SSL_CTX *serverctx, SSL_CTX *clientctx, SSL **sssl,
674                           SSL **cssl, BIO *s_to_c_fbio, BIO *c_to_s_fbio)
675 {
676     SSL *serverssl = NULL, *clientssl = NULL;
677     BIO *s_to_c_bio = NULL, *c_to_s_bio = NULL;
678
679     if (*sssl != NULL)
680         serverssl = *sssl;
681     else if (!TEST_ptr(serverssl = SSL_new(serverctx)))
682         goto error;
683     if (*cssl != NULL)
684         clientssl = *cssl;
685     else if (!TEST_ptr(clientssl = SSL_new(clientctx)))
686         goto error;
687
688     if (SSL_is_dtls(clientssl)) {
689         if (!TEST_ptr(s_to_c_bio = BIO_new(bio_s_mempacket_test()))
690                 || !TEST_ptr(c_to_s_bio = BIO_new(bio_s_mempacket_test())))
691             goto error;
692     } else {
693         if (!TEST_ptr(s_to_c_bio = BIO_new(BIO_s_mem()))
694                 || !TEST_ptr(c_to_s_bio = BIO_new(BIO_s_mem())))
695             goto error;
696     }
697
698     if (s_to_c_fbio != NULL
699             && !TEST_ptr(s_to_c_bio = BIO_push(s_to_c_fbio, s_to_c_bio)))
700         goto error;
701     if (c_to_s_fbio != NULL
702             && !TEST_ptr(c_to_s_bio = BIO_push(c_to_s_fbio, c_to_s_bio)))
703         goto error;
704
705     /* Set Non-blocking IO behaviour */
706     BIO_set_mem_eof_return(s_to_c_bio, -1);
707     BIO_set_mem_eof_return(c_to_s_bio, -1);
708
709     /* Up ref these as we are passing them to two SSL objects */
710     SSL_set_bio(serverssl, c_to_s_bio, s_to_c_bio);
711     BIO_up_ref(s_to_c_bio);
712     BIO_up_ref(c_to_s_bio);
713     SSL_set_bio(clientssl, s_to_c_bio, c_to_s_bio);
714     *sssl = serverssl;
715     *cssl = clientssl;
716     return 1;
717
718  error:
719     SSL_free(serverssl);
720     SSL_free(clientssl);
721     BIO_free(s_to_c_bio);
722     BIO_free(c_to_s_bio);
723     BIO_free(s_to_c_fbio);
724     BIO_free(c_to_s_fbio);
725
726     return 0;
727 }
728
729 /*
730  * Create an SSL connection, but does not ready any post-handshake
731  * NewSessionTicket messages.
732  * If |read| is set and we're using DTLS then we will attempt to SSL_read on
733  * the connection once we've completed one half of it, to ensure any retransmits
734  * get triggered.
735  */
736 int create_bare_ssl_connection(SSL *serverssl, SSL *clientssl, int want,
737                                int read)
738 {
739     int retc = -1, rets = -1, err, abortctr = 0;
740     int clienterr = 0, servererr = 0;
741     int isdtls = SSL_is_dtls(serverssl);
742
743     do {
744         err = SSL_ERROR_WANT_WRITE;
745         while (!clienterr && retc <= 0 && err == SSL_ERROR_WANT_WRITE) {
746             retc = SSL_connect(clientssl);
747             if (retc <= 0)
748                 err = SSL_get_error(clientssl, retc);
749         }
750
751         if (!clienterr && retc <= 0 && err != SSL_ERROR_WANT_READ) {
752             TEST_info("SSL_connect() failed %d, %d", retc, err);
753             clienterr = 1;
754         }
755         if (want != SSL_ERROR_NONE && err == want)
756             return 0;
757
758         err = SSL_ERROR_WANT_WRITE;
759         while (!servererr && rets <= 0 && err == SSL_ERROR_WANT_WRITE) {
760             rets = SSL_accept(serverssl);
761             if (rets <= 0)
762                 err = SSL_get_error(serverssl, rets);
763         }
764
765         if (!servererr && rets <= 0
766                 && err != SSL_ERROR_WANT_READ
767                 && err != SSL_ERROR_WANT_X509_LOOKUP) {
768             TEST_info("SSL_accept() failed %d, %d", rets, err);
769             servererr = 1;
770         }
771         if (want != SSL_ERROR_NONE && err == want)
772             return 0;
773         if (clienterr && servererr)
774             return 0;
775         if (isdtls && read) {
776             unsigned char buf[20];
777
778             /* Trigger any retransmits that may be appropriate */
779             if (rets > 0 && retc <= 0) {
780                 if (SSL_read(serverssl, buf, sizeof(buf)) > 0) {
781                     /* We don't expect this to succeed! */
782                     TEST_info("Unexpected SSL_read() success!");
783                     return 0;
784                 }
785             }
786             if (retc > 0 && rets <= 0) {
787                 if (SSL_read(clientssl, buf, sizeof(buf)) > 0) {
788                     /* We don't expect this to succeed! */
789                     TEST_info("Unexpected SSL_read() success!");
790                     return 0;
791                 }
792             }
793         }
794         if (++abortctr == MAXLOOPS) {
795             TEST_info("No progress made");
796             return 0;
797         }
798         if (isdtls && abortctr <= 50 && (abortctr % 10) == 0) {
799             /*
800              * It looks like we're just spinning. Pause for a short period to
801              * give the DTLS timer a chance to do something. We only do this for
802              * the first few times to prevent hangs.
803              */
804             ossl_sleep(50);
805         }
806     } while (retc <=0 || rets <= 0);
807
808     return 1;
809 }
810
811 /*
812  * Create an SSL connection including any post handshake NewSessionTicket
813  * messages.
814  */
815 int create_ssl_connection(SSL *serverssl, SSL *clientssl, int want)
816 {
817     int i;
818     unsigned char buf;
819     size_t readbytes;
820
821     if (!create_bare_ssl_connection(serverssl, clientssl, want, 1))
822         return 0;
823
824     /*
825      * We attempt to read some data on the client side which we expect to fail.
826      * This will ensure we have received the NewSessionTicket in TLSv1.3 where
827      * appropriate. We do this twice because there are 2 NewSesionTickets.
828      */
829     for (i = 0; i < 2; i++) {
830         if (SSL_read_ex(clientssl, &buf, sizeof(buf), &readbytes) > 0) {
831             if (!TEST_ulong_eq(readbytes, 0))
832                 return 0;
833         } else if (!TEST_int_eq(SSL_get_error(clientssl, 0),
834                                 SSL_ERROR_WANT_READ)) {
835             return 0;
836         }
837     }
838
839     return 1;
840 }
841
842 void shutdown_ssl_connection(SSL *serverssl, SSL *clientssl)
843 {
844     SSL_shutdown(clientssl);
845     SSL_shutdown(serverssl);
846     SSL_free(serverssl);
847     SSL_free(clientssl);
848 }