Improve BN_mod_inverse performance.
authorBodo Möller <bodo@openssl.org>
Wed, 29 Nov 2000 09:41:19 +0000 (09:41 +0000)
committerBodo Möller <bodo@openssl.org>
Wed, 29 Nov 2000 09:41:19 +0000 (09:41 +0000)
Get the BN_mod_exp_mont bugfix (for handling negative inputs) correct
this time.

CHANGES
crypto/bn/bn_exp.c
crypto/bn/bn_gcd.c
crypto/bn/expspeed.c

diff --git a/CHANGES b/CHANGES
index 998ea1639cf9628384dd3166621572fe99067ee1..7a5bac81d0ce94852349cb2547a92337b48ec275 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -3,6 +3,12 @@
 
  Changes between 0.9.6 and 0.9.7  [xx XXX 2000]
 
+  *) Make BN_mod_inverse faster by explicitly handling small quotients
+     in the Euclid loop instead of always using BN_div.
+     (Speed gain about 20% for small moduli [256 or 512 bits], about
+     30% for larger ones [1024 or 2048 bits].)
+     [Bodo Moeller]
+
   *) Disable ssl2_peek and ssl3_peek (i.e., both implementations
      of SSL_peek) because they both are completely broken.
      They will be fixed RSN by adding an additional 'peek' parameter
index eab394b96231051d0dc5070dd679142f3ceca8f9..35ab56efc04d6b1ab90cb1fbdb5e0e910dfac7e4 100644 (file)
@@ -376,7 +376,7 @@ int BN_mod_exp_mont(BIGNUM *rr, const BIGNUM *a, const BIGNUM *p,
 
        BN_init(&val[0]);
        ts=1;
-       if (!a->neg && BN_ucmp(a,m) >= 0)
+       if (a->neg || BN_ucmp(a,m) >= 0)
                {
                if (!BN_nnmod(&(val[0]),a,m,ctx))
                        goto err;
index ea6816a43fc32a6deb82a68f3fb552e3bf19fb0c..d53f32656b2e84927109a1415030d734d7a30834 100644 (file)
@@ -204,7 +204,7 @@ err:
 BIGNUM *BN_mod_inverse(BIGNUM *in,
        const BIGNUM *a, const BIGNUM *n, BN_CTX *ctx)
        {
-       BIGNUM *A,*B,*X,*Y,*M,*D,*R=NULL;
+       BIGNUM *A,*B,*X,*Y,*M,*D,*T,*R=NULL;
        BIGNUM *ret=NULL;
        int sign;
 
@@ -218,7 +218,8 @@ BIGNUM *BN_mod_inverse(BIGNUM *in,
        D = BN_CTX_get(ctx);
        M = BN_CTX_get(ctx);
        Y = BN_CTX_get(ctx);
-       if (Y == NULL) goto err;
+       T = BN_CTX_get(ctx);
+       if (T == NULL) goto err;
 
        if (in == NULL)
                R=BN_new();
@@ -253,7 +254,47 @@ BIGNUM *BN_mod_inverse(BIGNUM *in,
                 *     -sign*Y*a  ==  A   (mod |n|)
                 */
 
-               if (!BN_div(D,M,A,B,ctx)) goto err;
+               /* (D, M) := (A/B, A%B) ... */
+               if (BN_num_bits(A) == BN_num_bits(B))
+                       {
+                       if (!BN_one(D)) goto err;
+                       if (!BN_sub(M,A,B)) goto err;
+                       }
+               else if (BN_num_bits(A) == BN_num_bits(B) + 1)
+                       {
+                       /* A/B is 1, 2, or 3 */
+                       if (!BN_lshift1(T,B)) goto err;
+                       if (BN_ucmp(A,T) < 0)
+                               {
+                               /* A < 2*B, so D=1 */
+                               if (!BN_one(D)) goto err;
+                               if (!BN_sub(M,A,B)) goto err;
+                               }
+                       else
+                               {
+                               /* A >= 2*B, so D=2 or D=3 */
+                               if (!BN_sub(M,A,T)) goto err;
+                               if (!BN_add(D,T,B)) goto err; /* use D (:= 3*B) as temp */
+                               if (BN_ucmp(A,D) < 0)
+                                       {
+                                       /* A < 3*B, so D=2 */
+                                       if (!BN_set_word(D,2)) goto err;
+                                       /* M (= A - 2*B) already has the correct value */
+                                       }
+                               else
+                                       {
+                                       /* only D=3 remains */
+                                       if (!BN_set_word(D,3)) goto err;
+                                       /* currently  M = A - 2*B,  but we need  M = A - 3*B */
+                                       if (!BN_sub(M,M,B)) goto err;
+                                       }
+                               }
+                       }
+               else
+                       {
+                       if (!BN_div(D,M,A,B,ctx)) goto err;
+                       }
+               
                /* Now
                 *      A = D*B + M;
                 * thus we have
@@ -286,8 +327,33 @@ BIGNUM *BN_mod_inverse(BIGNUM *in,
                 * Note that  X  and  Y  stay non-negative all the time.
                 */
 
-               if (!BN_mul(tmp,D,X,ctx)) goto err;
-               if (!BN_add(tmp,tmp,Y)) goto err;
+               /* most of the time D is very small, so we can optimize tmp := D*X+Y */
+               if (BN_is_one(D))
+                       {
+                       if (!BN_add(tmp,X,Y)) goto err;
+                       }
+               else
+                       {
+                       if (BN_is_word(D,2))
+                               {
+                               if (!BN_lshift1(tmp,X)) goto err;
+                               }
+                       else if (BN_is_word(D,3))
+                               {
+                               if (!BN_lshift1(tmp,X)) goto err;
+                               if (!BN_add(tmp,tmp,X)) goto err;
+                               }
+                       else if (BN_is_word(D,4))
+                               {
+                               if (!BN_lshift(tmp,X,2)) goto err;
+                               }
+                       else
+                               {
+                               if (!BN_mul(tmp,D,X,ctx)) goto err;
+                               }
+                       if (!BN_add(tmp,tmp,Y)) goto err;
+                       }
+               
                M=Y; /* keep the BIGNUM object, the value does not matter */
                Y=X;
                X=tmp;
@@ -312,7 +378,10 @@ BIGNUM *BN_mod_inverse(BIGNUM *in,
        if (BN_is_one(A))
                {
                /* Y*a == 1  (mod |n|) */
-               if (!BN_mod(R,Y,n,ctx)) goto err;
+               if (BN_ucmp(Y,n) < 0)
+                       if (!BN_copy(R,Y)) goto err;
+               else
+                       if (!BN_nnmod(R,Y,n,ctx)) goto err;
                }
        else
                {
index 99cf2c52a856a99d77ec33e0feb347aabae161df..5f76aa41260bf351960f3fc9eb52a6f45822ae0d 100644 (file)
 
 /* most of this code has been pilfered from my libdes speed.c program */
 
-#define BASENUM        5000
+#define BASENUM        10000
+#define NUM_START 0
+
+
+/* determine timings for modexp, gcd, or modular inverse */
+#define TEST_EXP
+#undef TEST_GCD
+#undef TEST_INV
+
+
 #undef PROG
 #define PROG bnspeed_main
 
@@ -161,11 +170,30 @@ static double Time_F(int s)
 #endif
        }
 
-#define NUM_SIZES      6
-static int sizes[NUM_SIZES]={256,512,1024,2048,4096,8192};
-static int mul_c[NUM_SIZES]={8*8*8*8*8,8*8*8*8,8*8*8,8*8,8,1};
+#define NUM_SIZES      7
+#if NUM_START > NUM_SIZES
+#   error "NUM_START > NUM_SIZES"
+#endif
+static int sizes[NUM_SIZES]={128,256,512,1024,2048,4096,8192};
+static int mul_c[NUM_SIZES]={8*8*8*8*8*8,8*8*8*8*8,8*8*8*8,8*8*8,8*8,8,1};
 /*static int sizes[NUM_SIZES]={59,179,299,419,539}; */
 
+#define RAND_SEED(string) { const char str[] = string; RAND_seed(string, sizeof string); }
+
+static void genprime_cb(int p, int n, void *arg)
+       {
+       char c='*';
+
+       if (p == 0) c='.';
+       if (p == 1) c='+';
+       if (p == 2) c='*';
+       if (p == 3) c='\n';
+       putc(c, stderr);
+       fflush(stderr);
+       (void)n;
+       (void)arg;
+       }
+
 void do_mul_exp(BIGNUM *r,BIGNUM *a,BIGNUM *b,BIGNUM *c,BN_CTX *ctx); 
 
 int main(int argc, char **argv)
@@ -179,6 +207,10 @@ int main(int argc, char **argv)
        c=BN_new();
        r=BN_new();
 
+       while (!RAND_status())
+               /* not enough bits */
+               RAND_SEED("I demand a manual recount!");
+
        do_mul_exp(r,a,b,c,ctx);
        }
 
@@ -188,23 +220,61 @@ void do_mul_exp(BIGNUM *r, BIGNUM *a, BIGNUM *b, BIGNUM *c, BN_CTX *ctx)
        double tm;
        long num;
 
+#if defined(TEST_EXP) + defined(TEST_GCD) + defined(TEST_INV) != 1
+#  error "choose one test"
+#endif
+
+#ifdef TEST_INV
+#  define C_PRIME
+#endif
+
        num=BASENUM;
-       for (i=0; i<NUM_SIZES; i++)
+       for (i=NUM_START; i<NUM_SIZES; i++)
                {
-               BN_pseudo_rand(a,sizes[i],1,0);
-               BN_pseudo_rand(b,sizes[i],1,0);
-               BN_pseudo_rand(c,sizes[i],1,1);
-               BN_mod(a,a,c,ctx);
-               BN_mod(b,b,c,ctx);
+#ifdef C_PRIME
+               if (!BN_generate_prime(c,sizes[i],0,NULL,NULL,genprime_cb,NULL)) goto err;
+               putc('\n', stderr);
+               fflush(stderr);
+#endif
 
                Time_F(START);
                for (k=0; k<num; k++)
-                       BN_mod_exp(r,a,b,c,ctx);
+                       {
+                       if (k%50 == 0) /* Average over num/50 different choices of random numbers. */
+                               {
+                               if (!BN_pseudo_rand(a,sizes[i],1,0)) goto err;
+                               if (!BN_pseudo_rand(b,sizes[i],1,0)) goto err;
+#ifndef C_PRIME
+                               if (!BN_pseudo_rand(c,sizes[i],1,1)) goto err;
+#endif
+                               }
+#if defined(TEST_EXP)
+                       if (!BN_mod_exp(r,a,b,c,ctx)) goto err;
+#elif defined(TEST_GCD)
+                       if (!BN_gcd(r,a,b,ctx)) goto err;
+                       if (!BN_gcd(r,b,c,ctx)) goto err;
+                       if (!BN_gcd(r,b,c,ctx)) goto err;
+#else /* TEST_INV */
+                       if (!BN_mod_inverse(r,a,c,ctx)) goto err;
+                       if (!BN_mod_inverse(r,b,c,ctx)) goto err;
+#endif
+                       }
                tm=Time_F(STOP);
-               printf("mul %4d ^ %4d %% %d -> %8.3fms %5.1f\n",sizes[i],sizes[i],sizes[i],tm*1000.0/num,tm*mul_c[i]/num);
+               printf(
+#if defined(TEST_EXP)
+                       "modexp %4d ^ %4d %% %4d"
+#elif defined(TEST_GCD)
+                       "3*gcd %4d %4d %4d"
+#else /* TEST_INV */
+                       "2*inv %4d %4d mod %4d"
+#endif
+                       " -> %8.3fms %5.1f (%d)\n",sizes[i],sizes[i],sizes[i],tm*1000.0/num,tm*mul_c[i]/num, num);
                num/=7;
                if (num <= 0) num=1;
                }
+       return;
 
+ err:
+       ERR_print_errors_fp(stderr);
        }