sha512-x86_64.pl: fix typo.
[openssl.git] / crypto / bn / bn_sqr.c
index 19ec0ddf842cc6a5366f68296909c71f6540abe0..270d0cd348b90056f14ce429676b700cd577118b 100644 (file)
 
 /* r must not be a */
 /* I've just gone over this and it is now %20 faster on x86 - eay - 27 Jun 96 */
-int BN_sqr(r, a, ctx)
-BIGNUM *r;
-BIGNUM *a;
-BN_CTX *ctx;
+int BN_sqr(BIGNUM *r, const BIGNUM *a, BN_CTX *ctx)
        {
        int max,al;
-       BIGNUM *tmp;
+       int ret = 0;
+       BIGNUM *tmp,*rr;
 
 #ifdef BN_COUNT
-printf("BN_sqr %d * %d\n",a->top,a->top);
+       fprintf(stderr,"BN_sqr %d * %d\n",a->top,a->top);
 #endif
        bn_check_top(a);
-       tmp= &(ctx->bn[ctx->tos]);
 
        al=a->top;
        if (al <= 0)
                {
                r->top=0;
-               return(1);
+               return 1;
                }
 
-       max=(al+al);
-       if (bn_wexpand(r,max+1) == NULL) return(0);
+       BN_CTX_start(ctx);
+       rr=(a != r) ? r : BN_CTX_get(ctx);
+       tmp=BN_CTX_get(ctx);
+       if (!rr || !tmp) goto err;
+
+       max = 2 * al; /* Non-zero (from above) */
+       if (bn_wexpand(rr,max) == NULL) goto err;
 
-       r->neg=0;
        if (al == 4)
                {
 #ifndef BN_SQR_COMBA
                BN_ULONG t[8];
-               bn_sqr_normal(r->d,a->d,4,t);
+               bn_sqr_normal(rr->d,a->d,4,t);
 #else
-               bn_sqr_comba4(r->d,a->d);
+               bn_sqr_comba4(rr->d,a->d);
 #endif
                }
        else if (al == 8)
                {
 #ifndef BN_SQR_COMBA
                BN_ULONG t[16];
-               bn_sqr_normal(r->d,a->d,8,t);
+               bn_sqr_normal(rr->d,a->d,8,t);
 #else
-               bn_sqr_comba8(r->d,a->d);
+               bn_sqr_comba8(rr->d,a->d);
 #endif
                }
        else 
@@ -111,33 +112,54 @@ printf("BN_sqr %d * %d\n",a->top,a->top);
                if (al < BN_SQR_RECURSIVE_SIZE_NORMAL)
                        {
                        BN_ULONG t[BN_SQR_RECURSIVE_SIZE_NORMAL*2];
-                       bn_sqr_normal(r->d,a->d,al,t);
+                       bn_sqr_normal(rr->d,a->d,al,t);
                        }
                else
                        {
-                       if (bn_wexpand(tmp,2*max+1) == NULL) return(0);
-                       bn_sqr_recursive(r->d,a->d,al,tmp->d);
+                       int j,k;
+
+                       j=BN_num_bits_word((BN_ULONG)al);
+                       j=1<<(j-1);
+                       k=j+j;
+                       if (al == j)
+                               {
+                               if (bn_wexpand(tmp,k*2) == NULL) goto err;
+                               bn_sqr_recursive(rr->d,a->d,al,tmp->d);
+                               }
+                       else
+                               {
+                               if (bn_wexpand(tmp,max) == NULL) goto err;
+                               bn_sqr_normal(rr->d,a->d,al,tmp->d);
+                               }
                        }
 #else
-               if (bn_wexpand(tmp,max) == NULL) return(0);
-               bn_sqr_normal(r->d,a->d,al,tmp->d);
+               if (bn_wexpand(tmp,max) == NULL) goto err;
+               bn_sqr_normal(rr->d,a->d,al,tmp->d);
 #endif
                }
 
-       r->top=max;
-       if ((max > 0) && (r->d[max-1] == 0)) r->top--;
-       return(1);
+       rr->neg=0;
+       /* If the most-significant half of the top word of 'a' is zero, then
+        * the square of 'a' will max-1 words. */
+       if(a->d[al - 1] == (a->d[al - 1] & BN_MASK2l))
+               rr->top = max - 1;
+       else
+               rr->top = max;
+       if (rr != r) BN_copy(r,rr);
+       ret = 1;
+ err:
+       bn_check_top(rr);
+       bn_check_top(tmp);
+       BN_CTX_end(ctx);
+       return(ret);
        }
 
 /* tmp must have 2*n words */
-void bn_sqr_normal(r, a, n, tmp)
-BN_ULONG *r;
-BN_ULONG *a;
-int n;
-BN_ULONG *tmp;
+void bn_sqr_normal(BN_ULONG *r, const BN_ULONG *a, int n, BN_ULONG *tmp)
        {
        int i,j,max;
-       BN_ULONG *ap,*rp;
+       const BN_ULONG *ap;
+       BN_ULONG *rp;
 
        max=n*2;
        ap=a;
@@ -172,26 +194,23 @@ BN_ULONG *tmp;
 
 #ifdef BN_RECURSION
 /* r is 2*n words in size,
- * a and b are both n words in size.
+ * a and b are both n words in size.    (There's not actually a 'b' here ...)
  * n must be a power of 2.
  * We multiply and return the result.
  * t must be 2*n words in size
- * We calulate
+ * We calculate
  * a[0]*b[0]
  * a[0]*b[0]+a[1]*b[1]+(a[0]-a[1])*(b[1]-b[0])
  * a[1]*b[1]
  */
-void bn_sqr_recursive(r,a,n2,t)
-BN_ULONG *r,*a;
-int n2;
-BN_ULONG *t;
+void bn_sqr_recursive(BN_ULONG *r, const BN_ULONG *a, int n2, BN_ULONG *t)
        {
        int n=n2/2;
        int zero,c1;
        BN_ULONG ln,lo,*p;
 
 #ifdef BN_COUNT
-printf(" bn_sqr_recursive %d * %d\n",n2,n2);
+       fprintf(stderr," bn_sqr_recursive %d * %d\n",n2,n2);
 #endif
        if (n2 == 4)
                {
@@ -232,7 +251,7 @@ printf(" bn_sqr_recursive %d * %d\n",n2,n2);
        if (!zero)
                bn_sqr_recursive(&(t[n2]),t,n,p);
        else
-               memset(&(t[n2]),0,n*sizeof(BN_ULONG));
+               memset(&(t[n2]),0,n2*sizeof(BN_ULONG));
        bn_sqr_recursive(r,a,n,p);
        bn_sqr_recursive(&(r[n2]),&(a[n]),n,p);
 
@@ -241,17 +260,17 @@ printf(" bn_sqr_recursive %d * %d\n",n2,n2);
         * r[32] holds (b[1]*b[1])
         */
 
-       c1=bn_add_words(t,r,&(r[n2]),n2);
+       c1=(int)(bn_add_words(t,r,&(r[n2]),n2));
 
        /* t[32] is negative */
-       c1-=bn_sub_words(&(t[n2]),t,&(t[n2]),n2);
+       c1-=(int)(bn_sub_words(&(t[n2]),t,&(t[n2]),n2));
 
        /* t[32] holds (a[0]-a[1])*(a[1]-a[0])+(a[0]*a[0])+(a[1]*a[1])
         * r[10] holds (a[0]*a[0])
         * r[32] holds (a[1]*a[1])
         * c1 holds the carry bits
         */
-       c1+=bn_add_words(&(r[n]),&(r[n]),&(t[n2]),n2);
+       c1+=(int)(bn_add_words(&(r[n]),&(r[n]),&(t[n2]),n2));
        if (c1)
                {
                p= &(r[n+n2]);