Correct serious bug in AES-CBC decryption when the message length isn't
[openssl.git] / crypto / aes / aes_cbc.c
index bf1a808cf0e1ea168cba4403ae7819e92ee5c77a..0a28ab8d3437d424ce41bb55cd9d79abd924760f 100644 (file)
  *
  */
 
+#ifndef AES_DEBUG
+# ifndef NDEBUG
+#  define NDEBUG
+# endif
+#endif
 #include <assert.h>
+
 #include <openssl/aes.h>
 #include "aes_locl.h"
 
@@ -57,22 +63,22 @@ void AES_cbc_encrypt(const unsigned char *in, unsigned char *out,
                     const unsigned long length, const AES_KEY *key,
                     unsigned char *ivec, const int enc) {
 
-       int n;
+       unsigned long n;
        unsigned long len = length;
-       unsigned char tmp[16];
+       unsigned char tmp[AES_BLOCK_SIZE];
 
        assert(in && out && key && ivec);
        assert((AES_ENCRYPT == enc)||(AES_DECRYPT == enc));
 
        if (AES_ENCRYPT == enc) {
                while (len >= AES_BLOCK_SIZE) {
-                       for(n=0; n < 16; ++n)
+                       for(n=0; n < AES_BLOCK_SIZE; ++n)
                                tmp[n] = in[n] ^ ivec[n];
                        AES_encrypt(tmp, out, key);
-                       memcpy(ivec, out, 16);
-                       len -= 16;
-                       in += 16;
-                       out += 16;
+                       memcpy(ivec, out, AES_BLOCK_SIZE);
+                       len -= AES_BLOCK_SIZE;
+                       in += AES_BLOCK_SIZE;
+                       out += AES_BLOCK_SIZE;
                }
                if (len) {
                        for(n=0; n < len; ++n)
@@ -80,26 +86,25 @@ void AES_cbc_encrypt(const unsigned char *in, unsigned char *out,
                        for(n=len; n < AES_BLOCK_SIZE; ++n)
                                tmp[n] = ivec[n];
                        AES_encrypt(tmp, tmp, key);
-                       memcpy(out, tmp, len);
-                       memcpy(ivec, tmp, 16);
+                       memcpy(out, tmp, AES_BLOCK_SIZE);
+                       memcpy(ivec, tmp, AES_BLOCK_SIZE);
                }                       
        } else {
                while (len >= AES_BLOCK_SIZE) {
-                       memcpy(tmp, in, 16);
                        AES_decrypt(in, out, key);
-                       for(n=0; n < 16; ++n)
+                       for(n=0; n < AES_BLOCK_SIZE; ++n)
                                out[n] ^= ivec[n];
-                       memcpy(ivec, tmp, 16);
-                       len -= 16;
-                       in += 16;
-                       out += 16;
+                       memcpy(ivec, in, AES_BLOCK_SIZE);
+                       len -= AES_BLOCK_SIZE;
+                       in += AES_BLOCK_SIZE;
+                       out += AES_BLOCK_SIZE;
                }
                if (len) {
-                       memcpy(tmp, in, 16);
-                       AES_decrypt(tmp, tmp, key);
+                       memcpy(tmp, in, AES_BLOCK_SIZE);
+                       AES_decrypt(in, tmp, key);
                        for(n=0; n < len; ++n)
-                               out[n] ^= ivec[n];
-                       memcpy(ivec, tmp, 16);
+                               out[n] = tmp[n] ^ ivec[n];
+                       memcpy(ivec, in, AES_BLOCK_SIZE);
                }                       
        }
 }