067642644ec1020ec7b5c1031a922c0421ae76a2
[openssl.git] / crypto / bn / bn_gcd.c
1 /*
2  * Copyright 1995-2016 The OpenSSL Project Authors. All Rights Reserved.
3  *
4  * Licensed under the OpenSSL license (the "License").  You may not use
5  * this file except in compliance with the License.  You can obtain a copy
6  * in the file LICENSE in the source distribution or at
7  * https://www.openssl.org/source/license.html
8  */
9
10 #include "internal/cryptlib.h"
11 #include "bn_lcl.h"
12
13 static BIGNUM *euclid(BIGNUM *a, BIGNUM *b);
14
15 int BN_gcd(BIGNUM *r, const BIGNUM *in_a, const BIGNUM *in_b, BN_CTX *ctx)
16 {
17     BIGNUM *a, *b, *t;
18     int ret = 0;
19
20     bn_check_top(in_a);
21     bn_check_top(in_b);
22
23     BN_CTX_start(ctx);
24     a = BN_CTX_get(ctx);
25     b = BN_CTX_get(ctx);
26     if (a == NULL || b == NULL)
27         goto err;
28
29     if (BN_copy(a, in_a) == NULL)
30         goto err;
31     if (BN_copy(b, in_b) == NULL)
32         goto err;
33     a->neg = 0;
34     b->neg = 0;
35
36     if (BN_cmp(a, b) < 0) {
37         t = a;
38         a = b;
39         b = t;
40     }
41     t = euclid(a, b);
42     if (t == NULL)
43         goto err;
44
45     if (BN_copy(r, t) == NULL)
46         goto err;
47     ret = 1;
48  err:
49     BN_CTX_end(ctx);
50     bn_check_top(r);
51     return (ret);
52 }
53
54 static BIGNUM *euclid(BIGNUM *a, BIGNUM *b)
55 {
56     BIGNUM *t;
57     int shifts = 0;
58
59     bn_check_top(a);
60     bn_check_top(b);
61
62     /* 0 <= b <= a */
63     while (!BN_is_zero(b)) {
64         /* 0 < b <= a */
65
66         if (BN_is_odd(a)) {
67             if (BN_is_odd(b)) {
68                 if (!BN_sub(a, a, b))
69                     goto err;
70                 if (!BN_rshift1(a, a))
71                     goto err;
72                 if (BN_cmp(a, b) < 0) {
73                     t = a;
74                     a = b;
75                     b = t;
76                 }
77             } else {            /* a odd - b even */
78
79                 if (!BN_rshift1(b, b))
80                     goto err;
81                 if (BN_cmp(a, b) < 0) {
82                     t = a;
83                     a = b;
84                     b = t;
85                 }
86             }
87         } else {                /* a is even */
88
89             if (BN_is_odd(b)) {
90                 if (!BN_rshift1(a, a))
91                     goto err;
92                 if (BN_cmp(a, b) < 0) {
93                     t = a;
94                     a = b;
95                     b = t;
96                 }
97             } else {            /* a even - b even */
98
99                 if (!BN_rshift1(a, a))
100                     goto err;
101                 if (!BN_rshift1(b, b))
102                     goto err;
103                 shifts++;
104             }
105         }
106         /* 0 <= b <= a */
107     }
108
109     if (shifts) {
110         if (!BN_lshift(a, a, shifts))
111             goto err;
112     }
113     bn_check_top(a);
114     return (a);
115  err:
116     return (NULL);
117 }
118
119 /* solves ax == 1 (mod n) */
120 static BIGNUM *BN_mod_inverse_no_branch(BIGNUM *in,
121                                         const BIGNUM *a, const BIGNUM *n,
122                                         BN_CTX *ctx);
123
124 BIGNUM *BN_mod_inverse(BIGNUM *in,
125                        const BIGNUM *a, const BIGNUM *n, BN_CTX *ctx)
126 {
127     BIGNUM *rv;
128     int noinv;
129     rv = int_bn_mod_inverse(in, a, n, ctx, &noinv);
130     if (noinv)
131         BNerr(BN_F_BN_MOD_INVERSE, BN_R_NO_INVERSE);
132     return rv;
133 }
134
135 BIGNUM *int_bn_mod_inverse(BIGNUM *in,
136                            const BIGNUM *a, const BIGNUM *n, BN_CTX *ctx,
137                            int *pnoinv)
138 {
139     BIGNUM *A, *B, *X, *Y, *M, *D, *T, *R = NULL;
140     BIGNUM *ret = NULL;
141     int sign;
142
143     if (pnoinv)
144         *pnoinv = 0;
145
146     if ((BN_get_flags(a, BN_FLG_CONSTTIME) != 0)
147         || (BN_get_flags(n, BN_FLG_CONSTTIME) != 0)) {
148         return BN_mod_inverse_no_branch(in, a, n, ctx);
149     }
150
151     bn_check_top(a);
152     bn_check_top(n);
153
154     BN_CTX_start(ctx);
155     A = BN_CTX_get(ctx);
156     B = BN_CTX_get(ctx);
157     X = BN_CTX_get(ctx);
158     D = BN_CTX_get(ctx);
159     M = BN_CTX_get(ctx);
160     Y = BN_CTX_get(ctx);
161     T = BN_CTX_get(ctx);
162     if (T == NULL)
163         goto err;
164
165     if (in == NULL)
166         R = BN_new();
167     else
168         R = in;
169     if (R == NULL)
170         goto err;
171
172     BN_one(X);
173     BN_zero(Y);
174     if (BN_copy(B, a) == NULL)
175         goto err;
176     if (BN_copy(A, n) == NULL)
177         goto err;
178     A->neg = 0;
179     if (B->neg || (BN_ucmp(B, A) >= 0)) {
180         if (!BN_nnmod(B, B, A, ctx))
181             goto err;
182     }
183     sign = -1;
184     /*-
185      * From  B = a mod |n|,  A = |n|  it follows that
186      *
187      *      0 <= B < A,
188      *     -sign*X*a  ==  B   (mod |n|),
189      *      sign*Y*a  ==  A   (mod |n|).
190      */
191
192     if (BN_is_odd(n) && (BN_num_bits(n) <= 2048)) {
193         /*
194          * Binary inversion algorithm; requires odd modulus. This is faster
195          * than the general algorithm if the modulus is sufficiently small
196          * (about 400 .. 500 bits on 32-bit systems, but much more on 64-bit
197          * systems)
198          */
199         int shift;
200
201         while (!BN_is_zero(B)) {
202             /*-
203              *      0 < B < |n|,
204              *      0 < A <= |n|,
205              * (1) -sign*X*a  ==  B   (mod |n|),
206              * (2)  sign*Y*a  ==  A   (mod |n|)
207              */
208
209             /*
210              * Now divide B by the maximum possible power of two in the
211              * integers, and divide X by the same value mod |n|. When we're
212              * done, (1) still holds.
213              */
214             shift = 0;
215             while (!BN_is_bit_set(B, shift)) { /* note that 0 < B */
216                 shift++;
217
218                 if (BN_is_odd(X)) {
219                     if (!BN_uadd(X, X, n))
220                         goto err;
221                 }
222                 /*
223                  * now X is even, so we can easily divide it by two
224                  */
225                 if (!BN_rshift1(X, X))
226                     goto err;
227             }
228             if (shift > 0) {
229                 if (!BN_rshift(B, B, shift))
230                     goto err;
231             }
232
233             /*
234              * Same for A and Y.  Afterwards, (2) still holds.
235              */
236             shift = 0;
237             while (!BN_is_bit_set(A, shift)) { /* note that 0 < A */
238                 shift++;
239
240                 if (BN_is_odd(Y)) {
241                     if (!BN_uadd(Y, Y, n))
242                         goto err;
243                 }
244                 /* now Y is even */
245                 if (!BN_rshift1(Y, Y))
246                     goto err;
247             }
248             if (shift > 0) {
249                 if (!BN_rshift(A, A, shift))
250                     goto err;
251             }
252
253             /*-
254              * We still have (1) and (2).
255              * Both  A  and  B  are odd.
256              * The following computations ensure that
257              *
258              *     0 <= B < |n|,
259              *      0 < A < |n|,
260              * (1) -sign*X*a  ==  B   (mod |n|),
261              * (2)  sign*Y*a  ==  A   (mod |n|),
262              *
263              * and that either  A  or  B  is even in the next iteration.
264              */
265             if (BN_ucmp(B, A) >= 0) {
266                 /* -sign*(X + Y)*a == B - A  (mod |n|) */
267                 if (!BN_uadd(X, X, Y))
268                     goto err;
269                 /*
270                  * NB: we could use BN_mod_add_quick(X, X, Y, n), but that
271                  * actually makes the algorithm slower
272                  */
273                 if (!BN_usub(B, B, A))
274                     goto err;
275             } else {
276                 /*  sign*(X + Y)*a == A - B  (mod |n|) */
277                 if (!BN_uadd(Y, Y, X))
278                     goto err;
279                 /*
280                  * as above, BN_mod_add_quick(Y, Y, X, n) would slow things down
281                  */
282                 if (!BN_usub(A, A, B))
283                     goto err;
284             }
285         }
286     } else {
287         /* general inversion algorithm */
288
289         while (!BN_is_zero(B)) {
290             BIGNUM *tmp;
291
292             /*-
293              *      0 < B < A,
294              * (*) -sign*X*a  ==  B   (mod |n|),
295              *      sign*Y*a  ==  A   (mod |n|)
296              */
297
298             /* (D, M) := (A/B, A%B) ... */
299             if (BN_num_bits(A) == BN_num_bits(B)) {
300                 if (!BN_one(D))
301                     goto err;
302                 if (!BN_sub(M, A, B))
303                     goto err;
304             } else if (BN_num_bits(A) == BN_num_bits(B) + 1) {
305                 /* A/B is 1, 2, or 3 */
306                 if (!BN_lshift1(T, B))
307                     goto err;
308                 if (BN_ucmp(A, T) < 0) {
309                     /* A < 2*B, so D=1 */
310                     if (!BN_one(D))
311                         goto err;
312                     if (!BN_sub(M, A, B))
313                         goto err;
314                 } else {
315                     /* A >= 2*B, so D=2 or D=3 */
316                     if (!BN_sub(M, A, T))
317                         goto err;
318                     if (!BN_add(D, T, B))
319                         goto err; /* use D (:= 3*B) as temp */
320                     if (BN_ucmp(A, D) < 0) {
321                         /* A < 3*B, so D=2 */
322                         if (!BN_set_word(D, 2))
323                             goto err;
324                         /*
325                          * M (= A - 2*B) already has the correct value
326                          */
327                     } else {
328                         /* only D=3 remains */
329                         if (!BN_set_word(D, 3))
330                             goto err;
331                         /*
332                          * currently M = A - 2*B, but we need M = A - 3*B
333                          */
334                         if (!BN_sub(M, M, B))
335                             goto err;
336                     }
337                 }
338             } else {
339                 if (!BN_div(D, M, A, B, ctx))
340                     goto err;
341             }
342
343             /*-
344              * Now
345              *      A = D*B + M;
346              * thus we have
347              * (**)  sign*Y*a  ==  D*B + M   (mod |n|).
348              */
349
350             tmp = A;    /* keep the BIGNUM object, the value does not matter */
351
352             /* (A, B) := (B, A mod B) ... */
353             A = B;
354             B = M;
355             /* ... so we have  0 <= B < A  again */
356
357             /*-
358              * Since the former  M  is now  B  and the former  B  is now  A,
359              * (**) translates into
360              *       sign*Y*a  ==  D*A + B    (mod |n|),
361              * i.e.
362              *       sign*Y*a - D*A  ==  B    (mod |n|).
363              * Similarly, (*) translates into
364              *      -sign*X*a  ==  A          (mod |n|).
365              *
366              * Thus,
367              *   sign*Y*a + D*sign*X*a  ==  B  (mod |n|),
368              * i.e.
369              *        sign*(Y + D*X)*a  ==  B  (mod |n|).
370              *
371              * So if we set  (X, Y, sign) := (Y + D*X, X, -sign), we arrive back at
372              *      -sign*X*a  ==  B   (mod |n|),
373              *       sign*Y*a  ==  A   (mod |n|).
374              * Note that  X  and  Y  stay non-negative all the time.
375              */
376
377             /*
378              * most of the time D is very small, so we can optimize tmp := D*X+Y
379              */
380             if (BN_is_one(D)) {
381                 if (!BN_add(tmp, X, Y))
382                     goto err;
383             } else {
384                 if (BN_is_word(D, 2)) {
385                     if (!BN_lshift1(tmp, X))
386                         goto err;
387                 } else if (BN_is_word(D, 4)) {
388                     if (!BN_lshift(tmp, X, 2))
389                         goto err;
390                 } else if (D->top == 1) {
391                     if (!BN_copy(tmp, X))
392                         goto err;
393                     if (!BN_mul_word(tmp, D->d[0]))
394                         goto err;
395                 } else {
396                     if (!BN_mul(tmp, D, X, ctx))
397                         goto err;
398                 }
399                 if (!BN_add(tmp, tmp, Y))
400                     goto err;
401             }
402
403             M = Y;      /* keep the BIGNUM object, the value does not matter */
404             Y = X;
405             X = tmp;
406             sign = -sign;
407         }
408     }
409
410     /*-
411      * The while loop (Euclid's algorithm) ends when
412      *      A == gcd(a,n);
413      * we have
414      *       sign*Y*a  ==  A  (mod |n|),
415      * where  Y  is non-negative.
416      */
417
418     if (sign < 0) {
419         if (!BN_sub(Y, n, Y))
420             goto err;
421     }
422     /* Now  Y*a  ==  A  (mod |n|).  */
423
424     if (BN_is_one(A)) {
425         /* Y*a == 1  (mod |n|) */
426         if (!Y->neg && BN_ucmp(Y, n) < 0) {
427             if (!BN_copy(R, Y))
428                 goto err;
429         } else {
430             if (!BN_nnmod(R, Y, n, ctx))
431                 goto err;
432         }
433     } else {
434         if (pnoinv)
435             *pnoinv = 1;
436         goto err;
437     }
438     ret = R;
439  err:
440     if ((ret == NULL) && (in == NULL))
441         BN_free(R);
442     BN_CTX_end(ctx);
443     bn_check_top(ret);
444     return (ret);
445 }
446
447 /*
448  * BN_mod_inverse_no_branch is a special version of BN_mod_inverse. It does
449  * not contain branches that may leak sensitive information.
450  */
451 static BIGNUM *BN_mod_inverse_no_branch(BIGNUM *in,
452                                         const BIGNUM *a, const BIGNUM *n,
453                                         BN_CTX *ctx)
454 {
455     BIGNUM *A, *B, *X, *Y, *M, *D, *T, *R = NULL;
456     BIGNUM *ret = NULL;
457     int sign;
458
459     bn_check_top(a);
460     bn_check_top(n);
461
462     BN_CTX_start(ctx);
463     A = BN_CTX_get(ctx);
464     B = BN_CTX_get(ctx);
465     X = BN_CTX_get(ctx);
466     D = BN_CTX_get(ctx);
467     M = BN_CTX_get(ctx);
468     Y = BN_CTX_get(ctx);
469     T = BN_CTX_get(ctx);
470     if (T == NULL)
471         goto err;
472
473     if (in == NULL)
474         R = BN_new();
475     else
476         R = in;
477     if (R == NULL)
478         goto err;
479
480     BN_one(X);
481     BN_zero(Y);
482     if (BN_copy(B, a) == NULL)
483         goto err;
484     if (BN_copy(A, n) == NULL)
485         goto err;
486     A->neg = 0;
487
488     if (B->neg || (BN_ucmp(B, A) >= 0)) {
489         /*
490          * Turn BN_FLG_CONSTTIME flag on, so that when BN_div is invoked,
491          * BN_div_no_branch will be called eventually.
492          */
493          {
494             BIGNUM local_B;
495             bn_init(&local_B);
496             BN_with_flags(&local_B, B, BN_FLG_CONSTTIME);
497             if (!BN_nnmod(B, &local_B, A, ctx))
498                 goto err;
499             /* Ensure local_B goes out of scope before any further use of B */
500         }
501     }
502     sign = -1;
503     /*-
504      * From  B = a mod |n|,  A = |n|  it follows that
505      *
506      *      0 <= B < A,
507      *     -sign*X*a  ==  B   (mod |n|),
508      *      sign*Y*a  ==  A   (mod |n|).
509      */
510
511     while (!BN_is_zero(B)) {
512         BIGNUM *tmp;
513
514         /*-
515          *      0 < B < A,
516          * (*) -sign*X*a  ==  B   (mod |n|),
517          *      sign*Y*a  ==  A   (mod |n|)
518          */
519
520         /*
521          * Turn BN_FLG_CONSTTIME flag on, so that when BN_div is invoked,
522          * BN_div_no_branch will be called eventually.
523          */
524         {
525             BIGNUM local_A;
526             bn_init(&local_A);
527             BN_with_flags(&local_A, A, BN_FLG_CONSTTIME);
528
529             /* (D, M) := (A/B, A%B) ... */
530             if (!BN_div(D, M, &local_A, B, ctx))
531                 goto err;
532             /* Ensure local_A goes out of scope before any further use of A */
533         }
534
535         /*-
536          * Now
537          *      A = D*B + M;
538          * thus we have
539          * (**)  sign*Y*a  ==  D*B + M   (mod |n|).
540          */
541
542         tmp = A;                /* keep the BIGNUM object, the value does not
543                                  * matter */
544
545         /* (A, B) := (B, A mod B) ... */
546         A = B;
547         B = M;
548         /* ... so we have  0 <= B < A  again */
549
550         /*-
551          * Since the former  M  is now  B  and the former  B  is now  A,
552          * (**) translates into
553          *       sign*Y*a  ==  D*A + B    (mod |n|),
554          * i.e.
555          *       sign*Y*a - D*A  ==  B    (mod |n|).
556          * Similarly, (*) translates into
557          *      -sign*X*a  ==  A          (mod |n|).
558          *
559          * Thus,
560          *   sign*Y*a + D*sign*X*a  ==  B  (mod |n|),
561          * i.e.
562          *        sign*(Y + D*X)*a  ==  B  (mod |n|).
563          *
564          * So if we set  (X, Y, sign) := (Y + D*X, X, -sign), we arrive back at
565          *      -sign*X*a  ==  B   (mod |n|),
566          *       sign*Y*a  ==  A   (mod |n|).
567          * Note that  X  and  Y  stay non-negative all the time.
568          */
569
570         if (!BN_mul(tmp, D, X, ctx))
571             goto err;
572         if (!BN_add(tmp, tmp, Y))
573             goto err;
574
575         M = Y;                  /* keep the BIGNUM object, the value does not
576                                  * matter */
577         Y = X;
578         X = tmp;
579         sign = -sign;
580     }
581
582     /*-
583      * The while loop (Euclid's algorithm) ends when
584      *      A == gcd(a,n);
585      * we have
586      *       sign*Y*a  ==  A  (mod |n|),
587      * where  Y  is non-negative.
588      */
589
590     if (sign < 0) {
591         if (!BN_sub(Y, n, Y))
592             goto err;
593     }
594     /* Now  Y*a  ==  A  (mod |n|).  */
595
596     if (BN_is_one(A)) {
597         /* Y*a == 1  (mod |n|) */
598         if (!Y->neg && BN_ucmp(Y, n) < 0) {
599             if (!BN_copy(R, Y))
600                 goto err;
601         } else {
602             if (!BN_nnmod(R, Y, n, ctx))
603                 goto err;
604         }
605     } else {
606         BNerr(BN_F_BN_MOD_INVERSE_NO_BRANCH, BN_R_NO_INVERSE);
607         goto err;
608     }
609     ret = R;
610  err:
611     if ((ret == NULL) && (in == NULL))
612         BN_free(R);
613     BN_CTX_end(ctx);
614     bn_check_top(ret);
615     return (ret);
616 }