New functions BN_CTX_start(), BN_CTX_get(), BN_CTX_end() to access
[openssl.git] / crypto / bn / bn_sqr.c
index 1874c14628a2199186438d9cf4a30e46af998458..fe00c5f69a01025918a5b06e271cfb52f269e415 100644 (file)
 int BN_sqr(BIGNUM *r, BIGNUM *a, BN_CTX *ctx)
        {
        int max,al;
+       int ret = 0;
        BIGNUM *tmp,*rr;
 
 #ifdef BN_COUNT
 printf("BN_sqr %d * %d\n",a->top,a->top);
 #endif
        bn_check_top(a);
-       tmp= &(ctx->bn[ctx->tos]);
-       rr=(a != r)?r: (&ctx->bn[ctx->tos+1]);
 
        al=a->top;
        if (al <= 0)
@@ -81,8 +80,13 @@ printf("BN_sqr %d * %d\n",a->top,a->top);
                return(1);
                }
 
+       BN_CTX_start(ctx);
+       rr=(a != r) ? r : BN_CTX_get(ctx);
+       tmp=BN_CTX_get(ctx);
+       if (tmp == NULL) goto err;
+
        max=(al+al);
-       if (bn_wexpand(rr,max+1) == NULL) return(0);
+       if (bn_wexpand(rr,max+1) == NULL) goto err;
 
        r->neg=0;
        if (al == 4)
@@ -120,18 +124,18 @@ printf("BN_sqr %d * %d\n",a->top,a->top);
                        k=j+j;
                        if (al == j)
                                {
-                               if (bn_wexpand(a,k*2) == NULL) return(0);
-                               if (bn_wexpand(tmp,k*2) == NULL) return(0);
+                               if (bn_wexpand(a,k*2) == NULL) goto err;
+                               if (bn_wexpand(tmp,k*2) == NULL) goto err;
                                bn_sqr_recursive(rr->d,a->d,al,tmp->d);
                                }
                        else
                                {
-                               if (bn_wexpand(tmp,max) == NULL) return(0);
+                               if (bn_wexpand(tmp,max) == NULL) goto err;
                                bn_sqr_normal(rr->d,a->d,al,tmp->d);
                                }
                        }
 #else
-               if (bn_wexpand(tmp,max) == NULL) return(0);
+               if (bn_wexpand(tmp,max) == NULL) goto err;
                bn_sqr_normal(rr->d,a->d,al,tmp->d);
 #endif
                }
@@ -139,7 +143,10 @@ printf("BN_sqr %d * %d\n",a->top,a->top);
        rr->top=max;
        if ((max > 0) && (rr->d[max-1] == 0)) rr->top--;
        if (rr != r) BN_copy(r,rr);
-       return(1);
+       ret = 1;
+ err:
+       BN_CTX_end(ctx);
+       return(ret);
        }
 
 /* tmp must have 2*n words */