Check whether buffers have actually been allocated/freed
[openssl.git] / test / sslbuffertest.c
index 3c3e69d61da80e7d2edfd46b6b7c4f3d1d237aa1..6b6bf19292bd19d3e6b284c3ffbc018654228947 100644 (file)
 #include <openssl/bio.h>
 #include <openssl/err.h>
 
+/* We include internal headers so we can check if the buffers are allocated */
+#include "../ssl/ssl_local.h"
+#include "../ssl/record/record_local.h"
+#include "../ssl/record/recordmethod.h"
+#include "../ssl/record/methods/recmethod_local.h"
+
 #include "internal/packet.h"
 
 #include "helpers/ssltestlib.h"
@@ -28,6 +34,17 @@ static SSL_CTX *clientctx = NULL;
 
 #define MAX_ATTEMPTS    100
 
+static int checkbuffers(SSL *s, int isalloced)
+{
+    SSL_CONNECTION *sc = SSL_CONNECTION_FROM_SSL(s);
+    OSSL_RECORD_LAYER *rrl = sc->rlayer.rrl;
+    OSSL_RECORD_LAYER *wrl = sc->rlayer.wrl;
+
+    if (isalloced)
+        return rrl->rbuf.buf != NULL && wrl->wbuf[0].buf != NULL;
+
+    return rrl->rbuf.buf == NULL && wrl->wbuf[0].buf == NULL;
+}
 
 /*
  * There are 9 passes in the tests
@@ -78,14 +95,18 @@ static int test_func(int test)
         for (ret = -1, i = 0, len = 0; len != sizeof(testdata) && i < 2;
              i++) {
             /* test == 0 mean to free/allocate = control */
-            if (test >= 1 && !TEST_true(SSL_free_buffers(clientssl)))
+            if (test >= 1 && (!TEST_true(SSL_free_buffers(clientssl))
+                              || !TEST_true(checkbuffers(clientssl, 0))))
                 goto end;
-            if (test >= 2 && !TEST_true(SSL_alloc_buffers(clientssl)))
+            if (test >= 2 && (!TEST_true(SSL_alloc_buffers(clientssl))
+                              || !TEST_true(checkbuffers(clientssl, 1))))
                 goto end;
             /* allocate a second time */
-            if (test >= 3 && !TEST_true(SSL_alloc_buffers(clientssl)))
+            if (test >= 3 && (!TEST_true(SSL_alloc_buffers(clientssl))
+                              || !TEST_true(checkbuffers(clientssl, 1))))
                 goto end;
-            if (test >= 4 && !TEST_true(SSL_free_buffers(clientssl)))
+            if (test >= 4 && (!TEST_true(SSL_free_buffers(clientssl))
+                              || !TEST_true(checkbuffers(clientssl, 0))))
                 goto end;
 
             ret = SSL_write(clientssl, testdata + len,
@@ -112,14 +133,18 @@ static int test_func(int test)
         for (ret = -1, i = 0, len = 0; len != sizeof(testdata) &&
                  i < MAX_ATTEMPTS; i++)
         {
-            if (test >= 5 && !TEST_true(SSL_free_buffers(serverssl)))
+            if (test >= 5 && (!TEST_true(SSL_free_buffers(serverssl))
+                              || !TEST_true(checkbuffers(serverssl, 0))))
                 goto end;
             /* free a second time */
-            if (test >= 6 && !TEST_true(SSL_free_buffers(serverssl)))
+            if (test >= 6 && (!TEST_true(SSL_free_buffers(serverssl))
+                              || !TEST_true(checkbuffers(serverssl, 0))))
                 goto end;
-            if (test >= 7 && !TEST_true(SSL_alloc_buffers(serverssl)))
+            if (test >= 7 && (!TEST_true(SSL_alloc_buffers(serverssl))
+                              || !TEST_true(checkbuffers(serverssl, 1))))
                 goto end;
-            if (test >= 8 && !TEST_true(SSL_free_buffers(serverssl)))
+            if (test >= 8 && (!TEST_true(SSL_free_buffers(serverssl))
+                              || !TEST_true(checkbuffers(serverssl, 0))))
                 goto end;
 
             ret = SSL_read(serverssl, buf + len, sizeof(buf) - len);