Simplify the overflow checks in WPACKET_allocate_bytes()
[openssl.git] / ssl / packet.c
1 /*
2  * Copyright 2015-2016 The OpenSSL Project Authors. All Rights Reserved.
3  *
4  * Licensed under the OpenSSL license (the "License").  You may not use
5  * this file except in compliance with the License.  You can obtain a copy
6  * in the file LICENSE in the source distribution or at
7  * https://www.openssl.org/source/license.html
8  */
9
10 #include "packet_locl.h"
11
12 #define DEFAULT_BUF_SIZE    256
13
14 int WPACKET_allocate_bytes(WPACKET *pkt, size_t len, unsigned char **allocbytes)
15 {
16     if (pkt->subs == NULL || len == 0)
17         return 0;
18
19     if (pkt->maxsize - pkt->written < len)
20         return 0;
21
22     if (pkt->buf->length - pkt->written < len) {
23         size_t newlen;
24
25         if (pkt->buf->length > SIZE_MAX / 2) {
26             newlen = SIZE_MAX;
27         } else {
28             if (pkt->buf->length == 0)
29                 newlen = DEFAULT_BUF_SIZE;
30             else
31                 newlen = pkt->buf->length * 2;
32         }
33         if (BUF_MEM_grow(pkt->buf, newlen) == 0)
34             return 0;
35         if (pkt->curr == NULL) {
36             /*
37              * Can happen if initialised with a BUF_MEM that hasn't been
38              * pre-allocated.
39              */
40             pkt->curr = (unsigned char *)pkt->buf->data;
41         }
42     }
43     pkt->written += len;
44     *allocbytes = pkt->curr;
45     pkt->curr += len;
46
47     return 1;
48 }
49
50 static size_t maxmaxsize(size_t lenbytes)
51 {
52     if (lenbytes >= sizeof(size_t) || lenbytes == 0)
53         return SIZE_MAX;
54     else
55         return ((size_t)1 << (lenbytes * 8)) - 1 + lenbytes;
56 }
57
58 int WPACKET_init_len(WPACKET *pkt, BUF_MEM *buf, size_t lenbytes)
59 {
60     /* Sanity check */
61     if (buf == NULL)
62         return 0;
63
64     pkt->buf = buf;
65     pkt->curr = (unsigned char *)buf->data;
66     pkt->written = 0;
67     pkt->maxsize = maxmaxsize(lenbytes);
68
69     pkt->subs = OPENSSL_zalloc(sizeof(*pkt->subs));
70     if (pkt->subs == NULL)
71         return 0;
72
73     if (lenbytes == 0)
74         return 1;
75
76     pkt->subs->pwritten = lenbytes;
77     pkt->subs->lenbytes = lenbytes;
78
79     if (!WPACKET_allocate_bytes(pkt, lenbytes, &(pkt->subs->packet_len))) {
80         OPENSSL_free(pkt->subs);
81         pkt->subs = NULL;
82         return 0;
83     }
84
85     return 1;
86 }
87
88 int WPACKET_init(WPACKET *pkt, BUF_MEM *buf)
89 {
90     return WPACKET_init_len(pkt, buf, 0);
91 }
92
93 int WPACKET_set_packet_len(WPACKET *pkt, unsigned char *packet_len,
94                            size_t lenbytes)
95 {
96     size_t maxmax;
97
98     /* We only allow this to be set once */
99     if (pkt->subs == NULL || pkt->subs->lenbytes != 0)
100         return 0;
101
102     pkt->subs->lenbytes = lenbytes;
103     pkt->subs->packet_len = packet_len;
104
105     maxmax = maxmaxsize(lenbytes);
106     if (pkt->maxsize > maxmax)
107         pkt->maxsize = maxmax;
108
109     return 1;
110 }
111
112 int WPACKET_set_flags(WPACKET *pkt, unsigned int flags)
113 {
114     if (pkt->subs == NULL)
115         return 0;
116
117     pkt->subs->flags = flags;
118
119     return 1;
120 }
121
122
123 /*
124  * Internal helper function used by WPACKET_close() and WPACKET_finish() to
125  * close a sub-packet and write out its length if necessary.
126  */
127 static int wpacket_intern_close(WPACKET *pkt)
128 {
129     size_t packlen;
130     WPACKET_SUB *sub = pkt->subs;
131
132     packlen = pkt->written - sub->pwritten;
133     if (packlen == 0
134             && sub->flags & OPENSSL_WPACKET_FLAGS_NON_ZERO_LENGTH)
135         return 0;
136
137     if (packlen == 0
138             && sub->flags & OPENSSL_WPACKET_FLAGS_ABANDON_ON_ZERO_LENGTH) {
139         /* Deallocate any bytes allocated for the length of the WPACKET */
140         if ((pkt->curr - sub->lenbytes) == sub->packet_len) {
141             pkt->written -= sub->lenbytes;
142             pkt->curr -= sub->lenbytes;
143         }
144
145         /* Don't write out the packet length */
146         sub->packet_len = NULL;
147     }
148
149     /* Write out the WPACKET length if needed */
150     if (sub->packet_len != NULL) {
151         size_t lenbytes;
152
153         lenbytes = sub->lenbytes;
154
155         for (; lenbytes > 0; lenbytes--) {
156             sub->packet_len[lenbytes - 1]
157                 = (unsigned char)(packlen & 0xff);
158             packlen >>= 8;
159         }
160         if (packlen > 0) {
161             /*
162              * We've extended beyond the max allowed for the number of len bytes
163              */
164             return 0;
165         }
166     }
167
168     pkt->subs = sub->parent;
169     OPENSSL_free(sub);
170
171     return 1;
172 }
173
174 int WPACKET_close(WPACKET *pkt)
175 {
176     if (pkt->subs == NULL || pkt->subs->parent == NULL)
177         return 0;
178
179     return wpacket_intern_close(pkt);
180 }
181
182 int WPACKET_finish(WPACKET *pkt)
183 {
184     int ret;
185
186     if (pkt->subs == NULL || pkt->subs->parent != NULL)
187         return 0;
188
189     ret = wpacket_intern_close(pkt);
190
191     if (ret) {
192         OPENSSL_free(pkt->subs);
193         pkt->subs = NULL;
194     }
195     return ret;
196 }
197
198 int WPACKET_start_sub_packet_len(WPACKET *pkt, size_t lenbytes)
199 {
200     WPACKET_SUB *sub;
201
202     if (pkt->subs == NULL)
203         return 0;
204
205     sub = OPENSSL_zalloc(sizeof(*sub));
206     if (sub == NULL)
207         return 0;
208
209     sub->parent = pkt->subs;
210     pkt->subs = sub;
211     sub->pwritten = pkt->written + lenbytes;
212     sub->lenbytes = lenbytes;
213
214     if (lenbytes == 0) {
215         sub->packet_len = NULL;
216         return 1;
217     }
218
219     if (!WPACKET_allocate_bytes(pkt, lenbytes, &sub->packet_len)) {
220         return 0;
221     }
222
223     return 1;
224 }
225
226 int WPACKET_start_sub_packet(WPACKET *pkt)
227 {
228     return WPACKET_start_sub_packet_len(pkt, 0);
229 }
230
231 int WPACKET_put_bytes(WPACKET *pkt, unsigned int val, size_t bytes)
232 {
233     unsigned char *data;
234
235     if (bytes > sizeof(unsigned int)
236             || !WPACKET_allocate_bytes(pkt, bytes, &data))
237         return 0;
238
239     data += bytes - 1;
240     for (; bytes > 0; bytes--) {
241         *data = (unsigned char)(val & 0xff);
242         data--;
243         val >>= 8;
244     }
245
246     /* Check whether we could fit the value in the assigned number of bytes */
247     if (val > 0)
248         return 0;
249
250     return 1;
251 }
252
253 int WPACKET_set_max_size(WPACKET *pkt, size_t maxsize)
254 {
255     WPACKET_SUB *sub;
256     size_t lenbytes;
257
258     if (pkt->subs == NULL)
259         return 0;
260
261     /* Find the WPACKET_SUB for the top level */
262     for (sub = pkt->subs; sub->parent != NULL; sub = sub->parent);
263
264     lenbytes = sub->lenbytes;
265     if (lenbytes == 0)
266         lenbytes = sizeof(pkt->maxsize);
267
268     if (maxmaxsize(lenbytes) < maxsize || maxsize < pkt->written)
269         return 0;
270
271     pkt->maxsize = maxsize;
272
273     return 1;
274 }
275
276 int WPACKET_memcpy(WPACKET *pkt, const void *src, size_t len)
277 {
278     unsigned char *dest;
279
280     if (len == 0)
281         return 1;
282
283     if (!WPACKET_allocate_bytes(pkt, len, &dest))
284         return 0;
285
286     memcpy(dest, src, len);
287
288     return 1;
289 }
290
291 int WPACKET_sub_memcpy(WPACKET *pkt, const void *src, size_t len, size_t lenbytes)
292 {
293     if (!WPACKET_start_sub_packet_len(pkt, lenbytes)
294             || !WPACKET_memcpy(pkt, src, len)
295             || !WPACKET_close(pkt))
296         return 0;
297
298     return 1;
299 }
300
301 int WPACKET_get_total_written(WPACKET *pkt, size_t *written)
302 {
303     if (written == NULL)
304         return 0;
305
306     *written = pkt->written;
307
308     return 1;
309 }
310
311 int WPACKET_get_length(WPACKET *pkt, size_t *len)
312 {
313     if (pkt->subs == NULL || len == NULL)
314         return 0;
315
316     *len = pkt->written - pkt->subs->pwritten;
317
318     return 1;
319 }
320
321 void WPACKET_cleanup(WPACKET *pkt)
322 {
323     WPACKET_SUB *sub, *parent;
324
325     for (sub = pkt->subs; sub != NULL; sub = parent) {
326         parent = sub->parent;
327         OPENSSL_free(sub);
328     }
329     pkt->subs = NULL;
330 }