e532b6e668c1575bfee671fb6b061177a5b14b09
[openssl.git] / crypto / bn / bn_recp.c
1 /*
2  * Copyright 1995-2016 The OpenSSL Project Authors. All Rights Reserved.
3  *
4  * Licensed under the OpenSSL license (the "License").  You may not use
5  * this file except in compliance with the License.  You can obtain a copy
6  * in the file LICENSE in the source distribution or at
7  * https://www.openssl.org/source/license.html
8  */
9
10 #include "internal/cryptlib.h"
11 #include "bn_lcl.h"
12
13 void BN_RECP_CTX_init(BN_RECP_CTX *recp)
14 {
15     bn_init(&(recp->N));
16     bn_init(&(recp->Nr));
17     recp->num_bits = 0;
18     recp->flags = 0;
19 }
20
21 BN_RECP_CTX *BN_RECP_CTX_new(void)
22 {
23     BN_RECP_CTX *ret;
24
25     if ((ret = OPENSSL_zalloc(sizeof(*ret))) == NULL)
26         return (NULL);
27
28     BN_RECP_CTX_init(ret);
29     ret->flags = BN_FLG_MALLOCED;
30     return (ret);
31 }
32
33 void BN_RECP_CTX_free(BN_RECP_CTX *recp)
34 {
35     if (recp == NULL)
36         return;
37
38     BN_free(&(recp->N));
39     BN_free(&(recp->Nr));
40     if (recp->flags & BN_FLG_MALLOCED)
41         OPENSSL_free(recp);
42 }
43
44 int BN_RECP_CTX_set(BN_RECP_CTX *recp, const BIGNUM *d, BN_CTX *ctx)
45 {
46     if (!BN_copy(&(recp->N), d))
47         return 0;
48     BN_zero(&(recp->Nr));
49     recp->num_bits = BN_num_bits(d);
50     recp->shift = 0;
51     return (1);
52 }
53
54 int BN_mod_mul_reciprocal(BIGNUM *r, const BIGNUM *x, const BIGNUM *y,
55                           BN_RECP_CTX *recp, BN_CTX *ctx)
56 {
57     int ret = 0;
58     BIGNUM *a;
59     const BIGNUM *ca;
60
61     BN_CTX_start(ctx);
62     if ((a = BN_CTX_get(ctx)) == NULL)
63         goto err;
64     if (y != NULL) {
65         if (x == y) {
66             if (!BN_sqr(a, x, ctx))
67                 goto err;
68         } else {
69             if (!BN_mul(a, x, y, ctx))
70                 goto err;
71         }
72         ca = a;
73     } else
74         ca = x;                 /* Just do the mod */
75
76     ret = BN_div_recp(NULL, r, ca, recp, ctx);
77  err:
78     BN_CTX_end(ctx);
79     bn_check_top(r);
80     return (ret);
81 }
82
83 int BN_div_recp(BIGNUM *dv, BIGNUM *rem, const BIGNUM *m,
84                 BN_RECP_CTX *recp, BN_CTX *ctx)
85 {
86     int i, j, ret = 0;
87     BIGNUM *a, *b, *d, *r;
88
89     BN_CTX_start(ctx);
90     a = BN_CTX_get(ctx);
91     b = BN_CTX_get(ctx);
92     if (dv != NULL)
93         d = dv;
94     else
95         d = BN_CTX_get(ctx);
96     if (rem != NULL)
97         r = rem;
98     else
99         r = BN_CTX_get(ctx);
100     if (a == NULL || b == NULL || d == NULL || r == NULL)
101         goto err;
102
103     if (BN_ucmp(m, &(recp->N)) < 0) {
104         BN_zero(d);
105         if (!BN_copy(r, m)) {
106             BN_CTX_end(ctx);
107             return 0;
108         }
109         BN_CTX_end(ctx);
110         return (1);
111     }
112
113     /*
114      * We want the remainder Given input of ABCDEF / ab we need multiply
115      * ABCDEF by 3 digests of the reciprocal of ab
116      */
117
118     /* i := max(BN_num_bits(m), 2*BN_num_bits(N)) */
119     i = BN_num_bits(m);
120     j = recp->num_bits << 1;
121     if (j > i)
122         i = j;
123
124     /* Nr := round(2^i / N) */
125     if (i != recp->shift)
126         recp->shift = BN_reciprocal(&(recp->Nr), &(recp->N), i, ctx);
127     /* BN_reciprocal could have returned -1 for an error */
128     if (recp->shift == -1)
129         goto err;
130
131     /*-
132      * d := |round(round(m / 2^BN_num_bits(N)) * recp->Nr / 2^(i - BN_num_bits(N)))|
133      *    = |round(round(m / 2^BN_num_bits(N)) * round(2^i / N) / 2^(i - BN_num_bits(N)))|
134      *   <= |(m / 2^BN_num_bits(N)) * (2^i / N) * (2^BN_num_bits(N) / 2^i)|
135      *    = |m/N|
136      */
137     if (!BN_rshift(a, m, recp->num_bits))
138         goto err;
139     if (!BN_mul(b, a, &(recp->Nr), ctx))
140         goto err;
141     if (!BN_rshift(d, b, i - recp->num_bits))
142         goto err;
143     d->neg = 0;
144
145     if (!BN_mul(b, &(recp->N), d, ctx))
146         goto err;
147     if (!BN_usub(r, m, b))
148         goto err;
149     r->neg = 0;
150
151     j = 0;
152     while (BN_ucmp(r, &(recp->N)) >= 0) {
153         if (j++ > 2) {
154             BNerr(BN_F_BN_DIV_RECP, BN_R_BAD_RECIPROCAL);
155             goto err;
156         }
157         if (!BN_usub(r, r, &(recp->N)))
158             goto err;
159         if (!BN_add_word(d, 1))
160             goto err;
161     }
162
163     r->neg = BN_is_zero(r) ? 0 : m->neg;
164     d->neg = m->neg ^ recp->N.neg;
165     ret = 1;
166  err:
167     BN_CTX_end(ctx);
168     bn_check_top(dv);
169     bn_check_top(rem);
170     return (ret);
171 }
172
173 /*
174  * len is the expected size of the result We actually calculate with an extra
175  * word of precision, so we can do faster division if the remainder is not
176  * required.
177  */
178 /* r := 2^len / m */
179 int BN_reciprocal(BIGNUM *r, const BIGNUM *m, int len, BN_CTX *ctx)
180 {
181     int ret = -1;
182     BIGNUM *t;
183
184     BN_CTX_start(ctx);
185     if ((t = BN_CTX_get(ctx)) == NULL)
186         goto err;
187
188     if (!BN_set_bit(t, len))
189         goto err;
190
191     if (!BN_div(r, NULL, t, m, ctx))
192         goto err;
193
194     ret = len;
195  err:
196     bn_check_top(r);
197     BN_CTX_end(ctx);
198     return (ret);
199 }