Add a qtest_check_server_transport_err helper function
[openssl.git] / test / helpers / quictestlib.c
index 34672a5913b05dc26f62d1230daaf2ba7aa8cadb..ca3719c267553ed2b6780a71799cd2c876e0b529 100644 (file)
@@ -8,10 +8,13 @@
  */
 
 #include <assert.h>
+#include <openssl/bio.h>
 #include "quictestlib.h"
 #include "../testutil.h"
 #include "internal/quic_wire_pkt.h"
 #include "internal/quic_record_tx.h"
+#include "internal/quic_error.h"
+#include "internal/packet.h"
 
 #define GROWTH_ALLOWANCE 1024
 
@@ -27,8 +30,37 @@ struct ossl_quic_fault {
     size_t pplainbuf_alloc;
     ossl_quic_fault_on_packet_plain_cb pplaincb;
     void *pplaincbarg;
+
+    /* Handshake message mutations */
+    /* Handshake message buffer */
+    unsigned char *handbuf;
+    /* Allocated size of the handshake message buffer */
+    size_t handbufalloc;
+    /* Actual length of the handshake message */
+    size_t handbuflen;
+    ossl_quic_fault_on_handshake_cb handshakecb;
+    void *handshakecbarg;
+    ossl_quic_fault_on_enc_ext_cb encextcb;
+    void *encextcbarg;
+
+    /* Cipher packet mutations */
+    ossl_quic_fault_on_packet_cipher_cb pciphercb;
+    void *pciphercbarg;
+
+    /* Datagram mutations */
+    ossl_quic_fault_on_datagram_cb datagramcb;
+    void *datagramcbarg;
+    /* The currently processed message */
+    BIO_MSG msg;
+    /* Allocated size of msg data buffer */
+    size_t msgalloc;
 };
 
+static void packet_plain_finish(void *arg);
+static void handshake_finish(void *arg);
+
+static BIO_METHOD *get_bio_method(void);
+
 int qtest_create_quic_objects(SSL_CTX *clientctx, char *certfile, char *keyfile,
                               QUIC_TSERVER **qtserv, SSL **cssl,
                               OSSL_QUIC_FAULT **fault)
@@ -36,7 +68,7 @@ int qtest_create_quic_objects(SSL_CTX *clientctx, char *certfile, char *keyfile,
     /* ALPN value as recognised by QUIC_TSERVER */
     unsigned char alpn[] = { 8, 'o', 's', 's', 'l', 't', 'e', 's', 't' };
     QUIC_TSERVER_ARGS tserver_args = {0};
-    BIO *bio1 = NULL, *bio2 = NULL;
+    BIO *cbio = NULL, *sbio = NULL, *fisbio = NULL;
     BIO_ADDR *peeraddr = NULL;
     struct in_addr ina = {0};
 
@@ -54,14 +86,14 @@ int qtest_create_quic_objects(SSL_CTX *clientctx, char *certfile, char *keyfile,
     if (!TEST_false(SSL_set_alpn_protos(*cssl, alpn, sizeof(alpn))))
         goto err;
 
-    if (!TEST_true(BIO_new_bio_dgram_pair(&bio1, 0, &bio2, 0)))
+    if (!TEST_true(BIO_new_bio_dgram_pair(&cbio, 0, &sbio, 0)))
         goto err;
 
-    if (!TEST_true(BIO_dgram_set_caps(bio1, BIO_DGRAM_CAP_HANDLES_DST_ADDR))
-            || !TEST_true(BIO_dgram_set_caps(bio2, BIO_DGRAM_CAP_HANDLES_DST_ADDR)))
+    if (!TEST_true(BIO_dgram_set_caps(cbio, BIO_DGRAM_CAP_HANDLES_DST_ADDR))
+            || !TEST_true(BIO_dgram_set_caps(sbio, BIO_DGRAM_CAP_HANDLES_DST_ADDR)))
         goto err;
 
-    SSL_set_bio(*cssl, bio1, bio1);
+    SSL_set_bio(*cssl, cbio, cbio);
 
     if (!TEST_ptr(peeraddr = BIO_ADDR_new()))
         goto err;
@@ -74,36 +106,43 @@ int qtest_create_quic_objects(SSL_CTX *clientctx, char *certfile, char *keyfile,
     if (!TEST_true(SSL_set_initial_peer_addr(*cssl, peeraddr)))
         goto err;
 
-    /* 2 refs are passed for bio2 */
-    if (!BIO_up_ref(bio2))
+    if (fault != NULL) {
+        *fault = OPENSSL_zalloc(sizeof(**fault));
+        if (*fault == NULL)
+            goto err;
+    }
+
+    fisbio = BIO_new(get_bio_method());
+    if (!TEST_ptr(fisbio))
+        goto err;
+
+    BIO_set_data(fisbio, fault == NULL ? NULL : *fault);
+
+    if (!TEST_ptr(BIO_push(fisbio, sbio)))
         goto err;
-    tserver_args.net_rbio = bio2;
-    tserver_args.net_wbio = bio2;
+
+    tserver_args.net_rbio = sbio;
+    tserver_args.net_wbio = fisbio;
 
     if (!TEST_ptr(*qtserv = ossl_quic_tserver_new(&tserver_args, certfile,
-                                                  keyfile))) {
-        /* We hold 2 refs to bio2 at the moment */
-        BIO_free(bio2);
+                                                  keyfile)))
         goto err;
-    }
-    /* Ownership of bio2 is now held by *qtserv */
-    bio2 = NULL;
 
-    if (fault != NULL) {
-        *fault = OPENSSL_zalloc(sizeof(**fault));
-        if (*fault == NULL)
-            goto err;
+    /* Ownership of fisbio and sbio is now held by *qtserv */
+    sbio = NULL;
+    fisbio = NULL;
 
+    if (fault != NULL)
         (*fault)->qtserv = *qtserv;
-    }
 
     BIO_ADDR_free(peeraddr);
 
     return 1;
  err:
     BIO_ADDR_free(peeraddr);
-    BIO_free(bio1);
-    BIO_free(bio2);
+    BIO_free(cbio);
+    BIO_free(fisbio);
+    BIO_free(sbio);
     SSL_free(*cssl);
     ossl_quic_tserver_free(*qtserv);
     if (fault != NULL)
@@ -140,13 +179,13 @@ int qtest_create_quic_connection(QUIC_TSERVER *qtserv, SSL *clientssl)
          * the communications and don't expect network delays. This shouldn't
          * be done in a real application.
          */
-        if (!clienterr)
+        if (!clienterr && retc <= 0)
             SSL_tick(clientssl);
-        if (!servererr) {
+        if (!servererr && rets <= 0) {
             ossl_quic_tserver_tick(qtserv);
             servererr = ossl_quic_tserver_is_term_any(qtserv, NULL);
-            if (!servererr && !rets)
-                rets = ossl_quic_tserver_is_connected(qtserv);
+            if (!servererr)
+                rets = ossl_quic_tserver_is_handshake_confirmed(qtserv);
         }
 
         if (clienterr && servererr)
@@ -156,18 +195,47 @@ int qtest_create_quic_connection(QUIC_TSERVER *qtserv, SSL *clientssl)
             TEST_info("No progress made");
             goto err;
         }
-    } while (retc <=0 || rets <= 0);
+    } while ((retc <= 0 && !clienterr) || (rets <= 0 && !servererr));
 
-    ret = 1;
+    if (!clienterr && !servererr)
+        ret = 1;
  err:
     return ret;
 }
 
+int qtest_check_server_transport_err(QUIC_TSERVER *qtserv, uint64_t code)
+{
+    QUIC_TERMINATE_CAUSE cause;
+
+    ossl_quic_tserver_tick(qtserv);
+
+    /*
+     * Check that the server has closed with the specified code from the client
+     */
+    if (!TEST_true(ossl_quic_tserver_is_term_any(qtserv)))
+        return 0;
+
+    cause = ossl_quic_tserver_get_terminate_cause(qtserv);
+    if  (!TEST_true(cause.remote)
+            || !TEST_uint64_t_eq(cause.error_code, code))
+        return 0;
+
+    return 1;
+}
+
+int qtest_check_server_protocol_err(QUIC_TSERVER *qtserv)
+{
+    return qtest_check_server_transport_err(qtserv, QUIC_ERR_PROTOCOL_VIOLATION);
+}
+
 void ossl_quic_fault_free(OSSL_QUIC_FAULT *fault)
 {
     if (fault == NULL)
         return;
 
+    packet_plain_finish(fault);
+    handshake_finish(fault);
+
     OPENSSL_free(fault);
 }
 
@@ -231,6 +299,7 @@ static void packet_plain_finish(void *arg)
     OPENSSL_free((unsigned char *)fault->pplainio.buf);
     fault->pplainio.buf_len = 0;
     fault->pplainbuf_alloc = 0;
+    fault->pplainio.buf = NULL;
 }
 
 int ossl_quic_fault_set_packet_plain_listener(OSSL_QUIC_FAULT *fault,
@@ -240,8 +309,10 @@ int ossl_quic_fault_set_packet_plain_listener(OSSL_QUIC_FAULT *fault,
     fault->pplaincb = pplaincb;
     fault->pplaincbarg = pplaincbarg;
 
-    return ossl_quic_tserver_set_mutator(fault->qtserv, packet_plain_mutate,
-                                         packet_plain_finish, fault);
+    return ossl_quic_tserver_set_plain_packet_mutator(fault->qtserv,
+                                                      packet_plain_mutate,
+                                                      packet_plain_finish,
+                                                      fault);
 }
 
 /* To be called from a packet_plain_listener callback */
@@ -275,3 +346,369 @@ int ossl_quic_fault_resize_plain_packet(OSSL_QUIC_FAULT *fault, size_t newlen)
 
     return 1;
 }
+
+static int handshake_mutate(const unsigned char *msgin, size_t msginlen,
+                            unsigned char **msgout, size_t *msgoutlen,
+                            void *arg)
+{
+    OSSL_QUIC_FAULT *fault = arg;
+    unsigned char *buf;
+    unsigned long payloadlen;
+    unsigned int msgtype;
+    PACKET pkt;
+
+    buf = OPENSSL_malloc(msginlen + GROWTH_ALLOWANCE);
+    if (buf == NULL)
+        return 0;
+
+    fault->handbuf = buf;
+    fault->handbuflen = msginlen;
+    fault->handbufalloc = msginlen + GROWTH_ALLOWANCE;
+    memcpy(buf, msgin, msginlen);
+
+    if (!PACKET_buf_init(&pkt, buf, msginlen)
+            || !PACKET_get_1(&pkt, &msgtype)
+            || !PACKET_get_net_3(&pkt, &payloadlen)
+            || PACKET_remaining(&pkt) != payloadlen)
+        return 0;
+
+    /* Parse specific message types */
+    switch (msgtype) {
+    case SSL3_MT_ENCRYPTED_EXTENSIONS:
+    {
+        OSSL_QF_ENCRYPTED_EXTENSIONS ee;
+
+        if (fault->encextcb == NULL)
+            break;
+
+        /*
+         * The EncryptedExtensions message is very simple. It just has an
+         * extensions block in it and nothing else.
+         */
+        ee.extensions = (unsigned char *)PACKET_data(&pkt);
+        ee.extensionslen = payloadlen;
+        if (!fault->encextcb(fault, &ee, payloadlen, fault->encextcbarg))
+            return 0;
+    }
+
+    default:
+        /* No specific handlers for these message types yet */
+        break;
+    }
+
+    if (fault->handshakecb != NULL
+            && !fault->handshakecb(fault, buf, fault->handbuflen,
+                                   fault->handshakecbarg))
+        return 0;
+
+    *msgout = buf;
+    *msgoutlen = fault->handbuflen;
+
+    return 1;
+}
+
+static void handshake_finish(void *arg)
+{
+    OSSL_QUIC_FAULT *fault = arg;
+
+    OPENSSL_free(fault->handbuf);
+    fault->handbuf = NULL;
+}
+
+int ossl_quic_fault_set_handshake_listener(OSSL_QUIC_FAULT *fault,
+                                           ossl_quic_fault_on_handshake_cb handshakecb,
+                                           void *handshakecbarg)
+{
+    fault->handshakecb = handshakecb;
+    fault->handshakecbarg = handshakecbarg;
+
+    return ossl_quic_tserver_set_handshake_mutator(fault->qtserv,
+                                                   handshake_mutate,
+                                                   handshake_finish,
+                                                   fault);
+}
+
+int ossl_quic_fault_set_hand_enc_ext_listener(OSSL_QUIC_FAULT *fault,
+                                              ossl_quic_fault_on_enc_ext_cb encextcb,
+                                              void *encextcbarg)
+{
+    fault->encextcb = encextcb;
+    fault->encextcbarg = encextcbarg;
+
+    return ossl_quic_tserver_set_handshake_mutator(fault->qtserv,
+                                                   handshake_mutate,
+                                                   handshake_finish,
+                                                   fault);
+}
+
+/* To be called from a handshake_listener callback */
+int ossl_quic_fault_resize_handshake(OSSL_QUIC_FAULT *fault, size_t newlen)
+{
+    unsigned char *buf;
+    size_t oldlen = fault->handbuflen;
+
+    /*
+     * Alloc'd size should always be non-zero, so if this fails we've been
+     * incorrectly called
+     */
+    if (fault->handbufalloc == 0)
+        return 0;
+
+    if (newlen > fault->handbufalloc) {
+        /* This exceeds our growth allowance. Fail */
+        return 0;
+    }
+
+    buf = (unsigned char *)fault->handbuf;
+
+    if (newlen > oldlen) {
+        /* Extend packet with 0 bytes */
+        memset(buf + oldlen, 0, newlen - oldlen);
+    } /* else we're truncating or staying the same */
+
+    fault->handbuflen = newlen;
+    return 1;
+}
+
+/* To be called from message specific listener callbacks */
+int ossl_quic_fault_resize_message(OSSL_QUIC_FAULT *fault, size_t newlen)
+{
+    /* First resize the underlying message */
+    if (!ossl_quic_fault_resize_handshake(fault, newlen + SSL3_HM_HEADER_LENGTH))
+        return 0;
+
+    /* Fixup the handshake message header */
+    fault->handbuf[1] = (unsigned char)((newlen >> 16) & 0xff);
+    fault->handbuf[2] = (unsigned char)((newlen >>  8) & 0xff);
+    fault->handbuf[3] = (unsigned char)((newlen      ) & 0xff);
+
+    return 1;
+}
+
+int ossl_quic_fault_delete_extension(OSSL_QUIC_FAULT *fault,
+                                     unsigned int exttype, unsigned char *ext,
+                                     size_t *extlen)
+{
+    PACKET pkt, sub, subext;
+    unsigned int type;
+    const unsigned char *start, *end;
+    size_t newlen;
+    size_t msglen = fault->handbuflen;
+
+    if (!PACKET_buf_init(&pkt, ext, *extlen))
+        return 0;
+
+    /* Extension block starts with 2 bytes for extension block length */
+    if (!PACKET_as_length_prefixed_2(&pkt, &sub))
+        return 0;
+
+    do {
+        start = PACKET_data(&sub);
+        if (!PACKET_get_net_2(&sub, &type)
+                || !PACKET_get_length_prefixed_2(&sub, &subext))
+            return 0;
+    } while (type != exttype);
+
+    /* Found it */
+    end = PACKET_data(&sub);
+
+    /*
+     * If we're not the last extension we need to move the rest earlier. The
+     * cast below is safe because we own the underlying buffer and we're no
+     * longer making PACKET calls.
+     */
+    if (end < ext + *extlen)
+        memmove((unsigned char *)start, end, end - start);
+
+    /*
+     * Calculate new extensions payload length =
+     * Original length
+     * - 2 extension block length bytes
+     * - length of removed extension
+     */
+    newlen = *extlen - 2 - (end - start);
+
+    /* Fixup the length bytes for the extension block */
+    ext[0] = (unsigned char)((newlen >> 8) & 0xff);
+    ext[1] = (unsigned char)((newlen     ) & 0xff);
+
+    /*
+     * Length of the whole extension block is the new payload length plus the
+     * 2 bytes for the length
+     */
+    *extlen = newlen + 2;
+
+    /* We can now resize the message */
+    if ((size_t)(end - start) + SSL3_HM_HEADER_LENGTH > msglen)
+        return 0; /* Should not happen */
+    msglen -= (end - start) + SSL3_HM_HEADER_LENGTH;
+    if (!ossl_quic_fault_resize_message(fault, msglen))
+        return 0;
+
+    return 1;
+}
+
+#define BIO_TYPE_CIPHER_PACKET_FILTER  (0x80 | BIO_TYPE_FILTER)
+
+static BIO_METHOD *pcipherbiometh = NULL;
+
+# define BIO_MSG_N(array, stride, n) (*(BIO_MSG *)((char *)(array) + (n)*(stride)))
+
+static int pcipher_sendmmsg(BIO *b, BIO_MSG *msg, size_t stride,
+                            size_t num_msg, uint64_t flags,
+                            size_t *num_processed)
+{
+    OSSL_QUIC_FAULT *fault;
+    BIO *next = BIO_next(b);
+    ossl_ssize_t ret = 0;
+    size_t i = 0, tmpnump;
+    QUIC_PKT_HDR hdr;
+    PACKET pkt;
+    unsigned char *tmpdata;
+
+    if (next == NULL)
+        return 0;
+
+    fault = BIO_get_data(b);
+    if (fault == NULL
+            || (fault->pciphercb == NULL && fault->datagramcb == NULL))
+        return BIO_sendmmsg(next, msg, stride, num_msg, flags, num_processed);
+
+    if (num_msg == 0) {
+        *num_processed = 0;
+        return 1;
+    }
+
+    for (i = 0; i < num_msg; ++i) {
+        fault->msg = BIO_MSG_N(msg, stride, i);
+
+        /* Take a copy of the data so that callbacks can modify it */
+        tmpdata = OPENSSL_malloc(fault->msg.data_len + GROWTH_ALLOWANCE);
+        if (tmpdata == NULL)
+            return 0;
+        memcpy(tmpdata, fault->msg.data, fault->msg.data_len);
+        fault->msg.data = tmpdata;
+        fault->msgalloc = fault->msg.data_len + GROWTH_ALLOWANCE;
+
+        if (fault->pciphercb != NULL) {
+            if (!PACKET_buf_init(&pkt, fault->msg.data, fault->msg.data_len))
+                return 0;
+
+            do {
+                if (!ossl_quic_wire_decode_pkt_hdr(&pkt,
+                        0 /* TODO(QUIC): Not sure how this should be set*/, 1,
+                        &hdr, NULL))
+                    goto out;
+
+                /*
+                 * hdr.data is const - but its our buffer so casting away the
+                 * const is safe
+                 */
+                if (!fault->pciphercb(fault, &hdr, (unsigned char *)hdr.data,
+                                    hdr.len, fault->pciphercbarg))
+                    goto out;
+
+                /*
+                 * TODO(QUIC): At the moment modifications to hdr by the callback
+                 * are ignored. We might need to rewrite the QUIC header to
+                 * enable tests to change this. We also don't yet have a
+                 * mechanism for the callback to change the encrypted data
+                 * length. It's not clear if that's needed or not.
+                 */
+            } while (PACKET_remaining(&pkt) > 0);
+        }
+
+        if (fault->datagramcb != NULL
+                && !fault->datagramcb(fault, &fault->msg, stride,
+                                      fault->datagramcbarg))
+            goto out;
+
+        if (!BIO_sendmmsg(next, &fault->msg, stride, 1, flags, &tmpnump)) {
+            *num_processed = i;
+            goto out;
+        }
+
+        OPENSSL_free(fault->msg.data);
+        fault->msg.data = NULL;
+        fault->msgalloc = 0;
+    }
+
+    *num_processed = i;
+    ret = 1;
+out:
+    if (i > 0)
+        ret = 1;
+    else
+        ret = 0;
+    OPENSSL_free(fault->msg.data);
+    fault->msg.data = NULL;
+    return ret;
+}
+
+static long pcipher_ctrl(BIO *b, int cmd, long larg, void *parg)
+{
+    BIO *next = BIO_next(b);
+
+    if (next == NULL)
+        return -1;
+
+    return BIO_ctrl(next, cmd, larg, parg);
+}
+
+static BIO_METHOD *get_bio_method(void)
+{
+    BIO_METHOD *tmp;
+
+    if (pcipherbiometh != NULL)
+        return pcipherbiometh;
+
+    tmp = BIO_meth_new(BIO_TYPE_CIPHER_PACKET_FILTER, "Cipher Packet Filter");
+
+    if (!TEST_ptr(tmp))
+        return NULL;
+
+    if (!TEST_true(BIO_meth_set_sendmmsg(tmp, pcipher_sendmmsg))
+            || !TEST_true(BIO_meth_set_ctrl(tmp, pcipher_ctrl)))
+        goto err;
+
+    pcipherbiometh = tmp;
+    tmp = NULL;
+ err:
+    BIO_meth_free(tmp);
+    return pcipherbiometh;
+}
+
+int ossl_quic_fault_set_packet_cipher_listener(OSSL_QUIC_FAULT *fault,
+                                               ossl_quic_fault_on_packet_cipher_cb pciphercb,
+                                               void *pciphercbarg)
+{
+    fault->pciphercb = pciphercb;
+    fault->pciphercbarg = pciphercbarg;
+
+    return 1;
+}
+
+int ossl_quic_fault_set_datagram_listener(OSSL_QUIC_FAULT *fault,
+                                          ossl_quic_fault_on_datagram_cb datagramcb,
+                                          void *datagramcbarg)
+{
+    fault->datagramcb = datagramcb;
+    fault->datagramcbarg = datagramcbarg;
+
+    return 1;
+}
+
+/* To be called from a datagram_listener callback */
+int ossl_quic_fault_resize_datagram(OSSL_QUIC_FAULT *fault, size_t newlen)
+{
+    if (newlen > fault->msgalloc)
+            return 0;
+
+    if (newlen > fault->msg.data_len)
+        memset((unsigned char *)fault->msg.data + fault->msg.data_len, 0,
+                newlen - fault->msg.data_len);
+
+    fault->msg.data_len = newlen;
+
+    return 1;
+}