Handle special cases correctly in exponentation functions.
[openssl.git] / crypto / bn / bn_exp.c
index 51c8282593ee366f886b471fa20665ac8c6207b2..f7e7ced2ca0840973aa205f6323b7eccf4ee9f49 100644 (file)
@@ -240,11 +240,6 @@ int BN_mod_exp_recp(BIGNUM *r, const BIGNUM *a, const BIGNUM *p,
                ret = BN_one(r);
                return ret;
                }
-       if (BN_is_zero(a))
-               {
-               ret = BN_zero(r);
-               return ret;
-               }
 
        BN_CTX_start(ctx);
        if ((aa = BN_CTX_get(ctx)) == NULL) goto err;
@@ -256,6 +251,11 @@ int BN_mod_exp_recp(BIGNUM *r, const BIGNUM *a, const BIGNUM *p,
        ts=1;
 
        if (!BN_nnmod(&(val[0]),a,m,ctx)) goto err;             /* 1 */
+       if (BN_is_zero(&(val[0])))
+               {
+               ret = BN_zero(r);
+               goto err;
+               }
 
        window = BN_window_bits_for_exponent_size(bits);
        if (window > 1)
@@ -365,11 +365,7 @@ int BN_mod_exp_mont(BIGNUM *rr, const BIGNUM *a, const BIGNUM *p,
                ret = BN_one(rr);
                return ret;
                }
-       if (BN_is_zero(a))
-               {
-               ret = BN_zero(rr);
-               return ret;
-               }
+
        BN_CTX_start(ctx);
        d = BN_CTX_get(ctx);
        r = BN_CTX_get(ctx);
@@ -396,6 +392,11 @@ int BN_mod_exp_mont(BIGNUM *rr, const BIGNUM *a, const BIGNUM *p,
                }
        else
                aa=a;
+       if (BN_is_zero(aa))
+               {
+               ret = BN_zero(rr);
+               goto err;
+               }
        if (!BN_to_montgomery(&(val[0]),aa,mont,ctx)) goto err; /* 1 */
 
        window = BN_window_bits_for_exponent_size(bits);
@@ -632,11 +633,6 @@ int BN_mod_exp_simple(BIGNUM *r,
                ret = BN_one(r);
                return ret;
                }
-       if (BN_is_zero(a))
-               {
-               ret = BN_one(r);
-               return ret;
-               }
 
        BN_CTX_start(ctx);
        if ((d = BN_CTX_get(ctx)) == NULL) goto err;
@@ -644,6 +640,11 @@ int BN_mod_exp_simple(BIGNUM *r,
        BN_init(&(val[0]));
        ts=1;
        if (!BN_nnmod(&(val[0]),a,m,ctx)) goto err;             /* 1 */
+       if (BN_is_zero(&(val[0])))
+               {
+               ret = BN_one(r);
+               return ret;
+               }
 
        window = BN_window_bits_for_exponent_size(bits);
        if (window > 1)