changes: add note about application output formatting differences.
[openssl.git] / providers / implementations / signature / rsa_sig.c
1 /*
2  * Copyright 2019-2021 The OpenSSL Project Authors. All Rights Reserved.
3  *
4  * Licensed under the Apache License 2.0 (the "License").  You may not use
5  * this file except in compliance with the License.  You can obtain a copy
6  * in the file LICENSE in the source distribution or at
7  * https://www.openssl.org/source/license.html
8  */
9
10 /*
11  * RSA low level APIs are deprecated for public use, but still ok for
12  * internal use.
13  */
14 #include "internal/deprecated.h"
15
16 #include "e_os.h" /* strcasecmp */
17 #include <string.h>
18 #include <openssl/crypto.h>
19 #include <openssl/core_dispatch.h>
20 #include <openssl/core_names.h>
21 #include <openssl/err.h>
22 #include <openssl/rsa.h>
23 #include <openssl/params.h>
24 #include <openssl/evp.h>
25 #include <openssl/proverr.h>
26 #include "internal/cryptlib.h"
27 #include "internal/nelem.h"
28 #include "internal/sizes.h"
29 #include "crypto/rsa.h"
30 #include "prov/providercommon.h"
31 #include "prov/implementations.h"
32 #include "prov/provider_ctx.h"
33 #include "prov/der_rsa.h"
34 #include "prov/securitycheck.h"
35
36 #define RSA_DEFAULT_DIGEST_NAME OSSL_DIGEST_NAME_SHA1
37
38 static OSSL_FUNC_signature_newctx_fn rsa_newctx;
39 static OSSL_FUNC_signature_sign_init_fn rsa_sign_init;
40 static OSSL_FUNC_signature_verify_init_fn rsa_verify_init;
41 static OSSL_FUNC_signature_verify_recover_init_fn rsa_verify_recover_init;
42 static OSSL_FUNC_signature_sign_fn rsa_sign;
43 static OSSL_FUNC_signature_verify_fn rsa_verify;
44 static OSSL_FUNC_signature_verify_recover_fn rsa_verify_recover;
45 static OSSL_FUNC_signature_digest_sign_init_fn rsa_digest_sign_init;
46 static OSSL_FUNC_signature_digest_sign_update_fn rsa_digest_signverify_update;
47 static OSSL_FUNC_signature_digest_sign_final_fn rsa_digest_sign_final;
48 static OSSL_FUNC_signature_digest_verify_init_fn rsa_digest_verify_init;
49 static OSSL_FUNC_signature_digest_verify_update_fn rsa_digest_signverify_update;
50 static OSSL_FUNC_signature_digest_verify_final_fn rsa_digest_verify_final;
51 static OSSL_FUNC_signature_freectx_fn rsa_freectx;
52 static OSSL_FUNC_signature_dupctx_fn rsa_dupctx;
53 static OSSL_FUNC_signature_get_ctx_params_fn rsa_get_ctx_params;
54 static OSSL_FUNC_signature_gettable_ctx_params_fn rsa_gettable_ctx_params;
55 static OSSL_FUNC_signature_set_ctx_params_fn rsa_set_ctx_params;
56 static OSSL_FUNC_signature_settable_ctx_params_fn rsa_settable_ctx_params;
57 static OSSL_FUNC_signature_get_ctx_md_params_fn rsa_get_ctx_md_params;
58 static OSSL_FUNC_signature_gettable_ctx_md_params_fn rsa_gettable_ctx_md_params;
59 static OSSL_FUNC_signature_set_ctx_md_params_fn rsa_set_ctx_md_params;
60 static OSSL_FUNC_signature_settable_ctx_md_params_fn rsa_settable_ctx_md_params;
61
62 static OSSL_ITEM padding_item[] = {
63     { RSA_PKCS1_PADDING,        OSSL_PKEY_RSA_PAD_MODE_PKCSV15 },
64     { RSA_NO_PADDING,           OSSL_PKEY_RSA_PAD_MODE_NONE },
65     { RSA_X931_PADDING,         OSSL_PKEY_RSA_PAD_MODE_X931 },
66     { RSA_PKCS1_PSS_PADDING,    OSSL_PKEY_RSA_PAD_MODE_PSS },
67     { 0,                        NULL     }
68 };
69
70 /*
71  * What's passed as an actual key is defined by the KEYMGMT interface.
72  * We happen to know that our KEYMGMT simply passes RSA structures, so
73  * we use that here too.
74  */
75
76 typedef struct {
77     OSSL_LIB_CTX *libctx;
78     char *propq;
79     RSA *rsa;
80     int operation;
81
82     /*
83      * Flag to determine if the hash function can be changed (1) or not (0)
84      * Because it's dangerous to change during a DigestSign or DigestVerify
85      * operation, this flag is cleared by their Init function, and set again
86      * by their Final function.
87      */
88     unsigned int flag_allow_md : 1;
89     unsigned int mgf1_md_set : 1;
90
91     /* main digest */
92     EVP_MD *md;
93     EVP_MD_CTX *mdctx;
94     int mdnid;
95     char mdname[OSSL_MAX_NAME_SIZE]; /* Purely informational */
96
97     /* RSA padding mode */
98     int pad_mode;
99     /* message digest for MGF1 */
100     EVP_MD *mgf1_md;
101     int mgf1_mdnid;
102     char mgf1_mdname[OSSL_MAX_NAME_SIZE]; /* Purely informational */
103     /* PSS salt length */
104     int saltlen;
105     /* Minimum salt length or -1 if no PSS parameter restriction */
106     int min_saltlen;
107
108     /* Temp buffer */
109     unsigned char *tbuf;
110
111 } PROV_RSA_CTX;
112
113 /* True if PSS parameters are restricted */
114 #define rsa_pss_restricted(prsactx) (prsactx->min_saltlen != -1)
115
116 static size_t rsa_get_md_size(const PROV_RSA_CTX *prsactx)
117 {
118     if (prsactx->md != NULL)
119         return EVP_MD_size(prsactx->md);
120     return 0;
121 }
122
123 static int rsa_check_padding(const PROV_RSA_CTX *prsactx,
124                              const char *mdname, const char *mgf1_mdname,
125                              int mdnid)
126 {
127     switch(prsactx->pad_mode) {
128         case RSA_NO_PADDING:
129             ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_PADDING_MODE);
130             return 0;
131         case RSA_X931_PADDING:
132             if (RSA_X931_hash_id(mdnid) == -1) {
133                 ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_X931_DIGEST);
134                 return 0;
135             }
136             break;
137         case RSA_PKCS1_PSS_PADDING:
138             if (rsa_pss_restricted(prsactx))
139                 if ((mdname != NULL && !EVP_MD_is_a(prsactx->md, mdname))
140                     || (mgf1_mdname != NULL
141                         && !EVP_MD_is_a(prsactx->mgf1_md, mgf1_mdname))) {
142                     ERR_raise(ERR_LIB_PROV, PROV_R_DIGEST_NOT_ALLOWED);
143                     return 0;
144                 }
145             break;
146         default:
147             break;
148     }
149
150     return 1;
151 }
152
153 static int rsa_check_parameters(PROV_RSA_CTX *prsactx, int min_saltlen)
154 {
155     if (prsactx->pad_mode == RSA_PKCS1_PSS_PADDING) {
156         int max_saltlen;
157
158         /* See if minimum salt length exceeds maximum possible */
159         max_saltlen = RSA_size(prsactx->rsa) - EVP_MD_size(prsactx->md);
160         if ((RSA_bits(prsactx->rsa) & 0x7) == 1)
161             max_saltlen--;
162         if (min_saltlen < 0 || min_saltlen > max_saltlen) {
163             ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_SALT_LENGTH);
164             return 0;
165         }
166         prsactx->min_saltlen = min_saltlen;
167     }
168     return 1;
169 }
170
171 static void *rsa_newctx(void *provctx, const char *propq)
172 {
173     PROV_RSA_CTX *prsactx = NULL;
174     char *propq_copy = NULL;
175
176     if (!ossl_prov_is_running())
177         return NULL;
178
179     if ((prsactx = OPENSSL_zalloc(sizeof(PROV_RSA_CTX))) == NULL
180         || (propq != NULL
181             && (propq_copy = OPENSSL_strdup(propq)) == NULL)) {
182         OPENSSL_free(prsactx);
183         ERR_raise(ERR_LIB_PROV, ERR_R_MALLOC_FAILURE);
184         return NULL;
185     }
186
187     prsactx->libctx = PROV_LIBCTX_OF(provctx);
188     prsactx->flag_allow_md = 1;
189     prsactx->propq = propq_copy;
190     return prsactx;
191 }
192
193 static int rsa_pss_compute_saltlen(PROV_RSA_CTX *ctx)
194 {
195     int saltlen = ctx->saltlen;
196  
197     if (saltlen == RSA_PSS_SALTLEN_DIGEST) {
198         saltlen = EVP_MD_size(ctx->md);
199     } else if (saltlen == RSA_PSS_SALTLEN_AUTO || saltlen == RSA_PSS_SALTLEN_MAX) {
200         saltlen = RSA_size(ctx->rsa) - EVP_MD_size(ctx->md) - 2;
201         if ((RSA_bits(ctx->rsa) & 0x7) == 1)
202             saltlen--;
203     }
204     if (saltlen < 0) {
205         ERR_raise(ERR_LIB_PROV, ERR_R_INTERNAL_ERROR);
206         return -1;
207     } else if (saltlen < ctx->min_saltlen) {
208         ERR_raise_data(ERR_LIB_PROV, PROV_R_PSS_SALTLEN_TOO_SMALL,
209                        "minimum salt length: %d, actual salt length: %d",
210                        ctx->min_saltlen, saltlen);
211         return -1;
212     }
213     return saltlen;
214 }
215
216 static unsigned char *rsa_generate_signature_aid(PROV_RSA_CTX *ctx,
217                                                  unsigned char *aid_buf,
218                                                  size_t buf_len,
219                                                  size_t *aid_len)
220 {
221     WPACKET pkt;
222     unsigned char *aid = NULL;
223     int saltlen;
224     RSA_PSS_PARAMS_30 pss_params;
225     int ret;
226
227     if (!WPACKET_init_der(&pkt, aid_buf, buf_len)) {
228         ERR_raise(ERR_LIB_PROV, ERR_R_MALLOC_FAILURE);
229         return NULL;
230     }
231
232     switch(ctx->pad_mode) {
233     case RSA_PKCS1_PADDING:
234         ret = ossl_DER_w_algorithmIdentifier_MDWithRSAEncryption(&pkt, -1,
235                                                                  ctx->mdnid);
236
237         if (ret > 0) {
238             break;
239         } else if (ret == 0) {
240             ERR_raise(ERR_LIB_PROV, ERR_R_INTERNAL_ERROR);
241             goto cleanup;
242         }
243         ERR_raise_data(ERR_LIB_PROV, ERR_R_UNSUPPORTED,
244                        "Algorithm ID generation - md NID: %d",
245                        ctx->mdnid);
246         goto cleanup;
247     case RSA_PKCS1_PSS_PADDING:
248         saltlen = rsa_pss_compute_saltlen(ctx);
249         if (saltlen < 0)
250             goto cleanup;
251         if (!ossl_rsa_pss_params_30_set_defaults(&pss_params)
252             || !ossl_rsa_pss_params_30_set_hashalg(&pss_params, ctx->mdnid)
253             || !ossl_rsa_pss_params_30_set_maskgenhashalg(&pss_params,
254                                                           ctx->mgf1_mdnid)
255             || !ossl_rsa_pss_params_30_set_saltlen(&pss_params, saltlen)
256             || !ossl_DER_w_algorithmIdentifier_RSA_PSS(&pkt, -1,
257                                                        RSA_FLAG_TYPE_RSASSAPSS,
258                                                        &pss_params)) {
259             ERR_raise(ERR_LIB_PROV, ERR_R_INTERNAL_ERROR);
260             goto cleanup;
261         }
262         break;
263     default:
264         ERR_raise_data(ERR_LIB_PROV, ERR_R_UNSUPPORTED,
265                        "Algorithm ID generation - pad mode: %d",
266                        ctx->pad_mode);
267         goto cleanup;
268     }
269     if (WPACKET_finish(&pkt)) {
270         WPACKET_get_total_written(&pkt, aid_len);
271         aid = WPACKET_get_curr(&pkt);
272     }
273  cleanup:
274     WPACKET_cleanup(&pkt);
275     return aid;
276 }
277
278 static int rsa_setup_md(PROV_RSA_CTX *ctx, const char *mdname,
279                         const char *mdprops)
280 {
281     if (mdprops == NULL)
282         mdprops = ctx->propq;
283
284     if (mdname != NULL) {
285         EVP_MD *md = EVP_MD_fetch(ctx->libctx, mdname, mdprops);
286         int sha1_allowed = (ctx->operation != EVP_PKEY_OP_SIGN);
287         int md_nid = ossl_digest_rsa_sign_get_md_nid(ctx->libctx, md,
288                                                      sha1_allowed);
289         size_t mdname_len = strlen(mdname);
290
291         if (md == NULL
292             || md_nid == NID_undef
293             || !rsa_check_padding(ctx, mdname, NULL, md_nid)
294             || mdname_len >= sizeof(ctx->mdname)) {
295             if (md == NULL)
296                 ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_DIGEST,
297                                "%s could not be fetched", mdname);
298             if (md_nid == NID_undef)
299                 ERR_raise_data(ERR_LIB_PROV, PROV_R_DIGEST_NOT_ALLOWED,
300                                "digest=%s", mdname);
301             if (mdname_len >= sizeof(ctx->mdname))
302                 ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_DIGEST,
303                                "%s exceeds name buffer length", mdname);
304             EVP_MD_free(md);
305             return 0;
306         }
307
308         if (!ctx->mgf1_md_set) {
309             if (!EVP_MD_up_ref(md)) {
310                 EVP_MD_free(md);
311                 return 0;
312             }
313             EVP_MD_free(ctx->mgf1_md);
314             ctx->mgf1_md = md;
315             ctx->mgf1_mdnid = md_nid;
316             OPENSSL_strlcpy(ctx->mgf1_mdname, mdname, sizeof(ctx->mgf1_mdname));
317         }
318
319         EVP_MD_CTX_free(ctx->mdctx);
320         EVP_MD_free(ctx->md);
321
322         ctx->mdctx = NULL;
323         ctx->md = md;
324         ctx->mdnid = md_nid;
325         OPENSSL_strlcpy(ctx->mdname, mdname, sizeof(ctx->mdname));
326     }
327
328     return 1;
329 }
330
331 static int rsa_setup_mgf1_md(PROV_RSA_CTX *ctx, const char *mdname,
332                              const char *mdprops)
333 {
334     size_t len;
335     EVP_MD *md = NULL;
336     int mdnid;
337
338     if (mdprops == NULL)
339         mdprops = ctx->propq;
340
341     if ((md = EVP_MD_fetch(ctx->libctx, mdname, mdprops)) == NULL) {
342         ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_DIGEST,
343                        "%s could not be fetched", mdname);
344         return 0;
345     }
346     /* The default for mgf1 is SHA1 - so allow SHA1 */
347     if ((mdnid = ossl_digest_rsa_sign_get_md_nid(ctx->libctx, md, 1)) == NID_undef
348         || !rsa_check_padding(ctx, NULL, mdname, mdnid)) {
349         if (mdnid == NID_undef)
350             ERR_raise_data(ERR_LIB_PROV, PROV_R_DIGEST_NOT_ALLOWED,
351                            "digest=%s", mdname);
352         EVP_MD_free(md);
353         return 0;
354     }
355     len = OPENSSL_strlcpy(ctx->mgf1_mdname, mdname, sizeof(ctx->mgf1_mdname));
356     if (len >= sizeof(ctx->mgf1_mdname)) {
357         ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_DIGEST,
358                        "%s exceeds name buffer length", mdname);
359         EVP_MD_free(md);
360         return 0;
361     }
362
363     EVP_MD_free(ctx->mgf1_md);
364     ctx->mgf1_md = md;
365     ctx->mgf1_mdnid = mdnid;
366     ctx->mgf1_md_set = 1;
367     return 1;
368 }
369
370 static int rsa_signverify_init(void *vprsactx, void *vrsa,
371                                const OSSL_PARAM params[], int operation)
372 {
373     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
374
375     if (!ossl_prov_is_running())
376         return 0;
377
378     if (prsactx == NULL || vrsa == NULL)
379         return 0;
380
381     if (!ossl_rsa_check_key(prsactx->libctx, vrsa, operation))
382         return 0;
383
384     if (!RSA_up_ref(vrsa))
385         return 0;
386     RSA_free(prsactx->rsa);
387     prsactx->rsa = vrsa;
388     prsactx->operation = operation;
389
390     if (!rsa_set_ctx_params(prsactx, params))
391         return 0;
392
393     /* Maximum for sign, auto for verify */
394     prsactx->saltlen = RSA_PSS_SALTLEN_AUTO;
395     prsactx->min_saltlen = -1;
396
397     switch (RSA_test_flags(prsactx->rsa, RSA_FLAG_TYPE_MASK)) {
398     case RSA_FLAG_TYPE_RSA:
399         prsactx->pad_mode = RSA_PKCS1_PADDING;
400         break;
401     case RSA_FLAG_TYPE_RSASSAPSS:
402         prsactx->pad_mode = RSA_PKCS1_PSS_PADDING;
403
404         {
405             const RSA_PSS_PARAMS_30 *pss =
406                 ossl_rsa_get0_pss_params_30(prsactx->rsa);
407
408             if (!ossl_rsa_pss_params_30_is_unrestricted(pss)) {
409                 int md_nid = ossl_rsa_pss_params_30_hashalg(pss);
410                 int mgf1md_nid = ossl_rsa_pss_params_30_maskgenhashalg(pss);
411                 int min_saltlen = ossl_rsa_pss_params_30_saltlen(pss);
412                 const char *mdname, *mgf1mdname;
413                 size_t len;
414
415                 mdname = ossl_rsa_oaeppss_nid2name(md_nid);
416                 mgf1mdname = ossl_rsa_oaeppss_nid2name(mgf1md_nid);
417
418                 if (mdname == NULL) {
419                     ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_DIGEST,
420                                    "PSS restrictions lack hash algorithm");
421                     return 0;
422                 }
423                 if (mgf1mdname == NULL) {
424                     ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_DIGEST,
425                                    "PSS restrictions lack MGF1 hash algorithm");
426                     return 0;
427                 }
428
429                 len = OPENSSL_strlcpy(prsactx->mdname, mdname,
430                                       sizeof(prsactx->mdname));
431                 if (len >= sizeof(prsactx->mdname)) {
432                     ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_DIGEST,
433                                    "hash algorithm name too long");
434                     return 0;
435                 }
436                 len = OPENSSL_strlcpy(prsactx->mgf1_mdname, mgf1mdname,
437                                       sizeof(prsactx->mgf1_mdname));
438                 if (len >= sizeof(prsactx->mgf1_mdname)) {
439                     ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_DIGEST,
440                                    "MGF1 hash algorithm name too long");
441                     return 0;
442                 }
443                 prsactx->saltlen = min_saltlen;
444
445                 /* call rsa_setup_mgf1_md before rsa_setup_md to avoid duplication */
446                 return rsa_setup_mgf1_md(prsactx, mgf1mdname, prsactx->propq)
447                     && rsa_setup_md(prsactx, mdname, prsactx->propq)
448                     && rsa_check_parameters(prsactx, min_saltlen);
449             }
450         }
451
452         break;
453     default:
454         ERR_raise(ERR_LIB_RSA, PROV_R_OPERATION_NOT_SUPPORTED_FOR_THIS_KEYTYPE);
455         return 0;
456     }
457
458     return 1;
459 }
460
461 static int setup_tbuf(PROV_RSA_CTX *ctx)
462 {
463     if (ctx->tbuf != NULL)
464         return 1;
465     if ((ctx->tbuf = OPENSSL_malloc(RSA_size(ctx->rsa))) == NULL) {
466         ERR_raise(ERR_LIB_PROV, ERR_R_MALLOC_FAILURE);
467         return 0;
468     }
469     return 1;
470 }
471
472 static void clean_tbuf(PROV_RSA_CTX *ctx)
473 {
474     if (ctx->tbuf != NULL)
475         OPENSSL_cleanse(ctx->tbuf, RSA_size(ctx->rsa));
476 }
477
478 static void free_tbuf(PROV_RSA_CTX *ctx)
479 {
480     clean_tbuf(ctx);
481     OPENSSL_free(ctx->tbuf);
482     ctx->tbuf = NULL;
483 }
484
485 static int rsa_sign_init(void *vprsactx, void *vrsa, const OSSL_PARAM params[])
486 {
487     if (!ossl_prov_is_running())
488         return 0;
489     return rsa_signverify_init(vprsactx, vrsa, params, EVP_PKEY_OP_SIGN);
490 }
491
492 static int rsa_sign(void *vprsactx, unsigned char *sig, size_t *siglen,
493                     size_t sigsize, const unsigned char *tbs, size_t tbslen)
494 {
495     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
496     int ret;
497     size_t rsasize = RSA_size(prsactx->rsa);
498     size_t mdsize = rsa_get_md_size(prsactx);
499
500     if (!ossl_prov_is_running())
501         return 0;
502
503     if (sig == NULL) {
504         *siglen = rsasize;
505         return 1;
506     }
507
508     if (sigsize < rsasize) {
509         ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_SIGNATURE_SIZE,
510                        "is %zu, should be at least %zu", sigsize, rsasize);
511         return 0;
512     }
513
514     if (mdsize != 0) {
515         if (tbslen != mdsize) {
516             ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_DIGEST_LENGTH);
517             return 0;
518         }
519
520 #ifndef FIPS_MODULE
521         if (EVP_MD_is_a(prsactx->md, OSSL_DIGEST_NAME_MDC2)) {
522             unsigned int sltmp;
523
524             if (prsactx->pad_mode != RSA_PKCS1_PADDING) {
525                 ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_PADDING_MODE,
526                                "only PKCS#1 padding supported with MDC2");
527                 return 0;
528             }
529             ret = RSA_sign_ASN1_OCTET_STRING(0, tbs, tbslen, sig, &sltmp,
530                                              prsactx->rsa);
531
532             if (ret <= 0) {
533                 ERR_raise(ERR_LIB_PROV, ERR_R_RSA_LIB);
534                 return 0;
535             }
536             ret = sltmp;
537             goto end;
538         }
539 #endif
540         switch (prsactx->pad_mode) {
541         case RSA_X931_PADDING:
542             if ((size_t)RSA_size(prsactx->rsa) < tbslen + 1) {
543                 ERR_raise_data(ERR_LIB_PROV, PROV_R_KEY_SIZE_TOO_SMALL,
544                                "RSA key size = %d, expected minimum = %d",
545                                RSA_size(prsactx->rsa), tbslen + 1);
546                 return 0;
547             }
548             if (!setup_tbuf(prsactx)) {
549                 ERR_raise(ERR_LIB_PROV, ERR_R_MALLOC_FAILURE);
550                 return 0;
551             }
552             memcpy(prsactx->tbuf, tbs, tbslen);
553             prsactx->tbuf[tbslen] = RSA_X931_hash_id(prsactx->mdnid);
554             ret = RSA_private_encrypt(tbslen + 1, prsactx->tbuf,
555                                       sig, prsactx->rsa, RSA_X931_PADDING);
556             clean_tbuf(prsactx);
557             break;
558
559         case RSA_PKCS1_PADDING:
560             {
561                 unsigned int sltmp;
562
563                 ret = RSA_sign(prsactx->mdnid, tbs, tbslen, sig, &sltmp,
564                                prsactx->rsa);
565                 if (ret <= 0) {
566                     ERR_raise(ERR_LIB_PROV, ERR_R_RSA_LIB);
567                     return 0;
568                 }
569                 ret = sltmp;
570             }
571             break;
572
573         case RSA_PKCS1_PSS_PADDING:
574             /* Check PSS restrictions */
575             if (rsa_pss_restricted(prsactx)) {
576                 switch (prsactx->saltlen) {
577                 case RSA_PSS_SALTLEN_DIGEST:
578                     if (prsactx->min_saltlen > EVP_MD_size(prsactx->md)) {
579                         ERR_raise_data(ERR_LIB_PROV,
580                                        PROV_R_PSS_SALTLEN_TOO_SMALL,
581                                        "minimum salt length set to %d, "
582                                        "but the digest only gives %d",
583                                        prsactx->min_saltlen,
584                                        EVP_MD_size(prsactx->md));
585                         return 0;
586                     }
587                     /* FALLTHRU */
588                 default:
589                     if (prsactx->saltlen >= 0
590                         && prsactx->saltlen < prsactx->min_saltlen) {
591                         ERR_raise_data(ERR_LIB_PROV,
592                                        PROV_R_PSS_SALTLEN_TOO_SMALL,
593                                        "minimum salt length set to %d, but the"
594                                        "actual salt length is only set to %d",
595                                        prsactx->min_saltlen,
596                                        prsactx->saltlen);
597                         return 0;
598                     }
599                     break;
600                 }
601             }
602             if (!setup_tbuf(prsactx))
603                 return 0;
604             if (!RSA_padding_add_PKCS1_PSS_mgf1(prsactx->rsa,
605                                                 prsactx->tbuf, tbs,
606                                                 prsactx->md, prsactx->mgf1_md,
607                                                 prsactx->saltlen)) {
608                 ERR_raise(ERR_LIB_PROV, ERR_R_RSA_LIB);
609                 return 0;
610             }
611             ret = RSA_private_encrypt(RSA_size(prsactx->rsa), prsactx->tbuf,
612                                       sig, prsactx->rsa, RSA_NO_PADDING);
613             clean_tbuf(prsactx);
614             break;
615
616         default:
617             ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_PADDING_MODE,
618                            "Only X.931, PKCS#1 v1.5 or PSS padding allowed");
619             return 0;
620         }
621     } else {
622         ret = RSA_private_encrypt(tbslen, tbs, sig, prsactx->rsa,
623                                   prsactx->pad_mode);
624     }
625
626 #ifndef FIPS_MODULE
627  end:
628 #endif
629     if (ret <= 0) {
630         ERR_raise(ERR_LIB_PROV, ERR_R_RSA_LIB);
631         return 0;
632     }
633
634     *siglen = ret;
635     return 1;
636 }
637
638 static int rsa_verify_recover_init(void *vprsactx, void *vrsa,
639                                    const OSSL_PARAM params[])
640 {
641     if (!ossl_prov_is_running())
642         return 0;
643     return rsa_signverify_init(vprsactx, vrsa, params,
644                                EVP_PKEY_OP_VERIFYRECOVER);
645 }
646
647 static int rsa_verify_recover(void *vprsactx,
648                               unsigned char *rout,
649                               size_t *routlen,
650                               size_t routsize,
651                               const unsigned char *sig,
652                               size_t siglen)
653 {
654     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
655     int ret;
656
657     if (!ossl_prov_is_running())
658         return 0;
659
660     if (rout == NULL) {
661         *routlen = RSA_size(prsactx->rsa);
662         return 1;
663     }
664
665     if (prsactx->md != NULL) {
666         switch (prsactx->pad_mode) {
667         case RSA_X931_PADDING:
668             if (!setup_tbuf(prsactx))
669                 return 0;
670             ret = RSA_public_decrypt(siglen, sig, prsactx->tbuf, prsactx->rsa,
671                                      RSA_X931_PADDING);
672             if (ret < 1) {
673                 ERR_raise(ERR_LIB_PROV, ERR_R_RSA_LIB);
674                 return 0;
675             }
676             ret--;
677             if (prsactx->tbuf[ret] != RSA_X931_hash_id(prsactx->mdnid)) {
678                 ERR_raise(ERR_LIB_PROV, PROV_R_ALGORITHM_MISMATCH);
679                 return 0;
680             }
681             if (ret != EVP_MD_size(prsactx->md)) {
682                 ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_DIGEST_LENGTH,
683                                "Should be %d, but got %d",
684                                EVP_MD_size(prsactx->md), ret);
685                 return 0;
686             }
687
688             *routlen = ret;
689             if (rout != prsactx->tbuf) {
690                 if (routsize < (size_t)ret) {
691                     ERR_raise_data(ERR_LIB_PROV, PROV_R_OUTPUT_BUFFER_TOO_SMALL,
692                                    "buffer size is %d, should be %d",
693                                    routsize, ret);
694                     return 0;
695                 }
696                 memcpy(rout, prsactx->tbuf, ret);
697             }
698             break;
699
700         case RSA_PKCS1_PADDING:
701             {
702                 size_t sltmp;
703
704                 ret = ossl_rsa_verify(prsactx->mdnid, NULL, 0, rout, &sltmp,
705                                       sig, siglen, prsactx->rsa);
706                 if (ret <= 0) {
707                     ERR_raise(ERR_LIB_PROV, ERR_R_RSA_LIB);
708                     return 0;
709                 }
710                 ret = sltmp;
711             }
712             break;
713
714         default:
715             ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_PADDING_MODE,
716                            "Only X.931 or PKCS#1 v1.5 padding allowed");
717             return 0;
718         }
719     } else {
720         ret = RSA_public_decrypt(siglen, sig, rout, prsactx->rsa,
721                                  prsactx->pad_mode);
722         if (ret < 0) {
723             ERR_raise(ERR_LIB_PROV, ERR_R_RSA_LIB);
724             return 0;
725         }
726     }
727     *routlen = ret;
728     return 1;
729 }
730
731 static int rsa_verify_init(void *vprsactx, void *vrsa,
732                            const OSSL_PARAM params[])
733 {
734     if (!ossl_prov_is_running())
735         return 0;
736     return rsa_signverify_init(vprsactx, vrsa, params, EVP_PKEY_OP_VERIFY);
737 }
738
739 static int rsa_verify(void *vprsactx, const unsigned char *sig, size_t siglen,
740                       const unsigned char *tbs, size_t tbslen)
741 {
742     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
743     size_t rslen;
744
745     if (!ossl_prov_is_running())
746         return 0;
747     if (prsactx->md != NULL) {
748         switch (prsactx->pad_mode) {
749         case RSA_PKCS1_PADDING:
750             if (!RSA_verify(prsactx->mdnid, tbs, tbslen, sig, siglen,
751                             prsactx->rsa)) {
752                 ERR_raise(ERR_LIB_PROV, ERR_R_RSA_LIB);
753                 return 0;
754             }
755             return 1;
756         case RSA_X931_PADDING:
757             if (!setup_tbuf(prsactx))
758                 return 0;
759             if (rsa_verify_recover(prsactx, prsactx->tbuf, &rslen, 0,
760                                    sig, siglen) <= 0)
761                 return 0;
762             break;
763         case RSA_PKCS1_PSS_PADDING:
764             {
765                 int ret;
766                 size_t mdsize;
767
768                 /*
769                  * We need to check this for the RSA_verify_PKCS1_PSS_mgf1()
770                  * call
771                  */
772                 mdsize = rsa_get_md_size(prsactx);
773                 if (tbslen != mdsize) {
774                     ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_DIGEST_LENGTH,
775                                    "Should be %d, but got %d",
776                                    mdsize, tbslen);
777                     return 0;
778                 }
779
780                 if (!setup_tbuf(prsactx))
781                     return 0;
782                 ret = RSA_public_decrypt(siglen, sig, prsactx->tbuf,
783                                          prsactx->rsa, RSA_NO_PADDING);
784                 if (ret <= 0) {
785                     ERR_raise(ERR_LIB_PROV, ERR_R_RSA_LIB);
786                     return 0;
787                 }
788                 ret = RSA_verify_PKCS1_PSS_mgf1(prsactx->rsa, tbs,
789                                                 prsactx->md, prsactx->mgf1_md,
790                                                 prsactx->tbuf,
791                                                 prsactx->saltlen);
792                 if (ret <= 0) {
793                     ERR_raise(ERR_LIB_PROV, ERR_R_RSA_LIB);
794                     return 0;
795                 }
796                 return 1;
797             }
798         default:
799             ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_PADDING_MODE,
800                            "Only X.931, PKCS#1 v1.5 or PSS padding allowed");
801             return 0;
802         }
803     } else {
804         if (!setup_tbuf(prsactx))
805             return 0;
806         rslen = RSA_public_decrypt(siglen, sig, prsactx->tbuf, prsactx->rsa,
807                                    prsactx->pad_mode);
808         if (rslen == 0) {
809             ERR_raise(ERR_LIB_PROV, ERR_R_RSA_LIB);
810             return 0;
811         }
812     }
813
814     if ((rslen != tbslen) || memcmp(tbs, prsactx->tbuf, rslen))
815         return 0;
816
817     return 1;
818 }
819
820 static int rsa_digest_signverify_init(void *vprsactx, const char *mdname,
821                                       void *vrsa, const OSSL_PARAM params[],
822                                       int operation)
823 {
824     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
825
826     if (!ossl_prov_is_running())
827         return 0;
828
829     if (prsactx != NULL)
830         prsactx->flag_allow_md = 0;
831     if (!rsa_signverify_init(vprsactx, vrsa, params, operation))
832         return 0;
833     if (mdname != NULL
834         /* was rsa_setup_md already called in rsa_signverify_init()? */
835         && (mdname[0] == '\0' || strcasecmp(prsactx->mdname, mdname) != 0)
836         && !rsa_setup_md(prsactx, mdname, prsactx->propq))
837         return 0;
838
839     prsactx->mdctx = EVP_MD_CTX_new();
840     if (prsactx->mdctx == NULL) {
841         ERR_raise(ERR_LIB_PROV, ERR_R_MALLOC_FAILURE);
842         goto error;
843     }
844
845     if (!EVP_DigestInit_ex2(prsactx->mdctx, prsactx->md, params))
846         goto error;
847
848     return 1;
849
850  error:
851     EVP_MD_CTX_free(prsactx->mdctx);
852     EVP_MD_free(prsactx->md);
853     prsactx->mdctx = NULL;
854     prsactx->md = NULL;
855     return 0;
856 }
857
858 static int rsa_digest_signverify_update(void *vprsactx,
859                                         const unsigned char *data,
860                                         size_t datalen)
861 {
862     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
863
864     if (prsactx == NULL || prsactx->mdctx == NULL)
865         return 0;
866
867     return EVP_DigestUpdate(prsactx->mdctx, data, datalen);
868 }
869
870 static int rsa_digest_sign_init(void *vprsactx, const char *mdname,
871                                 void *vrsa, const OSSL_PARAM params[])
872 {
873     if (!ossl_prov_is_running())
874         return 0;
875     return rsa_digest_signverify_init(vprsactx, mdname, vrsa,
876                                       params, EVP_PKEY_OP_SIGN);
877 }
878
879 static int rsa_digest_sign_final(void *vprsactx, unsigned char *sig,
880                                  size_t *siglen, size_t sigsize)
881 {
882     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
883     unsigned char digest[EVP_MAX_MD_SIZE];
884     unsigned int dlen = 0;
885
886     if (!ossl_prov_is_running() || prsactx == NULL)
887         return 0;
888     prsactx->flag_allow_md = 1;
889     if (prsactx->mdctx == NULL)
890         return 0;
891     /*
892      * If sig is NULL then we're just finding out the sig size. Other fields
893      * are ignored. Defer to rsa_sign.
894      */
895     if (sig != NULL) {
896         /*
897          * The digests used here are all known (see rsa_get_md_nid()), so they
898          * should not exceed the internal buffer size of EVP_MAX_MD_SIZE.
899          */
900         if (!EVP_DigestFinal_ex(prsactx->mdctx, digest, &dlen))
901             return 0;
902     }
903
904     return rsa_sign(vprsactx, sig, siglen, sigsize, digest, (size_t)dlen);
905 }
906
907 static int rsa_digest_verify_init(void *vprsactx, const char *mdname,
908                                   void *vrsa, const OSSL_PARAM params[])
909 {
910     if (!ossl_prov_is_running())
911         return 0;
912     return rsa_digest_signverify_init(vprsactx, mdname, vrsa,
913                                       params, EVP_PKEY_OP_VERIFY);
914 }
915
916 int rsa_digest_verify_final(void *vprsactx, const unsigned char *sig,
917                             size_t siglen)
918 {
919     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
920     unsigned char digest[EVP_MAX_MD_SIZE];
921     unsigned int dlen = 0;
922
923     if (!ossl_prov_is_running())
924         return 0;
925
926     if (prsactx == NULL)
927         return 0;
928     prsactx->flag_allow_md = 1;
929     if (prsactx->mdctx == NULL)
930         return 0;
931
932     /*
933      * The digests used here are all known (see rsa_get_md_nid()), so they
934      * should not exceed the internal buffer size of EVP_MAX_MD_SIZE.
935      */
936     if (!EVP_DigestFinal_ex(prsactx->mdctx, digest, &dlen))
937         return 0;
938
939     return rsa_verify(vprsactx, sig, siglen, digest, (size_t)dlen);
940 }
941
942 static void rsa_freectx(void *vprsactx)
943 {
944     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
945
946     if (prsactx == NULL)
947         return;
948
949     EVP_MD_CTX_free(prsactx->mdctx);
950     EVP_MD_free(prsactx->md);
951     EVP_MD_free(prsactx->mgf1_md);
952     OPENSSL_free(prsactx->propq);
953     free_tbuf(prsactx);
954     RSA_free(prsactx->rsa);
955
956     OPENSSL_clear_free(prsactx, sizeof(*prsactx));
957 }
958
959 static void *rsa_dupctx(void *vprsactx)
960 {
961     PROV_RSA_CTX *srcctx = (PROV_RSA_CTX *)vprsactx;
962     PROV_RSA_CTX *dstctx;
963
964     if (!ossl_prov_is_running())
965         return NULL;
966
967     dstctx = OPENSSL_zalloc(sizeof(*srcctx));
968     if (dstctx == NULL) {
969         ERR_raise(ERR_LIB_PROV, ERR_R_MALLOC_FAILURE);
970         return NULL;
971     }
972
973     *dstctx = *srcctx;
974     dstctx->rsa = NULL;
975     dstctx->md = NULL;
976     dstctx->mdctx = NULL;
977     dstctx->tbuf = NULL;
978     dstctx->propq = NULL;
979
980     if (srcctx->rsa != NULL && !RSA_up_ref(srcctx->rsa))
981         goto err;
982     dstctx->rsa = srcctx->rsa;
983
984     if (srcctx->md != NULL && !EVP_MD_up_ref(srcctx->md))
985         goto err;
986     dstctx->md = srcctx->md;
987
988     if (srcctx->mgf1_md != NULL && !EVP_MD_up_ref(srcctx->mgf1_md))
989         goto err;
990     dstctx->mgf1_md = srcctx->mgf1_md;
991
992     if (srcctx->mdctx != NULL) {
993         dstctx->mdctx = EVP_MD_CTX_new();
994         if (dstctx->mdctx == NULL
995                 || !EVP_MD_CTX_copy_ex(dstctx->mdctx, srcctx->mdctx))
996             goto err;
997     }
998
999     if (srcctx->propq != NULL) {
1000         dstctx->propq = OPENSSL_strdup(srcctx->propq);
1001         if (dstctx->propq == NULL)
1002             goto err;
1003     }
1004
1005     return dstctx;
1006  err:
1007     rsa_freectx(dstctx);
1008     return NULL;
1009 }
1010
1011 static int rsa_get_ctx_params(void *vprsactx, OSSL_PARAM *params)
1012 {
1013     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
1014     OSSL_PARAM *p;
1015
1016     if (prsactx == NULL)
1017         return 0;
1018
1019     p = OSSL_PARAM_locate(params, OSSL_SIGNATURE_PARAM_ALGORITHM_ID);
1020     if (p != NULL) {
1021         /* The Algorithm Identifier of the combined signature algorithm */
1022         unsigned char aid_buf[128];
1023         unsigned char *aid;
1024         size_t  aid_len;
1025
1026         aid = rsa_generate_signature_aid(prsactx, aid_buf,
1027                                          sizeof(aid_buf), &aid_len);
1028         if (aid == NULL || !OSSL_PARAM_set_octet_string(p, aid, aid_len))
1029             return 0;
1030     }
1031
1032     p = OSSL_PARAM_locate(params, OSSL_SIGNATURE_PARAM_PAD_MODE);
1033     if (p != NULL)
1034         switch (p->data_type) {
1035         case OSSL_PARAM_INTEGER:
1036             if (!OSSL_PARAM_set_int(p, prsactx->pad_mode))
1037                 return 0;
1038             break;
1039         case OSSL_PARAM_UTF8_STRING:
1040             {
1041                 int i;
1042                 const char *word = NULL;
1043
1044                 for (i = 0; padding_item[i].id != 0; i++) {
1045                     if (prsactx->pad_mode == (int)padding_item[i].id) {
1046                         word = padding_item[i].ptr;
1047                         break;
1048                     }
1049                 }
1050
1051                 if (word != NULL) {
1052                     if (!OSSL_PARAM_set_utf8_string(p, word))
1053                         return 0;
1054                 } else {
1055                     ERR_raise(ERR_LIB_PROV, ERR_R_INTERNAL_ERROR);
1056                 }
1057             }
1058             break;
1059         default:
1060             return 0;
1061         }
1062
1063     p = OSSL_PARAM_locate(params, OSSL_SIGNATURE_PARAM_DIGEST);
1064     if (p != NULL && !OSSL_PARAM_set_utf8_string(p, prsactx->mdname))
1065         return 0;
1066
1067     p = OSSL_PARAM_locate(params, OSSL_SIGNATURE_PARAM_MGF1_DIGEST);
1068     if (p != NULL && !OSSL_PARAM_set_utf8_string(p, prsactx->mgf1_mdname))
1069         return 0;
1070
1071     p = OSSL_PARAM_locate(params, OSSL_SIGNATURE_PARAM_PSS_SALTLEN);
1072     if (p != NULL) {
1073         if (p->data_type == OSSL_PARAM_INTEGER) {
1074             if (!OSSL_PARAM_set_int(p, prsactx->saltlen))
1075                 return 0;
1076         } else if (p->data_type == OSSL_PARAM_UTF8_STRING) {
1077             const char *value = NULL;
1078
1079             switch (prsactx->saltlen) {
1080             case RSA_PSS_SALTLEN_DIGEST:
1081                 value = OSSL_PKEY_RSA_PSS_SALT_LEN_DIGEST;
1082                 break;
1083             case RSA_PSS_SALTLEN_MAX:
1084                 value = OSSL_PKEY_RSA_PSS_SALT_LEN_MAX;
1085                 break;
1086             case RSA_PSS_SALTLEN_AUTO:
1087                 value = OSSL_PKEY_RSA_PSS_SALT_LEN_AUTO;
1088                 break;
1089             default:
1090                 {
1091                     int len = BIO_snprintf(p->data, p->data_size, "%d",
1092                                            prsactx->saltlen);
1093
1094                     if (len <= 0)
1095                         return 0;
1096                     p->return_size = len;
1097                     break;
1098                 }
1099             }
1100             if (value != NULL
1101                 && !OSSL_PARAM_set_utf8_string(p, value))
1102                 return 0;
1103         }
1104     }
1105
1106     return 1;
1107 }
1108
1109 static const OSSL_PARAM known_gettable_ctx_params[] = {
1110     OSSL_PARAM_octet_string(OSSL_SIGNATURE_PARAM_ALGORITHM_ID, NULL, 0),
1111     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_PAD_MODE, NULL, 0),
1112     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_DIGEST, NULL, 0),
1113     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_MGF1_DIGEST, NULL, 0),
1114     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_PSS_SALTLEN, NULL, 0),
1115     OSSL_PARAM_END
1116 };
1117
1118 static const OSSL_PARAM *rsa_gettable_ctx_params(ossl_unused void *vprsactx,
1119                                                  ossl_unused void *provctx)
1120 {
1121     return known_gettable_ctx_params;
1122 }
1123
1124 static int rsa_set_ctx_params(void *vprsactx, const OSSL_PARAM params[])
1125 {
1126     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
1127     const OSSL_PARAM *p;
1128     int pad_mode;
1129     int saltlen;
1130     char mdname[OSSL_MAX_NAME_SIZE] = "", *pmdname = NULL;
1131     char mdprops[OSSL_MAX_PROPQUERY_SIZE] = "", *pmdprops = NULL;
1132     char mgf1mdname[OSSL_MAX_NAME_SIZE] = "", *pmgf1mdname = NULL;
1133     char mgf1mdprops[OSSL_MAX_PROPQUERY_SIZE] = "", *pmgf1mdprops = NULL;
1134
1135     if (prsactx == NULL)
1136         return 0;
1137     if (params == NULL)
1138         return 1;
1139
1140     pad_mode = prsactx->pad_mode;
1141     saltlen = prsactx->saltlen;
1142
1143     p = OSSL_PARAM_locate_const(params, OSSL_SIGNATURE_PARAM_DIGEST);
1144     /* Not allowed during certain operations */
1145     if (p != NULL && !prsactx->flag_allow_md)
1146         return 0;
1147     if (p != NULL) {
1148         const OSSL_PARAM *propsp =
1149             OSSL_PARAM_locate_const(params,
1150                                     OSSL_SIGNATURE_PARAM_PROPERTIES);
1151
1152         pmdname = mdname;
1153         if (!OSSL_PARAM_get_utf8_string(p, &pmdname, sizeof(mdname)))
1154             return 0;
1155
1156         if (propsp != NULL) {
1157             pmdprops = mdprops;
1158             if (!OSSL_PARAM_get_utf8_string(propsp,
1159                                             &pmdprops, sizeof(mdprops)))
1160                 return 0;
1161         }
1162     }
1163
1164     p = OSSL_PARAM_locate_const(params, OSSL_SIGNATURE_PARAM_PAD_MODE);
1165     if (p != NULL) {
1166         const char *err_extra_text = NULL;
1167
1168         switch (p->data_type) {
1169         case OSSL_PARAM_INTEGER: /* Support for legacy pad mode number */
1170             if (!OSSL_PARAM_get_int(p, &pad_mode))
1171                 return 0;
1172             break;
1173         case OSSL_PARAM_UTF8_STRING:
1174             {
1175                 int i;
1176
1177                 if (p->data == NULL)
1178                     return 0;
1179
1180                 for (i = 0; padding_item[i].id != 0; i++) {
1181                     if (strcmp(p->data, padding_item[i].ptr) == 0) {
1182                         pad_mode = padding_item[i].id;
1183                         break;
1184                     }
1185                 }
1186             }
1187             break;
1188         default:
1189             return 0;
1190         }
1191
1192         switch (pad_mode) {
1193         case RSA_PKCS1_OAEP_PADDING:
1194             /*
1195              * OAEP padding is for asymmetric cipher only so is not compatible
1196              * with signature use.
1197              */
1198             err_extra_text = "OAEP padding not allowed for signing / verifying";
1199             goto bad_pad;
1200         case RSA_PKCS1_PSS_PADDING:
1201             if ((prsactx->operation
1202                  & (EVP_PKEY_OP_SIGN | EVP_PKEY_OP_VERIFY)) == 0) {
1203                 err_extra_text =
1204                     "PSS padding only allowed for sign and verify operations";
1205                 goto bad_pad;
1206             }
1207             break;
1208         case RSA_PKCS1_PADDING:
1209             err_extra_text = "PKCS#1 padding not allowed with RSA-PSS";
1210             goto cont;
1211         case RSA_NO_PADDING:
1212             err_extra_text = "No padding not allowed with RSA-PSS";
1213             goto cont;
1214         case RSA_X931_PADDING:
1215             err_extra_text = "X.931 padding not allowed with RSA-PSS";
1216         cont:
1217             if (RSA_test_flags(prsactx->rsa,
1218                                RSA_FLAG_TYPE_MASK) == RSA_FLAG_TYPE_RSA)
1219                 break;
1220             /* FALLTHRU */
1221         default:
1222         bad_pad:
1223             if (err_extra_text == NULL)
1224                 ERR_raise(ERR_LIB_PROV,
1225                           PROV_R_ILLEGAL_OR_UNSUPPORTED_PADDING_MODE);
1226             else
1227                 ERR_raise_data(ERR_LIB_PROV,
1228                                PROV_R_ILLEGAL_OR_UNSUPPORTED_PADDING_MODE,
1229                                err_extra_text);
1230             return 0;
1231         }
1232     }
1233
1234     p = OSSL_PARAM_locate_const(params, OSSL_SIGNATURE_PARAM_PSS_SALTLEN);
1235     if (p != NULL) {
1236         if (pad_mode != RSA_PKCS1_PSS_PADDING) {
1237             ERR_raise_data(ERR_LIB_PROV, PROV_R_NOT_SUPPORTED,
1238                            "PSS saltlen can only be specified if "
1239                            "PSS padding has been specified first");
1240             return 0;
1241         }
1242
1243         switch (p->data_type) {
1244         case OSSL_PARAM_INTEGER: /* Support for legacy pad mode number */
1245             if (!OSSL_PARAM_get_int(p, &saltlen))
1246                 return 0;
1247             break;
1248         case OSSL_PARAM_UTF8_STRING:
1249             if (strcmp(p->data, OSSL_PKEY_RSA_PSS_SALT_LEN_DIGEST) == 0)
1250                 saltlen = RSA_PSS_SALTLEN_DIGEST;
1251             else if (strcmp(p->data, OSSL_PKEY_RSA_PSS_SALT_LEN_MAX) == 0)
1252                 saltlen = RSA_PSS_SALTLEN_MAX;
1253             else if (strcmp(p->data, OSSL_PKEY_RSA_PSS_SALT_LEN_AUTO) == 0)
1254                 saltlen = RSA_PSS_SALTLEN_AUTO;
1255             else
1256                 saltlen = atoi(p->data);
1257             break;
1258         default:
1259             return 0;
1260         }
1261
1262         /*
1263          * RSA_PSS_SALTLEN_MAX seems curiously named in this check.
1264          * Contrary to what it's name suggests, it's the currently
1265          * lowest saltlen number possible.
1266          */
1267         if (saltlen < RSA_PSS_SALTLEN_MAX) {
1268             ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_SALT_LENGTH);
1269             return 0;
1270         }
1271
1272         if (rsa_pss_restricted(prsactx)) {
1273             switch (saltlen) {
1274             case RSA_PSS_SALTLEN_AUTO:
1275                 if (prsactx->operation == EVP_PKEY_OP_VERIFY) {
1276                     ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_SALT_LENGTH,
1277                                    "Cannot use autodetected salt length");
1278                     return 0;
1279                 }
1280                 break;
1281             case RSA_PSS_SALTLEN_DIGEST:
1282                 if (prsactx->min_saltlen > EVP_MD_size(prsactx->md)) {
1283                     ERR_raise_data(ERR_LIB_PROV,
1284                                    PROV_R_PSS_SALTLEN_TOO_SMALL,
1285                                    "Should be more than %d, but would be "
1286                                    "set to match digest size (%d)",
1287                                    prsactx->min_saltlen,
1288                                    EVP_MD_size(prsactx->md));
1289                     return 0;
1290                 }
1291                 break;
1292             default:
1293                 if (saltlen >= 0 && saltlen < prsactx->min_saltlen) {
1294                     ERR_raise_data(ERR_LIB_PROV,
1295                                    PROV_R_PSS_SALTLEN_TOO_SMALL,
1296                                    "Should be more than %d, "
1297                                    "but would be set to %d",
1298                                    prsactx->min_saltlen, saltlen);
1299                     return 0;
1300                 }
1301             }
1302         }
1303     }
1304
1305     p = OSSL_PARAM_locate_const(params, OSSL_SIGNATURE_PARAM_MGF1_DIGEST);
1306     if (p != NULL) {
1307         const OSSL_PARAM *propsp =
1308             OSSL_PARAM_locate_const(params,
1309                                     OSSL_SIGNATURE_PARAM_MGF1_PROPERTIES);
1310
1311         pmgf1mdname = mgf1mdname;
1312         if (!OSSL_PARAM_get_utf8_string(p, &pmgf1mdname, sizeof(mgf1mdname)))
1313             return 0;
1314
1315         if (propsp != NULL) {
1316             pmgf1mdprops = mgf1mdprops;
1317             if (!OSSL_PARAM_get_utf8_string(propsp,
1318                                             &pmgf1mdprops, sizeof(mgf1mdprops)))
1319                 return 0;
1320         }
1321
1322         if (pad_mode != RSA_PKCS1_PSS_PADDING) {
1323             ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_MGF1_MD);
1324             return  0;
1325         }
1326     }
1327
1328     prsactx->saltlen = saltlen;
1329     prsactx->pad_mode = pad_mode;
1330
1331     if (prsactx->md == NULL && pmdname == NULL
1332         && pad_mode == RSA_PKCS1_PSS_PADDING)
1333         pmdname = RSA_DEFAULT_DIGEST_NAME;
1334
1335     if (pmgf1mdname != NULL
1336         && !rsa_setup_mgf1_md(prsactx, pmgf1mdname, pmgf1mdprops))
1337         return 0;
1338
1339     if (pmdname != NULL) {
1340         if (!rsa_setup_md(prsactx, pmdname, pmdprops))
1341             return 0;
1342     } else {
1343         if (!rsa_check_padding(prsactx, NULL, NULL, prsactx->mdnid))
1344             return 0;
1345     }
1346     return 1;
1347 }
1348
1349 static const OSSL_PARAM settable_ctx_params[] = {
1350     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_DIGEST, NULL, 0),
1351     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_PROPERTIES, NULL, 0),
1352     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_PAD_MODE, NULL, 0),
1353     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_MGF1_DIGEST, NULL, 0),
1354     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_MGF1_PROPERTIES, NULL, 0),
1355     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_PSS_SALTLEN, NULL, 0),
1356     OSSL_PARAM_END
1357 };
1358
1359 static const OSSL_PARAM settable_ctx_params_no_digest[] = {
1360     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_PAD_MODE, NULL, 0),
1361     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_MGF1_DIGEST, NULL, 0),
1362     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_MGF1_PROPERTIES, NULL, 0),
1363     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_PSS_SALTLEN, NULL, 0),
1364     OSSL_PARAM_END
1365 };
1366
1367 static const OSSL_PARAM *rsa_settable_ctx_params(void *vprsactx,
1368                                                  ossl_unused void *provctx)
1369 {
1370     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
1371
1372     if (prsactx != NULL && !prsactx->flag_allow_md)
1373         return settable_ctx_params_no_digest;
1374     return settable_ctx_params;
1375 }
1376
1377 static int rsa_get_ctx_md_params(void *vprsactx, OSSL_PARAM *params)
1378 {
1379     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
1380
1381     if (prsactx->mdctx == NULL)
1382         return 0;
1383
1384     return EVP_MD_CTX_get_params(prsactx->mdctx, params);
1385 }
1386
1387 static const OSSL_PARAM *rsa_gettable_ctx_md_params(void *vprsactx)
1388 {
1389     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
1390
1391     if (prsactx->md == NULL)
1392         return 0;
1393
1394     return EVP_MD_gettable_ctx_params(prsactx->md);
1395 }
1396
1397 static int rsa_set_ctx_md_params(void *vprsactx, const OSSL_PARAM params[])
1398 {
1399     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
1400
1401     if (prsactx->mdctx == NULL)
1402         return 0;
1403
1404     return EVP_MD_CTX_set_params(prsactx->mdctx, params);
1405 }
1406
1407 static const OSSL_PARAM *rsa_settable_ctx_md_params(void *vprsactx)
1408 {
1409     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
1410
1411     if (prsactx->md == NULL)
1412         return 0;
1413
1414     return EVP_MD_settable_ctx_params(prsactx->md);
1415 }
1416
1417 const OSSL_DISPATCH ossl_rsa_signature_functions[] = {
1418     { OSSL_FUNC_SIGNATURE_NEWCTX, (void (*)(void))rsa_newctx },
1419     { OSSL_FUNC_SIGNATURE_SIGN_INIT, (void (*)(void))rsa_sign_init },
1420     { OSSL_FUNC_SIGNATURE_SIGN, (void (*)(void))rsa_sign },
1421     { OSSL_FUNC_SIGNATURE_VERIFY_INIT, (void (*)(void))rsa_verify_init },
1422     { OSSL_FUNC_SIGNATURE_VERIFY, (void (*)(void))rsa_verify },
1423     { OSSL_FUNC_SIGNATURE_VERIFY_RECOVER_INIT,
1424       (void (*)(void))rsa_verify_recover_init },
1425     { OSSL_FUNC_SIGNATURE_VERIFY_RECOVER,
1426       (void (*)(void))rsa_verify_recover },
1427     { OSSL_FUNC_SIGNATURE_DIGEST_SIGN_INIT,
1428       (void (*)(void))rsa_digest_sign_init },
1429     { OSSL_FUNC_SIGNATURE_DIGEST_SIGN_UPDATE,
1430       (void (*)(void))rsa_digest_signverify_update },
1431     { OSSL_FUNC_SIGNATURE_DIGEST_SIGN_FINAL,
1432       (void (*)(void))rsa_digest_sign_final },
1433     { OSSL_FUNC_SIGNATURE_DIGEST_VERIFY_INIT,
1434       (void (*)(void))rsa_digest_verify_init },
1435     { OSSL_FUNC_SIGNATURE_DIGEST_VERIFY_UPDATE,
1436       (void (*)(void))rsa_digest_signverify_update },
1437     { OSSL_FUNC_SIGNATURE_DIGEST_VERIFY_FINAL,
1438       (void (*)(void))rsa_digest_verify_final },
1439     { OSSL_FUNC_SIGNATURE_FREECTX, (void (*)(void))rsa_freectx },
1440     { OSSL_FUNC_SIGNATURE_DUPCTX, (void (*)(void))rsa_dupctx },
1441     { OSSL_FUNC_SIGNATURE_GET_CTX_PARAMS, (void (*)(void))rsa_get_ctx_params },
1442     { OSSL_FUNC_SIGNATURE_GETTABLE_CTX_PARAMS,
1443       (void (*)(void))rsa_gettable_ctx_params },
1444     { OSSL_FUNC_SIGNATURE_SET_CTX_PARAMS, (void (*)(void))rsa_set_ctx_params },
1445     { OSSL_FUNC_SIGNATURE_SETTABLE_CTX_PARAMS,
1446       (void (*)(void))rsa_settable_ctx_params },
1447     { OSSL_FUNC_SIGNATURE_GET_CTX_MD_PARAMS,
1448       (void (*)(void))rsa_get_ctx_md_params },
1449     { OSSL_FUNC_SIGNATURE_GETTABLE_CTX_MD_PARAMS,
1450       (void (*)(void))rsa_gettable_ctx_md_params },
1451     { OSSL_FUNC_SIGNATURE_SET_CTX_MD_PARAMS,
1452       (void (*)(void))rsa_set_ctx_md_params },
1453     { OSSL_FUNC_SIGNATURE_SETTABLE_CTX_MD_PARAMS,
1454       (void (*)(void))rsa_settable_ctx_md_params },
1455     { 0, NULL }
1456 };