Simplify BN_rand_range
[openssl.git] / crypto / bn / bn_mont2.c
1 /*
2  *
3  *      bn_mont2.c
4  *
5  *      Montgomery Modular Arithmetic Functions.
6  *
7  *      Copyright (C) Lenka Fibikova 2000
8  *
9  *
10  */
11
12
13 #include <stdio.h>
14 #include <stdlib.h>
15 #include <assert.h>
16
17 #include "bn_lcl.h"
18 #include "bn_mont2.h"
19
20 #define BN_mask_word(x, m) ((x->d[0]) & (m))
21
22 BN_MONTGOMERY *BN_mont_new()
23         {
24         BN_MONTGOMERY *ret;
25
26         ret=(BN_MONTGOMERY *)malloc(sizeof(BN_MONTGOMERY));
27
28         if (ret == NULL) return NULL;
29
30         if ((ret->p = BN_new()) == NULL)
31                 {
32                 free(ret);
33                 return NULL;
34                 }
35
36         return ret;
37         }
38
39
40 void BN_mont_clear_free(BN_MONTGOMERY *mont)
41         {
42         if (mont == NULL) return;
43
44         if (mont->p != NULL) BN_clear_free(mont->p);
45
46         mont->p_num_bytes = 0;
47         mont->R_num_bits = 0;
48         mont->p_inv_b_neg = 0;
49         }
50
51
52 int BN_to_mont(BIGNUM *x, BN_MONTGOMERY *mont, BN_CTX *ctx)
53         {
54         assert(x != NULL);
55
56         assert(mont != NULL);
57         assert(mont->p != NULL);
58
59         assert(ctx != NULL);
60
61         if (!BN_lshift(x, x, mont->R_num_bits)) return 0;
62         if (!BN_mod(x, x, mont->p, ctx)) return 0;
63
64         return 1;
65         }
66
67
68 static BN_ULONG BN_mont_inv(BIGNUM *a, int e, BN_CTX *ctx)
69 /* y = a^{-1} (mod 2^e) for an odd number a */
70         {
71         BN_ULONG y, exp, mask;
72         BIGNUM *x, *xy, *x_sh;
73         int i;
74
75         assert(a != NULL && ctx != NULL);
76         assert(e <= BN_BITS2);
77         assert(BN_is_odd(a));
78         assert(!BN_is_zero(a) && !a->neg);
79
80
81         y = 1;
82         exp = 2;
83         mask = 3;
84         if((x = BN_dup(a)) == NULL) return 0;
85         if(!BN_mask_bits(x, e)) return 0;
86
87         BN_CTX_start(ctx);
88         xy = BN_CTX_get(ctx);
89         x_sh = BN_CTX_get(ctx);
90         if (x_sh == NULL) goto err;
91
92         if (BN_copy(xy, x) == NULL) goto err;
93         if (!BN_lshift1(x_sh, x)) goto err;
94
95
96         for (i = 2; i <= e; i++)
97                 {
98                 if (exp < BN_mask_word(xy, mask))
99                         {
100                         y = y + exp;
101                         if (!BN_add(xy, xy, x_sh)) goto err;
102                         }
103
104                 exp <<= 1;
105                 if (!BN_lshift1(x_sh, x_sh)) goto err;
106                 mask <<= 1;
107                 mask++;
108                 }
109
110
111 #ifdef TEST
112         if (xy->d[0] != 1) goto err;
113 #endif
114
115         if (x != NULL) BN_clear_free(x);
116         BN_CTX_end(ctx);
117         return y;
118
119
120 err:
121         if (x != NULL) BN_clear_free(x);
122         BN_CTX_end(ctx);
123         return 0;
124         }
125
126
127 int BN_mont_set(BIGNUM *p, BN_MONTGOMERY *mont, BN_CTX *ctx)
128         {
129         assert(p != NULL && ctx != NULL);
130         assert(mont != NULL);
131         assert(mont->p != NULL);
132         assert(!BN_is_zero(p) && !p->neg);
133
134
135         mont->p_num_bytes = p->top;
136         mont->R_num_bits = (mont->p_num_bytes) * BN_BITS2;
137
138         if (BN_copy(mont->p, p) == NULL);
139         
140         mont->p_inv_b_neg =  BN_mont_inv(p, BN_BITS2, ctx);
141         mont->p_inv_b_neg = 0 - mont->p_inv_b_neg;
142
143         return 1;
144         }
145
146
147 #ifdef BN_LLONG
148 #define cpy_mul_add(r, b, a, w, c) { \
149         BN_ULLONG t; \
150         t = (BN_ULLONG)w * (a) + (b) + (c); \
151         (r)= Lw(t); \
152         (c)= Hw(t); \
153         }
154
155 BN_ULONG BN_mul_add_rshift(BN_ULONG *r, BN_ULONG *a, int num, BN_ULONG w)
156 /* r = (r + a * w) >> BN_BITS2 */
157         {
158         BN_ULONG c = 0;
159
160         mul_add(r[0], a[0], w, c);
161         if (--num == 0) return c;
162         a++;
163
164         for (;;)
165                 {
166                 cpy_mul_add(r[0], r[1], a[0], w, c);
167                 if (--num == 0) break;
168                 cpy_mul_add(r[1], r[2], a[1], w, c);
169                 if (--num == 0) break;
170                 cpy_mul_add(r[2], r[3], a[2], w, c);
171                 if (--num == 0) break;
172                 cpy_mul_add(r[3], r[4], a[3], w, c);
173                 if (--num == 0) break;
174                 a += 4;
175                 r += 4;
176                 }
177         
178         return c;
179         }
180 #else
181
182 #define cpy_mul_add(r, b, a, bl, bh, c) { \
183         BN_ULONG l,h; \
184  \
185         h=(a); \
186         l=LBITS(h); \
187         h=HBITS(h); \
188         mul64(l,h,(bl),(bh)); \
189  \
190         /* non-multiply part */ \
191         l=(l+(c))&BN_MASK2; if (l < (c)) h++; \
192         (c)=(b); \
193         l=(l+(c))&BN_MASK2; if (l < (c)) h++; \
194         (c)=h&BN_MASK2; \
195         (r)=l; \
196         }
197
198 static BN_ULONG BN_mul_add_rshift(BN_ULONG *r, BN_ULONG *a, int num, BN_ULONG w)
199 /* ret = (ret + a * w) << shift * BN_BITS2 */
200         {
201         BN_ULONG c = 0;
202         BN_ULONG bl, bh;
203
204         bl = LBITS(w);
205         bh = HBITS(w);
206
207         mul_add(r[0], a[0], bl, bh, c);
208         if (--num == 0) return c;
209         a++;
210
211         for (;;)
212                 {
213                 cpy_mul_add(r[0], r[1], a[0], bl, bh, c);
214                 if (--num == 0) break;
215                 cpy_mul_add(r[1], r[2], a[1], bl, bh, c);
216                 if (--num == 0) break;
217                 cpy_mul_add(r[2], r[3], a[2], bl, bh, c);
218                 if (--num == 0) break;
219                 cpy_mul_add(r[3], r[4], a[3], bl, bh, c);
220                 if (--num == 0) break;
221                 a += 4;
222                 r += 4;
223                 }
224         return c;
225         }
226 #endif /* BN_LLONG */
227
228
229
230 int BN_mont_red(BIGNUM *y, BN_MONTGOMERY *mont)
231 /* yR^{-1} (mod p) */
232         {
233         BIGNUM *p;
234         BN_ULONG c;
235         int i, max;
236
237         assert(y != NULL && mont != NULL);
238         assert(mont->p != NULL);
239         assert(BN_cmp(y, mont->p) < 0);
240         assert(!y->neg);
241
242
243         if (BN_is_zero(y)) return 1;
244
245         p = mont->p;
246         max = mont->p_num_bytes;
247
248         if (bn_wexpand(y, max) == NULL) return 0;
249         for (i = y->top; i < max; i++) y->d[i] = 0;
250         y->top = max;
251
252         /* r = [r + (y_0 * p') * p] / b */
253         for (i = 0; i < max; i++)
254                 {
255                 c = BN_mul_add_rshift(y->d, p->d, max, ((y->d[0]) * mont->p_inv_b_neg) & BN_MASK2); 
256                 y->d[max - 1] = c;
257                 }
258
259         while (y->d[y->top - 1] == 0) y->top--;
260
261         if (BN_cmp(y, p) >= 0) 
262                 {
263                 if (!BN_sub(y, y, p)) return 0;
264                 }
265
266         return 1;
267         }
268
269
270 int BN_mont_mod_mul(BIGNUM *r, BIGNUM *x, BIGNUM *y, BN_MONTGOMERY *mont)
271 /* r = x * y mod p */
272 /* r != x && r! = y !!! */
273         {
274         BN_ULONG c;
275         BIGNUM *p;
276         int i, j, max;
277
278         assert(r != x && r != y);
279         assert(r != NULL && x != NULL  && y != NULL && mont != NULL);
280         assert(mont->p != NULL);
281         assert(BN_cmp(x, mont->p) < 0);
282         assert(BN_cmp(y, mont->p) < 0);
283         assert(!x->neg);
284         assert(!y->neg);
285
286         if (BN_is_zero(x) || BN_is_zero(y))
287                 {
288                 if (!BN_zero(r)) return 0;
289                 return 1;
290                 }
291
292         p = mont->p;
293         max = mont->p_num_bytes;
294
295         /* for multiplication we need at most max + 2 words
296                 the last one --- max + 3 --- is only as a backstop
297                 for incorrect input 
298         */
299         if (bn_wexpand(r, max + 3) == NULL) return 0;
300         for (i = 0; i < max + 3; i++) r->d[i] = 0;
301         r->top = max + 2;
302
303         for (i = 0; i < x->top; i++)
304                 {
305                 /* r = r + (r_0 + x_i * y_0) * p' * p */
306                 c = bn_mul_add_words(r->d, p->d, max, \
307                         ((r->d[0] + x->d[i] * y->d[0]) * mont->p_inv_b_neg) & BN_MASK2);
308                 if (c)
309                         {
310                         if (((r->d[max] += c) & BN_MASK2) < c)
311                                 if (((r->d[max + 1] ++) & BN_MASK2) == 0) return 0;
312                         }
313                 
314                 /* r = (r + x_i * y) / b */
315                 c = BN_mul_add_rshift(r->d, y->d, y->top, x->d[i]); 
316                 for(j = y->top; j <= max + 1; j++) r->d[j - 1] = r->d[j];
317                 if (c)
318                         {
319                         if (((r->d[y->top - 1] += c) & BN_MASK2) < c)
320                                 {
321                                 j = y->top;
322                                 while (((++ (r->d[j]) ) & BN_MASK2) == 0) 
323                                         j++;
324                                 if (j > max) return 0;
325                                 }
326                         }
327                 r->d[max + 1] = 0;
328                 }
329
330         for (i = x->top; i < max; i++)
331                 {
332                 /* r = (r + r_0 * p' * p) / b */
333                 c = BN_mul_add_rshift(r->d, p->d, max, ((r->d[0]) * mont->p_inv_b_neg) & BN_MASK2); 
334                 j = max - 1;
335                 r->d[j] = c + r->d[max];
336                 if (r->d[j++] < c) r->d[j] = r->d[++j] + 1;
337                 else r->d[j] = r->d[++j];
338                 r->d[max + 1] = 0;
339                 }
340
341         while (r->d[r->top - 1] == 0) r->top--;
342
343         if (BN_cmp(r, mont->p) >= 0) 
344                 {
345                 if (!BN_sub(r, r, mont->p)) return 0;
346                 }
347
348         return 1;
349         }