bed9fca4d9337a84322bffc353d76e65c94cbc6c
[openssl.git] / crypto / bn / bn_gcd.c
1 /*
2  * Copyright 1995-2018 The OpenSSL Project Authors. All Rights Reserved.
3  *
4  * Licensed under the Apache License 2.0 (the "License").  You may not use
5  * this file except in compliance with the License.  You can obtain a copy
6  * in the file LICENSE in the source distribution or at
7  * https://www.openssl.org/source/license.html
8  */
9
10 #include "internal/cryptlib.h"
11 #include "bn_local.h"
12
13 /* solves ax == 1 (mod n) */
14 static BIGNUM *BN_mod_inverse_no_branch(BIGNUM *in,
15                                         const BIGNUM *a, const BIGNUM *n,
16                                         BN_CTX *ctx);
17
18 BIGNUM *BN_mod_inverse(BIGNUM *in,
19                        const BIGNUM *a, const BIGNUM *n, BN_CTX *ctx)
20 {
21     BIGNUM *rv;
22     int noinv;
23     rv = int_bn_mod_inverse(in, a, n, ctx, &noinv);
24     if (noinv)
25         BNerr(BN_F_BN_MOD_INVERSE, BN_R_NO_INVERSE);
26     return rv;
27 }
28
29 BIGNUM *int_bn_mod_inverse(BIGNUM *in,
30                            const BIGNUM *a, const BIGNUM *n, BN_CTX *ctx,
31                            int *pnoinv)
32 {
33     BIGNUM *A, *B, *X, *Y, *M, *D, *T, *R = NULL;
34     BIGNUM *ret = NULL;
35     int sign;
36
37     /* This is invalid input so we don't worry about constant time here */
38     if (BN_abs_is_word(n, 1) || BN_is_zero(n)) {
39         if (pnoinv != NULL)
40             *pnoinv = 1;
41         return NULL;
42     }
43
44     if (pnoinv != NULL)
45         *pnoinv = 0;
46
47     if ((BN_get_flags(a, BN_FLG_CONSTTIME) != 0)
48         || (BN_get_flags(n, BN_FLG_CONSTTIME) != 0)) {
49         return BN_mod_inverse_no_branch(in, a, n, ctx);
50     }
51
52     bn_check_top(a);
53     bn_check_top(n);
54
55     BN_CTX_start(ctx);
56     A = BN_CTX_get(ctx);
57     B = BN_CTX_get(ctx);
58     X = BN_CTX_get(ctx);
59     D = BN_CTX_get(ctx);
60     M = BN_CTX_get(ctx);
61     Y = BN_CTX_get(ctx);
62     T = BN_CTX_get(ctx);
63     if (T == NULL)
64         goto err;
65
66     if (in == NULL)
67         R = BN_new();
68     else
69         R = in;
70     if (R == NULL)
71         goto err;
72
73     BN_one(X);
74     BN_zero(Y);
75     if (BN_copy(B, a) == NULL)
76         goto err;
77     if (BN_copy(A, n) == NULL)
78         goto err;
79     A->neg = 0;
80     if (B->neg || (BN_ucmp(B, A) >= 0)) {
81         if (!BN_nnmod(B, B, A, ctx))
82             goto err;
83     }
84     sign = -1;
85     /*-
86      * From  B = a mod |n|,  A = |n|  it follows that
87      *
88      *      0 <= B < A,
89      *     -sign*X*a  ==  B   (mod |n|),
90      *      sign*Y*a  ==  A   (mod |n|).
91      */
92
93     if (BN_is_odd(n) && (BN_num_bits(n) <= 2048)) {
94         /*
95          * Binary inversion algorithm; requires odd modulus. This is faster
96          * than the general algorithm if the modulus is sufficiently small
97          * (about 400 .. 500 bits on 32-bit systems, but much more on 64-bit
98          * systems)
99          */
100         int shift;
101
102         while (!BN_is_zero(B)) {
103             /*-
104              *      0 < B < |n|,
105              *      0 < A <= |n|,
106              * (1) -sign*X*a  ==  B   (mod |n|),
107              * (2)  sign*Y*a  ==  A   (mod |n|)
108              */
109
110             /*
111              * Now divide B by the maximum possible power of two in the
112              * integers, and divide X by the same value mod |n|. When we're
113              * done, (1) still holds.
114              */
115             shift = 0;
116             while (!BN_is_bit_set(B, shift)) { /* note that 0 < B */
117                 shift++;
118
119                 if (BN_is_odd(X)) {
120                     if (!BN_uadd(X, X, n))
121                         goto err;
122                 }
123                 /*
124                  * now X is even, so we can easily divide it by two
125                  */
126                 if (!BN_rshift1(X, X))
127                     goto err;
128             }
129             if (shift > 0) {
130                 if (!BN_rshift(B, B, shift))
131                     goto err;
132             }
133
134             /*
135              * Same for A and Y.  Afterwards, (2) still holds.
136              */
137             shift = 0;
138             while (!BN_is_bit_set(A, shift)) { /* note that 0 < A */
139                 shift++;
140
141                 if (BN_is_odd(Y)) {
142                     if (!BN_uadd(Y, Y, n))
143                         goto err;
144                 }
145                 /* now Y is even */
146                 if (!BN_rshift1(Y, Y))
147                     goto err;
148             }
149             if (shift > 0) {
150                 if (!BN_rshift(A, A, shift))
151                     goto err;
152             }
153
154             /*-
155              * We still have (1) and (2).
156              * Both  A  and  B  are odd.
157              * The following computations ensure that
158              *
159              *     0 <= B < |n|,
160              *      0 < A < |n|,
161              * (1) -sign*X*a  ==  B   (mod |n|),
162              * (2)  sign*Y*a  ==  A   (mod |n|),
163              *
164              * and that either  A  or  B  is even in the next iteration.
165              */
166             if (BN_ucmp(B, A) >= 0) {
167                 /* -sign*(X + Y)*a == B - A  (mod |n|) */
168                 if (!BN_uadd(X, X, Y))
169                     goto err;
170                 /*
171                  * NB: we could use BN_mod_add_quick(X, X, Y, n), but that
172                  * actually makes the algorithm slower
173                  */
174                 if (!BN_usub(B, B, A))
175                     goto err;
176             } else {
177                 /*  sign*(X + Y)*a == A - B  (mod |n|) */
178                 if (!BN_uadd(Y, Y, X))
179                     goto err;
180                 /*
181                  * as above, BN_mod_add_quick(Y, Y, X, n) would slow things down
182                  */
183                 if (!BN_usub(A, A, B))
184                     goto err;
185             }
186         }
187     } else {
188         /* general inversion algorithm */
189
190         while (!BN_is_zero(B)) {
191             BIGNUM *tmp;
192
193             /*-
194              *      0 < B < A,
195              * (*) -sign*X*a  ==  B   (mod |n|),
196              *      sign*Y*a  ==  A   (mod |n|)
197              */
198
199             /* (D, M) := (A/B, A%B) ... */
200             if (BN_num_bits(A) == BN_num_bits(B)) {
201                 if (!BN_one(D))
202                     goto err;
203                 if (!BN_sub(M, A, B))
204                     goto err;
205             } else if (BN_num_bits(A) == BN_num_bits(B) + 1) {
206                 /* A/B is 1, 2, or 3 */
207                 if (!BN_lshift1(T, B))
208                     goto err;
209                 if (BN_ucmp(A, T) < 0) {
210                     /* A < 2*B, so D=1 */
211                     if (!BN_one(D))
212                         goto err;
213                     if (!BN_sub(M, A, B))
214                         goto err;
215                 } else {
216                     /* A >= 2*B, so D=2 or D=3 */
217                     if (!BN_sub(M, A, T))
218                         goto err;
219                     if (!BN_add(D, T, B))
220                         goto err; /* use D (:= 3*B) as temp */
221                     if (BN_ucmp(A, D) < 0) {
222                         /* A < 3*B, so D=2 */
223                         if (!BN_set_word(D, 2))
224                             goto err;
225                         /*
226                          * M (= A - 2*B) already has the correct value
227                          */
228                     } else {
229                         /* only D=3 remains */
230                         if (!BN_set_word(D, 3))
231                             goto err;
232                         /*
233                          * currently M = A - 2*B, but we need M = A - 3*B
234                          */
235                         if (!BN_sub(M, M, B))
236                             goto err;
237                     }
238                 }
239             } else {
240                 if (!BN_div(D, M, A, B, ctx))
241                     goto err;
242             }
243
244             /*-
245              * Now
246              *      A = D*B + M;
247              * thus we have
248              * (**)  sign*Y*a  ==  D*B + M   (mod |n|).
249              */
250
251             tmp = A;    /* keep the BIGNUM object, the value does not matter */
252
253             /* (A, B) := (B, A mod B) ... */
254             A = B;
255             B = M;
256             /* ... so we have  0 <= B < A  again */
257
258             /*-
259              * Since the former  M  is now  B  and the former  B  is now  A,
260              * (**) translates into
261              *       sign*Y*a  ==  D*A + B    (mod |n|),
262              * i.e.
263              *       sign*Y*a - D*A  ==  B    (mod |n|).
264              * Similarly, (*) translates into
265              *      -sign*X*a  ==  A          (mod |n|).
266              *
267              * Thus,
268              *   sign*Y*a + D*sign*X*a  ==  B  (mod |n|),
269              * i.e.
270              *        sign*(Y + D*X)*a  ==  B  (mod |n|).
271              *
272              * So if we set  (X, Y, sign) := (Y + D*X, X, -sign), we arrive back at
273              *      -sign*X*a  ==  B   (mod |n|),
274              *       sign*Y*a  ==  A   (mod |n|).
275              * Note that  X  and  Y  stay non-negative all the time.
276              */
277
278             /*
279              * most of the time D is very small, so we can optimize tmp := D*X+Y
280              */
281             if (BN_is_one(D)) {
282                 if (!BN_add(tmp, X, Y))
283                     goto err;
284             } else {
285                 if (BN_is_word(D, 2)) {
286                     if (!BN_lshift1(tmp, X))
287                         goto err;
288                 } else if (BN_is_word(D, 4)) {
289                     if (!BN_lshift(tmp, X, 2))
290                         goto err;
291                 } else if (D->top == 1) {
292                     if (!BN_copy(tmp, X))
293                         goto err;
294                     if (!BN_mul_word(tmp, D->d[0]))
295                         goto err;
296                 } else {
297                     if (!BN_mul(tmp, D, X, ctx))
298                         goto err;
299                 }
300                 if (!BN_add(tmp, tmp, Y))
301                     goto err;
302             }
303
304             M = Y;      /* keep the BIGNUM object, the value does not matter */
305             Y = X;
306             X = tmp;
307             sign = -sign;
308         }
309     }
310
311     /*-
312      * The while loop (Euclid's algorithm) ends when
313      *      A == gcd(a,n);
314      * we have
315      *       sign*Y*a  ==  A  (mod |n|),
316      * where  Y  is non-negative.
317      */
318
319     if (sign < 0) {
320         if (!BN_sub(Y, n, Y))
321             goto err;
322     }
323     /* Now  Y*a  ==  A  (mod |n|).  */
324
325     if (BN_is_one(A)) {
326         /* Y*a == 1  (mod |n|) */
327         if (!Y->neg && BN_ucmp(Y, n) < 0) {
328             if (!BN_copy(R, Y))
329                 goto err;
330         } else {
331             if (!BN_nnmod(R, Y, n, ctx))
332                 goto err;
333         }
334     } else {
335         if (pnoinv)
336             *pnoinv = 1;
337         goto err;
338     }
339     ret = R;
340  err:
341     if ((ret == NULL) && (in == NULL))
342         BN_free(R);
343     BN_CTX_end(ctx);
344     bn_check_top(ret);
345     return ret;
346 }
347
348 /*
349  * BN_mod_inverse_no_branch is a special version of BN_mod_inverse. It does
350  * not contain branches that may leak sensitive information.
351  */
352 static BIGNUM *BN_mod_inverse_no_branch(BIGNUM *in,
353                                         const BIGNUM *a, const BIGNUM *n,
354                                         BN_CTX *ctx)
355 {
356     BIGNUM *A, *B, *X, *Y, *M, *D, *T, *R = NULL;
357     BIGNUM *ret = NULL;
358     int sign;
359
360     bn_check_top(a);
361     bn_check_top(n);
362
363     BN_CTX_start(ctx);
364     A = BN_CTX_get(ctx);
365     B = BN_CTX_get(ctx);
366     X = BN_CTX_get(ctx);
367     D = BN_CTX_get(ctx);
368     M = BN_CTX_get(ctx);
369     Y = BN_CTX_get(ctx);
370     T = BN_CTX_get(ctx);
371     if (T == NULL)
372         goto err;
373
374     if (in == NULL)
375         R = BN_new();
376     else
377         R = in;
378     if (R == NULL)
379         goto err;
380
381     BN_one(X);
382     BN_zero(Y);
383     if (BN_copy(B, a) == NULL)
384         goto err;
385     if (BN_copy(A, n) == NULL)
386         goto err;
387     A->neg = 0;
388
389     if (B->neg || (BN_ucmp(B, A) >= 0)) {
390         /*
391          * Turn BN_FLG_CONSTTIME flag on, so that when BN_div is invoked,
392          * BN_div_no_branch will be called eventually.
393          */
394          {
395             BIGNUM local_B;
396             bn_init(&local_B);
397             BN_with_flags(&local_B, B, BN_FLG_CONSTTIME);
398             if (!BN_nnmod(B, &local_B, A, ctx))
399                 goto err;
400             /* Ensure local_B goes out of scope before any further use of B */
401         }
402     }
403     sign = -1;
404     /*-
405      * From  B = a mod |n|,  A = |n|  it follows that
406      *
407      *      0 <= B < A,
408      *     -sign*X*a  ==  B   (mod |n|),
409      *      sign*Y*a  ==  A   (mod |n|).
410      */
411
412     while (!BN_is_zero(B)) {
413         BIGNUM *tmp;
414
415         /*-
416          *      0 < B < A,
417          * (*) -sign*X*a  ==  B   (mod |n|),
418          *      sign*Y*a  ==  A   (mod |n|)
419          */
420
421         /*
422          * Turn BN_FLG_CONSTTIME flag on, so that when BN_div is invoked,
423          * BN_div_no_branch will be called eventually.
424          */
425         {
426             BIGNUM local_A;
427             bn_init(&local_A);
428             BN_with_flags(&local_A, A, BN_FLG_CONSTTIME);
429
430             /* (D, M) := (A/B, A%B) ... */
431             if (!BN_div(D, M, &local_A, B, ctx))
432                 goto err;
433             /* Ensure local_A goes out of scope before any further use of A */
434         }
435
436         /*-
437          * Now
438          *      A = D*B + M;
439          * thus we have
440          * (**)  sign*Y*a  ==  D*B + M   (mod |n|).
441          */
442
443         tmp = A;                /* keep the BIGNUM object, the value does not
444                                  * matter */
445
446         /* (A, B) := (B, A mod B) ... */
447         A = B;
448         B = M;
449         /* ... so we have  0 <= B < A  again */
450
451         /*-
452          * Since the former  M  is now  B  and the former  B  is now  A,
453          * (**) translates into
454          *       sign*Y*a  ==  D*A + B    (mod |n|),
455          * i.e.
456          *       sign*Y*a - D*A  ==  B    (mod |n|).
457          * Similarly, (*) translates into
458          *      -sign*X*a  ==  A          (mod |n|).
459          *
460          * Thus,
461          *   sign*Y*a + D*sign*X*a  ==  B  (mod |n|),
462          * i.e.
463          *        sign*(Y + D*X)*a  ==  B  (mod |n|).
464          *
465          * So if we set  (X, Y, sign) := (Y + D*X, X, -sign), we arrive back at
466          *      -sign*X*a  ==  B   (mod |n|),
467          *       sign*Y*a  ==  A   (mod |n|).
468          * Note that  X  and  Y  stay non-negative all the time.
469          */
470
471         if (!BN_mul(tmp, D, X, ctx))
472             goto err;
473         if (!BN_add(tmp, tmp, Y))
474             goto err;
475
476         M = Y;                  /* keep the BIGNUM object, the value does not
477                                  * matter */
478         Y = X;
479         X = tmp;
480         sign = -sign;
481     }
482
483     /*-
484      * The while loop (Euclid's algorithm) ends when
485      *      A == gcd(a,n);
486      * we have
487      *       sign*Y*a  ==  A  (mod |n|),
488      * where  Y  is non-negative.
489      */
490
491     if (sign < 0) {
492         if (!BN_sub(Y, n, Y))
493             goto err;
494     }
495     /* Now  Y*a  ==  A  (mod |n|).  */
496
497     if (BN_is_one(A)) {
498         /* Y*a == 1  (mod |n|) */
499         if (!Y->neg && BN_ucmp(Y, n) < 0) {
500             if (!BN_copy(R, Y))
501                 goto err;
502         } else {
503             if (!BN_nnmod(R, Y, n, ctx))
504                 goto err;
505         }
506     } else {
507         BNerr(BN_F_BN_MOD_INVERSE_NO_BRANCH, BN_R_NO_INVERSE);
508         goto err;
509     }
510     ret = R;
511  err:
512     if ((ret == NULL) && (in == NULL))
513         BN_free(R);
514     BN_CTX_end(ctx);
515     bn_check_top(ret);
516     return ret;
517 }
518
519 /*-
520  * This function is based on the constant-time GCD work by Bernstein and Yang:
521  * https://eprint.iacr.org/2019/266
522  * Generalized fast GCD function to allow even inputs.
523  * The algorithm first finds the shared powers of 2 between
524  * the inputs, and removes them, reducing at least one of the
525  * inputs to an odd value. Then it proceeds to calculate the GCD.
526  * Before returning the resulting GCD, we take care of adding
527  * back the powers of two removed at the beginning.
528  * Note 1: we assume the bit length of both inputs is public information,
529  * since access to top potentially leaks this information.
530  */
531 int BN_gcd(BIGNUM *r, const BIGNUM *in_a, const BIGNUM *in_b, BN_CTX *ctx)
532 {
533     BIGNUM *g, *temp = NULL;
534     BN_ULONG mask = 0;
535     int i, j, top, rlen, glen, m, bit = 1, delta = 1, cond = 0, shifts = 0, ret = 0;
536
537     /* Note 2: zero input corner cases are not constant-time since they are
538      * handled immediately. An attacker can run an attack under this
539      * assumption without the need of side-channel information. */
540     if (BN_is_zero(in_b)) {
541         ret = BN_copy(r, in_a) != NULL;
542         r->neg = 0;
543         return ret;
544     }
545     if (BN_is_zero(in_a)) {
546         ret = BN_copy(r, in_b) != NULL;
547         r->neg = 0;
548         return ret;
549     }
550
551     bn_check_top(in_a);
552     bn_check_top(in_b);
553
554     BN_CTX_start(ctx);
555     temp = BN_CTX_get(ctx);
556     g = BN_CTX_get(ctx);
557
558     /* make r != 0, g != 0 even, so BN_rshift is not a potential nop */
559     if (g == NULL
560         || !BN_lshift1(g, in_b)
561         || !BN_lshift1(r, in_a))
562         goto err;
563
564     /* find shared powers of two, i.e. "shifts" >= 1 */
565     for (i = 0; i < r->dmax && i < g->dmax; i++) {
566         mask = ~(r->d[i] | g->d[i]);
567         for (j = 0; j < BN_BITS2; j++) {
568             bit &= mask;
569             shifts += bit;
570             mask >>= 1;
571         }
572     }
573
574     /* subtract shared powers of two; shifts >= 1 */
575     if (!BN_rshift(r, r, shifts)
576         || !BN_rshift(g, g, shifts))
577         goto err;
578
579     /* expand to biggest nword, with room for a possible extra word */
580     top = 1 + ((r->top >= g->top) ? r->top : g->top);
581     if (bn_wexpand(r, top) == NULL
582         || bn_wexpand(g, top) == NULL
583         || bn_wexpand(temp, top) == NULL)
584         goto err;
585
586     /* re arrange inputs s.t. r is odd */
587     BN_consttime_swap((~r->d[0]) & 1, r, g, top);
588
589     /* compute the number of iterations */
590     rlen = BN_num_bits(r);
591     glen = BN_num_bits(g);
592     m = 4 + 3 * ((rlen >= glen) ? rlen : glen);
593
594     for (i = 0; i < m; i++) {
595         /* conditionally flip signs if delta is positive and g is odd */
596         cond = (-delta >> (8 * sizeof(delta) - 1)) & g->d[0] & 1
597             /* make sure g->top > 0 (i.e. if top == 0 then g == 0 always) */
598             & (~((g->top - 1) >> (sizeof(g->top) * 8 - 1)));
599         delta = (-cond & -delta) | ((cond - 1) & delta);
600         r->neg ^= cond;
601         /* swap */
602         BN_consttime_swap(cond, r, g, top);
603
604         /* elimination step */
605         delta++;
606         if (!BN_add(temp, g, r))
607             goto err;
608         BN_consttime_swap(g->d[0] & 1 /* g is odd */
609                 /* make sure g->top > 0 (i.e. if top == 0 then g == 0 always) */
610                 & (~((g->top - 1) >> (sizeof(g->top) * 8 - 1))),
611                 g, temp, top);
612         if (!BN_rshift1(g, g))
613             goto err;
614     }
615
616     /* remove possible negative sign */
617     r->neg = 0;
618     /* add powers of 2 removed, then correct the artificial shift */
619     if (!BN_lshift(r, r, shifts)
620         || !BN_rshift1(r, r))
621         goto err;
622
623     ret = 1;
624
625  err:
626     BN_CTX_end(ctx);
627     bn_check_top(r);
628     return ret;
629 }