9a3f4130ede203b1357b789e87fe8cd12b35eb4c
[openssl.git] / crypto / bn / stuff / bn_knuth.c
1 /* crypto/bn/bn_knuth.c */
2
3 #include <stdio.h>
4 #include "cryptlib.h"
5 #include "bn.h"
6
7 /* This is just a test implementation, it has not been modified for
8  * speed and it still has memory leaks. */
9
10 int BN_mask_bits(BIGNUM *a,int n);
11
12 #undef DEBUG
13 #define MAIN
14
15 /* r must be different to a and b
16  * Toom-Cook multiplication algorithm, taken from
17  * The Art Of Computer Programming, Volume 2, Donald Knuth
18  */
19
20 #define CODE1           ((BIGNUM *)0x01)
21 #define CODE2           ((BIGNUM *)0x02)
22 #define CODE3           ((BIGNUM *)0x03)
23 #define MAXK            (30+1)
24
25 #define C3      3
26 #define C4      4
27 #define C5      5
28 #define C6      6
29 #define C7      7
30 #define C8      8
31 #define C9      9
32 #define C10     10
33 #define DONE    11
34
35 int new_total=0;
36 int Free_total=0;
37 int max=0,max_total=0;
38
39 BIGNUM *LBN_new(void );
40 BIGNUM *LBN_dup(BIGNUM *a);
41 void LBN_free(BIGNUM *a);
42
43 int BN_mul_knuth(w, a, b)
44 BIGNUM *w;
45 BIGNUM *a;
46 BIGNUM *b;
47         {
48         int ret=1;
49         int i,j,n,an,bn,y,z;
50         BIGNUM *U[MAXK],*V[MAXK],*T[MAXK];
51         BIGNUM *C[(MAXK*2*3)];
52         BIGNUM *W[(MAXK*2)],*t1,*t2,*t3,*t4;
53         int Utos,Vtos,Ctos,Wtos,Ttos;
54         unsigned int k,Q,R;
55         unsigned int q[MAXK];
56         unsigned int r[MAXK];
57         int state;
58
59         /* C1 */
60         Utos=Vtos=Ctos=Wtos=Ttos=0;
61         k=1;
62         q[0]=q[1]=64;
63         r[0]=r[1]=4;
64         Q=6;
65         R=2;
66
67         if (!bn_expand(w,BN_BITS2*2)) goto err;
68         an=BN_num_bits(a);
69         bn=BN_num_bits(b);
70         n=(an > bn)?an:bn;
71         while ((q[k-1]+q[k]) < n)
72                 {
73                 k++;
74                 Q+=R;
75                 i=R+1;
76                 if ((i*i) <= Q) R=i;
77                 q[k]=(1<<Q);
78                 r[k]=(1<<R);
79                 }
80 #ifdef DEBUG
81         printf("k   =");
82         for (i=0; i<=k; i++) printf("%7d",i);
83         printf("\nq[k]=");
84         for (i=0; i<=k; i++) printf("%7d",q[i]);
85         printf("\nr[k]=");
86         for (i=0; i<=k; i++) printf("%7d",r[i]);
87         printf("\n");
88 #endif
89
90         /* C2 */
91         C[Ctos++]=CODE1;
92         if ((t1=LBN_dup(a)) == NULL) goto err;
93         C[Ctos++]=t1;
94         if ((t1=LBN_dup(b)) == NULL) goto err;
95         C[Ctos++]=t1;
96
97         state=C3;
98         for (;;)
99                 {
100 #ifdef DEBUG
101                 printf("state=C%d, Ctos=%d Wtos=%d\n",state,Ctos,Wtos);
102 #endif
103                 switch (state)
104                         {
105                         int lr,lq,lp;
106                 case C3:
107                         k--;
108                         if (k == 0)
109                                 {
110                                 t1=C[--Ctos];
111                                 t2=C[--Ctos];
112 #ifdef DEBUG
113                                 printf("Ctos=%d poped %d\n",Ctos,2);
114 #endif
115                                 if ((t2->top == 0) || (t1->top == 0))
116                                         w->top=0;
117                                 else
118                                         BN_mul(w,t1,t2);
119
120                                 LBN_free(t1); /* FREE */
121                                 LBN_free(t2); /* FREE */
122                                 state=C10;
123                                 }
124                         else
125                                 {
126                                 lr=r[k];
127                                 lq=q[k];
128                                 lp=q[k-1]+q[k];
129                                 state=C4;
130                                 }
131                         break;
132                 case C4:
133                         for (z=0; z<2; z++) /* do for u and v */
134                                 {
135                                 /* break the item at C[Ctos-1] 
136                                  * into lr+1 parts of lq bits each
137                                  * for j=0; j<=2r; j++
138                                  */
139                                 t1=C[--Ctos]; /* pop off u */
140 #ifdef DEBUG
141                                 printf("Ctos=%d poped %d\n",Ctos,1);
142 #endif
143                                 if ((t2=LBN_dup(t1)) == NULL) goto err;
144                                 BN_mask_bits(t2,lq);
145                                 T[Ttos++]=t2;
146 #ifdef DEBUG
147                                 printf("C4 r=0 bits=%d\n",BN_num_bits(t2));
148 #endif
149                                 for (i=1; i<=lr; i++)
150                                         {
151                                         if (!BN_rshift(t1,t1,lq)) goto err;
152                                         if ((t2=LBN_dup(t1)) == NULL) goto err;
153                                         BN_mask_bits(t2,lq);
154                                         T[Ttos++]=t2;
155 #ifdef DEBUG
156                                         printf("C4 r=%d bits=%d\n",i,
157                                                 BN_num_bits(t2));
158 #endif
159                                         }
160                                 LBN_free(t1);
161
162                                 if ((t2=LBN_new()) == NULL) goto err;
163                                 if ((t3=LBN_new()) == NULL) goto err;
164                                 for (j=0; j<=2*lr; j++)
165                                         {
166                                         if ((t1=LBN_new()) == NULL) goto err;
167
168                                         if (!BN_set_word(t3,j)) goto err;
169                                         for (i=lr; i>=0; i--)
170                                                 {
171                                                 if (!BN_mul(t2,t1,t3)) goto err;
172                                                 if (!BN_add(t1,t2,T[i])) goto err;
173                                                 }
174                                         /* t1 is U(j) */
175                                         if (z == 0)
176                                                 U[Utos++]=t1;
177                                         else
178                                                 V[Vtos++]=t1;
179                                         }
180                                 LBN_free(t2);
181                                 LBN_free(t3);
182                                 while (Ttos) LBN_free(T[--Ttos]);
183                                 }
184 #ifdef DEBUG
185                         for (i=0; i<Utos; i++)
186                                 printf("U[%2d]=%4d bits\n",i,BN_num_bits(U[i]));
187                         for (i=0; i<Vtos; i++)
188                                 printf("V[%2d]=%4d bits\n",i,BN_num_bits(V[i]));
189 #endif
190                         /* C5 */
191 #ifdef DEBUG
192                         printf("PUSH CODE2 and %d CODE3 onto stack\n",2*lr);
193 #endif
194                         C[Ctos++]=CODE2;
195                         for (i=2*lr; i>0; i--)
196                                 {
197                                 C[Ctos++]=V[i];
198                                 C[Ctos++]=U[i];
199                                 C[Ctos++]=CODE3;
200                                 }
201                         C[Ctos++]=V[0];
202                         C[Ctos++]=U[0];
203 #ifdef DEBUG
204                                 printf("Ctos=%d pushed %d\n",Ctos,2*lr*3+3);
205 #endif
206                         Vtos=Utos=0;
207                         state=C3;
208                         break;
209                 case C6:
210                         if ((t1=LBN_dup(w)) == NULL) goto err;
211                         W[Wtos++]=t1;
212 #ifdef DEBUG
213                         printf("put %d bit number onto w\n",BN_num_bits(t1));
214 #endif
215                         state=C3;
216                         break;
217                 case C7:
218                         lr=r[k];
219                         lq=q[k];
220                         lp=q[k]+q[k-1];
221                         z=Wtos-2*lr-1;
222                         for (j=1; j<=2*lr; j++)
223                                 {
224                                 for (i=2*lr; i>=j; i--)
225                                         {
226                                         if (!BN_sub(W[z+i],W[z+i],W[z+i-1])) goto err;
227                                         BN_div_word(W[z+i],j);
228                                         }
229                                 }
230                         state=C8;
231                         break;
232                 case C8:
233                         y=2*lr-1;
234                         if ((t1=LBN_new()) == NULL) goto err;
235                         if ((t3=LBN_new()) == NULL) goto err;
236
237                         for (j=y; j>0; j--)
238                                 {
239                                 if (!BN_set_word(t3,j)) goto err;
240                                 for (i=j; i<=y; i++)
241                                         {
242                                         if (!BN_mul(t1,W[z+i+1],t3)) goto err;
243                                         if (!BN_sub(W[z+i],W[z+i],t1)) goto err;
244                                         }
245                                 }
246                         LBN_free(t1);
247                         LBN_free(t3);
248                         state=C9;
249                         break;
250                 case C9:
251                         BN_zero(w);
252 #ifdef DEBUG
253                         printf("lq=%d\n",lq);
254 #endif
255                         for (i=lr*2; i>=0; i--)
256                                 {
257                                 BN_lshift(w,w,lq);
258                                 BN_add(w,w,W[z+i]);
259                                 }
260                         for (i=0; i<=lr*2; i++)
261                                 LBN_free(W[--Wtos]);
262                         state=C10;
263                         break;
264                 case C10:
265                         k++;
266                         t1=C[--Ctos];
267 #ifdef DEBUG
268                         printf("Ctos=%d poped %d\n",Ctos,1);
269                         printf("code= CODE%d\n",t1);
270 #endif
271                         if (t1 == CODE3)
272                                 state=C6;
273                         else if (t1 == CODE2)
274                                 {
275                                 if ((t2=LBN_dup(w)) == NULL) goto err;
276                                 W[Wtos++]=t2;
277                                 state=C7;
278                                 }
279                         else if (t1 == CODE1)
280                                 {
281                                 state=DONE;
282                                 }
283                         else
284                                 {
285                                 printf("BAD ERROR\n");
286                                 goto err;
287                                 }
288                         break;
289                 default:
290                         printf("bad state\n");
291                         goto err;
292                         break;
293                         }
294                 if (state == DONE) break;
295                 }
296         ret=1;
297 err:
298         if (ret == 0) printf("ERROR\n");
299         return(ret);
300         }
301
302 #ifdef MAIN
303 main()
304         {
305         BIGNUM *a,*b,*r;
306         int i;
307
308         if ((a=LBN_new()) == NULL) goto err;
309         if ((b=LBN_new()) == NULL) goto err;
310         if ((r=LBN_new()) == NULL) goto err;
311
312         if (!BN_rand(a,1024*2,0,0)) goto err;
313         if (!BN_rand(b,1024*2,0,0)) goto err;
314
315         for (i=0; i<10; i++)
316                 {
317                 if (!BN_mul_knuth(r,a,b)) goto err; /**/
318                 /*if (!BN_mul(r,a,b)) goto err; /**/
319                 }
320 BN_print(stdout,a); printf(" * ");
321 BN_print(stdout,b); printf(" =\n");
322 BN_print(stdout,r); printf("\n");
323
324 printf("BN_new() =%d\nBN_free()=%d max=%d\n",new_total,Free_total,max);
325
326
327         exit(0);
328 err:
329         ERR_load_crypto_strings();
330         ERR_print_errors(stderr);
331         exit(1);
332         }
333 #endif
334
335 int BN_mask_bits(a,n)
336 BIGNUM *a;
337 int n;
338         {
339         int b,w;
340
341         w=n/BN_BITS2;
342         b=n%BN_BITS2;
343         if (w >= a->top) return(0);
344         if (b == 0)
345                 a->top=w;
346         else
347                 {
348                 a->top=w+1;
349                 a->d[w]&= ~(BN_MASK2<<b);
350                 }
351         return(1);
352         }
353
354 BIGNUM *LBN_dup(a)
355 BIGNUM *a;
356         {
357         new_total++;
358         max_total++;
359         if (max_total > max) max=max_total;
360         return(BN_dup(a));
361         }
362
363 BIGNUM *LBN_new()
364         {
365         new_total++;
366         max_total++;
367         if (max_total > max) max=max_total;
368         return(BN_new());
369         }
370
371 void LBN_free(a)
372 BIGNUM *a;
373         {
374         max_total--;
375         if (max_total > max) max=max_total;
376         Free_total++;
377         BN_free(a);
378         }