Add a test for the new early data callback
[openssl.git] / test / sslapitest.c
index 61619a327ca55c5a98f31d39c3c60e936434f051..91d3bf97cf1fb49bbd8b5d96a80f0bc2e7865d18 100644 (file)
@@ -1847,11 +1847,14 @@ static unsigned int psk_server_cb(SSL *ssl, const char *identity,
 static int setupearly_data_test(SSL_CTX **cctx, SSL_CTX **sctx, SSL **clientssl,
                                 SSL **serverssl, SSL_SESSION **sess, int idx)
 {
 static int setupearly_data_test(SSL_CTX **cctx, SSL_CTX **sctx, SSL **clientssl,
                                 SSL **serverssl, SSL_SESSION **sess, int idx)
 {
-    if (!TEST_true(create_ssl_ctx_pair(TLS_server_method(), TLS_client_method(),
-                                       TLS1_VERSION, TLS_MAX_VERSION,
-                                       sctx, cctx, cert, privkey))
-        || !TEST_true(SSL_CTX_set_max_early_data(*sctx,
-                                                 SSL3_RT_MAX_PLAIN_LENGTH)))
+    if (*sctx == NULL
+            && !TEST_true(create_ssl_ctx_pair(TLS_server_method(),
+                                              TLS_client_method(),
+                                              TLS1_VERSION, TLS_MAX_VERSION,
+                                              sctx, cctx, cert, privkey)))
+        return 0;
+
+    if (!TEST_true(SSL_CTX_set_max_early_data(*sctx, SSL3_RT_MAX_PLAIN_LENGTH)))
         return 0;
 
     if (idx == 1) {
         return 0;
 
     if (idx == 1) {
@@ -2156,12 +2159,47 @@ static int test_early_data_read_write(int idx)
     return testresult;
 }
 
     return testresult;
 }
 
-static int test_early_data_replay(int idx)
+static int allow_ed_cb_called = 0;
+
+static int allow_early_data_cb(SSL *s, void *arg)
+{
+    int *usecb = (int *)arg;
+
+    allow_ed_cb_called++;
+
+    if (*usecb == 1)
+        return 0;
+
+    return 1;
+}
+
+/*
+ * idx == 0: Standard early_data setup
+ * idx == 1: early_data setup using read_ahead
+ * usecb == 0: Don't use a custom early data callback
+ * usecb == 1: Use a custom early data callback and reject the early data
+ * usecb == 2: Use a custom early data callback and accept the early data
+ */
+static int test_early_data_replay_int(int idx, int usecb)
 {
     SSL_CTX *cctx = NULL, *sctx = NULL;
     SSL *clientssl = NULL, *serverssl = NULL;
     int testresult = 0;
     SSL_SESSION *sess = NULL;
 {
     SSL_CTX *cctx = NULL, *sctx = NULL;
     SSL *clientssl = NULL, *serverssl = NULL;
     int testresult = 0;
     SSL_SESSION *sess = NULL;
+    size_t readbytes, written;
+    unsigned char buf[20];
+
+    allow_ed_cb_called = 0;
+
+    if (!TEST_true(create_ssl_ctx_pair(TLS_server_method(), TLS_client_method(),
+                                       TLS1_VERSION, TLS_MAX_VERSION, &sctx,
+                                       &cctx, cert, privkey)))
+        return 0;
+
+    if (usecb > 0) {
+        SSL_CTX_set_options(sctx, SSL_OP_NO_ANTI_REPLAY);
+        SSL_CTX_set_allow_early_data_cb(sctx, allow_early_data_cb, &usecb);
+    }
 
     if (!TEST_true(setupearly_data_test(&cctx, &sctx, &clientssl,
                                         &serverssl, &sess, idx)))
 
     if (!TEST_true(setupearly_data_test(&cctx, &sctx, &clientssl,
                                         &serverssl, &sess, idx)))
@@ -2183,14 +2221,49 @@ static int test_early_data_replay(int idx)
 
     if (!TEST_true(create_ssl_objects(sctx, cctx, &serverssl,
                                       &clientssl, NULL, NULL))
 
     if (!TEST_true(create_ssl_objects(sctx, cctx, &serverssl,
                                       &clientssl, NULL, NULL))
-            || !TEST_true(SSL_set_session(clientssl, sess))
-            || !TEST_true(create_ssl_connection(serverssl, clientssl,
-                          SSL_ERROR_NONE))
-               /*
-                * This time we should not have resumed the session because we
-                * already used it once.
-                */
-            || !TEST_false(SSL_session_reused(clientssl)))
+            || !TEST_true(SSL_set_session(clientssl, sess)))
+        goto end;
+
+    /* Write and read some early data */
+    if (!TEST_true(SSL_write_early_data(clientssl, MSG1, strlen(MSG1),
+                                        &written))
+            || !TEST_size_t_eq(written, strlen(MSG1)))
+        goto end;
+
+    if (usecb <= 1) {
+        if (!TEST_int_eq(SSL_read_early_data(serverssl, buf, sizeof(buf),
+                                             &readbytes),
+                         SSL_READ_EARLY_DATA_FINISH)
+                   /*
+                    * The ticket was reused, so the we should have rejected the
+                    * early data
+                    */
+                || !TEST_int_eq(SSL_get_early_data_status(serverssl),
+                                SSL_EARLY_DATA_REJECTED))
+            goto end;
+    } else {
+        /* In this case the callback decides to accept the early data */
+        if (!TEST_int_eq(SSL_read_early_data(serverssl, buf, sizeof(buf),
+                                             &readbytes),
+                         SSL_READ_EARLY_DATA_SUCCESS)
+                || !TEST_mem_eq(MSG1, strlen(MSG1), buf, readbytes)
+                   /*
+                    * Server will have sent its flight so client can now send
+                    * end of early data and complete its half of the handshake
+                    */
+                || !TEST_int_gt(SSL_connect(clientssl), 0)
+                || !TEST_int_eq(SSL_read_early_data(serverssl, buf, sizeof(buf),
+                                             &readbytes),
+                                SSL_READ_EARLY_DATA_FINISH)
+                || !TEST_int_eq(SSL_get_early_data_status(serverssl),
+                                SSL_EARLY_DATA_ACCEPTED))
+            goto end;
+    }
+
+    /* Complete the connection */
+    if (!TEST_true(create_ssl_connection(serverssl, clientssl, SSL_ERROR_NONE))
+            || !TEST_int_eq(SSL_session_reused(clientssl), (usecb > 0) ? 1 : 0)
+            || !TEST_int_eq(allow_ed_cb_called, usecb > 0 ? 1 : 0))
         goto end;
 
     testresult = 1;
         goto end;
 
     testresult = 1;
@@ -2207,6 +2280,17 @@ static int test_early_data_replay(int idx)
     return testresult;
 }
 
     return testresult;
 }
 
+static int test_early_data_replay(int idx)
+{
+    int ret;
+
+    ret = test_early_data_replay_int(idx, 0);
+    ret &= test_early_data_replay_int(idx, 1);
+    ret &= test_early_data_replay_int(idx, 2);
+
+    return ret;
+}
+
 /*
  * Helper function to test that a server attempting to read early data can
  * handle a connection from a client where the early data should be skipped.
 /*
  * Helper function to test that a server attempting to read early data can
  * handle a connection from a client where the early data should be skipped.
@@ -4972,6 +5056,135 @@ static int test_ticket_callbacks(int tst)
     return testresult;
 }
 
     return testresult;
 }
 
+/*
+ * Test bi-directional shutdown.
+ * Test 0: TLSv1.2
+ * Test 1: TLSv1.2, server continues to read/write after client shutdown
+ * Test 2: TLSv1.3, no pending NewSessionTicket messages
+ * Test 3: TLSv1.3, pending NewSessionTicket messages
+ * Test 4: TLSv1.3, server continues to read/write after client shutdown, client
+ *                  reads it
+ * Test 5: TLSv1.3, server continues to read/write after client shutdown, client
+ *                  doesn't read it
+ */
+static int test_shutdown(int tst)
+{
+    SSL_CTX *cctx = NULL, *sctx = NULL;
+    SSL *clientssl = NULL, *serverssl = NULL;
+    int testresult = 0;
+    char msg[] = "A test message";
+    char buf[80];
+    size_t written, readbytes;
+
+#ifdef OPENSSL_NO_TLS1_2
+    if (tst == 0)
+        return 1;
+#endif
+#ifdef OPENSSL_NO_TLS1_3
+    if (tst != 0)
+        return 1;
+#endif
+
+    if (!TEST_true(create_ssl_ctx_pair(TLS_server_method(),
+                                       TLS_client_method(),
+                                       TLS1_VERSION,
+                                       (tst <= 1) ? TLS1_2_VERSION
+                                                  : TLS1_3_VERSION,
+                                       &sctx, &cctx, cert, privkey))
+            || !TEST_true(create_ssl_objects(sctx, cctx, &serverssl, &clientssl,
+                                             NULL, NULL)))
+        goto end;
+
+    if (tst == 3) {
+        if (!TEST_true(create_bare_ssl_connection(serverssl, clientssl,
+                                                  SSL_ERROR_NONE)))
+            goto end;
+    } else if (!TEST_true(create_ssl_connection(serverssl, clientssl,
+                                              SSL_ERROR_NONE))) {
+        goto end;
+    }
+
+    if (!TEST_int_eq(SSL_shutdown(clientssl), 0))
+        goto end;
+
+    if (tst >= 4) {
+        /*
+         * Reading on the server after the client has sent close_notify should
+         * fail and provide SSL_ERROR_ZERO_RETURN
+         */
+        if (!TEST_false(SSL_read_ex(serverssl, buf, sizeof(buf), &readbytes))
+                || !TEST_int_eq(SSL_get_error(serverssl, 0),
+                                SSL_ERROR_ZERO_RETURN)
+                || !TEST_int_eq(SSL_get_shutdown(serverssl),
+                                SSL_RECEIVED_SHUTDOWN)
+                   /*
+                    * Even though we're shutdown on receive we should still be
+                    * able to write.
+                    */
+                || !TEST_true(SSL_write(serverssl, msg, sizeof(msg)))
+                || !TEST_int_eq(SSL_shutdown(serverssl), 1))
+            goto end;
+        if (tst == 4) {
+                   /* Should still be able to read data from server */
+            if (!TEST_true(SSL_read_ex(clientssl, buf, sizeof(buf),
+                                          &readbytes))
+                    || !TEST_size_t_eq(readbytes, sizeof(msg))
+                    || !TEST_int_eq(memcmp(msg, buf, readbytes), 0))
+                goto end;
+        }
+    }
+
+    /* Writing on the client after sending close_notify shouldn't be possible */
+    if (!TEST_false(SSL_write_ex(clientssl, msg, sizeof(msg), &written)))
+        goto end;
+
+    if (tst < 4) {
+        /*
+         * For these tests the client has sent close_notify but it has not yet
+         * been received by the server. The server has not sent close_notify
+         * yet.
+         */
+        if (!TEST_int_eq(SSL_shutdown(serverssl), 0)
+                   /*
+                    * Writing on the server after sending close_notify shouldn't
+                    * be possible.
+                    */
+                || !TEST_false(SSL_write_ex(serverssl, msg, sizeof(msg), &written))
+                || !TEST_int_eq(SSL_shutdown(clientssl), 1)
+                || !TEST_int_eq(SSL_shutdown(serverssl), 1))
+            goto end;
+    } else if (tst == 4) {
+        /*
+         * In this test the client has sent close_notify and it has been
+         * received by the server which has responded with a close_notify. The
+         * client needs to read the close_notify sent by the server.
+         */
+        if (!TEST_int_eq(SSL_shutdown(clientssl), 1))
+            goto end;
+    } else {
+        /*
+         * tst == 5
+         *
+         * The client has sent close_notify and is expecting a close_notify
+         * back, but instead there is application data first. The shutdown
+         * should fail with a fatal error.
+         */
+        if (!TEST_int_eq(SSL_shutdown(clientssl), -1)
+                || !TEST_int_eq(SSL_get_error(clientssl, -1), SSL_ERROR_SSL))
+            goto end;
+    }
+
+    testresult = 1;
+
+ end:
+    SSL_free(serverssl);
+    SSL_free(clientssl);
+    SSL_CTX_free(sctx);
+    SSL_CTX_free(cctx);
+
+    return testresult;
+}
+
 int setup_tests(void)
 {
     if (!TEST_ptr(cert = test_get_argument(0))
 int setup_tests(void)
 {
     if (!TEST_ptr(cert = test_get_argument(0))
@@ -5069,6 +5282,7 @@ int setup_tests(void)
     ADD_ALL_TESTS(test_ssl_pending, 2);
     ADD_ALL_TESTS(test_ssl_get_shared_ciphers, OSSL_NELEM(shared_ciphers_data));
     ADD_ALL_TESTS(test_ticket_callbacks, 12);
     ADD_ALL_TESTS(test_ssl_pending, 2);
     ADD_ALL_TESTS(test_ssl_get_shared_ciphers, OSSL_NELEM(shared_ciphers_data));
     ADD_ALL_TESTS(test_ticket_callbacks, 12);
+    ADD_ALL_TESTS(test_shutdown, 6);
     return 1;
 }
 
     return 1;
 }