CORE: Add an internal function to distinguish the global default context
[openssl.git] / crypto / evp / pkey_kdf.c
index 29a24ac688276dafa8827528054534f761b5e0dd..dff16bfd41d7a05c29b7bfeccc67a50861fd5d7d 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright 2018 The OpenSSL Project Authors. All Rights Reserved.
+ * Copyright 2018-2020 The OpenSSL Project Authors. All Rights Reserved.
  * Copyright (c) 2018, Oracle and/or its affiliates.  All rights reserved.
  *
  * Licensed under the Apache License 2.0 (the "License").  You may not use
@@ -17,7 +17,7 @@
 #include <openssl/core_names.h>
 #include <openssl/params.h>
 #include "internal/numbers.h"
-#include "internal/evp_int.h"
+#include "crypto/evp.h"
 
 #define MAX_PARAM   20
 
@@ -50,7 +50,7 @@ static int pkey_kdf_init(EVP_PKEY_CTX *ctx)
         return 0;
 
     kdf = EVP_KDF_fetch(NULL, kdf_name, NULL);
-    kctx = EVP_KDF_CTX_new(kdf);
+    kctx = EVP_KDF_new_ctx(kdf);
     EVP_KDF_free(kdf);
     if (kctx == NULL) {
         OPENSSL_free(pkctx);
@@ -66,11 +66,32 @@ static void pkey_kdf_cleanup(EVP_PKEY_CTX *ctx)
 {
     EVP_PKEY_KDF_CTX *pkctx = ctx->data;
 
-    EVP_KDF_CTX_free(pkctx->kctx);
+    EVP_KDF_free_ctx(pkctx->kctx);
     pkey_kdf_free_collected(pkctx);
     OPENSSL_free(pkctx);
 }
 
+static int collect(BUF_MEM **collector, void *data, size_t datalen)
+{
+    size_t i;
+
+    if (*collector == NULL)
+        *collector = BUF_MEM_new();
+    if (*collector == NULL) {
+        ERR_raise(ERR_LIB_EVP, ERR_R_MALLOC_FAILURE);
+        return 0;
+    }
+
+    if (data != NULL && datalen > 0) {
+        i = (*collector)->length; /* BUF_MEM_grow() changes it! */
+
+        if (!BUF_MEM_grow(*collector, i + datalen))
+            return 0;
+        memcpy((*collector)->data + i, data, datalen);
+    }
+    return 1;
+}
+
 static int pkey_kdf_ctrl(EVP_PKEY_CTX *ctx, int type, int p1, void *p2)
 {
     EVP_PKEY_KDF_CTX *pkctx = ctx->data;
@@ -102,8 +123,10 @@ static int pkey_kdf_ctrl(EVP_PKEY_CTX *ctx, int type, int p1, void *p2)
          * Perform the semantics described in
          * EVP_PKEY_CTX_add1_tls1_prf_seed(3)
          */
-        if (ctx->pmeth->pkey_id == NID_tls1_prf)
+        if (ctx->pmeth->pkey_id == NID_tls1_prf) {
             BUF_MEM_free(pkctx->collected_seed);
+            pkctx->collected_seed = NULL;
+        }
         break;
     case EVP_PKEY_CTRL_TLS_SEED:
         cmd = T_OCTET_STRING;
@@ -144,16 +167,9 @@ static int pkey_kdf_ctrl(EVP_PKEY_CTX *ctx, int type, int p1, void *p2)
     }
 
     if (collector != NULL) {
-        size_t i;
-
         switch (cmd) {
         case T_OCTET_STRING:
-            if (*collector == NULL)
-                *collector = BUF_MEM_new();
-            i = (*collector)->length; /* BUF_MEM_grow() changes it! */
-            BUF_MEM_grow(*collector, i + p1);
-            memcpy((*collector)->data + i, p2, p1);
-            break;
+            return collect(collector, p2, p1);
         default:
             OPENSSL_assert("You shouldn't be here");
             break;
@@ -170,8 +186,7 @@ static int pkey_kdf_ctrl(EVP_PKEY_CTX *ctx, int type, int p1, void *p2)
 
     case T_DIGEST:
         mdname = EVP_MD_name((const EVP_MD *)p2);
-        params[0] = OSSL_PARAM_construct_utf8_string(name, (char *)mdname,
-                                                     strlen(mdname) + 1);
+        params[0] = OSSL_PARAM_construct_utf8_string(name, (char *)mdname, 0);
         break;
 
         /*
@@ -187,7 +202,7 @@ static int pkey_kdf_ctrl(EVP_PKEY_CTX *ctx, int type, int p1, void *p2)
         break;
     }
 
-    return EVP_KDF_CTX_set_params(kctx, params);
+    return EVP_KDF_set_ctx_params(kctx, params);
 }
 
 static int pkey_kdf_ctrl_str(EVP_PKEY_CTX *ctx, const char *type,
@@ -195,8 +210,9 @@ static int pkey_kdf_ctrl_str(EVP_PKEY_CTX *ctx, const char *type,
 {
     EVP_PKEY_KDF_CTX *pkctx = ctx->data;
     EVP_KDF_CTX *kctx = pkctx->kctx;
-    const EVP_KDF *kdf = EVP_KDF_CTX_kdf(kctx);
-    const OSSL_PARAM *defs = EVP_KDF_CTX_settable_params(kdf);
+    const EVP_KDF *kdf = EVP_KDF_get_ctx_kdf(kctx);
+    BUF_MEM **collector = NULL;
+    const OSSL_PARAM *defs = EVP_KDF_settable_ctx_params(kdf);
     OSSL_PARAM params[2] = { OSSL_PARAM_END, OSSL_PARAM_END };
     int ok = 0;
 
@@ -208,9 +224,22 @@ static int pkey_kdf_ctrl_str(EVP_PKEY_CTX *ctx, const char *type,
         type = OSSL_KDF_PARAM_SCRYPT_N;
 
     if (!OSSL_PARAM_allocate_from_text(&params[0], defs, type,
-                                       value, strlen(value)))
+                                       value, strlen(value), NULL))
         return 0;
-    ok = EVP_KDF_CTX_set_params(kctx, params);
+
+    /*
+     * We do the same special casing of seed and info here as in
+     * pkey_kdf_ctrl()
+     */
+    if (strcmp(params[0].key, OSSL_KDF_PARAM_SEED) == 0)
+        collector = &pkctx->collected_seed;
+    else if (strcmp(params[0].key, OSSL_KDF_PARAM_INFO) == 0)
+        collector = &pkctx->collected_info;
+
+    if (collector != NULL)
+        ok = collect(collector, params[0].data, params[0].data_size);
+    else
+        ok = EVP_KDF_set_ctx_params(kctx, params);
     OPENSSL_free(params[0].data);
     return ok;
 }
@@ -245,7 +274,7 @@ static int pkey_kdf_derive(EVP_PKEY_CTX *ctx, unsigned char *key,
                                               pkctx->collected_seed->data,
                                               pkctx->collected_seed->length);
 
-        r = EVP_KDF_CTX_set_params(kctx, params);
+        r = EVP_KDF_set_ctx_params(kctx, params);
         pkey_kdf_free_collected(pkctx);
         if (!r)
             return 0;
@@ -258,7 +287,7 @@ static int pkey_kdf_derive(EVP_PKEY_CTX *ctx, unsigned char *key,
                                               pkctx->collected_info->data,
                                               pkctx->collected_info->length);
 
-        r = EVP_KDF_CTX_set_params(kctx, params);
+        r = EVP_KDF_set_ctx_params(kctx, params);
         pkey_kdf_free_collected(pkctx);
         if (!r)
             return 0;
@@ -277,7 +306,7 @@ static int pkey_kdf_derive(EVP_PKEY_CTX *ctx, unsigned char *key,
 }
 
 #ifndef OPENSSL_NO_SCRYPT
-const EVP_PKEY_METHOD scrypt_pkey_meth = {
+static const EVP_PKEY_METHOD scrypt_pkey_meth = {
     EVP_PKEY_SCRYPT,
     0,
     pkey_kdf_init,
@@ -306,9 +335,14 @@ const EVP_PKEY_METHOD scrypt_pkey_meth = {
     pkey_kdf_ctrl,
     pkey_kdf_ctrl_str
 };
+
+const EVP_PKEY_METHOD *scrypt_pkey_method(void)
+{
+    return &scrypt_pkey_meth;
+}
 #endif
 
-const EVP_PKEY_METHOD tls1_prf_pkey_meth = {
+static const EVP_PKEY_METHOD tls1_prf_pkey_meth = {
     EVP_PKEY_TLS1_PRF,
     0,
     pkey_kdf_init,
@@ -338,7 +372,12 @@ const EVP_PKEY_METHOD tls1_prf_pkey_meth = {
     pkey_kdf_ctrl_str
 };
 
-const EVP_PKEY_METHOD hkdf_pkey_meth = {
+const EVP_PKEY_METHOD *tls1_prf_pkey_method(void)
+{
+    return &tls1_prf_pkey_meth;
+}
+
+static const EVP_PKEY_METHOD hkdf_pkey_meth = {
     EVP_PKEY_HKDF,
     0,
     pkey_kdf_init,
@@ -368,3 +407,7 @@ const EVP_PKEY_METHOD hkdf_pkey_meth = {
     pkey_kdf_ctrl_str
 };
 
+const EVP_PKEY_METHOD *hkdf_pkey_method(void)
+{
+    return &hkdf_pkey_meth;
+}