Secure memory fixes
[openssl.git] / crypto / mem_sec.c
index 5f4f733fd06f3f11f5871be37a7c25e39d5cd224..d61d945d6366cbd746732af9f65c67b534a19232 100644 (file)
 #include <openssl/crypto.h>
 #include <e_os.h>
 
+#include <string.h>
+
 #if defined(OPENSSL_SYS_LINUX) || defined(OPENSSL_SYS_UNIX)
 # define IMPLEMENTED
 # include <stdlib.h>
-# include <string.h>
 # include <assert.h>
 # include <unistd.h>
 # include <sys/types.h>
 # include <sys/param.h>
 # include <sys/stat.h>
 # include <fcntl.h>
+# include "internal/threads.h"
 #endif
 
-#define LOCK()      CRYPTO_w_lock(CRYPTO_LOCK_MALLOC)
-#define UNLOCK()    CRYPTO_w_unlock(CRYPTO_LOCK_MALLOC)
 #define CLEAR(p, s) OPENSSL_cleanse(p, s)
 #ifndef PAGE_SIZE
 # define PAGE_SIZE    4096
@@ -37,7 +37,8 @@
 static size_t secure_mem_used;
 
 static int secure_mem_initialized;
-static int too_late;
+
+static CRYPTO_RWLOCK *sec_malloc_lock = NULL;
 
 /*
  * These are the functions that must be implemented by a secure heap (sh).
@@ -46,7 +47,7 @@ static int sh_init(size_t size, int minsize);
 static char *sh_malloc(size_t size);
 static void sh_free(char *ptr);
 static void sh_done(void);
-static int sh_actual_size(char *ptr);
+static size_t sh_actual_size(char *ptr);
 static int sh_allocated(const char *ptr);
 #endif
 
@@ -55,29 +56,31 @@ int CRYPTO_secure_malloc_init(size_t size, int minsize)
 #ifdef IMPLEMENTED
     int ret = 0;
 
-    if (too_late)
-        return ret;
-    LOCK();
-    OPENSSL_assert(!secure_mem_initialized);
     if (!secure_mem_initialized) {
+        sec_malloc_lock = CRYPTO_THREAD_lock_new();
+        if (sec_malloc_lock == NULL)
+            return 0;
         ret = sh_init(size, minsize);
         secure_mem_initialized = 1;
     }
-    UNLOCK();
+
     return ret;
 #else
     return 0;
 #endif /* IMPLEMENTED */
 }
 
-void CRYPTO_secure_malloc_done()
+int CRYPTO_secure_malloc_done()
 {
 #ifdef IMPLEMENTED
-    LOCK();
-    sh_done();
-    secure_mem_initialized = 0;
-    UNLOCK();
+    if (secure_mem_used == 0) {
+        sh_done();
+        secure_mem_initialized = 0;
+        CRYPTO_THREAD_lock_free(sec_malloc_lock);
+        return 1;
+    }
 #endif /* IMPLEMENTED */
+    return 0;
 }
 
 int CRYPTO_secure_malloc_initialized()
@@ -96,39 +99,47 @@ void *CRYPTO_secure_malloc(size_t num, const char *file, int line)
     size_t actual_size;
 
     if (!secure_mem_initialized) {
-        too_late = 1;
         return CRYPTO_malloc(num, file, line);
     }
-    LOCK();
+    CRYPTO_THREAD_write_lock(sec_malloc_lock);
     ret = sh_malloc(num);
     actual_size = ret ? sh_actual_size(ret) : 0;
     secure_mem_used += actual_size;
-    UNLOCK();
+    CRYPTO_THREAD_unlock(sec_malloc_lock);
     return ret;
 #else
     return CRYPTO_malloc(num, file, line);
 #endif /* IMPLEMENTED */
 }
 
-void CRYPTO_secure_free(void *ptr)
+void *CRYPTO_secure_zalloc(size_t num, const char *file, int line)
+{
+    void *ret = CRYPTO_secure_malloc(num, file, line);
+
+    if (ret != NULL)
+        memset(ret, 0, num);
+    return ret;
+}
+
+void CRYPTO_secure_free(void *ptr, const char *file, int line)
 {
 #ifdef IMPLEMENTED
     size_t actual_size;
 
     if (ptr == NULL)
         return;
-    if (!secure_mem_initialized) {
-        CRYPTO_free(ptr);
+    if (!CRYPTO_secure_allocated(ptr)) {
+        CRYPTO_free(ptr, file, line);
         return;
     }
-    LOCK();
+    CRYPTO_THREAD_write_lock(sec_malloc_lock);
     actual_size = sh_actual_size(ptr);
     CLEAR(ptr, actual_size);
     secure_mem_used -= actual_size;
     sh_free(ptr);
-    UNLOCK();
+    CRYPTO_THREAD_unlock(sec_malloc_lock);
 #else
-    CRYPTO_free(ptr);
+    CRYPTO_free(ptr, file, line);
 #endif /* IMPLEMENTED */
 }
 
@@ -139,9 +150,9 @@ int CRYPTO_secure_allocated(const void *ptr)
 
     if (!secure_mem_initialized)
         return 0;
-    LOCK();
+    CRYPTO_THREAD_write_lock(sec_malloc_lock);
     ret = sh_allocated(ptr);
-    UNLOCK();
+    CRYPTO_THREAD_unlock(sec_malloc_lock);
     return ret;
 #else
     return 0;
@@ -157,6 +168,19 @@ size_t CRYPTO_secure_used()
 #endif /* IMPLEMENTED */
 }
 
+size_t CRYPTO_secure_actual_size(void *ptr)
+{
+#ifdef IMPLEMENTED
+    size_t actual_size;
+
+    CRYPTO_THREAD_write_lock(sec_malloc_lock);
+    actual_size = sh_actual_size(ptr);
+    CRYPTO_THREAD_unlock(sec_malloc_lock);
+    return actual_size;
+#else
+    return 0;
+#endif
+}
 /* END OF PAGE ...
 
    ... START OF PAGE */
@@ -182,9 +206,11 @@ size_t CRYPTO_secure_used()
  * place.
  */
 
-# define TESTBIT(t, b)  (t[(b) >> 3] &  (1 << ((b) & 7)))
-# define SETBIT(t, b)   (t[(b) >> 3] |= (1 << ((b) & 7)))
-# define CLEARBIT(t, b) (t[(b) >> 3] &= (0xFF & ~(1 << ((b) & 7))))
+#define ONE ((size_t)1)
+
+# define TESTBIT(t, b)  (t[(b) >> 3] &  (ONE << ((b) & 7)))
+# define SETBIT(t, b)   (t[(b) >> 3] |= (ONE << ((b) & 7)))
+# define CLEARBIT(t, b) (t[(b) >> 3] &= (0xFF & ~(ONE << ((b) & 7))))
 
 #define WITHIN_ARENA(p) \
     ((char*)(p) >= sh.arena && (char*)(p) < &sh.arena[sh.arena_size])
@@ -203,21 +229,21 @@ typedef struct sh_st
     char* map_result;
     size_t map_size;
     char *arena;
-    int arena_size;
+    size_t arena_size;
     char **freelist;
-    int freelist_size;
-    int minsize;
+    ossl_ssize_t freelist_size;
+    size_t minsize;
     unsigned char *bittable;
     unsigned char *bitmalloc;
-    int bittable_size; /* size in bits */
+    size_t bittable_size; /* size in bits */
 } SH;
 
 static SH sh;
 
-static int sh_getlist(char *ptr)
+static size_t sh_getlist(char *ptr)
 {
-    int list = sh.freelist_size - 1;
-    int bit = (sh.arena_size + ptr - sh.arena) / sh.minsize;
+    ossl_ssize_t list = sh.freelist_size - 1;
+    size_t bit = (sh.arena_size + ptr - sh.arena) / sh.minsize;
 
     for (; bit; bit >>= 1, list--) {
         if (TESTBIT(sh.bittable, bit))
@@ -231,22 +257,22 @@ static int sh_getlist(char *ptr)
 
 static int sh_testbit(char *ptr, int list, unsigned char *table)
 {
-    int bit;
+    size_t bit;
 
     OPENSSL_assert(list >= 0 && list < sh.freelist_size);
     OPENSSL_assert(((ptr - sh.arena) & ((sh.arena_size >> list) - 1)) == 0);
-    bit = (1 << list) + ((ptr - sh.arena) / (sh.arena_size >> list));
+    bit = (ONE << list) + ((ptr - sh.arena) / (sh.arena_size >> list));
     OPENSSL_assert(bit > 0 && bit < sh.bittable_size);
     return TESTBIT(table, bit);
 }
 
 static void sh_clearbit(char *ptr, int list, unsigned char *table)
 {
-    int bit;
+    size_t bit;
 
     OPENSSL_assert(list >= 0 && list < sh.freelist_size);
     OPENSSL_assert(((ptr - sh.arena) & ((sh.arena_size >> list) - 1)) == 0);
-    bit = (1 << list) + ((ptr - sh.arena) / (sh.arena_size >> list));
+    bit = (ONE << list) + ((ptr - sh.arena) / (sh.arena_size >> list));
     OPENSSL_assert(bit > 0 && bit < sh.bittable_size);
     OPENSSL_assert(TESTBIT(table, bit));
     CLEARBIT(table, bit);
@@ -254,11 +280,11 @@ static void sh_clearbit(char *ptr, int list, unsigned char *table)
 
 static void sh_setbit(char *ptr, int list, unsigned char *table)
 {
-    int bit;
+    size_t bit;
 
     OPENSSL_assert(list >= 0 && list < sh.freelist_size);
     OPENSSL_assert(((ptr - sh.arena) & ((sh.arena_size >> list) - 1)) == 0);
-    bit = (1 << list) + ((ptr - sh.arena) / (sh.arena_size >> list));
+    bit = (ONE << list) + ((ptr - sh.arena) / (sh.arena_size >> list));
     OPENSSL_assert(bit > 0 && bit < sh.bittable_size);
     OPENSSL_assert(!TESTBIT(table, bit));
     SETBIT(table, bit);
@@ -284,7 +310,7 @@ static void sh_add_to_list(char **list, char *ptr)
     *list = ptr;
 }
 
-static void sh_remove_from_list(char *ptr, char *list)
+static void sh_remove_from_list(char *ptr)
 {
     SH_LIST *temp, *temp2;
 
@@ -423,21 +449,21 @@ static int sh_allocated(const char *ptr)
 
 static char *sh_find_my_buddy(char *ptr, int list)
 {
-    int bit;
+    size_t bit;
     char *chunk = NULL;
 
-    bit = (1 << list) + (ptr - sh.arena) / (sh.arena_size >> list);
+    bit = (ONE << list) + (ptr - sh.arena) / (sh.arena_size >> list);
     bit ^= 1;
 
     if (TESTBIT(sh.bittable, bit) && !TESTBIT(sh.bitmalloc, bit))
-        chunk = sh.arena + ((bit & ((1 << list) - 1)) * (sh.arena_size >> list));
+        chunk = sh.arena + ((bit & ((ONE << list) - 1)) * (sh.arena_size >> list));
 
     return chunk;
 }
 
 static char *sh_malloc(size_t size)
 {
-    int list, slist;
+    ossl_ssize_t list, slist;
     size_t i;
     char *chunk;
 
@@ -461,7 +487,7 @@ static char *sh_malloc(size_t size)
         /* remove from bigger list */
         OPENSSL_assert(!sh_testbit(temp, slist, sh.bitmalloc));
         sh_clearbit(temp, slist, sh.bittable);
-        sh_remove_from_list(temp, sh.freelist[slist]);
+        sh_remove_from_list(temp);
         OPENSSL_assert(temp != sh.freelist[slist]);
 
         /* done with bigger list */
@@ -487,7 +513,7 @@ static char *sh_malloc(size_t size)
     chunk = sh.freelist[list];
     OPENSSL_assert(sh_testbit(chunk, list, sh.bittable));
     sh_setbit(chunk, list, sh.bitmalloc);
-    sh_remove_from_list(chunk, sh.freelist[list]);
+    sh_remove_from_list(chunk);
 
     OPENSSL_assert(WITHIN_ARENA(chunk));
 
@@ -496,7 +522,7 @@ static char *sh_malloc(size_t size)
 
 static void sh_free(char *ptr)
 {
-    int list;
+    size_t list;
     char *buddy;
 
     if (ptr == NULL)
@@ -516,10 +542,10 @@ static void sh_free(char *ptr)
         OPENSSL_assert(ptr != NULL);
         OPENSSL_assert(!sh_testbit(ptr, list, sh.bitmalloc));
         sh_clearbit(ptr, list, sh.bittable);
-        sh_remove_from_list(ptr, sh.freelist[list]);
+        sh_remove_from_list(ptr);
         OPENSSL_assert(!sh_testbit(ptr, list, sh.bitmalloc));
         sh_clearbit(buddy, list, sh.bittable);
-        sh_remove_from_list(buddy, sh.freelist[list]);
+        sh_remove_from_list(buddy);
 
         list--;
 
@@ -533,7 +559,7 @@ static void sh_free(char *ptr)
     }
 }
 
-static int sh_actual_size(char *ptr)
+static size_t sh_actual_size(char *ptr)
 {
     int list;
 
@@ -542,6 +568,6 @@ static int sh_actual_size(char *ptr)
         return 0;
     list = sh_getlist(ptr);
     OPENSSL_assert(sh_testbit(ptr, list, sh.bittable));
-    return sh.arena_size / (1 << list);
+    return sh.arena_size / (ONE << list);
 }
 #endif /* IMPLEMENTED */