Add a test for the new early data callback
[openssl.git] / test / sslapitest.c
1 /*
2  * Copyright 2016-2018 The OpenSSL Project Authors. All Rights Reserved.
3  *
4  * Licensed under the OpenSSL license (the "License").  You may not use
5  * this file except in compliance with the License.  You can obtain a copy
6  * in the file LICENSE in the source distribution or at
7  * https://www.openssl.org/source/license.html
8  */
9
10 #include <string.h>
11
12 #include <openssl/opensslconf.h>
13 #include <openssl/bio.h>
14 #include <openssl/crypto.h>
15 #include <openssl/ssl.h>
16 #include <openssl/ocsp.h>
17 #include <openssl/srp.h>
18 #include <openssl/txt_db.h>
19 #include <openssl/aes.h>
20
21 #include "ssltestlib.h"
22 #include "testutil.h"
23 #include "testutil/output.h"
24 #include "internal/nelem.h"
25 #include "../ssl/ssl_locl.h"
26
27 static char *cert = NULL;
28 static char *privkey = NULL;
29 static char *srpvfile = NULL;
30 static char *tmpfilename = NULL;
31
32 #define LOG_BUFFER_SIZE 2048
33 static char server_log_buffer[LOG_BUFFER_SIZE + 1] = {0};
34 static size_t server_log_buffer_index = 0;
35 static char client_log_buffer[LOG_BUFFER_SIZE + 1] = {0};
36 static size_t client_log_buffer_index = 0;
37 static int error_writing_log = 0;
38
39 #ifndef OPENSSL_NO_OCSP
40 static const unsigned char orespder[] = "Dummy OCSP Response";
41 static int ocsp_server_called = 0;
42 static int ocsp_client_called = 0;
43
44 static int cdummyarg = 1;
45 static X509 *ocspcert = NULL;
46 #endif
47
48 #define NUM_EXTRA_CERTS 40
49 #define CLIENT_VERSION_LEN      2
50
51 /*
52  * This structure is used to validate that the correct number of log messages
53  * of various types are emitted when emitting secret logs.
54  */
55 struct sslapitest_log_counts {
56     unsigned int rsa_key_exchange_count;
57     unsigned int master_secret_count;
58     unsigned int client_early_secret_count;
59     unsigned int client_handshake_secret_count;
60     unsigned int server_handshake_secret_count;
61     unsigned int client_application_secret_count;
62     unsigned int server_application_secret_count;
63     unsigned int early_exporter_secret_count;
64     unsigned int exporter_secret_count;
65 };
66
67
68 static unsigned char serverinfov1[] = {
69     0xff, 0xff, /* Dummy extension type */
70     0x00, 0x01, /* Extension length is 1 byte */
71     0xff        /* Dummy extension data */
72 };
73
74 static unsigned char serverinfov2[] = {
75     0x00, 0x00, 0x00,
76     (unsigned char)(SSL_EXT_CLIENT_HELLO & 0xff), /* Dummy context - 4 bytes */
77     0xff, 0xff, /* Dummy extension type */
78     0x00, 0x01, /* Extension length is 1 byte */
79     0xff        /* Dummy extension data */
80 };
81
82 static void client_keylog_callback(const SSL *ssl, const char *line)
83 {
84     int line_length = strlen(line);
85
86     /* If the log doesn't fit, error out. */
87     if (client_log_buffer_index + line_length > sizeof(client_log_buffer) - 1) {
88         TEST_info("Client log too full");
89         error_writing_log = 1;
90         return;
91     }
92
93     strcat(client_log_buffer, line);
94     client_log_buffer_index += line_length;
95     client_log_buffer[client_log_buffer_index++] = '\n';
96 }
97
98 static void server_keylog_callback(const SSL *ssl, const char *line)
99 {
100     int line_length = strlen(line);
101
102     /* If the log doesn't fit, error out. */
103     if (server_log_buffer_index + line_length > sizeof(server_log_buffer) - 1) {
104         TEST_info("Server log too full");
105         error_writing_log = 1;
106         return;
107     }
108
109     strcat(server_log_buffer, line);
110     server_log_buffer_index += line_length;
111     server_log_buffer[server_log_buffer_index++] = '\n';
112 }
113
114 static int compare_hex_encoded_buffer(const char *hex_encoded,
115                                       size_t hex_length,
116                                       const uint8_t *raw,
117                                       size_t raw_length)
118 {
119     size_t i, j;
120     char hexed[3];
121
122     if (!TEST_size_t_eq(raw_length * 2, hex_length))
123         return 1;
124
125     for (i = j = 0; i < raw_length && j + 1 < hex_length; i++, j += 2) {
126         sprintf(hexed, "%02x", raw[i]);
127         if (!TEST_int_eq(hexed[0], hex_encoded[j])
128                 || !TEST_int_eq(hexed[1], hex_encoded[j + 1]))
129             return 1;
130     }
131
132     return 0;
133 }
134
135 static int test_keylog_output(char *buffer, const SSL *ssl,
136                               const SSL_SESSION *session,
137                               struct sslapitest_log_counts *expected)
138 {
139     char *token = NULL;
140     unsigned char actual_client_random[SSL3_RANDOM_SIZE] = {0};
141     size_t client_random_size = SSL3_RANDOM_SIZE;
142     unsigned char actual_master_key[SSL_MAX_MASTER_KEY_LENGTH] = {0};
143     size_t master_key_size = SSL_MAX_MASTER_KEY_LENGTH;
144     unsigned int rsa_key_exchange_count = 0;
145     unsigned int master_secret_count = 0;
146     unsigned int client_early_secret_count = 0;
147     unsigned int client_handshake_secret_count = 0;
148     unsigned int server_handshake_secret_count = 0;
149     unsigned int client_application_secret_count = 0;
150     unsigned int server_application_secret_count = 0;
151     unsigned int early_exporter_secret_count = 0;
152     unsigned int exporter_secret_count = 0;
153
154     for (token = strtok(buffer, " \n"); token != NULL;
155          token = strtok(NULL, " \n")) {
156         if (strcmp(token, "RSA") == 0) {
157             /*
158              * Premaster secret. Tokens should be: 16 ASCII bytes of
159              * hex-encoded encrypted secret, then the hex-encoded pre-master
160              * secret.
161              */
162             if (!TEST_ptr(token = strtok(NULL, " \n")))
163                 return 0;
164             if (!TEST_size_t_eq(strlen(token), 16))
165                 return 0;
166             if (!TEST_ptr(token = strtok(NULL, " \n")))
167                 return 0;
168             /*
169              * We can't sensibly check the log because the premaster secret is
170              * transient, and OpenSSL doesn't keep hold of it once the master
171              * secret is generated.
172              */
173             rsa_key_exchange_count++;
174         } else if (strcmp(token, "CLIENT_RANDOM") == 0) {
175             /*
176              * Master secret. Tokens should be: 64 ASCII bytes of hex-encoded
177              * client random, then the hex-encoded master secret.
178              */
179             client_random_size = SSL_get_client_random(ssl,
180                                                        actual_client_random,
181                                                        SSL3_RANDOM_SIZE);
182             if (!TEST_size_t_eq(client_random_size, SSL3_RANDOM_SIZE))
183                 return 0;
184
185             if (!TEST_ptr(token = strtok(NULL, " \n")))
186                 return 0;
187             if (!TEST_size_t_eq(strlen(token), 64))
188                 return 0;
189             if (!TEST_false(compare_hex_encoded_buffer(token, 64,
190                                                        actual_client_random,
191                                                        client_random_size)))
192                 return 0;
193
194             if (!TEST_ptr(token = strtok(NULL, " \n")))
195                 return 0;
196             master_key_size = SSL_SESSION_get_master_key(session,
197                                                          actual_master_key,
198                                                          master_key_size);
199             if (!TEST_size_t_ne(master_key_size, 0))
200                 return 0;
201             if (!TEST_false(compare_hex_encoded_buffer(token, strlen(token),
202                                                        actual_master_key,
203                                                        master_key_size)))
204                 return 0;
205             master_secret_count++;
206         } else if (strcmp(token, "CLIENT_EARLY_TRAFFIC_SECRET") == 0
207                     || strcmp(token, "CLIENT_HANDSHAKE_TRAFFIC_SECRET") == 0
208                     || strcmp(token, "SERVER_HANDSHAKE_TRAFFIC_SECRET") == 0
209                     || strcmp(token, "CLIENT_TRAFFIC_SECRET_0") == 0
210                     || strcmp(token, "SERVER_TRAFFIC_SECRET_0") == 0
211                     || strcmp(token, "EARLY_EXPORTER_SECRET") == 0
212                     || strcmp(token, "EXPORTER_SECRET") == 0) {
213             /*
214              * TLSv1.3 secret. Tokens should be: 64 ASCII bytes of hex-encoded
215              * client random, and then the hex-encoded secret. In this case,
216              * we treat all of these secrets identically and then just
217              * distinguish between them when counting what we saw.
218              */
219             if (strcmp(token, "CLIENT_EARLY_TRAFFIC_SECRET") == 0)
220                 client_early_secret_count++;
221             else if (strcmp(token, "CLIENT_HANDSHAKE_TRAFFIC_SECRET") == 0)
222                 client_handshake_secret_count++;
223             else if (strcmp(token, "SERVER_HANDSHAKE_TRAFFIC_SECRET") == 0)
224                 server_handshake_secret_count++;
225             else if (strcmp(token, "CLIENT_TRAFFIC_SECRET_0") == 0)
226                 client_application_secret_count++;
227             else if (strcmp(token, "SERVER_TRAFFIC_SECRET_0") == 0)
228                 server_application_secret_count++;
229             else if (strcmp(token, "EARLY_EXPORTER_SECRET") == 0)
230                 early_exporter_secret_count++;
231             else if (strcmp(token, "EXPORTER_SECRET") == 0)
232                 exporter_secret_count++;
233
234             client_random_size = SSL_get_client_random(ssl,
235                                                        actual_client_random,
236                                                        SSL3_RANDOM_SIZE);
237             if (!TEST_size_t_eq(client_random_size, SSL3_RANDOM_SIZE))
238                 return 0;
239
240             if (!TEST_ptr(token = strtok(NULL, " \n")))
241                 return 0;
242             if (!TEST_size_t_eq(strlen(token), 64))
243                 return 0;
244             if (!TEST_false(compare_hex_encoded_buffer(token, 64,
245                                                        actual_client_random,
246                                                        client_random_size)))
247                 return 0;
248
249             if (!TEST_ptr(token = strtok(NULL, " \n")))
250                 return 0;
251
252             /*
253              * TODO(TLS1.3): test that application traffic secrets are what
254              * we expect */
255         } else {
256             TEST_info("Unexpected token %s\n", token);
257             return 0;
258         }
259     }
260
261     /* Got what we expected? */
262     if (!TEST_size_t_eq(rsa_key_exchange_count,
263                         expected->rsa_key_exchange_count)
264             || !TEST_size_t_eq(master_secret_count,
265                                expected->master_secret_count)
266             || !TEST_size_t_eq(client_early_secret_count,
267                                expected->client_early_secret_count)
268             || !TEST_size_t_eq(client_handshake_secret_count,
269                                expected->client_handshake_secret_count)
270             || !TEST_size_t_eq(server_handshake_secret_count,
271                                expected->server_handshake_secret_count)
272             || !TEST_size_t_eq(client_application_secret_count,
273                                expected->client_application_secret_count)
274             || !TEST_size_t_eq(server_application_secret_count,
275                                expected->server_application_secret_count)
276             || !TEST_size_t_eq(early_exporter_secret_count,
277                                expected->early_exporter_secret_count)
278             || !TEST_size_t_eq(exporter_secret_count,
279                                expected->exporter_secret_count))
280         return 0;
281     return 1;
282 }
283
284 #if !defined(OPENSSL_NO_TLS1_2) || defined(OPENSSL_NO_TLS1_3)
285 static int test_keylog(void)
286 {
287     SSL_CTX *cctx = NULL, *sctx = NULL;
288     SSL *clientssl = NULL, *serverssl = NULL;
289     int testresult = 0;
290     struct sslapitest_log_counts expected = {0};
291
292     /* Clean up logging space */
293     memset(client_log_buffer, 0, sizeof(client_log_buffer));
294     memset(server_log_buffer, 0, sizeof(server_log_buffer));
295     client_log_buffer_index = 0;
296     server_log_buffer_index = 0;
297     error_writing_log = 0;
298
299     if (!TEST_true(create_ssl_ctx_pair(TLS_server_method(),
300                                        TLS_client_method(),
301                                        TLS1_VERSION, TLS_MAX_VERSION,
302                                        &sctx, &cctx, cert, privkey)))
303         return 0;
304
305     /* We cannot log the master secret for TLSv1.3, so we should forbid it. */
306     SSL_CTX_set_options(cctx, SSL_OP_NO_TLSv1_3);
307     SSL_CTX_set_options(sctx, SSL_OP_NO_TLSv1_3);
308
309     /* We also want to ensure that we use RSA-based key exchange. */
310     if (!TEST_true(SSL_CTX_set_cipher_list(cctx, "RSA")))
311         goto end;
312
313     if (!TEST_true(SSL_CTX_get_keylog_callback(cctx) == NULL)
314             || !TEST_true(SSL_CTX_get_keylog_callback(sctx) == NULL))
315         goto end;
316     SSL_CTX_set_keylog_callback(cctx, client_keylog_callback);
317     if (!TEST_true(SSL_CTX_get_keylog_callback(cctx)
318                    == client_keylog_callback))
319         goto end;
320     SSL_CTX_set_keylog_callback(sctx, server_keylog_callback);
321     if (!TEST_true(SSL_CTX_get_keylog_callback(sctx)
322                    == server_keylog_callback))
323         goto end;
324
325     /* Now do a handshake and check that the logs have been written to. */
326     if (!TEST_true(create_ssl_objects(sctx, cctx, &serverssl,
327                                       &clientssl, NULL, NULL))
328             || !TEST_true(create_ssl_connection(serverssl, clientssl,
329                                                 SSL_ERROR_NONE))
330             || !TEST_false(error_writing_log)
331             || !TEST_int_gt(client_log_buffer_index, 0)
332             || !TEST_int_gt(server_log_buffer_index, 0))
333         goto end;
334
335     /*
336      * Now we want to test that our output data was vaguely sensible. We
337      * do that by using strtok and confirming that we have more or less the
338      * data we expect. For both client and server, we expect to see one master
339      * secret. The client should also see a RSA key exchange.
340      */
341     expected.rsa_key_exchange_count = 1;
342     expected.master_secret_count = 1;
343     if (!TEST_true(test_keylog_output(client_log_buffer, clientssl,
344                                       SSL_get_session(clientssl), &expected)))
345         goto end;
346
347     expected.rsa_key_exchange_count = 0;
348     if (!TEST_true(test_keylog_output(server_log_buffer, serverssl,
349                                       SSL_get_session(serverssl), &expected)))
350         goto end;
351
352     testresult = 1;
353
354 end:
355     SSL_free(serverssl);
356     SSL_free(clientssl);
357     SSL_CTX_free(sctx);
358     SSL_CTX_free(cctx);
359
360     return testresult;
361 }
362 #endif
363
364 #ifndef OPENSSL_NO_TLS1_3
365 static int test_keylog_no_master_key(void)
366 {
367     SSL_CTX *cctx = NULL, *sctx = NULL;
368     SSL *clientssl = NULL, *serverssl = NULL;
369     SSL_SESSION *sess = NULL;
370     int testresult = 0;
371     struct sslapitest_log_counts expected = {0};
372     unsigned char buf[1];
373     size_t readbytes, written;
374
375     /* Clean up logging space */
376     memset(client_log_buffer, 0, sizeof(client_log_buffer));
377     memset(server_log_buffer, 0, sizeof(server_log_buffer));
378     client_log_buffer_index = 0;
379     server_log_buffer_index = 0;
380     error_writing_log = 0;
381
382     if (!TEST_true(create_ssl_ctx_pair(TLS_server_method(), TLS_client_method(),
383                                        TLS1_VERSION, TLS_MAX_VERSION,
384                                        &sctx, &cctx, cert, privkey))
385         || !TEST_true(SSL_CTX_set_max_early_data(sctx,
386                                                  SSL3_RT_MAX_PLAIN_LENGTH)))
387         return 0;
388
389     if (!TEST_true(SSL_CTX_get_keylog_callback(cctx) == NULL)
390             || !TEST_true(SSL_CTX_get_keylog_callback(sctx) == NULL))
391         goto end;
392
393     SSL_CTX_set_keylog_callback(cctx, client_keylog_callback);
394     if (!TEST_true(SSL_CTX_get_keylog_callback(cctx)
395                    == client_keylog_callback))
396         goto end;
397
398     SSL_CTX_set_keylog_callback(sctx, server_keylog_callback);
399     if (!TEST_true(SSL_CTX_get_keylog_callback(sctx)
400                    == server_keylog_callback))
401         goto end;
402
403     /* Now do a handshake and check that the logs have been written to. */
404     if (!TEST_true(create_ssl_objects(sctx, cctx, &serverssl,
405                                       &clientssl, NULL, NULL))
406             || !TEST_true(create_ssl_connection(serverssl, clientssl,
407                                                 SSL_ERROR_NONE))
408             || !TEST_false(error_writing_log))
409         goto end;
410
411     /*
412      * Now we want to test that our output data was vaguely sensible. For this
413      * test, we expect no CLIENT_RANDOM entry because it doesn't make sense for
414      * TLSv1.3, but we do expect both client and server to emit keys.
415      */
416     expected.client_handshake_secret_count = 1;
417     expected.server_handshake_secret_count = 1;
418     expected.client_application_secret_count = 1;
419     expected.server_application_secret_count = 1;
420     expected.exporter_secret_count = 1;
421     if (!TEST_true(test_keylog_output(client_log_buffer, clientssl,
422                                       SSL_get_session(clientssl), &expected))
423             || !TEST_true(test_keylog_output(server_log_buffer, serverssl,
424                                              SSL_get_session(serverssl),
425                                              &expected)))
426         goto end;
427
428     /* Terminate old session and resume with early data. */
429     sess = SSL_get1_session(clientssl);
430     SSL_shutdown(clientssl);
431     SSL_shutdown(serverssl);
432     SSL_free(serverssl);
433     SSL_free(clientssl);
434     serverssl = clientssl = NULL;
435
436     /* Reset key log */
437     memset(client_log_buffer, 0, sizeof(client_log_buffer));
438     memset(server_log_buffer, 0, sizeof(server_log_buffer));
439     client_log_buffer_index = 0;
440     server_log_buffer_index = 0;
441
442     if (!TEST_true(create_ssl_objects(sctx, cctx, &serverssl,
443                                       &clientssl, NULL, NULL))
444             || !TEST_true(SSL_set_session(clientssl, sess))
445             /* Here writing 0 length early data is enough. */
446             || !TEST_true(SSL_write_early_data(clientssl, NULL, 0, &written))
447             || !TEST_int_eq(SSL_read_early_data(serverssl, buf, sizeof(buf),
448                                                 &readbytes),
449                             SSL_READ_EARLY_DATA_ERROR)
450             || !TEST_int_eq(SSL_get_early_data_status(serverssl),
451                             SSL_EARLY_DATA_ACCEPTED)
452             || !TEST_true(create_ssl_connection(serverssl, clientssl,
453                           SSL_ERROR_NONE))
454             || !TEST_true(SSL_session_reused(clientssl)))
455         goto end;
456
457     /* In addition to the previous entries, expect early secrets. */
458     expected.client_early_secret_count = 1;
459     expected.early_exporter_secret_count = 1;
460     if (!TEST_true(test_keylog_output(client_log_buffer, clientssl,
461                                       SSL_get_session(clientssl), &expected))
462             || !TEST_true(test_keylog_output(server_log_buffer, serverssl,
463                                              SSL_get_session(serverssl),
464                                              &expected)))
465         goto end;
466
467     testresult = 1;
468
469 end:
470     SSL_SESSION_free(sess);
471     SSL_free(serverssl);
472     SSL_free(clientssl);
473     SSL_CTX_free(sctx);
474     SSL_CTX_free(cctx);
475
476     return testresult;
477 }
478 #endif
479
480 #ifndef OPENSSL_NO_TLS1_2
481 static int full_client_hello_callback(SSL *s, int *al, void *arg)
482 {
483     int *ctr = arg;
484     const unsigned char *p;
485     int *exts;
486     /* We only configure two ciphers, but the SCSV is added automatically. */
487 #ifdef OPENSSL_NO_EC
488     const unsigned char expected_ciphers[] = {0x00, 0x9d, 0x00, 0xff};
489 #else
490     const unsigned char expected_ciphers[] = {0x00, 0x9d, 0xc0,
491                                               0x2c, 0x00, 0xff};
492 #endif
493     const int expected_extensions[] = {
494 #ifndef OPENSSL_NO_EC
495                                        11, 10,
496 #endif
497                                        35, 22, 23, 13};
498     size_t len;
499
500     /* Make sure we can defer processing and get called back. */
501     if ((*ctr)++ == 0)
502         return SSL_CLIENT_HELLO_RETRY;
503
504     len = SSL_client_hello_get0_ciphers(s, &p);
505     if (!TEST_mem_eq(p, len, expected_ciphers, sizeof(expected_ciphers))
506             || !TEST_size_t_eq(
507                        SSL_client_hello_get0_compression_methods(s, &p), 1)
508             || !TEST_int_eq(*p, 0))
509         return SSL_CLIENT_HELLO_ERROR;
510     if (!SSL_client_hello_get1_extensions_present(s, &exts, &len))
511         return SSL_CLIENT_HELLO_ERROR;
512     if (len != OSSL_NELEM(expected_extensions) ||
513         memcmp(exts, expected_extensions, len * sizeof(*exts)) != 0) {
514         printf("ClientHello callback expected extensions mismatch\n");
515         OPENSSL_free(exts);
516         return SSL_CLIENT_HELLO_ERROR;
517     }
518     OPENSSL_free(exts);
519     return SSL_CLIENT_HELLO_SUCCESS;
520 }
521
522 static int test_client_hello_cb(void)
523 {
524     SSL_CTX *cctx = NULL, *sctx = NULL;
525     SSL *clientssl = NULL, *serverssl = NULL;
526     int testctr = 0, testresult = 0;
527
528     if (!TEST_true(create_ssl_ctx_pair(TLS_server_method(), TLS_client_method(),
529                                        TLS1_VERSION, TLS_MAX_VERSION,
530                                        &sctx, &cctx, cert, privkey)))
531         goto end;
532     SSL_CTX_set_client_hello_cb(sctx, full_client_hello_callback, &testctr);
533
534     /* The gimpy cipher list we configure can't do TLS 1.3. */
535     SSL_CTX_set_max_proto_version(cctx, TLS1_2_VERSION);
536
537     if (!TEST_true(SSL_CTX_set_cipher_list(cctx,
538                         "AES256-GCM-SHA384:ECDHE-ECDSA-AES256-GCM-SHA384"))
539             || !TEST_true(create_ssl_objects(sctx, cctx, &serverssl,
540                                              &clientssl, NULL, NULL))
541             || !TEST_false(create_ssl_connection(serverssl, clientssl,
542                         SSL_ERROR_WANT_CLIENT_HELLO_CB))
543                 /*
544                  * Passing a -1 literal is a hack since
545                  * the real value was lost.
546                  * */
547             || !TEST_int_eq(SSL_get_error(serverssl, -1),
548                             SSL_ERROR_WANT_CLIENT_HELLO_CB)
549             || !TEST_true(create_ssl_connection(serverssl, clientssl,
550                                                 SSL_ERROR_NONE)))
551         goto end;
552
553     testresult = 1;
554
555 end:
556     SSL_free(serverssl);
557     SSL_free(clientssl);
558     SSL_CTX_free(sctx);
559     SSL_CTX_free(cctx);
560
561     return testresult;
562 }
563 #endif
564
565 static int execute_test_large_message(const SSL_METHOD *smeth,
566                                       const SSL_METHOD *cmeth,
567                                       int min_version, int max_version,
568                                       int read_ahead)
569 {
570     SSL_CTX *cctx = NULL, *sctx = NULL;
571     SSL *clientssl = NULL, *serverssl = NULL;
572     int testresult = 0;
573     int i;
574     BIO *certbio = NULL;
575     X509 *chaincert = NULL;
576     int certlen;
577
578     if (!TEST_ptr(certbio = BIO_new_file(cert, "r")))
579         goto end;
580     chaincert = PEM_read_bio_X509(certbio, NULL, NULL, NULL);
581     BIO_free(certbio);
582     certbio = NULL;
583     if (!TEST_ptr(chaincert))
584         goto end;
585
586     if (!TEST_true(create_ssl_ctx_pair(smeth, cmeth, min_version, max_version,
587                                        &sctx, &cctx, cert, privkey)))
588         goto end;
589
590     if (read_ahead) {
591         /*
592          * Test that read_ahead works correctly when dealing with large
593          * records
594          */
595         SSL_CTX_set_read_ahead(cctx, 1);
596     }
597
598     /*
599      * We assume the supplied certificate is big enough so that if we add
600      * NUM_EXTRA_CERTS it will make the overall message large enough. The
601      * default buffer size is requested to be 16k, but due to the way BUF_MEM
602      * works, it ends up allocating a little over 21k (16 * 4/3). So, in this
603      * test we need to have a message larger than that.
604      */
605     certlen = i2d_X509(chaincert, NULL);
606     OPENSSL_assert(certlen * NUM_EXTRA_CERTS >
607                    (SSL3_RT_MAX_PLAIN_LENGTH * 4) / 3);
608     for (i = 0; i < NUM_EXTRA_CERTS; i++) {
609         if (!X509_up_ref(chaincert))
610             goto end;
611         if (!SSL_CTX_add_extra_chain_cert(sctx, chaincert)) {
612             X509_free(chaincert);
613             goto end;
614         }
615     }
616
617     if (!TEST_true(create_ssl_objects(sctx, cctx, &serverssl, &clientssl,
618                                       NULL, NULL))
619             || !TEST_true(create_ssl_connection(serverssl, clientssl,
620                                                 SSL_ERROR_NONE)))
621         goto end;
622
623     /*
624      * Calling SSL_clear() first is not required but this tests that SSL_clear()
625      * doesn't leak (when using enable-crypto-mdebug).
626      */
627     if (!TEST_true(SSL_clear(serverssl)))
628         goto end;
629
630     testresult = 1;
631  end:
632     X509_free(chaincert);
633     SSL_free(serverssl);
634     SSL_free(clientssl);
635     SSL_CTX_free(sctx);
636     SSL_CTX_free(cctx);
637
638     return testresult;
639 }
640
641 static int test_large_message_tls(void)
642 {
643     return execute_test_large_message(TLS_server_method(), TLS_client_method(),
644                                       TLS1_VERSION, TLS_MAX_VERSION,
645                                       0);
646 }
647
648 static int test_large_message_tls_read_ahead(void)
649 {
650     return execute_test_large_message(TLS_server_method(), TLS_client_method(),
651                                       TLS1_VERSION, TLS_MAX_VERSION,
652                                       1);
653 }
654
655 #ifndef OPENSSL_NO_DTLS
656 static int test_large_message_dtls(void)
657 {
658     /*
659      * read_ahead is not relevant to DTLS because DTLS always acts as if
660      * read_ahead is set.
661      */
662     return execute_test_large_message(DTLS_server_method(),
663                                       DTLS_client_method(),
664                                       DTLS1_VERSION, DTLS_MAX_VERSION,
665                                       0);
666 }
667 #endif
668
669 #ifndef OPENSSL_NO_OCSP
670 static int ocsp_server_cb(SSL *s, void *arg)
671 {
672     int *argi = (int *)arg;
673     unsigned char *copy = NULL;
674     STACK_OF(OCSP_RESPID) *ids = NULL;
675     OCSP_RESPID *id = NULL;
676
677     if (*argi == 2) {
678         /* In this test we are expecting exactly 1 OCSP_RESPID */
679         SSL_get_tlsext_status_ids(s, &ids);
680         if (ids == NULL || sk_OCSP_RESPID_num(ids) != 1)
681             return SSL_TLSEXT_ERR_ALERT_FATAL;
682
683         id = sk_OCSP_RESPID_value(ids, 0);
684         if (id == NULL || !OCSP_RESPID_match(id, ocspcert))
685             return SSL_TLSEXT_ERR_ALERT_FATAL;
686     } else if (*argi != 1) {
687         return SSL_TLSEXT_ERR_ALERT_FATAL;
688     }
689
690     if (!TEST_ptr(copy = OPENSSL_memdup(orespder, sizeof(orespder))))
691         return SSL_TLSEXT_ERR_ALERT_FATAL;
692
693     SSL_set_tlsext_status_ocsp_resp(s, copy, sizeof(orespder));
694     ocsp_server_called = 1;
695     return SSL_TLSEXT_ERR_OK;
696 }
697
698 static int ocsp_client_cb(SSL *s, void *arg)
699 {
700     int *argi = (int *)arg;
701     const unsigned char *respderin;
702     size_t len;
703
704     if (*argi != 1 && *argi != 2)
705         return 0;
706
707     len = SSL_get_tlsext_status_ocsp_resp(s, &respderin);
708     if (!TEST_mem_eq(orespder, len, respderin, len))
709         return 0;
710
711     ocsp_client_called = 1;
712     return 1;
713 }
714
715 static int test_tlsext_status_type(void)
716 {
717     SSL_CTX *cctx = NULL, *sctx = NULL;
718     SSL *clientssl = NULL, *serverssl = NULL;
719     int testresult = 0;
720     STACK_OF(OCSP_RESPID) *ids = NULL;
721     OCSP_RESPID *id = NULL;
722     BIO *certbio = NULL;
723
724     if (!create_ssl_ctx_pair(TLS_server_method(), TLS_client_method(),
725                              TLS1_VERSION, TLS_MAX_VERSION,
726                              &sctx, &cctx, cert, privkey))
727         return 0;
728
729     if (SSL_CTX_get_tlsext_status_type(cctx) != -1)
730         goto end;
731
732     /* First just do various checks getting and setting tlsext_status_type */
733
734     clientssl = SSL_new(cctx);
735     if (!TEST_int_eq(SSL_get_tlsext_status_type(clientssl), -1)
736             || !TEST_true(SSL_set_tlsext_status_type(clientssl,
737                                                       TLSEXT_STATUSTYPE_ocsp))
738             || !TEST_int_eq(SSL_get_tlsext_status_type(clientssl),
739                             TLSEXT_STATUSTYPE_ocsp))
740         goto end;
741
742     SSL_free(clientssl);
743     clientssl = NULL;
744
745     if (!SSL_CTX_set_tlsext_status_type(cctx, TLSEXT_STATUSTYPE_ocsp)
746      || SSL_CTX_get_tlsext_status_type(cctx) != TLSEXT_STATUSTYPE_ocsp)
747         goto end;
748
749     clientssl = SSL_new(cctx);
750     if (SSL_get_tlsext_status_type(clientssl) != TLSEXT_STATUSTYPE_ocsp)
751         goto end;
752     SSL_free(clientssl);
753     clientssl = NULL;
754
755     /*
756      * Now actually do a handshake and check OCSP information is exchanged and
757      * the callbacks get called
758      */
759     SSL_CTX_set_tlsext_status_cb(cctx, ocsp_client_cb);
760     SSL_CTX_set_tlsext_status_arg(cctx, &cdummyarg);
761     SSL_CTX_set_tlsext_status_cb(sctx, ocsp_server_cb);
762     SSL_CTX_set_tlsext_status_arg(sctx, &cdummyarg);
763     if (!TEST_true(create_ssl_objects(sctx, cctx, &serverssl,
764                                       &clientssl, NULL, NULL))
765             || !TEST_true(create_ssl_connection(serverssl, clientssl,
766                                                 SSL_ERROR_NONE))
767             || !TEST_true(ocsp_client_called)
768             || !TEST_true(ocsp_server_called))
769         goto end;
770     SSL_free(serverssl);
771     SSL_free(clientssl);
772     serverssl = NULL;
773     clientssl = NULL;
774
775     /* Try again but this time force the server side callback to fail */
776     ocsp_client_called = 0;
777     ocsp_server_called = 0;
778     cdummyarg = 0;
779     if (!TEST_true(create_ssl_objects(sctx, cctx, &serverssl,
780                                       &clientssl, NULL, NULL))
781                 /* This should fail because the callback will fail */
782             || !TEST_false(create_ssl_connection(serverssl, clientssl,
783                                                  SSL_ERROR_NONE))
784             || !TEST_false(ocsp_client_called)
785             || !TEST_false(ocsp_server_called))
786         goto end;
787     SSL_free(serverssl);
788     SSL_free(clientssl);
789     serverssl = NULL;
790     clientssl = NULL;
791
792     /*
793      * This time we'll get the client to send an OCSP_RESPID that it will
794      * accept.
795      */
796     ocsp_client_called = 0;
797     ocsp_server_called = 0;
798     cdummyarg = 2;
799     if (!TEST_true(create_ssl_objects(sctx, cctx, &serverssl,
800                                       &clientssl, NULL, NULL)))
801         goto end;
802
803     /*
804      * We'll just use any old cert for this test - it doesn't have to be an OCSP
805      * specific one. We'll use the server cert.
806      */
807     if (!TEST_ptr(certbio = BIO_new_file(cert, "r"))
808             || !TEST_ptr(id = OCSP_RESPID_new())
809             || !TEST_ptr(ids = sk_OCSP_RESPID_new_null())
810             || !TEST_ptr(ocspcert = PEM_read_bio_X509(certbio,
811                                                       NULL, NULL, NULL))
812             || !TEST_true(OCSP_RESPID_set_by_key(id, ocspcert))
813             || !TEST_true(sk_OCSP_RESPID_push(ids, id)))
814         goto end;
815     id = NULL;
816     SSL_set_tlsext_status_ids(clientssl, ids);
817     /* Control has been transferred */
818     ids = NULL;
819
820     BIO_free(certbio);
821     certbio = NULL;
822
823     if (!TEST_true(create_ssl_connection(serverssl, clientssl,
824                                          SSL_ERROR_NONE))
825             || !TEST_true(ocsp_client_called)
826             || !TEST_true(ocsp_server_called))
827         goto end;
828
829     testresult = 1;
830
831  end:
832     SSL_free(serverssl);
833     SSL_free(clientssl);
834     SSL_CTX_free(sctx);
835     SSL_CTX_free(cctx);
836     sk_OCSP_RESPID_pop_free(ids, OCSP_RESPID_free);
837     OCSP_RESPID_free(id);
838     BIO_free(certbio);
839     X509_free(ocspcert);
840     ocspcert = NULL;
841
842     return testresult;
843 }
844 #endif
845
846 #if !defined(OPENSSL_NO_TLS1_3) || !defined(OPENSSL_NO_TLS1_2)
847 static int new_called, remove_called, get_called;
848
849 static int new_session_cb(SSL *ssl, SSL_SESSION *sess)
850 {
851     new_called++;
852     /*
853      * sess has been up-refed for us, but we don't actually need it so free it
854      * immediately.
855      */
856     SSL_SESSION_free(sess);
857     return 1;
858 }
859
860 static void remove_session_cb(SSL_CTX *ctx, SSL_SESSION *sess)
861 {
862     remove_called++;
863 }
864
865 static SSL_SESSION *get_sess_val = NULL;
866
867 static SSL_SESSION *get_session_cb(SSL *ssl, const unsigned char *id, int len,
868                                    int *copy)
869 {
870     get_called++;
871     *copy = 1;
872     return get_sess_val;
873 }
874
875 static int execute_test_session(int maxprot, int use_int_cache,
876                                 int use_ext_cache)
877 {
878     SSL_CTX *sctx = NULL, *cctx = NULL;
879     SSL *serverssl1 = NULL, *clientssl1 = NULL;
880     SSL *serverssl2 = NULL, *clientssl2 = NULL;
881 # ifndef OPENSSL_NO_TLS1_1
882     SSL *serverssl3 = NULL, *clientssl3 = NULL;
883 # endif
884     SSL_SESSION *sess1 = NULL, *sess2 = NULL;
885     int testresult = 0, numnewsesstick = 1;
886
887     new_called = remove_called = 0;
888
889     /* TLSv1.3 sends 2 NewSessionTickets */
890     if (maxprot == TLS1_3_VERSION)
891         numnewsesstick = 2;
892
893     if (!TEST_true(create_ssl_ctx_pair(TLS_server_method(), TLS_client_method(),
894                                        TLS1_VERSION, TLS_MAX_VERSION,
895                                        &sctx, &cctx, cert, privkey)))
896         return 0;
897
898     /*
899      * Only allow the max protocol version so we can force a connection failure
900      * later
901      */
902     SSL_CTX_set_min_proto_version(cctx, maxprot);
903     SSL_CTX_set_max_proto_version(cctx, maxprot);
904
905     /* Set up session cache */
906     if (use_ext_cache) {
907         SSL_CTX_sess_set_new_cb(cctx, new_session_cb);
908         SSL_CTX_sess_set_remove_cb(cctx, remove_session_cb);
909     }
910     if (use_int_cache) {
911         /* Also covers instance where both are set */
912         SSL_CTX_set_session_cache_mode(cctx, SSL_SESS_CACHE_CLIENT);
913     } else {
914         SSL_CTX_set_session_cache_mode(cctx,
915                                        SSL_SESS_CACHE_CLIENT
916                                        | SSL_SESS_CACHE_NO_INTERNAL_STORE);
917     }
918
919     if (!TEST_true(create_ssl_objects(sctx, cctx, &serverssl1, &clientssl1,
920                                       NULL, NULL))
921             || !TEST_true(create_ssl_connection(serverssl1, clientssl1,
922                                                 SSL_ERROR_NONE))
923             || !TEST_ptr(sess1 = SSL_get1_session(clientssl1)))
924         goto end;
925
926     /* Should fail because it should already be in the cache */
927     if (use_int_cache && !TEST_false(SSL_CTX_add_session(cctx, sess1)))
928         goto end;
929     if (use_ext_cache
930             && (!TEST_int_eq(new_called, numnewsesstick)
931
932                 || !TEST_int_eq(remove_called, 0)))
933         goto end;
934
935     new_called = remove_called = 0;
936     if (!TEST_true(create_ssl_objects(sctx, cctx, &serverssl2,
937                                       &clientssl2, NULL, NULL))
938             || !TEST_true(SSL_set_session(clientssl2, sess1))
939             || !TEST_true(create_ssl_connection(serverssl2, clientssl2,
940                                                 SSL_ERROR_NONE))
941             || !TEST_true(SSL_session_reused(clientssl2)))
942         goto end;
943
944     if (maxprot == TLS1_3_VERSION) {
945         /*
946          * In TLSv1.3 we should have created a new session even though we have
947          * resumed.
948          */
949         if (use_ext_cache
950                 && (!TEST_int_eq(new_called, 1)
951                     || !TEST_int_eq(remove_called, 0)))
952             goto end;
953     } else {
954         /*
955          * In TLSv1.2 we expect to have resumed so no sessions added or
956          * removed.
957          */
958         if (use_ext_cache
959                 && (!TEST_int_eq(new_called, 0)
960                     || !TEST_int_eq(remove_called, 0)))
961             goto end;
962     }
963
964     SSL_SESSION_free(sess1);
965     if (!TEST_ptr(sess1 = SSL_get1_session(clientssl2)))
966         goto end;
967     shutdown_ssl_connection(serverssl2, clientssl2);
968     serverssl2 = clientssl2 = NULL;
969
970     new_called = remove_called = 0;
971     if (!TEST_true(create_ssl_objects(sctx, cctx, &serverssl2,
972                                       &clientssl2, NULL, NULL))
973             || !TEST_true(create_ssl_connection(serverssl2, clientssl2,
974                                                 SSL_ERROR_NONE)))
975         goto end;
976
977     if (!TEST_ptr(sess2 = SSL_get1_session(clientssl2)))
978         goto end;
979
980     if (use_ext_cache
981             && (!TEST_int_eq(new_called, numnewsesstick)
982                 || !TEST_int_eq(remove_called, 0)))
983         goto end;
984
985     new_called = remove_called = 0;
986     /*
987      * This should clear sess2 from the cache because it is a "bad" session.
988      * See SSL_set_session() documentation.
989      */
990     if (!TEST_true(SSL_set_session(clientssl2, sess1)))
991         goto end;
992     if (use_ext_cache
993             && (!TEST_int_eq(new_called, 0) || !TEST_int_eq(remove_called, 1)))
994         goto end;
995     if (!TEST_ptr_eq(SSL_get_session(clientssl2), sess1))
996         goto end;
997
998     if (use_int_cache) {
999         /* Should succeeded because it should not already be in the cache */
1000         if (!TEST_true(SSL_CTX_add_session(cctx, sess2))
1001                 || !TEST_true(SSL_CTX_remove_session(cctx, sess2)))
1002             goto end;
1003     }
1004
1005     new_called = remove_called = 0;
1006     /* This shouldn't be in the cache so should fail */
1007     if (!TEST_false(SSL_CTX_remove_session(cctx, sess2)))
1008         goto end;
1009
1010     if (use_ext_cache
1011             && (!TEST_int_eq(new_called, 0) || !TEST_int_eq(remove_called, 1)))
1012         goto end;
1013
1014 # if !defined(OPENSSL_NO_TLS1_1)
1015     new_called = remove_called = 0;
1016     /* Force a connection failure */
1017     SSL_CTX_set_max_proto_version(sctx, TLS1_1_VERSION);
1018     if (!TEST_true(create_ssl_objects(sctx, cctx, &serverssl3,
1019                                       &clientssl3, NULL, NULL))
1020             || !TEST_true(SSL_set_session(clientssl3, sess1))
1021             /* This should fail because of the mismatched protocol versions */
1022             || !TEST_false(create_ssl_connection(serverssl3, clientssl3,
1023                                                  SSL_ERROR_NONE)))
1024         goto end;
1025
1026     /* We should have automatically removed the session from the cache */
1027     if (use_ext_cache
1028             && (!TEST_int_eq(new_called, 0) || !TEST_int_eq(remove_called, 1)))
1029         goto end;
1030
1031     /* Should succeed because it should not already be in the cache */
1032     if (use_int_cache && !TEST_true(SSL_CTX_add_session(cctx, sess2)))
1033         goto end;
1034 # endif
1035
1036     /* Now do some tests for server side caching */
1037     if (use_ext_cache) {
1038         SSL_CTX_sess_set_new_cb(cctx, NULL);
1039         SSL_CTX_sess_set_remove_cb(cctx, NULL);
1040         SSL_CTX_sess_set_new_cb(sctx, new_session_cb);
1041         SSL_CTX_sess_set_remove_cb(sctx, remove_session_cb);
1042         SSL_CTX_sess_set_get_cb(sctx, get_session_cb);
1043         get_sess_val = NULL;
1044     }
1045
1046     SSL_CTX_set_session_cache_mode(cctx, 0);
1047     /* Internal caching is the default on the server side */
1048     if (!use_int_cache)
1049         SSL_CTX_set_session_cache_mode(sctx,
1050                                        SSL_SESS_CACHE_SERVER
1051                                        | SSL_SESS_CACHE_NO_INTERNAL_STORE);
1052
1053     SSL_free(serverssl1);
1054     SSL_free(clientssl1);
1055     serverssl1 = clientssl1 = NULL;
1056     SSL_free(serverssl2);
1057     SSL_free(clientssl2);
1058     serverssl2 = clientssl2 = NULL;
1059     SSL_SESSION_free(sess1);
1060     sess1 = NULL;
1061     SSL_SESSION_free(sess2);
1062     sess2 = NULL;
1063
1064     SSL_CTX_set_max_proto_version(sctx, maxprot);
1065     if (maxprot == TLS1_2_VERSION)
1066         SSL_CTX_set_options(sctx, SSL_OP_NO_TICKET);
1067     new_called = remove_called = get_called = 0;
1068     if (!TEST_true(create_ssl_objects(sctx, cctx, &serverssl1, &clientssl1,
1069                                       NULL, NULL))
1070             || !TEST_true(create_ssl_connection(serverssl1, clientssl1,
1071                                                 SSL_ERROR_NONE))
1072             || !TEST_ptr(sess1 = SSL_get1_session(clientssl1))
1073             || !TEST_ptr(sess2 = SSL_get1_session(serverssl1)))
1074         goto end;
1075
1076     if (use_int_cache) {
1077         if (maxprot == TLS1_3_VERSION && !use_ext_cache) {
1078             /*
1079              * In TLSv1.3 it should not have been added to the internal cache,
1080              * except in the case where we also have an external cache (in that
1081              * case it gets added to the cache in order to generate remove
1082              * events after timeout).
1083              */
1084             if (!TEST_false(SSL_CTX_remove_session(sctx, sess2)))
1085                 goto end;
1086         } else {
1087             /* Should fail because it should already be in the cache */
1088             if (!TEST_false(SSL_CTX_add_session(sctx, sess2)))
1089                 goto end;
1090         }
1091     }
1092
1093     if (use_ext_cache) {
1094         SSL_SESSION *tmp = sess2;
1095
1096         if (!TEST_int_eq(new_called, numnewsesstick)
1097                 || !TEST_int_eq(remove_called, 0)
1098                 || !TEST_int_eq(get_called, 0))
1099             goto end;
1100         /*
1101          * Delete the session from the internal cache to force a lookup from
1102          * the external cache. We take a copy first because
1103          * SSL_CTX_remove_session() also marks the session as non-resumable.
1104          */
1105         if (use_int_cache && maxprot != TLS1_3_VERSION) {
1106             if (!TEST_ptr(tmp = SSL_SESSION_dup(sess2))
1107                     || !TEST_true(SSL_CTX_remove_session(sctx, sess2)))
1108                 goto end;
1109             SSL_SESSION_free(sess2);
1110         }
1111         sess2 = tmp;
1112     }
1113
1114     new_called = remove_called = get_called = 0;
1115     get_sess_val = sess2;
1116     if (!TEST_true(create_ssl_objects(sctx, cctx, &serverssl2,
1117                                       &clientssl2, NULL, NULL))
1118             || !TEST_true(SSL_set_session(clientssl2, sess1))
1119             || !TEST_true(create_ssl_connection(serverssl2, clientssl2,
1120                                                 SSL_ERROR_NONE))
1121             || !TEST_true(SSL_session_reused(clientssl2)))
1122         goto end;
1123
1124     if (use_ext_cache) {
1125         if (!TEST_int_eq(remove_called, 0))
1126             goto end;
1127
1128         if (maxprot == TLS1_3_VERSION) {
1129             if (!TEST_int_eq(new_called, 1)
1130                     || !TEST_int_eq(get_called, 0))
1131                 goto end;
1132         } else {
1133             if (!TEST_int_eq(new_called, 0)
1134                     || !TEST_int_eq(get_called, 1))
1135                 goto end;
1136         }
1137     }
1138
1139     testresult = 1;
1140
1141  end:
1142     SSL_free(serverssl1);
1143     SSL_free(clientssl1);
1144     SSL_free(serverssl2);
1145     SSL_free(clientssl2);
1146 # ifndef OPENSSL_NO_TLS1_1
1147     SSL_free(serverssl3);
1148     SSL_free(clientssl3);
1149 # endif
1150     SSL_SESSION_free(sess1);
1151     SSL_SESSION_free(sess2);
1152     SSL_CTX_free(sctx);
1153     SSL_CTX_free(cctx);
1154
1155     return testresult;
1156 }
1157 #endif /* !defined(OPENSSL_NO_TLS1_3) || !defined(OPENSSL_NO_TLS1_2) */
1158
1159 static int test_session_with_only_int_cache(void)
1160 {
1161 #ifndef OPENSSL_NO_TLS1_3
1162     if (!execute_test_session(TLS1_3_VERSION, 1, 0))
1163         return 0;
1164 #endif
1165
1166 #ifndef OPENSSL_NO_TLS1_2
1167     return execute_test_session(TLS1_2_VERSION, 1, 0);
1168 #else
1169     return 1;
1170 #endif
1171 }
1172
1173 static int test_session_with_only_ext_cache(void)
1174 {
1175 #ifndef OPENSSL_NO_TLS1_3
1176     if (!execute_test_session(TLS1_3_VERSION, 0, 1))
1177         return 0;
1178 #endif
1179
1180 #ifndef OPENSSL_NO_TLS1_2
1181     return execute_test_session(TLS1_2_VERSION, 0, 1);
1182 #else
1183     return 1;
1184 #endif
1185 }
1186
1187 static int test_session_with_both_cache(void)
1188 {
1189 #ifndef OPENSSL_NO_TLS1_3
1190     if (!execute_test_session(TLS1_3_VERSION, 1, 1))
1191         return 0;
1192 #endif
1193
1194 #ifndef OPENSSL_NO_TLS1_2
1195     return execute_test_session(TLS1_2_VERSION, 1, 1);
1196 #else
1197     return 1;
1198 #endif
1199 }
1200
1201 #ifndef OPENSSL_NO_TLS1_3
1202 static SSL_SESSION *sesscache[6];
1203 static int do_cache;
1204
1205 static int new_cachesession_cb(SSL *ssl, SSL_SESSION *sess)
1206 {
1207     if (do_cache) {
1208         sesscache[new_called] = sess;
1209     } else {
1210         /* We don't need the reference to the session, so free it */
1211         SSL_SESSION_free(sess);
1212     }
1213     new_called++;
1214
1215     return 1;
1216 }
1217
1218 static int post_handshake_verify(SSL *sssl, SSL *cssl)
1219 {
1220     SSL_set_verify(sssl, SSL_VERIFY_PEER, NULL);
1221     if (!TEST_true(SSL_verify_client_post_handshake(sssl)))
1222         return 0;
1223
1224     /* Start handshake on the server and client */
1225     if (!TEST_int_eq(SSL_do_handshake(sssl), 1)
1226             || !TEST_int_le(SSL_read(cssl, NULL, 0), 0)
1227             || !TEST_int_le(SSL_read(sssl, NULL, 0), 0)
1228             || !TEST_true(create_ssl_connection(sssl, cssl,
1229                                                 SSL_ERROR_NONE)))
1230         return 0;
1231
1232     return 1;
1233 }
1234
1235 static int test_tickets(int idx)
1236 {
1237     SSL_CTX *sctx = NULL, *cctx = NULL;
1238     SSL *serverssl = NULL, *clientssl = NULL;
1239     int testresult = 0, i;
1240     size_t j;
1241
1242     /* idx is the test number, but also the number of tickets we want */
1243
1244     new_called = 0;
1245     do_cache = 1;
1246
1247     if (!TEST_true(create_ssl_ctx_pair(TLS_server_method(), TLS_client_method(),
1248                                        TLS1_VERSION, TLS_MAX_VERSION, &sctx,
1249                                        &cctx, cert, privkey))
1250             || !TEST_true(SSL_CTX_set_num_tickets(sctx, idx)))
1251         goto end;
1252
1253     SSL_CTX_set_session_cache_mode(cctx, SSL_SESS_CACHE_CLIENT
1254                                          | SSL_SESS_CACHE_NO_INTERNAL_STORE);
1255     SSL_CTX_sess_set_new_cb(cctx, new_cachesession_cb);
1256
1257     if (!TEST_true(create_ssl_objects(sctx, cctx, &serverssl,
1258                                           &clientssl, NULL, NULL)))
1259         goto end;
1260
1261     SSL_force_post_handshake_auth(clientssl);
1262
1263     if (!TEST_true(create_ssl_connection(serverssl, clientssl,
1264                                                 SSL_ERROR_NONE))
1265                /* Check we got the number of tickets we were expecting */
1266             || !TEST_int_eq(idx, new_called))
1267         goto end;
1268
1269     /* After a post-handshake authentication we should get new tickets issued */
1270     if (!post_handshake_verify(serverssl, clientssl)
1271             || !TEST_int_eq(idx * 2, new_called))
1272         goto end;
1273
1274     SSL_shutdown(clientssl);
1275     SSL_shutdown(serverssl);
1276     SSL_free(serverssl);
1277     SSL_free(clientssl);
1278     serverssl = clientssl = NULL;
1279
1280     /* Stop caching sessions - just count them */
1281     do_cache = 0;
1282
1283     /* Test that we can resume with all the tickets we got given */
1284     for (i = 0; i < idx * 2; i++) {
1285         new_called = 0;
1286         if (!TEST_true(create_ssl_objects(sctx, cctx, &serverssl,
1287                                               &clientssl, NULL, NULL))
1288                 || !TEST_true(SSL_set_session(clientssl, sesscache[i])))
1289             goto end;
1290
1291         SSL_force_post_handshake_auth(clientssl);
1292
1293         if (!TEST_true(create_ssl_connection(serverssl, clientssl,
1294                                                     SSL_ERROR_NONE))
1295                 || !TEST_true(SSL_session_reused(clientssl))
1296                    /* Following a resumption we only get 1 ticket */
1297                 || !TEST_int_eq(new_called, 1))
1298             goto end;
1299
1300         new_called = 0;
1301         /* After a post-handshake authentication we should get 1 new ticket */
1302         if (!post_handshake_verify(serverssl, clientssl)
1303                 || !TEST_int_eq(new_called, 1))
1304             goto end;
1305
1306         SSL_shutdown(clientssl);
1307         SSL_shutdown(serverssl);
1308         SSL_free(serverssl);
1309         SSL_free(clientssl);
1310         serverssl = clientssl = NULL;
1311         SSL_SESSION_free(sesscache[i]);
1312         sesscache[i] = NULL;
1313     }
1314
1315     testresult = 1;
1316
1317  end:
1318     SSL_free(serverssl);
1319     SSL_free(clientssl);
1320     for (j = 0; j < OSSL_NELEM(sesscache); j++) {
1321         SSL_SESSION_free(sesscache[j]);
1322         sesscache[j] = NULL;
1323     }
1324     SSL_CTX_free(sctx);
1325     SSL_CTX_free(cctx);
1326
1327     return testresult;
1328 }
1329 #endif
1330
1331 #define USE_NULL            0
1332 #define USE_BIO_1           1
1333 #define USE_BIO_2           2
1334 #define USE_DEFAULT         3
1335
1336 #define CONNTYPE_CONNECTION_SUCCESS  0
1337 #define CONNTYPE_CONNECTION_FAIL     1
1338 #define CONNTYPE_NO_CONNECTION       2
1339
1340 #define TOTAL_NO_CONN_SSL_SET_BIO_TESTS         (3 * 3 * 3 * 3)
1341 #define TOTAL_CONN_SUCCESS_SSL_SET_BIO_TESTS    (2 * 2)
1342 #if !defined(OPENSSL_NO_TLS1_3) && !defined(OPENSSL_NO_TLS1_2)
1343 # define TOTAL_CONN_FAIL_SSL_SET_BIO_TESTS       (2 * 2)
1344 #else
1345 # define TOTAL_CONN_FAIL_SSL_SET_BIO_TESTS       0
1346 #endif
1347
1348 #define TOTAL_SSL_SET_BIO_TESTS TOTAL_NO_CONN_SSL_SET_BIO_TESTS \
1349                                 + TOTAL_CONN_SUCCESS_SSL_SET_BIO_TESTS \
1350                                 + TOTAL_CONN_FAIL_SSL_SET_BIO_TESTS
1351
1352 static void setupbio(BIO **res, BIO *bio1, BIO *bio2, int type)
1353 {
1354     switch (type) {
1355     case USE_NULL:
1356         *res = NULL;
1357         break;
1358     case USE_BIO_1:
1359         *res = bio1;
1360         break;
1361     case USE_BIO_2:
1362         *res = bio2;
1363         break;
1364     }
1365 }
1366
1367
1368 /*
1369  * Tests calls to SSL_set_bio() under various conditions.
1370  *
1371  * For the first 3 * 3 * 3 * 3 = 81 tests we do 2 calls to SSL_set_bio() with
1372  * various combinations of valid BIOs or NULL being set for the rbio/wbio. We
1373  * then do more tests where we create a successful connection first using our
1374  * standard connection setup functions, and then call SSL_set_bio() with
1375  * various combinations of valid BIOs or NULL. We then repeat these tests
1376  * following a failed connection. In this last case we are looking to check that
1377  * SSL_set_bio() functions correctly in the case where s->bbio is not NULL.
1378  */
1379 static int test_ssl_set_bio(int idx)
1380 {
1381     SSL_CTX *sctx = NULL, *cctx = NULL;
1382     BIO *bio1 = NULL;
1383     BIO *bio2 = NULL;
1384     BIO *irbio = NULL, *iwbio = NULL, *nrbio = NULL, *nwbio = NULL;
1385     SSL *serverssl = NULL, *clientssl = NULL;
1386     int initrbio, initwbio, newrbio, newwbio, conntype;
1387     int testresult = 0;
1388
1389     if (idx < TOTAL_NO_CONN_SSL_SET_BIO_TESTS) {
1390         initrbio = idx % 3;
1391         idx /= 3;
1392         initwbio = idx % 3;
1393         idx /= 3;
1394         newrbio = idx % 3;
1395         idx /= 3;
1396         newwbio = idx % 3;
1397         conntype = CONNTYPE_NO_CONNECTION;
1398     } else {
1399         idx -= TOTAL_NO_CONN_SSL_SET_BIO_TESTS;
1400         initrbio = initwbio = USE_DEFAULT;
1401         newrbio = idx % 2;
1402         idx /= 2;
1403         newwbio = idx % 2;
1404         idx /= 2;
1405         conntype = idx % 2;
1406     }
1407
1408     if (!TEST_true(create_ssl_ctx_pair(TLS_server_method(), TLS_client_method(),
1409                                        TLS1_VERSION, TLS_MAX_VERSION,
1410                                        &sctx, &cctx, cert, privkey)))
1411         goto end;
1412
1413     if (conntype == CONNTYPE_CONNECTION_FAIL) {
1414         /*
1415          * We won't ever get here if either TLSv1.3 or TLSv1.2 is disabled
1416          * because we reduced the number of tests in the definition of
1417          * TOTAL_CONN_FAIL_SSL_SET_BIO_TESTS to avoid this scenario. By setting
1418          * mismatched protocol versions we will force a connection failure.
1419          */
1420         SSL_CTX_set_min_proto_version(sctx, TLS1_3_VERSION);
1421         SSL_CTX_set_max_proto_version(cctx, TLS1_2_VERSION);
1422     }
1423
1424     if (!TEST_true(create_ssl_objects(sctx, cctx, &serverssl, &clientssl,
1425                                       NULL, NULL)))
1426         goto end;
1427
1428     if (initrbio == USE_BIO_1
1429             || initwbio == USE_BIO_1
1430             || newrbio == USE_BIO_1
1431             || newwbio == USE_BIO_1) {
1432         if (!TEST_ptr(bio1 = BIO_new(BIO_s_mem())))
1433             goto end;
1434     }
1435
1436     if (initrbio == USE_BIO_2
1437             || initwbio == USE_BIO_2
1438             || newrbio == USE_BIO_2
1439             || newwbio == USE_BIO_2) {
1440         if (!TEST_ptr(bio2 = BIO_new(BIO_s_mem())))
1441             goto end;
1442     }
1443
1444     if (initrbio != USE_DEFAULT) {
1445         setupbio(&irbio, bio1, bio2, initrbio);
1446         setupbio(&iwbio, bio1, bio2, initwbio);
1447         SSL_set_bio(clientssl, irbio, iwbio);
1448
1449         /*
1450          * We want to maintain our own refs to these BIO, so do an up ref for
1451          * each BIO that will have ownership transferred in the SSL_set_bio()
1452          * call
1453          */
1454         if (irbio != NULL)
1455             BIO_up_ref(irbio);
1456         if (iwbio != NULL && iwbio != irbio)
1457             BIO_up_ref(iwbio);
1458     }
1459
1460     if (conntype != CONNTYPE_NO_CONNECTION
1461             && !TEST_true(create_ssl_connection(serverssl, clientssl,
1462                                                 SSL_ERROR_NONE)
1463                           == (conntype == CONNTYPE_CONNECTION_SUCCESS)))
1464         goto end;
1465
1466     setupbio(&nrbio, bio1, bio2, newrbio);
1467     setupbio(&nwbio, bio1, bio2, newwbio);
1468
1469     /*
1470      * We will (maybe) transfer ownership again so do more up refs.
1471      * SSL_set_bio() has some really complicated ownership rules where BIOs have
1472      * already been set!
1473      */
1474     if (nrbio != NULL
1475             && nrbio != irbio
1476             && (nwbio != iwbio || nrbio != nwbio))
1477         BIO_up_ref(nrbio);
1478     if (nwbio != NULL
1479             && nwbio != nrbio
1480             && (nwbio != iwbio || (nwbio == iwbio && irbio == iwbio)))
1481         BIO_up_ref(nwbio);
1482
1483     SSL_set_bio(clientssl, nrbio, nwbio);
1484
1485     testresult = 1;
1486
1487  end:
1488     BIO_free(bio1);
1489     BIO_free(bio2);
1490
1491     /*
1492      * This test is checking that the ref counting for SSL_set_bio is correct.
1493      * If we get here and we did too many frees then we will fail in the above
1494      * functions. If we haven't done enough then this will only be detected in
1495      * a crypto-mdebug build
1496      */
1497     SSL_free(serverssl);
1498     SSL_free(clientssl);
1499     SSL_CTX_free(sctx);
1500     SSL_CTX_free(cctx);
1501     return testresult;
1502 }
1503
1504 typedef enum { NO_BIO_CHANGE, CHANGE_RBIO, CHANGE_WBIO } bio_change_t;
1505
1506 static int execute_test_ssl_bio(int pop_ssl, bio_change_t change_bio)
1507 {
1508     BIO *sslbio = NULL, *membio1 = NULL, *membio2 = NULL;
1509     SSL_CTX *ctx;
1510     SSL *ssl = NULL;
1511     int testresult = 0;
1512
1513     if (!TEST_ptr(ctx = SSL_CTX_new(TLS_method()))
1514             || !TEST_ptr(ssl = SSL_new(ctx))
1515             || !TEST_ptr(sslbio = BIO_new(BIO_f_ssl()))
1516             || !TEST_ptr(membio1 = BIO_new(BIO_s_mem())))
1517         goto end;
1518
1519     BIO_set_ssl(sslbio, ssl, BIO_CLOSE);
1520
1521     /*
1522      * If anything goes wrong here then we could leak memory, so this will
1523      * be caught in a crypto-mdebug build
1524      */
1525     BIO_push(sslbio, membio1);
1526
1527     /* Verify changing the rbio/wbio directly does not cause leaks */
1528     if (change_bio != NO_BIO_CHANGE) {
1529         if (!TEST_ptr(membio2 = BIO_new(BIO_s_mem())))
1530             goto end;
1531         if (change_bio == CHANGE_RBIO)
1532             SSL_set0_rbio(ssl, membio2);
1533         else
1534             SSL_set0_wbio(ssl, membio2);
1535     }
1536     ssl = NULL;
1537
1538     if (pop_ssl)
1539         BIO_pop(sslbio);
1540     else
1541         BIO_pop(membio1);
1542
1543     testresult = 1;
1544  end:
1545     BIO_free(membio1);
1546     BIO_free(sslbio);
1547     SSL_free(ssl);
1548     SSL_CTX_free(ctx);
1549
1550     return testresult;
1551 }
1552
1553 static int test_ssl_bio_pop_next_bio(void)
1554 {
1555     return execute_test_ssl_bio(0, NO_BIO_CHANGE);
1556 }
1557
1558 static int test_ssl_bio_pop_ssl_bio(void)
1559 {
1560     return execute_test_ssl_bio(1, NO_BIO_CHANGE);
1561 }
1562
1563 static int test_ssl_bio_change_rbio(void)
1564 {
1565     return execute_test_ssl_bio(0, CHANGE_RBIO);
1566 }
1567
1568 static int test_ssl_bio_change_wbio(void)
1569 {
1570     return execute_test_ssl_bio(0, CHANGE_WBIO);
1571 }
1572
1573 #if !defined(OPENSSL_NO_TLS1_2) || defined(OPENSSL_NO_TLS1_3)
1574 typedef struct {
1575     /* The list of sig algs */
1576     const int *list;
1577     /* The length of the list */
1578     size_t listlen;
1579     /* A sigalgs list in string format */
1580     const char *liststr;
1581     /* Whether setting the list should succeed */
1582     int valid;
1583     /* Whether creating a connection with the list should succeed */
1584     int connsuccess;
1585 } sigalgs_list;
1586
1587 static const int validlist1[] = {NID_sha256, EVP_PKEY_RSA};
1588 # ifndef OPENSSL_NO_EC
1589 static const int validlist2[] = {NID_sha256, EVP_PKEY_RSA, NID_sha512, EVP_PKEY_EC};
1590 static const int validlist3[] = {NID_sha512, EVP_PKEY_EC};
1591 # endif
1592 static const int invalidlist1[] = {NID_undef, EVP_PKEY_RSA};
1593 static const int invalidlist2[] = {NID_sha256, NID_undef};
1594 static const int invalidlist3[] = {NID_sha256, EVP_PKEY_RSA, NID_sha256};
1595 static const int invalidlist4[] = {NID_sha256};
1596 static const sigalgs_list testsigalgs[] = {
1597     {validlist1, OSSL_NELEM(validlist1), NULL, 1, 1},
1598 # ifndef OPENSSL_NO_EC
1599     {validlist2, OSSL_NELEM(validlist2), NULL, 1, 1},
1600     {validlist3, OSSL_NELEM(validlist3), NULL, 1, 0},
1601 # endif
1602     {NULL, 0, "RSA+SHA256", 1, 1},
1603 # ifndef OPENSSL_NO_EC
1604     {NULL, 0, "RSA+SHA256:ECDSA+SHA512", 1, 1},
1605     {NULL, 0, "ECDSA+SHA512", 1, 0},
1606 # endif
1607     {invalidlist1, OSSL_NELEM(invalidlist1), NULL, 0, 0},
1608     {invalidlist2, OSSL_NELEM(invalidlist2), NULL, 0, 0},
1609     {invalidlist3, OSSL_NELEM(invalidlist3), NULL, 0, 0},
1610     {invalidlist4, OSSL_NELEM(invalidlist4), NULL, 0, 0},
1611     {NULL, 0, "RSA", 0, 0},
1612     {NULL, 0, "SHA256", 0, 0},
1613     {NULL, 0, "RSA+SHA256:SHA256", 0, 0},
1614     {NULL, 0, "Invalid", 0, 0}
1615 };
1616
1617 static int test_set_sigalgs(int idx)
1618 {
1619     SSL_CTX *cctx = NULL, *sctx = NULL;
1620     SSL *clientssl = NULL, *serverssl = NULL;
1621     int testresult = 0;
1622     const sigalgs_list *curr;
1623     int testctx;
1624
1625     /* Should never happen */
1626     if (!TEST_size_t_le((size_t)idx, OSSL_NELEM(testsigalgs) * 2))
1627         return 0;
1628
1629     testctx = ((size_t)idx < OSSL_NELEM(testsigalgs));
1630     curr = testctx ? &testsigalgs[idx]
1631                    : &testsigalgs[idx - OSSL_NELEM(testsigalgs)];
1632
1633     if (!TEST_true(create_ssl_ctx_pair(TLS_server_method(), TLS_client_method(),
1634                                        TLS1_VERSION, TLS_MAX_VERSION,
1635                                        &sctx, &cctx, cert, privkey)))
1636         return 0;
1637
1638     /*
1639      * TODO(TLS1.3): These APIs cannot set TLSv1.3 sig algs so we just test it
1640      * for TLSv1.2 for now until we add a new API.
1641      */
1642     SSL_CTX_set_max_proto_version(cctx, TLS1_2_VERSION);
1643
1644     if (testctx) {
1645         int ret;
1646
1647         if (curr->list != NULL)
1648             ret = SSL_CTX_set1_sigalgs(cctx, curr->list, curr->listlen);
1649         else
1650             ret = SSL_CTX_set1_sigalgs_list(cctx, curr->liststr);
1651
1652         if (!ret) {
1653             if (curr->valid)
1654                 TEST_info("Failure setting sigalgs in SSL_CTX (%d)\n", idx);
1655             else
1656                 testresult = 1;
1657             goto end;
1658         }
1659         if (!curr->valid) {
1660             TEST_info("Not-failed setting sigalgs in SSL_CTX (%d)\n", idx);
1661             goto end;
1662         }
1663     }
1664
1665     if (!TEST_true(create_ssl_objects(sctx, cctx, &serverssl,
1666                                       &clientssl, NULL, NULL)))
1667         goto end;
1668
1669     if (!testctx) {
1670         int ret;
1671
1672         if (curr->list != NULL)
1673             ret = SSL_set1_sigalgs(clientssl, curr->list, curr->listlen);
1674         else
1675             ret = SSL_set1_sigalgs_list(clientssl, curr->liststr);
1676         if (!ret) {
1677             if (curr->valid)
1678                 TEST_info("Failure setting sigalgs in SSL (%d)\n", idx);
1679             else
1680                 testresult = 1;
1681             goto end;
1682         }
1683         if (!curr->valid)
1684             goto end;
1685     }
1686
1687     if (!TEST_int_eq(create_ssl_connection(serverssl, clientssl,
1688                                            SSL_ERROR_NONE),
1689                 curr->connsuccess))
1690         goto end;
1691
1692     testresult = 1;
1693
1694  end:
1695     SSL_free(serverssl);
1696     SSL_free(clientssl);
1697     SSL_CTX_free(sctx);
1698     SSL_CTX_free(cctx);
1699
1700     return testresult;
1701 }
1702 #endif
1703
1704 #ifndef OPENSSL_NO_TLS1_3
1705
1706 static SSL_SESSION *clientpsk = NULL;
1707 static SSL_SESSION *serverpsk = NULL;
1708 static const char *pskid = "Identity";
1709 static const char *srvid;
1710
1711 static int use_session_cb_cnt = 0;
1712 static int find_session_cb_cnt = 0;
1713 static int psk_client_cb_cnt = 0;
1714 static int psk_server_cb_cnt = 0;
1715
1716 static int use_session_cb(SSL *ssl, const EVP_MD *md, const unsigned char **id,
1717                           size_t *idlen, SSL_SESSION **sess)
1718 {
1719     switch (++use_session_cb_cnt) {
1720     case 1:
1721         /* The first call should always have a NULL md */
1722         if (md != NULL)
1723             return 0;
1724         break;
1725
1726     case 2:
1727         /* The second call should always have an md */
1728         if (md == NULL)
1729             return 0;
1730         break;
1731
1732     default:
1733         /* We should only be called a maximum of twice */
1734         return 0;
1735     }
1736
1737     if (clientpsk != NULL)
1738         SSL_SESSION_up_ref(clientpsk);
1739
1740     *sess = clientpsk;
1741     *id = (const unsigned char *)pskid;
1742     *idlen = strlen(pskid);
1743
1744     return 1;
1745 }
1746
1747 #ifndef OPENSSL_NO_PSK
1748 static unsigned int psk_client_cb(SSL *ssl, const char *hint, char *id,
1749                                   unsigned int max_id_len,
1750                                   unsigned char *psk,
1751                                   unsigned int max_psk_len)
1752 {
1753     unsigned int psklen = 0;
1754
1755     psk_client_cb_cnt++;
1756
1757     if (strlen(pskid) + 1 > max_id_len)
1758         return 0;
1759
1760     /* We should only ever be called a maximum of twice per connection */
1761     if (psk_client_cb_cnt > 2)
1762         return 0;
1763
1764     if (clientpsk == NULL)
1765         return 0;
1766
1767     /* We'll reuse the PSK we set up for TLSv1.3 */
1768     if (SSL_SESSION_get_master_key(clientpsk, NULL, 0) > max_psk_len)
1769         return 0;
1770     psklen = SSL_SESSION_get_master_key(clientpsk, psk, max_psk_len);
1771     strncpy(id, pskid, max_id_len);
1772
1773     return psklen;
1774 }
1775 #endif /* OPENSSL_NO_PSK */
1776
1777 static int find_session_cb(SSL *ssl, const unsigned char *identity,
1778                            size_t identity_len, SSL_SESSION **sess)
1779 {
1780     find_session_cb_cnt++;
1781
1782     /* We should only ever be called a maximum of twice per connection */
1783     if (find_session_cb_cnt > 2)
1784         return 0;
1785
1786     if (serverpsk == NULL)
1787         return 0;
1788
1789     /* Identity should match that set by the client */
1790     if (strlen(srvid) != identity_len
1791             || strncmp(srvid, (const char *)identity, identity_len) != 0) {
1792         /* No PSK found, continue but without a PSK */
1793         *sess = NULL;
1794         return 1;
1795     }
1796
1797     SSL_SESSION_up_ref(serverpsk);
1798     *sess = serverpsk;
1799
1800     return 1;
1801 }
1802
1803 #ifndef OPENSSL_NO_PSK
1804 static unsigned int psk_server_cb(SSL *ssl, const char *identity,
1805                                   unsigned char *psk, unsigned int max_psk_len)
1806 {
1807     unsigned int psklen = 0;
1808
1809     psk_server_cb_cnt++;
1810
1811     /* We should only ever be called a maximum of twice per connection */
1812     if (find_session_cb_cnt > 2)
1813         return 0;
1814
1815     if (serverpsk == NULL)
1816         return 0;
1817
1818     /* Identity should match that set by the client */
1819     if (strcmp(srvid, identity) != 0) {
1820         return 0;
1821     }
1822
1823     /* We'll reuse the PSK we set up for TLSv1.3 */
1824     if (SSL_SESSION_get_master_key(serverpsk, NULL, 0) > max_psk_len)
1825         return 0;
1826     psklen = SSL_SESSION_get_master_key(serverpsk, psk, max_psk_len);
1827
1828     return psklen;
1829 }
1830 #endif /* OPENSSL_NO_PSK */
1831
1832 #define MSG1    "Hello"
1833 #define MSG2    "World."
1834 #define MSG3    "This"
1835 #define MSG4    "is"
1836 #define MSG5    "a"
1837 #define MSG6    "test"
1838 #define MSG7    "message."
1839
1840 #define TLS13_AES_256_GCM_SHA384_BYTES  ((const unsigned char *)"\x13\x02")
1841 #define TLS13_AES_128_GCM_SHA256_BYTES  ((const unsigned char *)"\x13\x01")
1842
1843 /*
1844  * Helper method to setup objects for early data test. Caller frees objects on
1845  * error.
1846  */
1847 static int setupearly_data_test(SSL_CTX **cctx, SSL_CTX **sctx, SSL **clientssl,
1848                                 SSL **serverssl, SSL_SESSION **sess, int idx)
1849 {
1850     if (*sctx == NULL
1851             && !TEST_true(create_ssl_ctx_pair(TLS_server_method(),
1852                                               TLS_client_method(),
1853                                               TLS1_VERSION, TLS_MAX_VERSION,
1854                                               sctx, cctx, cert, privkey)))
1855         return 0;
1856
1857     if (!TEST_true(SSL_CTX_set_max_early_data(*sctx, SSL3_RT_MAX_PLAIN_LENGTH)))
1858         return 0;
1859
1860     if (idx == 1) {
1861         /* When idx == 1 we repeat the tests with read_ahead set */
1862         SSL_CTX_set_read_ahead(*cctx, 1);
1863         SSL_CTX_set_read_ahead(*sctx, 1);
1864     } else if (idx == 2) {
1865         /* When idx == 2 we are doing early_data with a PSK. Set up callbacks */
1866         SSL_CTX_set_psk_use_session_callback(*cctx, use_session_cb);
1867         SSL_CTX_set_psk_find_session_callback(*sctx, find_session_cb);
1868         use_session_cb_cnt = 0;
1869         find_session_cb_cnt = 0;
1870         srvid = pskid;
1871     }
1872
1873     if (!TEST_true(create_ssl_objects(*sctx, *cctx, serverssl, clientssl,
1874                                       NULL, NULL)))
1875         return 0;
1876
1877     /*
1878      * For one of the run throughs (doesn't matter which one), we'll try sending
1879      * some SNI data in the initial ClientHello. This will be ignored (because
1880      * there is no SNI cb set up by the server), so it should not impact
1881      * early_data.
1882      */
1883     if (idx == 1
1884             && !TEST_true(SSL_set_tlsext_host_name(*clientssl, "localhost")))
1885         return 0;
1886
1887     if (idx == 2) {
1888         /* Create the PSK */
1889         const SSL_CIPHER *cipher = NULL;
1890         const unsigned char key[] = {
1891             0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a,
1892             0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15,
1893             0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20,
1894             0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b,
1895             0x2c, 0x2d, 0x2e, 0x2f
1896         };
1897
1898         cipher = SSL_CIPHER_find(*clientssl, TLS13_AES_256_GCM_SHA384_BYTES);
1899         clientpsk = SSL_SESSION_new();
1900         if (!TEST_ptr(clientpsk)
1901                 || !TEST_ptr(cipher)
1902                 || !TEST_true(SSL_SESSION_set1_master_key(clientpsk, key,
1903                                                           sizeof(key)))
1904                 || !TEST_true(SSL_SESSION_set_cipher(clientpsk, cipher))
1905                 || !TEST_true(
1906                         SSL_SESSION_set_protocol_version(clientpsk,
1907                                                          TLS1_3_VERSION))
1908                    /*
1909                     * We just choose an arbitrary value for max_early_data which
1910                     * should be big enough for testing purposes.
1911                     */
1912                 || !TEST_true(SSL_SESSION_set_max_early_data(clientpsk,
1913                                                              0x100))
1914                 || !TEST_true(SSL_SESSION_up_ref(clientpsk))) {
1915             SSL_SESSION_free(clientpsk);
1916             clientpsk = NULL;
1917             return 0;
1918         }
1919         serverpsk = clientpsk;
1920
1921         if (sess != NULL) {
1922             if (!TEST_true(SSL_SESSION_up_ref(clientpsk))) {
1923                 SSL_SESSION_free(clientpsk);
1924                 SSL_SESSION_free(serverpsk);
1925                 clientpsk = serverpsk = NULL;
1926                 return 0;
1927             }
1928             *sess = clientpsk;
1929         }
1930         return 1;
1931     }
1932
1933     if (sess == NULL)
1934         return 1;
1935
1936     if (!TEST_true(create_ssl_connection(*serverssl, *clientssl,
1937                                          SSL_ERROR_NONE)))
1938         return 0;
1939
1940     *sess = SSL_get1_session(*clientssl);
1941     SSL_shutdown(*clientssl);
1942     SSL_shutdown(*serverssl);
1943     SSL_free(*serverssl);
1944     SSL_free(*clientssl);
1945     *serverssl = *clientssl = NULL;
1946
1947     if (!TEST_true(create_ssl_objects(*sctx, *cctx, serverssl,
1948                                       clientssl, NULL, NULL))
1949             || !TEST_true(SSL_set_session(*clientssl, *sess)))
1950         return 0;
1951
1952     return 1;
1953 }
1954
1955 static int test_early_data_read_write(int idx)
1956 {
1957     SSL_CTX *cctx = NULL, *sctx = NULL;
1958     SSL *clientssl = NULL, *serverssl = NULL;
1959     int testresult = 0;
1960     SSL_SESSION *sess = NULL;
1961     unsigned char buf[20], data[1024];
1962     size_t readbytes, written, eoedlen, rawread, rawwritten;
1963     BIO *rbio;
1964
1965     if (!TEST_true(setupearly_data_test(&cctx, &sctx, &clientssl,
1966                                         &serverssl, &sess, idx)))
1967         goto end;
1968
1969     /* Write and read some early data */
1970     if (!TEST_true(SSL_write_early_data(clientssl, MSG1, strlen(MSG1),
1971                                         &written))
1972             || !TEST_size_t_eq(written, strlen(MSG1))
1973             || !TEST_int_eq(SSL_read_early_data(serverssl, buf,
1974                                                 sizeof(buf), &readbytes),
1975                             SSL_READ_EARLY_DATA_SUCCESS)
1976             || !TEST_mem_eq(MSG1, readbytes, buf, strlen(MSG1))
1977             || !TEST_int_eq(SSL_get_early_data_status(serverssl),
1978                             SSL_EARLY_DATA_ACCEPTED))
1979         goto end;
1980
1981     /*
1982      * Server should be able to write data, and client should be able to
1983      * read it.
1984      */
1985     if (!TEST_true(SSL_write_early_data(serverssl, MSG2, strlen(MSG2),
1986                                         &written))
1987             || !TEST_size_t_eq(written, strlen(MSG2))
1988             || !TEST_true(SSL_read_ex(clientssl, buf, sizeof(buf), &readbytes))
1989             || !TEST_mem_eq(buf, readbytes, MSG2, strlen(MSG2)))
1990         goto end;
1991
1992     /* Even after reading normal data, client should be able write early data */
1993     if (!TEST_true(SSL_write_early_data(clientssl, MSG3, strlen(MSG3),
1994                                         &written))
1995             || !TEST_size_t_eq(written, strlen(MSG3)))
1996         goto end;
1997
1998     /* Server should still be able read early data after writing data */
1999     if (!TEST_int_eq(SSL_read_early_data(serverssl, buf, sizeof(buf),
2000                                          &readbytes),
2001                      SSL_READ_EARLY_DATA_SUCCESS)
2002             || !TEST_mem_eq(buf, readbytes, MSG3, strlen(MSG3)))
2003         goto end;
2004
2005     /* Write more data from server and read it from client */
2006     if (!TEST_true(SSL_write_early_data(serverssl, MSG4, strlen(MSG4),
2007                                         &written))
2008             || !TEST_size_t_eq(written, strlen(MSG4))
2009             || !TEST_true(SSL_read_ex(clientssl, buf, sizeof(buf), &readbytes))
2010             || !TEST_mem_eq(buf, readbytes, MSG4, strlen(MSG4)))
2011         goto end;
2012
2013     /*
2014      * If client writes normal data it should mean writing early data is no
2015      * longer possible.
2016      */
2017     if (!TEST_true(SSL_write_ex(clientssl, MSG5, strlen(MSG5), &written))
2018             || !TEST_size_t_eq(written, strlen(MSG5))
2019             || !TEST_int_eq(SSL_get_early_data_status(clientssl),
2020                             SSL_EARLY_DATA_ACCEPTED))
2021         goto end;
2022
2023     /*
2024      * At this point the client has written EndOfEarlyData, ClientFinished and
2025      * normal (fully protected) data. We are going to cause a delay between the
2026      * arrival of EndOfEarlyData and ClientFinished. We read out all the data
2027      * in the read BIO, and then just put back the EndOfEarlyData message.
2028      */
2029     rbio = SSL_get_rbio(serverssl);
2030     if (!TEST_true(BIO_read_ex(rbio, data, sizeof(data), &rawread))
2031             || !TEST_size_t_lt(rawread, sizeof(data))
2032             || !TEST_size_t_gt(rawread, SSL3_RT_HEADER_LENGTH))
2033         goto end;
2034
2035     /* Record length is in the 4th and 5th bytes of the record header */
2036     eoedlen = SSL3_RT_HEADER_LENGTH + (data[3] << 8 | data[4]);
2037     if (!TEST_true(BIO_write_ex(rbio, data, eoedlen, &rawwritten))
2038             || !TEST_size_t_eq(rawwritten, eoedlen))
2039         goto end;
2040
2041     /* Server should be told that there is no more early data */
2042     if (!TEST_int_eq(SSL_read_early_data(serverssl, buf, sizeof(buf),
2043                                          &readbytes),
2044                      SSL_READ_EARLY_DATA_FINISH)
2045             || !TEST_size_t_eq(readbytes, 0))
2046         goto end;
2047
2048     /*
2049      * Server has not finished init yet, so should still be able to write early
2050      * data.
2051      */
2052     if (!TEST_true(SSL_write_early_data(serverssl, MSG6, strlen(MSG6),
2053                                         &written))
2054             || !TEST_size_t_eq(written, strlen(MSG6)))
2055         goto end;
2056
2057     /* Push the ClientFinished and the normal data back into the server rbio */
2058     if (!TEST_true(BIO_write_ex(rbio, data + eoedlen, rawread - eoedlen,
2059                                 &rawwritten))
2060             || !TEST_size_t_eq(rawwritten, rawread - eoedlen))
2061         goto end;
2062
2063     /* Server should be able to read normal data */
2064     if (!TEST_true(SSL_read_ex(serverssl, buf, sizeof(buf), &readbytes))
2065             || !TEST_size_t_eq(readbytes, strlen(MSG5)))
2066         goto end;
2067
2068     /* Client and server should not be able to write/read early data now */
2069     if (!TEST_false(SSL_write_early_data(clientssl, MSG6, strlen(MSG6),
2070                                          &written)))
2071         goto end;
2072     ERR_clear_error();
2073     if (!TEST_int_eq(SSL_read_early_data(serverssl, buf, sizeof(buf),
2074                                          &readbytes),
2075                      SSL_READ_EARLY_DATA_ERROR))
2076         goto end;
2077     ERR_clear_error();
2078
2079     /* Client should be able to read the data sent by the server */
2080     if (!TEST_true(SSL_read_ex(clientssl, buf, sizeof(buf), &readbytes))
2081             || !TEST_mem_eq(buf, readbytes, MSG6, strlen(MSG6)))
2082         goto end;
2083
2084     /*
2085      * Make sure we process the two NewSessionTickets. These arrive
2086      * post-handshake. We attempt reads which we do not expect to return any
2087      * data.
2088      */
2089     if (!TEST_false(SSL_read_ex(clientssl, buf, sizeof(buf), &readbytes))
2090             || !TEST_false(SSL_read_ex(clientssl, buf, sizeof(buf),
2091                            &readbytes)))
2092         goto end;
2093
2094     /* Server should be able to write normal data */
2095     if (!TEST_true(SSL_write_ex(serverssl, MSG7, strlen(MSG7), &written))
2096             || !TEST_size_t_eq(written, strlen(MSG7))
2097             || !TEST_true(SSL_read_ex(clientssl, buf, sizeof(buf), &readbytes))
2098             || !TEST_mem_eq(buf, readbytes, MSG7, strlen(MSG7)))
2099         goto end;
2100
2101     SSL_SESSION_free(sess);
2102     sess = SSL_get1_session(clientssl);
2103     use_session_cb_cnt = 0;
2104     find_session_cb_cnt = 0;
2105
2106     SSL_shutdown(clientssl);
2107     SSL_shutdown(serverssl);
2108     SSL_free(serverssl);
2109     SSL_free(clientssl);
2110     serverssl = clientssl = NULL;
2111     if (!TEST_true(create_ssl_objects(sctx, cctx, &serverssl,
2112                                       &clientssl, NULL, NULL))
2113             || !TEST_true(SSL_set_session(clientssl, sess)))
2114         goto end;
2115
2116     /* Write and read some early data */
2117     if (!TEST_true(SSL_write_early_data(clientssl, MSG1, strlen(MSG1),
2118                                         &written))
2119             || !TEST_size_t_eq(written, strlen(MSG1))
2120             || !TEST_int_eq(SSL_read_early_data(serverssl, buf, sizeof(buf),
2121                                                 &readbytes),
2122                             SSL_READ_EARLY_DATA_SUCCESS)
2123             || !TEST_mem_eq(buf, readbytes, MSG1, strlen(MSG1)))
2124         goto end;
2125
2126     if (!TEST_int_gt(SSL_connect(clientssl), 0)
2127             || !TEST_int_gt(SSL_accept(serverssl), 0))
2128         goto end;
2129
2130     /* Client and server should not be able to write/read early data now */
2131     if (!TEST_false(SSL_write_early_data(clientssl, MSG6, strlen(MSG6),
2132                                          &written)))
2133         goto end;
2134     ERR_clear_error();
2135     if (!TEST_int_eq(SSL_read_early_data(serverssl, buf, sizeof(buf),
2136                                          &readbytes),
2137                      SSL_READ_EARLY_DATA_ERROR))
2138         goto end;
2139     ERR_clear_error();
2140
2141     /* Client and server should be able to write/read normal data */
2142     if (!TEST_true(SSL_write_ex(clientssl, MSG5, strlen(MSG5), &written))
2143             || !TEST_size_t_eq(written, strlen(MSG5))
2144             || !TEST_true(SSL_read_ex(serverssl, buf, sizeof(buf), &readbytes))
2145             || !TEST_size_t_eq(readbytes, strlen(MSG5)))
2146         goto end;
2147
2148     testresult = 1;
2149
2150  end:
2151     SSL_SESSION_free(sess);
2152     SSL_SESSION_free(clientpsk);
2153     SSL_SESSION_free(serverpsk);
2154     clientpsk = serverpsk = NULL;
2155     SSL_free(serverssl);
2156     SSL_free(clientssl);
2157     SSL_CTX_free(sctx);
2158     SSL_CTX_free(cctx);
2159     return testresult;
2160 }
2161
2162 static int allow_ed_cb_called = 0;
2163
2164 static int allow_early_data_cb(SSL *s, void *arg)
2165 {
2166     int *usecb = (int *)arg;
2167
2168     allow_ed_cb_called++;
2169
2170     if (*usecb == 1)
2171         return 0;
2172
2173     return 1;
2174 }
2175
2176 /*
2177  * idx == 0: Standard early_data setup
2178  * idx == 1: early_data setup using read_ahead
2179  * usecb == 0: Don't use a custom early data callback
2180  * usecb == 1: Use a custom early data callback and reject the early data
2181  * usecb == 2: Use a custom early data callback and accept the early data
2182  */
2183 static int test_early_data_replay_int(int idx, int usecb)
2184 {
2185     SSL_CTX *cctx = NULL, *sctx = NULL;
2186     SSL *clientssl = NULL, *serverssl = NULL;
2187     int testresult = 0;
2188     SSL_SESSION *sess = NULL;
2189     size_t readbytes, written;
2190     unsigned char buf[20];
2191
2192     allow_ed_cb_called = 0;
2193
2194     if (!TEST_true(create_ssl_ctx_pair(TLS_server_method(), TLS_client_method(),
2195                                        TLS1_VERSION, TLS_MAX_VERSION, &sctx,
2196                                        &cctx, cert, privkey)))
2197         return 0;
2198
2199     if (usecb > 0) {
2200         SSL_CTX_set_options(sctx, SSL_OP_NO_ANTI_REPLAY);
2201         SSL_CTX_set_allow_early_data_cb(sctx, allow_early_data_cb, &usecb);
2202     }
2203
2204     if (!TEST_true(setupearly_data_test(&cctx, &sctx, &clientssl,
2205                                         &serverssl, &sess, idx)))
2206         goto end;
2207
2208     /*
2209      * The server is configured to accept early data. Create a connection to
2210      * "use up" the ticket
2211      */
2212     if (!TEST_true(create_ssl_connection(serverssl, clientssl, SSL_ERROR_NONE))
2213             || !TEST_true(SSL_session_reused(clientssl)))
2214         goto end;
2215
2216     SSL_shutdown(clientssl);
2217     SSL_shutdown(serverssl);
2218     SSL_free(serverssl);
2219     SSL_free(clientssl);
2220     serverssl = clientssl = NULL;
2221
2222     if (!TEST_true(create_ssl_objects(sctx, cctx, &serverssl,
2223                                       &clientssl, NULL, NULL))
2224             || !TEST_true(SSL_set_session(clientssl, sess)))
2225         goto end;
2226
2227     /* Write and read some early data */
2228     if (!TEST_true(SSL_write_early_data(clientssl, MSG1, strlen(MSG1),
2229                                         &written))
2230             || !TEST_size_t_eq(written, strlen(MSG1)))
2231         goto end;
2232
2233     if (usecb <= 1) {
2234         if (!TEST_int_eq(SSL_read_early_data(serverssl, buf, sizeof(buf),
2235                                              &readbytes),
2236                          SSL_READ_EARLY_DATA_FINISH)
2237                    /*
2238                     * The ticket was reused, so the we should have rejected the
2239                     * early data
2240                     */
2241                 || !TEST_int_eq(SSL_get_early_data_status(serverssl),
2242                                 SSL_EARLY_DATA_REJECTED))
2243             goto end;
2244     } else {
2245         /* In this case the callback decides to accept the early data */
2246         if (!TEST_int_eq(SSL_read_early_data(serverssl, buf, sizeof(buf),
2247                                              &readbytes),
2248                          SSL_READ_EARLY_DATA_SUCCESS)
2249                 || !TEST_mem_eq(MSG1, strlen(MSG1), buf, readbytes)
2250                    /*
2251                     * Server will have sent its flight so client can now send
2252                     * end of early data and complete its half of the handshake
2253                     */
2254                 || !TEST_int_gt(SSL_connect(clientssl), 0)
2255                 || !TEST_int_eq(SSL_read_early_data(serverssl, buf, sizeof(buf),
2256                                              &readbytes),
2257                                 SSL_READ_EARLY_DATA_FINISH)
2258                 || !TEST_int_eq(SSL_get_early_data_status(serverssl),
2259                                 SSL_EARLY_DATA_ACCEPTED))
2260             goto end;
2261     }
2262
2263     /* Complete the connection */
2264     if (!TEST_true(create_ssl_connection(serverssl, clientssl, SSL_ERROR_NONE))
2265             || !TEST_int_eq(SSL_session_reused(clientssl), (usecb > 0) ? 1 : 0)
2266             || !TEST_int_eq(allow_ed_cb_called, usecb > 0 ? 1 : 0))
2267         goto end;
2268
2269     testresult = 1;
2270
2271  end:
2272     SSL_SESSION_free(sess);
2273     SSL_SESSION_free(clientpsk);
2274     SSL_SESSION_free(serverpsk);
2275     clientpsk = serverpsk = NULL;
2276     SSL_free(serverssl);
2277     SSL_free(clientssl);
2278     SSL_CTX_free(sctx);
2279     SSL_CTX_free(cctx);
2280     return testresult;
2281 }
2282
2283 static int test_early_data_replay(int idx)
2284 {
2285     int ret;
2286
2287     ret = test_early_data_replay_int(idx, 0);
2288     ret &= test_early_data_replay_int(idx, 1);
2289     ret &= test_early_data_replay_int(idx, 2);
2290
2291     return ret;
2292 }
2293
2294 /*
2295  * Helper function to test that a server attempting to read early data can
2296  * handle a connection from a client where the early data should be skipped.
2297  */
2298 static int early_data_skip_helper(int hrr, int idx)
2299 {
2300     SSL_CTX *cctx = NULL, *sctx = NULL;
2301     SSL *clientssl = NULL, *serverssl = NULL;
2302     int testresult = 0;
2303     SSL_SESSION *sess = NULL;
2304     unsigned char buf[20];
2305     size_t readbytes, written;
2306
2307     if (!TEST_true(setupearly_data_test(&cctx, &sctx, &clientssl,
2308                                         &serverssl, &sess, idx)))
2309         goto end;
2310
2311     if (hrr) {
2312         /* Force an HRR to occur */
2313         if (!TEST_true(SSL_set1_groups_list(serverssl, "P-256")))
2314             goto end;
2315     } else if (idx == 2) {
2316         /*
2317          * We force early_data rejection by ensuring the PSK identity is
2318          * unrecognised
2319          */
2320         srvid = "Dummy Identity";
2321     } else {
2322         /*
2323          * Deliberately corrupt the creation time. We take 20 seconds off the
2324          * time. It could be any value as long as it is not within tolerance.
2325          * This should mean the ticket is rejected.
2326          */
2327         if (!TEST_true(SSL_SESSION_set_time(sess, (long)(time(NULL) - 20))))
2328             goto end;
2329     }
2330
2331     /* Write some early data */
2332     if (!TEST_true(SSL_write_early_data(clientssl, MSG1, strlen(MSG1),
2333                                         &written))
2334             || !TEST_size_t_eq(written, strlen(MSG1)))
2335         goto end;
2336
2337     /* Server should reject the early data and skip over it */
2338     if (!TEST_int_eq(SSL_read_early_data(serverssl, buf, sizeof(buf),
2339                                          &readbytes),
2340                      SSL_READ_EARLY_DATA_FINISH)
2341             || !TEST_size_t_eq(readbytes, 0)
2342             || !TEST_int_eq(SSL_get_early_data_status(serverssl),
2343                             SSL_EARLY_DATA_REJECTED))
2344         goto end;
2345
2346     if (hrr) {
2347         /*
2348          * Finish off the handshake. We perform the same writes and reads as
2349          * further down but we expect them to fail due to the incomplete
2350          * handshake.
2351          */
2352         if (!TEST_false(SSL_write_ex(clientssl, MSG2, strlen(MSG2), &written))
2353                 || !TEST_false(SSL_read_ex(serverssl, buf, sizeof(buf),
2354                                &readbytes)))
2355             goto end;
2356     }
2357
2358     /* Should be able to send normal data despite rejection of early data */
2359     if (!TEST_true(SSL_write_ex(clientssl, MSG2, strlen(MSG2), &written))
2360             || !TEST_size_t_eq(written, strlen(MSG2))
2361             || !TEST_int_eq(SSL_get_early_data_status(clientssl),
2362                             SSL_EARLY_DATA_REJECTED)
2363             || !TEST_true(SSL_read_ex(serverssl, buf, sizeof(buf), &readbytes))
2364             || !TEST_mem_eq(buf, readbytes, MSG2, strlen(MSG2)))
2365         goto end;
2366
2367     testresult = 1;
2368
2369  end:
2370     SSL_SESSION_free(clientpsk);
2371     SSL_SESSION_free(serverpsk);
2372     clientpsk = serverpsk = NULL;
2373     SSL_SESSION_free(sess);
2374     SSL_free(serverssl);
2375     SSL_free(clientssl);
2376     SSL_CTX_free(sctx);
2377     SSL_CTX_free(cctx);
2378     return testresult;
2379 }
2380
2381 /*
2382  * Test that a server attempting to read early data can handle a connection
2383  * from a client where the early data is not acceptable.
2384  */
2385 static int test_early_data_skip(int idx)
2386 {
2387     return early_data_skip_helper(0, idx);
2388 }
2389
2390 /*
2391  * Test that a server attempting to read early data can handle a connection
2392  * from a client where an HRR occurs.
2393  */
2394 static int test_early_data_skip_hrr(int idx)
2395 {
2396     return early_data_skip_helper(1, idx);
2397 }
2398
2399 /*
2400  * Test that a server attempting to read early data can handle a connection
2401  * from a client that doesn't send any.
2402  */
2403 static int test_early_data_not_sent(int idx)
2404 {
2405     SSL_CTX *cctx = NULL, *sctx = NULL;
2406     SSL *clientssl = NULL, *serverssl = NULL;
2407     int testresult = 0;
2408     SSL_SESSION *sess = NULL;
2409     unsigned char buf[20];
2410     size_t readbytes, written;
2411
2412     if (!TEST_true(setupearly_data_test(&cctx, &sctx, &clientssl,
2413                                         &serverssl, &sess, idx)))
2414         goto end;
2415
2416     /* Write some data - should block due to handshake with server */
2417     SSL_set_connect_state(clientssl);
2418     if (!TEST_false(SSL_write_ex(clientssl, MSG1, strlen(MSG1), &written)))
2419         goto end;
2420
2421     /* Server should detect that early data has not been sent */
2422     if (!TEST_int_eq(SSL_read_early_data(serverssl, buf, sizeof(buf),
2423                                          &readbytes),
2424                      SSL_READ_EARLY_DATA_FINISH)
2425             || !TEST_size_t_eq(readbytes, 0)
2426             || !TEST_int_eq(SSL_get_early_data_status(serverssl),
2427                             SSL_EARLY_DATA_NOT_SENT)
2428             || !TEST_int_eq(SSL_get_early_data_status(clientssl),
2429                             SSL_EARLY_DATA_NOT_SENT))
2430         goto end;
2431
2432     /* Continue writing the message we started earlier */
2433     if (!TEST_true(SSL_write_ex(clientssl, MSG1, strlen(MSG1), &written))
2434             || !TEST_size_t_eq(written, strlen(MSG1))
2435             || !TEST_true(SSL_read_ex(serverssl, buf, sizeof(buf), &readbytes))
2436             || !TEST_mem_eq(buf, readbytes, MSG1, strlen(MSG1))
2437             || !SSL_write_ex(serverssl, MSG2, strlen(MSG2), &written)
2438             || !TEST_size_t_eq(written, strlen(MSG2)))
2439         goto end;
2440
2441     if (!TEST_true(SSL_read_ex(clientssl, buf, sizeof(buf), &readbytes))
2442             || !TEST_mem_eq(buf, readbytes, MSG2, strlen(MSG2)))
2443         goto end;
2444
2445     testresult = 1;
2446
2447  end:
2448     SSL_SESSION_free(sess);
2449     SSL_SESSION_free(clientpsk);
2450     SSL_SESSION_free(serverpsk);
2451     clientpsk = serverpsk = NULL;
2452     SSL_free(serverssl);
2453     SSL_free(clientssl);
2454     SSL_CTX_free(sctx);
2455     SSL_CTX_free(cctx);
2456     return testresult;
2457 }
2458
2459 static int hostname_cb(SSL *s, int *al, void *arg)
2460 {
2461     const char *hostname = SSL_get_servername(s, TLSEXT_NAMETYPE_host_name);
2462
2463     if (hostname != NULL && strcmp(hostname, "goodhost") == 0)
2464         return  SSL_TLSEXT_ERR_OK;
2465
2466     return SSL_TLSEXT_ERR_NOACK;
2467 }
2468
2469 static const char *servalpn;
2470
2471 static int alpn_select_cb(SSL *ssl, const unsigned char **out,
2472                           unsigned char *outlen, const unsigned char *in,
2473                           unsigned int inlen, void *arg)
2474 {
2475     unsigned int protlen = 0;
2476     const unsigned char *prot;
2477
2478     for (prot = in; prot < in + inlen; prot += protlen) {
2479         protlen = *prot++;
2480         if (in + inlen < prot + protlen)
2481             return SSL_TLSEXT_ERR_NOACK;
2482
2483         if (protlen == strlen(servalpn)
2484                 && memcmp(prot, servalpn, protlen) == 0) {
2485             *out = prot;
2486             *outlen = protlen;
2487             return SSL_TLSEXT_ERR_OK;
2488         }
2489     }
2490
2491     return SSL_TLSEXT_ERR_NOACK;
2492 }
2493
2494 /* Test that a PSK can be used to send early_data */
2495 static int test_early_data_psk(int idx)
2496 {
2497     SSL_CTX *cctx = NULL, *sctx = NULL;
2498     SSL *clientssl = NULL, *serverssl = NULL;
2499     int testresult = 0;
2500     SSL_SESSION *sess = NULL;
2501     unsigned char alpnlist[] = {
2502         0x08, 'g', 'o', 'o', 'd', 'a', 'l', 'p', 'n', 0x07, 'b', 'a', 'd', 'a',
2503         'l', 'p', 'n'
2504     };
2505 #define GOODALPNLEN     9
2506 #define BADALPNLEN      8
2507 #define GOODALPN        (alpnlist)
2508 #define BADALPN         (alpnlist + GOODALPNLEN)
2509     int err = 0;
2510     unsigned char buf[20];
2511     size_t readbytes, written;
2512     int readearlyres = SSL_READ_EARLY_DATA_SUCCESS, connectres = 1;
2513     int edstatus = SSL_EARLY_DATA_ACCEPTED;
2514
2515     /* We always set this up with a final parameter of "2" for PSK */
2516     if (!TEST_true(setupearly_data_test(&cctx, &sctx, &clientssl,
2517                                         &serverssl, &sess, 2)))
2518         goto end;
2519
2520     servalpn = "goodalpn";
2521
2522     /*
2523      * Note: There is no test for inconsistent SNI with late client detection.
2524      * This is because servers do not acknowledge SNI even if they are using
2525      * it in a resumption handshake - so it is not actually possible for a
2526      * client to detect a problem.
2527      */
2528     switch (idx) {
2529     case 0:
2530         /* Set inconsistent SNI (early client detection) */
2531         err = SSL_R_INCONSISTENT_EARLY_DATA_SNI;
2532         if (!TEST_true(SSL_SESSION_set1_hostname(sess, "goodhost"))
2533                 || !TEST_true(SSL_set_tlsext_host_name(clientssl, "badhost")))
2534             goto end;
2535         break;
2536
2537     case 1:
2538         /* Set inconsistent ALPN (early client detection) */
2539         err = SSL_R_INCONSISTENT_EARLY_DATA_ALPN;
2540         /* SSL_set_alpn_protos returns 0 for success and 1 for failure */
2541         if (!TEST_true(SSL_SESSION_set1_alpn_selected(sess, GOODALPN,
2542                                                       GOODALPNLEN))
2543                 || !TEST_false(SSL_set_alpn_protos(clientssl, BADALPN,
2544                                                    BADALPNLEN)))
2545             goto end;
2546         break;
2547
2548     case 2:
2549         /*
2550          * Set invalid protocol version. Technically this affects PSKs without
2551          * early_data too, but we test it here because it is similar to the
2552          * SNI/ALPN consistency tests.
2553          */
2554         err = SSL_R_BAD_PSK;
2555         if (!TEST_true(SSL_SESSION_set_protocol_version(sess, TLS1_2_VERSION)))
2556             goto end;
2557         break;
2558
2559     case 3:
2560         /*
2561          * Set inconsistent SNI (server detected). In this case the connection
2562          * will succeed but reject early_data.
2563          */
2564         SSL_SESSION_free(serverpsk);
2565         serverpsk = SSL_SESSION_dup(clientpsk);
2566         if (!TEST_ptr(serverpsk)
2567                 || !TEST_true(SSL_SESSION_set1_hostname(serverpsk, "badhost")))
2568             goto end;
2569         edstatus = SSL_EARLY_DATA_REJECTED;
2570         readearlyres = SSL_READ_EARLY_DATA_FINISH;
2571         /* Fall through */
2572     case 4:
2573         /* Set consistent SNI */
2574         if (!TEST_true(SSL_SESSION_set1_hostname(sess, "goodhost"))
2575                 || !TEST_true(SSL_set_tlsext_host_name(clientssl, "goodhost"))
2576                 || !TEST_true(SSL_CTX_set_tlsext_servername_callback(sctx,
2577                                 hostname_cb)))
2578             goto end;
2579         break;
2580
2581     case 5:
2582         /*
2583          * Set inconsistent ALPN (server detected). In this case the connection
2584          * will succeed but reject early_data.
2585          */
2586         servalpn = "badalpn";
2587         edstatus = SSL_EARLY_DATA_REJECTED;
2588         readearlyres = SSL_READ_EARLY_DATA_FINISH;
2589         /* Fall through */
2590     case 6:
2591         /*
2592          * Set consistent ALPN.
2593          * SSL_set_alpn_protos returns 0 for success and 1 for failure. It
2594          * accepts a list of protos (each one length prefixed).
2595          * SSL_set1_alpn_selected accepts a single protocol (not length
2596          * prefixed)
2597          */
2598         if (!TEST_true(SSL_SESSION_set1_alpn_selected(sess, GOODALPN + 1,
2599                                                       GOODALPNLEN - 1))
2600                 || !TEST_false(SSL_set_alpn_protos(clientssl, GOODALPN,
2601                                                    GOODALPNLEN)))
2602             goto end;
2603
2604         SSL_CTX_set_alpn_select_cb(sctx, alpn_select_cb, NULL);
2605         break;
2606
2607     case 7:
2608         /* Set inconsistent ALPN (late client detection) */
2609         SSL_SESSION_free(serverpsk);
2610         serverpsk = SSL_SESSION_dup(clientpsk);
2611         if (!TEST_ptr(serverpsk)
2612                 || !TEST_true(SSL_SESSION_set1_alpn_selected(clientpsk,
2613                                                              BADALPN + 1,
2614                                                              BADALPNLEN - 1))
2615                 || !TEST_true(SSL_SESSION_set1_alpn_selected(serverpsk,
2616                                                              GOODALPN + 1,
2617                                                              GOODALPNLEN - 1))
2618                 || !TEST_false(SSL_set_alpn_protos(clientssl, alpnlist,
2619                                                    sizeof(alpnlist))))
2620             goto end;
2621         SSL_CTX_set_alpn_select_cb(sctx, alpn_select_cb, NULL);
2622         edstatus = SSL_EARLY_DATA_ACCEPTED;
2623         readearlyres = SSL_READ_EARLY_DATA_SUCCESS;
2624         /* SSL_connect() call should fail */
2625         connectres = -1;
2626         break;
2627
2628     default:
2629         TEST_error("Bad test index");
2630         goto end;
2631     }
2632
2633     SSL_set_connect_state(clientssl);
2634     if (err != 0) {
2635         if (!TEST_false(SSL_write_early_data(clientssl, MSG1, strlen(MSG1),
2636                                             &written))
2637                 || !TEST_int_eq(SSL_get_error(clientssl, 0), SSL_ERROR_SSL)
2638                 || !TEST_int_eq(ERR_GET_REASON(ERR_get_error()), err))
2639             goto end;
2640     } else {
2641         if (!TEST_true(SSL_write_early_data(clientssl, MSG1, strlen(MSG1),
2642                                             &written)))
2643             goto end;
2644
2645         if (!TEST_int_eq(SSL_read_early_data(serverssl, buf, sizeof(buf),
2646                                              &readbytes), readearlyres)
2647                 || (readearlyres == SSL_READ_EARLY_DATA_SUCCESS
2648                     && !TEST_mem_eq(buf, readbytes, MSG1, strlen(MSG1)))
2649                 || !TEST_int_eq(SSL_get_early_data_status(serverssl), edstatus)
2650                 || !TEST_int_eq(SSL_connect(clientssl), connectres))
2651             goto end;
2652     }
2653
2654     testresult = 1;
2655
2656  end:
2657     SSL_SESSION_free(sess);
2658     SSL_SESSION_free(clientpsk);
2659     SSL_SESSION_free(serverpsk);
2660     clientpsk = serverpsk = NULL;
2661     SSL_free(serverssl);
2662     SSL_free(clientssl);
2663     SSL_CTX_free(sctx);
2664     SSL_CTX_free(cctx);
2665     return testresult;
2666 }
2667
2668 /*
2669  * Test that a server that doesn't try to read early data can handle a
2670  * client sending some.
2671  */
2672 static int test_early_data_not_expected(int idx)
2673 {
2674     SSL_CTX *cctx = NULL, *sctx = NULL;
2675     SSL *clientssl = NULL, *serverssl = NULL;
2676     int testresult = 0;
2677     SSL_SESSION *sess = NULL;
2678     unsigned char buf[20];
2679     size_t readbytes, written;
2680
2681     if (!TEST_true(setupearly_data_test(&cctx, &sctx, &clientssl,
2682                                         &serverssl, &sess, idx)))
2683         goto end;
2684
2685     /* Write some early data */
2686     if (!TEST_true(SSL_write_early_data(clientssl, MSG1, strlen(MSG1),
2687                                         &written)))
2688         goto end;
2689
2690     /*
2691      * Server should skip over early data and then block waiting for client to
2692      * continue handshake
2693      */
2694     if (!TEST_int_le(SSL_accept(serverssl), 0)
2695      || !TEST_int_gt(SSL_connect(clientssl), 0)
2696      || !TEST_int_eq(SSL_get_early_data_status(serverssl),
2697                      SSL_EARLY_DATA_REJECTED)
2698      || !TEST_int_gt(SSL_accept(serverssl), 0)
2699      || !TEST_int_eq(SSL_get_early_data_status(clientssl),
2700                      SSL_EARLY_DATA_REJECTED))
2701         goto end;
2702
2703     /* Send some normal data from client to server */
2704     if (!TEST_true(SSL_write_ex(clientssl, MSG2, strlen(MSG2), &written))
2705             || !TEST_size_t_eq(written, strlen(MSG2)))
2706         goto end;
2707
2708     if (!TEST_true(SSL_read_ex(serverssl, buf, sizeof(buf), &readbytes))
2709             || !TEST_mem_eq(buf, readbytes, MSG2, strlen(MSG2)))
2710         goto end;
2711
2712     testresult = 1;
2713
2714  end:
2715     SSL_SESSION_free(sess);
2716     SSL_SESSION_free(clientpsk);
2717     SSL_SESSION_free(serverpsk);
2718     clientpsk = serverpsk = NULL;
2719     SSL_free(serverssl);
2720     SSL_free(clientssl);
2721     SSL_CTX_free(sctx);
2722     SSL_CTX_free(cctx);
2723     return testresult;
2724 }
2725
2726
2727 # ifndef OPENSSL_NO_TLS1_2
2728 /*
2729  * Test that a server attempting to read early data can handle a connection
2730  * from a TLSv1.2 client.
2731  */
2732 static int test_early_data_tls1_2(int idx)
2733 {
2734     SSL_CTX *cctx = NULL, *sctx = NULL;
2735     SSL *clientssl = NULL, *serverssl = NULL;
2736     int testresult = 0;
2737     unsigned char buf[20];
2738     size_t readbytes, written;
2739
2740     if (!TEST_true(setupearly_data_test(&cctx, &sctx, &clientssl,
2741                                         &serverssl, NULL, idx)))
2742         goto end;
2743
2744     /* Write some data - should block due to handshake with server */
2745     SSL_set_max_proto_version(clientssl, TLS1_2_VERSION);
2746     SSL_set_connect_state(clientssl);
2747     if (!TEST_false(SSL_write_ex(clientssl, MSG1, strlen(MSG1), &written)))
2748         goto end;
2749
2750     /*
2751      * Server should do TLSv1.2 handshake. First it will block waiting for more
2752      * messages from client after ServerDone. Then SSL_read_early_data should
2753      * finish and detect that early data has not been sent
2754      */
2755     if (!TEST_int_eq(SSL_read_early_data(serverssl, buf, sizeof(buf),
2756                                          &readbytes),
2757                      SSL_READ_EARLY_DATA_ERROR))
2758         goto end;
2759
2760     /*
2761      * Continue writing the message we started earlier. Will still block waiting
2762      * for the CCS/Finished from server
2763      */
2764     if (!TEST_false(SSL_write_ex(clientssl, MSG1, strlen(MSG1), &written))
2765             || !TEST_int_eq(SSL_read_early_data(serverssl, buf, sizeof(buf),
2766                                                 &readbytes),
2767                             SSL_READ_EARLY_DATA_FINISH)
2768             || !TEST_size_t_eq(readbytes, 0)
2769             || !TEST_int_eq(SSL_get_early_data_status(serverssl),
2770                             SSL_EARLY_DATA_NOT_SENT))
2771         goto end;
2772
2773     /* Continue writing the message we started earlier */
2774     if (!TEST_true(SSL_write_ex(clientssl, MSG1, strlen(MSG1), &written))
2775             || !TEST_size_t_eq(written, strlen(MSG1))
2776             || !TEST_int_eq(SSL_get_early_data_status(clientssl),
2777                             SSL_EARLY_DATA_NOT_SENT)
2778             || !TEST_true(SSL_read_ex(serverssl, buf, sizeof(buf), &readbytes))
2779             || !TEST_mem_eq(buf, readbytes, MSG1, strlen(MSG1))
2780             || !TEST_true(SSL_write_ex(serverssl, MSG2, strlen(MSG2), &written))
2781             || !TEST_size_t_eq(written, strlen(MSG2))
2782             || !SSL_read_ex(clientssl, buf, sizeof(buf), &readbytes)
2783             || !TEST_mem_eq(buf, readbytes, MSG2, strlen(MSG2)))
2784         goto end;
2785
2786     testresult = 1;
2787
2788  end:
2789     SSL_SESSION_free(clientpsk);
2790     SSL_SESSION_free(serverpsk);
2791     clientpsk = serverpsk = NULL;
2792     SSL_free(serverssl);
2793     SSL_free(clientssl);
2794     SSL_CTX_free(sctx);
2795     SSL_CTX_free(cctx);
2796
2797     return testresult;
2798 }
2799 # endif /* OPENSSL_NO_TLS1_2 */
2800
2801 /*
2802  * Test configuring the TLSv1.3 ciphersuites
2803  *
2804  * Test 0: Set a default ciphersuite in the SSL_CTX (no explicit cipher_list)
2805  * Test 1: Set a non-default ciphersuite in the SSL_CTX (no explicit cipher_list)
2806  * Test 2: Set a default ciphersuite in the SSL (no explicit cipher_list)
2807  * Test 3: Set a non-default ciphersuite in the SSL (no explicit cipher_list)
2808  * Test 4: Set a default ciphersuite in the SSL_CTX (SSL_CTX cipher_list)
2809  * Test 5: Set a non-default ciphersuite in the SSL_CTX (SSL_CTX cipher_list)
2810  * Test 6: Set a default ciphersuite in the SSL (SSL_CTX cipher_list)
2811  * Test 7: Set a non-default ciphersuite in the SSL (SSL_CTX cipher_list)
2812  * Test 8: Set a default ciphersuite in the SSL (SSL cipher_list)
2813  * Test 9: Set a non-default ciphersuite in the SSL (SSL cipher_list)
2814  */
2815 static int test_set_ciphersuite(int idx)
2816 {
2817     SSL_CTX *cctx = NULL, *sctx = NULL;
2818     SSL *clientssl = NULL, *serverssl = NULL;
2819     int testresult = 0;
2820
2821     if (!TEST_true(create_ssl_ctx_pair(TLS_server_method(), TLS_client_method(),
2822                                        TLS1_VERSION, TLS_MAX_VERSION,
2823                                        &sctx, &cctx, cert, privkey))
2824             || !TEST_true(SSL_CTX_set_ciphersuites(sctx,
2825                            "TLS_AES_128_GCM_SHA256:TLS_AES_128_CCM_SHA256")))
2826         goto end;
2827
2828     if (idx >=4 && idx <= 7) {
2829         /* SSL_CTX explicit cipher list */
2830         if (!TEST_true(SSL_CTX_set_cipher_list(cctx, "AES256-GCM-SHA384")))
2831             goto end;
2832     }
2833
2834     if (idx == 0 || idx == 4) {
2835         /* Default ciphersuite */
2836         if (!TEST_true(SSL_CTX_set_ciphersuites(cctx,
2837                                                 "TLS_AES_128_GCM_SHA256")))
2838             goto end;
2839     } else if (idx == 1 || idx == 5) {
2840         /* Non default ciphersuite */
2841         if (!TEST_true(SSL_CTX_set_ciphersuites(cctx,
2842                                                 "TLS_AES_128_CCM_SHA256")))
2843             goto end;
2844     }
2845
2846     if (!TEST_true(create_ssl_objects(sctx, cctx, &serverssl,
2847                                           &clientssl, NULL, NULL)))
2848         goto end;
2849
2850     if (idx == 8 || idx == 9) {
2851         /* SSL explicit cipher list */
2852         if (!TEST_true(SSL_set_cipher_list(clientssl, "AES256-GCM-SHA384")))
2853             goto end;
2854     }
2855
2856     if (idx == 2 || idx == 6 || idx == 8) {
2857         /* Default ciphersuite */
2858         if (!TEST_true(SSL_set_ciphersuites(clientssl,
2859                                             "TLS_AES_128_GCM_SHA256")))
2860             goto end;
2861     } else if (idx == 3 || idx == 7 || idx == 9) {
2862         /* Non default ciphersuite */
2863         if (!TEST_true(SSL_set_ciphersuites(clientssl,
2864                                             "TLS_AES_128_CCM_SHA256")))
2865             goto end;
2866     }
2867
2868     if (!TEST_true(create_ssl_connection(serverssl, clientssl, SSL_ERROR_NONE)))
2869         goto end;
2870
2871     testresult = 1;
2872
2873  end:
2874     SSL_free(serverssl);
2875     SSL_free(clientssl);
2876     SSL_CTX_free(sctx);
2877     SSL_CTX_free(cctx);
2878
2879     return testresult;
2880 }
2881
2882 static int test_ciphersuite_change(void)
2883 {
2884     SSL_CTX *cctx = NULL, *sctx = NULL;
2885     SSL *clientssl = NULL, *serverssl = NULL;
2886     SSL_SESSION *clntsess = NULL;
2887     int testresult = 0;
2888     const SSL_CIPHER *aes_128_gcm_sha256 = NULL;
2889
2890     /* Create a session based on SHA-256 */
2891     if (!TEST_true(create_ssl_ctx_pair(TLS_server_method(), TLS_client_method(),
2892                                        TLS1_VERSION, TLS_MAX_VERSION,
2893                                        &sctx, &cctx, cert, privkey))
2894             || !TEST_true(SSL_CTX_set_ciphersuites(cctx,
2895                                                    "TLS_AES_128_GCM_SHA256"))
2896             || !TEST_true(create_ssl_objects(sctx, cctx, &serverssl,
2897                                           &clientssl, NULL, NULL))
2898             || !TEST_true(create_ssl_connection(serverssl, clientssl,
2899                                                 SSL_ERROR_NONE)))
2900         goto end;
2901
2902     clntsess = SSL_get1_session(clientssl);
2903     /* Save for later */
2904     aes_128_gcm_sha256 = SSL_SESSION_get0_cipher(clntsess);
2905     SSL_shutdown(clientssl);
2906     SSL_shutdown(serverssl);
2907     SSL_free(serverssl);
2908     SSL_free(clientssl);
2909     serverssl = clientssl = NULL;
2910
2911 # if !defined(OPENSSL_NO_CHACHA) && !defined(OPENSSL_NO_POLY1305)
2912     /* Check we can resume a session with a different SHA-256 ciphersuite */
2913     if (!TEST_true(SSL_CTX_set_ciphersuites(cctx,
2914                                             "TLS_CHACHA20_POLY1305_SHA256"))
2915             || !TEST_true(create_ssl_objects(sctx, cctx, &serverssl, &clientssl,
2916                                              NULL, NULL))
2917             || !TEST_true(SSL_set_session(clientssl, clntsess))
2918             || !TEST_true(create_ssl_connection(serverssl, clientssl,
2919                                                 SSL_ERROR_NONE))
2920             || !TEST_true(SSL_session_reused(clientssl)))
2921         goto end;
2922
2923     SSL_SESSION_free(clntsess);
2924     clntsess = SSL_get1_session(clientssl);
2925     SSL_shutdown(clientssl);
2926     SSL_shutdown(serverssl);
2927     SSL_free(serverssl);
2928     SSL_free(clientssl);
2929     serverssl = clientssl = NULL;
2930 # endif
2931
2932     /*
2933      * Check attempting to resume a SHA-256 session with no SHA-256 ciphersuites
2934      * succeeds but does not resume.
2935      */
2936     if (!TEST_true(SSL_CTX_set_ciphersuites(cctx, "TLS_AES_256_GCM_SHA384"))
2937             || !TEST_true(create_ssl_objects(sctx, cctx, &serverssl, &clientssl,
2938                                              NULL, NULL))
2939             || !TEST_true(SSL_set_session(clientssl, clntsess))
2940             || !TEST_true(create_ssl_connection(serverssl, clientssl,
2941                                                 SSL_ERROR_SSL))
2942             || !TEST_false(SSL_session_reused(clientssl)))
2943         goto end;
2944
2945     SSL_SESSION_free(clntsess);
2946     clntsess = NULL;
2947     SSL_shutdown(clientssl);
2948     SSL_shutdown(serverssl);
2949     SSL_free(serverssl);
2950     SSL_free(clientssl);
2951     serverssl = clientssl = NULL;
2952
2953     /* Create a session based on SHA384 */
2954     if (!TEST_true(SSL_CTX_set_ciphersuites(cctx, "TLS_AES_256_GCM_SHA384"))
2955             || !TEST_true(create_ssl_objects(sctx, cctx, &serverssl,
2956                                           &clientssl, NULL, NULL))
2957             || !TEST_true(create_ssl_connection(serverssl, clientssl,
2958                                                 SSL_ERROR_NONE)))
2959         goto end;
2960
2961     clntsess = SSL_get1_session(clientssl);
2962     SSL_shutdown(clientssl);
2963     SSL_shutdown(serverssl);
2964     SSL_free(serverssl);
2965     SSL_free(clientssl);
2966     serverssl = clientssl = NULL;
2967
2968     if (!TEST_true(SSL_CTX_set_ciphersuites(cctx,
2969                    "TLS_AES_128_GCM_SHA256:TLS_AES_256_GCM_SHA384"))
2970             || !TEST_true(SSL_CTX_set_ciphersuites(sctx,
2971                                                    "TLS_AES_256_GCM_SHA384"))
2972             || !TEST_true(create_ssl_objects(sctx, cctx, &serverssl, &clientssl,
2973                                              NULL, NULL))
2974             || !TEST_true(SSL_set_session(clientssl, clntsess))
2975                /*
2976                 * We use SSL_ERROR_WANT_READ below so that we can pause the
2977                 * connection after the initial ClientHello has been sent to
2978                 * enable us to make some session changes.
2979                 */
2980             || !TEST_false(create_ssl_connection(serverssl, clientssl,
2981                                                 SSL_ERROR_WANT_READ)))
2982         goto end;
2983
2984     /* Trick the client into thinking this session is for a different digest */
2985     clntsess->cipher = aes_128_gcm_sha256;
2986     clntsess->cipher_id = clntsess->cipher->id;
2987
2988     /*
2989      * Continue the previously started connection. Server has selected a SHA-384
2990      * ciphersuite, but client thinks the session is for SHA-256, so it should
2991      * bail out.
2992      */
2993     if (!TEST_false(create_ssl_connection(serverssl, clientssl,
2994                                                 SSL_ERROR_SSL))
2995             || !TEST_int_eq(ERR_GET_REASON(ERR_get_error()),
2996                             SSL_R_CIPHERSUITE_DIGEST_HAS_CHANGED))
2997         goto end;
2998
2999     testresult = 1;
3000
3001  end:
3002     SSL_SESSION_free(clntsess);
3003     SSL_free(serverssl);
3004     SSL_free(clientssl);
3005     SSL_CTX_free(sctx);
3006     SSL_CTX_free(cctx);
3007
3008     return testresult;
3009 }
3010
3011 /*
3012  * Test TLSv1.3 PSKs
3013  * Test 0 = Test new style callbacks
3014  * Test 1 = Test both new and old style callbacks
3015  * Test 2 = Test old style callbacks
3016  * Test 3 = Test old style callbacks with no certificate
3017  */
3018 static int test_tls13_psk(int idx)
3019 {
3020     SSL_CTX *sctx = NULL, *cctx = NULL;
3021     SSL *serverssl = NULL, *clientssl = NULL;
3022     const SSL_CIPHER *cipher = NULL;
3023     const unsigned char key[] = {
3024         0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b,
3025         0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
3026         0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20, 0x21, 0x22, 0x23,
3027         0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f
3028     };
3029     int testresult = 0;
3030
3031     if (!TEST_true(create_ssl_ctx_pair(TLS_server_method(), TLS_client_method(),
3032                                        TLS1_VERSION, TLS_MAX_VERSION,
3033                                        &sctx, &cctx, idx == 3 ? NULL : cert,
3034                                        idx == 3 ? NULL : privkey)))
3035         goto end;
3036
3037     if (idx != 3) {
3038         /*
3039          * We use a ciphersuite with SHA256 to ease testing old style PSK
3040          * callbacks which will always default to SHA256. This should not be
3041          * necessary if we have no cert/priv key. In that case the server should
3042          * prefer SHA256 automatically.
3043          */
3044         if (!TEST_true(SSL_CTX_set_ciphersuites(cctx,
3045                                                 "TLS_AES_128_GCM_SHA256")))
3046             goto end;
3047     }
3048
3049     /*
3050      * Test 0: New style callbacks only
3051      * Test 1: New and old style callbacks (only the new ones should be used)
3052      * Test 2: Old style callbacks only
3053      */
3054     if (idx == 0 || idx == 1) {
3055         SSL_CTX_set_psk_use_session_callback(cctx, use_session_cb);
3056         SSL_CTX_set_psk_find_session_callback(sctx, find_session_cb);
3057     }
3058 #ifndef OPENSSL_NO_PSK
3059     if (idx >= 1) {
3060         SSL_CTX_set_psk_client_callback(cctx, psk_client_cb);
3061         SSL_CTX_set_psk_server_callback(sctx, psk_server_cb);
3062     }
3063 #endif
3064     srvid = pskid;
3065     use_session_cb_cnt = 0;
3066     find_session_cb_cnt = 0;
3067     psk_client_cb_cnt = 0;
3068     psk_server_cb_cnt = 0;
3069
3070     if (idx != 3) {
3071         /*
3072          * Check we can create a connection if callback decides not to send a
3073          * PSK
3074          */
3075         if (!TEST_true(create_ssl_objects(sctx, cctx, &serverssl, &clientssl,
3076                                                  NULL, NULL))
3077                 || !TEST_true(create_ssl_connection(serverssl, clientssl,
3078                                                     SSL_ERROR_NONE))
3079                 || !TEST_false(SSL_session_reused(clientssl))
3080                 || !TEST_false(SSL_session_reused(serverssl)))
3081             goto end;
3082
3083         if (idx == 0 || idx == 1) {
3084             if (!TEST_true(use_session_cb_cnt == 1)
3085                     || !TEST_true(find_session_cb_cnt == 0)
3086                        /*
3087                         * If no old style callback then below should be 0
3088                         * otherwise 1
3089                         */
3090                     || !TEST_true(psk_client_cb_cnt == idx)
3091                     || !TEST_true(psk_server_cb_cnt == 0))
3092                 goto end;
3093         } else {
3094             if (!TEST_true(use_session_cb_cnt == 0)
3095                     || !TEST_true(find_session_cb_cnt == 0)
3096                     || !TEST_true(psk_client_cb_cnt == 1)
3097                     || !TEST_true(psk_server_cb_cnt == 0))
3098                 goto end;
3099         }
3100
3101         shutdown_ssl_connection(serverssl, clientssl);
3102         serverssl = clientssl = NULL;
3103         use_session_cb_cnt = psk_client_cb_cnt = 0;
3104     }
3105
3106     if (!TEST_true(create_ssl_objects(sctx, cctx, &serverssl, &clientssl,
3107                                              NULL, NULL)))
3108         goto end;
3109
3110     /* Create the PSK */
3111     cipher = SSL_CIPHER_find(clientssl, TLS13_AES_128_GCM_SHA256_BYTES);
3112     clientpsk = SSL_SESSION_new();
3113     if (!TEST_ptr(clientpsk)
3114             || !TEST_ptr(cipher)
3115             || !TEST_true(SSL_SESSION_set1_master_key(clientpsk, key,
3116                                                       sizeof(key)))
3117             || !TEST_true(SSL_SESSION_set_cipher(clientpsk, cipher))
3118             || !TEST_true(SSL_SESSION_set_protocol_version(clientpsk,
3119                                                            TLS1_3_VERSION))
3120             || !TEST_true(SSL_SESSION_up_ref(clientpsk)))
3121         goto end;
3122     serverpsk = clientpsk;
3123
3124     /* Check we can create a connection and the PSK is used */
3125     if (!TEST_true(create_ssl_connection(serverssl, clientssl, SSL_ERROR_NONE))
3126             || !TEST_true(SSL_session_reused(clientssl))
3127             || !TEST_true(SSL_session_reused(serverssl)))
3128         goto end;
3129
3130     if (idx == 0 || idx == 1) {
3131         if (!TEST_true(use_session_cb_cnt == 1)
3132                 || !TEST_true(find_session_cb_cnt == 1)
3133                 || !TEST_true(psk_client_cb_cnt == 0)
3134                 || !TEST_true(psk_server_cb_cnt == 0))
3135             goto end;
3136     } else {
3137         if (!TEST_true(use_session_cb_cnt == 0)
3138                 || !TEST_true(find_session_cb_cnt == 0)
3139                 || !TEST_true(psk_client_cb_cnt == 1)
3140                 || !TEST_true(psk_server_cb_cnt == 1))
3141             goto end;
3142     }
3143
3144     shutdown_ssl_connection(serverssl, clientssl);
3145     serverssl = clientssl = NULL;
3146     use_session_cb_cnt = find_session_cb_cnt = 0;
3147     psk_client_cb_cnt = psk_server_cb_cnt = 0;
3148
3149     if (!TEST_true(create_ssl_objects(sctx, cctx, &serverssl, &clientssl,
3150                                              NULL, NULL)))
3151         goto end;
3152
3153     /* Force an HRR */
3154     if (!TEST_true(SSL_set1_groups_list(serverssl, "P-256")))
3155         goto end;
3156
3157     /*
3158      * Check we can create a connection, the PSK is used and the callbacks are
3159      * called twice.
3160      */
3161     if (!TEST_true(create_ssl_connection(serverssl, clientssl, SSL_ERROR_NONE))
3162             || !TEST_true(SSL_session_reused(clientssl))
3163             || !TEST_true(SSL_session_reused(serverssl)))
3164         goto end;
3165
3166     if (idx == 0 || idx == 1) {
3167         if (!TEST_true(use_session_cb_cnt == 2)
3168                 || !TEST_true(find_session_cb_cnt == 2)
3169                 || !TEST_true(psk_client_cb_cnt == 0)
3170                 || !TEST_true(psk_server_cb_cnt == 0))
3171             goto end;
3172     } else {
3173         if (!TEST_true(use_session_cb_cnt == 0)
3174                 || !TEST_true(find_session_cb_cnt == 0)
3175                 || !TEST_true(psk_client_cb_cnt == 2)
3176                 || !TEST_true(psk_server_cb_cnt == 2))
3177             goto end;
3178     }
3179
3180     shutdown_ssl_connection(serverssl, clientssl);
3181     serverssl = clientssl = NULL;
3182     use_session_cb_cnt = find_session_cb_cnt = 0;
3183     psk_client_cb_cnt = psk_server_cb_cnt = 0;
3184
3185     if (idx != 3) {
3186         /*
3187          * Check that if the server rejects the PSK we can still connect, but with
3188          * a full handshake
3189          */
3190         srvid = "Dummy Identity";
3191         if (!TEST_true(create_ssl_objects(sctx, cctx, &serverssl, &clientssl,
3192                                                  NULL, NULL))
3193                 || !TEST_true(create_ssl_connection(serverssl, clientssl,
3194                                                     SSL_ERROR_NONE))
3195                 || !TEST_false(SSL_session_reused(clientssl))
3196                 || !TEST_false(SSL_session_reused(serverssl)))
3197             goto end;
3198
3199         if (idx == 0 || idx == 1) {
3200             if (!TEST_true(use_session_cb_cnt == 1)
3201                     || !TEST_true(find_session_cb_cnt == 1)
3202                     || !TEST_true(psk_client_cb_cnt == 0)
3203                        /*
3204                         * If no old style callback then below should be 0
3205                         * otherwise 1
3206                         */
3207                     || !TEST_true(psk_server_cb_cnt == idx))
3208                 goto end;
3209         } else {
3210             if (!TEST_true(use_session_cb_cnt == 0)
3211                     || !TEST_true(find_session_cb_cnt == 0)
3212                     || !TEST_true(psk_client_cb_cnt == 1)
3213                     || !TEST_true(psk_server_cb_cnt == 1))
3214                 goto end;
3215         }
3216
3217         shutdown_ssl_connection(serverssl, clientssl);
3218         serverssl = clientssl = NULL;
3219     }
3220     testresult = 1;
3221
3222  end:
3223     SSL_SESSION_free(clientpsk);
3224     SSL_SESSION_free(serverpsk);
3225     clientpsk = serverpsk = NULL;
3226     SSL_free(serverssl);
3227     SSL_free(clientssl);
3228     SSL_CTX_free(sctx);
3229     SSL_CTX_free(cctx);
3230     return testresult;
3231 }
3232
3233 static unsigned char cookie_magic_value[] = "cookie magic";
3234
3235 static int generate_cookie_callback(SSL *ssl, unsigned char *cookie,
3236                                     unsigned int *cookie_len)
3237 {
3238     /*
3239      * Not suitable as a real cookie generation function but good enough for
3240      * testing!
3241      */
3242     memcpy(cookie, cookie_magic_value, sizeof(cookie_magic_value) - 1);
3243     *cookie_len = sizeof(cookie_magic_value) - 1;
3244
3245     return 1;
3246 }
3247
3248 static int verify_cookie_callback(SSL *ssl, const unsigned char *cookie,
3249                                   unsigned int cookie_len)
3250 {
3251     if (cookie_len == sizeof(cookie_magic_value) - 1
3252         && memcmp(cookie, cookie_magic_value, cookie_len) == 0)
3253         return 1;
3254
3255     return 0;
3256 }
3257
3258 static int generate_stateless_cookie_callback(SSL *ssl, unsigned char *cookie,
3259                                         size_t *cookie_len)
3260 {
3261     unsigned int temp;
3262     int res = generate_cookie_callback(ssl, cookie, &temp);
3263     *cookie_len = temp;
3264     return res;
3265 }
3266
3267 static int verify_stateless_cookie_callback(SSL *ssl, const unsigned char *cookie,
3268                                       size_t cookie_len)
3269 {
3270     return verify_cookie_callback(ssl, cookie, cookie_len);
3271 }
3272
3273 static int test_stateless(void)
3274 {
3275     SSL_CTX *sctx = NULL, *cctx = NULL;
3276     SSL *serverssl = NULL, *clientssl = NULL;
3277     int testresult = 0;
3278
3279     if (!TEST_true(create_ssl_ctx_pair(TLS_server_method(), TLS_client_method(),
3280                                        TLS1_VERSION, TLS_MAX_VERSION,
3281                                        &sctx, &cctx, cert, privkey)))
3282         goto end;
3283
3284     /* The arrival of CCS messages can confuse the test */
3285     SSL_CTX_clear_options(cctx, SSL_OP_ENABLE_MIDDLEBOX_COMPAT);
3286
3287     if (!TEST_true(create_ssl_objects(sctx, cctx, &serverssl, &clientssl,
3288                                       NULL, NULL))
3289                /* Send the first ClientHello */
3290             || !TEST_false(create_ssl_connection(serverssl, clientssl,
3291                                                  SSL_ERROR_WANT_READ))
3292                /*
3293                 * This should fail with a -1 return because we have no callbacks
3294                 * set up
3295                 */
3296             || !TEST_int_eq(SSL_stateless(serverssl), -1))
3297         goto end;
3298
3299     /* Fatal error so abandon the connection from this client */
3300     SSL_free(clientssl);
3301     clientssl = NULL;
3302
3303     /* Set up the cookie generation and verification callbacks */
3304     SSL_CTX_set_stateless_cookie_generate_cb(sctx, generate_stateless_cookie_callback);
3305     SSL_CTX_set_stateless_cookie_verify_cb(sctx, verify_stateless_cookie_callback);
3306
3307     /*
3308      * Create a new connection from the client (we can reuse the server SSL
3309      * object).
3310      */
3311     if (!TEST_true(create_ssl_objects(sctx, cctx, &serverssl, &clientssl,
3312                                              NULL, NULL))
3313                /* Send the first ClientHello */
3314             || !TEST_false(create_ssl_connection(serverssl, clientssl,
3315                                                 SSL_ERROR_WANT_READ))
3316                /* This should fail because there is no cookie */
3317             || !TEST_int_eq(SSL_stateless(serverssl), 0))
3318         goto end;
3319
3320     /* Abandon the connection from this client */
3321     SSL_free(clientssl);
3322     clientssl = NULL;
3323
3324     /*
3325      * Now create a connection from a new client but with the same server SSL
3326      * object
3327      */
3328     if (!TEST_true(create_ssl_objects(sctx, cctx, &serverssl, &clientssl,
3329                                              NULL, NULL))
3330                /* Send the first ClientHello */
3331             || !TEST_false(create_ssl_connection(serverssl, clientssl,
3332                                                 SSL_ERROR_WANT_READ))
3333                /* This should fail because there is no cookie */
3334             || !TEST_int_eq(SSL_stateless(serverssl), 0)
3335                /* Send the second ClientHello */
3336             || !TEST_false(create_ssl_connection(serverssl, clientssl,
3337                                                 SSL_ERROR_WANT_READ))
3338                /* This should succeed because a cookie is now present */
3339             || !TEST_int_eq(SSL_stateless(serverssl), 1)
3340                /* Complete the connection */
3341             || !TEST_true(create_ssl_connection(serverssl, clientssl,
3342                                                 SSL_ERROR_NONE)))
3343         goto end;
3344
3345     shutdown_ssl_connection(serverssl, clientssl);
3346     serverssl = clientssl = NULL;
3347     testresult = 1;
3348
3349  end:
3350     SSL_free(serverssl);
3351     SSL_free(clientssl);
3352     SSL_CTX_free(sctx);
3353     SSL_CTX_free(cctx);
3354     return testresult;
3355
3356 }
3357 #endif /* OPENSSL_NO_TLS1_3 */
3358
3359 static int clntaddoldcb = 0;
3360 static int clntparseoldcb = 0;
3361 static int srvaddoldcb = 0;
3362 static int srvparseoldcb = 0;
3363 static int clntaddnewcb = 0;
3364 static int clntparsenewcb = 0;
3365 static int srvaddnewcb = 0;
3366 static int srvparsenewcb = 0;
3367 static int snicb = 0;
3368
3369 #define TEST_EXT_TYPE1  0xff00
3370
3371 static int old_add_cb(SSL *s, unsigned int ext_type, const unsigned char **out,
3372                       size_t *outlen, int *al, void *add_arg)
3373 {
3374     int *server = (int *)add_arg;
3375     unsigned char *data;
3376
3377     if (SSL_is_server(s))
3378         srvaddoldcb++;
3379     else
3380         clntaddoldcb++;
3381
3382     if (*server != SSL_is_server(s)
3383             || (data = OPENSSL_malloc(sizeof(*data))) == NULL)
3384         return -1;
3385
3386     *data = 1;
3387     *out = data;
3388     *outlen = sizeof(char);
3389     return 1;
3390 }
3391
3392 static void old_free_cb(SSL *s, unsigned int ext_type, const unsigned char *out,
3393                         void *add_arg)
3394 {
3395     OPENSSL_free((unsigned char *)out);
3396 }
3397
3398 static int old_parse_cb(SSL *s, unsigned int ext_type, const unsigned char *in,
3399                         size_t inlen, int *al, void *parse_arg)
3400 {
3401     int *server = (int *)parse_arg;
3402
3403     if (SSL_is_server(s))
3404         srvparseoldcb++;
3405     else
3406         clntparseoldcb++;
3407
3408     if (*server != SSL_is_server(s)
3409             || inlen != sizeof(char)
3410             || *in != 1)
3411         return -1;
3412
3413     return 1;
3414 }
3415
3416 static int new_add_cb(SSL *s, unsigned int ext_type, unsigned int context,
3417                       const unsigned char **out, size_t *outlen, X509 *x,
3418                       size_t chainidx, int *al, void *add_arg)
3419 {
3420     int *server = (int *)add_arg;
3421     unsigned char *data;
3422
3423     if (SSL_is_server(s))
3424         srvaddnewcb++;
3425     else
3426         clntaddnewcb++;
3427
3428     if (*server != SSL_is_server(s)
3429             || (data = OPENSSL_malloc(sizeof(*data))) == NULL)
3430         return -1;
3431
3432     *data = 1;
3433     *out = data;
3434     *outlen = sizeof(*data);
3435     return 1;
3436 }
3437
3438 static void new_free_cb(SSL *s, unsigned int ext_type, unsigned int context,
3439                         const unsigned char *out, void *add_arg)
3440 {
3441     OPENSSL_free((unsigned char *)out);
3442 }
3443
3444 static int new_parse_cb(SSL *s, unsigned int ext_type, unsigned int context,
3445                         const unsigned char *in, size_t inlen, X509 *x,
3446                         size_t chainidx, int *al, void *parse_arg)
3447 {
3448     int *server = (int *)parse_arg;
3449
3450     if (SSL_is_server(s))
3451         srvparsenewcb++;
3452     else
3453         clntparsenewcb++;
3454
3455     if (*server != SSL_is_server(s)
3456             || inlen != sizeof(char) || *in != 1)
3457         return -1;
3458
3459     return 1;
3460 }
3461
3462 static int sni_cb(SSL *s, int *al, void *arg)
3463 {
3464     SSL_CTX *ctx = (SSL_CTX *)arg;
3465
3466     if (SSL_set_SSL_CTX(s, ctx) == NULL) {
3467         *al = SSL_AD_INTERNAL_ERROR;
3468         return SSL_TLSEXT_ERR_ALERT_FATAL;
3469     }
3470     snicb++;
3471     return SSL_TLSEXT_ERR_OK;
3472 }
3473
3474 /*
3475  * Custom call back tests.
3476  * Test 0: Old style callbacks in TLSv1.2
3477  * Test 1: New style callbacks in TLSv1.2
3478  * Test 2: New style callbacks in TLSv1.2 with SNI
3479  * Test 3: New style callbacks in TLSv1.3. Extensions in CH and EE
3480  * Test 4: New style callbacks in TLSv1.3. Extensions in CH, SH, EE, Cert + NST
3481  */
3482 static int test_custom_exts(int tst)
3483 {
3484     SSL_CTX *cctx = NULL, *sctx = NULL, *sctx2 = NULL;
3485     SSL *clientssl = NULL, *serverssl = NULL;
3486     int testresult = 0;
3487     static int server = 1;
3488     static int client = 0;
3489     SSL_SESSION *sess = NULL;
3490     unsigned int context;
3491
3492 #if defined(OPENSSL_NO_TLS1_2) && !defined(OPENSSL_NO_TLS1_3)
3493     /* Skip tests for TLSv1.2 and below in this case */
3494     if (tst < 3)
3495         return 1;
3496 #endif
3497
3498     /* Reset callback counters */
3499     clntaddoldcb = clntparseoldcb = srvaddoldcb = srvparseoldcb = 0;
3500     clntaddnewcb = clntparsenewcb = srvaddnewcb = srvparsenewcb = 0;
3501     snicb = 0;
3502
3503     if (!TEST_true(create_ssl_ctx_pair(TLS_server_method(), TLS_client_method(),
3504                                        TLS1_VERSION, TLS_MAX_VERSION,
3505                                        &sctx, &cctx, cert, privkey)))
3506         goto end;
3507
3508     if (tst == 2
3509             && !TEST_true(create_ssl_ctx_pair(TLS_server_method(), NULL,
3510                                               TLS1_VERSION, TLS_MAX_VERSION,
3511                                               &sctx2, NULL, cert, privkey)))
3512         goto end;
3513
3514
3515     if (tst < 3) {