Tolerate negative numbers in BN_is_prime.
[openssl.git] / crypto / bn / bn_prime.c
index 21d49affda6435a3b5110c0d2d8973a53821a499..1b62e60b0e0c0212cb22225b97b54c217ea62fe7 100644 (file)
@@ -154,13 +154,13 @@ err:
        return(found ? rnd : NULL);
        }
 
-int BN_is_prime(BIGNUM *a, int checks, void (*callback)(int,int,void *),
+int BN_is_prime(const BIGNUM *a, int checks, void (*callback)(int,int,void *),
        BN_CTX *ctx_passed, void *cb_arg)
        {
        return BN_is_prime_fasttest(a, checks, callback, ctx_passed, cb_arg, 0);
        }
 
-int BN_is_prime_fasttest(BIGNUM *a, int checks,
+int BN_is_prime_fasttest(const BIGNUM *a, int checks,
                void (*callback)(int,int,void *),
                BN_CTX *ctx_passed, void *cb_arg,
                int do_trial_division)
@@ -168,15 +168,13 @@ int BN_is_prime_fasttest(BIGNUM *a, int checks,
        int i, j, ret = -1;
        int k;
        BN_CTX *ctx = NULL;
-       BIGNUM *a1, *a1_odd, *check; /* taken from ctx */
+       BIGNUM *A1, *A1_odd, *check; /* taken from ctx */
        BN_MONT_CTX *mont = NULL;
+       BIGNUM *A;
 
        if (checks == BN_prime_checks)
                checks = BN_prime_checks_for_size(BN_num_bits(a));
 
-       if (a->neg) /* for now, refuse to handle negative numbers */
-               return -1;
-
        /* first look for small factors */
        if (!BN_is_odd(a))
                return(0);
@@ -193,47 +191,56 @@ int BN_is_prime_fasttest(BIGNUM *a, int checks,
        else
                if ((ctx=BN_CTX_new()) == NULL)
                        goto err;
-       a1 = &(ctx->bn[ctx->tos++]);
-       a1_odd = &(ctx->bn[ctx->tos++]);
+       /* A := abs(a) */
+       if (a->neg)
+               {
+               A = &(ctx->bn[ctx->tos++]);
+               BN_copy(A, a);
+               A->neg = 0;
+               }
+       else
+               A = a;
+       A1 = &(ctx->bn[ctx->tos++]);
+       A1_odd = &(ctx->bn[ctx->tos++]);
        check = &(ctx->bn[ctx->tos++]);;
 
-       /* compute a1 := a - 1 */
-       if (!BN_copy(a1, a))
+       /* compute A1 := A - 1 */
+       if (!BN_copy(A1, A))
                goto err;
-       if (!BN_sub_word(a1, 1))
+       if (!BN_sub_word(A1, 1))
                goto err;
-       if (BN_is_zero(a1))
+       if (BN_is_zero(A1))
                {
                ret = 0;
                goto err;
                }
 
-       /* write  a1  as  a1_odd * 2^k */
+       /* write  A1  as  A1_odd * 2^k */
        k = 1;
-       while (!BN_is_bit_set(a1, k))
+       while (!BN_is_bit_set(A1, k))
                k++;
-       if (!BN_rshift(a1_odd, a1, k))
+       if (!BN_rshift(A1_odd, A1, k))
                goto err;
 
-       /* Montgomery setup for computations mod a */
+       /* Montgomery setup for computations mod A */
        mont = BN_MONT_CTX_new();
        if (mont == NULL)
                goto err;
-       if (!BN_MONT_CTX_set(mont, a, ctx))
+       if (!BN_MONT_CTX_set(mont, A, ctx))
                goto err;
        
        for (i = 0; i < checks; i++)
                {
-               if (!BN_pseudo_rand(check, BN_num_bits(a1), 0, 0))
+               if (!BN_pseudo_rand(check, BN_num_bits(A1), 0, 0))
                        goto err;
-               if (BN_cmp(check, a1) >= 0)
-                       if (!BN_sub(check, check, a1))
+               if (BN_cmp(check, A1) >= 0)
+                       if (!BN_sub(check, check, A1))
                                goto err;
                if (!BN_add_word(check, 1))
                        goto err;
-               /* now 1 <= check < a */
+               /* now 1 <= check < A */
 
-               j = witness(check, a, a1, a1_odd, k, ctx, mont);
+               j = witness(check, A, A1, A1_odd, k, ctx, mont);
                if (j == -1) goto err;
                if (j)
                        {
@@ -245,7 +252,11 @@ int BN_is_prime_fasttest(BIGNUM *a, int checks,
        ret=1;
 err:
        if (ctx_passed != NULL)
-               ctx_passed->tos -= 3; /* a1, a1_odd, check */
+               {
+               ctx_passed->tos -= 3; /* A1, A1_odd, check */
+               if (a != A)
+                       --ctx_passed->tos; /* A */
+               }
        else if (ctx != NULL)
                BN_CTX_free(ctx);
        if (mont != NULL)