EVP_DecryptInit() should call EVP_CipherInit() not EVP_CipherInit_ex().
[openssl.git] / crypto / evp / evp_enc.c
index 22a7b745c174ac3d68bd36698da6d8cf57fa82f5..ccfcc7e1b15ba815daad602608af989f1e979540 100644 (file)
@@ -63,8 +63,6 @@
 #include <openssl/engine.h>
 #include "evp_locl.h"
 
-#include <assert.h>
-
 const char *EVP_version="EVP" OPENSSL_VERSION_PTEXT;
 
 void EVP_CIPHER_CTX_init(EVP_CIPHER_CTX *ctx)
@@ -85,7 +83,14 @@ int EVP_CipherInit(EVP_CIPHER_CTX *ctx, const EVP_CIPHER *cipher,
 int EVP_CipherInit_ex(EVP_CIPHER_CTX *ctx, const EVP_CIPHER *cipher, ENGINE *impl,
             const unsigned char *key, const unsigned char *iv, int enc)
        {
-       if(enc && (enc != -1)) enc = 1;
+       if (enc == -1)
+               enc = ctx->encrypt;
+       else
+               {
+               if (enc)
+                       enc = 1;
+               ctx->encrypt = enc;
+               }
        /* Whether it's nice or not, "Inits" can be used on "Final"'d contexts
         * so this context may already have an ENGINE! Try to avoid releasing
         * the previous handle, re-querying for an ENGINE, and having a
@@ -95,11 +100,13 @@ int EVP_CipherInit_ex(EVP_CIPHER_CTX *ctx, const EVP_CIPHER *cipher, ENGINE *imp
                goto skip_to_init;
        if (cipher)
                {
-               /* Ensure an ENGINE left lying around from last time is cleared
+               /* Ensure a context left lying around from last time is cleared
                 * (the previous check attempted to avoid this if the same
                 * ENGINE and EVP_CIPHER could be used). */
-               if(ctx->engine)
-                       ENGINE_finish(ctx->engine);
+               EVP_CIPHER_CTX_cleanup(ctx);
+
+               /* Restore encrypt field: it is zeroed by cleanup */
+               ctx->encrypt = enc;
                if(impl)
                        {
                        if (!ENGINE_init(impl))
@@ -133,6 +140,7 @@ int EVP_CipherInit_ex(EVP_CIPHER_CTX *ctx, const EVP_CIPHER *cipher, ENGINE *imp
                        }
                else
                        ctx->engine = NULL;
+
                ctx->cipher=cipher;
                ctx->cipher_data=OPENSSL_malloc(ctx->cipher->ctx_size);
                ctx->key_len = cipher->key_len;
@@ -153,9 +161,9 @@ int EVP_CipherInit_ex(EVP_CIPHER_CTX *ctx, const EVP_CIPHER *cipher, ENGINE *imp
                }
 skip_to_init:
        /* we assume block size is a power of 2 in *cryptUpdate */
-       assert(ctx->cipher->block_size == 1
-              || ctx->cipher->block_size == 8
-              || ctx->cipher->block_size == 16);
+       OPENSSL_assert(ctx->cipher->block_size == 1
+           || ctx->cipher->block_size == 8
+           || ctx->cipher->block_size == 16);
 
        if(!(EVP_CIPHER_CTX_flags(ctx) & EVP_CIPH_CUSTOM_IV)) {
                switch(EVP_CIPHER_CTX_mode(ctx)) {
@@ -171,6 +179,7 @@ skip_to_init:
 
                        case EVP_CIPH_CBC_MODE:
 
+                       OPENSSL_assert(EVP_CIPHER_CTX_iv_length(ctx) <= sizeof ctx->iv);
                        if(iv) memcpy(ctx->oiv, iv, EVP_CIPHER_CTX_iv_length(ctx));
                        memcpy(ctx->iv, ctx->oiv, EVP_CIPHER_CTX_iv_length(ctx));
                        break;
@@ -184,7 +193,6 @@ skip_to_init:
        if(key || (ctx->cipher->flags & EVP_CIPH_ALWAYS_CALL_INIT)) {
                if(!ctx->cipher->init(ctx,key,iv,enc)) return 0;
        }
-       if(enc != -1) ctx->encrypt=enc;
        ctx->buf_len=0;
        ctx->final_used=0;
        ctx->block_mask=ctx->cipher->block_size-1;
@@ -228,7 +236,7 @@ int EVP_EncryptInit_ex(EVP_CIPHER_CTX *ctx,const EVP_CIPHER *cipher, ENGINE *imp
 int EVP_DecryptInit(EVP_CIPHER_CTX *ctx, const EVP_CIPHER *cipher,
             const unsigned char *key, const unsigned char *iv)
        {
-       return EVP_CipherInit_ex(ctx, cipher, NULL, key, iv, 0);
+       return EVP_CipherInit(ctx, cipher, key, iv, 0);
        }
 
 int EVP_DecryptInit_ex(EVP_CIPHER_CTX *ctx, const EVP_CIPHER *cipher, ENGINE *impl,
@@ -242,6 +250,7 @@ int EVP_EncryptUpdate(EVP_CIPHER_CTX *ctx, unsigned char *out, int *outl,
        {
        int i,j,bl;
 
+       OPENSSL_assert(inl > 0);
        if(ctx->buf_len == 0 && (inl&(ctx->block_mask)) == 0)
                {
                if(ctx->cipher->do_cipher(ctx,out,in,inl))
@@ -257,6 +266,7 @@ int EVP_EncryptUpdate(EVP_CIPHER_CTX *ctx, unsigned char *out, int *outl,
                }
        i=ctx->buf_len;
        bl=ctx->cipher->block_size;
+       OPENSSL_assert(bl <= sizeof ctx->buf);
        if (i != 0)
                {
                if (i+inl < bl)
@@ -297,7 +307,6 @@ int EVP_EncryptFinal(EVP_CIPHER_CTX *ctx, unsigned char *out, int *outl)
        {
        int ret;
        ret = EVP_EncryptFinal_ex(ctx, out, outl);
-       EVP_CIPHER_CTX_cleanup(ctx);
        return ret;
        }
 
@@ -306,16 +315,15 @@ int EVP_EncryptFinal_ex(EVP_CIPHER_CTX *ctx, unsigned char *out, int *outl)
        int i,n,b,bl,ret;
 
        b=ctx->cipher->block_size;
+       OPENSSL_assert(b <= sizeof ctx->buf);
        if (b == 1)
                {
-               EVP_CIPHER_CTX_cleanup(ctx);
                *outl=0;
                return 1;
                }
        bl=ctx->buf_len;
        if (ctx->flags & EVP_CIPH_NO_PADDING)
                {
-               EVP_CIPHER_CTX_cleanup(ctx);
                if(bl)
                        {
                        EVPerr(EVP_F_EVP_ENCRYPTFINAL,EVP_R_DATA_NOT_MULTIPLE_OF_BLOCK_LENGTH);
@@ -330,7 +338,6 @@ int EVP_EncryptFinal_ex(EVP_CIPHER_CTX *ctx, unsigned char *out, int *outl)
                ctx->buf[i]=n;
        ret=ctx->cipher->do_cipher(ctx,out,ctx->buf,b);
 
-       EVP_CIPHER_CTX_cleanup(ctx);
 
        if(ret)
                *outl=b;
@@ -353,6 +360,7 @@ int EVP_DecryptUpdate(EVP_CIPHER_CTX *ctx, unsigned char *out, int *outl,
                return EVP_EncryptUpdate(ctx, out, outl, in, inl);
 
        b=ctx->cipher->block_size;
+       OPENSSL_assert(b <= sizeof ctx->final);
 
        if(ctx->final_used)
                {
@@ -388,7 +396,6 @@ int EVP_DecryptFinal(EVP_CIPHER_CTX *ctx, unsigned char *out, int *outl)
        {
        int ret;
        ret = EVP_DecryptFinal_ex(ctx, out, outl);
-       EVP_CIPHER_CTX_cleanup(ctx);
        return ret;
        }
 
@@ -401,7 +408,6 @@ int EVP_DecryptFinal_ex(EVP_CIPHER_CTX *ctx, unsigned char *out, int *outl)
        b=ctx->cipher->block_size;
        if (ctx->flags & EVP_CIPH_NO_PADDING)
                {
-               EVP_CIPHER_CTX_cleanup(ctx);
                if(ctx->buf_len)
                        {
                        EVPerr(EVP_F_EVP_DECRYPTFINAL,EVP_R_DATA_NOT_MULTIPLE_OF_BLOCK_LENGTH);
@@ -414,14 +420,13 @@ int EVP_DecryptFinal_ex(EVP_CIPHER_CTX *ctx, unsigned char *out, int *outl)
                {
                if (ctx->buf_len || !ctx->final_used)
                        {
-                       EVP_CIPHER_CTX_cleanup(ctx);
                        EVPerr(EVP_F_EVP_DECRYPTFINAL,EVP_R_WRONG_FINAL_BLOCK_LENGTH);
                        return(0);
                        }
+               OPENSSL_assert(b <= sizeof ctx->final);
                n=ctx->final[b-1];
                if (n > b)
                        {
-                       EVP_CIPHER_CTX_cleanup(ctx);
                        EVPerr(EVP_F_EVP_DECRYPTFINAL,EVP_R_BAD_DECRYPT);
                        return(0);
                        }
@@ -429,7 +434,6 @@ int EVP_DecryptFinal_ex(EVP_CIPHER_CTX *ctx, unsigned char *out, int *outl)
                        {
                        if (ctx->final[--b] != n)
                                {
-                               EVP_CIPHER_CTX_cleanup(ctx);
                                EVPerr(EVP_F_EVP_DECRYPTFINAL,EVP_R_BAD_DECRYPT);
                                return(0);
                                }
@@ -441,17 +445,21 @@ int EVP_DecryptFinal_ex(EVP_CIPHER_CTX *ctx, unsigned char *out, int *outl)
                }
        else
                *outl=0;
-       EVP_CIPHER_CTX_cleanup(ctx);
        return(1);
        }
 
 int EVP_CIPHER_CTX_cleanup(EVP_CIPHER_CTX *c)
        {
-       if ((c->cipher != NULL) && (c->cipher->cleanup != NULL))
+       if (c->cipher != NULL)
                {
-               if(!c->cipher->cleanup(c)) return 0;
+               if(c->cipher->cleanup && !c->cipher->cleanup(c))
+                       return 0;
+               /* Cleanse cipher context data */
+               if (c->cipher_data)
+                       OPENSSL_cleanse(c->cipher_data, c->cipher->ctx_size);
                }
-       OPENSSL_free(c->cipher_data);
+       if (c->cipher_data)
+               OPENSSL_free(c->cipher_data);
        if (c->engine)
                /* The EVP_CIPHER we used belongs to an ENGINE, release the
                 * functional reference we held for this reason. */