0e3885ec1dcfed8bcfecebb5377978c6eca30189
[openssl.git] / providers / implementations / signature / rsa.c
1 /*
2  * Copyright 2019-2020 The OpenSSL Project Authors. All Rights Reserved.
3  *
4  * Licensed under the Apache License 2.0 (the "License").  You may not use
5  * this file except in compliance with the License.  You can obtain a copy
6  * in the file LICENSE in the source distribution or at
7  * https://www.openssl.org/source/license.html
8  */
9
10 /*
11  * RSA low level APIs are deprecated for public use, but still ok for
12  * internal use.
13  */
14 #include "internal/deprecated.h"
15
16 #include <string.h>
17 #include <openssl/crypto.h>
18 #include <openssl/core_numbers.h>
19 #include <openssl/core_names.h>
20 #include <openssl/err.h>
21 #include <openssl/rsa.h>
22 #include <openssl/params.h>
23 #include <openssl/evp.h>
24 #include "internal/cryptlib.h"
25 #include "internal/nelem.h"
26 #include "internal/sizes.h"
27 #include "crypto/rsa.h"
28 #include "prov/providercommonerr.h"
29 #include "prov/implementations.h"
30 #include "prov/provider_ctx.h"
31 #include "prov/der_rsa.h"
32
33 static OSSL_OP_signature_newctx_fn rsa_newctx;
34 static OSSL_OP_signature_sign_init_fn rsa_sign_init;
35 static OSSL_OP_signature_verify_init_fn rsa_verify_init;
36 static OSSL_OP_signature_verify_recover_init_fn rsa_verify_recover_init;
37 static OSSL_OP_signature_sign_fn rsa_sign;
38 static OSSL_OP_signature_verify_fn rsa_verify;
39 static OSSL_OP_signature_verify_recover_fn rsa_verify_recover;
40 static OSSL_OP_signature_digest_sign_init_fn rsa_digest_sign_init;
41 static OSSL_OP_signature_digest_sign_update_fn rsa_digest_signverify_update;
42 static OSSL_OP_signature_digest_sign_final_fn rsa_digest_sign_final;
43 static OSSL_OP_signature_digest_verify_init_fn rsa_digest_verify_init;
44 static OSSL_OP_signature_digest_verify_update_fn rsa_digest_signverify_update;
45 static OSSL_OP_signature_digest_verify_final_fn rsa_digest_verify_final;
46 static OSSL_OP_signature_freectx_fn rsa_freectx;
47 static OSSL_OP_signature_dupctx_fn rsa_dupctx;
48 static OSSL_OP_signature_get_ctx_params_fn rsa_get_ctx_params;
49 static OSSL_OP_signature_gettable_ctx_params_fn rsa_gettable_ctx_params;
50 static OSSL_OP_signature_set_ctx_params_fn rsa_set_ctx_params;
51 static OSSL_OP_signature_settable_ctx_params_fn rsa_settable_ctx_params;
52 static OSSL_OP_signature_get_ctx_md_params_fn rsa_get_ctx_md_params;
53 static OSSL_OP_signature_gettable_ctx_md_params_fn rsa_gettable_ctx_md_params;
54 static OSSL_OP_signature_set_ctx_md_params_fn rsa_set_ctx_md_params;
55 static OSSL_OP_signature_settable_ctx_md_params_fn rsa_settable_ctx_md_params;
56
57 static OSSL_ITEM padding_item[] = {
58     { RSA_PKCS1_PADDING,        "pkcs1"  },
59     { RSA_SSLV23_PADDING,       "sslv23" },
60     { RSA_NO_PADDING,           "none"   },
61     { RSA_PKCS1_OAEP_PADDING,   "oaep"   }, /* Correct spelling first */
62     { RSA_PKCS1_OAEP_PADDING,   "oeap"   },
63     { RSA_X931_PADDING,         "x931"   },
64     { RSA_PKCS1_PSS_PADDING,    "pss"    },
65     { 0,                        NULL     }
66 };
67
68 /*
69  * What's passed as an actual key is defined by the KEYMGMT interface.
70  * We happen to know that our KEYMGMT simply passes RSA structures, so
71  * we use that here too.
72  */
73
74 typedef struct {
75     OPENSSL_CTX *libctx;
76     char *propq;
77     RSA *rsa;
78     int operation;
79
80     /*
81      * Flag to determine if the hash function can be changed (1) or not (0)
82      * Because it's dangerous to change during a DigestSign or DigestVerify
83      * operation, this flag is cleared by their Init function, and set again
84      * by their Final function.
85      */
86     unsigned int flag_allow_md : 1;
87
88     /* The Algorithm Identifier of the combined signature agorithm */
89     unsigned char aid_buf[128];
90     unsigned char *aid;
91     size_t  aid_len;
92
93     /* main digest */
94     EVP_MD *md;
95     EVP_MD_CTX *mdctx;
96     int mdnid;
97     char mdname[OSSL_MAX_NAME_SIZE]; /* Purely informational */
98
99     /* RSA padding mode */
100     int pad_mode;
101     /* message digest for MGF1 */
102     EVP_MD *mgf1_md;
103     char mgf1_mdname[OSSL_MAX_NAME_SIZE]; /* Purely informational */
104     /* PSS salt length */
105     int saltlen;
106     /* Minimum salt length or -1 if no PSS parameter restriction */
107     int min_saltlen;
108
109     /* Temp buffer */
110     unsigned char *tbuf;
111
112 } PROV_RSA_CTX;
113
114 static size_t rsa_get_md_size(const PROV_RSA_CTX *prsactx)
115 {
116     if (prsactx->md != NULL)
117         return EVP_MD_size(prsactx->md);
118     return 0;
119 }
120
121 static int rsa_get_md_nid(const EVP_MD *md)
122 {
123     /*
124      * Because the RSA library deals with NIDs, we need to translate.
125      * We do so using EVP_MD_is_a(), and therefore need a name to NID
126      * map.
127      */
128     static const OSSL_ITEM name_to_nid[] = {
129         { NID_sha1,      OSSL_DIGEST_NAME_SHA1      },
130         { NID_sha224,    OSSL_DIGEST_NAME_SHA2_224  },
131         { NID_sha256,    OSSL_DIGEST_NAME_SHA2_256  },
132         { NID_sha384,    OSSL_DIGEST_NAME_SHA2_384  },
133         { NID_sha512,    OSSL_DIGEST_NAME_SHA2_512  },
134         { NID_sha512_224, OSSL_DIGEST_NAME_SHA2_512_224 },
135         { NID_sha512_256, OSSL_DIGEST_NAME_SHA2_512_256 },
136         { NID_md5,       OSSL_DIGEST_NAME_MD5       },
137         { NID_md5_sha1,  OSSL_DIGEST_NAME_MD5_SHA1  },
138         { NID_md2,       OSSL_DIGEST_NAME_MD2       },
139         { NID_md4,       OSSL_DIGEST_NAME_MD4       },
140         { NID_mdc2,      OSSL_DIGEST_NAME_MDC2      },
141         { NID_ripemd160, OSSL_DIGEST_NAME_RIPEMD160 },
142         { NID_sha3_224,  OSSL_DIGEST_NAME_SHA3_224  },
143         { NID_sha3_256,  OSSL_DIGEST_NAME_SHA3_256  },
144         { NID_sha3_384,  OSSL_DIGEST_NAME_SHA3_384  },
145         { NID_sha3_512,  OSSL_DIGEST_NAME_SHA3_512  },
146     };
147     size_t i;
148     int mdnid = NID_undef;
149
150     if (md == NULL)
151         goto end;
152
153     for (i = 0; i < OSSL_NELEM(name_to_nid); i++) {
154         if (EVP_MD_is_a(md, name_to_nid[i].ptr)) {
155             mdnid = (int)name_to_nid[i].id;
156             break;
157         }
158     }
159
160  end:
161     return mdnid;
162 }
163
164 static int rsa_check_padding(int mdnid, int padding)
165 {
166     if (padding == RSA_NO_PADDING) {
167         ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_PADDING_MODE);
168         return 0;
169     }
170
171     if (padding == RSA_X931_PADDING) {
172         if (RSA_X931_hash_id(mdnid) == -1) {
173             ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_X931_DIGEST);
174             return 0;
175         }
176     }
177
178     return 1;
179 }
180
181 static int rsa_check_parameters(EVP_MD *md, PROV_RSA_CTX *prsactx)
182 {
183     if (prsactx->pad_mode == RSA_PKCS1_PSS_PADDING) {
184         int max_saltlen;
185
186         /* See if minimum salt length exceeds maximum possible */
187         max_saltlen = RSA_size(prsactx->rsa) - EVP_MD_size(md);
188         if ((RSA_bits(prsactx->rsa) & 0x7) == 1)
189             max_saltlen--;
190         if (prsactx->min_saltlen > max_saltlen) {
191             ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_SALT_LENGTH);
192             return 0;
193         }
194     }
195     return 1;
196 }
197
198 static void *rsa_newctx(void *provctx, const char *propq)
199 {
200     PROV_RSA_CTX *prsactx = NULL;
201     char *propq_copy = NULL;
202
203     if ((prsactx = OPENSSL_zalloc(sizeof(PROV_RSA_CTX))) == NULL
204         || (propq != NULL
205             && (propq_copy = OPENSSL_strdup(propq)) == NULL)) {
206         OPENSSL_free(prsactx);
207         ERR_raise(ERR_LIB_PROV, ERR_R_MALLOC_FAILURE);
208         return NULL;
209     }
210
211     prsactx->libctx = PROV_LIBRARY_CONTEXT_OF(provctx);
212     prsactx->flag_allow_md = 1;
213     prsactx->propq = propq_copy;
214     return prsactx;
215 }
216
217 /* True if PSS parameters are restricted */
218 #define rsa_pss_restricted(prsactx) (prsactx->min_saltlen != -1)
219
220 static int rsa_setup_md(PROV_RSA_CTX *ctx, const char *mdname,
221                         const char *mdprops)
222 {
223     if (mdprops == NULL)
224         mdprops = ctx->propq;
225
226     if (mdname != NULL) {
227         EVP_MD *md = EVP_MD_fetch(ctx->libctx, mdname, mdprops);
228         int md_nid = rsa_get_md_nid(md);
229         WPACKET pkt;
230         size_t mdname_len = strlen(mdname);
231
232         if (md == NULL
233             || md_nid == NID_undef
234             || !rsa_check_padding(md_nid, ctx->pad_mode)
235             || !rsa_check_parameters(md, ctx)
236             || mdname_len >= sizeof(ctx->mdname)) {
237             if (md == NULL)
238                 ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_DIGEST,
239                                "%s could not be fetched", mdname);
240             if (md_nid == NID_undef)
241                 ERR_raise_data(ERR_LIB_PROV, PROV_R_DIGEST_NOT_ALLOWED,
242                                "digest=%s", mdname);
243             if (mdname_len >= sizeof(ctx->mdname))
244                 ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_DIGEST,
245                                "%s exceeds name buffer length", mdname);
246             EVP_MD_free(md);
247             return 0;
248         }
249
250         EVP_MD_CTX_free(ctx->mdctx);
251         EVP_MD_free(ctx->md);
252
253         /*
254          * TODO(3.0) Should we care about DER writing errors?
255          * All it really means is that for some reason, there's no
256          * AlgorithmIdentifier to be had (consider RSA with MD5-SHA1),
257          * but the operation itself is still valid, just as long as it's
258          * not used to construct anything that needs an AlgorithmIdentifier.
259          */
260         ctx->aid_len = 0;
261         if (WPACKET_init_der(&pkt, ctx->aid_buf, sizeof(ctx->aid_buf))
262             && DER_w_algorithmIdentifier_MDWithRSAEncryption(&pkt, -1, ctx->rsa,
263                                                              md_nid)
264             && WPACKET_finish(&pkt)) {
265             WPACKET_get_total_written(&pkt, &ctx->aid_len);
266             ctx->aid = WPACKET_get_curr(&pkt);
267         }
268         WPACKET_cleanup(&pkt);
269
270         ctx->mdctx = NULL;
271         ctx->md = md;
272         ctx->mdnid = md_nid;
273         OPENSSL_strlcpy(ctx->mdname, mdname, sizeof(ctx->mdname));
274     }
275
276     return 1;
277 }
278
279 static int rsa_setup_mgf1_md(PROV_RSA_CTX *ctx, const char *mdname,
280                              const char *mdprops)
281 {
282     size_t len;
283
284     if (mdprops == NULL)
285         mdprops = ctx->propq;
286
287     if (ctx->mgf1_mdname[0] != '\0')
288         EVP_MD_free(ctx->mgf1_md);
289
290     if ((ctx->mgf1_md = EVP_MD_fetch(ctx->libctx, mdname, mdprops)) == NULL) {
291         ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_DIGEST,
292                        "%s could not be fetched", mdname);
293         return 0;
294     }
295     len = OPENSSL_strlcpy(ctx->mgf1_mdname, mdname, sizeof(ctx->mgf1_mdname));
296     if (len >= sizeof(ctx->mgf1_mdname)) {
297         ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_DIGEST,
298                        "%s exceeds name buffer length", mdname);
299         return 0;
300     }
301
302     return 1;
303 }
304
305 static int rsa_signature_init(void *vprsactx, void *vrsa, int operation)
306 {
307     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
308
309     if (prsactx == NULL || vrsa == NULL || !RSA_up_ref(vrsa))
310         return 0;
311
312     RSA_free(prsactx->rsa);
313     prsactx->rsa = vrsa;
314     prsactx->operation = operation;
315
316     /* Maximum for sign, auto for verify */
317     prsactx->saltlen = RSA_PSS_SALTLEN_AUTO;
318     prsactx->min_saltlen = -1;
319
320     switch (RSA_test_flags(prsactx->rsa, RSA_FLAG_TYPE_MASK)) {
321     case RSA_FLAG_TYPE_RSA:
322         prsactx->pad_mode = RSA_PKCS1_PADDING;
323         break;
324     case RSA_FLAG_TYPE_RSASSAPSS:
325         prsactx->pad_mode = RSA_PKCS1_PSS_PADDING;
326
327         {
328             const RSA_PSS_PARAMS_30 *pss =
329                 rsa_get0_pss_params_30(prsactx->rsa);
330
331             if (!rsa_pss_params_30_is_unrestricted(pss)) {
332                 int md_nid = rsa_pss_params_30_hashalg(pss);
333                 int mgf1md_nid = rsa_pss_params_30_maskgenhashalg(pss);
334                 int min_saltlen = rsa_pss_params_30_saltlen(pss);
335                 const char *mdname, *mgf1mdname;
336                 size_t len;
337
338                 mdname = rsa_oaeppss_nid2name(md_nid);
339                 mgf1mdname = rsa_oaeppss_nid2name(mgf1md_nid);
340                 prsactx->min_saltlen = min_saltlen;
341
342                 if (mdname == NULL) {
343                     ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_DIGEST,
344                                    "PSS restrictions lack hash algorithm");
345                     return 0;
346                 }
347                 if (mgf1mdname == NULL) {
348                     ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_DIGEST,
349                                    "PSS restrictions lack MGF1 hash algorithm");
350                     return 0;
351                 }
352
353                 len = OPENSSL_strlcpy(prsactx->mdname, mdname,
354                                       sizeof(prsactx->mdname));
355                 if (len >= sizeof(prsactx->mdname)) {
356                     ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_DIGEST,
357                                    "hash algorithm name too long");
358                     return 0;
359                 }
360                 len = OPENSSL_strlcpy(prsactx->mgf1_mdname, mgf1mdname,
361                                       sizeof(prsactx->mgf1_mdname));
362                 if (len >= sizeof(prsactx->mgf1_mdname)) {
363                     ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_DIGEST,
364                                    "MGF1 hash algorithm name too long");
365                     return 0;
366                 }
367                 prsactx->saltlen = min_saltlen;
368
369                 return rsa_setup_md(prsactx, mdname, prsactx->propq)
370                     && rsa_setup_mgf1_md(prsactx, mgf1mdname, prsactx->propq);
371             }
372         }
373
374         break;
375     default:
376         ERR_raise(ERR_LIB_RSA, PROV_R_OPERATION_NOT_SUPPORTED_FOR_THIS_KEYTYPE);
377         return 0;
378     }
379
380     return 1;
381 }
382
383 static int setup_tbuf(PROV_RSA_CTX *ctx)
384 {
385     if (ctx->tbuf != NULL)
386         return 1;
387     if ((ctx->tbuf = OPENSSL_malloc(RSA_size(ctx->rsa))) == NULL) {
388         ERR_raise(ERR_LIB_PROV, ERR_R_MALLOC_FAILURE);
389         return 0;
390     }
391     return 1;
392 }
393
394 static void clean_tbuf(PROV_RSA_CTX *ctx)
395 {
396     if (ctx->tbuf != NULL)
397         OPENSSL_cleanse(ctx->tbuf, RSA_size(ctx->rsa));
398 }
399
400 static void free_tbuf(PROV_RSA_CTX *ctx)
401 {
402     clean_tbuf(ctx);
403     OPENSSL_free(ctx->tbuf);
404     ctx->tbuf = NULL;
405 }
406
407 static int rsa_sign_init(void *vprsactx, void *vrsa)
408 {
409     return rsa_signature_init(vprsactx, vrsa, EVP_PKEY_OP_SIGN);
410 }
411
412 static int rsa_sign(void *vprsactx, unsigned char *sig, size_t *siglen,
413                     size_t sigsize, const unsigned char *tbs, size_t tbslen)
414 {
415     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
416     int ret;
417     size_t rsasize = RSA_size(prsactx->rsa);
418     size_t mdsize = rsa_get_md_size(prsactx);
419
420     if (sig == NULL) {
421         *siglen = rsasize;
422         return 1;
423     }
424
425     if (sigsize < rsasize) {
426         ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_SIGNATURE_SIZE,
427                        "is %zu, should be at least %zu", sigsize, rsasize);
428         return 0;
429     }
430
431     if (mdsize != 0) {
432         if (tbslen != mdsize) {
433             ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_DIGEST_LENGTH);
434             return 0;
435         }
436
437 #ifndef FIPS_MODULE
438         if (EVP_MD_is_a(prsactx->md, OSSL_DIGEST_NAME_MDC2)) {
439             unsigned int sltmp;
440
441             if (prsactx->pad_mode != RSA_PKCS1_PADDING) {
442                 ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_PADDING_MODE,
443                                "only PKCS#1 padding supported with MDC2");
444                 return 0;
445             }
446             ret = RSA_sign_ASN1_OCTET_STRING(0, tbs, tbslen, sig, &sltmp,
447                                              prsactx->rsa);
448
449             if (ret <= 0) {
450                 ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
451                 return 0;
452             }
453             ret = sltmp;
454             goto end;
455         }
456 #endif
457         switch (prsactx->pad_mode) {
458         case RSA_X931_PADDING:
459             if ((size_t)RSA_size(prsactx->rsa) < tbslen + 1) {
460                 ERR_raise_data(ERR_LIB_PROV, PROV_R_KEY_SIZE_TOO_SMALL,
461                                "RSA key size = %d, expected minimum = %d",
462                                RSA_size(prsactx->rsa), tbslen + 1);
463                 return 0;
464             }
465             if (!setup_tbuf(prsactx)) {
466                 ERR_raise(ERR_LIB_PROV, ERR_R_MALLOC_FAILURE);
467                 return 0;
468             }
469             memcpy(prsactx->tbuf, tbs, tbslen);
470             prsactx->tbuf[tbslen] = RSA_X931_hash_id(prsactx->mdnid);
471             ret = RSA_private_encrypt(tbslen + 1, prsactx->tbuf,
472                                       sig, prsactx->rsa, RSA_X931_PADDING);
473             clean_tbuf(prsactx);
474             break;
475
476         case RSA_PKCS1_PADDING:
477             {
478                 unsigned int sltmp;
479
480                 ret = RSA_sign(prsactx->mdnid, tbs, tbslen, sig, &sltmp,
481                                prsactx->rsa);
482                 if (ret <= 0) {
483                     ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
484                     return 0;
485                 }
486                 ret = sltmp;
487             }
488             break;
489
490         case RSA_PKCS1_PSS_PADDING:
491             /* Check PSS restrictions */
492             if (rsa_pss_restricted(prsactx)) {
493                 switch (prsactx->saltlen) {
494                 case RSA_PSS_SALTLEN_DIGEST:
495                     if (prsactx->min_saltlen > EVP_MD_size(prsactx->md)) {
496                         ERR_raise_data(ERR_LIB_PROV,
497                                        PROV_R_PSS_SALTLEN_TOO_SMALL,
498                                        "minimum salt length set to %d, "
499                                        "but the digest only gives %d",
500                                        prsactx->min_saltlen,
501                                        EVP_MD_size(prsactx->md));
502                         return 0;
503                     }
504                     /* FALLTHRU */
505                 default:
506                     if (prsactx->saltlen >= 0
507                         && prsactx->saltlen < prsactx->min_saltlen) {
508                         ERR_raise_data(ERR_LIB_PROV,
509                                        PROV_R_PSS_SALTLEN_TOO_SMALL,
510                                        "minimum salt length set to %d, but the"
511                                        "actual salt length is only set to %d",
512                                        prsactx->min_saltlen,
513                                        prsactx->saltlen);
514                         return 0;
515                     }
516                     break;
517                 }
518             }
519             if (!setup_tbuf(prsactx))
520                 return 0;
521             if (!RSA_padding_add_PKCS1_PSS_mgf1(prsactx->rsa,
522                                                 prsactx->tbuf, tbs,
523                                                 prsactx->md, prsactx->mgf1_md,
524                                                 prsactx->saltlen)) {
525                 ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
526                 return 0;
527             }
528             ret = RSA_private_encrypt(RSA_size(prsactx->rsa), prsactx->tbuf,
529                                       sig, prsactx->rsa, RSA_NO_PADDING);
530             clean_tbuf(prsactx);
531             break;
532
533         default:
534             ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_PADDING_MODE,
535                            "Only X.931, PKCS#1 v1.5 or PSS padding allowed");
536             return 0;
537         }
538     } else {
539         ret = RSA_private_encrypt(tbslen, tbs, sig, prsactx->rsa,
540                                   prsactx->pad_mode);
541     }
542
543 #ifndef FIPS_MODULE
544  end:
545 #endif
546     if (ret <= 0) {
547         ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
548         return 0;
549     }
550
551     *siglen = ret;
552     return 1;
553 }
554
555 static int rsa_verify_recover_init(void *vprsactx, void *vrsa)
556 {
557     return rsa_signature_init(vprsactx, vrsa, EVP_PKEY_OP_VERIFYRECOVER);
558 }
559
560 static int rsa_verify_recover(void *vprsactx,
561                               unsigned char *rout,
562                               size_t *routlen,
563                               size_t routsize,
564                               const unsigned char *sig,
565                               size_t siglen)
566 {
567     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
568     int ret;
569
570     if (rout == NULL) {
571         *routlen = RSA_size(prsactx->rsa);
572         return 1;
573     }
574
575     if (prsactx->md != NULL) {
576         switch (prsactx->pad_mode) {
577         case RSA_X931_PADDING:
578             if (!setup_tbuf(prsactx))
579                 return 0;
580             ret = RSA_public_decrypt(siglen, sig, prsactx->tbuf, prsactx->rsa,
581                                      RSA_X931_PADDING);
582             if (ret < 1) {
583                 ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
584                 return 0;
585             }
586             ret--;
587             if (prsactx->tbuf[ret] != RSA_X931_hash_id(prsactx->mdnid)) {
588                 ERR_raise(ERR_LIB_PROV, PROV_R_ALGORITHM_MISMATCH);
589                 return 0;
590             }
591             if (ret != EVP_MD_size(prsactx->md)) {
592                 ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_DIGEST_LENGTH,
593                                "Should be %d, but got %d",
594                                EVP_MD_size(prsactx->md), ret);
595                 return 0;
596             }
597
598             *routlen = ret;
599             if (routsize < (size_t)ret) {
600                 ERR_raise_data(ERR_LIB_PROV, PROV_R_OUTPUT_BUFFER_TOO_SMALL,
601                                "buffer size is %d, should be %d",
602                                routsize, ret);
603                 return 0;
604             }
605             memcpy(rout, prsactx->tbuf, ret);
606             break;
607
608         case RSA_PKCS1_PADDING:
609             {
610                 size_t sltmp;
611
612                 ret = int_rsa_verify(prsactx->mdnid, NULL, 0, rout, &sltmp,
613                                      sig, siglen, prsactx->rsa);
614                 if (ret <= 0) {
615                     ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
616                     return 0;
617                 }
618                 ret = sltmp;
619             }
620             break;
621
622         default:
623             ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_PADDING_MODE,
624                            "Only X.931 or PKCS#1 v1.5 padding allowed");
625             return 0;
626         }
627     } else {
628         ret = RSA_public_decrypt(siglen, sig, rout, prsactx->rsa,
629                                  prsactx->pad_mode);
630         if (ret < 0) {
631             ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
632             return 0;
633         }
634     }
635     *routlen = ret;
636     return 1;
637 }
638
639 static int rsa_verify_init(void *vprsactx, void *vrsa)
640 {
641     return rsa_signature_init(vprsactx, vrsa, EVP_PKEY_OP_VERIFY);
642 }
643
644 static int rsa_verify(void *vprsactx, const unsigned char *sig, size_t siglen,
645                       const unsigned char *tbs, size_t tbslen)
646 {
647     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
648     size_t rslen;
649
650     if (prsactx->md != NULL) {
651         switch (prsactx->pad_mode) {
652         case RSA_PKCS1_PADDING:
653             if (!RSA_verify(prsactx->mdnid, tbs, tbslen, sig, siglen,
654                             prsactx->rsa)) {
655                 ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
656                 return 0;
657             }
658             return 1;
659         case RSA_X931_PADDING:
660             if (rsa_verify_recover(prsactx, NULL, &rslen, 0, sig, siglen) <= 0)
661                 return 0;
662             break;
663         case RSA_PKCS1_PSS_PADDING:
664             {
665                 int ret;
666                 size_t mdsize;
667
668                 /*
669                  * We need to check this for the RSA_verify_PKCS1_PSS_mgf1()
670                  * call
671                  */
672                 mdsize = rsa_get_md_size(prsactx);
673                 if (tbslen != mdsize) {
674                     ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_DIGEST_LENGTH,
675                                    "Should be %d, but got %d",
676                                    mdsize, tbslen);
677                     return 0;
678                 }
679
680                 if (!setup_tbuf(prsactx))
681                     return 0;
682                 ret = RSA_public_decrypt(siglen, sig, prsactx->tbuf,
683                                          prsactx->rsa, RSA_NO_PADDING);
684                 if (ret <= 0) {
685                     ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
686                     return 0;
687                 }
688                 ret = RSA_verify_PKCS1_PSS_mgf1(prsactx->rsa, tbs,
689                                                 prsactx->md, prsactx->mgf1_md,
690                                                 prsactx->tbuf,
691                                                 prsactx->saltlen);
692                 if (ret <= 0) {
693                     ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
694                     return 0;
695                 }
696                 return 1;
697             }
698         default:
699             ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_PADDING_MODE,
700                            "Only X.931, PKCS#1 v1.5 or PSS padding allowed");
701             return 0;
702         }
703     } else {
704         if (!setup_tbuf(prsactx))
705             return 0;
706         rslen = RSA_public_decrypt(siglen, sig, prsactx->tbuf, prsactx->rsa,
707                                    prsactx->pad_mode);
708         if (rslen == 0) {
709             ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
710             return 0;
711         }
712     }
713
714     if ((rslen != tbslen) || memcmp(tbs, prsactx->tbuf, rslen))
715         return 0;
716
717     return 1;
718 }
719
720 static int rsa_digest_signverify_init(void *vprsactx, const char *mdname,
721                                       void *vrsa, int operation)
722 {
723     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
724
725     prsactx->flag_allow_md = 0;
726     if (!rsa_signature_init(vprsactx, vrsa, operation)
727         || !rsa_setup_md(prsactx, mdname, NULL)) /* TODO RL */
728         return 0;
729
730     prsactx->mdctx = EVP_MD_CTX_new();
731     if (prsactx->mdctx == NULL) {
732         ERR_raise(ERR_LIB_PROV, ERR_R_MALLOC_FAILURE);
733         goto error;
734     }
735
736     if (!EVP_DigestInit_ex(prsactx->mdctx, prsactx->md, NULL))
737         goto error;
738
739     return 1;
740
741  error:
742     EVP_MD_CTX_free(prsactx->mdctx);
743     EVP_MD_free(prsactx->md);
744     prsactx->mdctx = NULL;
745     prsactx->md = NULL;
746     return 0;
747 }
748
749 static int rsa_digest_signverify_update(void *vprsactx,
750                                         const unsigned char *data,
751                                         size_t datalen)
752 {
753     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
754
755     if (prsactx == NULL || prsactx->mdctx == NULL)
756         return 0;
757
758     return EVP_DigestUpdate(prsactx->mdctx, data, datalen);
759 }
760
761 static int rsa_digest_sign_init(void *vprsactx, const char *mdname,
762                                 void *vrsa)
763 {
764     return rsa_digest_signverify_init(vprsactx, mdname, vrsa,
765                                       EVP_PKEY_OP_SIGN);
766 }
767
768 static int rsa_digest_sign_final(void *vprsactx, unsigned char *sig,
769                                  size_t *siglen, size_t sigsize)
770 {
771     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
772     unsigned char digest[EVP_MAX_MD_SIZE];
773     unsigned int dlen = 0;
774
775     prsactx->flag_allow_md = 1;
776     if (prsactx == NULL || prsactx->mdctx == NULL)
777         return 0;
778
779     /*
780      * If sig is NULL then we're just finding out the sig size. Other fields
781      * are ignored. Defer to rsa_sign.
782      */
783     if (sig != NULL) {
784         /*
785          * TODO(3.0): There is the possibility that some externally provided
786          * digests exceed EVP_MAX_MD_SIZE. We should probably handle that somehow -
787          * but that problem is much larger than just in RSA.
788          */
789         if (!EVP_DigestFinal_ex(prsactx->mdctx, digest, &dlen))
790             return 0;
791     }
792
793     return rsa_sign(vprsactx, sig, siglen, sigsize, digest, (size_t)dlen);
794 }
795
796 static int rsa_digest_verify_init(void *vprsactx, const char *mdname,
797                                   void *vrsa)
798 {
799     return rsa_digest_signverify_init(vprsactx, mdname, vrsa,
800                                       EVP_PKEY_OP_VERIFY);
801 }
802
803 int rsa_digest_verify_final(void *vprsactx, const unsigned char *sig,
804                             size_t siglen)
805 {
806     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
807     unsigned char digest[EVP_MAX_MD_SIZE];
808     unsigned int dlen = 0;
809
810     prsactx->flag_allow_md = 1;
811     if (prsactx == NULL || prsactx->mdctx == NULL)
812         return 0;
813
814     /*
815      * TODO(3.0): There is the possibility that some externally provided
816      * digests exceed EVP_MAX_MD_SIZE. We should probably handle that somehow -
817      * but that problem is much larger than just in RSA.
818      */
819     if (!EVP_DigestFinal_ex(prsactx->mdctx, digest, &dlen))
820         return 0;
821
822     return rsa_verify(vprsactx, sig, siglen, digest, (size_t)dlen);
823 }
824
825 static void rsa_freectx(void *vprsactx)
826 {
827     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
828
829     if (prsactx == NULL)
830         return;
831
832     RSA_free(prsactx->rsa);
833     EVP_MD_CTX_free(prsactx->mdctx);
834     EVP_MD_free(prsactx->md);
835     EVP_MD_free(prsactx->mgf1_md);
836     OPENSSL_free(prsactx->propq);
837     free_tbuf(prsactx);
838
839     OPENSSL_clear_free(prsactx, sizeof(prsactx));
840 }
841
842 static void *rsa_dupctx(void *vprsactx)
843 {
844     PROV_RSA_CTX *srcctx = (PROV_RSA_CTX *)vprsactx;
845     PROV_RSA_CTX *dstctx;
846
847     dstctx = OPENSSL_zalloc(sizeof(*srcctx));
848     if (dstctx == NULL) {
849         ERR_raise(ERR_LIB_PROV, ERR_R_MALLOC_FAILURE);
850         return NULL;
851     }
852
853     *dstctx = *srcctx;
854     dstctx->rsa = NULL;
855     dstctx->md = NULL;
856     dstctx->mdctx = NULL;
857     dstctx->tbuf = NULL;
858
859     if (srcctx->rsa != NULL && !RSA_up_ref(srcctx->rsa))
860         goto err;
861     dstctx->rsa = srcctx->rsa;
862
863     if (srcctx->md != NULL && !EVP_MD_up_ref(srcctx->md))
864         goto err;
865     dstctx->md = srcctx->md;
866
867     if (srcctx->mgf1_md != NULL && !EVP_MD_up_ref(srcctx->mgf1_md))
868         goto err;
869     dstctx->mgf1_md = srcctx->mgf1_md;
870
871     if (srcctx->mdctx != NULL) {
872         dstctx->mdctx = EVP_MD_CTX_new();
873         if (dstctx->mdctx == NULL
874                 || !EVP_MD_CTX_copy_ex(dstctx->mdctx, srcctx->mdctx))
875             goto err;
876     }
877
878     return dstctx;
879  err:
880     rsa_freectx(dstctx);
881     return NULL;
882 }
883
884 static int rsa_get_ctx_params(void *vprsactx, OSSL_PARAM *params)
885 {
886     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
887     OSSL_PARAM *p;
888
889     if (prsactx == NULL || params == NULL)
890         return 0;
891
892     p = OSSL_PARAM_locate(params, OSSL_SIGNATURE_PARAM_ALGORITHM_ID);
893     if (p != NULL
894         && !OSSL_PARAM_set_octet_string(p, prsactx->aid, prsactx->aid_len))
895         return 0;
896
897     p = OSSL_PARAM_locate(params, OSSL_SIGNATURE_PARAM_PAD_MODE);
898     if (p != NULL)
899         switch (p->data_type) {
900         case OSSL_PARAM_INTEGER:
901             if (!OSSL_PARAM_set_int(p, prsactx->pad_mode))
902                 return 0;
903             break;
904         case OSSL_PARAM_UTF8_STRING:
905             {
906                 int i;
907                 const char *word = NULL;
908
909                 for (i = 0; padding_item[i].id != 0; i++) {
910                     if (prsactx->pad_mode == (int)padding_item[i].id) {
911                         word = padding_item[i].ptr;
912                         break;
913                     }
914                 }
915
916                 if (word != NULL) {
917                     if (!OSSL_PARAM_set_utf8_string(p, word))
918                         return 0;
919                 } else {
920                     ERR_raise(ERR_LIB_PROV, ERR_R_INTERNAL_ERROR);
921                 }
922             }
923             break;
924         default:
925             return 0;
926         }
927
928     p = OSSL_PARAM_locate(params, OSSL_SIGNATURE_PARAM_DIGEST);
929     if (p != NULL && !OSSL_PARAM_set_utf8_string(p, prsactx->mdname))
930         return 0;
931
932     p = OSSL_PARAM_locate(params, OSSL_SIGNATURE_PARAM_MGF1_DIGEST);
933     if (p != NULL && !OSSL_PARAM_set_utf8_string(p, prsactx->mgf1_mdname))
934         return 0;
935
936     p = OSSL_PARAM_locate(params, OSSL_SIGNATURE_PARAM_PSS_SALTLEN);
937     if (p != NULL) {
938         if (p->data_type == OSSL_PARAM_INTEGER) {
939             if (!OSSL_PARAM_set_int(p, prsactx->saltlen))
940                 return 0;
941         } else if (p->data_type == OSSL_PARAM_UTF8_STRING) {
942             switch (prsactx->saltlen) {
943             case RSA_PSS_SALTLEN_DIGEST:
944                 if (!OSSL_PARAM_set_utf8_string(p, "digest"))
945                     return 0;
946                 break;
947             case RSA_PSS_SALTLEN_MAX:
948                 if (!OSSL_PARAM_set_utf8_string(p, "max"))
949                     return 0;
950                 break;
951             case RSA_PSS_SALTLEN_AUTO:
952                 if (!OSSL_PARAM_set_utf8_string(p, "auto"))
953                     return 0;
954                 break;
955             default:
956                 if (BIO_snprintf(p->data, p->data_size, "%d", prsactx->saltlen)
957                     <= 0)
958                     return 0;
959                 break;
960             }
961         }
962     }
963
964     return 1;
965 }
966
967 static const OSSL_PARAM known_gettable_ctx_params[] = {
968     OSSL_PARAM_octet_string(OSSL_SIGNATURE_PARAM_ALGORITHM_ID, NULL, 0),
969     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_PAD_MODE, NULL, 0),
970     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_DIGEST, NULL, 0),
971     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_MGF1_DIGEST, NULL, 0),
972     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_PSS_SALTLEN, NULL, 0),
973     OSSL_PARAM_END
974 };
975
976 static const OSSL_PARAM *rsa_gettable_ctx_params(void)
977 {
978     return known_gettable_ctx_params;
979 }
980
981 static int rsa_set_ctx_params(void *vprsactx, const OSSL_PARAM params[])
982 {
983     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
984     const OSSL_PARAM *p;
985
986     if (prsactx == NULL || params == NULL)
987         return 0;
988
989     p = OSSL_PARAM_locate_const(params, OSSL_SIGNATURE_PARAM_DIGEST);
990     /* Not allowed during certain operations */
991     if (p != NULL && !prsactx->flag_allow_md)
992         return 0;
993     if (p != NULL) {
994         char mdname[OSSL_MAX_NAME_SIZE] = "", *pmdname = mdname;
995         char mdprops[OSSL_MAX_PROPQUERY_SIZE] = "", *pmdprops = mdprops;
996         const OSSL_PARAM *propsp =
997             OSSL_PARAM_locate_const(params,
998                                     OSSL_SIGNATURE_PARAM_PROPERTIES);
999
1000         if (!OSSL_PARAM_get_utf8_string(p, &pmdname, sizeof(mdname)))
1001             return 0;
1002
1003         if (propsp == NULL)
1004             pmdprops = NULL;
1005         else if (!OSSL_PARAM_get_utf8_string(propsp,
1006                                              &pmdprops, sizeof(mdprops)))
1007             return 0;
1008
1009         if (rsa_pss_restricted(prsactx)) {
1010             /* TODO(3.0) figure out what to do for prsactx->md == NULL */
1011             if (prsactx->md == NULL || EVP_MD_is_a(prsactx->md, mdname))
1012                 return 1;
1013             ERR_raise(ERR_LIB_PROV, PROV_R_DIGEST_NOT_ALLOWED);
1014             return 0;
1015         }
1016
1017         /* non-PSS code follows */
1018         if (!rsa_setup_md(prsactx, mdname, pmdprops))
1019             return 0;
1020     }
1021
1022     p = OSSL_PARAM_locate_const(params, OSSL_SIGNATURE_PARAM_PAD_MODE);
1023     if (p != NULL) {
1024         int pad_mode = 0;
1025         const char *err_extra_text = NULL;
1026
1027         switch (p->data_type) {
1028         case OSSL_PARAM_INTEGER: /* Support for legacy pad mode number */
1029             if (!OSSL_PARAM_get_int(p, &pad_mode))
1030                 return 0;
1031             break;
1032         case OSSL_PARAM_UTF8_STRING:
1033             {
1034                 int i;
1035
1036                 if (p->data == NULL)
1037                     return 0;
1038
1039                 for (i = 0; padding_item[i].id != 0; i++) {
1040                     if (strcmp(p->data, padding_item[i].ptr) == 0) {
1041                         pad_mode = padding_item[i].id;
1042                         break;
1043                     }
1044                 }
1045             }
1046             break;
1047         default:
1048             return 0;
1049         }
1050
1051         switch (pad_mode) {
1052         case RSA_PKCS1_OAEP_PADDING:
1053             /*
1054              * OAEP padding is for asymmetric cipher only so is not compatible
1055              * with signature use.
1056              */
1057             err_extra_text = "OAEP padding not allowed for signing / verifying";
1058             goto bad_pad;
1059         case RSA_PKCS1_PSS_PADDING:
1060             if ((prsactx->operation
1061                  & (EVP_PKEY_OP_SIGN | EVP_PKEY_OP_VERIFY)) == 0) {
1062                 err_extra_text =
1063                     "PSS padding only allowed for sign and verify operations";
1064                 goto bad_pad;
1065             }
1066             if (prsactx->md == NULL
1067                 && !rsa_setup_md(prsactx, OSSL_DIGEST_NAME_SHA1, NULL)) {
1068                 return 0;
1069             }
1070             break;
1071         case RSA_PKCS1_PADDING:
1072             err_extra_text = "PKCS#1 padding not allowed with RSA-PSS";
1073             goto cont;
1074         case RSA_SSLV23_PADDING:
1075             err_extra_text = "SSLv3 padding not allowed with RSA-PSS";
1076             goto cont;
1077         case RSA_NO_PADDING:
1078             err_extra_text = "No padding not allowed with RSA-PSS";
1079             goto cont;
1080         case RSA_X931_PADDING:
1081             err_extra_text = "X.931 padding not allowed with RSA-PSS";
1082         cont:
1083             if (RSA_test_flags(prsactx->rsa,
1084                                RSA_FLAG_TYPE_MASK) == RSA_FLAG_TYPE_RSA)
1085                 break;
1086             /* FALLTHRU */
1087         default:
1088         bad_pad:
1089             if (err_extra_text == NULL)
1090                 ERR_raise(ERR_LIB_PROV,
1091                           PROV_R_ILLEGAL_OR_UNSUPPORTED_PADDING_MODE);
1092             else
1093                 ERR_raise_data(ERR_LIB_PROV,
1094                                PROV_R_ILLEGAL_OR_UNSUPPORTED_PADDING_MODE,
1095                                err_extra_text);
1096             return 0;
1097         }
1098         if (!rsa_check_padding(prsactx->mdnid, pad_mode))
1099             return 0;
1100         prsactx->pad_mode = pad_mode;
1101     }
1102
1103     p = OSSL_PARAM_locate_const(params, OSSL_SIGNATURE_PARAM_PSS_SALTLEN);
1104     if (p != NULL) {
1105         int saltlen;
1106
1107         if (prsactx->pad_mode != RSA_PKCS1_PSS_PADDING) {
1108             ERR_raise_data(ERR_LIB_PROV, PROV_R_NOT_SUPPORTED,
1109                            "PSS saltlen can only be specified if "
1110                            "PSS padding has been specified first");
1111             return 0;
1112         }
1113
1114         switch (p->data_type) {
1115         case OSSL_PARAM_INTEGER: /* Support for legacy pad mode number */
1116             if (!OSSL_PARAM_get_int(p, &saltlen))
1117                 return 0;
1118             break;
1119         case OSSL_PARAM_UTF8_STRING:
1120             if (strcmp(p->data, "digest") == 0)
1121                 saltlen = RSA_PSS_SALTLEN_DIGEST;
1122             else if (strcmp(p->data, "max") == 0)
1123                 saltlen = RSA_PSS_SALTLEN_MAX;
1124             else if (strcmp(p->data, "auto") == 0)
1125                 saltlen = RSA_PSS_SALTLEN_AUTO;
1126             else
1127                 saltlen = atoi(p->data);
1128             break;
1129         default:
1130             return 0;
1131         }
1132
1133         /*
1134          * RSA_PSS_SALTLEN_MAX seems curiously named in this check.
1135          * Contrary to what it's name suggests, it's the currently
1136          * lowest saltlen number possible.
1137          */
1138         if (saltlen < RSA_PSS_SALTLEN_MAX) {
1139             ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_PSS_SALTLEN);
1140             return 0;
1141         }
1142
1143         if (rsa_pss_restricted(prsactx)) {
1144             switch (prsactx->saltlen) {
1145             case RSA_PSS_SALTLEN_AUTO:
1146                 if (prsactx->operation == EVP_PKEY_OP_VERIFY) {
1147                     ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_PSS_SALTLEN);
1148                     return 0;
1149                 }
1150                 break;
1151             case RSA_PSS_SALTLEN_DIGEST:
1152                 if (prsactx->min_saltlen > EVP_MD_size(prsactx->md)) {
1153                     ERR_raise_data(ERR_LIB_PROV,
1154                                    PROV_R_PSS_SALTLEN_TOO_SMALL,
1155                                    "Should be more than %d, but would be "
1156                                    "set to match digest size (%d)",
1157                                    prsactx->min_saltlen,
1158                                    EVP_MD_size(prsactx->md));
1159                     return 0;
1160                 }
1161                 /* FALLTHRU */
1162             default:
1163                 if (saltlen >= 0 && saltlen < prsactx->min_saltlen) {
1164                     ERR_raise_data(ERR_LIB_PROV,
1165                                    PROV_R_PSS_SALTLEN_TOO_SMALL,
1166                                    "Should be more than %d, "
1167                                    "but would be set to %d",
1168                                    prsactx->min_saltlen, saltlen);
1169                     return 0;
1170                 }
1171             }
1172         }
1173
1174         prsactx->saltlen = saltlen;
1175     }
1176
1177     p = OSSL_PARAM_locate_const(params, OSSL_SIGNATURE_PARAM_MGF1_DIGEST);
1178     if (p != NULL) {
1179         char mdname[OSSL_MAX_NAME_SIZE] = "", *pmdname = mdname;
1180         char mdprops[OSSL_MAX_PROPQUERY_SIZE] = "", *pmdprops = mdprops;
1181         const OSSL_PARAM *propsp =
1182             OSSL_PARAM_locate_const(params,
1183                                     OSSL_SIGNATURE_PARAM_MGF1_PROPERTIES);
1184
1185         if (!OSSL_PARAM_get_utf8_string(p, &pmdname, sizeof(mdname)))
1186             return 0;
1187
1188         if (propsp == NULL)
1189             pmdprops = NULL;
1190         else if (!OSSL_PARAM_get_utf8_string(propsp,
1191                                              &pmdprops, sizeof(mdprops)))
1192             return 0;
1193
1194         if (prsactx->pad_mode != RSA_PKCS1_PSS_PADDING) {
1195             ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_MGF1_MD);
1196             return  0;
1197         }
1198
1199         if (rsa_pss_restricted(prsactx)) {
1200             /* TODO(3.0) figure out what to do for prsactx->mgf1_md == NULL */
1201             if (prsactx->mgf1_md == NULL
1202                 || EVP_MD_is_a(prsactx->mgf1_md, mdname))
1203                 return 1;
1204             ERR_raise(ERR_LIB_PROV, PROV_R_DIGEST_NOT_ALLOWED);
1205             return 0;
1206         }
1207
1208         /* non-PSS code follows */
1209         if (!rsa_setup_mgf1_md(prsactx, mdname, pmdprops))
1210             return 0;
1211     }
1212
1213     return 1;
1214 }
1215
1216 static const OSSL_PARAM known_settable_ctx_params[] = {
1217     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_PAD_MODE, NULL, 0),
1218     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_DIGEST, NULL, 0),
1219     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_PROPERTIES, NULL, 0),
1220     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_MGF1_DIGEST, NULL, 0),
1221     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_MGF1_PROPERTIES, NULL, 0),
1222     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_PSS_SALTLEN, NULL, 0),
1223     OSSL_PARAM_END
1224 };
1225
1226 static const OSSL_PARAM *rsa_settable_ctx_params(void)
1227 {
1228     /*
1229      * TODO(3.0): Should this function return a different set of settable ctx
1230      * params if the ctx is being used for a DigestSign/DigestVerify? In that
1231      * case it is not allowed to set the digest size/digest name because the
1232      * digest is explicitly set as part of the init.
1233      */
1234     return known_settable_ctx_params;
1235 }
1236
1237 static int rsa_get_ctx_md_params(void *vprsactx, OSSL_PARAM *params)
1238 {
1239     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
1240
1241     if (prsactx->mdctx == NULL)
1242         return 0;
1243
1244     return EVP_MD_CTX_get_params(prsactx->mdctx, params);
1245 }
1246
1247 static const OSSL_PARAM *rsa_gettable_ctx_md_params(void *vprsactx)
1248 {
1249     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
1250
1251     if (prsactx->md == NULL)
1252         return 0;
1253
1254     return EVP_MD_gettable_ctx_params(prsactx->md);
1255 }
1256
1257 static int rsa_set_ctx_md_params(void *vprsactx, const OSSL_PARAM params[])
1258 {
1259     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
1260
1261     if (prsactx->mdctx == NULL)
1262         return 0;
1263
1264     return EVP_MD_CTX_set_params(prsactx->mdctx, params);
1265 }
1266
1267 static const OSSL_PARAM *rsa_settable_ctx_md_params(void *vprsactx)
1268 {
1269     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
1270
1271     if (prsactx->md == NULL)
1272         return 0;
1273
1274     return EVP_MD_settable_ctx_params(prsactx->md);
1275 }
1276
1277 const OSSL_DISPATCH rsa_signature_functions[] = {
1278     { OSSL_FUNC_SIGNATURE_NEWCTX, (void (*)(void))rsa_newctx },
1279     { OSSL_FUNC_SIGNATURE_SIGN_INIT, (void (*)(void))rsa_sign_init },
1280     { OSSL_FUNC_SIGNATURE_SIGN, (void (*)(void))rsa_sign },
1281     { OSSL_FUNC_SIGNATURE_VERIFY_INIT, (void (*)(void))rsa_verify_init },
1282     { OSSL_FUNC_SIGNATURE_VERIFY, (void (*)(void))rsa_verify },
1283     { OSSL_FUNC_SIGNATURE_VERIFY_RECOVER_INIT,
1284       (void (*)(void))rsa_verify_recover_init },
1285     { OSSL_FUNC_SIGNATURE_VERIFY_RECOVER,
1286       (void (*)(void))rsa_verify_recover },
1287     { OSSL_FUNC_SIGNATURE_DIGEST_SIGN_INIT,
1288       (void (*)(void))rsa_digest_sign_init },
1289     { OSSL_FUNC_SIGNATURE_DIGEST_SIGN_UPDATE,
1290       (void (*)(void))rsa_digest_signverify_update },
1291     { OSSL_FUNC_SIGNATURE_DIGEST_SIGN_FINAL,
1292       (void (*)(void))rsa_digest_sign_final },
1293     { OSSL_FUNC_SIGNATURE_DIGEST_VERIFY_INIT,
1294       (void (*)(void))rsa_digest_verify_init },
1295     { OSSL_FUNC_SIGNATURE_DIGEST_VERIFY_UPDATE,
1296       (void (*)(void))rsa_digest_signverify_update },
1297     { OSSL_FUNC_SIGNATURE_DIGEST_VERIFY_FINAL,
1298       (void (*)(void))rsa_digest_verify_final },
1299     { OSSL_FUNC_SIGNATURE_FREECTX, (void (*)(void))rsa_freectx },
1300     { OSSL_FUNC_SIGNATURE_DUPCTX, (void (*)(void))rsa_dupctx },
1301     { OSSL_FUNC_SIGNATURE_GET_CTX_PARAMS, (void (*)(void))rsa_get_ctx_params },
1302     { OSSL_FUNC_SIGNATURE_GETTABLE_CTX_PARAMS,
1303       (void (*)(void))rsa_gettable_ctx_params },
1304     { OSSL_FUNC_SIGNATURE_SET_CTX_PARAMS, (void (*)(void))rsa_set_ctx_params },
1305     { OSSL_FUNC_SIGNATURE_SETTABLE_CTX_PARAMS,
1306       (void (*)(void))rsa_settable_ctx_params },
1307     { OSSL_FUNC_SIGNATURE_GET_CTX_MD_PARAMS,
1308       (void (*)(void))rsa_get_ctx_md_params },
1309     { OSSL_FUNC_SIGNATURE_GETTABLE_CTX_MD_PARAMS,
1310       (void (*)(void))rsa_gettable_ctx_md_params },
1311     { OSSL_FUNC_SIGNATURE_SET_CTX_MD_PARAMS,
1312       (void (*)(void))rsa_set_ctx_md_params },
1313     { OSSL_FUNC_SIGNATURE_SETTABLE_CTX_MD_PARAMS,
1314       (void (*)(void))rsa_settable_ctx_md_params },
1315     { 0, NULL }
1316 };