Split bignum code out of the sparcv9cap.c
[openssl.git] / crypto / bn / bn_exp.c
index 10d3912c5ad28a9faf89be7313fc2ef3473ed664..5329cd12a98470fdbe4546811adbcd0a9de25e11 100644 (file)
@@ -1,15 +1,15 @@
 /*
- * Copyright 1995-2018 The OpenSSL Project Authors. All Rights Reserved.
+ * Copyright 1995-2021 The OpenSSL Project Authors. All Rights Reserved.
  *
- * Licensed under the OpenSSL license (the "License").  You may not use
+ * Licensed under the Apache License 2.0 (the "License").  You may not use
  * this file except in compliance with the License.  You can obtain a copy
  * in the file LICENSE in the source distribution or at
  * https://www.openssl.org/source/license.html
  */
 
 #include "internal/cryptlib.h"
-#include "internal/constant_time_locl.h"
-#include "bn_lcl.h"
+#include "internal/constant_time.h"
+#include "bn_local.h"
 
 #include <stdlib.h>
 #ifdef _WIN32
@@ -29,8 +29,7 @@
 
 #undef SPARC_T4_MONT
 #if defined(OPENSSL_BN_ASM_MONT) && (defined(__sparc__) || defined(__sparc))
-# include "sparc_arch.h"
-extern unsigned int OPENSSL_sparcv9cap_P[];
+# include "crypto/sparc_arch.h"
 # define SPARC_T4_MONT
 #endif
 
@@ -46,7 +45,7 @@ int BN_exp(BIGNUM *r, const BIGNUM *a, const BIGNUM *p, BN_CTX *ctx)
     if (BN_get_flags(p, BN_FLG_CONSTTIME) != 0
             || BN_get_flags(a, BN_FLG_CONSTTIME) != 0) {
         /* BN_FLG_CONSTTIME only supported by BN_mod_exp_mont() */
-        BNerr(BN_F_BN_EXP, ERR_R_SHOULD_NOT_HAVE_BEEN_CALLED);
+        ERR_raise(ERR_LIB_BN, ERR_R_SHOULD_NOT_HAVE_BEEN_CALLED);
         return 0;
     }
 
@@ -172,7 +171,7 @@ int BN_mod_exp_recp(BIGNUM *r, const BIGNUM *a, const BIGNUM *p,
             || BN_get_flags(a, BN_FLG_CONSTTIME) != 0
             || BN_get_flags(m, BN_FLG_CONSTTIME) != 0) {
         /* BN_FLG_CONSTTIME only supported by BN_mod_exp_mont() */
-        BNerr(BN_F_BN_MOD_EXP_RECP, ERR_R_SHOULD_NOT_HAVE_BEEN_CALLED);
+        ERR_raise(ERR_LIB_BN, ERR_R_SHOULD_NOT_HAVE_BEEN_CALLED);
         return 0;
     }
 
@@ -252,7 +251,6 @@ int BN_mod_exp_recp(BIGNUM *r, const BIGNUM *a, const BIGNUM *p,
          * a window to do.  To do this we need to scan forward until the last
          * set bit before the end of the window
          */
-        j = wstart;
         wvalue = 1;
         wend = 0;
         for (i = 1; i < window; i++) {
@@ -315,7 +313,7 @@ int BN_mod_exp_mont(BIGNUM *rr, const BIGNUM *a, const BIGNUM *p,
     bn_check_top(m);
 
     if (!BN_is_odd(m)) {
-        BNerr(BN_F_BN_MOD_EXP_MONT, BN_R_CALLED_WITH_EVEN_MODULUS);
+        ERR_raise(ERR_LIB_BN, BN_R_CALLED_WITH_EVEN_MODULUS);
         return 0;
     }
     bits = BN_num_bits(p);
@@ -356,22 +354,17 @@ int BN_mod_exp_mont(BIGNUM *rr, const BIGNUM *a, const BIGNUM *p,
         aa = val[0];
     } else
         aa = a;
-    if (BN_is_zero(aa)) {
-        BN_zero(rr);
-        ret = 1;
-        goto err;
-    }
-    if (!BN_to_montgomery(val[0], aa, mont, ctx))
+    if (!bn_to_mont_fixed_top(val[0], aa, mont, ctx))
         goto err;               /* 1 */
 
     window = BN_window_bits_for_exponent_size(bits);
     if (window > 1) {
-        if (!BN_mod_mul_montgomery(d, val[0], val[0], mont, ctx))
+        if (!bn_mul_mont_fixed_top(d, val[0], val[0], mont, ctx))
             goto err;           /* 2 */
         j = 1 << (window - 1);
         for (i = 1; i < j; i++) {
             if (((val[i] = BN_CTX_get(ctx)) == NULL) ||
-                !BN_mod_mul_montgomery(val[i], val[i - 1], d, mont, ctx))
+                !bn_mul_mont_fixed_top(val[i], val[i - 1], d, mont, ctx))
                 goto err;
         }
     }
@@ -393,19 +386,15 @@ int BN_mod_exp_mont(BIGNUM *rr, const BIGNUM *a, const BIGNUM *p,
         for (i = 1; i < j; i++)
             r->d[i] = (~m->d[i]) & BN_MASK2;
         r->top = j;
-        /*
-         * Upper words will be zero if the corresponding words of 'm' were
-         * 0xfff[...], so decrement r->top accordingly.
-         */
-        bn_correct_top(r);
+        r->flags |= BN_FLG_FIXED_TOP;
     } else
 #endif
-    if (!BN_to_montgomery(r, BN_value_one(), mont, ctx))
+    if (!bn_to_mont_fixed_top(r, BN_value_one(), mont, ctx))
         goto err;
     for (;;) {
         if (BN_is_bit_set(p, wstart) == 0) {
             if (!start) {
-                if (!BN_mod_mul_montgomery(r, r, r, mont, ctx))
+                if (!bn_mul_mont_fixed_top(r, r, r, mont, ctx))
                     goto err;
             }
             if (wstart == 0)
@@ -418,7 +407,6 @@ int BN_mod_exp_mont(BIGNUM *rr, const BIGNUM *a, const BIGNUM *p,
          * a window to do.  To do this we need to scan forward until the last
          * set bit before the end of the window
          */
-        j = wstart;
         wvalue = 1;
         wend = 0;
         for (i = 1; i < window; i++) {
@@ -436,12 +424,12 @@ int BN_mod_exp_mont(BIGNUM *rr, const BIGNUM *a, const BIGNUM *p,
         /* add the 'bytes above' */
         if (!start)
             for (i = 0; i < j; i++) {
-                if (!BN_mod_mul_montgomery(r, r, r, mont, ctx))
+                if (!bn_mul_mont_fixed_top(r, r, r, mont, ctx))
                     goto err;
             }
 
         /* wvalue will be an odd number < 2^window */
-        if (!BN_mod_mul_montgomery(r, r, val[wvalue >> 1], mont, ctx))
+        if (!bn_mul_mont_fixed_top(r, r, val[wvalue >> 1], mont, ctx))
             goto err;
 
         /* move the 'window' down further */
@@ -451,6 +439,11 @@ int BN_mod_exp_mont(BIGNUM *rr, const BIGNUM *a, const BIGNUM *p,
         if (wstart < 0)
             break;
     }
+    /*
+     * Done with zero-padded intermediate BIGNUMs. Final BN_from_montgomery
+     * removes padding [if any] and makes return value suitable for public
+     * API consumer.
+     */
 #if defined(SPARC_T4_MONT)
     if (OPENSSL_sparcv9cap_P[0] & (SPARCV9_VIS3 | SPARCV9_PREFER_FPU)) {
         j = mont->N.top;        /* borrow j */
@@ -575,7 +568,7 @@ static int MOD_EXP_CTIME_COPY_FROM_PREBUF(BIGNUM *b, int top,
     }
 
     b->top = top;
-    bn_correct_top(b);
+    b->flags |= BN_FLG_FIXED_TOP;
     return 1;
 }
 
@@ -615,7 +608,7 @@ int BN_mod_exp_mont_consttime(BIGNUM *rr, const BIGNUM *a, const BIGNUM *p,
     bn_check_top(m);
 
     if (!BN_is_odd(m)) {
-        BNerr(BN_F_BN_MOD_EXP_MONT_CONSTTIME, BN_R_CALLED_WITH_EVEN_MODULUS);
+        ERR_raise(ERR_LIB_BN, BN_R_CALLED_WITH_EVEN_MODULUS);
         return 0;
     }
 
@@ -652,34 +645,41 @@ int BN_mod_exp_mont_consttime(BIGNUM *rr, const BIGNUM *a, const BIGNUM *p,
             goto err;
     }
 
+    if (a->neg || BN_ucmp(a, m) >= 0) {
+        BIGNUM *reduced = BN_CTX_get(ctx);
+        if (reduced == NULL
+            || !BN_nnmod(reduced, a, m, ctx)) {
+            goto err;
+        }
+        a = reduced;
+    }
+
 #ifdef RSAZ_ENABLED
-    if (!a->neg) {
-        /*
-         * If the size of the operands allow it, perform the optimized
-         * RSAZ exponentiation. For further information see
-         * crypto/bn/rsaz_exp.c and accompanying assembly modules.
-         */
-        if ((16 == a->top) && (16 == p->top) && (BN_num_bits(m) == 1024)
-            && rsaz_avx2_eligible()) {
-            if (NULL == bn_wexpand(rr, 16))
-                goto err;
-            RSAZ_1024_mod_exp_avx2(rr->d, a->d, p->d, m->d, mont->RR.d,
-                                   mont->n0[0]);
-            rr->top = 16;
-            rr->neg = 0;
-            bn_correct_top(rr);
-            ret = 1;
+    /*
+     * If the size of the operands allow it, perform the optimized
+     * RSAZ exponentiation. For further information see
+     * crypto/bn/rsaz_exp.c and accompanying assembly modules.
+     */
+    if ((16 == a->top) && (16 == p->top) && (BN_num_bits(m) == 1024)
+        && rsaz_avx2_eligible()) {
+        if (NULL == bn_wexpand(rr, 16))
             goto err;
-        } else if ((8 == a->top) && (8 == p->top) && (BN_num_bits(m) == 512)) {
-            if (NULL == bn_wexpand(rr, 8))
-                goto err;
-            RSAZ_512_mod_exp(rr->d, a->d, p->d, m->d, mont->n0[0], mont->RR.d);
-            rr->top = 8;
-            rr->neg = 0;
-            bn_correct_top(rr);
-            ret = 1;
+        RSAZ_1024_mod_exp_avx2(rr->d, a->d, p->d, m->d, mont->RR.d,
+                               mont->n0[0]);
+        rr->top = 16;
+        rr->neg = 0;
+        bn_correct_top(rr);
+        ret = 1;
+        goto err;
+    } else if ((8 == a->top) && (8 == p->top) && (BN_num_bits(m) == 512)) {
+        if (NULL == bn_wexpand(rr, 8))
             goto err;
-        }
+        RSAZ_512_mod_exp(rr->d, a->d, p->d, m->d, mont->n0[0], mont->RR.d);
+        rr->top = 8;
+        rr->neg = 0;
+        bn_correct_top(rr);
+        ret = 1;
+        goto err;
     }
 #endif
 
@@ -747,16 +747,11 @@ int BN_mod_exp_mont_consttime(BIGNUM *rr, const BIGNUM *a, const BIGNUM *p,
         tmp.top = top;
     } else
 #endif
-    if (!BN_to_montgomery(&tmp, BN_value_one(), mont, ctx))
+    if (!bn_to_mont_fixed_top(&tmp, BN_value_one(), mont, ctx))
         goto err;
 
     /* prepare a^1 in Montgomery domain */
-    if (a->neg || BN_ucmp(a, m) >= 0) {
-        if (!BN_nnmod(&am, a, m, ctx))
-            goto err;
-        if (!BN_to_montgomery(&am, &am, mont, ctx))
-            goto err;
-    } else if (!BN_to_montgomery(&am, a, mont, ctx))
+    if (!bn_to_mont_fixed_top(&am, a, mont, ctx))
         goto err;
 
 #if defined(SPARC_T4_MONT)
@@ -823,7 +818,7 @@ int BN_mod_exp_mont_consttime(BIGNUM *rr, const BIGNUM *a, const BIGNUM *p,
 
         /*
          * BN_to_montgomery can contaminate words above .top [in
-         * BN_DEBUG[_DEBUG] build]...
+         * BN_DEBUG build...
          */
         for (i = am.top; i < top; i++)
             am.d[i] = 0;
@@ -928,7 +923,7 @@ int BN_mod_exp_mont_consttime(BIGNUM *rr, const BIGNUM *a, const BIGNUM *p,
 
         /*
          * BN_to_montgomery can contaminate words above .top [in
-         * BN_DEBUG[_DEBUG] build]...
+         * BN_DEBUG build...
          */
         for (i = am.top; i < top; i++)
             am.d[i] = 0;
@@ -1034,14 +1029,14 @@ int BN_mod_exp_mont_consttime(BIGNUM *rr, const BIGNUM *a, const BIGNUM *p,
          * performance advantage of sqr over mul).
          */
         if (window > 1) {
-            if (!BN_mod_mul_montgomery(&tmp, &am, &am, mont, ctx))
+            if (!bn_mul_mont_fixed_top(&tmp, &am, &am, mont, ctx))
                 goto err;
             if (!MOD_EXP_CTIME_COPY_TO_PREBUF(&tmp, top, powerbuf, 2,
                                               window))
                 goto err;
             for (i = 3; i < numPowers; i++) {
                 /* Calculate a^i = a^(i-1) * a */
-                if (!BN_mod_mul_montgomery(&tmp, &am, &tmp, mont, ctx))
+                if (!bn_mul_mont_fixed_top(&tmp, &am, &tmp, mont, ctx))
                     goto err;
                 if (!MOD_EXP_CTIME_COPY_TO_PREBUF(&tmp, top, powerbuf, i,
                                                   window))
@@ -1072,7 +1067,7 @@ int BN_mod_exp_mont_consttime(BIGNUM *rr, const BIGNUM *a, const BIGNUM *p,
 
             /* Square the result window-size times */
             for (i = 0; i < window; i++)
-                if (!BN_mod_mul_montgomery(&tmp, &tmp, &tmp, mont, ctx))
+                if (!bn_mul_mont_fixed_top(&tmp, &tmp, &tmp, mont, ctx))
                     goto err;
 
             /*
@@ -1081,7 +1076,7 @@ int BN_mod_exp_mont_consttime(BIGNUM *rr, const BIGNUM *a, const BIGNUM *p,
              * is not only slower but also makes each bit vulnerable to
              * EM (and likely other) side-channel attacks like One&Done
              * (for details see "One&Done: A Single-Decryption EM-Based
-             *  Attack on OpenSSLs Constant-Time Blinded RSA" by M. Alam,
+             *  Attack on OpenSSL's Constant-Time Blinded RSA" by M. Alam,
              *  H. Khan, M. Dey, N. Sinha, R. Callan, A. Zajic, and
              *  M. Prvulovic, in USENIX Security'18)
              */
@@ -1095,12 +1090,16 @@ int BN_mod_exp_mont_consttime(BIGNUM *rr, const BIGNUM *a, const BIGNUM *p,
                 goto err;
 
             /* Multiply the result into the intermediate result */
-            if (!BN_mod_mul_montgomery(&tmp, &tmp, &am, mont, ctx))
+            if (!bn_mul_mont_fixed_top(&tmp, &tmp, &am, mont, ctx))
                 goto err;
         }
     }
 
-    /* Convert the final result from montgomery to standard format */
+    /*
+     * Done with zero-padded intermediate BIGNUMs. Final BN_from_montgomery
+     * removes padding [if any] and makes return value suitable for public
+     * API consumer.
+     */
 #if defined(SPARC_T4_MONT)
     if (OPENSSL_sparcv9cap_P[0] & (SPARCV9_VIS3 | SPARCV9_PREFER_FPU)) {
         am.d[0] = 1;            /* borrow am */
@@ -1153,7 +1152,7 @@ int BN_mod_exp_mont_word(BIGNUM *rr, BN_ULONG a, const BIGNUM *p,
     if (BN_get_flags(p, BN_FLG_CONSTTIME) != 0
             || BN_get_flags(m, BN_FLG_CONSTTIME) != 0) {
         /* BN_FLG_CONSTTIME only supported by BN_mod_exp_mont() */
-        BNerr(BN_F_BN_MOD_EXP_MONT_WORD, ERR_R_SHOULD_NOT_HAVE_BEEN_CALLED);
+        ERR_raise(ERR_LIB_BN, ERR_R_SHOULD_NOT_HAVE_BEEN_CALLED);
         return 0;
     }
 
@@ -1161,7 +1160,7 @@ int BN_mod_exp_mont_word(BIGNUM *rr, BN_ULONG a, const BIGNUM *p,
     bn_check_top(m);
 
     if (!BN_is_odd(m)) {
-        BNerr(BN_F_BN_MOD_EXP_MONT_WORD, BN_R_CALLED_WITH_EVEN_MODULUS);
+        ERR_raise(ERR_LIB_BN, BN_R_CALLED_WITH_EVEN_MODULUS);
         return 0;
     }
     if (m->top == 1)
@@ -1285,7 +1284,7 @@ int BN_mod_exp_simple(BIGNUM *r, const BIGNUM *a, const BIGNUM *p,
             || BN_get_flags(a, BN_FLG_CONSTTIME) != 0
             || BN_get_flags(m, BN_FLG_CONSTTIME) != 0) {
         /* BN_FLG_CONSTTIME only supported by BN_mod_exp_mont() */
-        BNerr(BN_F_BN_MOD_EXP_SIMPLE, ERR_R_SHOULD_NOT_HAVE_BEEN_CALLED);
+        ERR_raise(ERR_LIB_BN, ERR_R_SHOULD_NOT_HAVE_BEEN_CALLED);
         return 0;
     }
 
@@ -1352,7 +1351,6 @@ int BN_mod_exp_simple(BIGNUM *r, const BIGNUM *a, const BIGNUM *p,
          * a window to do.  To do this we need to scan forward until the last
          * set bit before the end of the window
          */
-        j = wstart;
         wvalue = 1;
         wend = 0;
         for (i = 1; i < window; i++) {
@@ -1391,3 +1389,85 @@ int BN_mod_exp_simple(BIGNUM *r, const BIGNUM *a, const BIGNUM *p,
     bn_check_top(r);
     return ret;
 }
+
+/*
+ * This is a variant of modular exponentiation optimization that does
+ * parallel 2-primes exponentiation using 256-bit (AVX512VL) AVX512_IFMA ISA
+ * in 52-bit binary redundant representation.
+ * If such instructions are not available, or input data size is not supported,
+ * it falls back to two BN_mod_exp_mont_consttime() calls.
+ */
+int BN_mod_exp_mont_consttime_x2(BIGNUM *rr1, const BIGNUM *a1, const BIGNUM *p1,
+                                 const BIGNUM *m1, BN_MONT_CTX *in_mont1,
+                                 BIGNUM *rr2, const BIGNUM *a2, const BIGNUM *p2,
+                                 const BIGNUM *m2, BN_MONT_CTX *in_mont2,
+                                 BN_CTX *ctx)
+{
+    int ret = 0;
+
+#ifdef RSAZ_ENABLED
+    BN_MONT_CTX *mont1 = NULL;
+    BN_MONT_CTX *mont2 = NULL;
+
+    if (ossl_rsaz_avx512ifma_eligible() &&
+        ((a1->top == 16) && (p1->top == 16) && (BN_num_bits(m1) == 1024) &&
+         (a2->top == 16) && (p2->top == 16) && (BN_num_bits(m2) == 1024))) {
+
+        if (bn_wexpand(rr1, 16) == NULL)
+            goto err;
+        if (bn_wexpand(rr2, 16) == NULL)
+            goto err;
+
+        /*  Ensure that montgomery contexts are initialized */
+        if (in_mont1 != NULL) {
+            mont1 = in_mont1;
+        } else {
+            if ((mont1 = BN_MONT_CTX_new()) == NULL)
+                goto err;
+            if (!BN_MONT_CTX_set(mont1, m1, ctx))
+                goto err;
+        }
+        if (in_mont2 != NULL) {
+            mont2 = in_mont2;
+        } else {
+            if ((mont2 = BN_MONT_CTX_new()) == NULL)
+                goto err;
+            if (!BN_MONT_CTX_set(mont2, m2, ctx))
+                goto err;
+        }
+
+        ret = ossl_rsaz_mod_exp_avx512_x2(rr1->d, a1->d, p1->d, m1->d,
+                                          mont1->RR.d, mont1->n0[0],
+                                          rr2->d, a2->d, p2->d, m2->d,
+                                          mont2->RR.d, mont2->n0[0],
+                                          1024 /* factor bit size */);
+
+        rr1->top = 16;
+        rr1->neg = 0;
+        bn_correct_top(rr1);
+        bn_check_top(rr1);
+
+        rr2->top = 16;
+        rr2->neg = 0;
+        bn_correct_top(rr2);
+        bn_check_top(rr2);
+
+        goto err;
+    }
+#endif
+
+    /* rr1 = a1^p1 mod m1 */
+    ret = BN_mod_exp_mont_consttime(rr1, a1, p1, m1, ctx, in_mont1);
+    /* rr2 = a2^p2 mod m2 */
+    ret &= BN_mod_exp_mont_consttime(rr2, a2, p2, m2, ctx, in_mont2);
+
+#ifdef RSAZ_ENABLED
+err:
+    if (in_mont2 == NULL)
+        BN_MONT_CTX_free(mont2);
+    if (in_mont1 == NULL)
+        BN_MONT_CTX_free(mont1);
+#endif
+
+    return ret;
+}