Add a maximum output length to update and final calls
[openssl.git] / providers / common / ciphers / aes.c
1 /*
2  * Copyright 2019 The OpenSSL Project Authors. All Rights Reserved.
3  *
4  * Licensed under the Apache License 2.0 (the "License").  You may not use
5  * this file except in compliance with the License.  You can obtain a copy
6  * in the file LICENSE in the source distribution or at
7  * https://www.openssl.org/source/license.html
8  */
9
10 #include <string.h>
11 #include <openssl/crypto.h>
12 #include <openssl/core_numbers.h>
13 #include <openssl/core_names.h>
14 #include <openssl/evp.h>
15 #include <openssl/params.h>
16 #include "internal/cryptlib.h"
17 #include "internal/provider_algs.h"
18 #include "ciphers_locl.h"
19
20 static int PROV_AES_KEY_generic_init(PROV_AES_KEY *ctx,
21                                       const unsigned char *iv,
22                                       size_t ivlen,
23                                       int enc)
24 {
25     if (iv != NULL && ctx->mode != EVP_CIPH_ECB_MODE) {
26         if (ivlen != AES_BLOCK_SIZE)
27             return 0;
28         memcpy(ctx->iv, iv, AES_BLOCK_SIZE);
29     }
30     ctx->enc = enc;
31
32     return 1;
33 }
34
35 static int aes_einit(void *vctx, const unsigned char *key, size_t keylen,
36                            const unsigned char *iv, size_t ivlen)
37 {
38     PROV_AES_KEY *ctx = (PROV_AES_KEY *)vctx;
39
40     if (!PROV_AES_KEY_generic_init(ctx, iv, ivlen, 1))
41         return 0;
42     if (key != NULL) {
43         if (keylen != ctx->keylen)
44             return 0;
45         return ctx->ciph->init(ctx, key, ctx->keylen);
46     }
47
48     return 1;
49 }
50
51 static int aes_dinit(void *vctx, const unsigned char *key, size_t keylen,
52                      const unsigned char *iv, size_t ivlen)
53 {
54     PROV_AES_KEY *ctx = (PROV_AES_KEY *)vctx;
55
56     if (!PROV_AES_KEY_generic_init(ctx, iv, ivlen, 0))
57         return 0;
58     if (key != NULL) {
59         if (keylen != ctx->keylen)
60             return 0;
61         return ctx->ciph->init(ctx, key, ctx->keylen);
62     }
63
64     return 1;
65 }
66
67 static int aes_block_update(void *vctx, unsigned char *out, size_t *outl,
68                             size_t outsize, const unsigned char *in, size_t inl)
69 {
70     PROV_AES_KEY *ctx = (PROV_AES_KEY *)vctx;
71     size_t nextblocks = fillblock(ctx->buf, &ctx->bufsz, AES_BLOCK_SIZE, &in,
72                                   &inl);
73     size_t outlint = 0;
74
75     /*
76      * If we're decrypting and we end an update on a block boundary we hold
77      * the last block back in case this is the last update call and the last
78      * block is padded.
79      */
80     if (ctx->bufsz == AES_BLOCK_SIZE
81             && (ctx->enc || inl > 0 || !ctx->pad)) {
82         if (outsize < AES_BLOCK_SIZE)
83             return 0;
84         if (!ctx->ciph->cipher(ctx, out, ctx->buf, AES_BLOCK_SIZE))
85             return 0;
86         ctx->bufsz = 0;
87         outlint = AES_BLOCK_SIZE;
88         out += AES_BLOCK_SIZE;
89     }
90     if (nextblocks > 0) {
91         if (!ctx->enc && ctx->pad && nextblocks == inl) {
92             if (!ossl_assert(inl >= AES_BLOCK_SIZE))
93                 return 0;
94             nextblocks -= AES_BLOCK_SIZE;
95         }
96         outlint += nextblocks;
97         if (outsize < outlint)
98             return 0;
99         if (!ctx->ciph->cipher(ctx, out, in, nextblocks))
100             return 0;
101         in += nextblocks;
102         inl -= nextblocks;
103     }
104     if (!trailingdata(ctx->buf, &ctx->bufsz, AES_BLOCK_SIZE, &in, &inl))
105         return 0;
106
107     *outl = outlint;
108     return inl == 0;
109 }
110
111 static int aes_block_final(void *vctx, unsigned char *out, size_t *outl,
112                            size_t outsize)
113 {
114     PROV_AES_KEY *ctx = (PROV_AES_KEY *)vctx;
115
116     if (ctx->enc) {
117         if (ctx->pad) {
118             padblock(ctx->buf, &ctx->bufsz, AES_BLOCK_SIZE);
119         } else if (ctx->bufsz == 0) {
120             *outl = 0;
121             return 1;
122         } else if (ctx->bufsz != AES_BLOCK_SIZE) {
123             /* TODO(3.0): What is the correct error code here? */
124             return 0;
125         }
126
127         if (outsize < AES_BLOCK_SIZE)
128             return 0;
129         if (!ctx->ciph->cipher(ctx, out, ctx->buf, AES_BLOCK_SIZE))
130             return 0;
131         ctx->bufsz = 0;
132         *outl = AES_BLOCK_SIZE;
133         return 1;
134     }
135
136     /* Decrypting */
137     /* TODO(3.0): What's the correct error here */
138     if (ctx->bufsz != AES_BLOCK_SIZE) {
139         if (ctx->bufsz == 0 && !ctx->pad) {
140             *outl = 0;
141             return 1;
142         }
143         return 0;
144     }
145
146     if (!ctx->ciph->cipher(ctx, ctx->buf, ctx->buf, AES_BLOCK_SIZE))
147         return 0;
148
149     /* TODO(3.0): What is the correct error here */
150     if (ctx->pad && !unpadblock(ctx->buf, &ctx->bufsz, AES_BLOCK_SIZE))
151         return 0;
152
153     if (outsize < ctx->bufsz)
154         return 0;
155     memcpy(out, ctx->buf, ctx->bufsz);
156     *outl = ctx->bufsz;
157     ctx->bufsz = 0;
158     return 1;
159 }
160
161 static int aes_stream_update(void *vctx, unsigned char *out, size_t *outl,
162                              size_t outsize, const unsigned char *in,
163                              size_t inl)
164 {
165     PROV_AES_KEY *ctx = (PROV_AES_KEY *)vctx;
166
167     if (outsize < inl)
168         return 0;
169
170     if (!ctx->ciph->cipher(ctx, out, in, inl))
171         return 0;
172
173     *outl = inl;
174     return 1;
175 }
176 static int aes_stream_final(void *vctx, unsigned char *out, size_t *outl,
177                             size_t outsize)
178 {
179     *outl = 0;
180     return 1;
181 }
182
183 static int aes_cipher(void *vctx, unsigned char *out, const unsigned char *in,
184                       size_t inl)
185 {
186     PROV_AES_KEY *ctx = (PROV_AES_KEY *)vctx;
187
188     if (!ctx->ciph->cipher(ctx, out, in, inl))
189         return 0;
190
191     return 1;
192 }
193
194 #define IMPLEMENT_new_params(lcmode, UCMODE) \
195     static int aes_##lcmode##_get_params(const OSSL_PARAM params[]) \
196     { \
197         const OSSL_PARAM *p; \
198     \
199         p = OSSL_PARAM_locate(params, OSSL_CIPHER_PARAM_MODE); \
200         if (p != NULL && !OSSL_PARAM_set_int(p, EVP_CIPH_##UCMODE##_MODE)) \
201             return 0; \
202     \
203         return 1; \
204     }
205
206 #define IMPLEMENT_new_ctx(lcmode, UCMODE, len) \
207     static void *aes_##len##_##lcmode##_newctx(void) \
208     { \
209         PROV_AES_KEY *ctx = OPENSSL_zalloc(sizeof(*ctx)); \
210     \
211         ctx->pad = 1; \
212         ctx->keylen = (len / 8); \
213         ctx->ciph = PROV_AES_CIPHER_##lcmode(); \
214         ctx->mode = EVP_CIPH_##UCMODE##_MODE; \
215         return ctx; \
216     }
217
218 /* ECB */
219 IMPLEMENT_new_params(ecb, ECB)
220 IMPLEMENT_new_ctx(ecb, ECB, 256)
221 IMPLEMENT_new_ctx(ecb, ECB, 192)
222 IMPLEMENT_new_ctx(ecb, ECB, 128)
223
224 /* CBC */
225 IMPLEMENT_new_params(cbc, CBC)
226 IMPLEMENT_new_ctx(cbc, CBC, 256)
227 IMPLEMENT_new_ctx(cbc, CBC, 192)
228 IMPLEMENT_new_ctx(cbc, CBC, 128)
229
230 /* OFB */
231 IMPLEMENT_new_params(ofb, OFB)
232 IMPLEMENT_new_ctx(ofb, OFB, 256)
233 IMPLEMENT_new_ctx(ofb, OFB, 192)
234 IMPLEMENT_new_ctx(ofb, OFB, 128)
235
236 /* CFB */
237 IMPLEMENT_new_params(cfb, CFB)
238 IMPLEMENT_new_params(cfb1, CFB)
239 IMPLEMENT_new_params(cfb8, CFB)
240 IMPLEMENT_new_ctx(cfb, CFB, 256)
241 IMPLEMENT_new_ctx(cfb, CFB, 192)
242 IMPLEMENT_new_ctx(cfb, CFB, 128)
243 IMPLEMENT_new_ctx(cfb1, CFB, 256)
244 IMPLEMENT_new_ctx(cfb1, CFB, 192)
245 IMPLEMENT_new_ctx(cfb1, CFB, 128)
246 IMPLEMENT_new_ctx(cfb8, CFB, 256)
247 IMPLEMENT_new_ctx(cfb8, CFB, 192)
248 IMPLEMENT_new_ctx(cfb8, CFB, 128)
249
250 /* CTR */
251 IMPLEMENT_new_params(ctr, CTR)
252 IMPLEMENT_new_ctx(ctr, CTR, 256)
253 IMPLEMENT_new_ctx(ctr, CTR, 192)
254 IMPLEMENT_new_ctx(ctr, CTR, 128)
255
256 static void aes_freectx(void *vctx)
257 {
258     PROV_AES_KEY *ctx = (PROV_AES_KEY *)vctx;
259
260     OPENSSL_clear_free(ctx,  sizeof(*ctx));
261 }
262
263 static void *aes_dupctx(void *ctx)
264 {
265     PROV_AES_KEY *in = (PROV_AES_KEY *)ctx;
266     PROV_AES_KEY *ret = OPENSSL_malloc(sizeof(*ret));
267
268     *ret = *in;
269
270     return ret;
271 }
272
273 static size_t key_length_256(void)
274 {
275     return 256 / 8;
276 }
277
278 static size_t key_length_192(void)
279 {
280     return 192 / 8;
281 }
282
283 static size_t key_length_128(void)
284 {
285     return 128 / 8;
286 }
287
288 static size_t iv_length_16(void)
289 {
290     return 16;
291 }
292
293 static size_t iv_length_0(void)
294 {
295     return 0;
296 }
297
298 static size_t block_size_16(void)
299 {
300     return 16;
301 }
302
303 static size_t block_size_1(void)
304 {
305     return 1;
306 }
307
308 static int aes_ctx_get_params(void *vctx, const OSSL_PARAM params[])
309 {
310     PROV_AES_KEY *ctx = (PROV_AES_KEY *)vctx;
311     const OSSL_PARAM *p;
312
313     p = OSSL_PARAM_locate(params, OSSL_CIPHER_PARAM_PADDING);
314     if (p != NULL && !OSSL_PARAM_set_uint(p, ctx->pad))
315         return 0;
316
317     return 1;
318 }
319
320 static int aes_ctx_set_params(void *vctx, const OSSL_PARAM params[])
321 {
322     PROV_AES_KEY *ctx = (PROV_AES_KEY *)vctx;
323     const OSSL_PARAM *p;
324
325     p = OSSL_PARAM_locate(params, OSSL_CIPHER_PARAM_PADDING);
326     if (p != NULL) {
327         int pad;
328
329         if (!OSSL_PARAM_get_int(p, &pad))
330             return 0;
331         ctx->pad = pad ? 1 : 0;
332     }
333     return 1;
334 }
335
336 #define IMPLEMENT_block_funcs(mode, keylen, ivlen) \
337     const OSSL_DISPATCH aes##keylen##mode##_functions[] = { \
338         { OSSL_FUNC_CIPHER_NEWCTX, (void (*)(void))aes_##keylen##_##mode##_newctx }, \
339         { OSSL_FUNC_CIPHER_ENCRYPT_INIT, (void (*)(void))aes_einit }, \
340         { OSSL_FUNC_CIPHER_DECRYPT_INIT, (void (*)(void))aes_dinit }, \
341         { OSSL_FUNC_CIPHER_UPDATE, (void (*)(void))aes_block_update }, \
342         { OSSL_FUNC_CIPHER_FINAL, (void (*)(void))aes_block_final }, \
343         { OSSL_FUNC_CIPHER_CIPHER, (void (*)(void))aes_cipher }, \
344         { OSSL_FUNC_CIPHER_FREECTX, (void (*)(void))aes_freectx }, \
345         { OSSL_FUNC_CIPHER_DUPCTX, (void (*)(void))aes_dupctx }, \
346         { OSSL_FUNC_CIPHER_KEY_LENGTH, (void (*)(void))key_length_##keylen }, \
347         { OSSL_FUNC_CIPHER_IV_LENGTH, (void (*)(void))iv_length_##ivlen }, \
348         { OSSL_FUNC_CIPHER_BLOCK_SIZE, (void (*)(void))block_size_16 }, \
349         { OSSL_FUNC_CIPHER_GET_PARAMS, (void (*)(void))aes_##mode##_get_params }, \
350         { OSSL_FUNC_CIPHER_CTX_GET_PARAMS, (void (*)(void))aes_ctx_get_params }, \
351         { OSSL_FUNC_CIPHER_CTX_SET_PARAMS, (void (*)(void))aes_ctx_set_params }, \
352         { 0, NULL } \
353     };
354
355 #define IMPLEMENT_stream_funcs(mode, keylen, ivlen) \
356     const OSSL_DISPATCH aes##keylen##mode##_functions[] = { \
357         { OSSL_FUNC_CIPHER_NEWCTX, (void (*)(void))aes_##keylen##_##mode##_newctx }, \
358         { OSSL_FUNC_CIPHER_ENCRYPT_INIT, (void (*)(void))aes_einit }, \
359         { OSSL_FUNC_CIPHER_DECRYPT_INIT, (void (*)(void))aes_dinit }, \
360         { OSSL_FUNC_CIPHER_UPDATE, (void (*)(void))aes_stream_update }, \
361         { OSSL_FUNC_CIPHER_FINAL, (void (*)(void))aes_stream_final }, \
362         { OSSL_FUNC_CIPHER_CIPHER, (void (*)(void))aes_cipher }, \
363         { OSSL_FUNC_CIPHER_FREECTX, (void (*)(void))aes_freectx }, \
364         { OSSL_FUNC_CIPHER_DUPCTX, (void (*)(void))aes_dupctx }, \
365         { OSSL_FUNC_CIPHER_KEY_LENGTH, (void (*)(void))key_length_##keylen }, \
366         { OSSL_FUNC_CIPHER_IV_LENGTH, (void (*)(void))iv_length_##ivlen }, \
367         { OSSL_FUNC_CIPHER_BLOCK_SIZE, (void (*)(void))block_size_1 }, \
368         { OSSL_FUNC_CIPHER_GET_PARAMS, (void (*)(void))aes_##mode##_get_params }, \
369         { OSSL_FUNC_CIPHER_CTX_GET_PARAMS, (void (*)(void))aes_ctx_get_params }, \
370         { OSSL_FUNC_CIPHER_CTX_SET_PARAMS, (void (*)(void))aes_ctx_set_params }, \
371         { 0, NULL } \
372     };
373
374 /* ECB */
375 IMPLEMENT_block_funcs(ecb, 256, 0)
376 IMPLEMENT_block_funcs(ecb, 192, 0)
377 IMPLEMENT_block_funcs(ecb, 128, 0)
378
379 /* CBC */
380 IMPLEMENT_block_funcs(cbc, 256, 16)
381 IMPLEMENT_block_funcs(cbc, 192, 16)
382 IMPLEMENT_block_funcs(cbc, 128, 16)
383
384 /* OFB */
385 IMPLEMENT_stream_funcs(ofb, 256, 16)
386 IMPLEMENT_stream_funcs(ofb, 192, 16)
387 IMPLEMENT_stream_funcs(ofb, 128, 16)
388
389 /* CFB */
390 IMPLEMENT_stream_funcs(cfb, 256, 16)
391 IMPLEMENT_stream_funcs(cfb, 192, 16)
392 IMPLEMENT_stream_funcs(cfb, 128, 16)
393 IMPLEMENT_stream_funcs(cfb1, 256, 16)
394 IMPLEMENT_stream_funcs(cfb1, 192, 16)
395 IMPLEMENT_stream_funcs(cfb1, 128, 16)
396 IMPLEMENT_stream_funcs(cfb8, 256, 16)
397 IMPLEMENT_stream_funcs(cfb8, 192, 16)
398 IMPLEMENT_stream_funcs(cfb8, 128, 16)
399
400 /* CTR */
401 IMPLEMENT_stream_funcs(ctr, 256, 16)
402 IMPLEMENT_stream_funcs(ctr, 192, 16)
403 IMPLEMENT_stream_funcs(ctr, 128, 16)