PSS parameter encode and decode.
[openssl.git] / crypto / rsa / rsa_ameth.c
1 /*
2  * Copyright 2006-2016 The OpenSSL Project Authors. All Rights Reserved.
3  *
4  * Licensed under the OpenSSL license (the "License").  You may not use
5  * this file except in compliance with the License.  You can obtain a copy
6  * in the file LICENSE in the source distribution or at
7  * https://www.openssl.org/source/license.html
8  */
9
10 #include <stdio.h>
11 #include "internal/cryptlib.h"
12 #include <openssl/asn1t.h>
13 #include <openssl/x509.h>
14 #include <openssl/bn.h>
15 #include <openssl/cms.h>
16 #include "internal/asn1_int.h"
17 #include "internal/evp_int.h"
18 #include "rsa_locl.h"
19
20 #ifndef OPENSSL_NO_CMS
21 static int rsa_cms_sign(CMS_SignerInfo *si);
22 static int rsa_cms_verify(CMS_SignerInfo *si);
23 static int rsa_cms_decrypt(CMS_RecipientInfo *ri);
24 static int rsa_cms_encrypt(CMS_RecipientInfo *ri);
25 #endif
26
27 /* Set any parameters associated with pkey */
28 static int rsa_param_encode(const EVP_PKEY *pkey,
29                             ASN1_STRING **pstr, int *pstrtype)
30 {
31     const RSA *rsa = pkey->pkey.rsa;
32     *pstr = NULL;
33     /* If RSA it's just NULL type */
34     if (pkey->ameth->pkey_id == EVP_PKEY_RSA) {
35         *pstrtype = V_ASN1_NULL;
36         return 1;
37     }
38     /* If no PSS parameters we omit parameters entirely */
39     if (rsa->pss == NULL) {
40         *pstrtype = V_ASN1_UNDEF;
41         return 1;
42     }
43     /* Encode PSS parameters */
44     if (!ASN1_item_pack(rsa->pss, ASN1_ITEM_rptr(RSA_PSS_PARAMS), pstr)) {
45         ASN1_STRING_free(*pstr);
46         *pstr = NULL;
47         return 0;
48     }
49
50     *pstrtype = V_ASN1_SEQUENCE;
51     return 1;
52 }
53 /* Decode any parameters and set them in RSA structure */
54 static int rsa_param_decode(RSA *rsa, const X509_ALGOR *alg)
55 {
56     const ASN1_OBJECT *algoid;
57     const void *algp;
58     int algptype;
59
60     X509_ALGOR_get0(&algoid, &algptype, &algp, alg);
61     if (OBJ_obj2nid(algoid) == EVP_PKEY_RSA)
62         return 1;
63     if (algptype == V_ASN1_UNDEF)
64         return 1;
65     if (algptype != V_ASN1_SEQUENCE)
66         return 0;
67     rsa->pss = ASN1_item_unpack(algp, ASN1_ITEM_rptr(RSA_PSS_PARAMS));
68     if (rsa->pss == NULL)
69         return 0;
70     return 1;
71 }
72
73 static int rsa_pub_encode(X509_PUBKEY *pk, const EVP_PKEY *pkey)
74 {
75     unsigned char *penc = NULL;
76     int penclen;
77     ASN1_STRING *str;
78     int strtype;
79     if (!rsa_param_encode(pkey, &str, &strtype))
80         return 0;
81     penclen = i2d_RSAPublicKey(pkey->pkey.rsa, &penc);
82     if (penclen <= 0)
83         return 0;
84     if (X509_PUBKEY_set0_param(pk, OBJ_nid2obj(pkey->ameth->pkey_id),
85                                strtype, str, penc, penclen))
86         return 1;
87
88     OPENSSL_free(penc);
89     return 0;
90 }
91
92 static int rsa_pub_decode(EVP_PKEY *pkey, X509_PUBKEY *pubkey)
93 {
94     const unsigned char *p;
95     int pklen;
96     X509_ALGOR *alg;
97     RSA *rsa = NULL;
98
99     if (!X509_PUBKEY_get0_param(NULL, &p, &pklen, &alg, pubkey))
100         return 0;
101     if ((rsa = d2i_RSAPublicKey(NULL, &p, pklen)) == NULL) {
102         RSAerr(RSA_F_RSA_PUB_DECODE, ERR_R_RSA_LIB);
103         return 0;
104     }
105     if (!rsa_param_decode(rsa, alg)) {
106         RSA_free(rsa);
107         return 0;
108     }
109     EVP_PKEY_assign(pkey, pkey->ameth->pkey_id, rsa);
110     return 1;
111 }
112
113 static int rsa_pub_cmp(const EVP_PKEY *a, const EVP_PKEY *b)
114 {
115     if (BN_cmp(b->pkey.rsa->n, a->pkey.rsa->n) != 0
116         || BN_cmp(b->pkey.rsa->e, a->pkey.rsa->e) != 0)
117         return 0;
118     return 1;
119 }
120
121 static int old_rsa_priv_decode(EVP_PKEY *pkey,
122                                const unsigned char **pder, int derlen)
123 {
124     RSA *rsa;
125
126     if ((rsa = d2i_RSAPrivateKey(NULL, pder, derlen)) == NULL) {
127         RSAerr(RSA_F_OLD_RSA_PRIV_DECODE, ERR_R_RSA_LIB);
128         return 0;
129     }
130     EVP_PKEY_assign(pkey, pkey->ameth->pkey_id, rsa);
131     return 1;
132 }
133
134 static int old_rsa_priv_encode(const EVP_PKEY *pkey, unsigned char **pder)
135 {
136     return i2d_RSAPrivateKey(pkey->pkey.rsa, pder);
137 }
138
139 static int rsa_priv_encode(PKCS8_PRIV_KEY_INFO *p8, const EVP_PKEY *pkey)
140 {
141     unsigned char *rk = NULL;
142     int rklen;
143     ASN1_STRING *str;
144     int strtype;
145     if (!rsa_param_encode(pkey, &str, &strtype))
146         return 0;
147     rklen = i2d_RSAPrivateKey(pkey->pkey.rsa, &rk);
148
149     if (rklen <= 0) {
150         RSAerr(RSA_F_RSA_PRIV_ENCODE, ERR_R_MALLOC_FAILURE);
151         return 0;
152     }
153
154     if (!PKCS8_pkey_set0(p8, OBJ_nid2obj(pkey->ameth->pkey_id), 0,
155                          strtype, str, rk, rklen)) {
156         RSAerr(RSA_F_RSA_PRIV_ENCODE, ERR_R_MALLOC_FAILURE);
157         return 0;
158     }
159
160     return 1;
161 }
162
163 static int rsa_priv_decode(EVP_PKEY *pkey, const PKCS8_PRIV_KEY_INFO *p8)
164 {
165     const unsigned char *p;
166     RSA *rsa;
167     int pklen;
168     const X509_ALGOR *alg;
169
170     if (!PKCS8_pkey_get0(NULL, &p, &pklen, &alg, p8))
171         return 0;
172     rsa = d2i_RSAPrivateKey(NULL, &p, pklen);
173     if (rsa == NULL) {
174         RSAerr(RSA_F_RSA_PRIV_DECODE, ERR_R_RSA_LIB);
175         return 0;
176     }
177     if (!rsa_param_decode(rsa, alg)) {
178         RSA_free(rsa);
179         return 0;
180     }
181     EVP_PKEY_assign(pkey, pkey->ameth->pkey_id, rsa);
182     return 1;
183 }
184
185 static int int_rsa_size(const EVP_PKEY *pkey)
186 {
187     return RSA_size(pkey->pkey.rsa);
188 }
189
190 static int rsa_bits(const EVP_PKEY *pkey)
191 {
192     return BN_num_bits(pkey->pkey.rsa->n);
193 }
194
195 static int rsa_security_bits(const EVP_PKEY *pkey)
196 {
197     return RSA_security_bits(pkey->pkey.rsa);
198 }
199
200 static void int_rsa_free(EVP_PKEY *pkey)
201 {
202     RSA_free(pkey->pkey.rsa);
203 }
204
205 static int do_rsa_print(BIO *bp, const RSA *x, int off, int priv)
206 {
207     char *str;
208     const char *s;
209     int ret = 0, mod_len = 0;
210
211     if (x->n != NULL)
212         mod_len = BN_num_bits(x->n);
213
214     if (!BIO_indent(bp, off, 128))
215         goto err;
216
217     if (priv && x->d) {
218         if (BIO_printf(bp, "Private-Key: (%d bit)\n", mod_len) <= 0)
219             goto err;
220         str = "modulus:";
221         s = "publicExponent:";
222     } else {
223         if (BIO_printf(bp, "Public-Key: (%d bit)\n", mod_len) <= 0)
224             goto err;
225         str = "Modulus:";
226         s = "Exponent:";
227     }
228     if (!ASN1_bn_print(bp, str, x->n, NULL, off))
229         goto err;
230     if (!ASN1_bn_print(bp, s, x->e, NULL, off))
231         goto err;
232     if (priv) {
233         if (!ASN1_bn_print(bp, "privateExponent:", x->d, NULL, off))
234             goto err;
235         if (!ASN1_bn_print(bp, "prime1:", x->p, NULL, off))
236             goto err;
237         if (!ASN1_bn_print(bp, "prime2:", x->q, NULL, off))
238             goto err;
239         if (!ASN1_bn_print(bp, "exponent1:", x->dmp1, NULL, off))
240             goto err;
241         if (!ASN1_bn_print(bp, "exponent2:", x->dmq1, NULL, off))
242             goto err;
243         if (!ASN1_bn_print(bp, "coefficient:", x->iqmp, NULL, off))
244             goto err;
245     }
246     ret = 1;
247  err:
248     return (ret);
249 }
250
251 static int rsa_pub_print(BIO *bp, const EVP_PKEY *pkey, int indent,
252                          ASN1_PCTX *ctx)
253 {
254     return do_rsa_print(bp, pkey->pkey.rsa, indent, 0);
255 }
256
257 static int rsa_priv_print(BIO *bp, const EVP_PKEY *pkey, int indent,
258                           ASN1_PCTX *ctx)
259 {
260     return do_rsa_print(bp, pkey->pkey.rsa, indent, 1);
261 }
262
263 static X509_ALGOR *rsa_mgf1_decode(X509_ALGOR *alg)
264 {
265     if (OBJ_obj2nid(alg->algorithm) != NID_mgf1)
266         return NULL;
267     return ASN1_TYPE_unpack_sequence(ASN1_ITEM_rptr(X509_ALGOR),
268                                      alg->parameter);
269 }
270
271 static RSA_PSS_PARAMS *rsa_pss_decode(const X509_ALGOR *alg)
272 {
273     RSA_PSS_PARAMS *pss;
274
275     pss = ASN1_TYPE_unpack_sequence(ASN1_ITEM_rptr(RSA_PSS_PARAMS),
276                                     alg->parameter);
277
278     if (pss == NULL)
279         return NULL;
280
281     if (pss->maskGenAlgorithm != NULL) {
282         pss->maskHash = rsa_mgf1_decode(pss->maskGenAlgorithm);
283         if (pss->maskHash == NULL) {
284             RSA_PSS_PARAMS_free(pss);
285             return NULL;
286         }
287     }
288
289     return pss;
290 }
291
292 static int rsa_pss_param_print(BIO *bp, RSA_PSS_PARAMS *pss, int indent)
293 {
294     int rv = 0;
295     if (!pss) {
296         if (BIO_puts(bp, " (INVALID PSS PARAMETERS)\n") <= 0)
297             return 0;
298         return 1;
299     }
300     if (BIO_puts(bp, "\n") <= 0)
301         goto err;
302     if (!BIO_indent(bp, indent, 128))
303         goto err;
304     if (BIO_puts(bp, "Hash Algorithm: ") <= 0)
305         goto err;
306
307     if (pss->hashAlgorithm) {
308         if (i2a_ASN1_OBJECT(bp, pss->hashAlgorithm->algorithm) <= 0)
309             goto err;
310     } else if (BIO_puts(bp, "sha1 (default)") <= 0)
311         goto err;
312
313     if (BIO_puts(bp, "\n") <= 0)
314         goto err;
315
316     if (!BIO_indent(bp, indent, 128))
317         goto err;
318
319     if (BIO_puts(bp, "Mask Algorithm: ") <= 0)
320         goto err;
321     if (pss->maskGenAlgorithm) {
322         if (i2a_ASN1_OBJECT(bp, pss->maskGenAlgorithm->algorithm) <= 0)
323             goto err;
324         if (BIO_puts(bp, " with ") <= 0)
325             goto err;
326         if (pss->maskHash) {
327             if (i2a_ASN1_OBJECT(bp, pss->maskHash->algorithm) <= 0)
328                 goto err;
329         } else if (BIO_puts(bp, "INVALID") <= 0)
330             goto err;
331     } else if (BIO_puts(bp, "mgf1 with sha1 (default)") <= 0)
332         goto err;
333     BIO_puts(bp, "\n");
334
335     if (!BIO_indent(bp, indent, 128))
336         goto err;
337     if (BIO_puts(bp, "Salt Length: 0x") <= 0)
338         goto err;
339     if (pss->saltLength) {
340         if (i2a_ASN1_INTEGER(bp, pss->saltLength) <= 0)
341             goto err;
342     } else if (BIO_puts(bp, "14 (default)") <= 0)
343         goto err;
344     BIO_puts(bp, "\n");
345
346     if (!BIO_indent(bp, indent, 128))
347         goto err;
348     if (BIO_puts(bp, "Trailer Field: 0x") <= 0)
349         goto err;
350     if (pss->trailerField) {
351         if (i2a_ASN1_INTEGER(bp, pss->trailerField) <= 0)
352             goto err;
353     } else if (BIO_puts(bp, "BC (default)") <= 0)
354         goto err;
355     BIO_puts(bp, "\n");
356
357     rv = 1;
358
359  err:
360     return rv;
361
362 }
363
364 static int rsa_sig_print(BIO *bp, const X509_ALGOR *sigalg,
365                          const ASN1_STRING *sig, int indent, ASN1_PCTX *pctx)
366 {
367     if (OBJ_obj2nid(sigalg->algorithm) == EVP_PKEY_RSA_PSS) {
368         int rv;
369         RSA_PSS_PARAMS *pss;
370         pss = rsa_pss_decode(sigalg);
371         rv = rsa_pss_param_print(bp, pss, indent);
372         RSA_PSS_PARAMS_free(pss);
373         if (!rv)
374             return 0;
375     } else if (!sig && BIO_puts(bp, "\n") <= 0)
376         return 0;
377     if (sig)
378         return X509_signature_dump(bp, sig, indent);
379     return 1;
380 }
381
382 static int rsa_pkey_ctrl(EVP_PKEY *pkey, int op, long arg1, void *arg2)
383 {
384     X509_ALGOR *alg = NULL;
385     switch (op) {
386
387     case ASN1_PKEY_CTRL_PKCS7_SIGN:
388         if (arg1 == 0)
389             PKCS7_SIGNER_INFO_get0_algs(arg2, NULL, NULL, &alg);
390         break;
391
392     case ASN1_PKEY_CTRL_PKCS7_ENCRYPT:
393         if (arg1 == 0)
394             PKCS7_RECIP_INFO_get0_alg(arg2, &alg);
395         break;
396 #ifndef OPENSSL_NO_CMS
397     case ASN1_PKEY_CTRL_CMS_SIGN:
398         if (arg1 == 0)
399             return rsa_cms_sign(arg2);
400         else if (arg1 == 1)
401             return rsa_cms_verify(arg2);
402         break;
403
404     case ASN1_PKEY_CTRL_CMS_ENVELOPE:
405         if (arg1 == 0)
406             return rsa_cms_encrypt(arg2);
407         else if (arg1 == 1)
408             return rsa_cms_decrypt(arg2);
409         break;
410
411     case ASN1_PKEY_CTRL_CMS_RI_TYPE:
412         *(int *)arg2 = CMS_RECIPINFO_TRANS;
413         return 1;
414 #endif
415
416     case ASN1_PKEY_CTRL_DEFAULT_MD_NID:
417         *(int *)arg2 = NID_sha256;
418         return 1;
419
420     default:
421         return -2;
422
423     }
424
425     if (alg)
426         X509_ALGOR_set0(alg, OBJ_nid2obj(NID_rsaEncryption), V_ASN1_NULL, 0);
427
428     return 1;
429
430 }
431
432 /* allocate and set algorithm ID from EVP_MD, default SHA1 */
433 static int rsa_md_to_algor(X509_ALGOR **palg, const EVP_MD *md)
434 {
435     if (md == NULL || EVP_MD_type(md) == NID_sha1)
436         return 1;
437     *palg = X509_ALGOR_new();
438     if (*palg == NULL)
439         return 0;
440     X509_ALGOR_set_md(*palg, md);
441     return 1;
442 }
443
444 /* Allocate and set MGF1 algorithm ID from EVP_MD */
445 static int rsa_md_to_mgf1(X509_ALGOR **palg, const EVP_MD *mgf1md)
446 {
447     X509_ALGOR *algtmp = NULL;
448     ASN1_STRING *stmp = NULL;
449     *palg = NULL;
450     if (mgf1md == NULL || EVP_MD_type(mgf1md) == NID_sha1)
451         return 1;
452     /* need to embed algorithm ID inside another */
453     if (!rsa_md_to_algor(&algtmp, mgf1md))
454         goto err;
455     if (!ASN1_item_pack(algtmp, ASN1_ITEM_rptr(X509_ALGOR), &stmp))
456          goto err;
457     *palg = X509_ALGOR_new();
458     if (*palg == NULL)
459         goto err;
460     X509_ALGOR_set0(*palg, OBJ_nid2obj(NID_mgf1), V_ASN1_SEQUENCE, stmp);
461     stmp = NULL;
462  err:
463     ASN1_STRING_free(stmp);
464     X509_ALGOR_free(algtmp);
465     if (*palg)
466         return 1;
467     return 0;
468 }
469
470 /* convert algorithm ID to EVP_MD, default SHA1 */
471 static const EVP_MD *rsa_algor_to_md(X509_ALGOR *alg)
472 {
473     const EVP_MD *md;
474     if (!alg)
475         return EVP_sha1();
476     md = EVP_get_digestbyobj(alg->algorithm);
477     if (md == NULL)
478         RSAerr(RSA_F_RSA_ALGOR_TO_MD, RSA_R_UNKNOWN_DIGEST);
479     return md;
480 }
481
482 /*
483  * Convert EVP_PKEY_CTX in PSS mode into corresponding algorithm parameter,
484  * suitable for setting an AlgorithmIdentifier.
485  */
486
487 static RSA_PSS_PARAMS *rsa_ctx_to_pss(EVP_PKEY_CTX *pkctx)
488 {
489     const EVP_MD *sigmd, *mgf1md;
490     EVP_PKEY *pk = EVP_PKEY_CTX_get0_pkey(pkctx);
491     int saltlen;
492     if (EVP_PKEY_CTX_get_signature_md(pkctx, &sigmd) <= 0)
493         return NULL;
494     if (EVP_PKEY_CTX_get_rsa_mgf1_md(pkctx, &mgf1md) <= 0)
495         return NULL;
496     if (!EVP_PKEY_CTX_get_rsa_pss_saltlen(pkctx, &saltlen))
497         return NULL;
498     if (saltlen == -1)
499         saltlen = EVP_MD_size(sigmd);
500     else if (saltlen == -2) {
501         saltlen = EVP_PKEY_size(pk) - EVP_MD_size(sigmd) - 2;
502         if (((EVP_PKEY_bits(pk) - 1) & 0x7) == 0)
503             saltlen--;
504     }
505
506     return rsa_pss_params_create(sigmd, mgf1md, saltlen);
507 }
508
509 RSA_PSS_PARAMS *rsa_pss_params_create(const EVP_MD *sigmd,
510                                       const EVP_MD *mgf1md, int saltlen)
511 {
512     RSA_PSS_PARAMS *pss = RSA_PSS_PARAMS_new();
513     if (pss == NULL)
514         goto err;
515     if (saltlen != 20) {
516         pss->saltLength = ASN1_INTEGER_new();
517         if (pss->saltLength == NULL)
518             goto err;
519         if (!ASN1_INTEGER_set(pss->saltLength, saltlen))
520             goto err;
521     }
522     if (!rsa_md_to_algor(&pss->hashAlgorithm, sigmd))
523         goto err;
524     if (mgf1md == NULL)
525             mgf1md = sigmd;
526     if (!rsa_md_to_mgf1(&pss->maskGenAlgorithm, mgf1md))
527         goto err;
528     return pss;
529  err:
530     RSA_PSS_PARAMS_free(pss);
531     return NULL;
532 }
533
534 static ASN1_STRING *rsa_ctx_to_pss_string(EVP_PKEY_CTX *pkctx)
535 {
536     RSA_PSS_PARAMS *pss = rsa_ctx_to_pss(pkctx);
537     ASN1_STRING *os = NULL;
538     if (pss == NULL)
539         return NULL;
540
541     if (!ASN1_item_pack(pss, ASN1_ITEM_rptr(RSA_PSS_PARAMS), &os)) {
542         ASN1_STRING_free(os);
543         os = NULL;
544     }
545     RSA_PSS_PARAMS_free(pss);
546     return os;
547 }
548
549 /*
550  * From PSS AlgorithmIdentifier set public key parameters. If pkey isn't NULL
551  * then the EVP_MD_CTX is setup and initialised. If it is NULL parameters are
552  * passed to pkctx instead.
553  */
554
555 static int rsa_pss_to_ctx(EVP_MD_CTX *ctx, EVP_PKEY_CTX *pkctx,
556                           X509_ALGOR *sigalg, EVP_PKEY *pkey)
557 {
558     int rv = -1;
559     int saltlen;
560     const EVP_MD *mgf1md = NULL, *md = NULL;
561     RSA_PSS_PARAMS *pss;
562     /* Sanity check: make sure it is PSS */
563     if (OBJ_obj2nid(sigalg->algorithm) != EVP_PKEY_RSA_PSS) {
564         RSAerr(RSA_F_RSA_PSS_TO_CTX, RSA_R_UNSUPPORTED_SIGNATURE_TYPE);
565         return -1;
566     }
567     /* Decode PSS parameters */
568     pss = rsa_pss_decode(sigalg);
569
570     if (pss == NULL) {
571         RSAerr(RSA_F_RSA_PSS_TO_CTX, RSA_R_INVALID_PSS_PARAMETERS);
572         goto err;
573     }
574     mgf1md = rsa_algor_to_md(pss->maskHash);
575     if (!mgf1md)
576         goto err;
577     md = rsa_algor_to_md(pss->hashAlgorithm);
578     if (!md)
579         goto err;
580
581     if (pss->saltLength) {
582         saltlen = ASN1_INTEGER_get(pss->saltLength);
583
584         /*
585          * Could perform more salt length sanity checks but the main RSA
586          * routines will trap other invalid values anyway.
587          */
588         if (saltlen < 0) {
589             RSAerr(RSA_F_RSA_PSS_TO_CTX, RSA_R_INVALID_SALT_LENGTH);
590             goto err;
591         }
592     } else
593         saltlen = 20;
594
595     /*
596      * low-level routines support only trailer field 0xbc (value 1) and
597      * PKCS#1 says we should reject any other value anyway.
598      */
599     if (pss->trailerField && ASN1_INTEGER_get(pss->trailerField) != 1) {
600         RSAerr(RSA_F_RSA_PSS_TO_CTX, RSA_R_INVALID_TRAILER);
601         goto err;
602     }
603
604     /* We have all parameters now set up context */
605
606     if (pkey) {
607         if (!EVP_DigestVerifyInit(ctx, &pkctx, md, NULL, pkey))
608             goto err;
609     } else {
610         const EVP_MD *checkmd;
611         if (EVP_PKEY_CTX_get_signature_md(pkctx, &checkmd) <= 0)
612             goto err;
613         if (EVP_MD_type(md) != EVP_MD_type(checkmd)) {
614             RSAerr(RSA_F_RSA_PSS_TO_CTX, RSA_R_DIGEST_DOES_NOT_MATCH);
615             goto err;
616         }
617     }
618
619     if (EVP_PKEY_CTX_set_rsa_padding(pkctx, RSA_PKCS1_PSS_PADDING) <= 0)
620         goto err;
621
622     if (EVP_PKEY_CTX_set_rsa_pss_saltlen(pkctx, saltlen) <= 0)
623         goto err;
624
625     if (EVP_PKEY_CTX_set_rsa_mgf1_md(pkctx, mgf1md) <= 0)
626         goto err;
627     /* Carry on */
628     rv = 1;
629
630  err:
631     RSA_PSS_PARAMS_free(pss);
632     return rv;
633 }
634
635 #ifndef OPENSSL_NO_CMS
636 static int rsa_cms_verify(CMS_SignerInfo *si)
637 {
638     int nid, nid2;
639     X509_ALGOR *alg;
640     EVP_PKEY_CTX *pkctx = CMS_SignerInfo_get0_pkey_ctx(si);
641     CMS_SignerInfo_get0_algs(si, NULL, NULL, NULL, &alg);
642     nid = OBJ_obj2nid(alg->algorithm);
643     if (nid == NID_rsaEncryption)
644         return 1;
645     if (nid == EVP_PKEY_RSA_PSS)
646         return rsa_pss_to_ctx(NULL, pkctx, alg, NULL);
647     /* Workaround for some implementation that use a signature OID */
648     if (OBJ_find_sigid_algs(nid, NULL, &nid2)) {
649         if (nid2 == NID_rsaEncryption)
650             return 1;
651     }
652     return 0;
653 }
654 #endif
655
656 /*
657  * Customised RSA item verification routine. This is called when a signature
658  * is encountered requiring special handling. We currently only handle PSS.
659  */
660
661 static int rsa_item_verify(EVP_MD_CTX *ctx, const ASN1_ITEM *it, void *asn,
662                            X509_ALGOR *sigalg, ASN1_BIT_STRING *sig,
663                            EVP_PKEY *pkey)
664 {
665     /* Sanity check: make sure it is PSS */
666     if (OBJ_obj2nid(sigalg->algorithm) != EVP_PKEY_RSA_PSS) {
667         RSAerr(RSA_F_RSA_ITEM_VERIFY, RSA_R_UNSUPPORTED_SIGNATURE_TYPE);
668         return -1;
669     }
670     if (rsa_pss_to_ctx(ctx, NULL, sigalg, pkey) > 0) {
671         /* Carry on */
672         return 2;
673     }
674     return -1;
675 }
676
677 #ifndef OPENSSL_NO_CMS
678 static int rsa_cms_sign(CMS_SignerInfo *si)
679 {
680     int pad_mode = RSA_PKCS1_PADDING;
681     X509_ALGOR *alg;
682     EVP_PKEY_CTX *pkctx = CMS_SignerInfo_get0_pkey_ctx(si);
683     ASN1_STRING *os = NULL;
684     CMS_SignerInfo_get0_algs(si, NULL, NULL, NULL, &alg);
685     if (pkctx) {
686         if (EVP_PKEY_CTX_get_rsa_padding(pkctx, &pad_mode) <= 0)
687             return 0;
688     }
689     if (pad_mode == RSA_PKCS1_PADDING) {
690         X509_ALGOR_set0(alg, OBJ_nid2obj(NID_rsaEncryption), V_ASN1_NULL, 0);
691         return 1;
692     }
693     /* We don't support it */
694     if (pad_mode != RSA_PKCS1_PSS_PADDING)
695         return 0;
696     os = rsa_ctx_to_pss_string(pkctx);
697     if (!os)
698         return 0;
699     X509_ALGOR_set0(alg, OBJ_nid2obj(EVP_PKEY_RSA_PSS), V_ASN1_SEQUENCE, os);
700     return 1;
701 }
702 #endif
703
704 static int rsa_item_sign(EVP_MD_CTX *ctx, const ASN1_ITEM *it, void *asn,
705                          X509_ALGOR *alg1, X509_ALGOR *alg2,
706                          ASN1_BIT_STRING *sig)
707 {
708     int pad_mode;
709     EVP_PKEY_CTX *pkctx = EVP_MD_CTX_pkey_ctx(ctx);
710     if (EVP_PKEY_CTX_get_rsa_padding(pkctx, &pad_mode) <= 0)
711         return 0;
712     if (pad_mode == RSA_PKCS1_PADDING)
713         return 2;
714     if (pad_mode == RSA_PKCS1_PSS_PADDING) {
715         ASN1_STRING *os1 = NULL;
716         os1 = rsa_ctx_to_pss_string(pkctx);
717         if (!os1)
718             return 0;
719         /* Duplicate parameters if we have to */
720         if (alg2) {
721             ASN1_STRING *os2 = ASN1_STRING_dup(os1);
722             if (!os2) {
723                 ASN1_STRING_free(os1);
724                 return 0;
725             }
726             X509_ALGOR_set0(alg2, OBJ_nid2obj(EVP_PKEY_RSA_PSS),
727                             V_ASN1_SEQUENCE, os2);
728         }
729         X509_ALGOR_set0(alg1, OBJ_nid2obj(EVP_PKEY_RSA_PSS),
730                         V_ASN1_SEQUENCE, os1);
731         return 3;
732     }
733     return 2;
734 }
735
736 #ifndef OPENSSL_NO_CMS
737 static RSA_OAEP_PARAMS *rsa_oaep_decode(const X509_ALGOR *alg)
738 {
739     RSA_OAEP_PARAMS *oaep;
740
741     oaep = ASN1_TYPE_unpack_sequence(ASN1_ITEM_rptr(RSA_OAEP_PARAMS),
742                                     alg->parameter);
743
744     if (oaep == NULL)
745         return NULL;
746
747     if (oaep->maskGenFunc != NULL) {
748         oaep->maskHash = rsa_mgf1_decode(oaep->maskGenFunc);
749         if (oaep->maskHash == NULL) {
750             RSA_OAEP_PARAMS_free(oaep);
751             return NULL;
752         }
753     }
754     return oaep;
755 }
756
757 static int rsa_cms_decrypt(CMS_RecipientInfo *ri)
758 {
759     EVP_PKEY_CTX *pkctx;
760     X509_ALGOR *cmsalg;
761     int nid;
762     int rv = -1;
763     unsigned char *label = NULL;
764     int labellen = 0;
765     const EVP_MD *mgf1md = NULL, *md = NULL;
766     RSA_OAEP_PARAMS *oaep;
767     pkctx = CMS_RecipientInfo_get0_pkey_ctx(ri);
768     if (!pkctx)
769         return 0;
770     if (!CMS_RecipientInfo_ktri_get0_algs(ri, NULL, NULL, &cmsalg))
771         return -1;
772     nid = OBJ_obj2nid(cmsalg->algorithm);
773     if (nid == NID_rsaEncryption)
774         return 1;
775     if (nid != NID_rsaesOaep) {
776         RSAerr(RSA_F_RSA_CMS_DECRYPT, RSA_R_UNSUPPORTED_ENCRYPTION_TYPE);
777         return -1;
778     }
779     /* Decode OAEP parameters */
780     oaep = rsa_oaep_decode(cmsalg);
781
782     if (oaep == NULL) {
783         RSAerr(RSA_F_RSA_CMS_DECRYPT, RSA_R_INVALID_OAEP_PARAMETERS);
784         goto err;
785     }
786
787     mgf1md = rsa_algor_to_md(oaep->maskHash);
788     if (!mgf1md)
789         goto err;
790     md = rsa_algor_to_md(oaep->hashFunc);
791     if (!md)
792         goto err;
793
794     if (oaep->pSourceFunc) {
795         X509_ALGOR *plab = oaep->pSourceFunc;
796         if (OBJ_obj2nid(plab->algorithm) != NID_pSpecified) {
797             RSAerr(RSA_F_RSA_CMS_DECRYPT, RSA_R_UNSUPPORTED_LABEL_SOURCE);
798             goto err;
799         }
800         if (plab->parameter->type != V_ASN1_OCTET_STRING) {
801             RSAerr(RSA_F_RSA_CMS_DECRYPT, RSA_R_INVALID_LABEL);
802             goto err;
803         }
804
805         label = plab->parameter->value.octet_string->data;
806         /* Stop label being freed when OAEP parameters are freed */
807         plab->parameter->value.octet_string->data = NULL;
808         labellen = plab->parameter->value.octet_string->length;
809     }
810
811     if (EVP_PKEY_CTX_set_rsa_padding(pkctx, RSA_PKCS1_OAEP_PADDING) <= 0)
812         goto err;
813     if (EVP_PKEY_CTX_set_rsa_oaep_md(pkctx, md) <= 0)
814         goto err;
815     if (EVP_PKEY_CTX_set_rsa_mgf1_md(pkctx, mgf1md) <= 0)
816         goto err;
817     if (EVP_PKEY_CTX_set0_rsa_oaep_label(pkctx, label, labellen) <= 0)
818         goto err;
819     /* Carry on */
820     rv = 1;
821
822  err:
823     RSA_OAEP_PARAMS_free(oaep);
824     return rv;
825 }
826
827 static int rsa_cms_encrypt(CMS_RecipientInfo *ri)
828 {
829     const EVP_MD *md, *mgf1md;
830     RSA_OAEP_PARAMS *oaep = NULL;
831     ASN1_STRING *os = NULL;
832     X509_ALGOR *alg;
833     EVP_PKEY_CTX *pkctx = CMS_RecipientInfo_get0_pkey_ctx(ri);
834     int pad_mode = RSA_PKCS1_PADDING, rv = 0, labellen;
835     unsigned char *label;
836     CMS_RecipientInfo_ktri_get0_algs(ri, NULL, NULL, &alg);
837     if (pkctx) {
838         if (EVP_PKEY_CTX_get_rsa_padding(pkctx, &pad_mode) <= 0)
839             return 0;
840     }
841     if (pad_mode == RSA_PKCS1_PADDING) {
842         X509_ALGOR_set0(alg, OBJ_nid2obj(NID_rsaEncryption), V_ASN1_NULL, 0);
843         return 1;
844     }
845     /* Not supported */
846     if (pad_mode != RSA_PKCS1_OAEP_PADDING)
847         return 0;
848     if (EVP_PKEY_CTX_get_rsa_oaep_md(pkctx, &md) <= 0)
849         goto err;
850     if (EVP_PKEY_CTX_get_rsa_mgf1_md(pkctx, &mgf1md) <= 0)
851         goto err;
852     labellen = EVP_PKEY_CTX_get0_rsa_oaep_label(pkctx, &label);
853     if (labellen < 0)
854         goto err;
855     oaep = RSA_OAEP_PARAMS_new();
856     if (oaep == NULL)
857         goto err;
858     if (!rsa_md_to_algor(&oaep->hashFunc, md))
859         goto err;
860     if (!rsa_md_to_mgf1(&oaep->maskGenFunc, mgf1md))
861         goto err;
862     if (labellen > 0) {
863         ASN1_OCTET_STRING *los;
864         oaep->pSourceFunc = X509_ALGOR_new();
865         if (oaep->pSourceFunc == NULL)
866             goto err;
867         los = ASN1_OCTET_STRING_new();
868         if (los == NULL)
869             goto err;
870         if (!ASN1_OCTET_STRING_set(los, label, labellen)) {
871             ASN1_OCTET_STRING_free(los);
872             goto err;
873         }
874         X509_ALGOR_set0(oaep->pSourceFunc, OBJ_nid2obj(NID_pSpecified),
875                         V_ASN1_OCTET_STRING, los);
876     }
877     /* create string with pss parameter encoding. */
878     if (!ASN1_item_pack(oaep, ASN1_ITEM_rptr(RSA_OAEP_PARAMS), &os))
879          goto err;
880     X509_ALGOR_set0(alg, OBJ_nid2obj(NID_rsaesOaep), V_ASN1_SEQUENCE, os);
881     os = NULL;
882     rv = 1;
883  err:
884     RSA_OAEP_PARAMS_free(oaep);
885     ASN1_STRING_free(os);
886     return rv;
887 }
888 #endif
889
890 const EVP_PKEY_ASN1_METHOD rsa_asn1_meths[2] = {
891     {
892      EVP_PKEY_RSA,
893      EVP_PKEY_RSA,
894      ASN1_PKEY_SIGPARAM_NULL,
895
896      "RSA",
897      "OpenSSL RSA method",
898
899      rsa_pub_decode,
900      rsa_pub_encode,
901      rsa_pub_cmp,
902      rsa_pub_print,
903
904      rsa_priv_decode,
905      rsa_priv_encode,
906      rsa_priv_print,
907
908      int_rsa_size,
909      rsa_bits,
910      rsa_security_bits,
911
912      0, 0, 0, 0, 0, 0,
913
914      rsa_sig_print,
915      int_rsa_free,
916      rsa_pkey_ctrl,
917      old_rsa_priv_decode,
918      old_rsa_priv_encode,
919      rsa_item_verify,
920      rsa_item_sign},
921
922     {
923      EVP_PKEY_RSA2,
924      EVP_PKEY_RSA,
925      ASN1_PKEY_ALIAS}
926 };
927
928 const EVP_PKEY_ASN1_METHOD rsa_pss_asn1_meth = {
929      EVP_PKEY_RSA_PSS,
930      EVP_PKEY_RSA_PSS,
931      ASN1_PKEY_SIGPARAM_NULL,
932
933      "RSA-PSS",
934      "OpenSSL RSA-PSS method",
935
936      rsa_pub_decode,
937      rsa_pub_encode,
938      rsa_pub_cmp,
939      rsa_pub_print,
940
941      rsa_priv_decode,
942      rsa_priv_encode,
943      rsa_priv_print,
944
945      int_rsa_size,
946      rsa_bits,
947      rsa_security_bits,
948
949      0, 0, 0, 0, 0, 0,
950
951      rsa_sig_print,
952      int_rsa_free,
953      rsa_pkey_ctrl,
954      0, 0,
955      rsa_item_verify,
956      rsa_item_sign,
957 };