Deprecate the low level RSA functions.
[openssl.git] / crypto / rsa / rsa_ossl.c
1 /*
2  * Copyright 1995-2018 The OpenSSL Project Authors. All Rights Reserved.
3  *
4  * Licensed under the Apache License 2.0 (the "License").  You may not use
5  * this file except in compliance with the License.  You can obtain a copy
6  * in the file LICENSE in the source distribution or at
7  * https://www.openssl.org/source/license.html
8  */
9
10 /*
11  * RSA low level APIs are deprecated for public use, but still ok for
12  * internal use.
13  */
14 #include "internal/deprecated.h"
15
16 #include "internal/cryptlib.h"
17 #include "crypto/bn.h"
18 #include "rsa_local.h"
19 #include "internal/constant_time.h"
20
21 static int rsa_ossl_public_encrypt(int flen, const unsigned char *from,
22                                   unsigned char *to, RSA *rsa, int padding);
23 static int rsa_ossl_private_encrypt(int flen, const unsigned char *from,
24                                    unsigned char *to, RSA *rsa, int padding);
25 static int rsa_ossl_public_decrypt(int flen, const unsigned char *from,
26                                   unsigned char *to, RSA *rsa, int padding);
27 static int rsa_ossl_private_decrypt(int flen, const unsigned char *from,
28                                    unsigned char *to, RSA *rsa, int padding);
29 static int rsa_ossl_mod_exp(BIGNUM *r0, const BIGNUM *i, RSA *rsa,
30                            BN_CTX *ctx);
31 static int rsa_ossl_init(RSA *rsa);
32 static int rsa_ossl_finish(RSA *rsa);
33 static RSA_METHOD rsa_pkcs1_ossl_meth = {
34     "OpenSSL PKCS#1 RSA",
35     rsa_ossl_public_encrypt,
36     rsa_ossl_public_decrypt,     /* signature verification */
37     rsa_ossl_private_encrypt,    /* signing */
38     rsa_ossl_private_decrypt,
39     rsa_ossl_mod_exp,
40     BN_mod_exp_mont,            /* XXX probably we should not use Montgomery
41                                  * if e == 3 */
42     rsa_ossl_init,
43     rsa_ossl_finish,
44     RSA_FLAG_FIPS_METHOD,       /* flags */
45     NULL,
46     0,                          /* rsa_sign */
47     0,                          /* rsa_verify */
48     NULL,                       /* rsa_keygen */
49     NULL                        /* rsa_multi_prime_keygen */
50 };
51
52 static const RSA_METHOD *default_RSA_meth = &rsa_pkcs1_ossl_meth;
53
54 void RSA_set_default_method(const RSA_METHOD *meth)
55 {
56     default_RSA_meth = meth;
57 }
58
59 const RSA_METHOD *RSA_get_default_method(void)
60 {
61     return default_RSA_meth;
62 }
63
64 const RSA_METHOD *RSA_PKCS1_OpenSSL(void)
65 {
66     return &rsa_pkcs1_ossl_meth;
67 }
68
69 const RSA_METHOD *RSA_null_method(void)
70 {
71     return NULL;
72 }
73
74 static int rsa_ossl_public_encrypt(int flen, const unsigned char *from,
75                                   unsigned char *to, RSA *rsa, int padding)
76 {
77     BIGNUM *f, *ret;
78     int i, num = 0, r = -1;
79     unsigned char *buf = NULL;
80     BN_CTX *ctx = NULL;
81
82     if (BN_num_bits(rsa->n) > OPENSSL_RSA_MAX_MODULUS_BITS) {
83         RSAerr(RSA_F_RSA_OSSL_PUBLIC_ENCRYPT, RSA_R_MODULUS_TOO_LARGE);
84         return -1;
85     }
86
87     if (BN_ucmp(rsa->n, rsa->e) <= 0) {
88         RSAerr(RSA_F_RSA_OSSL_PUBLIC_ENCRYPT, RSA_R_BAD_E_VALUE);
89         return -1;
90     }
91
92     /* for large moduli, enforce exponent limit */
93     if (BN_num_bits(rsa->n) > OPENSSL_RSA_SMALL_MODULUS_BITS) {
94         if (BN_num_bits(rsa->e) > OPENSSL_RSA_MAX_PUBEXP_BITS) {
95             RSAerr(RSA_F_RSA_OSSL_PUBLIC_ENCRYPT, RSA_R_BAD_E_VALUE);
96             return -1;
97         }
98     }
99
100     if ((ctx = BN_CTX_new_ex(rsa->libctx)) == NULL)
101         goto err;
102     BN_CTX_start(ctx);
103     f = BN_CTX_get(ctx);
104     ret = BN_CTX_get(ctx);
105     num = BN_num_bytes(rsa->n);
106     buf = OPENSSL_malloc(num);
107     if (ret == NULL || buf == NULL) {
108         RSAerr(RSA_F_RSA_OSSL_PUBLIC_ENCRYPT, ERR_R_MALLOC_FAILURE);
109         goto err;
110     }
111
112     switch (padding) {
113     case RSA_PKCS1_PADDING:
114         i = RSA_padding_add_PKCS1_type_2(buf, num, from, flen);
115         break;
116     case RSA_PKCS1_OAEP_PADDING:
117         i = RSA_padding_add_PKCS1_OAEP(buf, num, from, flen, NULL, 0);
118         break;
119 #ifndef FIPS_MODE
120     case RSA_SSLV23_PADDING:
121         i = RSA_padding_add_SSLv23(buf, num, from, flen);
122         break;
123 #endif
124     case RSA_NO_PADDING:
125         i = RSA_padding_add_none(buf, num, from, flen);
126         break;
127     default:
128         RSAerr(RSA_F_RSA_OSSL_PUBLIC_ENCRYPT, RSA_R_UNKNOWN_PADDING_TYPE);
129         goto err;
130     }
131     if (i <= 0)
132         goto err;
133
134     if (BN_bin2bn(buf, num, f) == NULL)
135         goto err;
136
137     if (BN_ucmp(f, rsa->n) >= 0) {
138         /* usually the padding functions would catch this */
139         RSAerr(RSA_F_RSA_OSSL_PUBLIC_ENCRYPT,
140                RSA_R_DATA_TOO_LARGE_FOR_MODULUS);
141         goto err;
142     }
143
144     if (rsa->flags & RSA_FLAG_CACHE_PUBLIC)
145         if (!BN_MONT_CTX_set_locked(&rsa->_method_mod_n, rsa->lock,
146                                     rsa->n, ctx))
147             goto err;
148
149     if (!rsa->meth->bn_mod_exp(ret, f, rsa->e, rsa->n, ctx,
150                                rsa->_method_mod_n))
151         goto err;
152
153     /*
154      * BN_bn2binpad puts in leading 0 bytes if the number is less than
155      * the length of the modulus.
156      */
157     r = BN_bn2binpad(ret, to, num);
158  err:
159     BN_CTX_end(ctx);
160     BN_CTX_free(ctx);
161     OPENSSL_clear_free(buf, num);
162     return r;
163 }
164
165 static BN_BLINDING *rsa_get_blinding(RSA *rsa, int *local, BN_CTX *ctx)
166 {
167     BN_BLINDING *ret;
168
169     CRYPTO_THREAD_write_lock(rsa->lock);
170
171     if (rsa->blinding == NULL) {
172         rsa->blinding = RSA_setup_blinding(rsa, ctx);
173     }
174
175     ret = rsa->blinding;
176     if (ret == NULL)
177         goto err;
178
179     if (BN_BLINDING_is_current_thread(ret)) {
180         /* rsa->blinding is ours! */
181
182         *local = 1;
183     } else {
184         /* resort to rsa->mt_blinding instead */
185
186         /*
187          * instructs rsa_blinding_convert(), rsa_blinding_invert() that the
188          * BN_BLINDING is shared, meaning that accesses require locks, and
189          * that the blinding factor must be stored outside the BN_BLINDING
190          */
191         *local = 0;
192
193         if (rsa->mt_blinding == NULL) {
194             rsa->mt_blinding = RSA_setup_blinding(rsa, ctx);
195         }
196         ret = rsa->mt_blinding;
197     }
198
199  err:
200     CRYPTO_THREAD_unlock(rsa->lock);
201     return ret;
202 }
203
204 static int rsa_blinding_convert(BN_BLINDING *b, BIGNUM *f, BIGNUM *unblind,
205                                 BN_CTX *ctx)
206 {
207     if (unblind == NULL) {
208         /*
209          * Local blinding: store the unblinding factor in BN_BLINDING.
210          */
211         return BN_BLINDING_convert_ex(f, NULL, b, ctx);
212     } else {
213         /*
214          * Shared blinding: store the unblinding factor outside BN_BLINDING.
215          */
216         int ret;
217
218         BN_BLINDING_lock(b);
219         ret = BN_BLINDING_convert_ex(f, unblind, b, ctx);
220         BN_BLINDING_unlock(b);
221
222         return ret;
223     }
224 }
225
226 static int rsa_blinding_invert(BN_BLINDING *b, BIGNUM *f, BIGNUM *unblind,
227                                BN_CTX *ctx)
228 {
229     /*
230      * For local blinding, unblind is set to NULL, and BN_BLINDING_invert_ex
231      * will use the unblinding factor stored in BN_BLINDING. If BN_BLINDING
232      * is shared between threads, unblind must be non-null:
233      * BN_BLINDING_invert_ex will then use the local unblinding factor, and
234      * will only read the modulus from BN_BLINDING. In both cases it's safe
235      * to access the blinding without a lock.
236      */
237     return BN_BLINDING_invert_ex(f, unblind, b, ctx);
238 }
239
240 /* signing */
241 static int rsa_ossl_private_encrypt(int flen, const unsigned char *from,
242                                    unsigned char *to, RSA *rsa, int padding)
243 {
244     BIGNUM *f, *ret, *res;
245     int i, num = 0, r = -1;
246     unsigned char *buf = NULL;
247     BN_CTX *ctx = NULL;
248     int local_blinding = 0;
249     /*
250      * Used only if the blinding structure is shared. A non-NULL unblind
251      * instructs rsa_blinding_convert() and rsa_blinding_invert() to store
252      * the unblinding factor outside the blinding structure.
253      */
254     BIGNUM *unblind = NULL;
255     BN_BLINDING *blinding = NULL;
256
257     if ((ctx = BN_CTX_new_ex(rsa->libctx)) == NULL)
258         goto err;
259     BN_CTX_start(ctx);
260     f = BN_CTX_get(ctx);
261     ret = BN_CTX_get(ctx);
262     num = BN_num_bytes(rsa->n);
263     buf = OPENSSL_malloc(num);
264     if (ret == NULL || buf == NULL) {
265         RSAerr(RSA_F_RSA_OSSL_PRIVATE_ENCRYPT, ERR_R_MALLOC_FAILURE);
266         goto err;
267     }
268
269     switch (padding) {
270     case RSA_PKCS1_PADDING:
271         i = RSA_padding_add_PKCS1_type_1(buf, num, from, flen);
272         break;
273     case RSA_X931_PADDING:
274         i = RSA_padding_add_X931(buf, num, from, flen);
275         break;
276     case RSA_NO_PADDING:
277         i = RSA_padding_add_none(buf, num, from, flen);
278         break;
279     case RSA_SSLV23_PADDING:
280     default:
281         RSAerr(RSA_F_RSA_OSSL_PRIVATE_ENCRYPT, RSA_R_UNKNOWN_PADDING_TYPE);
282         goto err;
283     }
284     if (i <= 0)
285         goto err;
286
287     if (BN_bin2bn(buf, num, f) == NULL)
288         goto err;
289
290     if (BN_ucmp(f, rsa->n) >= 0) {
291         /* usually the padding functions would catch this */
292         RSAerr(RSA_F_RSA_OSSL_PRIVATE_ENCRYPT,
293                RSA_R_DATA_TOO_LARGE_FOR_MODULUS);
294         goto err;
295     }
296
297     if (rsa->flags & RSA_FLAG_CACHE_PUBLIC)
298         if (!BN_MONT_CTX_set_locked(&rsa->_method_mod_n, rsa->lock,
299                                     rsa->n, ctx))
300             goto err;
301
302     if (!(rsa->flags & RSA_FLAG_NO_BLINDING)) {
303         blinding = rsa_get_blinding(rsa, &local_blinding, ctx);
304         if (blinding == NULL) {
305             RSAerr(RSA_F_RSA_OSSL_PRIVATE_ENCRYPT, ERR_R_INTERNAL_ERROR);
306             goto err;
307         }
308     }
309
310     if (blinding != NULL) {
311         if (!local_blinding && ((unblind = BN_CTX_get(ctx)) == NULL)) {
312             RSAerr(RSA_F_RSA_OSSL_PRIVATE_ENCRYPT, ERR_R_MALLOC_FAILURE);
313             goto err;
314         }
315         if (!rsa_blinding_convert(blinding, f, unblind, ctx))
316             goto err;
317     }
318
319     if ((rsa->flags & RSA_FLAG_EXT_PKEY) ||
320         (rsa->version == RSA_ASN1_VERSION_MULTI) ||
321         ((rsa->p != NULL) &&
322          (rsa->q != NULL) &&
323          (rsa->dmp1 != NULL) && (rsa->dmq1 != NULL) && (rsa->iqmp != NULL))) {
324         if (!rsa->meth->rsa_mod_exp(ret, f, rsa, ctx))
325             goto err;
326     } else {
327         BIGNUM *d = BN_new();
328         if (d == NULL) {
329             RSAerr(RSA_F_RSA_OSSL_PRIVATE_ENCRYPT, ERR_R_MALLOC_FAILURE);
330             goto err;
331         }
332         if (rsa->d == NULL) {
333             RSAerr(RSA_F_RSA_OSSL_PRIVATE_ENCRYPT, RSA_R_MISSING_PRIVATE_KEY);
334             BN_free(d);
335             goto err;
336         }
337         BN_with_flags(d, rsa->d, BN_FLG_CONSTTIME);
338
339         if (!rsa->meth->bn_mod_exp(ret, f, d, rsa->n, ctx,
340                                    rsa->_method_mod_n)) {
341             BN_free(d);
342             goto err;
343         }
344         /* We MUST free d before any further use of rsa->d */
345         BN_free(d);
346     }
347
348     if (blinding)
349         if (!rsa_blinding_invert(blinding, ret, unblind, ctx))
350             goto err;
351
352     if (padding == RSA_X931_PADDING) {
353         if (!BN_sub(f, rsa->n, ret))
354             goto err;
355         if (BN_cmp(ret, f) > 0)
356             res = f;
357         else
358             res = ret;
359     } else {
360         res = ret;
361     }
362
363     /*
364      * BN_bn2binpad puts in leading 0 bytes if the number is less than
365      * the length of the modulus.
366      */
367     r = BN_bn2binpad(res, to, num);
368  err:
369     BN_CTX_end(ctx);
370     BN_CTX_free(ctx);
371     OPENSSL_clear_free(buf, num);
372     return r;
373 }
374
375 static int rsa_ossl_private_decrypt(int flen, const unsigned char *from,
376                                    unsigned char *to, RSA *rsa, int padding)
377 {
378     BIGNUM *f, *ret;
379     int j, num = 0, r = -1;
380     unsigned char *buf = NULL;
381     BN_CTX *ctx = NULL;
382     int local_blinding = 0;
383     /*
384      * Used only if the blinding structure is shared. A non-NULL unblind
385      * instructs rsa_blinding_convert() and rsa_blinding_invert() to store
386      * the unblinding factor outside the blinding structure.
387      */
388     BIGNUM *unblind = NULL;
389     BN_BLINDING *blinding = NULL;
390
391     if ((ctx = BN_CTX_new_ex(rsa->libctx)) == NULL)
392         goto err;
393     BN_CTX_start(ctx);
394     f = BN_CTX_get(ctx);
395     ret = BN_CTX_get(ctx);
396     num = BN_num_bytes(rsa->n);
397     buf = OPENSSL_malloc(num);
398     if (ret == NULL || buf == NULL) {
399         RSAerr(RSA_F_RSA_OSSL_PRIVATE_DECRYPT, ERR_R_MALLOC_FAILURE);
400         goto err;
401     }
402
403     /*
404      * This check was for equality but PGP does evil things and chops off the
405      * top '0' bytes
406      */
407     if (flen > num) {
408         RSAerr(RSA_F_RSA_OSSL_PRIVATE_DECRYPT,
409                RSA_R_DATA_GREATER_THAN_MOD_LEN);
410         goto err;
411     }
412
413     /* make data into a big number */
414     if (BN_bin2bn(from, (int)flen, f) == NULL)
415         goto err;
416
417     if (BN_ucmp(f, rsa->n) >= 0) {
418         RSAerr(RSA_F_RSA_OSSL_PRIVATE_DECRYPT,
419                RSA_R_DATA_TOO_LARGE_FOR_MODULUS);
420         goto err;
421     }
422
423     if (!(rsa->flags & RSA_FLAG_NO_BLINDING)) {
424         blinding = rsa_get_blinding(rsa, &local_blinding, ctx);
425         if (blinding == NULL) {
426             RSAerr(RSA_F_RSA_OSSL_PRIVATE_DECRYPT, ERR_R_INTERNAL_ERROR);
427             goto err;
428         }
429     }
430
431     if (blinding != NULL) {
432         if (!local_blinding && ((unblind = BN_CTX_get(ctx)) == NULL)) {
433             RSAerr(RSA_F_RSA_OSSL_PRIVATE_DECRYPT, ERR_R_MALLOC_FAILURE);
434             goto err;
435         }
436         if (!rsa_blinding_convert(blinding, f, unblind, ctx))
437             goto err;
438     }
439
440     /* do the decrypt */
441     if ((rsa->flags & RSA_FLAG_EXT_PKEY) ||
442         (rsa->version == RSA_ASN1_VERSION_MULTI) ||
443         ((rsa->p != NULL) &&
444          (rsa->q != NULL) &&
445          (rsa->dmp1 != NULL) && (rsa->dmq1 != NULL) && (rsa->iqmp != NULL))) {
446         if (!rsa->meth->rsa_mod_exp(ret, f, rsa, ctx))
447             goto err;
448     } else {
449         BIGNUM *d = BN_new();
450         if (d == NULL) {
451             RSAerr(RSA_F_RSA_OSSL_PRIVATE_DECRYPT, ERR_R_MALLOC_FAILURE);
452             goto err;
453         }
454         if (rsa->d == NULL) {
455             RSAerr(RSA_F_RSA_OSSL_PRIVATE_DECRYPT, RSA_R_MISSING_PRIVATE_KEY);
456             BN_free(d);
457             goto err;
458         }
459         BN_with_flags(d, rsa->d, BN_FLG_CONSTTIME);
460
461         if (rsa->flags & RSA_FLAG_CACHE_PUBLIC)
462             if (!BN_MONT_CTX_set_locked(&rsa->_method_mod_n, rsa->lock,
463                                         rsa->n, ctx)) {
464                 BN_free(d);
465                 goto err;
466             }
467         if (!rsa->meth->bn_mod_exp(ret, f, d, rsa->n, ctx,
468                                    rsa->_method_mod_n)) {
469             BN_free(d);
470             goto err;
471         }
472         /* We MUST free d before any further use of rsa->d */
473         BN_free(d);
474     }
475
476     if (blinding)
477         if (!rsa_blinding_invert(blinding, ret, unblind, ctx))
478             goto err;
479
480     j = BN_bn2binpad(ret, buf, num);
481     if (j < 0)
482         goto err;
483
484     switch (padding) {
485     case RSA_PKCS1_PADDING:
486         r = RSA_padding_check_PKCS1_type_2(to, num, buf, j, num);
487         break;
488     case RSA_PKCS1_OAEP_PADDING:
489         r = RSA_padding_check_PKCS1_OAEP(to, num, buf, j, num, NULL, 0);
490         break;
491 #ifndef FIPS_MODE
492     case RSA_SSLV23_PADDING:
493         r = RSA_padding_check_SSLv23(to, num, buf, j, num);
494         break;
495 #endif
496     case RSA_NO_PADDING:
497         memcpy(to, buf, (r = j));
498         break;
499     default:
500         RSAerr(RSA_F_RSA_OSSL_PRIVATE_DECRYPT, RSA_R_UNKNOWN_PADDING_TYPE);
501         goto err;
502     }
503 #ifndef FIPS_MODE
504     /*
505      * This trick doesn't work in the FIPS provider because libcrypto manages
506      * the error stack. Instead we opt not to put an error on the stack at all
507      * in case of padding failure in the FIPS provider.
508      */
509     RSAerr(RSA_F_RSA_OSSL_PRIVATE_DECRYPT, RSA_R_PADDING_CHECK_FAILED);
510     err_clear_last_constant_time(1 & ~constant_time_msb(r));
511 #endif
512
513  err:
514     BN_CTX_end(ctx);
515     BN_CTX_free(ctx);
516     OPENSSL_clear_free(buf, num);
517     return r;
518 }
519
520 /* signature verification */
521 static int rsa_ossl_public_decrypt(int flen, const unsigned char *from,
522                                   unsigned char *to, RSA *rsa, int padding)
523 {
524     BIGNUM *f, *ret;
525     int i, num = 0, r = -1;
526     unsigned char *buf = NULL;
527     BN_CTX *ctx = NULL;
528
529     if (BN_num_bits(rsa->n) > OPENSSL_RSA_MAX_MODULUS_BITS) {
530         RSAerr(RSA_F_RSA_OSSL_PUBLIC_DECRYPT, RSA_R_MODULUS_TOO_LARGE);
531         return -1;
532     }
533
534     if (BN_ucmp(rsa->n, rsa->e) <= 0) {
535         RSAerr(RSA_F_RSA_OSSL_PUBLIC_DECRYPT, RSA_R_BAD_E_VALUE);
536         return -1;
537     }
538
539     /* for large moduli, enforce exponent limit */
540     if (BN_num_bits(rsa->n) > OPENSSL_RSA_SMALL_MODULUS_BITS) {
541         if (BN_num_bits(rsa->e) > OPENSSL_RSA_MAX_PUBEXP_BITS) {
542             RSAerr(RSA_F_RSA_OSSL_PUBLIC_DECRYPT, RSA_R_BAD_E_VALUE);
543             return -1;
544         }
545     }
546
547     if ((ctx = BN_CTX_new_ex(rsa->libctx)) == NULL)
548         goto err;
549     BN_CTX_start(ctx);
550     f = BN_CTX_get(ctx);
551     ret = BN_CTX_get(ctx);
552     num = BN_num_bytes(rsa->n);
553     buf = OPENSSL_malloc(num);
554     if (ret == NULL || buf == NULL) {
555         RSAerr(RSA_F_RSA_OSSL_PUBLIC_DECRYPT, ERR_R_MALLOC_FAILURE);
556         goto err;
557     }
558
559     /*
560      * This check was for equality but PGP does evil things and chops off the
561      * top '0' bytes
562      */
563     if (flen > num) {
564         RSAerr(RSA_F_RSA_OSSL_PUBLIC_DECRYPT, RSA_R_DATA_GREATER_THAN_MOD_LEN);
565         goto err;
566     }
567
568     if (BN_bin2bn(from, flen, f) == NULL)
569         goto err;
570
571     if (BN_ucmp(f, rsa->n) >= 0) {
572         RSAerr(RSA_F_RSA_OSSL_PUBLIC_DECRYPT,
573                RSA_R_DATA_TOO_LARGE_FOR_MODULUS);
574         goto err;
575     }
576
577     if (rsa->flags & RSA_FLAG_CACHE_PUBLIC)
578         if (!BN_MONT_CTX_set_locked(&rsa->_method_mod_n, rsa->lock,
579                                     rsa->n, ctx))
580             goto err;
581
582     if (!rsa->meth->bn_mod_exp(ret, f, rsa->e, rsa->n, ctx,
583                                rsa->_method_mod_n))
584         goto err;
585
586     if ((padding == RSA_X931_PADDING) && ((bn_get_words(ret)[0] & 0xf) != 12))
587         if (!BN_sub(ret, rsa->n, ret))
588             goto err;
589
590     i = BN_bn2binpad(ret, buf, num);
591     if (i < 0)
592         goto err;
593
594     switch (padding) {
595     case RSA_PKCS1_PADDING:
596         r = RSA_padding_check_PKCS1_type_1(to, num, buf, i, num);
597         break;
598     case RSA_X931_PADDING:
599         r = RSA_padding_check_X931(to, num, buf, i, num);
600         break;
601     case RSA_NO_PADDING:
602         memcpy(to, buf, (r = i));
603         break;
604     default:
605         RSAerr(RSA_F_RSA_OSSL_PUBLIC_DECRYPT, RSA_R_UNKNOWN_PADDING_TYPE);
606         goto err;
607     }
608     if (r < 0)
609         RSAerr(RSA_F_RSA_OSSL_PUBLIC_DECRYPT, RSA_R_PADDING_CHECK_FAILED);
610
611  err:
612     BN_CTX_end(ctx);
613     BN_CTX_free(ctx);
614     OPENSSL_clear_free(buf, num);
615     return r;
616 }
617
618 static int rsa_ossl_mod_exp(BIGNUM *r0, const BIGNUM *I, RSA *rsa, BN_CTX *ctx)
619 {
620     BIGNUM *r1, *m1, *vrfy;
621     int ret = 0, smooth = 0;
622 #ifndef FIPS_MODE
623     BIGNUM *r2, *m[RSA_MAX_PRIME_NUM - 2];
624     int i, ex_primes = 0;
625     RSA_PRIME_INFO *pinfo;
626 #endif
627
628     BN_CTX_start(ctx);
629
630     r1 = BN_CTX_get(ctx);
631 #ifndef FIPS_MODE
632     r2 = BN_CTX_get(ctx);
633 #endif
634     m1 = BN_CTX_get(ctx);
635     vrfy = BN_CTX_get(ctx);
636     if (vrfy == NULL)
637         goto err;
638
639 #ifndef FIPS_MODE
640     if (rsa->version == RSA_ASN1_VERSION_MULTI
641         && ((ex_primes = sk_RSA_PRIME_INFO_num(rsa->prime_infos)) <= 0
642              || ex_primes > RSA_MAX_PRIME_NUM - 2))
643         goto err;
644 #endif
645
646     if (rsa->flags & RSA_FLAG_CACHE_PRIVATE) {
647         BIGNUM *factor = BN_new();
648
649         if (factor == NULL)
650             goto err;
651
652         /*
653          * Make sure BN_mod_inverse in Montgomery initialization uses the
654          * BN_FLG_CONSTTIME flag
655          */
656         if (!(BN_with_flags(factor, rsa->p, BN_FLG_CONSTTIME),
657               BN_MONT_CTX_set_locked(&rsa->_method_mod_p, rsa->lock,
658                                      factor, ctx))
659             || !(BN_with_flags(factor, rsa->q, BN_FLG_CONSTTIME),
660                  BN_MONT_CTX_set_locked(&rsa->_method_mod_q, rsa->lock,
661                                         factor, ctx))) {
662             BN_free(factor);
663             goto err;
664         }
665 #ifndef FIPS_MODE
666         for (i = 0; i < ex_primes; i++) {
667             pinfo = sk_RSA_PRIME_INFO_value(rsa->prime_infos, i);
668             BN_with_flags(factor, pinfo->r, BN_FLG_CONSTTIME);
669             if (!BN_MONT_CTX_set_locked(&pinfo->m, rsa->lock, factor, ctx)) {
670                 BN_free(factor);
671                 goto err;
672             }
673         }
674 #endif
675         /*
676          * We MUST free |factor| before any further use of the prime factors
677          */
678         BN_free(factor);
679
680         smooth = (rsa->meth->bn_mod_exp == BN_mod_exp_mont)
681 #ifndef FIPS_MODE
682                  && (ex_primes == 0)
683 #endif
684                  && (BN_num_bits(rsa->q) == BN_num_bits(rsa->p));
685     }
686
687     if (rsa->flags & RSA_FLAG_CACHE_PUBLIC)
688         if (!BN_MONT_CTX_set_locked(&rsa->_method_mod_n, rsa->lock,
689                                     rsa->n, ctx))
690             goto err;
691
692     if (smooth) {
693         /*
694          * Conversion from Montgomery domain, a.k.a. Montgomery reduction,
695          * accepts values in [0-m*2^w) range. w is m's bit width rounded up
696          * to limb width. So that at the very least if |I| is fully reduced,
697          * i.e. less than p*q, we can count on from-to round to perform
698          * below modulo operations on |I|. Unlike BN_mod it's constant time.
699          */
700         if (/* m1 = I moq q */
701             !bn_from_mont_fixed_top(m1, I, rsa->_method_mod_q, ctx)
702             || !bn_to_mont_fixed_top(m1, m1, rsa->_method_mod_q, ctx)
703             /* m1 = m1^dmq1 mod q */
704             || !BN_mod_exp_mont_consttime(m1, m1, rsa->dmq1, rsa->q, ctx,
705                                           rsa->_method_mod_q)
706             /* r1 = I mod p */
707             || !bn_from_mont_fixed_top(r1, I, rsa->_method_mod_p, ctx)
708             || !bn_to_mont_fixed_top(r1, r1, rsa->_method_mod_p, ctx)
709             /* r1 = r1^dmp1 mod p */
710             || !BN_mod_exp_mont_consttime(r1, r1, rsa->dmp1, rsa->p, ctx,
711                                           rsa->_method_mod_p)
712             /* r1 = (r1 - m1) mod p */
713             /*
714              * bn_mod_sub_fixed_top is not regular modular subtraction,
715              * it can tolerate subtrahend to be larger than modulus, but
716              * not bit-wise wider. This makes up for uncommon q>p case,
717              * when |m1| can be larger than |rsa->p|.
718              */
719             || !bn_mod_sub_fixed_top(r1, r1, m1, rsa->p)
720
721             /* r1 = r1 * iqmp mod p */
722             || !bn_to_mont_fixed_top(r1, r1, rsa->_method_mod_p, ctx)
723             || !bn_mul_mont_fixed_top(r1, r1, rsa->iqmp, rsa->_method_mod_p,
724                                       ctx)
725             /* r0 = r1 * q + m1 */
726             || !bn_mul_fixed_top(r0, r1, rsa->q, ctx)
727             || !bn_mod_add_fixed_top(r0, r0, m1, rsa->n))
728             goto err;
729
730         goto tail;
731     }
732
733     /* compute I mod q */
734     {
735         BIGNUM *c = BN_new();
736         if (c == NULL)
737             goto err;
738         BN_with_flags(c, I, BN_FLG_CONSTTIME);
739
740         if (!BN_mod(r1, c, rsa->q, ctx)) {
741             BN_free(c);
742             goto err;
743         }
744
745         {
746             BIGNUM *dmq1 = BN_new();
747             if (dmq1 == NULL) {
748                 BN_free(c);
749                 goto err;
750             }
751             BN_with_flags(dmq1, rsa->dmq1, BN_FLG_CONSTTIME);
752
753             /* compute r1^dmq1 mod q */
754             if (!rsa->meth->bn_mod_exp(m1, r1, dmq1, rsa->q, ctx,
755                                        rsa->_method_mod_q)) {
756                 BN_free(c);
757                 BN_free(dmq1);
758                 goto err;
759             }
760             /* We MUST free dmq1 before any further use of rsa->dmq1 */
761             BN_free(dmq1);
762         }
763
764         /* compute I mod p */
765         if (!BN_mod(r1, c, rsa->p, ctx)) {
766             BN_free(c);
767             goto err;
768         }
769         /* We MUST free c before any further use of I */
770         BN_free(c);
771     }
772
773     {
774         BIGNUM *dmp1 = BN_new();
775         if (dmp1 == NULL)
776             goto err;
777         BN_with_flags(dmp1, rsa->dmp1, BN_FLG_CONSTTIME);
778
779         /* compute r1^dmp1 mod p */
780         if (!rsa->meth->bn_mod_exp(r0, r1, dmp1, rsa->p, ctx,
781                                    rsa->_method_mod_p)) {
782             BN_free(dmp1);
783             goto err;
784         }
785         /* We MUST free dmp1 before any further use of rsa->dmp1 */
786         BN_free(dmp1);
787     }
788
789 #ifndef FIPS_MODE
790     /*
791      * calculate m_i in multi-prime case
792      *
793      * TODO:
794      * 1. squash the following two loops and calculate |m_i| there.
795      * 2. remove cc and reuse |c|.
796      * 3. remove |dmq1| and |dmp1| in previous block and use |di|.
797      *
798      * If these things are done, the code will be more readable.
799      */
800     if (ex_primes > 0) {
801         BIGNUM *di = BN_new(), *cc = BN_new();
802
803         if (cc == NULL || di == NULL) {
804             BN_free(cc);
805             BN_free(di);
806             goto err;
807         }
808
809         for (i = 0; i < ex_primes; i++) {
810             /* prepare m_i */
811             if ((m[i] = BN_CTX_get(ctx)) == NULL) {
812                 BN_free(cc);
813                 BN_free(di);
814                 goto err;
815             }
816
817             pinfo = sk_RSA_PRIME_INFO_value(rsa->prime_infos, i);
818
819             /* prepare c and d_i */
820             BN_with_flags(cc, I, BN_FLG_CONSTTIME);
821             BN_with_flags(di, pinfo->d, BN_FLG_CONSTTIME);
822
823             if (!BN_mod(r1, cc, pinfo->r, ctx)) {
824                 BN_free(cc);
825                 BN_free(di);
826                 goto err;
827             }
828             /* compute r1 ^ d_i mod r_i */
829             if (!rsa->meth->bn_mod_exp(m[i], r1, di, pinfo->r, ctx, pinfo->m)) {
830                 BN_free(cc);
831                 BN_free(di);
832                 goto err;
833             }
834         }
835
836         BN_free(cc);
837         BN_free(di);
838     }
839 #endif
840
841     if (!BN_sub(r0, r0, m1))
842         goto err;
843     /*
844      * This will help stop the size of r0 increasing, which does affect the
845      * multiply if it optimised for a power of 2 size
846      */
847     if (BN_is_negative(r0))
848         if (!BN_add(r0, r0, rsa->p))
849             goto err;
850
851     if (!BN_mul(r1, r0, rsa->iqmp, ctx))
852         goto err;
853
854     {
855         BIGNUM *pr1 = BN_new();
856         if (pr1 == NULL)
857             goto err;
858         BN_with_flags(pr1, r1, BN_FLG_CONSTTIME);
859
860         if (!BN_mod(r0, pr1, rsa->p, ctx)) {
861             BN_free(pr1);
862             goto err;
863         }
864         /* We MUST free pr1 before any further use of r1 */
865         BN_free(pr1);
866     }
867
868     /*
869      * If p < q it is occasionally possible for the correction of adding 'p'
870      * if r0 is negative above to leave the result still negative. This can
871      * break the private key operations: the following second correction
872      * should *always* correct this rare occurrence. This will *never* happen
873      * with OpenSSL generated keys because they ensure p > q [steve]
874      */
875     if (BN_is_negative(r0))
876         if (!BN_add(r0, r0, rsa->p))
877             goto err;
878     if (!BN_mul(r1, r0, rsa->q, ctx))
879         goto err;
880     if (!BN_add(r0, r1, m1))
881         goto err;
882
883 #ifndef FIPS_MODE
884     /* add m_i to m in multi-prime case */
885     if (ex_primes > 0) {
886         BIGNUM *pr2 = BN_new();
887
888         if (pr2 == NULL)
889             goto err;
890
891         for (i = 0; i < ex_primes; i++) {
892             pinfo = sk_RSA_PRIME_INFO_value(rsa->prime_infos, i);
893             if (!BN_sub(r1, m[i], r0)) {
894                 BN_free(pr2);
895                 goto err;
896             }
897
898             if (!BN_mul(r2, r1, pinfo->t, ctx)) {
899                 BN_free(pr2);
900                 goto err;
901             }
902
903             BN_with_flags(pr2, r2, BN_FLG_CONSTTIME);
904
905             if (!BN_mod(r1, pr2, pinfo->r, ctx)) {
906                 BN_free(pr2);
907                 goto err;
908             }
909
910             if (BN_is_negative(r1))
911                 if (!BN_add(r1, r1, pinfo->r)) {
912                     BN_free(pr2);
913                     goto err;
914                 }
915             if (!BN_mul(r1, r1, pinfo->pp, ctx)) {
916                 BN_free(pr2);
917                 goto err;
918             }
919             if (!BN_add(r0, r0, r1)) {
920                 BN_free(pr2);
921                 goto err;
922             }
923         }
924         BN_free(pr2);
925     }
926 #endif
927
928  tail:
929     if (rsa->e && rsa->n) {
930         if (rsa->meth->bn_mod_exp == BN_mod_exp_mont) {
931             if (!BN_mod_exp_mont(vrfy, r0, rsa->e, rsa->n, ctx,
932                                  rsa->_method_mod_n))
933                 goto err;
934         } else {
935             bn_correct_top(r0);
936             if (!rsa->meth->bn_mod_exp(vrfy, r0, rsa->e, rsa->n, ctx,
937                                        rsa->_method_mod_n))
938                 goto err;
939         }
940         /*
941          * If 'I' was greater than (or equal to) rsa->n, the operation will
942          * be equivalent to using 'I mod n'. However, the result of the
943          * verify will *always* be less than 'n' so we don't check for
944          * absolute equality, just congruency.
945          */
946         if (!BN_sub(vrfy, vrfy, I))
947             goto err;
948         if (BN_is_zero(vrfy)) {
949             bn_correct_top(r0);
950             ret = 1;
951             goto err;   /* not actually error */
952         }
953         if (!BN_mod(vrfy, vrfy, rsa->n, ctx))
954             goto err;
955         if (BN_is_negative(vrfy))
956             if (!BN_add(vrfy, vrfy, rsa->n))
957                 goto err;
958         if (!BN_is_zero(vrfy)) {
959             /*
960              * 'I' and 'vrfy' aren't congruent mod n. Don't leak
961              * miscalculated CRT output, just do a raw (slower) mod_exp and
962              * return that instead.
963              */
964
965             BIGNUM *d = BN_new();
966             if (d == NULL)
967                 goto err;
968             BN_with_flags(d, rsa->d, BN_FLG_CONSTTIME);
969
970             if (!rsa->meth->bn_mod_exp(r0, I, d, rsa->n, ctx,
971                                        rsa->_method_mod_n)) {
972                 BN_free(d);
973                 goto err;
974             }
975             /* We MUST free d before any further use of rsa->d */
976             BN_free(d);
977         }
978     }
979     /*
980      * It's unfortunate that we have to bn_correct_top(r0). What hopefully
981      * saves the day is that correction is highly unlike, and private key
982      * operations are customarily performed on blinded message. Which means
983      * that attacker won't observe correlation with chosen plaintext.
984      * Secondly, remaining code would still handle it in same computational
985      * time and even conceal memory access pattern around corrected top.
986      */
987     bn_correct_top(r0);
988     ret = 1;
989  err:
990     BN_CTX_end(ctx);
991     return ret;
992 }
993
994 static int rsa_ossl_init(RSA *rsa)
995 {
996     rsa->flags |= RSA_FLAG_CACHE_PUBLIC | RSA_FLAG_CACHE_PRIVATE;
997     return 1;
998 }
999
1000 static int rsa_ossl_finish(RSA *rsa)
1001 {
1002 #ifndef FIPS_MODE
1003     int i;
1004     RSA_PRIME_INFO *pinfo;
1005
1006     for (i = 0; i < sk_RSA_PRIME_INFO_num(rsa->prime_infos); i++) {
1007         pinfo = sk_RSA_PRIME_INFO_value(rsa->prime_infos, i);
1008         BN_MONT_CTX_free(pinfo->m);
1009     }
1010 #endif
1011
1012     BN_MONT_CTX_free(rsa->_method_mod_n);
1013     BN_MONT_CTX_free(rsa->_method_mod_p);
1014     BN_MONT_CTX_free(rsa->_method_mod_q);
1015     return 1;
1016 }