Add fips checks for rsa signatures.
[openssl.git] / providers / implementations / signature / rsa.c
1 /*
2  * Copyright 2019-2020 The OpenSSL Project Authors. All Rights Reserved.
3  *
4  * Licensed under the Apache License 2.0 (the "License").  You may not use
5  * this file except in compliance with the License.  You can obtain a copy
6  * in the file LICENSE in the source distribution or at
7  * https://www.openssl.org/source/license.html
8  */
9
10 /*
11  * RSA low level APIs are deprecated for public use, but still ok for
12  * internal use.
13  */
14 #include "internal/deprecated.h"
15
16 #include <string.h>
17 #include <openssl/crypto.h>
18 #include <openssl/core_dispatch.h>
19 #include <openssl/core_names.h>
20 #include <openssl/err.h>
21 #include <openssl/rsa.h>
22 #include <openssl/params.h>
23 #include <openssl/evp.h>
24 #include "internal/cryptlib.h"
25 #include "internal/nelem.h"
26 #include "internal/sizes.h"
27 #include "crypto/rsa.h"
28 #include "prov/providercommon.h"
29 #include "prov/providercommonerr.h"
30 #include "prov/implementations.h"
31 #include "prov/provider_ctx.h"
32 #include "prov/der_rsa.h"
33 #include "prov/provider_util.h"
34
35 #define RSA_DEFAULT_DIGEST_NAME OSSL_DIGEST_NAME_SHA1
36
37 static OSSL_FUNC_signature_newctx_fn rsa_newctx;
38 static OSSL_FUNC_signature_sign_init_fn rsa_sign_init;
39 static OSSL_FUNC_signature_verify_init_fn rsa_verify_init;
40 static OSSL_FUNC_signature_verify_recover_init_fn rsa_verify_recover_init;
41 static OSSL_FUNC_signature_sign_fn rsa_sign;
42 static OSSL_FUNC_signature_verify_fn rsa_verify;
43 static OSSL_FUNC_signature_verify_recover_fn rsa_verify_recover;
44 static OSSL_FUNC_signature_digest_sign_init_fn rsa_digest_sign_init;
45 static OSSL_FUNC_signature_digest_sign_update_fn rsa_digest_signverify_update;
46 static OSSL_FUNC_signature_digest_sign_final_fn rsa_digest_sign_final;
47 static OSSL_FUNC_signature_digest_verify_init_fn rsa_digest_verify_init;
48 static OSSL_FUNC_signature_digest_verify_update_fn rsa_digest_signverify_update;
49 static OSSL_FUNC_signature_digest_verify_final_fn rsa_digest_verify_final;
50 static OSSL_FUNC_signature_freectx_fn rsa_freectx;
51 static OSSL_FUNC_signature_dupctx_fn rsa_dupctx;
52 static OSSL_FUNC_signature_get_ctx_params_fn rsa_get_ctx_params;
53 static OSSL_FUNC_signature_gettable_ctx_params_fn rsa_gettable_ctx_params;
54 static OSSL_FUNC_signature_set_ctx_params_fn rsa_set_ctx_params;
55 static OSSL_FUNC_signature_settable_ctx_params_fn rsa_settable_ctx_params;
56 static OSSL_FUNC_signature_get_ctx_md_params_fn rsa_get_ctx_md_params;
57 static OSSL_FUNC_signature_gettable_ctx_md_params_fn rsa_gettable_ctx_md_params;
58 static OSSL_FUNC_signature_set_ctx_md_params_fn rsa_set_ctx_md_params;
59 static OSSL_FUNC_signature_settable_ctx_md_params_fn rsa_settable_ctx_md_params;
60
61 static OSSL_ITEM padding_item[] = {
62     { RSA_PKCS1_PADDING,        OSSL_PKEY_RSA_PAD_MODE_PKCSV15 },
63     { RSA_SSLV23_PADDING,       OSSL_PKEY_RSA_PAD_MODE_SSLV23 },
64     { RSA_NO_PADDING,           OSSL_PKEY_RSA_PAD_MODE_NONE },
65     { RSA_X931_PADDING,         OSSL_PKEY_RSA_PAD_MODE_X931 },
66     { RSA_PKCS1_PSS_PADDING,    OSSL_PKEY_RSA_PAD_MODE_PSS },
67     { 0,                        NULL     }
68 };
69
70 /*
71  * What's passed as an actual key is defined by the KEYMGMT interface.
72  * We happen to know that our KEYMGMT simply passes RSA structures, so
73  * we use that here too.
74  */
75
76 typedef struct {
77     OPENSSL_CTX *libctx;
78     char *propq;
79     RSA *rsa;
80     int operation;
81
82     /*
83      * Flag to determine if the hash function can be changed (1) or not (0)
84      * Because it's dangerous to change during a DigestSign or DigestVerify
85      * operation, this flag is cleared by their Init function, and set again
86      * by their Final function.
87      */
88     unsigned int flag_allow_md : 1;
89
90     /* The Algorithm Identifier of the combined signature algorithm */
91     unsigned char aid_buf[128];
92     unsigned char *aid;
93     size_t  aid_len;
94
95     /* main digest */
96     EVP_MD *md;
97     EVP_MD_CTX *mdctx;
98     int mdnid;
99     char mdname[OSSL_MAX_NAME_SIZE]; /* Purely informational */
100
101     /* RSA padding mode */
102     int pad_mode;
103     /* message digest for MGF1 */
104     EVP_MD *mgf1_md;
105     char mgf1_mdname[OSSL_MAX_NAME_SIZE]; /* Purely informational */
106     /* PSS salt length */
107     int saltlen;
108     /* Minimum salt length or -1 if no PSS parameter restriction */
109     int min_saltlen;
110
111     /* Temp buffer */
112     unsigned char *tbuf;
113
114 } PROV_RSA_CTX;
115
116 static size_t rsa_get_md_size(const PROV_RSA_CTX *prsactx)
117 {
118     if (prsactx->md != NULL)
119         return EVP_MD_size(prsactx->md);
120     return 0;
121 }
122
123 static int rsa_get_md_nid_check(const PROV_RSA_CTX *ctx, const EVP_MD *md,
124                                 int sha1_allowed)
125 {
126     int mdnid = NID_undef;
127
128     #ifndef FIPS_MODULE
129     static const OSSL_ITEM name_to_nid[] = {
130         { NID_md5,       OSSL_DIGEST_NAME_MD5       },
131         { NID_md5_sha1,  OSSL_DIGEST_NAME_MD5_SHA1  },
132         { NID_md2,       OSSL_DIGEST_NAME_MD2       },
133         { NID_md4,       OSSL_DIGEST_NAME_MD4       },
134         { NID_mdc2,      OSSL_DIGEST_NAME_MDC2      },
135         { NID_ripemd160, OSSL_DIGEST_NAME_RIPEMD160 },
136     };
137     #endif
138
139     if (md == NULL)
140         goto end;
141
142     mdnid = ossl_prov_digest_get_approved_nid(md, sha1_allowed);
143
144     #ifndef FIPS_MODULE
145     if (mdnid == NID_undef)
146         mdnid = ossl_prov_digest_md_to_nid(md, name_to_nid,
147                                            OSSL_NELEM(name_to_nid));
148     #endif
149     end:
150     return mdnid;
151 }
152
153 static int rsa_get_md_nid(const PROV_RSA_CTX *ctx, const EVP_MD *md)
154 {
155     return rsa_get_md_nid_check(ctx, md, ctx->operation != EVP_PKEY_OP_SIGN);
156 }
157
158 static int rsa_get_md_mgf1_nid(const PROV_RSA_CTX *ctx, const EVP_MD *md)
159 {
160     /* The default for mgf1 is SHA1 - so allow this */
161     return rsa_get_md_nid_check(ctx, md, 1);
162 }
163
164 static int rsa_check_key_size(const PROV_RSA_CTX *prsactx)
165 {
166 #ifdef FIPS_MODULE
167     int sz = RSA_bits(prsactx->rsa);
168
169     return (prsactx->operation == EVP_PKEY_OP_SIGN) ? (sz >= 2048) : (sz >= 1024);
170 #else
171     return 1;
172 #endif
173 }
174
175 static int rsa_check_padding(int mdnid, int padding)
176 {
177     if (padding == RSA_NO_PADDING) {
178         ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_PADDING_MODE);
179         return 0;
180     }
181
182     if (padding == RSA_X931_PADDING) {
183         if (RSA_X931_hash_id(mdnid) == -1) {
184             ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_X931_DIGEST);
185             return 0;
186         }
187     }
188
189     return 1;
190 }
191
192 static int rsa_check_parameters(PROV_RSA_CTX *prsactx)
193 {
194     if (prsactx->pad_mode == RSA_PKCS1_PSS_PADDING) {
195         int max_saltlen;
196
197         /* See if minimum salt length exceeds maximum possible */
198         max_saltlen = RSA_size(prsactx->rsa) - EVP_MD_size(prsactx->md);
199         if ((RSA_bits(prsactx->rsa) & 0x7) == 1)
200             max_saltlen--;
201         if (prsactx->min_saltlen < 0 || prsactx->min_saltlen > max_saltlen) {
202             ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_SALT_LENGTH);
203             return 0;
204         }
205     }
206     return 1;
207 }
208
209 static void *rsa_newctx(void *provctx, const char *propq)
210 {
211     PROV_RSA_CTX *prsactx = NULL;
212     char *propq_copy = NULL;
213
214     if (!ossl_prov_is_running())
215         return NULL;
216
217     if ((prsactx = OPENSSL_zalloc(sizeof(PROV_RSA_CTX))) == NULL
218         || (propq != NULL
219             && (propq_copy = OPENSSL_strdup(propq)) == NULL)) {
220         OPENSSL_free(prsactx);
221         ERR_raise(ERR_LIB_PROV, ERR_R_MALLOC_FAILURE);
222         return NULL;
223     }
224
225     prsactx->libctx = PROV_LIBRARY_CONTEXT_OF(provctx);
226     prsactx->flag_allow_md = 1;
227     prsactx->propq = propq_copy;
228     return prsactx;
229 }
230
231 /* True if PSS parameters are restricted */
232 #define rsa_pss_restricted(prsactx) (prsactx->min_saltlen != -1)
233
234 static int rsa_setup_md(PROV_RSA_CTX *ctx, const char *mdname,
235                         const char *mdprops)
236 {
237     if (mdprops == NULL)
238         mdprops = ctx->propq;
239
240     if (mdname != NULL) {
241         WPACKET pkt;
242         EVP_MD *md = EVP_MD_fetch(ctx->libctx, mdname, mdprops);
243         int md_nid = rsa_get_md_nid(ctx, md);
244         size_t mdname_len = strlen(mdname);
245
246         if (md == NULL
247             || md_nid == NID_undef
248             || !rsa_check_padding(md_nid, ctx->pad_mode)
249             || mdname_len >= sizeof(ctx->mdname)) {
250             if (md == NULL)
251                 ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_DIGEST,
252                                "%s could not be fetched", mdname);
253             if (md_nid == NID_undef)
254                 ERR_raise_data(ERR_LIB_PROV, PROV_R_DIGEST_NOT_ALLOWED,
255                                "digest=%s", mdname);
256             if (mdname_len >= sizeof(ctx->mdname))
257                 ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_DIGEST,
258                                "%s exceeds name buffer length", mdname);
259             EVP_MD_free(md);
260             return 0;
261         }
262
263         EVP_MD_CTX_free(ctx->mdctx);
264         EVP_MD_free(ctx->md);
265
266         /*
267          * TODO(3.0) Should we care about DER writing errors?
268          * All it really means is that for some reason, there's no
269          * AlgorithmIdentifier to be had (consider RSA with MD5-SHA1),
270          * but the operation itself is still valid, just as long as it's
271          * not used to construct anything that needs an AlgorithmIdentifier.
272          */
273         ctx->aid_len = 0;
274         if (WPACKET_init_der(&pkt, ctx->aid_buf, sizeof(ctx->aid_buf))
275             && DER_w_algorithmIdentifier_MDWithRSAEncryption(&pkt, -1, ctx->rsa,
276                                                              md_nid)
277             && WPACKET_finish(&pkt)) {
278             WPACKET_get_total_written(&pkt, &ctx->aid_len);
279             ctx->aid = WPACKET_get_curr(&pkt);
280         }
281         WPACKET_cleanup(&pkt);
282
283         ctx->mdctx = NULL;
284         ctx->md = md;
285         ctx->mdnid = md_nid;
286         OPENSSL_strlcpy(ctx->mdname, mdname, sizeof(ctx->mdname));
287     }
288
289     return 1;
290 }
291
292 static int rsa_setup_mgf1_md(PROV_RSA_CTX *ctx, const char *mdname,
293                              const char *mdprops)
294 {
295     size_t len;
296     EVP_MD *md = NULL;
297
298     if (mdprops == NULL)
299         mdprops = ctx->propq;
300
301     if (ctx->mgf1_mdname[0] != '\0')
302         EVP_MD_free(ctx->mgf1_md);
303
304     if ((md = EVP_MD_fetch(ctx->libctx, mdname, mdprops)) == NULL) {
305         ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_DIGEST,
306                        "%s could not be fetched", mdname);
307         return 0;
308     }
309     if (rsa_get_md_mgf1_nid(ctx, md) == NID_undef) {
310         ERR_raise_data(ERR_LIB_PROV, PROV_R_DIGEST_NOT_ALLOWED,
311                        "digest=%s", mdname);
312         EVP_MD_free(md);
313         return 0;
314     }
315     ctx->mgf1_md = md;
316     len = OPENSSL_strlcpy(ctx->mgf1_mdname, mdname, sizeof(ctx->mgf1_mdname));
317     if (len >= sizeof(ctx->mgf1_mdname)) {
318         ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_DIGEST,
319                        "%s exceeds name buffer length", mdname);
320         return 0;
321     }
322
323     return 1;
324 }
325
326 static int rsa_signverify_init(void *vprsactx, void *vrsa, int operation)
327 {
328     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
329
330     if (!ossl_prov_is_running())
331         return 0;
332
333     if (prsactx == NULL || vrsa == NULL || !RSA_up_ref(vrsa))
334         return 0;
335
336     RSA_free(prsactx->rsa);
337     prsactx->rsa = vrsa;
338     prsactx->operation = operation;
339
340     if (!rsa_check_key_size(prsactx)) {
341         ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_KEY_LENGTH);
342         return 0;
343     }
344
345     /* Maximum for sign, auto for verify */
346     prsactx->saltlen = RSA_PSS_SALTLEN_AUTO;
347     prsactx->min_saltlen = -1;
348
349     switch (RSA_test_flags(prsactx->rsa, RSA_FLAG_TYPE_MASK)) {
350     case RSA_FLAG_TYPE_RSA:
351         prsactx->pad_mode = RSA_PKCS1_PADDING;
352         break;
353     case RSA_FLAG_TYPE_RSASSAPSS:
354         prsactx->pad_mode = RSA_PKCS1_PSS_PADDING;
355
356         {
357             const RSA_PSS_PARAMS_30 *pss =
358                 rsa_get0_pss_params_30(prsactx->rsa);
359
360             if (!rsa_pss_params_30_is_unrestricted(pss)) {
361                 int md_nid = rsa_pss_params_30_hashalg(pss);
362                 int mgf1md_nid = rsa_pss_params_30_maskgenhashalg(pss);
363                 int min_saltlen = rsa_pss_params_30_saltlen(pss);
364                 const char *mdname, *mgf1mdname;
365                 size_t len;
366
367                 mdname = rsa_oaeppss_nid2name(md_nid);
368                 mgf1mdname = rsa_oaeppss_nid2name(mgf1md_nid);
369                 prsactx->min_saltlen = min_saltlen;
370
371                 if (mdname == NULL) {
372                     ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_DIGEST,
373                                    "PSS restrictions lack hash algorithm");
374                     return 0;
375                 }
376                 if (mgf1mdname == NULL) {
377                     ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_DIGEST,
378                                    "PSS restrictions lack MGF1 hash algorithm");
379                     return 0;
380                 }
381
382                 len = OPENSSL_strlcpy(prsactx->mdname, mdname,
383                                       sizeof(prsactx->mdname));
384                 if (len >= sizeof(prsactx->mdname)) {
385                     ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_DIGEST,
386                                    "hash algorithm name too long");
387                     return 0;
388                 }
389                 len = OPENSSL_strlcpy(prsactx->mgf1_mdname, mgf1mdname,
390                                       sizeof(prsactx->mgf1_mdname));
391                 if (len >= sizeof(prsactx->mgf1_mdname)) {
392                     ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_DIGEST,
393                                    "MGF1 hash algorithm name too long");
394                     return 0;
395                 }
396                 prsactx->saltlen = min_saltlen;
397
398                 return rsa_setup_md(prsactx, mdname, prsactx->propq)
399                     && rsa_setup_mgf1_md(prsactx, mgf1mdname, prsactx->propq)
400                     && rsa_check_parameters(prsactx);
401             }
402         }
403
404         break;
405     default:
406         ERR_raise(ERR_LIB_RSA, PROV_R_OPERATION_NOT_SUPPORTED_FOR_THIS_KEYTYPE);
407         return 0;
408     }
409
410     return 1;
411 }
412
413 static int setup_tbuf(PROV_RSA_CTX *ctx)
414 {
415     if (ctx->tbuf != NULL)
416         return 1;
417     if ((ctx->tbuf = OPENSSL_malloc(RSA_size(ctx->rsa))) == NULL) {
418         ERR_raise(ERR_LIB_PROV, ERR_R_MALLOC_FAILURE);
419         return 0;
420     }
421     return 1;
422 }
423
424 static void clean_tbuf(PROV_RSA_CTX *ctx)
425 {
426     if (ctx->tbuf != NULL)
427         OPENSSL_cleanse(ctx->tbuf, RSA_size(ctx->rsa));
428 }
429
430 static void free_tbuf(PROV_RSA_CTX *ctx)
431 {
432     clean_tbuf(ctx);
433     OPENSSL_free(ctx->tbuf);
434     ctx->tbuf = NULL;
435 }
436
437 static int rsa_sign_init(void *vprsactx, void *vrsa)
438 {
439     if (!ossl_prov_is_running())
440         return 0;
441     return rsa_signverify_init(vprsactx, vrsa, EVP_PKEY_OP_SIGN);
442 }
443
444 static int rsa_sign(void *vprsactx, unsigned char *sig, size_t *siglen,
445                     size_t sigsize, const unsigned char *tbs, size_t tbslen)
446 {
447     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
448     int ret;
449     size_t rsasize = RSA_size(prsactx->rsa);
450     size_t mdsize = rsa_get_md_size(prsactx);
451
452     if (!ossl_prov_is_running())
453         return 0;
454
455     if (sig == NULL) {
456         *siglen = rsasize;
457         return 1;
458     }
459
460     if (sigsize < rsasize) {
461         ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_SIGNATURE_SIZE,
462                        "is %zu, should be at least %zu", sigsize, rsasize);
463         return 0;
464     }
465
466     if (mdsize != 0) {
467         if (tbslen != mdsize) {
468             ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_DIGEST_LENGTH);
469             return 0;
470         }
471
472 #ifndef FIPS_MODULE
473         if (EVP_MD_is_a(prsactx->md, OSSL_DIGEST_NAME_MDC2)) {
474             unsigned int sltmp;
475
476             if (prsactx->pad_mode != RSA_PKCS1_PADDING) {
477                 ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_PADDING_MODE,
478                                "only PKCS#1 padding supported with MDC2");
479                 return 0;
480             }
481             ret = RSA_sign_ASN1_OCTET_STRING(0, tbs, tbslen, sig, &sltmp,
482                                              prsactx->rsa);
483
484             if (ret <= 0) {
485                 ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
486                 return 0;
487             }
488             ret = sltmp;
489             goto end;
490         }
491 #endif
492         switch (prsactx->pad_mode) {
493         case RSA_X931_PADDING:
494             if ((size_t)RSA_size(prsactx->rsa) < tbslen + 1) {
495                 ERR_raise_data(ERR_LIB_PROV, PROV_R_KEY_SIZE_TOO_SMALL,
496                                "RSA key size = %d, expected minimum = %d",
497                                RSA_size(prsactx->rsa), tbslen + 1);
498                 return 0;
499             }
500             if (!setup_tbuf(prsactx)) {
501                 ERR_raise(ERR_LIB_PROV, ERR_R_MALLOC_FAILURE);
502                 return 0;
503             }
504             memcpy(prsactx->tbuf, tbs, tbslen);
505             prsactx->tbuf[tbslen] = RSA_X931_hash_id(prsactx->mdnid);
506             ret = RSA_private_encrypt(tbslen + 1, prsactx->tbuf,
507                                       sig, prsactx->rsa, RSA_X931_PADDING);
508             clean_tbuf(prsactx);
509             break;
510
511         case RSA_PKCS1_PADDING:
512             {
513                 unsigned int sltmp;
514
515                 ret = RSA_sign(prsactx->mdnid, tbs, tbslen, sig, &sltmp,
516                                prsactx->rsa);
517                 if (ret <= 0) {
518                     ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
519                     return 0;
520                 }
521                 ret = sltmp;
522             }
523             break;
524
525         case RSA_PKCS1_PSS_PADDING:
526             /* Check PSS restrictions */
527             if (rsa_pss_restricted(prsactx)) {
528                 switch (prsactx->saltlen) {
529                 case RSA_PSS_SALTLEN_DIGEST:
530                     if (prsactx->min_saltlen > EVP_MD_size(prsactx->md)) {
531                         ERR_raise_data(ERR_LIB_PROV,
532                                        PROV_R_PSS_SALTLEN_TOO_SMALL,
533                                        "minimum salt length set to %d, "
534                                        "but the digest only gives %d",
535                                        prsactx->min_saltlen,
536                                        EVP_MD_size(prsactx->md));
537                         return 0;
538                     }
539                     /* FALLTHRU */
540                 default:
541                     if (prsactx->saltlen >= 0
542                         && prsactx->saltlen < prsactx->min_saltlen) {
543                         ERR_raise_data(ERR_LIB_PROV,
544                                        PROV_R_PSS_SALTLEN_TOO_SMALL,
545                                        "minimum salt length set to %d, but the"
546                                        "actual salt length is only set to %d",
547                                        prsactx->min_saltlen,
548                                        prsactx->saltlen);
549                         return 0;
550                     }
551                     break;
552                 }
553             }
554             if (!setup_tbuf(prsactx))
555                 return 0;
556             if (!RSA_padding_add_PKCS1_PSS_mgf1(prsactx->rsa,
557                                                 prsactx->tbuf, tbs,
558                                                 prsactx->md, prsactx->mgf1_md,
559                                                 prsactx->saltlen)) {
560                 ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
561                 return 0;
562             }
563             ret = RSA_private_encrypt(RSA_size(prsactx->rsa), prsactx->tbuf,
564                                       sig, prsactx->rsa, RSA_NO_PADDING);
565             clean_tbuf(prsactx);
566             break;
567
568         default:
569             ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_PADDING_MODE,
570                            "Only X.931, PKCS#1 v1.5 or PSS padding allowed");
571             return 0;
572         }
573     } else {
574         ret = RSA_private_encrypt(tbslen, tbs, sig, prsactx->rsa,
575                                   prsactx->pad_mode);
576     }
577
578 #ifndef FIPS_MODULE
579  end:
580 #endif
581     if (ret <= 0) {
582         ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
583         return 0;
584     }
585
586     *siglen = ret;
587     return 1;
588 }
589
590 static int rsa_verify_recover_init(void *vprsactx, void *vrsa)
591 {
592     if (!ossl_prov_is_running())
593         return 0;
594     return rsa_signverify_init(vprsactx, vrsa, EVP_PKEY_OP_VERIFYRECOVER);
595 }
596
597 static int rsa_verify_recover(void *vprsactx,
598                               unsigned char *rout,
599                               size_t *routlen,
600                               size_t routsize,
601                               const unsigned char *sig,
602                               size_t siglen)
603 {
604     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
605     int ret;
606
607     if (!ossl_prov_is_running())
608         return 0;
609
610     if (rout == NULL) {
611         *routlen = RSA_size(prsactx->rsa);
612         return 1;
613     }
614
615     if (prsactx->md != NULL) {
616         switch (prsactx->pad_mode) {
617         case RSA_X931_PADDING:
618             if (!setup_tbuf(prsactx))
619                 return 0;
620             ret = RSA_public_decrypt(siglen, sig, prsactx->tbuf, prsactx->rsa,
621                                      RSA_X931_PADDING);
622             if (ret < 1) {
623                 ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
624                 return 0;
625             }
626             ret--;
627             if (prsactx->tbuf[ret] != RSA_X931_hash_id(prsactx->mdnid)) {
628                 ERR_raise(ERR_LIB_PROV, PROV_R_ALGORITHM_MISMATCH);
629                 return 0;
630             }
631             if (ret != EVP_MD_size(prsactx->md)) {
632                 ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_DIGEST_LENGTH,
633                                "Should be %d, but got %d",
634                                EVP_MD_size(prsactx->md), ret);
635                 return 0;
636             }
637
638             *routlen = ret;
639             if (rout != prsactx->tbuf) {
640                 if (routsize < (size_t)ret) {
641                     ERR_raise_data(ERR_LIB_PROV, PROV_R_OUTPUT_BUFFER_TOO_SMALL,
642                                    "buffer size is %d, should be %d",
643                                    routsize, ret);
644                     return 0;
645                 }
646                 memcpy(rout, prsactx->tbuf, ret);
647             }
648             break;
649
650         case RSA_PKCS1_PADDING:
651             {
652                 size_t sltmp;
653
654                 ret = int_rsa_verify(prsactx->mdnid, NULL, 0, rout, &sltmp,
655                                      sig, siglen, prsactx->rsa);
656                 if (ret <= 0) {
657                     ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
658                     return 0;
659                 }
660                 ret = sltmp;
661             }
662             break;
663
664         default:
665             ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_PADDING_MODE,
666                            "Only X.931 or PKCS#1 v1.5 padding allowed");
667             return 0;
668         }
669     } else {
670         ret = RSA_public_decrypt(siglen, sig, rout, prsactx->rsa,
671                                  prsactx->pad_mode);
672         if (ret < 0) {
673             ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
674             return 0;
675         }
676     }
677     *routlen = ret;
678     return 1;
679 }
680
681 static int rsa_verify_init(void *vprsactx, void *vrsa)
682 {
683     if (!ossl_prov_is_running())
684         return 0;
685     return rsa_signverify_init(vprsactx, vrsa, EVP_PKEY_OP_VERIFY);
686 }
687
688 static int rsa_verify(void *vprsactx, const unsigned char *sig, size_t siglen,
689                       const unsigned char *tbs, size_t tbslen)
690 {
691     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
692     size_t rslen;
693
694     if (!ossl_prov_is_running())
695         return 0;
696     if (prsactx->md != NULL) {
697         switch (prsactx->pad_mode) {
698         case RSA_PKCS1_PADDING:
699             if (!RSA_verify(prsactx->mdnid, tbs, tbslen, sig, siglen,
700                             prsactx->rsa)) {
701                 ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
702                 return 0;
703             }
704             return 1;
705         case RSA_X931_PADDING:
706             if (!setup_tbuf(prsactx))
707                 return 0;
708             if (rsa_verify_recover(prsactx, prsactx->tbuf, &rslen, 0,
709                                    sig, siglen) <= 0)
710                 return 0;
711             break;
712         case RSA_PKCS1_PSS_PADDING:
713             {
714                 int ret;
715                 size_t mdsize;
716
717                 /*
718                  * We need to check this for the RSA_verify_PKCS1_PSS_mgf1()
719                  * call
720                  */
721                 mdsize = rsa_get_md_size(prsactx);
722                 if (tbslen != mdsize) {
723                     ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_DIGEST_LENGTH,
724                                    "Should be %d, but got %d",
725                                    mdsize, tbslen);
726                     return 0;
727                 }
728
729                 if (!setup_tbuf(prsactx))
730                     return 0;
731                 ret = RSA_public_decrypt(siglen, sig, prsactx->tbuf,
732                                          prsactx->rsa, RSA_NO_PADDING);
733                 if (ret <= 0) {
734                     ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
735                     return 0;
736                 }
737                 ret = RSA_verify_PKCS1_PSS_mgf1(prsactx->rsa, tbs,
738                                                 prsactx->md, prsactx->mgf1_md,
739                                                 prsactx->tbuf,
740                                                 prsactx->saltlen);
741                 if (ret <= 0) {
742                     ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
743                     return 0;
744                 }
745                 return 1;
746             }
747         default:
748             ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_PADDING_MODE,
749                            "Only X.931, PKCS#1 v1.5 or PSS padding allowed");
750             return 0;
751         }
752     } else {
753         if (!setup_tbuf(prsactx))
754             return 0;
755         rslen = RSA_public_decrypt(siglen, sig, prsactx->tbuf, prsactx->rsa,
756                                    prsactx->pad_mode);
757         if (rslen == 0) {
758             ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
759             return 0;
760         }
761     }
762
763     if ((rslen != tbslen) || memcmp(tbs, prsactx->tbuf, rslen))
764         return 0;
765
766     return 1;
767 }
768
769 static int rsa_digest_signverify_init(void *vprsactx, const char *mdname,
770                                       void *vrsa, int operation)
771 {
772     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
773
774     if (!ossl_prov_is_running())
775         return 0;
776
777     if (prsactx != NULL)
778         prsactx->flag_allow_md = 0;
779     if (!rsa_signverify_init(vprsactx, vrsa, operation)
780         || !rsa_setup_md(prsactx, mdname, NULL)) /* TODO RL */
781         return 0;
782
783     prsactx->mdctx = EVP_MD_CTX_new();
784     if (prsactx->mdctx == NULL) {
785         ERR_raise(ERR_LIB_PROV, ERR_R_MALLOC_FAILURE);
786         goto error;
787     }
788
789     if (!EVP_DigestInit_ex(prsactx->mdctx, prsactx->md, NULL))
790         goto error;
791
792     return 1;
793
794  error:
795     EVP_MD_CTX_free(prsactx->mdctx);
796     EVP_MD_free(prsactx->md);
797     prsactx->mdctx = NULL;
798     prsactx->md = NULL;
799     return 0;
800 }
801
802 static int rsa_digest_signverify_update(void *vprsactx,
803                                         const unsigned char *data,
804                                         size_t datalen)
805 {
806     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
807
808     if (prsactx == NULL || prsactx->mdctx == NULL)
809         return 0;
810
811     return EVP_DigestUpdate(prsactx->mdctx, data, datalen);
812 }
813
814 static int rsa_digest_sign_init(void *vprsactx, const char *mdname,
815                                 void *vrsa)
816 {
817     if (!ossl_prov_is_running())
818         return 0;
819     return rsa_digest_signverify_init(vprsactx, mdname, vrsa,
820                                       EVP_PKEY_OP_SIGN);
821 }
822
823 static int rsa_digest_sign_final(void *vprsactx, unsigned char *sig,
824                                  size_t *siglen, size_t sigsize)
825 {
826     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
827     unsigned char digest[EVP_MAX_MD_SIZE];
828     unsigned int dlen = 0;
829
830     if (!ossl_prov_is_running() || prsactx == NULL)
831         return 0;
832     prsactx->flag_allow_md = 1;
833     if (prsactx->mdctx == NULL)
834         return 0;
835     /*
836      * If sig is NULL then we're just finding out the sig size. Other fields
837      * are ignored. Defer to rsa_sign.
838      */
839     if (sig != NULL) {
840         /*
841          * The digests used here are all known (see rsa_get_md_nid()), so they
842          * should not exceed the internal buffer size of EVP_MAX_MD_SIZE.
843          */
844         if (!EVP_DigestFinal_ex(prsactx->mdctx, digest, &dlen))
845             return 0;
846     }
847
848     return rsa_sign(vprsactx, sig, siglen, sigsize, digest, (size_t)dlen);
849 }
850
851 static int rsa_digest_verify_init(void *vprsactx, const char *mdname,
852                                   void *vrsa)
853 {
854     if (!ossl_prov_is_running())
855         return 0;
856     return rsa_digest_signverify_init(vprsactx, mdname, vrsa,
857                                       EVP_PKEY_OP_VERIFY);
858 }
859
860 int rsa_digest_verify_final(void *vprsactx, const unsigned char *sig,
861                             size_t siglen)
862 {
863     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
864     unsigned char digest[EVP_MAX_MD_SIZE];
865     unsigned int dlen = 0;
866
867     if (!ossl_prov_is_running())
868         return 0;
869
870     if (prsactx == NULL)
871         return 0;
872     prsactx->flag_allow_md = 1;
873     if (prsactx->mdctx == NULL)
874         return 0;
875
876     /*
877      * The digests used here are all known (see rsa_get_md_nid()), so they
878      * should not exceed the internal buffer size of EVP_MAX_MD_SIZE.
879      */
880     if (!EVP_DigestFinal_ex(prsactx->mdctx, digest, &dlen))
881         return 0;
882
883     return rsa_verify(vprsactx, sig, siglen, digest, (size_t)dlen);
884 }
885
886 static void rsa_freectx(void *vprsactx)
887 {
888     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
889
890     if (prsactx == NULL)
891         return;
892
893     EVP_MD_CTX_free(prsactx->mdctx);
894     EVP_MD_free(prsactx->md);
895     EVP_MD_free(prsactx->mgf1_md);
896     OPENSSL_free(prsactx->propq);
897     free_tbuf(prsactx);
898     RSA_free(prsactx->rsa);
899
900     OPENSSL_clear_free(prsactx, sizeof(*prsactx));
901 }
902
903 static void *rsa_dupctx(void *vprsactx)
904 {
905     PROV_RSA_CTX *srcctx = (PROV_RSA_CTX *)vprsactx;
906     PROV_RSA_CTX *dstctx;
907
908     if (!ossl_prov_is_running())
909         return NULL;
910
911     dstctx = OPENSSL_zalloc(sizeof(*srcctx));
912     if (dstctx == NULL) {
913         ERR_raise(ERR_LIB_PROV, ERR_R_MALLOC_FAILURE);
914         return NULL;
915     }
916
917     *dstctx = *srcctx;
918     dstctx->rsa = NULL;
919     dstctx->md = NULL;
920     dstctx->mdctx = NULL;
921     dstctx->tbuf = NULL;
922
923     if (srcctx->rsa != NULL && !RSA_up_ref(srcctx->rsa))
924         goto err;
925     dstctx->rsa = srcctx->rsa;
926
927     if (srcctx->md != NULL && !EVP_MD_up_ref(srcctx->md))
928         goto err;
929     dstctx->md = srcctx->md;
930
931     if (srcctx->mgf1_md != NULL && !EVP_MD_up_ref(srcctx->mgf1_md))
932         goto err;
933     dstctx->mgf1_md = srcctx->mgf1_md;
934
935     if (srcctx->mdctx != NULL) {
936         dstctx->mdctx = EVP_MD_CTX_new();
937         if (dstctx->mdctx == NULL
938                 || !EVP_MD_CTX_copy_ex(dstctx->mdctx, srcctx->mdctx))
939             goto err;
940     }
941
942     return dstctx;
943  err:
944     rsa_freectx(dstctx);
945     return NULL;
946 }
947
948 static int rsa_get_ctx_params(void *vprsactx, OSSL_PARAM *params)
949 {
950     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
951     OSSL_PARAM *p;
952
953     if (prsactx == NULL || params == NULL)
954         return 0;
955
956     p = OSSL_PARAM_locate(params, OSSL_SIGNATURE_PARAM_ALGORITHM_ID);
957     if (p != NULL
958         && !OSSL_PARAM_set_octet_string(p, prsactx->aid, prsactx->aid_len))
959         return 0;
960
961     p = OSSL_PARAM_locate(params, OSSL_SIGNATURE_PARAM_PAD_MODE);
962     if (p != NULL)
963         switch (p->data_type) {
964         case OSSL_PARAM_INTEGER:
965             if (!OSSL_PARAM_set_int(p, prsactx->pad_mode))
966                 return 0;
967             break;
968         case OSSL_PARAM_UTF8_STRING:
969             {
970                 int i;
971                 const char *word = NULL;
972
973                 for (i = 0; padding_item[i].id != 0; i++) {
974                     if (prsactx->pad_mode == (int)padding_item[i].id) {
975                         word = padding_item[i].ptr;
976                         break;
977                     }
978                 }
979
980                 if (word != NULL) {
981                     if (!OSSL_PARAM_set_utf8_string(p, word))
982                         return 0;
983                 } else {
984                     ERR_raise(ERR_LIB_PROV, ERR_R_INTERNAL_ERROR);
985                 }
986             }
987             break;
988         default:
989             return 0;
990         }
991
992     p = OSSL_PARAM_locate(params, OSSL_SIGNATURE_PARAM_DIGEST);
993     if (p != NULL && !OSSL_PARAM_set_utf8_string(p, prsactx->mdname))
994         return 0;
995
996     p = OSSL_PARAM_locate(params, OSSL_SIGNATURE_PARAM_MGF1_DIGEST);
997     if (p != NULL && !OSSL_PARAM_set_utf8_string(p, prsactx->mgf1_mdname))
998         return 0;
999
1000     p = OSSL_PARAM_locate(params, OSSL_SIGNATURE_PARAM_PSS_SALTLEN);
1001     if (p != NULL) {
1002         if (p->data_type == OSSL_PARAM_INTEGER) {
1003             if (!OSSL_PARAM_set_int(p, prsactx->saltlen))
1004                 return 0;
1005         } else if (p->data_type == OSSL_PARAM_UTF8_STRING) {
1006             const char *value = NULL;
1007
1008             switch (prsactx->saltlen) {
1009             case RSA_PSS_SALTLEN_DIGEST:
1010                 value = OSSL_PKEY_RSA_PSS_SALT_LEN_DIGEST;
1011                 break;
1012             case RSA_PSS_SALTLEN_MAX:
1013                 value = OSSL_PKEY_RSA_PSS_SALT_LEN_MAX;
1014                 break;
1015             case RSA_PSS_SALTLEN_AUTO:
1016                 value = OSSL_PKEY_RSA_PSS_SALT_LEN_AUTO;
1017                 break;
1018             default:
1019                 {
1020                     int len = BIO_snprintf(p->data, p->data_size, "%d",
1021                                            prsactx->saltlen);
1022
1023                     if (len <= 0)
1024                         return 0;
1025                     p->return_size = len;
1026                     break;
1027                 }
1028             }
1029             if (value != NULL
1030                 && !OSSL_PARAM_set_utf8_string(p, value))
1031                 return 0;
1032         }
1033     }
1034
1035     return 1;
1036 }
1037
1038 static const OSSL_PARAM known_gettable_ctx_params[] = {
1039     OSSL_PARAM_octet_string(OSSL_SIGNATURE_PARAM_ALGORITHM_ID, NULL, 0),
1040     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_PAD_MODE, NULL, 0),
1041     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_DIGEST, NULL, 0),
1042     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_MGF1_DIGEST, NULL, 0),
1043     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_PSS_SALTLEN, NULL, 0),
1044     OSSL_PARAM_END
1045 };
1046
1047 static const OSSL_PARAM *rsa_gettable_ctx_params(ossl_unused void *vctx)
1048 {
1049     return known_gettable_ctx_params;
1050 }
1051
1052 static int rsa_set_ctx_params(void *vprsactx, const OSSL_PARAM params[])
1053 {
1054     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
1055     const OSSL_PARAM *p;
1056
1057     if (prsactx == NULL || params == NULL)
1058         return 0;
1059
1060     p = OSSL_PARAM_locate_const(params, OSSL_SIGNATURE_PARAM_DIGEST);
1061     /* Not allowed during certain operations */
1062     if (p != NULL && !prsactx->flag_allow_md)
1063         return 0;
1064     if (p != NULL) {
1065         char mdname[OSSL_MAX_NAME_SIZE] = "", *pmdname = mdname;
1066         char mdprops[OSSL_MAX_PROPQUERY_SIZE] = "", *pmdprops = mdprops;
1067         const OSSL_PARAM *propsp =
1068             OSSL_PARAM_locate_const(params,
1069                                     OSSL_SIGNATURE_PARAM_PROPERTIES);
1070
1071         if (!OSSL_PARAM_get_utf8_string(p, &pmdname, sizeof(mdname)))
1072             return 0;
1073
1074         if (propsp == NULL)
1075             pmdprops = NULL;
1076         else if (!OSSL_PARAM_get_utf8_string(propsp,
1077                                              &pmdprops, sizeof(mdprops)))
1078             return 0;
1079
1080         if (rsa_pss_restricted(prsactx)) {
1081             /* TODO(3.0) figure out what to do for prsactx->md == NULL */
1082             if (prsactx->md == NULL || EVP_MD_is_a(prsactx->md, mdname))
1083                 return 1;
1084             ERR_raise(ERR_LIB_PROV, PROV_R_DIGEST_NOT_ALLOWED);
1085             return 0;
1086         }
1087
1088         /* non-PSS code follows */
1089         if (!rsa_setup_md(prsactx, mdname, pmdprops))
1090             return 0;
1091     }
1092
1093     p = OSSL_PARAM_locate_const(params, OSSL_SIGNATURE_PARAM_PAD_MODE);
1094     if (p != NULL) {
1095         int pad_mode = 0;
1096         const char *err_extra_text = NULL;
1097
1098         switch (p->data_type) {
1099         case OSSL_PARAM_INTEGER: /* Support for legacy pad mode number */
1100             if (!OSSL_PARAM_get_int(p, &pad_mode))
1101                 return 0;
1102             break;
1103         case OSSL_PARAM_UTF8_STRING:
1104             {
1105                 int i;
1106
1107                 if (p->data == NULL)
1108                     return 0;
1109
1110                 for (i = 0; padding_item[i].id != 0; i++) {
1111                     if (strcmp(p->data, padding_item[i].ptr) == 0) {
1112                         pad_mode = padding_item[i].id;
1113                         break;
1114                     }
1115                 }
1116             }
1117             break;
1118         default:
1119             return 0;
1120         }
1121
1122         switch (pad_mode) {
1123         case RSA_PKCS1_OAEP_PADDING:
1124             /*
1125              * OAEP padding is for asymmetric cipher only so is not compatible
1126              * with signature use.
1127              */
1128             err_extra_text = "OAEP padding not allowed for signing / verifying";
1129             goto bad_pad;
1130         case RSA_PKCS1_PSS_PADDING:
1131             if ((prsactx->operation
1132                  & (EVP_PKEY_OP_SIGN | EVP_PKEY_OP_VERIFY)) == 0) {
1133                 err_extra_text =
1134                     "PSS padding only allowed for sign and verify operations";
1135                 goto bad_pad;
1136             }
1137             if (prsactx->md == NULL
1138                 && !rsa_setup_md(prsactx, RSA_DEFAULT_DIGEST_NAME, NULL)) {
1139                 return 0;
1140             }
1141             break;
1142         case RSA_PKCS1_PADDING:
1143             err_extra_text = "PKCS#1 padding not allowed with RSA-PSS";
1144             goto cont;
1145         case RSA_SSLV23_PADDING:
1146             err_extra_text = "SSLv3 padding not allowed with RSA-PSS";
1147             goto cont;
1148         case RSA_NO_PADDING:
1149             err_extra_text = "No padding not allowed with RSA-PSS";
1150             goto cont;
1151         case RSA_X931_PADDING:
1152             err_extra_text = "X.931 padding not allowed with RSA-PSS";
1153         cont:
1154             if (RSA_test_flags(prsactx->rsa,
1155                                RSA_FLAG_TYPE_MASK) == RSA_FLAG_TYPE_RSA)
1156                 break;
1157             /* FALLTHRU */
1158         default:
1159         bad_pad:
1160             if (err_extra_text == NULL)
1161                 ERR_raise(ERR_LIB_PROV,
1162                           PROV_R_ILLEGAL_OR_UNSUPPORTED_PADDING_MODE);
1163             else
1164                 ERR_raise_data(ERR_LIB_PROV,
1165                                PROV_R_ILLEGAL_OR_UNSUPPORTED_PADDING_MODE,
1166                                err_extra_text);
1167             return 0;
1168         }
1169         if (!rsa_check_padding(prsactx->mdnid, pad_mode))
1170             return 0;
1171         prsactx->pad_mode = pad_mode;
1172     }
1173
1174     p = OSSL_PARAM_locate_const(params, OSSL_SIGNATURE_PARAM_PSS_SALTLEN);
1175     if (p != NULL) {
1176         int saltlen;
1177
1178         if (prsactx->pad_mode != RSA_PKCS1_PSS_PADDING) {
1179             ERR_raise_data(ERR_LIB_PROV, PROV_R_NOT_SUPPORTED,
1180                            "PSS saltlen can only be specified if "
1181                            "PSS padding has been specified first");
1182             return 0;
1183         }
1184
1185         switch (p->data_type) {
1186         case OSSL_PARAM_INTEGER: /* Support for legacy pad mode number */
1187             if (!OSSL_PARAM_get_int(p, &saltlen))
1188                 return 0;
1189             break;
1190         case OSSL_PARAM_UTF8_STRING:
1191             if (strcmp(p->data, OSSL_PKEY_RSA_PSS_SALT_LEN_DIGEST) == 0)
1192                 saltlen = RSA_PSS_SALTLEN_DIGEST;
1193             else if (strcmp(p->data, OSSL_PKEY_RSA_PSS_SALT_LEN_MAX) == 0)
1194                 saltlen = RSA_PSS_SALTLEN_MAX;
1195             else if (strcmp(p->data, OSSL_PKEY_RSA_PSS_SALT_LEN_AUTO) == 0)
1196                 saltlen = RSA_PSS_SALTLEN_AUTO;
1197             else
1198                 saltlen = atoi(p->data);
1199             break;
1200         default:
1201             return 0;
1202         }
1203
1204         /*
1205          * RSA_PSS_SALTLEN_MAX seems curiously named in this check.
1206          * Contrary to what it's name suggests, it's the currently
1207          * lowest saltlen number possible.
1208          */
1209         if (saltlen < RSA_PSS_SALTLEN_MAX) {
1210             ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_PSS_SALTLEN);
1211             return 0;
1212         }
1213
1214         if (rsa_pss_restricted(prsactx)) {
1215             switch (saltlen) {
1216             case RSA_PSS_SALTLEN_AUTO:
1217                 if (prsactx->operation == EVP_PKEY_OP_VERIFY) {
1218                     ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_PSS_SALTLEN);
1219                     return 0;
1220                 }
1221                 break;
1222             case RSA_PSS_SALTLEN_DIGEST:
1223                 if (prsactx->min_saltlen > EVP_MD_size(prsactx->md)) {
1224                     ERR_raise_data(ERR_LIB_PROV,
1225                                    PROV_R_PSS_SALTLEN_TOO_SMALL,
1226                                    "Should be more than %d, but would be "
1227                                    "set to match digest size (%d)",
1228                                    prsactx->min_saltlen,
1229                                    EVP_MD_size(prsactx->md));
1230                     return 0;
1231                 }
1232                 break;
1233             default:
1234                 if (saltlen >= 0 && saltlen < prsactx->min_saltlen) {
1235                     ERR_raise_data(ERR_LIB_PROV,
1236                                    PROV_R_PSS_SALTLEN_TOO_SMALL,
1237                                    "Should be more than %d, "
1238                                    "but would be set to %d",
1239                                    prsactx->min_saltlen, saltlen);
1240                     return 0;
1241                 }
1242             }
1243         }
1244
1245         prsactx->saltlen = saltlen;
1246     }
1247
1248     p = OSSL_PARAM_locate_const(params, OSSL_SIGNATURE_PARAM_MGF1_DIGEST);
1249     if (p != NULL) {
1250         char mdname[OSSL_MAX_NAME_SIZE] = "", *pmdname = mdname;
1251         char mdprops[OSSL_MAX_PROPQUERY_SIZE] = "", *pmdprops = mdprops;
1252         const OSSL_PARAM *propsp =
1253             OSSL_PARAM_locate_const(params,
1254                                     OSSL_SIGNATURE_PARAM_MGF1_PROPERTIES);
1255
1256         if (!OSSL_PARAM_get_utf8_string(p, &pmdname, sizeof(mdname)))
1257             return 0;
1258
1259         if (propsp == NULL)
1260             pmdprops = NULL;
1261         else if (!OSSL_PARAM_get_utf8_string(propsp,
1262                                              &pmdprops, sizeof(mdprops)))
1263             return 0;
1264
1265         if (prsactx->pad_mode != RSA_PKCS1_PSS_PADDING) {
1266             ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_MGF1_MD);
1267             return  0;
1268         }
1269
1270         if (rsa_pss_restricted(prsactx)) {
1271             /* TODO(3.0) figure out what to do for prsactx->mgf1_md == NULL */
1272             if (prsactx->mgf1_md == NULL
1273                 || EVP_MD_is_a(prsactx->mgf1_md, mdname))
1274                 return 1;
1275             ERR_raise(ERR_LIB_PROV, PROV_R_DIGEST_NOT_ALLOWED);
1276             return 0;
1277         }
1278
1279         /* non-PSS code follows */
1280         if (!rsa_setup_mgf1_md(prsactx, mdname, pmdprops))
1281             return 0;
1282     }
1283
1284     return 1;
1285 }
1286
1287 static const OSSL_PARAM known_settable_ctx_params[] = {
1288     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_PAD_MODE, NULL, 0),
1289     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_DIGEST, NULL, 0),
1290     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_PROPERTIES, NULL, 0),
1291     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_MGF1_DIGEST, NULL, 0),
1292     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_MGF1_PROPERTIES, NULL, 0),
1293     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_PSS_SALTLEN, NULL, 0),
1294     OSSL_PARAM_END
1295 };
1296
1297 static const OSSL_PARAM *rsa_settable_ctx_params(void *provctx)
1298 {
1299     /*
1300      * TODO(3.0): Should this function return a different set of settable ctx
1301      * params if the ctx is being used for a DigestSign/DigestVerify? In that
1302      * case it is not allowed to set the digest size/digest name because the
1303      * digest is explicitly set as part of the init.
1304      */
1305     return known_settable_ctx_params;
1306 }
1307
1308 static int rsa_get_ctx_md_params(void *vprsactx, OSSL_PARAM *params)
1309 {
1310     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
1311
1312     if (prsactx->mdctx == NULL)
1313         return 0;
1314
1315     return EVP_MD_CTX_get_params(prsactx->mdctx, params);
1316 }
1317
1318 static const OSSL_PARAM *rsa_gettable_ctx_md_params(void *vprsactx)
1319 {
1320     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
1321
1322     if (prsactx->md == NULL)
1323         return 0;
1324
1325     return EVP_MD_gettable_ctx_params(prsactx->md);
1326 }
1327
1328 static int rsa_set_ctx_md_params(void *vprsactx, const OSSL_PARAM params[])
1329 {
1330     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
1331
1332     if (prsactx->mdctx == NULL)
1333         return 0;
1334
1335     return EVP_MD_CTX_set_params(prsactx->mdctx, params);
1336 }
1337
1338 static const OSSL_PARAM *rsa_settable_ctx_md_params(void *vprsactx)
1339 {
1340     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
1341
1342     if (prsactx->md == NULL)
1343         return 0;
1344
1345     return EVP_MD_settable_ctx_params(prsactx->md);
1346 }
1347
1348 const OSSL_DISPATCH rsa_signature_functions[] = {
1349     { OSSL_FUNC_SIGNATURE_NEWCTX, (void (*)(void))rsa_newctx },
1350     { OSSL_FUNC_SIGNATURE_SIGN_INIT, (void (*)(void))rsa_sign_init },
1351     { OSSL_FUNC_SIGNATURE_SIGN, (void (*)(void))rsa_sign },
1352     { OSSL_FUNC_SIGNATURE_VERIFY_INIT, (void (*)(void))rsa_verify_init },
1353     { OSSL_FUNC_SIGNATURE_VERIFY, (void (*)(void))rsa_verify },
1354     { OSSL_FUNC_SIGNATURE_VERIFY_RECOVER_INIT,
1355       (void (*)(void))rsa_verify_recover_init },
1356     { OSSL_FUNC_SIGNATURE_VERIFY_RECOVER,
1357       (void (*)(void))rsa_verify_recover },
1358     { OSSL_FUNC_SIGNATURE_DIGEST_SIGN_INIT,
1359       (void (*)(void))rsa_digest_sign_init },
1360     { OSSL_FUNC_SIGNATURE_DIGEST_SIGN_UPDATE,
1361       (void (*)(void))rsa_digest_signverify_update },
1362     { OSSL_FUNC_SIGNATURE_DIGEST_SIGN_FINAL,
1363       (void (*)(void))rsa_digest_sign_final },
1364     { OSSL_FUNC_SIGNATURE_DIGEST_VERIFY_INIT,
1365       (void (*)(void))rsa_digest_verify_init },
1366     { OSSL_FUNC_SIGNATURE_DIGEST_VERIFY_UPDATE,
1367       (void (*)(void))rsa_digest_signverify_update },
1368     { OSSL_FUNC_SIGNATURE_DIGEST_VERIFY_FINAL,
1369       (void (*)(void))rsa_digest_verify_final },
1370     { OSSL_FUNC_SIGNATURE_FREECTX, (void (*)(void))rsa_freectx },
1371     { OSSL_FUNC_SIGNATURE_DUPCTX, (void (*)(void))rsa_dupctx },
1372     { OSSL_FUNC_SIGNATURE_GET_CTX_PARAMS, (void (*)(void))rsa_get_ctx_params },
1373     { OSSL_FUNC_SIGNATURE_GETTABLE_CTX_PARAMS,
1374       (void (*)(void))rsa_gettable_ctx_params },
1375     { OSSL_FUNC_SIGNATURE_SET_CTX_PARAMS, (void (*)(void))rsa_set_ctx_params },
1376     { OSSL_FUNC_SIGNATURE_SETTABLE_CTX_PARAMS,
1377       (void (*)(void))rsa_settable_ctx_params },
1378     { OSSL_FUNC_SIGNATURE_GET_CTX_MD_PARAMS,
1379       (void (*)(void))rsa_get_ctx_md_params },
1380     { OSSL_FUNC_SIGNATURE_GETTABLE_CTX_MD_PARAMS,
1381       (void (*)(void))rsa_gettable_ctx_md_params },
1382     { OSSL_FUNC_SIGNATURE_SET_CTX_MD_PARAMS,
1383       (void (*)(void))rsa_set_ctx_md_params },
1384     { OSSL_FUNC_SIGNATURE_SETTABLE_CTX_MD_PARAMS,
1385       (void (*)(void))rsa_settable_ctx_md_params },
1386     { 0, NULL }
1387 };