More error handling to HKDF and one more case in TLS1-PRF
[openssl.git] / crypto / kdf / tls1_prf.c
1 /*
2  * Copyright 2016 The OpenSSL Project Authors. All Rights Reserved.
3  *
4  * Licensed under the OpenSSL license (the "License").  You may not use
5  * this file except in compliance with the License.  You can obtain a copy
6  * in the file LICENSE in the source distribution or at
7  * https://www.openssl.org/source/license.html
8  */
9
10 #include <stdio.h>
11 #include "internal/cryptlib.h"
12 #include <openssl/kdf.h>
13 #include <openssl/evp.h>
14 #include "internal/evp_int.h"
15
16 static int tls1_prf_alg(const EVP_MD *md,
17                         const unsigned char *sec, size_t slen,
18                         const unsigned char *seed, size_t seed_len,
19                         unsigned char *out, size_t olen);
20
21 #define TLS1_PRF_MAXBUF 1024
22
23 /* TLS KDF pkey context structure */
24
25 typedef struct {
26     /* Digest to use for PRF */
27     const EVP_MD *md;
28     /* Secret value to use for PRF */
29     unsigned char *sec;
30     size_t seclen;
31     /* Buffer of concatenated seed data */
32     unsigned char seed[TLS1_PRF_MAXBUF];
33     size_t seedlen;
34 } TLS1_PRF_PKEY_CTX;
35
36 static int pkey_tls1_prf_init(EVP_PKEY_CTX *ctx)
37 {
38     TLS1_PRF_PKEY_CTX *kctx;
39
40     kctx = OPENSSL_zalloc(sizeof(*kctx));
41     if (kctx == NULL)
42         return 0;
43     ctx->data = kctx;
44
45     return 1;
46 }
47
48 static void pkey_tls1_prf_cleanup(EVP_PKEY_CTX *ctx)
49 {
50     TLS1_PRF_PKEY_CTX *kctx = ctx->data;
51     OPENSSL_clear_free(kctx->sec, kctx->seclen);
52     OPENSSL_cleanse(kctx->seed, kctx->seedlen);
53     OPENSSL_free(kctx);
54 }
55
56 static int pkey_tls1_prf_ctrl(EVP_PKEY_CTX *ctx, int type, int p1, void *p2)
57 {
58     TLS1_PRF_PKEY_CTX *kctx = ctx->data;
59     switch (type) {
60     case EVP_PKEY_CTRL_TLS_MD:
61         kctx->md = p2;
62         return 1;
63
64     case EVP_PKEY_CTRL_TLS_SECRET:
65         if (p1 < 0)
66             return 0;
67         if (kctx->sec != NULL)
68             OPENSSL_clear_free(kctx->sec, kctx->seclen);
69         OPENSSL_cleanse(kctx->seed, kctx->seedlen);
70         kctx->seedlen = 0;
71         kctx->sec = OPENSSL_memdup(p2, p1);
72         if (kctx->sec == NULL)
73             return 0;
74         kctx->seclen  = p1;
75         return 1;
76
77     case EVP_PKEY_CTRL_TLS_SEED:
78         if (p1 == 0 || p2 == NULL)
79             return 1;
80         if (p1 < 0 || p1 > (int)(TLS1_PRF_MAXBUF - kctx->seedlen))
81             return 0;
82         memcpy(kctx->seed + kctx->seedlen, p2, p1);
83         kctx->seedlen += p1;
84         return 1;
85
86     default:
87         return -2;
88
89     }
90 }
91
92 static int pkey_tls1_prf_ctrl_str(EVP_PKEY_CTX *ctx,
93                                   const char *type, const char *value)
94 {
95     if (value == NULL) {
96         KDFerr(KDF_F_PKEY_TLS1_PRF_CTRL_STR, KDF_R_VALUE_MISSING);
97         return 0;
98     }
99     if (strcmp(type, "md") == 0) {
100         TLS1_PRF_PKEY_CTX *kctx = ctx->data;
101
102         const EVP_MD *md = EVP_get_digestbyname(value);
103         if (md == NULL) {
104             KDFerr(KDF_F_PKEY_TLS1_PRF_CTRL_STR, KDF_R_INVALID_DIGEST);
105             return 0;
106         }
107         kctx->md = md;
108         return 1;
109     }
110     if (strcmp(type, "secret") == 0)
111         return EVP_PKEY_CTX_str2ctrl(ctx, EVP_PKEY_CTRL_TLS_SECRET, value);
112     if (strcmp(type, "hexsecret") == 0)
113         return EVP_PKEY_CTX_hex2ctrl(ctx, EVP_PKEY_CTRL_TLS_SECRET, value);
114     if (strcmp(type, "seed") == 0)
115         return EVP_PKEY_CTX_str2ctrl(ctx, EVP_PKEY_CTRL_TLS_SEED, value);
116     if (strcmp(type, "hexseed") == 0)
117         return EVP_PKEY_CTX_hex2ctrl(ctx, EVP_PKEY_CTRL_TLS_SEED, value);
118
119     KDFerr(KDF_F_PKEY_TLS1_PRF_CTRL_STR, KDF_R_UNKNOWN_PARAMETER_TYPE);
120     return -2;
121 }
122
123 static int pkey_tls1_prf_derive(EVP_PKEY_CTX *ctx, unsigned char *key,
124                                 size_t *keylen)
125 {
126     TLS1_PRF_PKEY_CTX *kctx = ctx->data;
127     if (kctx->md == NULL || kctx->sec == NULL || kctx->seedlen == 0) {
128         KDFerr(KDF_F_PKEY_TLS1_PRF_DERIVE, KDF_R_MISSING_PARAMETER);
129         return 0;
130     }
131     return tls1_prf_alg(kctx->md, kctx->sec, kctx->seclen,
132                         kctx->seed, kctx->seedlen,
133                         key, *keylen);
134 }
135
136 const EVP_PKEY_METHOD tls1_prf_pkey_meth = {
137     EVP_PKEY_TLS1_PRF,
138     0,
139     pkey_tls1_prf_init,
140     0,
141     pkey_tls1_prf_cleanup,
142
143     0, 0,
144     0, 0,
145
146     0,
147     0,
148
149     0,
150     0,
151
152     0, 0,
153
154     0, 0, 0, 0,
155
156     0, 0,
157
158     0, 0,
159
160     0,
161     pkey_tls1_prf_derive,
162     pkey_tls1_prf_ctrl,
163     pkey_tls1_prf_ctrl_str
164 };
165
166 static int tls1_prf_P_hash(const EVP_MD *md,
167                            const unsigned char *sec, size_t sec_len,
168                            const unsigned char *seed, size_t seed_len,
169                            unsigned char *out, size_t olen)
170 {
171     int chunk;
172     EVP_MD_CTX *ctx = NULL, *ctx_tmp = NULL, *ctx_init = NULL;
173     EVP_PKEY *mac_key = NULL;
174     unsigned char A1[EVP_MAX_MD_SIZE];
175     size_t A1_len;
176     int ret = 0;
177
178     chunk = EVP_MD_size(md);
179     OPENSSL_assert(chunk >= 0);
180
181     ctx = EVP_MD_CTX_new();
182     ctx_tmp = EVP_MD_CTX_new();
183     ctx_init = EVP_MD_CTX_new();
184     if (ctx == NULL || ctx_tmp == NULL || ctx_init == NULL)
185         goto err;
186     EVP_MD_CTX_set_flags(ctx_init, EVP_MD_CTX_FLAG_NON_FIPS_ALLOW);
187     mac_key = EVP_PKEY_new_mac_key(EVP_PKEY_HMAC, NULL, sec, sec_len);
188     if (mac_key == NULL)
189         goto err;
190     if (!EVP_DigestSignInit(ctx_init, NULL, md, NULL, mac_key))
191         goto err;
192     if (!EVP_MD_CTX_copy_ex(ctx, ctx_init))
193         goto err;
194     if (seed != NULL && !EVP_DigestSignUpdate(ctx, seed, seed_len))
195         goto err;
196     if (!EVP_DigestSignFinal(ctx, A1, &A1_len))
197         goto err;
198
199     for (;;) {
200         /* Reinit mac contexts */
201         if (!EVP_MD_CTX_copy_ex(ctx, ctx_init))
202             goto err;
203         if (!EVP_DigestSignUpdate(ctx, A1, A1_len))
204             goto err;
205         if (olen > (size_t)chunk && !EVP_MD_CTX_copy_ex(ctx_tmp, ctx))
206             goto err;
207         if (seed && !EVP_DigestSignUpdate(ctx, seed, seed_len))
208             goto err;
209
210         if (olen > (size_t)chunk) {
211             size_t mac_len;
212             if (!EVP_DigestSignFinal(ctx, out, &mac_len))
213                 goto err;
214             out += mac_len;
215             olen -= mac_len;
216             /* calc the next A1 value */
217             if (!EVP_DigestSignFinal(ctx_tmp, A1, &A1_len))
218                 goto err;
219         } else {                /* last one */
220
221             if (!EVP_DigestSignFinal(ctx, A1, &A1_len))
222                 goto err;
223             memcpy(out, A1, olen);
224             break;
225         }
226     }
227     ret = 1;
228  err:
229     EVP_PKEY_free(mac_key);
230     EVP_MD_CTX_free(ctx);
231     EVP_MD_CTX_free(ctx_tmp);
232     EVP_MD_CTX_free(ctx_init);
233     OPENSSL_cleanse(A1, sizeof(A1));
234     return ret;
235 }
236
237 static int tls1_prf_alg(const EVP_MD *md,
238                         const unsigned char *sec, size_t slen,
239                         const unsigned char *seed, size_t seed_len,
240                         unsigned char *out, size_t olen)
241 {
242
243     if (EVP_MD_type(md) == NID_md5_sha1) {
244         size_t i;
245         unsigned char *tmp;
246         if (!tls1_prf_P_hash(EVP_md5(), sec, slen/2 + (slen & 1),
247                          seed, seed_len, out, olen))
248             return 0;
249
250         tmp = OPENSSL_malloc(olen);
251         if (tmp == NULL)
252             return 0;
253         if (!tls1_prf_P_hash(EVP_sha1(), sec + slen/2, slen/2 + (slen & 1),
254                          seed, seed_len, tmp, olen)) {
255             OPENSSL_clear_free(tmp, olen);
256             return 0;
257         }
258         for (i = 0; i < olen; i++)
259             out[i] ^= tmp[i];
260         OPENSSL_clear_free(tmp, olen);
261         return 1;
262     }
263     if (!tls1_prf_P_hash(md, sec, slen, seed, seed_len, out, olen))
264         return 0;
265
266     return 1;
267 }