Ensure that x**0 mod 1 = 0.
[openssl.git] / crypto / bn / bn_sqr.c
index bbff1ad72af28c0cac922e394b24ed224495a101..65bbf165d0e2f3bc3de3d662c1c02800dbb33c70 100644 (file)
@@ -66,7 +66,7 @@ int BN_sqr(BIGNUM *r, const BIGNUM *a, BN_CTX *ctx)
        {
        int max,al;
        int ret = 0;
-       BIGNUM *tmp,*rr,*free_a = NULL;
+       BIGNUM *tmp,*rr;
 
 #ifdef BN_COUNT
        fprintf(stderr,"BN_sqr %d * %d\n",a->top,a->top);
@@ -77,18 +77,18 @@ int BN_sqr(BIGNUM *r, const BIGNUM *a, BN_CTX *ctx)
        if (al <= 0)
                {
                r->top=0;
-               return(1);
+               r->neg = 0;
+               return 1;
                }
 
        BN_CTX_start(ctx);
        rr=(a != r) ? r : BN_CTX_get(ctx);
        tmp=BN_CTX_get(ctx);
-       if (tmp == NULL) goto err;
+       if (!rr || !tmp) goto err;
 
-       max=(al+al);
-       if (bn_wexpand(rr,max+1) == NULL) goto err;
+       max = 2 * al; /* Non-zero (from above) */
+       if (bn_wexpand(rr,max) == NULL) goto err;
 
-       r->neg=0;
        if (al == 4)
                {
 #ifndef BN_SQR_COMBA
@@ -139,12 +139,18 @@ int BN_sqr(BIGNUM *r, const BIGNUM *a, BN_CTX *ctx)
 #endif
                }
 
-       rr->top=max;
-       if ((max > 0) && (rr->d[max-1] == 0)) rr->top--;
+       rr->neg=0;
+       /* If the most-significant half of the top word of 'a' is zero, then
+        * the square of 'a' will max-1 words. */
+       if(a->d[al - 1] == (a->d[al - 1] & BN_MASK2l))
+               rr->top = max - 1;
+       else
+               rr->top = max;
        if (rr != r) BN_copy(r,rr);
        ret = 1;
  err:
-       if (free_a) BN_free(free_a);
+       bn_check_top(rr);
+       bn_check_top(tmp);
        BN_CTX_end(ctx);
        return(ret);
        }
@@ -246,7 +252,7 @@ void bn_sqr_recursive(BN_ULONG *r, const BN_ULONG *a, int n2, BN_ULONG *t)
        if (!zero)
                bn_sqr_recursive(&(t[n2]),t,n,p);
        else
-               memset(&(t[n2]),0,n*sizeof(BN_ULONG));
+               memset(&(t[n2]),0,n2*sizeof(BN_ULONG));
        bn_sqr_recursive(r,a,n,p);
        bn_sqr_recursive(&(r[n2]),&(a[n]),n,p);