bn_exp.c: improve portability.
[openssl.git] / crypto / bn / bn_exp.c
index ce31ad0a58a61f74f33fbe12118fedb188ecea3c..aae00491db21202ef1e484fa98837f8a1b7914fd 100644 (file)
 #include "cryptlib.h"
 #include "bn_lcl.h"
 
+#include <stdlib.h>
+#ifdef _WIN32
+# include <malloc.h>
+# ifndef alloca
+#  define alloca _alloca
+# endif
+#elif defined(__GNUC__)
+# ifndef alloca
+#  define alloca(s) __builtin_alloca((s))
+# endif
+#endif
+
 /* maximum precomputation table size for *variable* sliding windows */
 #define TABLE_SIZE     32
 
@@ -562,7 +574,7 @@ static int MOD_EXP_CTIME_COPY_FROM_PREBUF(BIGNUM *b, int top, unsigned char *buf
 
 /* Given a pointer value, compute the next address that is a cache line multiple. */
 #define MOD_EXP_CTIME_ALIGN(x_) \
-       ((unsigned char*)(x_) + (MOD_EXP_CTIME_MIN_CACHE_LINE_WIDTH - (((BN_ULONG)(x_)) & (MOD_EXP_CTIME_MIN_CACHE_LINE_MASK))))
+       ((unsigned char*)(x_) + (MOD_EXP_CTIME_MIN_CACHE_LINE_WIDTH - (((size_t)(x_)) & (MOD_EXP_CTIME_MIN_CACHE_LINE_MASK))))
 
 /* This variant of BN_mod_exp_mont() uses fixed windows and the special
  * precomputation memory layout to limit data-dependency to a minimum
@@ -573,17 +585,16 @@ static int MOD_EXP_CTIME_COPY_FROM_PREBUF(BIGNUM *b, int top, unsigned char *buf
 int BN_mod_exp_mont_consttime(BIGNUM *rr, const BIGNUM *a, const BIGNUM *p,
                    const BIGNUM *m, BN_CTX *ctx, BN_MONT_CTX *in_mont)
        {
-       int i,bits,ret=0,idx,window,wvalue;
+       int i,bits,ret=0,window,wvalue;
        int top;
        BIGNUM *r;
-       const BIGNUM *aa;
        BN_MONT_CTX *mont=NULL;
 
        int numPowers;
        unsigned char *powerbufFree=NULL;
        int powerbufLen = 0;
        unsigned char *powerbuf=NULL;
-       BIGNUM *computeTemp=NULL, *am=NULL;
+       BIGNUM computeTemp, *am=NULL;
 
        bn_check_top(a);
        bn_check_top(p);
@@ -621,39 +632,154 @@ int BN_mod_exp_mont_consttime(BIGNUM *rr, const BIGNUM *a, const BIGNUM *p,
 
        /* Get the window size to use with size of p. */
        window = BN_window_bits_for_ctime_exponent_size(bits);
+#if defined(OPENSSL_BN_ASM_MONT5)
+       if (window==6 && bits<=1024) window=5;  /* ~5% improvement of 2048-bit RSA sign */
+#endif
+       /* Adjust the number of bits up to a multiple of the window size.
+        * If the exponent length is not a multiple of the window size, then
+        * this pads the most significant bits with zeros to normalize the
+        * scanning loop to there's no special cases.
+        *
+        * * NOTE: Making the window size a power of two less than the native
+        * * word size ensures that the padded bits won't go past the last
+        * * word in the internal BIGNUM structure. Going past the end will
+        * * still produce the correct result, but causes a different branch
+        * * to be taken in the BN_is_bit_set function.
+        */
+       bits = ((bits+window-1)/window)*window;
 
        /* Allocate a buffer large enough to hold all of the pre-computed
-        * powers of a.
+        * powers of a, plus computeTemp.
         */
        numPowers = 1 << window;
-       powerbufLen = sizeof(m->d[0])*top*numPowers;
+       powerbufLen = sizeof(m->d[0])*(top*numPowers +
+                               (top>numPowers?top:numPowers));
+#ifdef alloca
+       if (powerbufLen < 3072)
+               powerbufFree = alloca(powerbufLen+MOD_EXP_CTIME_MIN_CACHE_LINE_WIDTH);
+       else
+#endif
        if ((powerbufFree=(unsigned char*)OPENSSL_malloc(powerbufLen+MOD_EXP_CTIME_MIN_CACHE_LINE_WIDTH)) == NULL)
                goto err;
                
        powerbuf = MOD_EXP_CTIME_ALIGN(powerbufFree);
        memset(powerbuf, 0, powerbufLen);
 
+#ifdef alloca
+       if (powerbufLen < 3072)
+               powerbufFree = NULL;
+#endif
+
+       computeTemp.d = (BN_ULONG *)(powerbuf + sizeof(m->d[0])*top*numPowers);
+       computeTemp.top = computeTemp.dmax = top;
+       computeTemp.neg = 0;
+       computeTemp.flags = BN_FLG_STATIC_DATA;
+
        /* Initialize the intermediate result. Do this early to save double conversion,
         * once each for a^0 and intermediate result.
         */
        if (!BN_to_montgomery(r,BN_value_one(),mont,ctx)) goto err;
-       if (!MOD_EXP_CTIME_COPY_TO_PREBUF(r, top, powerbuf, 0, numPowers)) goto err;
 
        /* Initialize computeTemp as a^1 with montgomery precalcs */
-       computeTemp = BN_CTX_get(ctx);
        am = BN_CTX_get(ctx);
-       if (computeTemp==NULL || am==NULL) goto err;
+       if (am==NULL) goto err;
 
        if (a->neg || BN_ucmp(a,m) >= 0)
                {
-               if (!BN_mod(am,a,m,ctx))
-                       goto err;
-               aa= am;
+               if (!BN_mod(am,a,m,ctx))                goto err;
+               if (!BN_to_montgomery(am,am,mont,ctx))  goto err;
                }
-       else
-               aa=a;
-       if (!BN_to_montgomery(am,aa,mont,ctx)) goto err;
-       if (!BN_copy(computeTemp, am)) goto err;
+       else    if (!BN_to_montgomery(am,a,mont,ctx))   goto err;
+
+       if (!BN_copy(&computeTemp, am)) goto err;
+
+       if (bn_wexpand(am,top)==NULL || bn_wexpand(r,top)==NULL)
+               goto err;
+
+#if defined(OPENSSL_BN_ASM_MONT5)
+    /* This optimization uses ideas from http://eprint.iacr.org/2011/239,
+     * specifically optimization of cache-timing attack countermeasures
+     * and pre-computation optimization. */
+
+    /* Dedicated window==4 case improves 512-bit RSA sign by ~15%, but as
+     * 512-bit RSA is hardly relevant, we omit it to spare size... */ 
+    if (window==5)
+       {
+       void bn_mul_mont_gather5(BN_ULONG *rp,const BN_ULONG *ap,
+                       const void *table,const BN_ULONG *np,
+                       const BN_ULONG *n0,int num,int power);
+       void bn_scatter5(const BN_ULONG *inp,size_t num,
+                       void *table,size_t power);
+
+       BN_ULONG *acc, *np=mont->N.d, *n0=mont->n0;
+
+       bn_scatter5(r->d,r->top,powerbuf,0);
+       bn_scatter5(am->d,am->top,powerbuf,1);
+
+       acc = computeTemp.d;
+#if 0
+       for (i=2; i<32; i++)
+               {
+               bn_mul_mont_gather5(acc,am->d,powerbuf,np,n0,top,i-1);
+               bn_scatter5(acc,top,powerbuf,i);
+               }
+#else
+       /* same as above, but uses squaring for 1/2 of operations */
+       for (i=2; i<32; i*=2)
+               {
+               bn_mul_mont(acc,acc,acc,np,n0,top);
+               bn_scatter5(acc,top,powerbuf,i);
+               }
+       for (i=3; i<8; i+=2)
+               {
+               int j;
+               bn_mul_mont_gather5(acc,am->d,powerbuf,np,n0,top,i-1);
+               bn_scatter5(acc,top,powerbuf,i);
+               for (j=2*i; j<32; j*=2)
+                       {
+                       bn_mul_mont(acc,acc,acc,np,n0,top);
+                       bn_scatter5(acc,top,powerbuf,j);
+                       }
+               }
+       for (; i<16; i+=2)
+               {
+               bn_mul_mont_gather5(acc,am->d,powerbuf,np,n0,top,i-1);
+               bn_scatter5(acc,top,powerbuf,i);
+               bn_mul_mont(acc,acc,acc,np,n0,top);
+               bn_scatter5(acc,top,powerbuf,2*i);
+               }
+       for (; i<32; i+=2)
+               {
+               bn_mul_mont_gather5(acc,am->d,powerbuf,np,n0,top,i-1);
+               bn_scatter5(acc,top,powerbuf,i);
+               }
+#endif
+       acc = r->d;
+
+       /* Scan the exponent one window at a time starting from the most
+        * significant bits.
+        */
+       bits--;
+       while (bits >= 0)
+               {
+               for (wvalue=0, i=0; i<5; i++,bits--)
+                       wvalue = (wvalue<<1)+BN_is_bit_set(p,bits);
+
+               bn_mul_mont(acc,acc,acc,np,n0,top);
+               bn_mul_mont(acc,acc,acc,np,n0,top);
+               bn_mul_mont(acc,acc,acc,np,n0,top);
+               bn_mul_mont(acc,acc,acc,np,n0,top);
+               bn_mul_mont(acc,acc,acc,np,n0,top);
+               bn_mul_mont_gather5(acc,acc,powerbuf,np,n0,top,wvalue);
+               }
+
+       r->top=top;
+       bn_correct_top(r);
+       }
+    else
+#endif
+       {
+       if (!MOD_EXP_CTIME_COPY_TO_PREBUF(r, top, powerbuf, 0, numPowers)) goto err;
        if (!MOD_EXP_CTIME_COPY_TO_PREBUF(am, top, powerbuf, 1, numPowers)) goto err;
 
        /* If the window size is greater than 1, then calculate
@@ -666,46 +792,34 @@ int BN_mod_exp_mont_consttime(BIGNUM *rr, const BIGNUM *a, const BIGNUM *p,
                for (i=2; i<numPowers; i++)
                        {
                        /* Calculate a^i = a^(i-1) * a */
-                       if (!BN_mod_mul_montgomery(computeTemp,am,computeTemp,mont,ctx))
+                       if (!BN_mod_mul_montgomery(&computeTemp,am,&computeTemp,mont,ctx))
                                goto err;
-                       if (!MOD_EXP_CTIME_COPY_TO_PREBUF(computeTemp, top, powerbuf, i, numPowers)) goto err;
+                       if (!MOD_EXP_CTIME_COPY_TO_PREBUF(&computeTemp, top, powerbuf, i, numPowers)) goto err;
                        }
                }
 
-       /* Adjust the number of bits up to a multiple of the window size.
-        * If the exponent length is not a multiple of the window size, then
-        * this pads the most significant bits with zeros to normalize the
-        * scanning loop to there's no special cases.
-        *
-        * * NOTE: Making the window size a power of two less than the native
-        * * word size ensures that the padded bits won't go past the last
-        * * word in the internal BIGNUM structure. Going past the end will
-        * * still produce the correct result, but causes a different branch
-        * * to be taken in the BN_is_bit_set function.
-        */
-       bits = ((bits+window-1)/window)*window;
-       idx=bits-1;     /* The top bit of the window */
-
-       /* Scan the exponent one window at a time starting from the most
+       /* Scan the exponent one window at a time starting from the most
         * significant bits.
         */
-       while (idx >= 0)
+       bits--;
+       while (bits >= 0)
                {
                wvalue=0; /* The 'value' of the window */
                
                /* Scan the window, squaring the result as we go */
-               for (i=0; i<window; i++,idx--)
+               for (i=0; i<window; i++,bits--)
                        {
                        if (!BN_mod_mul_montgomery(r,r,r,mont,ctx))     goto err;
-                       wvalue = (wvalue<<1)+BN_is_bit_set(p,idx);
+                       wvalue = (wvalue<<1)+BN_is_bit_set(p,bits);
                        }
                
                /* Fetch the appropriate pre-computed value from the pre-buf */
-               if (!MOD_EXP_CTIME_COPY_FROM_PREBUF(computeTemp, top, powerbuf, wvalue, numPowers)) goto err;
+               if (!MOD_EXP_CTIME_COPY_FROM_PREBUF(&computeTemp, top, powerbuf, wvalue, numPowers)) goto err;
 
                /* Multiply the result into the intermediate result */
-               if (!BN_mod_mul_montgomery(r,r,computeTemp,mont,ctx)) goto err;
+               if (!BN_mod_mul_montgomery(r,r,&computeTemp,mont,ctx)) goto err;
                }
+       }
 
        /* Convert the final result from montgomery to standard format */
        if (!BN_from_montgomery(rr,r,mont,ctx)) goto err;
@@ -715,10 +829,9 @@ err:
        if (powerbuf!=NULL)
                {
                OPENSSL_cleanse(powerbuf,powerbufLen);
-               OPENSSL_free(powerbufFree);
+               if (powerbufFree) OPENSSL_free(powerbufFree);
                }
        if (am!=NULL) BN_clear(am);
-       if (computeTemp!=NULL) BN_clear(computeTemp);
        BN_CTX_end(ctx);
        return(ret);
        }