Do buildtests on our public header files with C++ as well
[openssl.git] / test / ssltestlib.c
index c7689631f1b0d67b9e2efb4b10d8b65d2e845384..8d7ab610fb1bb0fb6bec32d8d3c56634174beb8b 100644 (file)
@@ -1,7 +1,7 @@
 /*
- * Copyright 2016-2018 The OpenSSL Project Authors. All Rights Reserved.
+ * Copyright 2016-2019 The OpenSSL Project Authors. All Rights Reserved.
  *
- * Licensed under the OpenSSL license (the "License").  You may not use
+ * Licensed under the Apache License 2.0 (the "License").  You may not use
  * this file except in compliance with the License.  You can obtain a copy
  * in the file LICENSE in the source distribution or at
  * https://www.openssl.org/source/license.html
 
 #ifdef OPENSSL_SYS_UNIX
 # include <unistd.h>
+#ifndef OPENSSL_NO_KTLS
+# include <netinet/in.h>
+# include <netinet/in.h>
+# include <arpa/inet.h>
+# include <sys/socket.h>
+# include <unistd.h>
+# include <fcntl.h>
+#endif
 
-static ossl_inline void ossl_sleep(unsigned int millis) {
+static ossl_inline void ossl_sleep(unsigned int millis)
+{
+# ifdef OPENSSL_SYS_VXWORKS
+    struct timespec ts;
+    ts.tv_sec = (long int) (millis / 1000);
+    ts.tv_nsec = (long int) (millis % 1000) * 1000000ul;
+    nanosleep(&ts, NULL);
+# else
     usleep(millis * 1000);
+# endif
 }
 #elif defined(_WIN32)
 # include <windows.h>
 
-static ossl_inline void ossl_sleep(unsigned int millis) {
+static ossl_inline void ossl_sleep(unsigned int millis)
+{
     Sleep(millis);
 }
 #else
 /* Fallback to a busy wait */
-static ossl_inline void ossl_sleep(unsigned int millis) {
+static ossl_inline void ossl_sleep(unsigned int millis)
+{
     struct timeval start, now;
     unsigned int elapsedms;
 
@@ -284,6 +302,7 @@ typedef struct mempacket_test_ctx_st {
     unsigned int noinject;
     unsigned int dropepoch;
     int droprec;
+    int duprec;
 } MEMPACKET_TEST_CTX;
 
 static int mempacket_test_new(BIO *bi);
@@ -426,12 +445,27 @@ int mempacket_test_inject(BIO *bio, const char *in, int inl, int pktnum,
                           int type)
 {
     MEMPACKET_TEST_CTX *ctx = BIO_get_data(bio);
-    MEMPACKET *thispkt, *looppkt, *nextpkt;
-    int i;
+    MEMPACKET *thispkt = NULL, *looppkt, *nextpkt, *allpkts[3];
+    int i, duprec;
+    const unsigned char *inu = (const unsigned char *)in;
+    size_t len = ((inu[RECORD_LEN_HI] << 8) | inu[RECORD_LEN_LO])
+                 + DTLS1_RT_HEADER_LENGTH;
 
     if (ctx == NULL)
         return -1;
 
+    if ((size_t)inl < len)
+        return -1;
+
+    if ((size_t)inl == len)
+        duprec = 0;
+    else
+        duprec = ctx->duprec > 0;
+
+    /* We don't support arbitrary injection when duplicating records */
+    if (duprec && pktnum != -1)
+        return -1;
+
     /* We only allow injection before we've started writing any data */
     if (pktnum >= 0) {
         if (ctx->noinject)
@@ -441,25 +475,36 @@ int mempacket_test_inject(BIO *bio, const char *in, int inl, int pktnum,
         ctx->noinject = 1;
     }
 
-    if (!TEST_ptr(thispkt = OPENSSL_malloc(sizeof(*thispkt))))
-        return -1;
-    if (!TEST_ptr(thispkt->data = OPENSSL_malloc(inl))) {
-        mempacket_free(thispkt);
-        return -1;
-    }
+    for (i = 0; i < (duprec ? 3 : 1); i++) {
+        if (!TEST_ptr(allpkts[i] = OPENSSL_malloc(sizeof(*thispkt))))
+            goto err;
+        thispkt = allpkts[i];
 
-    memcpy(thispkt->data, in, inl);
-    thispkt->len = inl;
-    thispkt->num = (pktnum >= 0) ? (unsigned int)pktnum : ctx->lastpkt;
-    thispkt->type = type;
+        if (!TEST_ptr(thispkt->data = OPENSSL_malloc(inl)))
+            goto err;
+        /*
+         * If we are duplicating the packet, we duplicate it three times. The
+         * first two times we drop the first record if there are more than one.
+         * In this way we know that libssl will not be able to make progress
+         * until it receives the last packet, and hence will be forced to
+         * buffer these records.
+         */
+        if (duprec && i != 2) {
+            memcpy(thispkt->data, in + len, inl - len);
+            thispkt->len = inl - len;
+        } else {
+            memcpy(thispkt->data, in, inl);
+            thispkt->len = inl;
+        }
+        thispkt->num = (pktnum >= 0) ? (unsigned int)pktnum : ctx->lastpkt + i;
+        thispkt->type = type;
+    }
 
     for(i = 0; (looppkt = sk_MEMPACKET_value(ctx->pkts, i)) != NULL; i++) {
         /* Check if we found the right place to insert this packet */
         if (looppkt->num > thispkt->num) {
-            if (sk_MEMPACKET_insert(ctx->pkts, thispkt, i) == 0) {
-                mempacket_free(thispkt);
-                return -1;
-            }
+            if (sk_MEMPACKET_insert(ctx->pkts, thispkt, i) == 0)
+                goto err;
             /* If we're doing up front injection then we're done */
             if (pktnum >= 0)
                 return inl;
@@ -480,7 +525,7 @@ int mempacket_test_inject(BIO *bio, const char *in, int inl, int pktnum,
         } else if (looppkt->num == thispkt->num) {
             if (!ctx->noinject) {
                 /* We injected two packets with the same packet number! */
-                return -1;
+                goto err;
             }
             ctx->lastpkt++;
             thispkt->num++;
@@ -490,15 +535,21 @@ int mempacket_test_inject(BIO *bio, const char *in, int inl, int pktnum,
      * We didn't find any packets with a packet number equal to or greater than
      * this one, so we just add it onto the end
      */
-    if (!sk_MEMPACKET_push(ctx->pkts, thispkt)) {
-        mempacket_free(thispkt);
-        return -1;
-    }
+    for (i = 0; i < (duprec ? 3 : 1); i++) {
+        thispkt = allpkts[i];
+        if (!sk_MEMPACKET_push(ctx->pkts, thispkt))
+            goto err;
 
-    if (pktnum < 0)
-        ctx->lastpkt++;
+        if (pktnum < 0)
+            ctx->lastpkt++;
+    }
 
     return inl;
+
+ err:
+    for (i = 0; i < (ctx->duprec > 0 ? 3 : 1); i++)
+        mempacket_free(allpkts[i]);
+    return -1;
 }
 
 static int mempacket_test_write(BIO *bio, const char *in, int inl)
@@ -544,6 +595,9 @@ static long mempacket_test_ctrl(BIO *bio, int cmd, long num, void *ptr)
     case MEMPACKET_CTRL_GET_DROP_REC:
         ret = ctx->droprec;
         break;
+    case MEMPACKET_CTRL_SET_DUPLICATE_REC:
+        ctx->duprec = (int)num;
+        break;
     case BIO_CTRL_RESET:
     case BIO_CTRL_DUP:
     case BIO_CTRL_PUSH:
@@ -621,6 +675,114 @@ int create_ssl_ctx_pair(const SSL_METHOD *sm, const SSL_METHOD *cm,
 
 #define MAXLOOPS    1000000
 
+#if !defined(OPENSSL_NO_KTLS) && !defined(OPENSSL_NO_SOCK)
+static int set_nb(int fd)
+{
+    int flags;
+
+    flags = fcntl(fd,F_GETFL,0);
+    if (flags == -1)
+        return flags;
+    flags = fcntl(fd, F_SETFL, flags | O_NONBLOCK);
+    return flags;
+}
+
+int create_test_sockets(int *cfd, int *sfd)
+{
+    struct sockaddr_in sin;
+    const char *host = "127.0.0.1";
+    int cfd_connected = 0, ret = 0;
+    socklen_t slen = sizeof(sin);
+    int afd = -1;
+
+    *cfd = -1;
+    *sfd = -1;
+
+    memset ((char *) &sin, 0, sizeof(sin));
+    sin.sin_family = AF_INET;
+    sin.sin_addr.s_addr = inet_addr(host);
+
+    afd = socket(AF_INET, SOCK_STREAM, 0);
+    if (afd < 0)
+        return 0;
+
+    if (bind(afd, (struct sockaddr*)&sin, sizeof(sin)) < 0)
+        goto out;
+
+    if (getsockname(afd, (struct sockaddr*)&sin, &slen) < 0)
+        goto out;
+
+    if (listen(afd, 1) < 0)
+        goto out;
+
+    *cfd = socket(AF_INET, SOCK_STREAM, 0);
+    if (*cfd < 0)
+        goto out;
+
+    if (set_nb(afd) == -1)
+        goto out;
+
+    while (*sfd == -1 || !cfd_connected ) {
+        *sfd = accept(afd, NULL, 0);
+        if (*sfd == -1 && errno != EAGAIN)
+            goto out;
+
+        if (!cfd_connected && connect(*cfd, (struct sockaddr*)&sin, sizeof(sin)) < 0)
+            goto out;
+        else
+            cfd_connected = 1;
+    }
+
+    if (set_nb(*cfd) == -1 || set_nb(*sfd) == -1)
+        goto out;
+    ret = 1;
+    goto success;
+
+out:
+        if (*cfd != -1)
+            close(*cfd);
+        if (*sfd != -1)
+            close(*sfd);
+success:
+        if (afd != -1)
+            close(afd);
+    return ret;
+}
+
+int create_ssl_objects2(SSL_CTX *serverctx, SSL_CTX *clientctx, SSL **sssl,
+                          SSL **cssl, int sfd, int cfd)
+{
+    SSL *serverssl = NULL, *clientssl = NULL;
+    BIO *s_to_c_bio = NULL, *c_to_s_bio = NULL;
+
+    if (*sssl != NULL)
+        serverssl = *sssl;
+    else if (!TEST_ptr(serverssl = SSL_new(serverctx)))
+        goto error;
+    if (*cssl != NULL)
+        clientssl = *cssl;
+    else if (!TEST_ptr(clientssl = SSL_new(clientctx)))
+        goto error;
+
+    if (!TEST_ptr(s_to_c_bio = BIO_new_socket(sfd, BIO_NOCLOSE))
+            || !TEST_ptr(c_to_s_bio = BIO_new_socket(cfd, BIO_NOCLOSE)))
+        goto error;
+
+    SSL_set_bio(clientssl, c_to_s_bio, c_to_s_bio);
+    SSL_set_bio(serverssl, s_to_c_bio, s_to_c_bio);
+    *sssl = serverssl;
+    *cssl = clientssl;
+    return 1;
+
+ error:
+    SSL_free(serverssl);
+    SSL_free(clientssl);
+    BIO_free(s_to_c_bio);
+    BIO_free(c_to_s_bio);
+    return 0;
+}
+#endif
+
 /*
  * NOTE: Transfers control of the BIOs - this function will free them on error
  */
@@ -680,12 +842,18 @@ int create_ssl_objects(SSL_CTX *serverctx, SSL_CTX *clientctx, SSL **sssl,
     return 0;
 }
 
-int create_ssl_connection(SSL *serverssl, SSL *clientssl, int want)
+/*
+ * Create an SSL connection, but does not ready any post-handshake
+ * NewSessionTicket messages.
+ * If |read| is set and we're using DTLS then we will attempt to SSL_read on
+ * the connection once we've completed one half of it, to ensure any retransmits
+ * get triggered.
+ */
+int create_bare_ssl_connection(SSL *serverssl, SSL *clientssl, int want,
+                               int read)
 {
     int retc = -1, rets = -1, err, abortctr = 0;
     int clienterr = 0, servererr = 0;
-    unsigned char buf;
-    size_t readbytes;
     int isdtls = SSL_is_dtls(serverssl);
 
     do {
@@ -710,7 +878,9 @@ int create_ssl_connection(SSL *serverssl, SSL *clientssl, int want)
                 err = SSL_get_error(serverssl, rets);
         }
 
-        if (!servererr && rets <= 0 && err != SSL_ERROR_WANT_READ) {
+        if (!servererr && rets <= 0
+                && err != SSL_ERROR_WANT_READ
+                && err != SSL_ERROR_WANT_X509_LOOKUP) {
             TEST_info("SSL_accept() failed %d, %d", rets, err);
             servererr = 1;
         }
@@ -718,11 +888,24 @@ int create_ssl_connection(SSL *serverssl, SSL *clientssl, int want)
             return 0;
         if (clienterr && servererr)
             return 0;
-        if (isdtls) {
-            if (rets > 0 && retc <= 0)
-                DTLSv1_handle_timeout(serverssl);
-            if (retc > 0 && rets <= 0)
-                DTLSv1_handle_timeout(clientssl);
+        if (isdtls && read) {
+            unsigned char buf[20];
+
+            /* Trigger any retransmits that may be appropriate */
+            if (rets > 0 && retc <= 0) {
+                if (SSL_read(serverssl, buf, sizeof(buf)) > 0) {
+                    /* We don't expect this to succeed! */
+                    TEST_info("Unexpected SSL_read() success!");
+                    return 0;
+                }
+            }
+            if (retc > 0 && rets <= 0) {
+                if (SSL_read(clientssl, buf, sizeof(buf)) > 0) {
+                    /* We don't expect this to succeed! */
+                    TEST_info("Unexpected SSL_read() success!");
+                    return 0;
+                }
+            }
         }
         if (++abortctr == MAXLOOPS) {
             TEST_info("No progress made");
@@ -738,16 +921,35 @@ int create_ssl_connection(SSL *serverssl, SSL *clientssl, int want)
         }
     } while (retc <=0 || rets <= 0);
 
+    return 1;
+}
+
+/*
+ * Create an SSL connection including any post handshake NewSessionTicket
+ * messages.
+ */
+int create_ssl_connection(SSL *serverssl, SSL *clientssl, int want)
+{
+    int i;
+    unsigned char buf;
+    size_t readbytes;
+
+    if (!create_bare_ssl_connection(serverssl, clientssl, want, 1))
+        return 0;
+
     /*
      * We attempt to read some data on the client side which we expect to fail.
      * This will ensure we have received the NewSessionTicket in TLSv1.3 where
-     * appropriate.
+     * appropriate. We do this twice because there are 2 NewSesionTickets.
      */
-    if (SSL_read_ex(clientssl, &buf, sizeof(buf), &readbytes) > 0) {
-        if (!TEST_ulong_eq(readbytes, 0))
+    for (i = 0; i < 2; i++) {
+        if (SSL_read_ex(clientssl, &buf, sizeof(buf), &readbytes) > 0) {
+            if (!TEST_ulong_eq(readbytes, 0))
+                return 0;
+        } else if (!TEST_int_eq(SSL_get_error(clientssl, 0),
+                                SSL_ERROR_WANT_READ)) {
             return 0;
-    } else if (!TEST_int_eq(SSL_get_error(clientssl, 0), SSL_ERROR_WANT_READ)) {
-        return 0;
+        }
     }
 
     return 1;