Split PSS parameter creation.
[openssl.git] / crypto / rsa / rsa_ameth.c
1 /*
2  * Copyright 2006-2016 The OpenSSL Project Authors. All Rights Reserved.
3  *
4  * Licensed under the OpenSSL license (the "License").  You may not use
5  * this file except in compliance with the License.  You can obtain a copy
6  * in the file LICENSE in the source distribution or at
7  * https://www.openssl.org/source/license.html
8  */
9
10 #include <stdio.h>
11 #include "internal/cryptlib.h"
12 #include <openssl/asn1t.h>
13 #include <openssl/x509.h>
14 #include <openssl/bn.h>
15 #include <openssl/cms.h>
16 #include "internal/asn1_int.h"
17 #include "internal/evp_int.h"
18 #include "rsa_locl.h"
19
20 #ifndef OPENSSL_NO_CMS
21 static int rsa_cms_sign(CMS_SignerInfo *si);
22 static int rsa_cms_verify(CMS_SignerInfo *si);
23 static int rsa_cms_decrypt(CMS_RecipientInfo *ri);
24 static int rsa_cms_encrypt(CMS_RecipientInfo *ri);
25 #endif
26
27 static int rsa_pub_encode(X509_PUBKEY *pk, const EVP_PKEY *pkey)
28 {
29     unsigned char *penc = NULL;
30     int penclen;
31     penclen = i2d_RSAPublicKey(pkey->pkey.rsa, &penc);
32     if (penclen <= 0)
33         return 0;
34     if (X509_PUBKEY_set0_param(pk, OBJ_nid2obj(EVP_PKEY_RSA),
35                                V_ASN1_NULL, NULL, penc, penclen))
36         return 1;
37
38     OPENSSL_free(penc);
39     return 0;
40 }
41
42 static int rsa_pub_decode(EVP_PKEY *pkey, X509_PUBKEY *pubkey)
43 {
44     const unsigned char *p;
45     int pklen;
46     RSA *rsa = NULL;
47
48     if (!X509_PUBKEY_get0_param(NULL, &p, &pklen, NULL, pubkey))
49         return 0;
50     if ((rsa = d2i_RSAPublicKey(NULL, &p, pklen)) == NULL) {
51         RSAerr(RSA_F_RSA_PUB_DECODE, ERR_R_RSA_LIB);
52         return 0;
53     }
54     EVP_PKEY_assign(pkey, pkey->ameth->pkey_id, rsa);
55     return 1;
56 }
57
58 static int rsa_pub_cmp(const EVP_PKEY *a, const EVP_PKEY *b)
59 {
60     if (BN_cmp(b->pkey.rsa->n, a->pkey.rsa->n) != 0
61         || BN_cmp(b->pkey.rsa->e, a->pkey.rsa->e) != 0)
62         return 0;
63     return 1;
64 }
65
66 static int old_rsa_priv_decode(EVP_PKEY *pkey,
67                                const unsigned char **pder, int derlen)
68 {
69     RSA *rsa;
70
71     if ((rsa = d2i_RSAPrivateKey(NULL, pder, derlen)) == NULL) {
72         RSAerr(RSA_F_OLD_RSA_PRIV_DECODE, ERR_R_RSA_LIB);
73         return 0;
74     }
75     EVP_PKEY_assign(pkey, pkey->ameth->pkey_id, rsa);
76     return 1;
77 }
78
79 static int old_rsa_priv_encode(const EVP_PKEY *pkey, unsigned char **pder)
80 {
81     return i2d_RSAPrivateKey(pkey->pkey.rsa, pder);
82 }
83
84 static int rsa_priv_encode(PKCS8_PRIV_KEY_INFO *p8, const EVP_PKEY *pkey)
85 {
86     unsigned char *rk = NULL;
87     int rklen;
88     rklen = i2d_RSAPrivateKey(pkey->pkey.rsa, &rk);
89
90     if (rklen <= 0) {
91         RSAerr(RSA_F_RSA_PRIV_ENCODE, ERR_R_MALLOC_FAILURE);
92         return 0;
93     }
94
95     if (!PKCS8_pkey_set0(p8, OBJ_nid2obj(pkey->ameth->pkey_id), 0,
96                          V_ASN1_NULL, NULL, rk, rklen)) {
97         RSAerr(RSA_F_RSA_PRIV_ENCODE, ERR_R_MALLOC_FAILURE);
98         return 0;
99     }
100
101     return 1;
102 }
103
104 static int rsa_priv_decode(EVP_PKEY *pkey, const PKCS8_PRIV_KEY_INFO *p8)
105 {
106     const unsigned char *p;
107     int pklen;
108     if (!PKCS8_pkey_get0(NULL, &p, &pklen, NULL, p8))
109         return 0;
110     return old_rsa_priv_decode(pkey, &p, pklen);
111 }
112
113 static int int_rsa_size(const EVP_PKEY *pkey)
114 {
115     return RSA_size(pkey->pkey.rsa);
116 }
117
118 static int rsa_bits(const EVP_PKEY *pkey)
119 {
120     return BN_num_bits(pkey->pkey.rsa->n);
121 }
122
123 static int rsa_security_bits(const EVP_PKEY *pkey)
124 {
125     return RSA_security_bits(pkey->pkey.rsa);
126 }
127
128 static void int_rsa_free(EVP_PKEY *pkey)
129 {
130     RSA_free(pkey->pkey.rsa);
131 }
132
133 static int do_rsa_print(BIO *bp, const RSA *x, int off, int priv)
134 {
135     char *str;
136     const char *s;
137     int ret = 0, mod_len = 0;
138
139     if (x->n != NULL)
140         mod_len = BN_num_bits(x->n);
141
142     if (!BIO_indent(bp, off, 128))
143         goto err;
144
145     if (priv && x->d) {
146         if (BIO_printf(bp, "Private-Key: (%d bit)\n", mod_len) <= 0)
147             goto err;
148         str = "modulus:";
149         s = "publicExponent:";
150     } else {
151         if (BIO_printf(bp, "Public-Key: (%d bit)\n", mod_len) <= 0)
152             goto err;
153         str = "Modulus:";
154         s = "Exponent:";
155     }
156     if (!ASN1_bn_print(bp, str, x->n, NULL, off))
157         goto err;
158     if (!ASN1_bn_print(bp, s, x->e, NULL, off))
159         goto err;
160     if (priv) {
161         if (!ASN1_bn_print(bp, "privateExponent:", x->d, NULL, off))
162             goto err;
163         if (!ASN1_bn_print(bp, "prime1:", x->p, NULL, off))
164             goto err;
165         if (!ASN1_bn_print(bp, "prime2:", x->q, NULL, off))
166             goto err;
167         if (!ASN1_bn_print(bp, "exponent1:", x->dmp1, NULL, off))
168             goto err;
169         if (!ASN1_bn_print(bp, "exponent2:", x->dmq1, NULL, off))
170             goto err;
171         if (!ASN1_bn_print(bp, "coefficient:", x->iqmp, NULL, off))
172             goto err;
173     }
174     ret = 1;
175  err:
176     return (ret);
177 }
178
179 static int rsa_pub_print(BIO *bp, const EVP_PKEY *pkey, int indent,
180                          ASN1_PCTX *ctx)
181 {
182     return do_rsa_print(bp, pkey->pkey.rsa, indent, 0);
183 }
184
185 static int rsa_priv_print(BIO *bp, const EVP_PKEY *pkey, int indent,
186                           ASN1_PCTX *ctx)
187 {
188     return do_rsa_print(bp, pkey->pkey.rsa, indent, 1);
189 }
190
191 static X509_ALGOR *rsa_mgf1_decode(X509_ALGOR *alg)
192 {
193     if (OBJ_obj2nid(alg->algorithm) != NID_mgf1)
194         return NULL;
195     return ASN1_TYPE_unpack_sequence(ASN1_ITEM_rptr(X509_ALGOR),
196                                      alg->parameter);
197 }
198
199 static RSA_PSS_PARAMS *rsa_pss_decode(const X509_ALGOR *alg)
200 {
201     RSA_PSS_PARAMS *pss;
202
203     pss = ASN1_TYPE_unpack_sequence(ASN1_ITEM_rptr(RSA_PSS_PARAMS),
204                                     alg->parameter);
205
206     if (pss == NULL)
207         return NULL;
208
209     if (pss->maskGenAlgorithm != NULL) {
210         pss->maskHash = rsa_mgf1_decode(pss->maskGenAlgorithm);
211         if (pss->maskHash == NULL) {
212             RSA_PSS_PARAMS_free(pss);
213             return NULL;
214         }
215     }
216
217     return pss;
218 }
219
220 static int rsa_pss_param_print(BIO *bp, RSA_PSS_PARAMS *pss, int indent)
221 {
222     int rv = 0;
223     if (!pss) {
224         if (BIO_puts(bp, " (INVALID PSS PARAMETERS)\n") <= 0)
225             return 0;
226         return 1;
227     }
228     if (BIO_puts(bp, "\n") <= 0)
229         goto err;
230     if (!BIO_indent(bp, indent, 128))
231         goto err;
232     if (BIO_puts(bp, "Hash Algorithm: ") <= 0)
233         goto err;
234
235     if (pss->hashAlgorithm) {
236         if (i2a_ASN1_OBJECT(bp, pss->hashAlgorithm->algorithm) <= 0)
237             goto err;
238     } else if (BIO_puts(bp, "sha1 (default)") <= 0)
239         goto err;
240
241     if (BIO_puts(bp, "\n") <= 0)
242         goto err;
243
244     if (!BIO_indent(bp, indent, 128))
245         goto err;
246
247     if (BIO_puts(bp, "Mask Algorithm: ") <= 0)
248         goto err;
249     if (pss->maskGenAlgorithm) {
250         if (i2a_ASN1_OBJECT(bp, pss->maskGenAlgorithm->algorithm) <= 0)
251             goto err;
252         if (BIO_puts(bp, " with ") <= 0)
253             goto err;
254         if (pss->maskHash) {
255             if (i2a_ASN1_OBJECT(bp, pss->maskHash->algorithm) <= 0)
256                 goto err;
257         } else if (BIO_puts(bp, "INVALID") <= 0)
258             goto err;
259     } else if (BIO_puts(bp, "mgf1 with sha1 (default)") <= 0)
260         goto err;
261     BIO_puts(bp, "\n");
262
263     if (!BIO_indent(bp, indent, 128))
264         goto err;
265     if (BIO_puts(bp, "Salt Length: 0x") <= 0)
266         goto err;
267     if (pss->saltLength) {
268         if (i2a_ASN1_INTEGER(bp, pss->saltLength) <= 0)
269             goto err;
270     } else if (BIO_puts(bp, "14 (default)") <= 0)
271         goto err;
272     BIO_puts(bp, "\n");
273
274     if (!BIO_indent(bp, indent, 128))
275         goto err;
276     if (BIO_puts(bp, "Trailer Field: 0x") <= 0)
277         goto err;
278     if (pss->trailerField) {
279         if (i2a_ASN1_INTEGER(bp, pss->trailerField) <= 0)
280             goto err;
281     } else if (BIO_puts(bp, "BC (default)") <= 0)
282         goto err;
283     BIO_puts(bp, "\n");
284
285     rv = 1;
286
287  err:
288     return rv;
289
290 }
291
292 static int rsa_sig_print(BIO *bp, const X509_ALGOR *sigalg,
293                          const ASN1_STRING *sig, int indent, ASN1_PCTX *pctx)
294 {
295     if (OBJ_obj2nid(sigalg->algorithm) == EVP_PKEY_RSA_PSS) {
296         int rv;
297         RSA_PSS_PARAMS *pss;
298         pss = rsa_pss_decode(sigalg);
299         rv = rsa_pss_param_print(bp, pss, indent);
300         RSA_PSS_PARAMS_free(pss);
301         if (!rv)
302             return 0;
303     } else if (!sig && BIO_puts(bp, "\n") <= 0)
304         return 0;
305     if (sig)
306         return X509_signature_dump(bp, sig, indent);
307     return 1;
308 }
309
310 static int rsa_pkey_ctrl(EVP_PKEY *pkey, int op, long arg1, void *arg2)
311 {
312     X509_ALGOR *alg = NULL;
313     switch (op) {
314
315     case ASN1_PKEY_CTRL_PKCS7_SIGN:
316         if (arg1 == 0)
317             PKCS7_SIGNER_INFO_get0_algs(arg2, NULL, NULL, &alg);
318         break;
319
320     case ASN1_PKEY_CTRL_PKCS7_ENCRYPT:
321         if (arg1 == 0)
322             PKCS7_RECIP_INFO_get0_alg(arg2, &alg);
323         break;
324 #ifndef OPENSSL_NO_CMS
325     case ASN1_PKEY_CTRL_CMS_SIGN:
326         if (arg1 == 0)
327             return rsa_cms_sign(arg2);
328         else if (arg1 == 1)
329             return rsa_cms_verify(arg2);
330         break;
331
332     case ASN1_PKEY_CTRL_CMS_ENVELOPE:
333         if (arg1 == 0)
334             return rsa_cms_encrypt(arg2);
335         else if (arg1 == 1)
336             return rsa_cms_decrypt(arg2);
337         break;
338
339     case ASN1_PKEY_CTRL_CMS_RI_TYPE:
340         *(int *)arg2 = CMS_RECIPINFO_TRANS;
341         return 1;
342 #endif
343
344     case ASN1_PKEY_CTRL_DEFAULT_MD_NID:
345         *(int *)arg2 = NID_sha256;
346         return 1;
347
348     default:
349         return -2;
350
351     }
352
353     if (alg)
354         X509_ALGOR_set0(alg, OBJ_nid2obj(NID_rsaEncryption), V_ASN1_NULL, 0);
355
356     return 1;
357
358 }
359
360 /* allocate and set algorithm ID from EVP_MD, default SHA1 */
361 static int rsa_md_to_algor(X509_ALGOR **palg, const EVP_MD *md)
362 {
363     if (md == NULL || EVP_MD_type(md) == NID_sha1)
364         return 1;
365     *palg = X509_ALGOR_new();
366     if (*palg == NULL)
367         return 0;
368     X509_ALGOR_set_md(*palg, md);
369     return 1;
370 }
371
372 /* Allocate and set MGF1 algorithm ID from EVP_MD */
373 static int rsa_md_to_mgf1(X509_ALGOR **palg, const EVP_MD *mgf1md)
374 {
375     X509_ALGOR *algtmp = NULL;
376     ASN1_STRING *stmp = NULL;
377     *palg = NULL;
378     if (mgf1md == NULL || EVP_MD_type(mgf1md) == NID_sha1)
379         return 1;
380     /* need to embed algorithm ID inside another */
381     if (!rsa_md_to_algor(&algtmp, mgf1md))
382         goto err;
383     if (!ASN1_item_pack(algtmp, ASN1_ITEM_rptr(X509_ALGOR), &stmp))
384          goto err;
385     *palg = X509_ALGOR_new();
386     if (*palg == NULL)
387         goto err;
388     X509_ALGOR_set0(*palg, OBJ_nid2obj(NID_mgf1), V_ASN1_SEQUENCE, stmp);
389     stmp = NULL;
390  err:
391     ASN1_STRING_free(stmp);
392     X509_ALGOR_free(algtmp);
393     if (*palg)
394         return 1;
395     return 0;
396 }
397
398 /* convert algorithm ID to EVP_MD, default SHA1 */
399 static const EVP_MD *rsa_algor_to_md(X509_ALGOR *alg)
400 {
401     const EVP_MD *md;
402     if (!alg)
403         return EVP_sha1();
404     md = EVP_get_digestbyobj(alg->algorithm);
405     if (md == NULL)
406         RSAerr(RSA_F_RSA_ALGOR_TO_MD, RSA_R_UNKNOWN_DIGEST);
407     return md;
408 }
409
410 /*
411  * Convert EVP_PKEY_CTX in PSS mode into corresponding algorithm parameter,
412  * suitable for setting an AlgorithmIdentifier.
413  */
414
415 static RSA_PSS_PARAMS *rsa_ctx_to_pss(EVP_PKEY_CTX *pkctx)
416 {
417     const EVP_MD *sigmd, *mgf1md;
418     EVP_PKEY *pk = EVP_PKEY_CTX_get0_pkey(pkctx);
419     int saltlen;
420     if (EVP_PKEY_CTX_get_signature_md(pkctx, &sigmd) <= 0)
421         return NULL;
422     if (EVP_PKEY_CTX_get_rsa_mgf1_md(pkctx, &mgf1md) <= 0)
423         return NULL;
424     if (!EVP_PKEY_CTX_get_rsa_pss_saltlen(pkctx, &saltlen))
425         return NULL;
426     if (saltlen == -1)
427         saltlen = EVP_MD_size(sigmd);
428     else if (saltlen == -2) {
429         saltlen = EVP_PKEY_size(pk) - EVP_MD_size(sigmd) - 2;
430         if (((EVP_PKEY_bits(pk) - 1) & 0x7) == 0)
431             saltlen--;
432     }
433
434     return rsa_pss_params_create(sigmd, mgf1md, saltlen);
435 }
436
437 RSA_PSS_PARAMS *rsa_pss_params_create(const EVP_MD *sigmd,
438                                       const EVP_MD *mgf1md, int saltlen)
439 {
440     RSA_PSS_PARAMS *pss = RSA_PSS_PARAMS_new();
441     if (pss == NULL)
442         goto err;
443     if (saltlen != 20) {
444         pss->saltLength = ASN1_INTEGER_new();
445         if (pss->saltLength == NULL)
446             goto err;
447         if (!ASN1_INTEGER_set(pss->saltLength, saltlen))
448             goto err;
449     }
450     if (!rsa_md_to_algor(&pss->hashAlgorithm, sigmd))
451         goto err;
452     if (mgf1md == NULL)
453             mgf1md = sigmd;
454     if (!rsa_md_to_mgf1(&pss->maskGenAlgorithm, mgf1md))
455         goto err;
456     return pss;
457  err:
458     RSA_PSS_PARAMS_free(pss);
459     return NULL;
460 }
461
462 static ASN1_STRING *rsa_ctx_to_pss_string(EVP_PKEY_CTX *pkctx)
463 {
464     RSA_PSS_PARAMS *pss = rsa_ctx_to_pss(pkctx);
465     ASN1_STRING *os = NULL;
466     if (pss == NULL)
467         return NULL;
468
469     if (!ASN1_item_pack(pss, ASN1_ITEM_rptr(RSA_PSS_PARAMS), &os)) {
470         ASN1_STRING_free(os);
471         os = NULL;
472     }
473     RSA_PSS_PARAMS_free(pss);
474     return os;
475 }
476
477 /*
478  * From PSS AlgorithmIdentifier set public key parameters. If pkey isn't NULL
479  * then the EVP_MD_CTX is setup and initialised. If it is NULL parameters are
480  * passed to pkctx instead.
481  */
482
483 static int rsa_pss_to_ctx(EVP_MD_CTX *ctx, EVP_PKEY_CTX *pkctx,
484                           X509_ALGOR *sigalg, EVP_PKEY *pkey)
485 {
486     int rv = -1;
487     int saltlen;
488     const EVP_MD *mgf1md = NULL, *md = NULL;
489     RSA_PSS_PARAMS *pss;
490     /* Sanity check: make sure it is PSS */
491     if (OBJ_obj2nid(sigalg->algorithm) != EVP_PKEY_RSA_PSS) {
492         RSAerr(RSA_F_RSA_PSS_TO_CTX, RSA_R_UNSUPPORTED_SIGNATURE_TYPE);
493         return -1;
494     }
495     /* Decode PSS parameters */
496     pss = rsa_pss_decode(sigalg);
497
498     if (pss == NULL) {
499         RSAerr(RSA_F_RSA_PSS_TO_CTX, RSA_R_INVALID_PSS_PARAMETERS);
500         goto err;
501     }
502     mgf1md = rsa_algor_to_md(pss->maskHash);
503     if (!mgf1md)
504         goto err;
505     md = rsa_algor_to_md(pss->hashAlgorithm);
506     if (!md)
507         goto err;
508
509     if (pss->saltLength) {
510         saltlen = ASN1_INTEGER_get(pss->saltLength);
511
512         /*
513          * Could perform more salt length sanity checks but the main RSA
514          * routines will trap other invalid values anyway.
515          */
516         if (saltlen < 0) {
517             RSAerr(RSA_F_RSA_PSS_TO_CTX, RSA_R_INVALID_SALT_LENGTH);
518             goto err;
519         }
520     } else
521         saltlen = 20;
522
523     /*
524      * low-level routines support only trailer field 0xbc (value 1) and
525      * PKCS#1 says we should reject any other value anyway.
526      */
527     if (pss->trailerField && ASN1_INTEGER_get(pss->trailerField) != 1) {
528         RSAerr(RSA_F_RSA_PSS_TO_CTX, RSA_R_INVALID_TRAILER);
529         goto err;
530     }
531
532     /* We have all parameters now set up context */
533
534     if (pkey) {
535         if (!EVP_DigestVerifyInit(ctx, &pkctx, md, NULL, pkey))
536             goto err;
537     } else {
538         const EVP_MD *checkmd;
539         if (EVP_PKEY_CTX_get_signature_md(pkctx, &checkmd) <= 0)
540             goto err;
541         if (EVP_MD_type(md) != EVP_MD_type(checkmd)) {
542             RSAerr(RSA_F_RSA_PSS_TO_CTX, RSA_R_DIGEST_DOES_NOT_MATCH);
543             goto err;
544         }
545     }
546
547     if (EVP_PKEY_CTX_set_rsa_padding(pkctx, RSA_PKCS1_PSS_PADDING) <= 0)
548         goto err;
549
550     if (EVP_PKEY_CTX_set_rsa_pss_saltlen(pkctx, saltlen) <= 0)
551         goto err;
552
553     if (EVP_PKEY_CTX_set_rsa_mgf1_md(pkctx, mgf1md) <= 0)
554         goto err;
555     /* Carry on */
556     rv = 1;
557
558  err:
559     RSA_PSS_PARAMS_free(pss);
560     return rv;
561 }
562
563 #ifndef OPENSSL_NO_CMS
564 static int rsa_cms_verify(CMS_SignerInfo *si)
565 {
566     int nid, nid2;
567     X509_ALGOR *alg;
568     EVP_PKEY_CTX *pkctx = CMS_SignerInfo_get0_pkey_ctx(si);
569     CMS_SignerInfo_get0_algs(si, NULL, NULL, NULL, &alg);
570     nid = OBJ_obj2nid(alg->algorithm);
571     if (nid == NID_rsaEncryption)
572         return 1;
573     if (nid == EVP_PKEY_RSA_PSS)
574         return rsa_pss_to_ctx(NULL, pkctx, alg, NULL);
575     /* Workaround for some implementation that use a signature OID */
576     if (OBJ_find_sigid_algs(nid, NULL, &nid2)) {
577         if (nid2 == NID_rsaEncryption)
578             return 1;
579     }
580     return 0;
581 }
582 #endif
583
584 /*
585  * Customised RSA item verification routine. This is called when a signature
586  * is encountered requiring special handling. We currently only handle PSS.
587  */
588
589 static int rsa_item_verify(EVP_MD_CTX *ctx, const ASN1_ITEM *it, void *asn,
590                            X509_ALGOR *sigalg, ASN1_BIT_STRING *sig,
591                            EVP_PKEY *pkey)
592 {
593     /* Sanity check: make sure it is PSS */
594     if (OBJ_obj2nid(sigalg->algorithm) != EVP_PKEY_RSA_PSS) {
595         RSAerr(RSA_F_RSA_ITEM_VERIFY, RSA_R_UNSUPPORTED_SIGNATURE_TYPE);
596         return -1;
597     }
598     if (rsa_pss_to_ctx(ctx, NULL, sigalg, pkey) > 0) {
599         /* Carry on */
600         return 2;
601     }
602     return -1;
603 }
604
605 #ifndef OPENSSL_NO_CMS
606 static int rsa_cms_sign(CMS_SignerInfo *si)
607 {
608     int pad_mode = RSA_PKCS1_PADDING;
609     X509_ALGOR *alg;
610     EVP_PKEY_CTX *pkctx = CMS_SignerInfo_get0_pkey_ctx(si);
611     ASN1_STRING *os = NULL;
612     CMS_SignerInfo_get0_algs(si, NULL, NULL, NULL, &alg);
613     if (pkctx) {
614         if (EVP_PKEY_CTX_get_rsa_padding(pkctx, &pad_mode) <= 0)
615             return 0;
616     }
617     if (pad_mode == RSA_PKCS1_PADDING) {
618         X509_ALGOR_set0(alg, OBJ_nid2obj(NID_rsaEncryption), V_ASN1_NULL, 0);
619         return 1;
620     }
621     /* We don't support it */
622     if (pad_mode != RSA_PKCS1_PSS_PADDING)
623         return 0;
624     os = rsa_ctx_to_pss_string(pkctx);
625     if (!os)
626         return 0;
627     X509_ALGOR_set0(alg, OBJ_nid2obj(EVP_PKEY_RSA_PSS), V_ASN1_SEQUENCE, os);
628     return 1;
629 }
630 #endif
631
632 static int rsa_item_sign(EVP_MD_CTX *ctx, const ASN1_ITEM *it, void *asn,
633                          X509_ALGOR *alg1, X509_ALGOR *alg2,
634                          ASN1_BIT_STRING *sig)
635 {
636     int pad_mode;
637     EVP_PKEY_CTX *pkctx = EVP_MD_CTX_pkey_ctx(ctx);
638     if (EVP_PKEY_CTX_get_rsa_padding(pkctx, &pad_mode) <= 0)
639         return 0;
640     if (pad_mode == RSA_PKCS1_PADDING)
641         return 2;
642     if (pad_mode == RSA_PKCS1_PSS_PADDING) {
643         ASN1_STRING *os1 = NULL;
644         os1 = rsa_ctx_to_pss_string(pkctx);
645         if (!os1)
646             return 0;
647         /* Duplicate parameters if we have to */
648         if (alg2) {
649             ASN1_STRING *os2 = ASN1_STRING_dup(os1);
650             if (!os2) {
651                 ASN1_STRING_free(os1);
652                 return 0;
653             }
654             X509_ALGOR_set0(alg2, OBJ_nid2obj(EVP_PKEY_RSA_PSS),
655                             V_ASN1_SEQUENCE, os2);
656         }
657         X509_ALGOR_set0(alg1, OBJ_nid2obj(EVP_PKEY_RSA_PSS),
658                         V_ASN1_SEQUENCE, os1);
659         return 3;
660     }
661     return 2;
662 }
663
664 #ifndef OPENSSL_NO_CMS
665 static RSA_OAEP_PARAMS *rsa_oaep_decode(const X509_ALGOR *alg)
666 {
667     RSA_OAEP_PARAMS *oaep;
668
669     oaep = ASN1_TYPE_unpack_sequence(ASN1_ITEM_rptr(RSA_OAEP_PARAMS),
670                                     alg->parameter);
671
672     if (oaep == NULL)
673         return NULL;
674
675     if (oaep->maskGenFunc != NULL) {
676         oaep->maskHash = rsa_mgf1_decode(oaep->maskGenFunc);
677         if (oaep->maskHash == NULL) {
678             RSA_OAEP_PARAMS_free(oaep);
679             return NULL;
680         }
681     }
682     return oaep;
683 }
684
685 static int rsa_cms_decrypt(CMS_RecipientInfo *ri)
686 {
687     EVP_PKEY_CTX *pkctx;
688     X509_ALGOR *cmsalg;
689     int nid;
690     int rv = -1;
691     unsigned char *label = NULL;
692     int labellen = 0;
693     const EVP_MD *mgf1md = NULL, *md = NULL;
694     RSA_OAEP_PARAMS *oaep;
695     pkctx = CMS_RecipientInfo_get0_pkey_ctx(ri);
696     if (!pkctx)
697         return 0;
698     if (!CMS_RecipientInfo_ktri_get0_algs(ri, NULL, NULL, &cmsalg))
699         return -1;
700     nid = OBJ_obj2nid(cmsalg->algorithm);
701     if (nid == NID_rsaEncryption)
702         return 1;
703     if (nid != NID_rsaesOaep) {
704         RSAerr(RSA_F_RSA_CMS_DECRYPT, RSA_R_UNSUPPORTED_ENCRYPTION_TYPE);
705         return -1;
706     }
707     /* Decode OAEP parameters */
708     oaep = rsa_oaep_decode(cmsalg);
709
710     if (oaep == NULL) {
711         RSAerr(RSA_F_RSA_CMS_DECRYPT, RSA_R_INVALID_OAEP_PARAMETERS);
712         goto err;
713     }
714
715     mgf1md = rsa_algor_to_md(oaep->maskHash);
716     if (!mgf1md)
717         goto err;
718     md = rsa_algor_to_md(oaep->hashFunc);
719     if (!md)
720         goto err;
721
722     if (oaep->pSourceFunc) {
723         X509_ALGOR *plab = oaep->pSourceFunc;
724         if (OBJ_obj2nid(plab->algorithm) != NID_pSpecified) {
725             RSAerr(RSA_F_RSA_CMS_DECRYPT, RSA_R_UNSUPPORTED_LABEL_SOURCE);
726             goto err;
727         }
728         if (plab->parameter->type != V_ASN1_OCTET_STRING) {
729             RSAerr(RSA_F_RSA_CMS_DECRYPT, RSA_R_INVALID_LABEL);
730             goto err;
731         }
732
733         label = plab->parameter->value.octet_string->data;
734         /* Stop label being freed when OAEP parameters are freed */
735         plab->parameter->value.octet_string->data = NULL;
736         labellen = plab->parameter->value.octet_string->length;
737     }
738
739     if (EVP_PKEY_CTX_set_rsa_padding(pkctx, RSA_PKCS1_OAEP_PADDING) <= 0)
740         goto err;
741     if (EVP_PKEY_CTX_set_rsa_oaep_md(pkctx, md) <= 0)
742         goto err;
743     if (EVP_PKEY_CTX_set_rsa_mgf1_md(pkctx, mgf1md) <= 0)
744         goto err;
745     if (EVP_PKEY_CTX_set0_rsa_oaep_label(pkctx, label, labellen) <= 0)
746         goto err;
747     /* Carry on */
748     rv = 1;
749
750  err:
751     RSA_OAEP_PARAMS_free(oaep);
752     return rv;
753 }
754
755 static int rsa_cms_encrypt(CMS_RecipientInfo *ri)
756 {
757     const EVP_MD *md, *mgf1md;
758     RSA_OAEP_PARAMS *oaep = NULL;
759     ASN1_STRING *os = NULL;
760     X509_ALGOR *alg;
761     EVP_PKEY_CTX *pkctx = CMS_RecipientInfo_get0_pkey_ctx(ri);
762     int pad_mode = RSA_PKCS1_PADDING, rv = 0, labellen;
763     unsigned char *label;
764     CMS_RecipientInfo_ktri_get0_algs(ri, NULL, NULL, &alg);
765     if (pkctx) {
766         if (EVP_PKEY_CTX_get_rsa_padding(pkctx, &pad_mode) <= 0)
767             return 0;
768     }
769     if (pad_mode == RSA_PKCS1_PADDING) {
770         X509_ALGOR_set0(alg, OBJ_nid2obj(NID_rsaEncryption), V_ASN1_NULL, 0);
771         return 1;
772     }
773     /* Not supported */
774     if (pad_mode != RSA_PKCS1_OAEP_PADDING)
775         return 0;
776     if (EVP_PKEY_CTX_get_rsa_oaep_md(pkctx, &md) <= 0)
777         goto err;
778     if (EVP_PKEY_CTX_get_rsa_mgf1_md(pkctx, &mgf1md) <= 0)
779         goto err;
780     labellen = EVP_PKEY_CTX_get0_rsa_oaep_label(pkctx, &label);
781     if (labellen < 0)
782         goto err;
783     oaep = RSA_OAEP_PARAMS_new();
784     if (oaep == NULL)
785         goto err;
786     if (!rsa_md_to_algor(&oaep->hashFunc, md))
787         goto err;
788     if (!rsa_md_to_mgf1(&oaep->maskGenFunc, mgf1md))
789         goto err;
790     if (labellen > 0) {
791         ASN1_OCTET_STRING *los;
792         oaep->pSourceFunc = X509_ALGOR_new();
793         if (oaep->pSourceFunc == NULL)
794             goto err;
795         los = ASN1_OCTET_STRING_new();
796         if (los == NULL)
797             goto err;
798         if (!ASN1_OCTET_STRING_set(los, label, labellen)) {
799             ASN1_OCTET_STRING_free(los);
800             goto err;
801         }
802         X509_ALGOR_set0(oaep->pSourceFunc, OBJ_nid2obj(NID_pSpecified),
803                         V_ASN1_OCTET_STRING, los);
804     }
805     /* create string with pss parameter encoding. */
806     if (!ASN1_item_pack(oaep, ASN1_ITEM_rptr(RSA_OAEP_PARAMS), &os))
807          goto err;
808     X509_ALGOR_set0(alg, OBJ_nid2obj(NID_rsaesOaep), V_ASN1_SEQUENCE, os);
809     os = NULL;
810     rv = 1;
811  err:
812     RSA_OAEP_PARAMS_free(oaep);
813     ASN1_STRING_free(os);
814     return rv;
815 }
816 #endif
817
818 const EVP_PKEY_ASN1_METHOD rsa_asn1_meths[2] = {
819     {
820      EVP_PKEY_RSA,
821      EVP_PKEY_RSA,
822      ASN1_PKEY_SIGPARAM_NULL,
823
824      "RSA",
825      "OpenSSL RSA method",
826
827      rsa_pub_decode,
828      rsa_pub_encode,
829      rsa_pub_cmp,
830      rsa_pub_print,
831
832      rsa_priv_decode,
833      rsa_priv_encode,
834      rsa_priv_print,
835
836      int_rsa_size,
837      rsa_bits,
838      rsa_security_bits,
839
840      0, 0, 0, 0, 0, 0,
841
842      rsa_sig_print,
843      int_rsa_free,
844      rsa_pkey_ctrl,
845      old_rsa_priv_decode,
846      old_rsa_priv_encode,
847      rsa_item_verify,
848      rsa_item_sign},
849
850     {
851      EVP_PKEY_RSA2,
852      EVP_PKEY_RSA,
853      ASN1_PKEY_ALIAS}
854 };
855
856 const EVP_PKEY_ASN1_METHOD rsa_pss_asn1_meth = {
857      EVP_PKEY_RSA_PSS,
858      EVP_PKEY_RSA_PSS,
859      ASN1_PKEY_SIGPARAM_NULL,
860
861      "RSA-PSS",
862      "OpenSSL RSA-PSS method",
863
864      rsa_pub_decode,
865      rsa_pub_encode,
866      rsa_pub_cmp,
867      rsa_pub_print,
868
869      rsa_priv_decode,
870      rsa_priv_encode,
871      rsa_priv_print,
872
873      int_rsa_size,
874      rsa_bits,
875      rsa_security_bits,
876
877      0, 0, 0, 0, 0, 0,
878
879      rsa_sig_print,
880      int_rsa_free,
881      rsa_pkey_ctrl,
882      0, 0,
883      rsa_item_verify,
884      rsa_item_sign,
885 };