Add option to fipsinstall to disable fips security checks at run time.
[openssl.git] / providers / implementations / signature / rsa.c
1 /*
2  * Copyright 2019-2020 The OpenSSL Project Authors. All Rights Reserved.
3  *
4  * Licensed under the Apache License 2.0 (the "License").  You may not use
5  * this file except in compliance with the License.  You can obtain a copy
6  * in the file LICENSE in the source distribution or at
7  * https://www.openssl.org/source/license.html
8  */
9
10 /*
11  * RSA low level APIs are deprecated for public use, but still ok for
12  * internal use.
13  */
14 #include "internal/deprecated.h"
15
16 #include <string.h>
17 #include <openssl/crypto.h>
18 #include <openssl/core_dispatch.h>
19 #include <openssl/core_names.h>
20 #include <openssl/err.h>
21 #include <openssl/rsa.h>
22 #include <openssl/params.h>
23 #include <openssl/evp.h>
24 #include "internal/cryptlib.h"
25 #include "internal/nelem.h"
26 #include "internal/sizes.h"
27 #include "crypto/rsa.h"
28 #include "prov/providercommon.h"
29 #include "prov/providercommonerr.h"
30 #include "prov/implementations.h"
31 #include "prov/provider_ctx.h"
32 #include "prov/der_rsa.h"
33 #include "prov/securitycheck.h"
34
35 #define RSA_DEFAULT_DIGEST_NAME OSSL_DIGEST_NAME_SHA1
36
37 static OSSL_FUNC_signature_newctx_fn rsa_newctx;
38 static OSSL_FUNC_signature_sign_init_fn rsa_sign_init;
39 static OSSL_FUNC_signature_verify_init_fn rsa_verify_init;
40 static OSSL_FUNC_signature_verify_recover_init_fn rsa_verify_recover_init;
41 static OSSL_FUNC_signature_sign_fn rsa_sign;
42 static OSSL_FUNC_signature_verify_fn rsa_verify;
43 static OSSL_FUNC_signature_verify_recover_fn rsa_verify_recover;
44 static OSSL_FUNC_signature_digest_sign_init_fn rsa_digest_sign_init;
45 static OSSL_FUNC_signature_digest_sign_update_fn rsa_digest_signverify_update;
46 static OSSL_FUNC_signature_digest_sign_final_fn rsa_digest_sign_final;
47 static OSSL_FUNC_signature_digest_verify_init_fn rsa_digest_verify_init;
48 static OSSL_FUNC_signature_digest_verify_update_fn rsa_digest_signverify_update;
49 static OSSL_FUNC_signature_digest_verify_final_fn rsa_digest_verify_final;
50 static OSSL_FUNC_signature_freectx_fn rsa_freectx;
51 static OSSL_FUNC_signature_dupctx_fn rsa_dupctx;
52 static OSSL_FUNC_signature_get_ctx_params_fn rsa_get_ctx_params;
53 static OSSL_FUNC_signature_gettable_ctx_params_fn rsa_gettable_ctx_params;
54 static OSSL_FUNC_signature_set_ctx_params_fn rsa_set_ctx_params;
55 static OSSL_FUNC_signature_settable_ctx_params_fn rsa_settable_ctx_params;
56 static OSSL_FUNC_signature_get_ctx_md_params_fn rsa_get_ctx_md_params;
57 static OSSL_FUNC_signature_gettable_ctx_md_params_fn rsa_gettable_ctx_md_params;
58 static OSSL_FUNC_signature_set_ctx_md_params_fn rsa_set_ctx_md_params;
59 static OSSL_FUNC_signature_settable_ctx_md_params_fn rsa_settable_ctx_md_params;
60
61 static OSSL_ITEM padding_item[] = {
62     { RSA_PKCS1_PADDING,        OSSL_PKEY_RSA_PAD_MODE_PKCSV15 },
63     { RSA_SSLV23_PADDING,       OSSL_PKEY_RSA_PAD_MODE_SSLV23 },
64     { RSA_NO_PADDING,           OSSL_PKEY_RSA_PAD_MODE_NONE },
65     { RSA_X931_PADDING,         OSSL_PKEY_RSA_PAD_MODE_X931 },
66     { RSA_PKCS1_PSS_PADDING,    OSSL_PKEY_RSA_PAD_MODE_PSS },
67     { 0,                        NULL     }
68 };
69
70 /*
71  * What's passed as an actual key is defined by the KEYMGMT interface.
72  * We happen to know that our KEYMGMT simply passes RSA structures, so
73  * we use that here too.
74  */
75
76 typedef struct {
77     OPENSSL_CTX *libctx;
78     char *propq;
79     RSA *rsa;
80     int operation;
81
82     /*
83      * Flag to determine if the hash function can be changed (1) or not (0)
84      * Because it's dangerous to change during a DigestSign or DigestVerify
85      * operation, this flag is cleared by their Init function, and set again
86      * by their Final function.
87      */
88     unsigned int flag_allow_md : 1;
89
90     /* The Algorithm Identifier of the combined signature algorithm */
91     unsigned char aid_buf[128];
92     unsigned char *aid;
93     size_t  aid_len;
94
95     /* main digest */
96     EVP_MD *md;
97     EVP_MD_CTX *mdctx;
98     int mdnid;
99     char mdname[OSSL_MAX_NAME_SIZE]; /* Purely informational */
100
101     /* RSA padding mode */
102     int pad_mode;
103     /* message digest for MGF1 */
104     EVP_MD *mgf1_md;
105     char mgf1_mdname[OSSL_MAX_NAME_SIZE]; /* Purely informational */
106     /* PSS salt length */
107     int saltlen;
108     /* Minimum salt length or -1 if no PSS parameter restriction */
109     int min_saltlen;
110
111     /* Temp buffer */
112     unsigned char *tbuf;
113
114 } PROV_RSA_CTX;
115
116 static size_t rsa_get_md_size(const PROV_RSA_CTX *prsactx)
117 {
118     if (prsactx->md != NULL)
119         return EVP_MD_size(prsactx->md);
120     return 0;
121 }
122
123 static int rsa_check_padding(int mdnid, int padding)
124 {
125     if (padding == RSA_NO_PADDING) {
126         ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_PADDING_MODE);
127         return 0;
128     }
129
130     if (padding == RSA_X931_PADDING) {
131         if (RSA_X931_hash_id(mdnid) == -1) {
132             ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_X931_DIGEST);
133             return 0;
134         }
135     }
136
137     return 1;
138 }
139
140 static int rsa_check_parameters(PROV_RSA_CTX *prsactx)
141 {
142     if (prsactx->pad_mode == RSA_PKCS1_PSS_PADDING) {
143         int max_saltlen;
144
145         /* See if minimum salt length exceeds maximum possible */
146         max_saltlen = RSA_size(prsactx->rsa) - EVP_MD_size(prsactx->md);
147         if ((RSA_bits(prsactx->rsa) & 0x7) == 1)
148             max_saltlen--;
149         if (prsactx->min_saltlen < 0 || prsactx->min_saltlen > max_saltlen) {
150             ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_SALT_LENGTH);
151             return 0;
152         }
153     }
154     return 1;
155 }
156
157 static void *rsa_newctx(void *provctx, const char *propq)
158 {
159     PROV_RSA_CTX *prsactx = NULL;
160     char *propq_copy = NULL;
161
162     if (!ossl_prov_is_running())
163         return NULL;
164
165     if ((prsactx = OPENSSL_zalloc(sizeof(PROV_RSA_CTX))) == NULL
166         || (propq != NULL
167             && (propq_copy = OPENSSL_strdup(propq)) == NULL)) {
168         OPENSSL_free(prsactx);
169         ERR_raise(ERR_LIB_PROV, ERR_R_MALLOC_FAILURE);
170         return NULL;
171     }
172
173     prsactx->libctx = PROV_LIBRARY_CONTEXT_OF(provctx);
174     prsactx->flag_allow_md = 1;
175     prsactx->propq = propq_copy;
176     return prsactx;
177 }
178
179 /* True if PSS parameters are restricted */
180 #define rsa_pss_restricted(prsactx) (prsactx->min_saltlen != -1)
181
182 static int rsa_setup_md(PROV_RSA_CTX *ctx, const char *mdname,
183                         const char *mdprops)
184 {
185     if (mdprops == NULL)
186         mdprops = ctx->propq;
187
188     if (mdname != NULL) {
189         WPACKET pkt;
190         EVP_MD *md = EVP_MD_fetch(ctx->libctx, mdname, mdprops);
191         int sha1_allowed = (ctx->operation != EVP_PKEY_OP_SIGN);
192         int md_nid = digest_rsa_sign_get_md_nid(md, sha1_allowed);
193         size_t mdname_len = strlen(mdname);
194
195         if (md == NULL
196             || md_nid == NID_undef
197             || !rsa_check_padding(md_nid, ctx->pad_mode)
198             || mdname_len >= sizeof(ctx->mdname)) {
199             if (md == NULL)
200                 ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_DIGEST,
201                                "%s could not be fetched", mdname);
202             if (md_nid == NID_undef)
203                 ERR_raise_data(ERR_LIB_PROV, PROV_R_DIGEST_NOT_ALLOWED,
204                                "digest=%s", mdname);
205             if (mdname_len >= sizeof(ctx->mdname))
206                 ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_DIGEST,
207                                "%s exceeds name buffer length", mdname);
208             EVP_MD_free(md);
209             return 0;
210         }
211
212         EVP_MD_CTX_free(ctx->mdctx);
213         EVP_MD_free(ctx->md);
214
215         /*
216          * TODO(3.0) Should we care about DER writing errors?
217          * All it really means is that for some reason, there's no
218          * AlgorithmIdentifier to be had (consider RSA with MD5-SHA1),
219          * but the operation itself is still valid, just as long as it's
220          * not used to construct anything that needs an AlgorithmIdentifier.
221          */
222         ctx->aid_len = 0;
223         if (WPACKET_init_der(&pkt, ctx->aid_buf, sizeof(ctx->aid_buf))
224             && DER_w_algorithmIdentifier_MDWithRSAEncryption(&pkt, -1, ctx->rsa,
225                                                              md_nid)
226             && WPACKET_finish(&pkt)) {
227             WPACKET_get_total_written(&pkt, &ctx->aid_len);
228             ctx->aid = WPACKET_get_curr(&pkt);
229         }
230         WPACKET_cleanup(&pkt);
231
232         ctx->mdctx = NULL;
233         ctx->md = md;
234         ctx->mdnid = md_nid;
235         OPENSSL_strlcpy(ctx->mdname, mdname, sizeof(ctx->mdname));
236     }
237
238     return 1;
239 }
240
241 static int rsa_setup_mgf1_md(PROV_RSA_CTX *ctx, const char *mdname,
242                              const char *mdprops)
243 {
244     size_t len;
245     EVP_MD *md = NULL;
246
247     if (mdprops == NULL)
248         mdprops = ctx->propq;
249
250     if (ctx->mgf1_mdname[0] != '\0')
251         EVP_MD_free(ctx->mgf1_md);
252
253     if ((md = EVP_MD_fetch(ctx->libctx, mdname, mdprops)) == NULL) {
254         ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_DIGEST,
255                        "%s could not be fetched", mdname);
256         return 0;
257     }
258     /* The default for mgf1 is SHA1 - so allow SHA1 */
259     if (digest_rsa_sign_get_md_nid(md, 1) == NID_undef) {
260         ERR_raise_data(ERR_LIB_PROV, PROV_R_DIGEST_NOT_ALLOWED,
261                        "digest=%s", mdname);
262         EVP_MD_free(md);
263         return 0;
264     }
265     ctx->mgf1_md = md;
266     len = OPENSSL_strlcpy(ctx->mgf1_mdname, mdname, sizeof(ctx->mgf1_mdname));
267     if (len >= sizeof(ctx->mgf1_mdname)) {
268         ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_DIGEST,
269                        "%s exceeds name buffer length", mdname);
270         return 0;
271     }
272
273     return 1;
274 }
275
276 static int rsa_signverify_init(void *vprsactx, void *vrsa, int operation)
277 {
278     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
279
280     if (!ossl_prov_is_running())
281         return 0;
282
283     if (prsactx == NULL || vrsa == NULL || !RSA_up_ref(vrsa))
284         return 0;
285
286     RSA_free(prsactx->rsa);
287     prsactx->rsa = vrsa;
288     prsactx->operation = operation;
289
290     if (!rsa_check_key(vrsa, operation == EVP_PKEY_OP_SIGN)) {
291         ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_KEY_LENGTH);
292         return 0;
293     }
294
295     /* Maximum for sign, auto for verify */
296     prsactx->saltlen = RSA_PSS_SALTLEN_AUTO;
297     prsactx->min_saltlen = -1;
298
299     switch (RSA_test_flags(prsactx->rsa, RSA_FLAG_TYPE_MASK)) {
300     case RSA_FLAG_TYPE_RSA:
301         prsactx->pad_mode = RSA_PKCS1_PADDING;
302         break;
303     case RSA_FLAG_TYPE_RSASSAPSS:
304         prsactx->pad_mode = RSA_PKCS1_PSS_PADDING;
305
306         {
307             const RSA_PSS_PARAMS_30 *pss =
308                 rsa_get0_pss_params_30(prsactx->rsa);
309
310             if (!rsa_pss_params_30_is_unrestricted(pss)) {
311                 int md_nid = rsa_pss_params_30_hashalg(pss);
312                 int mgf1md_nid = rsa_pss_params_30_maskgenhashalg(pss);
313                 int min_saltlen = rsa_pss_params_30_saltlen(pss);
314                 const char *mdname, *mgf1mdname;
315                 size_t len;
316
317                 mdname = rsa_oaeppss_nid2name(md_nid);
318                 mgf1mdname = rsa_oaeppss_nid2name(mgf1md_nid);
319                 prsactx->min_saltlen = min_saltlen;
320
321                 if (mdname == NULL) {
322                     ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_DIGEST,
323                                    "PSS restrictions lack hash algorithm");
324                     return 0;
325                 }
326                 if (mgf1mdname == NULL) {
327                     ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_DIGEST,
328                                    "PSS restrictions lack MGF1 hash algorithm");
329                     return 0;
330                 }
331
332                 len = OPENSSL_strlcpy(prsactx->mdname, mdname,
333                                       sizeof(prsactx->mdname));
334                 if (len >= sizeof(prsactx->mdname)) {
335                     ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_DIGEST,
336                                    "hash algorithm name too long");
337                     return 0;
338                 }
339                 len = OPENSSL_strlcpy(prsactx->mgf1_mdname, mgf1mdname,
340                                       sizeof(prsactx->mgf1_mdname));
341                 if (len >= sizeof(prsactx->mgf1_mdname)) {
342                     ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_DIGEST,
343                                    "MGF1 hash algorithm name too long");
344                     return 0;
345                 }
346                 prsactx->saltlen = min_saltlen;
347
348                 return rsa_setup_md(prsactx, mdname, prsactx->propq)
349                     && rsa_setup_mgf1_md(prsactx, mgf1mdname, prsactx->propq)
350                     && rsa_check_parameters(prsactx);
351             }
352         }
353
354         break;
355     default:
356         ERR_raise(ERR_LIB_RSA, PROV_R_OPERATION_NOT_SUPPORTED_FOR_THIS_KEYTYPE);
357         return 0;
358     }
359
360     return 1;
361 }
362
363 static int setup_tbuf(PROV_RSA_CTX *ctx)
364 {
365     if (ctx->tbuf != NULL)
366         return 1;
367     if ((ctx->tbuf = OPENSSL_malloc(RSA_size(ctx->rsa))) == NULL) {
368         ERR_raise(ERR_LIB_PROV, ERR_R_MALLOC_FAILURE);
369         return 0;
370     }
371     return 1;
372 }
373
374 static void clean_tbuf(PROV_RSA_CTX *ctx)
375 {
376     if (ctx->tbuf != NULL)
377         OPENSSL_cleanse(ctx->tbuf, RSA_size(ctx->rsa));
378 }
379
380 static void free_tbuf(PROV_RSA_CTX *ctx)
381 {
382     clean_tbuf(ctx);
383     OPENSSL_free(ctx->tbuf);
384     ctx->tbuf = NULL;
385 }
386
387 static int rsa_sign_init(void *vprsactx, void *vrsa)
388 {
389     if (!ossl_prov_is_running())
390         return 0;
391     return rsa_signverify_init(vprsactx, vrsa, EVP_PKEY_OP_SIGN);
392 }
393
394 static int rsa_sign(void *vprsactx, unsigned char *sig, size_t *siglen,
395                     size_t sigsize, const unsigned char *tbs, size_t tbslen)
396 {
397     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
398     int ret;
399     size_t rsasize = RSA_size(prsactx->rsa);
400     size_t mdsize = rsa_get_md_size(prsactx);
401
402     if (!ossl_prov_is_running())
403         return 0;
404
405     if (sig == NULL) {
406         *siglen = rsasize;
407         return 1;
408     }
409
410     if (sigsize < rsasize) {
411         ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_SIGNATURE_SIZE,
412                        "is %zu, should be at least %zu", sigsize, rsasize);
413         return 0;
414     }
415
416     if (mdsize != 0) {
417         if (tbslen != mdsize) {
418             ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_DIGEST_LENGTH);
419             return 0;
420         }
421
422 #ifndef FIPS_MODULE
423         if (EVP_MD_is_a(prsactx->md, OSSL_DIGEST_NAME_MDC2)) {
424             unsigned int sltmp;
425
426             if (prsactx->pad_mode != RSA_PKCS1_PADDING) {
427                 ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_PADDING_MODE,
428                                "only PKCS#1 padding supported with MDC2");
429                 return 0;
430             }
431             ret = RSA_sign_ASN1_OCTET_STRING(0, tbs, tbslen, sig, &sltmp,
432                                              prsactx->rsa);
433
434             if (ret <= 0) {
435                 ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
436                 return 0;
437             }
438             ret = sltmp;
439             goto end;
440         }
441 #endif
442         switch (prsactx->pad_mode) {
443         case RSA_X931_PADDING:
444             if ((size_t)RSA_size(prsactx->rsa) < tbslen + 1) {
445                 ERR_raise_data(ERR_LIB_PROV, PROV_R_KEY_SIZE_TOO_SMALL,
446                                "RSA key size = %d, expected minimum = %d",
447                                RSA_size(prsactx->rsa), tbslen + 1);
448                 return 0;
449             }
450             if (!setup_tbuf(prsactx)) {
451                 ERR_raise(ERR_LIB_PROV, ERR_R_MALLOC_FAILURE);
452                 return 0;
453             }
454             memcpy(prsactx->tbuf, tbs, tbslen);
455             prsactx->tbuf[tbslen] = RSA_X931_hash_id(prsactx->mdnid);
456             ret = RSA_private_encrypt(tbslen + 1, prsactx->tbuf,
457                                       sig, prsactx->rsa, RSA_X931_PADDING);
458             clean_tbuf(prsactx);
459             break;
460
461         case RSA_PKCS1_PADDING:
462             {
463                 unsigned int sltmp;
464
465                 ret = RSA_sign(prsactx->mdnid, tbs, tbslen, sig, &sltmp,
466                                prsactx->rsa);
467                 if (ret <= 0) {
468                     ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
469                     return 0;
470                 }
471                 ret = sltmp;
472             }
473             break;
474
475         case RSA_PKCS1_PSS_PADDING:
476             /* Check PSS restrictions */
477             if (rsa_pss_restricted(prsactx)) {
478                 switch (prsactx->saltlen) {
479                 case RSA_PSS_SALTLEN_DIGEST:
480                     if (prsactx->min_saltlen > EVP_MD_size(prsactx->md)) {
481                         ERR_raise_data(ERR_LIB_PROV,
482                                        PROV_R_PSS_SALTLEN_TOO_SMALL,
483                                        "minimum salt length set to %d, "
484                                        "but the digest only gives %d",
485                                        prsactx->min_saltlen,
486                                        EVP_MD_size(prsactx->md));
487                         return 0;
488                     }
489                     /* FALLTHRU */
490                 default:
491                     if (prsactx->saltlen >= 0
492                         && prsactx->saltlen < prsactx->min_saltlen) {
493                         ERR_raise_data(ERR_LIB_PROV,
494                                        PROV_R_PSS_SALTLEN_TOO_SMALL,
495                                        "minimum salt length set to %d, but the"
496                                        "actual salt length is only set to %d",
497                                        prsactx->min_saltlen,
498                                        prsactx->saltlen);
499                         return 0;
500                     }
501                     break;
502                 }
503             }
504             if (!setup_tbuf(prsactx))
505                 return 0;
506             if (!RSA_padding_add_PKCS1_PSS_mgf1(prsactx->rsa,
507                                                 prsactx->tbuf, tbs,
508                                                 prsactx->md, prsactx->mgf1_md,
509                                                 prsactx->saltlen)) {
510                 ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
511                 return 0;
512             }
513             ret = RSA_private_encrypt(RSA_size(prsactx->rsa), prsactx->tbuf,
514                                       sig, prsactx->rsa, RSA_NO_PADDING);
515             clean_tbuf(prsactx);
516             break;
517
518         default:
519             ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_PADDING_MODE,
520                            "Only X.931, PKCS#1 v1.5 or PSS padding allowed");
521             return 0;
522         }
523     } else {
524         ret = RSA_private_encrypt(tbslen, tbs, sig, prsactx->rsa,
525                                   prsactx->pad_mode);
526     }
527
528 #ifndef FIPS_MODULE
529  end:
530 #endif
531     if (ret <= 0) {
532         ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
533         return 0;
534     }
535
536     *siglen = ret;
537     return 1;
538 }
539
540 static int rsa_verify_recover_init(void *vprsactx, void *vrsa)
541 {
542     if (!ossl_prov_is_running())
543         return 0;
544     return rsa_signverify_init(vprsactx, vrsa, EVP_PKEY_OP_VERIFYRECOVER);
545 }
546
547 static int rsa_verify_recover(void *vprsactx,
548                               unsigned char *rout,
549                               size_t *routlen,
550                               size_t routsize,
551                               const unsigned char *sig,
552                               size_t siglen)
553 {
554     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
555     int ret;
556
557     if (!ossl_prov_is_running())
558         return 0;
559
560     if (rout == NULL) {
561         *routlen = RSA_size(prsactx->rsa);
562         return 1;
563     }
564
565     if (prsactx->md != NULL) {
566         switch (prsactx->pad_mode) {
567         case RSA_X931_PADDING:
568             if (!setup_tbuf(prsactx))
569                 return 0;
570             ret = RSA_public_decrypt(siglen, sig, prsactx->tbuf, prsactx->rsa,
571                                      RSA_X931_PADDING);
572             if (ret < 1) {
573                 ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
574                 return 0;
575             }
576             ret--;
577             if (prsactx->tbuf[ret] != RSA_X931_hash_id(prsactx->mdnid)) {
578                 ERR_raise(ERR_LIB_PROV, PROV_R_ALGORITHM_MISMATCH);
579                 return 0;
580             }
581             if (ret != EVP_MD_size(prsactx->md)) {
582                 ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_DIGEST_LENGTH,
583                                "Should be %d, but got %d",
584                                EVP_MD_size(prsactx->md), ret);
585                 return 0;
586             }
587
588             *routlen = ret;
589             if (rout != prsactx->tbuf) {
590                 if (routsize < (size_t)ret) {
591                     ERR_raise_data(ERR_LIB_PROV, PROV_R_OUTPUT_BUFFER_TOO_SMALL,
592                                    "buffer size is %d, should be %d",
593                                    routsize, ret);
594                     return 0;
595                 }
596                 memcpy(rout, prsactx->tbuf, ret);
597             }
598             break;
599
600         case RSA_PKCS1_PADDING:
601             {
602                 size_t sltmp;
603
604                 ret = int_rsa_verify(prsactx->mdnid, NULL, 0, rout, &sltmp,
605                                      sig, siglen, prsactx->rsa);
606                 if (ret <= 0) {
607                     ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
608                     return 0;
609                 }
610                 ret = sltmp;
611             }
612             break;
613
614         default:
615             ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_PADDING_MODE,
616                            "Only X.931 or PKCS#1 v1.5 padding allowed");
617             return 0;
618         }
619     } else {
620         ret = RSA_public_decrypt(siglen, sig, rout, prsactx->rsa,
621                                  prsactx->pad_mode);
622         if (ret < 0) {
623             ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
624             return 0;
625         }
626     }
627     *routlen = ret;
628     return 1;
629 }
630
631 static int rsa_verify_init(void *vprsactx, void *vrsa)
632 {
633     if (!ossl_prov_is_running())
634         return 0;
635     return rsa_signverify_init(vprsactx, vrsa, EVP_PKEY_OP_VERIFY);
636 }
637
638 static int rsa_verify(void *vprsactx, const unsigned char *sig, size_t siglen,
639                       const unsigned char *tbs, size_t tbslen)
640 {
641     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
642     size_t rslen;
643
644     if (!ossl_prov_is_running())
645         return 0;
646     if (prsactx->md != NULL) {
647         switch (prsactx->pad_mode) {
648         case RSA_PKCS1_PADDING:
649             if (!RSA_verify(prsactx->mdnid, tbs, tbslen, sig, siglen,
650                             prsactx->rsa)) {
651                 ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
652                 return 0;
653             }
654             return 1;
655         case RSA_X931_PADDING:
656             if (!setup_tbuf(prsactx))
657                 return 0;
658             if (rsa_verify_recover(prsactx, prsactx->tbuf, &rslen, 0,
659                                    sig, siglen) <= 0)
660                 return 0;
661             break;
662         case RSA_PKCS1_PSS_PADDING:
663             {
664                 int ret;
665                 size_t mdsize;
666
667                 /*
668                  * We need to check this for the RSA_verify_PKCS1_PSS_mgf1()
669                  * call
670                  */
671                 mdsize = rsa_get_md_size(prsactx);
672                 if (tbslen != mdsize) {
673                     ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_DIGEST_LENGTH,
674                                    "Should be %d, but got %d",
675                                    mdsize, tbslen);
676                     return 0;
677                 }
678
679                 if (!setup_tbuf(prsactx))
680                     return 0;
681                 ret = RSA_public_decrypt(siglen, sig, prsactx->tbuf,
682                                          prsactx->rsa, RSA_NO_PADDING);
683                 if (ret <= 0) {
684                     ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
685                     return 0;
686                 }
687                 ret = RSA_verify_PKCS1_PSS_mgf1(prsactx->rsa, tbs,
688                                                 prsactx->md, prsactx->mgf1_md,
689                                                 prsactx->tbuf,
690                                                 prsactx->saltlen);
691                 if (ret <= 0) {
692                     ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
693                     return 0;
694                 }
695                 return 1;
696             }
697         default:
698             ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_PADDING_MODE,
699                            "Only X.931, PKCS#1 v1.5 or PSS padding allowed");
700             return 0;
701         }
702     } else {
703         if (!setup_tbuf(prsactx))
704             return 0;
705         rslen = RSA_public_decrypt(siglen, sig, prsactx->tbuf, prsactx->rsa,
706                                    prsactx->pad_mode);
707         if (rslen == 0) {
708             ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
709             return 0;
710         }
711     }
712
713     if ((rslen != tbslen) || memcmp(tbs, prsactx->tbuf, rslen))
714         return 0;
715
716     return 1;
717 }
718
719 static int rsa_digest_signverify_init(void *vprsactx, const char *mdname,
720                                       void *vrsa, int operation)
721 {
722     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
723
724     if (!ossl_prov_is_running())
725         return 0;
726
727     if (prsactx != NULL)
728         prsactx->flag_allow_md = 0;
729     if (!rsa_signverify_init(vprsactx, vrsa, operation)
730         || !rsa_setup_md(prsactx, mdname, NULL)) /* TODO RL */
731         return 0;
732
733     prsactx->mdctx = EVP_MD_CTX_new();
734     if (prsactx->mdctx == NULL) {
735         ERR_raise(ERR_LIB_PROV, ERR_R_MALLOC_FAILURE);
736         goto error;
737     }
738
739     if (!EVP_DigestInit_ex(prsactx->mdctx, prsactx->md, NULL))
740         goto error;
741
742     return 1;
743
744  error:
745     EVP_MD_CTX_free(prsactx->mdctx);
746     EVP_MD_free(prsactx->md);
747     prsactx->mdctx = NULL;
748     prsactx->md = NULL;
749     return 0;
750 }
751
752 static int rsa_digest_signverify_update(void *vprsactx,
753                                         const unsigned char *data,
754                                         size_t datalen)
755 {
756     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
757
758     if (prsactx == NULL || prsactx->mdctx == NULL)
759         return 0;
760
761     return EVP_DigestUpdate(prsactx->mdctx, data, datalen);
762 }
763
764 static int rsa_digest_sign_init(void *vprsactx, const char *mdname,
765                                 void *vrsa)
766 {
767     if (!ossl_prov_is_running())
768         return 0;
769     return rsa_digest_signverify_init(vprsactx, mdname, vrsa,
770                                       EVP_PKEY_OP_SIGN);
771 }
772
773 static int rsa_digest_sign_final(void *vprsactx, unsigned char *sig,
774                                  size_t *siglen, size_t sigsize)
775 {
776     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
777     unsigned char digest[EVP_MAX_MD_SIZE];
778     unsigned int dlen = 0;
779
780     if (!ossl_prov_is_running() || prsactx == NULL)
781         return 0;
782     prsactx->flag_allow_md = 1;
783     if (prsactx->mdctx == NULL)
784         return 0;
785     /*
786      * If sig is NULL then we're just finding out the sig size. Other fields
787      * are ignored. Defer to rsa_sign.
788      */
789     if (sig != NULL) {
790         /*
791          * The digests used here are all known (see rsa_get_md_nid()), so they
792          * should not exceed the internal buffer size of EVP_MAX_MD_SIZE.
793          */
794         if (!EVP_DigestFinal_ex(prsactx->mdctx, digest, &dlen))
795             return 0;
796     }
797
798     return rsa_sign(vprsactx, sig, siglen, sigsize, digest, (size_t)dlen);
799 }
800
801 static int rsa_digest_verify_init(void *vprsactx, const char *mdname,
802                                   void *vrsa)
803 {
804     if (!ossl_prov_is_running())
805         return 0;
806     return rsa_digest_signverify_init(vprsactx, mdname, vrsa,
807                                       EVP_PKEY_OP_VERIFY);
808 }
809
810 int rsa_digest_verify_final(void *vprsactx, const unsigned char *sig,
811                             size_t siglen)
812 {
813     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
814     unsigned char digest[EVP_MAX_MD_SIZE];
815     unsigned int dlen = 0;
816
817     if (!ossl_prov_is_running())
818         return 0;
819
820     if (prsactx == NULL)
821         return 0;
822     prsactx->flag_allow_md = 1;
823     if (prsactx->mdctx == NULL)
824         return 0;
825
826     /*
827      * The digests used here are all known (see rsa_get_md_nid()), so they
828      * should not exceed the internal buffer size of EVP_MAX_MD_SIZE.
829      */
830     if (!EVP_DigestFinal_ex(prsactx->mdctx, digest, &dlen))
831         return 0;
832
833     return rsa_verify(vprsactx, sig, siglen, digest, (size_t)dlen);
834 }
835
836 static void rsa_freectx(void *vprsactx)
837 {
838     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
839
840     if (prsactx == NULL)
841         return;
842
843     EVP_MD_CTX_free(prsactx->mdctx);
844     EVP_MD_free(prsactx->md);
845     EVP_MD_free(prsactx->mgf1_md);
846     OPENSSL_free(prsactx->propq);
847     free_tbuf(prsactx);
848     RSA_free(prsactx->rsa);
849
850     OPENSSL_clear_free(prsactx, sizeof(*prsactx));
851 }
852
853 static void *rsa_dupctx(void *vprsactx)
854 {
855     PROV_RSA_CTX *srcctx = (PROV_RSA_CTX *)vprsactx;
856     PROV_RSA_CTX *dstctx;
857
858     if (!ossl_prov_is_running())
859         return NULL;
860
861     dstctx = OPENSSL_zalloc(sizeof(*srcctx));
862     if (dstctx == NULL) {
863         ERR_raise(ERR_LIB_PROV, ERR_R_MALLOC_FAILURE);
864         return NULL;
865     }
866
867     *dstctx = *srcctx;
868     dstctx->rsa = NULL;
869     dstctx->md = NULL;
870     dstctx->mdctx = NULL;
871     dstctx->tbuf = NULL;
872
873     if (srcctx->rsa != NULL && !RSA_up_ref(srcctx->rsa))
874         goto err;
875     dstctx->rsa = srcctx->rsa;
876
877     if (srcctx->md != NULL && !EVP_MD_up_ref(srcctx->md))
878         goto err;
879     dstctx->md = srcctx->md;
880
881     if (srcctx->mgf1_md != NULL && !EVP_MD_up_ref(srcctx->mgf1_md))
882         goto err;
883     dstctx->mgf1_md = srcctx->mgf1_md;
884
885     if (srcctx->mdctx != NULL) {
886         dstctx->mdctx = EVP_MD_CTX_new();
887         if (dstctx->mdctx == NULL
888                 || !EVP_MD_CTX_copy_ex(dstctx->mdctx, srcctx->mdctx))
889             goto err;
890     }
891
892     return dstctx;
893  err:
894     rsa_freectx(dstctx);
895     return NULL;
896 }
897
898 static int rsa_get_ctx_params(void *vprsactx, OSSL_PARAM *params)
899 {
900     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
901     OSSL_PARAM *p;
902
903     if (prsactx == NULL || params == NULL)
904         return 0;
905
906     p = OSSL_PARAM_locate(params, OSSL_SIGNATURE_PARAM_ALGORITHM_ID);
907     if (p != NULL
908         && !OSSL_PARAM_set_octet_string(p, prsactx->aid, prsactx->aid_len))
909         return 0;
910
911     p = OSSL_PARAM_locate(params, OSSL_SIGNATURE_PARAM_PAD_MODE);
912     if (p != NULL)
913         switch (p->data_type) {
914         case OSSL_PARAM_INTEGER:
915             if (!OSSL_PARAM_set_int(p, prsactx->pad_mode))
916                 return 0;
917             break;
918         case OSSL_PARAM_UTF8_STRING:
919             {
920                 int i;
921                 const char *word = NULL;
922
923                 for (i = 0; padding_item[i].id != 0; i++) {
924                     if (prsactx->pad_mode == (int)padding_item[i].id) {
925                         word = padding_item[i].ptr;
926                         break;
927                     }
928                 }
929
930                 if (word != NULL) {
931                     if (!OSSL_PARAM_set_utf8_string(p, word))
932                         return 0;
933                 } else {
934                     ERR_raise(ERR_LIB_PROV, ERR_R_INTERNAL_ERROR);
935                 }
936             }
937             break;
938         default:
939             return 0;
940         }
941
942     p = OSSL_PARAM_locate(params, OSSL_SIGNATURE_PARAM_DIGEST);
943     if (p != NULL && !OSSL_PARAM_set_utf8_string(p, prsactx->mdname))
944         return 0;
945
946     p = OSSL_PARAM_locate(params, OSSL_SIGNATURE_PARAM_MGF1_DIGEST);
947     if (p != NULL && !OSSL_PARAM_set_utf8_string(p, prsactx->mgf1_mdname))
948         return 0;
949
950     p = OSSL_PARAM_locate(params, OSSL_SIGNATURE_PARAM_PSS_SALTLEN);
951     if (p != NULL) {
952         if (p->data_type == OSSL_PARAM_INTEGER) {
953             if (!OSSL_PARAM_set_int(p, prsactx->saltlen))
954                 return 0;
955         } else if (p->data_type == OSSL_PARAM_UTF8_STRING) {
956             const char *value = NULL;
957
958             switch (prsactx->saltlen) {
959             case RSA_PSS_SALTLEN_DIGEST:
960                 value = OSSL_PKEY_RSA_PSS_SALT_LEN_DIGEST;
961                 break;
962             case RSA_PSS_SALTLEN_MAX:
963                 value = OSSL_PKEY_RSA_PSS_SALT_LEN_MAX;
964                 break;
965             case RSA_PSS_SALTLEN_AUTO:
966                 value = OSSL_PKEY_RSA_PSS_SALT_LEN_AUTO;
967                 break;
968             default:
969                 {
970                     int len = BIO_snprintf(p->data, p->data_size, "%d",
971                                            prsactx->saltlen);
972
973                     if (len <= 0)
974                         return 0;
975                     p->return_size = len;
976                     break;
977                 }
978             }
979             if (value != NULL
980                 && !OSSL_PARAM_set_utf8_string(p, value))
981                 return 0;
982         }
983     }
984
985     return 1;
986 }
987
988 static const OSSL_PARAM known_gettable_ctx_params[] = {
989     OSSL_PARAM_octet_string(OSSL_SIGNATURE_PARAM_ALGORITHM_ID, NULL, 0),
990     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_PAD_MODE, NULL, 0),
991     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_DIGEST, NULL, 0),
992     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_MGF1_DIGEST, NULL, 0),
993     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_PSS_SALTLEN, NULL, 0),
994     OSSL_PARAM_END
995 };
996
997 static const OSSL_PARAM *rsa_gettable_ctx_params(ossl_unused void *vctx)
998 {
999     return known_gettable_ctx_params;
1000 }
1001
1002 static int rsa_set_ctx_params(void *vprsactx, const OSSL_PARAM params[])
1003 {
1004     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
1005     const OSSL_PARAM *p;
1006
1007     if (prsactx == NULL || params == NULL)
1008         return 0;
1009
1010     p = OSSL_PARAM_locate_const(params, OSSL_SIGNATURE_PARAM_DIGEST);
1011     /* Not allowed during certain operations */
1012     if (p != NULL && !prsactx->flag_allow_md)
1013         return 0;
1014     if (p != NULL) {
1015         char mdname[OSSL_MAX_NAME_SIZE] = "", *pmdname = mdname;
1016         char mdprops[OSSL_MAX_PROPQUERY_SIZE] = "", *pmdprops = mdprops;
1017         const OSSL_PARAM *propsp =
1018             OSSL_PARAM_locate_const(params,
1019                                     OSSL_SIGNATURE_PARAM_PROPERTIES);
1020
1021         if (!OSSL_PARAM_get_utf8_string(p, &pmdname, sizeof(mdname)))
1022             return 0;
1023
1024         if (propsp == NULL)
1025             pmdprops = NULL;
1026         else if (!OSSL_PARAM_get_utf8_string(propsp,
1027                                              &pmdprops, sizeof(mdprops)))
1028             return 0;
1029
1030         if (rsa_pss_restricted(prsactx)) {
1031             /* TODO(3.0) figure out what to do for prsactx->md == NULL */
1032             if (prsactx->md == NULL || EVP_MD_is_a(prsactx->md, mdname))
1033                 return 1;
1034             ERR_raise(ERR_LIB_PROV, PROV_R_DIGEST_NOT_ALLOWED);
1035             return 0;
1036         }
1037
1038         /* non-PSS code follows */
1039         if (!rsa_setup_md(prsactx, mdname, pmdprops))
1040             return 0;
1041     }
1042
1043     p = OSSL_PARAM_locate_const(params, OSSL_SIGNATURE_PARAM_PAD_MODE);
1044     if (p != NULL) {
1045         int pad_mode = 0;
1046         const char *err_extra_text = NULL;
1047
1048         switch (p->data_type) {
1049         case OSSL_PARAM_INTEGER: /* Support for legacy pad mode number */
1050             if (!OSSL_PARAM_get_int(p, &pad_mode))
1051                 return 0;
1052             break;
1053         case OSSL_PARAM_UTF8_STRING:
1054             {
1055                 int i;
1056
1057                 if (p->data == NULL)
1058                     return 0;
1059
1060                 for (i = 0; padding_item[i].id != 0; i++) {
1061                     if (strcmp(p->data, padding_item[i].ptr) == 0) {
1062                         pad_mode = padding_item[i].id;
1063                         break;
1064                     }
1065                 }
1066             }
1067             break;
1068         default:
1069             return 0;
1070         }
1071
1072         switch (pad_mode) {
1073         case RSA_PKCS1_OAEP_PADDING:
1074             /*
1075              * OAEP padding is for asymmetric cipher only so is not compatible
1076              * with signature use.
1077              */
1078             err_extra_text = "OAEP padding not allowed for signing / verifying";
1079             goto bad_pad;
1080         case RSA_PKCS1_PSS_PADDING:
1081             if ((prsactx->operation
1082                  & (EVP_PKEY_OP_SIGN | EVP_PKEY_OP_VERIFY)) == 0) {
1083                 err_extra_text =
1084                     "PSS padding only allowed for sign and verify operations";
1085                 goto bad_pad;
1086             }
1087             if (prsactx->md == NULL
1088                 && !rsa_setup_md(prsactx, RSA_DEFAULT_DIGEST_NAME, NULL)) {
1089                 return 0;
1090             }
1091             break;
1092         case RSA_PKCS1_PADDING:
1093             err_extra_text = "PKCS#1 padding not allowed with RSA-PSS";
1094             goto cont;
1095         case RSA_SSLV23_PADDING:
1096             err_extra_text = "SSLv3 padding not allowed with RSA-PSS";
1097             goto cont;
1098         case RSA_NO_PADDING:
1099             err_extra_text = "No padding not allowed with RSA-PSS";
1100             goto cont;
1101         case RSA_X931_PADDING:
1102             err_extra_text = "X.931 padding not allowed with RSA-PSS";
1103         cont:
1104             if (RSA_test_flags(prsactx->rsa,
1105                                RSA_FLAG_TYPE_MASK) == RSA_FLAG_TYPE_RSA)
1106                 break;
1107             /* FALLTHRU */
1108         default:
1109         bad_pad:
1110             if (err_extra_text == NULL)
1111                 ERR_raise(ERR_LIB_PROV,
1112                           PROV_R_ILLEGAL_OR_UNSUPPORTED_PADDING_MODE);
1113             else
1114                 ERR_raise_data(ERR_LIB_PROV,
1115                                PROV_R_ILLEGAL_OR_UNSUPPORTED_PADDING_MODE,
1116                                err_extra_text);
1117             return 0;
1118         }
1119         if (!rsa_check_padding(prsactx->mdnid, pad_mode))
1120             return 0;
1121         prsactx->pad_mode = pad_mode;
1122     }
1123
1124     p = OSSL_PARAM_locate_const(params, OSSL_SIGNATURE_PARAM_PSS_SALTLEN);
1125     if (p != NULL) {
1126         int saltlen;
1127
1128         if (prsactx->pad_mode != RSA_PKCS1_PSS_PADDING) {
1129             ERR_raise_data(ERR_LIB_PROV, PROV_R_NOT_SUPPORTED,
1130                            "PSS saltlen can only be specified if "
1131                            "PSS padding has been specified first");
1132             return 0;
1133         }
1134
1135         switch (p->data_type) {
1136         case OSSL_PARAM_INTEGER: /* Support for legacy pad mode number */
1137             if (!OSSL_PARAM_get_int(p, &saltlen))
1138                 return 0;
1139             break;
1140         case OSSL_PARAM_UTF8_STRING:
1141             if (strcmp(p->data, OSSL_PKEY_RSA_PSS_SALT_LEN_DIGEST) == 0)
1142                 saltlen = RSA_PSS_SALTLEN_DIGEST;
1143             else if (strcmp(p->data, OSSL_PKEY_RSA_PSS_SALT_LEN_MAX) == 0)
1144                 saltlen = RSA_PSS_SALTLEN_MAX;
1145             else if (strcmp(p->data, OSSL_PKEY_RSA_PSS_SALT_LEN_AUTO) == 0)
1146                 saltlen = RSA_PSS_SALTLEN_AUTO;
1147             else
1148                 saltlen = atoi(p->data);
1149             break;
1150         default:
1151             return 0;
1152         }
1153
1154         /*
1155          * RSA_PSS_SALTLEN_MAX seems curiously named in this check.
1156          * Contrary to what it's name suggests, it's the currently
1157          * lowest saltlen number possible.
1158          */
1159         if (saltlen < RSA_PSS_SALTLEN_MAX) {
1160             ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_PSS_SALTLEN);
1161             return 0;
1162         }
1163
1164         if (rsa_pss_restricted(prsactx)) {
1165             switch (saltlen) {
1166             case RSA_PSS_SALTLEN_AUTO:
1167                 if (prsactx->operation == EVP_PKEY_OP_VERIFY) {
1168                     ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_PSS_SALTLEN);
1169                     return 0;
1170                 }
1171                 break;
1172             case RSA_PSS_SALTLEN_DIGEST:
1173                 if (prsactx->min_saltlen > EVP_MD_size(prsactx->md)) {
1174                     ERR_raise_data(ERR_LIB_PROV,
1175                                    PROV_R_PSS_SALTLEN_TOO_SMALL,
1176                                    "Should be more than %d, but would be "
1177                                    "set to match digest size (%d)",
1178                                    prsactx->min_saltlen,
1179                                    EVP_MD_size(prsactx->md));
1180                     return 0;
1181                 }
1182                 break;
1183             default:
1184                 if (saltlen >= 0 && saltlen < prsactx->min_saltlen) {
1185                     ERR_raise_data(ERR_LIB_PROV,
1186                                    PROV_R_PSS_SALTLEN_TOO_SMALL,
1187                                    "Should be more than %d, "
1188                                    "but would be set to %d",
1189                                    prsactx->min_saltlen, saltlen);
1190                     return 0;
1191                 }
1192             }
1193         }
1194
1195         prsactx->saltlen = saltlen;
1196     }
1197
1198     p = OSSL_PARAM_locate_const(params, OSSL_SIGNATURE_PARAM_MGF1_DIGEST);
1199     if (p != NULL) {
1200         char mdname[OSSL_MAX_NAME_SIZE] = "", *pmdname = mdname;
1201         char mdprops[OSSL_MAX_PROPQUERY_SIZE] = "", *pmdprops = mdprops;
1202         const OSSL_PARAM *propsp =
1203             OSSL_PARAM_locate_const(params,
1204                                     OSSL_SIGNATURE_PARAM_MGF1_PROPERTIES);
1205
1206         if (!OSSL_PARAM_get_utf8_string(p, &pmdname, sizeof(mdname)))
1207             return 0;
1208
1209         if (propsp == NULL)
1210             pmdprops = NULL;
1211         else if (!OSSL_PARAM_get_utf8_string(propsp,
1212                                              &pmdprops, sizeof(mdprops)))
1213             return 0;
1214
1215         if (prsactx->pad_mode != RSA_PKCS1_PSS_PADDING) {
1216             ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_MGF1_MD);
1217             return  0;
1218         }
1219
1220         if (rsa_pss_restricted(prsactx)) {
1221             /* TODO(3.0) figure out what to do for prsactx->mgf1_md == NULL */
1222             if (prsactx->mgf1_md == NULL
1223                 || EVP_MD_is_a(prsactx->mgf1_md, mdname))
1224                 return 1;
1225             ERR_raise(ERR_LIB_PROV, PROV_R_DIGEST_NOT_ALLOWED);
1226             return 0;
1227         }
1228
1229         /* non-PSS code follows */
1230         if (!rsa_setup_mgf1_md(prsactx, mdname, pmdprops))
1231             return 0;
1232     }
1233
1234     return 1;
1235 }
1236
1237 static const OSSL_PARAM known_settable_ctx_params[] = {
1238     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_PAD_MODE, NULL, 0),
1239     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_DIGEST, NULL, 0),
1240     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_PROPERTIES, NULL, 0),
1241     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_MGF1_DIGEST, NULL, 0),
1242     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_MGF1_PROPERTIES, NULL, 0),
1243     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_PSS_SALTLEN, NULL, 0),
1244     OSSL_PARAM_END
1245 };
1246
1247 static const OSSL_PARAM *rsa_settable_ctx_params(ossl_unused void *provctx)
1248 {
1249     /*
1250      * TODO(3.0): Should this function return a different set of settable ctx
1251      * params if the ctx is being used for a DigestSign/DigestVerify? In that
1252      * case it is not allowed to set the digest size/digest name because the
1253      * digest is explicitly set as part of the init.
1254      */
1255     return known_settable_ctx_params;
1256 }
1257
1258 static int rsa_get_ctx_md_params(void *vprsactx, OSSL_PARAM *params)
1259 {
1260     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
1261
1262     if (prsactx->mdctx == NULL)
1263         return 0;
1264
1265     return EVP_MD_CTX_get_params(prsactx->mdctx, params);
1266 }
1267
1268 static const OSSL_PARAM *rsa_gettable_ctx_md_params(void *vprsactx)
1269 {
1270     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
1271
1272     if (prsactx->md == NULL)
1273         return 0;
1274
1275     return EVP_MD_gettable_ctx_params(prsactx->md);
1276 }
1277
1278 static int rsa_set_ctx_md_params(void *vprsactx, const OSSL_PARAM params[])
1279 {
1280     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
1281
1282     if (prsactx->mdctx == NULL)
1283         return 0;
1284
1285     return EVP_MD_CTX_set_params(prsactx->mdctx, params);
1286 }
1287
1288 static const OSSL_PARAM *rsa_settable_ctx_md_params(void *vprsactx)
1289 {
1290     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
1291
1292     if (prsactx->md == NULL)
1293         return 0;
1294
1295     return EVP_MD_settable_ctx_params(prsactx->md);
1296 }
1297
1298 const OSSL_DISPATCH rsa_signature_functions[] = {
1299     { OSSL_FUNC_SIGNATURE_NEWCTX, (void (*)(void))rsa_newctx },
1300     { OSSL_FUNC_SIGNATURE_SIGN_INIT, (void (*)(void))rsa_sign_init },
1301     { OSSL_FUNC_SIGNATURE_SIGN, (void (*)(void))rsa_sign },
1302     { OSSL_FUNC_SIGNATURE_VERIFY_INIT, (void (*)(void))rsa_verify_init },
1303     { OSSL_FUNC_SIGNATURE_VERIFY, (void (*)(void))rsa_verify },
1304     { OSSL_FUNC_SIGNATURE_VERIFY_RECOVER_INIT,
1305       (void (*)(void))rsa_verify_recover_init },
1306     { OSSL_FUNC_SIGNATURE_VERIFY_RECOVER,
1307       (void (*)(void))rsa_verify_recover },
1308     { OSSL_FUNC_SIGNATURE_DIGEST_SIGN_INIT,
1309       (void (*)(void))rsa_digest_sign_init },
1310     { OSSL_FUNC_SIGNATURE_DIGEST_SIGN_UPDATE,
1311       (void (*)(void))rsa_digest_signverify_update },
1312     { OSSL_FUNC_SIGNATURE_DIGEST_SIGN_FINAL,
1313       (void (*)(void))rsa_digest_sign_final },
1314     { OSSL_FUNC_SIGNATURE_DIGEST_VERIFY_INIT,
1315       (void (*)(void))rsa_digest_verify_init },
1316     { OSSL_FUNC_SIGNATURE_DIGEST_VERIFY_UPDATE,
1317       (void (*)(void))rsa_digest_signverify_update },
1318     { OSSL_FUNC_SIGNATURE_DIGEST_VERIFY_FINAL,
1319       (void (*)(void))rsa_digest_verify_final },
1320     { OSSL_FUNC_SIGNATURE_FREECTX, (void (*)(void))rsa_freectx },
1321     { OSSL_FUNC_SIGNATURE_DUPCTX, (void (*)(void))rsa_dupctx },
1322     { OSSL_FUNC_SIGNATURE_GET_CTX_PARAMS, (void (*)(void))rsa_get_ctx_params },
1323     { OSSL_FUNC_SIGNATURE_GETTABLE_CTX_PARAMS,
1324       (void (*)(void))rsa_gettable_ctx_params },
1325     { OSSL_FUNC_SIGNATURE_SET_CTX_PARAMS, (void (*)(void))rsa_set_ctx_params },
1326     { OSSL_FUNC_SIGNATURE_SETTABLE_CTX_PARAMS,
1327       (void (*)(void))rsa_settable_ctx_params },
1328     { OSSL_FUNC_SIGNATURE_GET_CTX_MD_PARAMS,
1329       (void (*)(void))rsa_get_ctx_md_params },
1330     { OSSL_FUNC_SIGNATURE_GETTABLE_CTX_MD_PARAMS,
1331       (void (*)(void))rsa_gettable_ctx_md_params },
1332     { OSSL_FUNC_SIGNATURE_SET_CTX_MD_PARAMS,
1333       (void (*)(void))rsa_set_ctx_md_params },
1334     { OSSL_FUNC_SIGNATURE_SETTABLE_CTX_MD_PARAMS,
1335       (void (*)(void))rsa_settable_ctx_md_params },
1336     { 0, NULL }
1337 };