Fix BN_mod_word bug
[openssl.git] / crypto / bn / bn_exp2.c
index 4f4e9e329989d5b415bba73dd14eda0a245e7f83..5141c21f6d6bfed04818f0b97ba0024a09a7b12a 100644 (file)
+/*
+ * Copyright 1995-2016 The OpenSSL Project Authors. All Rights Reserved.
+ *
+ * Licensed under the OpenSSL license (the "License").  You may not use
+ * this file except in compliance with the License.  You can obtain a copy
+ * in the file LICENSE in the source distribution or at
+ * https://www.openssl.org/source/license.html
+ */
+
 #include <stdio.h>
-#include "cryptlib.h"
+#include "internal/cryptlib.h"
 #include "bn_lcl.h"
 
-/* I've done some timing with different table sizes.
- * The main hassle is that even with bits set at 3, this requires
- * 63 BIGNUMs to store the pre-calculated values.
- *          512   1024 
- * bits=1  75.4%  79.4%
- * bits=2  61.2%  62.4%
- * bits=3  61.3%  59.3%
- * The lack of speed improvement is also a function of the pre-calculation
- * which could be removed.
- */
-#define EXP2_TABLE_BITS        2 /* 1  2  3  4  5  */
-#define EXP2_TABLE_SIZE        4 /* 2  4  8 16 32  */
-
-int BN_mod_exp2_mont(BIGNUM *rr, BIGNUM *a1, BIGNUM *p1, BIGNUM *a2,
-            BIGNUM *p2, BIGNUM *m, BN_CTX *ctx, BN_MONT_CTX *in_mont)
-       {
-       int i,j,k,bits,bits1,bits2,ret=0,wstart,wend,window,xvalue,yvalue;
-       int start=1,ts=0,x,y;
-       BIGNUM *d,*aa1,*aa2,*r;
-       BIGNUM val[EXP2_TABLE_SIZE][EXP2_TABLE_SIZE];
-       BN_MONT_CTX *mont=NULL;
-
-       bn_check_top(a1);
-       bn_check_top(p1);
-       bn_check_top(a2);
-       bn_check_top(p2);
-       bn_check_top(m);
-
-       if (!(m->d[0] & 1))
-               {
-               BNerr(BN_F_BN_MOD_EXP_MONT,BN_R_CALLED_WITH_EVEN_MODULUS);
-               return(0);
-               }
-       bits1=BN_num_bits(p1);
-       bits2=BN_num_bits(p2);
-       if ((bits1 == 0) && (bits2 == 0))
-               {
-               BN_one(rr);
-               return(1);
-               }
-
-       BN_CTX_start(ctx);
-       d = BN_CTX_get(ctx);
-       r = BN_CTX_get(ctx);
-       if (d == NULL || r == NULL) goto err;
-
-       bits=(bits1 > bits2)?bits1:bits2;
-
-       /* If this is not done, things will break in the montgomery
-        * part */
-
-       if (in_mont != NULL)
-               mont=in_mont;
-       else
-               {
-               if ((mont=BN_MONT_CTX_new()) == NULL) goto err;
-               if (!BN_MONT_CTX_set(mont,m,ctx)) goto err;
-               }
-
-       BN_init(&(val[0][0]));
-       BN_init(&(val[1][1]));
-       BN_init(&(val[0][1]));
-       BN_init(&(val[1][0]));
-       ts=1;
-       if (BN_ucmp(a1,m) >= 0)
-               {
-               BN_mod(&(val[1][0]),a1,m,ctx);
-               aa1= &(val[1][0]);
-               }
-       else
-               aa1=a1;
-       if (BN_ucmp(a2,m) >= 0)
-               {
-               BN_mod(&(val[0][1]),a2,m,ctx);
-               aa2= &(val[0][1]);
-               }
-       else
-               aa2=a2;
-       if (!BN_to_montgomery(&(val[1][0]),aa1,mont,ctx)) goto err;
-       if (!BN_to_montgomery(&(val[0][1]),aa2,mont,ctx)) goto err;
-       if (!BN_mod_mul_montgomery(&(val[1][1]),
-               &(val[1][0]),&(val[0][1]),mont,ctx))
-               goto err;
-
-#if 0
-       if (bits <= 20) /* This is probably 3 or 0x10001, so just do singles */
-               window=1;
-       else if (bits > 250)
-               window=5;       /* max size of window */
-       else if (bits >= 120)
-               window=4;
-       else
-               window=3;
-#else
-       window=EXP2_TABLE_BITS;
-#endif
-
-       k=1<<window;
-       for (x=0; x<k; x++)
-               {
-               if (x >= 2)
-                       {
-                       BN_init(&(val[x][0]));
-                       BN_init(&(val[x][1]));
-                       if (!BN_mod_mul_montgomery(&(val[x][0]),
-                               &(val[1][0]),&(val[x-1][0]),mont,ctx)) goto err;
-                       if (!BN_mod_mul_montgomery(&(val[x][1]),
-                               &(val[1][0]),&(val[x-1][1]),mont,ctx)) goto err;
-                       }
-               for (y=2; y<k; y++)
-                       {
-                       BN_init(&(val[x][y]));
-                       if (!BN_mod_mul_montgomery(&(val[x][y]),
-                               &(val[x][y-1]),&(val[0][1]),mont,ctx))
-                               goto err;
-                       }
-               }
-       ts=k;
-
-       start=1;        /* This is used to avoid multiplication etc
-                        * when there is only the value '1' in the
-                        * buffer. */
-       xvalue=0;       /* The 'x value' of the window */
-       yvalue=0;       /* The 'y value' of the window */
-       wstart=bits-1;  /* The top bit of the window */
-       wend=0;         /* The bottom bit of the window */
-
-        if (!BN_to_montgomery(r,BN_value_one(),mont,ctx)) goto err;
-       for (;;)
-               {
-               xvalue=BN_is_bit_set(p1,wstart);
-               yvalue=BN_is_bit_set(p2,wstart);
-               if (!(xvalue || yvalue))
-                       {
-                       if (!start)
-                               {
-                               if (!BN_mod_mul_montgomery(r,r,r,mont,ctx))
-                                       goto err;
-                               }
-                       wstart--;
-                       if (wstart < 0) break;
-                       continue;
-                       }
-               /* We now have wstart on a 'set' bit, we now need to work out
-                * how bit a window to do.  To do this we need to scan
-                * forward until the last set bit before the end of the
-                * window */
-               j=wstart;
-               /* xvalue=BN_is_bit_set(p1,wstart); already set */
-               /* yvalue=BN_is_bit_set(p1,wstart); already set */
-               wend=0;
-               for (i=1; i<window; i++)
-                       {
-                       if (wstart-i < 0) break;
-                       xvalue+=xvalue;
-                       xvalue|=BN_is_bit_set(p1,wstart-i);
-                       yvalue+=yvalue;
-                       yvalue|=BN_is_bit_set(p2,wstart-i);
-                       }
-
-               /* i is the size of the current window */
-               /* add the 'bytes above' */
-               if (!start)
-                       for (j=0; j<i; j++)
-                               {
-                               if (!BN_mod_mul_montgomery(r,r,r,mont,ctx))
-                                       goto err;
-                               }
-               
-               /* wvalue will be an odd number < 2^window */
-               if (xvalue || yvalue)
-                       {
-                       if (!BN_mod_mul_montgomery(r,r,&(val[xvalue][yvalue]),
-                               mont,ctx)) goto err;
-                       }
-
-               /* move the 'window' down further */
-               wstart-=i;
-               start=0;
-               if (wstart < 0) break;
-               }
-       BN_from_montgomery(rr,r,mont,ctx);
-       ret=1;
-err:
-       if ((in_mont == NULL) && (mont != NULL)) BN_MONT_CTX_free(mont);
-       BN_CTX_end(ctx);
-       for (i=0; i<ts; i++)
-               {
-               for (j=0; j<ts; j++)
-                       {
-                       BN_clear_free(&(val[i][j]));
-                       }
-               }
-       return(ret);
-       }
+#define TABLE_SIZE      32
+
+int BN_mod_exp2_mont(BIGNUM *rr, const BIGNUM *a1, const BIGNUM *p1,
+                     const BIGNUM *a2, const BIGNUM *p2, const BIGNUM *m,
+                     BN_CTX *ctx, BN_MONT_CTX *in_mont)
+{
+    int i, j, bits, b, bits1, bits2, ret =
+        0, wpos1, wpos2, window1, window2, wvalue1, wvalue2;
+    int r_is_one = 1;
+    BIGNUM *d, *r;
+    const BIGNUM *a_mod_m;
+    /* Tables of variables obtained from 'ctx' */
+    BIGNUM *val1[TABLE_SIZE], *val2[TABLE_SIZE];
+    BN_MONT_CTX *mont = NULL;
+
+    bn_check_top(a1);
+    bn_check_top(p1);
+    bn_check_top(a2);
+    bn_check_top(p2);
+    bn_check_top(m);
+
+    if (!(m->d[0] & 1)) {
+        BNerr(BN_F_BN_MOD_EXP2_MONT, BN_R_CALLED_WITH_EVEN_MODULUS);
+        return (0);
+    }
+    bits1 = BN_num_bits(p1);
+    bits2 = BN_num_bits(p2);
+    if ((bits1 == 0) && (bits2 == 0)) {
+        ret = BN_one(rr);
+        return ret;
+    }
+
+    bits = (bits1 > bits2) ? bits1 : bits2;
+
+    BN_CTX_start(ctx);
+    d = BN_CTX_get(ctx);
+    r = BN_CTX_get(ctx);
+    val1[0] = BN_CTX_get(ctx);
+    val2[0] = BN_CTX_get(ctx);
+    if (!d || !r || !val1[0] || !val2[0])
+        goto err;
+
+    if (in_mont != NULL)
+        mont = in_mont;
+    else {
+        if ((mont = BN_MONT_CTX_new()) == NULL)
+            goto err;
+        if (!BN_MONT_CTX_set(mont, m, ctx))
+            goto err;
+    }
+
+    window1 = BN_window_bits_for_exponent_size(bits1);
+    window2 = BN_window_bits_for_exponent_size(bits2);
+
+    /*
+     * Build table for a1:   val1[i] := a1^(2*i + 1) mod m  for i = 0 .. 2^(window1-1)
+     */
+    if (a1->neg || BN_ucmp(a1, m) >= 0) {
+        if (!BN_mod(val1[0], a1, m, ctx))
+            goto err;
+        a_mod_m = val1[0];
+    } else
+        a_mod_m = a1;
+    if (BN_is_zero(a_mod_m)) {
+        BN_zero(rr);
+        ret = 1;
+        goto err;
+    }
+
+    if (!BN_to_montgomery(val1[0], a_mod_m, mont, ctx))
+        goto err;
+    if (window1 > 1) {
+        if (!BN_mod_mul_montgomery(d, val1[0], val1[0], mont, ctx))
+            goto err;
+
+        j = 1 << (window1 - 1);
+        for (i = 1; i < j; i++) {
+            if (((val1[i] = BN_CTX_get(ctx)) == NULL) ||
+                !BN_mod_mul_montgomery(val1[i], val1[i - 1], d, mont, ctx))
+                goto err;
+        }
+    }
+
+    /*
+     * Build table for a2:   val2[i] := a2^(2*i + 1) mod m  for i = 0 .. 2^(window2-1)
+     */
+    if (a2->neg || BN_ucmp(a2, m) >= 0) {
+        if (!BN_mod(val2[0], a2, m, ctx))
+            goto err;
+        a_mod_m = val2[0];
+    } else
+        a_mod_m = a2;
+    if (BN_is_zero(a_mod_m)) {
+        BN_zero(rr);
+        ret = 1;
+        goto err;
+    }
+    if (!BN_to_montgomery(val2[0], a_mod_m, mont, ctx))
+        goto err;
+    if (window2 > 1) {
+        if (!BN_mod_mul_montgomery(d, val2[0], val2[0], mont, ctx))
+            goto err;
+
+        j = 1 << (window2 - 1);
+        for (i = 1; i < j; i++) {
+            if (((val2[i] = BN_CTX_get(ctx)) == NULL) ||
+                !BN_mod_mul_montgomery(val2[i], val2[i - 1], d, mont, ctx))
+                goto err;
+        }
+    }
+
+    /* Now compute the power product, using independent windows. */
+    r_is_one = 1;
+    wvalue1 = 0;                /* The 'value' of the first window */
+    wvalue2 = 0;                /* The 'value' of the second window */
+    wpos1 = 0;                  /* If wvalue1 > 0, the bottom bit of the
+                                 * first window */
+    wpos2 = 0;                  /* If wvalue2 > 0, the bottom bit of the
+                                 * second window */
+
+    if (!BN_to_montgomery(r, BN_value_one(), mont, ctx))
+        goto err;
+    for (b = bits - 1; b >= 0; b--) {
+        if (!r_is_one) {
+            if (!BN_mod_mul_montgomery(r, r, r, mont, ctx))
+                goto err;
+        }
+
+        if (!wvalue1)
+            if (BN_is_bit_set(p1, b)) {
+                /*
+                 * consider bits b-window1+1 .. b for this window
+                 */
+                i = b - window1 + 1;
+                while (!BN_is_bit_set(p1, i)) /* works for i<0 */
+                    i++;
+                wpos1 = i;
+                wvalue1 = 1;
+                for (i = b - 1; i >= wpos1; i--) {
+                    wvalue1 <<= 1;
+                    if (BN_is_bit_set(p1, i))
+                        wvalue1++;
+                }
+            }
+
+        if (!wvalue2)
+            if (BN_is_bit_set(p2, b)) {
+                /*
+                 * consider bits b-window2+1 .. b for this window
+                 */
+                i = b - window2 + 1;
+                while (!BN_is_bit_set(p2, i))
+                    i++;
+                wpos2 = i;
+                wvalue2 = 1;
+                for (i = b - 1; i >= wpos2; i--) {
+                    wvalue2 <<= 1;
+                    if (BN_is_bit_set(p2, i))
+                        wvalue2++;
+                }
+            }
+
+        if (wvalue1 && b == wpos1) {
+            /* wvalue1 is odd and < 2^window1 */
+            if (!BN_mod_mul_montgomery(r, r, val1[wvalue1 >> 1], mont, ctx))
+                goto err;
+            wvalue1 = 0;
+            r_is_one = 0;
+        }
+
+        if (wvalue2 && b == wpos2) {
+            /* wvalue2 is odd and < 2^window2 */
+            if (!BN_mod_mul_montgomery(r, r, val2[wvalue2 >> 1], mont, ctx))
+                goto err;
+            wvalue2 = 0;
+            r_is_one = 0;
+        }
+    }
+    if (!BN_from_montgomery(rr, r, mont, ctx))
+        goto err;
+    ret = 1;
+ err:
+    if (in_mont == NULL)
+        BN_MONT_CTX_free(mont);
+    BN_CTX_end(ctx);
+    bn_check_top(rr);
+    return (ret);
+}