Erase temporary buffer in EVP_PKEY_get_bn_param()
[openssl.git] / crypto / bn / rsa_sup_mul.c
1 /*
2  * Copyright 2022-2023 The OpenSSL Project Authors. All Rights Reserved.
3  *
4  * Licensed under the Apache License 2.0 (the "License").  You may not use
5  * this file except in compliance with the License.  You can obtain a copy
6  * in the file LICENSE in the source distribution or at
7  * https://www.openssl.org/source/license.html
8  */
9
10 #include <openssl/e_os2.h>
11 #include <stddef.h>
12 #include <sys/types.h>
13 #include <string.h>
14 #include <openssl/bn.h>
15 #include <openssl/err.h>
16 #include <openssl/rsaerr.h>
17 #include "internal/endian.h"
18 #include "internal/numbers.h"
19 #include "internal/constant_time.h"
20 #include "bn_local.h"
21
22 # if BN_BYTES == 8
23 typedef uint64_t limb_t;
24 #  if defined(__SIZEOF_INT128__) && __SIZEOF_INT128__ == 16
25 typedef uint128_t limb2_t;
26 #   define HAVE_LIMB2_T
27 #  endif
28 #  define LIMB_BIT_SIZE 64
29 #  define LIMB_BYTE_SIZE 8
30 # elif BN_BYTES == 4
31 typedef uint32_t limb_t;
32 typedef uint64_t limb2_t;
33 #  define LIMB_BIT_SIZE 32
34 #  define LIMB_BYTE_SIZE 4
35 #  define HAVE_LIMB2_T
36 # else
37 #  error "Not supported"
38 # endif
39
40 /*
41  * For multiplication we're using schoolbook multiplication,
42  * so if we have two numbers, each with 6 "digits" (words)
43  * the multiplication is calculated as follows:
44  *                        A B C D E F
45  *                     x  I J K L M N
46  *                     --------------
47  *                                N*F
48  *                              N*E
49  *                            N*D
50  *                          N*C
51  *                        N*B
52  *                      N*A
53  *                              M*F
54  *                            M*E
55  *                          M*D
56  *                        M*C
57  *                      M*B
58  *                    M*A
59  *                            L*F
60  *                          L*E
61  *                        L*D
62  *                      L*C
63  *                    L*B
64  *                  L*A
65  *                          K*F
66  *                        K*E
67  *                      K*D
68  *                    K*C
69  *                  K*B
70  *                K*A
71  *                        J*F
72  *                      J*E
73  *                    J*D
74  *                  J*C
75  *                J*B
76  *              J*A
77  *                      I*F
78  *                    I*E
79  *                  I*D
80  *                I*C
81  *              I*B
82  *         +  I*A
83  *         ==========================
84  *                        N*B N*D N*F
85  *                    + N*A N*C N*E
86  *                    + M*B M*D M*F
87  *                  + M*A M*C M*E
88  *                  + L*B L*D L*F
89  *                + L*A L*C L*E
90  *                + K*B K*D K*F
91  *              + K*A K*C K*E
92  *              + J*B J*D J*F
93  *            + J*A J*C J*E
94  *            + I*B I*D I*F
95  *          + I*A I*C I*E
96  *
97  *                1+1 1+3 1+5
98  *              1+0 1+2 1+4
99  *              0+1 0+3 0+5
100  *            0+0 0+2 0+4
101  *
102  *            0 1 2 3 4 5 6
103  * which requires n^2 multiplications and 2n full length additions
104  * as we can keep every other result of limb multiplication in two separate
105  * limbs
106  */
107
108 #if defined HAVE_LIMB2_T
109 static ossl_inline void _mul_limb(limb_t *hi, limb_t *lo, limb_t a, limb_t b)
110 {
111     limb2_t t;
112     /*
113      * this is idiomatic code to tell compiler to use the native mul
114      * those three lines will actually compile to single instruction
115      */
116
117     t = (limb2_t)a * b;
118     *hi = t >> LIMB_BIT_SIZE;
119     *lo = (limb_t)t;
120 }
121 #elif (BN_BYTES == 8) && (defined _MSC_VER)
122 # if defined(_M_X64)
123 /*
124  * on x86_64 (x64) we can use the _umul128 intrinsic to get one `mul`
125  * instruction to get both high and low 64 bits of the multiplication.
126  * https://learn.microsoft.com/en-us/cpp/intrinsics/umul128?view=msvc-140
127  */
128 #include <intrin.h>
129 #pragma intrinsic(_umul128)
130 static ossl_inline void _mul_limb(limb_t *hi, limb_t *lo, limb_t a, limb_t b)
131 {
132     *lo = _umul128(a, b, hi);
133 }
134 # elif defined(_M_ARM64) || defined (_M_IA64)
135 /*
136  * We can't use the __umulh() on x86_64 as then msvc generates two `mul`
137  * instructions; so use this more portable intrinsic on platforms that
138  * don't support _umul128 (like aarch64 (ARM64) or ia64)
139  * https://learn.microsoft.com/en-us/cpp/intrinsics/umulh?view=msvc-140
140  */
141 #include <intrin.h>
142 static ossl_inline void _mul_limb(limb_t *hi, limb_t *lo, limb_t a, limb_t b)
143 {
144     *lo = a * b;
145     *hi = __umulh(a, b);
146 }
147 # else
148 # error Only x64, ARM64 and IA64 supported.
149 # endif /* defined(_M_X64) */
150 #else
151 /*
152  * if the compiler doesn't have either a 128bit data type nor a "return
153  * high 64 bits of multiplication"
154  */
155 static ossl_inline void _mul_limb(limb_t *hi, limb_t *lo, limb_t a, limb_t b)
156 {
157     limb_t a_low = (limb_t)(uint32_t)a;
158     limb_t a_hi = a >> 32;
159     limb_t b_low = (limb_t)(uint32_t)b;
160     limb_t b_hi = b >> 32;
161
162     limb_t p0 = a_low * b_low;
163     limb_t p1 = a_low * b_hi;
164     limb_t p2 = a_hi * b_low;
165     limb_t p3 = a_hi * b_hi;
166
167     uint32_t cy = (uint32_t)(((p0 >> 32) + (uint32_t)p1 + (uint32_t)p2) >> 32);
168
169     *lo = p0 + (p1 << 32) + (p2 << 32);
170     *hi = p3 + (p1 >> 32) + (p2 >> 32) + cy;
171 }
172 #endif
173
174 /* add two limbs with carry in, return carry out */
175 static ossl_inline limb_t _add_limb(limb_t *ret, limb_t a, limb_t b, limb_t carry)
176 {
177     limb_t carry1, carry2, t;
178     /*
179      * `c = a + b; if (c < a)` is idiomatic code that makes compilers
180      * use add with carry on assembly level
181      */
182
183     *ret = a + carry;
184     if (*ret < a)
185         carry1 = 1;
186     else
187         carry1 = 0;
188
189     t = *ret;
190     *ret = t + b;
191     if (*ret < t)
192         carry2 = 1;
193     else
194         carry2 = 0;
195
196     return carry1 + carry2;
197 }
198
199 /*
200  * add two numbers of the same size, return overflow
201  *
202  * add a to b, place result in ret; all arrays need to be n limbs long
203  * return overflow from addition (0 or 1)
204  */
205 static ossl_inline limb_t add(limb_t *ret, limb_t *a, limb_t *b, size_t n)
206 {
207     limb_t c = 0;
208     ossl_ssize_t i;
209
210     for(i = n - 1; i > -1; i--)
211         c = _add_limb(&ret[i], a[i], b[i], c);
212
213     return c;
214 }
215
216 /*
217  * return number of limbs necessary for temporary values
218  * when multiplying numbers n limbs large
219  */
220 static ossl_inline size_t mul_limb_numb(size_t n)
221 {
222     return  2 * n * 2;
223 }
224
225 /*
226  * multiply two numbers of the same size
227  *
228  * multiply a by b, place result in ret; a and b need to be n limbs long
229  * ret needs to be 2*n limbs long, tmp needs to be mul_limb_numb(n) limbs
230  * long
231  */
232 static void limb_mul(limb_t *ret, limb_t *a, limb_t *b, size_t n, limb_t *tmp)
233 {
234     limb_t *r_odd, *r_even;
235     size_t i, j, k;
236
237     r_odd = tmp;
238     r_even = &tmp[2 * n];
239
240     memset(ret, 0, 2 * n * sizeof(limb_t));
241
242     for (i = 0; i < n; i++) {
243         for (k = 0; k < i + n + 1; k++) {
244             r_even[k] = 0;
245             r_odd[k] = 0;
246         }
247         for (j = 0; j < n; j++) {
248             /*
249              * place results from even and odd limbs in separate arrays so that
250              * we don't have to calculate overflow every time we get individual
251              * limb multiplication result
252              */
253             if (j % 2 == 0)
254                 _mul_limb(&r_even[i + j], &r_even[i + j + 1], a[i], b[j]);
255             else
256                 _mul_limb(&r_odd[i + j], &r_odd[i + j + 1], a[i], b[j]);
257         }
258         /*
259          * skip the least significant limbs when adding multiples of
260          * more significant limbs (they're zero anyway)
261          */
262         add(ret, ret, r_even, n + i + 1);
263         add(ret, ret, r_odd, n + i + 1);
264     }
265 }
266
267 /* modifies the value in place by performing a right shift by one bit */
268 static ossl_inline void rshift1(limb_t *val, size_t n)
269 {
270     limb_t shift_in = 0, shift_out = 0;
271     size_t i;
272
273     for (i = 0; i < n; i++) {
274         shift_out = val[i] & 1;
275         val[i] = shift_in << (LIMB_BIT_SIZE - 1) | (val[i] >> 1);
276         shift_in = shift_out;
277     }
278 }
279
280 /* extend the LSB of flag to all bits of limb */
281 static ossl_inline limb_t mk_mask(limb_t flag)
282 {
283     flag |= flag << 1;
284     flag |= flag << 2;
285     flag |= flag << 4;
286     flag |= flag << 8;
287     flag |= flag << 16;
288 #if (LIMB_BYTE_SIZE == 8)
289     flag |= flag << 32;
290 #endif
291     return flag;
292 }
293
294 /*
295  * copy from either a or b to ret based on flag
296  * when flag == 0, then copies from b
297  * when flag == 1, then copies from a
298  */
299 static ossl_inline void cselect(limb_t flag, limb_t *ret, limb_t *a, limb_t *b, size_t n)
300 {
301     /*
302      * would be more efficient with non volatile mask, but then gcc
303      * generates code with jumps
304      */
305     volatile limb_t mask;
306     size_t i;
307
308     mask = mk_mask(flag);
309     for (i = 0; i < n; i++) {
310 #if (LIMB_BYTE_SIZE == 8)
311         ret[i] = constant_time_select_64(mask, a[i], b[i]);
312 #else
313         ret[i] = constant_time_select_32(mask, a[i], b[i]);
314 #endif
315     }
316 }
317
318 static limb_t _sub_limb(limb_t *ret, limb_t a, limb_t b, limb_t borrow)
319 {
320     limb_t borrow1, borrow2, t;
321     /*
322      * while it doesn't look constant-time, this is idiomatic code
323      * to tell compilers to use the carry bit from subtraction
324      */
325
326     *ret = a - borrow;
327     if (*ret > a)
328         borrow1 = 1;
329     else
330         borrow1 = 0;
331
332     t = *ret;
333     *ret = t - b;
334     if (*ret > t)
335         borrow2 = 1;
336     else
337         borrow2 = 0;
338
339     return borrow1 + borrow2;
340 }
341
342 /*
343  * place the result of a - b into ret, return the borrow bit.
344  * All arrays need to be n limbs long
345  */
346 static limb_t sub(limb_t *ret, limb_t *a, limb_t *b, size_t n)
347 {
348     limb_t borrow = 0;
349     ossl_ssize_t i;
350
351     for (i = n - 1; i > -1; i--)
352         borrow = _sub_limb(&ret[i], a[i], b[i], borrow);
353
354     return borrow;
355 }
356
357 /* return the number of limbs necessary to allocate for the mod() tmp operand */
358 static ossl_inline size_t mod_limb_numb(size_t anum, size_t modnum)
359 {
360     return (anum + modnum) * 3;
361 }
362
363 /*
364  * calculate a % mod, place the result in ret
365  * size of a is defined by anum, size of ret and mod is modnum,
366  * size of tmp is returned by mod_limb_numb()
367  */
368 static void mod(limb_t *ret, limb_t *a, size_t anum, limb_t *mod,
369                size_t modnum, limb_t *tmp)
370 {
371     limb_t *atmp, *modtmp, *rettmp;
372     limb_t res;
373     size_t i;
374
375     memset(tmp, 0, mod_limb_numb(anum, modnum) * LIMB_BYTE_SIZE);
376
377     atmp = tmp;
378     modtmp = &tmp[anum + modnum];
379     rettmp = &tmp[(anum + modnum) * 2];
380
381     for (i = modnum; i <modnum + anum; i++)
382         atmp[i] = a[i-modnum];
383
384     for (i = 0; i < modnum; i++)
385         modtmp[i] = mod[i];
386
387     for (i = 0; i < anum * LIMB_BIT_SIZE; i++) {
388         rshift1(modtmp, anum + modnum);
389         res = sub(rettmp, atmp, modtmp, anum+modnum);
390         cselect(res, atmp, atmp, rettmp, anum+modnum);
391     }
392
393     memcpy(ret, &atmp[anum], sizeof(limb_t) * modnum);
394 }
395
396 /* necessary size of tmp for a _mul_add_limb() call with provided anum */
397 static ossl_inline size_t _mul_add_limb_numb(size_t anum)
398 {
399     return 2 * (anum + 1);
400 }
401
402 /* multiply a by m, add to ret, return carry */
403 static limb_t _mul_add_limb(limb_t *ret, limb_t *a, size_t anum,
404                            limb_t m, limb_t *tmp)
405 {
406     limb_t carry = 0;
407     limb_t *r_odd, *r_even;
408     size_t i;
409
410     memset(tmp, 0, sizeof(limb_t) * (anum + 1) * 2);
411
412     r_odd = tmp;
413     r_even = &tmp[anum + 1];
414
415     for (i = 0; i < anum; i++) {
416         /*
417          * place the results from even and odd limbs in separate arrays
418          * so that we have to worry about carry just once
419          */
420         if (i % 2 == 0)
421             _mul_limb(&r_even[i], &r_even[i + 1], a[i], m);
422         else
423             _mul_limb(&r_odd[i], &r_odd[i + 1], a[i], m);
424     }
425     /* assert: add() carry here will be equal zero */
426     add(r_even, r_even, r_odd, anum + 1);
427     /*
428      * while here it will not overflow as the max value from multiplication
429      * is -2 while max overflow from addition is 1, so the max value of
430      * carry is -1 (i.e. max int)
431      */
432     carry = add(ret, ret, &r_even[1], anum) + r_even[0];
433
434     return carry;
435 }
436
437 static ossl_inline size_t mod_montgomery_limb_numb(size_t modnum)
438 {
439     return modnum * 2 + _mul_add_limb_numb(modnum);
440 }
441
442 /*
443  * calculate a % mod, place result in ret
444  * assumes that a is in Montgomery form with the R (Montgomery modulus) being
445  * smallest power of two big enough to fit mod and that's also a power
446  * of the count of number of bits in limb_t (B).
447  * For calculation, we also need n', such that mod * n' == -1 mod B.
448  * anum must be <= 2 * modnum
449  * ret needs to be modnum words long
450  * tmp needs to be mod_montgomery_limb_numb(modnum) limbs long
451  */
452 static void mod_montgomery(limb_t *ret, limb_t *a, size_t anum, limb_t *mod,
453                           size_t modnum, limb_t ni0, limb_t *tmp)
454 {
455     limb_t carry, v;
456     limb_t *res, *rp, *tmp2;
457     ossl_ssize_t i;
458
459     res = tmp;
460     /*
461      * for intermediate result we need an integer twice as long as modulus
462      * but keep the input in the least significant limbs
463      */
464     memset(res, 0, sizeof(limb_t) * (modnum * 2));
465     memcpy(&res[modnum * 2 - anum], a, sizeof(limb_t) * anum);
466     rp = &res[modnum];
467     tmp2 = &res[modnum * 2];
468
469     carry = 0;
470
471     /* add multiples of the modulus to the value until R divides it cleanly */
472     for (i = modnum; i > 0; i--, rp--) {
473         v = _mul_add_limb(rp, mod, modnum, rp[modnum-1] * ni0, tmp2);
474         v = v + carry + rp[-1];
475         carry |= (v != rp[-1]);
476         carry &= (v <= rp[-1]);
477         rp[-1] = v;
478     }
479
480     /* perform the final reduction by mod... */
481     carry -= sub(ret, rp, mod, modnum);
482
483     /* ...conditionally */
484     cselect(carry, ret, rp, ret, modnum);
485 }
486
487 /* allocated buffer should be freed afterwards */
488 static void BN_to_limb(const BIGNUM *bn, limb_t *buf, size_t limbs)
489 {
490     int i;
491     int real_limbs = (BN_num_bytes(bn) + LIMB_BYTE_SIZE - 1) / LIMB_BYTE_SIZE;
492     limb_t *ptr = buf + (limbs - real_limbs);
493
494     for (i = 0; i < real_limbs; i++)
495          ptr[i] = bn->d[real_limbs - i - 1];
496 }
497
498 #if LIMB_BYTE_SIZE == 8
499 static ossl_inline uint64_t be64(uint64_t host)
500 {
501     uint64_t big = 0;
502     DECLARE_IS_ENDIAN;
503
504     if (!IS_LITTLE_ENDIAN)
505         return host;
506
507     big |= (host & 0xff00000000000000) >> 56;
508     big |= (host & 0x00ff000000000000) >> 40;
509     big |= (host & 0x0000ff0000000000) >> 24;
510     big |= (host & 0x000000ff00000000) >>  8;
511     big |= (host & 0x00000000ff000000) <<  8;
512     big |= (host & 0x0000000000ff0000) << 24;
513     big |= (host & 0x000000000000ff00) << 40;
514     big |= (host & 0x00000000000000ff) << 56;
515     return big;
516 }
517
518 #else
519 /* Not all platforms have htobe32(). */
520 static ossl_inline uint32_t be32(uint32_t host)
521 {
522     uint32_t big = 0;
523     DECLARE_IS_ENDIAN;
524
525     if (!IS_LITTLE_ENDIAN)
526         return host;
527
528     big |= (host & 0xff000000) >> 24;
529     big |= (host & 0x00ff0000) >> 8;
530     big |= (host & 0x0000ff00) << 8;
531     big |= (host & 0x000000ff) << 24;
532     return big;
533 }
534 #endif
535
536 /*
537  * We assume that intermediate, possible_arg2, blinding, and ctx are used
538  * similar to BN_BLINDING_invert_ex() arguments.
539  * to_mod is RSA modulus.
540  * buf and num is the serialization buffer and its length.
541  *
542  * Here we use classic/Montgomery multiplication and modulo. After the calculation finished
543  * we serialize the new structure instead of BIGNUMs taking endianness into account.
544  */
545 int ossl_bn_rsa_do_unblind(const BIGNUM *intermediate,
546                            const BN_BLINDING *blinding,
547                            const BIGNUM *possible_arg2,
548                            const BIGNUM *to_mod, BN_CTX *ctx,
549                            unsigned char *buf, int num)
550 {
551     limb_t *l_im = NULL, *l_mul = NULL, *l_mod = NULL;
552     limb_t *l_ret = NULL, *l_tmp = NULL, l_buf;
553     size_t l_im_count = 0, l_mul_count = 0, l_size = 0, l_mod_count = 0;
554     size_t l_tmp_count = 0;
555     int ret = 0;
556     size_t i;
557     unsigned char *tmp;
558     const BIGNUM *arg1 = intermediate;
559     const BIGNUM *arg2 = (possible_arg2 == NULL) ? blinding->Ai : possible_arg2;
560
561     l_im_count  = (BN_num_bytes(arg1)   + LIMB_BYTE_SIZE - 1) / LIMB_BYTE_SIZE;
562     l_mul_count = (BN_num_bytes(arg2)   + LIMB_BYTE_SIZE - 1) / LIMB_BYTE_SIZE;
563     l_mod_count = (BN_num_bytes(to_mod) + LIMB_BYTE_SIZE - 1) / LIMB_BYTE_SIZE;
564
565     l_size = l_im_count > l_mul_count ? l_im_count : l_mul_count;
566     l_im  = OPENSSL_zalloc(l_size * LIMB_BYTE_SIZE);
567     l_mul = OPENSSL_zalloc(l_size * LIMB_BYTE_SIZE);
568     l_mod = OPENSSL_zalloc(l_mod_count * LIMB_BYTE_SIZE);
569
570     if ((l_im == NULL) || (l_mul == NULL) || (l_mod == NULL))
571         goto err;
572
573     BN_to_limb(arg1,   l_im,  l_size);
574     BN_to_limb(arg2,   l_mul, l_size);
575     BN_to_limb(to_mod, l_mod, l_mod_count);
576
577     l_ret = OPENSSL_malloc(2 * l_size * LIMB_BYTE_SIZE);
578
579     if (blinding->m_ctx != NULL) {
580         l_tmp_count = mul_limb_numb(l_size) > mod_montgomery_limb_numb(l_mod_count) ?
581                       mul_limb_numb(l_size) : mod_montgomery_limb_numb(l_mod_count);
582         l_tmp = OPENSSL_malloc(l_tmp_count * LIMB_BYTE_SIZE);
583     } else {
584         l_tmp_count = mul_limb_numb(l_size) > mod_limb_numb(2 * l_size, l_mod_count) ?
585                       mul_limb_numb(l_size) : mod_limb_numb(2 * l_size, l_mod_count);
586         l_tmp = OPENSSL_malloc(l_tmp_count * LIMB_BYTE_SIZE);
587     }
588
589     if ((l_ret == NULL) || (l_tmp == NULL))
590         goto err;
591
592     if (blinding->m_ctx != NULL) {
593         limb_mul(l_ret, l_im, l_mul, l_size, l_tmp);
594         mod_montgomery(l_ret, l_ret, 2 * l_size, l_mod, l_mod_count,
595                        blinding->m_ctx->n0[0], l_tmp);
596     } else {
597         limb_mul(l_ret, l_im, l_mul, l_size, l_tmp);
598         mod(l_ret, l_ret, 2 * l_size, l_mod, l_mod_count, l_tmp);
599     }
600
601     /* modulus size in bytes can be equal to num but after limbs conversion it becomes bigger */
602     if (num < BN_num_bytes(to_mod)) {
603         ERR_raise(ERR_LIB_BN, ERR_R_PASSED_INVALID_ARGUMENT);
604         goto err;
605     }
606
607     memset(buf, 0, num);
608     tmp = buf + num - BN_num_bytes(to_mod);
609     for (i = 0; i < l_mod_count; i++) {
610 #if LIMB_BYTE_SIZE == 8
611         l_buf = be64(l_ret[i]);
612 #else
613         l_buf = be32(l_ret[i]);
614 #endif
615         if (i == 0) {
616             int delta = LIMB_BYTE_SIZE - ((l_mod_count * LIMB_BYTE_SIZE) - num);
617
618             memcpy(tmp, ((char *)&l_buf) + LIMB_BYTE_SIZE - delta, delta);
619             tmp += delta;
620         } else {
621             memcpy(tmp, &l_buf, LIMB_BYTE_SIZE);
622             tmp += LIMB_BYTE_SIZE;
623         }
624     }
625     ret = num;
626
627  err:
628     OPENSSL_free(l_im);
629     OPENSSL_free(l_mul);
630     OPENSSL_free(l_mod);
631     OPENSSL_free(l_tmp);
632     OPENSSL_free(l_ret);
633
634     return ret;
635 }