Add blinding in BN_GF2m_mod_inv for binary field inversions
[openssl.git] / crypto / bn / bn_gf2m.c
1 /*
2  * Copyright 2002-2017 The OpenSSL Project Authors. All Rights Reserved.
3  * Copyright (c) 2002, Oracle and/or its affiliates. All rights reserved
4  *
5  * Licensed under the OpenSSL license (the "License").  You may not use
6  * this file except in compliance with the License.  You can obtain a copy
7  * in the file LICENSE in the source distribution or at
8  * https://www.openssl.org/source/license.html
9  */
10
11 #include <assert.h>
12 #include <limits.h>
13 #include <stdio.h>
14 #include "internal/cryptlib.h"
15 #include "bn_lcl.h"
16
17 #ifndef OPENSSL_NO_EC2M
18
19 /*
20  * Maximum number of iterations before BN_GF2m_mod_solve_quad_arr should
21  * fail.
22  */
23 # define MAX_ITERATIONS 50
24
25 static const BN_ULONG SQR_tb[16] = { 0, 1, 4, 5, 16, 17, 20, 21,
26     64, 65, 68, 69, 80, 81, 84, 85
27 };
28
29 /* Platform-specific macros to accelerate squaring. */
30 # if defined(SIXTY_FOUR_BIT) || defined(SIXTY_FOUR_BIT_LONG)
31 #  define SQR1(w) \
32     SQR_tb[(w) >> 60 & 0xF] << 56 | SQR_tb[(w) >> 56 & 0xF] << 48 | \
33     SQR_tb[(w) >> 52 & 0xF] << 40 | SQR_tb[(w) >> 48 & 0xF] << 32 | \
34     SQR_tb[(w) >> 44 & 0xF] << 24 | SQR_tb[(w) >> 40 & 0xF] << 16 | \
35     SQR_tb[(w) >> 36 & 0xF] <<  8 | SQR_tb[(w) >> 32 & 0xF]
36 #  define SQR0(w) \
37     SQR_tb[(w) >> 28 & 0xF] << 56 | SQR_tb[(w) >> 24 & 0xF] << 48 | \
38     SQR_tb[(w) >> 20 & 0xF] << 40 | SQR_tb[(w) >> 16 & 0xF] << 32 | \
39     SQR_tb[(w) >> 12 & 0xF] << 24 | SQR_tb[(w) >>  8 & 0xF] << 16 | \
40     SQR_tb[(w) >>  4 & 0xF] <<  8 | SQR_tb[(w)       & 0xF]
41 # endif
42 # ifdef THIRTY_TWO_BIT
43 #  define SQR1(w) \
44     SQR_tb[(w) >> 28 & 0xF] << 24 | SQR_tb[(w) >> 24 & 0xF] << 16 | \
45     SQR_tb[(w) >> 20 & 0xF] <<  8 | SQR_tb[(w) >> 16 & 0xF]
46 #  define SQR0(w) \
47     SQR_tb[(w) >> 12 & 0xF] << 24 | SQR_tb[(w) >>  8 & 0xF] << 16 | \
48     SQR_tb[(w) >>  4 & 0xF] <<  8 | SQR_tb[(w)       & 0xF]
49 # endif
50
51 # if !defined(OPENSSL_BN_ASM_GF2m)
52 /*
53  * Product of two polynomials a, b each with degree < BN_BITS2 - 1, result is
54  * a polynomial r with degree < 2 * BN_BITS - 1 The caller MUST ensure that
55  * the variables have the right amount of space allocated.
56  */
57 #  ifdef THIRTY_TWO_BIT
58 static void bn_GF2m_mul_1x1(BN_ULONG *r1, BN_ULONG *r0, const BN_ULONG a,
59                             const BN_ULONG b)
60 {
61     register BN_ULONG h, l, s;
62     BN_ULONG tab[8], top2b = a >> 30;
63     register BN_ULONG a1, a2, a4;
64
65     a1 = a & (0x3FFFFFFF);
66     a2 = a1 << 1;
67     a4 = a2 << 1;
68
69     tab[0] = 0;
70     tab[1] = a1;
71     tab[2] = a2;
72     tab[3] = a1 ^ a2;
73     tab[4] = a4;
74     tab[5] = a1 ^ a4;
75     tab[6] = a2 ^ a4;
76     tab[7] = a1 ^ a2 ^ a4;
77
78     s = tab[b & 0x7];
79     l = s;
80     s = tab[b >> 3 & 0x7];
81     l ^= s << 3;
82     h = s >> 29;
83     s = tab[b >> 6 & 0x7];
84     l ^= s << 6;
85     h ^= s >> 26;
86     s = tab[b >> 9 & 0x7];
87     l ^= s << 9;
88     h ^= s >> 23;
89     s = tab[b >> 12 & 0x7];
90     l ^= s << 12;
91     h ^= s >> 20;
92     s = tab[b >> 15 & 0x7];
93     l ^= s << 15;
94     h ^= s >> 17;
95     s = tab[b >> 18 & 0x7];
96     l ^= s << 18;
97     h ^= s >> 14;
98     s = tab[b >> 21 & 0x7];
99     l ^= s << 21;
100     h ^= s >> 11;
101     s = tab[b >> 24 & 0x7];
102     l ^= s << 24;
103     h ^= s >> 8;
104     s = tab[b >> 27 & 0x7];
105     l ^= s << 27;
106     h ^= s >> 5;
107     s = tab[b >> 30];
108     l ^= s << 30;
109     h ^= s >> 2;
110
111     /* compensate for the top two bits of a */
112
113     if (top2b & 01) {
114         l ^= b << 30;
115         h ^= b >> 2;
116     }
117     if (top2b & 02) {
118         l ^= b << 31;
119         h ^= b >> 1;
120     }
121
122     *r1 = h;
123     *r0 = l;
124 }
125 #  endif
126 #  if defined(SIXTY_FOUR_BIT) || defined(SIXTY_FOUR_BIT_LONG)
127 static void bn_GF2m_mul_1x1(BN_ULONG *r1, BN_ULONG *r0, const BN_ULONG a,
128                             const BN_ULONG b)
129 {
130     register BN_ULONG h, l, s;
131     BN_ULONG tab[16], top3b = a >> 61;
132     register BN_ULONG a1, a2, a4, a8;
133
134     a1 = a & (0x1FFFFFFFFFFFFFFFULL);
135     a2 = a1 << 1;
136     a4 = a2 << 1;
137     a8 = a4 << 1;
138
139     tab[0] = 0;
140     tab[1] = a1;
141     tab[2] = a2;
142     tab[3] = a1 ^ a2;
143     tab[4] = a4;
144     tab[5] = a1 ^ a4;
145     tab[6] = a2 ^ a4;
146     tab[7] = a1 ^ a2 ^ a4;
147     tab[8] = a8;
148     tab[9] = a1 ^ a8;
149     tab[10] = a2 ^ a8;
150     tab[11] = a1 ^ a2 ^ a8;
151     tab[12] = a4 ^ a8;
152     tab[13] = a1 ^ a4 ^ a8;
153     tab[14] = a2 ^ a4 ^ a8;
154     tab[15] = a1 ^ a2 ^ a4 ^ a8;
155
156     s = tab[b & 0xF];
157     l = s;
158     s = tab[b >> 4 & 0xF];
159     l ^= s << 4;
160     h = s >> 60;
161     s = tab[b >> 8 & 0xF];
162     l ^= s << 8;
163     h ^= s >> 56;
164     s = tab[b >> 12 & 0xF];
165     l ^= s << 12;
166     h ^= s >> 52;
167     s = tab[b >> 16 & 0xF];
168     l ^= s << 16;
169     h ^= s >> 48;
170     s = tab[b >> 20 & 0xF];
171     l ^= s << 20;
172     h ^= s >> 44;
173     s = tab[b >> 24 & 0xF];
174     l ^= s << 24;
175     h ^= s >> 40;
176     s = tab[b >> 28 & 0xF];
177     l ^= s << 28;
178     h ^= s >> 36;
179     s = tab[b >> 32 & 0xF];
180     l ^= s << 32;
181     h ^= s >> 32;
182     s = tab[b >> 36 & 0xF];
183     l ^= s << 36;
184     h ^= s >> 28;
185     s = tab[b >> 40 & 0xF];
186     l ^= s << 40;
187     h ^= s >> 24;
188     s = tab[b >> 44 & 0xF];
189     l ^= s << 44;
190     h ^= s >> 20;
191     s = tab[b >> 48 & 0xF];
192     l ^= s << 48;
193     h ^= s >> 16;
194     s = tab[b >> 52 & 0xF];
195     l ^= s << 52;
196     h ^= s >> 12;
197     s = tab[b >> 56 & 0xF];
198     l ^= s << 56;
199     h ^= s >> 8;
200     s = tab[b >> 60];
201     l ^= s << 60;
202     h ^= s >> 4;
203
204     /* compensate for the top three bits of a */
205
206     if (top3b & 01) {
207         l ^= b << 61;
208         h ^= b >> 3;
209     }
210     if (top3b & 02) {
211         l ^= b << 62;
212         h ^= b >> 2;
213     }
214     if (top3b & 04) {
215         l ^= b << 63;
216         h ^= b >> 1;
217     }
218
219     *r1 = h;
220     *r0 = l;
221 }
222 #  endif
223
224 /*
225  * Product of two polynomials a, b each with degree < 2 * BN_BITS2 - 1,
226  * result is a polynomial r with degree < 4 * BN_BITS2 - 1 The caller MUST
227  * ensure that the variables have the right amount of space allocated.
228  */
229 static void bn_GF2m_mul_2x2(BN_ULONG *r, const BN_ULONG a1, const BN_ULONG a0,
230                             const BN_ULONG b1, const BN_ULONG b0)
231 {
232     BN_ULONG m1, m0;
233     /* r[3] = h1, r[2] = h0; r[1] = l1; r[0] = l0 */
234     bn_GF2m_mul_1x1(r + 3, r + 2, a1, b1);
235     bn_GF2m_mul_1x1(r + 1, r, a0, b0);
236     bn_GF2m_mul_1x1(&m1, &m0, a0 ^ a1, b0 ^ b1);
237     /* Correction on m1 ^= l1 ^ h1; m0 ^= l0 ^ h0; */
238     r[2] ^= m1 ^ r[1] ^ r[3];   /* h0 ^= m1 ^ l1 ^ h1; */
239     r[1] = r[3] ^ r[2] ^ r[0] ^ m1 ^ m0; /* l1 ^= l0 ^ h0 ^ m0; */
240 }
241 # else
242 void bn_GF2m_mul_2x2(BN_ULONG *r, BN_ULONG a1, BN_ULONG a0, BN_ULONG b1,
243                      BN_ULONG b0);
244 # endif
245
246 /*
247  * Add polynomials a and b and store result in r; r could be a or b, a and b
248  * could be equal; r is the bitwise XOR of a and b.
249  */
250 int BN_GF2m_add(BIGNUM *r, const BIGNUM *a, const BIGNUM *b)
251 {
252     int i;
253     const BIGNUM *at, *bt;
254
255     bn_check_top(a);
256     bn_check_top(b);
257
258     if (a->top < b->top) {
259         at = b;
260         bt = a;
261     } else {
262         at = a;
263         bt = b;
264     }
265
266     if (bn_wexpand(r, at->top) == NULL)
267         return 0;
268
269     for (i = 0; i < bt->top; i++) {
270         r->d[i] = at->d[i] ^ bt->d[i];
271     }
272     for (; i < at->top; i++) {
273         r->d[i] = at->d[i];
274     }
275
276     r->top = at->top;
277     bn_correct_top(r);
278
279     return 1;
280 }
281
282 /*-
283  * Some functions allow for representation of the irreducible polynomials
284  * as an int[], say p.  The irreducible f(t) is then of the form:
285  *     t^p[0] + t^p[1] + ... + t^p[k]
286  * where m = p[0] > p[1] > ... > p[k] = 0.
287  */
288
289 /* Performs modular reduction of a and store result in r.  r could be a. */
290 int BN_GF2m_mod_arr(BIGNUM *r, const BIGNUM *a, const int p[])
291 {
292     int j, k;
293     int n, dN, d0, d1;
294     BN_ULONG zz, *z;
295
296     bn_check_top(a);
297
298     if (!p[0]) {
299         /* reduction mod 1 => return 0 */
300         BN_zero(r);
301         return 1;
302     }
303
304     /*
305      * Since the algorithm does reduction in the r value, if a != r, copy the
306      * contents of a into r so we can do reduction in r.
307      */
308     if (a != r) {
309         if (!bn_wexpand(r, a->top))
310             return 0;
311         for (j = 0; j < a->top; j++) {
312             r->d[j] = a->d[j];
313         }
314         r->top = a->top;
315     }
316     z = r->d;
317
318     /* start reduction */
319     dN = p[0] / BN_BITS2;
320     for (j = r->top - 1; j > dN;) {
321         zz = z[j];
322         if (z[j] == 0) {
323             j--;
324             continue;
325         }
326         z[j] = 0;
327
328         for (k = 1; p[k] != 0; k++) {
329             /* reducing component t^p[k] */
330             n = p[0] - p[k];
331             d0 = n % BN_BITS2;
332             d1 = BN_BITS2 - d0;
333             n /= BN_BITS2;
334             z[j - n] ^= (zz >> d0);
335             if (d0)
336                 z[j - n - 1] ^= (zz << d1);
337         }
338
339         /* reducing component t^0 */
340         n = dN;
341         d0 = p[0] % BN_BITS2;
342         d1 = BN_BITS2 - d0;
343         z[j - n] ^= (zz >> d0);
344         if (d0)
345             z[j - n - 1] ^= (zz << d1);
346     }
347
348     /* final round of reduction */
349     while (j == dN) {
350
351         d0 = p[0] % BN_BITS2;
352         zz = z[dN] >> d0;
353         if (zz == 0)
354             break;
355         d1 = BN_BITS2 - d0;
356
357         /* clear up the top d1 bits */
358         if (d0)
359             z[dN] = (z[dN] << d1) >> d1;
360         else
361             z[dN] = 0;
362         z[0] ^= zz;             /* reduction t^0 component */
363
364         for (k = 1; p[k] != 0; k++) {
365             BN_ULONG tmp_ulong;
366
367             /* reducing component t^p[k] */
368             n = p[k] / BN_BITS2;
369             d0 = p[k] % BN_BITS2;
370             d1 = BN_BITS2 - d0;
371             z[n] ^= (zz << d0);
372             if (d0 && (tmp_ulong = zz >> d1))
373                 z[n + 1] ^= tmp_ulong;
374         }
375
376     }
377
378     bn_correct_top(r);
379     return 1;
380 }
381
382 /*
383  * Performs modular reduction of a by p and store result in r.  r could be a.
384  * This function calls down to the BN_GF2m_mod_arr implementation; this wrapper
385  * function is only provided for convenience; for best performance, use the
386  * BN_GF2m_mod_arr function.
387  */
388 int BN_GF2m_mod(BIGNUM *r, const BIGNUM *a, const BIGNUM *p)
389 {
390     int ret = 0;
391     int arr[6];
392     bn_check_top(a);
393     bn_check_top(p);
394     ret = BN_GF2m_poly2arr(p, arr, OSSL_NELEM(arr));
395     if (!ret || ret > (int)OSSL_NELEM(arr)) {
396         BNerr(BN_F_BN_GF2M_MOD, BN_R_INVALID_LENGTH);
397         return 0;
398     }
399     ret = BN_GF2m_mod_arr(r, a, arr);
400     bn_check_top(r);
401     return ret;
402 }
403
404 /*
405  * Compute the product of two polynomials a and b, reduce modulo p, and store
406  * the result in r.  r could be a or b; a could be b.
407  */
408 int BN_GF2m_mod_mul_arr(BIGNUM *r, const BIGNUM *a, const BIGNUM *b,
409                         const int p[], BN_CTX *ctx)
410 {
411     int zlen, i, j, k, ret = 0;
412     BIGNUM *s;
413     BN_ULONG x1, x0, y1, y0, zz[4];
414
415     bn_check_top(a);
416     bn_check_top(b);
417
418     if (a == b) {
419         return BN_GF2m_mod_sqr_arr(r, a, p, ctx);
420     }
421
422     BN_CTX_start(ctx);
423     if ((s = BN_CTX_get(ctx)) == NULL)
424         goto err;
425
426     zlen = a->top + b->top + 4;
427     if (!bn_wexpand(s, zlen))
428         goto err;
429     s->top = zlen;
430
431     for (i = 0; i < zlen; i++)
432         s->d[i] = 0;
433
434     for (j = 0; j < b->top; j += 2) {
435         y0 = b->d[j];
436         y1 = ((j + 1) == b->top) ? 0 : b->d[j + 1];
437         for (i = 0; i < a->top; i += 2) {
438             x0 = a->d[i];
439             x1 = ((i + 1) == a->top) ? 0 : a->d[i + 1];
440             bn_GF2m_mul_2x2(zz, x1, x0, y1, y0);
441             for (k = 0; k < 4; k++)
442                 s->d[i + j + k] ^= zz[k];
443         }
444     }
445
446     bn_correct_top(s);
447     if (BN_GF2m_mod_arr(r, s, p))
448         ret = 1;
449     bn_check_top(r);
450
451  err:
452     BN_CTX_end(ctx);
453     return ret;
454 }
455
456 /*
457  * Compute the product of two polynomials a and b, reduce modulo p, and store
458  * the result in r.  r could be a or b; a could equal b. This function calls
459  * down to the BN_GF2m_mod_mul_arr implementation; this wrapper function is
460  * only provided for convenience; for best performance, use the
461  * BN_GF2m_mod_mul_arr function.
462  */
463 int BN_GF2m_mod_mul(BIGNUM *r, const BIGNUM *a, const BIGNUM *b,
464                     const BIGNUM *p, BN_CTX *ctx)
465 {
466     int ret = 0;
467     const int max = BN_num_bits(p) + 1;
468     int *arr = NULL;
469     bn_check_top(a);
470     bn_check_top(b);
471     bn_check_top(p);
472     if ((arr = OPENSSL_malloc(sizeof(*arr) * max)) == NULL)
473         goto err;
474     ret = BN_GF2m_poly2arr(p, arr, max);
475     if (!ret || ret > max) {
476         BNerr(BN_F_BN_GF2M_MOD_MUL, BN_R_INVALID_LENGTH);
477         goto err;
478     }
479     ret = BN_GF2m_mod_mul_arr(r, a, b, arr, ctx);
480     bn_check_top(r);
481  err:
482     OPENSSL_free(arr);
483     return ret;
484 }
485
486 /* Square a, reduce the result mod p, and store it in a.  r could be a. */
487 int BN_GF2m_mod_sqr_arr(BIGNUM *r, const BIGNUM *a, const int p[],
488                         BN_CTX *ctx)
489 {
490     int i, ret = 0;
491     BIGNUM *s;
492
493     bn_check_top(a);
494     BN_CTX_start(ctx);
495     if ((s = BN_CTX_get(ctx)) == NULL)
496         goto err;
497     if (!bn_wexpand(s, 2 * a->top))
498         goto err;
499
500     for (i = a->top - 1; i >= 0; i--) {
501         s->d[2 * i + 1] = SQR1(a->d[i]);
502         s->d[2 * i] = SQR0(a->d[i]);
503     }
504
505     s->top = 2 * a->top;
506     bn_correct_top(s);
507     if (!BN_GF2m_mod_arr(r, s, p))
508         goto err;
509     bn_check_top(r);
510     ret = 1;
511  err:
512     BN_CTX_end(ctx);
513     return ret;
514 }
515
516 /*
517  * Square a, reduce the result mod p, and store it in a.  r could be a. This
518  * function calls down to the BN_GF2m_mod_sqr_arr implementation; this
519  * wrapper function is only provided for convenience; for best performance,
520  * use the BN_GF2m_mod_sqr_arr function.
521  */
522 int BN_GF2m_mod_sqr(BIGNUM *r, const BIGNUM *a, const BIGNUM *p, BN_CTX *ctx)
523 {
524     int ret = 0;
525     const int max = BN_num_bits(p) + 1;
526     int *arr = NULL;
527
528     bn_check_top(a);
529     bn_check_top(p);
530     if ((arr = OPENSSL_malloc(sizeof(*arr) * max)) == NULL)
531         goto err;
532     ret = BN_GF2m_poly2arr(p, arr, max);
533     if (!ret || ret > max) {
534         BNerr(BN_F_BN_GF2M_MOD_SQR, BN_R_INVALID_LENGTH);
535         goto err;
536     }
537     ret = BN_GF2m_mod_sqr_arr(r, a, arr, ctx);
538     bn_check_top(r);
539  err:
540     OPENSSL_free(arr);
541     return ret;
542 }
543
544 /*
545  * Invert a, reduce modulo p, and store the result in r. r could be a. Uses
546  * Modified Almost Inverse Algorithm (Algorithm 10) from Hankerson, D.,
547  * Hernandez, J.L., and Menezes, A.  "Software Implementation of Elliptic
548  * Curve Cryptography Over Binary Fields".
549  */
550 static int BN_GF2m_mod_inv_vartime(BIGNUM *r, const BIGNUM *a,
551                                    const BIGNUM *p, BN_CTX *ctx)
552 {
553     BIGNUM *b, *c = NULL, *u = NULL, *v = NULL, *tmp;
554     int ret = 0;
555
556     bn_check_top(a);
557     bn_check_top(p);
558
559     BN_CTX_start(ctx);
560
561     b = BN_CTX_get(ctx);
562     c = BN_CTX_get(ctx);
563     u = BN_CTX_get(ctx);
564     v = BN_CTX_get(ctx);
565     if (v == NULL)
566         goto err;
567
568     if (!BN_GF2m_mod(u, a, p))
569         goto err;
570     if (BN_is_zero(u))
571         goto err;
572
573     if (!BN_copy(v, p))
574         goto err;
575 # if 0
576     if (!BN_one(b))
577         goto err;
578
579     while (1) {
580         while (!BN_is_odd(u)) {
581             if (BN_is_zero(u))
582                 goto err;
583             if (!BN_rshift1(u, u))
584                 goto err;
585             if (BN_is_odd(b)) {
586                 if (!BN_GF2m_add(b, b, p))
587                     goto err;
588             }
589             if (!BN_rshift1(b, b))
590                 goto err;
591         }
592
593         if (BN_abs_is_word(u, 1))
594             break;
595
596         if (BN_num_bits(u) < BN_num_bits(v)) {
597             tmp = u;
598             u = v;
599             v = tmp;
600             tmp = b;
601             b = c;
602             c = tmp;
603         }
604
605         if (!BN_GF2m_add(u, u, v))
606             goto err;
607         if (!BN_GF2m_add(b, b, c))
608             goto err;
609     }
610 # else
611     {
612         int i;
613         int ubits = BN_num_bits(u);
614         int vbits = BN_num_bits(v); /* v is copy of p */
615         int top = p->top;
616         BN_ULONG *udp, *bdp, *vdp, *cdp;
617
618         if (!bn_wexpand(u, top))
619             goto err;
620         udp = u->d;
621         for (i = u->top; i < top; i++)
622             udp[i] = 0;
623         u->top = top;
624         if (!bn_wexpand(b, top))
625           goto err;
626         bdp = b->d;
627         bdp[0] = 1;
628         for (i = 1; i < top; i++)
629             bdp[i] = 0;
630         b->top = top;
631         if (!bn_wexpand(c, top))
632           goto err;
633         cdp = c->d;
634         for (i = 0; i < top; i++)
635             cdp[i] = 0;
636         c->top = top;
637         vdp = v->d;             /* It pays off to "cache" *->d pointers,
638                                  * because it allows optimizer to be more
639                                  * aggressive. But we don't have to "cache"
640                                  * p->d, because *p is declared 'const'... */
641         while (1) {
642             while (ubits && !(udp[0] & 1)) {
643                 BN_ULONG u0, u1, b0, b1, mask;
644
645                 u0 = udp[0];
646                 b0 = bdp[0];
647                 mask = (BN_ULONG)0 - (b0 & 1);
648                 b0 ^= p->d[0] & mask;
649                 for (i = 0; i < top - 1; i++) {
650                     u1 = udp[i + 1];
651                     udp[i] = ((u0 >> 1) | (u1 << (BN_BITS2 - 1))) & BN_MASK2;
652                     u0 = u1;
653                     b1 = bdp[i + 1] ^ (p->d[i + 1] & mask);
654                     bdp[i] = ((b0 >> 1) | (b1 << (BN_BITS2 - 1))) & BN_MASK2;
655                     b0 = b1;
656                 }
657                 udp[i] = u0 >> 1;
658                 bdp[i] = b0 >> 1;
659                 ubits--;
660             }
661
662             if (ubits <= BN_BITS2) {
663                 if (udp[0] == 0) /* poly was reducible */
664                     goto err;
665                 if (udp[0] == 1)
666                     break;
667             }
668
669             if (ubits < vbits) {
670                 i = ubits;
671                 ubits = vbits;
672                 vbits = i;
673                 tmp = u;
674                 u = v;
675                 v = tmp;
676                 tmp = b;
677                 b = c;
678                 c = tmp;
679                 udp = vdp;
680                 vdp = v->d;
681                 bdp = cdp;
682                 cdp = c->d;
683             }
684             for (i = 0; i < top; i++) {
685                 udp[i] ^= vdp[i];
686                 bdp[i] ^= cdp[i];
687             }
688             if (ubits == vbits) {
689                 BN_ULONG ul;
690                 int utop = (ubits - 1) / BN_BITS2;
691
692                 while ((ul = udp[utop]) == 0 && utop)
693                     utop--;
694                 ubits = utop * BN_BITS2 + BN_num_bits_word(ul);
695             }
696         }
697         bn_correct_top(b);
698     }
699 # endif
700
701     if (!BN_copy(r, b))
702         goto err;
703     bn_check_top(r);
704     ret = 1;
705
706  err:
707 # ifdef BN_DEBUG                /* BN_CTX_end would complain about the
708                                  * expanded form */
709     bn_correct_top(c);
710     bn_correct_top(u);
711     bn_correct_top(v);
712 # endif
713     BN_CTX_end(ctx);
714     return ret;
715 }
716
717 /*-
718  * Wrapper for BN_GF2m_mod_inv_vartime that blinds the input before calling.
719  * This is not constant time.
720  * But it does eliminate first order deduction on the input.
721  */
722 int BN_GF2m_mod_inv(BIGNUM *r, const BIGNUM *a, const BIGNUM *p, BN_CTX *ctx)
723 {
724     BIGNUM *b = NULL;
725     int ret = 0;
726
727     BN_CTX_start(ctx);
728     if ((b = BN_CTX_get(ctx)) == NULL)
729         goto err;
730
731     /* generate blinding value */
732     do {
733         if (!BN_priv_rand(b, BN_num_bits(p) - 1,
734                           BN_RAND_TOP_ANY, BN_RAND_BOTTOM_ANY))
735             goto err;
736     } while (BN_is_zero(b));
737
738     /* r := a * b */
739     if (!BN_GF2m_mod_mul(r, a, b, p, ctx))
740         goto err;
741
742     /* r := 1/(a * b) */
743     if (!BN_GF2m_mod_inv_vartime(r, r, p, ctx))
744         goto err;
745
746     /* r := b/(a * b) = 1/a */
747     if (!BN_GF2m_mod_mul(r, r, b, p, ctx))
748         goto err;
749
750     ret = 1;
751
752  err:
753     BN_CTX_end(ctx);
754     return ret;
755 }
756
757 /*
758  * Invert xx, reduce modulo p, and store the result in r. r could be xx.
759  * This function calls down to the BN_GF2m_mod_inv implementation; this
760  * wrapper function is only provided for convenience; for best performance,
761  * use the BN_GF2m_mod_inv function.
762  */
763 int BN_GF2m_mod_inv_arr(BIGNUM *r, const BIGNUM *xx, const int p[],
764                         BN_CTX *ctx)
765 {
766     BIGNUM *field;
767     int ret = 0;
768
769     bn_check_top(xx);
770     BN_CTX_start(ctx);
771     if ((field = BN_CTX_get(ctx)) == NULL)
772         goto err;
773     if (!BN_GF2m_arr2poly(p, field))
774         goto err;
775
776     ret = BN_GF2m_mod_inv(r, xx, field, ctx);
777     bn_check_top(r);
778
779  err:
780     BN_CTX_end(ctx);
781     return ret;
782 }
783
784 /*
785  * Divide y by x, reduce modulo p, and store the result in r. r could be x
786  * or y, x could equal y.
787  */
788 int BN_GF2m_mod_div(BIGNUM *r, const BIGNUM *y, const BIGNUM *x,
789                     const BIGNUM *p, BN_CTX *ctx)
790 {
791     BIGNUM *xinv = NULL;
792     int ret = 0;
793
794     bn_check_top(y);
795     bn_check_top(x);
796     bn_check_top(p);
797
798     BN_CTX_start(ctx);
799     xinv = BN_CTX_get(ctx);
800     if (xinv == NULL)
801         goto err;
802
803     if (!BN_GF2m_mod_inv(xinv, x, p, ctx))
804         goto err;
805     if (!BN_GF2m_mod_mul(r, y, xinv, p, ctx))
806         goto err;
807     bn_check_top(r);
808     ret = 1;
809
810  err:
811     BN_CTX_end(ctx);
812     return ret;
813 }
814
815 /*
816  * Divide yy by xx, reduce modulo p, and store the result in r. r could be xx
817  * * or yy, xx could equal yy. This function calls down to the
818  * BN_GF2m_mod_div implementation; this wrapper function is only provided for
819  * convenience; for best performance, use the BN_GF2m_mod_div function.
820  */
821 int BN_GF2m_mod_div_arr(BIGNUM *r, const BIGNUM *yy, const BIGNUM *xx,
822                         const int p[], BN_CTX *ctx)
823 {
824     BIGNUM *field;
825     int ret = 0;
826
827     bn_check_top(yy);
828     bn_check_top(xx);
829
830     BN_CTX_start(ctx);
831     if ((field = BN_CTX_get(ctx)) == NULL)
832         goto err;
833     if (!BN_GF2m_arr2poly(p, field))
834         goto err;
835
836     ret = BN_GF2m_mod_div(r, yy, xx, field, ctx);
837     bn_check_top(r);
838
839  err:
840     BN_CTX_end(ctx);
841     return ret;
842 }
843
844 /*
845  * Compute the bth power of a, reduce modulo p, and store the result in r.  r
846  * could be a. Uses simple square-and-multiply algorithm A.5.1 from IEEE
847  * P1363.
848  */
849 int BN_GF2m_mod_exp_arr(BIGNUM *r, const BIGNUM *a, const BIGNUM *b,
850                         const int p[], BN_CTX *ctx)
851 {
852     int ret = 0, i, n;
853     BIGNUM *u;
854
855     bn_check_top(a);
856     bn_check_top(b);
857
858     if (BN_is_zero(b))
859         return BN_one(r);
860
861     if (BN_abs_is_word(b, 1))
862         return (BN_copy(r, a) != NULL);
863
864     BN_CTX_start(ctx);
865     if ((u = BN_CTX_get(ctx)) == NULL)
866         goto err;
867
868     if (!BN_GF2m_mod_arr(u, a, p))
869         goto err;
870
871     n = BN_num_bits(b) - 1;
872     for (i = n - 1; i >= 0; i--) {
873         if (!BN_GF2m_mod_sqr_arr(u, u, p, ctx))
874             goto err;
875         if (BN_is_bit_set(b, i)) {
876             if (!BN_GF2m_mod_mul_arr(u, u, a, p, ctx))
877                 goto err;
878         }
879     }
880     if (!BN_copy(r, u))
881         goto err;
882     bn_check_top(r);
883     ret = 1;
884  err:
885     BN_CTX_end(ctx);
886     return ret;
887 }
888
889 /*
890  * Compute the bth power of a, reduce modulo p, and store the result in r.  r
891  * could be a. This function calls down to the BN_GF2m_mod_exp_arr
892  * implementation; this wrapper function is only provided for convenience;
893  * for best performance, use the BN_GF2m_mod_exp_arr function.
894  */
895 int BN_GF2m_mod_exp(BIGNUM *r, const BIGNUM *a, const BIGNUM *b,
896                     const BIGNUM *p, BN_CTX *ctx)
897 {
898     int ret = 0;
899     const int max = BN_num_bits(p) + 1;
900     int *arr = NULL;
901     bn_check_top(a);
902     bn_check_top(b);
903     bn_check_top(p);
904     if ((arr = OPENSSL_malloc(sizeof(*arr) * max)) == NULL)
905         goto err;
906     ret = BN_GF2m_poly2arr(p, arr, max);
907     if (!ret || ret > max) {
908         BNerr(BN_F_BN_GF2M_MOD_EXP, BN_R_INVALID_LENGTH);
909         goto err;
910     }
911     ret = BN_GF2m_mod_exp_arr(r, a, b, arr, ctx);
912     bn_check_top(r);
913  err:
914     OPENSSL_free(arr);
915     return ret;
916 }
917
918 /*
919  * Compute the square root of a, reduce modulo p, and store the result in r.
920  * r could be a. Uses exponentiation as in algorithm A.4.1 from IEEE P1363.
921  */
922 int BN_GF2m_mod_sqrt_arr(BIGNUM *r, const BIGNUM *a, const int p[],
923                          BN_CTX *ctx)
924 {
925     int ret = 0;
926     BIGNUM *u;
927
928     bn_check_top(a);
929
930     if (!p[0]) {
931         /* reduction mod 1 => return 0 */
932         BN_zero(r);
933         return 1;
934     }
935
936     BN_CTX_start(ctx);
937     if ((u = BN_CTX_get(ctx)) == NULL)
938         goto err;
939
940     if (!BN_set_bit(u, p[0] - 1))
941         goto err;
942     ret = BN_GF2m_mod_exp_arr(r, a, u, p, ctx);
943     bn_check_top(r);
944
945  err:
946     BN_CTX_end(ctx);
947     return ret;
948 }
949
950 /*
951  * Compute the square root of a, reduce modulo p, and store the result in r.
952  * r could be a. This function calls down to the BN_GF2m_mod_sqrt_arr
953  * implementation; this wrapper function is only provided for convenience;
954  * for best performance, use the BN_GF2m_mod_sqrt_arr function.
955  */
956 int BN_GF2m_mod_sqrt(BIGNUM *r, const BIGNUM *a, const BIGNUM *p, BN_CTX *ctx)
957 {
958     int ret = 0;
959     const int max = BN_num_bits(p) + 1;
960     int *arr = NULL;
961     bn_check_top(a);
962     bn_check_top(p);
963     if ((arr = OPENSSL_malloc(sizeof(*arr) * max)) == NULL)
964         goto err;
965     ret = BN_GF2m_poly2arr(p, arr, max);
966     if (!ret || ret > max) {
967         BNerr(BN_F_BN_GF2M_MOD_SQRT, BN_R_INVALID_LENGTH);
968         goto err;
969     }
970     ret = BN_GF2m_mod_sqrt_arr(r, a, arr, ctx);
971     bn_check_top(r);
972  err:
973     OPENSSL_free(arr);
974     return ret;
975 }
976
977 /*
978  * Find r such that r^2 + r = a mod p.  r could be a. If no r exists returns
979  * 0. Uses algorithms A.4.7 and A.4.6 from IEEE P1363.
980  */
981 int BN_GF2m_mod_solve_quad_arr(BIGNUM *r, const BIGNUM *a_, const int p[],
982                                BN_CTX *ctx)
983 {
984     int ret = 0, count = 0, j;
985     BIGNUM *a, *z, *rho, *w, *w2, *tmp;
986
987     bn_check_top(a_);
988
989     if (!p[0]) {
990         /* reduction mod 1 => return 0 */
991         BN_zero(r);
992         return 1;
993     }
994
995     BN_CTX_start(ctx);
996     a = BN_CTX_get(ctx);
997     z = BN_CTX_get(ctx);
998     w = BN_CTX_get(ctx);
999     if (w == NULL)
1000         goto err;
1001
1002     if (!BN_GF2m_mod_arr(a, a_, p))
1003         goto err;
1004
1005     if (BN_is_zero(a)) {
1006         BN_zero(r);
1007         ret = 1;
1008         goto err;
1009     }
1010
1011     if (p[0] & 0x1) {           /* m is odd */
1012         /* compute half-trace of a */
1013         if (!BN_copy(z, a))
1014             goto err;
1015         for (j = 1; j <= (p[0] - 1) / 2; j++) {
1016             if (!BN_GF2m_mod_sqr_arr(z, z, p, ctx))
1017                 goto err;
1018             if (!BN_GF2m_mod_sqr_arr(z, z, p, ctx))
1019                 goto err;
1020             if (!BN_GF2m_add(z, z, a))
1021                 goto err;
1022         }
1023
1024     } else {                    /* m is even */
1025
1026         rho = BN_CTX_get(ctx);
1027         w2 = BN_CTX_get(ctx);
1028         tmp = BN_CTX_get(ctx);
1029         if (tmp == NULL)
1030             goto err;
1031         do {
1032             if (!BN_priv_rand(rho, p[0], BN_RAND_TOP_ONE, BN_RAND_BOTTOM_ANY))
1033                 goto err;
1034             if (!BN_GF2m_mod_arr(rho, rho, p))
1035                 goto err;
1036             BN_zero(z);
1037             if (!BN_copy(w, rho))
1038                 goto err;
1039             for (j = 1; j <= p[0] - 1; j++) {
1040                 if (!BN_GF2m_mod_sqr_arr(z, z, p, ctx))
1041                     goto err;
1042                 if (!BN_GF2m_mod_sqr_arr(w2, w, p, ctx))
1043                     goto err;
1044                 if (!BN_GF2m_mod_mul_arr(tmp, w2, a, p, ctx))
1045                     goto err;
1046                 if (!BN_GF2m_add(z, z, tmp))
1047                     goto err;
1048                 if (!BN_GF2m_add(w, w2, rho))
1049                     goto err;
1050             }
1051             count++;
1052         } while (BN_is_zero(w) && (count < MAX_ITERATIONS));
1053         if (BN_is_zero(w)) {
1054             BNerr(BN_F_BN_GF2M_MOD_SOLVE_QUAD_ARR, BN_R_TOO_MANY_ITERATIONS);
1055             goto err;
1056         }
1057     }
1058
1059     if (!BN_GF2m_mod_sqr_arr(w, z, p, ctx))
1060         goto err;
1061     if (!BN_GF2m_add(w, z, w))
1062         goto err;
1063     if (BN_GF2m_cmp(w, a)) {
1064         BNerr(BN_F_BN_GF2M_MOD_SOLVE_QUAD_ARR, BN_R_NO_SOLUTION);
1065         goto err;
1066     }
1067
1068     if (!BN_copy(r, z))
1069         goto err;
1070     bn_check_top(r);
1071
1072     ret = 1;
1073
1074  err:
1075     BN_CTX_end(ctx);
1076     return ret;
1077 }
1078
1079 /*
1080  * Find r such that r^2 + r = a mod p.  r could be a. If no r exists returns
1081  * 0. This function calls down to the BN_GF2m_mod_solve_quad_arr
1082  * implementation; this wrapper function is only provided for convenience;
1083  * for best performance, use the BN_GF2m_mod_solve_quad_arr function.
1084  */
1085 int BN_GF2m_mod_solve_quad(BIGNUM *r, const BIGNUM *a, const BIGNUM *p,
1086                            BN_CTX *ctx)
1087 {
1088     int ret = 0;
1089     const int max = BN_num_bits(p) + 1;
1090     int *arr = NULL;
1091     bn_check_top(a);
1092     bn_check_top(p);
1093     if ((arr = OPENSSL_malloc(sizeof(*arr) * max)) == NULL)
1094         goto err;
1095     ret = BN_GF2m_poly2arr(p, arr, max);
1096     if (!ret || ret > max) {
1097         BNerr(BN_F_BN_GF2M_MOD_SOLVE_QUAD, BN_R_INVALID_LENGTH);
1098         goto err;
1099     }
1100     ret = BN_GF2m_mod_solve_quad_arr(r, a, arr, ctx);
1101     bn_check_top(r);
1102  err:
1103     OPENSSL_free(arr);
1104     return ret;
1105 }
1106
1107 /*
1108  * Convert the bit-string representation of a polynomial ( \sum_{i=0}^n a_i *
1109  * x^i) into an array of integers corresponding to the bits with non-zero
1110  * coefficient.  Array is terminated with -1. Up to max elements of the array
1111  * will be filled.  Return value is total number of array elements that would
1112  * be filled if array was large enough.
1113  */
1114 int BN_GF2m_poly2arr(const BIGNUM *a, int p[], int max)
1115 {
1116     int i, j, k = 0;
1117     BN_ULONG mask;
1118
1119     if (BN_is_zero(a))
1120         return 0;
1121
1122     for (i = a->top - 1; i >= 0; i--) {
1123         if (!a->d[i])
1124             /* skip word if a->d[i] == 0 */
1125             continue;
1126         mask = BN_TBIT;
1127         for (j = BN_BITS2 - 1; j >= 0; j--) {
1128             if (a->d[i] & mask) {
1129                 if (k < max)
1130                     p[k] = BN_BITS2 * i + j;
1131                 k++;
1132             }
1133             mask >>= 1;
1134         }
1135     }
1136
1137     if (k < max) {
1138         p[k] = -1;
1139         k++;
1140     }
1141
1142     return k;
1143 }
1144
1145 /*
1146  * Convert the coefficient array representation of a polynomial to a
1147  * bit-string.  The array must be terminated by -1.
1148  */
1149 int BN_GF2m_arr2poly(const int p[], BIGNUM *a)
1150 {
1151     int i;
1152
1153     bn_check_top(a);
1154     BN_zero(a);
1155     for (i = 0; p[i] != -1; i++) {
1156         if (BN_set_bit(a, p[i]) == 0)
1157             return 0;
1158     }
1159     bn_check_top(a);
1160
1161     return 1;
1162 }
1163
1164 #endif