Add a test to check that if a PSK extension is not last then we fail
[openssl.git] / test / asynciotest.c
index 23d0907dc9151ece990d757e68ff3012bd1d3695..f418bbeb2c6169b0e67550c1d12f9b4ab1537bef 100644 (file)
@@ -142,8 +142,9 @@ static int async_write(BIO *bio, const char *in, int inl)
                 abort();
 
             while (PACKET_remaining(&pkt) > 0) {
-                PACKET payload;
+                PACKET payload, wholebody;
                 unsigned int contenttype, versionhi, versionlo, data;
+                unsigned int msgtype = 0, negversion = 0;
 
                 if (   !PACKET_get_1(&pkt, &contenttype)
                     || !PACKET_get_1(&pkt, &versionhi)
@@ -154,6 +155,17 @@ static int async_write(BIO *bio, const char *in, int inl)
                 /* Pretend we wrote out the record header */
                 written += SSL3_RT_HEADER_LENGTH;
 
+                wholebody = payload;
+                if (contenttype == SSL3_RT_HANDSHAKE
+                        && !PACKET_get_1(&wholebody, &msgtype))
+                    abort();
+
+                if (msgtype == SSL3_MT_SERVER_HELLO
+                        && (!PACKET_forward(&wholebody,
+                                            SSL3_HM_HEADER_LENGTH - 1)
+                            || !PACKET_get_net_2(&wholebody, &negversion)))
+                    abort();
+
                 while (PACKET_get_1(&payload, &data)) {
                     /* Create a new one byte long record for each byte in the
                      * record in the input buffer
@@ -177,10 +189,14 @@ static int async_write(BIO *bio, const char *in, int inl)
                     written++;
                 }
                 /*
-                 * We can't fragment anything after the CCS, otherwise we
-                 * get a bad record MAC
+                 * We can't fragment anything after the ServerHello (or CCS <=
+                 * TLS1.2), otherwise we get a bad record MAC
+                 * TODO(TLS1.3): Change TLS1_3_VERSION_DRAFT to TLS1_3_VERSION
+                 * before release
                  */
-                if (contenttype == SSL3_RT_CHANGE_CIPHER_SPEC) {
+                if (contenttype == SSL3_RT_CHANGE_CIPHER_SPEC
+                        || (negversion == TLS1_3_VERSION_DRAFT
+                            && msgtype == SSL3_MT_SERVER_HELLO)) {
                     fragment = 0;
                     break;
                 }
@@ -189,7 +205,7 @@ static int async_write(BIO *bio, const char *in, int inl)
         /* Write any data we have left after fragmenting */
         ret = 0;
         if ((int)written < inl) {
-            ret = BIO_write(next, in + written , inl - written);
+            ret = BIO_write(next, in + written, inl - written);
         }
 
         if (ret <= 0 && BIO_should_write(next))
@@ -287,7 +303,7 @@ int main(int argc, char *argv[])
             goto end;
         }
 
-        if (!create_ssl_connection(serverssl, clientssl)) {
+        if (!create_ssl_connection(serverssl, clientssl, SSL_ERROR_NONE)) {
             printf("Test %d failed: Create SSL connection failed\n", test);
             goto end;
         }
@@ -297,32 +313,59 @@ int main(int argc, char *argv[])
          * we hit at least one async event in both reading and writing
          */
         for (j = 0; j < 2; j++) {
+            int len;
+
             /*
              * Write some test data. It should never take more than 2 attempts
-             * (the first one might be a retryable fail). A zero return from
-             * SSL_write() is a non-retryable failure, so fail immediately if
-             * we get that.
+             * (the first one might be a retryable fail).
              */
-            for (ret = -1, i = 0; ret < 0 && i < 2 * sizeof(testdata); i++)
-                ret = SSL_write(clientssl, testdata, sizeof(testdata));
-            if (ret <= 0) {
-                printf("Test %d failed: Failed to write app data\n", test);
+            for (ret = -1, i = 0, len = 0; len != sizeof(testdata) && i < 2;
+                i++) {
+                ret = SSL_write(clientssl, testdata + len,
+                    sizeof(testdata) - len);
+                if (ret > 0) {
+                    len += ret;
+                } else {
+                    int ssl_error = SSL_get_error(clientssl, ret);
+
+                    if (ssl_error == SSL_ERROR_SYSCALL ||
+                        ssl_error == SSL_ERROR_SSL) {
+                        printf("Test %d failed: Failed to write app data\n", test);
+                        err = -1;
+                        goto end;
+                    }
+                }
+            }
+            if (len != sizeof(testdata)) {
+                err = -1;
+                printf("Test %d failed: Failed to write all app data\n", test);
                 goto end;
             }
             /*
              * Now read the test data. It may take more attemps here because
              * it could fail once for each byte read, including all overhead
-             * bytes from the record header/padding etc. Fail immediately if we
-             * get a zero return from SSL_read().
+             * bytes from the record header/padding etc.
              */
-            for (ret = -1, i = 0; ret < 0 && i < MAX_ATTEMPTS; i++)
-                ret = SSL_read(serverssl, buf, sizeof(buf));
-            if (ret <= 0) {
-                printf("Test %d failed: Failed to read app data\n", test);
-                goto end;
+            for (ret = -1, i = 0, len = 0; len != sizeof(testdata) &&
+                i < MAX_ATTEMPTS; i++)
+            {
+                ret = SSL_read(serverssl, buf + len, sizeof(buf) - len);
+                if (ret > 0) {
+                    len += ret;
+                } else {
+                    int ssl_error = SSL_get_error(serverssl, ret);
+
+                    if (ssl_error == SSL_ERROR_SYSCALL ||
+                        ssl_error == SSL_ERROR_SSL) {
+                        printf("Test %d failed: Failed to read app data\n", test);
+                        err = -1;
+                        goto end;
+                    }
+                }
             }
-            if (ret != sizeof(testdata)
+            if (len != sizeof(testdata)
                     || memcmp(buf, testdata, sizeof(testdata)) != 0) {
+                err = -1;
                 printf("Test %d failed: Unexpected app data received\n", test);
                 goto end;
             }