fabbf6cb7ee25973b393d6b804198dc4a36a6e1d
[openssl.git] / ssl / quic / quic_demux.c
1 /*
2  * Copyright 2022 The OpenSSL Project Authors. All Rights Reserved.
3  *
4  * Licensed under the Apache License 2.0 (the "License").  You may not use
5  * this file except in compliance with the License.  You can obtain a copy
6  * in the file LICENSE in the source distribution or at
7  * https://www.openssl.org/source/license.html
8  */
9
10 #include "internal/quic_demux.h"
11 #include "internal/quic_wire_pkt.h"
12 #include "internal/common.h"
13 #include <openssl/lhash.h>
14 #include <openssl/err.h>
15
16 #define URXE_DEMUX_STATE_FREE       0 /* on urx_free list */
17 #define URXE_DEMUX_STATE_PENDING    1 /* on urx_pending list */
18 #define URXE_DEMUX_STATE_ISSUED     2 /* on neither list */
19
20 #define DEMUX_MAX_MSGS_PER_CALL    32
21
22 #define DEMUX_DEFAULT_MTU        1500
23
24 /* Structure used to track a given connection ID. */
25 typedef struct quic_demux_conn_st QUIC_DEMUX_CONN;
26
27 struct quic_demux_conn_st {
28     QUIC_DEMUX_CONN                 *next; /* used when unregistering only */
29     QUIC_CONN_ID                    dst_conn_id;
30     ossl_quic_demux_cb_fn           *cb;
31     void                            *cb_arg;
32 };
33
34 DEFINE_LHASH_OF_EX(QUIC_DEMUX_CONN);
35
36 static unsigned long demux_conn_hash(const QUIC_DEMUX_CONN *conn)
37 {
38     size_t i;
39     unsigned long v = 0;
40
41     assert(conn->dst_conn_id.id_len <= QUIC_MAX_CONN_ID_LEN);
42
43     for (i = 0; i < conn->dst_conn_id.id_len; ++i)
44         v ^= ((unsigned long)conn->dst_conn_id.id[i])
45              << ((i * 8) % (sizeof(unsigned long) * 8));
46
47     return v;
48 }
49
50 static int demux_conn_cmp(const QUIC_DEMUX_CONN *a, const QUIC_DEMUX_CONN *b)
51 {
52     return !ossl_quic_conn_id_eq(&a->dst_conn_id, &b->dst_conn_id);
53 }
54
55 struct quic_demux_st {
56     /* The underlying transport BIO with datagram semantics. */
57     BIO                        *net_bio;
58
59     /*
60      * QUIC short packets do not contain the length of the connection ID field,
61      * therefore it must be known contextually. The demuxer requires connection
62      * IDs of the same length to be used for all incoming packets.
63      */
64     size_t                      short_conn_id_len;
65
66     /*
67      * Our current understanding of the upper bound on an incoming datagram size
68      * in bytes.
69      */
70     size_t                      mtu;
71
72     /* Time retrieval callback. */
73     OSSL_TIME                 (*now)(void *arg);
74     void                       *now_arg;
75
76     /* Hashtable mapping connection IDs to QUIC_DEMUX_CONN structures. */
77     LHASH_OF(QUIC_DEMUX_CONN)  *conns_by_id;
78
79     /* The default packet handler, if any. */
80     ossl_quic_demux_cb_fn      *default_cb;
81     void                       *default_cb_arg;
82
83     /*
84      * List of URXEs which are not currently in use (i.e., not filled with
85      * unconsumed data). These are moved to the pending list as they are filled.
86      */
87     QUIC_URXE_LIST              urx_free;
88
89     /*
90      * List of URXEs which are filled with received encrypted data. These are
91      * removed from this list as we invoke the callbacks for each of them. They
92      * are then not on any list managed by us; we forget about them until our
93      * user calls ossl_quic_demux_release_urxe to return the URXE to us, at
94      * which point we add it to the free list.
95      */
96     QUIC_URXE_LIST              urx_pending;
97
98     /* Whether to use local address support. */
99     char                        use_local_addr;
100 };
101
102 QUIC_DEMUX *ossl_quic_demux_new(BIO *net_bio,
103                                 size_t short_conn_id_len,
104                                 OSSL_TIME (*now)(void *arg),
105                                 void *now_arg)
106 {
107     QUIC_DEMUX *demux;
108
109     demux = OPENSSL_zalloc(sizeof(QUIC_DEMUX));
110     if (demux == NULL)
111         return NULL;
112
113     demux->net_bio                  = net_bio;
114     demux->short_conn_id_len        = short_conn_id_len;
115     /* We update this if possible when we get a BIO. */
116     demux->mtu                      = DEMUX_DEFAULT_MTU;
117     demux->now                      = now;
118     demux->now_arg                  = now_arg;
119
120     demux->conns_by_id
121         = lh_QUIC_DEMUX_CONN_new(demux_conn_hash, demux_conn_cmp);
122     if (demux->conns_by_id == NULL) {
123         OPENSSL_free(demux);
124         return NULL;
125     }
126
127     if (net_bio != NULL
128         && BIO_dgram_get_local_addr_cap(net_bio)
129         && BIO_dgram_set_local_addr_enable(net_bio, 1))
130         demux->use_local_addr = 1;
131
132     return demux;
133 }
134
135 static void demux_free_conn_it(QUIC_DEMUX_CONN *conn, void *arg)
136 {
137     OPENSSL_free(conn);
138 }
139
140 static void demux_free_urxl(QUIC_URXE_LIST *l)
141 {
142     QUIC_URXE *e, *enext;
143
144     for (e = ossl_list_urxe_head(l); e != NULL; e = enext) {
145         enext = ossl_list_urxe_next(e);
146         ossl_list_urxe_remove(l, e);
147         OPENSSL_free(e);
148     }
149 }
150
151 void ossl_quic_demux_free(QUIC_DEMUX *demux)
152 {
153     if (demux == NULL)
154         return;
155
156     /* Free all connection structures. */
157     lh_QUIC_DEMUX_CONN_doall_arg(demux->conns_by_id, demux_free_conn_it, NULL);
158     lh_QUIC_DEMUX_CONN_free(demux->conns_by_id);
159
160     /* Free all URXEs we are holding. */
161     demux_free_urxl(&demux->urx_free);
162     demux_free_urxl(&demux->urx_pending);
163
164     OPENSSL_free(demux);
165 }
166
167 void ossl_quic_demux_set_bio(QUIC_DEMUX *demux, BIO *net_bio)
168 {
169     unsigned int mtu;
170
171     demux->net_bio = net_bio;
172
173     if (net_bio != NULL) {
174         /*
175          * Try to determine our MTU if possible. The BIO is not required to
176          * support this, in which case we remain at the last known MTU, or our
177          * initial default.
178          */
179         mtu = BIO_dgram_get_mtu(net_bio);
180         if (mtu >= QUIC_MIN_INITIAL_DGRAM_LEN)
181             ossl_quic_demux_set_mtu(demux, mtu); /* best effort */
182     }
183 }
184
185 int ossl_quic_demux_set_mtu(QUIC_DEMUX *demux, unsigned int mtu)
186 {
187     if (mtu < QUIC_MIN_INITIAL_DGRAM_LEN)
188         return 0;
189
190     demux->mtu = mtu;
191     return 1;
192 }
193
194 static QUIC_DEMUX_CONN *demux_get_by_conn_id(QUIC_DEMUX *demux,
195                                              const QUIC_CONN_ID *dst_conn_id)
196 {
197     QUIC_DEMUX_CONN key;
198
199     if (dst_conn_id->id_len > QUIC_MAX_CONN_ID_LEN)
200         return NULL;
201
202     key.dst_conn_id = *dst_conn_id;
203     return lh_QUIC_DEMUX_CONN_retrieve(demux->conns_by_id, &key);
204 }
205
206 int ossl_quic_demux_register(QUIC_DEMUX *demux,
207                              const QUIC_CONN_ID *dst_conn_id,
208                              ossl_quic_demux_cb_fn *cb, void *cb_arg)
209 {
210     QUIC_DEMUX_CONN *conn;
211
212     if (dst_conn_id == NULL
213         || dst_conn_id->id_len > QUIC_MAX_CONN_ID_LEN
214         || cb == NULL)
215         return 0;
216
217     /* Ensure not already registered. */
218     if (demux_get_by_conn_id(demux, dst_conn_id) != NULL)
219         /* Handler already registered with this connection ID. */
220         return 0;
221
222     conn = OPENSSL_zalloc(sizeof(QUIC_DEMUX_CONN));
223     if (conn == NULL)
224         return 0;
225
226     conn->dst_conn_id   = *dst_conn_id;
227     conn->cb            = cb;
228     conn->cb_arg        = cb_arg;
229
230     lh_QUIC_DEMUX_CONN_insert(demux->conns_by_id, conn);
231     return 1;
232 }
233
234 static void demux_unregister(QUIC_DEMUX *demux,
235                              QUIC_DEMUX_CONN *conn)
236 {
237     lh_QUIC_DEMUX_CONN_delete(demux->conns_by_id, conn);
238     OPENSSL_free(conn);
239 }
240
241 int ossl_quic_demux_unregister(QUIC_DEMUX *demux,
242                                const QUIC_CONN_ID *dst_conn_id)
243 {
244     QUIC_DEMUX_CONN *conn;
245
246     if (dst_conn_id == NULL
247         || dst_conn_id->id_len > QUIC_MAX_CONN_ID_LEN)
248         return 0;
249
250     conn = demux_get_by_conn_id(demux, dst_conn_id);
251     if (conn == NULL)
252         return 0;
253
254     demux_unregister(demux, conn);
255     return 1;
256 }
257
258 struct unreg_arg {
259     ossl_quic_demux_cb_fn *cb;
260     void *cb_arg;
261     QUIC_DEMUX_CONN *head;
262 };
263
264 static void demux_unregister_by_cb(QUIC_DEMUX_CONN *conn, void *arg_)
265 {
266     struct unreg_arg *arg = arg_;
267
268     if (conn->cb == arg->cb && conn->cb_arg == arg->cb_arg) {
269         conn->next = arg->head;
270         arg->head = conn;
271     }
272 }
273
274 void ossl_quic_demux_unregister_by_cb(QUIC_DEMUX *demux,
275                                       ossl_quic_demux_cb_fn *cb,
276                                       void *cb_arg)
277 {
278     QUIC_DEMUX_CONN *conn, *cnext;
279     struct unreg_arg arg = {0};
280     arg.cb      = cb;
281     arg.cb_arg  = cb_arg;
282
283     lh_QUIC_DEMUX_CONN_doall_arg(demux->conns_by_id,
284                                  demux_unregister_by_cb, &arg);
285
286     for (conn = arg.head; conn != NULL; conn = cnext) {
287         cnext = conn->next;
288         demux_unregister(demux, conn);
289     }
290 }
291
292 void ossl_quic_demux_set_default_handler(QUIC_DEMUX *demux,
293                                          ossl_quic_demux_cb_fn *cb,
294                                          void *cb_arg)
295 {
296     demux->default_cb       = cb;
297     demux->default_cb_arg   = cb_arg;
298 }
299
300 static QUIC_URXE *demux_alloc_urxe(size_t alloc_len)
301 {
302     QUIC_URXE *e;
303
304     if (alloc_len >= SIZE_MAX - sizeof(QUIC_URXE))
305         return NULL;
306
307     e = OPENSSL_malloc(sizeof(QUIC_URXE) + alloc_len);
308     if (e == NULL)
309         return NULL;
310
311     ossl_list_urxe_init_elem(e);
312     e->alloc_len   = alloc_len;
313     e->data_len    = 0;
314     return e;
315 }
316
317 static QUIC_URXE *demux_resize_urxe(QUIC_DEMUX *demux, QUIC_URXE *e,
318                                     size_t new_alloc_len)
319 {
320     QUIC_URXE *e2, *prev;
321
322     if (!ossl_assert(e->demux_state == URXE_DEMUX_STATE_FREE))
323         /* Never attempt to resize a URXE which is not on the free list. */
324         return NULL;
325
326     prev = ossl_list_urxe_prev(e);
327     ossl_list_urxe_remove(&demux->urx_free, e);
328
329     e2 = OPENSSL_realloc(e, sizeof(QUIC_URXE) + new_alloc_len);
330     if (e2 == NULL) {
331         /* Failed to resize, abort. */
332         if (prev == NULL)
333             ossl_list_urxe_insert_head(&demux->urx_free, e);
334         else
335             ossl_list_urxe_insert_after(&demux->urx_free, prev, e);
336
337         return NULL;
338     }
339
340     if (prev == NULL)
341         ossl_list_urxe_insert_head(&demux->urx_free, e2);
342     else
343         ossl_list_urxe_insert_after(&demux->urx_free, prev, e2);
344
345     e2->alloc_len = new_alloc_len;
346     return e2;
347 }
348
349 static QUIC_URXE *demux_reserve_urxe(QUIC_DEMUX *demux, QUIC_URXE *e,
350                                      size_t alloc_len)
351 {
352     return e->alloc_len < alloc_len ? demux_resize_urxe(demux, e, alloc_len) : e;
353 }
354
355 static int demux_ensure_free_urxe(QUIC_DEMUX *demux, size_t min_num_free)
356 {
357     QUIC_URXE *e;
358
359     while (ossl_list_urxe_num(&demux->urx_free) < min_num_free) {
360         e = demux_alloc_urxe(demux->mtu);
361         if (e == NULL)
362             return 0;
363
364         ossl_list_urxe_insert_tail(&demux->urx_free, e);
365         e->demux_state = URXE_DEMUX_STATE_FREE;
366     }
367
368     return 1;
369 }
370
371 /*
372  * Receive datagrams from network, placing them into URXEs.
373  *
374  * Returns 1 on success or 0 on failure.
375  *
376  * Precondition: at least one URXE is free
377  * Precondition: there are no pending URXEs
378  */
379 static int demux_recv(QUIC_DEMUX *demux)
380 {
381     BIO_MSG msg[DEMUX_MAX_MSGS_PER_CALL];
382     size_t rd, i;
383     QUIC_URXE *urxe = ossl_list_urxe_head(&demux->urx_free), *unext;
384     OSSL_TIME now;
385
386     /* This should never be called when we have any pending URXE. */
387     assert(ossl_list_urxe_head(&demux->urx_pending) == NULL);
388     assert(urxe->demux_state == URXE_DEMUX_STATE_FREE);
389
390     if (demux->net_bio == NULL)
391         /*
392          * If no BIO is plugged in, treat this as no datagram being available.
393          */
394         return QUIC_DEMUX_PUMP_RES_TRANSIENT_FAIL;
395
396     /*
397      * Opportunistically receive as many messages as possible in a single
398      * syscall, determined by how many free URXEs are available.
399      */
400     for (i = 0; i < (ossl_ssize_t)OSSL_NELEM(msg);
401             ++i, urxe = ossl_list_urxe_next(urxe)) {
402         if (urxe == NULL) {
403             /* We need at least one URXE to receive into. */
404             if (!ossl_assert(i > 0))
405                 return QUIC_DEMUX_PUMP_RES_PERMANENT_FAIL;
406
407             break;
408         }
409
410         /* Ensure the URXE is big enough. */
411         urxe = demux_reserve_urxe(demux, urxe, demux->mtu);
412         if (urxe == NULL)
413             /* Allocation error, fail. */
414             return QUIC_DEMUX_PUMP_RES_PERMANENT_FAIL;
415
416         /* Ensure we zero any fields added to BIO_MSG at a later date. */
417         memset(&msg[i], 0, sizeof(BIO_MSG));
418         msg[i].data     = ossl_quic_urxe_data(urxe);
419         msg[i].data_len = urxe->alloc_len;
420         msg[i].peer     = &urxe->peer;
421         BIO_ADDR_clear(&urxe->peer);
422         if (demux->use_local_addr)
423             msg[i].local = &urxe->local;
424         else
425             BIO_ADDR_clear(&urxe->local);
426     }
427
428     ERR_set_mark();
429     if (!BIO_recvmmsg(demux->net_bio, msg, sizeof(BIO_MSG), i, 0, &rd)) {
430         if (BIO_err_is_non_fatal(ERR_peek_last_error())) {
431             /* Transient error, clear the error and stop. */
432             ERR_pop_to_mark();
433             return QUIC_DEMUX_PUMP_RES_TRANSIENT_FAIL;
434         } else {
435             /* Non-transient error, do not clear the error. */
436             ERR_clear_last_mark();
437             return QUIC_DEMUX_PUMP_RES_PERMANENT_FAIL;
438         }
439     }
440
441     ERR_clear_last_mark();
442     now = demux->now != NULL ? demux->now(demux->now_arg) : ossl_time_zero();
443
444     urxe = ossl_list_urxe_head(&demux->urx_free);
445     for (i = 0; i < rd; ++i, urxe = unext) {
446         unext = ossl_list_urxe_next(urxe);
447         /* Set URXE with actual length of received datagram. */
448         urxe->data_len      = msg[i].data_len;
449         /* Time we received datagram. */
450         urxe->time          = now;
451         /* Move from free list to pending list. */
452         ossl_list_urxe_remove(&demux->urx_free, urxe);
453         ossl_list_urxe_insert_tail(&demux->urx_pending, urxe);
454         urxe->demux_state = URXE_DEMUX_STATE_PENDING;
455     }
456
457     return QUIC_DEMUX_PUMP_RES_OK;
458 }
459
460 /* Extract destination connection ID from the first packet in a datagram. */
461 static int demux_identify_conn_id(QUIC_DEMUX *demux,
462                                   QUIC_URXE *e,
463                                   QUIC_CONN_ID *dst_conn_id)
464 {
465     return ossl_quic_wire_get_pkt_hdr_dst_conn_id(ossl_quic_urxe_data(e),
466                                                   e->data_len,
467                                                   demux->short_conn_id_len,
468                                                   dst_conn_id);
469 }
470
471 /* Identify the connection structure corresponding to a given URXE. */
472 static QUIC_DEMUX_CONN *demux_identify_conn(QUIC_DEMUX *demux, QUIC_URXE *e)
473 {
474     QUIC_CONN_ID dst_conn_id;
475
476     if (!demux_identify_conn_id(demux, e, &dst_conn_id))
477         /*
478          * Datagram is so badly malformed we can't get the DCID from the first
479          * packet in it, so just give up.
480          */
481         return NULL;
482
483     return demux_get_by_conn_id(demux, &dst_conn_id);
484 }
485
486 /* Process a single pending URXE. */
487 static int demux_process_pending_urxe(QUIC_DEMUX *demux, QUIC_URXE *e)
488 {
489     QUIC_DEMUX_CONN *conn;
490
491     /* The next URXE we process should be at the head of the pending list. */
492     if (!ossl_assert(e == ossl_list_urxe_head(&demux->urx_pending)))
493         return 0;
494
495     assert(e->demux_state == URXE_DEMUX_STATE_PENDING);
496
497     conn = demux_identify_conn(demux, e);
498     if (conn == NULL) {
499         /*
500          * We could not identify a connection. If we have a default packet
501          * handler, pass it to the handler. Otherwise, we will never be able to
502          * process this datagram, so get rid of it.
503          */
504         ossl_list_urxe_remove(&demux->urx_pending, e);
505         if (demux->default_cb != NULL) {
506             /* Pass to default handler. */
507             e->demux_state = URXE_DEMUX_STATE_ISSUED;
508             demux->default_cb(e, demux->default_cb_arg);
509         } else {
510             /* Discard. */
511             ossl_list_urxe_insert_tail(&demux->urx_free, e);
512             e->demux_state = URXE_DEMUX_STATE_FREE;
513         }
514         return 1; /* keep processing pending URXEs */
515     }
516
517     /*
518      * Remove from list and invoke callback. The URXE now belongs to the
519      * callback. (QUIC_DEMUX_CONN never has non-NULL cb.)
520      */
521     ossl_list_urxe_remove(&demux->urx_pending, e);
522     e->demux_state = URXE_DEMUX_STATE_ISSUED;
523     conn->cb(e, conn->cb_arg);
524     return 1;
525 }
526
527 /* Process pending URXEs to generate callbacks. */
528 static int demux_process_pending_urxl(QUIC_DEMUX *demux)
529 {
530     QUIC_URXE *e;
531
532     while ((e = ossl_list_urxe_head(&demux->urx_pending)) != NULL)
533         if (!demux_process_pending_urxe(demux, e))
534             return 0;
535
536     return 1;
537 }
538
539 /*
540  * Drain the pending URXE list, processing any pending URXEs by making their
541  * callbacks. If no URXEs are pending, a network read is attempted first.
542  */
543 int ossl_quic_demux_pump(QUIC_DEMUX *demux)
544 {
545     int ret;
546
547     if (ossl_list_urxe_head(&demux->urx_pending) == NULL) {
548         ret = demux_ensure_free_urxe(demux, DEMUX_MAX_MSGS_PER_CALL);
549         if (ret != 1)
550             return QUIC_DEMUX_PUMP_RES_PERMANENT_FAIL;
551
552         ret = demux_recv(demux);
553         if (ret != QUIC_DEMUX_PUMP_RES_OK)
554             return ret;
555
556         /*
557          * If demux_recv returned successfully, we should always have something.
558          */
559         assert(ossl_list_urxe_head(&demux->urx_pending) != NULL);
560     }
561
562     if (!demux_process_pending_urxl(demux))
563         return QUIC_DEMUX_PUMP_RES_PERMANENT_FAIL;
564
565     return QUIC_DEMUX_PUMP_RES_OK;
566 }
567
568 /* Artificially inject a packet into the demuxer for testing purposes. */
569 int ossl_quic_demux_inject(QUIC_DEMUX *demux,
570                            const unsigned char *buf,
571                            size_t buf_len,
572                            const BIO_ADDR *peer,
573                            const BIO_ADDR *local)
574 {
575     int ret;
576     QUIC_URXE *urxe;
577
578     ret = demux_ensure_free_urxe(demux, 1);
579     if (ret != 1)
580         return 0;
581
582     urxe = ossl_list_urxe_head(&demux->urx_free);
583
584     assert(urxe->demux_state == URXE_DEMUX_STATE_FREE);
585
586     urxe = demux_reserve_urxe(demux, urxe, buf_len);
587     if (urxe == NULL)
588         return 0;
589
590     memcpy(ossl_quic_urxe_data(urxe), buf, buf_len);
591     urxe->data_len = buf_len;
592
593     if (peer != NULL)
594         urxe->peer = *peer;
595     else
596         BIO_ADDR_clear(&urxe->peer);
597
598     if (local != NULL)
599         urxe->local = *local;
600     else
601         BIO_ADDR_clear(&urxe->local);
602
603     /* Move from free list to pending list. */
604     ossl_list_urxe_remove(&demux->urx_free, urxe);
605     ossl_list_urxe_insert_tail(&demux->urx_pending, urxe);
606     urxe->demux_state = URXE_DEMUX_STATE_PENDING;
607
608     return demux_process_pending_urxl(demux);
609 }
610
611 /* Called by our user to return a URXE to the free list. */
612 void ossl_quic_demux_release_urxe(QUIC_DEMUX *demux,
613                                   QUIC_URXE *e)
614 {
615     assert(ossl_list_urxe_prev(e) == NULL && ossl_list_urxe_next(e) == NULL);
616     assert(e->demux_state == URXE_DEMUX_STATE_ISSUED);
617     ossl_list_urxe_insert_tail(&demux->urx_free, e);
618     e->demux_state = URXE_DEMUX_STATE_FREE;
619 }
620
621 void ossl_quic_demux_reinject_urxe(QUIC_DEMUX *demux,
622                                    QUIC_URXE *e)
623 {
624     assert(ossl_list_urxe_prev(e) == NULL && ossl_list_urxe_next(e) == NULL);
625     assert(e->demux_state == URXE_DEMUX_STATE_ISSUED);
626     ossl_list_urxe_insert_head(&demux->urx_pending, e);
627     e->demux_state = URXE_DEMUX_STATE_PENDING;
628 }