821dcd3e55688c6add3c5fe10dc96b4bc40323c6
[openssl.git] / crypto / bn / bn_mont2.c
1 /*
2  *
3  *      bn_mont2.c
4  *
5  *      Montgomery Modular Arithmetic Functions.
6  *
7  *      Copyright (C) Lenka Fibikova 2000
8  *
9  *
10  */
11
12
13 #include <stdio.h>
14 #include <stdlib.h>
15 #include <assert.h>
16
17 #include "bn_lcl.h"
18 #include "bn_mont2.h"
19
20 #define BN_mask_word(x, m) ((x->d[0]) & (m))
21
22 BN_MONTGOMERY *BN_mont_new()
23         {
24         BN_MONTGOMERY *ret;
25
26         ret=(BN_MONTGOMERY *)malloc(sizeof(BN_MONTGOMERY));
27
28         if (ret == NULL) return NULL;
29
30         if ((ret->p = BN_new()) == NULL)
31                 {
32                 free(ret);
33                 return NULL;
34                 }
35
36         return ret;
37         }
38
39
40 void BN_mont_clear_free(BN_MONTGOMERY *mont)
41         {
42         if (mont == NULL) return;
43
44         if (mont->p != NULL) BN_clear_free(mont->p);
45
46         mont->p_num_bytes = 0;
47         mont->R_num_bits = 0;
48         mont->p_inv_b_neg = 0;
49         }
50
51
52 int BN_to_mont(BIGNUM *x, BN_MONTGOMERY *mont, BN_CTX *ctx)
53         {
54         assert(x != NULL);
55
56         assert(mont != NULL);
57         assert(mont->p != NULL);
58
59         assert(ctx != NULL);
60
61         if (!BN_lshift(x, x, mont->R_num_bits)) return 0;
62         if (!BN_mod(x, x, mont->p, ctx)) return 0;
63
64         return 1;
65         }
66
67
68 static BN_ULONG BN_mont_inv(BIGNUM *a, int e, BN_CTX *ctx)
69 /* y = a^{-1} (mod 2^e) for an odd number a */
70         {
71         BN_ULONG y, exp, mask;
72         BIGNUM *x, *xy, *x_sh;
73         int i;
74
75         assert(a != NULL && ctx != NULL);
76         assert(e <= BN_BITS2);
77         assert(BN_is_odd(a));
78         assert(!BN_is_zero(a) && !a->neg);
79
80
81         y = 1;
82         exp = 2;
83         mask = 3;
84         if((x = BN_dup(a)) == NULL) return 0;
85         if (x->top > e/BN_BITS2)
86                 if(!BN_mask_bits(x, e)) return 0;
87
88         BN_CTX_start(ctx);
89         xy = BN_CTX_get(ctx);
90         x_sh = BN_CTX_get(ctx);
91         if (x_sh == NULL) goto err;
92
93         if (BN_copy(xy, x) == NULL) goto err;
94         if (!BN_lshift1(x_sh, x)) goto err;
95
96
97         for (i = 2; i <= e; i++)
98                 {
99                 if (exp < BN_mask_word(xy, mask))
100                         {
101                         y = y + exp;
102                         if (!BN_add(xy, xy, x_sh)) goto err;
103                         }
104
105                 exp <<= 1;
106                 if (!BN_lshift1(x_sh, x_sh)) goto err;
107                 mask <<= 1;
108                 mask++;
109                 }
110
111
112 #ifdef TEST
113         if (xy->d[0] != 1) goto err;
114 #endif
115
116         if (x != NULL) BN_clear_free(x);
117         BN_CTX_end(ctx);
118         return y;
119
120
121 err:
122         if (x != NULL) BN_clear_free(x);
123         BN_CTX_end(ctx);
124         return 0;
125         }
126
127
128 int BN_mont_set(BIGNUM *p, BN_MONTGOMERY *mont, BN_CTX *ctx)
129         {
130         assert(p != NULL && ctx != NULL);
131         assert(mont != NULL);
132         assert(mont->p != NULL);
133         assert(!BN_is_zero(p) && !p->neg);
134
135
136         mont->p_num_bytes = p->top;
137         mont->R_num_bits = (mont->p_num_bytes) * BN_BITS2;
138
139         if (BN_copy(mont->p, p) == NULL);
140         
141         mont->p_inv_b_neg =  BN_mont_inv(p, BN_BITS2, ctx);
142         if (!mont->p_inv_b_neg) return 0;
143         mont->p_inv_b_neg = 0 - mont->p_inv_b_neg;
144
145         return 1;
146         }
147
148
149 #ifdef BN_LLONG
150 #define cpy_mul_add(r, b, a, w, c) { \
151         BN_ULLONG t; \
152         t = (BN_ULLONG)w * (a) + (b) + (c); \
153         (r)= Lw(t); \
154         (c)= Hw(t); \
155         }
156
157 BN_ULONG BN_mul_add_rshift(BN_ULONG *r, BN_ULONG *a, int num, BN_ULONG w)
158 /* r = (r + a * w) >> BN_BITS2 */
159         {
160         BN_ULONG c = 0;
161
162         mul_add(r[0], a[0], w, c);
163         if (--num == 0) return c;
164         a++;
165
166         for (;;)
167                 {
168                 cpy_mul_add(r[0], r[1], a[0], w, c);
169                 if (--num == 0) break;
170                 cpy_mul_add(r[1], r[2], a[1], w, c);
171                 if (--num == 0) break;
172                 cpy_mul_add(r[2], r[3], a[2], w, c);
173                 if (--num == 0) break;
174                 cpy_mul_add(r[3], r[4], a[3], w, c);
175                 if (--num == 0) break;
176                 a += 4;
177                 r += 4;
178                 }
179         
180         return c;
181         }
182 #else
183
184 #define cpy_mul_add(r, b, a, bl, bh, c) { \
185         BN_ULONG l,h; \
186  \
187         h=(a); \
188         l=LBITS(h); \
189         h=HBITS(h); \
190         mul64(l,h,(bl),(bh)); \
191  \
192         /* non-multiply part */ \
193         l=(l+(c))&BN_MASK2; if (l < (c)) h++; \
194         (c)=(b); \
195         l=(l+(c))&BN_MASK2; if (l < (c)) h++; \
196         (c)=h&BN_MASK2; \
197         (r)=l; \
198         }
199
200 static BN_ULONG BN_mul_add_rshift(BN_ULONG *r, BN_ULONG *a, int num, BN_ULONG w)
201 /* ret = (ret + a * w) << shift * BN_BITS2 */
202         {
203         BN_ULONG c = 0;
204         BN_ULONG bl, bh;
205
206         bl = LBITS(w);
207         bh = HBITS(w);
208
209         mul_add(r[0], a[0], bl, bh, c);
210         if (--num == 0) return c;
211         a++;
212
213         for (;;)
214                 {
215                 cpy_mul_add(r[0], r[1], a[0], bl, bh, c);
216                 if (--num == 0) break;
217                 cpy_mul_add(r[1], r[2], a[1], bl, bh, c);
218                 if (--num == 0) break;
219                 cpy_mul_add(r[2], r[3], a[2], bl, bh, c);
220                 if (--num == 0) break;
221                 cpy_mul_add(r[3], r[4], a[3], bl, bh, c);
222                 if (--num == 0) break;
223                 a += 4;
224                 r += 4;
225                 }
226         return c;
227         }
228 #endif /* BN_LLONG */
229
230
231
232 int BN_mont_red(BIGNUM *y, BN_MONTGOMERY *mont)
233 /* yR^{-1} (mod p) */
234         {
235         BIGNUM *p;
236         BN_ULONG c;
237         int i, max;
238
239         assert(y != NULL && mont != NULL);
240         assert(mont->p != NULL);
241         assert(BN_cmp(y, mont->p) < 0);
242         assert(!y->neg);
243
244
245         if (BN_is_zero(y)) return 1;
246
247         p = mont->p;
248         max = mont->p_num_bytes;
249
250         if (bn_wexpand(y, max) == NULL) return 0;
251         for (i = y->top; i < max; i++) y->d[i] = 0;
252         y->top = max;
253
254         /* r = [r + (y_0 * p') * p] / b */
255         for (i = 0; i < max; i++)
256                 {
257                 c = BN_mul_add_rshift(y->d, p->d, max, ((y->d[0]) * mont->p_inv_b_neg) & BN_MASK2); 
258                 y->d[max - 1] = c;
259                 }
260
261         while (y->d[y->top - 1] == 0) y->top--;
262
263         if (BN_cmp(y, p) >= 0) 
264                 {
265                 if (!BN_sub(y, y, p)) return 0;
266                 }
267
268         return 1;
269         }
270
271
272 int BN_mont_mod_mul(BIGNUM *r_, BIGNUM *x, BIGNUM *y, BN_MONTGOMERY *mont, BN_CTX *ctx)
273 /* r = x * y mod p */
274 /* r != x && r! = y !!! */
275         {
276         BN_ULONG c;
277         BIGNUM *p;
278         int i, j, max;
279         BIGNUM *r;
280
281         assert(r_!= NULL && x != NULL  && y != NULL && mont != NULL);
282         assert(mont->p != NULL);
283         assert(BN_cmp(x, mont->p) < 0);
284         assert(BN_cmp(y, mont->p) < 0);
285         assert(!x->neg);
286         assert(!y->neg);
287
288         if (BN_is_zero(x) || BN_is_zero(y))
289                 {
290                 if (!BN_zero(r)) return 0;
291                 return 1;
292                 }
293
294         if (r_ == x || r_ == y)
295                 {
296                 BN_CTX_start(ctx);
297                 r = BN_CTX_get(ctx);
298                 }
299         else
300                 r = r_;
301
302         p = mont->p;
303         max = mont->p_num_bytes;
304
305         /* for multiplication we need at most max + 2 words
306                 the last one --- max + 3 --- is only as a backstop
307                 for incorrect input 
308         */
309         if (bn_wexpand(r, max + 3) == NULL) goto err;
310         for (i = 0; i < max + 3; i++) r->d[i] = 0;
311         r->top = max + 2;
312
313         for (i = 0; i < x->top; i++)
314                 {
315                 /* r = r + (r_0 + x_i * y_0) * p' * p */
316                 c = bn_mul_add_words(r->d, p->d, max, \
317                         ((r->d[0] + x->d[i] * y->d[0]) * mont->p_inv_b_neg) & BN_MASK2);
318                 if (c)
319                         {
320                         if (((r->d[max] += c) & BN_MASK2) < c)
321                                 if (((r->d[max + 1] ++) & BN_MASK2) == 0) goto err;
322                         }
323                 
324                 /* r = (r + x_i * y) / b */
325                 c = BN_mul_add_rshift(r->d, y->d, y->top, x->d[i]); 
326                 for(j = y->top; j <= max + 1; j++) r->d[j - 1] = r->d[j];
327                 if (c)
328                         {
329                         if (((r->d[y->top - 1] += c) & BN_MASK2) < c)
330                                 {
331                                 j = y->top;
332                                 while (((++ (r->d[j]) ) & BN_MASK2) == 0) 
333                                         j++;
334                                 if (j > max) goto err;
335                                 }
336                         }
337                 r->d[max + 1] = 0;
338                 }
339
340         for (i = x->top; i < max; i++)
341                 {
342                 /* r = (r + r_0 * p' * p) / b */
343                 c = BN_mul_add_rshift(r->d, p->d, max, ((r->d[0]) * mont->p_inv_b_neg) & BN_MASK2); 
344                 j = max - 1;
345                 r->d[j] = c + r->d[max];
346                 if (r->d[j++] < c) r->d[j] = r->d[++j] + 1;
347                 else r->d[j] = r->d[++j];
348                 r->d[max + 1] = 0;
349                 }
350
351         while (r->d[r->top - 1] == 0) r->top--;
352
353         if (BN_cmp(r, mont->p) >= 0) 
354                 {
355                 if (!BN_sub(r, r, mont->p)) goto err;
356                 }
357
358         if (r != r_)
359                 {
360                 if (!BN_copy(r_, r)) goto err;
361                 BN_CTX_end(ctx);
362                 }
363
364         return 1;
365
366  err:
367         if (r != r_)
368                 BN_CTX_end(ctx);
369         return 0;
370         }