/* r must not be a */
/* I've just gone over this and it is now %20 faster on x86 - eay - 27 Jun 96 */
-int BN_sqr(r, a, ctx)
-BIGNUM *r;
-BIGNUM *a;
-BN_CTX *ctx;
+int BN_sqr(BIGNUM *r, const BIGNUM *a, BN_CTX *ctx)
{
int max,al;
- BIGNUM *tmp;
+ int ret = 0;
+ BIGNUM *tmp,*rr,*free_a = NULL;
#ifdef BN_COUNT
printf("BN_sqr %d * %d\n",a->top,a->top);
#endif
bn_check_top(a);
- tmp= &(ctx->bn[ctx->tos]);
al=a->top;
if (al <= 0)
return(1);
}
+ BN_CTX_start(ctx);
+ rr=(a != r) ? r : BN_CTX_get(ctx);
+ tmp=BN_CTX_get(ctx);
+ if (tmp == NULL) goto err;
+
max=(al+al);
- if (bn_wexpand(r,max+1) == NULL) return(0);
+ if (bn_wexpand(rr,max+1) == NULL) goto err;
r->neg=0;
if (al == 4)
{
#ifndef BN_SQR_COMBA
BN_ULONG t[8];
- bn_sqr_normal(r->d,a->d,4,t);
+ bn_sqr_normal(rr->d,a->d,4,t);
#else
- bn_sqr_comba4(r->d,a->d);
+ bn_sqr_comba4(rr->d,a->d);
#endif
}
else if (al == 8)
{
#ifndef BN_SQR_COMBA
BN_ULONG t[16];
- bn_sqr_normal(r->d,a->d,8,t);
+ bn_sqr_normal(rr->d,a->d,8,t);
#else
- bn_sqr_comba8(r->d,a->d);
+ bn_sqr_comba8(rr->d,a->d);
#endif
}
else
if (al < BN_SQR_RECURSIVE_SIZE_NORMAL)
{
BN_ULONG t[BN_SQR_RECURSIVE_SIZE_NORMAL*2];
- bn_sqr_normal(r->d,a->d,al,t);
+ bn_sqr_normal(rr->d,a->d,al,t);
}
else
{
- if (bn_wexpand(tmp,2*max+1) == NULL) return(0);
- bn_sqr_recursive(r->d,a->d,al,tmp->d);
+ int j,k;
+
+ j=BN_num_bits_word((BN_ULONG)al);
+ j=1<<(j-1);
+ k=j+j;
+ if (al == j)
+ {
+ BIGNUM *tmp_bn = free_a;
+ if ((a = free_a = bn_dup_expand(a,k*2)) == NULL) goto err;
+ if (bn_wexpand(tmp,k*2) == NULL) goto err;
+ if (tmp_bn) BN_free(tmp_bn);
+ bn_sqr_recursive(rr->d,a->d,al,tmp->d);
+ }
+ else
+ {
+ if (bn_wexpand(tmp,max) == NULL) goto err;
+ bn_sqr_normal(rr->d,a->d,al,tmp->d);
+ }
}
#else
- if (bn_wexpand(tmp,max) == NULL) return(0);
- bn_sqr_normal(r->d,a->d,al,tmp->d);
+ if (bn_wexpand(tmp,max) == NULL) goto err;
+ bn_sqr_normal(rr->d,a->d,al,tmp->d);
#endif
}
- r->top=max;
- if ((max > 0) && (r->d[max-1] == 0)) r->top--;
- return(1);
+ rr->top=max;
+ if ((max > 0) && (rr->d[max-1] == 0)) rr->top--;
+ if (rr != r) BN_copy(r,rr);
+ ret = 1;
+ err:
+ if (free_a) BN_free(free_a);
+ BN_CTX_end(ctx);
+ return(ret);
}
/* tmp must have 2*n words */
-void bn_sqr_normal(r, a, n, tmp)
-BN_ULONG *r;
-BN_ULONG *a;
-int n;
-BN_ULONG *tmp;
+void bn_sqr_normal(BN_ULONG *r, BN_ULONG *a, int n, BN_ULONG *tmp)
{
int i,j,max;
BN_ULONG *ap,*rp;
#ifdef BN_RECURSION
/* r is 2*n words in size,
- * a and b are both n words in size.
+ * a and b are both n words in size. (There's not actually a 'b' here ...)
* n must be a power of 2.
* We multiply and return the result.
* t must be 2*n words in size
- * We calulate
+ * We calculate
* a[0]*b[0]
* a[0]*b[0]+a[1]*b[1]+(a[0]-a[1])*(b[1]-b[0])
* a[1]*b[1]
*/
-void bn_sqr_recursive(r,a,n2,t)
-BN_ULONG *r,*a;
-int n2;
-BN_ULONG *t;
+void bn_sqr_recursive(BN_ULONG *r, BN_ULONG *a, int n2, BN_ULONG *t)
{
int n=n2/2;
int zero,c1;
* r[32] holds (b[1]*b[1])
*/
- c1=bn_add_words(t,r,&(r[n2]),n2);
+ c1=(int)(bn_add_words(t,r,&(r[n2]),n2));
/* t[32] is negative */
- c1-=bn_sub_words(&(t[n2]),t,&(t[n2]),n2);
+ c1-=(int)(bn_sub_words(&(t[n2]),t,&(t[n2]),n2));
/* t[32] holds (a[0]-a[1])*(a[1]-a[0])+(a[0]*a[0])+(a[1]*a[1])
* r[10] holds (a[0]*a[0])
* r[32] holds (a[1]*a[1])
* c1 holds the carry bits
*/
- c1+=bn_add_words(&(r[n]),&(r[n]),&(t[n2]),n2);
+ c1+=(int)(bn_add_words(&(r[n]),&(r[n]),&(t[n2]),n2));
if (c1)
{
p= &(r[n+n2]);