Update copyright year
[openssl.git] / providers / implementations / signature / rsa.c
1 /*
2  * Copyright 2019-2020 The OpenSSL Project Authors. All Rights Reserved.
3  *
4  * Licensed under the Apache License 2.0 (the "License").  You may not use
5  * this file except in compliance with the License.  You can obtain a copy
6  * in the file LICENSE in the source distribution or at
7  * https://www.openssl.org/source/license.html
8  */
9
10 /*
11  * RSA low level APIs are deprecated for public use, but still ok for
12  * internal use.
13  */
14 #include "internal/deprecated.h"
15
16 #include <string.h>
17 #include <openssl/crypto.h>
18 #include <openssl/core_numbers.h>
19 #include <openssl/core_names.h>
20 #include <openssl/err.h>
21 #include <openssl/rsa.h>
22 #include <openssl/params.h>
23 #include <openssl/evp.h>
24 #include "internal/cryptlib.h"
25 #include "internal/nelem.h"
26 #include "internal/sizes.h"
27 #include "crypto/rsa.h"
28 #include "prov/providercommonerr.h"
29 #include "prov/implementations.h"
30 #include "prov/provider_ctx.h"
31 #include "prov/der_rsa.h"
32
33 static OSSL_OP_signature_newctx_fn rsa_newctx;
34 static OSSL_OP_signature_sign_init_fn rsa_signature_init;
35 static OSSL_OP_signature_verify_init_fn rsa_signature_init;
36 static OSSL_OP_signature_verify_recover_init_fn rsa_signature_init;
37 static OSSL_OP_signature_sign_fn rsa_sign;
38 static OSSL_OP_signature_verify_fn rsa_verify;
39 static OSSL_OP_signature_verify_recover_fn rsa_verify_recover;
40 static OSSL_OP_signature_digest_sign_init_fn rsa_digest_signverify_init;
41 static OSSL_OP_signature_digest_sign_update_fn rsa_digest_signverify_update;
42 static OSSL_OP_signature_digest_sign_final_fn rsa_digest_sign_final;
43 static OSSL_OP_signature_digest_verify_init_fn rsa_digest_signverify_init;
44 static OSSL_OP_signature_digest_verify_update_fn rsa_digest_signverify_update;
45 static OSSL_OP_signature_digest_verify_final_fn rsa_digest_verify_final;
46 static OSSL_OP_signature_freectx_fn rsa_freectx;
47 static OSSL_OP_signature_dupctx_fn rsa_dupctx;
48 static OSSL_OP_signature_get_ctx_params_fn rsa_get_ctx_params;
49 static OSSL_OP_signature_gettable_ctx_params_fn rsa_gettable_ctx_params;
50 static OSSL_OP_signature_set_ctx_params_fn rsa_set_ctx_params;
51 static OSSL_OP_signature_settable_ctx_params_fn rsa_settable_ctx_params;
52 static OSSL_OP_signature_get_ctx_md_params_fn rsa_get_ctx_md_params;
53 static OSSL_OP_signature_gettable_ctx_md_params_fn rsa_gettable_ctx_md_params;
54 static OSSL_OP_signature_set_ctx_md_params_fn rsa_set_ctx_md_params;
55 static OSSL_OP_signature_settable_ctx_md_params_fn rsa_settable_ctx_md_params;
56
57 static OSSL_ITEM padding_item[] = {
58     { RSA_PKCS1_PADDING,        "pkcs1"  },
59     { RSA_SSLV23_PADDING,       "sslv23" },
60     { RSA_NO_PADDING,           "none"   },
61     { RSA_PKCS1_OAEP_PADDING,   "oaep"   }, /* Correct spelling first */
62     { RSA_PKCS1_OAEP_PADDING,   "oeap"   },
63     { RSA_X931_PADDING,         "x931"   },
64     { RSA_PKCS1_PSS_PADDING,    "pss"    },
65     { 0,                        NULL     }
66 };
67
68 /*
69  * What's passed as an actual key is defined by the KEYMGMT interface.
70  * We happen to know that our KEYMGMT simply passes RSA structures, so
71  * we use that here too.
72  */
73
74 typedef struct {
75     OPENSSL_CTX *libctx;
76     RSA *rsa;
77
78     /*
79      * Flag to determine if the hash function can be changed (1) or not (0)
80      * Because it's dangerous to change during a DigestSign or DigestVerify
81      * operation, this flag is cleared by their Init function, and set again
82      * by their Final function.
83      */
84     unsigned int flag_allow_md : 1;
85
86     /* The Algorithm Identifier of the combined signature agorithm */
87     unsigned char aid_buf[128];
88     unsigned char *aid;
89     size_t  aid_len;
90
91     /* main digest */
92     EVP_MD *md;
93     EVP_MD_CTX *mdctx;
94     int mdnid;
95     char mdname[OSSL_MAX_NAME_SIZE]; /* Purely informational */
96
97     /* RSA padding mode */
98     int pad_mode;
99     /* message digest for MGF1 */
100     EVP_MD *mgf1_md;
101     char mgf1_mdname[OSSL_MAX_NAME_SIZE]; /* Purely informational */
102     /* PSS salt length */
103     int saltlen;
104     /* Minimum salt length or -1 if no PSS parameter restriction */
105     int min_saltlen;
106
107     /* Temp buffer */
108     unsigned char *tbuf;
109
110 } PROV_RSA_CTX;
111
112 static size_t rsa_get_md_size(const PROV_RSA_CTX *prsactx)
113 {
114     if (prsactx->md != NULL)
115         return EVP_MD_size(prsactx->md);
116     return 0;
117 }
118
119 static int rsa_get_md_nid(const EVP_MD *md)
120 {
121     /*
122      * Because the RSA library deals with NIDs, we need to translate.
123      * We do so using EVP_MD_is_a(), and therefore need a name to NID
124      * map.
125      */
126     static const OSSL_ITEM name_to_nid[] = {
127         { NID_sha1,      OSSL_DIGEST_NAME_SHA1      },
128         { NID_sha224,    OSSL_DIGEST_NAME_SHA2_224  },
129         { NID_sha256,    OSSL_DIGEST_NAME_SHA2_256  },
130         { NID_sha384,    OSSL_DIGEST_NAME_SHA2_384  },
131         { NID_sha512,    OSSL_DIGEST_NAME_SHA2_512  },
132         { NID_md5,       OSSL_DIGEST_NAME_MD5       },
133         { NID_md5_sha1,  OSSL_DIGEST_NAME_MD5_SHA1  },
134         { NID_md2,       OSSL_DIGEST_NAME_MD2       },
135         { NID_md4,       OSSL_DIGEST_NAME_MD4       },
136         { NID_mdc2,      OSSL_DIGEST_NAME_MDC2      },
137         { NID_ripemd160, OSSL_DIGEST_NAME_RIPEMD160 },
138         { NID_sha3_224,  OSSL_DIGEST_NAME_SHA3_224  },
139         { NID_sha3_256,  OSSL_DIGEST_NAME_SHA3_256  },
140         { NID_sha3_384,  OSSL_DIGEST_NAME_SHA3_384  },
141         { NID_sha3_512,  OSSL_DIGEST_NAME_SHA3_512  },
142     };
143     size_t i;
144     int mdnid = NID_undef;
145
146     if (md == NULL)
147         goto end;
148
149     for (i = 0; i < OSSL_NELEM(name_to_nid); i++) {
150         if (EVP_MD_is_a(md, name_to_nid[i].ptr)) {
151             mdnid = (int)name_to_nid[i].id;
152             break;
153         }
154     }
155
156     if (mdnid == NID_undef)
157         ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_DIGEST);
158
159  end:
160     return mdnid;
161 }
162
163 static int rsa_check_padding(int mdnid, int padding)
164 {
165     if (padding == RSA_NO_PADDING) {
166         ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_PADDING_MODE);
167         return 0;
168     }
169
170     if (padding == RSA_X931_PADDING) {
171         if (RSA_X931_hash_id(mdnid) == -1) {
172             ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_X931_DIGEST);
173             return 0;
174         }
175     }
176
177     return 1;
178 }
179
180 static void *rsa_newctx(void *provctx)
181 {
182     PROV_RSA_CTX *prsactx = OPENSSL_zalloc(sizeof(PROV_RSA_CTX));
183
184     if (prsactx == NULL)
185         return NULL;
186
187     prsactx->libctx = PROV_LIBRARY_CONTEXT_OF(provctx);
188     prsactx->flag_allow_md = 1;
189     return prsactx;
190 }
191
192 /* True if PSS parameters are restricted */
193 #define rsa_pss_restricted(prsactx) (prsactx->min_saltlen != -1)
194
195 static int rsa_signature_init(void *vprsactx, void *vrsa)
196 {
197     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
198
199     if (prsactx == NULL || vrsa == NULL || !RSA_up_ref(vrsa))
200         return 0;
201
202     RSA_free(prsactx->rsa);
203     prsactx->rsa = vrsa;
204     if (RSA_get0_pss_params(prsactx->rsa) != NULL)
205         prsactx->pad_mode = RSA_PKCS1_PSS_PADDING;
206     else
207         prsactx->pad_mode = RSA_PKCS1_PADDING;
208     /* Maximum for sign, auto for verify */
209     prsactx->saltlen = RSA_PSS_SALTLEN_AUTO;
210     prsactx->min_saltlen = -1;
211
212     return 1;
213 }
214
215 static int rsa_setup_md(PROV_RSA_CTX *ctx, const char *mdname,
216                         const char *mdprops)
217 {
218     if (mdname != NULL) {
219         EVP_MD *md = EVP_MD_fetch(ctx->libctx, mdname, mdprops);
220         int md_nid = rsa_get_md_nid(md);
221         WPACKET pkt;
222
223         if (md == NULL
224             || md_nid == NID_undef
225             || !rsa_check_padding(md_nid, ctx->pad_mode)) {
226             EVP_MD_free(md);
227             return 0;
228         }
229
230         EVP_MD_CTX_free(ctx->mdctx);
231         EVP_MD_free(ctx->md);
232
233         /*
234          * TODO(3.0) Should we care about DER writing errors?
235          * All it really means is that for some reason, there's no
236          * AlgorithmIdentifier to be had (consider RSA with MD5-SHA1),
237          * but the operation itself is still valid, just as long as it's
238          * not used to construct anything that needs an AlgorithmIdentifier.
239          */
240         ctx->aid_len = 0;
241         if (WPACKET_init_der(&pkt, ctx->aid_buf, sizeof(ctx->aid_buf))
242             && DER_w_algorithmIdentifier_RSA_with(&pkt, -1, ctx->rsa, md_nid)
243             && WPACKET_finish(&pkt)) {
244             WPACKET_get_total_written(&pkt, &ctx->aid_len);
245             ctx->aid = WPACKET_get_curr(&pkt);
246         }
247         WPACKET_cleanup(&pkt);
248
249         ctx->mdctx = NULL;
250         ctx->md = md;
251         ctx->mdnid = md_nid;
252         OPENSSL_strlcpy(ctx->mdname, mdname, sizeof(ctx->mdname));
253     }
254
255     return 1;
256 }
257
258 static int rsa_setup_mgf1_md(PROV_RSA_CTX *ctx, const char *mdname,
259                              const char *props)
260 {
261     if (ctx->mgf1_mdname[0] != '\0')
262         EVP_MD_free(ctx->mgf1_md);
263
264     if ((ctx->mgf1_md = EVP_MD_fetch(ctx->libctx, mdname, props)) == NULL)
265         return 0;
266     OPENSSL_strlcpy(ctx->mgf1_mdname, mdname, sizeof(ctx->mgf1_mdname));
267
268     return 1;
269 }
270
271 static int setup_tbuf(PROV_RSA_CTX *ctx)
272 {
273     if (ctx->tbuf != NULL)
274         return 1;
275     if ((ctx->tbuf = OPENSSL_malloc(RSA_size(ctx->rsa))) == NULL) {
276         ERR_raise(ERR_LIB_PROV, ERR_R_MALLOC_FAILURE);
277         return 0;
278     }
279     return 1;
280 }
281
282 static void clean_tbuf(PROV_RSA_CTX *ctx)
283 {
284     if (ctx->tbuf != NULL)
285         OPENSSL_cleanse(ctx->tbuf, RSA_size(ctx->rsa));
286 }
287
288 static void free_tbuf(PROV_RSA_CTX *ctx)
289 {
290     OPENSSL_clear_free(ctx->tbuf, RSA_size(ctx->rsa));
291     ctx->tbuf = NULL;
292 }
293
294 static int rsa_sign(void *vprsactx, unsigned char *sig, size_t *siglen,
295                     size_t sigsize, const unsigned char *tbs, size_t tbslen)
296 {
297     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
298     int ret;
299     size_t rsasize = RSA_size(prsactx->rsa);
300     size_t mdsize = rsa_get_md_size(prsactx);
301
302     if (sig == NULL) {
303         *siglen = rsasize;
304         return 1;
305     }
306
307     if (sigsize < (size_t)rsasize)
308         return 0;
309
310     if (mdsize != 0) {
311         if (tbslen != mdsize) {
312             ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_DIGEST_LENGTH);
313             return 0;
314         }
315
316 #ifndef FIPS_MODE
317         if (EVP_MD_is_a(prsactx->md, OSSL_DIGEST_NAME_MDC2)) {
318             unsigned int sltmp;
319
320             if (prsactx->pad_mode != RSA_PKCS1_PADDING) {
321                 ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_PADDING_MODE,
322                                "only PKCS#1 padding supported with MDC2");
323                 return 0;
324             }
325             ret = RSA_sign_ASN1_OCTET_STRING(0, tbs, tbslen, sig, &sltmp,
326                                              prsactx->rsa);
327
328             if (ret <= 0) {
329                 ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
330                 return 0;
331             }
332             ret = sltmp;
333             goto end;
334         }
335 #endif
336         switch (prsactx->pad_mode) {
337         case RSA_X931_PADDING:
338             if ((size_t)RSA_size(prsactx->rsa) < tbslen + 1) {
339                 ERR_raise(ERR_LIB_PROV, PROV_R_KEY_SIZE_TOO_SMALL);
340                 return 0;
341             }
342             if (!setup_tbuf(prsactx)) {
343                 ERR_raise(ERR_LIB_PROV, ERR_R_MALLOC_FAILURE);
344                 return 0;
345             }
346             memcpy(prsactx->tbuf, tbs, tbslen);
347             prsactx->tbuf[tbslen] = RSA_X931_hash_id(prsactx->mdnid);
348             ret = RSA_private_encrypt(tbslen + 1, prsactx->tbuf,
349                                       sig, prsactx->rsa, RSA_X931_PADDING);
350             clean_tbuf(prsactx);
351             break;
352
353         case RSA_PKCS1_PADDING:
354             {
355                 unsigned int sltmp;
356
357                 ret = RSA_sign(prsactx->mdnid, tbs, tbslen, sig, &sltmp,
358                                prsactx->rsa);
359                 if (ret <= 0) {
360                     ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
361                     return 0;
362                 }
363                 ret = sltmp;
364             }
365             break;
366
367         case RSA_PKCS1_PSS_PADDING:
368             /* Check PSS restrictions */
369             if (rsa_pss_restricted(prsactx)) {
370                 switch (prsactx->saltlen) {
371                 case RSA_PSS_SALTLEN_DIGEST:
372                     if (prsactx->min_saltlen > EVP_MD_size(prsactx->md)) {
373                         ERR_raise(ERR_LIB_PROV, PROV_R_PSS_SALTLEN_TOO_SMALL);
374                         return 0;
375                     }
376                     /* FALLTHRU */
377                 default:
378                     if (prsactx->saltlen >= 0
379                         && prsactx->saltlen < prsactx->min_saltlen) {
380                         ERR_raise(ERR_LIB_PROV, PROV_R_PSS_SALTLEN_TOO_SMALL);
381                         return 0;
382                     }
383                     break;
384                 }
385             }
386             if (!setup_tbuf(prsactx))
387                 return 0;
388             if (!RSA_padding_add_PKCS1_PSS_mgf1(prsactx->rsa,
389                                                 prsactx->tbuf, tbs,
390                                                 prsactx->md, prsactx->mgf1_md,
391                                                 prsactx->saltlen)) {
392                 ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
393                 return 0;
394             }
395             ret = RSA_private_encrypt(RSA_size(prsactx->rsa), prsactx->tbuf,
396                                       sig, prsactx->rsa, RSA_NO_PADDING);
397             clean_tbuf(prsactx);
398             break;
399
400         default:
401             ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_PADDING_MODE,
402                            "Only X.931, PKCS#1 v1.5 or PSS padding allowed");
403             return 0;
404         }
405     } else {
406         ret = RSA_private_encrypt(tbslen, tbs, sig, prsactx->rsa,
407                                   prsactx->pad_mode);
408     }
409
410 #ifndef FIPS_MODE
411  end:
412 #endif
413     if (ret <= 0) {
414         ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
415         return 0;
416     }
417
418     *siglen = ret;
419     return 1;
420 }
421
422 static int rsa_verify_recover(void *vprsactx,
423                               unsigned char *rout,
424                               size_t *routlen,
425                               size_t routsize,
426                               const unsigned char *sig,
427                               size_t siglen)
428 {
429     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
430     int ret;
431
432     if (rout == NULL) {
433         *routlen = RSA_size(prsactx->rsa);
434         return 1;
435     }
436
437     if (prsactx->md != NULL) {
438         switch (prsactx->pad_mode) {
439         case RSA_X931_PADDING:
440             if (!setup_tbuf(prsactx))
441                 return 0;
442             ret = RSA_public_decrypt(siglen, sig, prsactx->tbuf, prsactx->rsa,
443                                      RSA_X931_PADDING);
444             if (ret < 1) {
445                 ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
446                 return 0;
447             }
448             ret--;
449             if (prsactx->tbuf[ret] != RSA_X931_hash_id(prsactx->mdnid)) {
450                 ERR_raise(ERR_LIB_PROV, PROV_R_ALGORITHM_MISMATCH);
451                 return 0;
452             }
453             if (ret != EVP_MD_size(prsactx->md)) {
454                 ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_DIGEST_LENGTH,
455                                "Should be %d, but got %d",
456                                EVP_MD_size(prsactx->md), ret);
457                 return 0;
458             }
459
460             *routlen = ret;
461             if (routsize < (size_t)ret) {
462                 ERR_raise(ERR_LIB_PROV, PROV_R_OUTPUT_BUFFER_TOO_SMALL);
463                 return 0;
464             }
465             memcpy(rout, prsactx->tbuf, ret);
466             break;
467
468         case RSA_PKCS1_PADDING:
469             {
470                 size_t sltmp;
471
472                 ret = int_rsa_verify(prsactx->mdnid, NULL, 0, rout, &sltmp,
473                                      sig, siglen, prsactx->rsa);
474                 if (ret <= 0) {
475                     ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
476                     return 0;
477                 }
478                 ret = sltmp;
479             }
480             break;
481
482         default:
483             ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_PADDING_MODE,
484                            "Only X.931 or PKCS#1 v1.5 padding allowed");
485             return 0;
486         }
487     } else {
488         ret = RSA_public_decrypt(siglen, sig, rout, prsactx->rsa,
489                                  prsactx->pad_mode);
490         if (ret < 0) {
491             ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
492             return 0;
493         }
494     }
495     *routlen = ret;
496     return 1;
497 }
498
499 static int rsa_verify(void *vprsactx, const unsigned char *sig, size_t siglen,
500                       const unsigned char *tbs, size_t tbslen)
501 {
502     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
503     size_t rslen;
504
505     if (prsactx->md != NULL) {
506         switch (prsactx->pad_mode) {
507         case RSA_PKCS1_PADDING:
508             if (!RSA_verify(prsactx->mdnid, tbs, tbslen, sig, siglen,
509                             prsactx->rsa)) {
510                 ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
511                 return 0;
512             }
513             return 1;
514         case RSA_X931_PADDING:
515             if (rsa_verify_recover(prsactx, NULL, &rslen, 0, sig, siglen) <= 0)
516                 return 0;
517             break;
518         case RSA_PKCS1_PSS_PADDING:
519             {
520                 int ret;
521                 size_t mdsize;
522
523                 /* Check PSS restrictions */
524                 if (rsa_pss_restricted(prsactx)) {
525                     switch (prsactx->saltlen) {
526                     case RSA_PSS_SALTLEN_AUTO:
527                         ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_PSS_SALTLEN);
528                         return 0;
529                     case RSA_PSS_SALTLEN_DIGEST:
530                         if (prsactx->min_saltlen > EVP_MD_size(prsactx->md)) {
531                             ERR_raise(ERR_LIB_PROV,
532                                       PROV_R_PSS_SALTLEN_TOO_SMALL);
533                             return 0;
534                         }
535                         /* FALLTHRU */
536                     default:
537                         if (prsactx->saltlen >= 0
538                             && prsactx->saltlen < prsactx->min_saltlen) {
539                             ERR_raise(ERR_LIB_PROV, PROV_R_PSS_SALTLEN_TOO_SMALL);
540                             return 0;
541                         }
542                         break;
543                     }
544                 }
545
546                 /*
547                  * We need to check this for the RSA_verify_PKCS1_PSS_mgf1()
548                  * call
549                  */
550                 mdsize = rsa_get_md_size(prsactx);
551                 if (tbslen != mdsize) {
552                     ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_DIGEST_LENGTH,
553                                    "Should be %d, but got %d",
554                                    mdsize, tbslen);
555                     return 0;
556                 }
557
558                 if (!setup_tbuf(prsactx))
559                     return 0;
560                 ret = RSA_public_decrypt(siglen, sig, prsactx->tbuf,
561                                          prsactx->rsa, RSA_NO_PADDING);
562                 if (ret <= 0) {
563                     ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
564                     return 0;
565                 }
566                 ret = RSA_verify_PKCS1_PSS_mgf1(prsactx->rsa, tbs,
567                                                 prsactx->md, prsactx->mgf1_md,
568                                                 prsactx->tbuf,
569                                                 prsactx->saltlen);
570                 if (ret <= 0) {
571                     ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
572                     return 0;
573                 }
574                 return 1;
575             }
576         default:
577             ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_PADDING_MODE,
578                            "Only X.931, PKCS#1 v1.5 or PSS padding allowed");
579             return 0;
580         }
581     } else {
582         if (!setup_tbuf(prsactx))
583             return 0;
584         rslen = RSA_public_decrypt(siglen, sig, prsactx->tbuf, prsactx->rsa,
585                                    prsactx->pad_mode);
586         if (rslen == 0) {
587             ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
588             return 0;
589         }
590     }
591
592     if ((rslen != tbslen) || memcmp(tbs, prsactx->tbuf, rslen))
593         return 0;
594
595     return 1;
596 }
597
598 static int rsa_digest_signverify_init(void *vprsactx, const char *mdname,
599                                       const char *props, void *vrsa)
600 {
601     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
602
603     prsactx->flag_allow_md = 0;
604     if (!rsa_signature_init(vprsactx, vrsa)
605         || !rsa_setup_md(prsactx, mdname, props))
606         return 0;
607
608     prsactx->mdctx = EVP_MD_CTX_new();
609     if (prsactx->mdctx == NULL)
610         goto error;
611
612     if (!EVP_DigestInit_ex(prsactx->mdctx, prsactx->md, NULL))
613         goto error;
614
615     return 1;
616
617  error:
618     EVP_MD_CTX_free(prsactx->mdctx);
619     EVP_MD_free(prsactx->md);
620     prsactx->mdctx = NULL;
621     prsactx->md = NULL;
622     return 0;
623 }
624
625 int rsa_digest_signverify_update(void *vprsactx, const unsigned char *data,
626                                  size_t datalen)
627 {
628     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
629
630     if (prsactx == NULL || prsactx->mdctx == NULL)
631         return 0;
632
633     return EVP_DigestUpdate(prsactx->mdctx, data, datalen);
634 }
635
636 int rsa_digest_sign_final(void *vprsactx, unsigned char *sig, size_t *siglen,
637                           size_t sigsize)
638 {
639     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
640     unsigned char digest[EVP_MAX_MD_SIZE];
641     unsigned int dlen = 0;
642
643     prsactx->flag_allow_md = 1;
644     if (prsactx == NULL || prsactx->mdctx == NULL)
645         return 0;
646
647     /*
648      * If sig is NULL then we're just finding out the sig size. Other fields
649      * are ignored. Defer to rsa_sign.
650      */
651     if (sig != NULL) {
652         /*
653          * TODO(3.0): There is the possibility that some externally provided
654          * digests exceed EVP_MAX_MD_SIZE. We should probably handle that somehow -
655          * but that problem is much larger than just in RSA.
656          */
657         if (!EVP_DigestFinal_ex(prsactx->mdctx, digest, &dlen))
658             return 0;
659     }
660
661     return rsa_sign(vprsactx, sig, siglen, sigsize, digest, (size_t)dlen);
662 }
663
664
665 int rsa_digest_verify_final(void *vprsactx, const unsigned char *sig,
666                             size_t siglen)
667 {
668     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
669     unsigned char digest[EVP_MAX_MD_SIZE];
670     unsigned int dlen = 0;
671
672     prsactx->flag_allow_md = 1;
673     if (prsactx == NULL || prsactx->mdctx == NULL)
674         return 0;
675
676     /*
677      * TODO(3.0): There is the possibility that some externally provided
678      * digests exceed EVP_MAX_MD_SIZE. We should probably handle that somehow -
679      * but that problem is much larger than just in RSA.
680      */
681     if (!EVP_DigestFinal_ex(prsactx->mdctx, digest, &dlen))
682         return 0;
683
684     return rsa_verify(vprsactx, sig, siglen, digest, (size_t)dlen);
685 }
686
687 static void rsa_freectx(void *vprsactx)
688 {
689     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
690
691     if (prsactx == NULL)
692         return;
693
694     RSA_free(prsactx->rsa);
695     EVP_MD_CTX_free(prsactx->mdctx);
696     EVP_MD_free(prsactx->md);
697     EVP_MD_free(prsactx->mgf1_md);
698     free_tbuf(prsactx);
699
700     OPENSSL_clear_free(prsactx, sizeof(prsactx));
701 }
702
703 static void *rsa_dupctx(void *vprsactx)
704 {
705     PROV_RSA_CTX *srcctx = (PROV_RSA_CTX *)vprsactx;
706     PROV_RSA_CTX *dstctx;
707
708     dstctx = OPENSSL_zalloc(sizeof(*srcctx));
709     if (dstctx == NULL)
710         return NULL;
711
712     *dstctx = *srcctx;
713     dstctx->rsa = NULL;
714     dstctx->md = NULL;
715     dstctx->mdctx = NULL;
716     dstctx->tbuf = NULL;
717
718     if (srcctx->rsa != NULL && !RSA_up_ref(srcctx->rsa))
719         goto err;
720     dstctx->rsa = srcctx->rsa;
721
722     if (srcctx->md != NULL && !EVP_MD_up_ref(srcctx->md))
723         goto err;
724     dstctx->md = srcctx->md;
725
726     if (srcctx->mgf1_md != NULL && !EVP_MD_up_ref(srcctx->mgf1_md))
727         goto err;
728     dstctx->mgf1_md = srcctx->mgf1_md;
729
730     if (srcctx->mdctx != NULL) {
731         dstctx->mdctx = EVP_MD_CTX_new();
732         if (dstctx->mdctx == NULL
733                 || !EVP_MD_CTX_copy_ex(dstctx->mdctx, srcctx->mdctx))
734             goto err;
735     }
736
737     return dstctx;
738  err:
739     rsa_freectx(dstctx);
740     return NULL;
741 }
742
743 static int rsa_get_ctx_params(void *vprsactx, OSSL_PARAM *params)
744 {
745     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
746     OSSL_PARAM *p;
747
748     if (prsactx == NULL || params == NULL)
749         return 0;
750
751     p = OSSL_PARAM_locate(params, OSSL_SIGNATURE_PARAM_ALGORITHM_ID);
752     if (p != NULL
753         && !OSSL_PARAM_set_octet_string(p, prsactx->aid, prsactx->aid_len))
754         return 0;
755
756     p = OSSL_PARAM_locate(params, OSSL_SIGNATURE_PARAM_PAD_MODE);
757     if (p != NULL)
758         switch (p->data_type) {
759         case OSSL_PARAM_INTEGER:
760             if (!OSSL_PARAM_set_int(p, prsactx->pad_mode))
761                 return 0;
762             break;
763         case OSSL_PARAM_UTF8_STRING:
764             {
765                 int i;
766                 const char *word = NULL;
767
768                 for (i = 0; padding_item[i].id != 0; i++) {
769                     if (prsactx->pad_mode == (int)padding_item[i].id) {
770                         word = padding_item[i].ptr;
771                         break;
772                     }
773                 }
774
775                 if (word != NULL) {
776                     if (!OSSL_PARAM_set_utf8_string(p, word))
777                         return 0;
778                 } else {
779                     ERR_raise(ERR_LIB_PROV, ERR_R_INTERNAL_ERROR);
780                 }
781             }
782             break;
783         default:
784             return 0;
785         }
786
787     p = OSSL_PARAM_locate(params, OSSL_SIGNATURE_PARAM_DIGEST);
788     if (p != NULL && !OSSL_PARAM_set_utf8_string(p, prsactx->mdname))
789         return 0;
790
791     p = OSSL_PARAM_locate(params, OSSL_SIGNATURE_PARAM_MGF1_DIGEST);
792     if (p != NULL && !OSSL_PARAM_set_utf8_string(p, prsactx->mgf1_mdname))
793         return 0;
794
795     p = OSSL_PARAM_locate(params, OSSL_SIGNATURE_PARAM_PSS_SALTLEN);
796     if (p != NULL) {
797         if (p->data_type == OSSL_PARAM_INTEGER) {
798             if (!OSSL_PARAM_set_int(p, prsactx->saltlen))
799                 return 0;
800         } else if (p->data_type == OSSL_PARAM_UTF8_STRING) {
801             switch (prsactx->saltlen) {
802             case RSA_PSS_SALTLEN_DIGEST:
803                 if (!OSSL_PARAM_set_utf8_string(p, "digest"))
804                     return 0;
805                 break;
806             case RSA_PSS_SALTLEN_MAX:
807                 if (!OSSL_PARAM_set_utf8_string(p, "max"))
808                     return 0;
809                 break;
810             case RSA_PSS_SALTLEN_AUTO:
811                 if (!OSSL_PARAM_set_utf8_string(p, "auto"))
812                     return 0;
813                 break;
814             default:
815                 if (BIO_snprintf(p->data, p->data_size, "%d", prsactx->saltlen)
816                     <= 0)
817                     return 0;
818                 break;
819             }
820         }
821     }
822
823     return 1;
824 }
825
826 static const OSSL_PARAM known_gettable_ctx_params[] = {
827     OSSL_PARAM_octet_string(OSSL_SIGNATURE_PARAM_ALGORITHM_ID, NULL, 0),
828     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_PAD_MODE, NULL, 0),
829     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_DIGEST, NULL, 0),
830     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_MGF1_DIGEST, NULL, 0),
831     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_PSS_SALTLEN, NULL, 0),
832     OSSL_PARAM_END
833 };
834
835 static const OSSL_PARAM *rsa_gettable_ctx_params(void)
836 {
837     return known_gettable_ctx_params;
838 }
839
840 static int rsa_set_ctx_params(void *vprsactx, const OSSL_PARAM params[])
841 {
842     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
843     const OSSL_PARAM *p;
844
845     if (prsactx == NULL || params == NULL)
846         return 0;
847
848     p = OSSL_PARAM_locate_const(params, OSSL_SIGNATURE_PARAM_DIGEST);
849     /* Not allowed during certain operations */
850     if (p != NULL && !prsactx->flag_allow_md)
851         return 0;
852     if (p != NULL) {
853         char mdname[OSSL_MAX_NAME_SIZE] = "", *pmdname = mdname;
854         char mdprops[OSSL_MAX_PROPQUERY_SIZE] = "", *pmdprops = mdprops;
855         const OSSL_PARAM *propsp =
856             OSSL_PARAM_locate_const(params,
857                                     OSSL_SIGNATURE_PARAM_PROPERTIES);
858
859         if (!OSSL_PARAM_get_utf8_string(p, &pmdname, sizeof(mdname)))
860             return 0;
861         if (propsp != NULL
862             && !OSSL_PARAM_get_utf8_string(propsp, &pmdprops, sizeof(mdprops)))
863             return 0;
864
865         /* TODO(3.0) PSS check needs more work */
866         if (rsa_pss_restricted(prsactx)) {
867             /* TODO(3.0) figure out what to do for prsactx->md == NULL */
868             if (prsactx->md == NULL || EVP_MD_is_a(prsactx->md, mdname))
869                 return 1;
870             ERR_raise(ERR_LIB_PROV, PROV_R_DIGEST_NOT_ALLOWED);
871             return 0;
872         }
873
874         /* non-PSS code follows */
875         if (!rsa_setup_md(prsactx, mdname, mdprops))
876             return 0;
877     }
878
879     p = OSSL_PARAM_locate_const(params, OSSL_SIGNATURE_PARAM_PAD_MODE);
880     if (p != NULL) {
881         int pad_mode = 0;
882
883         switch (p->data_type) {
884         case OSSL_PARAM_INTEGER: /* Support for legacy pad mode number */
885             if (!OSSL_PARAM_get_int(p, &pad_mode))
886                 return 0;
887             break;
888         case OSSL_PARAM_UTF8_STRING:
889             {
890                 int i;
891
892                 if (p->data == NULL)
893                     return 0;
894
895                 for (i = 0; padding_item[i].id != 0; i++) {
896                     if (strcmp(p->data, padding_item[i].ptr) == 0) {
897                         pad_mode = padding_item[i].id;
898                         break;
899                     }
900                 }
901             }
902             break;
903         default:
904             return 0;
905         }
906
907         switch (pad_mode) {
908         case RSA_PKCS1_OAEP_PADDING:
909             /*
910              * OAEP padding is for asymmetric cipher only so is not compatible
911              * with signature use.
912              */
913             ERR_raise_data(ERR_LIB_PROV,
914                            PROV_R_ILLEGAL_OR_UNSUPPORTED_PADDING_MODE,
915                            "OAEP padding not allowed for signing / verifying");
916             return 0;
917         case RSA_PKCS1_PSS_PADDING:
918             if (prsactx->mdname[0] == '\0')
919                 rsa_setup_md(prsactx, "SHA1", "");
920             goto cont;
921         case RSA_PKCS1_PADDING:
922         case RSA_SSLV23_PADDING:
923         case RSA_NO_PADDING:
924         case RSA_X931_PADDING:
925             if (RSA_get0_pss_params(prsactx->rsa) != NULL) {
926                 ERR_raise_data(ERR_LIB_PROV,
927                                PROV_R_ILLEGAL_OR_UNSUPPORTED_PADDING_MODE,
928                                "X.931 padding not allowed with RSA-PSS");
929                 return 0;
930             }
931         cont:
932             if (!rsa_check_padding(prsactx->mdnid, pad_mode))
933                 return 0;
934             break;
935         default:
936             return 0;
937         }
938         prsactx->pad_mode = pad_mode;
939     }
940
941     p = OSSL_PARAM_locate_const(params, OSSL_SIGNATURE_PARAM_PSS_SALTLEN);
942     if (p != NULL) {
943         int saltlen;
944
945         if (prsactx->pad_mode != RSA_PKCS1_PSS_PADDING) {
946             ERR_raise_data(ERR_LIB_PROV, PROV_R_NOT_SUPPORTED,
947                            "PSS saltlen can only be specified if "
948                            "PSS padding has been specified first");
949             return 0;
950         }
951
952         switch (p->data_type) {
953         case OSSL_PARAM_INTEGER: /* Support for legacy pad mode number */
954             if (!OSSL_PARAM_get_int(p, &saltlen))
955                 return 0;
956             break;
957         case OSSL_PARAM_UTF8_STRING:
958             if (strcmp(p->data, "digest") == 0)
959                 saltlen = RSA_PSS_SALTLEN_DIGEST;
960             else if (strcmp(p->data, "max") == 0)
961                 saltlen = RSA_PSS_SALTLEN_MAX;
962             else if (strcmp(p->data, "auto") == 0)
963                 saltlen = RSA_PSS_SALTLEN_AUTO;
964             else
965                 saltlen = atoi(p->data);
966             break;
967         default:
968             return 0;
969         }
970
971         /*
972          * RSA_PSS_SALTLEN_MAX seems curiously named in this check.
973          * Contrary to what it's name suggests, it's the currently
974          * lowest saltlen number possible.
975          */
976         if (saltlen < RSA_PSS_SALTLEN_MAX) {
977             ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_PSS_SALTLEN);
978             return 0;
979         }
980
981         prsactx->saltlen = saltlen;
982     }
983
984     p = OSSL_PARAM_locate_const(params, OSSL_SIGNATURE_PARAM_MGF1_DIGEST);
985     if (p != NULL) {
986         char mdname[OSSL_MAX_NAME_SIZE] = "", *pmdname = mdname;
987         char mdprops[OSSL_MAX_PROPQUERY_SIZE] = "", *pmdprops = mdprops;
988         const OSSL_PARAM *propsp =
989             OSSL_PARAM_locate_const(params,
990                                     OSSL_SIGNATURE_PARAM_MGF1_PROPERTIES);
991
992         if (!OSSL_PARAM_get_utf8_string(p, &pmdname, sizeof(mdname)))
993             return 0;
994         if (propsp != NULL
995             && !OSSL_PARAM_get_utf8_string(propsp, &pmdprops, sizeof(mdprops)))
996             return 0;
997
998         if (prsactx->pad_mode != RSA_PKCS1_PSS_PADDING) {
999             ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_MGF1_MD);
1000             return  0;
1001         }
1002
1003         /* TODO(3.0) PSS check needs more work */
1004         if (rsa_pss_restricted(prsactx)) {
1005             /* TODO(3.0) figure out what to do for prsactx->md == NULL */
1006             if (prsactx->mgf1_md == NULL
1007                 || EVP_MD_is_a(prsactx->mgf1_md, mdname))
1008                 return 1;
1009             ERR_raise(ERR_LIB_PROV, PROV_R_DIGEST_NOT_ALLOWED);
1010             return 0;
1011         }
1012
1013         /* non-PSS code follows */
1014         if (!rsa_setup_mgf1_md(prsactx, mdname, mdprops))
1015             return 0;
1016     }
1017
1018     return 1;
1019 }
1020
1021 static const OSSL_PARAM known_settable_ctx_params[] = {
1022     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_PAD_MODE, NULL, 0),
1023     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_DIGEST, NULL, 0),
1024     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_PROPERTIES, NULL, 0),
1025     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_MGF1_DIGEST, NULL, 0),
1026     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_MGF1_PROPERTIES, NULL, 0),
1027     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_PSS_SALTLEN, NULL, 0),
1028     OSSL_PARAM_END
1029 };
1030
1031 static const OSSL_PARAM *rsa_settable_ctx_params(void)
1032 {
1033     /*
1034      * TODO(3.0): Should this function return a different set of settable ctx
1035      * params if the ctx is being used for a DigestSign/DigestVerify? In that
1036      * case it is not allowed to set the digest size/digest name because the
1037      * digest is explicitly set as part of the init.
1038      */
1039     return known_settable_ctx_params;
1040 }
1041
1042 static int rsa_get_ctx_md_params(void *vprsactx, OSSL_PARAM *params)
1043 {
1044     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
1045
1046     if (prsactx->mdctx == NULL)
1047         return 0;
1048
1049     return EVP_MD_CTX_get_params(prsactx->mdctx, params);
1050 }
1051
1052 static const OSSL_PARAM *rsa_gettable_ctx_md_params(void *vprsactx)
1053 {
1054     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
1055
1056     if (prsactx->md == NULL)
1057         return 0;
1058
1059     return EVP_MD_gettable_ctx_params(prsactx->md);
1060 }
1061
1062 static int rsa_set_ctx_md_params(void *vprsactx, const OSSL_PARAM params[])
1063 {
1064     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
1065
1066     if (prsactx->mdctx == NULL)
1067         return 0;
1068
1069     return EVP_MD_CTX_set_params(prsactx->mdctx, params);
1070 }
1071
1072 static const OSSL_PARAM *rsa_settable_ctx_md_params(void *vprsactx)
1073 {
1074     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
1075
1076     if (prsactx->md == NULL)
1077         return 0;
1078
1079     return EVP_MD_settable_ctx_params(prsactx->md);
1080 }
1081
1082 const OSSL_DISPATCH rsa_signature_functions[] = {
1083     { OSSL_FUNC_SIGNATURE_NEWCTX, (void (*)(void))rsa_newctx },
1084     { OSSL_FUNC_SIGNATURE_SIGN_INIT, (void (*)(void))rsa_signature_init },
1085     { OSSL_FUNC_SIGNATURE_SIGN, (void (*)(void))rsa_sign },
1086     { OSSL_FUNC_SIGNATURE_VERIFY_INIT, (void (*)(void))rsa_signature_init },
1087     { OSSL_FUNC_SIGNATURE_VERIFY, (void (*)(void))rsa_verify },
1088     { OSSL_FUNC_SIGNATURE_VERIFY_RECOVER_INIT, (void (*)(void))rsa_signature_init },
1089     { OSSL_FUNC_SIGNATURE_VERIFY_RECOVER, (void (*)(void))rsa_verify_recover },
1090     { OSSL_FUNC_SIGNATURE_DIGEST_SIGN_INIT,
1091       (void (*)(void))rsa_digest_signverify_init },
1092     { OSSL_FUNC_SIGNATURE_DIGEST_SIGN_UPDATE,
1093       (void (*)(void))rsa_digest_signverify_update },
1094     { OSSL_FUNC_SIGNATURE_DIGEST_SIGN_FINAL,
1095       (void (*)(void))rsa_digest_sign_final },
1096     { OSSL_FUNC_SIGNATURE_DIGEST_VERIFY_INIT,
1097       (void (*)(void))rsa_digest_signverify_init },
1098     { OSSL_FUNC_SIGNATURE_DIGEST_VERIFY_UPDATE,
1099       (void (*)(void))rsa_digest_signverify_update },
1100     { OSSL_FUNC_SIGNATURE_DIGEST_VERIFY_FINAL,
1101       (void (*)(void))rsa_digest_verify_final },
1102     { OSSL_FUNC_SIGNATURE_FREECTX, (void (*)(void))rsa_freectx },
1103     { OSSL_FUNC_SIGNATURE_DUPCTX, (void (*)(void))rsa_dupctx },
1104     { OSSL_FUNC_SIGNATURE_GET_CTX_PARAMS, (void (*)(void))rsa_get_ctx_params },
1105     { OSSL_FUNC_SIGNATURE_GETTABLE_CTX_PARAMS,
1106       (void (*)(void))rsa_gettable_ctx_params },
1107     { OSSL_FUNC_SIGNATURE_SET_CTX_PARAMS, (void (*)(void))rsa_set_ctx_params },
1108     { OSSL_FUNC_SIGNATURE_SETTABLE_CTX_PARAMS,
1109       (void (*)(void))rsa_settable_ctx_params },
1110     { OSSL_FUNC_SIGNATURE_GET_CTX_MD_PARAMS,
1111       (void (*)(void))rsa_get_ctx_md_params },
1112     { OSSL_FUNC_SIGNATURE_GETTABLE_CTX_MD_PARAMS,
1113       (void (*)(void))rsa_gettable_ctx_md_params },
1114     { OSSL_FUNC_SIGNATURE_SET_CTX_MD_PARAMS,
1115       (void (*)(void))rsa_set_ctx_md_params },
1116     { OSSL_FUNC_SIGNATURE_SETTABLE_CTX_MD_PARAMS,
1117       (void (*)(void))rsa_settable_ctx_md_params },
1118     { 0, NULL }
1119 };