Deal with BUF_MEM_grow ambiguity
[openssl.git] / crypto / evp / pkey_kdf.c
index 29a24ac..f4cf40e 100644 (file)
@@ -71,6 +71,31 @@ static void pkey_kdf_cleanup(EVP_PKEY_CTX *ctx)
     OPENSSL_free(pkctx);
 }
 
+static int collect(BUF_MEM **collector, void *data, size_t datalen)
+{
+    size_t i;
+
+    if (*collector == NULL)
+        *collector = BUF_MEM_new();
+    if (*collector == NULL) {
+        ERR_raise(ERR_LIB_EVP, ERR_R_MALLOC_FAILURE);
+        return 0;
+    }
+
+    i = (*collector)->length; /* BUF_MEM_grow() changes it! */
+    /*
+     * The i + datalen check is to distinguish between BUF_MEM_grow()
+     * signaling an error and BUF_MEM_grow() simply returning the (zero)
+     * length.
+     */
+    if (!BUF_MEM_grow(*collector, i + datalen)
+        && i + datalen != 0)
+        return 0;
+    if (data != NULL)
+        memcpy((*collector)->data + i, data, datalen);
+    return 1;
+}
+
 static int pkey_kdf_ctrl(EVP_PKEY_CTX *ctx, int type, int p1, void *p2)
 {
     EVP_PKEY_KDF_CTX *pkctx = ctx->data;
@@ -144,16 +169,9 @@ static int pkey_kdf_ctrl(EVP_PKEY_CTX *ctx, int type, int p1, void *p2)
     }
 
     if (collector != NULL) {
-        size_t i;
-
         switch (cmd) {
         case T_OCTET_STRING:
-            if (*collector == NULL)
-                *collector = BUF_MEM_new();
-            i = (*collector)->length; /* BUF_MEM_grow() changes it! */
-            BUF_MEM_grow(*collector, i + p1);
-            memcpy((*collector)->data + i, p2, p1);
-            break;
+            return collect(collector, p2, p1);
         default:
             OPENSSL_assert("You shouldn't be here");
             break;
@@ -196,6 +214,7 @@ static int pkey_kdf_ctrl_str(EVP_PKEY_CTX *ctx, const char *type,
     EVP_PKEY_KDF_CTX *pkctx = ctx->data;
     EVP_KDF_CTX *kctx = pkctx->kctx;
     const EVP_KDF *kdf = EVP_KDF_CTX_kdf(kctx);
+    BUF_MEM **collector = NULL;
     const OSSL_PARAM *defs = EVP_KDF_CTX_settable_params(kdf);
     OSSL_PARAM params[2] = { OSSL_PARAM_END, OSSL_PARAM_END };
     int ok = 0;
@@ -210,7 +229,20 @@ static int pkey_kdf_ctrl_str(EVP_PKEY_CTX *ctx, const char *type,
     if (!OSSL_PARAM_allocate_from_text(&params[0], defs, type,
                                        value, strlen(value)))
         return 0;
-    ok = EVP_KDF_CTX_set_params(kctx, params);
+
+    /*
+     * We do the same special casing of seed and info here as in
+     * pkey_kdf_ctrl()
+     */
+    if (strcmp(params[0].key, OSSL_KDF_PARAM_SEED) == 0)
+        collector = &pkctx->collected_seed;
+    else if (strcmp(params[0].key, OSSL_KDF_PARAM_INFO) == 0)
+        collector = &pkctx->collected_info;
+
+    if (collector != NULL)
+        ok = collect(collector, params[0].data, params[0].data_size);
+    else
+        ok = EVP_KDF_CTX_set_params(kctx, params);
     OPENSSL_free(params[0].data);
     return ok;
 }