Remove pkey_downgrade from PKCS7 code
[openssl.git] / crypto / rsa / rsa_ameth.c
1 /*
2  * Copyright 2006-2020 The OpenSSL Project Authors. All Rights Reserved.
3  *
4  * Licensed under the Apache License 2.0 (the "License").  You may not use
5  * this file except in compliance with the License.  You can obtain a copy
6  * in the file LICENSE in the source distribution or at
7  * https://www.openssl.org/source/license.html
8  */
9
10 /*
11  * RSA low level APIs are deprecated for public use, but still ok for
12  * internal use.
13  */
14 #include "internal/deprecated.h"
15
16 #include <stdio.h>
17 #include "internal/cryptlib.h"
18 #include <openssl/asn1t.h>
19 #include <openssl/x509.h>
20 #include <openssl/bn.h>
21 #include <openssl/core_names.h>
22 #include <openssl/param_build.h>
23 #include "crypto/asn1.h"
24 #include "crypto/evp.h"
25 #include "crypto/rsa.h"
26 #include "rsa_local.h"
27
28 static RSA_PSS_PARAMS *rsa_pss_decode(const X509_ALGOR *alg);
29 static int rsa_sync_to_pss_params_30(RSA *rsa);
30
31 /* Set any parameters associated with pkey */
32 static int rsa_param_encode(const EVP_PKEY *pkey,
33                             ASN1_STRING **pstr, int *pstrtype)
34 {
35     const RSA *rsa = pkey->pkey.rsa;
36
37     *pstr = NULL;
38     /* If RSA it's just NULL type */
39     if (RSA_test_flags(rsa, RSA_FLAG_TYPE_MASK) != RSA_FLAG_TYPE_RSASSAPSS) {
40         *pstrtype = V_ASN1_NULL;
41         return 1;
42     }
43     /* If no PSS parameters we omit parameters entirely */
44     if (rsa->pss == NULL) {
45         *pstrtype = V_ASN1_UNDEF;
46         return 1;
47     }
48     /* Encode PSS parameters */
49     if (ASN1_item_pack(rsa->pss, ASN1_ITEM_rptr(RSA_PSS_PARAMS), pstr) == NULL)
50         return 0;
51
52     *pstrtype = V_ASN1_SEQUENCE;
53     return 1;
54 }
55 /* Decode any parameters and set them in RSA structure */
56 static int rsa_param_decode(RSA *rsa, const X509_ALGOR *alg)
57 {
58     const ASN1_OBJECT *algoid;
59     const void *algp;
60     int algptype;
61
62     X509_ALGOR_get0(&algoid, &algptype, &algp, alg);
63     if (OBJ_obj2nid(algoid) != EVP_PKEY_RSA_PSS)
64         return 1;
65     if (algptype == V_ASN1_UNDEF)
66         return 1;
67     if (algptype != V_ASN1_SEQUENCE) {
68         ERR_raise(ERR_LIB_RSA, RSA_R_INVALID_PSS_PARAMETERS);
69         return 0;
70     }
71     rsa->pss = rsa_pss_decode(alg);
72     if (rsa->pss == NULL)
73         return 0;
74     if (!rsa_sync_to_pss_params_30(rsa))
75         return 0;
76     return 1;
77 }
78
79 static int rsa_pub_encode(X509_PUBKEY *pk, const EVP_PKEY *pkey)
80 {
81     unsigned char *penc = NULL;
82     int penclen;
83     ASN1_STRING *str;
84     int strtype;
85
86     if (!rsa_param_encode(pkey, &str, &strtype))
87         return 0;
88     penclen = i2d_RSAPublicKey(pkey->pkey.rsa, &penc);
89     if (penclen <= 0)
90         return 0;
91     if (X509_PUBKEY_set0_param(pk, OBJ_nid2obj(pkey->ameth->pkey_id),
92                                strtype, str, penc, penclen))
93         return 1;
94
95     OPENSSL_free(penc);
96     return 0;
97 }
98
99 static int rsa_pub_decode(EVP_PKEY *pkey, const X509_PUBKEY *pubkey)
100 {
101     const unsigned char *p;
102     int pklen;
103     X509_ALGOR *alg;
104     RSA *rsa = NULL;
105
106     if (!X509_PUBKEY_get0_param(NULL, &p, &pklen, &alg, pubkey))
107         return 0;
108     if ((rsa = d2i_RSAPublicKey(NULL, &p, pklen)) == NULL)
109         return 0;
110     if (!rsa_param_decode(rsa, alg)) {
111         RSA_free(rsa);
112         return 0;
113     }
114
115     RSA_clear_flags(rsa, RSA_FLAG_TYPE_MASK);
116     switch (pkey->ameth->pkey_id) {
117     case EVP_PKEY_RSA:
118         RSA_set_flags(rsa, RSA_FLAG_TYPE_RSA);
119         break;
120     case EVP_PKEY_RSA_PSS:
121         RSA_set_flags(rsa, RSA_FLAG_TYPE_RSASSAPSS);
122         break;
123     default:
124         /* Leave the type bits zero */
125         break;
126     }
127
128     if (!EVP_PKEY_assign(pkey, pkey->ameth->pkey_id, rsa)) {
129         RSA_free(rsa);
130         return 0;
131     }
132     return 1;
133 }
134
135 static int rsa_pub_cmp(const EVP_PKEY *a, const EVP_PKEY *b)
136 {
137     /*
138      * Don't check the public/private key, this is mostly for smart
139      * cards.
140      */
141     if (((RSA_flags(a->pkey.rsa) & RSA_METHOD_FLAG_NO_CHECK))
142             || (RSA_flags(b->pkey.rsa) & RSA_METHOD_FLAG_NO_CHECK)) {
143         return 1;
144     }
145
146     if (BN_cmp(b->pkey.rsa->n, a->pkey.rsa->n) != 0
147         || BN_cmp(b->pkey.rsa->e, a->pkey.rsa->e) != 0)
148         return 0;
149     return 1;
150 }
151
152 static int old_rsa_priv_decode(EVP_PKEY *pkey,
153                                const unsigned char **pder, int derlen)
154 {
155     RSA *rsa;
156
157     if ((rsa = d2i_RSAPrivateKey(NULL, pder, derlen)) == NULL)
158         return 0;
159     EVP_PKEY_assign(pkey, pkey->ameth->pkey_id, rsa);
160     return 1;
161 }
162
163 static int old_rsa_priv_encode(const EVP_PKEY *pkey, unsigned char **pder)
164 {
165     return i2d_RSAPrivateKey(pkey->pkey.rsa, pder);
166 }
167
168 static int rsa_priv_encode(PKCS8_PRIV_KEY_INFO *p8, const EVP_PKEY *pkey)
169 {
170     unsigned char *rk = NULL;
171     int rklen;
172     ASN1_STRING *str;
173     int strtype;
174
175     if (!rsa_param_encode(pkey, &str, &strtype))
176         return 0;
177     rklen = i2d_RSAPrivateKey(pkey->pkey.rsa, &rk);
178
179     if (rklen <= 0) {
180         ERR_raise(ERR_LIB_RSA, ERR_R_MALLOC_FAILURE);
181         ASN1_STRING_free(str);
182         return 0;
183     }
184
185     if (!PKCS8_pkey_set0(p8, OBJ_nid2obj(pkey->ameth->pkey_id), 0,
186                          strtype, str, rk, rklen)) {
187         ERR_raise(ERR_LIB_RSA, ERR_R_MALLOC_FAILURE);
188         ASN1_STRING_free(str);
189         return 0;
190     }
191
192     return 1;
193 }
194
195 static int rsa_priv_decode(EVP_PKEY *pkey, const PKCS8_PRIV_KEY_INFO *p8)
196 {
197     const unsigned char *p;
198     RSA *rsa;
199     int pklen;
200     const X509_ALGOR *alg;
201
202     if (!PKCS8_pkey_get0(NULL, &p, &pklen, &alg, p8))
203         return 0;
204     rsa = d2i_RSAPrivateKey(NULL, &p, pklen);
205     if (rsa == NULL) {
206         ERR_raise(ERR_LIB_RSA, ERR_R_RSA_LIB);
207         return 0;
208     }
209     if (!rsa_param_decode(rsa, alg)) {
210         RSA_free(rsa);
211         return 0;
212     }
213
214     RSA_clear_flags(rsa, RSA_FLAG_TYPE_MASK);
215     switch (pkey->ameth->pkey_id) {
216     case EVP_PKEY_RSA:
217         RSA_set_flags(rsa, RSA_FLAG_TYPE_RSA);
218         break;
219     case EVP_PKEY_RSA_PSS:
220         RSA_set_flags(rsa, RSA_FLAG_TYPE_RSASSAPSS);
221         break;
222     default:
223         /* Leave the type bits zero */
224         break;
225     }
226
227     EVP_PKEY_assign(pkey, pkey->ameth->pkey_id, rsa);
228     return 1;
229 }
230
231 static int int_rsa_size(const EVP_PKEY *pkey)
232 {
233     return RSA_size(pkey->pkey.rsa);
234 }
235
236 static int rsa_bits(const EVP_PKEY *pkey)
237 {
238     return BN_num_bits(pkey->pkey.rsa->n);
239 }
240
241 static int rsa_security_bits(const EVP_PKEY *pkey)
242 {
243     return RSA_security_bits(pkey->pkey.rsa);
244 }
245
246 static void int_rsa_free(EVP_PKEY *pkey)
247 {
248     RSA_free(pkey->pkey.rsa);
249 }
250
251 static int rsa_pss_param_print(BIO *bp, int pss_key, RSA_PSS_PARAMS *pss,
252                                int indent)
253 {
254     int rv = 0;
255     X509_ALGOR *maskHash = NULL;
256
257     if (!BIO_indent(bp, indent, 128))
258         goto err;
259     if (pss_key) {
260         if (pss == NULL) {
261             if (BIO_puts(bp, "No PSS parameter restrictions\n") <= 0)
262                 return 0;
263             return 1;
264         } else {
265             if (BIO_puts(bp, "PSS parameter restrictions:") <= 0)
266                 return 0;
267         }
268     } else if (pss == NULL) {
269         if (BIO_puts(bp,"(INVALID PSS PARAMETERS)\n") <= 0)
270             return 0;
271         return 1;
272     }
273     if (BIO_puts(bp, "\n") <= 0)
274         goto err;
275     if (pss_key)
276         indent += 2;
277     if (!BIO_indent(bp, indent, 128))
278         goto err;
279     if (BIO_puts(bp, "Hash Algorithm: ") <= 0)
280         goto err;
281
282     if (pss->hashAlgorithm) {
283         if (i2a_ASN1_OBJECT(bp, pss->hashAlgorithm->algorithm) <= 0)
284             goto err;
285     } else if (BIO_puts(bp, "sha1 (default)") <= 0) {
286         goto err;
287     }
288
289     if (BIO_puts(bp, "\n") <= 0)
290         goto err;
291
292     if (!BIO_indent(bp, indent, 128))
293         goto err;
294
295     if (BIO_puts(bp, "Mask Algorithm: ") <= 0)
296         goto err;
297     if (pss->maskGenAlgorithm) {
298         if (i2a_ASN1_OBJECT(bp, pss->maskGenAlgorithm->algorithm) <= 0)
299             goto err;
300         if (BIO_puts(bp, " with ") <= 0)
301             goto err;
302         maskHash = x509_algor_mgf1_decode(pss->maskGenAlgorithm);
303         if (maskHash != NULL) {
304             if (i2a_ASN1_OBJECT(bp, maskHash->algorithm) <= 0)
305                 goto err;
306         } else if (BIO_puts(bp, "INVALID") <= 0) {
307             goto err;
308         }
309     } else if (BIO_puts(bp, "mgf1 with sha1 (default)") <= 0) {
310         goto err;
311     }
312     BIO_puts(bp, "\n");
313
314     if (!BIO_indent(bp, indent, 128))
315         goto err;
316     if (BIO_printf(bp, "%s Salt Length: 0x", pss_key ? "Minimum" : "") <= 0)
317         goto err;
318     if (pss->saltLength) {
319         if (i2a_ASN1_INTEGER(bp, pss->saltLength) <= 0)
320             goto err;
321     } else if (BIO_puts(bp, "14 (default)") <= 0) {
322         goto err;
323     }
324     BIO_puts(bp, "\n");
325
326     if (!BIO_indent(bp, indent, 128))
327         goto err;
328     if (BIO_puts(bp, "Trailer Field: 0x") <= 0)
329         goto err;
330     if (pss->trailerField) {
331         if (i2a_ASN1_INTEGER(bp, pss->trailerField) <= 0)
332             goto err;
333     } else if (BIO_puts(bp, "BC (default)") <= 0) {
334         goto err;
335     }
336     BIO_puts(bp, "\n");
337
338     rv = 1;
339
340  err:
341     X509_ALGOR_free(maskHash);
342     return rv;
343
344 }
345
346 static int pkey_rsa_print(BIO *bp, const EVP_PKEY *pkey, int off, int priv)
347 {
348     const RSA *x = pkey->pkey.rsa;
349     char *str;
350     const char *s;
351     int ret = 0, mod_len = 0, ex_primes;
352
353     if (x->n != NULL)
354         mod_len = BN_num_bits(x->n);
355     ex_primes = sk_RSA_PRIME_INFO_num(x->prime_infos);
356
357     if (!BIO_indent(bp, off, 128))
358         goto err;
359
360     if (BIO_printf(bp, "%s ", pkey_is_pss(pkey) ?  "RSA-PSS" : "RSA") <= 0)
361         goto err;
362
363     if (priv && x->d) {
364         if (BIO_printf(bp, "Private-Key: (%d bit, %d primes)\n",
365                        mod_len, ex_primes <= 0 ? 2 : ex_primes + 2) <= 0)
366             goto err;
367         str = "modulus:";
368         s = "publicExponent:";
369     } else {
370         if (BIO_printf(bp, "Public-Key: (%d bit)\n", mod_len) <= 0)
371             goto err;
372         str = "Modulus:";
373         s = "Exponent:";
374     }
375     if (!ASN1_bn_print(bp, str, x->n, NULL, off))
376         goto err;
377     if (!ASN1_bn_print(bp, s, x->e, NULL, off))
378         goto err;
379     if (priv) {
380         int i;
381
382         if (!ASN1_bn_print(bp, "privateExponent:", x->d, NULL, off))
383             goto err;
384         if (!ASN1_bn_print(bp, "prime1:", x->p, NULL, off))
385             goto err;
386         if (!ASN1_bn_print(bp, "prime2:", x->q, NULL, off))
387             goto err;
388         if (!ASN1_bn_print(bp, "exponent1:", x->dmp1, NULL, off))
389             goto err;
390         if (!ASN1_bn_print(bp, "exponent2:", x->dmq1, NULL, off))
391             goto err;
392         if (!ASN1_bn_print(bp, "coefficient:", x->iqmp, NULL, off))
393             goto err;
394         for (i = 0; i < sk_RSA_PRIME_INFO_num(x->prime_infos); i++) {
395             /* print multi-prime info */
396             BIGNUM *bn = NULL;
397             RSA_PRIME_INFO *pinfo;
398             int j;
399
400             pinfo = sk_RSA_PRIME_INFO_value(x->prime_infos, i);
401             for (j = 0; j < 3; j++) {
402                 if (!BIO_indent(bp, off, 128))
403                     goto err;
404                 switch (j) {
405                 case 0:
406                     if (BIO_printf(bp, "prime%d:", i + 3) <= 0)
407                         goto err;
408                     bn = pinfo->r;
409                     break;
410                 case 1:
411                     if (BIO_printf(bp, "exponent%d:", i + 3) <= 0)
412                         goto err;
413                     bn = pinfo->d;
414                     break;
415                 case 2:
416                     if (BIO_printf(bp, "coefficient%d:", i + 3) <= 0)
417                         goto err;
418                     bn = pinfo->t;
419                     break;
420                 default:
421                     break;
422                 }
423                 if (!ASN1_bn_print(bp, "", bn, NULL, off))
424                     goto err;
425             }
426         }
427     }
428     if (pkey_is_pss(pkey) && !rsa_pss_param_print(bp, 1, x->pss, off))
429         goto err;
430     ret = 1;
431  err:
432     return ret;
433 }
434
435 static int rsa_pub_print(BIO *bp, const EVP_PKEY *pkey, int indent,
436                          ASN1_PCTX *ctx)
437 {
438     return pkey_rsa_print(bp, pkey, indent, 0);
439 }
440
441 static int rsa_priv_print(BIO *bp, const EVP_PKEY *pkey, int indent,
442                           ASN1_PCTX *ctx)
443 {
444     return pkey_rsa_print(bp, pkey, indent, 1);
445 }
446
447 static RSA_PSS_PARAMS *rsa_pss_decode(const X509_ALGOR *alg)
448 {
449     RSA_PSS_PARAMS *pss;
450
451     pss = ASN1_TYPE_unpack_sequence(ASN1_ITEM_rptr(RSA_PSS_PARAMS),
452                                     alg->parameter);
453
454     if (pss == NULL)
455         return NULL;
456
457     if (pss->maskGenAlgorithm != NULL) {
458         pss->maskHash = x509_algor_mgf1_decode(pss->maskGenAlgorithm);
459         if (pss->maskHash == NULL) {
460             RSA_PSS_PARAMS_free(pss);
461             return NULL;
462         }
463     }
464
465     return pss;
466 }
467
468 static int rsa_sig_print(BIO *bp, const X509_ALGOR *sigalg,
469                          const ASN1_STRING *sig, int indent, ASN1_PCTX *pctx)
470 {
471     if (OBJ_obj2nid(sigalg->algorithm) == EVP_PKEY_RSA_PSS) {
472         int rv;
473         RSA_PSS_PARAMS *pss = rsa_pss_decode(sigalg);
474
475         rv = rsa_pss_param_print(bp, 0, pss, indent);
476         RSA_PSS_PARAMS_free(pss);
477         if (!rv)
478             return 0;
479     } else if (BIO_puts(bp, "\n") <= 0) {
480         return 0;
481     }
482     if (sig)
483         return X509_signature_dump(bp, sig, indent);
484     return 1;
485 }
486
487 static int rsa_pkey_ctrl(EVP_PKEY *pkey, int op, long arg1, void *arg2)
488 {
489     X509_ALGOR *alg = NULL;
490     const EVP_MD *md;
491     const EVP_MD *mgf1md;
492     int min_saltlen;
493
494     switch (op) {
495     case ASN1_PKEY_CTRL_DEFAULT_MD_NID:
496         if (pkey->pkey.rsa->pss != NULL) {
497             if (!rsa_pss_get_param(pkey->pkey.rsa->pss, &md, &mgf1md,
498                                    &min_saltlen)) {
499                 ERR_raise(ERR_LIB_RSA, ERR_R_INTERNAL_ERROR);
500                 return 0;
501             }
502             *(int *)arg2 = EVP_MD_type(md);
503             /* Return of 2 indicates this MD is mandatory */
504             return 2;
505         }
506         *(int *)arg2 = NID_sha256;
507         return 1;
508
509     default:
510         return -2;
511
512     }
513
514     if (alg)
515         X509_ALGOR_set0(alg, OBJ_nid2obj(NID_rsaEncryption), V_ASN1_NULL, 0);
516
517     return 1;
518
519 }
520
521 /*
522  * Convert EVP_PKEY_CTX in PSS mode into corresponding algorithm parameter,
523  * suitable for setting an AlgorithmIdentifier.
524  */
525
526 static RSA_PSS_PARAMS *rsa_ctx_to_pss(EVP_PKEY_CTX *pkctx)
527 {
528     const EVP_MD *sigmd, *mgf1md;
529     EVP_PKEY *pk = EVP_PKEY_CTX_get0_pkey(pkctx);
530     int saltlen;
531
532     if (EVP_PKEY_CTX_get_signature_md(pkctx, &sigmd) <= 0)
533         return NULL;
534     if (EVP_PKEY_CTX_get_rsa_mgf1_md(pkctx, &mgf1md) <= 0)
535         return NULL;
536     if (!EVP_PKEY_CTX_get_rsa_pss_saltlen(pkctx, &saltlen))
537         return NULL;
538     if (saltlen == -1) {
539         saltlen = EVP_MD_size(sigmd);
540     } else if (saltlen == -2 || saltlen == -3) {
541         saltlen = EVP_PKEY_size(pk) - EVP_MD_size(sigmd) - 2;
542         if ((EVP_PKEY_bits(pk) & 0x7) == 1)
543             saltlen--;
544         if (saltlen < 0)
545             return NULL;
546     }
547
548     return rsa_pss_params_create(sigmd, mgf1md, saltlen);
549 }
550
551 RSA_PSS_PARAMS *rsa_pss_params_create(const EVP_MD *sigmd,
552                                       const EVP_MD *mgf1md, int saltlen)
553 {
554     RSA_PSS_PARAMS *pss = RSA_PSS_PARAMS_new();
555
556     if (pss == NULL)
557         goto err;
558     if (saltlen != 20) {
559         pss->saltLength = ASN1_INTEGER_new();
560         if (pss->saltLength == NULL)
561             goto err;
562         if (!ASN1_INTEGER_set(pss->saltLength, saltlen))
563             goto err;
564     }
565     if (!x509_algor_new_from_md(&pss->hashAlgorithm, sigmd))
566         goto err;
567     if (mgf1md == NULL)
568         mgf1md = sigmd;
569     if (!x509_algor_md_to_mgf1(&pss->maskGenAlgorithm, mgf1md))
570         goto err;
571     if (!x509_algor_new_from_md(&pss->maskHash, mgf1md))
572         goto err;
573     return pss;
574  err:
575     RSA_PSS_PARAMS_free(pss);
576     return NULL;
577 }
578
579 ASN1_STRING *ossl_rsa_ctx_to_pss_string(EVP_PKEY_CTX *pkctx)
580 {
581     RSA_PSS_PARAMS *pss = rsa_ctx_to_pss(pkctx);
582     ASN1_STRING *os;
583
584     if (pss == NULL)
585         return NULL;
586
587     os = ASN1_item_pack(pss, ASN1_ITEM_rptr(RSA_PSS_PARAMS), NULL);
588     RSA_PSS_PARAMS_free(pss);
589     return os;
590 }
591
592 /*
593  * From PSS AlgorithmIdentifier set public key parameters. If pkey isn't NULL
594  * then the EVP_MD_CTX is setup and initialised. If it is NULL parameters are
595  * passed to pkctx instead.
596  */
597
598 int ossl_rsa_pss_to_ctx(EVP_MD_CTX *ctx, EVP_PKEY_CTX *pkctx,
599                         const X509_ALGOR *sigalg, EVP_PKEY *pkey)
600 {
601     int rv = -1;
602     int saltlen;
603     const EVP_MD *mgf1md = NULL, *md = NULL;
604     RSA_PSS_PARAMS *pss;
605
606     /* Sanity check: make sure it is PSS */
607     if (OBJ_obj2nid(sigalg->algorithm) != EVP_PKEY_RSA_PSS) {
608         ERR_raise(ERR_LIB_RSA, RSA_R_UNSUPPORTED_SIGNATURE_TYPE);
609         return -1;
610     }
611     /* Decode PSS parameters */
612     pss = rsa_pss_decode(sigalg);
613
614     if (!rsa_pss_get_param(pss, &md, &mgf1md, &saltlen)) {
615         ERR_raise(ERR_LIB_RSA, RSA_R_INVALID_PSS_PARAMETERS);
616         goto err;
617     }
618
619     /* We have all parameters now set up context */
620     if (pkey) {
621         if (!EVP_DigestVerifyInit(ctx, &pkctx, md, NULL, pkey))
622             goto err;
623     } else {
624         const EVP_MD *checkmd;
625         if (EVP_PKEY_CTX_get_signature_md(pkctx, &checkmd) <= 0)
626             goto err;
627         if (EVP_MD_type(md) != EVP_MD_type(checkmd)) {
628             ERR_raise(ERR_LIB_RSA, RSA_R_DIGEST_DOES_NOT_MATCH);
629             goto err;
630         }
631     }
632
633     if (EVP_PKEY_CTX_set_rsa_padding(pkctx, RSA_PKCS1_PSS_PADDING) <= 0)
634         goto err;
635
636     if (EVP_PKEY_CTX_set_rsa_pss_saltlen(pkctx, saltlen) <= 0)
637         goto err;
638
639     if (EVP_PKEY_CTX_set_rsa_mgf1_md(pkctx, mgf1md) <= 0)
640         goto err;
641     /* Carry on */
642     rv = 1;
643
644  err:
645     RSA_PSS_PARAMS_free(pss);
646     return rv;
647 }
648
649 static int rsa_pss_verify_param(const EVP_MD **pmd, const EVP_MD **pmgf1md,
650                                 int *psaltlen, int *ptrailerField)
651 {
652     if (psaltlen != NULL && *psaltlen < 0) {
653         ERR_raise(ERR_LIB_RSA, RSA_R_INVALID_SALT_LENGTH);
654         return 0;
655     }
656     /*
657      * low-level routines support only trailer field 0xbc (value 1) and
658      * PKCS#1 says we should reject any other value anyway.
659      */
660     if (ptrailerField != NULL && *ptrailerField != 1) {
661         ERR_raise(ERR_LIB_RSA, RSA_R_INVALID_TRAILER);
662         return 0;
663     }
664     return 1;
665 }
666
667 static int rsa_pss_get_param_unverified(const RSA_PSS_PARAMS *pss,
668                                         const EVP_MD **pmd,
669                                         const EVP_MD **pmgf1md,
670                                         int *psaltlen, int *ptrailerField)
671 {
672     RSA_PSS_PARAMS_30 pss_params;
673
674     /* Get the defaults from the ONE place */
675     (void)ossl_rsa_pss_params_30_set_defaults(&pss_params);
676
677     if (pss == NULL)
678         return 0;
679     *pmd = x509_algor_get_md(pss->hashAlgorithm);
680     if (*pmd == NULL)
681         return 0;
682     *pmgf1md = x509_algor_get_md(pss->maskHash);
683     if (*pmgf1md == NULL)
684         return 0;
685     if (pss->saltLength)
686         *psaltlen = ASN1_INTEGER_get(pss->saltLength);
687     else
688         *psaltlen = ossl_rsa_pss_params_30_saltlen(&pss_params);
689     if (pss->trailerField)
690         *ptrailerField = ASN1_INTEGER_get(pss->trailerField);
691     else
692         *ptrailerField = ossl_rsa_pss_params_30_trailerfield(&pss_params);;
693
694     return 1;
695 }
696
697 int rsa_pss_get_param(const RSA_PSS_PARAMS *pss, const EVP_MD **pmd,
698                       const EVP_MD **pmgf1md, int *psaltlen)
699 {
700     /*
701      * Callers do not care about the trailer field, and yet, we must
702      * pass it from get_param to verify_param, since the latter checks
703      * its value.
704      *
705      * When callers start caring, it's a simple thing to add another
706      * argument to this function.
707      */
708     int trailerField = 0;
709
710     return rsa_pss_get_param_unverified(pss, pmd, pmgf1md, psaltlen,
711                                         &trailerField)
712         && rsa_pss_verify_param(pmd, pmgf1md, psaltlen, &trailerField);
713 }
714
715 static int rsa_sync_to_pss_params_30(RSA *rsa)
716 {
717     if (rsa != NULL && rsa->pss != NULL) {
718         const EVP_MD *md = NULL, *mgf1md = NULL;
719         int md_nid, mgf1md_nid, saltlen, trailerField;
720         RSA_PSS_PARAMS_30 pss_params;
721
722         /*
723          * We don't care about the validity of the fields here, we just
724          * want to synchronise values.  Verifying here makes it impossible
725          * to even read a key with invalid values, making it hard to test
726          * a bad situation.
727          *
728          * Other routines use rsa_pss_get_param(), so the values will be
729          * checked, eventually.
730          */
731         if (!rsa_pss_get_param_unverified(rsa->pss, &md, &mgf1md,
732                                           &saltlen, &trailerField))
733             return 0;
734         md_nid = EVP_MD_type(md);
735         mgf1md_nid = EVP_MD_type(mgf1md);
736         if (!ossl_rsa_pss_params_30_set_defaults(&pss_params)
737             || !ossl_rsa_pss_params_30_set_hashalg(&pss_params, md_nid)
738             || !ossl_rsa_pss_params_30_set_maskgenhashalg(&pss_params,
739                                                           mgf1md_nid)
740             || !ossl_rsa_pss_params_30_set_saltlen(&pss_params, saltlen)
741             || !ossl_rsa_pss_params_30_set_trailerfield(&pss_params,
742                                                         trailerField))
743             return 0;
744         rsa->pss_params = pss_params;
745     }
746     return 1;
747 }
748
749 /*
750  * Customised RSA item verification routine. This is called when a signature
751  * is encountered requiring special handling. We currently only handle PSS.
752  */
753
754 static int rsa_item_verify(EVP_MD_CTX *ctx, const ASN1_ITEM *it,
755                            const void *asn, const X509_ALGOR *sigalg,
756                            const ASN1_BIT_STRING *sig, EVP_PKEY *pkey)
757 {
758     /* Sanity check: make sure it is PSS */
759     if (OBJ_obj2nid(sigalg->algorithm) != EVP_PKEY_RSA_PSS) {
760         ERR_raise(ERR_LIB_RSA, RSA_R_UNSUPPORTED_SIGNATURE_TYPE);
761         return -1;
762     }
763     if (ossl_rsa_pss_to_ctx(ctx, NULL, sigalg, pkey) > 0) {
764         /* Carry on */
765         return 2;
766     }
767     return -1;
768 }
769
770 static int rsa_item_sign(EVP_MD_CTX *ctx, const ASN1_ITEM *it, const void *asn,
771                          X509_ALGOR *alg1, X509_ALGOR *alg2,
772                          ASN1_BIT_STRING *sig)
773 {
774     int pad_mode;
775     EVP_PKEY_CTX *pkctx = EVP_MD_CTX_pkey_ctx(ctx);
776
777     if (EVP_PKEY_CTX_get_rsa_padding(pkctx, &pad_mode) <= 0)
778         return 0;
779     if (pad_mode == RSA_PKCS1_PADDING)
780         return 2;
781     if (pad_mode == RSA_PKCS1_PSS_PADDING) {
782         ASN1_STRING *os1 = NULL;
783         os1 = ossl_rsa_ctx_to_pss_string(pkctx);
784         if (!os1)
785             return 0;
786         /* Duplicate parameters if we have to */
787         if (alg2) {
788             ASN1_STRING *os2 = ASN1_STRING_dup(os1);
789             if (!os2) {
790                 ASN1_STRING_free(os1);
791                 return 0;
792             }
793             X509_ALGOR_set0(alg2, OBJ_nid2obj(EVP_PKEY_RSA_PSS),
794                             V_ASN1_SEQUENCE, os2);
795         }
796         X509_ALGOR_set0(alg1, OBJ_nid2obj(EVP_PKEY_RSA_PSS),
797                         V_ASN1_SEQUENCE, os1);
798         return 3;
799     }
800     return 2;
801 }
802
803 static int rsa_sig_info_set(X509_SIG_INFO *siginf, const X509_ALGOR *sigalg,
804                             const ASN1_STRING *sig)
805 {
806     int rv = 0;
807     int mdnid, saltlen;
808     uint32_t flags;
809     const EVP_MD *mgf1md = NULL, *md = NULL;
810     RSA_PSS_PARAMS *pss;
811     int secbits;
812
813     /* Sanity check: make sure it is PSS */
814     if (OBJ_obj2nid(sigalg->algorithm) != EVP_PKEY_RSA_PSS)
815         return 0;
816     /* Decode PSS parameters */
817     pss = rsa_pss_decode(sigalg);
818     if (!rsa_pss_get_param(pss, &md, &mgf1md, &saltlen))
819         goto err;
820     mdnid = EVP_MD_type(md);
821     /*
822      * For TLS need SHA256, SHA384 or SHA512, digest and MGF1 digest must
823      * match and salt length must equal digest size
824      */
825     if ((mdnid == NID_sha256 || mdnid == NID_sha384 || mdnid == NID_sha512)
826             && mdnid == EVP_MD_type(mgf1md) && saltlen == EVP_MD_size(md))
827         flags = X509_SIG_INFO_TLS;
828     else
829         flags = 0;
830     /* Note: security bits half number of digest bits */
831     secbits = EVP_MD_size(md) * 4;
832     /*
833      * SHA1 and MD5 are known to be broken. Reduce security bits so that
834      * they're no longer accepted at security level 1. The real values don't
835      * really matter as long as they're lower than 80, which is our security
836      * level 1.
837      * https://eprint.iacr.org/2020/014 puts a chosen-prefix attack for SHA1 at
838      * 2^63.4
839      * https://documents.epfl.ch/users/l/le/lenstra/public/papers/lat.pdf
840      * puts a chosen-prefix attack for MD5 at 2^39.
841      */
842     if (mdnid == NID_sha1)
843         secbits = 64;
844     else if (mdnid == NID_md5_sha1)
845         secbits = 68;
846     else if (mdnid == NID_md5)
847         secbits = 39;
848     X509_SIG_INFO_set(siginf, mdnid, EVP_PKEY_RSA_PSS, secbits,
849                       flags);
850     rv = 1;
851     err:
852     RSA_PSS_PARAMS_free(pss);
853     return rv;
854 }
855
856 static int rsa_pkey_check(const EVP_PKEY *pkey)
857 {
858     return RSA_check_key_ex(pkey->pkey.rsa, NULL);
859 }
860
861 static size_t rsa_pkey_dirty_cnt(const EVP_PKEY *pkey)
862 {
863     return pkey->pkey.rsa->dirty_cnt;
864 }
865
866 /*
867  * For the moment, we trust the call path, where keys going through
868  * rsa_pkey_export_to() match a KEYMGMT for the "RSA" keytype, while
869  * keys going through rsa_pss_pkey_export_to() match a KEYMGMT for the
870  * "RSA-PSS" keytype.
871  * TODO(3.0) Investigate whether we should simply continue to trust the
872  * call path, or if we should strengthen this function by checking that
873  * |rsa_type| matches the RSA key subtype.  The latter requires ensuring
874  * that the type flag for the RSA key is properly set by other functions
875  * in this file.
876  */
877 static int rsa_int_export_to(const EVP_PKEY *from, int rsa_type,
878                              void *to_keydata, EVP_KEYMGMT *to_keymgmt,
879                              OSSL_LIB_CTX *libctx, const char *propq)
880 {
881     RSA *rsa = from->pkey.rsa;
882     OSSL_PARAM_BLD *tmpl = OSSL_PARAM_BLD_new();
883     OSSL_PARAM *params = NULL;
884     int selection = 0;
885     int rv = 0;
886
887     if (tmpl == NULL)
888         return 0;
889     /*
890      * If the RSA method is foreign, then we can't be sure of anything, and
891      * can therefore not export or pretend to export.
892      */
893     if (RSA_get_method(rsa) != RSA_PKCS1_OpenSSL())
894         goto err;
895
896     /* Public parameters must always be present */
897     if (RSA_get0_n(rsa) == NULL || RSA_get0_e(rsa) == NULL)
898         goto err;
899
900     if (!ossl_rsa_todata(rsa, tmpl, NULL))
901         goto err;
902
903     selection |= OSSL_KEYMGMT_SELECT_PUBLIC_KEY;
904     if (RSA_get0_d(rsa) != NULL)
905         selection |= OSSL_KEYMGMT_SELECT_PRIVATE_KEY;
906
907     if (rsa->pss != NULL) {
908         const EVP_MD *md = NULL, *mgf1md = NULL;
909         int md_nid, mgf1md_nid, saltlen, trailerfield;
910         RSA_PSS_PARAMS_30 pss_params;
911
912         if (!rsa_pss_get_param_unverified(rsa->pss, &md, &mgf1md,
913                                           &saltlen, &trailerfield))
914             goto err;
915         md_nid = EVP_MD_type(md);
916         mgf1md_nid = EVP_MD_type(mgf1md);
917         if (!ossl_rsa_pss_params_30_set_defaults(&pss_params)
918             || !ossl_rsa_pss_params_30_set_hashalg(&pss_params, md_nid)
919             || !ossl_rsa_pss_params_30_set_maskgenhashalg(&pss_params,
920                                                           mgf1md_nid)
921             || !ossl_rsa_pss_params_30_set_saltlen(&pss_params, saltlen)
922             || !ossl_rsa_pss_params_30_todata(&pss_params, tmpl, NULL))
923             goto err;
924         selection |= OSSL_KEYMGMT_SELECT_OTHER_PARAMETERS;
925     }
926
927     if ((params = OSSL_PARAM_BLD_to_param(tmpl)) == NULL)
928         goto err;
929
930     /* We export, the provider imports */
931     rv = evp_keymgmt_import(to_keymgmt, to_keydata, selection, params);
932
933  err:
934     OSSL_PARAM_BLD_free_params(params);
935     OSSL_PARAM_BLD_free(tmpl);
936     return rv;
937 }
938
939 static int rsa_int_import_from(const OSSL_PARAM params[], void *vpctx,
940                                int rsa_type)
941 {
942     EVP_PKEY_CTX *pctx = vpctx;
943     EVP_PKEY *pkey = EVP_PKEY_CTX_get0_pkey(pctx);
944     RSA *rsa = ossl_rsa_new_with_ctx(pctx->libctx);
945     RSA_PSS_PARAMS_30 rsa_pss_params = { 0, };
946     int ok = 0;
947
948     if (rsa == NULL) {
949         ERR_raise(ERR_LIB_DH, ERR_R_MALLOC_FAILURE);
950         return 0;
951     }
952
953     RSA_clear_flags(rsa, RSA_FLAG_TYPE_MASK);
954     RSA_set_flags(rsa, rsa_type);
955
956     if (!ossl_rsa_pss_params_30_fromdata(&rsa_pss_params, params, pctx->libctx))
957         goto err;
958
959     switch (rsa_type) {
960     case RSA_FLAG_TYPE_RSA:
961         /*
962          * Were PSS parameters filled in?
963          * In that case, something's wrong
964          */
965         if (!ossl_rsa_pss_params_30_is_unrestricted(&rsa_pss_params))
966             goto err;
967         break;
968     case RSA_FLAG_TYPE_RSASSAPSS:
969         /*
970          * Were PSS parameters filled in?  In that case, create the old
971          * RSA_PSS_PARAMS structure.  Otherwise, this is an unrestricted key.
972          */
973         if (!ossl_rsa_pss_params_30_is_unrestricted(&rsa_pss_params)) {
974             /* Create the older RSA_PSS_PARAMS from RSA_PSS_PARAMS_30 data */
975             int mdnid = ossl_rsa_pss_params_30_hashalg(&rsa_pss_params);
976             int mgf1mdnid = ossl_rsa_pss_params_30_maskgenhashalg(&rsa_pss_params);
977             int saltlen = ossl_rsa_pss_params_30_saltlen(&rsa_pss_params);
978             const EVP_MD *md = EVP_get_digestbynid(mdnid);
979             const EVP_MD *mgf1md = EVP_get_digestbynid(mgf1mdnid);
980
981             if ((rsa->pss = rsa_pss_params_create(md, mgf1md, saltlen)) == NULL)
982                 goto err;
983         }
984         break;
985     default:
986         /* RSA key sub-types we don't know how to handle yet */
987         goto err;
988     }
989
990     if (!ossl_rsa_fromdata(rsa, params))
991         goto err;
992
993     switch (rsa_type) {
994     case RSA_FLAG_TYPE_RSA:
995         ok = EVP_PKEY_assign_RSA(pkey, rsa);
996         break;
997     case RSA_FLAG_TYPE_RSASSAPSS:
998         ok = EVP_PKEY_assign(pkey, EVP_PKEY_RSA_PSS, rsa);
999         break;
1000     }
1001
1002  err:
1003     if (!ok)
1004         RSA_free(rsa);
1005     return ok;
1006 }
1007
1008 static int rsa_pkey_export_to(const EVP_PKEY *from, void *to_keydata,
1009                               EVP_KEYMGMT *to_keymgmt, OSSL_LIB_CTX *libctx,
1010                               const char *propq)
1011 {
1012     return rsa_int_export_to(from, RSA_FLAG_TYPE_RSA, to_keydata,
1013                              to_keymgmt, libctx, propq);
1014 }
1015
1016 static int rsa_pss_pkey_export_to(const EVP_PKEY *from, void *to_keydata,
1017                                   EVP_KEYMGMT *to_keymgmt, OSSL_LIB_CTX *libctx,
1018                                   const char *propq)
1019 {
1020     return rsa_int_export_to(from, RSA_FLAG_TYPE_RSASSAPSS, to_keydata,
1021                              to_keymgmt, libctx, propq);
1022 }
1023
1024 static int rsa_pkey_import_from(const OSSL_PARAM params[], void *vpctx)
1025 {
1026     return rsa_int_import_from(params, vpctx, RSA_FLAG_TYPE_RSA);
1027 }
1028
1029 static int rsa_pss_pkey_import_from(const OSSL_PARAM params[], void *vpctx)
1030 {
1031     return rsa_int_import_from(params, vpctx, RSA_FLAG_TYPE_RSASSAPSS);
1032 }
1033
1034 const EVP_PKEY_ASN1_METHOD rsa_asn1_meths[2] = {
1035     {
1036      EVP_PKEY_RSA,
1037      EVP_PKEY_RSA,
1038      ASN1_PKEY_SIGPARAM_NULL,
1039
1040      "RSA",
1041      "OpenSSL RSA method",
1042
1043      rsa_pub_decode,
1044      rsa_pub_encode,
1045      rsa_pub_cmp,
1046      rsa_pub_print,
1047
1048      rsa_priv_decode,
1049      rsa_priv_encode,
1050      rsa_priv_print,
1051
1052      int_rsa_size,
1053      rsa_bits,
1054      rsa_security_bits,
1055
1056      0, 0, 0, 0, 0, 0,
1057
1058      rsa_sig_print,
1059      int_rsa_free,
1060      rsa_pkey_ctrl,
1061      old_rsa_priv_decode,
1062      old_rsa_priv_encode,
1063      rsa_item_verify,
1064      rsa_item_sign,
1065      rsa_sig_info_set,
1066      rsa_pkey_check,
1067
1068      0, 0,
1069      0, 0, 0, 0,
1070
1071      rsa_pkey_dirty_cnt,
1072      rsa_pkey_export_to,
1073      rsa_pkey_import_from
1074     },
1075
1076     {
1077      EVP_PKEY_RSA2,
1078      EVP_PKEY_RSA,
1079      ASN1_PKEY_ALIAS}
1080 };
1081
1082 const EVP_PKEY_ASN1_METHOD rsa_pss_asn1_meth = {
1083      EVP_PKEY_RSA_PSS,
1084      EVP_PKEY_RSA_PSS,
1085      ASN1_PKEY_SIGPARAM_NULL,
1086
1087      "RSA-PSS",
1088      "OpenSSL RSA-PSS method",
1089
1090      rsa_pub_decode,
1091      rsa_pub_encode,
1092      rsa_pub_cmp,
1093      rsa_pub_print,
1094
1095      rsa_priv_decode,
1096      rsa_priv_encode,
1097      rsa_priv_print,
1098
1099      int_rsa_size,
1100      rsa_bits,
1101      rsa_security_bits,
1102
1103      0, 0, 0, 0, 0, 0,
1104
1105      rsa_sig_print,
1106      int_rsa_free,
1107      rsa_pkey_ctrl,
1108      0, 0,
1109      rsa_item_verify,
1110      rsa_item_sign,
1111      0,
1112      rsa_pkey_check,
1113
1114      0, 0,
1115      0, 0, 0, 0,
1116
1117      rsa_pkey_dirty_cnt,
1118      rsa_pss_pkey_export_to,
1119      rsa_pss_pkey_import_from
1120 };