More BN_mod_... functions.
[openssl.git] / crypto / bn / bn_mont2.c
1 /*
2  *
3  *      bn_mont2.c
4  *
5  *      Montgomery Modular Arithmetic Functions.
6  *
7  *      Copyright (C) Lenka Fibikova 2000
8  *
9  *
10  */
11
12
13 #include <stdio.h>
14 #include <stdlib.h>
15 #include <assert.h>
16
17 #include "bn.h"
18 #include "bn_modfs.h"
19 #include "bn_mont2.h"
20
21 #define BN_mask_word(x, m) ((x->d[0]) & (m))
22
23 BN_MONTGOMERY *BN_mont_new()
24 {
25         BN_MONTGOMERY *ret;
26
27         ret=(BN_MONTGOMERY *)malloc(sizeof(BN_MONTGOMERY));
28
29         if (ret == NULL) return NULL;
30
31         if ((ret->p = BN_new()) == NULL)
32         {
33                 free(ret);
34                 return NULL;
35         }
36
37         return ret;
38 }
39
40
41 void BN_mont_clear_free(BN_MONTGOMERY *mont)
42 {
43         if (mont == NULL) return;
44
45         if (mont->p != NULL) BN_clear_free(mont->p);
46
47         mont->p_num_bytes = 0;
48         mont->R_num_bits = 0;
49         mont->p_inv_b_neg = 0;
50 }
51
52 int BN_to_mont(BIGNUM *x, BN_MONTGOMERY *mont, BN_CTX *ctx)
53 {
54         assert(x != NULL);
55
56         assert(mont != NULL);
57         assert(mont->p != NULL);
58
59         assert(ctx != NULL);
60
61         if (!BN_lshift(x, x, mont->R_num_bits)) return 0;
62         if (!BN_mod(x, x, mont->p, ctx)) return 0;
63
64         return 1;
65 }
66
67
68 static BN_ULONG BN_mont_inv(BIGNUM *a, int e, BN_CTX *ctx)
69 /* y = a^{-1} (mod 2^e) for an odd number a */
70 {
71         BN_ULONG y, exp, mask;
72         BIGNUM *x, *xy, *x_sh;
73         int i;
74
75         assert(a != NULL && ctx != NULL);
76         assert(e <= BN_BITS2);
77         assert(BN_is_odd(a));
78         assert(!BN_is_zero(a) && !a->neg);
79
80
81         y = 1;
82         exp = 2;
83         mask = 3;
84         if((x = BN_dup(a)) == NULL) return 0;
85         if(!BN_mask_bits(x, e)) return 0;
86
87         xy = ctx->bn[ctx->tos]; 
88         x_sh = ctx->bn[ctx->tos + 1]; 
89         ctx->tos += 2;
90
91         if (BN_copy(xy, x) == NULL) goto err;
92         if (!BN_lshift1(x_sh, x)) goto err;
93
94
95         for (i = 2; i <= e; i++)
96         {
97                 if (exp < BN_mask_word(xy, mask))
98                 {
99                         y = y + exp;
100                         if (!BN_add(xy, xy, x_sh)) goto err;
101                 }
102
103                 exp <<= 1;
104                 if (!BN_lshift1(x_sh, x_sh)) goto err;
105                 mask <<= 1;
106                 mask++;
107         }
108
109
110 #ifdef TEST
111         if (xy->d[0] != 1) goto err;
112 #endif
113
114         if (x != NULL) BN_clear_free(x);
115         ctx->tos -= 2;
116         return y;
117
118
119 err:
120         if (x != NULL) BN_clear_free(x);
121         ctx->tos -= 2;
122         return 0;
123
124 }
125
126 int BN_mont_set(BIGNUM *p, BN_MONTGOMERY *mont, BN_CTX *ctx)
127 {
128         assert(p != NULL && ctx != NULL);
129         assert(mont != NULL);
130         assert(mont->p != NULL);
131         assert(!BN_is_zero(p) && !p->neg);
132
133
134         mont->p_num_bytes = p->top;
135         mont->R_num_bits = (mont->p_num_bytes) * BN_BITS2;
136
137         if (BN_copy(mont->p, p) == NULL);
138         
139         mont->p_inv_b_neg =  BN_mont_inv(p, BN_BITS2, ctx);
140         mont->p_inv_b_neg = 0 - mont->p_inv_b_neg;
141
142         return 1;
143 }
144
145 static int BN_cpy_mul_word(BIGNUM *ret, BIGNUM *a, BN_ULONG w)
146 /* ret = a * w */
147 {
148         if (BN_copy(ret, a) == NULL) return 0;
149
150         if (!BN_mul_word(ret, w)) return 0;
151
152         return 1;
153 }
154
155
156 int BN_mont_red(BIGNUM *y, BN_MONTGOMERY *mont, BN_CTX *ctx)
157 /* yR^{-1} (mod p) */
158 {
159         int i;
160         BIGNUM *up, *p;
161         BN_ULONG u;
162
163         assert(y != NULL && mont != NULL && ctx != NULL);
164         assert(mont->p != NULL);
165         assert(BN_cmp(y, mont->p) < 0);
166         assert(!y->neg);
167
168
169         if (BN_is_zero(y)) return 1;
170
171         p = mont->p;
172         up = ctx->bn[ctx->tos]; 
173         ctx->tos += 1;
174
175
176         for (i = 0; i < mont->p_num_bytes; i++)
177         {
178                 u = (y->d[0]) * mont->p_inv_b_neg;                      /* u = y_0 * p' */
179
180                 if (!BN_cpy_mul_word(up, p, u)) goto err;       /* up = u * p */
181
182                 if (!BN_add(y, y, up)) goto err;                        
183 #ifdef TEST
184                 if (y->d[0]) goto err;
185 #endif
186                 if (!BN_rshift(y, y, BN_BITS2)) goto err;       /* y = (y + up)/b */
187         }
188
189
190         if (BN_cmp(y, mont->p) >= 0) 
191         {
192                 if (!BN_sub(y, y, mont->p)) goto err;
193         }
194
195         ctx->tos -= 1;
196         return 1;
197
198 err:
199         ctx->tos -= 1;
200         return 0;
201
202 }
203
204
205 int BN_mont_mod_mul(BIGNUM *r, BIGNUM *x, BIGNUM *y, BN_MONTGOMERY *mont, BN_CTX *ctx)
206 /* r = x * y mod p */
207 /* r != x && r! = y !!! */
208 {
209         BIGNUM *xiy, *up;
210         BN_ULONG u;
211         int i;
212         
213
214         assert(r != x && r != y);
215         assert(r != NULL && x != NULL  && y != NULL && mont != NULL && ctx != NULL);
216         assert(mont->p != NULL);
217         assert(BN_cmp(x, mont->p) < 0);
218         assert(BN_cmp(y, mont->p) < 0);
219         assert(!x->neg);
220         assert(!y->neg);
221
222         if (BN_is_zero(x) || BN_is_zero(y))
223         {
224                 if (!BN_zero(r)) return 0;
225                 return 1;
226         }
227
228
229
230         xiy = ctx->bn[ctx->tos]; 
231         up = ctx->bn[ctx->tos + 1]; 
232         ctx->tos += 2;
233
234         if (!BN_zero(r)) goto err;
235
236         for (i = 0; i < x->top; i++)
237         {
238                 u = (r->d[0] + x->d[i] * y->d[0]) * mont->p_inv_b_neg;
239
240                 if (!BN_cpy_mul_word(xiy, y, x->d[i])) goto err;
241                 if (!BN_cpy_mul_word(up, mont->p, u)) goto err;
242
243                 if (!BN_add(r, r, xiy)) goto err;
244                 if (!BN_add(r, r, up)) goto err;
245
246 #ifdef TEST
247                 if (r->d[0]) goto err;
248 #endif
249                 if (!BN_rshift(r, r, BN_BITS2)) goto err; 
250         }
251
252         for (i = x->top; i < mont->p_num_bytes; i++)
253         {
254                 u = (r->d[0]) * mont->p_inv_b_neg;
255
256                 if (!BN_cpy_mul_word(up, mont->p, u)) goto err;
257
258                 if (!BN_add(r, r, up)) goto err;
259
260 #ifdef TEST
261                 if (r->d[0]) goto err;
262 #endif
263                 if (!BN_rshift(r, r, BN_BITS2)) goto err; 
264         }
265
266
267         if (BN_cmp(r, mont->p) >= 0) 
268         {
269                 if (!BN_sub(r, r, mont->p)) goto err;
270         }
271
272
273         ctx->tos -= 2;
274         return 1;
275
276 err:
277         ctx->tos -= 2;
278         return 0;
279 }