uint_set: convert uint_set to use the list data type
authorPauli <pauli@openssl.org>
Tue, 11 Oct 2022 07:41:04 +0000 (18:41 +1100)
committerPauli <pauli@openssl.org>
Wed, 16 Nov 2022 07:02:02 +0000 (18:02 +1100)
This is instead of re-implementing a linked list itself.

Reviewed-by: Tim Hudson <tjh@openssl.org>
Reviewed-by: Shane Lontis <shane.lontis@oracle.com>
(Merged from https://github.com/openssl/openssl/pull/19377)

include/internal/uint_set.h
ssl/quic/quic_ackm.c
ssl/quic/quic_sstream.c
ssl/quic/uint_set.c

index 800860718e019caa40dde31641b0f4ec9e788f79..dcb29b33f3cca7d6f772301850b7b8084696775d 100644 (file)
@@ -10,6 +10,7 @@
 # define OSSL_UINT_SET_H
 
 #include "openssl/params.h"
+#include "internal/list.h"
 
 /*
  * uint64_t Integer Sets
@@ -27,17 +28,15 @@ typedef struct uint_range_st {
     uint64_t    start, end;
 } UINT_RANGE;
 
-typedef struct uint_set_item_st {
-    struct uint_set_item_st    *prev, *next;
+typedef struct uint_set_item_st UINT_SET_ITEM;
+struct uint_set_item_st {
+    OSSL_LIST_MEMBER(uint_set, UINT_SET_ITEM);
     UINT_RANGE                  range;
-} UINT_SET_ITEM;
+};
 
-typedef struct uint_set_st {
-    UINT_SET_ITEM  *head, *tail;
+DEFINE_LIST_OF(uint_set, UINT_SET_ITEM);
 
-    /* Number of ranges (not integers) in the set. */
-    size_t          num_ranges;
-} UINT_SET;
+typedef OSSL_LIST(uint_set) UINT_SET;
 
 void ossl_uint_set_init(UINT_SET *s);
 void ossl_uint_set_destroy(UINT_SET *s);
index 492bf2f1e4937a3bc9bcd9099615a1b3af06e64b..0f7166eaf39d4c605a68c57702af4d5a56d8219b 100644 (file)
@@ -440,8 +440,8 @@ static void rx_pkt_history_trim_range_count(struct rx_pkt_history_st *h)
 {
     QUIC_PN highest = QUIC_PN_INVALID;
 
-    while (h->set.num_ranges > MAX_RX_ACK_RANGES) {
-        UINT_RANGE r = h->set.head->range;
+    while (ossl_list_uint_set_num(&h->set) > MAX_RX_ACK_RANGES) {
+        UINT_RANGE r = ossl_list_uint_set_head(&h->set)->range;
 
         highest = (highest == QUIC_PN_INVALID)
             ? r.end : ossl_quic_pn_max(highest, r.end);
@@ -1416,7 +1416,7 @@ static int ackm_has_newly_missing(OSSL_ACKM *ackm, int pkt_space)
 
     h = get_rx_history(ackm, pkt_space);
 
-    if (h->set.tail == NULL)
+    if (ossl_list_uint_set_num(&h->set) == 0)
         return 0;
 
     /*
@@ -1432,8 +1432,9 @@ static int ackm_has_newly_missing(OSSL_ACKM *ackm, int pkt_space)
      * the PNs we have ACK'd previously and the PN we have just received.
      */
     return ackm->ack[pkt_space].num_ack_ranges > 0
-        && h->set.tail->range.start == h->set.tail->range.end
-        && h->set.tail->range.start
+        && ossl_list_uint_set_tail(&h->set)->range.start
+           == ossl_list_uint_set_tail(&h->set)->range.end
+        && ossl_list_uint_set_tail(&h->set)->range.start
             > ackm->ack[pkt_space].ack_ranges[0].end + 1;
 }
 
@@ -1582,9 +1583,9 @@ static void ackm_fill_rx_ack_ranges(OSSL_ACKM *ackm, int pkt_space,
      * Copy out ranges from the PN set, starting at the end, until we reach our
      * maximum number of ranges.
      */
-    for (x = h->set.tail;
+    for (x = ossl_list_uint_set_tail(&h->set);
          x != NULL && i < OSSL_NELEM(ackm->ack_ranges);
-         x = x->prev, ++i) {
+         x = ossl_list_uint_set_prev(x), ++i) {
         ackm->ack_ranges[pkt_space][i].start = x->range.start;
         ackm->ack_ranges[pkt_space][i].end   = x->range.end;
     }
index 56d80cb56812ad3aa8f3d42dc3eb6b2b061f8108..07113884e12f9593a13779f06e91d296bf1209ac 100644 (file)
@@ -271,13 +271,13 @@ int ossl_quic_sstream_get_stream_frame(QUIC_SSTREAM *qss,
     size_t num_iov_ = 0, src_len = 0, total_len = 0, i;
     uint64_t max_len;
     const unsigned char *src = NULL;
-    UINT_SET_ITEM *range = qss->new_set.head;
+    UINT_SET_ITEM *range = ossl_list_uint_set_head(&qss->new_set);
 
     if (*num_iov < 2)
         return 0;
 
     for (i = 0; i < skip && range != NULL; ++i)
-        range = range->next;
+        range = ossl_list_uint_set_next(range);
 
     if (range == NULL) {
         /* No new bytes to send, but we might have a FIN */
@@ -476,6 +476,8 @@ int ossl_quic_sstream_append(QUIC_SSTREAM *qss,
 
 static void qss_cull(QUIC_SSTREAM *qss)
 {
+    UINT_SET_ITEM *h = ossl_list_uint_set_head(&qss->acked_set);
+
     /*
      * Potentially cull data from our ring buffer. This can happen once data has
      * been ACKed and we know we are never going to have to transmit it again.
@@ -492,10 +494,8 @@ static void qss_cull(QUIC_SSTREAM *qss)
      * We only need to check the first range entry in the integer set because we
      * can only cull contiguous areas at the start of the ring buffer anyway.
      */
-    if (qss->acked_set.head != NULL)
-        ring_buf_cpop_range(&qss->ring_buf,
-                            qss->acked_set.head->range.start,
-                            qss->acked_set.head->range.end);
+    if (h != NULL)
+        ring_buf_cpop_range(&qss->ring_buf, h->range.start, h->range.end);
 }
 
 int ossl_quic_sstream_set_buffer_size(QUIC_SSTREAM *qss, size_t num_bytes)
index bfa8e3f93f07ae26d473b5c9e440ec722e5d9cea..9d0440b42361e9a9a765c6afa9b6552913a22eb8 100644 (file)
  */
 void ossl_uint_set_init(UINT_SET *s)
 {
-    s->head = s->tail = NULL;
-    s->num_ranges = 0;
+    ossl_list_uint_set_init(s);
 }
 
 void ossl_uint_set_destroy(UINT_SET *s)
 {
     UINT_SET_ITEM *x, *xnext;
 
-    for (x = s->head; x != NULL; x = xnext) {
-        xnext = x->next;
+    for (x = ossl_list_uint_set_head(s); x != NULL; x = xnext) {
+        xnext = ossl_list_uint_set_next(x);
         OPENSSL_free(x);
     }
 }
 
-/* Possible merge of x, x->prev */
+/* Possible merge of x, prev(x) */
 static void uint_set_merge_adjacent(UINT_SET *s, UINT_SET_ITEM *x)
 {
-    UINT_SET_ITEM *xprev = x->prev;
+    UINT_SET_ITEM *xprev = ossl_list_uint_set_prev(x);
 
     if (xprev == NULL)
         return;
@@ -85,15 +84,8 @@ static void uint_set_merge_adjacent(UINT_SET *s, UINT_SET_ITEM *x)
         return;
 
     x->range.start = xprev->range.start;
-    x->prev = xprev->prev;
-    if (x->prev != NULL)
-        x->prev->next = x;
-
-    if (s->head == xprev)
-        s->head = x;
-
+    ossl_list_uint_set_remove(s, xprev);
     OPENSSL_free(xprev);
-    --s->num_ranges;
 }
 
 static uint64_t u64_min(uint64_t x, uint64_t y)
@@ -117,28 +109,37 @@ static int uint_range_overlaps(const UINT_RANGE *a,
         >= u64_max(a->start, b->start);
 }
 
+static UINT_SET_ITEM *create_set_item(uint64_t start, uint64_t end)
+{
+        UINT_SET_ITEM *x = OPENSSL_malloc(sizeof(UINT_SET_ITEM));
+
+        ossl_list_uint_set_init_elem(x);
+        if (x != NULL) {
+            x->range.start = start;
+            x->range.end   = end;
+        }
+        return x;
+}
+
 int ossl_uint_set_insert(UINT_SET *s, const UINT_RANGE *range)
 {
-    UINT_SET_ITEM *x, *z, *xnext, *f, *fnext;
+    UINT_SET_ITEM *x, *xnext, *z, *zprev, *f;
     uint64_t start = range->start, end = range->end;
 
     if (!ossl_assert(start <= end))
         return 0;
 
-    if (s->head == NULL) {
+    if (ossl_list_uint_set_is_empty(s)) {
         /* Nothing in the set yet, so just add this range. */
-        x = OPENSSL_zalloc(sizeof(UINT_SET_ITEM));
+        x = create_set_item(start, end);
         if (x == NULL)
             return 0;
-
-        x->range.start = start;
-        x->range.end   = end;
-        s->head = s->tail = x;
-        ++s->num_ranges;
+        ossl_list_uint_set_insert_head(s, x);
         return 1;
     }
 
-    if (start > s->tail->range.end) {
+    z = ossl_list_uint_set_tail(s);
+    if (start > z->range.end) {
         /*
          * Range is after the latest range in the set, so append.
          *
@@ -146,42 +147,33 @@ int ossl_uint_set_insert(UINT_SET *s, const UINT_RANGE *range)
          * set is handled as a degenerate case of the final case below. See
          * optimization note (*) below.
          */
-        if (s->tail->range.end + 1 == start) {
-            s->tail->range.end = end;
+        if (z->range.end + 1 == start) {
+            z->range.end = end;
             return 1;
         }
 
-        x = OPENSSL_zalloc(sizeof(UINT_SET_ITEM));
+        x = create_set_item(start, end);
         if (x == NULL)
             return 0;
-
-        x->range.start = start;
-        x->range.end   = end;
-        x->prev        = s->tail;
-        if (s->tail != NULL)
-            s->tail->next = x;
-        s->tail = x;
-        ++s->num_ranges;
+        ossl_list_uint_set_insert_tail(s, x);
         return 1;
     }
 
-    if (start <= s->head->range.start && end >= s->tail->range.end) {
+    f = ossl_list_uint_set_head(s);
+    if (start <= f->range.start && end >= z->range.end) {
         /*
          * New range dwarfs all ranges in our set.
          *
          * Free everything except the first range in the set, which we scavenge
          * and reuse.
          */
-        for (x = s->head->next; x != NULL; x = xnext) {
-            xnext = x->next;
-            OPENSSL_free(x);
+        x = ossl_list_uint_set_head(s);
+        x->range.start = start;
+        x->range.end = end;
+        for (x = ossl_list_uint_set_next(x); x != NULL; x = xnext) {
+            xnext = ossl_list_uint_set_next(x);
+            ossl_list_uint_set_remove(s, x);
         }
-
-        s->head->range.start = start;
-        s->head->range.end   = end;
-        s->head->next = s->head->prev = NULL;
-        s->tail = s->head;
-        s->num_ranges = 1;
         return 1;
     }
 
@@ -192,9 +184,11 @@ int ossl_uint_set_insert(UINT_SET *s, const UINT_RANGE *range)
      * insertion at the start and end of the space will be the most common
      * operations. (*)
      */
-    z = end < s->head->range.start ? s->head : s->tail;
+    z = end < f->range.start ? f : z;
+
+    for (; z != NULL; z = zprev) {
+        zprev = ossl_list_uint_set_prev(z);
 
-    for (; z != NULL; z = z->prev) {
         /* An existing range dwarfs our new range (optimisation). */
         if (z->range.start <= start && z->range.end >= end)
             return 1;
@@ -205,35 +199,26 @@ int ossl_uint_set_insert(UINT_SET *s, const UINT_RANGE *range)
              * existing ranges.
              */
             UINT_SET_ITEM *ovend = z;
-            UINT_RANGE t;
-            size_t n = 0;
 
-            t.end = u64_max(end, z->range.end);
+            ovend->range.end = u64_max(end, z->range.end);
 
             /* Get earliest overlapping range. */
-            for (; z->prev != NULL && uint_range_overlaps(&z->prev->range, range);
-                   z = z->prev);
-
-            t.start = u64_min(start, z->range.start);
-
-            /* Replace sequence of nodes z..ovend with ovend only. */
-            ovend->range = t;
-            ovend->prev = z->prev;
-            if (z->prev != NULL)
-                z->prev->next = ovend;
-            if (s->head == z)
-                s->head = ovend;
-
-            /* Free now unused nodes. */
-            for (f = z; f != ovend; f = fnext, ++n) {
-                fnext = f->next;
-                OPENSSL_free(f);
+            while (zprev != NULL && uint_range_overlaps(&zprev->range, range)) {
+                z = zprev;
+                zprev = ossl_list_uint_set_prev(z);
             }
 
-            s->num_ranges -= n;
+            ovend->range.start = u64_min(start, z->range.start);
+
+            /* Replace sequence of nodes z..ovend with updated ovend only. */
+            while (z != ovend) {
+                z = ossl_list_uint_set_next(x = z);
+                ossl_list_uint_set_remove(s, x);
+                OPENSSL_free(x);
+            }
             break;
         } else if (end < z->range.start
-                    && (z->prev == NULL || start > z->prev->range.end)) {
+                    && (zprev == NULL || start > zprev->range.end)) {
             if (z->range.start == end + 1) {
                 /* We can extend the following range backwards. */
                 z->range.start = start;
@@ -243,9 +228,9 @@ int ossl_uint_set_insert(UINT_SET *s, const UINT_RANGE *range)
                  * consecutive nodes.
                  */
                 uint_set_merge_adjacent(s, z);
-            } else if (z->prev != NULL && z->prev->range.end + 1 == start) {
+            } else if (zprev != NULL && zprev->range.end + 1 == start) {
                 /* We can extend the preceding range forwards. */
-                z->prev->range.end = end;
+                zprev->range.end = end;
 
                 /*
                  * If this closes a gap we now need to merge
@@ -257,22 +242,10 @@ int ossl_uint_set_insert(UINT_SET *s, const UINT_RANGE *range)
                  * The new interval is between intervals without overlapping or
                  * touching them, so insert between, preserving sort.
                  */
-                x = OPENSSL_zalloc(sizeof(UINT_SET_ITEM));
+                x = create_set_item(start, end);
                 if (x == NULL)
                     return 0;
-
-                x->range.start = start;
-                x->range.end   = end;
-
-                x->next = z;
-                x->prev = z->prev;
-                if (x->prev != NULL)
-                    x->prev->next = x;
-                z->prev = x;
-                if (s->head == z)
-                    s->head = x;
-
-                ++s->num_ranges;
+                ossl_list_uint_set_insert_before(s, z, x);
             }
             break;
         }
@@ -290,8 +263,8 @@ int ossl_uint_set_remove(UINT_SET *s, const UINT_RANGE *range)
         return 0;
 
     /* Walk backwards since we will most often be removing at the end. */
-    for (z = s->tail; z != NULL; z = zprev) {
-        zprev = z->prev;
+    for (z = ossl_list_uint_set_tail(s); z != NULL; z = zprev) {
+        zprev = ossl_list_uint_set_prev(z);
 
         if (start > z->range.end)
             /* No overlapping ranges can exist beyond this point, so stop. */
@@ -302,17 +275,8 @@ int ossl_uint_set_remove(UINT_SET *s, const UINT_RANGE *range)
              * The range being removed dwarfs this range, so it should be
              * removed.
              */
-            if (z->next != NULL)
-                z->next->prev = z->prev;
-            if (z->prev != NULL)
-                z->prev->next = z->next;
-            if (s->head == z)
-                s->head = z->next;
-            if (s->tail == z)
-                s->tail = z->prev;
-
+            ossl_list_uint_set_remove(s, z);
             OPENSSL_free(z);
-            --s->num_ranges;
         } else if (start <= z->range.start) {
             /*
              * The range being removed includes start of this range, but does
@@ -337,24 +301,8 @@ int ossl_uint_set_remove(UINT_SET *s, const UINT_RANGE *range)
              * into two. Cases where a zero-length range would be created are
              * handled by the above cases.
              */
-            y = OPENSSL_zalloc(sizeof(UINT_SET_ITEM));
-            if (y == NULL)
-                return 0;
-
-            y->range.end   = z->range.end;
-            y->range.start = end + 1;
-            y->next = z->next;
-            y->prev = z;
-            if (y->next != NULL)
-                y->next->prev = y;
-
-            z->range.end = start - 1;
-            z->next = y;
-
-            if (s->tail == z)
-                s->tail = y;
-
-            ++s->num_ranges;
+            y = create_set_item(end + 1, z->range.end);
+            ossl_list_uint_set_insert_after(s, z, y);
             break;
         } else {
             /* Assert no partial overlap; all cases should be covered above. */
@@ -369,10 +317,10 @@ int ossl_uint_set_query(const UINT_SET *s, uint64_t v)
 {
     UINT_SET_ITEM *x;
 
-    if (s->head == NULL)
+    if (ossl_list_uint_set_is_empty(s))
         return 0;
 
-    for (x = s->tail; x != NULL; x = x->prev)
+    for (x = ossl_list_uint_set_tail(s); x != NULL; x = ossl_list_uint_set_prev(x))
         if (x->range.start <= v && x->range.end >= v)
             return 1;
         else if (x->range.end < v)