Reduce inputs before the RSAZ code.
[openssl.git] / crypto / bn / bn_exp.c
index 83b0e5a..9ea120b 100644 (file)
@@ -648,34 +648,41 @@ int BN_mod_exp_mont_consttime(BIGNUM *rr, const BIGNUM *a, const BIGNUM *p,
             goto err;
     }
 
+    if (a->neg || BN_ucmp(a, m) >= 0) {
+        BIGNUM *reduced = BN_CTX_get(ctx);
+        if (reduced == NULL
+            || !BN_nnmod(reduced, a, m, ctx)) {
+            goto err;
+        }
+        a = reduced;
+    }
+
 #ifdef RSAZ_ENABLED
-    if (!a->neg) {
-        /*
-         * If the size of the operands allow it, perform the optimized
-         * RSAZ exponentiation. For further information see
-         * crypto/bn/rsaz_exp.c and accompanying assembly modules.
-         */
-        if ((16 == a->top) && (16 == p->top) && (BN_num_bits(m) == 1024)
-            && rsaz_avx2_eligible()) {
-            if (NULL == bn_wexpand(rr, 16))
-                goto err;
-            RSAZ_1024_mod_exp_avx2(rr->d, a->d, p->d, m->d, mont->RR.d,
-                                   mont->n0[0]);
-            rr->top = 16;
-            rr->neg = 0;
-            bn_correct_top(rr);
-            ret = 1;
+    /*
+     * If the size of the operands allow it, perform the optimized
+     * RSAZ exponentiation. For further information see
+     * crypto/bn/rsaz_exp.c and accompanying assembly modules.
+     */
+    if ((16 == a->top) && (16 == p->top) && (BN_num_bits(m) == 1024)
+        && rsaz_avx2_eligible()) {
+        if (NULL == bn_wexpand(rr, 16))
             goto err;
-        } else if ((8 == a->top) && (8 == p->top) && (BN_num_bits(m) == 512)) {
-            if (NULL == bn_wexpand(rr, 8))
-                goto err;
-            RSAZ_512_mod_exp(rr->d, a->d, p->d, m->d, mont->n0[0], mont->RR.d);
-            rr->top = 8;
-            rr->neg = 0;
-            bn_correct_top(rr);
-            ret = 1;
+        RSAZ_1024_mod_exp_avx2(rr->d, a->d, p->d, m->d, mont->RR.d,
+                               mont->n0[0]);
+        rr->top = 16;
+        rr->neg = 0;
+        bn_correct_top(rr);
+        ret = 1;
+        goto err;
+    } else if ((8 == a->top) && (8 == p->top) && (BN_num_bits(m) == 512)) {
+        if (NULL == bn_wexpand(rr, 8))
             goto err;
-        }
+        RSAZ_512_mod_exp(rr->d, a->d, p->d, m->d, mont->n0[0], mont->RR.d);
+        rr->top = 8;
+        rr->neg = 0;
+        bn_correct_top(rr);
+        ret = 1;
+        goto err;
     }
 #endif
 
@@ -747,12 +754,7 @@ int BN_mod_exp_mont_consttime(BIGNUM *rr, const BIGNUM *a, const BIGNUM *p,
         goto err;
 
     /* prepare a^1 in Montgomery domain */
-    if (a->neg || BN_ucmp(a, m) >= 0) {
-        if (!BN_nnmod(&am, a, m, ctx))
-            goto err;
-        if (!bn_to_mont_fixed_top(&am, &am, mont, ctx))
-            goto err;
-    } else if (!bn_to_mont_fixed_top(&am, a, mont, ctx))
+    if (!bn_to_mont_fixed_top(&am, a, mont, ctx))
         goto err;
 
 #if defined(SPARC_T4_MONT)