Add a (D)TLS dumper BIO
[openssl.git] / test / ssltestlib.c
1 /*
2  * Copyright 2016 The OpenSSL Project Authors. All Rights Reserved.
3  *
4  * Licensed under the OpenSSL license (the "License").  You may not use
5  * this file except in compliance with the License.  You can obtain a copy
6  * in the file LICENSE in the source distribution or at
7  * https://www.openssl.org/source/license.html
8  */
9
10 #include "ssltestlib.h"
11
12 static int tls_dump_new(BIO *bi);
13 static int tls_dump_free(BIO *a);
14 static int tls_dump_read(BIO *b, char *out, int outl);
15 static int tls_dump_write(BIO *b, const char *in, int inl);
16 static long tls_dump_ctrl(BIO *b, int cmd, long num, void *ptr);
17 static int tls_dump_gets(BIO *bp, char *buf, int size);
18 static int tls_dump_puts(BIO *bp, const char *str);
19
20 /* Choose a sufficiently large type likely to be unused for this custom BIO */
21 # define BIO_TYPE_TLS_DUMP_FILTER  (0x80 | BIO_TYPE_FILTER)
22
23 # define BIO_TYPE_MEMPACKET_TEST      0x81
24
25 static BIO_METHOD *method_tls_dump = NULL;
26 static BIO_METHOD *method_mempacket_test = NULL;
27
28 /* Note: Not thread safe! */
29 const BIO_METHOD *bio_f_tls_dump_filter(void)
30 {
31     if (method_tls_dump == NULL) {
32         method_tls_dump = BIO_meth_new(BIO_TYPE_TLS_DUMP_FILTER,
33                                         "TLS dump filter");
34         if (   method_tls_dump == NULL
35             || !BIO_meth_set_write(method_tls_dump, tls_dump_write)
36             || !BIO_meth_set_read(method_tls_dump, tls_dump_read)
37             || !BIO_meth_set_puts(method_tls_dump, tls_dump_puts)
38             || !BIO_meth_set_gets(method_tls_dump, tls_dump_gets)
39             || !BIO_meth_set_ctrl(method_tls_dump, tls_dump_ctrl)
40             || !BIO_meth_set_create(method_tls_dump, tls_dump_new)
41             || !BIO_meth_set_destroy(method_tls_dump, tls_dump_free))
42             return NULL;
43     }
44     return method_tls_dump;
45 }
46
47 void bio_f_tls_dump_filter_free(void)
48 {
49     BIO_meth_free(method_tls_dump);
50 }
51
52 static int tls_dump_new(BIO *bio)
53 {
54     BIO_set_init(bio, 1);
55     return 1;
56 }
57
58 static int tls_dump_free(BIO *bio)
59 {
60     BIO_set_init(bio, 0);
61
62     return 1;
63 }
64
65 static void copy_flags(BIO *bio)
66 {
67     int flags;
68     BIO *next = BIO_next(bio);
69
70     flags = BIO_test_flags(next, BIO_FLAGS_SHOULD_RETRY | BIO_FLAGS_RWS);
71     BIO_clear_flags(bio, BIO_FLAGS_SHOULD_RETRY | BIO_FLAGS_RWS);
72     BIO_set_flags(bio, flags);
73 }
74
75 #define RECORD_CONTENT_TYPE     0
76 #define RECORD_VERSION_HI       1
77 #define RECORD_VERSION_LO       2
78 #define RECORD_EPOCH_HI         3
79 #define RECORD_EPOCH_LO         4
80 #define RECORD_SEQUENCE_START   5
81 #define RECORD_SEQUENCE_END     10
82 #define RECORD_LEN_HI           11
83 #define RECORD_LEN_LO           12
84
85 #define MSG_TYPE                0
86 #define MSG_LEN_HI              1
87 #define MSG_LEN_MID             2
88 #define MSG_LEN_LO              3
89 #define MSG_SEQ_HI              4
90 #define MSG_SEQ_LO              5
91 #define MSG_FRAG_OFF_HI         6
92 #define MSG_FRAG_OFF_MID        7
93 #define MSG_FRAG_OFF_LO         8
94 #define MSG_FRAG_LEN_HI         9
95 #define MSG_FRAG_LEN_MID        10
96 #define MSG_FRAG_LEN_LO         11
97
98
99 static void dump_data(const char *data, int len)
100 {
101     int rem, i, content, reclen, msglen, fragoff, fraglen, epoch;
102     unsigned char *rec;
103
104     printf("---- START OF PACKET ----\n");
105
106     rem = len;
107     rec = (unsigned char *)data;
108
109     while (rem > 0) {
110         if (rem != len)
111             printf("*\n");
112         printf("*---- START OF RECORD ----\n");
113         if (rem < DTLS1_RT_HEADER_LENGTH) {
114             printf("*---- RECORD TRUNCATED ----\n");
115             break;
116         }
117         content = rec[RECORD_CONTENT_TYPE];
118         printf("** Record Content-type: %d\n", content);
119         printf("** Record Version: %02x%02x\n",
120                rec[RECORD_VERSION_HI], rec[RECORD_VERSION_LO]);
121         epoch = (rec[RECORD_EPOCH_HI] << 8) | rec[RECORD_EPOCH_LO];
122         printf("** Record Epoch: %d\n", epoch);
123         printf("** Record Sequence: ");
124         for (i = RECORD_SEQUENCE_START; i <= RECORD_SEQUENCE_END; i++)
125             printf("%02x", rec[i]);
126         reclen = (rec[RECORD_LEN_HI] << 8) | rec[RECORD_LEN_LO];
127         printf("\n** Record Length: %d\n", reclen);
128
129         /* Now look at message */
130         rec += DTLS1_RT_HEADER_LENGTH;
131         rem -= DTLS1_RT_HEADER_LENGTH;
132         if (content == SSL3_RT_HANDSHAKE) {
133             printf("**---- START OF HANDSHAKE MESSAGE FRAGMENT ----\n");
134             if (epoch > 0) {
135                 printf("**---- HANDSHAKE MESSAGE FRAGMENT ENCRYPTED ----\n");
136             } else if (rem < DTLS1_HM_HEADER_LENGTH
137                     || reclen < DTLS1_HM_HEADER_LENGTH) {
138                 printf("**---- HANDSHAKE MESSAGE FRAGMENT TRUNCATED ----\n");
139             } else {
140                 printf("*** Message Type: %d\n", rec[MSG_TYPE]);
141                 msglen = (rec[MSG_LEN_HI] << 16) | (rec[MSG_LEN_MID] << 8)
142                          | rec[MSG_LEN_LO];
143                 printf("*** Message Length: %d\n", msglen);
144                 printf("*** Message sequence: %d\n",
145                        (rec[MSG_SEQ_HI] << 8) | rec[MSG_SEQ_LO]);
146                 fragoff = (rec[MSG_FRAG_OFF_HI] << 16)
147                           | (rec[MSG_FRAG_OFF_MID] << 8)
148                           | rec[MSG_FRAG_OFF_LO];
149                 printf("*** Message Fragment offset: %d\n", fragoff);
150                 fraglen = (rec[MSG_FRAG_LEN_HI] << 16)
151                           | (rec[MSG_FRAG_LEN_MID] << 8)
152                           | rec[MSG_FRAG_LEN_LO];
153                 printf("*** Message Fragment len: %d\n", fraglen);
154                 if (fragoff + fraglen > msglen)
155                     printf("***---- HANDSHAKE MESSAGE FRAGMENT INVALID ----\n");
156                 else if(reclen < fraglen)
157                     printf("**---- HANDSHAKE MESSAGE FRAGMENT TRUNCATED ----\n");
158                 else
159                     printf("**---- END OF HANDSHAKE MESSAGE FRAGMENT ----\n");
160             }
161         }
162         if (rem < reclen) {
163             printf("*---- RECORD TRUNCATED ----\n");
164             rem = 0;
165         } else {
166             rec += reclen;
167             rem -= reclen;
168             printf("*---- END OF RECORD ----\n");
169         }
170     }
171     printf("---- END OF PACKET ----\n\n");
172     fflush(stdout);
173 }
174
175 static int tls_dump_read(BIO *bio, char *out, int outl)
176 {
177     int ret;
178     BIO *next = BIO_next(bio);
179
180     ret = BIO_read(next, out, outl);
181     copy_flags(bio);
182
183     if (ret > 0) {
184         dump_data(out, ret);
185     }
186
187     return ret;
188 }
189
190 static int tls_dump_write(BIO *bio, const char *in, int inl)
191 {
192     int ret;
193     BIO *next = BIO_next(bio);
194
195     ret = BIO_write(next, in, inl);
196     copy_flags(bio);
197
198     return ret;
199 }
200
201 static long tls_dump_ctrl(BIO *bio, int cmd, long num, void *ptr)
202 {
203     long ret;
204     BIO *next = BIO_next(bio);
205
206     if (next == NULL)
207         return 0;
208
209     switch (cmd) {
210     case BIO_CTRL_DUP:
211         ret = 0L;
212         break;
213     default:
214         ret = BIO_ctrl(next, cmd, num, ptr);
215         break;
216     }
217     return ret;
218 }
219
220 static int tls_dump_gets(BIO *bio, char *buf, int size)
221 {
222     /* We don't support this - not needed anyway */
223     return -1;
224 }
225
226 static int tls_dump_puts(BIO *bio, const char *str)
227 {
228     return tls_dump_write(bio, str, strlen(str));
229 }
230
231 int create_ssl_ctx_pair(const SSL_METHOD *sm, const SSL_METHOD *cm,
232                         SSL_CTX **sctx, SSL_CTX **cctx, char *certfile,
233                         char *privkeyfile)
234 {
235     SSL_CTX *serverctx = NULL;
236     SSL_CTX *clientctx = NULL;
237
238     serverctx = SSL_CTX_new(sm);
239     clientctx = SSL_CTX_new(cm);
240     if (serverctx == NULL || clientctx == NULL) {
241         printf("Failed to create SSL_CTX\n");
242         goto err;
243     }
244
245     if (SSL_CTX_use_certificate_file(serverctx, certfile,
246                                      SSL_FILETYPE_PEM) <= 0) {
247         printf("Failed to load server certificate\n");
248         goto err;
249     }
250     if (SSL_CTX_use_PrivateKey_file(serverctx, privkeyfile,
251                                     SSL_FILETYPE_PEM) <= 0) {
252         printf("Failed to load server private key\n");
253     }
254     if (SSL_CTX_check_private_key(serverctx) <= 0) {
255         printf("Failed to check private key\n");
256         goto err;
257     }
258
259     *sctx = serverctx;
260     *cctx = clientctx;
261
262     return 1;
263  err:
264     SSL_CTX_free(serverctx);
265     SSL_CTX_free(clientctx);
266     return 0;
267 }
268
269 #define MAXLOOPS    100000
270
271 /*
272  * NOTE: Transfers control of the BIOs - this function will free them on error
273  */
274 int create_ssl_connection(SSL_CTX *serverctx, SSL_CTX *clientctx, SSL **sssl,
275                           SSL **cssl, BIO *s_to_c_fbio, BIO *c_to_s_fbio)
276 {
277     int retc = -1, rets = -1, err, abortctr = 0;
278     int clienterr = 0, servererr = 0;
279     SSL *serverssl, *clientssl;
280     BIO *s_to_c_bio = NULL, *c_to_s_bio = NULL;
281
282     if (*sssl == NULL)
283         serverssl = SSL_new(serverctx);
284     else
285         serverssl = *sssl;
286     if (*cssl == NULL)
287         clientssl = SSL_new(clientctx);
288     else
289         clientssl = *cssl;
290
291     if (serverssl == NULL || clientssl == NULL) {
292         printf("Failed to create SSL object\n");
293         goto error;
294     }
295
296     s_to_c_bio = BIO_new(BIO_s_mem());
297     c_to_s_bio = BIO_new(BIO_s_mem());
298     if (s_to_c_bio == NULL || c_to_s_bio == NULL) {
299         printf("Failed to create mem BIOs\n");
300         goto error;
301     }
302
303     if (s_to_c_fbio != NULL)
304         s_to_c_bio = BIO_push(s_to_c_fbio, s_to_c_bio);
305     if (c_to_s_fbio != NULL)
306         c_to_s_bio = BIO_push(c_to_s_fbio, c_to_s_bio);
307     if (s_to_c_bio == NULL || c_to_s_bio == NULL) {
308         printf("Failed to create chained BIOs\n");
309         goto error;
310     }
311
312     /* Set Non-blocking IO behaviour */
313     BIO_set_mem_eof_return(s_to_c_bio, -1);
314     BIO_set_mem_eof_return(c_to_s_bio, -1);
315
316     /* Up ref these as we are passing them to two SSL objects */
317     BIO_up_ref(s_to_c_bio);
318     BIO_up_ref(c_to_s_bio);
319
320     SSL_set_bio(serverssl, c_to_s_bio, s_to_c_bio);
321     SSL_set_bio(clientssl, s_to_c_bio, c_to_s_bio);
322
323     /* BIOs will now be freed when SSL objects are freed */
324     s_to_c_bio = c_to_s_bio = NULL;
325     s_to_c_fbio = c_to_s_fbio = NULL;
326
327     do {
328         err = SSL_ERROR_WANT_WRITE;
329         while (!clienterr && retc <= 0 && err == SSL_ERROR_WANT_WRITE) {
330             retc = SSL_connect(clientssl);
331             if (retc <= 0)
332                 err = SSL_get_error(clientssl, retc);
333         }
334
335         if (!clienterr && retc <= 0 && err != SSL_ERROR_WANT_READ) {
336             printf("SSL_connect() failed %d, %d\n", retc, err);
337             clienterr = 1;
338         }
339
340         err = SSL_ERROR_WANT_WRITE;
341         while (!servererr && rets <= 0 && err == SSL_ERROR_WANT_WRITE) {
342             rets = SSL_accept(serverssl);
343             if (rets <= 0)
344                 err = SSL_get_error(serverssl, rets);
345         }
346
347         if (!servererr && rets <= 0 && err != SSL_ERROR_WANT_READ) {
348             printf("SSL_accept() failed %d, %d\n", retc, err);
349             servererr = 1;
350         }
351         if (clienterr && servererr)
352             goto error;
353         if (++abortctr == MAXLOOPS) {
354             printf("No progress made\n");
355             goto error;
356         }
357     } while (retc <=0 || rets <= 0);
358
359     *sssl = serverssl;
360     *cssl = clientssl;
361
362     return 1;
363
364  error:
365     if (*sssl == NULL) {
366         SSL_free(serverssl);
367         BIO_free(s_to_c_bio);
368         BIO_free(s_to_c_fbio);
369     }
370     if (*cssl == NULL) {
371         SSL_free(clientssl);
372         BIO_free(c_to_s_bio);
373         BIO_free(c_to_s_fbio);
374     }
375
376     return 0;
377 }