Add a test for using a PSK with QUIC
authorMatt Caswell <matt@openssl.org>
Thu, 7 Sep 2023 16:36:13 +0000 (17:36 +0100)
committerTomas Mraz <tomas@openssl.org>
Tue, 12 Sep 2023 13:29:00 +0000 (15:29 +0200)
Check that we can set and use a PSK when establishing a QUIC connection.

Fixes openssl/project#83

Reviewed-by: Hugo Landau <hlandau@openssl.org>
Reviewed-by: Paul Dale <pauli@openssl.org>
Reviewed-by: Tomas Mraz <tomas@openssl.org>
(Merged from https://github.com/openssl/openssl/pull/22011)

include/internal/quic_tserver.h
ssl/quic/quic_tserver.c
test/helpers/ssltestlib.c
test/helpers/ssltestlib.h
test/quicapitest.c
test/sslapitest.c

index 9213f60666aa6ac8392fd3480d15ed3736e8b46d..4f358dd4e87c58bdef39817361c52a0fdd5ab470 100644 (file)
@@ -211,6 +211,10 @@ int ossl_quic_tserver_new_ticket(QUIC_TSERVER *srv);
 int ossl_quic_tserver_set_max_early_data(QUIC_TSERVER *srv,
                                          uint32_t max_early_data);
 
+/* Set the find session callback for getting a server PSK */
+void ossl_quic_tserver_set_psk_find_session_cb(QUIC_TSERVER *srv,
+                                               SSL_psk_find_session_cb_func cb);
+
 # endif
 
 #endif
index 788d4780d8e38c037f1db96638104ea3bf76f238..92c17d10f3c64846a61390e195b1716c097ddf43 100644 (file)
@@ -99,10 +99,12 @@ QUIC_TSERVER *ossl_quic_tserver_new(const QUIC_TSERVER_ARGS *args,
     if (srv->ctx == NULL)
         goto err;
 
-    if (SSL_CTX_use_certificate_file(srv->ctx, certfile, SSL_FILETYPE_PEM) <= 0)
+    if (certfile != NULL
+            && SSL_CTX_use_certificate_file(srv->ctx, certfile, SSL_FILETYPE_PEM) <= 0)
         goto err;
 
-    if (SSL_CTX_use_PrivateKey_file(srv->ctx, keyfile, SSL_FILETYPE_PEM) <= 0)
+    if (keyfile != NULL
+            && SSL_CTX_use_PrivateKey_file(srv->ctx, keyfile, SSL_FILETYPE_PEM) <= 0)
         goto err;
 
     SSL_CTX_set_alpn_select_cb(srv->ctx, alpn_select_cb, srv);
@@ -556,3 +558,9 @@ int ossl_quic_tserver_set_max_early_data(QUIC_TSERVER *srv,
 {
     return SSL_set_max_early_data(srv->tls, max_early_data);
 }
+
+void ossl_quic_tserver_set_psk_find_session_cb(QUIC_TSERVER *srv,
+                                               SSL_psk_find_session_cb_func cb)
+{
+    SSL_set_psk_find_session_callback(srv->tls, cb);
+}
index 94a170b9a52f440a2a0364880938195215359392..0b1e56f064ca6ab11e0bf8ec8c27835fcf2c0290 100644 (file)
@@ -1247,3 +1247,41 @@ void shutdown_ssl_connection(SSL *serverssl, SSL *clientssl)
     SSL_free(serverssl);
     SSL_free(clientssl);
 }
+
+SSL_SESSION *create_a_psk(SSL *ssl, size_t mdsize)
+{
+    const SSL_CIPHER *cipher = NULL;
+    const unsigned char key[SHA384_DIGEST_LENGTH] = {
+        0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a,
+        0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15,
+        0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20,
+        0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b,
+        0x2c, 0x2d, 0x2e, 0x2f
+    };
+    SSL_SESSION *sess = NULL;
+
+    if (mdsize == SHA384_DIGEST_LENGTH) {
+        cipher = SSL_CIPHER_find(ssl, TLS13_AES_256_GCM_SHA384_BYTES);
+    } else if (mdsize == SHA256_DIGEST_LENGTH) {
+        /*
+         * Any ciphersuite using SHA256 will do - it will be compatible with
+         * the actual ciphersuite selected as long as it too is based on SHA256
+         */
+        cipher = SSL_CIPHER_find(ssl, TLS13_AES_128_GCM_SHA256_BYTES);
+    } else {
+        /* Should not happen */
+        return NULL;
+    }
+    sess = SSL_SESSION_new();
+    if (!TEST_ptr(sess)
+            || !TEST_ptr(cipher)
+            || !TEST_true(SSL_SESSION_set1_master_key(sess, key, mdsize))
+            || !TEST_true(SSL_SESSION_set_cipher(sess, cipher))
+            || !TEST_true(
+                    SSL_SESSION_set_protocol_version(sess,
+                                                     TLS1_3_VERSION))) {
+        SSL_SESSION_free(sess);
+        return NULL;
+    }
+    return sess;
+}
index c8dcb8a82d28d269d1b4bda47bc3f2cf4108277b..c513769ddd956027957dd7022af260a84e6e47ba 100644 (file)
 
 # include <openssl/ssl.h>
 
+#define TLS13_AES_128_GCM_SHA256_BYTES  ((const unsigned char *)"\x13\x01")
+#define TLS13_AES_256_GCM_SHA384_BYTES  ((const unsigned char *)"\x13\x02")
+#define TLS13_CHACHA20_POLY1305_SHA256_BYTES ((const unsigned char *)"\x13\x03")
+#define TLS13_AES_128_CCM_SHA256_BYTES ((const unsigned char *)"\x13\x04")
+#define TLS13_AES_128_CCM_8_SHA256_BYTES ((const unsigned char *)"\x13\05")
+
 int create_ssl_ctx_pair(OSSL_LIB_CTX *libctx, const SSL_METHOD *sm,
                         const SSL_METHOD *cm, int min_proto_version,
                         int max_proto_version, SSL_CTX **sctx, SSL_CTX **cctx,
@@ -60,4 +66,6 @@ typedef struct mempacket_st MEMPACKET;
 
 DEFINE_STACK_OF(MEMPACKET)
 
+SSL_SESSION *create_a_psk(SSL *ssl, size_t mdsize);
+
 #endif /* OSSL_TEST_SSLTESTLIB_H */
index 87c134eb88bf36527282f3443844984cfe674465..a24946a649ef45a4e6d5dbf75bdb5249f53977b4 100644 (file)
@@ -1061,6 +1061,92 @@ static int test_non_io_retry(int idx)
     return testresult;
 }
 
+static int use_session_cb_cnt = 0;
+static int find_session_cb_cnt = 0;
+static const char *pskid = "Identity";
+static SSL_SESSION *serverpsk = NULL, *clientpsk = NULL;
+
+static int use_session_cb(SSL *ssl, const EVP_MD *md, const unsigned char **id,
+                          size_t *idlen, SSL_SESSION **sess)
+{
+    use_session_cb_cnt++;
+
+    if (clientpsk == NULL)
+        return 0;
+
+    SSL_SESSION_up_ref(clientpsk);
+
+    *sess = clientpsk;
+    *id = (const unsigned char *)pskid;
+    *idlen = strlen(pskid);
+
+    return 1;
+}
+
+static int find_session_cb(SSL *ssl, const unsigned char *identity,
+                           size_t identity_len, SSL_SESSION **sess)
+{
+    find_session_cb_cnt++;
+
+    if (serverpsk == NULL)
+        return 0;
+
+    /* Identity should match that set by the client */
+    if (strlen(pskid) != identity_len
+            || strncmp(pskid, (const char *)identity, identity_len) != 0)
+        return 0;
+
+    SSL_SESSION_up_ref(serverpsk);
+    *sess = serverpsk;
+
+    return 1;
+}
+
+static int test_quic_psk(void)
+{
+    SSL_CTX *cctx = SSL_CTX_new_ex(libctx, NULL, OSSL_QUIC_client_method());
+    SSL *clientquic = NULL;
+    QUIC_TSERVER *qtserv = NULL;
+    int testresult = 0;
+
+    if (!TEST_ptr(cctx)
+               /* No cert or private key for the server, i.e. PSK only */
+            || !TEST_true(qtest_create_quic_objects(libctx, cctx, NULL, NULL,
+                                                    NULL, 0, &qtserv,
+                                                    &clientquic, NULL)))
+        goto end;
+
+    SSL_set_psk_use_session_callback(clientquic, use_session_cb);
+    ossl_quic_tserver_set_psk_find_session_cb(qtserv, find_session_cb);
+    use_session_cb_cnt = 0;
+    find_session_cb_cnt = 0;
+
+    clientpsk = serverpsk = create_a_psk(clientquic, SHA384_DIGEST_LENGTH);
+    if (!TEST_ptr(clientpsk))
+        goto end;
+    /* We already had one ref. Add another one */
+    SSL_SESSION_up_ref(clientpsk);
+
+    if (!TEST_true(qtest_create_quic_connection(qtserv, clientquic))
+            || !TEST_int_eq(1, find_session_cb_cnt)
+            || !TEST_int_eq(1, use_session_cb_cnt)
+               /* Check that we actually used the PSK */
+            || !TEST_true(SSL_session_reused(clientquic)))
+        goto end;
+
+    testresult = 1;
+
+ end:
+    SSL_free(clientquic);
+    ossl_quic_tserver_free(qtserv);
+    SSL_CTX_free(cctx);
+    SSL_SESSION_free(clientpsk);
+    SSL_SESSION_free(serverpsk);
+    clientpsk = serverpsk = NULL;
+
+    return testresult;
+}
+
 OPT_TEST_DECLARE_USAGE("provider config certsdir datadir\n")
 
 int setup_tests(void)
@@ -1131,6 +1217,7 @@ int setup_tests(void)
     ADD_TEST(test_back_pressure);
     ADD_TEST(test_multiple_dgrams);
     ADD_ALL_TESTS(test_non_io_retry, 2);
+    ADD_TEST(test_quic_psk);
     return 1;
  err:
     cleanup_tests();
index 756675c1dce6aa4ce7314ddafd9cef20985220db..ec29157007c74764d0ae7fd82ba050b01c3c3269 100644 (file)
@@ -77,8 +77,6 @@ static int find_session_cb(SSL *ssl, const unsigned char *identity,
 
 static int use_session_cb_cnt = 0;
 static int find_session_cb_cnt = 0;
-
-static SSL_SESSION *create_a_psk(SSL *ssl, size_t mdsize);
 #endif
 
 static char *certsdir = NULL;
@@ -3385,51 +3383,6 @@ static unsigned int psk_server_cb(SSL *ssl, const char *identity,
 #define MSG6    "test"
 #define MSG7    "message."
 
-#define TLS13_AES_128_GCM_SHA256_BYTES  ((const unsigned char *)"\x13\x01")
-#define TLS13_AES_256_GCM_SHA384_BYTES  ((const unsigned char *)"\x13\x02")
-#define TLS13_CHACHA20_POLY1305_SHA256_BYTES ((const unsigned char *)"\x13\x03")
-#define TLS13_AES_128_CCM_SHA256_BYTES ((const unsigned char *)"\x13\x04")
-#define TLS13_AES_128_CCM_8_SHA256_BYTES ((const unsigned char *)"\x13\05")
-
-
-static SSL_SESSION *create_a_psk(SSL *ssl, size_t mdsize)
-{
-    const SSL_CIPHER *cipher = NULL;
-    const unsigned char key[] = {
-        0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a,
-        0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15,
-        0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20,
-        0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b,
-        0x2c, 0x2d, 0x2e, 0x2f /* SHA384_DIGEST_LENGTH bytes */
-    };
-    SSL_SESSION *sess = NULL;
-
-    if (mdsize == SHA384_DIGEST_LENGTH) {
-        cipher = SSL_CIPHER_find(ssl, TLS13_AES_256_GCM_SHA384_BYTES);
-    } else if (mdsize == SHA256_DIGEST_LENGTH) {
-        /*
-         * Any ciphersuite using SHA256 will do - it will be compatible with
-         * the actual ciphersuite selected as long as it too is based on SHA256
-         */
-        cipher = SSL_CIPHER_find(ssl, TLS13_AES_128_GCM_SHA256_BYTES);
-    } else {
-        /* Should not happen */
-        return NULL;
-    }
-    sess = SSL_SESSION_new();
-    if (!TEST_ptr(sess)
-            || !TEST_ptr(cipher)
-            || !TEST_true(SSL_SESSION_set1_master_key(sess, key, mdsize))
-            || !TEST_true(SSL_SESSION_set_cipher(sess, cipher))
-            || !TEST_true(
-                    SSL_SESSION_set_protocol_version(sess,
-                                                     TLS1_3_VERSION))) {
-        SSL_SESSION_free(sess);
-        return NULL;
-    }
-    return sess;
-}
-
 static int artificial_ticket_time = 0;
 
 static int ed_gen_cb(SSL *s, void *arg)