Constify private key decode.
[openssl.git] / crypto / rsa / rsa_ameth.c
1 /*
2  * Copyright 2006-2016 The OpenSSL Project Authors. All Rights Reserved.
3  *
4  * Licensed under the OpenSSL license (the "License").  You may not use
5  * this file except in compliance with the License.  You can obtain a copy
6  * in the file LICENSE in the source distribution or at
7  * https://www.openssl.org/source/license.html
8  */
9
10 #include <stdio.h>
11 #include "internal/cryptlib.h"
12 #include <openssl/asn1t.h>
13 #include <openssl/x509.h>
14 #include <openssl/bn.h>
15 #include <openssl/cms.h>
16 #include "internal/asn1_int.h"
17 #include "internal/evp_int.h"
18 #include "rsa_locl.h"
19
20 #ifndef OPENSSL_NO_CMS
21 static int rsa_cms_sign(CMS_SignerInfo *si);
22 static int rsa_cms_verify(CMS_SignerInfo *si);
23 static int rsa_cms_decrypt(CMS_RecipientInfo *ri);
24 static int rsa_cms_encrypt(CMS_RecipientInfo *ri);
25 #endif
26
27 static int rsa_pub_encode(X509_PUBKEY *pk, const EVP_PKEY *pkey)
28 {
29     unsigned char *penc = NULL;
30     int penclen;
31     penclen = i2d_RSAPublicKey(pkey->pkey.rsa, &penc);
32     if (penclen <= 0)
33         return 0;
34     if (X509_PUBKEY_set0_param(pk, OBJ_nid2obj(EVP_PKEY_RSA),
35                                V_ASN1_NULL, NULL, penc, penclen))
36         return 1;
37
38     OPENSSL_free(penc);
39     return 0;
40 }
41
42 static int rsa_pub_decode(EVP_PKEY *pkey, X509_PUBKEY *pubkey)
43 {
44     const unsigned char *p;
45     int pklen;
46     RSA *rsa = NULL;
47
48     if (!X509_PUBKEY_get0_param(NULL, &p, &pklen, NULL, pubkey))
49         return 0;
50     if ((rsa = d2i_RSAPublicKey(NULL, &p, pklen)) == NULL) {
51         RSAerr(RSA_F_RSA_PUB_DECODE, ERR_R_RSA_LIB);
52         return 0;
53     }
54     EVP_PKEY_assign_RSA(pkey, rsa);
55     return 1;
56 }
57
58 static int rsa_pub_cmp(const EVP_PKEY *a, const EVP_PKEY *b)
59 {
60     if (BN_cmp(b->pkey.rsa->n, a->pkey.rsa->n) != 0
61         || BN_cmp(b->pkey.rsa->e, a->pkey.rsa->e) != 0)
62         return 0;
63     return 1;
64 }
65
66 static int old_rsa_priv_decode(EVP_PKEY *pkey,
67                                const unsigned char **pder, int derlen)
68 {
69     RSA *rsa;
70
71     if ((rsa = d2i_RSAPrivateKey(NULL, pder, derlen)) == NULL) {
72         RSAerr(RSA_F_OLD_RSA_PRIV_DECODE, ERR_R_RSA_LIB);
73         return 0;
74     }
75     EVP_PKEY_assign_RSA(pkey, rsa);
76     return 1;
77 }
78
79 static int old_rsa_priv_encode(const EVP_PKEY *pkey, unsigned char **pder)
80 {
81     return i2d_RSAPrivateKey(pkey->pkey.rsa, pder);
82 }
83
84 static int rsa_priv_encode(PKCS8_PRIV_KEY_INFO *p8, const EVP_PKEY *pkey)
85 {
86     unsigned char *rk = NULL;
87     int rklen;
88     rklen = i2d_RSAPrivateKey(pkey->pkey.rsa, &rk);
89
90     if (rklen <= 0) {
91         RSAerr(RSA_F_RSA_PRIV_ENCODE, ERR_R_MALLOC_FAILURE);
92         return 0;
93     }
94
95     if (!PKCS8_pkey_set0(p8, OBJ_nid2obj(NID_rsaEncryption), 0,
96                          V_ASN1_NULL, NULL, rk, rklen)) {
97         RSAerr(RSA_F_RSA_PRIV_ENCODE, ERR_R_MALLOC_FAILURE);
98         return 0;
99     }
100
101     return 1;
102 }
103
104 static int rsa_priv_decode(EVP_PKEY *pkey, const PKCS8_PRIV_KEY_INFO *p8)
105 {
106     const unsigned char *p;
107     int pklen;
108     if (!PKCS8_pkey_get0(NULL, &p, &pklen, NULL, p8))
109         return 0;
110     return old_rsa_priv_decode(pkey, &p, pklen);
111 }
112
113 static int int_rsa_size(const EVP_PKEY *pkey)
114 {
115     return RSA_size(pkey->pkey.rsa);
116 }
117
118 static int rsa_bits(const EVP_PKEY *pkey)
119 {
120     return BN_num_bits(pkey->pkey.rsa->n);
121 }
122
123 static int rsa_security_bits(const EVP_PKEY *pkey)
124 {
125     return RSA_security_bits(pkey->pkey.rsa);
126 }
127
128 static void int_rsa_free(EVP_PKEY *pkey)
129 {
130     RSA_free(pkey->pkey.rsa);
131 }
132
133 static int do_rsa_print(BIO *bp, const RSA *x, int off, int priv)
134 {
135     char *str;
136     const char *s;
137     int ret = 0, mod_len = 0;
138
139     if (x->n != NULL)
140         mod_len = BN_num_bits(x->n);
141
142     if (!BIO_indent(bp, off, 128))
143         goto err;
144
145     if (priv && x->d) {
146         if (BIO_printf(bp, "Private-Key: (%d bit)\n", mod_len) <= 0)
147             goto err;
148         str = "modulus:";
149         s = "publicExponent:";
150     } else {
151         if (BIO_printf(bp, "Public-Key: (%d bit)\n", mod_len) <= 0)
152             goto err;
153         str = "Modulus:";
154         s = "Exponent:";
155     }
156     if (!ASN1_bn_print(bp, str, x->n, NULL, off))
157         goto err;
158     if (!ASN1_bn_print(bp, s, x->e, NULL, off))
159         goto err;
160     if (priv) {
161         if (!ASN1_bn_print(bp, "privateExponent:", x->d, NULL, off))
162             goto err;
163         if (!ASN1_bn_print(bp, "prime1:", x->p, NULL, off))
164             goto err;
165         if (!ASN1_bn_print(bp, "prime2:", x->q, NULL, off))
166             goto err;
167         if (!ASN1_bn_print(bp, "exponent1:", x->dmp1, NULL, off))
168             goto err;
169         if (!ASN1_bn_print(bp, "exponent2:", x->dmq1, NULL, off))
170             goto err;
171         if (!ASN1_bn_print(bp, "coefficient:", x->iqmp, NULL, off))
172             goto err;
173     }
174     ret = 1;
175  err:
176     return (ret);
177 }
178
179 static int rsa_pub_print(BIO *bp, const EVP_PKEY *pkey, int indent,
180                          ASN1_PCTX *ctx)
181 {
182     return do_rsa_print(bp, pkey->pkey.rsa, indent, 0);
183 }
184
185 static int rsa_priv_print(BIO *bp, const EVP_PKEY *pkey, int indent,
186                           ASN1_PCTX *ctx)
187 {
188     return do_rsa_print(bp, pkey->pkey.rsa, indent, 1);
189 }
190
191 /* Given an MGF1 Algorithm ID decode to an Algorithm Identifier */
192 static X509_ALGOR *rsa_mgf1_decode(X509_ALGOR *alg)
193 {
194     if (alg == NULL)
195         return NULL;
196     if (OBJ_obj2nid(alg->algorithm) != NID_mgf1)
197         return NULL;
198     return ASN1_TYPE_unpack_sequence(ASN1_ITEM_rptr(X509_ALGOR),
199                                      alg->parameter);
200 }
201
202 static RSA_PSS_PARAMS *rsa_pss_decode(const X509_ALGOR *alg,
203                                       X509_ALGOR **pmaskHash)
204 {
205     RSA_PSS_PARAMS *pss;
206
207     *pmaskHash = NULL;
208
209     pss = ASN1_TYPE_unpack_sequence(ASN1_ITEM_rptr(RSA_PSS_PARAMS),
210                                     alg->parameter);
211
212     if (!pss)
213         return NULL;
214
215     *pmaskHash = rsa_mgf1_decode(pss->maskGenAlgorithm);
216
217     return pss;
218 }
219
220 static int rsa_pss_param_print(BIO *bp, RSA_PSS_PARAMS *pss,
221                                X509_ALGOR *maskHash, int indent)
222 {
223     int rv = 0;
224     if (!pss) {
225         if (BIO_puts(bp, " (INVALID PSS PARAMETERS)\n") <= 0)
226             return 0;
227         return 1;
228     }
229     if (BIO_puts(bp, "\n") <= 0)
230         goto err;
231     if (!BIO_indent(bp, indent, 128))
232         goto err;
233     if (BIO_puts(bp, "Hash Algorithm: ") <= 0)
234         goto err;
235
236     if (pss->hashAlgorithm) {
237         if (i2a_ASN1_OBJECT(bp, pss->hashAlgorithm->algorithm) <= 0)
238             goto err;
239     } else if (BIO_puts(bp, "sha1 (default)") <= 0)
240         goto err;
241
242     if (BIO_puts(bp, "\n") <= 0)
243         goto err;
244
245     if (!BIO_indent(bp, indent, 128))
246         goto err;
247
248     if (BIO_puts(bp, "Mask Algorithm: ") <= 0)
249         goto err;
250     if (pss->maskGenAlgorithm) {
251         if (i2a_ASN1_OBJECT(bp, pss->maskGenAlgorithm->algorithm) <= 0)
252             goto err;
253         if (BIO_puts(bp, " with ") <= 0)
254             goto err;
255         if (maskHash) {
256             if (i2a_ASN1_OBJECT(bp, maskHash->algorithm) <= 0)
257                 goto err;
258         } else if (BIO_puts(bp, "INVALID") <= 0)
259             goto err;
260     } else if (BIO_puts(bp, "mgf1 with sha1 (default)") <= 0)
261         goto err;
262     BIO_puts(bp, "\n");
263
264     if (!BIO_indent(bp, indent, 128))
265         goto err;
266     if (BIO_puts(bp, "Salt Length: 0x") <= 0)
267         goto err;
268     if (pss->saltLength) {
269         if (i2a_ASN1_INTEGER(bp, pss->saltLength) <= 0)
270             goto err;
271     } else if (BIO_puts(bp, "14 (default)") <= 0)
272         goto err;
273     BIO_puts(bp, "\n");
274
275     if (!BIO_indent(bp, indent, 128))
276         goto err;
277     if (BIO_puts(bp, "Trailer Field: 0x") <= 0)
278         goto err;
279     if (pss->trailerField) {
280         if (i2a_ASN1_INTEGER(bp, pss->trailerField) <= 0)
281             goto err;
282     } else if (BIO_puts(bp, "BC (default)") <= 0)
283         goto err;
284     BIO_puts(bp, "\n");
285
286     rv = 1;
287
288  err:
289     return rv;
290
291 }
292
293 static int rsa_sig_print(BIO *bp, const X509_ALGOR *sigalg,
294                          const ASN1_STRING *sig, int indent, ASN1_PCTX *pctx)
295 {
296     if (OBJ_obj2nid(sigalg->algorithm) == NID_rsassaPss) {
297         int rv;
298         RSA_PSS_PARAMS *pss;
299         X509_ALGOR *maskHash;
300         pss = rsa_pss_decode(sigalg, &maskHash);
301         rv = rsa_pss_param_print(bp, pss, maskHash, indent);
302         RSA_PSS_PARAMS_free(pss);
303         X509_ALGOR_free(maskHash);
304         if (!rv)
305             return 0;
306     } else if (!sig && BIO_puts(bp, "\n") <= 0)
307         return 0;
308     if (sig)
309         return X509_signature_dump(bp, sig, indent);
310     return 1;
311 }
312
313 static int rsa_pkey_ctrl(EVP_PKEY *pkey, int op, long arg1, void *arg2)
314 {
315     X509_ALGOR *alg = NULL;
316     switch (op) {
317
318     case ASN1_PKEY_CTRL_PKCS7_SIGN:
319         if (arg1 == 0)
320             PKCS7_SIGNER_INFO_get0_algs(arg2, NULL, NULL, &alg);
321         break;
322
323     case ASN1_PKEY_CTRL_PKCS7_ENCRYPT:
324         if (arg1 == 0)
325             PKCS7_RECIP_INFO_get0_alg(arg2, &alg);
326         break;
327 #ifndef OPENSSL_NO_CMS
328     case ASN1_PKEY_CTRL_CMS_SIGN:
329         if (arg1 == 0)
330             return rsa_cms_sign(arg2);
331         else if (arg1 == 1)
332             return rsa_cms_verify(arg2);
333         break;
334
335     case ASN1_PKEY_CTRL_CMS_ENVELOPE:
336         if (arg1 == 0)
337             return rsa_cms_encrypt(arg2);
338         else if (arg1 == 1)
339             return rsa_cms_decrypt(arg2);
340         break;
341
342     case ASN1_PKEY_CTRL_CMS_RI_TYPE:
343         *(int *)arg2 = CMS_RECIPINFO_TRANS;
344         return 1;
345 #endif
346
347     case ASN1_PKEY_CTRL_DEFAULT_MD_NID:
348         *(int *)arg2 = NID_sha256;
349         return 1;
350
351     default:
352         return -2;
353
354     }
355
356     if (alg)
357         X509_ALGOR_set0(alg, OBJ_nid2obj(NID_rsaEncryption), V_ASN1_NULL, 0);
358
359     return 1;
360
361 }
362
363 /* allocate and set algorithm ID from EVP_MD, default SHA1 */
364 static int rsa_md_to_algor(X509_ALGOR **palg, const EVP_MD *md)
365 {
366     if (EVP_MD_type(md) == NID_sha1)
367         return 1;
368     *palg = X509_ALGOR_new();
369     if (*palg == NULL)
370         return 0;
371     X509_ALGOR_set_md(*palg, md);
372     return 1;
373 }
374
375 /* Allocate and set MGF1 algorithm ID from EVP_MD */
376 static int rsa_md_to_mgf1(X509_ALGOR **palg, const EVP_MD *mgf1md)
377 {
378     X509_ALGOR *algtmp = NULL;
379     ASN1_STRING *stmp = NULL;
380     *palg = NULL;
381     if (EVP_MD_type(mgf1md) == NID_sha1)
382         return 1;
383     /* need to embed algorithm ID inside another */
384     if (!rsa_md_to_algor(&algtmp, mgf1md))
385         goto err;
386     if (!ASN1_item_pack(algtmp, ASN1_ITEM_rptr(X509_ALGOR), &stmp))
387          goto err;
388     *palg = X509_ALGOR_new();
389     if (*palg == NULL)
390         goto err;
391     X509_ALGOR_set0(*palg, OBJ_nid2obj(NID_mgf1), V_ASN1_SEQUENCE, stmp);
392     stmp = NULL;
393  err:
394     ASN1_STRING_free(stmp);
395     X509_ALGOR_free(algtmp);
396     if (*palg)
397         return 1;
398     return 0;
399 }
400
401 /* convert algorithm ID to EVP_MD, default SHA1 */
402 static const EVP_MD *rsa_algor_to_md(X509_ALGOR *alg)
403 {
404     const EVP_MD *md;
405     if (!alg)
406         return EVP_sha1();
407     md = EVP_get_digestbyobj(alg->algorithm);
408     if (md == NULL)
409         RSAerr(RSA_F_RSA_ALGOR_TO_MD, RSA_R_UNKNOWN_DIGEST);
410     return md;
411 }
412
413 /* convert MGF1 algorithm ID to EVP_MD, default SHA1 */
414 static const EVP_MD *rsa_mgf1_to_md(X509_ALGOR *alg, X509_ALGOR *maskHash)
415 {
416     const EVP_MD *md;
417     if (!alg)
418         return EVP_sha1();
419     /* Check mask and lookup mask hash algorithm */
420     if (OBJ_obj2nid(alg->algorithm) != NID_mgf1) {
421         RSAerr(RSA_F_RSA_MGF1_TO_MD, RSA_R_UNSUPPORTED_MASK_ALGORITHM);
422         return NULL;
423     }
424     if (!maskHash) {
425         RSAerr(RSA_F_RSA_MGF1_TO_MD, RSA_R_UNSUPPORTED_MASK_PARAMETER);
426         return NULL;
427     }
428     md = EVP_get_digestbyobj(maskHash->algorithm);
429     if (md == NULL) {
430         RSAerr(RSA_F_RSA_MGF1_TO_MD, RSA_R_UNKNOWN_MASK_DIGEST);
431         return NULL;
432     }
433     return md;
434 }
435
436 /*
437  * Convert EVP_PKEY_CTX is PSS mode into corresponding algorithm parameter,
438  * suitable for setting an AlgorithmIdentifier.
439  */
440
441 static ASN1_STRING *rsa_ctx_to_pss(EVP_PKEY_CTX *pkctx)
442 {
443     const EVP_MD *sigmd, *mgf1md;
444     RSA_PSS_PARAMS *pss = NULL;
445     ASN1_STRING *os = NULL;
446     EVP_PKEY *pk = EVP_PKEY_CTX_get0_pkey(pkctx);
447     int saltlen, rv = 0;
448     if (EVP_PKEY_CTX_get_signature_md(pkctx, &sigmd) <= 0)
449         goto err;
450     if (EVP_PKEY_CTX_get_rsa_mgf1_md(pkctx, &mgf1md) <= 0)
451         goto err;
452     if (!EVP_PKEY_CTX_get_rsa_pss_saltlen(pkctx, &saltlen))
453         goto err;
454     if (saltlen == -1)
455         saltlen = EVP_MD_size(sigmd);
456     else if (saltlen == -2) {
457         saltlen = EVP_PKEY_size(pk) - EVP_MD_size(sigmd) - 2;
458         if (((EVP_PKEY_bits(pk) - 1) & 0x7) == 0)
459             saltlen--;
460     }
461     pss = RSA_PSS_PARAMS_new();
462     if (pss == NULL)
463         goto err;
464     if (saltlen != 20) {
465         pss->saltLength = ASN1_INTEGER_new();
466         if (pss->saltLength == NULL)
467             goto err;
468         if (!ASN1_INTEGER_set(pss->saltLength, saltlen))
469             goto err;
470     }
471     if (!rsa_md_to_algor(&pss->hashAlgorithm, sigmd))
472         goto err;
473     if (!rsa_md_to_mgf1(&pss->maskGenAlgorithm, mgf1md))
474         goto err;
475     /* Finally create string with pss parameter encoding. */
476     if (!ASN1_item_pack(pss, ASN1_ITEM_rptr(RSA_PSS_PARAMS), &os))
477          goto err;
478     rv = 1;
479  err:
480     RSA_PSS_PARAMS_free(pss);
481     if (rv)
482         return os;
483     ASN1_STRING_free(os);
484     return NULL;
485 }
486
487 /*
488  * From PSS AlgorithmIdentifier set public key parameters. If pkey isn't NULL
489  * then the EVP_MD_CTX is setup and initialised. If it is NULL parameters are
490  * passed to pkctx instead.
491  */
492
493 static int rsa_pss_to_ctx(EVP_MD_CTX *ctx, EVP_PKEY_CTX *pkctx,
494                           X509_ALGOR *sigalg, EVP_PKEY *pkey)
495 {
496     int rv = -1;
497     int saltlen;
498     const EVP_MD *mgf1md = NULL, *md = NULL;
499     RSA_PSS_PARAMS *pss;
500     X509_ALGOR *maskHash;
501     /* Sanity check: make sure it is PSS */
502     if (OBJ_obj2nid(sigalg->algorithm) != NID_rsassaPss) {
503         RSAerr(RSA_F_RSA_PSS_TO_CTX, RSA_R_UNSUPPORTED_SIGNATURE_TYPE);
504         return -1;
505     }
506     /* Decode PSS parameters */
507     pss = rsa_pss_decode(sigalg, &maskHash);
508
509     if (pss == NULL) {
510         RSAerr(RSA_F_RSA_PSS_TO_CTX, RSA_R_INVALID_PSS_PARAMETERS);
511         goto err;
512     }
513     mgf1md = rsa_mgf1_to_md(pss->maskGenAlgorithm, maskHash);
514     if (!mgf1md)
515         goto err;
516     md = rsa_algor_to_md(pss->hashAlgorithm);
517     if (!md)
518         goto err;
519
520     if (pss->saltLength) {
521         saltlen = ASN1_INTEGER_get(pss->saltLength);
522
523         /*
524          * Could perform more salt length sanity checks but the main RSA
525          * routines will trap other invalid values anyway.
526          */
527         if (saltlen < 0) {
528             RSAerr(RSA_F_RSA_PSS_TO_CTX, RSA_R_INVALID_SALT_LENGTH);
529             goto err;
530         }
531     } else
532         saltlen = 20;
533
534     /*
535      * low-level routines support only trailer field 0xbc (value 1) and
536      * PKCS#1 says we should reject any other value anyway.
537      */
538     if (pss->trailerField && ASN1_INTEGER_get(pss->trailerField) != 1) {
539         RSAerr(RSA_F_RSA_PSS_TO_CTX, RSA_R_INVALID_TRAILER);
540         goto err;
541     }
542
543     /* We have all parameters now set up context */
544
545     if (pkey) {
546         if (!EVP_DigestVerifyInit(ctx, &pkctx, md, NULL, pkey))
547             goto err;
548     } else {
549         const EVP_MD *checkmd;
550         if (EVP_PKEY_CTX_get_signature_md(pkctx, &checkmd) <= 0)
551             goto err;
552         if (EVP_MD_type(md) != EVP_MD_type(checkmd)) {
553             RSAerr(RSA_F_RSA_PSS_TO_CTX, RSA_R_DIGEST_DOES_NOT_MATCH);
554             goto err;
555         }
556     }
557
558     if (EVP_PKEY_CTX_set_rsa_padding(pkctx, RSA_PKCS1_PSS_PADDING) <= 0)
559         goto err;
560
561     if (EVP_PKEY_CTX_set_rsa_pss_saltlen(pkctx, saltlen) <= 0)
562         goto err;
563
564     if (EVP_PKEY_CTX_set_rsa_mgf1_md(pkctx, mgf1md) <= 0)
565         goto err;
566     /* Carry on */
567     rv = 1;
568
569  err:
570     RSA_PSS_PARAMS_free(pss);
571     X509_ALGOR_free(maskHash);
572     return rv;
573 }
574
575 #ifndef OPENSSL_NO_CMS
576 static int rsa_cms_verify(CMS_SignerInfo *si)
577 {
578     int nid, nid2;
579     X509_ALGOR *alg;
580     EVP_PKEY_CTX *pkctx = CMS_SignerInfo_get0_pkey_ctx(si);
581     CMS_SignerInfo_get0_algs(si, NULL, NULL, NULL, &alg);
582     nid = OBJ_obj2nid(alg->algorithm);
583     if (nid == NID_rsaEncryption)
584         return 1;
585     if (nid == NID_rsassaPss)
586         return rsa_pss_to_ctx(NULL, pkctx, alg, NULL);
587     /* Workaround for some implementation that use a signature OID */
588     if (OBJ_find_sigid_algs(nid, NULL, &nid2)) {
589         if (nid2 == NID_rsaEncryption)
590             return 1;
591     }
592     return 0;
593 }
594 #endif
595
596 /*
597  * Customised RSA item verification routine. This is called when a signature
598  * is encountered requiring special handling. We currently only handle PSS.
599  */
600
601 static int rsa_item_verify(EVP_MD_CTX *ctx, const ASN1_ITEM *it, void *asn,
602                            X509_ALGOR *sigalg, ASN1_BIT_STRING *sig,
603                            EVP_PKEY *pkey)
604 {
605     /* Sanity check: make sure it is PSS */
606     if (OBJ_obj2nid(sigalg->algorithm) != NID_rsassaPss) {
607         RSAerr(RSA_F_RSA_ITEM_VERIFY, RSA_R_UNSUPPORTED_SIGNATURE_TYPE);
608         return -1;
609     }
610     if (rsa_pss_to_ctx(ctx, NULL, sigalg, pkey) > 0) {
611         /* Carry on */
612         return 2;
613     }
614     return -1;
615 }
616
617 #ifndef OPENSSL_NO_CMS
618 static int rsa_cms_sign(CMS_SignerInfo *si)
619 {
620     int pad_mode = RSA_PKCS1_PADDING;
621     X509_ALGOR *alg;
622     EVP_PKEY_CTX *pkctx = CMS_SignerInfo_get0_pkey_ctx(si);
623     ASN1_STRING *os = NULL;
624     CMS_SignerInfo_get0_algs(si, NULL, NULL, NULL, &alg);
625     if (pkctx) {
626         if (EVP_PKEY_CTX_get_rsa_padding(pkctx, &pad_mode) <= 0)
627             return 0;
628     }
629     if (pad_mode == RSA_PKCS1_PADDING) {
630         X509_ALGOR_set0(alg, OBJ_nid2obj(NID_rsaEncryption), V_ASN1_NULL, 0);
631         return 1;
632     }
633     /* We don't support it */
634     if (pad_mode != RSA_PKCS1_PSS_PADDING)
635         return 0;
636     os = rsa_ctx_to_pss(pkctx);
637     if (!os)
638         return 0;
639     X509_ALGOR_set0(alg, OBJ_nid2obj(NID_rsassaPss), V_ASN1_SEQUENCE, os);
640     return 1;
641 }
642 #endif
643
644 static int rsa_item_sign(EVP_MD_CTX *ctx, const ASN1_ITEM *it, void *asn,
645                          X509_ALGOR *alg1, X509_ALGOR *alg2,
646                          ASN1_BIT_STRING *sig)
647 {
648     int pad_mode;
649     EVP_PKEY_CTX *pkctx = EVP_MD_CTX_pkey_ctx(ctx);
650     if (EVP_PKEY_CTX_get_rsa_padding(pkctx, &pad_mode) <= 0)
651         return 0;
652     if (pad_mode == RSA_PKCS1_PADDING)
653         return 2;
654     if (pad_mode == RSA_PKCS1_PSS_PADDING) {
655         ASN1_STRING *os1 = NULL;
656         os1 = rsa_ctx_to_pss(pkctx);
657         if (!os1)
658             return 0;
659         /* Duplicate parameters if we have to */
660         if (alg2) {
661             ASN1_STRING *os2 = ASN1_STRING_dup(os1);
662             if (!os2) {
663                 ASN1_STRING_free(os1);
664                 return 0;
665             }
666             X509_ALGOR_set0(alg2, OBJ_nid2obj(NID_rsassaPss),
667                             V_ASN1_SEQUENCE, os2);
668         }
669         X509_ALGOR_set0(alg1, OBJ_nid2obj(NID_rsassaPss),
670                         V_ASN1_SEQUENCE, os1);
671         return 3;
672     }
673     return 2;
674 }
675
676 #ifndef OPENSSL_NO_CMS
677 static RSA_OAEP_PARAMS *rsa_oaep_decode(const X509_ALGOR *alg,
678                                         X509_ALGOR **pmaskHash)
679 {
680     RSA_OAEP_PARAMS *pss;
681
682     *pmaskHash = NULL;
683
684     pss = ASN1_TYPE_unpack_sequence(ASN1_ITEM_rptr(RSA_OAEP_PARAMS),
685                                     alg->parameter);
686
687     if (!pss)
688         return NULL;
689
690     *pmaskHash = rsa_mgf1_decode(pss->maskGenFunc);
691
692     return pss;
693 }
694
695 static int rsa_cms_decrypt(CMS_RecipientInfo *ri)
696 {
697     EVP_PKEY_CTX *pkctx;
698     X509_ALGOR *cmsalg;
699     int nid;
700     int rv = -1;
701     unsigned char *label = NULL;
702     int labellen = 0;
703     const EVP_MD *mgf1md = NULL, *md = NULL;
704     RSA_OAEP_PARAMS *oaep;
705     X509_ALGOR *maskHash;
706     pkctx = CMS_RecipientInfo_get0_pkey_ctx(ri);
707     if (!pkctx)
708         return 0;
709     if (!CMS_RecipientInfo_ktri_get0_algs(ri, NULL, NULL, &cmsalg))
710         return -1;
711     nid = OBJ_obj2nid(cmsalg->algorithm);
712     if (nid == NID_rsaEncryption)
713         return 1;
714     if (nid != NID_rsaesOaep) {
715         RSAerr(RSA_F_RSA_CMS_DECRYPT, RSA_R_UNSUPPORTED_ENCRYPTION_TYPE);
716         return -1;
717     }
718     /* Decode OAEP parameters */
719     oaep = rsa_oaep_decode(cmsalg, &maskHash);
720
721     if (oaep == NULL) {
722         RSAerr(RSA_F_RSA_CMS_DECRYPT, RSA_R_INVALID_OAEP_PARAMETERS);
723         goto err;
724     }
725
726     mgf1md = rsa_mgf1_to_md(oaep->maskGenFunc, maskHash);
727     if (!mgf1md)
728         goto err;
729     md = rsa_algor_to_md(oaep->hashFunc);
730     if (!md)
731         goto err;
732
733     if (oaep->pSourceFunc) {
734         X509_ALGOR *plab = oaep->pSourceFunc;
735         if (OBJ_obj2nid(plab->algorithm) != NID_pSpecified) {
736             RSAerr(RSA_F_RSA_CMS_DECRYPT, RSA_R_UNSUPPORTED_LABEL_SOURCE);
737             goto err;
738         }
739         if (plab->parameter->type != V_ASN1_OCTET_STRING) {
740             RSAerr(RSA_F_RSA_CMS_DECRYPT, RSA_R_INVALID_LABEL);
741             goto err;
742         }
743
744         label = plab->parameter->value.octet_string->data;
745         /* Stop label being freed when OAEP parameters are freed */
746         plab->parameter->value.octet_string->data = NULL;
747         labellen = plab->parameter->value.octet_string->length;
748     }
749
750     if (EVP_PKEY_CTX_set_rsa_padding(pkctx, RSA_PKCS1_OAEP_PADDING) <= 0)
751         goto err;
752     if (EVP_PKEY_CTX_set_rsa_oaep_md(pkctx, md) <= 0)
753         goto err;
754     if (EVP_PKEY_CTX_set_rsa_mgf1_md(pkctx, mgf1md) <= 0)
755         goto err;
756     if (EVP_PKEY_CTX_set0_rsa_oaep_label(pkctx, label, labellen) <= 0)
757         goto err;
758     /* Carry on */
759     rv = 1;
760
761  err:
762     RSA_OAEP_PARAMS_free(oaep);
763     X509_ALGOR_free(maskHash);
764     return rv;
765 }
766
767 static int rsa_cms_encrypt(CMS_RecipientInfo *ri)
768 {
769     const EVP_MD *md, *mgf1md;
770     RSA_OAEP_PARAMS *oaep = NULL;
771     ASN1_STRING *os = NULL;
772     X509_ALGOR *alg;
773     EVP_PKEY_CTX *pkctx = CMS_RecipientInfo_get0_pkey_ctx(ri);
774     int pad_mode = RSA_PKCS1_PADDING, rv = 0, labellen;
775     unsigned char *label;
776     CMS_RecipientInfo_ktri_get0_algs(ri, NULL, NULL, &alg);
777     if (pkctx) {
778         if (EVP_PKEY_CTX_get_rsa_padding(pkctx, &pad_mode) <= 0)
779             return 0;
780     }
781     if (pad_mode == RSA_PKCS1_PADDING) {
782         X509_ALGOR_set0(alg, OBJ_nid2obj(NID_rsaEncryption), V_ASN1_NULL, 0);
783         return 1;
784     }
785     /* Not supported */
786     if (pad_mode != RSA_PKCS1_OAEP_PADDING)
787         return 0;
788     if (EVP_PKEY_CTX_get_rsa_oaep_md(pkctx, &md) <= 0)
789         goto err;
790     if (EVP_PKEY_CTX_get_rsa_mgf1_md(pkctx, &mgf1md) <= 0)
791         goto err;
792     labellen = EVP_PKEY_CTX_get0_rsa_oaep_label(pkctx, &label);
793     if (labellen < 0)
794         goto err;
795     oaep = RSA_OAEP_PARAMS_new();
796     if (oaep == NULL)
797         goto err;
798     if (!rsa_md_to_algor(&oaep->hashFunc, md))
799         goto err;
800     if (!rsa_md_to_mgf1(&oaep->maskGenFunc, mgf1md))
801         goto err;
802     if (labellen > 0) {
803         ASN1_OCTET_STRING *los;
804         oaep->pSourceFunc = X509_ALGOR_new();
805         if (oaep->pSourceFunc == NULL)
806             goto err;
807         los = ASN1_OCTET_STRING_new();
808         if (los == NULL)
809             goto err;
810         if (!ASN1_OCTET_STRING_set(los, label, labellen)) {
811             ASN1_OCTET_STRING_free(los);
812             goto err;
813         }
814         X509_ALGOR_set0(oaep->pSourceFunc, OBJ_nid2obj(NID_pSpecified),
815                         V_ASN1_OCTET_STRING, los);
816     }
817     /* create string with pss parameter encoding. */
818     if (!ASN1_item_pack(oaep, ASN1_ITEM_rptr(RSA_OAEP_PARAMS), &os))
819          goto err;
820     X509_ALGOR_set0(alg, OBJ_nid2obj(NID_rsaesOaep), V_ASN1_SEQUENCE, os);
821     os = NULL;
822     rv = 1;
823  err:
824     RSA_OAEP_PARAMS_free(oaep);
825     ASN1_STRING_free(os);
826     return rv;
827 }
828 #endif
829
830 const EVP_PKEY_ASN1_METHOD rsa_asn1_meths[2] = {
831     {
832      EVP_PKEY_RSA,
833      EVP_PKEY_RSA,
834      ASN1_PKEY_SIGPARAM_NULL,
835
836      "RSA",
837      "OpenSSL RSA method",
838
839      rsa_pub_decode,
840      rsa_pub_encode,
841      rsa_pub_cmp,
842      rsa_pub_print,
843
844      rsa_priv_decode,
845      rsa_priv_encode,
846      rsa_priv_print,
847
848      int_rsa_size,
849      rsa_bits,
850      rsa_security_bits,
851
852      0, 0, 0, 0, 0, 0,
853
854      rsa_sig_print,
855      int_rsa_free,
856      rsa_pkey_ctrl,
857      old_rsa_priv_decode,
858      old_rsa_priv_encode,
859      rsa_item_verify,
860      rsa_item_sign},
861
862     {
863      EVP_PKEY_RSA2,
864      EVP_PKEY_RSA,
865      ASN1_PKEY_ALIAS}
866 };