2 * Copyright 2022-2023 The OpenSSL Project Authors. All Rights Reserved.
4 * Licensed under the Apache License 2.0 (the "License"). You may not use
5 * this file except in compliance with the License. You can obtain a copy
6 * in the file LICENSE in the source distribution or at
7 * https://www.openssl.org/source/license.html
10 #include <openssl/e_os2.h>
12 #include <sys/types.h>
14 #include <openssl/bn.h>
15 #include <openssl/err.h>
16 #include <openssl/rsaerr.h>
17 #include "internal/endian.h"
18 #include "internal/numbers.h"
19 #include "internal/constant_time.h"
23 typedef uint64_t limb_t;
24 # if defined(__SIZEOF_INT128__) && __SIZEOF_INT128__ == 16
25 typedef uint128_t limb2_t;
28 # define LIMB_BIT_SIZE 64
29 # define LIMB_BYTE_SIZE 8
31 typedef uint32_t limb_t;
32 typedef uint64_t limb2_t;
33 # define LIMB_BIT_SIZE 32
34 # define LIMB_BYTE_SIZE 4
37 # error "Not supported"
41 * For multiplication we're using schoolbook multiplication,
42 * so if we have two numbers, each with 6 "digits" (words)
43 * the multiplication is calculated as follows:
83 * ==========================
103 * which requires n^2 multiplications and 2n full length additions
104 * as we can keep every other result of limb multiplication in two separate
108 #if defined HAVE_LIMB2_T
109 static ossl_inline void _mul_limb(limb_t *hi, limb_t *lo, limb_t a, limb_t b)
113 * this is idiomatic code to tell compiler to use the native mul
114 * those three lines will actually compile to single instruction
118 *hi = t >> LIMB_BIT_SIZE;
121 #elif (BN_BYTES == 8) && (defined _MSC_VER)
124 * on x86_64 (x64) we can use the _umul128 intrinsic to get one `mul`
125 * instruction to get both high and low 64 bits of the multiplication.
126 * https://learn.microsoft.com/en-us/cpp/intrinsics/umul128?view=msvc-140
129 #pragma intrinsic(_umul128)
130 static ossl_inline void _mul_limb(limb_t *hi, limb_t *lo, limb_t a, limb_t b)
132 *lo = _umul128(a, b, hi);
134 # elif defined(_M_ARM64) || defined (_M_IA64)
136 * We can't use the __umulh() on x86_64 as then msvc generates two `mul`
137 * instructions; so use this more portable intrinsic on platforms that
138 * don't support _umul128 (like aarch64 (ARM64) or ia64)
139 * https://learn.microsoft.com/en-us/cpp/intrinsics/umulh?view=msvc-140
142 static ossl_inline void _mul_limb(limb_t *hi, limb_t *lo, limb_t a, limb_t b)
148 # error Only x64, ARM64 and IA64 supported.
149 # endif /* defined(_M_X64) */
152 * if the compiler doesn't have either a 128bit data type nor a "return
153 * high 64 bits of multiplication"
155 static ossl_inline void _mul_limb(limb_t *hi, limb_t *lo, limb_t a, limb_t b)
157 limb_t a_low = (limb_t)(uint32_t)a;
158 limb_t a_hi = a >> 32;
159 limb_t b_low = (limb_t)(uint32_t)b;
160 limb_t b_hi = b >> 32;
162 limb_t p0 = a_low * b_low;
163 limb_t p1 = a_low * b_hi;
164 limb_t p2 = a_hi * b_low;
165 limb_t p3 = a_hi * b_hi;
167 uint32_t cy = (uint32_t)(((p0 >> 32) + (uint32_t)p1 + (uint32_t)p2) >> 32);
169 *lo = p0 + (p1 << 32) + (p2 << 32);
170 *hi = p3 + (p1 >> 32) + (p2 >> 32) + cy;
174 /* add two limbs with carry in, return carry out */
175 static ossl_inline limb_t _add_limb(limb_t *ret, limb_t a, limb_t b, limb_t carry)
177 limb_t carry1, carry2, t;
179 * `c = a + b; if (c < a)` is idiomatic code that makes compilers
180 * use add with carry on assembly level
196 return carry1 + carry2;
200 * add two numbers of the same size, return overflow
202 * add a to b, place result in ret; all arrays need to be n limbs long
203 * return overflow from addition (0 or 1)
205 static ossl_inline limb_t add(limb_t *ret, limb_t *a, limb_t *b, size_t n)
210 for(i = n - 1; i > -1; i--)
211 c = _add_limb(&ret[i], a[i], b[i], c);
217 * return number of limbs necessary for temporary values
218 * when multiplying numbers n limbs large
220 static ossl_inline size_t mul_limb_numb(size_t n)
226 * multiply two numbers of the same size
228 * multiply a by b, place result in ret; a and b need to be n limbs long
229 * ret needs to be 2*n limbs long, tmp needs to be mul_limb_numb(n) limbs
232 static void limb_mul(limb_t *ret, limb_t *a, limb_t *b, size_t n, limb_t *tmp)
234 limb_t *r_odd, *r_even;
238 r_even = &tmp[2 * n];
240 memset(ret, 0, 2 * n * sizeof(limb_t));
242 for (i = 0; i < n; i++) {
243 for (k = 0; k < i + n + 1; k++) {
247 for (j = 0; j < n; j++) {
249 * place results from even and odd limbs in separate arrays so that
250 * we don't have to calculate overflow every time we get individual
251 * limb multiplication result
254 _mul_limb(&r_even[i + j], &r_even[i + j + 1], a[i], b[j]);
256 _mul_limb(&r_odd[i + j], &r_odd[i + j + 1], a[i], b[j]);
259 * skip the least significant limbs when adding multiples of
260 * more significant limbs (they're zero anyway)
262 add(ret, ret, r_even, n + i + 1);
263 add(ret, ret, r_odd, n + i + 1);
267 /* modifies the value in place by performing a right shift by one bit */
268 static ossl_inline void rshift1(limb_t *val, size_t n)
270 limb_t shift_in = 0, shift_out = 0;
273 for (i = 0; i < n; i++) {
274 shift_out = val[i] & 1;
275 val[i] = shift_in << (LIMB_BIT_SIZE - 1) | (val[i] >> 1);
276 shift_in = shift_out;
280 /* extend the LSB of flag to all bits of limb */
281 static ossl_inline limb_t mk_mask(limb_t flag)
288 #if (LIMB_BYTE_SIZE == 8)
295 * copy from either a or b to ret based on flag
296 * when flag == 0, then copies from b
297 * when flag == 1, then copies from a
299 static ossl_inline void cselect(limb_t flag, limb_t *ret, limb_t *a, limb_t *b, size_t n)
302 * would be more efficient with non volatile mask, but then gcc
303 * generates code with jumps
305 volatile limb_t mask;
308 mask = mk_mask(flag);
309 for (i = 0; i < n; i++) {
310 #if (LIMB_BYTE_SIZE == 8)
311 ret[i] = constant_time_select_64(mask, a[i], b[i]);
313 ret[i] = constant_time_select_32(mask, a[i], b[i]);
318 static limb_t _sub_limb(limb_t *ret, limb_t a, limb_t b, limb_t borrow)
320 limb_t borrow1, borrow2, t;
322 * while it doesn't look constant-time, this is idiomatic code
323 * to tell compilers to use the carry bit from subtraction
339 return borrow1 + borrow2;
343 * place the result of a - b into ret, return the borrow bit.
344 * All arrays need to be n limbs long
346 static limb_t sub(limb_t *ret, limb_t *a, limb_t *b, size_t n)
351 for (i = n - 1; i > -1; i--)
352 borrow = _sub_limb(&ret[i], a[i], b[i], borrow);
357 /* return the number of limbs necessary to allocate for the mod() tmp operand */
358 static ossl_inline size_t mod_limb_numb(size_t anum, size_t modnum)
360 return (anum + modnum) * 3;
364 * calculate a % mod, place the result in ret
365 * size of a is defined by anum, size of ret and mod is modnum,
366 * size of tmp is returned by mod_limb_numb()
368 static void mod(limb_t *ret, limb_t *a, size_t anum, limb_t *mod,
369 size_t modnum, limb_t *tmp)
371 limb_t *atmp, *modtmp, *rettmp;
375 memset(tmp, 0, mod_limb_numb(anum, modnum) * LIMB_BYTE_SIZE);
378 modtmp = &tmp[anum + modnum];
379 rettmp = &tmp[(anum + modnum) * 2];
381 for (i = modnum; i <modnum + anum; i++)
382 atmp[i] = a[i-modnum];
384 for (i = 0; i < modnum; i++)
387 for (i = 0; i < anum * LIMB_BIT_SIZE; i++) {
388 rshift1(modtmp, anum + modnum);
389 res = sub(rettmp, atmp, modtmp, anum+modnum);
390 cselect(res, atmp, atmp, rettmp, anum+modnum);
393 memcpy(ret, &atmp[anum], sizeof(limb_t) * modnum);
396 /* necessary size of tmp for a _mul_add_limb() call with provided anum */
397 static ossl_inline size_t _mul_add_limb_numb(size_t anum)
399 return 2 * (anum + 1);
402 /* multiply a by m, add to ret, return carry */
403 static limb_t _mul_add_limb(limb_t *ret, limb_t *a, size_t anum,
404 limb_t m, limb_t *tmp)
407 limb_t *r_odd, *r_even;
410 memset(tmp, 0, sizeof(limb_t) * (anum + 1) * 2);
413 r_even = &tmp[anum + 1];
415 for (i = 0; i < anum; i++) {
417 * place the results from even and odd limbs in separate arrays
418 * so that we have to worry about carry just once
421 _mul_limb(&r_even[i], &r_even[i + 1], a[i], m);
423 _mul_limb(&r_odd[i], &r_odd[i + 1], a[i], m);
425 /* assert: add() carry here will be equal zero */
426 add(r_even, r_even, r_odd, anum + 1);
428 * while here it will not overflow as the max value from multiplication
429 * is -2 while max overflow from addition is 1, so the max value of
430 * carry is -1 (i.e. max int)
432 carry = add(ret, ret, &r_even[1], anum) + r_even[0];
437 static ossl_inline size_t mod_montgomery_limb_numb(size_t modnum)
439 return modnum * 2 + _mul_add_limb_numb(modnum);
443 * calculate a % mod, place result in ret
444 * assumes that a is in Montgomery form with the R (Montgomery modulus) being
445 * smallest power of two big enough to fit mod and that's also a power
446 * of the count of number of bits in limb_t (B).
447 * For calculation, we also need n', such that mod * n' == -1 mod B.
448 * anum must be <= 2 * modnum
449 * ret needs to be modnum words long
450 * tmp needs to be mod_montgomery_limb_numb(modnum) limbs long
452 static void mod_montgomery(limb_t *ret, limb_t *a, size_t anum, limb_t *mod,
453 size_t modnum, limb_t ni0, limb_t *tmp)
456 limb_t *res, *rp, *tmp2;
461 * for intermediate result we need an integer twice as long as modulus
462 * but keep the input in the least significant limbs
464 memset(res, 0, sizeof(limb_t) * (modnum * 2));
465 memcpy(&res[modnum * 2 - anum], a, sizeof(limb_t) * anum);
467 tmp2 = &res[modnum * 2];
471 /* add multiples of the modulus to the value until R divides it cleanly */
472 for (i = modnum; i > 0; i--, rp--) {
473 v = _mul_add_limb(rp, mod, modnum, rp[modnum-1] * ni0, tmp2);
474 v = v + carry + rp[-1];
475 carry |= (v != rp[-1]);
476 carry &= (v <= rp[-1]);
480 /* perform the final reduction by mod... */
481 carry -= sub(ret, rp, mod, modnum);
483 /* ...conditionally */
484 cselect(carry, ret, rp, ret, modnum);
487 /* allocated buffer should be freed afterwards */
488 static void BN_to_limb(const BIGNUM *bn, limb_t *buf, size_t limbs)
491 int real_limbs = (BN_num_bytes(bn) + LIMB_BYTE_SIZE - 1) / LIMB_BYTE_SIZE;
492 limb_t *ptr = buf + (limbs - real_limbs);
494 for (i = 0; i < real_limbs; i++)
495 ptr[i] = bn->d[real_limbs - i - 1];
498 #if LIMB_BYTE_SIZE == 8
499 static ossl_inline uint64_t be64(uint64_t host)
504 if (!IS_LITTLE_ENDIAN)
507 big |= (host & 0xff00000000000000) >> 56;
508 big |= (host & 0x00ff000000000000) >> 40;
509 big |= (host & 0x0000ff0000000000) >> 24;
510 big |= (host & 0x000000ff00000000) >> 8;
511 big |= (host & 0x00000000ff000000) << 8;
512 big |= (host & 0x0000000000ff0000) << 24;
513 big |= (host & 0x000000000000ff00) << 40;
514 big |= (host & 0x00000000000000ff) << 56;
519 /* Not all platforms have htobe32(). */
520 static ossl_inline uint32_t be32(uint32_t host)
525 if (!IS_LITTLE_ENDIAN)
528 big |= (host & 0xff000000) >> 24;
529 big |= (host & 0x00ff0000) >> 8;
530 big |= (host & 0x0000ff00) << 8;
531 big |= (host & 0x000000ff) << 24;
537 * We assume that intermediate, possible_arg2, blinding, and ctx are used
538 * similar to BN_BLINDING_invert_ex() arguments.
539 * to_mod is RSA modulus.
540 * buf and num is the serialization buffer and its length.
542 * Here we use classic/Montgomery multiplication and modulo. After the calculation finished
543 * we serialize the new structure instead of BIGNUMs taking endianness into account.
545 int ossl_bn_rsa_do_unblind(const BIGNUM *intermediate,
546 const BN_BLINDING *blinding,
547 const BIGNUM *possible_arg2,
548 const BIGNUM *to_mod, BN_CTX *ctx,
549 unsigned char *buf, int num)
551 limb_t *l_im = NULL, *l_mul = NULL, *l_mod = NULL;
552 limb_t *l_ret = NULL, *l_tmp = NULL, l_buf;
553 size_t l_im_count = 0, l_mul_count = 0, l_size = 0, l_mod_count = 0;
554 size_t l_tmp_count = 0;
558 const BIGNUM *arg1 = intermediate;
559 const BIGNUM *arg2 = (possible_arg2 == NULL) ? blinding->Ai : possible_arg2;
561 l_im_count = (BN_num_bytes(arg1) + LIMB_BYTE_SIZE - 1) / LIMB_BYTE_SIZE;
562 l_mul_count = (BN_num_bytes(arg2) + LIMB_BYTE_SIZE - 1) / LIMB_BYTE_SIZE;
563 l_mod_count = (BN_num_bytes(to_mod) + LIMB_BYTE_SIZE - 1) / LIMB_BYTE_SIZE;
565 l_size = l_im_count > l_mul_count ? l_im_count : l_mul_count;
566 l_im = OPENSSL_zalloc(l_size * LIMB_BYTE_SIZE);
567 l_mul = OPENSSL_zalloc(l_size * LIMB_BYTE_SIZE);
568 l_mod = OPENSSL_zalloc(l_mod_count * LIMB_BYTE_SIZE);
570 if ((l_im == NULL) || (l_mul == NULL) || (l_mod == NULL))
573 BN_to_limb(arg1, l_im, l_size);
574 BN_to_limb(arg2, l_mul, l_size);
575 BN_to_limb(to_mod, l_mod, l_mod_count);
577 l_ret = OPENSSL_malloc(2 * l_size * LIMB_BYTE_SIZE);
579 if (blinding->m_ctx != NULL) {
580 l_tmp_count = mul_limb_numb(l_size) > mod_montgomery_limb_numb(l_mod_count) ?
581 mul_limb_numb(l_size) : mod_montgomery_limb_numb(l_mod_count);
582 l_tmp = OPENSSL_malloc(l_tmp_count * LIMB_BYTE_SIZE);
584 l_tmp_count = mul_limb_numb(l_size) > mod_limb_numb(2 * l_size, l_mod_count) ?
585 mul_limb_numb(l_size) : mod_limb_numb(2 * l_size, l_mod_count);
586 l_tmp = OPENSSL_malloc(l_tmp_count * LIMB_BYTE_SIZE);
589 if ((l_ret == NULL) || (l_tmp == NULL))
592 if (blinding->m_ctx != NULL) {
593 limb_mul(l_ret, l_im, l_mul, l_size, l_tmp);
594 mod_montgomery(l_ret, l_ret, 2 * l_size, l_mod, l_mod_count,
595 blinding->m_ctx->n0[0], l_tmp);
597 limb_mul(l_ret, l_im, l_mul, l_size, l_tmp);
598 mod(l_ret, l_ret, 2 * l_size, l_mod, l_mod_count, l_tmp);
601 /* modulus size in bytes can be equal to num but after limbs conversion it becomes bigger */
602 if (num < BN_num_bytes(to_mod)) {
603 ERR_raise(ERR_LIB_BN, ERR_R_PASSED_INVALID_ARGUMENT);
608 tmp = buf + num - BN_num_bytes(to_mod);
609 for (i = 0; i < l_mod_count; i++) {
610 #if LIMB_BYTE_SIZE == 8
611 l_buf = be64(l_ret[i]);
613 l_buf = be32(l_ret[i]);
616 int delta = LIMB_BYTE_SIZE - ((l_mod_count * LIMB_BYTE_SIZE) - num);
618 memcpy(tmp, ((char *)&l_buf) + LIMB_BYTE_SIZE - delta, delta);
621 memcpy(tmp, &l_buf, LIMB_BYTE_SIZE);
622 tmp += LIMB_BYTE_SIZE;