Make the CT code library context aware
[openssl.git] / test / ssltestlib.c
index eafac3cc42f78a6f61623cfb80fcb578f93bcb52..e579ceff92e04296470234d05b2ea3f7f16cb8b5 100644 (file)
@@ -1,7 +1,7 @@
 /*
- * Copyright 2016-2018 The OpenSSL Project Authors. All Rights Reserved.
+ * Copyright 2016-2019 The OpenSSL Project Authors. All Rights Reserved.
  *
- * Licensed under the OpenSSL license (the "License").  You may not use
+ * Licensed under the Apache License 2.0 (the "License").  You may not use
  * this file except in compliance with the License.  You can obtain a copy
  * in the file LICENSE in the source distribution or at
  * https://www.openssl.org/source/license.html
 #include <string.h>
 
 #include "internal/nelem.h"
+#include "internal/cryptlib.h" /* for ossl_sleep() */
 #include "ssltestlib.h"
 #include "testutil.h"
 #include "e_os.h"
 
 #ifdef OPENSSL_SYS_UNIX
 # include <unistd.h>
-
-static ossl_inline void ossl_sleep(unsigned int millis) {
-    usleep(millis * 1000);
-}
-#elif defined(_WIN32)
-# include <windows.h>
-
-static ossl_inline void ossl_sleep(unsigned int millis) {
-    Sleep(millis);
-}
-#else
-/* Fallback to a busy wait */
-static ossl_inline void ossl_sleep(unsigned int millis) {
-    struct timeval start, now;
-    unsigned int elapsedms;
-
-    gettimeofday(&start, NULL);
-    do {
-        gettimeofday(&now, NULL);
-        elapsedms = (((now.tv_sec - start.tv_sec) * 1000000)
-                     + now.tv_usec - start.tv_usec) / 1000;
-    } while (elapsedms < millis);
-}
+# ifndef OPENSSL_NO_KTLS
+#  include <netinet/in.h>
+#  include <netinet/in.h>
+#  include <arpa/inet.h>
+#  include <sys/socket.h>
+#  include <unistd.h>
+#  include <fcntl.h>
+# endif
 #endif
 
 static int tls_dump_new(BIO *bi);
@@ -52,9 +38,11 @@ static int tls_dump_puts(BIO *bp, const char *str);
 /* Choose a sufficiently large type likely to be unused for this custom BIO */
 #define BIO_TYPE_TLS_DUMP_FILTER  (0x80 | BIO_TYPE_FILTER)
 #define BIO_TYPE_MEMPACKET_TEST    0x81
+#define BIO_TYPE_ALWAYS_RETRY      0x82
 
 static BIO_METHOD *method_tls_dump = NULL;
 static BIO_METHOD *meth_mem = NULL;
+static BIO_METHOD *meth_always_retry = NULL;
 
 /* Note: Not thread safe! */
 const BIO_METHOD *bio_f_tls_dump_filter(void)
@@ -428,7 +416,7 @@ int mempacket_test_inject(BIO *bio, const char *in, int inl, int pktnum,
 {
     MEMPACKET_TEST_CTX *ctx = BIO_get_data(bio);
     MEMPACKET *thispkt = NULL, *looppkt, *nextpkt, *allpkts[3];
-    int i, duprec = ctx->duprec > 0;
+    int i, duprec;
     const unsigned char *inu = (const unsigned char *)in;
     size_t len = ((inu[RECORD_LEN_HI] << 8) | inu[RECORD_LEN_LO])
                  + DTLS1_RT_HEADER_LENGTH;
@@ -441,6 +429,8 @@ int mempacket_test_inject(BIO *bio, const char *in, int inl, int pktnum,
 
     if ((size_t)inl == len)
         duprec = 0;
+    else
+        duprec = ctx->duprec > 0;
 
     /* We don't support arbitrary injection when duplicating records */
     if (duprec && pktnum != -1)
@@ -600,6 +590,100 @@ static int mempacket_test_puts(BIO *bio, const char *str)
     return mempacket_test_write(bio, str, strlen(str));
 }
 
+static int always_retry_new(BIO *bi);
+static int always_retry_free(BIO *a);
+static int always_retry_read(BIO *b, char *out, int outl);
+static int always_retry_write(BIO *b, const char *in, int inl);
+static long always_retry_ctrl(BIO *b, int cmd, long num, void *ptr);
+static int always_retry_gets(BIO *bp, char *buf, int size);
+static int always_retry_puts(BIO *bp, const char *str);
+
+const BIO_METHOD *bio_s_always_retry(void)
+{
+    if (meth_always_retry == NULL) {
+        if (!TEST_ptr(meth_always_retry = BIO_meth_new(BIO_TYPE_ALWAYS_RETRY,
+                                                       "Always Retry"))
+            || !TEST_true(BIO_meth_set_write(meth_always_retry,
+                                             always_retry_write))
+            || !TEST_true(BIO_meth_set_read(meth_always_retry,
+                                            always_retry_read))
+            || !TEST_true(BIO_meth_set_puts(meth_always_retry,
+                                            always_retry_puts))
+            || !TEST_true(BIO_meth_set_gets(meth_always_retry,
+                                            always_retry_gets))
+            || !TEST_true(BIO_meth_set_ctrl(meth_always_retry,
+                                            always_retry_ctrl))
+            || !TEST_true(BIO_meth_set_create(meth_always_retry,
+                                              always_retry_new))
+            || !TEST_true(BIO_meth_set_destroy(meth_always_retry,
+                                               always_retry_free)))
+            return NULL;
+    }
+    return meth_always_retry;
+}
+
+void bio_s_always_retry_free(void)
+{
+    BIO_meth_free(meth_always_retry);
+}
+
+static int always_retry_new(BIO *bio)
+{
+    BIO_set_init(bio, 1);
+    return 1;
+}
+
+static int always_retry_free(BIO *bio)
+{
+    BIO_set_data(bio, NULL);
+    BIO_set_init(bio, 0);
+    return 1;
+}
+
+static int always_retry_read(BIO *bio, char *out, int outl)
+{
+    BIO_set_retry_read(bio);
+    return -1;
+}
+
+static int always_retry_write(BIO *bio, const char *in, int inl)
+{
+    BIO_set_retry_write(bio);
+    return -1;
+}
+
+static long always_retry_ctrl(BIO *bio, int cmd, long num, void *ptr)
+{
+    long ret = 1;
+
+    switch (cmd) {
+    case BIO_CTRL_FLUSH:
+        BIO_set_retry_write(bio);
+        /* fall through */
+    case BIO_CTRL_EOF:
+    case BIO_CTRL_RESET:
+    case BIO_CTRL_DUP:
+    case BIO_CTRL_PUSH:
+    case BIO_CTRL_POP:
+    default:
+        ret = 0;
+        break;
+    }
+    return ret;
+}
+
+static int always_retry_gets(BIO *bio, char *buf, int size)
+{
+    BIO_set_retry_read(bio);
+    return -1;
+}
+
+static int always_retry_puts(BIO *bio, const char *str)
+{
+    BIO_set_retry_write(bio);
+    return -1;
+}
+
 int create_ssl_ctx_pair(const SSL_METHOD *sm, const SSL_METHOD *cm,
                         int min_proto_version, int max_proto_version,
                         SSL_CTX **sctx, SSL_CTX **cctx, char *certfile,
@@ -608,10 +692,18 @@ int create_ssl_ctx_pair(const SSL_METHOD *sm, const SSL_METHOD *cm,
     SSL_CTX *serverctx = NULL;
     SSL_CTX *clientctx = NULL;
 
-    if (!TEST_ptr(serverctx = SSL_CTX_new(sm))
-            || (cctx != NULL && !TEST_ptr(clientctx = SSL_CTX_new(cm))))
+    if (*sctx != NULL)
+        serverctx = *sctx;
+    else if (!TEST_ptr(serverctx = SSL_CTX_new(sm)))
         goto err;
 
+    if (cctx != NULL) {
+        if (*cctx != NULL)
+            clientctx = *cctx;
+        else if (!TEST_ptr(clientctx = SSL_CTX_new(cm)))
+            goto err;
+    }
+
     if ((min_proto_version > 0
          && !TEST_true(SSL_CTX_set_min_proto_version(serverctx,
                                                      min_proto_version)))
@@ -655,6 +747,114 @@ int create_ssl_ctx_pair(const SSL_METHOD *sm, const SSL_METHOD *cm,
 
 #define MAXLOOPS    1000000
 
+#if !defined(OPENSSL_NO_KTLS) && !defined(OPENSSL_NO_SOCK)
+static int set_nb(int fd)
+{
+    int flags;
+
+    flags = fcntl(fd,F_GETFL,0);
+    if (flags == -1)
+        return flags;
+    flags = fcntl(fd, F_SETFL, flags | O_NONBLOCK);
+    return flags;
+}
+
+int create_test_sockets(int *cfd, int *sfd)
+{
+    struct sockaddr_in sin;
+    const char *host = "127.0.0.1";
+    int cfd_connected = 0, ret = 0;
+    socklen_t slen = sizeof(sin);
+    int afd = -1;
+
+    *cfd = -1;
+    *sfd = -1;
+
+    memset ((char *) &sin, 0, sizeof(sin));
+    sin.sin_family = AF_INET;
+    sin.sin_addr.s_addr = inet_addr(host);
+
+    afd = socket(AF_INET, SOCK_STREAM, 0);
+    if (afd < 0)
+        return 0;
+
+    if (bind(afd, (struct sockaddr*)&sin, sizeof(sin)) < 0)
+        goto out;
+
+    if (getsockname(afd, (struct sockaddr*)&sin, &slen) < 0)
+        goto out;
+
+    if (listen(afd, 1) < 0)
+        goto out;
+
+    *cfd = socket(AF_INET, SOCK_STREAM, 0);
+    if (*cfd < 0)
+        goto out;
+
+    if (set_nb(afd) == -1)
+        goto out;
+
+    while (*sfd == -1 || !cfd_connected ) {
+        *sfd = accept(afd, NULL, 0);
+        if (*sfd == -1 && errno != EAGAIN)
+            goto out;
+
+        if (!cfd_connected && connect(*cfd, (struct sockaddr*)&sin, sizeof(sin)) < 0)
+            goto out;
+        else
+            cfd_connected = 1;
+    }
+
+    if (set_nb(*cfd) == -1 || set_nb(*sfd) == -1)
+        goto out;
+    ret = 1;
+    goto success;
+
+out:
+        if (*cfd != -1)
+            close(*cfd);
+        if (*sfd != -1)
+            close(*sfd);
+success:
+        if (afd != -1)
+            close(afd);
+    return ret;
+}
+
+int create_ssl_objects2(SSL_CTX *serverctx, SSL_CTX *clientctx, SSL **sssl,
+                          SSL **cssl, int sfd, int cfd)
+{
+    SSL *serverssl = NULL, *clientssl = NULL;
+    BIO *s_to_c_bio = NULL, *c_to_s_bio = NULL;
+
+    if (*sssl != NULL)
+        serverssl = *sssl;
+    else if (!TEST_ptr(serverssl = SSL_new(serverctx)))
+        goto error;
+    if (*cssl != NULL)
+        clientssl = *cssl;
+    else if (!TEST_ptr(clientssl = SSL_new(clientctx)))
+        goto error;
+
+    if (!TEST_ptr(s_to_c_bio = BIO_new_socket(sfd, BIO_NOCLOSE))
+            || !TEST_ptr(c_to_s_bio = BIO_new_socket(cfd, BIO_NOCLOSE)))
+        goto error;
+
+    SSL_set_bio(clientssl, c_to_s_bio, c_to_s_bio);
+    SSL_set_bio(serverssl, s_to_c_bio, s_to_c_bio);
+    *sssl = serverssl;
+    *cssl = clientssl;
+    return 1;
+
+ error:
+    SSL_free(serverssl);
+    SSL_free(clientssl);
+    BIO_free(s_to_c_bio);
+    BIO_free(c_to_s_bio);
+    return 0;
+}
+#endif
+
 /*
  * NOTE: Transfers control of the BIOs - this function will free them on error
  */
@@ -715,10 +915,17 @@ int create_ssl_objects(SSL_CTX *serverctx, SSL_CTX *clientctx, SSL **sssl,
 }
 
 /*
- * Create an SSL connection, but does not ready any post-handshake
+ * Create an SSL connection, but does not read any post-handshake
  * NewSessionTicket messages.
+ * If |read| is set and we're using DTLS then we will attempt to SSL_read on
+ * the connection once we've completed one half of it, to ensure any retransmits
+ * get triggered.
+ * We stop the connection attempt (and return a failure value) if either peer
+ * has SSL_get_error() return the value in the |want| parameter. The connection
+ * attempt could be restarted by a subsequent call to this function.
  */
-int create_bare_ssl_connection(SSL *serverssl, SSL *clientssl, int want)
+int create_bare_ssl_connection(SSL *serverssl, SSL *clientssl, int want,
+                               int read)
 {
     int retc = -1, rets = -1, err, abortctr = 0;
     int clienterr = 0, servererr = 0;
@@ -734,6 +941,8 @@ int create_bare_ssl_connection(SSL *serverssl, SSL *clientssl, int want)
 
         if (!clienterr && retc <= 0 && err != SSL_ERROR_WANT_READ) {
             TEST_info("SSL_connect() failed %d, %d", retc, err);
+            if (want != SSL_ERROR_SSL)
+                TEST_openssl_errors();
             clienterr = 1;
         }
         if (want != SSL_ERROR_NONE && err == want)
@@ -750,17 +959,32 @@ int create_bare_ssl_connection(SSL *serverssl, SSL *clientssl, int want)
                 && err != SSL_ERROR_WANT_READ
                 && err != SSL_ERROR_WANT_X509_LOOKUP) {
             TEST_info("SSL_accept() failed %d, %d", rets, err);
+            if (want != SSL_ERROR_SSL)
+                TEST_openssl_errors();
             servererr = 1;
         }
         if (want != SSL_ERROR_NONE && err == want)
             return 0;
         if (clienterr && servererr)
             return 0;
-        if (isdtls) {
-            if (rets > 0 && retc <= 0)
-                DTLSv1_handle_timeout(serverssl);
-            if (retc > 0 && rets <= 0)
-                DTLSv1_handle_timeout(clientssl);
+        if (isdtls && read) {
+            unsigned char buf[20];
+
+            /* Trigger any retransmits that may be appropriate */
+            if (rets > 0 && retc <= 0) {
+                if (SSL_read(serverssl, buf, sizeof(buf)) > 0) {
+                    /* We don't expect this to succeed! */
+                    TEST_info("Unexpected SSL_read() success!");
+                    return 0;
+                }
+            }
+            if (retc > 0 && rets <= 0) {
+                if (SSL_read(clientssl, buf, sizeof(buf)) > 0) {
+                    /* We don't expect this to succeed! */
+                    TEST_info("Unexpected SSL_read() success!");
+                    return 0;
+                }
+            }
         }
         if (++abortctr == MAXLOOPS) {
             TEST_info("No progress made");
@@ -789,13 +1013,13 @@ int create_ssl_connection(SSL *serverssl, SSL *clientssl, int want)
     unsigned char buf;
     size_t readbytes;
 
-    if (!create_bare_ssl_connection(serverssl, clientssl, want))
+    if (!create_bare_ssl_connection(serverssl, clientssl, want, 1))
         return 0;
 
     /*
      * We attempt to read some data on the client side which we expect to fail.
      * This will ensure we have received the NewSessionTicket in TLSv1.3 where
-     * appropriate. We do this twice because there are 2 NewSesionTickets.
+     * appropriate. We do this twice because there are 2 NewSessionTickets.
      */
     for (i = 0; i < 2; i++) {
         if (SSL_read_ex(clientssl, &buf, sizeof(buf), &readbytes) > 0) {