PROV: add RSA signature implementation
[openssl.git] / providers / implementations / signature / rsa.c
1 /*
2  * Copyright 2019 The OpenSSL Project Authors. All Rights Reserved.
3  *
4  * Licensed under the Apache License 2.0 (the "License").  You may not use
5  * this file except in compliance with the License.  You can obtain a copy
6  * in the file LICENSE in the source distribution or at
7  * https://www.openssl.org/source/license.html
8  */
9
10 /*
11  * RSA low level APIs are deprecated for public use, but still ok for
12  * internal use.
13  */
14 #include "internal/deprecated.h"
15
16 #include <string.h>
17 #include <openssl/crypto.h>
18 #include <openssl/core_numbers.h>
19 #include <openssl/core_names.h>
20 #include <openssl/err.h>
21 #include <openssl/rsa.h>
22 #include <openssl/params.h>
23 #include <openssl/evp.h>
24 #include "internal/cryptlib.h"
25 #include "internal/nelem.h"
26 #include "internal/sizes.h"
27 #include "crypto/rsa.h"
28 #include "prov/providercommonerr.h"
29 #include "prov/implementations.h"
30 #include "prov/provider_ctx.h"
31
32 static OSSL_OP_signature_newctx_fn rsa_newctx;
33 static OSSL_OP_signature_sign_init_fn rsa_signature_init;
34 static OSSL_OP_signature_verify_init_fn rsa_signature_init;
35 static OSSL_OP_signature_verify_recover_init_fn rsa_signature_init;
36 static OSSL_OP_signature_sign_fn rsa_sign;
37 static OSSL_OP_signature_verify_fn rsa_verify;
38 static OSSL_OP_signature_verify_recover_fn rsa_verify_recover;
39 static OSSL_OP_signature_digest_sign_init_fn rsa_digest_signverify_init;
40 static OSSL_OP_signature_digest_sign_update_fn rsa_digest_signverify_update;
41 static OSSL_OP_signature_digest_sign_final_fn rsa_digest_sign_final;
42 static OSSL_OP_signature_digest_verify_init_fn rsa_digest_signverify_init;
43 static OSSL_OP_signature_digest_verify_update_fn rsa_digest_signverify_update;
44 static OSSL_OP_signature_digest_verify_final_fn rsa_digest_verify_final;
45 static OSSL_OP_signature_freectx_fn rsa_freectx;
46 static OSSL_OP_signature_dupctx_fn rsa_dupctx;
47 static OSSL_OP_signature_get_ctx_params_fn rsa_get_ctx_params;
48 static OSSL_OP_signature_gettable_ctx_params_fn rsa_gettable_ctx_params;
49 static OSSL_OP_signature_set_ctx_params_fn rsa_set_ctx_params;
50 static OSSL_OP_signature_settable_ctx_params_fn rsa_settable_ctx_params;
51 static OSSL_OP_signature_get_ctx_md_params_fn rsa_get_ctx_md_params;
52 static OSSL_OP_signature_gettable_ctx_md_params_fn rsa_gettable_ctx_md_params;
53 static OSSL_OP_signature_set_ctx_md_params_fn rsa_set_ctx_md_params;
54 static OSSL_OP_signature_settable_ctx_md_params_fn rsa_settable_ctx_md_params;
55
56 static OSSL_ITEM padding_item[] = {
57     { RSA_PKCS1_PADDING,        "pkcs1"  },
58     { RSA_SSLV23_PADDING,       "sslv23" },
59     { RSA_NO_PADDING,           "none"   },
60     { RSA_PKCS1_OAEP_PADDING,   "oaep"   }, /* Correct spelling first */
61     { RSA_PKCS1_OAEP_PADDING,   "oeap"   },
62     { RSA_X931_PADDING,         "x931"   },
63     { RSA_PKCS1_PSS_PADDING,    "pss"    },
64     { 0,                        NULL     }
65 };
66
67 /*
68  * What's passed as an actual key is defined by the KEYMGMT interface.
69  * We happen to know that our KEYMGMT simply passes RSA structures, so
70  * we use that here too.
71  */
72
73 typedef struct {
74     OPENSSL_CTX *libctx;
75     RSA *rsa;
76
77     /*
78      * Flag to determine if the hash function can be changed (1) or not (0)
79      * Because it's dangerous to change during a DigestSign or DigestVerify
80      * operation, this flag is cleared by their Init function, and set again
81      * by their Final function.
82      */
83     unsigned int flag_allow_md : 1;
84
85     /* The Algorithm Identifier of the combined signature agorithm */
86     unsigned char aid[128];
87     size_t  aid_len;
88
89     /* main digest */
90     EVP_MD *md;
91     EVP_MD_CTX *mdctx;
92     int mdnid;
93     char mdname[OSSL_MAX_NAME_SIZE]; /* Purely informational */
94
95     /* RSA padding mode */
96     int pad_mode;
97     /* message digest for MGF1 */
98     EVP_MD *mgf1_md;
99     char mgf1_mdname[OSSL_MAX_NAME_SIZE]; /* Purely informational */
100     /* PSS salt length */
101     int saltlen;
102     /* Minimum salt length or -1 if no PSS parameter restriction */
103     int min_saltlen;
104
105     /* Temp buffer */
106     unsigned char *tbuf;
107
108 } PROV_RSA_CTX;
109
110 static size_t rsa_get_md_size(const PROV_RSA_CTX *prsactx)
111 {
112     if (prsactx->md != NULL)
113         return EVP_MD_size(prsactx->md);
114     return 0;
115 }
116
117 static int rsa_get_md_nid(const EVP_MD *md)
118 {
119     /*
120      * Because the RSA library deals with NIDs, we need to translate.
121      * We do so using EVP_MD_is_a(), and therefore need a name to NID
122      * map.
123      */
124     static const OSSL_ITEM name_to_nid[] = {
125         { NID_sha1,      OSSL_DIGEST_NAME_SHA1      },
126         { NID_sha224,    OSSL_DIGEST_NAME_SHA2_224  },
127         { NID_sha256,    OSSL_DIGEST_NAME_SHA2_256  },
128         { NID_sha384,    OSSL_DIGEST_NAME_SHA2_384  },
129         { NID_sha512,    OSSL_DIGEST_NAME_SHA2_512  },
130         { NID_md5,       OSSL_DIGEST_NAME_MD5       },
131         { NID_md5_sha1,  OSSL_DIGEST_NAME_MD5_SHA1  },
132         { NID_md2,       OSSL_DIGEST_NAME_MD2       },
133         { NID_md4,       OSSL_DIGEST_NAME_MD4       },
134         { NID_mdc2,      OSSL_DIGEST_NAME_MDC2      },
135         { NID_ripemd160, OSSL_DIGEST_NAME_RIPEMD160 },
136         { NID_sha3_224,  OSSL_DIGEST_NAME_SHA3_224  },
137         { NID_sha3_256,  OSSL_DIGEST_NAME_SHA3_256  },
138         { NID_sha3_384,  OSSL_DIGEST_NAME_SHA3_384  },
139         { NID_sha3_512,  OSSL_DIGEST_NAME_SHA3_512  },
140     };
141     size_t i;
142     int mdnid = NID_undef;
143
144     if (md == NULL)
145         goto end;
146
147     for (i = 0; i < OSSL_NELEM(name_to_nid); i++) {
148         if (EVP_MD_is_a(md, name_to_nid[i].ptr)) {
149             mdnid = (int)name_to_nid[i].id;
150             break;
151         }
152     }
153
154     if (mdnid == NID_undef)
155         ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_DIGEST);
156
157  end:
158     return mdnid;
159 }
160
161 static int rsa_check_padding(int mdnid, int padding)
162 {
163     if (padding == RSA_NO_PADDING) {
164         ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_PADDING_MODE);
165         return 0;
166     }
167
168     if (padding == RSA_X931_PADDING) {
169         if (RSA_X931_hash_id(mdnid) == -1) {
170             ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_X931_DIGEST);
171             return 0;
172         }
173     }
174
175     return 1;
176 }
177
178 static void *rsa_newctx(void *provctx)
179 {
180     PROV_RSA_CTX *prsactx = OPENSSL_zalloc(sizeof(PROV_RSA_CTX));
181
182     if (prsactx == NULL)
183         return NULL;
184
185     prsactx->libctx = PROV_LIBRARY_CONTEXT_OF(provctx);
186     prsactx->flag_allow_md = 1;
187     return prsactx;
188 }
189
190 /* True if PSS parameters are restricted */
191 #define rsa_pss_restricted(prsactx) (prsactx->min_saltlen != -1)
192
193 static int rsa_signature_init(void *vprsactx, void *vrsa)
194 {
195     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
196
197     if (prsactx == NULL || vrsa == NULL || !RSA_up_ref(vrsa))
198         return 0;
199
200     RSA_free(prsactx->rsa);
201     prsactx->rsa = vrsa;
202     if (RSA_get0_pss_params(prsactx->rsa) != NULL)
203         prsactx->pad_mode = RSA_PKCS1_PSS_PADDING;
204     else
205         prsactx->pad_mode = RSA_PKCS1_PADDING;
206     /* Maximum for sign, auto for verify */
207     prsactx->saltlen = RSA_PSS_SALTLEN_AUTO;
208     prsactx->min_saltlen = -1;
209
210     return 1;
211 }
212
213 static int rsa_setup_md(PROV_RSA_CTX *ctx, const char *mdname,
214                         const char *mdprops)
215 {
216     if (mdname != NULL) {
217         EVP_MD *md = EVP_MD_fetch(ctx->libctx, mdname, mdprops);
218         int md_nid = rsa_get_md_nid(md);
219         size_t algorithmidentifier_len = 0;
220         const unsigned char *algorithmidentifier = NULL;
221
222         if (md == NULL)
223             return 0;
224
225         if (!rsa_check_padding(md_nid, ctx->pad_mode)) {
226             EVP_MD_free(md);
227             return 0;
228         }
229
230         EVP_MD_CTX_free(ctx->mdctx);
231         EVP_MD_free(ctx->md);
232         ctx->md = NULL;
233         ctx->mdctx = NULL;
234         ctx->mdname[0] = '\0';
235         ctx->aid[0] = '\0';
236         ctx->aid_len = 0;
237
238         algorithmidentifier =
239             rsa_algorithmidentifier_encoding(md_nid, &algorithmidentifier_len);
240
241         ctx->md = md;
242         ctx->mdnid = md_nid;
243         OPENSSL_strlcpy(ctx->mdname, mdname, sizeof(ctx->mdname));
244         if (algorithmidentifier != NULL) {
245             memcpy(ctx->aid, algorithmidentifier, algorithmidentifier_len);
246             ctx->aid_len = algorithmidentifier_len;
247         }
248     }
249
250     return 1;
251 }
252
253 static int rsa_setup_mgf1_md(PROV_RSA_CTX *ctx, const char *mdname,
254                              const char *props)
255 {
256     if (ctx->mgf1_mdname[0] != '\0')
257         EVP_MD_free(ctx->mgf1_md);
258
259     if ((ctx->mgf1_md = EVP_MD_fetch(ctx->libctx, mdname, props)) == NULL)
260         return 0;
261     OPENSSL_strlcpy(ctx->mgf1_mdname, mdname, sizeof(ctx->mgf1_mdname));
262
263     return 1;
264 }
265
266 static int setup_tbuf(PROV_RSA_CTX *ctx)
267 {
268     if (ctx->tbuf != NULL)
269         return 1;
270     if ((ctx->tbuf = OPENSSL_malloc(RSA_size(ctx->rsa))) == NULL) {
271         ERR_raise(ERR_LIB_PROV, ERR_R_MALLOC_FAILURE);
272         return 0;
273     }
274     return 1;
275 }
276
277 static void clean_tbuf(PROV_RSA_CTX *ctx)
278 {
279     if (ctx->tbuf != NULL)
280         OPENSSL_cleanse(ctx->tbuf, RSA_size(ctx->rsa));
281 }
282
283 static void free_tbuf(PROV_RSA_CTX *ctx)
284 {
285     OPENSSL_clear_free(ctx->tbuf, RSA_size(ctx->rsa));
286     ctx->tbuf = NULL;
287 }
288
289 static int rsa_sign(void *vprsactx, unsigned char *sig, size_t *siglen,
290                     size_t sigsize, const unsigned char *tbs, size_t tbslen)
291 {
292     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
293     int ret;
294     size_t rsasize = RSA_size(prsactx->rsa);
295     size_t mdsize = rsa_get_md_size(prsactx);
296
297     if (sig == NULL) {
298         *siglen = rsasize;
299         return 1;
300     }
301
302     if (sigsize < (size_t)rsasize)
303         return 0;
304
305     if (mdsize != 0) {
306         if (tbslen != mdsize) {
307             ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_DIGEST_LENGTH);
308             return 0;
309         }
310
311         if (EVP_MD_is_a(prsactx->md, OSSL_DIGEST_NAME_MDC2)) {
312             unsigned int sltmp;
313
314             if (prsactx->pad_mode != RSA_PKCS1_PADDING) {
315                 ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_PADDING_MODE,
316                                "only PKCS#1 padding supported with MDC2");
317                 return 0;
318             }
319             ret = RSA_sign_ASN1_OCTET_STRING(0, tbs, tbslen, sig, &sltmp,
320                                              prsactx->rsa);
321
322             if (ret <= 0) {
323                 ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
324                 return 0;
325             }
326             ret = sltmp;
327             goto end;
328         }
329
330         switch (prsactx->pad_mode) {
331         case RSA_X931_PADDING:
332             if ((size_t)RSA_size(prsactx->rsa) < tbslen + 1) {
333                 ERR_raise(ERR_LIB_PROV, PROV_R_KEY_SIZE_TOO_SMALL);
334                 return 0;
335             }
336             if (!setup_tbuf(prsactx)) {
337                 ERR_raise(ERR_LIB_PROV, ERR_R_MALLOC_FAILURE);
338                 return 0;
339             }
340             memcpy(prsactx->tbuf, tbs, tbslen);
341             prsactx->tbuf[tbslen] = RSA_X931_hash_id(prsactx->mdnid);
342             ret = RSA_private_encrypt(tbslen + 1, prsactx->tbuf,
343                                       sig, prsactx->rsa, RSA_X931_PADDING);
344             clean_tbuf(prsactx);
345             break;
346
347         case RSA_PKCS1_PADDING:
348             {
349                 unsigned int sltmp;
350
351                 ret = RSA_sign(prsactx->mdnid, tbs, tbslen, sig, &sltmp,
352                                prsactx->rsa);
353                 if (ret <= 0) {
354                     ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
355                     return 0;
356                 }
357                 ret = sltmp;
358             }
359             break;
360
361         case RSA_PKCS1_PSS_PADDING:
362             /* Check PSS restrictions */
363             if (rsa_pss_restricted(prsactx)) {
364                 switch (prsactx->saltlen) {
365                 case RSA_PSS_SALTLEN_DIGEST:
366                     if (prsactx->min_saltlen > EVP_MD_size(prsactx->md)) {
367                         ERR_raise(ERR_LIB_PROV, PROV_R_PSS_SALTLEN_TOO_SMALL);
368                         return 0;
369                     }
370                     /* FALLTHRU */
371                 default:
372                     if (prsactx->saltlen >= 0
373                         && prsactx->saltlen < prsactx->min_saltlen) {
374                         ERR_raise(ERR_LIB_PROV, PROV_R_PSS_SALTLEN_TOO_SMALL);
375                         return 0;
376                     }
377                     break;
378                 }
379             }
380             if (!setup_tbuf(prsactx))
381                 return 0;
382             if (!RSA_padding_add_PKCS1_PSS_mgf1(prsactx->rsa,
383                                                 prsactx->tbuf, tbs,
384                                                 prsactx->md, prsactx->mgf1_md,
385                                                 prsactx->saltlen)) {
386                 ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
387                 return 0;
388             }
389             ret = RSA_private_encrypt(RSA_size(prsactx->rsa), prsactx->tbuf,
390                                       sig, prsactx->rsa, RSA_NO_PADDING);
391             clean_tbuf(prsactx);
392             break;
393
394         default:
395             ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_PADDING_MODE,
396                            "Only X.931, PKCS#1 v1.5 or PSS padding allowed");
397             return 0;
398         }
399     } else {
400         ret = RSA_private_encrypt(tbslen, tbs, sig, prsactx->rsa,
401                                   prsactx->pad_mode);
402     }
403
404 #ifdef LEGACY_MODE
405  end:
406 #endif
407     if (ret <= 0) {
408         ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
409         return 0;
410     }
411
412     *siglen = ret;
413     return 1;
414 }
415
416 static int rsa_verify_recover(void *vprsactx,
417                               unsigned char *rout,
418                               size_t *routlen,
419                               size_t routsize,
420                               const unsigned char *sig,
421                               size_t siglen)
422 {
423     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
424     int ret;
425
426     if (rout == NULL) {
427         *routlen = RSA_size(prsactx->rsa);
428         return 1;
429     }
430
431     if (prsactx->md != NULL) {
432         switch (prsactx->pad_mode) {
433         case RSA_X931_PADDING:
434             if (!setup_tbuf(prsactx))
435                 return 0;
436             ret = RSA_public_decrypt(siglen, sig, prsactx->tbuf, prsactx->rsa,
437                                      RSA_X931_PADDING);
438             if (ret < 1) {
439                 ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
440                 return 0;
441             }
442             ret--;
443             if (prsactx->tbuf[ret] != RSA_X931_hash_id(prsactx->mdnid)) {
444                 ERR_raise(ERR_LIB_PROV, PROV_R_ALGORITHM_MISMATCH);
445                 return 0;
446             }
447             if (ret != EVP_MD_size(prsactx->md)) {
448                 ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_DIGEST_LENGTH,
449                                "Should be %d, but got %d",
450                                EVP_MD_size(prsactx->md), ret);
451                 return 0;
452             }
453
454             *routlen = ret;
455             if (routsize < (size_t)ret) {
456                 ERR_raise(ERR_LIB_PROV, PROV_R_OUTPUT_BUFFER_TOO_SMALL);
457                 return 0;
458             }
459             memcpy(rout, prsactx->tbuf, ret);
460             break;
461
462         case RSA_PKCS1_PADDING:
463             {
464                 size_t sltmp;
465
466                 ret = int_rsa_verify(prsactx->mdnid, NULL, 0, rout, &sltmp,
467                                      sig, siglen, prsactx->rsa);
468                 if (ret <= 0) {
469                     ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
470                     return 0;
471                 }
472                 ret = sltmp;
473             }
474             break;
475
476         default:
477             ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_PADDING_MODE,
478                            "Only X.931 or PKCS#1 v1.5 padding allowed");
479             return 0;
480         }
481     } else {
482         ret = RSA_public_decrypt(siglen, sig, rout, prsactx->rsa,
483                                  prsactx->pad_mode);
484         if (ret < 0) {
485             ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
486             return 0;
487         }
488     }
489     *routlen = ret;
490     return 1;
491 }
492
493 static int rsa_verify(void *vprsactx, const unsigned char *sig, size_t siglen,
494                       const unsigned char *tbs, size_t tbslen)
495 {
496     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
497     size_t rslen;
498
499     if (prsactx->md != NULL) {
500         switch (prsactx->pad_mode) {
501         case RSA_PKCS1_PADDING:
502             if (!RSA_verify(prsactx->mdnid, tbs, tbslen, sig, siglen,
503                             prsactx->rsa)) {
504                 ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
505                 return 0;
506             }
507             return 1;
508         case RSA_X931_PADDING:
509             if (rsa_verify_recover(prsactx, NULL, &rslen, 0, sig, siglen) <= 0)
510                 return 0;
511             break;
512         case RSA_PKCS1_PSS_PADDING:
513             {
514                 int ret;
515                 size_t mdsize;
516
517                 /* Check PSS restrictions */
518                 if (rsa_pss_restricted(prsactx)) {
519                     switch (prsactx->saltlen) {
520                     case RSA_PSS_SALTLEN_AUTO:
521                         ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_PSS_SALTLEN);
522                         return 0;
523                     case RSA_PSS_SALTLEN_DIGEST:
524                         if (prsactx->min_saltlen > EVP_MD_size(prsactx->md)) {
525                             ERR_raise(ERR_LIB_PROV,
526                                       PROV_R_PSS_SALTLEN_TOO_SMALL);
527                             return 0;
528                         }
529                         /* FALLTHRU */
530                     default:
531                         if (prsactx->saltlen >= 0
532                             && prsactx->saltlen < prsactx->min_saltlen) {
533                             ERR_raise(ERR_LIB_PROV, PROV_R_PSS_SALTLEN_TOO_SMALL);
534                             return 0;
535                         }
536                         break;
537                     }
538                 }
539
540                 /*
541                  * We need to check this for the RSA_verify_PKCS1_PSS_mgf1()
542                  * call
543                  */
544                 mdsize = rsa_get_md_size(prsactx);
545                 if (tbslen != mdsize) {
546                     ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_DIGEST_LENGTH,
547                                    "Should be %d, but got %d",
548                                    mdsize, tbslen);
549                     return 0;
550                 }
551
552                 if (!setup_tbuf(prsactx))
553                     return 0;
554                 ret = RSA_public_decrypt(siglen, sig, prsactx->tbuf,
555                                          prsactx->rsa, RSA_NO_PADDING);
556                 if (ret <= 0) {
557                     ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
558                     return 0;
559                 }
560                 ret = RSA_verify_PKCS1_PSS_mgf1(prsactx->rsa, tbs,
561                                                 prsactx->md, prsactx->mgf1_md,
562                                                 prsactx->tbuf,
563                                                 prsactx->saltlen);
564                 if (ret <= 0) {
565                     ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
566                     return 0;
567                 }
568                 return 1;
569             }
570         default:
571             ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_PADDING_MODE,
572                            "Only X.931, PKCS#1 v1.5 or PSS padding allowed");
573             return 0;
574         }
575     } else {
576         if (!setup_tbuf(prsactx))
577             return 0;
578         rslen = RSA_public_decrypt(siglen, sig, prsactx->tbuf, prsactx->rsa,
579                                    prsactx->pad_mode);
580         if (rslen == 0) {
581             ERR_raise(ERR_LIB_PROV, ERR_LIB_RSA);
582             return 0;
583         }
584     }
585
586     if ((rslen != tbslen) || memcmp(tbs, prsactx->tbuf, rslen))
587         return 0;
588
589     return 1;
590 }
591
592 static int rsa_digest_signverify_init(void *vprsactx, const char *mdname,
593                                       const char *props, void *vrsa)
594 {
595     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
596
597     prsactx->flag_allow_md = 0;
598     if (!rsa_signature_init(vprsactx, vrsa)
599         || !rsa_setup_md(prsactx, mdname, props))
600         return 0;
601
602     prsactx->mdctx = EVP_MD_CTX_new();
603     if (prsactx->mdctx == NULL)
604         goto error;
605
606     if (!EVP_DigestInit_ex(prsactx->mdctx, prsactx->md, NULL))
607         goto error;
608
609     return 1;
610
611  error:
612     EVP_MD_CTX_free(prsactx->mdctx);
613     EVP_MD_free(prsactx->md);
614     prsactx->mdctx = NULL;
615     prsactx->md = NULL;
616     return 0;
617 }
618
619 int rsa_digest_signverify_update(void *vprsactx, const unsigned char *data,
620                                  size_t datalen)
621 {
622     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
623
624     if (prsactx == NULL || prsactx->mdctx == NULL)
625         return 0;
626
627     return EVP_DigestUpdate(prsactx->mdctx, data, datalen);
628 }
629
630 int rsa_digest_sign_final(void *vprsactx, unsigned char *sig, size_t *siglen,
631                           size_t sigsize)
632 {
633     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
634     unsigned char digest[EVP_MAX_MD_SIZE];
635     unsigned int dlen = 0;
636
637     prsactx->flag_allow_md = 1;
638     if (prsactx == NULL || prsactx->mdctx == NULL)
639         return 0;
640
641     /*
642      * If sig is NULL then we're just finding out the sig size. Other fields
643      * are ignored. Defer to rsa_sign.
644      */
645     if (sig != NULL) {
646         /*
647          * TODO(3.0): There is the possibility that some externally provided
648          * digests exceed EVP_MAX_MD_SIZE. We should probably handle that somehow -
649          * but that problem is much larger than just in RSA.
650          */
651         if (!EVP_DigestFinal_ex(prsactx->mdctx, digest, &dlen))
652             return 0;
653     }
654
655     return rsa_sign(vprsactx, sig, siglen, sigsize, digest, (size_t)dlen);
656 }
657
658
659 int rsa_digest_verify_final(void *vprsactx, const unsigned char *sig,
660                             size_t siglen)
661 {
662     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
663     unsigned char digest[EVP_MAX_MD_SIZE];
664     unsigned int dlen = 0;
665
666     prsactx->flag_allow_md = 1;
667     if (prsactx == NULL || prsactx->mdctx == NULL)
668         return 0;
669
670     /*
671      * TODO(3.0): There is the possibility that some externally provided
672      * digests exceed EVP_MAX_MD_SIZE. We should probably handle that somehow -
673      * but that problem is much larger than just in RSA.
674      */
675     if (!EVP_DigestFinal_ex(prsactx->mdctx, digest, &dlen))
676         return 0;
677
678     return rsa_verify(vprsactx, sig, siglen, digest, (size_t)dlen);
679 }
680
681 static void rsa_freectx(void *vprsactx)
682 {
683     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
684
685     if (prsactx == NULL)
686         return;
687
688     RSA_free(prsactx->rsa);
689     EVP_MD_CTX_free(prsactx->mdctx);
690     EVP_MD_free(prsactx->md);
691     EVP_MD_free(prsactx->mgf1_md);
692     free_tbuf(prsactx);
693
694     OPENSSL_clear_free(prsactx, sizeof(prsactx));
695 }
696
697 static void *rsa_dupctx(void *vprsactx)
698 {
699     PROV_RSA_CTX *srcctx = (PROV_RSA_CTX *)vprsactx;
700     PROV_RSA_CTX *dstctx;
701
702     dstctx = OPENSSL_zalloc(sizeof(*srcctx));
703     if (dstctx == NULL)
704         return NULL;
705
706     *dstctx = *srcctx;
707     dstctx->rsa = NULL;
708     dstctx->md = NULL;
709     dstctx->mdctx = NULL;
710     dstctx->tbuf = NULL;
711
712     if (srcctx->rsa != NULL && !RSA_up_ref(srcctx->rsa))
713         goto err;
714     dstctx->rsa = srcctx->rsa;
715
716     if (srcctx->md != NULL && !EVP_MD_up_ref(srcctx->md))
717         goto err;
718     dstctx->md = srcctx->md;
719
720     if (srcctx->mgf1_md != NULL && !EVP_MD_up_ref(srcctx->mgf1_md))
721         goto err;
722     dstctx->mgf1_md = srcctx->mgf1_md;
723
724     if (srcctx->mdctx != NULL) {
725         dstctx->mdctx = EVP_MD_CTX_new();
726         if (dstctx->mdctx == NULL
727                 || !EVP_MD_CTX_copy_ex(dstctx->mdctx, srcctx->mdctx))
728             goto err;
729     }
730
731     return dstctx;
732  err:
733     rsa_freectx(dstctx);
734     return NULL;
735 }
736
737 static int rsa_get_ctx_params(void *vprsactx, OSSL_PARAM *params)
738 {
739     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
740     OSSL_PARAM *p;
741
742     if (prsactx == NULL || params == NULL)
743         return 0;
744
745     p = OSSL_PARAM_locate(params, OSSL_SIGNATURE_PARAM_ALGORITHM_ID);
746     if (p != NULL
747         && !OSSL_PARAM_set_octet_string(p, prsactx->aid, prsactx->aid_len))
748         return 0;
749
750     p = OSSL_PARAM_locate(params, OSSL_SIGNATURE_PARAM_PAD_MODE);
751     if (p != NULL)
752         switch (p->data_type) {
753         case OSSL_PARAM_INTEGER:
754             if (!OSSL_PARAM_set_int(p, prsactx->pad_mode))
755                 return 0;
756             break;
757         case OSSL_PARAM_UTF8_STRING:
758             {
759                 int i;
760                 const char *word = NULL;
761
762                 for (i = 0; padding_item[i].id != 0; i++) {
763                     if (prsactx->pad_mode == (int)padding_item[i].id) {
764                         word = padding_item[i].ptr;
765                         break;
766                     }
767                 }
768
769                 if (word != NULL) {
770                     if (!OSSL_PARAM_set_utf8_string(p, word))
771                         return 0;
772                 } else {
773                     ERR_raise(ERR_LIB_PROV, ERR_R_INTERNAL_ERROR);
774                 }
775             }
776             break;
777         default:
778             return 0;
779         }
780
781     p = OSSL_PARAM_locate(params, OSSL_SIGNATURE_PARAM_DIGEST);
782     if (p != NULL && !OSSL_PARAM_set_utf8_string(p, prsactx->mdname))
783         return 0;
784
785     p = OSSL_PARAM_locate(params, OSSL_SIGNATURE_PARAM_MGF1_DIGEST);
786     if (p != NULL && !OSSL_PARAM_set_utf8_string(p, prsactx->mgf1_mdname))
787         return 0;
788
789     p = OSSL_PARAM_locate(params, OSSL_SIGNATURE_PARAM_PSS_SALTLEN);
790     if (p != NULL) {
791         if (p->data_type == OSSL_PARAM_INTEGER) {
792             if (!OSSL_PARAM_set_int(p, prsactx->saltlen))
793                 return 0;
794         } else if (p->data_type == OSSL_PARAM_UTF8_STRING) {
795             switch (prsactx->saltlen) {
796             case RSA_PSS_SALTLEN_DIGEST:
797                 if (!OSSL_PARAM_set_utf8_string(p, "digest"))
798                     return 0;
799                 break;
800             case RSA_PSS_SALTLEN_MAX:
801                 if (!OSSL_PARAM_set_utf8_string(p, "max"))
802                     return 0;
803                 break;
804             case RSA_PSS_SALTLEN_AUTO:
805                 if (!OSSL_PARAM_set_utf8_string(p, "auto"))
806                     return 0;
807                 break;
808             default:
809                 if (BIO_snprintf(p->data, p->data_size, "%d", prsactx->saltlen)
810                     <= 0)
811                     return 0;
812                 break;
813             }
814         }
815     }
816
817     return 1;
818 }
819
820 static const OSSL_PARAM known_gettable_ctx_params[] = {
821     OSSL_PARAM_octet_string(OSSL_SIGNATURE_PARAM_ALGORITHM_ID, NULL, 0),
822     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_PAD_MODE, NULL, 0),
823     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_DIGEST, NULL, 0),
824     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_MGF1_DIGEST, NULL, 0),
825     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_PSS_SALTLEN, NULL, 0),
826     OSSL_PARAM_END
827 };
828
829 static const OSSL_PARAM *rsa_gettable_ctx_params(void)
830 {
831     return known_gettable_ctx_params;
832 }
833
834 static int rsa_set_ctx_params(void *vprsactx, const OSSL_PARAM params[])
835 {
836     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
837     const OSSL_PARAM *p;
838
839     if (prsactx == NULL || params == NULL)
840         return 0;
841
842     p = OSSL_PARAM_locate_const(params, OSSL_SIGNATURE_PARAM_DIGEST);
843     /* Not allowed during certain operations */
844     if (p != NULL && !prsactx->flag_allow_md)
845         return 0;
846     if (p != NULL) {
847         char mdname[OSSL_MAX_NAME_SIZE] = "", *pmdname = mdname;
848         char mdprops[OSSL_MAX_PROPQUERY_SIZE] = "", *pmdprops = mdprops;
849         const OSSL_PARAM *propsp =
850             OSSL_PARAM_locate_const(params,
851                                     OSSL_SIGNATURE_PARAM_PROPERTIES);
852
853         if (!OSSL_PARAM_get_utf8_string(p, &pmdname, sizeof(mdname)))
854             return 0;
855         if (propsp != NULL
856             && !OSSL_PARAM_get_utf8_string(propsp, &pmdprops, sizeof(mdprops)))
857             return 0;
858
859         /* TODO(3.0) PSS check needs more work */
860         if (rsa_pss_restricted(prsactx)) {
861             /* TODO(3.0) figure out what to do for prsactx->md == NULL */
862             if (prsactx->md == NULL || EVP_MD_is_a(prsactx->md, mdname))
863                 return 1;
864             ERR_raise(ERR_LIB_PROV, PROV_R_DIGEST_NOT_ALLOWED);
865             return 0;
866         }
867
868         /* non-PSS code follows */
869         if (!rsa_setup_md(prsactx, mdname, mdprops))
870             return 0;
871     }
872
873     p = OSSL_PARAM_locate_const(params, OSSL_SIGNATURE_PARAM_PAD_MODE);
874     if (p != NULL) {
875         int pad_mode = 0;
876
877         switch (p->data_type) {
878         case OSSL_PARAM_INTEGER: /* Support for legacy pad mode number */
879             if (!OSSL_PARAM_get_int(p, &pad_mode))
880                 return 0;
881             break;
882         case OSSL_PARAM_UTF8_STRING:
883             {
884                 int i;
885
886                 if (p->data == NULL)
887                     return 0;
888
889                 for (i = 0; padding_item[i].id != 0; i++) {
890                     if (strcmp(p->data, padding_item[i].ptr) == 0) {
891                         pad_mode = padding_item[i].id;
892                         break;
893                     }
894                 }
895             }
896             break;
897         default:
898             return 0;
899         }
900
901         switch (pad_mode) {
902         case RSA_PKCS1_OAEP_PADDING:
903             /*
904              * OAEP padding is for asymmetric cipher only so is not compatible
905              * with signature use.
906              */
907             ERR_raise_data(ERR_LIB_PROV,
908                            PROV_R_ILLEGAL_OR_UNSUPPORTED_PADDING_MODE,
909                            "OAEP padding not allowed for signing / verifying");
910             return 0;
911         case RSA_PKCS1_PSS_PADDING:
912             if (prsactx->mdname[0] == '\0')
913                 rsa_setup_md(prsactx, "SHA1", "");
914             goto cont;
915         case RSA_PKCS1_PADDING:
916         case RSA_SSLV23_PADDING:
917         case RSA_NO_PADDING:
918         case RSA_X931_PADDING:
919             if (RSA_get0_pss_params(prsactx->rsa) != NULL) {
920                 ERR_raise_data(ERR_LIB_PROV,
921                                PROV_R_ILLEGAL_OR_UNSUPPORTED_PADDING_MODE,
922                                "X.931 padding not allowed with RSA-PSS");
923                 return 0;
924             }
925         cont:
926             if (!rsa_check_padding(prsactx->mdnid, pad_mode))
927                 return 0;
928             break;
929         default:
930             return 0;
931         }
932         prsactx->pad_mode = pad_mode;
933     }
934
935     p = OSSL_PARAM_locate_const(params, OSSL_SIGNATURE_PARAM_PSS_SALTLEN);
936     if (p != NULL) {
937         int saltlen;
938
939         if (prsactx->pad_mode != RSA_PKCS1_PSS_PADDING) {
940             ERR_raise_data(ERR_LIB_PROV, PROV_R_NOT_SUPPORTED,
941                            "PSS saltlen can only be specified if "
942                            "PSS padding has been specified first");
943             return 0;
944         }
945
946         switch (p->data_type) {
947         case OSSL_PARAM_INTEGER: /* Support for legacy pad mode number */
948             if (!OSSL_PARAM_get_int(p, &saltlen))
949                 return 0;
950             break;
951         case OSSL_PARAM_UTF8_STRING:
952             if (strcmp(p->data, "digest") == 0)
953                 saltlen = RSA_PSS_SALTLEN_DIGEST;
954             else if (strcmp(p->data, "max") == 0)
955                 saltlen = RSA_PSS_SALTLEN_MAX;
956             else if (strcmp(p->data, "auto") == 0)
957                 saltlen = RSA_PSS_SALTLEN_AUTO;
958             else
959                 saltlen = atoi(p->data);
960             break;
961         default:
962             return 0;
963         }
964
965         /*
966          * RSA_PSS_SALTLEN_MAX seems curiously named in this check.
967          * Contrary to what it's name suggests, it's the currently
968          * lowest saltlen number possible.
969          */
970         if (saltlen < RSA_PSS_SALTLEN_MAX) {
971             ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_PSS_SALTLEN);
972             return 0;
973         }
974
975         prsactx->saltlen = saltlen;
976     }
977
978     p = OSSL_PARAM_locate_const(params, OSSL_SIGNATURE_PARAM_MGF1_DIGEST);
979     if (p != NULL) {
980         char mdname[OSSL_MAX_NAME_SIZE] = "", *pmdname = mdname;
981         char mdprops[OSSL_MAX_PROPQUERY_SIZE] = "", *pmdprops = mdprops;
982         const OSSL_PARAM *propsp =
983             OSSL_PARAM_locate_const(params,
984                                     OSSL_SIGNATURE_PARAM_MGF1_PROPERTIES);
985
986         if (!OSSL_PARAM_get_utf8_string(p, &pmdname, sizeof(mdname)))
987             return 0;
988         if (propsp != NULL
989             && !OSSL_PARAM_get_utf8_string(propsp, &pmdprops, sizeof(mdprops)))
990             return 0;
991
992         if (prsactx->pad_mode != RSA_PKCS1_PSS_PADDING) {
993             ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_MGF1_MD);
994             return  0;
995         }
996
997         /* TODO(3.0) PSS check needs more work */
998         if (rsa_pss_restricted(prsactx)) {
999             /* TODO(3.0) figure out what to do for prsactx->md == NULL */
1000             if (prsactx->mgf1_md == NULL
1001                 || EVP_MD_is_a(prsactx->mgf1_md, mdname))
1002                 return 1;
1003             ERR_raise(ERR_LIB_PROV, PROV_R_DIGEST_NOT_ALLOWED);
1004             return 0;
1005         }
1006
1007         /* non-PSS code follows */
1008         if (!rsa_setup_mgf1_md(prsactx, mdname, mdprops))
1009             return 0;
1010     }
1011
1012     return 1;
1013 }
1014
1015 static const OSSL_PARAM known_settable_ctx_params[] = {
1016     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_PAD_MODE, NULL, 0),
1017     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_DIGEST, NULL, 0),
1018     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_PROPERTIES, NULL, 0),
1019     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_MGF1_DIGEST, NULL, 0),
1020     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_MGF1_PROPERTIES, NULL, 0),
1021     OSSL_PARAM_utf8_string(OSSL_SIGNATURE_PARAM_PSS_SALTLEN, NULL, 0),
1022     OSSL_PARAM_END
1023 };
1024
1025 static const OSSL_PARAM *rsa_settable_ctx_params(void)
1026 {
1027     /*
1028      * TODO(3.0): Should this function return a different set of settable ctx
1029      * params if the ctx is being used for a DigestSign/DigestVerify? In that
1030      * case it is not allowed to set the digest size/digest name because the
1031      * digest is explicitly set as part of the init.
1032      */
1033     return known_settable_ctx_params;
1034 }
1035
1036 static int rsa_get_ctx_md_params(void *vprsactx, OSSL_PARAM *params)
1037 {
1038     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
1039
1040     if (prsactx->mdctx == NULL)
1041         return 0;
1042
1043     return EVP_MD_CTX_get_params(prsactx->mdctx, params);
1044 }
1045
1046 static const OSSL_PARAM *rsa_gettable_ctx_md_params(void *vprsactx)
1047 {
1048     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
1049
1050     if (prsactx->md == NULL)
1051         return 0;
1052
1053     return EVP_MD_gettable_ctx_params(prsactx->md);
1054 }
1055
1056 static int rsa_set_ctx_md_params(void *vprsactx, const OSSL_PARAM params[])
1057 {
1058     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
1059
1060     if (prsactx->mdctx == NULL)
1061         return 0;
1062
1063     return EVP_MD_CTX_set_params(prsactx->mdctx, params);
1064 }
1065
1066 static const OSSL_PARAM *rsa_settable_ctx_md_params(void *vprsactx)
1067 {
1068     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
1069
1070     if (prsactx->md == NULL)
1071         return 0;
1072
1073     return EVP_MD_settable_ctx_params(prsactx->md);
1074 }
1075
1076 const OSSL_DISPATCH rsa_signature_functions[] = {
1077     { OSSL_FUNC_SIGNATURE_NEWCTX, (void (*)(void))rsa_newctx },
1078     { OSSL_FUNC_SIGNATURE_SIGN_INIT, (void (*)(void))rsa_signature_init },
1079     { OSSL_FUNC_SIGNATURE_SIGN, (void (*)(void))rsa_sign },
1080     { OSSL_FUNC_SIGNATURE_VERIFY_INIT, (void (*)(void))rsa_signature_init },
1081     { OSSL_FUNC_SIGNATURE_VERIFY, (void (*)(void))rsa_verify },
1082     { OSSL_FUNC_SIGNATURE_VERIFY_RECOVER_INIT, (void (*)(void))rsa_signature_init },
1083     { OSSL_FUNC_SIGNATURE_VERIFY_RECOVER, (void (*)(void))rsa_verify_recover },
1084     { OSSL_FUNC_SIGNATURE_DIGEST_SIGN_INIT,
1085       (void (*)(void))rsa_digest_signverify_init },
1086     { OSSL_FUNC_SIGNATURE_DIGEST_SIGN_UPDATE,
1087       (void (*)(void))rsa_digest_signverify_update },
1088     { OSSL_FUNC_SIGNATURE_DIGEST_SIGN_FINAL,
1089       (void (*)(void))rsa_digest_sign_final },
1090     { OSSL_FUNC_SIGNATURE_DIGEST_VERIFY_INIT,
1091       (void (*)(void))rsa_digest_signverify_init },
1092     { OSSL_FUNC_SIGNATURE_DIGEST_VERIFY_UPDATE,
1093       (void (*)(void))rsa_digest_signverify_update },
1094     { OSSL_FUNC_SIGNATURE_DIGEST_VERIFY_FINAL,
1095       (void (*)(void))rsa_digest_verify_final },
1096     { OSSL_FUNC_SIGNATURE_FREECTX, (void (*)(void))rsa_freectx },
1097     { OSSL_FUNC_SIGNATURE_DUPCTX, (void (*)(void))rsa_dupctx },
1098     { OSSL_FUNC_SIGNATURE_GET_CTX_PARAMS, (void (*)(void))rsa_get_ctx_params },
1099     { OSSL_FUNC_SIGNATURE_GETTABLE_CTX_PARAMS,
1100       (void (*)(void))rsa_gettable_ctx_params },
1101     { OSSL_FUNC_SIGNATURE_SET_CTX_PARAMS, (void (*)(void))rsa_set_ctx_params },
1102     { OSSL_FUNC_SIGNATURE_SETTABLE_CTX_PARAMS,
1103       (void (*)(void))rsa_settable_ctx_params },
1104     { OSSL_FUNC_SIGNATURE_GET_CTX_MD_PARAMS,
1105       (void (*)(void))rsa_get_ctx_md_params },
1106     { OSSL_FUNC_SIGNATURE_GETTABLE_CTX_MD_PARAMS,
1107       (void (*)(void))rsa_gettable_ctx_md_params },
1108     { OSSL_FUNC_SIGNATURE_SET_CTX_MD_PARAMS,
1109       (void (*)(void))rsa_set_ctx_md_params },
1110     { OSSL_FUNC_SIGNATURE_SETTABLE_CTX_MD_PARAMS,
1111       (void (*)(void))rsa_settable_ctx_md_params },
1112     { 0, NULL }
1113 };