Ensure that the addition mods[i]+delta cannot overflow in probable_prime().
[openssl.git] / crypto / bn / bn_prime.c
index 6c16029957ed7f43eac0c4783e3c68989121af4c..5bab019553bf2ecf3486e46c561c01408520be89 100644 (file)
@@ -142,6 +142,8 @@ int BN_GENCB_call(BN_GENCB *cb, int a, int b)
                {
        case 1:
                /* Deprecated-style callbacks */
+               if(!cb->cb.cb_1)
+                       return 1;
                cb->cb.cb_1(a, b, cb->arg);
                return 1;
        case 2:
@@ -157,15 +159,17 @@ int BN_GENCB_call(BN_GENCB *cb, int a, int b)
 int BN_generate_prime_ex(BIGNUM *ret, int bits, int safe,
        const BIGNUM *add, const BIGNUM *rem, BN_GENCB *cb)
        {
-       BIGNUM t;
+       BIGNUM *t;
        int found=0;
        int i,j,c1=0;
        BN_CTX *ctx;
        int checks = BN_prime_checks_for_size(bits);
 
-       BN_init(&t);
        ctx=BN_CTX_new();
        if (ctx == NULL) goto err;
+       BN_CTX_start(ctx);
+       t = BN_CTX_get(ctx);
+       if(!t) goto err;
 loop: 
        /* make a random number and set the top and bottom bits */
        if (add == NULL)
@@ -202,7 +206,7 @@ loop:
                 * check that (p-1)/2 is prime.
                 * Since a prime is odd, We just
                 * need to divide by 2 */
-               if (!BN_rshift1(&t,ret)) goto err;
+               if (!BN_rshift1(t,ret)) goto err;
 
                for (i=0; i<checks; i++)
                        {
@@ -210,7 +214,7 @@ loop:
                        if (j == -1) goto err;
                        if (j == 0) goto loop;
 
-                       j=BN_is_prime_fasttest_ex(&t,1,ctx,0,cb);
+                       j=BN_is_prime_fasttest_ex(t,1,ctx,0,cb);
                        if (j == -1) goto err;
                        if (j == 0) goto loop;
 
@@ -222,8 +226,12 @@ loop:
        /* we have a prime :-) */
        found = 1;
 err:
-       BN_free(&t);
-       if (ctx != NULL) BN_CTX_free(ctx);
+       if (ctx != NULL)
+               {
+               BN_CTX_end(ctx);
+               BN_CTX_free(ctx);
+               }
+       bn_check_top(ret);
        return found;
        }
 
@@ -250,7 +258,8 @@ int BN_is_prime_fasttest_ex(const BIGNUM *a, int checks, BN_CTX *ctx_passed,
 
        /* first look for small factors */
        if (!BN_is_odd(a))
-               return 0;
+               /* a is even => a is prime if and only if a == 2 */
+               return BN_is_word(a, 2);
        if (do_trial_division)
                {
                for (i = 1; i < NUMPRIMES; i++)
@@ -361,6 +370,7 @@ static int witness(BIGNUM *w, const BIGNUM *a, const BIGNUM *a1,
                }
        /* If we get here, 'w' is the (a-1)/2-th power of the original 'w',
         * and it is neither -1 nor +1 -- so 'a' cannot be prime */
+       bn_check_top(w);
        return 1;
        }
 
@@ -368,13 +378,14 @@ static int probable_prime(BIGNUM *rnd, int bits)
        {
        int i;
        BN_ULONG mods[NUMPRIMES];
-       BN_ULONG delta,d;
+       BN_ULONG delta,maxdelta;
 
 again:
        if (!BN_rand(rnd,bits,1,1)) return(0);
        /* we now have a random number 'rand' to test. */
        for (i=1; i<NUMPRIMES; i++)
                mods[i]=BN_mod_word(rnd,(BN_ULONG)primes[i]);
+       maxdelta=BN_MASK2 - primes[NUMPRIMES-1];
        delta=0;
        loop: for (i=1; i<NUMPRIMES; i++)
                {
@@ -382,16 +393,13 @@ again:
                 * that gcd(rnd-1,primes) == 1 (except for 2) */
                if (((mods[i]+delta)%primes[i]) <= 1)
                        {
-                       d=delta;
                        delta+=2;
-                       /* perhaps need to check for overflow of
-                        * delta (but delta can be up to 2^32)
-                        * 21-May-98 eay - added overflow check */
-                       if (delta < d) goto again;
+                       if (delta > maxdelta) goto again;
                        goto loop;
                        }
                }
        if (!BN_add_word(rnd,delta)) return(0);
+       bn_check_top(rnd);
        return(1);
        }
 
@@ -429,6 +437,7 @@ static int probable_prime_dh(BIGNUM *rnd, int bits,
        ret=1;
 err:
        BN_CTX_end(ctx);
+       bn_check_top(rnd);
        return(ret);
        }
 
@@ -480,5 +489,6 @@ static int probable_prime_dh_safe(BIGNUM *p, int bits, const BIGNUM *padd,
        ret=1;
 err:
        BN_CTX_end(ctx);
+       bn_check_top(p);
        return(ret);
        }