EVP: For SIGNATURE operations, pass the propquery early
[openssl.git] / providers / implementations / signature / rsa.c
index a59b234a2c7861c11f6601b5dbfce7b611d1f6ef..27e35be3c9aeb063d0430e9d73acabe85bb7617e 100644 (file)
@@ -73,6 +73,7 @@ static OSSL_ITEM padding_item[] = {
 
 typedef struct {
     OPENSSL_CTX *libctx;
 
 typedef struct {
     OPENSSL_CTX *libctx;
+    char *propq;
     RSA *rsa;
     int operation;
 
     RSA *rsa;
     int operation;
 
@@ -180,7 +181,7 @@ static int rsa_check_padding(int mdnid, int padding)
     return 1;
 }
 
     return 1;
 }
 
-static void *rsa_newctx(void *provctx)
+static void *rsa_newctx(void *provctx, const char *propq)
 {
     PROV_RSA_CTX *prsactx = OPENSSL_zalloc(sizeof(PROV_RSA_CTX));
 
 {
     PROV_RSA_CTX *prsactx = OPENSSL_zalloc(sizeof(PROV_RSA_CTX));
 
@@ -189,6 +190,11 @@ static void *rsa_newctx(void *provctx)
 
     prsactx->libctx = PROV_LIBRARY_CONTEXT_OF(provctx);
     prsactx->flag_allow_md = 1;
 
     prsactx->libctx = PROV_LIBRARY_CONTEXT_OF(provctx);
     prsactx->flag_allow_md = 1;
+    if (propq != NULL && (prsactx->propq = OPENSSL_strdup(propq)) == NULL) {
+        OPENSSL_free(prsactx);
+        prsactx = NULL;
+        ERR_raise(ERR_LIB_PROV, ERR_R_MALLOC_FAILURE);
+    }
     return prsactx;
 }
 
     return prsactx;
 }
 
@@ -219,6 +225,9 @@ static int rsa_signature_init(void *vprsactx, void *vrsa, int operation)
 static int rsa_setup_md(PROV_RSA_CTX *ctx, const char *mdname,
                         const char *mdprops)
 {
 static int rsa_setup_md(PROV_RSA_CTX *ctx, const char *mdname,
                         const char *mdprops)
 {
+    if (mdprops == NULL)
+        mdprops = ctx->propq;
+
     if (mdname != NULL) {
         EVP_MD *md = EVP_MD_fetch(ctx->libctx, mdname, mdprops);
         int md_nid = rsa_get_md_nid(md);
     if (mdname != NULL) {
         EVP_MD *md = EVP_MD_fetch(ctx->libctx, mdname, mdprops);
         int md_nid = rsa_get_md_nid(md);
@@ -260,12 +269,15 @@ static int rsa_setup_md(PROV_RSA_CTX *ctx, const char *mdname,
 }
 
 static int rsa_setup_mgf1_md(PROV_RSA_CTX *ctx, const char *mdname,
 }
 
 static int rsa_setup_mgf1_md(PROV_RSA_CTX *ctx, const char *mdname,
-                             const char *props)
+                             const char *mdprops)
 {
 {
+    if (mdprops == NULL)
+        mdprops = ctx->propq;
+
     if (ctx->mgf1_mdname[0] != '\0')
         EVP_MD_free(ctx->mgf1_md);
 
     if (ctx->mgf1_mdname[0] != '\0')
         EVP_MD_free(ctx->mgf1_md);
 
-    if ((ctx->mgf1_md = EVP_MD_fetch(ctx->libctx, mdname, props)) == NULL)
+    if ((ctx->mgf1_md = EVP_MD_fetch(ctx->libctx, mdname, mdprops)) == NULL)
         return 0;
     OPENSSL_strlcpy(ctx->mgf1_mdname, mdname, sizeof(ctx->mgf1_mdname));
 
         return 0;
     OPENSSL_strlcpy(ctx->mgf1_mdname, mdname, sizeof(ctx->mgf1_mdname));
 
@@ -592,14 +604,13 @@ static int rsa_verify(void *vprsactx, const unsigned char *sig, size_t siglen,
 }
 
 static int rsa_digest_signverify_init(void *vprsactx, const char *mdname,
 }
 
 static int rsa_digest_signverify_init(void *vprsactx, const char *mdname,
-                                      const char *props, void *vrsa,
-                                      int operation)
+                                      void *vrsa, int operation)
 {
     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
 
     prsactx->flag_allow_md = 0;
     if (!rsa_signature_init(vprsactx, vrsa, operation)
 {
     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
 
     prsactx->flag_allow_md = 0;
     if (!rsa_signature_init(vprsactx, vrsa, operation)
-        || !rsa_setup_md(prsactx, mdname, props))
+        || !rsa_setup_md(prsactx, mdname, NULL))
         return 0;
 
     prsactx->mdctx = EVP_MD_CTX_new();
         return 0;
 
     prsactx->mdctx = EVP_MD_CTX_new();
@@ -706,6 +717,7 @@ static void rsa_freectx(void *vprsactx)
     EVP_MD_CTX_free(prsactx->mdctx);
     EVP_MD_free(prsactx->md);
     EVP_MD_free(prsactx->mgf1_md);
     EVP_MD_CTX_free(prsactx->mdctx);
     EVP_MD_free(prsactx->md);
     EVP_MD_free(prsactx->mgf1_md);
+    OPENSSL_free(prsactx->propq);
     free_tbuf(prsactx);
 
     OPENSSL_clear_free(prsactx, sizeof(prsactx));
     free_tbuf(prsactx);
 
     OPENSSL_clear_free(prsactx, sizeof(prsactx));
@@ -869,8 +881,11 @@ static int rsa_set_ctx_params(void *vprsactx, const OSSL_PARAM params[])
 
         if (!OSSL_PARAM_get_utf8_string(p, &pmdname, sizeof(mdname)))
             return 0;
 
         if (!OSSL_PARAM_get_utf8_string(p, &pmdname, sizeof(mdname)))
             return 0;
-        if (propsp != NULL
-            && !OSSL_PARAM_get_utf8_string(propsp, &pmdprops, sizeof(mdprops)))
+
+        if (propsp == NULL)
+            pmdprops = NULL;
+        else if (!OSSL_PARAM_get_utf8_string(propsp,
+                                             &pmdprops, sizeof(mdprops)))
             return 0;
 
         /* TODO(3.0) PSS check needs more work */
             return 0;
 
         /* TODO(3.0) PSS check needs more work */
@@ -883,7 +898,7 @@ static int rsa_set_ctx_params(void *vprsactx, const OSSL_PARAM params[])
         }
 
         /* non-PSS code follows */
         }
 
         /* non-PSS code follows */
-        if (!rsa_setup_md(prsactx, mdname, mdprops))
+        if (!rsa_setup_md(prsactx, mdname, pmdprops))
             return 0;
     }
 
             return 0;
     }
 
@@ -1054,8 +1069,11 @@ static int rsa_set_ctx_params(void *vprsactx, const OSSL_PARAM params[])
 
         if (!OSSL_PARAM_get_utf8_string(p, &pmdname, sizeof(mdname)))
             return 0;
 
         if (!OSSL_PARAM_get_utf8_string(p, &pmdname, sizeof(mdname)))
             return 0;
-        if (propsp != NULL
-            && !OSSL_PARAM_get_utf8_string(propsp, &pmdprops, sizeof(mdprops)))
+
+        if (propsp == NULL)
+            pmdprops = NULL;
+        else if (!OSSL_PARAM_get_utf8_string(propsp,
+                                             &pmdprops, sizeof(mdprops)))
             return 0;
 
         if (prsactx->pad_mode != RSA_PKCS1_PSS_PADDING) {
             return 0;
 
         if (prsactx->pad_mode != RSA_PKCS1_PSS_PADDING) {
@@ -1073,7 +1091,7 @@ static int rsa_set_ctx_params(void *vprsactx, const OSSL_PARAM params[])
         }
 
         /* non-PSS code follows */
         }
 
         /* non-PSS code follows */
-        if (!rsa_setup_mgf1_md(prsactx, mdname, mdprops))
+        if (!rsa_setup_mgf1_md(prsactx, mdname, pmdprops))
             return 0;
     }
 
             return 0;
     }