6b0f55a19aa02f240263bae147341206276f9c9e
[openssl.git] / providers / implementations / signature / rsa.c
1 /*
2  * Copyright 2019 The OpenSSL Project Authors. All Rights Reserved.
3  *
4  * Licensed under the Apache License 2.0 (the "License").  You may not use
5  * this file except in compliance with the License.  You can obtain a copy
6  * in the file LICENSE in the source distribution or at
7  * https://www.openssl.org/source/license.html
8  */
9
10 /*
11  * RSA low level APIs are deprecated for public use, but still ok for
12  * internal use.
13  */
14 #include "internal/deprecated.h"
15
16 #include <string.h>
17 #include <openssl/crypto.h>
18 #include <openssl/core_numbers.h>
19 #include <openssl/core_names.h>
20 #include <openssl/err.h>
21 #include <openssl/rsa.h>
22 #include <openssl/params.h>
23 #include <openssl/evp.h>
24 #include "internal/cryptlib.h"
25 #include "internal/nelem.h"
26 #include "internal/sizes.h"
27 #include "crypto/rsa.h"
28 #include "prov/providercommonerr.h"
29 #include "prov/implementations.h"
30 #include "prov/provider_ctx.h"
31
32 static OSSL_OP_signature_newctx_fn rsa_newctx;
33 static OSSL_OP_signature_sign_init_fn rsa_signature_init;
34 static OSSL_OP_signature_verify_init_fn rsa_signature_init;
35 static OSSL_OP_signature_verify_recover_init_fn rsa_signature_init;
36 static OSSL_OP_signature_sign_fn rsa_sign;
37 static OSSL_OP_signature_verify_fn rsa_verify;
38 static OSSL_OP_signature_verify_recover_fn rsa_verify_recover;
39 static OSSL_OP_signature_digest_sign_init_fn rsa_digest_signverify_init;
40 static OSSL_OP_signature_digest_sign_update_fn rsa_digest_signverify_update;
41 static OSSL_OP_signature_digest_sign_final_fn rsa_digest_sign_final;
42 static OSSL_OP_signature_digest_verify_init_fn rsa_digest_signverify_init;
43 static OSSL_OP_signature_digest_verify_update_fn rsa_digest_signverify_update;
44 static OSSL_OP_signature_digest_verify_final_fn rsa_digest_verify_final;
45 static OSSL_OP_signature_freectx_fn rsa_freectx;
46 static OSSL_OP_signature_dupctx_fn rsa_dupctx;
47 static OSSL_OP_signature_get_ctx_params_fn rsa_get_ctx_params;
48 static OSSL_OP_signature_gettable_ctx_params_fn rsa_gettable_ctx_params;
49 static OSSL_OP_signature_set_ctx_params_fn rsa_set_ctx_params;
50 static OSSL_OP_signature_settable_ctx_params_fn rsa_settable_ctx_params;
51 static OSSL_OP_signature_get_ctx_md_params_fn rsa_get_ctx_md_params;
52 static OSSL_OP_signature_gettable_ctx_md_params_fn rsa_gettable_ctx_md_params;
53 static OSSL_OP_signature_set_ctx_md_params_fn rsa_set_ctx_md_params;
54 static OSSL_OP_signature_settable_ctx_md_params_fn rsa_settable_ctx_md_params;
55
56 static OSSL_ITEM padding_item[] = {
57     { RSA_PKCS1_PADDING,        "pkcs1"  },
58     { RSA_SSLV23_PADDING,       "sslv23" },
59     { RSA_NO_PADDING,           "none"   },
60     { RSA_PKCS1_OAEP_PADDING,   "oaep"   }, /* Correct spelling first */
61     { RSA_PKCS1_OAEP_PADDING,   "oeap"   },
62     { RSA_X931_PADDING,         "x931"   },
63     { RSA_PKCS1_PSS_PADDING,    "pss"    },
64     { 0,                        NULL     }
65 };
66
67 /*
68  * What's passed as an actual key is defined by the KEYMGMT interface.
69  * We happen to know that our KEYMGMT simply passes RSA structures, so
70  * we use that here too.
71  */
72
73 typedef struct {
74     OPENSSL_CTX *libctx;
75     RSA *rsa;
76
77     /*
78      * Flag to determine if the hash function can be changed (1) or not (0)
79      * Because it's dangerous to change during a DigestSign or DigestVerify
80      * operation, this flag is cleared by their Init function, and set again
81      * by their Final function.
82      */
83     unsigned int flag_allow_md : 1;
84
85     /* The Algorithm Identifier of the combined signature agorithm */
86     unsigned char aid[128];
87     size_t  aid_len;
88
89     /* main digest */
90     EVP_MD *md;
91     EVP_MD_CTX *mdctx;
92     int mdnid;
93     char mdname[OSSL_MAX_NAME_SIZE]; /* Purely informational */
94
95     /* RSA padding mode */
96     int pad_mode;
97     /* message digest for MGF1 */
98     EVP_MD *mgf1_md;
99     char mgf1_mdname[OSSL_MAX_NAME_SIZE]; /* Purely informational */
100     /* PSS salt length */
101     int saltlen;
102     /* Minimum salt length or -1 if no PSS parameter restriction */
103     int min_saltlen;
104
105     /* Temp buffer */
106     unsigned char *tbuf;
107
108 } PROV_RSA_CTX;
109
110 static size_t rsa_get_md_size(const PROV_RSA_CTX *prsactx)
111 {
112     if (prsactx->md != NULL)
113         return EVP_MD_size(prsactx->md);
114     return 0;
115 }
116
117 static int rsa_get_md_nid(const EVP_MD *md)
118 {
119     /*
120      * Because the RSA library deals with NIDs, we need to translate.
121      * We do so using EVP_MD_is_a(), and therefore need a name to NID
122      * map.
123      */
124     static const OSSL_ITEM name_to_nid[] = {
125         { NID_sha1,      OSSL_DIGEST_NAME_SHA1      },
126         { NID_sha224,    OSSL_DIGEST_NAME_SHA2_224  },
127         { NID_sha256,    OSSL_DIGEST_NAME_SHA2_256  },
128         { NID_sha384,    OSSL_DIGEST_NAME_SHA2_384  },
129         { NID_sha512,    OSSL_DIGEST_NAME_SHA2_512  },
130         { NID_md5,       OSSL_DIGEST_NAME_MD5       },
131         { NID_md5_sha1,  OSSL_DIGEST_NAME_MD5_SHA1  },
132         { NID_md2,       OSSL_DIGEST_NAME_MD2       },
133         { NID_md4,       OSSL_DIGEST_NAME_MD4       },
134         { NID_mdc2,      OSSL_DIGEST_NAME_MDC2      },
135         { NID_ripemd160, OSSL_DIGEST_NAME_RIPEMD160 },
136         { NID_sha3_224,  OSSL_DIGEST_NAME_SHA3_224  },
137         { NID_sha3_256,  OSSL_DIGEST_NAME_SHA3_256  },
138         { NID_sha3_384,  OSSL_DIGEST_NAME_SHA3_384  },
139         { NID_sha3_512,  OSSL_DIGEST_NAME_SHA3_512  },
140     };
141     size_t i;
142     int mdnid = NID_undef;
143
144     if (md == NULL)
145         goto end;
146
147     for (i = 0; i < OSSL_NELEM(name_to_nid); i++) {
148         if (EVP_MD_is_a(md, name_to_nid[i].ptr)) {
149             mdnid = (int)name_to_nid[i].id;
150             break;
151         }
152     }
153
154     if (mdnid == NID_undef)
155         ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_DIGEST);
156
157  end:
158     return mdnid;
159 }
160
161 static int rsa_check_padding(int mdnid, int padding)
162 {
163     if (padding == RSA_NO_PADDING) {
164         ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_PADDING_MODE);
165         return 0;
166     }
167
168     if (padding == RSA_X931_PADDING) {
169         if (RSA_X931_hash_id(mdnid) == -1) {
170             ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_X931_DIGEST);
171             return 0;
172         }
173     }
174
175     return 1;
176 }
177
178 static void *rsa_newctx(void *provctx)
179 {
180     PROV_RSA_CTX *prsactx = OPENSSL_zalloc(sizeof(PROV_RSA_CTX));
181
182     if (prsactx == NULL)
183         return NULL;
184
185     prsactx->libctx = PROV_LIBRARY_CONTEXT_OF(provctx);
186     prsactx->flag_allow_md = 1;
187     return prsactx;
188 }
189
190 /* True if PSS parameters are restricted */
191 #define rsa_pss_restricted(prsactx) (prsactx->min_saltlen != -1)
192
193 static int rsa_signature_init(void *vprsactx, void *vrsa)
194 {
195     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
196
197     if (prsactx == NULL || vrsa == NULL || !RSA_up_ref(vrsa))
198         return 0;
199
200     RSA_free(prsactx->rsa);
201     prsactx->rsa = vrsa;
202     if (RSA_get0_pss_params(prsactx->rsa) != NULL)
203         prsactx->pad_mode = RSA_PKCS1_PSS_PADDING;
204     else
205         prsactx->pad_mode = RSA_PKCS1_PADDING;
206     /* Maximum for sign, auto for verify */
207     prsactx->saltlen = RSA_PSS_SALTLEN_AUTO;
208     prsactx->min_saltlen = -1;
209
210     return 1;
211 }
212
213 static int rsa_setup_md(PROV_RSA_CTX *ctx, const char *mdname,
214                         const char *mdprops)
215 {
216     if (mdname != NULL) {
217         EVP_MD *md = EVP_MD_fetch(ctx->libctx, mdname, mdprops);
218         int md_nid = rsa_get_md_nid(md);
219         size_t algorithmidentifier_len = 0;
220         const unsigned char *algorithmidentifier = NULL;
221
222         if (md == NULL)
223             return 0;
224
225         if (!rsa_check_padding(md_nid, ctx->pad_mode)) {
226             EVP_MD_free(md);
227             return 0;
228         }
229
230         EVP_MD_CTX_free(ctx->mdctx);
231         EVP_MD_free(ctx->md);
232         ctx->md = NULL;
233         ctx->mdctx = NULL;
234         ctx->mdname[0] = '\0';
235         ctx->aid[0] = '\0';
236         ctx->aid_len = 0;
237
238         algorithmidentifier =
239             rsa_algorithmidentifier_encoding(md_nid, &algorithmidentifier_len);
240
241         ctx->md = md;
242         ctx->mdnid = md_nid;
243         OPENSSL_strlcpy(ctx->mdname, mdname, sizeof(ctx->mdname));
244         if (algorithmidentifier != NULL) {
245             memcpy(ctx->aid, algorithmidentifier, algorithmidentifier_len);
246             ctx->aid_len = algorithmidentifier_len;
247         }
248     }
249
250     return 1;
251 }
252
253 static int rsa_setup_mgf1_md(PROV_RSA_CTX *ctx, const char *mdname,
254                              const char *props)
255 {
256     if (ctx->mgf1_mdname[0] != '\0')
257         EVP_MD_free(ctx->mgf1_md);
258
259     if ((ctx->mgf1_md = EVP_MD_fetch(ctx->libctx, mdname, props)) == NULL)
260         return 0;
261     OPENSSL_strlcpy(ctx->mgf1_mdname, mdname, sizeof(ctx->mgf1_mdname));
262
263     return 1;
264 }
265
266 static int setup_tbuf(PROV_RSA_CTX *ctx)
267 {
268     if (ctx->tbuf != NULL)
269         return 1;
270     if ((ctx->tbuf = OPENSSL_malloc(RSA_size(ctx->rsa))) == NULL) {
271         ERR_raise(ERR_LIB_PROV, ERR_R_MALLOC_FAILURE);
272         return 0;
273     }
274     return 1;
275 }
276
277 static void clean_tbuf(PROV_RSA_CTX *ctx)
278 {
279     if (ctx->tbuf != NULL)
280         OPENSSL_cleanse(ctx->tbuf, RSA_size(ctx->rsa));
281 }
282
283 static void free_tbuf(PROV_RSA_CTX *ctx)
284 {
285     OPENSSL_clear_free(ctx->tbuf, RSA_size(ctx->rsa));
286     ctx->tbuf = NULL;
287 }
288
289 static int rsa_sign(void *vprsactx, unsigned char *sig, size_t *siglen,
290                     size_t sigsize, const unsigned char *tbs, size_t tbslen)
291 {
292     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
293     int ret;
294     size_t rsasize = RSA_size(prsactx->rsa);
295     size_t mdsize = rsa_get_md_size(prsactx);
296
297     if (sig == NULL) {
298         *siglen = rsasize;
299         return 1;
300     }
301
302     if (sigsize < (size_t)rsasize)
303         return 0;
304
305     if (mdsize != 0) {
306         if (tbslen != mdsize) {
307             ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_DIGEST_LENGTH);
308             return 0;
309         }
310
311 #ifndef FIPS_MODE
312         if (EVP_MD_is_a(prsactx->md, OSSL_DIGEST_NAME_MDC2)) {
313             unsigned int sltmp;
314
315             if (prsactx->pad_mode != RSA_PKCS1_PADDING) {
316                 ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_PADDING_MODE,
317                                "only PKCS#1 padding supported with MDC2");
318                 return 0;
319             }
320             ret = RSA_sign_ASN1_OCTET_STRING(0, tbs, tbslen, sig, &sltmp,
321                                              prsactx->rsa);
322
323             if (ret <= 0) {
324                 ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
325                 return 0;
326             }
327             ret = sltmp;
328             goto end;
329         }
330 #endif
331
332         switch (prsactx->pad_mode) {
333         case RSA_X931_PADDING:
334             if ((size_t)RSA_size(prsactx->rsa) < tbslen + 1) {
335                 ERR_raise(ERR_LIB_PROV, PROV_R_KEY_SIZE_TOO_SMALL);
336                 return 0;
337             }
338             if (!setup_tbuf(prsactx)) {
339                 ERR_raise(ERR_LIB_PROV, ERR_R_MALLOC_FAILURE);
340                 return 0;
341             }
342             memcpy(prsactx->tbuf, tbs, tbslen);
343             prsactx->tbuf[tbslen] = RSA_X931_hash_id(prsactx->mdnid);
344             ret = RSA_private_encrypt(tbslen + 1, prsactx->tbuf,
345                                       sig, prsactx->rsa, RSA_X931_PADDING);
346             clean_tbuf(prsactx);
347             break;
348
349         case RSA_PKCS1_PADDING:
350             {
351                 unsigned int sltmp;
352
353                 ret = RSA_sign(prsactx->mdnid, tbs, tbslen, sig, &sltmp,
354                                prsactx->rsa);
355                 if (ret <= 0) {
356                     ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
357                     return 0;
358                 }
359                 ret = sltmp;
360             }
361             break;
362
363         case RSA_PKCS1_PSS_PADDING:
364             /* Check PSS restrictions */
365             if (rsa_pss_restricted(prsactx)) {
366                 switch (prsactx->saltlen) {
367                 case RSA_PSS_SALTLEN_DIGEST:
368                     if (prsactx->min_saltlen > EVP_MD_size(prsactx->md)) {
369                         ERR_raise(ERR_LIB_PROV, PROV_R_PSS_SALTLEN_TOO_SMALL);
370                         return 0;
371                     }
372                     /* FALLTHRU */
373                 default:
374                     if (prsactx->saltlen >= 0
375                         && prsactx->saltlen < prsactx->min_saltlen) {
376                         ERR_raise(ERR_LIB_PROV, PROV_R_PSS_SALTLEN_TOO_SMALL);
377                         return 0;
378                     }
379                     break;
380                 }
381             }
382             if (!setup_tbuf(prsactx))
383                 return 0;
384             if (!RSA_padding_add_PKCS1_PSS_mgf1(prsactx->rsa,
385                                                 prsactx->tbuf, tbs,
386                                                 prsactx->md, prsactx->mgf1_md,
387                                                 prsactx->saltlen)) {
388                 ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
389                 return 0;
390             }
391             ret = RSA_private_encrypt(RSA_size(prsactx->rsa), prsactx->tbuf,
392                                       sig, prsactx->rsa, RSA_NO_PADDING);
393             clean_tbuf(prsactx);
394             break;
395
396         default:
397             ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_PADDING_MODE,
398                            "Only X.931, PKCS#1 v1.5 or PSS padding allowed");
399             return 0;
400         }
401     } else {
402         ret = RSA_private_encrypt(tbslen, tbs, sig, prsactx->rsa,
403                                   prsactx->pad_mode);
404     }
405
406 #ifndef FIPS_MODE
407  end:
408 #endif
409     if (ret <= 0) {
410         ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
411         return 0;
412     }
413
414     *siglen = ret;
415     return 1;
416 }
417
418 static int rsa_verify_recover(void *vprsactx,
419                               unsigned char *rout,
420                               size_t *routlen,
421                               size_t routsize,
422                               const unsigned char *sig,
423                               size_t siglen)
424 {
425     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
426     int ret;
427
428     if (rout == NULL) {
429         *routlen = RSA_size(prsactx->rsa);
430         return 1;
431     }
432
433     if (prsactx->md != NULL) {
434         switch (prsactx->pad_mode) {
435         case RSA_X931_PADDING:
436             if (!setup_tbuf(prsactx))
437                 return 0;
438             ret = RSA_public_decrypt(siglen, sig, prsactx->tbuf, prsactx->rsa,
439                                      RSA_X931_PADDING);
440             if (ret < 1) {
441                 ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
442                 return 0;
443             }
444             ret--;
445             if (prsactx->tbuf[ret] != RSA_X931_hash_id(prsactx->mdnid)) {
446                 ERR_raise(ERR_LIB_PROV, PROV_R_ALGORITHM_MISMATCH);
447                 return 0;
448             }
449             if (ret != EVP_MD_size(prsactx->md)) {
450                 ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_DIGEST_LENGTH,
451                                "Should be %d, but got %d",
452                                EVP_MD_size(prsactx->md), ret);
453                 return 0;
454             }
455
456             *routlen = ret;
457             if (routsize < (size_t)ret) {
458                 ERR_raise(ERR_LIB_PROV, PROV_R_OUTPUT_BUFFER_TOO_SMALL);
459                 return 0;
460             }
461             memcpy(rout, prsactx->tbuf, ret);
462             break;
463
464         case RSA_PKCS1_PADDING:
465             {
466                 size_t sltmp;
467
468                 ret = int_rsa_verify(prsactx->mdnid, NULL, 0, rout, &sltmp,
469                                      sig, siglen, prsactx->rsa);
470                 if (ret <= 0) {
471                     ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
472                     return 0;
473                 }
474                 ret = sltmp;
475             }
476             break;
477
478         default:
479             ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_PADDING_MODE,
480                            "Only X.931 or PKCS#1 v1.5 padding allowed");
481             return 0;
482         }
483     } else {
484         ret = RSA_public_decrypt(siglen, sig, rout, prsactx->rsa,
485                                  prsactx->pad_mode);
486         if (ret < 0) {
487             ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
488             return 0;
489         }
490     }
491     *routlen = ret;
492     return 1;
493 }
494
495 static int rsa_verify(void *vprsactx, const unsigned char *sig, size_t siglen,
496                       const unsigned char *tbs, size_t tbslen)
497 {
498     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
499     size_t rslen;
500
501     if (prsactx->md != NULL) {
502         switch (prsactx->pad_mode) {
503         case RSA_PKCS1_PADDING:
504             if (!RSA_verify(prsactx->mdnid, tbs, tbslen, sig, siglen,
505                             prsactx->rsa)) {
506                 ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
507                 return 0;
508             }
509             return 1;
510         case RSA_X931_PADDING:
511             if (rsa_verify_recover(prsactx, NULL, &rslen, 0, sig, siglen) <= 0)
512                 return 0;
513             break;
514         case RSA_PKCS1_PSS_PADDING:
515             {
516                 int ret;
517                 size_t mdsize;
518
519                 /* Check PSS restrictions */
520                 if (rsa_pss_restricted(prsactx)) {
521                     switch (prsactx->saltlen) {
522                     case RSA_PSS_SALTLEN_AUTO:
523                         ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_PSS_SALTLEN);
524                         return 0;
525                     case RSA_PSS_SALTLEN_DIGEST:
526                         if (prsactx->min_saltlen > EVP_MD_size(prsactx->md)) {
527                             ERR_raise(ERR_LIB_PROV,
528                                       PROV_R_PSS_SALTLEN_TOO_SMALL);
529                             return 0;
530                         }
531                         /* FALLTHRU */
532                     default:
533                         if (prsactx->saltlen >= 0
534                             && prsactx->saltlen < prsactx->min_saltlen) {
535                             ERR_raise(ERR_LIB_PROV, PROV_R_PSS_SALTLEN_TOO_SMALL);
536                             return 0;
537                         }
538                         break;
539                     }
540                 }
541
542                 /*
543                  * We need to check this for the RSA_verify_PKCS1_PSS_mgf1()
544                  * call
545                  */
546                 mdsize = rsa_get_md_size(prsactx);
547                 if (tbslen != mdsize) {
548                     ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_DIGEST_LENGTH,
549                                    "Should be %d, but got %d",
550                                    mdsize, tbslen);
551                     return 0;
552                 }
553
554                 if (!setup_tbuf(prsactx))
555                     return 0;
556                 ret = RSA_public_decrypt(siglen, sig, prsactx->tbuf,
557                                          prsactx->rsa, RSA_NO_PADDING);
558                 if (ret <= 0) {
559                     ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
560                     return 0;
561                 }
562                 ret = RSA_verify_PKCS1_PSS_mgf1(prsactx->rsa, tbs,
563                                                 prsactx->md, prsactx->mgf1_md,
564                                                 prsactx->tbuf,
565                                                 prsactx->saltlen);
566                 if (ret <= 0) {
567                     ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
568                     return 0;
569                 }
570                 return 1;
571             }
572         default:
573             ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_PADDING_MODE,
574                            "Only X.931, PKCS#1 v1.5 or PSS padding allowed");
575             return 0;
576         }
577     } else {
578         if (!setup_tbuf(prsactx))
579             return 0;
580         rslen = RSA_public_decrypt(siglen, sig, prsactx->tbuf, prsactx->rsa,
581                                    prsactx->pad_mode);
582         if (rslen == 0) {
583             ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
584             return 0;
585         }
586     }
587
588     if ((rslen != tbslen) || memcmp(tbs, prsactx->tbuf, rslen))
589         return 0;
590
591     return 1;
592 }
593
594 static int rsa_digest_signverify_init(void *vprsactx, const char *mdname,
595                                       const char *props, void *vrsa)
596 {
597     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
598
599     prsactx->flag_allow_md = 0;
600     if (!rsa_signature_init(vprsactx, vrsa)
601         || !rsa_setup_md(prsactx, mdname, props))
602         return 0;
603
604     prsactx->mdctx = EVP_MD_CTX_new();
605     if (prsactx->mdctx == NULL)
606         goto error;
607
608     if (!EVP_DigestInit_ex(prsactx->mdctx, prsactx->md, NULL))
609         goto error;
610
611     return 1;
612
613  error:
614     EVP_MD_CTX_free(prsactx->mdctx);
615     EVP_MD_free(prsactx->md);
616     prsactx->mdctx = NULL;
617     prsactx->md = NULL;
618     return 0;
619 }
620
621 int rsa_digest_signverify_update(void *vprsactx, const unsigned char *data,
622                                  size_t datalen)
623 {
624     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
625
626     if (prsactx == NULL || prsactx->mdctx == NULL)
627         return 0;
628
629     return EVP_DigestUpdate(prsactx->mdctx, data, datalen);
630 }
631
632 int rsa_digest_sign_final(void *vprsactx, unsigned char *sig, size_t *siglen,
633                           size_t sigsize)
634 {
635     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
636     unsigned char digest[EVP_MAX_MD_SIZE];
637     unsigned int dlen = 0;
638
639     prsactx->flag_allow_md = 1;
640     if (prsactx == NULL || prsactx->mdctx == NULL)
641         return 0;
642
643     /*
644      * If sig is NULL then we're just finding out the sig size. Other fields
645      * are ignored. Defer to rsa_sign.
646      */
647     if (sig != NULL) {
648         /*
649          * TODO(3.0): There is the possibility that some externally provided
650          * digests exceed EVP_MAX_MD_SIZE. We should probably handle that somehow -
651          * but that problem is much larger than just in RSA.
652          */
653         if (!EVP_DigestFinal_ex(prsactx->mdctx, digest, &dlen))
654             return 0;
655     }
656
657     return rsa_sign(vprsactx, sig, siglen, sigsize, digest, (size_t)dlen);
658 }
659
660
661 int rsa_digest_verify_final(void *vprsactx, const unsigned char *sig,
662                             size_t siglen)
663 {
664     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
665     unsigned char digest[EVP_MAX_MD_SIZE];
666     unsigned int dlen = 0;
667
668     prsactx->flag_allow_md = 1;
669     if (prsactx == NULL || prsactx->mdctx == NULL)
670         return 0;
671
672     /*
673      * TODO(3.0): There is the possibility that some externally provided
674      * digests exceed EVP_MAX_MD_SIZE. We should probably handle that somehow -
675      * but that problem is much larger than just in RSA.
676      */
677     if (!EVP_DigestFinal_ex(prsactx->mdctx, digest, &dlen))
678         return 0;
679
680     return rsa_verify(vprsactx, sig, siglen, digest, (size_t)dlen);
681 }
682
683 static void rsa_freectx(void *vprsactx)
684 {
685     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
686
687     if (prsactx == NULL)
688         return;
689
690     RSA_free(prsactx->rsa);
691     EVP_MD_CTX_free(prsactx->mdctx);
692     EVP_MD_free(prsactx->md);
693     EVP_MD_free(prsactx->mgf1_md);
694     free_tbuf(prsactx);
695
696     OPENSSL_clear_free(prsactx, sizeof(prsactx));
697 }
698
699 static void *rsa_dupctx(void *vprsactx)
700 {
701     PROV_RSA_CTX *srcctx = (PROV_RSA_CTX *)vprsactx;
702     PROV_RSA_CTX *dstctx;
703
704     dstctx = OPENSSL_zalloc(sizeof(*srcctx));
705     if (dstctx == NULL)
706         return NULL;
707
708     *dstctx = *srcctx;
709     dstctx->rsa = NULL;
710     dstctx->md = NULL;
711     dstctx->mdctx = NULL;
712     dstctx->tbuf = NULL;
713
714     if (srcctx->rsa != NULL && !RSA_up_ref(srcctx->rsa))
715         goto err;
716     dstctx->rsa = srcctx->rsa;
717
718     if (srcctx->md != NULL && !EVP_MD_up_ref(srcctx->md))
719         goto err;
720     dstctx->md = srcctx->md;
721
722     if (srcctx->mgf1_md != NULL && !EVP_MD_up_ref(srcctx->mgf1_md))
723         goto err;
724     dstctx->mgf1_md = srcctx->mgf1_md;
725
726     if (srcctx->mdctx != NULL) {
727         dstctx->mdctx = EVP_MD_CTX_new();
728         if (dstctx->mdctx == NULL
729                 || !EVP_MD_CTX_copy_ex(dstctx->mdctx, srcctx->mdctx))
730             goto err;
731     }
732
733     return dstctx;
734  err:
735     rsa_freectx(dstctx);
736     return NULL;
737 }
738
739 static int rsa_get_ctx_params(void *vprsactx, OSSL_PARAM *params)
740 {
741     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
742     OSSL_PARAM *p;
743
744     if (prsactx == NULL || params == NULL)
745         return 0;
746
747     p = OSSL_PARAM_locate(params, OSSL_SIGNATURE_PARAM_ALGORITHM_ID);
748     if (p != NULL
749         && !OSSL_PARAM_set_octet_string(p, prsactx->aid, prsactx->aid_len))
750         return 0;
751
752     p = OSSL_PARAM_locate(params, OSSL_SIGNATURE_PARAM_PAD_MODE);
753     if (p != NULL)
754         switch (p->data_type) {
755         case OSSL_PARAM_INTEGER:
756             if (!OSSL_PARAM_set_int(p, prsactx->pad_mode))
757                 return 0;
758             break;
759         case OSSL_PARAM_UTF8_STRING:
760             {
761                 int i;
762                 const char *word = NULL;
763
764                 for (i = 0; padding_item[i].id != 0; i++) {
765                     if (prsactx->pad_mode == (int)padding_item[i].id) {
766                         word = padding_item[i].ptr;
767                         break;
768                     }
769                 }
770
771                 if (word != NULL) {
772                     if (!OSSL_PARAM_set_utf8_string(p, word))
773                         return 0;
774                 } else {
775                     ERR_raise(ERR_LIB_PROV, ERR_R_INTERNAL_ERROR);
776                 }
777             }
778             break;
779         default:
780             return 0;
781         }
782
783     p = OSSL_PARAM_locate(params, OSSL_SIGNATURE_PARAM_DIGEST);
784     if (p != NULL && !OSSL_PARAM_set_utf8_string(p, prsactx->mdname))
785         return 0;
786
787     p = OSSL_PARAM_locate(params, OSSL_SIGNATURE_PARAM_MGF1_DIGEST);
788     if (p != NULL && !OSSL_PARAM_set_utf8_string(p, prsactx->mgf1_mdname))
789         return 0;
790
791     p = OSSL_PARAM_locate(params, OSSL_SIGNATURE_PARAM_PSS_SALTLEN);
792     if (p != NULL) {
793         if (p->data_type == OSSL_PARAM_INTEGER) {
794             if (!OSSL_PARAM_set_int(p, prsactx->saltlen))
795                 return 0;
796         } else if (p->data_type == OSSL_PARAM_UTF8_STRING) {
797             switch (prsactx->saltlen) {
798             case RSA_PSS_SALTLEN_DIGEST:
799                 if (!OSSL_PARAM_set_utf8_string(p, "digest"))
800                     return 0;
801                 break;
802             case RSA_PSS_SALTLEN_MAX:
803                 if (!OSSL_PARAM_set_utf8_string(p, "max"))
804                     return 0;
805                 break;
806             case RSA_PSS_SALTLEN_AUTO:
807                 if (!OSSL_PARAM_set_utf8_string(p, "auto"))
808                     return 0;
809                 break;
810             default:
811                 if (BIO_snprintf(p->data, p->data_size, "%d", prsactx->saltlen)
812                     <= 0)
813                     return 0;
814                 break;
815             }
816         }
817     }
818
819     return 1;
820 }
821
822 static const OSSL_PARAM known_gettable_ctx_params[] = {
823     OSSL_PARAM_octet_string(OSSL_SIGNATURE_PARAM_ALGORITHM_ID, NULL, 0),
824     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_PAD_MODE, NULL, 0),
825     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_DIGEST, NULL, 0),
826     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_MGF1_DIGEST, NULL, 0),
827     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_PSS_SALTLEN, NULL, 0),
828     OSSL_PARAM_END
829 };
830
831 static const OSSL_PARAM *rsa_gettable_ctx_params(void)
832 {
833     return known_gettable_ctx_params;
834 }
835
836 static int rsa_set_ctx_params(void *vprsactx, const OSSL_PARAM params[])
837 {
838     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
839     const OSSL_PARAM *p;
840
841     if (prsactx == NULL || params == NULL)
842         return 0;
843
844     p = OSSL_PARAM_locate_const(params, OSSL_SIGNATURE_PARAM_DIGEST);
845     /* Not allowed during certain operations */
846     if (p != NULL && !prsactx->flag_allow_md)
847         return 0;
848     if (p != NULL) {
849         char mdname[OSSL_MAX_NAME_SIZE] = "", *pmdname = mdname;
850         char mdprops[OSSL_MAX_PROPQUERY_SIZE] = "", *pmdprops = mdprops;
851         const OSSL_PARAM *propsp =
852             OSSL_PARAM_locate_const(params,
853                                     OSSL_SIGNATURE_PARAM_PROPERTIES);
854
855         if (!OSSL_PARAM_get_utf8_string(p, &pmdname, sizeof(mdname)))
856             return 0;
857         if (propsp != NULL
858             && !OSSL_PARAM_get_utf8_string(propsp, &pmdprops, sizeof(mdprops)))
859             return 0;
860
861         /* TODO(3.0) PSS check needs more work */
862         if (rsa_pss_restricted(prsactx)) {
863             /* TODO(3.0) figure out what to do for prsactx->md == NULL */
864             if (prsactx->md == NULL || EVP_MD_is_a(prsactx->md, mdname))
865                 return 1;
866             ERR_raise(ERR_LIB_PROV, PROV_R_DIGEST_NOT_ALLOWED);
867             return 0;
868         }
869
870         /* non-PSS code follows */
871         if (!rsa_setup_md(prsactx, mdname, mdprops))
872             return 0;
873     }
874
875     p = OSSL_PARAM_locate_const(params, OSSL_SIGNATURE_PARAM_PAD_MODE);
876     if (p != NULL) {
877         int pad_mode = 0;
878
879         switch (p->data_type) {
880         case OSSL_PARAM_INTEGER: /* Support for legacy pad mode number */
881             if (!OSSL_PARAM_get_int(p, &pad_mode))
882                 return 0;
883             break;
884         case OSSL_PARAM_UTF8_STRING:
885             {
886                 int i;
887
888                 if (p->data == NULL)
889                     return 0;
890
891                 for (i = 0; padding_item[i].id != 0; i++) {
892                     if (strcmp(p->data, padding_item[i].ptr) == 0) {
893                         pad_mode = padding_item[i].id;
894                         break;
895                     }
896                 }
897             }
898             break;
899         default:
900             return 0;
901         }
902
903         switch (pad_mode) {
904         case RSA_PKCS1_OAEP_PADDING:
905             /*
906              * OAEP padding is for asymmetric cipher only so is not compatible
907              * with signature use.
908              */
909             ERR_raise_data(ERR_LIB_PROV,
910                            PROV_R_ILLEGAL_OR_UNSUPPORTED_PADDING_MODE,
911                            "OAEP padding not allowed for signing / verifying");
912             return 0;
913         case RSA_PKCS1_PSS_PADDING:
914             if (prsactx->mdname[0] == '\0')
915                 rsa_setup_md(prsactx, "SHA1", "");
916             goto cont;
917         case RSA_PKCS1_PADDING:
918         case RSA_SSLV23_PADDING:
919         case RSA_NO_PADDING:
920         case RSA_X931_PADDING:
921             if (RSA_get0_pss_params(prsactx->rsa) != NULL) {
922                 ERR_raise_data(ERR_LIB_PROV,
923                                PROV_R_ILLEGAL_OR_UNSUPPORTED_PADDING_MODE,
924                                "X.931 padding not allowed with RSA-PSS");
925                 return 0;
926             }
927         cont:
928             if (!rsa_check_padding(prsactx->mdnid, pad_mode))
929                 return 0;
930             break;
931         default:
932             return 0;
933         }
934         prsactx->pad_mode = pad_mode;
935     }
936
937     p = OSSL_PARAM_locate_const(params, OSSL_SIGNATURE_PARAM_PSS_SALTLEN);
938     if (p != NULL) {
939         int saltlen;
940
941         if (prsactx->pad_mode != RSA_PKCS1_PSS_PADDING) {
942             ERR_raise_data(ERR_LIB_PROV, PROV_R_NOT_SUPPORTED,
943                            "PSS saltlen can only be specified if "
944                            "PSS padding has been specified first");
945             return 0;
946         }
947
948         switch (p->data_type) {
949         case OSSL_PARAM_INTEGER: /* Support for legacy pad mode number */
950             if (!OSSL_PARAM_get_int(p, &saltlen))
951                 return 0;
952             break;
953         case OSSL_PARAM_UTF8_STRING:
954             if (strcmp(p->data, "digest") == 0)
955                 saltlen = RSA_PSS_SALTLEN_DIGEST;
956             else if (strcmp(p->data, "max") == 0)
957                 saltlen = RSA_PSS_SALTLEN_MAX;
958             else if (strcmp(p->data, "auto") == 0)
959                 saltlen = RSA_PSS_SALTLEN_AUTO;
960             else
961                 saltlen = atoi(p->data);
962             break;
963         default:
964             return 0;
965         }
966
967         /*
968          * RSA_PSS_SALTLEN_MAX seems curiously named in this check.
969          * Contrary to what it's name suggests, it's the currently
970          * lowest saltlen number possible.
971          */
972         if (saltlen < RSA_PSS_SALTLEN_MAX) {
973             ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_PSS_SALTLEN);
974             return 0;
975         }
976
977         prsactx->saltlen = saltlen;
978     }
979
980     p = OSSL_PARAM_locate_const(params, OSSL_SIGNATURE_PARAM_MGF1_DIGEST);
981     if (p != NULL) {
982         char mdname[OSSL_MAX_NAME_SIZE] = "", *pmdname = mdname;
983         char mdprops[OSSL_MAX_PROPQUERY_SIZE] = "", *pmdprops = mdprops;
984         const OSSL_PARAM *propsp =
985             OSSL_PARAM_locate_const(params,
986                                     OSSL_SIGNATURE_PARAM_MGF1_PROPERTIES);
987
988         if (!OSSL_PARAM_get_utf8_string(p, &pmdname, sizeof(mdname)))
989             return 0;
990         if (propsp != NULL
991             && !OSSL_PARAM_get_utf8_string(propsp, &pmdprops, sizeof(mdprops)))
992             return 0;
993
994         if (prsactx->pad_mode != RSA_PKCS1_PSS_PADDING) {
995             ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_MGF1_MD);
996             return  0;
997         }
998
999         /* TODO(3.0) PSS check needs more work */
1000         if (rsa_pss_restricted(prsactx)) {
1001             /* TODO(3.0) figure out what to do for prsactx->md == NULL */
1002             if (prsactx->mgf1_md == NULL
1003                 || EVP_MD_is_a(prsactx->mgf1_md, mdname))
1004                 return 1;
1005             ERR_raise(ERR_LIB_PROV, PROV_R_DIGEST_NOT_ALLOWED);
1006             return 0;
1007         }
1008
1009         /* non-PSS code follows */
1010         if (!rsa_setup_mgf1_md(prsactx, mdname, mdprops))
1011             return 0;
1012     }
1013
1014     return 1;
1015 }
1016
1017 static const OSSL_PARAM known_settable_ctx_params[] = {
1018     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_PAD_MODE, NULL, 0),
1019     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_DIGEST, NULL, 0),
1020     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_PROPERTIES, NULL, 0),
1021     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_MGF1_DIGEST, NULL, 0),
1022     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_MGF1_PROPERTIES, NULL, 0),
1023     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_PSS_SALTLEN, NULL, 0),
1024     OSSL_PARAM_END
1025 };
1026
1027 static const OSSL_PARAM *rsa_settable_ctx_params(void)
1028 {
1029     /*
1030      * TODO(3.0): Should this function return a different set of settable ctx
1031      * params if the ctx is being used for a DigestSign/DigestVerify? In that
1032      * case it is not allowed to set the digest size/digest name because the
1033      * digest is explicitly set as part of the init.
1034      */
1035     return known_settable_ctx_params;
1036 }
1037
1038 static int rsa_get_ctx_md_params(void *vprsactx, OSSL_PARAM *params)
1039 {
1040     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
1041
1042     if (prsactx->mdctx == NULL)
1043         return 0;
1044
1045     return EVP_MD_CTX_get_params(prsactx->mdctx, params);
1046 }
1047
1048 static const OSSL_PARAM *rsa_gettable_ctx_md_params(void *vprsactx)
1049 {
1050     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
1051
1052     if (prsactx->md == NULL)
1053         return 0;
1054
1055     return EVP_MD_gettable_ctx_params(prsactx->md);
1056 }
1057
1058 static int rsa_set_ctx_md_params(void *vprsactx, const OSSL_PARAM params[])
1059 {
1060     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
1061
1062     if (prsactx->mdctx == NULL)
1063         return 0;
1064
1065     return EVP_MD_CTX_set_params(prsactx->mdctx, params);
1066 }
1067
1068 static const OSSL_PARAM *rsa_settable_ctx_md_params(void *vprsactx)
1069 {
1070     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
1071
1072     if (prsactx->md == NULL)
1073         return 0;
1074
1075     return EVP_MD_settable_ctx_params(prsactx->md);
1076 }
1077
1078 const OSSL_DISPATCH rsa_signature_functions[] = {
1079     { OSSL_FUNC_SIGNATURE_NEWCTX, (void (*)(void))rsa_newctx },
1080     { OSSL_FUNC_SIGNATURE_SIGN_INIT, (void (*)(void))rsa_signature_init },
1081     { OSSL_FUNC_SIGNATURE_SIGN, (void (*)(void))rsa_sign },
1082     { OSSL_FUNC_SIGNATURE_VERIFY_INIT, (void (*)(void))rsa_signature_init },
1083     { OSSL_FUNC_SIGNATURE_VERIFY, (void (*)(void))rsa_verify },
1084     { OSSL_FUNC_SIGNATURE_VERIFY_RECOVER_INIT, (void (*)(void))rsa_signature_init },
1085     { OSSL_FUNC_SIGNATURE_VERIFY_RECOVER, (void (*)(void))rsa_verify_recover },
1086     { OSSL_FUNC_SIGNATURE_DIGEST_SIGN_INIT,
1087       (void (*)(void))rsa_digest_signverify_init },
1088     { OSSL_FUNC_SIGNATURE_DIGEST_SIGN_UPDATE,
1089       (void (*)(void))rsa_digest_signverify_update },
1090     { OSSL_FUNC_SIGNATURE_DIGEST_SIGN_FINAL,
1091       (void (*)(void))rsa_digest_sign_final },
1092     { OSSL_FUNC_SIGNATURE_DIGEST_VERIFY_INIT,
1093       (void (*)(void))rsa_digest_signverify_init },
1094     { OSSL_FUNC_SIGNATURE_DIGEST_VERIFY_UPDATE,
1095       (void (*)(void))rsa_digest_signverify_update },
1096     { OSSL_FUNC_SIGNATURE_DIGEST_VERIFY_FINAL,
1097       (void (*)(void))rsa_digest_verify_final },
1098     { OSSL_FUNC_SIGNATURE_FREECTX, (void (*)(void))rsa_freectx },
1099     { OSSL_FUNC_SIGNATURE_DUPCTX, (void (*)(void))rsa_dupctx },
1100     { OSSL_FUNC_SIGNATURE_GET_CTX_PARAMS, (void (*)(void))rsa_get_ctx_params },
1101     { OSSL_FUNC_SIGNATURE_GETTABLE_CTX_PARAMS,
1102       (void (*)(void))rsa_gettable_ctx_params },
1103     { OSSL_FUNC_SIGNATURE_SET_CTX_PARAMS, (void (*)(void))rsa_set_ctx_params },
1104     { OSSL_FUNC_SIGNATURE_SETTABLE_CTX_PARAMS,
1105       (void (*)(void))rsa_settable_ctx_params },
1106     { OSSL_FUNC_SIGNATURE_GET_CTX_MD_PARAMS,
1107       (void (*)(void))rsa_get_ctx_md_params },
1108     { OSSL_FUNC_SIGNATURE_GETTABLE_CTX_MD_PARAMS,
1109       (void (*)(void))rsa_gettable_ctx_md_params },
1110     { OSSL_FUNC_SIGNATURE_SET_CTX_MD_PARAMS,
1111       (void (*)(void))rsa_set_ctx_md_params },
1112     { OSSL_FUNC_SIGNATURE_SETTABLE_CTX_MD_PARAMS,
1113       (void (*)(void))rsa_settable_ctx_md_params },
1114     { 0, NULL }
1115 };