Adapt the provider AES for more use of OSSL_PARAM
[openssl.git] / providers / common / ciphers / aes.c
1 /*
2  * Copyright 2019 The OpenSSL Project Authors. All Rights Reserved.
3  *
4  * Licensed under the Apache License 2.0 (the "License").  You may not use
5  * this file except in compliance with the License.  You can obtain a copy
6  * in the file LICENSE in the source distribution or at
7  * https://www.openssl.org/source/license.html
8  */
9
10 #include <string.h>
11 #include <openssl/crypto.h>
12 #include <openssl/core_numbers.h>
13 #include <openssl/core_names.h>
14 #include <openssl/evp.h>
15 #include <openssl/params.h>
16 #include "internal/cryptlib.h"
17 #include "internal/provider_algs.h"
18 #include "ciphers_locl.h"
19 #include "internal/providercommonerr.h"
20
21 static OSSL_OP_cipher_encrypt_init_fn aes_einit;
22 static OSSL_OP_cipher_decrypt_init_fn aes_dinit;
23 static OSSL_OP_cipher_update_fn aes_block_update;
24 static OSSL_OP_cipher_final_fn aes_block_final;
25 static OSSL_OP_cipher_update_fn aes_stream_update;
26 static OSSL_OP_cipher_final_fn aes_stream_final;
27 static OSSL_OP_cipher_cipher_fn aes_cipher;
28 static OSSL_OP_cipher_freectx_fn aes_freectx;
29 static OSSL_OP_cipher_dupctx_fn aes_dupctx;
30 static OSSL_OP_cipher_ctx_get_params_fn aes_ctx_get_params;
31 static OSSL_OP_cipher_ctx_set_params_fn aes_ctx_set_params;
32
33 static int PROV_AES_KEY_generic_init(PROV_AES_KEY *ctx,
34                                       const unsigned char *iv,
35                                       size_t ivlen,
36                                       int enc)
37 {
38     if (iv != NULL && ctx->mode != EVP_CIPH_ECB_MODE) {
39         if (ivlen != AES_BLOCK_SIZE) {
40             PROVerr(PROV_F_PROV_AES_KEY_GENERIC_INIT, ERR_R_INTERNAL_ERROR);
41             return 0;
42         }
43         memcpy(ctx->iv, iv, AES_BLOCK_SIZE);
44     }
45     ctx->enc = enc;
46
47     return 1;
48 }
49
50 static int aes_einit(void *vctx, const unsigned char *key, size_t keylen,
51                            const unsigned char *iv, size_t ivlen)
52 {
53     PROV_AES_KEY *ctx = (PROV_AES_KEY *)vctx;
54
55     if (!PROV_AES_KEY_generic_init(ctx, iv, ivlen, 1)) {
56         /* PROVerr already called */
57         return 0;
58     }
59     if (key != NULL) {
60         if (keylen != ctx->keylen) {
61             PROVerr(PROV_F_AES_EINIT, PROV_R_INVALID_KEYLEN);
62             return 0;
63         }
64         return ctx->ciph->init(ctx, key, ctx->keylen);
65     }
66
67     return 1;
68 }
69
70 static int aes_dinit(void *vctx, const unsigned char *key, size_t keylen,
71                      const unsigned char *iv, size_t ivlen)
72 {
73     PROV_AES_KEY *ctx = (PROV_AES_KEY *)vctx;
74
75     if (!PROV_AES_KEY_generic_init(ctx, iv, ivlen, 0)) {
76         /* PROVerr already called */
77         return 0;
78     }
79     if (key != NULL) {
80         if (keylen != ctx->keylen) {
81             PROVerr(PROV_F_AES_DINIT, PROV_R_INVALID_KEYLEN);
82             return 0;
83         }
84         return ctx->ciph->init(ctx, key, ctx->keylen);
85     }
86
87     return 1;
88 }
89
90 static int aes_block_update(void *vctx, unsigned char *out, size_t *outl,
91                             size_t outsize, const unsigned char *in, size_t inl)
92 {
93     PROV_AES_KEY *ctx = (PROV_AES_KEY *)vctx;
94     size_t nextblocks = fillblock(ctx->buf, &ctx->bufsz, AES_BLOCK_SIZE, &in,
95                                   &inl);
96     size_t outlint = 0;
97
98     /*
99      * If we're decrypting and we end an update on a block boundary we hold
100      * the last block back in case this is the last update call and the last
101      * block is padded.
102      */
103     if (ctx->bufsz == AES_BLOCK_SIZE
104             && (ctx->enc || inl > 0 || !ctx->pad)) {
105         if (outsize < AES_BLOCK_SIZE) {
106             PROVerr(PROV_F_AES_BLOCK_UPDATE, PROV_R_OUTPUT_BUFFER_TOO_SMALL);
107             return 0;
108         }
109         if (!ctx->ciph->cipher(ctx, out, ctx->buf, AES_BLOCK_SIZE)) {
110             PROVerr(PROV_F_AES_BLOCK_UPDATE, PROV_R_CIPHER_OPERATION_FAILED);
111             return 0;
112         }
113         ctx->bufsz = 0;
114         outlint = AES_BLOCK_SIZE;
115         out += AES_BLOCK_SIZE;
116     }
117     if (nextblocks > 0) {
118         if (!ctx->enc && ctx->pad && nextblocks == inl) {
119             if (!ossl_assert(inl >= AES_BLOCK_SIZE)) {
120                 PROVerr(PROV_F_AES_BLOCK_UPDATE, PROV_R_OUTPUT_BUFFER_TOO_SMALL);
121                 return 0;
122             }
123             nextblocks -= AES_BLOCK_SIZE;
124         }
125         outlint += nextblocks;
126         if (outsize < outlint) {
127             PROVerr(PROV_F_AES_BLOCK_UPDATE, PROV_R_OUTPUT_BUFFER_TOO_SMALL);
128             return 0;
129         }
130         if (!ctx->ciph->cipher(ctx, out, in, nextblocks)) {
131             PROVerr(PROV_F_AES_BLOCK_UPDATE, PROV_R_CIPHER_OPERATION_FAILED);
132             return 0;
133         }
134         in += nextblocks;
135         inl -= nextblocks;
136     }
137     if (!trailingdata(ctx->buf, &ctx->bufsz, AES_BLOCK_SIZE, &in, &inl)) {
138         /* PROVerr already called */
139         return 0;
140     }
141
142     *outl = outlint;
143     return inl == 0;
144 }
145
146 static int aes_block_final(void *vctx, unsigned char *out, size_t *outl,
147                            size_t outsize)
148 {
149     PROV_AES_KEY *ctx = (PROV_AES_KEY *)vctx;
150
151     if (ctx->enc) {
152         if (ctx->pad) {
153             padblock(ctx->buf, &ctx->bufsz, AES_BLOCK_SIZE);
154         } else if (ctx->bufsz == 0) {
155             *outl = 0;
156             return 1;
157         } else if (ctx->bufsz != AES_BLOCK_SIZE) {
158             PROVerr(PROV_F_AES_BLOCK_FINAL, PROV_R_WRONG_FINAL_BLOCK_LENGTH);
159             return 0;
160         }
161
162         if (outsize < AES_BLOCK_SIZE) {
163             PROVerr(PROV_F_AES_BLOCK_FINAL, PROV_R_OUTPUT_BUFFER_TOO_SMALL);
164             return 0;
165         }
166         if (!ctx->ciph->cipher(ctx, out, ctx->buf, AES_BLOCK_SIZE)) {
167             PROVerr(PROV_F_AES_BLOCK_FINAL, PROV_R_CIPHER_OPERATION_FAILED);
168             return 0;
169         }
170         ctx->bufsz = 0;
171         *outl = AES_BLOCK_SIZE;
172         return 1;
173     }
174
175     /* Decrypting */
176     if (ctx->bufsz != AES_BLOCK_SIZE) {
177         if (ctx->bufsz == 0 && !ctx->pad) {
178             *outl = 0;
179             return 1;
180         }
181         PROVerr(PROV_F_AES_BLOCK_FINAL, PROV_R_WRONG_FINAL_BLOCK_LENGTH);
182         return 0;
183     }
184
185     if (!ctx->ciph->cipher(ctx, ctx->buf, ctx->buf, AES_BLOCK_SIZE)) {
186         PROVerr(PROV_F_AES_BLOCK_FINAL, PROV_R_CIPHER_OPERATION_FAILED);
187         return 0;
188     }
189
190     if (ctx->pad && !unpadblock(ctx->buf, &ctx->bufsz, AES_BLOCK_SIZE)) {
191         /* PROVerr already called */
192         return 0;
193     }
194
195     if (outsize < ctx->bufsz) {
196         PROVerr(PROV_F_AES_BLOCK_FINAL, PROV_R_OUTPUT_BUFFER_TOO_SMALL);
197         return 0;
198     }
199     memcpy(out, ctx->buf, ctx->bufsz);
200     *outl = ctx->bufsz;
201     ctx->bufsz = 0;
202     return 1;
203 }
204
205 static int aes_stream_update(void *vctx, unsigned char *out, size_t *outl,
206                              size_t outsize, const unsigned char *in,
207                              size_t inl)
208 {
209     PROV_AES_KEY *ctx = (PROV_AES_KEY *)vctx;
210
211     if (outsize < inl) {
212         PROVerr(PROV_F_AES_STREAM_UPDATE, PROV_R_OUTPUT_BUFFER_TOO_SMALL);
213         return 0;
214     }
215
216     if (!ctx->ciph->cipher(ctx, out, in, inl)) {
217         PROVerr(PROV_F_AES_STREAM_UPDATE, PROV_R_CIPHER_OPERATION_FAILED);
218         return 0;
219     }
220
221     *outl = inl;
222     return 1;
223 }
224 static int aes_stream_final(void *vctx, unsigned char *out, size_t *outl,
225                             size_t outsize)
226 {
227     *outl = 0;
228     return 1;
229 }
230
231 static int aes_cipher(void *vctx,
232                       unsigned char *out, size_t *outl, size_t outsize,
233                       const unsigned char *in, size_t inl)
234 {
235     PROV_AES_KEY *ctx = (PROV_AES_KEY *)vctx;
236
237     if (outsize < inl) {
238         PROVerr(PROV_F_AES_CIPHER, PROV_R_OUTPUT_BUFFER_TOO_SMALL);
239         return 0;
240     }
241
242     if (!ctx->ciph->cipher(ctx, out, in, inl)) {
243         PROVerr(PROV_F_AES_CIPHER, PROV_R_CIPHER_OPERATION_FAILED);
244         return 0;
245     }
246
247     *outl = inl;
248     return 1;
249 }
250
251 #define IMPLEMENT_cipher(lcmode, UCMODE, flags, kbits, blkbits, ivbits) \
252     static OSSL_OP_cipher_get_params_fn aes_##kbits##_##lcmode##_get_params; \
253     static int aes_##kbits##_##lcmode##_get_params(OSSL_PARAM params[]) \
254     { \
255         OSSL_PARAM *p; \
256                                                                 \
257         p = OSSL_PARAM_locate(params, OSSL_CIPHER_PARAM_MODE);          \
258         if (p != NULL) {                                                \
259             if (!OSSL_PARAM_set_int(p, EVP_CIPH_##UCMODE##_MODE))           \
260                 return 0;                                               \
261         }                                                           \
262         p = OSSL_PARAM_locate(params, OSSL_CIPHER_PARAM_FLAGS); \
263         if (p != NULL) {                                                \
264             if (!OSSL_PARAM_set_ulong(p, (flags)))                          \
265                 return 0;                                               \
266         }                                                           \
267         p = OSSL_PARAM_locate(params, OSSL_CIPHER_PARAM_KEYLEN);        \
268         if (p != NULL) {                                                \
269             if (!OSSL_PARAM_set_int(p, (kbits) / 8))                         \
270                 return 0;                                               \
271         }                                                           \
272         p = OSSL_PARAM_locate(params, OSSL_CIPHER_PARAM_BLOCK_SIZE);    \
273         if (p != NULL) {                                                \
274             if (!OSSL_PARAM_set_int(p, (blkbits) / 8))                   \
275                 return 0;                                               \
276         }                                                               \
277         p = OSSL_PARAM_locate(params, OSSL_CIPHER_PARAM_IVLEN);         \
278         if (p != NULL) {                                                \
279             if (!OSSL_PARAM_set_int(p, (ivbits) / 8))                    \
280                 return 0;                                               \
281         }                                                               \
282     \
283         return 1; \
284     } \
285     static OSSL_OP_cipher_newctx_fn aes_##kbits##_##lcmode##_newctx; \
286     static void *aes_##kbits##_##lcmode##_newctx(void *provctx) \
287     { \
288         PROV_AES_KEY *ctx = OPENSSL_zalloc(sizeof(*ctx)); \
289     \
290         ctx->pad = 1; \
291         ctx->keylen = ((kbits) / 8);                        \
292         ctx->ciph = PROV_AES_CIPHER_##lcmode(ctx->keylen); \
293         ctx->mode = EVP_CIPH_##UCMODE##_MODE; \
294         return ctx; \
295     }
296
297 /* ECB */
298 IMPLEMENT_cipher(ecb, ECB, 0, 256, 128, 0)
299 IMPLEMENT_cipher(ecb, ECB, 0, 192, 128, 0)
300 IMPLEMENT_cipher(ecb, ECB, 0, 128, 128, 0)
301
302 /* CBC */
303 IMPLEMENT_cipher(cbc, CBC, 0, 256, 128, 128)
304 IMPLEMENT_cipher(cbc, CBC, 0, 192, 128, 128)
305 IMPLEMENT_cipher(cbc, CBC, 0, 128, 128, 128)
306
307 /* OFB */
308 IMPLEMENT_cipher(ofb, OFB, 0, 256, 8, 128)
309 IMPLEMENT_cipher(ofb, OFB, 0, 192, 8, 128)
310 IMPLEMENT_cipher(ofb, OFB, 0, 128, 8, 128)
311
312 /* CFB */
313 IMPLEMENT_cipher(cfb, CFB, 0, 256, 8, 128)
314 IMPLEMENT_cipher(cfb, CFB, 0, 192, 8, 128)
315 IMPLEMENT_cipher(cfb, CFB, 0, 128, 8, 128)
316 IMPLEMENT_cipher(cfb1, CFB, 0, 256, 8, 128)
317 IMPLEMENT_cipher(cfb1, CFB, 0, 192, 8, 128)
318 IMPLEMENT_cipher(cfb1, CFB, 0, 128, 8, 128)
319 IMPLEMENT_cipher(cfb8, CFB, 0, 256, 8, 128)
320 IMPLEMENT_cipher(cfb8, CFB, 0, 192, 8, 128)
321 IMPLEMENT_cipher(cfb8, CFB, 0, 128, 8, 128)
322
323 /* CTR */
324 IMPLEMENT_cipher(ctr, CTR, 0, 256, 8, 128)
325 IMPLEMENT_cipher(ctr, CTR, 0, 192, 8, 128)
326 IMPLEMENT_cipher(ctr, CTR, 0, 128, 8, 128)
327
328 static void aes_freectx(void *vctx)
329 {
330     PROV_AES_KEY *ctx = (PROV_AES_KEY *)vctx;
331
332     OPENSSL_clear_free(ctx,  sizeof(*ctx));
333 }
334
335 static void *aes_dupctx(void *ctx)
336 {
337     PROV_AES_KEY *in = (PROV_AES_KEY *)ctx;
338     PROV_AES_KEY *ret = OPENSSL_malloc(sizeof(*ret));
339
340     if (ret == NULL) {
341         PROVerr(PROV_F_AES_DUPCTX, ERR_R_MALLOC_FAILURE);
342         return NULL;
343     }
344     *ret = *in;
345
346     return ret;
347 }
348
349 static int aes_ctx_get_params(void *vctx, OSSL_PARAM params[])
350 {
351     PROV_AES_KEY *ctx = (PROV_AES_KEY *)vctx;
352     OSSL_PARAM *p;
353
354     p = OSSL_PARAM_locate(params, OSSL_CIPHER_PARAM_PADDING);
355     if (p != NULL && !OSSL_PARAM_set_int(p, ctx->pad)) {
356         PROVerr(PROV_F_AES_CTX_GET_PARAMS, PROV_R_FAILED_TO_SET_PARAMETER);
357         return 0;
358     }
359     p = OSSL_PARAM_locate(params, OSSL_CIPHER_PARAM_IV);
360     if (p != NULL
361         && !OSSL_PARAM_set_octet_ptr(p, &ctx->iv, AES_BLOCK_SIZE)
362         && !OSSL_PARAM_set_octet_string(p, &ctx->iv, AES_BLOCK_SIZE)) {
363         PROVerr(PROV_F_AES_CTX_GET_PARAMS,
364                 PROV_R_FAILED_TO_SET_PARAMETER);
365         return 0;
366     }
367     p = OSSL_PARAM_locate(params, OSSL_CIPHER_PARAM_NUM);
368     if (p != NULL && !OSSL_PARAM_set_size_t(p, ctx->num)) {
369         PROVerr(PROV_F_AES_CTX_GET_PARAMS,
370                 PROV_R_FAILED_TO_SET_PARAMETER);
371         return 0;
372     }
373     p = OSSL_PARAM_locate(params, OSSL_CIPHER_PARAM_KEYLEN);
374     if (p != NULL && !OSSL_PARAM_set_int(p, ctx->keylen)) {
375         PROVerr(PROV_F_AES_CTX_GET_PARAMS,
376                 PROV_R_FAILED_TO_SET_PARAMETER);
377         return 0;
378     }
379
380     return 1;
381 }
382
383 static int aes_ctx_set_params(void *vctx, const OSSL_PARAM params[])
384 {
385     PROV_AES_KEY *ctx = (PROV_AES_KEY *)vctx;
386     const OSSL_PARAM *p;
387
388     p = OSSL_PARAM_locate_const(params, OSSL_CIPHER_PARAM_PADDING);
389     if (p != NULL) {
390         int pad;
391
392         if (!OSSL_PARAM_get_int(p, &pad)) {
393             PROVerr(PROV_F_AES_CTX_SET_PARAMS,
394                     PROV_R_FAILED_TO_GET_PARAMETER);
395             return 0;
396         }
397         ctx->pad = pad ? 1 : 0;
398     }
399     p = OSSL_PARAM_locate_const(params, OSSL_CIPHER_PARAM_NUM);
400     if (p != NULL) {
401         int num;
402
403         if (!OSSL_PARAM_get_int(p, &num)) {
404             PROVerr(PROV_F_AES_CTX_SET_PARAMS,
405                     PROV_R_FAILED_TO_GET_PARAMETER);
406             return 0;
407         }
408         ctx->num = num;
409     }
410     p = OSSL_PARAM_locate_const(params, OSSL_CIPHER_PARAM_KEYLEN);
411     if (p != NULL) {
412         int keylen;
413
414         if (!OSSL_PARAM_get_int(p, &keylen)) {
415             PROVerr(PROV_F_AES_CTX_SET_PARAMS,
416                     PROV_R_FAILED_TO_GET_PARAMETER);
417             return 0;
418         }
419         ctx->keylen = keylen;
420     }
421     return 1;
422 }
423
424 #define IMPLEMENT_block_funcs(mode, kbits) \
425     const OSSL_DISPATCH aes##kbits##mode##_functions[] = { \
426         { OSSL_FUNC_CIPHER_NEWCTX, (void (*)(void))aes_##kbits##_##mode##_newctx }, \
427         { OSSL_FUNC_CIPHER_ENCRYPT_INIT, (void (*)(void))aes_einit }, \
428         { OSSL_FUNC_CIPHER_DECRYPT_INIT, (void (*)(void))aes_dinit }, \
429         { OSSL_FUNC_CIPHER_UPDATE, (void (*)(void))aes_block_update }, \
430         { OSSL_FUNC_CIPHER_FINAL, (void (*)(void))aes_block_final }, \
431         { OSSL_FUNC_CIPHER_CIPHER, (void (*)(void))aes_cipher }, \
432         { OSSL_FUNC_CIPHER_FREECTX, (void (*)(void))aes_freectx }, \
433         { OSSL_FUNC_CIPHER_DUPCTX, (void (*)(void))aes_dupctx }, \
434         { OSSL_FUNC_CIPHER_GET_PARAMS, (void (*)(void))aes_##kbits##_##mode##_get_params }, \
435         { OSSL_FUNC_CIPHER_CTX_GET_PARAMS, (void (*)(void))aes_ctx_get_params }, \
436         { OSSL_FUNC_CIPHER_CTX_SET_PARAMS, (void (*)(void))aes_ctx_set_params }, \
437         { 0, NULL } \
438     };
439
440 #define IMPLEMENT_stream_funcs(mode, kbits) \
441     const OSSL_DISPATCH aes##kbits##mode##_functions[] = { \
442         { OSSL_FUNC_CIPHER_NEWCTX, (void (*)(void))aes_##kbits##_##mode##_newctx }, \
443         { OSSL_FUNC_CIPHER_ENCRYPT_INIT, (void (*)(void))aes_einit }, \
444         { OSSL_FUNC_CIPHER_DECRYPT_INIT, (void (*)(void))aes_dinit }, \
445         { OSSL_FUNC_CIPHER_UPDATE, (void (*)(void))aes_stream_update }, \
446         { OSSL_FUNC_CIPHER_FINAL, (void (*)(void))aes_stream_final }, \
447         { OSSL_FUNC_CIPHER_CIPHER, (void (*)(void))aes_cipher }, \
448         { OSSL_FUNC_CIPHER_FREECTX, (void (*)(void))aes_freectx }, \
449         { OSSL_FUNC_CIPHER_DUPCTX, (void (*)(void))aes_dupctx }, \
450         { OSSL_FUNC_CIPHER_GET_PARAMS, (void (*)(void))aes_##kbits##_##mode##_get_params }, \
451         { OSSL_FUNC_CIPHER_CTX_GET_PARAMS, (void (*)(void))aes_ctx_get_params }, \
452         { OSSL_FUNC_CIPHER_CTX_SET_PARAMS, (void (*)(void))aes_ctx_set_params }, \
453         { 0, NULL } \
454     };
455
456 /* ECB */
457 IMPLEMENT_block_funcs(ecb, 256)
458 IMPLEMENT_block_funcs(ecb, 192)
459 IMPLEMENT_block_funcs(ecb, 128)
460
461 /* CBC */
462 IMPLEMENT_block_funcs(cbc, 256)
463 IMPLEMENT_block_funcs(cbc, 192)
464 IMPLEMENT_block_funcs(cbc, 128)
465
466 /* OFB */
467 IMPLEMENT_stream_funcs(ofb, 256)
468 IMPLEMENT_stream_funcs(ofb, 192)
469 IMPLEMENT_stream_funcs(ofb, 128)
470
471 /* CFB */
472 IMPLEMENT_stream_funcs(cfb, 256)
473 IMPLEMENT_stream_funcs(cfb, 192)
474 IMPLEMENT_stream_funcs(cfb, 128)
475 IMPLEMENT_stream_funcs(cfb1, 256)
476 IMPLEMENT_stream_funcs(cfb1, 192)
477 IMPLEMENT_stream_funcs(cfb1, 128)
478 IMPLEMENT_stream_funcs(cfb8, 256)
479 IMPLEMENT_stream_funcs(cfb8, 192)
480 IMPLEMENT_stream_funcs(cfb8, 128)
481
482 /* CTR */
483 IMPLEMENT_stream_funcs(ctr, 256)
484 IMPLEMENT_stream_funcs(ctr, 192)
485 IMPLEMENT_stream_funcs(ctr, 128)