Refactor SSKDF to create the MAC contexts early
[openssl.git] / providers / common / kdfs / sskdf.c
index 61e4607bee88512c5dd4ab15adb24c55fe9b0c15..d74da436a2712e068683ed1ea49ee29c9d5067d5 100644 (file)
 #include "internal/provider_ctx.h"
 #include "internal/providercommonerr.h"
 #include "internal/provider_algs.h"
+#include "internal/provider_util.h"
 
 typedef struct {
     void *provctx;
-    EVP_MAC *mac;       /* H(x) = HMAC_hash OR H(x) = KMAC */
-    EVP_MD *md;         /* H(x) = hash OR when H(x) = HMAC_hash */
+    EVP_MAC_CTX *macctx;         /* H(x) = HMAC_hash OR H(x) = KMAC */
+    PROV_DIGEST digest;          /* H(x) = hash(x) */
     unsigned char *secret;
     size_t secret_len;
     unsigned char *info;
@@ -206,7 +207,7 @@ static int kmac_init(EVP_MAC_CTX *ctx, const unsigned char *custom,
  *     H(x) = HMAC-hash(salt, x) OR
  *     H(x) = KMAC#(salt, x, outbits, CustomString='KDF')
  */
-static int SSKDF_mac_kdm(EVP_MAC *kdf_mac, const EVP_MD *hmac_md,
+static int SSKDF_mac_kdm(EVP_MAC_CTX *ctx_init,
                          const unsigned char *kmac_custom,
                          size_t kmac_custom_len, size_t kmac_out_len,
                          const unsigned char *salt, size_t salt_len,
@@ -219,30 +220,18 @@ static int SSKDF_mac_kdm(EVP_MAC *kdf_mac, const EVP_MD *hmac_md,
     unsigned char c[4];
     unsigned char mac_buf[EVP_MAX_MD_SIZE];
     unsigned char *out = derived_key;
-    EVP_MAC_CTX *ctx = NULL, *ctx_init = NULL;
+    EVP_MAC_CTX *ctx = NULL;
     unsigned char *mac = mac_buf, *kmac_buffer = NULL;
-    OSSL_PARAM params[3];
-    size_t params_n = 0;
+    OSSL_PARAM params[2], *p = params;
 
     if (z_len > SSKDF_MAX_INLEN || info_len > SSKDF_MAX_INLEN
             || derived_key_len > SSKDF_MAX_INLEN
             || derived_key_len == 0)
         return 0;
 
-    ctx_init = EVP_MAC_CTX_new(kdf_mac);
-    if (ctx_init == NULL)
-        goto end;
-
-    if (hmac_md != NULL) {
-        const char *mdname = EVP_MD_name(hmac_md);
-        params[params_n++] =
-            OSSL_PARAM_construct_utf8_string(OSSL_MAC_PARAM_DIGEST,
-                                             (char *)mdname, 0);
-    }
-    params[params_n++] =
-        OSSL_PARAM_construct_octet_string(OSSL_MAC_PARAM_KEY, (void *)salt,
-                                          salt_len);
-    params[params_n] = OSSL_PARAM_construct_end();
+    *p++ = OSSL_PARAM_construct_octet_string(OSSL_MAC_PARAM_KEY,
+                                             (void *)salt, salt_len);
+    *p = OSSL_PARAM_construct_end();
 
     if (!EVP_MAC_CTX_set_params(ctx_init, params))
         goto end;
@@ -297,7 +286,6 @@ end:
         OPENSSL_cleanse(mac_buf, sizeof(mac_buf));
 
     EVP_MAC_CTX_free(ctx);
-    EVP_MAC_CTX_free(ctx_init);
     return ret;
 }
 
@@ -315,10 +303,11 @@ static void sskdf_reset(void *vctx)
 {
     KDF_SSKDF *ctx = (KDF_SSKDF *)vctx;
 
+    EVP_MAC_CTX_free(ctx->macctx);
+    ossl_prov_digest_reset(&ctx->digest);
     OPENSSL_clear_free(ctx->secret, ctx->secret_len);
     OPENSSL_clear_free(ctx->info, ctx->info_len);
     OPENSSL_clear_free(ctx->salt, ctx->salt_len);
-    EVP_MAC_free(ctx->mac);
     memset(ctx, 0, sizeof(*ctx));
 }
 
@@ -327,8 +316,6 @@ static void sskdf_free(void *vctx)
     KDF_SSKDF *ctx = (KDF_SSKDF *)vctx;
 
     sskdf_reset(ctx);
-    EVP_MD_meth_free(ctx->md);
-    EVP_MAC_free(ctx->mac);
     OPENSSL_free(ctx);
 }
 
@@ -345,53 +332,54 @@ static int sskdf_set_buffer(unsigned char **out, size_t *out_len,
 static size_t sskdf_size(KDF_SSKDF *ctx)
 {
     int len;
+    const EVP_MD *md = ossl_prov_digest_md(&ctx->digest);
 
-    if (ctx->md == NULL) {
+    if (md == NULL) {
         ERR_raise(ERR_LIB_PROV, PROV_R_MISSING_MESSAGE_DIGEST);
         return 0;
     }
-    len = EVP_MD_size(ctx->md);
+    len = EVP_MD_size(md);
     return (len <= 0) ? 0 : (size_t)len;
 }
 
 static int sskdf_derive(void *vctx, unsigned char *key, size_t keylen)
 {
     KDF_SSKDF *ctx = (KDF_SSKDF *)vctx;
+    const EVP_MD *md = ossl_prov_digest_md(&ctx->digest);
 
     if (ctx->secret == NULL) {
         ERR_raise(ERR_LIB_PROV, PROV_R_MISSING_SECRET);
         return 0;
     }
 
-    if (ctx->mac != NULL) {
+    if (ctx->macctx != NULL) {
         /* H(x) = KMAC or H(x) = HMAC */
         int ret;
         const unsigned char *custom = NULL;
         size_t custom_len = 0;
-        const char *macname;
         int default_salt_len;
+        EVP_MAC *mac = EVP_MAC_CTX_mac(ctx->macctx);
 
         /*
          * TODO(3.0) investigate the necessity to have all these controls.
          * Why does KMAC require a salt length that's shorter than the MD
          * block size?
          */
-        macname = EVP_MAC_name(ctx->mac);
-        if (strcmp(macname, OSSL_MAC_NAME_HMAC) == 0) {
+        if (EVP_MAC_is_a(mac, OSSL_MAC_NAME_HMAC)) {
             /* H(x) = HMAC(x, salt, hash) */
-            if (ctx->md == NULL) {
+            if (md == NULL) {
                 ERR_raise(ERR_LIB_PROV, PROV_R_MISSING_MESSAGE_DIGEST);
                 return 0;
             }
-            default_salt_len = EVP_MD_block_size(ctx->md);
+            default_salt_len = EVP_MD_size(md);
             if (default_salt_len <= 0)
                 return 0;
-        } else if (strcmp(macname, OSSL_MAC_NAME_KMAC128) == 0
-                   || strcmp(macname, OSSL_MAC_NAME_KMAC256) == 0) {
+        } else if (EVP_MAC_is_a(mac, OSSL_MAC_NAME_KMAC128)
+                   || EVP_MAC_is_a(mac, OSSL_MAC_NAME_KMAC256)) {
             /* H(x) = KMACzzz(x, salt, custom) */
             custom = kmac_custom_str;
             custom_len = sizeof(kmac_custom_str);
-            if (strcmp(macname, OSSL_MAC_NAME_KMAC128) == 0)
+            if (EVP_MAC_is_a(mac, OSSL_MAC_NAME_KMAC128))
                 default_salt_len = SSKDF_KMAC128_DEFAULT_SALT_SIZE;
             else
                 default_salt_len = SSKDF_KMAC256_DEFAULT_SALT_SIZE;
@@ -408,7 +396,7 @@ static int sskdf_derive(void *vctx, unsigned char *key, size_t keylen)
             }
             ctx->salt_len = default_salt_len;
         }
-        ret = SSKDF_mac_kdm(ctx->mac, ctx->md,
+        ret = SSKDF_mac_kdm(ctx->macctx,
                             custom, custom_len, ctx->out_len,
                             ctx->salt, ctx->salt_len,
                             ctx->secret, ctx->secret_len,
@@ -416,11 +404,11 @@ static int sskdf_derive(void *vctx, unsigned char *key, size_t keylen)
         return ret;
     } else {
         /* H(x) = hash */
-        if (ctx->md == NULL) {
+        if (md == NULL) {
             ERR_raise(ERR_LIB_PROV, PROV_R_MISSING_MESSAGE_DIGEST);
             return 0;
         }
-        return SSKDF_hash_kdm(ctx->md, ctx->secret, ctx->secret_len,
+        return SSKDF_hash_kdm(md, ctx->secret, ctx->secret_len,
                               ctx->info, ctx->info_len, 0, key, keylen);
     }
 }
@@ -428,67 +416,41 @@ static int sskdf_derive(void *vctx, unsigned char *key, size_t keylen)
 static int x963kdf_derive(void *vctx, unsigned char *key, size_t keylen)
 {
     KDF_SSKDF *ctx = (KDF_SSKDF *)vctx;
+    const EVP_MD *md = ossl_prov_digest_md(&ctx->digest);
 
     if (ctx->secret == NULL) {
         ERR_raise(ERR_LIB_PROV, PROV_R_MISSING_SECRET);
         return 0;
     }
 
-    if (ctx->mac != NULL) {
+    if (ctx->macctx != NULL) {
         ERR_raise(ERR_LIB_PROV, PROV_R_NOT_SUPPORTED);
         return 0;
-    } else {
-        /* H(x) = hash */
-        if (ctx->md == NULL) {
-            ERR_raise(ERR_LIB_PROV, PROV_R_MISSING_MESSAGE_DIGEST);
-            return 0;
-        }
-        return SSKDF_hash_kdm(ctx->md, ctx->secret, ctx->secret_len,
-                              ctx->info, ctx->info_len, 1, key, keylen);
     }
+
+    /* H(x) = hash */
+    if (md == NULL) {
+        ERR_raise(ERR_LIB_PROV, PROV_R_MISSING_MESSAGE_DIGEST);
+        return 0;
+    }
+
+    return SSKDF_hash_kdm(md, ctx->secret, ctx->secret_len,
+                          ctx->info, ctx->info_len, 1, key, keylen);
 }
 
 static int sskdf_set_ctx_params(void *vctx, const OSSL_PARAM params[])
 {
     const OSSL_PARAM *p;
     KDF_SSKDF *ctx = vctx;
-    EVP_MD *md;
-    EVP_MAC *mac;
+    OPENSSL_CTX *libctx = PROV_LIBRARY_CONTEXT_OF(ctx->provctx);
     size_t sz;
-    const char *properties = NULL;
 
-    /* Grab search properties, should be before the digest and mac lookups */
-    if ((p = OSSL_PARAM_locate_const(params, OSSL_KDF_PARAM_PROPERTIES))
-        != NULL) {
-        if (p->data_type != OSSL_PARAM_UTF8_STRING)
-            return 0;
-        properties = p->data;
-    }
-    /* Handle aliasing of digest parameter names */
-    if ((p = OSSL_PARAM_locate_const(params, OSSL_KDF_PARAM_DIGEST)) != NULL) {
-        if (p->data_type != OSSL_PARAM_UTF8_STRING)
-            return 0;
-        md = EVP_MD_fetch(PROV_LIBRARY_CONTEXT_OF(ctx->provctx), p->data,
-                          properties);
-        if (md == NULL) {
-            ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_DIGEST);
-            return 0;
-        }
-        EVP_MD_meth_free(ctx->md);
-        ctx->md = md;
-    }
-
-    if ((p = OSSL_PARAM_locate_const(params, OSSL_KDF_PARAM_MAC)) != NULL) {
-        EVP_MAC_free(ctx->mac);
-        ctx->mac = NULL;
+    if (!ossl_prov_digest_load_from_params(&ctx->digest, params, libctx))
+        return 0;
 
-        mac = EVP_MAC_fetch(PROV_LIBRARY_CONTEXT_OF(ctx->provctx), p->data,
-                            properties);
-        if (mac == NULL)
-            return 0;
-        EVP_MAC_free(ctx->mac);
-        ctx->mac = mac;
-    }
+    if (!ossl_prov_macctx_load_from_params(&ctx->macctx, params,
+                                           NULL, NULL, NULL, libctx))
+        return 0;
 
     if ((p = OSSL_PARAM_locate_const(params, OSSL_KDF_PARAM_SECRET)) != NULL
         || (p = OSSL_PARAM_locate_const(params, OSSL_KDF_PARAM_KEY)) != NULL)