Don't exclude quite so much in a no-sock build
[openssl.git] / test / ssltestlib.c
index 64acea3f7dc6d3cbd06b13f7946085feb0080f2f..66d4e9b3a072268ed0e66718441f45b9a4cc1caf 100644 (file)
@@ -1,7 +1,7 @@
 /*
- * Copyright 2016 The OpenSSL Project Authors. All Rights Reserved.
+ * Copyright 2016-2019 The OpenSSL Project Authors. All Rights Reserved.
  *
- * Licensed under the OpenSSL license (the "License").  You may not use
+ * Licensed under the Apache License 2.0 (the "License").  You may not use
  * this file except in compliance with the License.  You can obtain a copy
  * in the file LICENSE in the source distribution or at
  * https://www.openssl.org/source/license.html
@@ -9,9 +9,23 @@
 
 #include <string.h>
 
-#include "e_os.h"
+#include "internal/nelem.h"
+#include "internal/cryptlib.h" /* for ossl_sleep() */
 #include "ssltestlib.h"
 #include "testutil.h"
+#include "e_os.h"
+
+#ifdef OPENSSL_SYS_UNIX
+# include <unistd.h>
+# ifndef OPENSSL_NO_KTLS
+#  include <netinet/in.h>
+#  include <netinet/in.h>
+#  include <arpa/inet.h>
+#  include <sys/socket.h>
+#  include <unistd.h>
+#  include <fcntl.h>
+# endif
+#endif
 
 static int tls_dump_new(BIO *bi);
 static int tls_dump_free(BIO *a);
@@ -24,9 +38,11 @@ static int tls_dump_puts(BIO *bp, const char *str);
 /* Choose a sufficiently large type likely to be unused for this custom BIO */
 #define BIO_TYPE_TLS_DUMP_FILTER  (0x80 | BIO_TYPE_FILTER)
 #define BIO_TYPE_MEMPACKET_TEST    0x81
+#define BIO_TYPE_ALWAYS_RETRY      0x82
 
 static BIO_METHOD *method_tls_dump = NULL;
 static BIO_METHOD *meth_mem = NULL;
+static BIO_METHOD *meth_always_retry = NULL;
 
 /* Note: Not thread safe! */
 const BIO_METHOD *bio_f_tls_dump_filter(void)
@@ -252,7 +268,11 @@ typedef struct mempacket_test_ctx_st {
     unsigned int currrec;
     unsigned int currpkt;
     unsigned int lastpkt;
+    unsigned int injected;
     unsigned int noinject;
+    unsigned int dropepoch;
+    int droprec;
+    int duprec;
 } MEMPACKET_TEST_CTX;
 
 static int mempacket_test_new(BIO *bi);
@@ -295,6 +315,8 @@ static int mempacket_test_new(BIO *bio)
         OPENSSL_free(ctx);
         return 0;
     }
+    ctx->dropepoch = 0;
+    ctx->droprec = -1;
     BIO_set_init(bio, 1);
     BIO_set_data(bio, ctx);
     return 1;
@@ -312,8 +334,8 @@ static int mempacket_test_free(BIO *bio)
 }
 
 /* Record Header values */
-#define EPOCH_HI        4
-#define EPOCH_LO        5
+#define EPOCH_HI        3
+#define EPOCH_LO        4
 #define RECORD_SEQUENCE 10
 #define RECORD_LEN_HI   11
 #define RECORD_LEN_LO   12
@@ -341,15 +363,15 @@ static int mempacket_test_read(BIO *bio, char *out, int outl)
     if (outl > thispkt->len)
         outl = thispkt->len;
 
-    if (thispkt->type != INJECT_PACKET_IGNORE_REC_SEQ) {
+    if (thispkt->type != INJECT_PACKET_IGNORE_REC_SEQ
+            && (ctx->injected || ctx->droprec >= 0)) {
         /*
          * Overwrite the record sequence number. We strictly number them in
          * the order received. Since we are actually a reliable transport
          * we know that there won't be any re-ordering. We overwrite to deal
          * with any packets that have been injected
          */
-        for (rem = thispkt->len, rec = thispkt->data
-                ; rem > 0; rec += len, rem -= len) {
+        for (rem = thispkt->len, rec = thispkt->data; rem > 0; rem -= len) {
             if (rem < DTLS1_RT_HEADER_LENGTH)
                 return -1;
             epoch = (rec[EPOCH_HI] << 8) | rec[EPOCH_LO];
@@ -364,10 +386,23 @@ static int mempacket_test_read(BIO *bio, char *out, int outl)
                 seq >>= 8;
                 offset++;
             } while (seq > 0);
-            ctx->currrec++;
 
             len = ((rec[RECORD_LEN_HI] << 8) | rec[RECORD_LEN_LO])
                   + DTLS1_RT_HEADER_LENGTH;
+            if (rem < (int)len)
+                return -1;
+            if (ctx->droprec == (int)ctx->currrec && ctx->dropepoch == epoch) {
+                if (rem > (int)len)
+                    memmove(rec, rec + len, rem - len);
+                outl -= len;
+                ctx->droprec = -1;
+                if (outl == 0)
+                    BIO_set_retry_read(bio);
+            } else {
+                rec += len;
+            }
+
+            ctx->currrec++;
         }
     }
 
@@ -380,39 +415,66 @@ int mempacket_test_inject(BIO *bio, const char *in, int inl, int pktnum,
                           int type)
 {
     MEMPACKET_TEST_CTX *ctx = BIO_get_data(bio);
-    MEMPACKET *thispkt, *looppkt, *nextpkt;
-    int i;
+    MEMPACKET *thispkt = NULL, *looppkt, *nextpkt, *allpkts[3];
+    int i, duprec;
+    const unsigned char *inu = (const unsigned char *)in;
+    size_t len = ((inu[RECORD_LEN_HI] << 8) | inu[RECORD_LEN_LO])
+                 + DTLS1_RT_HEADER_LENGTH;
 
     if (ctx == NULL)
         return -1;
 
+    if ((size_t)inl < len)
+        return -1;
+
+    if ((size_t)inl == len)
+        duprec = 0;
+    else
+        duprec = ctx->duprec > 0;
+
+    /* We don't support arbitrary injection when duplicating records */
+    if (duprec && pktnum != -1)
+        return -1;
+
     /* We only allow injection before we've started writing any data */
     if (pktnum >= 0) {
         if (ctx->noinject)
             return -1;
+        ctx->injected  = 1;
     } else {
         ctx->noinject = 1;
     }
 
-    if (!TEST_ptr(thispkt = OPENSSL_malloc(sizeof(*thispkt))))
-        return -1;
-    if (!TEST_ptr(thispkt->data = OPENSSL_malloc(inl))) {
-        mempacket_free(thispkt);
-        return -1;
-    }
+    for (i = 0; i < (duprec ? 3 : 1); i++) {
+        if (!TEST_ptr(allpkts[i] = OPENSSL_malloc(sizeof(*thispkt))))
+            goto err;
+        thispkt = allpkts[i];
 
-    memcpy(thispkt->data, in, inl);
-    thispkt->len = inl;
-    thispkt->num = (pktnum >= 0) ? (unsigned int)pktnum : ctx->lastpkt;
-    thispkt->type = type;
+        if (!TEST_ptr(thispkt->data = OPENSSL_malloc(inl)))
+            goto err;
+        /*
+         * If we are duplicating the packet, we duplicate it three times. The
+         * first two times we drop the first record if there are more than one.
+         * In this way we know that libssl will not be able to make progress
+         * until it receives the last packet, and hence will be forced to
+         * buffer these records.
+         */
+        if (duprec && i != 2) {
+            memcpy(thispkt->data, in + len, inl - len);
+            thispkt->len = inl - len;
+        } else {
+            memcpy(thispkt->data, in, inl);
+            thispkt->len = inl;
+        }
+        thispkt->num = (pktnum >= 0) ? (unsigned int)pktnum : ctx->lastpkt + i;
+        thispkt->type = type;
+    }
 
     for(i = 0; (looppkt = sk_MEMPACKET_value(ctx->pkts, i)) != NULL; i++) {
         /* Check if we found the right place to insert this packet */
         if (looppkt->num > thispkt->num) {
-            if (sk_MEMPACKET_insert(ctx->pkts, thispkt, i) == 0) {
-                mempacket_free(thispkt);
-                return -1;
-            }
+            if (sk_MEMPACKET_insert(ctx->pkts, thispkt, i) == 0)
+                goto err;
             /* If we're doing up front injection then we're done */
             if (pktnum >= 0)
                 return inl;
@@ -433,7 +495,7 @@ int mempacket_test_inject(BIO *bio, const char *in, int inl, int pktnum,
         } else if (looppkt->num == thispkt->num) {
             if (!ctx->noinject) {
                 /* We injected two packets with the same packet number! */
-                return -1;
+                goto err;
             }
             ctx->lastpkt++;
             thispkt->num++;
@@ -443,15 +505,21 @@ int mempacket_test_inject(BIO *bio, const char *in, int inl, int pktnum,
      * We didn't find any packets with a packet number equal to or greater than
      * this one, so we just add it onto the end
      */
-    if (!sk_MEMPACKET_push(ctx->pkts, thispkt)) {
-        mempacket_free(thispkt);
-        return -1;
-    }
+    for (i = 0; i < (duprec ? 3 : 1); i++) {
+        thispkt = allpkts[i];
+        if (!sk_MEMPACKET_push(ctx->pkts, thispkt))
+            goto err;
 
-    if (pktnum < 0)
-        ctx->lastpkt++;
+        if (pktnum < 0)
+            ctx->lastpkt++;
+    }
 
     return inl;
+
+ err:
+    for (i = 0; i < (ctx->duprec > 0 ? 3 : 1); i++)
+        mempacket_free(allpkts[i]);
+    return -1;
 }
 
 static int mempacket_test_write(BIO *bio, const char *in, int inl)
@@ -488,6 +556,18 @@ static long mempacket_test_ctrl(BIO *bio, int cmd, long num, void *ptr)
     case BIO_CTRL_FLUSH:
         ret = 1;
         break;
+    case MEMPACKET_CTRL_SET_DROP_EPOCH:
+        ctx->dropepoch = (unsigned int)num;
+        break;
+    case MEMPACKET_CTRL_SET_DROP_REC:
+        ctx->droprec = (int)num;
+        break;
+    case MEMPACKET_CTRL_GET_DROP_REC:
+        ret = ctx->droprec;
+        break;
+    case MEMPACKET_CTRL_SET_DUPLICATE_REC:
+        ctx->duprec = (int)num;
+        break;
     case BIO_CTRL_RESET:
     case BIO_CTRL_DUP:
     case BIO_CTRL_PUSH:
@@ -510,23 +590,145 @@ static int mempacket_test_puts(BIO *bio, const char *str)
     return mempacket_test_write(bio, str, strlen(str));
 }
 
+static int always_retry_new(BIO *bi);
+static int always_retry_free(BIO *a);
+static int always_retry_read(BIO *b, char *out, int outl);
+static int always_retry_write(BIO *b, const char *in, int inl);
+static long always_retry_ctrl(BIO *b, int cmd, long num, void *ptr);
+static int always_retry_gets(BIO *bp, char *buf, int size);
+static int always_retry_puts(BIO *bp, const char *str);
+
+const BIO_METHOD *bio_s_always_retry(void)
+{
+    if (meth_always_retry == NULL) {
+        if (!TEST_ptr(meth_always_retry = BIO_meth_new(BIO_TYPE_ALWAYS_RETRY,
+                                                       "Always Retry"))
+            || !TEST_true(BIO_meth_set_write(meth_always_retry,
+                                             always_retry_write))
+            || !TEST_true(BIO_meth_set_read(meth_always_retry,
+                                            always_retry_read))
+            || !TEST_true(BIO_meth_set_puts(meth_always_retry,
+                                            always_retry_puts))
+            || !TEST_true(BIO_meth_set_gets(meth_always_retry,
+                                            always_retry_gets))
+            || !TEST_true(BIO_meth_set_ctrl(meth_always_retry,
+                                            always_retry_ctrl))
+            || !TEST_true(BIO_meth_set_create(meth_always_retry,
+                                              always_retry_new))
+            || !TEST_true(BIO_meth_set_destroy(meth_always_retry,
+                                               always_retry_free)))
+            return NULL;
+    }
+    return meth_always_retry;
+}
+
+void bio_s_always_retry_free(void)
+{
+    BIO_meth_free(meth_always_retry);
+}
+
+static int always_retry_new(BIO *bio)
+{
+    BIO_set_init(bio, 1);
+    return 1;
+}
+
+static int always_retry_free(BIO *bio)
+{
+    BIO_set_data(bio, NULL);
+    BIO_set_init(bio, 0);
+    return 1;
+}
+
+static int always_retry_read(BIO *bio, char *out, int outl)
+{
+    BIO_set_retry_read(bio);
+    return -1;
+}
+
+static int always_retry_write(BIO *bio, const char *in, int inl)
+{
+    BIO_set_retry_write(bio);
+    return -1;
+}
+
+static long always_retry_ctrl(BIO *bio, int cmd, long num, void *ptr)
+{
+    long ret = 1;
+
+    switch (cmd) {
+    case BIO_CTRL_FLUSH:
+        BIO_set_retry_write(bio);
+        /* fall through */
+    case BIO_CTRL_EOF:
+    case BIO_CTRL_RESET:
+    case BIO_CTRL_DUP:
+    case BIO_CTRL_PUSH:
+    case BIO_CTRL_POP:
+    default:
+        ret = 0;
+        break;
+    }
+    return ret;
+}
+
+static int always_retry_gets(BIO *bio, char *buf, int size)
+{
+    BIO_set_retry_read(bio);
+    return -1;
+}
+
+static int always_retry_puts(BIO *bio, const char *str)
+{
+    BIO_set_retry_write(bio);
+    return -1;
+}
+
 int create_ssl_ctx_pair(const SSL_METHOD *sm, const SSL_METHOD *cm,
+                        int min_proto_version, int max_proto_version,
                         SSL_CTX **sctx, SSL_CTX **cctx, char *certfile,
                         char *privkeyfile)
 {
     SSL_CTX *serverctx = NULL;
     SSL_CTX *clientctx = NULL;
 
-    if (!TEST_ptr(serverctx = SSL_CTX_new(sm))
-            || (cctx != NULL && !TEST_ptr(clientctx = SSL_CTX_new(cm))))
+    if (*sctx != NULL)
+        serverctx = *sctx;
+    else if (!TEST_ptr(serverctx = SSL_CTX_new(sm)))
         goto err;
 
-    if (!TEST_int_eq(SSL_CTX_use_certificate_file(serverctx, certfile,
-                                                  SSL_FILETYPE_PEM), 1)
-            || !TEST_int_eq(SSL_CTX_use_PrivateKey_file(serverctx, privkeyfile,
-                                                        SSL_FILETYPE_PEM), 1)
-            || !TEST_int_eq(SSL_CTX_check_private_key(serverctx), 1))
+    if (cctx != NULL) {
+        if (*cctx != NULL)
+            clientctx = *cctx;
+        else if (!TEST_ptr(clientctx = SSL_CTX_new(cm)))
+            goto err;
+    }
+
+    if ((min_proto_version > 0
+         && !TEST_true(SSL_CTX_set_min_proto_version(serverctx,
+                                                     min_proto_version)))
+        || (max_proto_version > 0
+            && !TEST_true(SSL_CTX_set_max_proto_version(serverctx,
+                                                        max_proto_version))))
         goto err;
+    if (clientctx != NULL
+        && ((min_proto_version > 0
+             && !TEST_true(SSL_CTX_set_min_proto_version(clientctx,
+                                                         min_proto_version)))
+            || (max_proto_version > 0
+                && !TEST_true(SSL_CTX_set_max_proto_version(clientctx,
+                                                            max_proto_version)))))
+        goto err;
+
+    if (certfile != NULL && privkeyfile != NULL) {
+        if (!TEST_int_eq(SSL_CTX_use_certificate_file(serverctx, certfile,
+                                                      SSL_FILETYPE_PEM), 1)
+                || !TEST_int_eq(SSL_CTX_use_PrivateKey_file(serverctx,
+                                                            privkeyfile,
+                                                            SSL_FILETYPE_PEM), 1)
+                || !TEST_int_eq(SSL_CTX_check_private_key(serverctx), 1))
+            goto err;
+    }
 
 #ifndef OPENSSL_NO_DH
     SSL_CTX_set_dh_auto(serverctx, 1);
@@ -545,6 +747,114 @@ int create_ssl_ctx_pair(const SSL_METHOD *sm, const SSL_METHOD *cm,
 
 #define MAXLOOPS    1000000
 
+#if !defined(OPENSSL_NO_KTLS) && !defined(OPENSSL_NO_SOCK)
+static int set_nb(int fd)
+{
+    int flags;
+
+    flags = fcntl(fd,F_GETFL,0);
+    if (flags == -1)
+        return flags;
+    flags = fcntl(fd, F_SETFL, flags | O_NONBLOCK);
+    return flags;
+}
+
+int create_test_sockets(int *cfd, int *sfd)
+{
+    struct sockaddr_in sin;
+    const char *host = "127.0.0.1";
+    int cfd_connected = 0, ret = 0;
+    socklen_t slen = sizeof(sin);
+    int afd = -1;
+
+    *cfd = -1;
+    *sfd = -1;
+
+    memset ((char *) &sin, 0, sizeof(sin));
+    sin.sin_family = AF_INET;
+    sin.sin_addr.s_addr = inet_addr(host);
+
+    afd = socket(AF_INET, SOCK_STREAM, 0);
+    if (afd < 0)
+        return 0;
+
+    if (bind(afd, (struct sockaddr*)&sin, sizeof(sin)) < 0)
+        goto out;
+
+    if (getsockname(afd, (struct sockaddr*)&sin, &slen) < 0)
+        goto out;
+
+    if (listen(afd, 1) < 0)
+        goto out;
+
+    *cfd = socket(AF_INET, SOCK_STREAM, 0);
+    if (*cfd < 0)
+        goto out;
+
+    if (set_nb(afd) == -1)
+        goto out;
+
+    while (*sfd == -1 || !cfd_connected ) {
+        *sfd = accept(afd, NULL, 0);
+        if (*sfd == -1 && errno != EAGAIN)
+            goto out;
+
+        if (!cfd_connected && connect(*cfd, (struct sockaddr*)&sin, sizeof(sin)) < 0)
+            goto out;
+        else
+            cfd_connected = 1;
+    }
+
+    if (set_nb(*cfd) == -1 || set_nb(*sfd) == -1)
+        goto out;
+    ret = 1;
+    goto success;
+
+out:
+        if (*cfd != -1)
+            close(*cfd);
+        if (*sfd != -1)
+            close(*sfd);
+success:
+        if (afd != -1)
+            close(afd);
+    return ret;
+}
+
+int create_ssl_objects2(SSL_CTX *serverctx, SSL_CTX *clientctx, SSL **sssl,
+                          SSL **cssl, int sfd, int cfd)
+{
+    SSL *serverssl = NULL, *clientssl = NULL;
+    BIO *s_to_c_bio = NULL, *c_to_s_bio = NULL;
+
+    if (*sssl != NULL)
+        serverssl = *sssl;
+    else if (!TEST_ptr(serverssl = SSL_new(serverctx)))
+        goto error;
+    if (*cssl != NULL)
+        clientssl = *cssl;
+    else if (!TEST_ptr(clientssl = SSL_new(clientctx)))
+        goto error;
+
+    if (!TEST_ptr(s_to_c_bio = BIO_new_socket(sfd, BIO_NOCLOSE))
+            || !TEST_ptr(c_to_s_bio = BIO_new_socket(cfd, BIO_NOCLOSE)))
+        goto error;
+
+    SSL_set_bio(clientssl, c_to_s_bio, c_to_s_bio);
+    SSL_set_bio(serverssl, s_to_c_bio, s_to_c_bio);
+    *sssl = serverssl;
+    *cssl = clientssl;
+    return 1;
+
+ error:
+    SSL_free(serverssl);
+    SSL_free(clientssl);
+    BIO_free(s_to_c_bio);
+    BIO_free(c_to_s_bio);
+    return 0;
+}
+#endif
+
 /*
  * NOTE: Transfers control of the BIOs - this function will free them on error
  */
@@ -604,12 +914,19 @@ int create_ssl_objects(SSL_CTX *serverctx, SSL_CTX *clientctx, SSL **sssl,
     return 0;
 }
 
-int create_ssl_connection(SSL *serverssl, SSL *clientssl, int want)
+/*
+ * Create an SSL connection, but does not ready any post-handshake
+ * NewSessionTicket messages.
+ * If |read| is set and we're using DTLS then we will attempt to SSL_read on
+ * the connection once we've completed one half of it, to ensure any retransmits
+ * get triggered.
+ */
+int create_bare_ssl_connection(SSL *serverssl, SSL *clientssl, int want,
+                               int read)
 {
     int retc = -1, rets = -1, err, abortctr = 0;
     int clienterr = 0, servererr = 0;
-    unsigned char buf;
-    size_t readbytes;
+    int isdtls = SSL_is_dtls(serverssl);
 
     do {
         err = SSL_ERROR_WANT_WRITE;
@@ -633,7 +950,9 @@ int create_ssl_connection(SSL *serverssl, SSL *clientssl, int want)
                 err = SSL_get_error(serverssl, rets);
         }
 
-        if (!servererr && rets <= 0 && err != SSL_ERROR_WANT_READ) {
+        if (!servererr && rets <= 0
+                && err != SSL_ERROR_WANT_READ
+                && err != SSL_ERROR_WANT_X509_LOOKUP) {
             TEST_info("SSL_accept() failed %d, %d", rets, err);
             servererr = 1;
         }
@@ -641,22 +960,68 @@ int create_ssl_connection(SSL *serverssl, SSL *clientssl, int want)
             return 0;
         if (clienterr && servererr)
             return 0;
+        if (isdtls && read) {
+            unsigned char buf[20];
+
+            /* Trigger any retransmits that may be appropriate */
+            if (rets > 0 && retc <= 0) {
+                if (SSL_read(serverssl, buf, sizeof(buf)) > 0) {
+                    /* We don't expect this to succeed! */
+                    TEST_info("Unexpected SSL_read() success!");
+                    return 0;
+                }
+            }
+            if (retc > 0 && rets <= 0) {
+                if (SSL_read(clientssl, buf, sizeof(buf)) > 0) {
+                    /* We don't expect this to succeed! */
+                    TEST_info("Unexpected SSL_read() success!");
+                    return 0;
+                }
+            }
+        }
         if (++abortctr == MAXLOOPS) {
             TEST_info("No progress made");
             return 0;
         }
+        if (isdtls && abortctr <= 50 && (abortctr % 10) == 0) {
+            /*
+             * It looks like we're just spinning. Pause for a short period to
+             * give the DTLS timer a chance to do something. We only do this for
+             * the first few times to prevent hangs.
+             */
+            ossl_sleep(50);
+        }
     } while (retc <=0 || rets <= 0);
 
+    return 1;
+}
+
+/*
+ * Create an SSL connection including any post handshake NewSessionTicket
+ * messages.
+ */
+int create_ssl_connection(SSL *serverssl, SSL *clientssl, int want)
+{
+    int i;
+    unsigned char buf;
+    size_t readbytes;
+
+    if (!create_bare_ssl_connection(serverssl, clientssl, want, 1))
+        return 0;
+
     /*
      * We attempt to read some data on the client side which we expect to fail.
      * This will ensure we have received the NewSessionTicket in TLSv1.3 where
-     * appropriate.
+     * appropriate. We do this twice because there are 2 NewSessionTickets.
      */
-    if (SSL_read_ex(clientssl, &buf, sizeof(buf), &readbytes) > 0) {
-        if (!TEST_ulong_eq(readbytes, 0))
+    for (i = 0; i < 2; i++) {
+        if (SSL_read_ex(clientssl, &buf, sizeof(buf), &readbytes) > 0) {
+            if (!TEST_ulong_eq(readbytes, 0))
+                return 0;
+        } else if (!TEST_int_eq(SSL_get_error(clientssl, 0),
+                                SSL_ERROR_WANT_READ)) {
             return 0;
-    } else if (!TEST_int_eq(SSL_get_error(clientssl, 0), SSL_ERROR_WANT_READ)) {
-        return 0;
+        }
     }
 
     return 1;