Copyright consolidation 07/10
[openssl.git] / crypto / kdf / hkdf.c
1 /*
2  * Copyright 2016 The OpenSSL Project Authors. All Rights Reserved.
3  *
4  * Licensed under the OpenSSL license (the "License").  You may not use
5  * this file except in compliance with the License.  You can obtain a copy
6  * in the file LICENSE in the source distribution or at
7  * https://www.openssl.org/source/license.html
8  */
9
10 #include <stdlib.h>
11 #include <string.h>
12 #include <openssl/hmac.h>
13 #include <openssl/kdf.h>
14 #include <openssl/evp.h>
15 #include "internal/cryptlib.h"
16 #include "internal/evp_int.h"
17
18 #define HKDF_MAXBUF 1024
19
20 static unsigned char *HKDF(const EVP_MD *evp_md,
21                            const unsigned char *salt, size_t salt_len,
22                            const unsigned char *key, size_t key_len,
23                            const unsigned char *info, size_t info_len,
24                            unsigned char *okm, size_t okm_len);
25
26 static unsigned char *HKDF_Extract(const EVP_MD *evp_md,
27                                    const unsigned char *salt, size_t salt_len,
28                                    const unsigned char *key, size_t key_len,
29                                    unsigned char *prk, size_t *prk_len);
30
31 static unsigned char *HKDF_Expand(const EVP_MD *evp_md,
32                                   const unsigned char *prk, size_t prk_len,
33                                   const unsigned char *info, size_t info_len,
34                                   unsigned char *okm, size_t okm_len);
35
36 typedef struct {
37     const EVP_MD *md;
38     unsigned char *salt;
39     size_t salt_len;
40     unsigned char *key;
41     size_t key_len;
42     unsigned char info[HKDF_MAXBUF];
43     size_t info_len;
44 } HKDF_PKEY_CTX;
45
46 static int pkey_hkdf_init(EVP_PKEY_CTX *ctx)
47 {
48     HKDF_PKEY_CTX *kctx;
49
50     kctx = OPENSSL_zalloc(sizeof(*kctx));
51     if (kctx == NULL)
52         return 0;
53
54     ctx->data = kctx;
55
56     return 1;
57 }
58
59 static void pkey_hkdf_cleanup(EVP_PKEY_CTX *ctx)
60 {
61     HKDF_PKEY_CTX *kctx = ctx->data;
62     OPENSSL_clear_free(kctx->salt, kctx->salt_len);
63     OPENSSL_clear_free(kctx->key, kctx->key_len);
64     OPENSSL_cleanse(kctx->info, kctx->info_len);
65     OPENSSL_free(kctx);
66 }
67
68 static int pkey_hkdf_ctrl(EVP_PKEY_CTX *ctx, int type, int p1, void *p2)
69 {
70     HKDF_PKEY_CTX *kctx = ctx->data;
71
72     switch (type) {
73     case EVP_PKEY_CTRL_HKDF_MD:
74         if (p2 == NULL)
75             return 0;
76
77         kctx->md = p2;
78         return 1;
79
80     case EVP_PKEY_CTRL_HKDF_SALT:
81         if (p1 == 0 || p2 == NULL)
82             return 1;
83
84         if (p1 < 0)
85             return 0;
86
87         if (kctx->salt != NULL)
88             OPENSSL_clear_free(kctx->salt, kctx->salt_len);
89
90         kctx->salt = OPENSSL_memdup(p2, p1);
91         if (kctx->salt == NULL)
92             return 0;
93
94         kctx->salt_len = p1;
95         return 1;
96
97     case EVP_PKEY_CTRL_HKDF_KEY:
98         if (p1 < 0)
99             return 0;
100
101         if (kctx->key != NULL)
102             OPENSSL_clear_free(kctx->key, kctx->key_len);
103
104         kctx->key = OPENSSL_memdup(p2, p1);
105         if (kctx->key == NULL)
106             return 0;
107
108         kctx->key_len  = p1;
109         return 1;
110
111     case EVP_PKEY_CTRL_HKDF_INFO:
112         if (p1 == 0 || p2 == NULL)
113             return 1;
114
115         if (p1 < 0 || p1 > (int)(HKDF_MAXBUF - kctx->info_len))
116             return 0;
117
118         memcpy(kctx->info + kctx->info_len, p2, p1);
119         kctx->info_len += p1;
120         return 1;
121
122     default:
123         return -2;
124
125     }
126 }
127
128 static int pkey_hkdf_ctrl_str(EVP_PKEY_CTX *ctx, const char *type,
129                               const char *value)
130 {
131     if (strcmp(type, "md") == 0)
132         return EVP_PKEY_CTX_set_hkdf_md(ctx, EVP_get_digestbyname(value));
133
134     if (strcmp(type, "salt") == 0)
135         return EVP_PKEY_CTX_str2ctrl(ctx, EVP_PKEY_CTRL_HKDF_SALT, value);
136
137     if (strcmp(type, "hexsalt") == 0)
138         return EVP_PKEY_CTX_hex2ctrl(ctx, EVP_PKEY_CTRL_HKDF_SALT, value);
139
140     if (strcmp(type, "key") == 0)
141         return EVP_PKEY_CTX_str2ctrl(ctx, EVP_PKEY_CTRL_HKDF_KEY, value);
142
143     if (strcmp(type, "hexkey") == 0)
144         return EVP_PKEY_CTX_hex2ctrl(ctx, EVP_PKEY_CTRL_HKDF_KEY, value);
145
146     if (strcmp(type, "info") == 0)
147         return EVP_PKEY_CTX_str2ctrl(ctx, EVP_PKEY_CTRL_HKDF_INFO, value);
148
149     if (strcmp(type, "hexinfo") == 0)
150         return EVP_PKEY_CTX_hex2ctrl(ctx, EVP_PKEY_CTRL_HKDF_INFO, value);
151
152     return -2;
153 }
154
155 static int pkey_hkdf_derive(EVP_PKEY_CTX *ctx, unsigned char *key,
156                             size_t *keylen)
157 {
158     HKDF_PKEY_CTX *kctx = ctx->data;
159
160     if (kctx->md == NULL || kctx->key == NULL)
161         return 0;
162
163     if (HKDF(kctx->md, kctx->salt, kctx->salt_len, kctx->key, kctx->key_len,
164              kctx->info, kctx->info_len, key, *keylen) == NULL)
165     {
166         return 0;
167     }
168
169     return 1;
170 }
171
172 const EVP_PKEY_METHOD hkdf_pkey_meth = {
173     EVP_PKEY_HKDF,
174     0,
175     pkey_hkdf_init,
176     0,
177     pkey_hkdf_cleanup,
178
179     0, 0,
180     0, 0,
181
182     0,
183     0,
184
185     0,
186     0,
187
188     0, 0,
189
190     0, 0, 0, 0,
191
192     0, 0,
193
194     0, 0,
195
196     0,
197     pkey_hkdf_derive,
198     pkey_hkdf_ctrl,
199     pkey_hkdf_ctrl_str
200 };
201
202 static unsigned char *HKDF(const EVP_MD *evp_md,
203                            const unsigned char *salt, size_t salt_len,
204                            const unsigned char *key, size_t key_len,
205                            const unsigned char *info, size_t info_len,
206                            unsigned char *okm, size_t okm_len)
207 {
208     unsigned char prk[EVP_MAX_MD_SIZE];
209     size_t prk_len;
210
211     if (!HKDF_Extract(evp_md, salt, salt_len, key, key_len, prk, &prk_len))
212         return NULL;
213
214     return HKDF_Expand(evp_md, prk, prk_len, info, info_len, okm, okm_len);
215 }
216
217 static unsigned char *HKDF_Extract(const EVP_MD *evp_md,
218                                    const unsigned char *salt, size_t salt_len,
219                                    const unsigned char *key, size_t key_len,
220                                    unsigned char *prk, size_t *prk_len)
221 {
222     unsigned int tmp_len;
223
224     if (!HMAC(evp_md, salt, salt_len, key, key_len, prk, &tmp_len))
225         return NULL;
226
227     *prk_len = tmp_len;
228     return prk;
229 }
230
231 static unsigned char *HKDF_Expand(const EVP_MD *evp_md,
232                                   const unsigned char *prk, size_t prk_len,
233                                   const unsigned char *info, size_t info_len,
234                                   unsigned char *okm, size_t okm_len)
235 {
236     HMAC_CTX *hmac;
237
238     unsigned int i;
239
240     unsigned char prev[EVP_MAX_MD_SIZE];
241
242     size_t done_len = 0, dig_len = EVP_MD_size(evp_md);
243
244     size_t n = okm_len / dig_len;
245     if (okm_len % dig_len)
246         n++;
247
248     if (n > 255)
249         return NULL;
250
251     if ((hmac = HMAC_CTX_new()) == NULL)
252         return NULL;
253
254     if (!HMAC_Init_ex(hmac, prk, prk_len, evp_md, NULL))
255         goto err;
256
257     for (i = 1; i <= n; i++) {
258         size_t copy_len;
259         const unsigned char ctr = i;
260
261         if (i > 1) {
262             if (!HMAC_Init_ex(hmac, NULL, 0, NULL, NULL))
263                 goto err;
264
265             if (!HMAC_Update(hmac, prev, dig_len))
266                 goto err;
267         }
268
269         if (!HMAC_Update(hmac, info, info_len))
270             goto err;
271
272         if (!HMAC_Update(hmac, &ctr, 1))
273             goto err;
274
275         if (!HMAC_Final(hmac, prev, NULL))
276             goto err;
277
278         copy_len = (done_len + dig_len > okm_len) ?
279                        okm_len - done_len :
280                        dig_len;
281
282         memcpy(okm + done_len, prev, copy_len);
283
284         done_len += copy_len;
285     }
286
287     HMAC_CTX_free(hmac);
288     return okm;
289
290  err:
291     HMAC_CTX_free(hmac);
292     return NULL;
293 }