Add some PSK early_data tests
[openssl.git] / test / sslapitest.c
index 70fbf80..e9d1961 100644 (file)
@@ -1399,6 +1399,70 @@ static int test_set_sigalgs(int idx)
     return testresult;
 }
 
+static SSL_SESSION *psk = NULL;
+static const char *pskid = "Identity";
+static const char *srvid;
+
+static int use_session_cb_cnt = 0;
+static int find_session_cb_cnt = 0;
+
+static int use_session_cb(SSL *ssl, const EVP_MD *md, const unsigned char **id,
+                          size_t *idlen, SSL_SESSION **sess)
+{
+    switch (++use_session_cb_cnt) {
+    case 1:
+        /* The first call should always have a NULL md */
+        if (md != NULL)
+            return 0;
+        break;
+
+    case 2:
+        /* The second call should always have an md */
+        if (md == NULL)
+            return 0;
+        break;
+
+    default:
+        /* We should only be called a maximum of twice */
+        return 0;
+    }
+
+    if (psk != NULL)
+        SSL_SESSION_up_ref(psk);
+
+    *sess = psk;
+    *id = (const unsigned char *)pskid;
+    *idlen = strlen(pskid);
+
+    return 1;
+}
+
+static int find_session_cb(SSL *ssl, const unsigned char *identity,
+                           size_t identity_len, SSL_SESSION **sess)
+{
+    find_session_cb_cnt++;
+
+    /* We should only ever be called a maximum of twice per connection */
+    if (find_session_cb_cnt > 2)
+        return 0;
+
+    if (psk == NULL)
+        return 0;
+
+    /* Identity should match that set by the client */
+    if (strlen(srvid) != identity_len
+            || strncmp(srvid, (const char *)identity, identity_len) != 0) {
+        /* No PSK found, continue but without a PSK */
+        *sess = NULL;
+        return 1;
+    }
+
+    SSL_SESSION_up_ref(psk);
+    *sess = psk;
+
+    return 1;
+}
+
 #ifndef OPENSSL_NO_TLS1_3
 
 #define MSG1    "Hello"
@@ -1409,6 +1473,8 @@ static int test_set_sigalgs(int idx)
 #define MSG6    "test"
 #define MSG7    "message."
 
+#define TLS13_AES_256_GCM_SHA384_BYTES  ((const unsigned char *)"\x13\x02")
+
 /*
  * Helper method to setup objects for early data test. Caller frees objects on
  * error.
@@ -1421,16 +1487,64 @@ static int setupearly_data_test(SSL_CTX **cctx, SSL_CTX **sctx, SSL **clientssl,
                                        cctx, cert, privkey)))
         return 0;
 
-    /* When idx == 1 we repeat the tests with read_ahead set */
-    if (idx > 0) {
+    if (idx == 1) {
+        /* When idx == 1 we repeat the tests with read_ahead set */
         SSL_CTX_set_read_ahead(*cctx, 1);
         SSL_CTX_set_read_ahead(*sctx, 1);
+    } else if (idx == 2) {
+        /* When idx == 2 we are doing early_data with a PSK. Set up callbacks */
+        SSL_CTX_set_psk_use_session_callback(*cctx, use_session_cb);
+        SSL_CTX_set_psk_find_session_callback(*sctx, find_session_cb);
+        use_session_cb_cnt = 0;
+        find_session_cb_cnt = 0;
+        srvid = pskid;
     }
 
     if (!TEST_true(create_ssl_objects(*sctx, *cctx, serverssl, clientssl,
-                                      NULL, NULL))
-            || !TEST_true(create_ssl_connection(*serverssl, *clientssl,
-                                                SSL_ERROR_NONE)))
+                                      NULL, NULL)))
+        return 0;
+
+    if (idx == 2) {
+        /* Create the PSK */
+        const SSL_CIPHER *cipher = NULL;
+        const unsigned char key[] = {
+            0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a,
+            0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15,
+            0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20,
+            0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b,
+            0x2c, 0x2d, 0x2e, 0x2f
+        };
+
+        cipher = SSL_CIPHER_find(*clientssl, TLS13_AES_256_GCM_SHA384_BYTES);
+        psk = SSL_SESSION_new();
+        if (!TEST_ptr(psk)
+                || !TEST_ptr(cipher)
+                || !TEST_true(SSL_SESSION_set1_master_key(psk, key,
+                                                          sizeof(key)))
+                || !TEST_true(SSL_SESSION_set_cipher(psk, cipher))
+                || !TEST_true(
+                        SSL_SESSION_set_protocol_version(psk,
+                                                         TLS1_3_VERSION))
+                   /*
+                    * We just choose an arbitrary value for max_early_data which
+                    * should be big enough for testing purposes.
+                    */
+                || !TEST_true(SSL_SESSION_set_max_early_data(psk, 0x100))) {
+            SSL_SESSION_free(psk);
+            psk = NULL;
+            return 0;
+        }
+
+        if (sess != NULL)
+            *sess = psk;
+        return 1;
+    }
+
+    if (sess == NULL)
+        return 1;
+
+    if (!TEST_true(create_ssl_connection(*serverssl, *clientssl,
+                                         SSL_ERROR_NONE)))
         return 0;
 
     *sess = SSL_get1_session(*clientssl);
@@ -1591,8 +1705,12 @@ static int test_early_data_read_write(int idx)
             || !TEST_mem_eq(buf, readbytes, MSG7, strlen(MSG7)))
         goto end;
 
-    SSL_SESSION_free(sess);
+    /* We keep the PSK session around if using PSK */
+    if (idx != 2)
+        SSL_SESSION_free(sess);
     sess = SSL_get1_session(clientssl);
+    use_session_cb_cnt = 0;
+    find_session_cb_cnt = 0;
 
     SSL_shutdown(clientssl);
     SSL_shutdown(serverssl);
@@ -1640,6 +1758,7 @@ static int test_early_data_read_write(int idx)
 
  end:
     SSL_SESSION_free(sess);
+    psk = NULL;
     SSL_free(serverssl);
     SSL_free(clientssl);
     SSL_CTX_free(sctx);
@@ -1668,6 +1787,12 @@ static int early_data_skip_helper(int hrr, int idx)
         /* Force an HRR to occur */
         if (!TEST_true(SSL_set1_groups_list(serverssl, "P-256")))
             goto end;
+    } else if (idx == 2) {
+        /*
+         * We force early_data rejection by ensuring the PSK identity is
+         * unrecognised
+         */
+        srvid = "Dummy Identity";
     } else {
         /*
          * Deliberately corrupt the creation time. We take 20 seconds off the
@@ -1717,6 +1842,9 @@ static int early_data_skip_helper(int hrr, int idx)
     testresult = 1;
 
  end:
+    if (sess != psk)
+        SSL_SESSION_free(psk);
+    psk = NULL;
     SSL_SESSION_free(sess);
     SSL_free(serverssl);
     SSL_free(clientssl);
@@ -1789,7 +1917,7 @@ static int test_early_data_not_sent(int idx)
      * Should block due to the NewSessionTicket arrival unless we're using
      * read_ahead
      */
-    if (idx == 0) {
+    if (idx != 1) {
         if (!TEST_false(SSL_read_ex(clientssl, buf, sizeof(buf), &readbytes)))
             goto end;
     }
@@ -1802,6 +1930,7 @@ static int test_early_data_not_sent(int idx)
 
  end:
     SSL_SESSION_free(sess);
+    psk = NULL;
     SSL_free(serverssl);
     SSL_free(clientssl);
     SSL_CTX_free(sctx);
@@ -1822,7 +1951,6 @@ static int test_early_data_not_expected(int idx)
     unsigned char buf[20];
     size_t readbytes, written;
 
-
     if (!TEST_true(setupearly_data_test(&cctx, &sctx, &clientssl,
                                         &serverssl, &sess, idx)))
         goto end;
@@ -1858,6 +1986,7 @@ static int test_early_data_not_expected(int idx)
 
  end:
     SSL_SESSION_free(sess);
+    psk = NULL;
     SSL_free(serverssl);
     SSL_free(clientssl);
     SSL_CTX_free(sctx);
@@ -1879,19 +2008,8 @@ static int test_early_data_tls1_2(int idx)
     unsigned char buf[20];
     size_t readbytes, written;
 
-    if (!TEST_true(create_ssl_ctx_pair(TLS_server_method(),
-                                       TLS_client_method(), &sctx,
-                                       &cctx, cert, privkey)))
-        goto end;
-
-    /* When idx == 1 we repeat the tests with read_ahead set */
-    if (idx > 0) {
-        SSL_CTX_set_read_ahead(cctx, 1);
-        SSL_CTX_set_read_ahead(sctx, 1);
-    }
-
-    if (!TEST_true(create_ssl_objects(sctx, cctx, &serverssl,
-                                      &clientssl, NULL, NULL)))
+    if (!TEST_true(setupearly_data_test(&cctx, &sctx, &clientssl,
+                                        &serverssl, NULL, idx)))
         goto end;
 
     /* Write some data - should block due to handshake with server */
@@ -1939,6 +2057,8 @@ static int test_early_data_tls1_2(int idx)
     testresult = 1;
 
  end:
+    SSL_SESSION_free(psk);
+    psk = NULL;
     SSL_free(serverssl);
     SSL_free(clientssl);
     SSL_CTX_free(sctx);
@@ -2075,73 +2195,6 @@ static int test_ciphersuite_change(void)
     return testresult;
 }
 
-
-static SSL_SESSION *psk = NULL;
-static const char *pskid = "Identity";
-static const char *srvid;
-
-static int use_session_cb_cnt = 0;
-static int find_session_cb_cnt = 0;
-
-static int use_session_cb(SSL *ssl, const EVP_MD *md, const unsigned char **id,
-                          size_t *idlen, SSL_SESSION **sess)
-{
-    switch (++use_session_cb_cnt) {
-    case 1:
-        /* The first call should always have a NULL md */
-        if (md != NULL)
-            return 0;
-        break;
-
-    case 2:
-        /* The second call should always have an md */
-        if (md == NULL)
-            return 0;
-        break;
-
-    default:
-        /* We should only be called a maximum of twice */
-        return 0;
-    }
-
-    if (psk != NULL)
-        SSL_SESSION_up_ref(psk);
-
-    *sess = psk;
-    *id = (const unsigned char *)pskid;
-    *idlen = strlen(pskid);
-
-    return 1;
-}
-
-static int find_session_cb(SSL *ssl, const unsigned char *identity,
-                           size_t identity_len, SSL_SESSION **sess)
-{
-    find_session_cb_cnt++;
-
-    /* We should only ever be called a maximum of twice per connection */
-    if (find_session_cb_cnt > 2)
-        return 0;
-
-    if (psk == NULL)
-        return 0;
-
-    /* Identity should match that set by the client */
-    if (strlen(srvid) != identity_len
-            || strncmp(srvid, (const char *)identity, identity_len) != 0) {
-        /* No PSK found, continue but without a PSK */
-        *sess = NULL;
-        return 1;
-    }
-
-    SSL_SESSION_up_ref(psk);
-    *sess = psk;
-
-    return 1;
-}
-
-#define TLS13_AES_256_GCM_SHA384_BYTES  ((const unsigned char *)"\x13\x02")
-
 static int test_tls13_psk(void)
 {
     SSL_CTX *sctx = NULL, *cctx = NULL;
@@ -2163,6 +2216,8 @@ static int test_tls13_psk(void)
     SSL_CTX_set_psk_use_session_callback(cctx, use_session_cb);
     SSL_CTX_set_psk_find_session_callback(sctx, find_session_cb);
     srvid = pskid;
+    use_session_cb_cnt = 0;
+    find_session_cb_cnt = 0;
 
     /* Check we can create a connection if callback decides not to send a PSK */
     if (!TEST_true(create_ssl_objects(sctx, cctx, &serverssl, &clientssl,
@@ -2846,13 +2901,13 @@ int setup_tests(void)
     ADD_TEST(test_early_cb);
 #endif
 #ifndef OPENSSL_NO_TLS1_3
-    ADD_ALL_TESTS(test_early_data_read_write, 2);
-    ADD_ALL_TESTS(test_early_data_skip, 2);
-    ADD_ALL_TESTS(test_early_data_skip_hrr, 2);
-    ADD_ALL_TESTS(test_early_data_not_sent, 2);
-    ADD_ALL_TESTS(test_early_data_not_expected, 2);
+    ADD_ALL_TESTS(test_early_data_read_write, 3);
+    ADD_ALL_TESTS(test_early_data_skip, 3);
+    ADD_ALL_TESTS(test_early_data_skip_hrr, 3);
+    ADD_ALL_TESTS(test_early_data_not_sent, 3);
+    ADD_ALL_TESTS(test_early_data_not_expected, 3);
 # ifndef OPENSSL_NO_TLS1_2
-    ADD_ALL_TESTS(test_early_data_tls1_2, 2);
+    ADD_ALL_TESTS(test_early_data_tls1_2, 3);
 # endif
 #endif
 #ifndef OPENSSL_NO_TLS1_3