Refactor TLS1-PRF to create the MAC contexts early
[openssl.git] / providers / common / kdfs / tls1_prf.c
index 5d7e599e64e0030fa1a22a25411805806f4fe6cb..0acdcdf3b88db4399494e038c60b9d44dd631ac8 100644 (file)
@@ -58,6 +58,7 @@
 #include "internal/provider_ctx.h"
 #include "internal/providercommonerr.h"
 #include "internal/provider_algs.h"
+#include "internal/provider_util.h"
 #include "e_os.h"
 
 static OSSL_OP_kdf_newctx_fn kdf_tls1_prf_new;
@@ -67,7 +68,7 @@ static OSSL_OP_kdf_derive_fn kdf_tls1_prf_derive;
 static OSSL_OP_kdf_settable_ctx_params_fn kdf_tls1_prf_settable_ctx_params;
 static OSSL_OP_kdf_set_ctx_params_fn kdf_tls1_prf_set_ctx_params;
 
-static int tls1_prf_alg(const EVP_MD *md, const EVP_MD *sha1,
+static int tls1_prf_alg(EVP_MAC_CTX *mdctx, EVP_MAC_CTX *sha1ctx,
                         const unsigned char *sec, size_t slen,
                         const unsigned char *seed, size_t seed_len,
                         unsigned char *out, size_t olen);
@@ -77,10 +78,12 @@ static int tls1_prf_alg(const EVP_MD *md, const EVP_MD *sha1,
 /* TLS KDF kdf context structure */
 typedef struct {
     void *provctx;
-    /* Digest to use for PRF */
-    EVP_MD *md;
-    /* Second digest for the MD5/SHA-1 combined PRF */
-    EVP_MD *sha1;
+
+    /* MAC context for the main digest */
+    EVP_MAC_CTX *P_hash;
+    /* MAC context for SHA1 for the MD5/SHA-1 combined PRF */
+    EVP_MAC_CTX *P_sha1;
+
     /* Secret value to use for PRF */
     unsigned char *sec;
     size_t seclen;
@@ -104,8 +107,6 @@ static void kdf_tls1_prf_free(void *vctx)
     TLS1_PRF *ctx = (TLS1_PRF *)vctx;
 
     kdf_tls1_prf_reset(ctx);
-    EVP_MD_meth_free(ctx->sha1);
-    EVP_MD_meth_free(ctx->md);
     OPENSSL_free(ctx);
 }
 
@@ -113,6 +114,8 @@ static void kdf_tls1_prf_reset(void *vctx)
 {
     TLS1_PRF *ctx = (TLS1_PRF *)vctx;
 
+    EVP_MAC_CTX_free(ctx->P_hash);
+    EVP_MAC_CTX_free(ctx->P_sha1);
     OPENSSL_clear_free(ctx->sec, ctx->seclen);
     OPENSSL_cleanse(ctx->seed, ctx->seedlen);
     memset(ctx, 0, sizeof(*ctx));
@@ -123,7 +126,7 @@ static int kdf_tls1_prf_derive(void *vctx, unsigned char *key,
 {
     TLS1_PRF *ctx = (TLS1_PRF *)vctx;
 
-    if (ctx->md == NULL) {
+    if (ctx->P_hash == NULL) {
         ERR_raise(ERR_LIB_PROV, PROV_R_MISSING_MESSAGE_DIGEST);
         return 0;
     }
@@ -135,50 +138,73 @@ static int kdf_tls1_prf_derive(void *vctx, unsigned char *key,
         ERR_raise(ERR_LIB_PROV, PROV_R_MISSING_SEED);
         return 0;
     }
-    return tls1_prf_alg(ctx->md, ctx->sha1, ctx->sec, ctx->seclen,
+
+    return tls1_prf_alg(ctx->P_hash, ctx->P_sha1,
+                        ctx->sec, ctx->seclen,
                         ctx->seed, ctx->seedlen,
                         key, keylen);
 }
 
+static EVP_MAC_CTX *kdf_tls1_prf_mkmacctx(OPENSSL_CTX *libctx,
+                                          const char *mdname,
+                                          const OSSL_PARAM params[])
+{
+    const OSSL_PARAM *p;
+    OSSL_PARAM mac_params[5], *mp = mac_params;
+    const char *properties = NULL;
+    /* TODO(3.0) rethink "flags", also see hmac.c in providers */
+    int mac_flags = EVP_MD_CTX_FLAG_NON_FIPS_ALLOW;
+    EVP_MAC_CTX *macctx = NULL;
+
+    *mp++ = OSSL_PARAM_construct_utf8_string(OSSL_MAC_PARAM_DIGEST,
+                                             (char *)mdname, 0);
+#if !defined(OPENSSL_NO_ENGINE) && !defined(FIPS_MODE)
+    if ((p = OSSL_PARAM_locate_const(params, OSSL_KDF_PARAM_ENGINE)) != NULL)
+        *mp++ = *p;
+#endif
+    if ((p = OSSL_PARAM_locate_const(params,
+                                     OSSL_KDF_PARAM_PROPERTIES)) != NULL) {
+        properties = p->data;
+        *mp++ = *p;
+    }
+    *mp++ = OSSL_PARAM_construct_int(OSSL_MAC_PARAM_FLAGS, &mac_flags);
+    *mp = OSSL_PARAM_construct_end();
+
+    /* Implicit fetch */
+    {
+        EVP_MAC *mac = EVP_MAC_fetch(libctx, OSSL_MAC_NAME_HMAC, properties);
+
+        macctx = EVP_MAC_CTX_new(mac);
+        /* The context holds on to the MAC */
+        EVP_MAC_free(mac);
+        if (macctx == NULL)
+            goto err;
+    }
+
+    if (EVP_MAC_CTX_set_params(macctx, mac_params))
+        goto done;
+ err:
+    EVP_MAC_CTX_free(macctx);
+    macctx = NULL;
+ done:
+    return macctx;
+}
+
 static int kdf_tls1_prf_set_ctx_params(void *vctx, const OSSL_PARAM params[])
 {
     const OSSL_PARAM *p;
     TLS1_PRF *ctx = vctx;
-    EVP_MD *md, *sha = NULL;
-    const char *properties = NULL, *name;
+    OPENSSL_CTX *libctx = PROV_LIBRARY_CONTEXT_OF(ctx->provctx);
 
-    /* Grab search properties, this should be before the digest lookup */
-    if ((p = OSSL_PARAM_locate_const(params, OSSL_KDF_PARAM_PROPERTIES))
-        != NULL) {
-        if (p->data_type != OSSL_PARAM_UTF8_STRING)
-            return 0;
-        properties = p->data;
-    }
-    /* Handle aliasing of digest parameter names */
     if ((p = OSSL_PARAM_locate_const(params, OSSL_KDF_PARAM_DIGEST)) != NULL) {
-        if (p->data_type != OSSL_PARAM_UTF8_STRING)
-            return 0;
-        name = p->data;
-        if (strcasecmp(name, SN_md5_sha1) == 0) {
-            sha = EVP_MD_fetch(PROV_LIBRARY_CONTEXT_OF(ctx->provctx), SN_sha1,
-                               properties);
-            if (sha == NULL) {
-                ERR_raise(ERR_LIB_PROV, PROV_R_UNABLE_TO_LOAD_SHA1);
-                return 0;
-            }
-            name = SN_md5;
+        EVP_MAC_CTX_free(ctx->P_hash);
+        EVP_MAC_CTX_free(ctx->P_sha1);
+        if (strcasecmp(p->data, SN_md5_sha1) == 0) {
+            ctx->P_hash = kdf_tls1_prf_mkmacctx(libctx, SN_md5, params);
+            ctx->P_sha1 = kdf_tls1_prf_mkmacctx(libctx, SN_sha1, params);
+        } else {
+            ctx->P_hash = kdf_tls1_prf_mkmacctx(libctx, p->data, params);
         }
-        md = EVP_MD_fetch(PROV_LIBRARY_CONTEXT_OF(ctx->provctx), name,
-                          properties);
-        if (md == NULL) {
-            ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_DIGEST);
-            EVP_MD_meth_free(sha);
-            return 0;
-        }
-        EVP_MD_meth_free(ctx->sha1);
-        EVP_MD_meth_free(ctx->md);
-        ctx->md = md;
-        ctx->sha1 = sha;
     }
 
     if ((p = OSSL_PARAM_locate_const(params, OSSL_KDF_PARAM_SECRET)) != NULL) {
@@ -275,34 +301,21 @@ const OSSL_DISPATCH kdf_tls1_prf_functions[] = {
  *     A(0) = seed
  *     A(i) = HMAC_<hash>(secret, A(i-1))
  */
-static int tls1_prf_P_hash(const EVP_MD *md,
+static int tls1_prf_P_hash(EVP_MAC_CTX *ctx_init,
                            const unsigned char *sec, size_t sec_len,
                            const unsigned char *seed, size_t seed_len,
                            unsigned char *out, size_t olen)
 {
     size_t chunk;
-    EVP_MAC *mac = NULL;
-    EVP_MAC_CTX *ctx = NULL, *ctx_Ai = NULL, *ctx_init = NULL;
+    EVP_MAC_CTX *ctx = NULL, *ctx_Ai = NULL;
     unsigned char Ai[EVP_MAX_MD_SIZE];
     size_t Ai_len;
     int ret = 0;
-    OSSL_PARAM params[4];
-    int mac_flags;
-    const char *mdname = EVP_MD_name(md);
-
-    mac = EVP_MAC_fetch(NULL, OSSL_MAC_NAME_HMAC, NULL); /* Implicit fetch */
-    ctx_init = EVP_MAC_CTX_new(mac);
-    if (ctx_init == NULL)
-        goto err;
+    OSSL_PARAM params[2], *p = params;
 
-    /* TODO(3.0) rethink "flags", also see hmac.c in providers */
-    mac_flags = EVP_MD_CTX_FLAG_NON_FIPS_ALLOW;
-    params[0] = OSSL_PARAM_construct_int(OSSL_MAC_PARAM_FLAGS, &mac_flags);
-    params[1] = OSSL_PARAM_construct_utf8_string(OSSL_MAC_PARAM_DIGEST,
-                                                 (char *)mdname, 0);
-    params[2] = OSSL_PARAM_construct_octet_string(OSSL_MAC_PARAM_KEY,
-                                                  (void *)sec, sec_len);
-    params[3] = OSSL_PARAM_construct_end();
+    *p++ = OSSL_PARAM_construct_octet_string(OSSL_MAC_PARAM_KEY,
+                                             (void *)sec, sec_len);
+    *p = OSSL_PARAM_construct_end();
     if (!EVP_MAC_CTX_set_params(ctx_init, params))
         goto err;
     if (!EVP_MAC_init(ctx_init))
@@ -356,8 +369,6 @@ static int tls1_prf_P_hash(const EVP_MD *md,
  err:
     EVP_MAC_CTX_free(ctx);
     EVP_MAC_CTX_free(ctx_Ai);
-    EVP_MAC_CTX_free(ctx_init);
-    EVP_MAC_free(mac);
     OPENSSL_cleanse(Ai, sizeof(Ai));
     return ret;
 }
@@ -382,12 +393,12 @@ static int tls1_prf_P_hash(const EVP_MD *md,
  *
  *   PRF(secret, label, seed) = P_<hash>(secret, label + seed)
  */
-static int tls1_prf_alg(const EVP_MD *md, const EVP_MD *sha1,
+static int tls1_prf_alg(EVP_MAC_CTX *mdctx, EVP_MAC_CTX *sha1ctx,
                         const unsigned char *sec, size_t slen,
                         const unsigned char *seed, size_t seed_len,
                         unsigned char *out, size_t olen)
 {
-    if (sha1 != NULL) {
+    if (sha1ctx != NULL) {
         /* TLS v1.0 and TLS v1.1 */
         size_t i;
         unsigned char *tmp;
@@ -395,7 +406,7 @@ static int tls1_prf_alg(const EVP_MD *md, const EVP_MD *sha1,
         size_t L_S1 = (slen + 1) / 2;
         size_t L_S2 = L_S1;
 
-        if (!tls1_prf_P_hash(md, sec, L_S1,
+        if (!tls1_prf_P_hash(mdctx, sec, L_S1,
                              seed, seed_len, out, olen))
             return 0;
 
@@ -403,7 +414,8 @@ static int tls1_prf_alg(const EVP_MD *md, const EVP_MD *sha1,
             ERR_raise(ERR_LIB_PROV, ERR_R_MALLOC_FAILURE);
             return 0;
         }
-        if (!tls1_prf_P_hash(sha1, sec + slen - L_S2, L_S2,
+
+        if (!tls1_prf_P_hash(sha1ctx, sec + slen - L_S2, L_S2,
                              seed, seed_len, tmp, olen)) {
             OPENSSL_clear_free(tmp, olen);
             return 0;
@@ -415,7 +427,7 @@ static int tls1_prf_alg(const EVP_MD *md, const EVP_MD *sha1,
     }
 
     /* TLS v1.2 */
-    if (!tls1_prf_P_hash(md, sec, slen, seed, seed_len, out, olen))
+    if (!tls1_prf_P_hash(mdctx, sec, slen, seed, seed_len, out, olen))
         return 0;
 
     return 1;