rsa: add ossl_ prefix to internal rsa_ calls.
[openssl.git] / crypto / rsa / rsa_ossl.c
1 /*
2  * Copyright 1995-2020 The OpenSSL Project Authors. All Rights Reserved.
3  *
4  * Licensed under the Apache License 2.0 (the "License").  You may not use
5  * this file except in compliance with the License.  You can obtain a copy
6  * in the file LICENSE in the source distribution or at
7  * https://www.openssl.org/source/license.html
8  */
9
10 /*
11  * RSA low level APIs are deprecated for public use, but still ok for
12  * internal use.
13  */
14 #include "internal/deprecated.h"
15
16 #include "internal/cryptlib.h"
17 #include "crypto/bn.h"
18 #include "rsa_local.h"
19 #include "internal/constant_time.h"
20
21 static int rsa_ossl_public_encrypt(int flen, const unsigned char *from,
22                                   unsigned char *to, RSA *rsa, int padding);
23 static int rsa_ossl_private_encrypt(int flen, const unsigned char *from,
24                                    unsigned char *to, RSA *rsa, int padding);
25 static int rsa_ossl_public_decrypt(int flen, const unsigned char *from,
26                                   unsigned char *to, RSA *rsa, int padding);
27 static int rsa_ossl_private_decrypt(int flen, const unsigned char *from,
28                                    unsigned char *to, RSA *rsa, int padding);
29 static int rsa_ossl_mod_exp(BIGNUM *r0, const BIGNUM *i, RSA *rsa,
30                            BN_CTX *ctx);
31 static int rsa_ossl_init(RSA *rsa);
32 static int rsa_ossl_finish(RSA *rsa);
33 static RSA_METHOD rsa_pkcs1_ossl_meth = {
34     "OpenSSL PKCS#1 RSA",
35     rsa_ossl_public_encrypt,
36     rsa_ossl_public_decrypt,     /* signature verification */
37     rsa_ossl_private_encrypt,    /* signing */
38     rsa_ossl_private_decrypt,
39     rsa_ossl_mod_exp,
40     BN_mod_exp_mont,            /* XXX probably we should not use Montgomery
41                                  * if e == 3 */
42     rsa_ossl_init,
43     rsa_ossl_finish,
44     RSA_FLAG_FIPS_METHOD,       /* flags */
45     NULL,
46     0,                          /* rsa_sign */
47     0,                          /* rsa_verify */
48     NULL,                       /* rsa_keygen */
49     NULL                        /* rsa_multi_prime_keygen */
50 };
51
52 static const RSA_METHOD *default_RSA_meth = &rsa_pkcs1_ossl_meth;
53
54 void RSA_set_default_method(const RSA_METHOD *meth)
55 {
56     default_RSA_meth = meth;
57 }
58
59 const RSA_METHOD *RSA_get_default_method(void)
60 {
61     return default_RSA_meth;
62 }
63
64 const RSA_METHOD *RSA_PKCS1_OpenSSL(void)
65 {
66     return &rsa_pkcs1_ossl_meth;
67 }
68
69 const RSA_METHOD *RSA_null_method(void)
70 {
71     return NULL;
72 }
73
74 static int rsa_ossl_public_encrypt(int flen, const unsigned char *from,
75                                   unsigned char *to, RSA *rsa, int padding)
76 {
77     BIGNUM *f, *ret;
78     int i, num = 0, r = -1;
79     unsigned char *buf = NULL;
80     BN_CTX *ctx = NULL;
81
82     if (BN_num_bits(rsa->n) > OPENSSL_RSA_MAX_MODULUS_BITS) {
83         RSAerr(RSA_F_RSA_OSSL_PUBLIC_ENCRYPT, RSA_R_MODULUS_TOO_LARGE);
84         return -1;
85     }
86
87     if (BN_ucmp(rsa->n, rsa->e) <= 0) {
88         RSAerr(RSA_F_RSA_OSSL_PUBLIC_ENCRYPT, RSA_R_BAD_E_VALUE);
89         return -1;
90     }
91
92     /* for large moduli, enforce exponent limit */
93     if (BN_num_bits(rsa->n) > OPENSSL_RSA_SMALL_MODULUS_BITS) {
94         if (BN_num_bits(rsa->e) > OPENSSL_RSA_MAX_PUBEXP_BITS) {
95             RSAerr(RSA_F_RSA_OSSL_PUBLIC_ENCRYPT, RSA_R_BAD_E_VALUE);
96             return -1;
97         }
98     }
99
100     if ((ctx = BN_CTX_new_ex(rsa->libctx)) == NULL)
101         goto err;
102     BN_CTX_start(ctx);
103     f = BN_CTX_get(ctx);
104     ret = BN_CTX_get(ctx);
105     num = BN_num_bytes(rsa->n);
106     buf = OPENSSL_malloc(num);
107     if (ret == NULL || buf == NULL) {
108         RSAerr(RSA_F_RSA_OSSL_PUBLIC_ENCRYPT, ERR_R_MALLOC_FAILURE);
109         goto err;
110     }
111
112     switch (padding) {
113     case RSA_PKCS1_PADDING:
114         i = ossl_rsa_padding_add_PKCS1_type_2_ex(rsa->libctx, buf, num,
115                                                  from, flen);
116         break;
117     case RSA_PKCS1_OAEP_PADDING:
118         i = ossl_rsa_padding_add_PKCS1_OAEP_mgf1_ex(rsa->libctx, buf, num,
119                                                     from, flen, NULL, 0,
120                                                     NULL, NULL);
121         break;
122 #ifndef FIPS_MODULE
123     case RSA_SSLV23_PADDING:
124         i = ossl_rsa_padding_add_SSLv23_ex(rsa->libctx, buf, num, from, flen);
125         break;
126 #endif
127     case RSA_NO_PADDING:
128         i = RSA_padding_add_none(buf, num, from, flen);
129         break;
130     default:
131         RSAerr(RSA_F_RSA_OSSL_PUBLIC_ENCRYPT, RSA_R_UNKNOWN_PADDING_TYPE);
132         goto err;
133     }
134     if (i <= 0)
135         goto err;
136
137     if (BN_bin2bn(buf, num, f) == NULL)
138         goto err;
139
140     if (BN_ucmp(f, rsa->n) >= 0) {
141         /* usually the padding functions would catch this */
142         RSAerr(RSA_F_RSA_OSSL_PUBLIC_ENCRYPT,
143                RSA_R_DATA_TOO_LARGE_FOR_MODULUS);
144         goto err;
145     }
146
147     if (rsa->flags & RSA_FLAG_CACHE_PUBLIC)
148         if (!BN_MONT_CTX_set_locked(&rsa->_method_mod_n, rsa->lock,
149                                     rsa->n, ctx))
150             goto err;
151
152     if (!rsa->meth->bn_mod_exp(ret, f, rsa->e, rsa->n, ctx,
153                                rsa->_method_mod_n))
154         goto err;
155
156     /*
157      * BN_bn2binpad puts in leading 0 bytes if the number is less than
158      * the length of the modulus.
159      */
160     r = BN_bn2binpad(ret, to, num);
161  err:
162     BN_CTX_end(ctx);
163     BN_CTX_free(ctx);
164     OPENSSL_clear_free(buf, num);
165     return r;
166 }
167
168 static BN_BLINDING *rsa_get_blinding(RSA *rsa, int *local, BN_CTX *ctx)
169 {
170     BN_BLINDING *ret;
171
172     CRYPTO_THREAD_write_lock(rsa->lock);
173
174     if (rsa->blinding == NULL) {
175         rsa->blinding = RSA_setup_blinding(rsa, ctx);
176     }
177
178     ret = rsa->blinding;
179     if (ret == NULL)
180         goto err;
181
182     if (BN_BLINDING_is_current_thread(ret)) {
183         /* rsa->blinding is ours! */
184
185         *local = 1;
186     } else {
187         /* resort to rsa->mt_blinding instead */
188
189         /*
190          * instructs rsa_blinding_convert(), rsa_blinding_invert() that the
191          * BN_BLINDING is shared, meaning that accesses require locks, and
192          * that the blinding factor must be stored outside the BN_BLINDING
193          */
194         *local = 0;
195
196         if (rsa->mt_blinding == NULL) {
197             rsa->mt_blinding = RSA_setup_blinding(rsa, ctx);
198         }
199         ret = rsa->mt_blinding;
200     }
201
202  err:
203     CRYPTO_THREAD_unlock(rsa->lock);
204     return ret;
205 }
206
207 static int rsa_blinding_convert(BN_BLINDING *b, BIGNUM *f, BIGNUM *unblind,
208                                 BN_CTX *ctx)
209 {
210     if (unblind == NULL) {
211         /*
212          * Local blinding: store the unblinding factor in BN_BLINDING.
213          */
214         return BN_BLINDING_convert_ex(f, NULL, b, ctx);
215     } else {
216         /*
217          * Shared blinding: store the unblinding factor outside BN_BLINDING.
218          */
219         int ret;
220
221         BN_BLINDING_lock(b);
222         ret = BN_BLINDING_convert_ex(f, unblind, b, ctx);
223         BN_BLINDING_unlock(b);
224
225         return ret;
226     }
227 }
228
229 static int rsa_blinding_invert(BN_BLINDING *b, BIGNUM *f, BIGNUM *unblind,
230                                BN_CTX *ctx)
231 {
232     /*
233      * For local blinding, unblind is set to NULL, and BN_BLINDING_invert_ex
234      * will use the unblinding factor stored in BN_BLINDING. If BN_BLINDING
235      * is shared between threads, unblind must be non-null:
236      * BN_BLINDING_invert_ex will then use the local unblinding factor, and
237      * will only read the modulus from BN_BLINDING. In both cases it's safe
238      * to access the blinding without a lock.
239      */
240     return BN_BLINDING_invert_ex(f, unblind, b, ctx);
241 }
242
243 /* signing */
244 static int rsa_ossl_private_encrypt(int flen, const unsigned char *from,
245                                    unsigned char *to, RSA *rsa, int padding)
246 {
247     BIGNUM *f, *ret, *res;
248     int i, num = 0, r = -1;
249     unsigned char *buf = NULL;
250     BN_CTX *ctx = NULL;
251     int local_blinding = 0;
252     /*
253      * Used only if the blinding structure is shared. A non-NULL unblind
254      * instructs rsa_blinding_convert() and rsa_blinding_invert() to store
255      * the unblinding factor outside the blinding structure.
256      */
257     BIGNUM *unblind = NULL;
258     BN_BLINDING *blinding = NULL;
259
260     if ((ctx = BN_CTX_new_ex(rsa->libctx)) == NULL)
261         goto err;
262     BN_CTX_start(ctx);
263     f = BN_CTX_get(ctx);
264     ret = BN_CTX_get(ctx);
265     num = BN_num_bytes(rsa->n);
266     buf = OPENSSL_malloc(num);
267     if (ret == NULL || buf == NULL) {
268         RSAerr(RSA_F_RSA_OSSL_PRIVATE_ENCRYPT, ERR_R_MALLOC_FAILURE);
269         goto err;
270     }
271
272     switch (padding) {
273     case RSA_PKCS1_PADDING:
274         i = RSA_padding_add_PKCS1_type_1(buf, num, from, flen);
275         break;
276     case RSA_X931_PADDING:
277         i = RSA_padding_add_X931(buf, num, from, flen);
278         break;
279     case RSA_NO_PADDING:
280         i = RSA_padding_add_none(buf, num, from, flen);
281         break;
282     case RSA_SSLV23_PADDING:
283     default:
284         RSAerr(RSA_F_RSA_OSSL_PRIVATE_ENCRYPT, RSA_R_UNKNOWN_PADDING_TYPE);
285         goto err;
286     }
287     if (i <= 0)
288         goto err;
289
290     if (BN_bin2bn(buf, num, f) == NULL)
291         goto err;
292
293     if (BN_ucmp(f, rsa->n) >= 0) {
294         /* usually the padding functions would catch this */
295         RSAerr(RSA_F_RSA_OSSL_PRIVATE_ENCRYPT,
296                RSA_R_DATA_TOO_LARGE_FOR_MODULUS);
297         goto err;
298     }
299
300     if (rsa->flags & RSA_FLAG_CACHE_PUBLIC)
301         if (!BN_MONT_CTX_set_locked(&rsa->_method_mod_n, rsa->lock,
302                                     rsa->n, ctx))
303             goto err;
304
305     if (!(rsa->flags & RSA_FLAG_NO_BLINDING)) {
306         blinding = rsa_get_blinding(rsa, &local_blinding, ctx);
307         if (blinding == NULL) {
308             RSAerr(RSA_F_RSA_OSSL_PRIVATE_ENCRYPT, ERR_R_INTERNAL_ERROR);
309             goto err;
310         }
311     }
312
313     if (blinding != NULL) {
314         if (!local_blinding && ((unblind = BN_CTX_get(ctx)) == NULL)) {
315             RSAerr(RSA_F_RSA_OSSL_PRIVATE_ENCRYPT, ERR_R_MALLOC_FAILURE);
316             goto err;
317         }
318         if (!rsa_blinding_convert(blinding, f, unblind, ctx))
319             goto err;
320     }
321
322     if ((rsa->flags & RSA_FLAG_EXT_PKEY) ||
323         (rsa->version == RSA_ASN1_VERSION_MULTI) ||
324         ((rsa->p != NULL) &&
325          (rsa->q != NULL) &&
326          (rsa->dmp1 != NULL) && (rsa->dmq1 != NULL) && (rsa->iqmp != NULL))) {
327         if (!rsa->meth->rsa_mod_exp(ret, f, rsa, ctx))
328             goto err;
329     } else {
330         BIGNUM *d = BN_new();
331         if (d == NULL) {
332             RSAerr(RSA_F_RSA_OSSL_PRIVATE_ENCRYPT, ERR_R_MALLOC_FAILURE);
333             goto err;
334         }
335         if (rsa->d == NULL) {
336             RSAerr(RSA_F_RSA_OSSL_PRIVATE_ENCRYPT, RSA_R_MISSING_PRIVATE_KEY);
337             BN_free(d);
338             goto err;
339         }
340         BN_with_flags(d, rsa->d, BN_FLG_CONSTTIME);
341
342         if (!rsa->meth->bn_mod_exp(ret, f, d, rsa->n, ctx,
343                                    rsa->_method_mod_n)) {
344             BN_free(d);
345             goto err;
346         }
347         /* We MUST free d before any further use of rsa->d */
348         BN_free(d);
349     }
350
351     if (blinding)
352         if (!rsa_blinding_invert(blinding, ret, unblind, ctx))
353             goto err;
354
355     if (padding == RSA_X931_PADDING) {
356         if (!BN_sub(f, rsa->n, ret))
357             goto err;
358         if (BN_cmp(ret, f) > 0)
359             res = f;
360         else
361             res = ret;
362     } else {
363         res = ret;
364     }
365
366     /*
367      * BN_bn2binpad puts in leading 0 bytes if the number is less than
368      * the length of the modulus.
369      */
370     r = BN_bn2binpad(res, to, num);
371  err:
372     BN_CTX_end(ctx);
373     BN_CTX_free(ctx);
374     OPENSSL_clear_free(buf, num);
375     return r;
376 }
377
378 static int rsa_ossl_private_decrypt(int flen, const unsigned char *from,
379                                    unsigned char *to, RSA *rsa, int padding)
380 {
381     BIGNUM *f, *ret;
382     int j, num = 0, r = -1;
383     unsigned char *buf = NULL;
384     BN_CTX *ctx = NULL;
385     int local_blinding = 0;
386     /*
387      * Used only if the blinding structure is shared. A non-NULL unblind
388      * instructs rsa_blinding_convert() and rsa_blinding_invert() to store
389      * the unblinding factor outside the blinding structure.
390      */
391     BIGNUM *unblind = NULL;
392     BN_BLINDING *blinding = NULL;
393
394     if ((ctx = BN_CTX_new_ex(rsa->libctx)) == NULL)
395         goto err;
396     BN_CTX_start(ctx);
397     f = BN_CTX_get(ctx);
398     ret = BN_CTX_get(ctx);
399     num = BN_num_bytes(rsa->n);
400     buf = OPENSSL_malloc(num);
401     if (ret == NULL || buf == NULL) {
402         RSAerr(RSA_F_RSA_OSSL_PRIVATE_DECRYPT, ERR_R_MALLOC_FAILURE);
403         goto err;
404     }
405
406     /*
407      * This check was for equality but PGP does evil things and chops off the
408      * top '0' bytes
409      */
410     if (flen > num) {
411         RSAerr(RSA_F_RSA_OSSL_PRIVATE_DECRYPT,
412                RSA_R_DATA_GREATER_THAN_MOD_LEN);
413         goto err;
414     }
415
416     /* make data into a big number */
417     if (BN_bin2bn(from, (int)flen, f) == NULL)
418         goto err;
419
420     if (BN_ucmp(f, rsa->n) >= 0) {
421         RSAerr(RSA_F_RSA_OSSL_PRIVATE_DECRYPT,
422                RSA_R_DATA_TOO_LARGE_FOR_MODULUS);
423         goto err;
424     }
425
426     if (!(rsa->flags & RSA_FLAG_NO_BLINDING)) {
427         blinding = rsa_get_blinding(rsa, &local_blinding, ctx);
428         if (blinding == NULL) {
429             RSAerr(RSA_F_RSA_OSSL_PRIVATE_DECRYPT, ERR_R_INTERNAL_ERROR);
430             goto err;
431         }
432     }
433
434     if (blinding != NULL) {
435         if (!local_blinding && ((unblind = BN_CTX_get(ctx)) == NULL)) {
436             RSAerr(RSA_F_RSA_OSSL_PRIVATE_DECRYPT, ERR_R_MALLOC_FAILURE);
437             goto err;
438         }
439         if (!rsa_blinding_convert(blinding, f, unblind, ctx))
440             goto err;
441     }
442
443     /* do the decrypt */
444     if ((rsa->flags & RSA_FLAG_EXT_PKEY) ||
445         (rsa->version == RSA_ASN1_VERSION_MULTI) ||
446         ((rsa->p != NULL) &&
447          (rsa->q != NULL) &&
448          (rsa->dmp1 != NULL) && (rsa->dmq1 != NULL) && (rsa->iqmp != NULL))) {
449         if (!rsa->meth->rsa_mod_exp(ret, f, rsa, ctx))
450             goto err;
451     } else {
452         BIGNUM *d = BN_new();
453         if (d == NULL) {
454             RSAerr(RSA_F_RSA_OSSL_PRIVATE_DECRYPT, ERR_R_MALLOC_FAILURE);
455             goto err;
456         }
457         if (rsa->d == NULL) {
458             RSAerr(RSA_F_RSA_OSSL_PRIVATE_DECRYPT, RSA_R_MISSING_PRIVATE_KEY);
459             BN_free(d);
460             goto err;
461         }
462         BN_with_flags(d, rsa->d, BN_FLG_CONSTTIME);
463
464         if (rsa->flags & RSA_FLAG_CACHE_PUBLIC)
465             if (!BN_MONT_CTX_set_locked(&rsa->_method_mod_n, rsa->lock,
466                                         rsa->n, ctx)) {
467                 BN_free(d);
468                 goto err;
469             }
470         if (!rsa->meth->bn_mod_exp(ret, f, d, rsa->n, ctx,
471                                    rsa->_method_mod_n)) {
472             BN_free(d);
473             goto err;
474         }
475         /* We MUST free d before any further use of rsa->d */
476         BN_free(d);
477     }
478
479     if (blinding)
480         if (!rsa_blinding_invert(blinding, ret, unblind, ctx))
481             goto err;
482
483     j = BN_bn2binpad(ret, buf, num);
484     if (j < 0)
485         goto err;
486
487     switch (padding) {
488     case RSA_PKCS1_PADDING:
489         r = RSA_padding_check_PKCS1_type_2(to, num, buf, j, num);
490         break;
491     case RSA_PKCS1_OAEP_PADDING:
492         r = RSA_padding_check_PKCS1_OAEP(to, num, buf, j, num, NULL, 0);
493         break;
494 #ifndef FIPS_MODULE
495     case RSA_SSLV23_PADDING:
496         r = RSA_padding_check_SSLv23(to, num, buf, j, num);
497         break;
498 #endif
499     case RSA_NO_PADDING:
500         memcpy(to, buf, (r = j));
501         break;
502     default:
503         RSAerr(RSA_F_RSA_OSSL_PRIVATE_DECRYPT, RSA_R_UNKNOWN_PADDING_TYPE);
504         goto err;
505     }
506 #ifndef FIPS_MODULE
507     /*
508      * This trick doesn't work in the FIPS provider because libcrypto manages
509      * the error stack. Instead we opt not to put an error on the stack at all
510      * in case of padding failure in the FIPS provider.
511      */
512     RSAerr(RSA_F_RSA_OSSL_PRIVATE_DECRYPT, RSA_R_PADDING_CHECK_FAILED);
513     err_clear_last_constant_time(1 & ~constant_time_msb(r));
514 #endif
515
516  err:
517     BN_CTX_end(ctx);
518     BN_CTX_free(ctx);
519     OPENSSL_clear_free(buf, num);
520     return r;
521 }
522
523 /* signature verification */
524 static int rsa_ossl_public_decrypt(int flen, const unsigned char *from,
525                                   unsigned char *to, RSA *rsa, int padding)
526 {
527     BIGNUM *f, *ret;
528     int i, num = 0, r = -1;
529     unsigned char *buf = NULL;
530     BN_CTX *ctx = NULL;
531
532     if (BN_num_bits(rsa->n) > OPENSSL_RSA_MAX_MODULUS_BITS) {
533         RSAerr(RSA_F_RSA_OSSL_PUBLIC_DECRYPT, RSA_R_MODULUS_TOO_LARGE);
534         return -1;
535     }
536
537     if (BN_ucmp(rsa->n, rsa->e) <= 0) {
538         RSAerr(RSA_F_RSA_OSSL_PUBLIC_DECRYPT, RSA_R_BAD_E_VALUE);
539         return -1;
540     }
541
542     /* for large moduli, enforce exponent limit */
543     if (BN_num_bits(rsa->n) > OPENSSL_RSA_SMALL_MODULUS_BITS) {
544         if (BN_num_bits(rsa->e) > OPENSSL_RSA_MAX_PUBEXP_BITS) {
545             RSAerr(RSA_F_RSA_OSSL_PUBLIC_DECRYPT, RSA_R_BAD_E_VALUE);
546             return -1;
547         }
548     }
549
550     if ((ctx = BN_CTX_new_ex(rsa->libctx)) == NULL)
551         goto err;
552     BN_CTX_start(ctx);
553     f = BN_CTX_get(ctx);
554     ret = BN_CTX_get(ctx);
555     num = BN_num_bytes(rsa->n);
556     buf = OPENSSL_malloc(num);
557     if (ret == NULL || buf == NULL) {
558         RSAerr(RSA_F_RSA_OSSL_PUBLIC_DECRYPT, ERR_R_MALLOC_FAILURE);
559         goto err;
560     }
561
562     /*
563      * This check was for equality but PGP does evil things and chops off the
564      * top '0' bytes
565      */
566     if (flen > num) {
567         RSAerr(RSA_F_RSA_OSSL_PUBLIC_DECRYPT, RSA_R_DATA_GREATER_THAN_MOD_LEN);
568         goto err;
569     }
570
571     if (BN_bin2bn(from, flen, f) == NULL)
572         goto err;
573
574     if (BN_ucmp(f, rsa->n) >= 0) {
575         RSAerr(RSA_F_RSA_OSSL_PUBLIC_DECRYPT,
576                RSA_R_DATA_TOO_LARGE_FOR_MODULUS);
577         goto err;
578     }
579
580     if (rsa->flags & RSA_FLAG_CACHE_PUBLIC)
581         if (!BN_MONT_CTX_set_locked(&rsa->_method_mod_n, rsa->lock,
582                                     rsa->n, ctx))
583             goto err;
584
585     if (!rsa->meth->bn_mod_exp(ret, f, rsa->e, rsa->n, ctx,
586                                rsa->_method_mod_n))
587         goto err;
588
589     if ((padding == RSA_X931_PADDING) && ((bn_get_words(ret)[0] & 0xf) != 12))
590         if (!BN_sub(ret, rsa->n, ret))
591             goto err;
592
593     i = BN_bn2binpad(ret, buf, num);
594     if (i < 0)
595         goto err;
596
597     switch (padding) {
598     case RSA_PKCS1_PADDING:
599         r = RSA_padding_check_PKCS1_type_1(to, num, buf, i, num);
600         break;
601     case RSA_X931_PADDING:
602         r = RSA_padding_check_X931(to, num, buf, i, num);
603         break;
604     case RSA_NO_PADDING:
605         memcpy(to, buf, (r = i));
606         break;
607     default:
608         RSAerr(RSA_F_RSA_OSSL_PUBLIC_DECRYPT, RSA_R_UNKNOWN_PADDING_TYPE);
609         goto err;
610     }
611     if (r < 0)
612         RSAerr(RSA_F_RSA_OSSL_PUBLIC_DECRYPT, RSA_R_PADDING_CHECK_FAILED);
613
614  err:
615     BN_CTX_end(ctx);
616     BN_CTX_free(ctx);
617     OPENSSL_clear_free(buf, num);
618     return r;
619 }
620
621 static int rsa_ossl_mod_exp(BIGNUM *r0, const BIGNUM *I, RSA *rsa, BN_CTX *ctx)
622 {
623     BIGNUM *r1, *m1, *vrfy;
624     int ret = 0, smooth = 0;
625 #ifndef FIPS_MODULE
626     BIGNUM *r2, *m[RSA_MAX_PRIME_NUM - 2];
627     int i, ex_primes = 0;
628     RSA_PRIME_INFO *pinfo;
629 #endif
630
631     BN_CTX_start(ctx);
632
633     r1 = BN_CTX_get(ctx);
634 #ifndef FIPS_MODULE
635     r2 = BN_CTX_get(ctx);
636 #endif
637     m1 = BN_CTX_get(ctx);
638     vrfy = BN_CTX_get(ctx);
639     if (vrfy == NULL)
640         goto err;
641
642 #ifndef FIPS_MODULE
643     if (rsa->version == RSA_ASN1_VERSION_MULTI
644         && ((ex_primes = sk_RSA_PRIME_INFO_num(rsa->prime_infos)) <= 0
645              || ex_primes > RSA_MAX_PRIME_NUM - 2))
646         goto err;
647 #endif
648
649     if (rsa->flags & RSA_FLAG_CACHE_PRIVATE) {
650         BIGNUM *factor = BN_new();
651
652         if (factor == NULL)
653             goto err;
654
655         /*
656          * Make sure BN_mod_inverse in Montgomery initialization uses the
657          * BN_FLG_CONSTTIME flag
658          */
659         if (!(BN_with_flags(factor, rsa->p, BN_FLG_CONSTTIME),
660               BN_MONT_CTX_set_locked(&rsa->_method_mod_p, rsa->lock,
661                                      factor, ctx))
662             || !(BN_with_flags(factor, rsa->q, BN_FLG_CONSTTIME),
663                  BN_MONT_CTX_set_locked(&rsa->_method_mod_q, rsa->lock,
664                                         factor, ctx))) {
665             BN_free(factor);
666             goto err;
667         }
668 #ifndef FIPS_MODULE
669         for (i = 0; i < ex_primes; i++) {
670             pinfo = sk_RSA_PRIME_INFO_value(rsa->prime_infos, i);
671             BN_with_flags(factor, pinfo->r, BN_FLG_CONSTTIME);
672             if (!BN_MONT_CTX_set_locked(&pinfo->m, rsa->lock, factor, ctx)) {
673                 BN_free(factor);
674                 goto err;
675             }
676         }
677 #endif
678         /*
679          * We MUST free |factor| before any further use of the prime factors
680          */
681         BN_free(factor);
682
683         smooth = (rsa->meth->bn_mod_exp == BN_mod_exp_mont)
684 #ifndef FIPS_MODULE
685                  && (ex_primes == 0)
686 #endif
687                  && (BN_num_bits(rsa->q) == BN_num_bits(rsa->p));
688     }
689
690     if (rsa->flags & RSA_FLAG_CACHE_PUBLIC)
691         if (!BN_MONT_CTX_set_locked(&rsa->_method_mod_n, rsa->lock,
692                                     rsa->n, ctx))
693             goto err;
694
695     if (smooth) {
696         /*
697          * Conversion from Montgomery domain, a.k.a. Montgomery reduction,
698          * accepts values in [0-m*2^w) range. w is m's bit width rounded up
699          * to limb width. So that at the very least if |I| is fully reduced,
700          * i.e. less than p*q, we can count on from-to round to perform
701          * below modulo operations on |I|. Unlike BN_mod it's constant time.
702          */
703         if (/* m1 = I moq q */
704             !bn_from_mont_fixed_top(m1, I, rsa->_method_mod_q, ctx)
705             || !bn_to_mont_fixed_top(m1, m1, rsa->_method_mod_q, ctx)
706             /* m1 = m1^dmq1 mod q */
707             || !BN_mod_exp_mont_consttime(m1, m1, rsa->dmq1, rsa->q, ctx,
708                                           rsa->_method_mod_q)
709             /* r1 = I mod p */
710             || !bn_from_mont_fixed_top(r1, I, rsa->_method_mod_p, ctx)
711             || !bn_to_mont_fixed_top(r1, r1, rsa->_method_mod_p, ctx)
712             /* r1 = r1^dmp1 mod p */
713             || !BN_mod_exp_mont_consttime(r1, r1, rsa->dmp1, rsa->p, ctx,
714                                           rsa->_method_mod_p)
715             /* r1 = (r1 - m1) mod p */
716             /*
717              * bn_mod_sub_fixed_top is not regular modular subtraction,
718              * it can tolerate subtrahend to be larger than modulus, but
719              * not bit-wise wider. This makes up for uncommon q>p case,
720              * when |m1| can be larger than |rsa->p|.
721              */
722             || !bn_mod_sub_fixed_top(r1, r1, m1, rsa->p)
723
724             /* r1 = r1 * iqmp mod p */
725             || !bn_to_mont_fixed_top(r1, r1, rsa->_method_mod_p, ctx)
726             || !bn_mul_mont_fixed_top(r1, r1, rsa->iqmp, rsa->_method_mod_p,
727                                       ctx)
728             /* r0 = r1 * q + m1 */
729             || !bn_mul_fixed_top(r0, r1, rsa->q, ctx)
730             || !bn_mod_add_fixed_top(r0, r0, m1, rsa->n))
731             goto err;
732
733         goto tail;
734     }
735
736     /* compute I mod q */
737     {
738         BIGNUM *c = BN_new();
739         if (c == NULL)
740             goto err;
741         BN_with_flags(c, I, BN_FLG_CONSTTIME);
742
743         if (!BN_mod(r1, c, rsa->q, ctx)) {
744             BN_free(c);
745             goto err;
746         }
747
748         {
749             BIGNUM *dmq1 = BN_new();
750             if (dmq1 == NULL) {
751                 BN_free(c);
752                 goto err;
753             }
754             BN_with_flags(dmq1, rsa->dmq1, BN_FLG_CONSTTIME);
755
756             /* compute r1^dmq1 mod q */
757             if (!rsa->meth->bn_mod_exp(m1, r1, dmq1, rsa->q, ctx,
758                                        rsa->_method_mod_q)) {
759                 BN_free(c);
760                 BN_free(dmq1);
761                 goto err;
762             }
763             /* We MUST free dmq1 before any further use of rsa->dmq1 */
764             BN_free(dmq1);
765         }
766
767         /* compute I mod p */
768         if (!BN_mod(r1, c, rsa->p, ctx)) {
769             BN_free(c);
770             goto err;
771         }
772         /* We MUST free c before any further use of I */
773         BN_free(c);
774     }
775
776     {
777         BIGNUM *dmp1 = BN_new();
778         if (dmp1 == NULL)
779             goto err;
780         BN_with_flags(dmp1, rsa->dmp1, BN_FLG_CONSTTIME);
781
782         /* compute r1^dmp1 mod p */
783         if (!rsa->meth->bn_mod_exp(r0, r1, dmp1, rsa->p, ctx,
784                                    rsa->_method_mod_p)) {
785             BN_free(dmp1);
786             goto err;
787         }
788         /* We MUST free dmp1 before any further use of rsa->dmp1 */
789         BN_free(dmp1);
790     }
791
792 #ifndef FIPS_MODULE
793     /*
794      * calculate m_i in multi-prime case
795      *
796      * TODO:
797      * 1. squash the following two loops and calculate |m_i| there.
798      * 2. remove cc and reuse |c|.
799      * 3. remove |dmq1| and |dmp1| in previous block and use |di|.
800      *
801      * If these things are done, the code will be more readable.
802      */
803     if (ex_primes > 0) {
804         BIGNUM *di = BN_new(), *cc = BN_new();
805
806         if (cc == NULL || di == NULL) {
807             BN_free(cc);
808             BN_free(di);
809             goto err;
810         }
811
812         for (i = 0; i < ex_primes; i++) {
813             /* prepare m_i */
814             if ((m[i] = BN_CTX_get(ctx)) == NULL) {
815                 BN_free(cc);
816                 BN_free(di);
817                 goto err;
818             }
819
820             pinfo = sk_RSA_PRIME_INFO_value(rsa->prime_infos, i);
821
822             /* prepare c and d_i */
823             BN_with_flags(cc, I, BN_FLG_CONSTTIME);
824             BN_with_flags(di, pinfo->d, BN_FLG_CONSTTIME);
825
826             if (!BN_mod(r1, cc, pinfo->r, ctx)) {
827                 BN_free(cc);
828                 BN_free(di);
829                 goto err;
830             }
831             /* compute r1 ^ d_i mod r_i */
832             if (!rsa->meth->bn_mod_exp(m[i], r1, di, pinfo->r, ctx, pinfo->m)) {
833                 BN_free(cc);
834                 BN_free(di);
835                 goto err;
836             }
837         }
838
839         BN_free(cc);
840         BN_free(di);
841     }
842 #endif
843
844     if (!BN_sub(r0, r0, m1))
845         goto err;
846     /*
847      * This will help stop the size of r0 increasing, which does affect the
848      * multiply if it optimised for a power of 2 size
849      */
850     if (BN_is_negative(r0))
851         if (!BN_add(r0, r0, rsa->p))
852             goto err;
853
854     if (!BN_mul(r1, r0, rsa->iqmp, ctx))
855         goto err;
856
857     {
858         BIGNUM *pr1 = BN_new();
859         if (pr1 == NULL)
860             goto err;
861         BN_with_flags(pr1, r1, BN_FLG_CONSTTIME);
862
863         if (!BN_mod(r0, pr1, rsa->p, ctx)) {
864             BN_free(pr1);
865             goto err;
866         }
867         /* We MUST free pr1 before any further use of r1 */
868         BN_free(pr1);
869     }
870
871     /*
872      * If p < q it is occasionally possible for the correction of adding 'p'
873      * if r0 is negative above to leave the result still negative. This can
874      * break the private key operations: the following second correction
875      * should *always* correct this rare occurrence. This will *never* happen
876      * with OpenSSL generated keys because they ensure p > q [steve]
877      */
878     if (BN_is_negative(r0))
879         if (!BN_add(r0, r0, rsa->p))
880             goto err;
881     if (!BN_mul(r1, r0, rsa->q, ctx))
882         goto err;
883     if (!BN_add(r0, r1, m1))
884         goto err;
885
886 #ifndef FIPS_MODULE
887     /* add m_i to m in multi-prime case */
888     if (ex_primes > 0) {
889         BIGNUM *pr2 = BN_new();
890
891         if (pr2 == NULL)
892             goto err;
893
894         for (i = 0; i < ex_primes; i++) {
895             pinfo = sk_RSA_PRIME_INFO_value(rsa->prime_infos, i);
896             if (!BN_sub(r1, m[i], r0)) {
897                 BN_free(pr2);
898                 goto err;
899             }
900
901             if (!BN_mul(r2, r1, pinfo->t, ctx)) {
902                 BN_free(pr2);
903                 goto err;
904             }
905
906             BN_with_flags(pr2, r2, BN_FLG_CONSTTIME);
907
908             if (!BN_mod(r1, pr2, pinfo->r, ctx)) {
909                 BN_free(pr2);
910                 goto err;
911             }
912
913             if (BN_is_negative(r1))
914                 if (!BN_add(r1, r1, pinfo->r)) {
915                     BN_free(pr2);
916                     goto err;
917                 }
918             if (!BN_mul(r1, r1, pinfo->pp, ctx)) {
919                 BN_free(pr2);
920                 goto err;
921             }
922             if (!BN_add(r0, r0, r1)) {
923                 BN_free(pr2);
924                 goto err;
925             }
926         }
927         BN_free(pr2);
928     }
929 #endif
930
931  tail:
932     if (rsa->e && rsa->n) {
933         if (rsa->meth->bn_mod_exp == BN_mod_exp_mont) {
934             if (!BN_mod_exp_mont(vrfy, r0, rsa->e, rsa->n, ctx,
935                                  rsa->_method_mod_n))
936                 goto err;
937         } else {
938             bn_correct_top(r0);
939             if (!rsa->meth->bn_mod_exp(vrfy, r0, rsa->e, rsa->n, ctx,
940                                        rsa->_method_mod_n))
941                 goto err;
942         }
943         /*
944          * If 'I' was greater than (or equal to) rsa->n, the operation will
945          * be equivalent to using 'I mod n'. However, the result of the
946          * verify will *always* be less than 'n' so we don't check for
947          * absolute equality, just congruency.
948          */
949         if (!BN_sub(vrfy, vrfy, I))
950             goto err;
951         if (BN_is_zero(vrfy)) {
952             bn_correct_top(r0);
953             ret = 1;
954             goto err;   /* not actually error */
955         }
956         if (!BN_mod(vrfy, vrfy, rsa->n, ctx))
957             goto err;
958         if (BN_is_negative(vrfy))
959             if (!BN_add(vrfy, vrfy, rsa->n))
960                 goto err;
961         if (!BN_is_zero(vrfy)) {
962             /*
963              * 'I' and 'vrfy' aren't congruent mod n. Don't leak
964              * miscalculated CRT output, just do a raw (slower) mod_exp and
965              * return that instead.
966              */
967
968             BIGNUM *d = BN_new();
969             if (d == NULL)
970                 goto err;
971             BN_with_flags(d, rsa->d, BN_FLG_CONSTTIME);
972
973             if (!rsa->meth->bn_mod_exp(r0, I, d, rsa->n, ctx,
974                                        rsa->_method_mod_n)) {
975                 BN_free(d);
976                 goto err;
977             }
978             /* We MUST free d before any further use of rsa->d */
979             BN_free(d);
980         }
981     }
982     /*
983      * It's unfortunate that we have to bn_correct_top(r0). What hopefully
984      * saves the day is that correction is highly unlike, and private key
985      * operations are customarily performed on blinded message. Which means
986      * that attacker won't observe correlation with chosen plaintext.
987      * Secondly, remaining code would still handle it in same computational
988      * time and even conceal memory access pattern around corrected top.
989      */
990     bn_correct_top(r0);
991     ret = 1;
992  err:
993     BN_CTX_end(ctx);
994     return ret;
995 }
996
997 static int rsa_ossl_init(RSA *rsa)
998 {
999     rsa->flags |= RSA_FLAG_CACHE_PUBLIC | RSA_FLAG_CACHE_PRIVATE;
1000     return 1;
1001 }
1002
1003 static int rsa_ossl_finish(RSA *rsa)
1004 {
1005 #ifndef FIPS_MODULE
1006     int i;
1007     RSA_PRIME_INFO *pinfo;
1008
1009     for (i = 0; i < sk_RSA_PRIME_INFO_num(rsa->prime_infos); i++) {
1010         pinfo = sk_RSA_PRIME_INFO_value(rsa->prime_infos, i);
1011         BN_MONT_CTX_free(pinfo->m);
1012     }
1013 #endif
1014
1015     BN_MONT_CTX_free(rsa->_method_mod_n);
1016     BN_MONT_CTX_free(rsa->_method_mod_p);
1017     BN_MONT_CTX_free(rsa->_method_mod_q);
1018     return 1;
1019 }