Following the changes to HKDF to accept a mode, add some tests for this
[openssl.git] / crypto / kdf / hkdf.c
1 /*
2  * Copyright 2016 The OpenSSL Project Authors. All Rights Reserved.
3  *
4  * Licensed under the OpenSSL license (the "License").  You may not use
5  * this file except in compliance with the License.  You can obtain a copy
6  * in the file LICENSE in the source distribution or at
7  * https://www.openssl.org/source/license.html
8  */
9
10 #include <stdlib.h>
11 #include <string.h>
12 #include <openssl/hmac.h>
13 #include <openssl/kdf.h>
14 #include <openssl/evp.h>
15 #include "internal/cryptlib.h"
16 #include "internal/evp_int.h"
17
18 #define HKDF_MAXBUF 1024
19
20 static unsigned char *HKDF(const EVP_MD *evp_md,
21                            const unsigned char *salt, size_t salt_len,
22                            const unsigned char *key, size_t key_len,
23                            const unsigned char *info, size_t info_len,
24                            unsigned char *okm, size_t okm_len);
25
26 static unsigned char *HKDF_Extract(const EVP_MD *evp_md,
27                                    const unsigned char *salt, size_t salt_len,
28                                    const unsigned char *key, size_t key_len,
29                                    unsigned char *prk, size_t *prk_len);
30
31 static unsigned char *HKDF_Expand(const EVP_MD *evp_md,
32                                   const unsigned char *prk, size_t prk_len,
33                                   const unsigned char *info, size_t info_len,
34                                   unsigned char *okm, size_t okm_len);
35
36 typedef struct {
37     int mode;
38     const EVP_MD *md;
39     unsigned char *salt;
40     size_t salt_len;
41     unsigned char *key;
42     size_t key_len;
43     unsigned char info[HKDF_MAXBUF];
44     size_t info_len;
45 } HKDF_PKEY_CTX;
46
47 static int pkey_hkdf_init(EVP_PKEY_CTX *ctx)
48 {
49     HKDF_PKEY_CTX *kctx;
50
51     kctx = OPENSSL_zalloc(sizeof(*kctx));
52     if (kctx == NULL)
53         return 0;
54
55     ctx->data = kctx;
56
57     return 1;
58 }
59
60 static void pkey_hkdf_cleanup(EVP_PKEY_CTX *ctx)
61 {
62     HKDF_PKEY_CTX *kctx = ctx->data;
63     OPENSSL_clear_free(kctx->salt, kctx->salt_len);
64     OPENSSL_clear_free(kctx->key, kctx->key_len);
65     OPENSSL_cleanse(kctx->info, kctx->info_len);
66     OPENSSL_free(kctx);
67 }
68
69 static int pkey_hkdf_ctrl(EVP_PKEY_CTX *ctx, int type, int p1, void *p2)
70 {
71     HKDF_PKEY_CTX *kctx = ctx->data;
72
73     switch (type) {
74     case EVP_PKEY_CTRL_HKDF_MD:
75         if (p2 == NULL)
76             return 0;
77
78         kctx->md = p2;
79         return 1;
80
81     case EVP_PKEY_CTRL_HKDF_MODE:
82         kctx->mode = p1;
83         return 1;
84
85     case EVP_PKEY_CTRL_HKDF_SALT:
86         if (p1 == 0 || p2 == NULL)
87             return 1;
88
89         if (p1 < 0)
90             return 0;
91
92         if (kctx->salt != NULL)
93             OPENSSL_clear_free(kctx->salt, kctx->salt_len);
94
95         kctx->salt = OPENSSL_memdup(p2, p1);
96         if (kctx->salt == NULL)
97             return 0;
98
99         kctx->salt_len = p1;
100         return 1;
101
102     case EVP_PKEY_CTRL_HKDF_KEY:
103         if (p1 < 0)
104             return 0;
105
106         if (kctx->key != NULL)
107             OPENSSL_clear_free(kctx->key, kctx->key_len);
108
109         kctx->key = OPENSSL_memdup(p2, p1);
110         if (kctx->key == NULL)
111             return 0;
112
113         kctx->key_len  = p1;
114         return 1;
115
116     case EVP_PKEY_CTRL_HKDF_INFO:
117         if (p1 == 0 || p2 == NULL)
118             return 1;
119
120         if (p1 < 0 || p1 > (int)(HKDF_MAXBUF - kctx->info_len))
121             return 0;
122
123         memcpy(kctx->info + kctx->info_len, p2, p1);
124         kctx->info_len += p1;
125         return 1;
126
127     default:
128         return -2;
129
130     }
131 }
132
133 static int pkey_hkdf_ctrl_str(EVP_PKEY_CTX *ctx, const char *type,
134                               const char *value)
135 {
136     if (strcmp(type, "mode") == 0) {
137         int mode;
138
139         if (strcmp(value, "EXTRACT_AND_EXPAND") == 0)
140             mode = EVP_PKEY_HKDEF_MODE_EXTRACT_AND_EXPAND;
141         else if (strcmp(value, "EXTRACT_ONLY") == 0)
142             mode = EVP_PKEY_HKDEF_MODE_EXTRACT_ONLY;
143         else if (strcmp(value, "EXPAND_ONLY") == 0)
144             mode = EVP_PKEY_HKDEF_MODE_EXPAND_ONLY;
145         else
146             return 0;
147
148         return EVP_PKEY_CTX_hkdf_mode(ctx, mode);
149     }
150
151     if (strcmp(type, "md") == 0)
152         return EVP_PKEY_CTX_set_hkdf_md(ctx, EVP_get_digestbyname(value));
153
154     if (strcmp(type, "salt") == 0)
155         return EVP_PKEY_CTX_str2ctrl(ctx, EVP_PKEY_CTRL_HKDF_SALT, value);
156
157     if (strcmp(type, "hexsalt") == 0)
158         return EVP_PKEY_CTX_hex2ctrl(ctx, EVP_PKEY_CTRL_HKDF_SALT, value);
159
160     if (strcmp(type, "key") == 0)
161         return EVP_PKEY_CTX_str2ctrl(ctx, EVP_PKEY_CTRL_HKDF_KEY, value);
162
163     if (strcmp(type, "hexkey") == 0)
164         return EVP_PKEY_CTX_hex2ctrl(ctx, EVP_PKEY_CTRL_HKDF_KEY, value);
165
166     if (strcmp(type, "info") == 0)
167         return EVP_PKEY_CTX_str2ctrl(ctx, EVP_PKEY_CTRL_HKDF_INFO, value);
168
169     if (strcmp(type, "hexinfo") == 0)
170         return EVP_PKEY_CTX_hex2ctrl(ctx, EVP_PKEY_CTRL_HKDF_INFO, value);
171
172     return -2;
173 }
174
175 static int pkey_hkdf_derive(EVP_PKEY_CTX *ctx, unsigned char *key,
176                             size_t *keylen)
177 {
178     HKDF_PKEY_CTX *kctx = ctx->data;
179
180     if (kctx->md == NULL || kctx->key == NULL)
181         return 0;
182
183     switch (kctx->mode) {
184     case EVP_PKEY_HKDEF_MODE_EXTRACT_AND_EXPAND:
185         return HKDF(kctx->md, kctx->salt, kctx->salt_len, kctx->key,
186                     kctx->key_len, kctx->info, kctx->info_len, key,
187                     *keylen) != NULL;
188
189     case EVP_PKEY_HKDEF_MODE_EXTRACT_ONLY:
190         if (key == NULL) {
191             *keylen = EVP_MD_size(kctx->md);
192             return 1;
193         }
194         return HKDF_Extract(kctx->md, kctx->salt, kctx->salt_len, kctx->key,
195                             kctx->key_len, key, keylen) != NULL;
196
197     case EVP_PKEY_HKDEF_MODE_EXPAND_ONLY:
198         return HKDF_Expand(kctx->md, kctx->key, kctx->key_len, kctx->info,
199                            kctx->info_len, key, *keylen) != NULL;
200
201     default:
202         return 0;
203     }
204 }
205
206 const EVP_PKEY_METHOD hkdf_pkey_meth = {
207     EVP_PKEY_HKDF,
208     0,
209     pkey_hkdf_init,
210     0,
211     pkey_hkdf_cleanup,
212
213     0, 0,
214     0, 0,
215
216     0,
217     0,
218
219     0,
220     0,
221
222     0, 0,
223
224     0, 0, 0, 0,
225
226     0, 0,
227
228     0, 0,
229
230     0,
231     pkey_hkdf_derive,
232     pkey_hkdf_ctrl,
233     pkey_hkdf_ctrl_str
234 };
235
236 static unsigned char *HKDF(const EVP_MD *evp_md,
237                            const unsigned char *salt, size_t salt_len,
238                            const unsigned char *key, size_t key_len,
239                            const unsigned char *info, size_t info_len,
240                            unsigned char *okm, size_t okm_len)
241 {
242     unsigned char prk[EVP_MAX_MD_SIZE];
243     unsigned char *ret;
244     size_t prk_len;
245
246     if (!HKDF_Extract(evp_md, salt, salt_len, key, key_len, prk, &prk_len))
247         return NULL;
248
249     ret = HKDF_Expand(evp_md, prk, prk_len, info, info_len, okm, okm_len);
250     OPENSSL_cleanse(prk, sizeof(prk));
251
252     return ret;
253 }
254
255 static unsigned char *HKDF_Extract(const EVP_MD *evp_md,
256                                    const unsigned char *salt, size_t salt_len,
257                                    const unsigned char *key, size_t key_len,
258                                    unsigned char *prk, size_t *prk_len)
259 {
260     unsigned int tmp_len;
261
262     if (!HMAC(evp_md, salt, salt_len, key, key_len, prk, &tmp_len))
263         return NULL;
264
265     *prk_len = tmp_len;
266     return prk;
267 }
268
269 static unsigned char *HKDF_Expand(const EVP_MD *evp_md,
270                                   const unsigned char *prk, size_t prk_len,
271                                   const unsigned char *info, size_t info_len,
272                                   unsigned char *okm, size_t okm_len)
273 {
274     HMAC_CTX *hmac;
275
276     unsigned int i;
277
278     unsigned char prev[EVP_MAX_MD_SIZE];
279
280     size_t done_len = 0, dig_len = EVP_MD_size(evp_md);
281
282     size_t n = okm_len / dig_len;
283     if (okm_len % dig_len)
284         n++;
285
286     if (n > 255 || okm == NULL)
287         return NULL;
288
289     if ((hmac = HMAC_CTX_new()) == NULL)
290         return NULL;
291
292     if (!HMAC_Init_ex(hmac, prk, prk_len, evp_md, NULL))
293         goto err;
294
295     for (i = 1; i <= n; i++) {
296         size_t copy_len;
297         const unsigned char ctr = i;
298
299         if (i > 1) {
300             if (!HMAC_Init_ex(hmac, NULL, 0, NULL, NULL))
301                 goto err;
302
303             if (!HMAC_Update(hmac, prev, dig_len))
304                 goto err;
305         }
306
307         if (!HMAC_Update(hmac, info, info_len))
308             goto err;
309
310         if (!HMAC_Update(hmac, &ctr, 1))
311             goto err;
312
313         if (!HMAC_Final(hmac, prev, NULL))
314             goto err;
315
316         copy_len = (done_len + dig_len > okm_len) ?
317                        okm_len - done_len :
318                        dig_len;
319
320         memcpy(okm + done_len, prev, copy_len);
321
322         done_len += copy_len;
323     }
324
325     HMAC_CTX_free(hmac);
326     return okm;
327
328  err:
329     HMAC_CTX_free(hmac);
330     return NULL;
331 }