Fix typos found by codespell
[openssl.git] / crypto / hpke / hpke.c
index 78341d358f1867b5138e48f55a65e106ab8bc474..8178ff249a19d26385ab0a11bf84044d01b6a169 100644 (file)
@@ -52,6 +52,11 @@ struct ossl_hpke_ctx_st
     char *propq; /* properties */
     int mode; /* HPKE mode */
     OSSL_HPKE_SUITE suite; /* suite */
+    const OSSL_HPKE_KEM_INFO *kem_info;
+    const OSSL_HPKE_KDF_INFO *kdf_info;
+    const OSSL_HPKE_AEAD_INFO *aead_info;
+    EVP_CIPHER *aead_ciph;
+    int role; /* sender(0) or receiver(1) */
     uint64_t seq; /* aead sequence number */
     unsigned char *shared_secret; /* KEM output, zz */
     size_t shared_secretlen;
@@ -125,13 +130,8 @@ static EVP_PKEY *evp_pkey_new_raw_nist_public_key(OSSL_LIB_CTX *libctx,
 
 /**
  * @brief do the AEAD decryption
- * @param libctx is the context to use
- * @param propq is a properties string
- * @param suite is the ciphersuite
- * @param key is the secret
- * @param keylen is the length of the secret
+ * @param hctx is the context to use
  * @param iv is the initialisation vector
- * @param ivlen is the length of the iv
  * @param aad is the additional authenticated data
  * @param aadlen is the length of the aad
  * @param ct is the ciphertext buffer
@@ -140,10 +140,7 @@ static EVP_PKEY *evp_pkey_new_raw_nist_public_key(OSSL_LIB_CTX *libctx,
  * @param ptlen input/output, better be big enough on input, exact on output
  * @return 1 on success, 0 otherwise
  */
-static int hpke_aead_dec(OSSL_LIB_CTX *libctx, const char *propq,
-                         OSSL_HPKE_SUITE suite,
-                         const unsigned char *key, size_t keylen,
-                         const unsigned char *iv, size_t ivlen,
+static int hpke_aead_dec(OSSL_HPKE_CTX *hctx, const unsigned char *iv,
                          const unsigned char *aad, size_t aadlen,
                          const unsigned char *ct, size_t ctlen,
                          unsigned char *pt, size_t *ptlen)
@@ -152,46 +149,28 @@ static int hpke_aead_dec(OSSL_LIB_CTX *libctx, const char *propq,
     EVP_CIPHER_CTX *ctx = NULL;
     int len = 0;
     size_t taglen;
-    EVP_CIPHER *enc = NULL;
-    const OSSL_HPKE_AEAD_INFO *aead_info = NULL;
 
-    if (pt == NULL || ptlen == NULL) {
-        ERR_raise(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR);
-        goto err;
-    }
-    aead_info = ossl_HPKE_AEAD_INFO_find_id(suite.aead_id);
-    if (aead_info == NULL) {
-        ERR_raise(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR);
-        goto err;
-    }
-    taglen = aead_info->taglen;
+    taglen = hctx->aead_info->taglen;
     if (ctlen <= taglen || *ptlen < ctlen - taglen) {
         ERR_raise(ERR_LIB_CRYPTO, ERR_R_PASSED_INVALID_ARGUMENT);
-        goto err;
+        return 0;
     }
     /* Create and initialise the context */
-    if ((ctx = EVP_CIPHER_CTX_new()) == NULL) {
-        ERR_raise(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR);
-        goto err;
-    }
-    /* Initialise the encryption operation */
-    enc = EVP_CIPHER_fetch(libctx, aead_info->name, propq);
-    if (enc == NULL) {
-        ERR_raise(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR);
-        goto err;
-    }
-    if (EVP_DecryptInit_ex(ctx, enc, NULL, NULL, NULL) != 1) {
+    if ((ctx = EVP_CIPHER_CTX_new()) == NULL)
+        return 0;
+
+    /* Initialise the decryption operation. */
+    if (EVP_DecryptInit_ex(ctx, hctx->aead_ciph, NULL, NULL, NULL) != 1) {
         ERR_raise(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR);
         goto err;
     }
-    EVP_CIPHER_free(enc);
-    enc = NULL;
-    if (EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_AEAD_SET_IVLEN, ivlen, NULL) != 1) {
+    if (EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_AEAD_SET_IVLEN,
+                            hctx->noncelen, NULL) != 1) {
         ERR_raise(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR);
         goto err;
     }
     /* Initialise key and IV */
-    if (EVP_DecryptInit_ex(ctx, NULL, NULL, key, iv) != 1) {
+    if (EVP_DecryptInit_ex(ctx, NULL, NULL, hctx->key, iv) != 1) {
         ERR_raise(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR);
         goto err;
     }
@@ -223,19 +202,13 @@ err:
     if (erv != 1)
         OPENSSL_cleanse(pt, *ptlen);
     EVP_CIPHER_CTX_free(ctx);
-    EVP_CIPHER_free(enc);
     return erv;
 }
 
 /**
  * @brief do AEAD encryption as per the RFC
- * @param libctx is the context to use
- * @param propq is a properties string
- * @param suite is the ciphersuite
- * @param key is the secret
- * @param keylen is the length of the secret
+ * @param hctx is the context to use
  * @param iv is the initialisation vector
- * @param ivlen is the length of the iv
  * @param aad is the additional authenticated data
  * @param aadlen is the length of the aad
  * @param pt is the plaintext buffer
@@ -244,10 +217,7 @@ err:
  * @param ctlen input/output, needs space for tag on input, exact on output
  * @return 1 for success, 0 otherwise
  */
-static int hpke_aead_enc(OSSL_LIB_CTX *libctx, const char *propq,
-                         OSSL_HPKE_SUITE suite,
-                         const unsigned char *key, size_t keylen,
-                         const unsigned char *iv, size_t ivlen,
+static int hpke_aead_enc(OSSL_HPKE_CTX *hctx, const unsigned char *iv,
                          const unsigned char *aad, size_t aadlen,
                          const unsigned char *pt, size_t ptlen,
                          unsigned char *ct, size_t *ctlen)
@@ -256,47 +226,29 @@ static int hpke_aead_enc(OSSL_LIB_CTX *libctx, const char *propq,
     EVP_CIPHER_CTX *ctx = NULL;
     int len;
     size_t taglen = 0;
-    const OSSL_HPKE_AEAD_INFO *aead_info = NULL;
-    EVP_CIPHER *enc = NULL;
     unsigned char tag[16];
 
-    if (ct == NULL || ctlen == NULL) {
-        ERR_raise(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR);
-        goto err;
-    }
-    aead_info = ossl_HPKE_AEAD_INFO_find_id(suite.aead_id);
-    if (aead_info == NULL) {
-        ERR_raise(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR);
-        goto err;
-    }
-    taglen = aead_info->taglen;
+    taglen = hctx->aead_info->taglen;
     if (*ctlen <= taglen || ptlen > *ctlen - taglen) {
         ERR_raise(ERR_LIB_CRYPTO, ERR_R_PASSED_INVALID_ARGUMENT);
-        goto err;
+        return 0;
     }
     /* Create and initialise the context */
-    if ((ctx = EVP_CIPHER_CTX_new()) == NULL) {
-        ERR_raise(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR);
-        goto err;
-    }
+    if ((ctx = EVP_CIPHER_CTX_new()) == NULL)
+        return 0;
+
     /* Initialise the encryption operation. */
-    enc = EVP_CIPHER_fetch(libctx, aead_info->name, propq);
-    if (enc == NULL) {
+    if (EVP_EncryptInit_ex(ctx, hctx->aead_ciph, NULL, NULL, NULL) != 1) {
         ERR_raise(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR);
         goto err;
     }
-    if (EVP_EncryptInit_ex(ctx, enc, NULL, NULL, NULL) != 1) {
-        ERR_raise(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR);
-        goto err;
-    }
-    EVP_CIPHER_free(enc);
-    enc = NULL;
-    if (EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_AEAD_SET_IVLEN, ivlen, NULL) != 1) {
+    if (EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_AEAD_SET_IVLEN,
+                            hctx->noncelen, NULL) != 1) {
         ERR_raise(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR);
         goto err;
     }
     /* Initialise key and IV */
-    if (EVP_EncryptInit_ex(ctx, NULL, NULL, key, iv) != 1) {
+    if (EVP_EncryptInit_ex(ctx, NULL, NULL, hctx->key, iv) != 1) {
         ERR_raise(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR);
         goto err;
     }
@@ -331,7 +283,6 @@ err:
     if (erv != 1)
         OPENSSL_cleanse(ct, *ctlen);
     EVP_CIPHER_CTX_free(ctx);
-    EVP_CIPHER_free(enc);
     return erv;
 }
 
@@ -359,15 +310,30 @@ static int hpke_mode_check(unsigned int mode)
  * @param suite is the suite to check
  * @return 1 for good, 0 otherwise
  */
-static int hpke_suite_check(OSSL_HPKE_SUITE suite)
+static int hpke_suite_check(OSSL_HPKE_SUITE suite,
+                            const OSSL_HPKE_KEM_INFO **kem_info,
+                            const OSSL_HPKE_KDF_INFO **kdf_info,
+                            const OSSL_HPKE_AEAD_INFO **aead_info)
 {
+    const OSSL_HPKE_KEM_INFO *kem_info_;
+    const OSSL_HPKE_KDF_INFO *kdf_info_;
+    const OSSL_HPKE_AEAD_INFO *aead_info_;
+
     /* check KEM, KDF and AEAD are supported here */
-    if (ossl_HPKE_KEM_INFO_find_id(suite.kem_id) == NULL)
+    if ((kem_info_ = ossl_HPKE_KEM_INFO_find_id(suite.kem_id)) == NULL)
         return 0;
-    if (ossl_HPKE_KDF_INFO_find_id(suite.kdf_id) == NULL)
+    if ((kdf_info_ = ossl_HPKE_KDF_INFO_find_id(suite.kdf_id)) == NULL)
         return 0;
-    if (ossl_HPKE_AEAD_INFO_find_id(suite.aead_id) == NULL)
+    if ((aead_info_ = ossl_HPKE_AEAD_INFO_find_id(suite.aead_id)) == NULL)
         return 0;
+
+    if (kem_info != NULL)
+        *kem_info = kem_info_;
+    if (kdf_info != NULL)
+        *kdf_info = kdf_info_;
+    if (aead_info != NULL)
+        *aead_info = aead_info_;
+
     return 1;
 }
 
@@ -433,21 +399,11 @@ static int hpke_expansion(OSSL_HPKE_SUITE suite,
         ERR_raise(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR);
         return 0;
     }
-    if (hpke_suite_check(suite) != 1) {
+    if (hpke_suite_check(suite, &kem_info, NULL, &aead_info) != 1) {
         ERR_raise(ERR_LIB_CRYPTO, ERR_R_PASSED_INVALID_ARGUMENT);
         return 0;
     }
-    aead_info = ossl_HPKE_AEAD_INFO_find_id(suite.aead_id);
-    if (aead_info == NULL) {
-        ERR_raise(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR);
-        return 0;
-    }
     *cipherlen = clearlen + aead_info->taglen;
-    kem_info = ossl_HPKE_KEM_INFO_find_id(suite.kem_id);
-    if (kem_info == NULL) {
-        ERR_raise(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR);
-        return 0;
-    }
     *enclen = kem_info->Nenc;
     return 1;
 }
@@ -482,7 +438,7 @@ static size_t hpke_seqnonce2buf(OSSL_HPKE_CTX *ctx,
  * @brief call the underlying KEM to encap
  * @param ctx is the OSSL_HPKE_CTX
  * @param enc is a buffer for the sender's ephemeral public value
- * @param enclen is the size of enc on input, number of octets used on ouptut
+ * @param enclen is the size of enc on input, number of octets used on output
  * @param pub is the recipient's public value
  * @param publen is the length of pub
  * @return 1 for success, 0 for error
@@ -846,39 +802,61 @@ err:
  * in doc/man3/OSSL_HPKE_CTX_new.pod to avoid duplication
  */
 
-OSSL_HPKE_CTX *OSSL_HPKE_CTX_new(int mode, OSSL_HPKE_SUITE suite,
+OSSL_HPKE_CTX *OSSL_HPKE_CTX_new(int mode, OSSL_HPKE_SUITE suite, int role,
                                  OSSL_LIB_CTX *libctx, const char *propq)
 {
     OSSL_HPKE_CTX *ctx = NULL;
+    const OSSL_HPKE_KEM_INFO *kem_info;
+    const OSSL_HPKE_KDF_INFO *kdf_info;
+    const OSSL_HPKE_AEAD_INFO *aead_info;
 
     if (hpke_mode_check(mode) != 1) {
         ERR_raise(ERR_LIB_CRYPTO, ERR_R_PASSED_INVALID_ARGUMENT);
         return NULL;
     }
-    if (hpke_suite_check(suite) != 1) {
+    if (hpke_suite_check(suite, &kem_info, &kdf_info, &aead_info) != 1) {
         ERR_raise(ERR_LIB_CRYPTO, ERR_R_PASSED_INVALID_ARGUMENT);
         return NULL;
     }
+    if (role != OSSL_HPKE_ROLE_SENDER && role != OSSL_HPKE_ROLE_RECEIVER) {
+        ERR_raise(ERR_LIB_CRYPTO, ERR_R_PASSED_INVALID_ARGUMENT);
+        return 0;
+    }
     ctx = OPENSSL_zalloc(sizeof(*ctx));
     if (ctx == NULL)
         return NULL;
     ctx->libctx = libctx;
     if (propq != NULL) {
         ctx->propq = OPENSSL_strdup(propq);
-        if (ctx->propq == NULL) {
-            OPENSSL_free(ctx);
-            return NULL;
+        if (ctx->propq == NULL)
+            goto err;
+    }
+    if (suite.aead_id != OSSL_HPKE_AEAD_ID_EXPORTONLY) {
+        ctx->aead_ciph = EVP_CIPHER_fetch(libctx, aead_info->name, propq);
+        if (ctx->aead_ciph == NULL) {
+            ERR_raise(ERR_LIB_CRYPTO, ERR_R_FETCH_FAILED);
+            goto err;
         }
     }
+    ctx->role = role;
     ctx->mode = mode;
     ctx->suite = suite;
+    ctx->kem_info = kem_info;
+    ctx->kdf_info = kdf_info;
+    ctx->aead_info = aead_info;
     return ctx;
+
+ err:
+    EVP_CIPHER_free(ctx->aead_ciph);
+    OPENSSL_free(ctx);
+    return NULL;
 }
 
 void OSSL_HPKE_CTX_free(OSSL_HPKE_CTX *ctx)
 {
     if (ctx == NULL)
         return;
+    EVP_CIPHER_free(ctx->aead_ciph);
     OPENSSL_free(ctx->propq);
     OPENSSL_clear_free(ctx->exportersec, ctx->exporterseclen);
     OPENSSL_free(ctx->pskid);
@@ -943,6 +921,10 @@ int OSSL_HPKE_CTX_set1_ikme(OSSL_HPKE_CTX *ctx,
         ERR_raise(ERR_LIB_CRYPTO, ERR_R_PASSED_INVALID_ARGUMENT);
         return 0;
     }
+    if (ctx->role != OSSL_HPKE_ROLE_SENDER) {
+        ERR_raise(ERR_LIB_CRYPTO, ERR_R_PASSED_INVALID_ARGUMENT);
+        return 0;
+    }
     OPENSSL_clear_free(ctx->ikme, ctx->ikmelen);
     ctx->ikme = OPENSSL_memdup(ikme, ikmelen);
     if (ctx->ikme == NULL)
@@ -962,6 +944,10 @@ int OSSL_HPKE_CTX_set1_authpriv(OSSL_HPKE_CTX *ctx, EVP_PKEY *priv)
         ERR_raise(ERR_LIB_CRYPTO, ERR_R_PASSED_INVALID_ARGUMENT);
         return 0;
     }
+    if (ctx->role != OSSL_HPKE_ROLE_SENDER) {
+        ERR_raise(ERR_LIB_CRYPTO, ERR_R_PASSED_INVALID_ARGUMENT);
+        return 0;
+    }
     EVP_PKEY_free(ctx->authpriv);
     ctx->authpriv = EVP_PKEY_dup(priv);
     if (ctx->authpriv == NULL)
@@ -987,6 +973,10 @@ int OSSL_HPKE_CTX_set1_authpub(OSSL_HPKE_CTX *ctx,
         ERR_raise(ERR_LIB_CRYPTO, ERR_R_PASSED_INVALID_ARGUMENT);
         return 0;
     }
+    if (ctx->role != OSSL_HPKE_ROLE_RECEIVER) {
+        ERR_raise(ERR_LIB_CRYPTO, ERR_R_PASSED_INVALID_ARGUMENT);
+        return 0;
+    }
     /* check the value seems like a good public key for this kem */
     kem_info = ossl_HPKE_KEM_INFO_find_id(ctx->suite.kem_id);
     if (kem_info == NULL)
@@ -1048,6 +1038,15 @@ int OSSL_HPKE_CTX_set_seq(OSSL_HPKE_CTX *ctx, uint64_t seq)
         ERR_raise(ERR_LIB_CRYPTO, ERR_R_PASSED_NULL_PARAMETER);
         return 0;
     }
+    /*
+     * We disallow senders from doing this as it's dangerous
+     * Receivers are ok to use this, as no harm should ensue
+     * if they go wrong.
+     */
+    if (ctx->role == OSSL_HPKE_ROLE_SENDER) {
+        ERR_raise(ERR_LIB_CRYPTO, ERR_R_PASSED_INVALID_ARGUMENT);
+        return 0;
+    }
     ctx->seq = seq;
     return 1;
 }
@@ -1064,6 +1063,10 @@ int OSSL_HPKE_encap(OSSL_HPKE_CTX *ctx,
         ERR_raise(ERR_LIB_CRYPTO, ERR_R_PASSED_NULL_PARAMETER);
         return 0;
     }
+    if (ctx->role != OSSL_HPKE_ROLE_SENDER) {
+        ERR_raise(ERR_LIB_CRYPTO, ERR_R_PASSED_INVALID_ARGUMENT);
+        return 0;
+    }
     if (infolen > OSSL_HPKE_MAX_INFOLEN) {
         ERR_raise(ERR_LIB_CRYPTO, ERR_R_PASSED_INVALID_ARGUMENT);
         return 0;
@@ -1097,6 +1100,10 @@ int OSSL_HPKE_decap(OSSL_HPKE_CTX *ctx,
         ERR_raise(ERR_LIB_CRYPTO, ERR_R_PASSED_NULL_PARAMETER);
         return 0;
     }
+    if (ctx->role != OSSL_HPKE_ROLE_RECEIVER) {
+        ERR_raise(ERR_LIB_CRYPTO, ERR_R_PASSED_INVALID_ARGUMENT);
+        return 0;
+    }
     if (infolen > OSSL_HPKE_MAX_INFOLEN) {
         ERR_raise(ERR_LIB_CRYPTO, ERR_R_PASSED_INVALID_ARGUMENT);
         return 0;
@@ -1133,6 +1140,10 @@ int OSSL_HPKE_seal(OSSL_HPKE_CTX *ctx,
         ERR_raise(ERR_LIB_CRYPTO, ERR_R_PASSED_NULL_PARAMETER);
         return 0;
     }
+    if (ctx->role != OSSL_HPKE_ROLE_SENDER) {
+        ERR_raise(ERR_LIB_CRYPTO, ERR_R_PASSED_INVALID_ARGUMENT);
+        return 0;
+    }
     if ((ctx->seq + 1) == 0) { /* wrap around imminent !!! */
         ERR_raise(ERR_LIB_CRYPTO, ERR_R_SHOULD_NOT_HAVE_BEEN_CALLED);
         return 0;
@@ -1147,9 +1158,7 @@ int OSSL_HPKE_seal(OSSL_HPKE_CTX *ctx,
         ERR_raise(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR);
         return 0;
     }
-    if (hpke_aead_enc(ctx->libctx, ctx->propq, ctx->suite,
-                      ctx->key, ctx->keylen, seqbuf, ctx->noncelen,
-                      aad, aadlen, pt, ptlen, ct, ctlen) != 1) {
+    if (hpke_aead_enc(ctx, seqbuf, aad, aadlen, pt, ptlen, ct, ctlen) != 1) {
         ERR_raise(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR);
         OPENSSL_cleanse(seqbuf, sizeof(seqbuf));
         return 0;
@@ -1173,6 +1182,10 @@ int OSSL_HPKE_open(OSSL_HPKE_CTX *ctx,
         ERR_raise(ERR_LIB_CRYPTO, ERR_R_PASSED_NULL_PARAMETER);
         return 0;
     }
+    if (ctx->role != OSSL_HPKE_ROLE_RECEIVER) {
+        ERR_raise(ERR_LIB_CRYPTO, ERR_R_PASSED_INVALID_ARGUMENT);
+        return 0;
+    }
     if ((ctx->seq + 1) == 0) { /* wrap around imminent !!! */
         ERR_raise(ERR_LIB_CRYPTO, ERR_R_SHOULD_NOT_HAVE_BEEN_CALLED);
         return 0;
@@ -1187,9 +1200,7 @@ int OSSL_HPKE_open(OSSL_HPKE_CTX *ctx,
         ERR_raise(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR);
         return 0;
     }
-    if (hpke_aead_dec(ctx->libctx, ctx->propq, ctx->suite,
-                      ctx->key, ctx->keylen, seqbuf, ctx->noncelen,
-                      aad, aadlen, ct, ctlen, pt, ptlen) != 1) {
+    if (hpke_aead_dec(ctx, seqbuf, aad, aadlen, ct, ctlen, pt, ptlen) != 1) {
         ERR_raise(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR);
         OPENSSL_cleanse(seqbuf, sizeof(seqbuf));
         return 0;
@@ -1266,7 +1277,7 @@ int OSSL_HPKE_keygen(OSSL_HPKE_SUITE suite,
         ERR_raise(ERR_LIB_CRYPTO, ERR_R_PASSED_NULL_PARAMETER);
         return 0;
     }
-    if (hpke_suite_check(suite) != 1) {
+    if (hpke_suite_check(suite, &kem_info, NULL, NULL) != 1) {
         ERR_raise(ERR_LIB_CRYPTO, ERR_R_PASSED_INVALID_ARGUMENT);
         return 0;
     }
@@ -1277,9 +1288,6 @@ int OSSL_HPKE_keygen(OSSL_HPKE_SUITE suite,
         return 0;
     }
 
-    kem_info = ossl_HPKE_KEM_INFO_find_id(suite.kem_id);
-    if (kem_info == NULL)
-        return 0;
     if (hpke_kem_id_nist_curve(suite.kem_id) == 1) {
         *p++ = OSSL_PARAM_construct_utf8_string(OSSL_PKEY_PARAM_GROUP_NAME,
                                                 (char *)kem_info->groupname, 0);
@@ -1323,16 +1331,14 @@ err:
 
 int OSSL_HPKE_suite_check(OSSL_HPKE_SUITE suite)
 {
-    return hpke_suite_check(suite);
+    return hpke_suite_check(suite, NULL, NULL, NULL);
 }
 
-int OSSL_HPKE_get_grease_value(OSSL_LIB_CTX *libctx, const char *propq,
-                               const OSSL_HPKE_SUITE *suite_in,
+int OSSL_HPKE_get_grease_value(const OSSL_HPKE_SUITE *suite_in,
                                OSSL_HPKE_SUITE *suite,
-                               unsigned char *enc,
-                               size_t *enclen,
-                               unsigned char *ct,
-                               size_t ctlen)
+                               unsigned char *enc, size_t *enclen,
+                               unsigned char *ct, size_t ctlen,
+                               OSSL_LIB_CTX *libctx, const char *propq)
 {
     OSSL_HPKE_SUITE chosen;
     size_t plen = 0;
@@ -1354,17 +1360,7 @@ int OSSL_HPKE_get_grease_value(OSSL_LIB_CTX *libctx, const char *propq,
     } else {
         chosen = *suite_in;
     }
-    kem_info = ossl_HPKE_KEM_INFO_find_id(chosen.kem_id);
-    if (kem_info == NULL) {
-        ERR_raise(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR);
-        goto err;
-    }
-    aead_info = ossl_HPKE_AEAD_INFO_find_id(chosen.aead_id);
-    if (aead_info == NULL) {
-        ERR_raise(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR);
-        goto err;
-    }
-    if (hpke_suite_check(chosen) != 1) {
+    if (hpke_suite_check(chosen, &kem_info, NULL, &aead_info) != 1) {
         ERR_raise(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR);
         goto err;
     }
@@ -1432,8 +1428,10 @@ size_t OSSL_HPKE_get_recommended_ikmelen(OSSL_HPKE_SUITE suite)
 {
     const OSSL_HPKE_KEM_INFO *kem_info = NULL;
 
-    if (hpke_suite_check(suite) != 1)
+    if (hpke_suite_check(suite, &kem_info, NULL, NULL) != 1)
         return 0;
-    kem_info = ossl_HPKE_KEM_INFO_find_id(suite.kem_id);
+    if (kem_info == NULL)
+        return 0;
+
     return kem_info->Nsk;
 }