Simplify BN_rand_range
[openssl.git] / crypto / bn / bn_rand.c
index f2c79b5e319f537ddcf3d205155ee8c5e4f43c67..54d622e6b412f5fe508633188d2f3ae2a9d6cfa2 100644 (file)
@@ -169,13 +169,54 @@ int     BN_bntest_rand(BIGNUM *rnd, int bits, int top, int bottom)
        }
 #endif
 
-/* random number r: min <= r < max */
-int    BN_rand_range(BIGNUM *r, BIGNUM *min, BIGNUM *max)
+
+/* random number r:  0 <= r < range */
+int    BN_rand_range(BIGNUM *r, BIGNUM *range)
        {
-       int n = BN_num_bits(max);
-       do
+       int n;
+
+       if (range->neg || BN_is_zero(range))
                {
-               if (!BN_rand(r, n, 0, 0)) return 0;
-               } while ((min && BN_cmp(r, min) < 0) || BN_cmp(r, max) >= 0);
+               BNerr(BN_F_BN_RAND_RANGE, BN_R_INVALID_RANGE);
+               return 0;
+               }
+
+       n = BN_num_bits(range); /* n > 0 */
+
+       if (n == 1)
+               {
+               if (!BN_zero(r)) return 0;
+               }
+       else if (BN_is_bit_set(range, n - 2))
+               {
+               do
+                       {
+                       /* range = 11..._2, so each iteration succeeds with probability >= .75 */
+                       if (!BN_rand(r, n, 0, 0)) return 0;
+                       }
+               while (BN_cmp(r, range) >= 0);
+               }
+       else
+               {
+               /* range = 10..._2,
+                * so  3*range (= 11..._2)  is exactly one bit longer than  range */
+               do
+                       {
+                       if (!BN_rand(r, n + 1, 0, 0)) return 0;
+                       /* If  r < 3*range,  use  r := r MOD range
+                        * (which is either  r, r - range,  or  r - 2*range).
+                        * Otherwise, iterate once more.
+                        * Since  3*range = 11..._2, each iteration succeeds with
+                        * probability >= .75. */
+                       if (BN_cmp(r ,range) >= 0)
+                               {
+                               if (!BN_sub(r, r, range)) return 0;
+                               if (BN_cmp(r, range) >= 0)
+                                       if (!BN_sub(r, r, range)) return 0;
+                               }
+                       }
+               while (BN_cmp(r, range) >= 0);
+               }
+
        return 1;
        }