Change rsa gen so it can use the propq from OSSL_PKEY_PARAM_RSA_DIGEST
[openssl.git] / crypto / rsa / rsa_backend.c
index 871aa17a2222333e5db2f9349544c58c4f0eaafd..fae09d706752469e30af960f11d3feebd64cbea1 100644 (file)
@@ -163,7 +163,7 @@ int rsa_todata(RSA *rsa, OSSL_PARAM_BLD *bld, OSSL_PARAM params[])
     return ret;
 }
 
-int rsa_pss_params_30_todata(const RSA_PSS_PARAMS_30 *pss, const char *propq,
+int rsa_pss_params_30_todata(const RSA_PSS_PARAMS_30 *pss,
                              OSSL_PARAM_BLD *bld, OSSL_PARAM params[])
 {
     if (!rsa_pss_params_30_is_unrestricted(pss)) {
@@ -211,13 +211,16 @@ int rsa_pss_params_30_fromdata(RSA_PSS_PARAMS_30 *pss_params,
                                const OSSL_PARAM params[], OPENSSL_CTX *libctx)
 {
     const OSSL_PARAM *param_md, *param_mgf, *param_mgf1md,  *param_saltlen;
+    const OSSL_PARAM *param_propq;
+    const char *propq = NULL;
     EVP_MD *md = NULL, *mgf1md = NULL;
     int saltlen;
     int ret = 0;
 
     if (pss_params == NULL)
         return 0;
-
+    param_propq =
+        OSSL_PARAM_locate_const(params, OSSL_PKEY_PARAM_RSA_DIGEST_PROPS);
     param_md =
         OSSL_PARAM_locate_const(params, OSSL_PKEY_PARAM_RSA_DIGEST);
     param_mgf =
@@ -227,6 +230,10 @@ int rsa_pss_params_30_fromdata(RSA_PSS_PARAMS_30 *pss_params,
     param_saltlen =
         OSSL_PARAM_locate_const(params, OSSL_PKEY_PARAM_RSA_PSS_SALTLEN);
 
+    if (param_propq != NULL) {
+        if (param_propq->data_type == OSSL_PARAM_UTF8_STRING)
+            propq = param_propq->data;
+    }
     /*
      * If we get any of the parameters, we know we have at least some
      * restrictions, so we start by setting default values, and let each
@@ -265,7 +272,7 @@ int rsa_pss_params_30_fromdata(RSA_PSS_PARAMS_30 *pss_params,
         else if (!OSSL_PARAM_get_utf8_ptr(param_mgf, &mdname))
             goto err;
 
-        if ((md = EVP_MD_fetch(libctx, mdname, NULL)) == NULL
+        if ((md = EVP_MD_fetch(libctx, mdname, propq)) == NULL
             || !rsa_pss_params_30_set_hashalg(pss_params,
                                               rsa_oaeppss_md2nid(md)))
             goto err;
@@ -279,7 +286,7 @@ int rsa_pss_params_30_fromdata(RSA_PSS_PARAMS_30 *pss_params,
         else if (!OSSL_PARAM_get_utf8_ptr(param_mgf, &mgf1mdname))
             goto err;
 
-        if ((mgf1md = EVP_MD_fetch(libctx, mgf1mdname, NULL)) == NULL
+        if ((mgf1md = EVP_MD_fetch(libctx, mgf1mdname, propq)) == NULL
             || !rsa_pss_params_30_set_maskgenhashalg(pss_params,
                                                      rsa_oaeppss_md2nid(mgf1md)))
             goto err;