Fix CTS cipher decrypt so that the updated IV is returned correctly.
[openssl.git] / providers / implementations / ciphers / cipher_cts.c
1 /*
2  * Copyright 2020-2021 The OpenSSL Project Authors. All Rights Reserved.
3  *
4  * Licensed under the Apache License 2.0 (the "License").  You may not use
5  * this file except in compliance with the License.  You can obtain a copy
6  * in the file LICENSE in the source distribution or at
7  * https://www.openssl.org/source/license.html
8  */
9
10 /*
11  * Helper functions for 128 bit CBC CTS ciphers (Currently AES and Camellia).
12  *
13  * The function dispatch tables are embedded into cipher_aes.c
14  * and cipher_camellia.c using cipher_aes_cts.inc and cipher_camellia_cts.inc
15  */
16
17 /*
18  * Refer to SP800-38A-Addendum
19  *
20  * Ciphertext stealing encrypts plaintext using a block cipher, without padding
21  * the message to a multiple of the block size, so the ciphertext is the same
22  * size as the plaintext.
23  * It does this by altering processing of the last two blocks of the message.
24  * The processing of all but the last two blocks is unchanged, but a portion of
25  * the second-last block's ciphertext is "stolen" to pad the last plaintext
26  * block. The padded final block is then encrypted as usual.
27  * The final ciphertext for the last two blocks, consists of the partial block
28  * (with the "stolen" portion omitted) plus the full final block,
29  * which are the same size as the original plaintext.
30  * Decryption requires decrypting the final block first, then restoring the
31  * stolen ciphertext to the partial block, which can then be decrypted as usual.
32
33  * AES_CBC_CTS has 3 variants:
34  *  (1) CS1 The NIST variant.
35  *      If the length is a multiple of the blocksize it is the same as CBC mode.
36  *      otherwise it produces C1||C2||(C(n-1))*||Cn.
37  *      Where C(n-1)* is a partial block.
38  *  (2) CS2
39  *      If the length is a multiple of the blocksize it is the same as CBC mode.
40  *      otherwise it produces C1||C2||Cn||(C(n-1))*.
41  *      Where C(n-1)* is a partial block.
42  *  (3) CS3 The Kerberos5 variant.
43  *      Produces C1||C2||Cn||(C(n-1))* regardless of the length.
44  *      If the length is a multiple of the blocksize it looks similar to CBC mode
45  *      with the last 2 blocks swapped.
46  *      Otherwise it is the same as CS2.
47  */
48
49 #include "e_os.h" /* strcasecmp */
50 #include <openssl/core_names.h>
51 #include "prov/ciphercommon.h"
52 #include "internal/nelem.h"
53 #include "cipher_cts.h"
54
55 /* The value assigned to 0 is the default */
56 #define CTS_CS1 0
57 #define CTS_CS2 1
58 #define CTS_CS3 2
59
60 #define CTS_BLOCK_SIZE 16
61
62 typedef union {
63     size_t align;
64     unsigned char c[CTS_BLOCK_SIZE];
65 } aligned_16bytes;
66
67 typedef struct cts_mode_name2id_st {
68     unsigned int id;
69     const char *name;
70 } CTS_MODE_NAME2ID;
71
72 static CTS_MODE_NAME2ID cts_modes[] =
73 {
74     { CTS_CS1, OSSL_CIPHER_CTS_MODE_CS1 },
75     { CTS_CS2, OSSL_CIPHER_CTS_MODE_CS2 },
76     { CTS_CS3, OSSL_CIPHER_CTS_MODE_CS3 },
77 };
78
79 const char *ossl_cipher_cbc_cts_mode_id2name(unsigned int id)
80 {
81     size_t i;
82
83     for (i = 0; i < OSSL_NELEM(cts_modes); ++i) {
84         if (cts_modes[i].id == id)
85             return cts_modes[i].name;
86     }
87     return NULL;
88 }
89
90 int ossl_cipher_cbc_cts_mode_name2id(const char *name)
91 {
92     size_t i;
93
94     for (i = 0; i < OSSL_NELEM(cts_modes); ++i) {
95         if (strcasecmp(name, cts_modes[i].name) == 0)
96             return (int)cts_modes[i].id;
97     }
98     return -1;
99 }
100
101 static size_t cts128_cs1_encrypt(PROV_CIPHER_CTX *ctx, const unsigned char *in,
102                                  unsigned char *out, size_t len)
103 {
104     aligned_16bytes tmp_in;
105     size_t residue;
106
107     residue = len % CTS_BLOCK_SIZE;
108     len -= residue;
109     if (!ctx->hw->cipher(ctx, out, in, len))
110         return 0;
111
112     if (residue == 0)
113         return len;
114
115     in += len;
116     out += len;
117
118     memset(tmp_in.c, 0, sizeof(tmp_in));
119     memcpy(tmp_in.c, in, residue);
120     if (!ctx->hw->cipher(ctx, out - CTS_BLOCK_SIZE + residue, tmp_in.c,
121                          CTS_BLOCK_SIZE))
122         return 0;
123     return len + residue;
124 }
125
126 static void do_xor(const unsigned char *in1, const unsigned char *in2,
127                    size_t len, unsigned char *out)
128 {
129     size_t i;
130
131     for (i = 0; i < len; ++i)
132         out[i] = in1[i] ^ in2[i];
133 }
134
135 static size_t cts128_cs1_decrypt(PROV_CIPHER_CTX *ctx, const unsigned char *in,
136                                  unsigned char *out, size_t len)
137 {
138     aligned_16bytes mid_iv, ct_mid, cn, pt_last;
139     size_t residue;
140
141     residue = len % CTS_BLOCK_SIZE;
142     if (residue == 0) {
143         /* If there are no partial blocks then it is the same as CBC mode */
144         if (!ctx->hw->cipher(ctx, out, in, len))
145             return 0;
146         return len;
147     }
148     /* Process blocks at the start - but leave the last 2 blocks */
149     len -= CTS_BLOCK_SIZE + residue;
150     if (len > 0) {
151         if (!ctx->hw->cipher(ctx, out, in, len))
152             return 0;
153         in += len;
154         out += len;
155     }
156     /* Save the iv that will be used by the second last block */
157     memcpy(mid_iv.c, ctx->iv, CTS_BLOCK_SIZE);
158     /* Save the C(n) block */
159     memcpy(cn.c, in + residue, CTS_BLOCK_SIZE);
160
161     /* Decrypt the last block first using an iv of zero */
162     memset(ctx->iv, 0, CTS_BLOCK_SIZE);
163     if (!ctx->hw->cipher(ctx, pt_last.c, in + residue, CTS_BLOCK_SIZE))
164         return 0;
165
166     /*
167      * Rebuild the ciphertext of the second last block as a combination of
168      * the decrypted last block + replace the start with the ciphertext bytes
169      * of the partial second last block.
170      */
171     memcpy(ct_mid.c, in, residue);
172     memcpy(ct_mid.c + residue, pt_last.c + residue, CTS_BLOCK_SIZE - residue);
173     /*
174      * Restore the last partial ciphertext block.
175      * Now that we have the cipher text of the second last block, apply
176      * that to the partial plaintext end block. We have already decrypted the
177      * block using an IV of zero. For decryption the IV is just XORed after
178      * doing an Cipher CBC block - so just XOR in the cipher text.
179      */
180     do_xor(ct_mid.c, pt_last.c, residue, out + CTS_BLOCK_SIZE);
181
182     /* Restore the iv needed by the second last block */
183     memcpy(ctx->iv, mid_iv.c, CTS_BLOCK_SIZE);
184
185     /*
186      * Decrypt the second last plaintext block now that we have rebuilt the
187      * ciphertext.
188      */
189     if (!ctx->hw->cipher(ctx, out, ct_mid.c, CTS_BLOCK_SIZE))
190         return 0;
191
192     /* The returned iv is the C(n) block */
193     memcpy(ctx->iv, cn.c, CTS_BLOCK_SIZE);
194     return len + CTS_BLOCK_SIZE + residue;
195 }
196
197 static size_t cts128_cs3_encrypt(PROV_CIPHER_CTX *ctx, const unsigned char *in,
198                                  unsigned char *out, size_t len)
199 {
200     aligned_16bytes tmp_in;
201     size_t residue;
202
203     if (len < CTS_BLOCK_SIZE)  /* CS3 requires at least one block */
204         return 0;
205
206     /* If we only have one block then just process the aligned block */
207     if (len == CTS_BLOCK_SIZE)
208         return ctx->hw->cipher(ctx, out, in, len) ? len : 0;
209
210     residue = len % CTS_BLOCK_SIZE;
211     if (residue == 0)
212         residue = CTS_BLOCK_SIZE;
213     len -= residue;
214
215     if (!ctx->hw->cipher(ctx, out, in, len))
216         return 0;
217
218     in += len;
219     out += len;
220
221     memset(tmp_in.c, 0, sizeof(tmp_in));
222     memcpy(tmp_in.c, in, residue);
223     memcpy(out, out - CTS_BLOCK_SIZE, residue);
224     if (!ctx->hw->cipher(ctx, out - CTS_BLOCK_SIZE, tmp_in.c, CTS_BLOCK_SIZE))
225         return 0;
226     return len + residue;
227 }
228
229 /*
230  * Note:
231  *  The cipher text (in) is of the form C(0), C(1), ., C(n), C(n-1)* where
232  *  C(n) is a full block and C(n-1)* can be a partial block
233  *  (but could be a full block).
234  *  This means that the output plaintext (out) needs to swap the plaintext of
235  *  the last two decoded ciphertext blocks.
236  */
237 static size_t cts128_cs3_decrypt(PROV_CIPHER_CTX *ctx, const unsigned char *in,
238                                  unsigned char *out, size_t len)
239 {
240     aligned_16bytes mid_iv, ct_mid, cn, pt_last;
241     size_t residue;
242
243     if (len < CTS_BLOCK_SIZE) /* CS3 requires at least one block */
244         return 0;
245
246     /* If we only have one block then just process the aligned block */
247     if (len == CTS_BLOCK_SIZE)
248         return ctx->hw->cipher(ctx, out, in, len) ? len : 0;
249
250     /* Process blocks at the start - but leave the last 2 blocks */
251     residue = len % CTS_BLOCK_SIZE;
252     if (residue == 0)
253         residue = CTS_BLOCK_SIZE;
254     len -= CTS_BLOCK_SIZE + residue;
255
256     if (len > 0) {
257         if (!ctx->hw->cipher(ctx, out, in, len))
258             return 0;
259         in += len;
260         out += len;
261     }
262     /* Save the iv that will be used by the second last block */
263     memcpy(mid_iv.c, ctx->iv, CTS_BLOCK_SIZE);
264     /* Save the C(n) block : For CS3 it is C(1)||...||C(n-2)||C(n)||C(n-1)* */
265     memcpy(cn.c, in, CTS_BLOCK_SIZE);
266
267     /* Decrypt the C(n) block first using an iv of zero */
268     memset(ctx->iv, 0, CTS_BLOCK_SIZE);
269     if (!ctx->hw->cipher(ctx, pt_last.c, in, CTS_BLOCK_SIZE))
270         return 0;
271
272     /*
273      * Rebuild the ciphertext of C(n-1) as a combination of
274      * the decrypted C(n) block + replace the start with the ciphertext bytes
275      * of the partial last block.
276      */
277     memcpy(ct_mid.c, in + CTS_BLOCK_SIZE, residue);
278     if (residue != CTS_BLOCK_SIZE)
279         memcpy(ct_mid.c + residue, pt_last.c + residue, CTS_BLOCK_SIZE - residue);
280     /*
281      * Restore the last partial ciphertext block.
282      * Now that we have the cipher text of the second last block, apply
283      * that to the partial plaintext end block. We have already decrypted the
284      * block using an IV of zero. For decryption the IV is just XORed after
285      * doing an AES block - so just XOR in the ciphertext.
286      */
287     do_xor(ct_mid.c, pt_last.c, residue, out + CTS_BLOCK_SIZE);
288
289     /* Restore the iv needed by the second last block */
290     memcpy(ctx->iv, mid_iv.c, CTS_BLOCK_SIZE);
291     /*
292      * Decrypt the second last plaintext block now that we have rebuilt the
293      * ciphertext.
294      */
295     if (!ctx->hw->cipher(ctx, out, ct_mid.c, CTS_BLOCK_SIZE))
296         return 0;
297
298     /* The returned iv is the C(n) block */
299     memcpy(ctx->iv, cn.c, CTS_BLOCK_SIZE);
300     return len + CTS_BLOCK_SIZE + residue;
301 }
302
303 static size_t cts128_cs2_encrypt(PROV_CIPHER_CTX *ctx, const unsigned char *in,
304                                  unsigned char *out, size_t len)
305 {
306     if (len % CTS_BLOCK_SIZE == 0) {
307         /* If there are no partial blocks then it is the same as CBC mode */
308         if (!ctx->hw->cipher(ctx, out, in, len))
309             return 0;
310         return len;
311     }
312     /* For partial blocks CS2 is equivalent to CS3 */
313     return cts128_cs3_encrypt(ctx, in, out, len);
314 }
315
316 static size_t cts128_cs2_decrypt(PROV_CIPHER_CTX *ctx, const unsigned char *in,
317                                  unsigned char *out, size_t len)
318 {
319     if (len % CTS_BLOCK_SIZE == 0) {
320         /* If there are no partial blocks then it is the same as CBC mode */
321         if (!ctx->hw->cipher(ctx, out, in, len))
322             return 0;
323         return len;
324     }
325     /* For partial blocks CS2 is equivalent to CS3 */
326     return cts128_cs3_decrypt(ctx, in, out, len);
327 }
328
329 int ossl_cipher_cbc_cts_block_update(void *vctx, unsigned char *out, size_t *outl,
330                                      size_t outsize, const unsigned char *in,
331                                      size_t inl)
332 {
333     PROV_CIPHER_CTX *ctx = (PROV_CIPHER_CTX *)vctx;
334     size_t sz = 0;
335
336     if (inl < CTS_BLOCK_SIZE) /* There must be at least one block for CTS mode */
337         return 0;
338     if (outsize < inl)
339         return 0;
340     if (out == NULL) {
341         *outl = inl;
342         return 1;
343     }
344
345     /*
346      * Return an error if the update is called multiple times, only one shot
347      * is supported.
348      */
349     if (ctx->updated == 1)
350         return 0;
351
352     if (ctx->enc) {
353         if (ctx->cts_mode == CTS_CS1)
354             sz = cts128_cs1_encrypt(ctx, in, out, inl);
355         else if (ctx->cts_mode == CTS_CS2)
356             sz = cts128_cs2_encrypt(ctx, in, out, inl);
357         else if (ctx->cts_mode == CTS_CS3)
358             sz = cts128_cs3_encrypt(ctx, in, out, inl);
359     } else {
360         if (ctx->cts_mode == CTS_CS1)
361             sz = cts128_cs1_decrypt(ctx, in, out, inl);
362         else if (ctx->cts_mode == CTS_CS2)
363             sz = cts128_cs2_decrypt(ctx, in, out, inl);
364         else if (ctx->cts_mode == CTS_CS3)
365             sz = cts128_cs3_decrypt(ctx, in, out, inl);
366     }
367     if (sz == 0)
368         return 0;
369     ctx->updated = 1; /* Stop multiple updates being allowed */
370     *outl = sz;
371     return 1;
372 }
373
374 int ossl_cipher_cbc_cts_block_final(void *vctx, unsigned char *out, size_t *outl,
375                                     size_t outsize)
376 {
377     *outl = 0;
378     return 1;
379 }