848cbd72499fbdc1cba3e23291c3962ba138d023
[openssl.git] / providers / implementations / signature / rsa.c
1 /*
2  * Copyright 2019 The OpenSSL Project Authors. All Rights Reserved.
3  *
4  * Licensed under the Apache License 2.0 (the "License").  You may not use
5  * this file except in compliance with the License.  You can obtain a copy
6  * in the file LICENSE in the source distribution or at
7  * https://www.openssl.org/source/license.html
8  */
9
10 /*
11  * RSA low level APIs are deprecated for public use, but still ok for
12  * internal use.
13  */
14 #include "internal/deprecated.h"
15
16 #include <string.h>
17 #include <openssl/crypto.h>
18 #include <openssl/core_numbers.h>
19 #include <openssl/core_names.h>
20 #include <openssl/err.h>
21 #include <openssl/rsa.h>
22 #include <openssl/params.h>
23 #include <openssl/evp.h>
24 #include "internal/cryptlib.h"
25 #include "internal/nelem.h"
26 #include "internal/sizes.h"
27 #include "crypto/rsa.h"
28 #include "prov/providercommonerr.h"
29 #include "prov/implementations.h"
30 #include "prov/provider_ctx.h"
31
32 static OSSL_OP_signature_newctx_fn rsa_newctx;
33 static OSSL_OP_signature_sign_init_fn rsa_signature_init;
34 static OSSL_OP_signature_verify_init_fn rsa_signature_init;
35 static OSSL_OP_signature_verify_recover_init_fn rsa_signature_init;
36 static OSSL_OP_signature_sign_fn rsa_sign;
37 static OSSL_OP_signature_verify_fn rsa_verify;
38 static OSSL_OP_signature_verify_recover_fn rsa_verify_recover;
39 static OSSL_OP_signature_digest_sign_init_fn rsa_digest_signverify_init;
40 static OSSL_OP_signature_digest_sign_update_fn rsa_digest_signverify_update;
41 static OSSL_OP_signature_digest_sign_final_fn rsa_digest_sign_final;
42 static OSSL_OP_signature_digest_verify_init_fn rsa_digest_signverify_init;
43 static OSSL_OP_signature_digest_verify_update_fn rsa_digest_signverify_update;
44 static OSSL_OP_signature_digest_verify_final_fn rsa_digest_verify_final;
45 static OSSL_OP_signature_freectx_fn rsa_freectx;
46 static OSSL_OP_signature_dupctx_fn rsa_dupctx;
47 static OSSL_OP_signature_get_ctx_params_fn rsa_get_ctx_params;
48 static OSSL_OP_signature_gettable_ctx_params_fn rsa_gettable_ctx_params;
49 static OSSL_OP_signature_set_ctx_params_fn rsa_set_ctx_params;
50 static OSSL_OP_signature_settable_ctx_params_fn rsa_settable_ctx_params;
51 static OSSL_OP_signature_get_ctx_md_params_fn rsa_get_ctx_md_params;
52 static OSSL_OP_signature_gettable_ctx_md_params_fn rsa_gettable_ctx_md_params;
53 static OSSL_OP_signature_set_ctx_md_params_fn rsa_set_ctx_md_params;
54 static OSSL_OP_signature_settable_ctx_md_params_fn rsa_settable_ctx_md_params;
55
56 static OSSL_ITEM padding_item[] = {
57     { RSA_PKCS1_PADDING,        "pkcs1"  },
58     { RSA_SSLV23_PADDING,       "sslv23" },
59     { RSA_NO_PADDING,           "none"   },
60     { RSA_PKCS1_OAEP_PADDING,   "oaep"   }, /* Correct spelling first */
61     { RSA_PKCS1_OAEP_PADDING,   "oeap"   },
62     { RSA_X931_PADDING,         "x931"   },
63     { RSA_PKCS1_PSS_PADDING,    "pss"    },
64     { 0,                        NULL     }
65 };
66
67 /*
68  * What's passed as an actual key is defined by the KEYMGMT interface.
69  * We happen to know that our KEYMGMT simply passes RSA structures, so
70  * we use that here too.
71  */
72
73 typedef struct {
74     OPENSSL_CTX *libctx;
75     RSA *rsa;
76
77     /*
78      * Flag to determine if the hash function can be changed (1) or not (0)
79      * Because it's dangerous to change during a DigestSign or DigestVerify
80      * operation, this flag is cleared by their Init function, and set again
81      * by their Final function.
82      */
83     unsigned int flag_allow_md : 1;
84
85     /* The Algorithm Identifier of the combined signature agorithm */
86     unsigned char aid[128];
87     size_t  aid_len;
88
89     /* main digest */
90     EVP_MD *md;
91     EVP_MD_CTX *mdctx;
92     int mdnid;
93     char mdname[OSSL_MAX_NAME_SIZE]; /* Purely informational */
94
95     /* RSA padding mode */
96     int pad_mode;
97     /* message digest for MGF1 */
98     EVP_MD *mgf1_md;
99     char mgf1_mdname[OSSL_MAX_NAME_SIZE]; /* Purely informational */
100     /* PSS salt length */
101     int saltlen;
102     /* Minimum salt length or -1 if no PSS parameter restriction */
103     int min_saltlen;
104
105     /* Temp buffer */
106     unsigned char *tbuf;
107
108 } PROV_RSA_CTX;
109
110 static size_t rsa_get_md_size(const PROV_RSA_CTX *prsactx)
111 {
112     if (prsactx->md != NULL)
113         return EVP_MD_size(prsactx->md);
114     return 0;
115 }
116
117 static int rsa_get_md_nid(const EVP_MD *md)
118 {
119     /*
120      * Because the RSA library deals with NIDs, we need to translate.
121      * We do so using EVP_MD_is_a(), and therefore need a name to NID
122      * map.
123      */
124     static const OSSL_ITEM name_to_nid[] = {
125         { NID_sha1,      OSSL_DIGEST_NAME_SHA1      },
126         { NID_sha224,    OSSL_DIGEST_NAME_SHA2_224  },
127         { NID_sha256,    OSSL_DIGEST_NAME_SHA2_256  },
128         { NID_sha384,    OSSL_DIGEST_NAME_SHA2_384  },
129         { NID_sha512,    OSSL_DIGEST_NAME_SHA2_512  },
130         { NID_md5,       OSSL_DIGEST_NAME_MD5       },
131         { NID_md5_sha1,  OSSL_DIGEST_NAME_MD5_SHA1  },
132         { NID_md2,       OSSL_DIGEST_NAME_MD2       },
133         { NID_md4,       OSSL_DIGEST_NAME_MD4       },
134         { NID_mdc2,      OSSL_DIGEST_NAME_MDC2      },
135         { NID_ripemd160, OSSL_DIGEST_NAME_RIPEMD160 },
136         { NID_sha3_224,  OSSL_DIGEST_NAME_SHA3_224  },
137         { NID_sha3_256,  OSSL_DIGEST_NAME_SHA3_256  },
138         { NID_sha3_384,  OSSL_DIGEST_NAME_SHA3_384  },
139         { NID_sha3_512,  OSSL_DIGEST_NAME_SHA3_512  },
140     };
141     size_t i;
142     int mdnid = NID_undef;
143
144     if (md == NULL)
145         goto end;
146
147     for (i = 0; i < OSSL_NELEM(name_to_nid); i++) {
148         if (EVP_MD_is_a(md, name_to_nid[i].ptr)) {
149             mdnid = (int)name_to_nid[i].id;
150             break;
151         }
152     }
153
154     if (mdnid == NID_undef)
155         ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_DIGEST);
156
157  end:
158     return mdnid;
159 }
160
161 static int rsa_check_padding(int mdnid, int padding)
162 {
163     if (padding == RSA_NO_PADDING) {
164         ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_PADDING_MODE);
165         return 0;
166     }
167
168     if (padding == RSA_X931_PADDING) {
169         if (RSA_X931_hash_id(mdnid) == -1) {
170             ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_X931_DIGEST);
171             return 0;
172         }
173     }
174
175     return 1;
176 }
177
178 static void *rsa_newctx(void *provctx)
179 {
180     PROV_RSA_CTX *prsactx = OPENSSL_zalloc(sizeof(PROV_RSA_CTX));
181
182     if (prsactx == NULL)
183         return NULL;
184
185     prsactx->libctx = PROV_LIBRARY_CONTEXT_OF(provctx);
186     prsactx->flag_allow_md = 1;
187     return prsactx;
188 }
189
190 /* True if PSS parameters are restricted */
191 #define rsa_pss_restricted(prsactx) (prsactx->min_saltlen != -1)
192
193 static int rsa_signature_init(void *vprsactx, void *vrsa)
194 {
195     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
196
197     if (prsactx == NULL || vrsa == NULL || !RSA_up_ref(vrsa))
198         return 0;
199
200     RSA_free(prsactx->rsa);
201     prsactx->rsa = vrsa;
202     if (RSA_get0_pss_params(prsactx->rsa) != NULL)
203         prsactx->pad_mode = RSA_PKCS1_PSS_PADDING;
204     else
205         prsactx->pad_mode = RSA_PKCS1_PADDING;
206     /* Maximum for sign, auto for verify */
207     prsactx->saltlen = RSA_PSS_SALTLEN_AUTO;
208     prsactx->min_saltlen = -1;
209
210     return 1;
211 }
212
213 static int rsa_setup_md(PROV_RSA_CTX *ctx, const char *mdname,
214                         const char *mdprops)
215 {
216     if (mdname != NULL) {
217         EVP_MD *md = EVP_MD_fetch(ctx->libctx, mdname, mdprops);
218         int md_nid = rsa_get_md_nid(md);
219         size_t algorithmidentifier_len = 0;
220         const unsigned char *algorithmidentifier = NULL;
221
222         if (md == NULL)
223             return 0;
224
225         if (!rsa_check_padding(md_nid, ctx->pad_mode)) {
226             EVP_MD_free(md);
227             return 0;
228         }
229
230         EVP_MD_CTX_free(ctx->mdctx);
231         EVP_MD_free(ctx->md);
232         ctx->md = NULL;
233         ctx->mdctx = NULL;
234         ctx->mdname[0] = '\0';
235         ctx->aid[0] = '\0';
236         ctx->aid_len = 0;
237
238         algorithmidentifier =
239             rsa_algorithmidentifier_encoding(md_nid, &algorithmidentifier_len);
240
241         ctx->md = md;
242         ctx->mdnid = md_nid;
243         OPENSSL_strlcpy(ctx->mdname, mdname, sizeof(ctx->mdname));
244         if (algorithmidentifier != NULL) {
245             memcpy(ctx->aid, algorithmidentifier, algorithmidentifier_len);
246             ctx->aid_len = algorithmidentifier_len;
247         }
248     }
249
250     return 1;
251 }
252
253 static int rsa_setup_mgf1_md(PROV_RSA_CTX *ctx, const char *mdname,
254                              const char *props)
255 {
256     if (ctx->mgf1_mdname[0] != '\0')
257         EVP_MD_free(ctx->mgf1_md);
258
259     if ((ctx->mgf1_md = EVP_MD_fetch(ctx->libctx, mdname, props)) == NULL)
260         return 0;
261     OPENSSL_strlcpy(ctx->mgf1_mdname, mdname, sizeof(ctx->mgf1_mdname));
262
263     return 1;
264 }
265
266 static int setup_tbuf(PROV_RSA_CTX *ctx)
267 {
268     if (ctx->tbuf != NULL)
269         return 1;
270     if ((ctx->tbuf = OPENSSL_malloc(RSA_size(ctx->rsa))) == NULL) {
271         ERR_raise(ERR_LIB_PROV, ERR_R_MALLOC_FAILURE);
272         return 0;
273     }
274     return 1;
275 }
276
277 static void clean_tbuf(PROV_RSA_CTX *ctx)
278 {
279     if (ctx->tbuf != NULL)
280         OPENSSL_cleanse(ctx->tbuf, RSA_size(ctx->rsa));
281 }
282
283 static void free_tbuf(PROV_RSA_CTX *ctx)
284 {
285     OPENSSL_clear_free(ctx->tbuf, RSA_size(ctx->rsa));
286     ctx->tbuf = NULL;
287 }
288
289 static int rsa_sign(void *vprsactx, unsigned char *sig, size_t *siglen,
290                     size_t sigsize, const unsigned char *tbs, size_t tbslen)
291 {
292     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
293     int ret;
294     size_t rsasize = RSA_size(prsactx->rsa);
295     size_t mdsize = rsa_get_md_size(prsactx);
296
297     if (sig == NULL) {
298         *siglen = rsasize;
299         return 1;
300     }
301
302     if (sigsize < (size_t)rsasize)
303         return 0;
304
305     if (mdsize != 0) {
306         if (tbslen != mdsize) {
307             ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_DIGEST_LENGTH);
308             return 0;
309         }
310
311 #ifndef FIPS_MODE
312         if (EVP_MD_is_a(prsactx->md, OSSL_DIGEST_NAME_MDC2)) {
313             unsigned int sltmp;
314
315             if (prsactx->pad_mode != RSA_PKCS1_PADDING) {
316                 ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_PADDING_MODE,
317                                "only PKCS#1 padding supported with MDC2");
318                 return 0;
319             }
320             ret = RSA_sign_ASN1_OCTET_STRING(0, tbs, tbslen, sig, &sltmp,
321                                              prsactx->rsa);
322
323             if (ret <= 0) {
324                 ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
325                 return 0;
326             }
327             ret = sltmp;
328             goto end;
329         }
330 #endif
331         switch (prsactx->pad_mode) {
332         case RSA_X931_PADDING:
333             if ((size_t)RSA_size(prsactx->rsa) < tbslen + 1) {
334                 ERR_raise(ERR_LIB_PROV, PROV_R_KEY_SIZE_TOO_SMALL);
335                 return 0;
336             }
337             if (!setup_tbuf(prsactx)) {
338                 ERR_raise(ERR_LIB_PROV, ERR_R_MALLOC_FAILURE);
339                 return 0;
340             }
341             memcpy(prsactx->tbuf, tbs, tbslen);
342             prsactx->tbuf[tbslen] = RSA_X931_hash_id(prsactx->mdnid);
343             ret = RSA_private_encrypt(tbslen + 1, prsactx->tbuf,
344                                       sig, prsactx->rsa, RSA_X931_PADDING);
345             clean_tbuf(prsactx);
346             break;
347
348         case RSA_PKCS1_PADDING:
349             {
350                 unsigned int sltmp;
351
352                 ret = RSA_sign(prsactx->mdnid, tbs, tbslen, sig, &sltmp,
353                                prsactx->rsa);
354                 if (ret <= 0) {
355                     ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
356                     return 0;
357                 }
358                 ret = sltmp;
359             }
360             break;
361
362         case RSA_PKCS1_PSS_PADDING:
363             /* Check PSS restrictions */
364             if (rsa_pss_restricted(prsactx)) {
365                 switch (prsactx->saltlen) {
366                 case RSA_PSS_SALTLEN_DIGEST:
367                     if (prsactx->min_saltlen > EVP_MD_size(prsactx->md)) {
368                         ERR_raise(ERR_LIB_PROV, PROV_R_PSS_SALTLEN_TOO_SMALL);
369                         return 0;
370                     }
371                     /* FALLTHRU */
372                 default:
373                     if (prsactx->saltlen >= 0
374                         && prsactx->saltlen < prsactx->min_saltlen) {
375                         ERR_raise(ERR_LIB_PROV, PROV_R_PSS_SALTLEN_TOO_SMALL);
376                         return 0;
377                     }
378                     break;
379                 }
380             }
381             if (!setup_tbuf(prsactx))
382                 return 0;
383             if (!RSA_padding_add_PKCS1_PSS_mgf1(prsactx->rsa,
384                                                 prsactx->tbuf, tbs,
385                                                 prsactx->md, prsactx->mgf1_md,
386                                                 prsactx->saltlen)) {
387                 ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
388                 return 0;
389             }
390             ret = RSA_private_encrypt(RSA_size(prsactx->rsa), prsactx->tbuf,
391                                       sig, prsactx->rsa, RSA_NO_PADDING);
392             clean_tbuf(prsactx);
393             break;
394
395         default:
396             ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_PADDING_MODE,
397                            "Only X.931, PKCS#1 v1.5 or PSS padding allowed");
398             return 0;
399         }
400     } else {
401         ret = RSA_private_encrypt(tbslen, tbs, sig, prsactx->rsa,
402                                   prsactx->pad_mode);
403     }
404
405 #ifndef FIPS_MODE
406  end:
407 #endif
408     if (ret <= 0) {
409         ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
410         return 0;
411     }
412
413     *siglen = ret;
414     return 1;
415 }
416
417 static int rsa_verify_recover(void *vprsactx,
418                               unsigned char *rout,
419                               size_t *routlen,
420                               size_t routsize,
421                               const unsigned char *sig,
422                               size_t siglen)
423 {
424     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
425     int ret;
426
427     if (rout == NULL) {
428         *routlen = RSA_size(prsactx->rsa);
429         return 1;
430     }
431
432     if (prsactx->md != NULL) {
433         switch (prsactx->pad_mode) {
434         case RSA_X931_PADDING:
435             if (!setup_tbuf(prsactx))
436                 return 0;
437             ret = RSA_public_decrypt(siglen, sig, prsactx->tbuf, prsactx->rsa,
438                                      RSA_X931_PADDING);
439             if (ret < 1) {
440                 ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
441                 return 0;
442             }
443             ret--;
444             if (prsactx->tbuf[ret] != RSA_X931_hash_id(prsactx->mdnid)) {
445                 ERR_raise(ERR_LIB_PROV, PROV_R_ALGORITHM_MISMATCH);
446                 return 0;
447             }
448             if (ret != EVP_MD_size(prsactx->md)) {
449                 ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_DIGEST_LENGTH,
450                                "Should be %d, but got %d",
451                                EVP_MD_size(prsactx->md), ret);
452                 return 0;
453             }
454
455             *routlen = ret;
456             if (routsize < (size_t)ret) {
457                 ERR_raise(ERR_LIB_PROV, PROV_R_OUTPUT_BUFFER_TOO_SMALL);
458                 return 0;
459             }
460             memcpy(rout, prsactx->tbuf, ret);
461             break;
462
463         case RSA_PKCS1_PADDING:
464             {
465                 size_t sltmp;
466
467                 ret = int_rsa_verify(prsactx->mdnid, NULL, 0, rout, &sltmp,
468                                      sig, siglen, prsactx->rsa);
469                 if (ret <= 0) {
470                     ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
471                     return 0;
472                 }
473                 ret = sltmp;
474             }
475             break;
476
477         default:
478             ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_PADDING_MODE,
479                            "Only X.931 or PKCS#1 v1.5 padding allowed");
480             return 0;
481         }
482     } else {
483         ret = RSA_public_decrypt(siglen, sig, rout, prsactx->rsa,
484                                  prsactx->pad_mode);
485         if (ret < 0) {
486             ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
487             return 0;
488         }
489     }
490     *routlen = ret;
491     return 1;
492 }
493
494 static int rsa_verify(void *vprsactx, const unsigned char *sig, size_t siglen,
495                       const unsigned char *tbs, size_t tbslen)
496 {
497     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
498     size_t rslen;
499
500     if (prsactx->md != NULL) {
501         switch (prsactx->pad_mode) {
502         case RSA_PKCS1_PADDING:
503             if (!RSA_verify(prsactx->mdnid, tbs, tbslen, sig, siglen,
504                             prsactx->rsa)) {
505                 ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
506                 return 0;
507             }
508             return 1;
509         case RSA_X931_PADDING:
510             if (rsa_verify_recover(prsactx, NULL, &rslen, 0, sig, siglen) <= 0)
511                 return 0;
512             break;
513         case RSA_PKCS1_PSS_PADDING:
514             {
515                 int ret;
516                 size_t mdsize;
517
518                 /* Check PSS restrictions */
519                 if (rsa_pss_restricted(prsactx)) {
520                     switch (prsactx->saltlen) {
521                     case RSA_PSS_SALTLEN_AUTO:
522                         ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_PSS_SALTLEN);
523                         return 0;
524                     case RSA_PSS_SALTLEN_DIGEST:
525                         if (prsactx->min_saltlen > EVP_MD_size(prsactx->md)) {
526                             ERR_raise(ERR_LIB_PROV,
527                                       PROV_R_PSS_SALTLEN_TOO_SMALL);
528                             return 0;
529                         }
530                         /* FALLTHRU */
531                     default:
532                         if (prsactx->saltlen >= 0
533                             && prsactx->saltlen < prsactx->min_saltlen) {
534                             ERR_raise(ERR_LIB_PROV, PROV_R_PSS_SALTLEN_TOO_SMALL);
535                             return 0;
536                         }
537                         break;
538                     }
539                 }
540
541                 /*
542                  * We need to check this for the RSA_verify_PKCS1_PSS_mgf1()
543                  * call
544                  */
545                 mdsize = rsa_get_md_size(prsactx);
546                 if (tbslen != mdsize) {
547                     ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_DIGEST_LENGTH,
548                                    "Should be %d, but got %d",
549                                    mdsize, tbslen);
550                     return 0;
551                 }
552
553                 if (!setup_tbuf(prsactx))
554                     return 0;
555                 ret = RSA_public_decrypt(siglen, sig, prsactx->tbuf,
556                                          prsactx->rsa, RSA_NO_PADDING);
557                 if (ret <= 0) {
558                     ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
559                     return 0;
560                 }
561                 ret = RSA_verify_PKCS1_PSS_mgf1(prsactx->rsa, tbs,
562                                                 prsactx->md, prsactx->mgf1_md,
563                                                 prsactx->tbuf,
564                                                 prsactx->saltlen);
565                 if (ret <= 0) {
566                     ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
567                     return 0;
568                 }
569                 return 1;
570             }
571         default:
572             ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_PADDING_MODE,
573                            "Only X.931, PKCS#1 v1.5 or PSS padding allowed");
574             return 0;
575         }
576     } else {
577         if (!setup_tbuf(prsactx))
578             return 0;
579         rslen = RSA_public_decrypt(siglen, sig, prsactx->tbuf, prsactx->rsa,
580                                    prsactx->pad_mode);
581         if (rslen == 0) {
582             ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
583             return 0;
584         }
585     }
586
587     if ((rslen != tbslen) || memcmp(tbs, prsactx->tbuf, rslen))
588         return 0;
589
590     return 1;
591 }
592
593 static int rsa_digest_signverify_init(void *vprsactx, const char *mdname,
594                                       const char *props, void *vrsa)
595 {
596     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
597
598     prsactx->flag_allow_md = 0;
599     if (!rsa_signature_init(vprsactx, vrsa)
600         || !rsa_setup_md(prsactx, mdname, props))
601         return 0;
602
603     prsactx->mdctx = EVP_MD_CTX_new();
604     if (prsactx->mdctx == NULL)
605         goto error;
606
607     if (!EVP_DigestInit_ex(prsactx->mdctx, prsactx->md, NULL))
608         goto error;
609
610     return 1;
611
612  error:
613     EVP_MD_CTX_free(prsactx->mdctx);
614     EVP_MD_free(prsactx->md);
615     prsactx->mdctx = NULL;
616     prsactx->md = NULL;
617     return 0;
618 }
619
620 int rsa_digest_signverify_update(void *vprsactx, const unsigned char *data,
621                                  size_t datalen)
622 {
623     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
624
625     if (prsactx == NULL || prsactx->mdctx == NULL)
626         return 0;
627
628     return EVP_DigestUpdate(prsactx->mdctx, data, datalen);
629 }
630
631 int rsa_digest_sign_final(void *vprsactx, unsigned char *sig, size_t *siglen,
632                           size_t sigsize)
633 {
634     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
635     unsigned char digest[EVP_MAX_MD_SIZE];
636     unsigned int dlen = 0;
637
638     prsactx->flag_allow_md = 1;
639     if (prsactx == NULL || prsactx->mdctx == NULL)
640         return 0;
641
642     /*
643      * If sig is NULL then we're just finding out the sig size. Other fields
644      * are ignored. Defer to rsa_sign.
645      */
646     if (sig != NULL) {
647         /*
648          * TODO(3.0): There is the possibility that some externally provided
649          * digests exceed EVP_MAX_MD_SIZE. We should probably handle that somehow -
650          * but that problem is much larger than just in RSA.
651          */
652         if (!EVP_DigestFinal_ex(prsactx->mdctx, digest, &dlen))
653             return 0;
654     }
655
656     return rsa_sign(vprsactx, sig, siglen, sigsize, digest, (size_t)dlen);
657 }
658
659
660 int rsa_digest_verify_final(void *vprsactx, const unsigned char *sig,
661                             size_t siglen)
662 {
663     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
664     unsigned char digest[EVP_MAX_MD_SIZE];
665     unsigned int dlen = 0;
666
667     prsactx->flag_allow_md = 1;
668     if (prsactx == NULL || prsactx->mdctx == NULL)
669         return 0;
670
671     /*
672      * TODO(3.0): There is the possibility that some externally provided
673      * digests exceed EVP_MAX_MD_SIZE. We should probably handle that somehow -
674      * but that problem is much larger than just in RSA.
675      */
676     if (!EVP_DigestFinal_ex(prsactx->mdctx, digest, &dlen))
677         return 0;
678
679     return rsa_verify(vprsactx, sig, siglen, digest, (size_t)dlen);
680 }
681
682 static void rsa_freectx(void *vprsactx)
683 {
684     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
685
686     if (prsactx == NULL)
687         return;
688
689     RSA_free(prsactx->rsa);
690     EVP_MD_CTX_free(prsactx->mdctx);
691     EVP_MD_free(prsactx->md);
692     EVP_MD_free(prsactx->mgf1_md);
693     free_tbuf(prsactx);
694
695     OPENSSL_clear_free(prsactx, sizeof(prsactx));
696 }
697
698 static void *rsa_dupctx(void *vprsactx)
699 {
700     PROV_RSA_CTX *srcctx = (PROV_RSA_CTX *)vprsactx;
701     PROV_RSA_CTX *dstctx;
702
703     dstctx = OPENSSL_zalloc(sizeof(*srcctx));
704     if (dstctx == NULL)
705         return NULL;
706
707     *dstctx = *srcctx;
708     dstctx->rsa = NULL;
709     dstctx->md = NULL;
710     dstctx->mdctx = NULL;
711     dstctx->tbuf = NULL;
712
713     if (srcctx->rsa != NULL && !RSA_up_ref(srcctx->rsa))
714         goto err;
715     dstctx->rsa = srcctx->rsa;
716
717     if (srcctx->md != NULL && !EVP_MD_up_ref(srcctx->md))
718         goto err;
719     dstctx->md = srcctx->md;
720
721     if (srcctx->mgf1_md != NULL && !EVP_MD_up_ref(srcctx->mgf1_md))
722         goto err;
723     dstctx->mgf1_md = srcctx->mgf1_md;
724
725     if (srcctx->mdctx != NULL) {
726         dstctx->mdctx = EVP_MD_CTX_new();
727         if (dstctx->mdctx == NULL
728                 || !EVP_MD_CTX_copy_ex(dstctx->mdctx, srcctx->mdctx))
729             goto err;
730     }
731
732     return dstctx;
733  err:
734     rsa_freectx(dstctx);
735     return NULL;
736 }
737
738 static int rsa_get_ctx_params(void *vprsactx, OSSL_PARAM *params)
739 {
740     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
741     OSSL_PARAM *p;
742
743     if (prsactx == NULL || params == NULL)
744         return 0;
745
746     p = OSSL_PARAM_locate(params, OSSL_SIGNATURE_PARAM_ALGORITHM_ID);
747     if (p != NULL
748         && !OSSL_PARAM_set_octet_string(p, prsactx->aid, prsactx->aid_len))
749         return 0;
750
751     p = OSSL_PARAM_locate(params, OSSL_SIGNATURE_PARAM_PAD_MODE);
752     if (p != NULL)
753         switch (p->data_type) {
754         case OSSL_PARAM_INTEGER:
755             if (!OSSL_PARAM_set_int(p, prsactx->pad_mode))
756                 return 0;
757             break;
758         case OSSL_PARAM_UTF8_STRING:
759             {
760                 int i;
761                 const char *word = NULL;
762
763                 for (i = 0; padding_item[i].id != 0; i++) {
764                     if (prsactx->pad_mode == (int)padding_item[i].id) {
765                         word = padding_item[i].ptr;
766                         break;
767                     }
768                 }
769
770                 if (word != NULL) {
771                     if (!OSSL_PARAM_set_utf8_string(p, word))
772                         return 0;
773                 } else {
774                     ERR_raise(ERR_LIB_PROV, ERR_R_INTERNAL_ERROR);
775                 }
776             }
777             break;
778         default:
779             return 0;
780         }
781
782     p = OSSL_PARAM_locate(params, OSSL_SIGNATURE_PARAM_DIGEST);
783     if (p != NULL && !OSSL_PARAM_set_utf8_string(p, prsactx->mdname))
784         return 0;
785
786     p = OSSL_PARAM_locate(params, OSSL_SIGNATURE_PARAM_MGF1_DIGEST);
787     if (p != NULL && !OSSL_PARAM_set_utf8_string(p, prsactx->mgf1_mdname))
788         return 0;
789
790     p = OSSL_PARAM_locate(params, OSSL_SIGNATURE_PARAM_PSS_SALTLEN);
791     if (p != NULL) {
792         if (p->data_type == OSSL_PARAM_INTEGER) {
793             if (!OSSL_PARAM_set_int(p, prsactx->saltlen))
794                 return 0;
795         } else if (p->data_type == OSSL_PARAM_UTF8_STRING) {
796             switch (prsactx->saltlen) {
797             case RSA_PSS_SALTLEN_DIGEST:
798                 if (!OSSL_PARAM_set_utf8_string(p, "digest"))
799                     return 0;
800                 break;
801             case RSA_PSS_SALTLEN_MAX:
802                 if (!OSSL_PARAM_set_utf8_string(p, "max"))
803                     return 0;
804                 break;
805             case RSA_PSS_SALTLEN_AUTO:
806                 if (!OSSL_PARAM_set_utf8_string(p, "auto"))
807                     return 0;
808                 break;
809             default:
810                 if (BIO_snprintf(p->data, p->data_size, "%d", prsactx->saltlen)
811                     <= 0)
812                     return 0;
813                 break;
814             }
815         }
816     }
817
818     return 1;
819 }
820
821 static const OSSL_PARAM known_gettable_ctx_params[] = {
822     OSSL_PARAM_octet_string(OSSL_SIGNATURE_PARAM_ALGORITHM_ID, NULL, 0),
823     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_PAD_MODE, NULL, 0),
824     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_DIGEST, NULL, 0),
825     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_MGF1_DIGEST, NULL, 0),
826     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_PSS_SALTLEN, NULL, 0),
827     OSSL_PARAM_END
828 };
829
830 static const OSSL_PARAM *rsa_gettable_ctx_params(void)
831 {
832     return known_gettable_ctx_params;
833 }
834
835 static int rsa_set_ctx_params(void *vprsactx, const OSSL_PARAM params[])
836 {
837     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
838     const OSSL_PARAM *p;
839
840     if (prsactx == NULL || params == NULL)
841         return 0;
842
843     p = OSSL_PARAM_locate_const(params, OSSL_SIGNATURE_PARAM_DIGEST);
844     /* Not allowed during certain operations */
845     if (p != NULL && !prsactx->flag_allow_md)
846         return 0;
847     if (p != NULL) {
848         char mdname[OSSL_MAX_NAME_SIZE] = "", *pmdname = mdname;
849         char mdprops[OSSL_MAX_PROPQUERY_SIZE] = "", *pmdprops = mdprops;
850         const OSSL_PARAM *propsp =
851             OSSL_PARAM_locate_const(params,
852                                     OSSL_SIGNATURE_PARAM_PROPERTIES);
853
854         if (!OSSL_PARAM_get_utf8_string(p, &pmdname, sizeof(mdname)))
855             return 0;
856         if (propsp != NULL
857             && !OSSL_PARAM_get_utf8_string(propsp, &pmdprops, sizeof(mdprops)))
858             return 0;
859
860         /* TODO(3.0) PSS check needs more work */
861         if (rsa_pss_restricted(prsactx)) {
862             /* TODO(3.0) figure out what to do for prsactx->md == NULL */
863             if (prsactx->md == NULL || EVP_MD_is_a(prsactx->md, mdname))
864                 return 1;
865             ERR_raise(ERR_LIB_PROV, PROV_R_DIGEST_NOT_ALLOWED);
866             return 0;
867         }
868
869         /* non-PSS code follows */
870         if (!rsa_setup_md(prsactx, mdname, mdprops))
871             return 0;
872     }
873
874     p = OSSL_PARAM_locate_const(params, OSSL_SIGNATURE_PARAM_PAD_MODE);
875     if (p != NULL) {
876         int pad_mode = 0;
877
878         switch (p->data_type) {
879         case OSSL_PARAM_INTEGER: /* Support for legacy pad mode number */
880             if (!OSSL_PARAM_get_int(p, &pad_mode))
881                 return 0;
882             break;
883         case OSSL_PARAM_UTF8_STRING:
884             {
885                 int i;
886
887                 if (p->data == NULL)
888                     return 0;
889
890                 for (i = 0; padding_item[i].id != 0; i++) {
891                     if (strcmp(p->data, padding_item[i].ptr) == 0) {
892                         pad_mode = padding_item[i].id;
893                         break;
894                     }
895                 }
896             }
897             break;
898         default:
899             return 0;
900         }
901
902         switch (pad_mode) {
903         case RSA_PKCS1_OAEP_PADDING:
904             /*
905              * OAEP padding is for asymmetric cipher only so is not compatible
906              * with signature use.
907              */
908             ERR_raise_data(ERR_LIB_PROV,
909                            PROV_R_ILLEGAL_OR_UNSUPPORTED_PADDING_MODE,
910                            "OAEP padding not allowed for signing / verifying");
911             return 0;
912         case RSA_PKCS1_PSS_PADDING:
913             if (prsactx->mdname[0] == '\0')
914                 rsa_setup_md(prsactx, "SHA1", "");
915             goto cont;
916         case RSA_PKCS1_PADDING:
917         case RSA_SSLV23_PADDING:
918         case RSA_NO_PADDING:
919         case RSA_X931_PADDING:
920             if (RSA_get0_pss_params(prsactx->rsa) != NULL) {
921                 ERR_raise_data(ERR_LIB_PROV,
922                                PROV_R_ILLEGAL_OR_UNSUPPORTED_PADDING_MODE,
923                                "X.931 padding not allowed with RSA-PSS");
924                 return 0;
925             }
926         cont:
927             if (!rsa_check_padding(prsactx->mdnid, pad_mode))
928                 return 0;
929             break;
930         default:
931             return 0;
932         }
933         prsactx->pad_mode = pad_mode;
934     }
935
936     p = OSSL_PARAM_locate_const(params, OSSL_SIGNATURE_PARAM_PSS_SALTLEN);
937     if (p != NULL) {
938         int saltlen;
939
940         if (prsactx->pad_mode != RSA_PKCS1_PSS_PADDING) {
941             ERR_raise_data(ERR_LIB_PROV, PROV_R_NOT_SUPPORTED,
942                            "PSS saltlen can only be specified if "
943                            "PSS padding has been specified first");
944             return 0;
945         }
946
947         switch (p->data_type) {
948         case OSSL_PARAM_INTEGER: /* Support for legacy pad mode number */
949             if (!OSSL_PARAM_get_int(p, &saltlen))
950                 return 0;
951             break;
952         case OSSL_PARAM_UTF8_STRING:
953             if (strcmp(p->data, "digest") == 0)
954                 saltlen = RSA_PSS_SALTLEN_DIGEST;
955             else if (strcmp(p->data, "max") == 0)
956                 saltlen = RSA_PSS_SALTLEN_MAX;
957             else if (strcmp(p->data, "auto") == 0)
958                 saltlen = RSA_PSS_SALTLEN_AUTO;
959             else
960                 saltlen = atoi(p->data);
961             break;
962         default:
963             return 0;
964         }
965
966         /*
967          * RSA_PSS_SALTLEN_MAX seems curiously named in this check.
968          * Contrary to what it's name suggests, it's the currently
969          * lowest saltlen number possible.
970          */
971         if (saltlen < RSA_PSS_SALTLEN_MAX) {
972             ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_PSS_SALTLEN);
973             return 0;
974         }
975
976         prsactx->saltlen = saltlen;
977     }
978
979     p = OSSL_PARAM_locate_const(params, OSSL_SIGNATURE_PARAM_MGF1_DIGEST);
980     if (p != NULL) {
981         char mdname[OSSL_MAX_NAME_SIZE] = "", *pmdname = mdname;
982         char mdprops[OSSL_MAX_PROPQUERY_SIZE] = "", *pmdprops = mdprops;
983         const OSSL_PARAM *propsp =
984             OSSL_PARAM_locate_const(params,
985                                     OSSL_SIGNATURE_PARAM_MGF1_PROPERTIES);
986
987         if (!OSSL_PARAM_get_utf8_string(p, &pmdname, sizeof(mdname)))
988             return 0;
989         if (propsp != NULL
990             && !OSSL_PARAM_get_utf8_string(propsp, &pmdprops, sizeof(mdprops)))
991             return 0;
992
993         if (prsactx->pad_mode != RSA_PKCS1_PSS_PADDING) {
994             ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_MGF1_MD);
995             return  0;
996         }
997
998         /* TODO(3.0) PSS check needs more work */
999         if (rsa_pss_restricted(prsactx)) {
1000             /* TODO(3.0) figure out what to do for prsactx->md == NULL */
1001             if (prsactx->mgf1_md == NULL
1002                 || EVP_MD_is_a(prsactx->mgf1_md, mdname))
1003                 return 1;
1004             ERR_raise(ERR_LIB_PROV, PROV_R_DIGEST_NOT_ALLOWED);
1005             return 0;
1006         }
1007
1008         /* non-PSS code follows */
1009         if (!rsa_setup_mgf1_md(prsactx, mdname, mdprops))
1010             return 0;
1011     }
1012
1013     return 1;
1014 }
1015
1016 static const OSSL_PARAM known_settable_ctx_params[] = {
1017     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_PAD_MODE, NULL, 0),
1018     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_DIGEST, NULL, 0),
1019     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_PROPERTIES, NULL, 0),
1020     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_MGF1_DIGEST, NULL, 0),
1021     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_MGF1_PROPERTIES, NULL, 0),
1022     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_PSS_SALTLEN, NULL, 0),
1023     OSSL_PARAM_END
1024 };
1025
1026 static const OSSL_PARAM *rsa_settable_ctx_params(void)
1027 {
1028     /*
1029      * TODO(3.0): Should this function return a different set of settable ctx
1030      * params if the ctx is being used for a DigestSign/DigestVerify? In that
1031      * case it is not allowed to set the digest size/digest name because the
1032      * digest is explicitly set as part of the init.
1033      */
1034     return known_settable_ctx_params;
1035 }
1036
1037 static int rsa_get_ctx_md_params(void *vprsactx, OSSL_PARAM *params)
1038 {
1039     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
1040
1041     if (prsactx->mdctx == NULL)
1042         return 0;
1043
1044     return EVP_MD_CTX_get_params(prsactx->mdctx, params);
1045 }
1046
1047 static const OSSL_PARAM *rsa_gettable_ctx_md_params(void *vprsactx)
1048 {
1049     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
1050
1051     if (prsactx->md == NULL)
1052         return 0;
1053
1054     return EVP_MD_gettable_ctx_params(prsactx->md);
1055 }
1056
1057 static int rsa_set_ctx_md_params(void *vprsactx, const OSSL_PARAM params[])
1058 {
1059     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
1060
1061     if (prsactx->mdctx == NULL)
1062         return 0;
1063
1064     return EVP_MD_CTX_set_params(prsactx->mdctx, params);
1065 }
1066
1067 static const OSSL_PARAM *rsa_settable_ctx_md_params(void *vprsactx)
1068 {
1069     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
1070
1071     if (prsactx->md == NULL)
1072         return 0;
1073
1074     return EVP_MD_settable_ctx_params(prsactx->md);
1075 }
1076
1077 const OSSL_DISPATCH rsa_signature_functions[] = {
1078     { OSSL_FUNC_SIGNATURE_NEWCTX, (void (*)(void))rsa_newctx },
1079     { OSSL_FUNC_SIGNATURE_SIGN_INIT, (void (*)(void))rsa_signature_init },
1080     { OSSL_FUNC_SIGNATURE_SIGN, (void (*)(void))rsa_sign },
1081     { OSSL_FUNC_SIGNATURE_VERIFY_INIT, (void (*)(void))rsa_signature_init },
1082     { OSSL_FUNC_SIGNATURE_VERIFY, (void (*)(void))rsa_verify },
1083     { OSSL_FUNC_SIGNATURE_VERIFY_RECOVER_INIT, (void (*)(void))rsa_signature_init },
1084     { OSSL_FUNC_SIGNATURE_VERIFY_RECOVER, (void (*)(void))rsa_verify_recover },
1085     { OSSL_FUNC_SIGNATURE_DIGEST_SIGN_INIT,
1086       (void (*)(void))rsa_digest_signverify_init },
1087     { OSSL_FUNC_SIGNATURE_DIGEST_SIGN_UPDATE,
1088       (void (*)(void))rsa_digest_signverify_update },
1089     { OSSL_FUNC_SIGNATURE_DIGEST_SIGN_FINAL,
1090       (void (*)(void))rsa_digest_sign_final },
1091     { OSSL_FUNC_SIGNATURE_DIGEST_VERIFY_INIT,
1092       (void (*)(void))rsa_digest_signverify_init },
1093     { OSSL_FUNC_SIGNATURE_DIGEST_VERIFY_UPDATE,
1094       (void (*)(void))rsa_digest_signverify_update },
1095     { OSSL_FUNC_SIGNATURE_DIGEST_VERIFY_FINAL,
1096       (void (*)(void))rsa_digest_verify_final },
1097     { OSSL_FUNC_SIGNATURE_FREECTX, (void (*)(void))rsa_freectx },
1098     { OSSL_FUNC_SIGNATURE_DUPCTX, (void (*)(void))rsa_dupctx },
1099     { OSSL_FUNC_SIGNATURE_GET_CTX_PARAMS, (void (*)(void))rsa_get_ctx_params },
1100     { OSSL_FUNC_SIGNATURE_GETTABLE_CTX_PARAMS,
1101       (void (*)(void))rsa_gettable_ctx_params },
1102     { OSSL_FUNC_SIGNATURE_SET_CTX_PARAMS, (void (*)(void))rsa_set_ctx_params },
1103     { OSSL_FUNC_SIGNATURE_SETTABLE_CTX_PARAMS,
1104       (void (*)(void))rsa_settable_ctx_params },
1105     { OSSL_FUNC_SIGNATURE_GET_CTX_MD_PARAMS,
1106       (void (*)(void))rsa_get_ctx_md_params },
1107     { OSSL_FUNC_SIGNATURE_GETTABLE_CTX_MD_PARAMS,
1108       (void (*)(void))rsa_gettable_ctx_md_params },
1109     { OSSL_FUNC_SIGNATURE_SET_CTX_MD_PARAMS,
1110       (void (*)(void))rsa_set_ctx_md_params },
1111     { OSSL_FUNC_SIGNATURE_SETTABLE_CTX_MD_PARAMS,
1112       (void (*)(void))rsa_settable_ctx_md_params },
1113     { 0, NULL }
1114 };