[crypto/dh] side channel hardening for computing DH shared keys
[openssl.git] / crypto / dh / dh_key.c
1 /*
2  * Copyright 1995-2020 The OpenSSL Project Authors. All Rights Reserved.
3  *
4  * Licensed under the Apache License 2.0 (the "License").  You may not use
5  * this file except in compliance with the License.  You can obtain a copy
6  * in the file LICENSE in the source distribution or at
7  * https://www.openssl.org/source/license.html
8  */
9
10 /*
11  * DH low level APIs are deprecated for public use, but still ok for
12  * internal use.
13  */
14 #include "internal/deprecated.h"
15
16 #include <stdio.h>
17 #include "internal/cryptlib.h"
18 #include "dh_local.h"
19 #include "crypto/bn.h"
20 #include "crypto/dh.h"
21 #include "crypto/security_bits.h"
22
23 #ifdef FIPS_MODULE
24 # define MIN_STRENGTH 112
25 #else
26 # define MIN_STRENGTH 80
27 #endif
28
29 static int generate_key(DH *dh);
30 static int dh_bn_mod_exp(const DH *dh, BIGNUM *r,
31                          const BIGNUM *a, const BIGNUM *p,
32                          const BIGNUM *m, BN_CTX *ctx, BN_MONT_CTX *m_ctx);
33 static int dh_init(DH *dh);
34 static int dh_finish(DH *dh);
35
36 static int compute_key(unsigned char *key, const BIGNUM *pub_key, DH *dh)
37 {
38     BN_CTX *ctx = NULL;
39     BN_MONT_CTX *mont = NULL;
40     BIGNUM *tmp;
41     int ret = -1;
42 #ifndef FIPS_MODULE
43     int check_result;
44 #endif
45
46     if (BN_num_bits(dh->params.p) > OPENSSL_DH_MAX_MODULUS_BITS) {
47         ERR_raise(ERR_LIB_DH, DH_R_MODULUS_TOO_LARGE);
48         goto err;
49     }
50
51     if (BN_num_bits(dh->params.p) < DH_MIN_MODULUS_BITS) {
52         ERR_raise(ERR_LIB_DH, DH_R_MODULUS_TOO_SMALL);
53         return 0;
54     }
55
56     ctx = BN_CTX_new_ex(dh->libctx);
57     if (ctx == NULL)
58         goto err;
59     BN_CTX_start(ctx);
60     tmp = BN_CTX_get(ctx);
61     if (tmp == NULL)
62         goto err;
63
64     if (dh->priv_key == NULL) {
65         ERR_raise(ERR_LIB_DH, DH_R_NO_PRIVATE_VALUE);
66         goto err;
67     }
68
69     if (dh->flags & DH_FLAG_CACHE_MONT_P) {
70         mont = BN_MONT_CTX_set_locked(&dh->method_mont_p,
71                                       dh->lock, dh->params.p, ctx);
72         BN_set_flags(dh->priv_key, BN_FLG_CONSTTIME);
73         if (!mont)
74             goto err;
75     }
76 /* TODO(3.0) : Solve in a PR related to Key validation for DH */
77 #ifndef FIPS_MODULE
78     if (!DH_check_pub_key(dh, pub_key, &check_result) || check_result) {
79         ERR_raise(ERR_LIB_DH, DH_R_INVALID_PUBKEY);
80         goto err;
81     }
82 #endif
83     if (!dh->meth->bn_mod_exp(dh, tmp, pub_key, dh->priv_key, dh->params.p, ctx,
84                               mont)) {
85         ERR_raise(ERR_LIB_DH, ERR_R_BN_LIB);
86         goto err;
87     }
88
89     /* return the padded key, i.e. same number of bytes as the modulus */
90     ret = BN_bn2binpad(tmp, key, BN_num_bytes(dh->params.p));
91  err:
92     BN_CTX_end(ctx);
93     BN_CTX_free(ctx);
94     return ret;
95 }
96
97 /*-
98  * NB: This function is inherently not constant time due to the
99  * RFC 5246 (8.1.2) padding style that strips leading zero bytes.
100  */
101 int DH_compute_key(unsigned char *key, const BIGNUM *pub_key, DH *dh)
102 {
103     int ret = 0, i;
104     volatile size_t npad = 0, mask = 1;
105
106     /* compute the key; ret is constant unless compute_key is external */
107 #ifdef FIPS_MODULE
108     ret = compute_key(key, pub_key, dh);
109 #else
110     ret = dh->meth->compute_key(key, pub_key, dh);
111 #endif
112     if (ret <= 0)
113         return ret;
114
115     /* count leading zero bytes, yet still touch all bytes */
116     for (i = 0; i < ret; i++) {
117         mask &= !key[i];
118         npad += mask;
119     }
120
121     /* unpad key */
122     ret -= npad;
123     /* key-dependent memory access, potentially leaking npad / ret */
124     memmove(key, key + npad, ret);
125     /* key-dependent memory access, potentially leaking npad / ret */
126     memset(key + ret, 0, npad);
127
128     return ret;
129 }
130
131 int DH_compute_key_padded(unsigned char *key, const BIGNUM *pub_key, DH *dh)
132 {
133     int rv, pad;
134
135     /* rv is constant unless compute_key is external */
136 #ifdef FIPS_MODULE
137     rv = compute_key(key, pub_key, dh);
138 #else
139     rv = dh->meth->compute_key(key, pub_key, dh);
140 #endif
141     if (rv <= 0)
142         return rv;
143     pad = BN_num_bytes(dh->params.p) - rv;
144     /* pad is constant (zero) unless compute_key is external */
145     if (pad > 0) {
146         memmove(key + pad, key, rv);
147         memset(key, 0, pad);
148     }
149     return rv + pad;
150 }
151
152 static DH_METHOD dh_ossl = {
153     "OpenSSL DH Method",
154     generate_key,
155     compute_key,
156     dh_bn_mod_exp,
157     dh_init,
158     dh_finish,
159     DH_FLAG_FIPS_METHOD,
160     NULL,
161     NULL
162 };
163
164 static const DH_METHOD *default_DH_method = &dh_ossl;
165
166 const DH_METHOD *DH_OpenSSL(void)
167 {
168     return &dh_ossl;
169 }
170
171 const DH_METHOD *DH_get_default_method(void)
172 {
173     return default_DH_method;
174 }
175
176 static int dh_bn_mod_exp(const DH *dh, BIGNUM *r,
177                          const BIGNUM *a, const BIGNUM *p,
178                          const BIGNUM *m, BN_CTX *ctx, BN_MONT_CTX *m_ctx)
179 {
180     return BN_mod_exp_mont(r, a, p, m, ctx, m_ctx);
181 }
182
183 static int dh_init(DH *dh)
184 {
185     dh->flags |= DH_FLAG_CACHE_MONT_P;
186     ossl_ffc_params_init(&dh->params);
187     dh->dirty_cnt++;
188     return 1;
189 }
190
191 static int dh_finish(DH *dh)
192 {
193     BN_MONT_CTX_free(dh->method_mont_p);
194     return 1;
195 }
196
197 #ifndef FIPS_MODULE
198 void DH_set_default_method(const DH_METHOD *meth)
199 {
200     default_DH_method = meth;
201 }
202 #endif /* FIPS_MODULE */
203
204 int DH_generate_key(DH *dh)
205 {
206 #ifdef FIPS_MODULE
207     return generate_key(dh);
208 #else
209     return dh->meth->generate_key(dh);
210 #endif
211 }
212
213 int dh_generate_public_key(BN_CTX *ctx, const DH *dh, const BIGNUM *priv_key,
214                            BIGNUM *pub_key)
215 {
216     int ret = 0;
217     BIGNUM *prk = BN_new();
218     BN_MONT_CTX *mont = NULL;
219
220     if (prk == NULL)
221         return 0;
222
223     if (dh->flags & DH_FLAG_CACHE_MONT_P) {
224         /*
225          * We take the input DH as const, but we lie, because in some cases we
226          * want to get a hold of its Montgomery context.
227          *
228          * We cast to remove the const qualifier in this case, it should be
229          * fine...
230          */
231         BN_MONT_CTX **pmont = (BN_MONT_CTX **)&dh->method_mont_p;
232
233         mont = BN_MONT_CTX_set_locked(pmont, dh->lock, dh->params.p, ctx);
234         if (mont == NULL)
235             goto err;
236     }
237     BN_with_flags(prk, priv_key, BN_FLG_CONSTTIME);
238
239     /* pub_key = g^priv_key mod p */
240     if (!dh->meth->bn_mod_exp(dh, pub_key, dh->params.g, prk, dh->params.p,
241                               ctx, mont))
242         goto err;
243     ret = 1;
244 err:
245     BN_clear_free(prk);
246     return ret;
247 }
248
249 static int generate_key(DH *dh)
250 {
251     int ok = 0;
252     int generate_new_key = 0;
253 #ifndef FIPS_MODULE
254     unsigned l;
255 #endif
256     BN_CTX *ctx = NULL;
257     BIGNUM *pub_key = NULL, *priv_key = NULL;
258
259     if (BN_num_bits(dh->params.p) > OPENSSL_DH_MAX_MODULUS_BITS) {
260         ERR_raise(ERR_LIB_DH, DH_R_MODULUS_TOO_LARGE);
261         return 0;
262     }
263
264     if (BN_num_bits(dh->params.p) < DH_MIN_MODULUS_BITS) {
265         ERR_raise(ERR_LIB_DH, DH_R_MODULUS_TOO_SMALL);
266         return 0;
267     }
268
269     ctx = BN_CTX_new_ex(dh->libctx);
270     if (ctx == NULL)
271         goto err;
272
273     if (dh->priv_key == NULL) {
274         priv_key = BN_secure_new();
275         if (priv_key == NULL)
276             goto err;
277         generate_new_key = 1;
278     } else {
279         priv_key = dh->priv_key;
280     }
281
282     if (dh->pub_key == NULL) {
283         pub_key = BN_new();
284         if (pub_key == NULL)
285             goto err;
286     } else {
287         pub_key = dh->pub_key;
288     }
289     if (generate_new_key) {
290         /* Is it an approved safe prime ?*/
291         if (DH_get_nid(dh) != NID_undef) {
292             int max_strength =
293                     ifc_ffc_compute_security_bits(BN_num_bits(dh->params.p));
294
295             if (dh->params.q == NULL
296                 || dh->length > BN_num_bits(dh->params.q))
297                 goto err;
298             /* dh->length = maximum bit length of generated private key */
299             if (!ossl_ffc_generate_private_key(ctx, &dh->params, dh->length,
300                                                max_strength, priv_key))
301                 goto err;
302         } else {
303 #ifdef FIPS_MODULE
304             if (dh->params.q == NULL)
305                 goto err;
306 #else
307             if (dh->params.q == NULL) {
308                 /* secret exponent length, must satisfy 2^(l-1) <= p */
309                 if (dh->length != 0
310                     && dh->length >= BN_num_bits(dh->params.p))
311                     goto err;
312                 l = dh->length ? dh->length : BN_num_bits(dh->params.p) - 1;
313                 if (!BN_priv_rand_ex(priv_key, l, BN_RAND_TOP_ONE,
314                                      BN_RAND_BOTTOM_ANY, ctx))
315                     goto err;
316                 /*
317                  * We handle just one known case where g is a quadratic non-residue:
318                  * for g = 2: p % 8 == 3
319                  */
320                 if (BN_is_word(dh->params.g, DH_GENERATOR_2)
321                     && !BN_is_bit_set(dh->params.p, 2)) {
322                     /* clear bit 0, since it won't be a secret anyway */
323                     if (!BN_clear_bit(priv_key, 0))
324                         goto err;
325                 }
326             } else
327 #endif
328             {
329                 /* Do a partial check for invalid p, q, g */
330                 if (!ossl_ffc_params_simple_validate(dh->libctx, &dh->params,
331                                                      FFC_PARAM_TYPE_DH))
332                     goto err;
333                 /*
334                  * For FFC FIPS 186-4 keygen
335                  * security strength s = 112,
336                  * Max Private key size N = len(q)
337                  */
338                 if (!ossl_ffc_generate_private_key(ctx, &dh->params,
339                                                    BN_num_bits(dh->params.q),
340                                                    MIN_STRENGTH,
341                                                    priv_key))
342                     goto err;
343             }
344         }
345     }
346
347     if (!dh_generate_public_key(ctx, dh, priv_key, pub_key))
348         goto err;
349
350     dh->pub_key = pub_key;
351     dh->priv_key = priv_key;
352     dh->dirty_cnt++;
353     ok = 1;
354  err:
355     if (ok != 1)
356         ERR_raise(ERR_LIB_DH, ERR_R_BN_LIB);
357
358     if (pub_key != dh->pub_key)
359         BN_free(pub_key);
360     if (priv_key != dh->priv_key)
361         BN_free(priv_key);
362     BN_CTX_free(ctx);
363     return ok;
364 }
365
366 int dh_buf2key(DH *dh, const unsigned char *buf, size_t len)
367 {
368     int err_reason = DH_R_BN_ERROR;
369     BIGNUM *pubkey = NULL;
370     const BIGNUM *p;
371     size_t p_size;
372
373     if ((pubkey = BN_bin2bn(buf, len, NULL)) == NULL)
374         goto err;
375     DH_get0_pqg(dh, &p, NULL, NULL);
376     if (p == NULL || (p_size = BN_num_bytes(p)) == 0) {
377         err_reason = DH_R_NO_PARAMETERS_SET;
378         goto err;
379     }
380     /*
381      * As per Section 4.2.8.1 of RFC 8446 fail if DHE's
382      * public key is of size not equal to size of p
383      */
384     if (BN_is_zero(pubkey) || p_size != len) {
385         err_reason = DH_R_INVALID_PUBKEY;
386         goto err;
387     }
388     if (DH_set0_key(dh, pubkey, NULL) != 1)
389         goto err;
390     return 1;
391 err:
392     ERR_raise(ERR_LIB_DH, err_reason);
393     BN_free(pubkey);
394     return 0;
395 }
396
397 size_t dh_key2buf(const DH *dh, unsigned char **pbuf_out, size_t size, int alloc)
398 {
399     const BIGNUM *pubkey;
400     unsigned char *pbuf = NULL;
401     const BIGNUM *p;
402     int p_size;
403
404     DH_get0_pqg(dh, &p, NULL, NULL);
405     DH_get0_key(dh, &pubkey, NULL);
406     if (p == NULL || pubkey == NULL
407             || (p_size = BN_num_bytes(p)) == 0
408             || BN_num_bytes(pubkey) == 0) {
409         ERR_raise(ERR_LIB_DH, DH_R_INVALID_PUBKEY);
410         return 0;
411     }
412     if (pbuf_out != NULL && (alloc || *pbuf_out != NULL)) {
413         if (!alloc) {
414             if (size >= (size_t)p_size)
415                 pbuf = *pbuf_out;
416         } else {
417             pbuf = OPENSSL_malloc(p_size);
418         }
419
420         if (pbuf == NULL) {
421             ERR_raise(ERR_LIB_DH, ERR_R_MALLOC_FAILURE);
422             return 0;
423         }
424         /*
425          * As per Section 4.2.8.1 of RFC 8446 left pad public
426          * key with zeros to the size of p
427          */
428         if (BN_bn2binpad(pubkey, pbuf, p_size) < 0) {
429             if (alloc)
430                 OPENSSL_free(pbuf);
431             ERR_raise(ERR_LIB_DH, DH_R_BN_ERROR);
432             return 0;
433         }
434         *pbuf_out = pbuf;
435     }
436     return p_size;
437 }