QUIC: use list.h
[openssl.git] / ssl / quic / quic_demux.c
1 /*
2  * Copyright 2022 The OpenSSL Project Authors. All Rights Reserved.
3  *
4  * Licensed under the Apache License 2.0 (the "License").  You may not use
5  * this file except in compliance with the License.  You can obtain a copy
6  * in the file LICENSE in the source distribution or at
7  * https://www.openssl.org/source/license.html
8  */
9
10 #include "internal/quic_demux.h"
11 #include "internal/quic_wire_pkt.h"
12 #include "internal/common.h"
13 #include <openssl/lhash.h>
14
15 #define DEMUX_MAX_MSGS_PER_CALL    32
16
17 /* Structure used to track a given connection ID. */
18 typedef struct quic_demux_conn_st QUIC_DEMUX_CONN;
19
20 struct quic_demux_conn_st {
21     QUIC_DEMUX_CONN            *next; /* used when unregistering only */
22     QUIC_CONN_ID                dst_conn_id;
23     ossl_quic_demux_cb_fn      *cb;
24     void                       *cb_arg;
25 };
26
27 DEFINE_LHASH_OF_EX(QUIC_DEMUX_CONN);
28
29 static unsigned long demux_conn_hash(const QUIC_DEMUX_CONN *conn)
30 {
31     size_t i;
32     unsigned long v = 0;
33
34     assert(conn->dst_conn_id.id_len <= QUIC_MAX_CONN_ID_LEN);
35
36     for (i = 0; i < conn->dst_conn_id.id_len; ++i)
37         v ^= ((unsigned long)conn->dst_conn_id.id[i])
38              << ((i * 8) % (sizeof(unsigned long) * 8));
39
40     return v;
41 }
42
43 static int demux_conn_cmp(const QUIC_DEMUX_CONN *a, const QUIC_DEMUX_CONN *b)
44 {
45     return !ossl_quic_conn_id_eq(&a->dst_conn_id, &b->dst_conn_id);
46 }
47
48 struct quic_demux_st {
49     /* The underlying transport BIO with datagram semantics. */
50     BIO                        *net_bio;
51
52     /*
53      * QUIC short packets do not contain the length of the connection ID field,
54      * therefore it must be known contextually. The demuxer requires connection
55      * IDs of the same length to be used for all incoming packets.
56      */
57     size_t                      short_conn_id_len;
58
59     /* Default URXE buffer size in bytes. */
60     size_t                      default_urxe_alloc_len;
61
62     /* Time retrieval callback. */
63     OSSL_TIME                 (*now)(void *arg);
64     void                       *now_arg;
65
66     /* Hashtable mapping connection IDs to QUIC_DEMUX_CONN structures. */
67     LHASH_OF(QUIC_DEMUX_CONN)  *conns_by_id;
68
69     /*
70      * List of URXEs which are not currently in use (i.e., not filled with
71      * unconsumed data). These are moved to the pending list as they are filled.
72      */
73     QUIC_URXE_LIST              urx_free;
74
75     /*
76      * List of URXEs which are filled with received encrypted data. These are
77      * removed from this list as we invoke the callbacks for each of them. They
78      * are then not on any list managed by us; we forget about them until our
79      * user calls ossl_quic_demux_release_urxe to return the URXE to us, at
80      * which point we add it to the free list.
81      */
82     QUIC_URXE_LIST              urx_pending;
83
84     /* Whether to use local address support. */
85     char                        use_local_addr;
86 };
87
88 QUIC_DEMUX *ossl_quic_demux_new(BIO *net_bio,
89                                 size_t short_conn_id_len,
90                                 size_t default_urxe_alloc_len,
91                                 OSSL_TIME (*now)(void *arg),
92                                 void *now_arg)
93 {
94     QUIC_DEMUX *demux;
95
96     demux = OPENSSL_zalloc(sizeof(QUIC_DEMUX));
97     if (demux == NULL)
98         return NULL;
99
100     demux->net_bio                  = net_bio;
101     demux->short_conn_id_len        = short_conn_id_len;
102     demux->default_urxe_alloc_len   = default_urxe_alloc_len;
103     demux->now                      = now;
104     demux->now_arg                  = now_arg;
105
106     demux->conns_by_id
107         = lh_QUIC_DEMUX_CONN_new(demux_conn_hash, demux_conn_cmp);
108     if (demux->conns_by_id == NULL) {
109         OPENSSL_free(demux);
110         return NULL;
111     }
112
113     if (net_bio != NULL
114         && BIO_dgram_get_local_addr_cap(net_bio)
115         && BIO_dgram_set_local_addr_enable(net_bio, 1))
116         demux->use_local_addr = 1;
117
118     return demux;
119 }
120
121 static void demux_free_conn_it(QUIC_DEMUX_CONN *conn, void *arg)
122 {
123     OPENSSL_free(conn);
124 }
125
126 static void demux_free_urxl(QUIC_URXE_LIST *l)
127 {
128     QUIC_URXE *e, *enext;
129
130     for (e = ossl_list_urxe_head(l); e != NULL; e = enext) {
131         enext = ossl_list_urxe_next(e);
132         ossl_list_urxe_remove(l, e);
133         OPENSSL_free(e);
134     }
135 }
136
137 void ossl_quic_demux_free(QUIC_DEMUX *demux)
138 {
139     if (demux == NULL)
140         return;
141
142     /* Free all connection structures. */
143     lh_QUIC_DEMUX_CONN_doall_arg(demux->conns_by_id, demux_free_conn_it, NULL);
144     lh_QUIC_DEMUX_CONN_free(demux->conns_by_id);
145
146     /* Free all URXEs we are holding. */
147     demux_free_urxl(&demux->urx_free);
148     demux_free_urxl(&demux->urx_pending);
149
150     OPENSSL_free(demux);
151 }
152
153 static QUIC_DEMUX_CONN *demux_get_by_conn_id(QUIC_DEMUX *demux,
154                                              const QUIC_CONN_ID *dst_conn_id)
155 {
156     QUIC_DEMUX_CONN key;
157
158     if (dst_conn_id->id_len > QUIC_MAX_CONN_ID_LEN)
159         return NULL;
160
161     key.dst_conn_id = *dst_conn_id;
162     return lh_QUIC_DEMUX_CONN_retrieve(demux->conns_by_id, &key);
163 }
164
165 int ossl_quic_demux_register(QUIC_DEMUX *demux,
166                              const QUIC_CONN_ID *dst_conn_id,
167                              ossl_quic_demux_cb_fn *cb, void *cb_arg)
168 {
169     QUIC_DEMUX_CONN *conn;
170
171     if (dst_conn_id == NULL
172         || dst_conn_id->id_len > QUIC_MAX_CONN_ID_LEN
173         || cb == NULL)
174         return 0;
175
176     /* Ensure not already registered. */
177     if (demux_get_by_conn_id(demux, dst_conn_id) != NULL)
178         /* Handler already registered with this connection ID. */
179         return 0;
180
181     conn = OPENSSL_zalloc(sizeof(QUIC_DEMUX_CONN));
182     if (conn == NULL)
183         return 0;
184
185     conn->dst_conn_id   = *dst_conn_id;
186     conn->cb            = cb;
187     conn->cb_arg        = cb_arg;
188
189     lh_QUIC_DEMUX_CONN_insert(demux->conns_by_id, conn);
190     return 1;
191 }
192
193 static void demux_unregister(QUIC_DEMUX *demux,
194                              QUIC_DEMUX_CONN *conn)
195 {
196     lh_QUIC_DEMUX_CONN_delete(demux->conns_by_id, conn);
197     OPENSSL_free(conn);
198 }
199
200 int ossl_quic_demux_unregister(QUIC_DEMUX *demux,
201                                const QUIC_CONN_ID *dst_conn_id)
202 {
203     QUIC_DEMUX_CONN *conn;
204
205     if (dst_conn_id == NULL
206         || dst_conn_id->id_len > QUIC_MAX_CONN_ID_LEN)
207         return 0;
208
209     conn = demux_get_by_conn_id(demux, dst_conn_id);
210     if (conn == NULL)
211         return 0;
212
213     demux_unregister(demux, conn);
214     return 1;
215 }
216
217 struct unreg_arg {
218     ossl_quic_demux_cb_fn *cb;
219     void *cb_arg;
220     QUIC_DEMUX_CONN *head;
221 };
222
223 static void demux_unregister_by_cb(QUIC_DEMUX_CONN *conn, void *arg_)
224 {
225     struct unreg_arg *arg = arg_;
226
227     if (conn->cb == arg->cb && conn->cb_arg == arg->cb_arg) {
228         conn->next = arg->head;
229         arg->head = conn;
230     }
231 }
232
233 void ossl_quic_demux_unregister_by_cb(QUIC_DEMUX *demux,
234                                       ossl_quic_demux_cb_fn *cb,
235                                       void *cb_arg)
236 {
237     QUIC_DEMUX_CONN *conn, *cnext;
238     struct unreg_arg arg = {0};
239     arg.cb      = cb;
240     arg.cb_arg  = cb_arg;
241
242     lh_QUIC_DEMUX_CONN_doall_arg(demux->conns_by_id,
243                                  demux_unregister_by_cb, &arg);
244
245     for (conn = arg.head; conn != NULL; conn = cnext) {
246         cnext = conn->next;
247         demux_unregister(demux, conn);
248     }
249 }
250
251 static QUIC_URXE *demux_alloc_urxe(size_t alloc_len)
252 {
253     QUIC_URXE *e;
254
255     if (alloc_len >= SIZE_MAX - sizeof(QUIC_URXE))
256         return NULL;
257
258     e = OPENSSL_malloc(sizeof(QUIC_URXE) + alloc_len);
259     if (e == NULL)
260         return NULL;
261
262     ossl_list_urxe_init_elem(e);
263     e->alloc_len        = alloc_len;
264     e->data_len = 0;
265     return e;
266 }
267
268 static int demux_ensure_free_urxe(QUIC_DEMUX *demux, size_t min_num_free)
269 {
270     QUIC_URXE *e;
271
272     while (ossl_list_urxe_num(&demux->urx_free) < min_num_free) {
273         e = demux_alloc_urxe(demux->default_urxe_alloc_len);
274         if (e == NULL)
275             return 0;
276
277         ossl_list_urxe_insert_tail(&demux->urx_free, e);
278     }
279
280     return 1;
281 }
282
283 /*
284  * Receive datagrams from network, placing them into URXEs.
285  *
286  * Returns 1 on success or 0 on failure.
287  *
288  * Precondition: at least one URXE is free
289  * Precondition: there are no pending URXEs
290  */
291 static int demux_recv(QUIC_DEMUX *demux)
292 {
293     BIO_MSG msg[DEMUX_MAX_MSGS_PER_CALL];
294     size_t rd, i;
295     QUIC_URXE *urxe = ossl_list_urxe_head(&demux->urx_free), *unext;
296     OSSL_TIME now;
297
298     /* This should never be called when we have any pending URXE. */
299     assert(ossl_list_urxe_head(&demux->urx_pending) == NULL);
300
301     if (demux->net_bio == NULL)
302         return 0;
303
304     /*
305      * Opportunistically receive as many messages as possible in a single
306      * syscall, determined by how many free URXEs are available.
307      */
308     for (i = 0; i < (ossl_ssize_t)OSSL_NELEM(msg);
309             ++i, urxe = ossl_list_urxe_next(urxe)) {
310         if (urxe == NULL) {
311             /* We need at least one URXE to receive into. */
312             if (!ossl_assert(i > 0))
313                 return 0;
314
315             break;
316         }
317
318         /* Ensure we zero any fields added to BIO_MSG at a later date. */
319         memset(&msg[i], 0, sizeof(BIO_MSG));
320         msg[i].data     = ossl_quic_urxe_data(urxe);
321         msg[i].data_len = urxe->alloc_len;
322         msg[i].peer     = &urxe->peer;
323         if (demux->use_local_addr)
324             msg[i].local = &urxe->local;
325         else
326             BIO_ADDR_clear(&urxe->local);
327     }
328
329     if (!BIO_recvmmsg(demux->net_bio, msg, sizeof(BIO_MSG), i, 0, &rd))
330         return 0;
331
332     now = demux->now != NULL ? demux->now(demux->now_arg) : ossl_time_zero();
333
334     urxe = ossl_list_urxe_head(&demux->urx_free);
335     for (i = 0; i < rd; ++i, urxe = unext) {
336         unext = ossl_list_urxe_next(urxe);
337         /* Set URXE with actual length of received datagram. */
338         urxe->data_len      = msg[i].data_len;
339         /* Time we received datagram. */
340         urxe->time          = now;
341         /* Move from free list to pending list. */
342         ossl_list_urxe_remove(&demux->urx_free, urxe);
343         ossl_list_urxe_insert_tail(&demux->urx_pending, urxe);
344     }
345
346     return 1;
347 }
348
349 /* Extract destination connection ID from the first packet in a datagram. */
350 static int demux_identify_conn_id(QUIC_DEMUX *demux,
351                                   QUIC_URXE *e,
352                                   QUIC_CONN_ID *dst_conn_id)
353 {
354     return ossl_quic_wire_get_pkt_hdr_dst_conn_id(ossl_quic_urxe_data(e),
355                                                   e->data_len,
356                                                   demux->short_conn_id_len,
357                                                   dst_conn_id);
358 }
359
360 /* Identify the connection structure corresponding to a given URXE. */
361 static QUIC_DEMUX_CONN *demux_identify_conn(QUIC_DEMUX *demux, QUIC_URXE *e)
362 {
363     QUIC_CONN_ID dst_conn_id;
364
365     if (!demux_identify_conn_id(demux, e, &dst_conn_id))
366         /*
367          * Datagram is so badly malformed we can't get the DCID from the first
368          * packet in it, so just give up.
369          */
370         return NULL;
371
372     return demux_get_by_conn_id(demux, &dst_conn_id);
373 }
374
375 /* Process a single pending URXE. */
376 static int demux_process_pending_urxe(QUIC_DEMUX *demux, QUIC_URXE *e)
377 {
378     QUIC_DEMUX_CONN *conn;
379
380     /* The next URXE we process should be at the head of the pending list. */
381     if (!ossl_assert(e == ossl_list_urxe_head(&demux->urx_pending)))
382         return 0;
383
384     conn = demux_identify_conn(demux, e);
385     if (conn == NULL) {
386         /*
387          * We could not identify a connection. We will never be able to process
388          * this datagram, so get rid of it.
389          */
390         ossl_list_urxe_remove(&demux->urx_pending, e);
391         ossl_list_urxe_insert_tail(&demux->urx_free, e);
392         return 1; /* keep processing pending URXEs */
393     }
394
395     /*
396      * Remove from list and invoke callback. The URXE now belongs to the
397      * callback. (QUIC_DEMUX_CONN never has non-NULL cb.)
398      */
399     ossl_list_urxe_remove(&demux->urx_pending, e);
400     conn->cb(e, conn->cb_arg);
401     return 1;
402 }
403
404 /* Process pending URXEs to generate callbacks. */
405 static int demux_process_pending_urxl(QUIC_DEMUX *demux)
406 {
407     QUIC_URXE *e;
408
409     while ((e = ossl_list_urxe_head(&demux->urx_pending)) != NULL)
410         if (!demux_process_pending_urxe(demux, e))
411             return 0;
412
413     return 1;
414 }
415
416 /*
417  * Drain the pending URXE list, processing any pending URXEs by making their
418  * callbacks. If no URXEs are pending, a network read is attempted first.
419  */
420 int ossl_quic_demux_pump(QUIC_DEMUX *demux)
421 {
422     int ret;
423
424     if (ossl_list_urxe_head(&demux->urx_pending) == NULL) {
425         ret = demux_ensure_free_urxe(demux, DEMUX_MAX_MSGS_PER_CALL);
426         if (ret != 1)
427             return 0;
428
429         ret = demux_recv(demux);
430         if (ret != 1)
431             return 0;
432
433         /*
434          * If demux_recv returned successfully, we should always have something.
435          */
436         assert(ossl_list_urxe_head(&demux->urx_pending) != NULL);
437     }
438
439     return demux_process_pending_urxl(demux);
440 }
441
442 /* Artificially inject a packet into the demuxer for testing purposes. */
443 int ossl_quic_demux_inject(QUIC_DEMUX *demux,
444                            const unsigned char *buf,
445                            size_t buf_len,
446                            const BIO_ADDR *peer,
447                            const BIO_ADDR *local)
448 {
449     int ret;
450     QUIC_URXE *urxe;
451
452     ret = demux_ensure_free_urxe(demux, 1);
453     if (ret != 1)
454         return 0;
455
456     urxe = ossl_list_urxe_head(&demux->urx_free);
457     if (buf_len > urxe->alloc_len)
458         return 0;
459
460     memcpy(ossl_quic_urxe_data(urxe), buf, buf_len);
461     urxe->data_len = buf_len;
462
463     if (peer != NULL)
464         urxe->peer = *peer;
465     else
466         BIO_ADDR_clear(&urxe->local);
467
468     if (local != NULL)
469         urxe->local = *local;
470     else
471         BIO_ADDR_clear(&urxe->local);
472
473     /* Move from free list to pending list. */
474     ossl_list_urxe_remove(&demux->urx_free, urxe);
475     ossl_list_urxe_insert_tail(&demux->urx_pending, urxe);
476
477     return demux_process_pending_urxl(demux);
478 }
479
480 /* Called by our user to return a URXE to the free list. */
481 void ossl_quic_demux_release_urxe(QUIC_DEMUX *demux,
482                                   QUIC_URXE *e)
483 {
484     assert(ossl_list_urxe_prev(e) == NULL && ossl_list_urxe_next(e) == NULL);
485     ossl_list_urxe_insert_tail(&demux->urx_free, e);
486 }