Adds multiple checks to avoid buffer over reads
[openssl.git] / ssl / packet.c
index 87473fbbce9fdb44ccf60929d2b4792b539b8aa1..7a4414ae6a8012ffa7bbe39c2c54afbeffa7ef82 100644 (file)
@@ -7,7 +7,7 @@
  * https://www.openssl.org/source/license.html
  */
 
-#include <assert.h>
+#include "internal/cryptlib.h"
 #include "packet_locl.h"
 
 #define DEFAULT_BUF_SIZE    256
@@ -39,8 +39,7 @@ int WPACKET_sub_allocate_bytes__(WPACKET *pkt, size_t len,
 int WPACKET_reserve_bytes(WPACKET *pkt, size_t len, unsigned char **allocbytes)
 {
     /* Internal API, so should not fail */
-    assert(pkt->subs != NULL && len != 0);
-    if (pkt->subs == NULL || len == 0)
+    if (!ossl_assert(pkt->subs != NULL && len != 0))
         return 0;
 
     if (pkt->maxsize - pkt->written < len)
@@ -120,8 +119,7 @@ int WPACKET_init_static_len(WPACKET *pkt, unsigned char *buf, size_t len,
     size_t max = maxmaxsize(lenbytes);
 
     /* Internal API, so should not fail */
-    assert(buf != NULL && len > 0);
-    if (buf == NULL || len == 0)
+    if (!ossl_assert(buf != NULL && len > 0))
         return 0;
 
     pkt->staticbuf = buf;
@@ -134,8 +132,7 @@ int WPACKET_init_static_len(WPACKET *pkt, unsigned char *buf, size_t len,
 int WPACKET_init_len(WPACKET *pkt, BUF_MEM *buf, size_t lenbytes)
 {
     /* Internal API, so should not fail */
-    assert(buf != NULL);
-    if (buf == NULL)
+    if (!ossl_assert(buf != NULL))
         return 0;
 
     pkt->staticbuf = NULL;
@@ -153,8 +150,7 @@ int WPACKET_init(WPACKET *pkt, BUF_MEM *buf)
 int WPACKET_set_flags(WPACKET *pkt, unsigned int flags)
 {
     /* Internal API, so should not fail */
-    assert(pkt->subs != NULL);
-    if (pkt->subs == NULL)
+    if (!ossl_assert(pkt->subs != NULL))
         return 0;
 
     pkt->subs->flags = flags;
@@ -228,16 +224,13 @@ int WPACKET_fill_lengths(WPACKET *pkt)
 {
     WPACKET_SUB *sub;
 
-    assert(pkt->subs != NULL);
-    if (pkt->subs == NULL)
+    if (!ossl_assert(pkt->subs != NULL))
         return 0;
 
-    sub = pkt->subs;
-    do {
+    for (sub = pkt->subs; sub != NULL; sub = sub->parent) {
         if (!wpacket_intern_close(pkt, sub, 0))
             return 0;
-        sub = sub->parent;
-    } while (sub != NULL);
+    }
 
     return 1;
 }
@@ -280,8 +273,7 @@ int WPACKET_start_sub_packet_len__(WPACKET *pkt, size_t lenbytes)
     unsigned char *lenchars;
 
     /* Internal API, so should not fail */
-    assert(pkt->subs != NULL);
-    if (pkt->subs == NULL)
+    if (!ossl_assert(pkt->subs != NULL))
         return 0;
 
     sub = OPENSSL_zalloc(sizeof(*sub));
@@ -316,9 +308,7 @@ int WPACKET_put_bytes__(WPACKET *pkt, unsigned int val, size_t size)
     unsigned char *data;
 
     /* Internal API, so should not fail */
-    assert(size <= sizeof(unsigned int));
-
-    if (size > sizeof(unsigned int)
+    if (!ossl_assert(size <= sizeof(unsigned int))
             || !WPACKET_allocate_bytes(pkt, size, &data)
             || !put_value(data, val, size))
         return 0;
@@ -332,8 +322,7 @@ int WPACKET_set_max_size(WPACKET *pkt, size_t maxsize)
     size_t lenbytes;
 
     /* Internal API, so should not fail */
-    assert(pkt->subs != NULL);
-    if (pkt->subs == NULL)
+    if (!ossl_assert(pkt->subs != NULL))
         return 0;
 
     /* Find the WPACKET_SUB for the top level */
@@ -352,6 +341,21 @@ int WPACKET_set_max_size(WPACKET *pkt, size_t maxsize)
     return 1;
 }
 
+int WPACKET_memset(WPACKET *pkt, int ch, size_t len)
+{
+    unsigned char *dest;
+
+    if (len == 0)
+        return 1;
+
+    if (!WPACKET_allocate_bytes(pkt, len, &dest))
+        return 0;
+
+    memset(dest, ch, len);
+
+    return 1;
+}
+
 int WPACKET_memcpy(WPACKET *pkt, const void *src, size_t len)
 {
     unsigned char *dest;
@@ -381,8 +385,7 @@ int WPACKET_sub_memcpy__(WPACKET *pkt, const void *src, size_t len,
 int WPACKET_get_total_written(WPACKET *pkt, size_t *written)
 {
     /* Internal API, so should not fail */
-    assert(written != NULL);
-    if (written == NULL)
+    if (!ossl_assert(written != NULL))
         return 0;
 
     *written = pkt->written;
@@ -393,8 +396,7 @@ int WPACKET_get_total_written(WPACKET *pkt, size_t *written)
 int WPACKET_get_length(WPACKET *pkt, size_t *len)
 {
     /* Internal API, so should not fail */
-    assert(pkt->subs != NULL && len != NULL);
-    if (pkt->subs == NULL || len == NULL)
+    if (!ossl_assert(pkt->subs != NULL && len != NULL))
         return 0;
 
     *len = pkt->written - pkt->subs->pwritten;