bn_exp.c: further optimizations using more ideas from
authorAndy Polyakov <appro@openssl.org>
Mon, 17 Oct 2011 17:41:49 +0000 (17:41 +0000)
committerAndy Polyakov <appro@openssl.org>
Mon, 17 Oct 2011 17:41:49 +0000 (17:41 +0000)
http://eprint.iacr.org/2011/239.

crypto/bn/asm/x86_64-mont5.pl
crypto/bn/bn_exp.c

index d6e4c63..057cda2 100755 (executable)
@@ -607,7 +607,6 @@ $code.=<<___;
        add     $A[1],$N[1]             # np[j]*m1+ap[j]*bp[i]+tp[j]
        lea     4($j),$j                # j+=2
        adc     \$0,%rdx
-       mov     $N[1],(%rsp)            # tp[j-1]
        mov     %rdx,$N[0]
        jmp     .Linner4x
 .align 16
@@ -626,7 +625,7 @@ $code.=<<___;
        adc     \$0,%rdx
        add     $A[0],$N[0]
        adc     \$0,%rdx
-       mov     $N[0],-24(%rsp,$j,8)    # tp[j-1]
+       mov     $N[1],-32(%rsp,$j,8)    # tp[j-1]
        mov     %rdx,$N[1]
 
        mulq    $m0                     # ap[j]*bp[i]
@@ -643,7 +642,7 @@ $code.=<<___;
        adc     \$0,%rdx
        add     $A[1],$N[1]
        adc     \$0,%rdx
-       mov     $N[1],-16(%rsp,$j,8)    # tp[j-1]
+       mov     $N[0],-24(%rsp,$j,8)    # tp[j-1]
        mov     %rdx,$N[0]
 
        mulq    $m0                     # ap[j]*bp[i]
@@ -660,7 +659,7 @@ $code.=<<___;
        adc     \$0,%rdx
        add     $A[0],$N[0]
        adc     \$0,%rdx
-       mov     $N[0],-8(%rsp,$j,8)     # tp[j-1]
+       mov     $N[1],-16(%rsp,$j,8)    # tp[j-1]
        mov     %rdx,$N[1]
 
        mulq    $m0                     # ap[j]*bp[i]
@@ -678,7 +677,7 @@ $code.=<<___;
        adc     \$0,%rdx
        add     $A[1],$N[1]
        adc     \$0,%rdx
-       mov     $N[1],-32(%rsp,$j,8)    # tp[j-1]
+       mov     $N[0],-40(%rsp,$j,8)    # tp[j-1]
        mov     %rdx,$N[0]
        cmp     $num,$j
        jl      .Linner4x
@@ -697,7 +696,7 @@ $code.=<<___;
        adc     \$0,%rdx
        add     $A[0],$N[0]
        adc     \$0,%rdx
-       mov     $N[0],-24(%rsp,$j,8)    # tp[j-1]
+       mov     $N[1],-32(%rsp,$j,8)    # tp[j-1]
        mov     %rdx,$N[1]
 
        mulq    $m0                     # ap[j]*bp[i]
@@ -715,10 +714,11 @@ $code.=<<___;
        adc     \$0,%rdx
        add     $A[1],$N[1]
        adc     \$0,%rdx
-       mov     $N[1],-16(%rsp,$j,8)    # tp[j-1]
+       mov     $N[0],-24(%rsp,$j,8)    # tp[j-1]
        mov     %rdx,$N[0]
 
        movq    %xmm0,$m0               # bp[i+1]
+       mov     $N[1],-16(%rsp,$j,8)    # tp[j-1]
 
        xor     $N[1],$N[1]
        add     $A[0],$N[0]
@@ -831,6 +831,10 @@ ___
 {
 my ($inp,$num,$tbl,$idx)=$win64?("%rcx","%rdx","%r8", "%r9") : # Win64 order
                                ("%rdi","%rsi","%rdx","%rcx"); # Unix order
+my $out=$inp;
+my $STRIDE=2**5*8;
+my $N=$STRIDE/4;
+
 $code.=<<___;
 .globl bn_scatter5
 .type  bn_scatter5,\@abi-omnipotent
@@ -849,6 +853,61 @@ bn_scatter5:
 .Lscatter_epilogue:
        ret
 .size  bn_scatter5,.-bn_scatter5
+
+.globl bn_gather5
+.type  bn_gather5,\@abi-omnipotent
+.align 16
+bn_gather5:
+___
+$code.=<<___ if ($win64);
+.LSEH_begin_bn_gather5:
+       # I can't trust assembler to use specific encoding:-(
+       .byte   0x48,0x83,0xec,0x28             #sub    \$0x28,%rsp
+       .byte   0x0f,0x29,0x34,0x24             #movaps %xmm6,(%rsp)
+       .byte   0x0f,0x29,0x7c,0x24,0x10        #movdqa %xmm7,0x10(%rsp)
+___
+$code.=<<___;
+       mov     $idx,%r11
+       shr     \$`log($N/8)/log(2)`,$idx
+       and     \$`$N/8-1`,%r11
+       not     $idx
+       lea     .Lmagic_masks(%rip),%rax
+       and     \$`2**5/($N/8)-1`,$idx  # 5 is "window size"
+       lea     96($tbl,%r11,8),$tbl    # pointer within 1st cache line
+       movq    0(%rax,$idx,8),%xmm4    # set of masks denoting which
+       movq    8(%rax,$idx,8),%xmm5    # cache line contains element
+       movq    16(%rax,$idx,8),%xmm6   # denoted by 7th argument
+       movq    24(%rax,$idx,8),%xmm7
+       jmp     .Lgather
+.align 16
+.Lgather:
+       movq    `0*$STRIDE/4-96`($tbl),%xmm0
+       movq    `1*$STRIDE/4-96`($tbl),%xmm1
+       pand    %xmm4,%xmm0
+       movq    `2*$STRIDE/4-96`($tbl),%xmm2
+       pand    %xmm5,%xmm1
+       movq    `3*$STRIDE/4-96`($tbl),%xmm3
+       pand    %xmm6,%xmm2
+       por     %xmm1,%xmm0
+       pand    %xmm7,%xmm3
+       por     %xmm2,%xmm0
+       lea     $STRIDE($tbl),$tbl
+       por     %xmm3,%xmm0
+
+       movq    %xmm0,($out)            # m0=bp[0]
+       lea     8($out),$out
+       sub     \$1,$num
+       jnz     .Lgather
+___
+$code.=<<___ if ($win64);
+       movaps  %xmm6,(%rsp)
+       movaps  %xmm7,0x10(%rsp)
+       lea     0x28(%rsp),%rsp
+___
+$code.=<<___;
+       ret
+.LSEH_end_bn_gather5:
+.size  bn_gather5,.-bn_gather5
 ___
 }
 $code.=<<___;
@@ -980,6 +1039,10 @@ mul_handler:
        .rva    .LSEH_end_bn_mul4x_mont_gather5
        .rva    .LSEH_info_bn_mul4x_mont_gather5
 
+       .rva    .LSEH_begin_bn_gather5
+       .rva    .LSEH_end_bn_gather5
+       .rva    .LSEH_info_bn_gather5
+
 .section       .xdata
 .align 8
 .LSEH_info_bn_mul_mont_gather5:
@@ -992,6 +1055,12 @@ mul_handler:
        .rva    mul_handler
        .rva    .Lmul4x_alloca,.Lmul4x_body,.Lmul4x_epilogue    # HandlerData[]
 .align 8
+.LSEH_info_bn_gather5:
+        .byte   0x01,0x0d,0x05,0x00
+        .byte   0x0d,0x78,0x01,0x00    #movaps 0x10(rsp),xmm7
+        .byte   0x08,0x68,0x00,0x00    #movaps (rsp),xmm6
+        .byte   0x04,0x42,0x00,0x00    #sub    rsp,0x28
+.align 8
 ___
 }
 
index c69cd2c..5c49236 100644 (file)
@@ -535,23 +535,17 @@ err:
  * as cache lines are concerned.  The following functions are used to transfer a BIGNUM
  * from/to that table. */
 
-static int MOD_EXP_CTIME_COPY_TO_PREBUF(BIGNUM *b, int top, unsigned char *buf, int idx, int width)
+static int MOD_EXP_CTIME_COPY_TO_PREBUF(const BIGNUM *b, int top, unsigned char *buf, int idx, int width)
        {
        size_t i, j;
 
-       if (bn_wexpand(b, top) == NULL)
-               return 0;
-       while (b->top < top)
-               {
-               b->d[b->top++] = 0;
-               }
-       
+       if (top > b->top)
+               top = b->top; /* this works because 'buf' is explicitly zeroed */
        for (i = 0, j=idx; i < top * sizeof b->d[0]; i++, j+=width)
                {
                buf[j] = ((unsigned char*)b->d)[i];
                }
 
-       bn_correct_top(b);
        return 1;
        }
 
@@ -587,14 +581,13 @@ int BN_mod_exp_mont_consttime(BIGNUM *rr, const BIGNUM *a, const BIGNUM *p,
        {
        int i,bits,ret=0,window,wvalue;
        int top;
-       BIGNUM *r;
        BN_MONT_CTX *mont=NULL;
 
        int numPowers;
        unsigned char *powerbufFree=NULL;
        int powerbufLen = 0;
        unsigned char *powerbuf=NULL;
-       BIGNUM computeTemp, *am=NULL;
+       BIGNUM tmp, am;
 
        bn_check_top(a);
        bn_check_top(p);
@@ -614,10 +607,7 @@ int BN_mod_exp_mont_consttime(BIGNUM *rr, const BIGNUM *a, const BIGNUM *p,
                return ret;
                }
 
-       /* Initialize BIGNUM context and allocate intermediate result */
        BN_CTX_start(ctx);
-       r = BN_CTX_get(ctx);
-       if (r == NULL) goto err;
 
        /* Allocate a montgomery context if it was not supplied by the caller.
         * If this is not done, things will break in the montgomery part.
@@ -635,25 +625,13 @@ int BN_mod_exp_mont_consttime(BIGNUM *rr, const BIGNUM *a, const BIGNUM *p,
 #if defined(OPENSSL_BN_ASM_MONT5)
        if (window==6 && bits<=1024) window=5;  /* ~5% improvement of 2048-bit RSA sign */
 #endif
-       /* Adjust the number of bits up to a multiple of the window size.
-        * If the exponent length is not a multiple of the window size, then
-        * this pads the most significant bits with zeros to normalize the
-        * scanning loop to there's no special cases.
-        *
-        * * NOTE: Making the window size a power of two less than the native
-        * * word size ensures that the padded bits won't go past the last
-        * * word in the internal BIGNUM structure. Going past the end will
-        * * still produce the correct result, but causes a different branch
-        * * to be taken in the BN_is_bit_set function.
-        */
-       bits = ((bits+window-1)/window)*window;
 
        /* Allocate a buffer large enough to hold all of the pre-computed
-        * powers of a, plus computeTemp.
+        * powers of am, am itself and tmp.
         */
        numPowers = 1 << window;
        powerbufLen = sizeof(m->d[0])*(top*numPowers +
-                               (top>numPowers?top:numPowers));
+                               ((2*top)>numPowers?(2*top):numPowers));
 #ifdef alloca
        if (powerbufLen < 3072)
                powerbufFree = alloca(powerbufLen+MOD_EXP_CTIME_MIN_CACHE_LINE_WIDTH);
@@ -670,28 +648,31 @@ int BN_mod_exp_mont_consttime(BIGNUM *rr, const BIGNUM *a, const BIGNUM *p,
                powerbufFree = NULL;
 #endif
 
-       computeTemp.d = (BN_ULONG *)(powerbuf + sizeof(m->d[0])*top*numPowers);
-       computeTemp.top = computeTemp.dmax = top;
-       computeTemp.neg = 0;
-       computeTemp.flags = BN_FLG_STATIC_DATA;
-
-       /* Initialize the intermediate result. Do this early to save double conversion,
-        * once each for a^0 and intermediate result.
-        */
-       if (!BN_to_montgomery(r,BN_value_one(),mont,ctx)) goto err;
-
-       /* Initialize computeTemp as a^1 with montgomery precalcs */
-       am = BN_CTX_get(ctx);
-       if (am==NULL) goto err;
+       /* lay down tmp and am right after powers table */
+       tmp.d     = (BN_ULONG *)(powerbuf + sizeof(m->d[0])*top*numPowers);
+       am.d      = tmp.d + top;
+       tmp.top   = am.top  = 0;
+       tmp.dmax  = am.dmax = top;
+       tmp.neg   = am.neg  = 0;
+       tmp.flags = am.flags = BN_FLG_STATIC_DATA;
+
+       /* prepare a^0 in Montgomery domain */
+#if 1
+       if (!BN_to_montgomery(&tmp,BN_value_one(),mont,ctx))    goto err;
+#else
+       tmp.d[0] = (0-m->d[0])&BN_MASK2;        /* 2^(top*BN_BITS2) - m */
+       for (i=1;i<top;i++)
+               tmp.d[i] = (~m->d[i])&BN_MASK2;
+       tmp.top = top;
+#endif
 
+       /* prepare a^1 in Montgomery domain */
        if (a->neg || BN_ucmp(a,m) >= 0)
                {
-               if (!BN_mod(am,a,m,ctx))                goto err;
-               if (!BN_to_montgomery(am,am,mont,ctx))  goto err;
+               if (!BN_mod(&am,a,m,ctx))                       goto err;
+               if (!BN_to_montgomery(&am,&am,mont,ctx))        goto err;
                }
-       else    if (!BN_to_montgomery(am,a,mont,ctx))   goto err;
-
-       if (!BN_copy(&computeTemp, am)) goto err;
+       else    if (!BN_to_montgomery(&am,a,mont,ctx))          goto err;
 
 #if defined(OPENSSL_BN_ASM_MONT5)
     /* This optimization uses ideas from http://eprint.iacr.org/2011/239,
@@ -707,95 +688,83 @@ int BN_mod_exp_mont_consttime(BIGNUM *rr, const BIGNUM *a, const BIGNUM *p,
                        const BN_ULONG *n0,int num,int power);
        void bn_scatter5(const BN_ULONG *inp,size_t num,
                        void *table,size_t power);
+       void bn_gather5(BN_ULONG *out,size_t num,
+                       void *table,size_t power);
 
-       BN_ULONG *acc, *np=mont->N.d, *n0=mont->n0;
+       BN_ULONG *np=mont->N.d, *n0=mont->n0;
 
-       bn_scatter5(r->d,r->top,powerbuf,0);
-       bn_scatter5(am->d,am->top,powerbuf,1);
+       bn_scatter5(tmp.d,top,powerbuf,0);
+       bn_scatter5(am.d,am.top,powerbuf,1);
+       bn_mul_mont(tmp.d,am.d,am.d,np,n0,top);
+       bn_scatter5(tmp.d,top,powerbuf,2);
 
-       acc = computeTemp.d;
-       /* bn_mul_mont() and bn_mul_mont_gather5() assume fixed length inputs.
-        * Pad the inputs with zeroes.
-        */
-       if (bn_wexpand(am,top)==NULL || bn_wexpand(r,top)==NULL ||
-           bn_wexpand(&computeTemp,top)==NULL)
-               goto err;
-       for (i = am->top; i < top; ++i)
-               {
-               am->d[i] = 0;
-               }
-       for (i = computeTemp.top; i < top; ++i)
-               {
-               computeTemp.d[i] = 0;
-               }
-       for (i = r->top; i < top; ++i)
-               {
-               r->d[i] = 0;
-               }
 #if 0
-       for (i=2; i<32; i++)
+       for (i=3; i<32; i++)
                {
-               bn_mul_mont_gather5(acc,am->d,powerbuf,np,n0,top,i-1);
-               bn_scatter5(acc,top,powerbuf,i);
+               /* Calculate a^i = a^(i-1) * a */
+               bn_mul_mont_gather5(tmp.d,am.d,powerbuf,np,n0,top,i-1);
+               bn_scatter5(tmp.d,top,powerbuf,i);
                }
 #else
        /* same as above, but uses squaring for 1/2 of operations */
-       for (i=2; i<32; i*=2)
+       for (i=4; i<32; i*=2)
                {
-               bn_mul_mont(acc,acc,acc,np,n0,top);
-               bn_scatter5(acc,top,powerbuf,i);
+               bn_mul_mont(tmp.d,tmp.d,tmp.d,np,n0,top);
+               bn_scatter5(tmp.d,top,powerbuf,i);
                }
        for (i=3; i<8; i+=2)
                {
                int j;
-               bn_mul_mont_gather5(acc,am->d,powerbuf,np,n0,top,i-1);
-               bn_scatter5(acc,top,powerbuf,i);
+               bn_mul_mont_gather5(tmp.d,am.d,powerbuf,np,n0,top,i-1);
+               bn_scatter5(tmp.d,top,powerbuf,i);
                for (j=2*i; j<32; j*=2)
                        {
-                       bn_mul_mont(acc,acc,acc,np,n0,top);
-                       bn_scatter5(acc,top,powerbuf,j);
+                       bn_mul_mont(tmp.d,tmp.d,tmp.d,np,n0,top);
+                       bn_scatter5(tmp.d,top,powerbuf,j);
                        }
                }
        for (; i<16; i+=2)
                {
-               bn_mul_mont_gather5(acc,am->d,powerbuf,np,n0,top,i-1);
-               bn_scatter5(acc,top,powerbuf,i);
-               bn_mul_mont(acc,acc,acc,np,n0,top);
-               bn_scatter5(acc,top,powerbuf,2*i);
+               bn_mul_mont_gather5(tmp.d,am.d,powerbuf,np,n0,top,i-1);
+               bn_scatter5(tmp.d,top,powerbuf,i);
+               bn_mul_mont(tmp.d,tmp.d,tmp.d,np,n0,top);
+               bn_scatter5(tmp.d,top,powerbuf,2*i);
                }
        for (; i<32; i+=2)
                {
-               bn_mul_mont_gather5(acc,am->d,powerbuf,np,n0,top,i-1);
-               bn_scatter5(acc,top,powerbuf,i);
+               bn_mul_mont_gather5(tmp.d,am.d,powerbuf,np,n0,top,i-1);
+               bn_scatter5(tmp.d,top,powerbuf,i);
                }
 #endif
-       acc = r->d;
+       bits--;
+       for (wvalue=0, i=bits%5; i>=0; i--,bits--)
+               wvalue = (wvalue<<1)+BN_is_bit_set(p,bits);
+       bn_gather5(tmp.d,top,powerbuf,wvalue);
 
        /* Scan the exponent one window at a time starting from the most
         * significant bits.
         */
-       bits--;
        while (bits >= 0)
                {
                for (wvalue=0, i=0; i<5; i++,bits--)
                        wvalue = (wvalue<<1)+BN_is_bit_set(p,bits);
 
-               bn_mul_mont(acc,acc,acc,np,n0,top);
-               bn_mul_mont(acc,acc,acc,np,n0,top);
-               bn_mul_mont(acc,acc,acc,np,n0,top);
-               bn_mul_mont(acc,acc,acc,np,n0,top);
-               bn_mul_mont(acc,acc,acc,np,n0,top);
-               bn_mul_mont_gather5(acc,acc,powerbuf,np,n0,top,wvalue);
+               bn_mul_mont(tmp.d,tmp.d,tmp.d,np,n0,top);
+               bn_mul_mont(tmp.d,tmp.d,tmp.d,np,n0,top);
+               bn_mul_mont(tmp.d,tmp.d,tmp.d,np,n0,top);
+               bn_mul_mont(tmp.d,tmp.d,tmp.d,np,n0,top);
+               bn_mul_mont(tmp.d,tmp.d,tmp.d,np,n0,top);
+               bn_mul_mont_gather5(tmp.d,tmp.d,powerbuf,np,n0,top,wvalue);
                }
 
-       r->top=top;
-       bn_correct_top(r);
+       tmp.top=top;
+       bn_correct_top(&tmp);
        }
     else
 #endif
        {
-       if (!MOD_EXP_CTIME_COPY_TO_PREBUF(r, top, powerbuf, 0, numPowers)) goto err;
-       if (!MOD_EXP_CTIME_COPY_TO_PREBUF(am, top, powerbuf, 1, numPowers)) goto err;
+       if (!MOD_EXP_CTIME_COPY_TO_PREBUF(&tmp, top, powerbuf, 0, numPowers)) goto err;
+       if (!MOD_EXP_CTIME_COPY_TO_PREBUF(&am,  top, powerbuf, 1, numPowers)) goto err;
 
        /* If the window size is greater than 1, then calculate
         * val[i=2..2^winsize-1]. Powers are computed as a*a^(i-1)
@@ -804,19 +773,25 @@ int BN_mod_exp_mont_consttime(BIGNUM *rr, const BIGNUM *a, const BIGNUM *p,
         */
        if (window > 1)
                {
-               for (i=2; i<numPowers; i++)
+               if (!BN_mod_mul_montgomery(&tmp,&am,&am,mont,ctx))      goto err;
+               if (!MOD_EXP_CTIME_COPY_TO_PREBUF(&tmp, top, powerbuf, 2, numPowers)) goto err;
+               for (i=3; i<numPowers; i++)
                        {
                        /* Calculate a^i = a^(i-1) * a */
-                       if (!BN_mod_mul_montgomery(&computeTemp,am,&computeTemp,mont,ctx))
+                       if (!BN_mod_mul_montgomery(&tmp,&am,&tmp,mont,ctx))
                                goto err;
-                       if (!MOD_EXP_CTIME_COPY_TO_PREBUF(&computeTemp, top, powerbuf, i, numPowers)) goto err;
+                       if (!MOD_EXP_CTIME_COPY_TO_PREBUF(&tmp, top, powerbuf, i, numPowers)) goto err;
                        }
                }
 
-       /* Scan the exponent one window at a time starting from the most
-        * significant bits.
-        */
        bits--;
+       for (wvalue=0, i=bits%window; i>=0; i--,bits--)
+               wvalue = (wvalue<<1)+BN_is_bit_set(p,bits);
+       if (!MOD_EXP_CTIME_COPY_FROM_PREBUF(&tmp,top,powerbuf,wvalue,numPowers)) goto err;
+       /* Scan the exponent one window at a time starting from the most
+        * significant bits.
+        */
        while (bits >= 0)
                {
                wvalue=0; /* The 'value' of the window */
@@ -824,20 +799,20 @@ int BN_mod_exp_mont_consttime(BIGNUM *rr, const BIGNUM *a, const BIGNUM *p,
                /* Scan the window, squaring the result as we go */
                for (i=0; i<window; i++,bits--)
                        {
-                       if (!BN_mod_mul_montgomery(r,r,r,mont,ctx))     goto err;
+                       if (!BN_mod_mul_montgomery(&tmp,&tmp,&tmp,mont,ctx))    goto err;
                        wvalue = (wvalue<<1)+BN_is_bit_set(p,bits);
                        }
                
                /* Fetch the appropriate pre-computed value from the pre-buf */
-               if (!MOD_EXP_CTIME_COPY_FROM_PREBUF(&computeTemp, top, powerbuf, wvalue, numPowers)) goto err;
+               if (!MOD_EXP_CTIME_COPY_FROM_PREBUF(&am, top, powerbuf, wvalue, numPowers)) goto err;
 
                /* Multiply the result into the intermediate result */
-               if (!BN_mod_mul_montgomery(r,r,&computeTemp,mont,ctx)) goto err;
+               if (!BN_mod_mul_montgomery(&tmp,&tmp,&am,mont,ctx)) goto err;
                }
        }
 
        /* Convert the final result from montgomery to standard format */
-       if (!BN_from_montgomery(rr,r,mont,ctx)) goto err;
+       if (!BN_from_montgomery(rr,&tmp,mont,ctx)) goto err;
        ret=1;
 err:
        if ((in_mont == NULL) && (mont != NULL)) BN_MONT_CTX_free(mont);
@@ -846,7 +821,6 @@ err:
                OPENSSL_cleanse(powerbuf,powerbufLen);
                if (powerbufFree) OPENSSL_free(powerbufFree);
                }
-       if (am!=NULL) BN_clear(am);
        BN_CTX_end(ctx);
        return(ret);
        }