Make the RSA ASYM_CIPHER implementation available inside the FIPS module
[openssl.git] / crypto / rsa / rsa_ossl.c
1 /*
2  * Copyright 1995-2018 The OpenSSL Project Authors. All Rights Reserved.
3  *
4  * Licensed under the Apache License 2.0 (the "License").  You may not use
5  * this file except in compliance with the License.  You can obtain a copy
6  * in the file LICENSE in the source distribution or at
7  * https://www.openssl.org/source/license.html
8  */
9
10 #include "internal/cryptlib.h"
11 #include "crypto/bn.h"
12 #include "rsa_local.h"
13 #include "internal/constant_time.h"
14
15 static int rsa_ossl_public_encrypt(int flen, const unsigned char *from,
16                                   unsigned char *to, RSA *rsa, int padding);
17 static int rsa_ossl_private_encrypt(int flen, const unsigned char *from,
18                                    unsigned char *to, RSA *rsa, int padding);
19 static int rsa_ossl_public_decrypt(int flen, const unsigned char *from,
20                                   unsigned char *to, RSA *rsa, int padding);
21 static int rsa_ossl_private_decrypt(int flen, const unsigned char *from,
22                                    unsigned char *to, RSA *rsa, int padding);
23 static int rsa_ossl_mod_exp(BIGNUM *r0, const BIGNUM *i, RSA *rsa,
24                            BN_CTX *ctx);
25 static int rsa_ossl_init(RSA *rsa);
26 static int rsa_ossl_finish(RSA *rsa);
27 static RSA_METHOD rsa_pkcs1_ossl_meth = {
28     "OpenSSL PKCS#1 RSA",
29     rsa_ossl_public_encrypt,
30     rsa_ossl_public_decrypt,     /* signature verification */
31     rsa_ossl_private_encrypt,    /* signing */
32     rsa_ossl_private_decrypt,
33     rsa_ossl_mod_exp,
34     BN_mod_exp_mont,            /* XXX probably we should not use Montgomery
35                                  * if e == 3 */
36     rsa_ossl_init,
37     rsa_ossl_finish,
38     RSA_FLAG_FIPS_METHOD,       /* flags */
39     NULL,
40     0,                          /* rsa_sign */
41     0,                          /* rsa_verify */
42     NULL,                       /* rsa_keygen */
43     NULL                        /* rsa_multi_prime_keygen */
44 };
45
46 static const RSA_METHOD *default_RSA_meth = &rsa_pkcs1_ossl_meth;
47
48 void RSA_set_default_method(const RSA_METHOD *meth)
49 {
50     default_RSA_meth = meth;
51 }
52
53 const RSA_METHOD *RSA_get_default_method(void)
54 {
55     return default_RSA_meth;
56 }
57
58 const RSA_METHOD *RSA_PKCS1_OpenSSL(void)
59 {
60     return &rsa_pkcs1_ossl_meth;
61 }
62
63 const RSA_METHOD *RSA_null_method(void)
64 {
65     return NULL;
66 }
67
68 static int rsa_ossl_public_encrypt(int flen, const unsigned char *from,
69                                   unsigned char *to, RSA *rsa, int padding)
70 {
71     BIGNUM *f, *ret;
72     int i, num = 0, r = -1;
73     unsigned char *buf = NULL;
74     BN_CTX *ctx = NULL;
75
76     if (BN_num_bits(rsa->n) > OPENSSL_RSA_MAX_MODULUS_BITS) {
77         RSAerr(RSA_F_RSA_OSSL_PUBLIC_ENCRYPT, RSA_R_MODULUS_TOO_LARGE);
78         return -1;
79     }
80
81     if (BN_ucmp(rsa->n, rsa->e) <= 0) {
82         RSAerr(RSA_F_RSA_OSSL_PUBLIC_ENCRYPT, RSA_R_BAD_E_VALUE);
83         return -1;
84     }
85
86     /* for large moduli, enforce exponent limit */
87     if (BN_num_bits(rsa->n) > OPENSSL_RSA_SMALL_MODULUS_BITS) {
88         if (BN_num_bits(rsa->e) > OPENSSL_RSA_MAX_PUBEXP_BITS) {
89             RSAerr(RSA_F_RSA_OSSL_PUBLIC_ENCRYPT, RSA_R_BAD_E_VALUE);
90             return -1;
91         }
92     }
93
94     if ((ctx = BN_CTX_new_ex(rsa->libctx)) == NULL)
95         goto err;
96     BN_CTX_start(ctx);
97     f = BN_CTX_get(ctx);
98     ret = BN_CTX_get(ctx);
99     num = BN_num_bytes(rsa->n);
100     buf = OPENSSL_malloc(num);
101     if (ret == NULL || buf == NULL) {
102         RSAerr(RSA_F_RSA_OSSL_PUBLIC_ENCRYPT, ERR_R_MALLOC_FAILURE);
103         goto err;
104     }
105
106     switch (padding) {
107     case RSA_PKCS1_PADDING:
108         i = RSA_padding_add_PKCS1_type_2(buf, num, from, flen);
109         break;
110     case RSA_PKCS1_OAEP_PADDING:
111         i = RSA_padding_add_PKCS1_OAEP(buf, num, from, flen, NULL, 0);
112         break;
113 #ifndef FIPS_MODE
114     case RSA_SSLV23_PADDING:
115         i = RSA_padding_add_SSLv23(buf, num, from, flen);
116         break;
117 #endif
118     case RSA_NO_PADDING:
119         i = RSA_padding_add_none(buf, num, from, flen);
120         break;
121     default:
122         RSAerr(RSA_F_RSA_OSSL_PUBLIC_ENCRYPT, RSA_R_UNKNOWN_PADDING_TYPE);
123         goto err;
124     }
125     if (i <= 0)
126         goto err;
127
128     if (BN_bin2bn(buf, num, f) == NULL)
129         goto err;
130
131     if (BN_ucmp(f, rsa->n) >= 0) {
132         /* usually the padding functions would catch this */
133         RSAerr(RSA_F_RSA_OSSL_PUBLIC_ENCRYPT,
134                RSA_R_DATA_TOO_LARGE_FOR_MODULUS);
135         goto err;
136     }
137
138     if (rsa->flags & RSA_FLAG_CACHE_PUBLIC)
139         if (!BN_MONT_CTX_set_locked(&rsa->_method_mod_n, rsa->lock,
140                                     rsa->n, ctx))
141             goto err;
142
143     if (!rsa->meth->bn_mod_exp(ret, f, rsa->e, rsa->n, ctx,
144                                rsa->_method_mod_n))
145         goto err;
146
147     /*
148      * BN_bn2binpad puts in leading 0 bytes if the number is less than
149      * the length of the modulus.
150      */
151     r = BN_bn2binpad(ret, to, num);
152  err:
153     BN_CTX_end(ctx);
154     BN_CTX_free(ctx);
155     OPENSSL_clear_free(buf, num);
156     return r;
157 }
158
159 static BN_BLINDING *rsa_get_blinding(RSA *rsa, int *local, BN_CTX *ctx)
160 {
161     BN_BLINDING *ret;
162
163     CRYPTO_THREAD_write_lock(rsa->lock);
164
165     if (rsa->blinding == NULL) {
166         rsa->blinding = RSA_setup_blinding(rsa, ctx);
167     }
168
169     ret = rsa->blinding;
170     if (ret == NULL)
171         goto err;
172
173     if (BN_BLINDING_is_current_thread(ret)) {
174         /* rsa->blinding is ours! */
175
176         *local = 1;
177     } else {
178         /* resort to rsa->mt_blinding instead */
179
180         /*
181          * instructs rsa_blinding_convert(), rsa_blinding_invert() that the
182          * BN_BLINDING is shared, meaning that accesses require locks, and
183          * that the blinding factor must be stored outside the BN_BLINDING
184          */
185         *local = 0;
186
187         if (rsa->mt_blinding == NULL) {
188             rsa->mt_blinding = RSA_setup_blinding(rsa, ctx);
189         }
190         ret = rsa->mt_blinding;
191     }
192
193  err:
194     CRYPTO_THREAD_unlock(rsa->lock);
195     return ret;
196 }
197
198 static int rsa_blinding_convert(BN_BLINDING *b, BIGNUM *f, BIGNUM *unblind,
199                                 BN_CTX *ctx)
200 {
201     if (unblind == NULL) {
202         /*
203          * Local blinding: store the unblinding factor in BN_BLINDING.
204          */
205         return BN_BLINDING_convert_ex(f, NULL, b, ctx);
206     } else {
207         /*
208          * Shared blinding: store the unblinding factor outside BN_BLINDING.
209          */
210         int ret;
211
212         BN_BLINDING_lock(b);
213         ret = BN_BLINDING_convert_ex(f, unblind, b, ctx);
214         BN_BLINDING_unlock(b);
215
216         return ret;
217     }
218 }
219
220 static int rsa_blinding_invert(BN_BLINDING *b, BIGNUM *f, BIGNUM *unblind,
221                                BN_CTX *ctx)
222 {
223     /*
224      * For local blinding, unblind is set to NULL, and BN_BLINDING_invert_ex
225      * will use the unblinding factor stored in BN_BLINDING. If BN_BLINDING
226      * is shared between threads, unblind must be non-null:
227      * BN_BLINDING_invert_ex will then use the local unblinding factor, and
228      * will only read the modulus from BN_BLINDING. In both cases it's safe
229      * to access the blinding without a lock.
230      */
231     return BN_BLINDING_invert_ex(f, unblind, b, ctx);
232 }
233
234 /* signing */
235 static int rsa_ossl_private_encrypt(int flen, const unsigned char *from,
236                                    unsigned char *to, RSA *rsa, int padding)
237 {
238     BIGNUM *f, *ret, *res;
239     int i, num = 0, r = -1;
240     unsigned char *buf = NULL;
241     BN_CTX *ctx = NULL;
242     int local_blinding = 0;
243     /*
244      * Used only if the blinding structure is shared. A non-NULL unblind
245      * instructs rsa_blinding_convert() and rsa_blinding_invert() to store
246      * the unblinding factor outside the blinding structure.
247      */
248     BIGNUM *unblind = NULL;
249     BN_BLINDING *blinding = NULL;
250
251     if ((ctx = BN_CTX_new_ex(rsa->libctx)) == NULL)
252         goto err;
253     BN_CTX_start(ctx);
254     f = BN_CTX_get(ctx);
255     ret = BN_CTX_get(ctx);
256     num = BN_num_bytes(rsa->n);
257     buf = OPENSSL_malloc(num);
258     if (ret == NULL || buf == NULL) {
259         RSAerr(RSA_F_RSA_OSSL_PRIVATE_ENCRYPT, ERR_R_MALLOC_FAILURE);
260         goto err;
261     }
262
263     switch (padding) {
264     case RSA_PKCS1_PADDING:
265         i = RSA_padding_add_PKCS1_type_1(buf, num, from, flen);
266         break;
267     case RSA_X931_PADDING:
268         i = RSA_padding_add_X931(buf, num, from, flen);
269         break;
270     case RSA_NO_PADDING:
271         i = RSA_padding_add_none(buf, num, from, flen);
272         break;
273     case RSA_SSLV23_PADDING:
274     default:
275         RSAerr(RSA_F_RSA_OSSL_PRIVATE_ENCRYPT, RSA_R_UNKNOWN_PADDING_TYPE);
276         goto err;
277     }
278     if (i <= 0)
279         goto err;
280
281     if (BN_bin2bn(buf, num, f) == NULL)
282         goto err;
283
284     if (BN_ucmp(f, rsa->n) >= 0) {
285         /* usually the padding functions would catch this */
286         RSAerr(RSA_F_RSA_OSSL_PRIVATE_ENCRYPT,
287                RSA_R_DATA_TOO_LARGE_FOR_MODULUS);
288         goto err;
289     }
290
291     if (rsa->flags & RSA_FLAG_CACHE_PUBLIC)
292         if (!BN_MONT_CTX_set_locked(&rsa->_method_mod_n, rsa->lock,
293                                     rsa->n, ctx))
294             goto err;
295
296     if (!(rsa->flags & RSA_FLAG_NO_BLINDING)) {
297         blinding = rsa_get_blinding(rsa, &local_blinding, ctx);
298         if (blinding == NULL) {
299             RSAerr(RSA_F_RSA_OSSL_PRIVATE_ENCRYPT, ERR_R_INTERNAL_ERROR);
300             goto err;
301         }
302     }
303
304     if (blinding != NULL) {
305         if (!local_blinding && ((unblind = BN_CTX_get(ctx)) == NULL)) {
306             RSAerr(RSA_F_RSA_OSSL_PRIVATE_ENCRYPT, ERR_R_MALLOC_FAILURE);
307             goto err;
308         }
309         if (!rsa_blinding_convert(blinding, f, unblind, ctx))
310             goto err;
311     }
312
313     if ((rsa->flags & RSA_FLAG_EXT_PKEY) ||
314         (rsa->version == RSA_ASN1_VERSION_MULTI) ||
315         ((rsa->p != NULL) &&
316          (rsa->q != NULL) &&
317          (rsa->dmp1 != NULL) && (rsa->dmq1 != NULL) && (rsa->iqmp != NULL))) {
318         if (!rsa->meth->rsa_mod_exp(ret, f, rsa, ctx))
319             goto err;
320     } else {
321         BIGNUM *d = BN_new();
322         if (d == NULL) {
323             RSAerr(RSA_F_RSA_OSSL_PRIVATE_ENCRYPT, ERR_R_MALLOC_FAILURE);
324             goto err;
325         }
326         if (rsa->d == NULL) {
327             RSAerr(RSA_F_RSA_OSSL_PRIVATE_ENCRYPT, RSA_R_MISSING_PRIVATE_KEY);
328             BN_free(d);
329             goto err;
330         }
331         BN_with_flags(d, rsa->d, BN_FLG_CONSTTIME);
332
333         if (!rsa->meth->bn_mod_exp(ret, f, d, rsa->n, ctx,
334                                    rsa->_method_mod_n)) {
335             BN_free(d);
336             goto err;
337         }
338         /* We MUST free d before any further use of rsa->d */
339         BN_free(d);
340     }
341
342     if (blinding)
343         if (!rsa_blinding_invert(blinding, ret, unblind, ctx))
344             goto err;
345
346     if (padding == RSA_X931_PADDING) {
347         if (!BN_sub(f, rsa->n, ret))
348             goto err;
349         if (BN_cmp(ret, f) > 0)
350             res = f;
351         else
352             res = ret;
353     } else {
354         res = ret;
355     }
356
357     /*
358      * BN_bn2binpad puts in leading 0 bytes if the number is less than
359      * the length of the modulus.
360      */
361     r = BN_bn2binpad(res, to, num);
362  err:
363     BN_CTX_end(ctx);
364     BN_CTX_free(ctx);
365     OPENSSL_clear_free(buf, num);
366     return r;
367 }
368
369 static int rsa_ossl_private_decrypt(int flen, const unsigned char *from,
370                                    unsigned char *to, RSA *rsa, int padding)
371 {
372     BIGNUM *f, *ret;
373     int j, num = 0, r = -1;
374     unsigned char *buf = NULL;
375     BN_CTX *ctx = NULL;
376     int local_blinding = 0;
377     /*
378      * Used only if the blinding structure is shared. A non-NULL unblind
379      * instructs rsa_blinding_convert() and rsa_blinding_invert() to store
380      * the unblinding factor outside the blinding structure.
381      */
382     BIGNUM *unblind = NULL;
383     BN_BLINDING *blinding = NULL;
384
385     if ((ctx = BN_CTX_new_ex(rsa->libctx)) == NULL)
386         goto err;
387     BN_CTX_start(ctx);
388     f = BN_CTX_get(ctx);
389     ret = BN_CTX_get(ctx);
390     num = BN_num_bytes(rsa->n);
391     buf = OPENSSL_malloc(num);
392     if (ret == NULL || buf == NULL) {
393         RSAerr(RSA_F_RSA_OSSL_PRIVATE_DECRYPT, ERR_R_MALLOC_FAILURE);
394         goto err;
395     }
396
397     /*
398      * This check was for equality but PGP does evil things and chops off the
399      * top '0' bytes
400      */
401     if (flen > num) {
402         RSAerr(RSA_F_RSA_OSSL_PRIVATE_DECRYPT,
403                RSA_R_DATA_GREATER_THAN_MOD_LEN);
404         goto err;
405     }
406
407     /* make data into a big number */
408     if (BN_bin2bn(from, (int)flen, f) == NULL)
409         goto err;
410
411     if (BN_ucmp(f, rsa->n) >= 0) {
412         RSAerr(RSA_F_RSA_OSSL_PRIVATE_DECRYPT,
413                RSA_R_DATA_TOO_LARGE_FOR_MODULUS);
414         goto err;
415     }
416
417     if (!(rsa->flags & RSA_FLAG_NO_BLINDING)) {
418         blinding = rsa_get_blinding(rsa, &local_blinding, ctx);
419         if (blinding == NULL) {
420             RSAerr(RSA_F_RSA_OSSL_PRIVATE_DECRYPT, ERR_R_INTERNAL_ERROR);
421             goto err;
422         }
423     }
424
425     if (blinding != NULL) {
426         if (!local_blinding && ((unblind = BN_CTX_get(ctx)) == NULL)) {
427             RSAerr(RSA_F_RSA_OSSL_PRIVATE_DECRYPT, ERR_R_MALLOC_FAILURE);
428             goto err;
429         }
430         if (!rsa_blinding_convert(blinding, f, unblind, ctx))
431             goto err;
432     }
433
434     /* do the decrypt */
435     if ((rsa->flags & RSA_FLAG_EXT_PKEY) ||
436         (rsa->version == RSA_ASN1_VERSION_MULTI) ||
437         ((rsa->p != NULL) &&
438          (rsa->q != NULL) &&
439          (rsa->dmp1 != NULL) && (rsa->dmq1 != NULL) && (rsa->iqmp != NULL))) {
440         if (!rsa->meth->rsa_mod_exp(ret, f, rsa, ctx))
441             goto err;
442     } else {
443         BIGNUM *d = BN_new();
444         if (d == NULL) {
445             RSAerr(RSA_F_RSA_OSSL_PRIVATE_DECRYPT, ERR_R_MALLOC_FAILURE);
446             goto err;
447         }
448         if (rsa->d == NULL) {
449             RSAerr(RSA_F_RSA_OSSL_PRIVATE_DECRYPT, RSA_R_MISSING_PRIVATE_KEY);
450             BN_free(d);
451             goto err;
452         }
453         BN_with_flags(d, rsa->d, BN_FLG_CONSTTIME);
454
455         if (rsa->flags & RSA_FLAG_CACHE_PUBLIC)
456             if (!BN_MONT_CTX_set_locked(&rsa->_method_mod_n, rsa->lock,
457                                         rsa->n, ctx)) {
458                 BN_free(d);
459                 goto err;
460             }
461         if (!rsa->meth->bn_mod_exp(ret, f, d, rsa->n, ctx,
462                                    rsa->_method_mod_n)) {
463             BN_free(d);
464             goto err;
465         }
466         /* We MUST free d before any further use of rsa->d */
467         BN_free(d);
468     }
469
470     if (blinding)
471         if (!rsa_blinding_invert(blinding, ret, unblind, ctx))
472             goto err;
473
474     j = BN_bn2binpad(ret, buf, num);
475     if (j < 0)
476         goto err;
477
478     switch (padding) {
479     case RSA_PKCS1_PADDING:
480         r = RSA_padding_check_PKCS1_type_2(to, num, buf, j, num);
481         break;
482     case RSA_PKCS1_OAEP_PADDING:
483         r = RSA_padding_check_PKCS1_OAEP(to, num, buf, j, num, NULL, 0);
484         break;
485 #ifndef FIPS_MODE
486     case RSA_SSLV23_PADDING:
487         r = RSA_padding_check_SSLv23(to, num, buf, j, num);
488         break;
489 #endif
490     case RSA_NO_PADDING:
491         memcpy(to, buf, (r = j));
492         break;
493     default:
494         RSAerr(RSA_F_RSA_OSSL_PRIVATE_DECRYPT, RSA_R_UNKNOWN_PADDING_TYPE);
495         goto err;
496     }
497 #ifndef FIPS_MODE
498     /*
499      * This trick doesn't work in the FIPS provider because libcrypto manages
500      * the error stack. Instead we opt not to put an error on the stack at all
501      * in case of padding failure in the FIPS provider.
502      */
503     RSAerr(RSA_F_RSA_OSSL_PRIVATE_DECRYPT, RSA_R_PADDING_CHECK_FAILED);
504     err_clear_last_constant_time(1 & ~constant_time_msb(r));
505 #endif
506
507  err:
508     BN_CTX_end(ctx);
509     BN_CTX_free(ctx);
510     OPENSSL_clear_free(buf, num);
511     return r;
512 }
513
514 /* signature verification */
515 static int rsa_ossl_public_decrypt(int flen, const unsigned char *from,
516                                   unsigned char *to, RSA *rsa, int padding)
517 {
518     BIGNUM *f, *ret;
519     int i, num = 0, r = -1;
520     unsigned char *buf = NULL;
521     BN_CTX *ctx = NULL;
522
523     if (BN_num_bits(rsa->n) > OPENSSL_RSA_MAX_MODULUS_BITS) {
524         RSAerr(RSA_F_RSA_OSSL_PUBLIC_DECRYPT, RSA_R_MODULUS_TOO_LARGE);
525         return -1;
526     }
527
528     if (BN_ucmp(rsa->n, rsa->e) <= 0) {
529         RSAerr(RSA_F_RSA_OSSL_PUBLIC_DECRYPT, RSA_R_BAD_E_VALUE);
530         return -1;
531     }
532
533     /* for large moduli, enforce exponent limit */
534     if (BN_num_bits(rsa->n) > OPENSSL_RSA_SMALL_MODULUS_BITS) {
535         if (BN_num_bits(rsa->e) > OPENSSL_RSA_MAX_PUBEXP_BITS) {
536             RSAerr(RSA_F_RSA_OSSL_PUBLIC_DECRYPT, RSA_R_BAD_E_VALUE);
537             return -1;
538         }
539     }
540
541     if ((ctx = BN_CTX_new_ex(rsa->libctx)) == NULL)
542         goto err;
543     BN_CTX_start(ctx);
544     f = BN_CTX_get(ctx);
545     ret = BN_CTX_get(ctx);
546     num = BN_num_bytes(rsa->n);
547     buf = OPENSSL_malloc(num);
548     if (ret == NULL || buf == NULL) {
549         RSAerr(RSA_F_RSA_OSSL_PUBLIC_DECRYPT, ERR_R_MALLOC_FAILURE);
550         goto err;
551     }
552
553     /*
554      * This check was for equality but PGP does evil things and chops off the
555      * top '0' bytes
556      */
557     if (flen > num) {
558         RSAerr(RSA_F_RSA_OSSL_PUBLIC_DECRYPT, RSA_R_DATA_GREATER_THAN_MOD_LEN);
559         goto err;
560     }
561
562     if (BN_bin2bn(from, flen, f) == NULL)
563         goto err;
564
565     if (BN_ucmp(f, rsa->n) >= 0) {
566         RSAerr(RSA_F_RSA_OSSL_PUBLIC_DECRYPT,
567                RSA_R_DATA_TOO_LARGE_FOR_MODULUS);
568         goto err;
569     }
570
571     if (rsa->flags & RSA_FLAG_CACHE_PUBLIC)
572         if (!BN_MONT_CTX_set_locked(&rsa->_method_mod_n, rsa->lock,
573                                     rsa->n, ctx))
574             goto err;
575
576     if (!rsa->meth->bn_mod_exp(ret, f, rsa->e, rsa->n, ctx,
577                                rsa->_method_mod_n))
578         goto err;
579
580     if ((padding == RSA_X931_PADDING) && ((bn_get_words(ret)[0] & 0xf) != 12))
581         if (!BN_sub(ret, rsa->n, ret))
582             goto err;
583
584     i = BN_bn2binpad(ret, buf, num);
585     if (i < 0)
586         goto err;
587
588     switch (padding) {
589     case RSA_PKCS1_PADDING:
590         r = RSA_padding_check_PKCS1_type_1(to, num, buf, i, num);
591         break;
592     case RSA_X931_PADDING:
593         r = RSA_padding_check_X931(to, num, buf, i, num);
594         break;
595     case RSA_NO_PADDING:
596         memcpy(to, buf, (r = i));
597         break;
598     default:
599         RSAerr(RSA_F_RSA_OSSL_PUBLIC_DECRYPT, RSA_R_UNKNOWN_PADDING_TYPE);
600         goto err;
601     }
602     if (r < 0)
603         RSAerr(RSA_F_RSA_OSSL_PUBLIC_DECRYPT, RSA_R_PADDING_CHECK_FAILED);
604
605  err:
606     BN_CTX_end(ctx);
607     BN_CTX_free(ctx);
608     OPENSSL_clear_free(buf, num);
609     return r;
610 }
611
612 static int rsa_ossl_mod_exp(BIGNUM *r0, const BIGNUM *I, RSA *rsa, BN_CTX *ctx)
613 {
614     BIGNUM *r1, *m1, *vrfy;
615     int ret = 0, smooth = 0;
616 #ifndef FIPS_MODE
617     BIGNUM *r2, *m[RSA_MAX_PRIME_NUM - 2];
618     int i, ex_primes = 0;
619     RSA_PRIME_INFO *pinfo;
620 #endif
621
622     BN_CTX_start(ctx);
623
624     r1 = BN_CTX_get(ctx);
625 #ifndef FIPS_MODE
626     r2 = BN_CTX_get(ctx);
627 #endif
628     m1 = BN_CTX_get(ctx);
629     vrfy = BN_CTX_get(ctx);
630     if (vrfy == NULL)
631         goto err;
632
633 #ifndef FIPS_MODE
634     if (rsa->version == RSA_ASN1_VERSION_MULTI
635         && ((ex_primes = sk_RSA_PRIME_INFO_num(rsa->prime_infos)) <= 0
636              || ex_primes > RSA_MAX_PRIME_NUM - 2))
637         goto err;
638 #endif
639
640     if (rsa->flags & RSA_FLAG_CACHE_PRIVATE) {
641         BIGNUM *factor = BN_new();
642
643         if (factor == NULL)
644             goto err;
645
646         /*
647          * Make sure BN_mod_inverse in Montgomery initialization uses the
648          * BN_FLG_CONSTTIME flag
649          */
650         if (!(BN_with_flags(factor, rsa->p, BN_FLG_CONSTTIME),
651               BN_MONT_CTX_set_locked(&rsa->_method_mod_p, rsa->lock,
652                                      factor, ctx))
653             || !(BN_with_flags(factor, rsa->q, BN_FLG_CONSTTIME),
654                  BN_MONT_CTX_set_locked(&rsa->_method_mod_q, rsa->lock,
655                                         factor, ctx))) {
656             BN_free(factor);
657             goto err;
658         }
659 #ifndef FIPS_MODE
660         for (i = 0; i < ex_primes; i++) {
661             pinfo = sk_RSA_PRIME_INFO_value(rsa->prime_infos, i);
662             BN_with_flags(factor, pinfo->r, BN_FLG_CONSTTIME);
663             if (!BN_MONT_CTX_set_locked(&pinfo->m, rsa->lock, factor, ctx)) {
664                 BN_free(factor);
665                 goto err;
666             }
667         }
668 #endif
669         /*
670          * We MUST free |factor| before any further use of the prime factors
671          */
672         BN_free(factor);
673
674         smooth = (rsa->meth->bn_mod_exp == BN_mod_exp_mont)
675 #ifndef FIPS_MODE
676                  && (ex_primes == 0)
677 #endif
678                  && (BN_num_bits(rsa->q) == BN_num_bits(rsa->p));
679     }
680
681     if (rsa->flags & RSA_FLAG_CACHE_PUBLIC)
682         if (!BN_MONT_CTX_set_locked(&rsa->_method_mod_n, rsa->lock,
683                                     rsa->n, ctx))
684             goto err;
685
686     if (smooth) {
687         /*
688          * Conversion from Montgomery domain, a.k.a. Montgomery reduction,
689          * accepts values in [0-m*2^w) range. w is m's bit width rounded up
690          * to limb width. So that at the very least if |I| is fully reduced,
691          * i.e. less than p*q, we can count on from-to round to perform
692          * below modulo operations on |I|. Unlike BN_mod it's constant time.
693          */
694         if (/* m1 = I moq q */
695             !bn_from_mont_fixed_top(m1, I, rsa->_method_mod_q, ctx)
696             || !bn_to_mont_fixed_top(m1, m1, rsa->_method_mod_q, ctx)
697             /* m1 = m1^dmq1 mod q */
698             || !BN_mod_exp_mont_consttime(m1, m1, rsa->dmq1, rsa->q, ctx,
699                                           rsa->_method_mod_q)
700             /* r1 = I mod p */
701             || !bn_from_mont_fixed_top(r1, I, rsa->_method_mod_p, ctx)
702             || !bn_to_mont_fixed_top(r1, r1, rsa->_method_mod_p, ctx)
703             /* r1 = r1^dmp1 mod p */
704             || !BN_mod_exp_mont_consttime(r1, r1, rsa->dmp1, rsa->p, ctx,
705                                           rsa->_method_mod_p)
706             /* r1 = (r1 - m1) mod p */
707             /*
708              * bn_mod_sub_fixed_top is not regular modular subtraction,
709              * it can tolerate subtrahend to be larger than modulus, but
710              * not bit-wise wider. This makes up for uncommon q>p case,
711              * when |m1| can be larger than |rsa->p|.
712              */
713             || !bn_mod_sub_fixed_top(r1, r1, m1, rsa->p)
714
715             /* r1 = r1 * iqmp mod p */
716             || !bn_to_mont_fixed_top(r1, r1, rsa->_method_mod_p, ctx)
717             || !bn_mul_mont_fixed_top(r1, r1, rsa->iqmp, rsa->_method_mod_p,
718                                       ctx)
719             /* r0 = r1 * q + m1 */
720             || !bn_mul_fixed_top(r0, r1, rsa->q, ctx)
721             || !bn_mod_add_fixed_top(r0, r0, m1, rsa->n))
722             goto err;
723
724         goto tail;
725     }
726
727     /* compute I mod q */
728     {
729         BIGNUM *c = BN_new();
730         if (c == NULL)
731             goto err;
732         BN_with_flags(c, I, BN_FLG_CONSTTIME);
733
734         if (!BN_mod(r1, c, rsa->q, ctx)) {
735             BN_free(c);
736             goto err;
737         }
738
739         {
740             BIGNUM *dmq1 = BN_new();
741             if (dmq1 == NULL) {
742                 BN_free(c);
743                 goto err;
744             }
745             BN_with_flags(dmq1, rsa->dmq1, BN_FLG_CONSTTIME);
746
747             /* compute r1^dmq1 mod q */
748             if (!rsa->meth->bn_mod_exp(m1, r1, dmq1, rsa->q, ctx,
749                                        rsa->_method_mod_q)) {
750                 BN_free(c);
751                 BN_free(dmq1);
752                 goto err;
753             }
754             /* We MUST free dmq1 before any further use of rsa->dmq1 */
755             BN_free(dmq1);
756         }
757
758         /* compute I mod p */
759         if (!BN_mod(r1, c, rsa->p, ctx)) {
760             BN_free(c);
761             goto err;
762         }
763         /* We MUST free c before any further use of I */
764         BN_free(c);
765     }
766
767     {
768         BIGNUM *dmp1 = BN_new();
769         if (dmp1 == NULL)
770             goto err;
771         BN_with_flags(dmp1, rsa->dmp1, BN_FLG_CONSTTIME);
772
773         /* compute r1^dmp1 mod p */
774         if (!rsa->meth->bn_mod_exp(r0, r1, dmp1, rsa->p, ctx,
775                                    rsa->_method_mod_p)) {
776             BN_free(dmp1);
777             goto err;
778         }
779         /* We MUST free dmp1 before any further use of rsa->dmp1 */
780         BN_free(dmp1);
781     }
782
783 #ifndef FIPS_MODE
784     /*
785      * calculate m_i in multi-prime case
786      *
787      * TODO:
788      * 1. squash the following two loops and calculate |m_i| there.
789      * 2. remove cc and reuse |c|.
790      * 3. remove |dmq1| and |dmp1| in previous block and use |di|.
791      *
792      * If these things are done, the code will be more readable.
793      */
794     if (ex_primes > 0) {
795         BIGNUM *di = BN_new(), *cc = BN_new();
796
797         if (cc == NULL || di == NULL) {
798             BN_free(cc);
799             BN_free(di);
800             goto err;
801         }
802
803         for (i = 0; i < ex_primes; i++) {
804             /* prepare m_i */
805             if ((m[i] = BN_CTX_get(ctx)) == NULL) {
806                 BN_free(cc);
807                 BN_free(di);
808                 goto err;
809             }
810
811             pinfo = sk_RSA_PRIME_INFO_value(rsa->prime_infos, i);
812
813             /* prepare c and d_i */
814             BN_with_flags(cc, I, BN_FLG_CONSTTIME);
815             BN_with_flags(di, pinfo->d, BN_FLG_CONSTTIME);
816
817             if (!BN_mod(r1, cc, pinfo->r, ctx)) {
818                 BN_free(cc);
819                 BN_free(di);
820                 goto err;
821             }
822             /* compute r1 ^ d_i mod r_i */
823             if (!rsa->meth->bn_mod_exp(m[i], r1, di, pinfo->r, ctx, pinfo->m)) {
824                 BN_free(cc);
825                 BN_free(di);
826                 goto err;
827             }
828         }
829
830         BN_free(cc);
831         BN_free(di);
832     }
833 #endif
834
835     if (!BN_sub(r0, r0, m1))
836         goto err;
837     /*
838      * This will help stop the size of r0 increasing, which does affect the
839      * multiply if it optimised for a power of 2 size
840      */
841     if (BN_is_negative(r0))
842         if (!BN_add(r0, r0, rsa->p))
843             goto err;
844
845     if (!BN_mul(r1, r0, rsa->iqmp, ctx))
846         goto err;
847
848     {
849         BIGNUM *pr1 = BN_new();
850         if (pr1 == NULL)
851             goto err;
852         BN_with_flags(pr1, r1, BN_FLG_CONSTTIME);
853
854         if (!BN_mod(r0, pr1, rsa->p, ctx)) {
855             BN_free(pr1);
856             goto err;
857         }
858         /* We MUST free pr1 before any further use of r1 */
859         BN_free(pr1);
860     }
861
862     /*
863      * If p < q it is occasionally possible for the correction of adding 'p'
864      * if r0 is negative above to leave the result still negative. This can
865      * break the private key operations: the following second correction
866      * should *always* correct this rare occurrence. This will *never* happen
867      * with OpenSSL generated keys because they ensure p > q [steve]
868      */
869     if (BN_is_negative(r0))
870         if (!BN_add(r0, r0, rsa->p))
871             goto err;
872     if (!BN_mul(r1, r0, rsa->q, ctx))
873         goto err;
874     if (!BN_add(r0, r1, m1))
875         goto err;
876
877 #ifndef FIPS_MODE
878     /* add m_i to m in multi-prime case */
879     if (ex_primes > 0) {
880         BIGNUM *pr2 = BN_new();
881
882         if (pr2 == NULL)
883             goto err;
884
885         for (i = 0; i < ex_primes; i++) {
886             pinfo = sk_RSA_PRIME_INFO_value(rsa->prime_infos, i);
887             if (!BN_sub(r1, m[i], r0)) {
888                 BN_free(pr2);
889                 goto err;
890             }
891
892             if (!BN_mul(r2, r1, pinfo->t, ctx)) {
893                 BN_free(pr2);
894                 goto err;
895             }
896
897             BN_with_flags(pr2, r2, BN_FLG_CONSTTIME);
898
899             if (!BN_mod(r1, pr2, pinfo->r, ctx)) {
900                 BN_free(pr2);
901                 goto err;
902             }
903
904             if (BN_is_negative(r1))
905                 if (!BN_add(r1, r1, pinfo->r)) {
906                     BN_free(pr2);
907                     goto err;
908                 }
909             if (!BN_mul(r1, r1, pinfo->pp, ctx)) {
910                 BN_free(pr2);
911                 goto err;
912             }
913             if (!BN_add(r0, r0, r1)) {
914                 BN_free(pr2);
915                 goto err;
916             }
917         }
918         BN_free(pr2);
919     }
920 #endif
921
922  tail:
923     if (rsa->e && rsa->n) {
924         if (rsa->meth->bn_mod_exp == BN_mod_exp_mont) {
925             if (!BN_mod_exp_mont(vrfy, r0, rsa->e, rsa->n, ctx,
926                                  rsa->_method_mod_n))
927                 goto err;
928         } else {
929             bn_correct_top(r0);
930             if (!rsa->meth->bn_mod_exp(vrfy, r0, rsa->e, rsa->n, ctx,
931                                        rsa->_method_mod_n))
932                 goto err;
933         }
934         /*
935          * If 'I' was greater than (or equal to) rsa->n, the operation will
936          * be equivalent to using 'I mod n'. However, the result of the
937          * verify will *always* be less than 'n' so we don't check for
938          * absolute equality, just congruency.
939          */
940         if (!BN_sub(vrfy, vrfy, I))
941             goto err;
942         if (BN_is_zero(vrfy)) {
943             bn_correct_top(r0);
944             ret = 1;
945             goto err;   /* not actually error */
946         }
947         if (!BN_mod(vrfy, vrfy, rsa->n, ctx))
948             goto err;
949         if (BN_is_negative(vrfy))
950             if (!BN_add(vrfy, vrfy, rsa->n))
951                 goto err;
952         if (!BN_is_zero(vrfy)) {
953             /*
954              * 'I' and 'vrfy' aren't congruent mod n. Don't leak
955              * miscalculated CRT output, just do a raw (slower) mod_exp and
956              * return that instead.
957              */
958
959             BIGNUM *d = BN_new();
960             if (d == NULL)
961                 goto err;
962             BN_with_flags(d, rsa->d, BN_FLG_CONSTTIME);
963
964             if (!rsa->meth->bn_mod_exp(r0, I, d, rsa->n, ctx,
965                                        rsa->_method_mod_n)) {
966                 BN_free(d);
967                 goto err;
968             }
969             /* We MUST free d before any further use of rsa->d */
970             BN_free(d);
971         }
972     }
973     /*
974      * It's unfortunate that we have to bn_correct_top(r0). What hopefully
975      * saves the day is that correction is highly unlike, and private key
976      * operations are customarily performed on blinded message. Which means
977      * that attacker won't observe correlation with chosen plaintext.
978      * Secondly, remaining code would still handle it in same computational
979      * time and even conceal memory access pattern around corrected top.
980      */
981     bn_correct_top(r0);
982     ret = 1;
983  err:
984     BN_CTX_end(ctx);
985     return ret;
986 }
987
988 static int rsa_ossl_init(RSA *rsa)
989 {
990     rsa->flags |= RSA_FLAG_CACHE_PUBLIC | RSA_FLAG_CACHE_PRIVATE;
991     return 1;
992 }
993
994 static int rsa_ossl_finish(RSA *rsa)
995 {
996 #ifndef FIPS_MODE
997     int i;
998     RSA_PRIME_INFO *pinfo;
999
1000     for (i = 0; i < sk_RSA_PRIME_INFO_num(rsa->prime_infos); i++) {
1001         pinfo = sk_RSA_PRIME_INFO_value(rsa->prime_infos, i);
1002         BN_MONT_CTX_free(pinfo->m);
1003     }
1004 #endif
1005
1006     BN_MONT_CTX_free(rsa->_method_mod_n);
1007     BN_MONT_CTX_free(rsa->_method_mod_p);
1008     BN_MONT_CTX_free(rsa->_method_mod_q);
1009     return 1;
1010 }