f5e106346131d45b1ee1adc2cd0d62548dd101ad
[openssl.git] / crypto / kdf / tls1_prf.c
1 /*
2  * Copyright 2016 The OpenSSL Project Authors. All Rights Reserved.
3  *
4  * Licensed under the OpenSSL license (the "License").  You may not use
5  * this file except in compliance with the License.  You can obtain a copy
6  * in the file LICENSE in the source distribution or at
7  * https://www.openssl.org/source/license.html
8  */
9
10 #include <stdio.h>
11 #include "internal/cryptlib.h"
12 #include <openssl/kdf.h>
13 #include <openssl/evp.h>
14 #include "internal/evp_int.h"
15
16 static int tls1_prf_alg(const EVP_MD *md,
17                         const unsigned char *sec, size_t slen,
18                         const unsigned char *seed, size_t seed_len,
19                         unsigned char *out, size_t olen);
20
21 #define TLS1_PRF_MAXBUF 1024
22
23 /* TLS KDF pkey context structure */
24
25 typedef struct {
26     /* Digest to use for PRF */
27     const EVP_MD *md;
28     /* Secret value to use for PRF */
29     unsigned char *sec;
30     size_t seclen;
31     /* Buffer of concatenated seed data */
32     unsigned char seed[TLS1_PRF_MAXBUF];
33     size_t seedlen;
34 } TLS1_PRF_PKEY_CTX;
35
36 static int pkey_tls1_prf_init(EVP_PKEY_CTX *ctx)
37 {
38     TLS1_PRF_PKEY_CTX *kctx;
39
40     kctx = OPENSSL_zalloc(sizeof(*kctx));
41     if (kctx == NULL)
42         return 0;
43     ctx->data = kctx;
44
45     return 1;
46 }
47
48 static void pkey_tls1_prf_cleanup(EVP_PKEY_CTX *ctx)
49 {
50     TLS1_PRF_PKEY_CTX *kctx = ctx->data;
51     OPENSSL_clear_free(kctx->sec, kctx->seclen);
52     OPENSSL_cleanse(kctx->seed, kctx->seedlen);
53     OPENSSL_free(kctx);
54 }
55
56 static int pkey_tls1_prf_ctrl(EVP_PKEY_CTX *ctx, int type, int p1, void *p2)
57 {
58     TLS1_PRF_PKEY_CTX *kctx = ctx->data;
59     switch (type) {
60     case EVP_PKEY_CTRL_TLS_MD:
61         kctx->md = p2;
62         return 1;
63
64     case EVP_PKEY_CTRL_TLS_SECRET:
65         if (p1 < 0)
66             return 0;
67         if (kctx->sec != NULL)
68             OPENSSL_clear_free(kctx->sec, kctx->seclen);
69         OPENSSL_cleanse(kctx->seed, kctx->seedlen);
70         kctx->seedlen = 0;
71         kctx->sec = OPENSSL_memdup(p2, p1);
72         if (kctx->sec == NULL)
73             return 0;
74         kctx->seclen  = p1;
75         return 1;
76
77     case EVP_PKEY_CTRL_TLS_SEED:
78         if (p1 == 0 || p2 == NULL)
79             return 1;
80         if (p1 < 0 || p1 > (int)(TLS1_PRF_MAXBUF - kctx->seedlen))
81             return 0;
82         memcpy(kctx->seed + kctx->seedlen, p2, p1);
83         kctx->seedlen += p1;
84         return 1;
85
86     default:
87         return -2;
88
89     }
90 }
91
92 static int pkey_tls1_prf_ctrl_str(EVP_PKEY_CTX *ctx,
93                                   const char *type, const char *value)
94 {
95     if (value == NULL) {
96         KDFerr(KDF_F_PKEY_TLS1_PRF_CTRL_STR, KDF_R_VALUE_MISSING);
97         return 0;
98     }
99     if (strcmp(type, "md") == 0) {
100         TLS1_PRF_PKEY_CTX *kctx = ctx->data;
101
102         const EVP_MD *md = EVP_get_digestbyname(value);
103         if (md == NULL) {
104             KDFerr(KDF_F_PKEY_TLS1_PRF_CTRL_STR, KDF_R_INVALID_DIGEST);
105             return 0;
106         }
107         kctx->md = md;
108         return 1;
109     }
110     if (strcmp(type, "secret") == 0)
111         return EVP_PKEY_CTX_str2ctrl(ctx, EVP_PKEY_CTRL_TLS_SECRET, value);
112     if (strcmp(type, "hexsecret") == 0)
113         return EVP_PKEY_CTX_hex2ctrl(ctx, EVP_PKEY_CTRL_TLS_SECRET, value);
114     if (strcmp(type, "seed") == 0)
115         return EVP_PKEY_CTX_str2ctrl(ctx, EVP_PKEY_CTRL_TLS_SEED, value);
116     if (strcmp(type, "hexseed") == 0)
117         return EVP_PKEY_CTX_hex2ctrl(ctx, EVP_PKEY_CTRL_TLS_SEED, value);
118
119     KDFerr(KDF_F_PKEY_TLS1_PRF_CTRL_STR, KDF_R_UNKNOWN_PARAMETER_TYPE);
120     return -2;
121 }
122
123 static int pkey_tls1_prf_derive(EVP_PKEY_CTX *ctx, unsigned char *key,
124                                 size_t *keylen)
125 {
126     TLS1_PRF_PKEY_CTX *kctx = ctx->data;
127     if (kctx->md == NULL) {
128         KDFerr(KDF_F_PKEY_TLS1_PRF_DERIVE, KDF_R_MISSING_MESSAGE_DIGEST);
129         return 0;
130     }
131     if (kctx->sec == NULL || kctx->seedlen == 0) {
132         KDFerr(KDF_F_PKEY_TLS1_PRF_DERIVE, KDF_R_MISSING_SEED);
133         return 0;
134     }
135     return tls1_prf_alg(kctx->md, kctx->sec, kctx->seclen,
136                         kctx->seed, kctx->seedlen,
137                         key, *keylen);
138 }
139
140 const EVP_PKEY_METHOD tls1_prf_pkey_meth = {
141     EVP_PKEY_TLS1_PRF,
142     0,
143     pkey_tls1_prf_init,
144     0,
145     pkey_tls1_prf_cleanup,
146
147     0, 0,
148     0, 0,
149
150     0,
151     0,
152
153     0,
154     0,
155
156     0, 0,
157
158     0, 0, 0, 0,
159
160     0, 0,
161
162     0, 0,
163
164     0,
165     pkey_tls1_prf_derive,
166     pkey_tls1_prf_ctrl,
167     pkey_tls1_prf_ctrl_str
168 };
169
170 static int tls1_prf_P_hash(const EVP_MD *md,
171                            const unsigned char *sec, size_t sec_len,
172                            const unsigned char *seed, size_t seed_len,
173                            unsigned char *out, size_t olen)
174 {
175     int chunk;
176     EVP_MD_CTX *ctx = NULL, *ctx_tmp = NULL, *ctx_init = NULL;
177     EVP_PKEY *mac_key = NULL;
178     unsigned char A1[EVP_MAX_MD_SIZE];
179     size_t A1_len;
180     int ret = 0;
181
182     chunk = EVP_MD_size(md);
183     OPENSSL_assert(chunk >= 0);
184
185     ctx = EVP_MD_CTX_new();
186     ctx_tmp = EVP_MD_CTX_new();
187     ctx_init = EVP_MD_CTX_new();
188     if (ctx == NULL || ctx_tmp == NULL || ctx_init == NULL)
189         goto err;
190     EVP_MD_CTX_set_flags(ctx_init, EVP_MD_CTX_FLAG_NON_FIPS_ALLOW);
191     mac_key = EVP_PKEY_new_mac_key(EVP_PKEY_HMAC, NULL, sec, sec_len);
192     if (mac_key == NULL)
193         goto err;
194     if (!EVP_DigestSignInit(ctx_init, NULL, md, NULL, mac_key))
195         goto err;
196     if (!EVP_MD_CTX_copy_ex(ctx, ctx_init))
197         goto err;
198     if (seed != NULL && !EVP_DigestSignUpdate(ctx, seed, seed_len))
199         goto err;
200     if (!EVP_DigestSignFinal(ctx, A1, &A1_len))
201         goto err;
202
203     for (;;) {
204         /* Reinit mac contexts */
205         if (!EVP_MD_CTX_copy_ex(ctx, ctx_init))
206             goto err;
207         if (!EVP_DigestSignUpdate(ctx, A1, A1_len))
208             goto err;
209         if (olen > (size_t)chunk && !EVP_MD_CTX_copy_ex(ctx_tmp, ctx))
210             goto err;
211         if (seed && !EVP_DigestSignUpdate(ctx, seed, seed_len))
212             goto err;
213
214         if (olen > (size_t)chunk) {
215             size_t mac_len;
216             if (!EVP_DigestSignFinal(ctx, out, &mac_len))
217                 goto err;
218             out += mac_len;
219             olen -= mac_len;
220             /* calc the next A1 value */
221             if (!EVP_DigestSignFinal(ctx_tmp, A1, &A1_len))
222                 goto err;
223         } else {                /* last one */
224
225             if (!EVP_DigestSignFinal(ctx, A1, &A1_len))
226                 goto err;
227             memcpy(out, A1, olen);
228             break;
229         }
230     }
231     ret = 1;
232  err:
233     EVP_PKEY_free(mac_key);
234     EVP_MD_CTX_free(ctx);
235     EVP_MD_CTX_free(ctx_tmp);
236     EVP_MD_CTX_free(ctx_init);
237     OPENSSL_cleanse(A1, sizeof(A1));
238     return ret;
239 }
240
241 static int tls1_prf_alg(const EVP_MD *md,
242                         const unsigned char *sec, size_t slen,
243                         const unsigned char *seed, size_t seed_len,
244                         unsigned char *out, size_t olen)
245 {
246
247     if (EVP_MD_type(md) == NID_md5_sha1) {
248         size_t i;
249         unsigned char *tmp;
250         if (!tls1_prf_P_hash(EVP_md5(), sec, slen/2 + (slen & 1),
251                          seed, seed_len, out, olen))
252             return 0;
253
254         tmp = OPENSSL_malloc(olen);
255         if (tmp == NULL)
256             return 0;
257         if (!tls1_prf_P_hash(EVP_sha1(), sec + slen/2, slen/2 + (slen & 1),
258                          seed, seed_len, tmp, olen)) {
259             OPENSSL_clear_free(tmp, olen);
260             return 0;
261         }
262         for (i = 0; i < olen; i++)
263             out[i] ^= tmp[i];
264         OPENSSL_clear_free(tmp, olen);
265         return 1;
266     }
267     if (!tls1_prf_P_hash(md, sec, slen, seed, seed_len, out, olen))
268         return 0;
269
270     return 1;
271 }