Avoid passing NULL to memcpy
[openssl.git] / crypto / evp / pkey_kdf.c
index 9774408..f32d213 100644 (file)
@@ -11,6 +11,7 @@
 #include <string.h>
 #include <openssl/evp.h>
 #include <openssl/err.h>
 #include <string.h>
 #include <openssl/evp.h>
 #include <openssl/err.h>
+#include <openssl/buffer.h>
 #include <openssl/kdf.h>
 #include <openssl/core.h>
 #include <openssl/core_names.h>
 #include <openssl/kdf.h>
 #include <openssl/core.h>
 #include <openssl/core_names.h>
 
 typedef struct {
     EVP_KDF_CTX *kctx;
 
 typedef struct {
     EVP_KDF_CTX *kctx;
-    /* TODO(3.0): come up with a better way to do this */
-    OSSL_PARAM params[MAX_PARAM];
-    int palloc[MAX_PARAM];
-    uint64_t uint64s[MAX_PARAM];
-    int ints[MAX_PARAM];
-    int pidx;
+    /*
+     * EVP_PKEY implementations collect bits of certain data
+     */
+    BUF_MEM *collected_seed;
+    BUF_MEM *collected_info;
 } EVP_PKEY_KDF_CTX;
 
 } EVP_PKEY_KDF_CTX;
 
-static void pkey_kdf_free_param_data(EVP_PKEY_KDF_CTX *pkctx)
+static void pkey_kdf_free_collected(EVP_PKEY_KDF_CTX *pkctx)
 {
 {
-    int i;
-
-    for (i = 0; i < pkctx->pidx; i++)
-        if (pkctx->palloc[i])
-            OPENSSL_free(pkctx->params[i].data);
-    pkctx->pidx = 0;
+    BUF_MEM_free(pkctx->collected_seed);
+    pkctx->collected_seed = NULL;
+    BUF_MEM_free(pkctx->collected_info);
+    pkctx->collected_info = NULL;
 }
 
 static int pkey_kdf_init(EVP_PKEY_CTX *ctx)
 }
 
 static int pkey_kdf_init(EVP_PKEY_CTX *ctx)
@@ -69,16 +67,39 @@ static void pkey_kdf_cleanup(EVP_PKEY_CTX *ctx)
     EVP_PKEY_KDF_CTX *pkctx = ctx->data;
 
     EVP_KDF_CTX_free(pkctx->kctx);
     EVP_PKEY_KDF_CTX *pkctx = ctx->data;
 
     EVP_KDF_CTX_free(pkctx->kctx);
-    pkey_kdf_free_param_data(pkctx);
+    pkey_kdf_free_collected(pkctx);
     OPENSSL_free(pkctx);
 }
 
     OPENSSL_free(pkctx);
 }
 
+static int collect(BUF_MEM **collector, void *data, size_t datalen)
+{
+    size_t i;
+
+    if (*collector == NULL)
+        *collector = BUF_MEM_new();
+    if (*collector == NULL) {
+        ERR_raise(ERR_LIB_EVP, ERR_R_MALLOC_FAILURE);
+        return 0;
+    }
+
+    if (data != NULL && datalen > 0) {
+        i = (*collector)->length; /* BUF_MEM_grow() changes it! */
+
+        if (!BUF_MEM_grow(*collector, i + datalen))
+            return 0;
+        memcpy((*collector)->data + i, data, datalen);
+    }
+    return 1;
+}
+
 static int pkey_kdf_ctrl(EVP_PKEY_CTX *ctx, int type, int p1, void *p2)
 {
     EVP_PKEY_KDF_CTX *pkctx = ctx->data;
 static int pkey_kdf_ctrl(EVP_PKEY_CTX *ctx, int type, int p1, void *p2)
 {
     EVP_PKEY_KDF_CTX *pkctx = ctx->data;
+    EVP_KDF_CTX *kctx = pkctx->kctx;
     enum { T_OCTET_STRING, T_UINT64, T_DIGEST, T_INT } cmd;
     const char *name, *mdname;
     enum { T_OCTET_STRING, T_UINT64, T_DIGEST, T_INT } cmd;
     const char *name, *mdname;
-    OSSL_PARAM *p = pkctx->params + pkctx->pidx;
+    BUF_MEM **collector = NULL;
+    OSSL_PARAM params[2] = { OSSL_PARAM_END, OSSL_PARAM_END };
 
     switch (type) {
     case EVP_PKEY_CTRL_PASS:
 
     switch (type) {
     case EVP_PKEY_CTRL_PASS:
@@ -98,10 +119,19 @@ static int pkey_kdf_ctrl(EVP_PKEY_CTX *ctx, int type, int p1, void *p2)
     case EVP_PKEY_CTRL_TLS_SECRET:
         cmd = T_OCTET_STRING;
         name = OSSL_KDF_PARAM_SECRET;
     case EVP_PKEY_CTRL_TLS_SECRET:
         cmd = T_OCTET_STRING;
         name = OSSL_KDF_PARAM_SECRET;
+        /*
+         * Perform the semantics described in
+         * EVP_PKEY_CTX_add1_tls1_prf_seed(3)
+         */
+        if (ctx->pmeth->pkey_id == NID_tls1_prf) {
+            BUF_MEM_free(pkctx->collected_seed);
+            pkctx->collected_seed = NULL;
+        }
         break;
     case EVP_PKEY_CTRL_TLS_SEED:
         cmd = T_OCTET_STRING;
         name = OSSL_KDF_PARAM_SEED;
         break;
     case EVP_PKEY_CTRL_TLS_SEED:
         cmd = T_OCTET_STRING;
         name = OSSL_KDF_PARAM_SEED;
+        collector = &pkctx->collected_seed;
         break;
     case EVP_PKEY_CTRL_HKDF_KEY:
         cmd = T_OCTET_STRING;
         break;
     case EVP_PKEY_CTRL_HKDF_KEY:
         cmd = T_OCTET_STRING;
@@ -110,6 +140,7 @@ static int pkey_kdf_ctrl(EVP_PKEY_CTX *ctx, int type, int p1, void *p2)
     case EVP_PKEY_CTRL_HKDF_INFO:
         cmd = T_OCTET_STRING;
         name = OSSL_KDF_PARAM_INFO;
     case EVP_PKEY_CTRL_HKDF_INFO:
         cmd = T_OCTET_STRING;
         name = OSSL_KDF_PARAM_INFO;
+        collector = &pkctx->collected_info;
         break;
     case EVP_PKEY_CTRL_HKDF_MODE:
         cmd = T_INT;
         break;
     case EVP_PKEY_CTRL_HKDF_MODE:
         cmd = T_INT;
@@ -135,16 +166,28 @@ static int pkey_kdf_ctrl(EVP_PKEY_CTX *ctx, int type, int p1, void *p2)
         return -2;
     }
 
         return -2;
     }
 
+    if (collector != NULL) {
+        switch (cmd) {
+        case T_OCTET_STRING:
+            return collect(collector, p2, p1);
+        default:
+            OPENSSL_assert("You shouldn't be here");
+            break;
+        }
+        return 1;
+    }
+
     switch (cmd) {
     case T_OCTET_STRING:
     switch (cmd) {
     case T_OCTET_STRING:
-        *p = OSSL_PARAM_construct_octet_string(name, (unsigned char *)p2,
-                                               (size_t)p1);
+        params[0] =
+            OSSL_PARAM_construct_octet_string(name, (unsigned char *)p2,
+                                              (size_t)p1);
         break;
 
     case T_DIGEST:
         mdname = EVP_MD_name((const EVP_MD *)p2);
         break;
 
     case T_DIGEST:
         mdname = EVP_MD_name((const EVP_MD *)p2);
-        *p = OSSL_PARAM_construct_utf8_string(name, (char *)mdname,
-                                              strlen(mdname) + 1);
+        params[0] = OSSL_PARAM_construct_utf8_string(name, (char *)mdname,
+                                                     strlen(mdname) + 1);
         break;
 
         /*
         break;
 
         /*
@@ -152,17 +195,15 @@ static int pkey_kdf_ctrl(EVP_PKEY_CTX *ctx, int type, int p1, void *p2)
          * stack, so a local copy is required.
          */
     case T_INT:
          * stack, so a local copy is required.
          */
     case T_INT:
-        pkctx->ints[pkctx->pidx] = *(int *)p2;
-        *p = OSSL_PARAM_construct_int(name, pkctx->ints + pkctx->pidx);
+        params[0] = OSSL_PARAM_construct_int(name, &p1);
         break;
 
     case T_UINT64:
         break;
 
     case T_UINT64:
-        pkctx->uint64s[pkctx->pidx] = *(uint64_t *)p2;
-        *p = OSSL_PARAM_construct_uint64(name, pkctx->uint64s + pkctx->pidx);
+        params[0] = OSSL_PARAM_construct_uint64(name, (uint64_t *)p2);
         break;
     }
         break;
     }
-    pkctx->palloc[pkctx->pidx++] = 0;
-    return 1;
+
+    return EVP_KDF_CTX_set_params(kctx, params);
 }
 
 static int pkey_kdf_ctrl_str(EVP_PKEY_CTX *ctx, const char *type,
 }
 
 static int pkey_kdf_ctrl_str(EVP_PKEY_CTX *ctx, const char *type,
@@ -171,8 +212,10 @@ static int pkey_kdf_ctrl_str(EVP_PKEY_CTX *ctx, const char *type,
     EVP_PKEY_KDF_CTX *pkctx = ctx->data;
     EVP_KDF_CTX *kctx = pkctx->kctx;
     const EVP_KDF *kdf = EVP_KDF_CTX_kdf(kctx);
     EVP_PKEY_KDF_CTX *pkctx = ctx->data;
     EVP_KDF_CTX *kctx = pkctx->kctx;
     const EVP_KDF *kdf = EVP_KDF_CTX_kdf(kctx);
+    BUF_MEM **collector = NULL;
     const OSSL_PARAM *defs = EVP_KDF_CTX_settable_params(kdf);
     const OSSL_PARAM *defs = EVP_KDF_CTX_settable_params(kdf);
-    OSSL_PARAM *p = pkctx->params + pkctx->pidx;
+    OSSL_PARAM params[2] = { OSSL_PARAM_END, OSSL_PARAM_END };
+    int ok = 0;
 
     /* Deal with ctrl name aliasing */
     if (strcmp(type, "md") == 0)
 
     /* Deal with ctrl name aliasing */
     if (strcmp(type, "md") == 0)
@@ -181,18 +224,34 @@ static int pkey_kdf_ctrl_str(EVP_PKEY_CTX *ctx, const char *type,
     if (strcmp(type, "N") == 0)
         type = OSSL_KDF_PARAM_SCRYPT_N;
 
     if (strcmp(type, "N") == 0)
         type = OSSL_KDF_PARAM_SCRYPT_N;
 
-    if (!OSSL_PARAM_allocate_from_text(p, defs, type, value, strlen(value)))
+    if (!OSSL_PARAM_allocate_from_text(&params[0], defs, type,
+                                       value, strlen(value)))
         return 0;
         return 0;
-    pkctx->palloc[pkctx->pidx++] = 1;
-    return 1;
+
+    /*
+     * We do the same special casing of seed and info here as in
+     * pkey_kdf_ctrl()
+     */
+    if (strcmp(params[0].key, OSSL_KDF_PARAM_SEED) == 0)
+        collector = &pkctx->collected_seed;
+    else if (strcmp(params[0].key, OSSL_KDF_PARAM_INFO) == 0)
+        collector = &pkctx->collected_info;
+
+    if (collector != NULL)
+        ok = collect(collector, params[0].data, params[0].data_size);
+    else
+        ok = EVP_KDF_CTX_set_params(kctx, params);
+    OPENSSL_free(params[0].data);
+    return ok;
 }
 
 static int pkey_kdf_derive_init(EVP_PKEY_CTX *ctx)
 {
     EVP_PKEY_KDF_CTX *pkctx = ctx->data;
 
 }
 
 static int pkey_kdf_derive_init(EVP_PKEY_CTX *ctx)
 {
     EVP_PKEY_KDF_CTX *pkctx = ctx->data;
 
-    pkey_kdf_free_param_data(pkctx);
-    EVP_KDF_reset(pkctx->kctx);
+    pkey_kdf_free_collected(pkctx);
+    if (pkctx->kctx != NULL)
+        EVP_KDF_reset(pkctx->kctx);
     return 1;
 }
 
     return 1;
 }
 
@@ -208,10 +267,29 @@ static int pkey_kdf_derive(EVP_PKEY_CTX *ctx, unsigned char *key,
     size_t outlen = EVP_KDF_size(kctx);
     int r;
 
     size_t outlen = EVP_KDF_size(kctx);
     int r;
 
-    if (pkctx->pidx > 0) {
-        pkctx->params[pkctx->pidx] = OSSL_PARAM_construct_end();
-        r = EVP_KDF_CTX_set_params(kctx, pkctx->params);
-        pkey_kdf_free_param_data(pkctx);
+    if (pkctx->collected_seed != NULL) {
+        OSSL_PARAM params[] = { OSSL_PARAM_END, OSSL_PARAM_END };
+
+        params[0] =
+            OSSL_PARAM_construct_octet_string(OSSL_KDF_PARAM_SEED,
+                                              pkctx->collected_seed->data,
+                                              pkctx->collected_seed->length);
+
+        r = EVP_KDF_CTX_set_params(kctx, params);
+        pkey_kdf_free_collected(pkctx);
+        if (!r)
+            return 0;
+    }
+    if (pkctx->collected_info != NULL) {
+        OSSL_PARAM params[] = { OSSL_PARAM_END, OSSL_PARAM_END };
+
+        params[0] =
+            OSSL_PARAM_construct_octet_string(OSSL_KDF_PARAM_INFO,
+                                              pkctx->collected_info->data,
+                                              pkctx->collected_info->length);
+
+        r = EVP_KDF_CTX_set_params(kctx, params);
+        pkey_kdf_free_collected(pkctx);
         if (!r)
             return 0;
     }
         if (!r)
             return 0;
     }