Add a pointer to a paper (is the algorithm in section 4.2 the
[openssl.git] / crypto / bn / bn_mont.c
index e73b0cbb693df6c96fdba26c833cba7e4cf2a6cc..5ef08d9157247ba70a2f024f6c56b52f6b16426f 100644 (file)
  */
 
 /*
  */
 
 /*
- * Details about Montgomery multiplication algorithms can be found at:
- * http://www.ece.orst.edu/ISL/Publications.html
- * http://www.ece.orst.edu/ISL/Koc/papers/j37acmon.pdf
+ * Details about Montgomery multiplication algorithms can be found at
+ * http://security.ece.orst.edu/publications.html, e.g.
+ * http://security.ece.orst.edu/koc/papers/j37acmon.pdf and
+ * sections 3.8 and 4.2 in http://security.ece.orst.edu/koc/papers/r01rsasw.pdf
  */
 
 #include <stdio.h>
 #include "cryptlib.h"
 #include "bn_lcl.h"
 
  */
 
 #include <stdio.h>
 #include "cryptlib.h"
 #include "bn_lcl.h"
 
-#define MONT_WORD
-
-int BN_mod_mul_montgomery(r,a,b,mont,ctx)
-BIGNUM *r,*a,*b;
-BN_MONT_CTX *mont;
-BN_CTX *ctx;
+int BN_mod_mul_montgomery(BIGNUM *r, BIGNUM *a, BIGNUM *b,
+                         BN_MONT_CTX *mont, BN_CTX *ctx)
        {
        BIGNUM *tmp,*tmp2;
 
        {
        BIGNUM *tmp,*tmp2;
 
@@ -107,36 +104,34 @@ err:
        return(0);
        }
 
        return(0);
        }
 
-int BN_from_montgomery(ret,a,mont,ctx)
-BIGNUM *ret;
-BIGNUM *a;
-BN_MONT_CTX *mont;
-BN_CTX *ctx;
+int BN_from_montgomery(BIGNUM *ret, BIGNUM *a, BN_MONT_CTX *mont,
+            BN_CTX *ctx)
        {
        {
-#ifdef BN_RECURSION
+       int retn=0;
+#ifdef BN_RECURSION_MONT
        if (mont->use_word)
 #endif
                {
                BIGNUM *n,*r;
                BN_ULONG *ap,*np,*rp,n0,v,*nrp;
                int al,nl,max,i,x,ri;
        if (mont->use_word)
 #endif
                {
                BIGNUM *n,*r;
                BN_ULONG *ap,*np,*rp,n0,v,*nrp;
                int al,nl,max,i,x,ri;
-               int retn=0;
 
                r= &(ctx->bn[ctx->tos]);
 
 
                r= &(ctx->bn[ctx->tos]);
 
-               if (!BN_copy(r,a)) goto err1;
+               if (!BN_copy(r,a)) goto err;
                n= &(mont->N);
 
                ap=a->d;
                n= &(mont->N);
 
                ap=a->d;
-               /* mont->ri is the size of mont->N in bits/words */
+               /* mont->ri is the size of mont->N in bits (rounded up
+                   to the word size) */
                al=ri=mont->ri/BN_BITS2;
 
                nl=n->top;
                if ((al == 0) || (nl == 0)) { r->top=0; return(1); }
 
                max=(nl+al+1); /* allow for overflow (no?) XXX */
                al=ri=mont->ri/BN_BITS2;
 
                nl=n->top;
                if ((al == 0) || (nl == 0)) { r->top=0; return(1); }
 
                max=(nl+al+1); /* allow for overflow (no?) XXX */
-               if (bn_wexpand(r,max) == NULL) goto err1;
-               if (bn_wexpand(ret,max) == NULL) goto err1;
+               if (bn_wexpand(r,max) == NULL) goto err;
+               if (bn_wexpand(ret,max) == NULL) goto err;
 
                r->neg=a->neg^n->neg;
                np=n->d;
 
                r->neg=a->neg^n->neg;
                np=n->d;
@@ -209,67 +204,37 @@ printf("word BN_from_montgomery %d * %d\n",nl,nl);
                        BN_usub(ret,ret,&(mont->N)); /* XXX */
                        }
                retn=1;
                        BN_usub(ret,ret,&(mont->N)); /* XXX */
                        }
                retn=1;
-err1:
-               return(retn);
                }
                }
-#ifdef BN_RECURSION
+#ifdef BN_RECURSION_MONT
        else /* bignum version */ 
                {
        else /* bignum version */ 
                {
-               BIGNUM *t1,*t2,*t3;
-               int j,i;
+               BIGNUM *t1,*t2;
 
 
-#ifdef BN_COUNT
-printf("number BN_from_montgomery\n");
-#endif
+               t1=&(ctx->bn[ctx->tos]);
+               t2=&(ctx->bn[ctx->tos+1]);
+               ctx->tos+=2;
 
 
-               t1= &(ctx->bn[ctx->tos]);
-               t2= &(ctx->bn[ctx->tos+1]);
-               t3= &(ctx->bn[ctx->tos+2]);
+               if (!BN_copy(t1,a)) goto err;
+               BN_mask_bits(t1,mont->ri);
 
 
-               i=mont->Ni.top;
-               bn_wexpand(ret,i); /* perhaps only i*2 */
-               bn_wexpand(t1,i*4); /* perhaps only i*2 */
-               bn_wexpand(t2,i*2); /* perhaps only i   */
+               if (!BN_mul(t2,t1,&mont->Ni,ctx)) goto err;
+               BN_mask_bits(t2,mont->ri);
 
 
-               bn_mul_low_recursive(t2->d,a->d,mont->Ni.d,i,t1->d);
+               if (!BN_mul(t1,t2,&mont->N,ctx)) goto err;
+               if (!BN_add(t2,a,t1)) goto err;
+               BN_rshift(ret,t2,mont->ri);
 
 
-               BN_zero(t3);
-               BN_set_bit(t3,mont->N.top*BN_BITS2);
-               bn_sub_words(t3->d,t3->d,a->d,i);
-               bn_mul_high(ret->d,t2->d,mont->N.d,t3->d,i,t1->d);
-
-               /* hmm... if a is between i and 2*i, things are bad */
-               if (a->top > i)
-                       {
-                       j=(int)(bn_add_words(ret->d,ret->d,&(a->d[i]),i));
-                       if (j) /* overflow */
-                               bn_sub_words(ret->d,ret->d,mont->N.d,i);
-                       }
-               ret->top=i;
-               bn_fix_top(ret);
-               if (a->d[0])
-                       BN_add_word(ret,1); /* Always? */
-               else    /* Very very rare */
-                       {
-                       for (i=1; i<mont->N.top-1; i++)
-                               {
-                               if (a->d[i])
-                                       {
-                                       BN_add_word(ret,1); /* Always? */
-                                       break;
-                                       }
-                               }
-                       }
-
-               if (BN_ucmp(ret,&(mont->N)) >= 0)
-                       BN_usub(ret,ret,&(mont->N));
-
-               return(1);
+               if (BN_ucmp(ret,&mont->N) >= 0)
+                       BN_usub(ret,ret,&mont->N);
+               ctx->tos-=2;
+               retn=1;
                }
 #endif
                }
 #endif
+ err:
+       return(retn);
        }
 
        }
 
-BN_MONT_CTX *BN_MONT_CTX_new()
+BN_MONT_CTX *BN_MONT_CTX_new(void)
        {
        BN_MONT_CTX *ret;
 
        {
        BN_MONT_CTX *ret;
 
@@ -281,8 +246,7 @@ BN_MONT_CTX *BN_MONT_CTX_new()
        return(ret);
        }
 
        return(ret);
        }
 
-void BN_MONT_CTX_init(ctx)
-BN_MONT_CTX *ctx;
+void BN_MONT_CTX_init(BN_MONT_CTX *ctx)
        {
        ctx->use_word=0;
        ctx->ri=0;
        {
        ctx->use_word=0;
        ctx->ri=0;
@@ -292,8 +256,7 @@ BN_MONT_CTX *ctx;
        ctx->flags=0;
        }
 
        ctx->flags=0;
        }
 
-void BN_MONT_CTX_free(mont)
-BN_MONT_CTX *mont;
+void BN_MONT_CTX_free(BN_MONT_CTX *mont)
        {
        if(mont == NULL)
            return;
        {
        if(mont == NULL)
            return;
@@ -305,10 +268,7 @@ BN_MONT_CTX *mont;
                Free(mont);
        }
 
                Free(mont);
        }
 
-int BN_MONT_CTX_set(mont,mod,ctx)
-BN_MONT_CTX *mont;
-BIGNUM *mod;
-BN_CTX *ctx;
+int BN_MONT_CTX_set(BN_MONT_CTX *mont, const BIGNUM *mod, BN_CTX *ctx)
        {
        BIGNUM Ri,*R;
 
        {
        BIGNUM Ri,*R;
 
@@ -316,8 +276,9 @@ BN_CTX *ctx;
        R= &(mont->RR);                                 /* grab RR as a temp */
        BN_copy(&(mont->N),mod);                        /* Set N */
 
        R= &(mont->RR);                                 /* grab RR as a temp */
        BN_copy(&(mont->N),mod);                        /* Set N */
 
-#ifdef BN_RECURSION
-       if (mont->N.top < BN_MONT_CTX_SET_SIZE_WORD)
+#ifdef BN_RECURSION_MONT
+       /* the word-based algorithm is faster */
+       if (mont->N.top > BN_MONT_CTX_SET_SIZE_WORD)
 #endif
                {
                BIGNUM tmod;
 #endif
                {
                BIGNUM tmod;
@@ -327,74 +288,47 @@ BN_CTX *ctx;
 
                mont->ri=(BN_num_bits(mod)+(BN_BITS2-1))/BN_BITS2*BN_BITS2;
                BN_zero(R);
 
                mont->ri=(BN_num_bits(mod)+(BN_BITS2-1))/BN_BITS2*BN_BITS2;
                BN_zero(R);
-               BN_set_bit(R,BN_BITS2);
-               /* I was bad, this modification of a passed variable was
-                * breaking the multithreaded stuff :-(
-                * z=mod->top;
-                * mod->top=1; */
+               BN_set_bit(R,BN_BITS2);                 /* R = 2^ri */
 
 
-               buf[0]=mod->d[0];
+               buf[0]=mod->d[0]; /* tmod = N mod word size */
                buf[1]=0;
                tmod.d=buf;
                tmod.top=1;
                buf[1]=0;
                tmod.d=buf;
                tmod.top=1;
-               tmod.max=mod->max;
+               tmod.max=2;
                tmod.neg=mod->neg;
                tmod.neg=mod->neg;
-
+                                                       /* Ri = R^-1 mod N*/
                if ((BN_mod_inverse(&Ri,R,&tmod,ctx)) == NULL)
                        goto err;
                if ((BN_mod_inverse(&Ri,R,&tmod,ctx)) == NULL)
                        goto err;
-               BN_lshift(&Ri,&Ri,BN_BITS2);                    /* R*Ri */
+               BN_lshift(&Ri,&Ri,BN_BITS2);            /* R*Ri */
                if (!BN_is_zero(&Ri))
                if (!BN_is_zero(&Ri))
-                       {
-#if 1
                        BN_sub_word(&Ri,1);
                        BN_sub_word(&Ri,1);
-#else
-                       BN_usub(&Ri,&Ri,BN_value_one());        /* R*Ri - 1 */
-#endif
-                       }
-               else
-                       {
-                       /* This is not common..., 1 in BN_MASK2,
-                        * It happens when buf[0] was == 1.  So for 8 bit,
-                        * this is 1/256, 16bit, 1 in 2^16 etc.
-                        */
-                       BN_set_word(&Ri,BN_MASK2);
-                       }
-               BN_div(&Ri,NULL,&Ri,&tmod,ctx);
+               else /* if N mod word size == 1 */
+                       BN_set_word(&Ri,BN_MASK2);  /* Ri-- (mod word size) */
+               BN_div(&Ri,NULL,&Ri,&tmod,ctx);          /* Ni = (R*Ri-1)/N */
                mont->n0=Ri.d[0];
                BN_free(&Ri);
                mont->n0=Ri.d[0];
                BN_free(&Ri);
-               /* mod->top=z; */
                }
                }
-#ifdef BN_RECURSION
+#ifdef BN_RECURSION_MONT
        else
        else
-               {
+               { /* bignum version */
                mont->use_word=0;
                mont->use_word=0;
-               mont->ri=(BN_num_bits(mod)+(BN_BITS2-1))/BN_BITS2*BN_BITS2;
-#if 1
+               mont->ri=BN_num_bits(mod);
                BN_zero(R);
                BN_zero(R);
-               BN_set_bit(R,mont->ri);
-#else
-               BN_lshift(R,BN_value_one(),mont->ri);   /* R */
-#endif
+               BN_set_bit(R,mont->ri);                 /* R = 2^ri */
+                                                       /* Ri = R^-1 mod N*/
                if ((BN_mod_inverse(&Ri,R,mod,ctx)) == NULL)
                        goto err;
                BN_lshift(&Ri,&Ri,mont->ri);            /* R*Ri */
                if ((BN_mod_inverse(&Ri,R,mod,ctx)) == NULL)
                        goto err;
                BN_lshift(&Ri,&Ri,mont->ri);            /* R*Ri */
-#if 1
                BN_sub_word(&Ri,1);
                BN_sub_word(&Ri,1);
-#else
-               BN_usub(&Ri,&Ri,BN_value_one());        /* R*Ri - 1 */
-#endif
+                                                       /* Ni = (R*Ri-1) / N */
                BN_div(&(mont->Ni),NULL,&Ri,mod,ctx);
                BN_free(&Ri);
                }
 #endif
 
        /* setup RR for conversions */
                BN_div(&(mont->Ni),NULL,&Ri,mod,ctx);
                BN_free(&Ri);
                }
 #endif
 
        /* setup RR for conversions */
-#if 1
        BN_zero(&(mont->RR));
        BN_set_bit(&(mont->RR),mont->ri*2);
        BN_zero(&(mont->RR));
        BN_set_bit(&(mont->RR),mont->ri*2);
-#else
-       BN_lshift(mont->RR,BN_value_one(),mont->ri*2);
-#endif
        BN_mod(&(mont->RR),&(mont->RR),&(mont->N),ctx);
 
        return(1);
        BN_mod(&(mont->RR),&(mont->RR),&(mont->N),ctx);
 
        return(1);
@@ -402,8 +336,7 @@ err:
        return(0);
        }
 
        return(0);
        }
 
-BN_MONT_CTX *BN_MONT_CTX_copy(to, from)
-BN_MONT_CTX *to, *from;
+BN_MONT_CTX *BN_MONT_CTX_copy(BN_MONT_CTX *to, BN_MONT_CTX *from)
        {
        if (to == from) return(to);
 
        {
        if (to == from) return(to);