Fix from 0.9.8-stable.
[openssl.git] / crypto / rsa / rsa_oaep.c
index 3652677a99822b9cd47b9354bc0ba6366e3a2fc7..e238d10e5cc6fc517e0a6ce4cb42f453da2552ad 100644 (file)
@@ -28,7 +28,7 @@
 #include <openssl/rand.h>
 #include <openssl/sha.h>
 
-int MGF1(unsigned char *mask, long len,
+static int MGF1(unsigned char *mask, long len,
        const unsigned char *seed, long seedlen);
 
 int RSA_padding_add_PKCS1_OAEP(unsigned char *to, int tlen,
@@ -52,13 +52,6 @@ int RSA_padding_add_PKCS1_OAEP(unsigned char *to, int tlen,
                return 0;
                }
 
-       dbmask = OPENSSL_malloc(emlen - SHA_DIGEST_LENGTH);
-       if (dbmask == NULL)
-               {
-               RSAerr(RSA_F_RSA_PADDING_ADD_PKCS1_OAEP, ERR_R_MALLOC_FAILURE);
-               return 0;
-               }
-
        to[0] = 0;
        seed = to + 1;
        db = to + SHA_DIGEST_LENGTH + 1;
@@ -76,11 +69,20 @@ int RSA_padding_add_PKCS1_OAEP(unsigned char *to, int tlen,
           20);
 #endif
 
-       MGF1(dbmask, emlen - SHA_DIGEST_LENGTH, seed, SHA_DIGEST_LENGTH);
+       dbmask = OPENSSL_malloc(emlen - SHA_DIGEST_LENGTH);
+       if (dbmask == NULL)
+               {
+               RSAerr(RSA_F_RSA_PADDING_ADD_PKCS1_OAEP, ERR_R_MALLOC_FAILURE);
+               return 0;
+               }
+
+       if (MGF1(dbmask, emlen - SHA_DIGEST_LENGTH, seed, SHA_DIGEST_LENGTH) < 0)
+               return 0;
        for (i = 0; i < emlen - SHA_DIGEST_LENGTH; i++)
                db[i] ^= dbmask[i];
 
-       MGF1(seedmask, SHA_DIGEST_LENGTH, db, emlen - SHA_DIGEST_LENGTH);
+       if (MGF1(seedmask, SHA_DIGEST_LENGTH, db, emlen - SHA_DIGEST_LENGTH) < 0)
+               return 0;
        for (i = 0; i < SHA_DIGEST_LENGTH; i++)
                seed[i] ^= seedmask[i];
 
@@ -133,11 +135,13 @@ int RSA_padding_check_PKCS1_OAEP(unsigned char *to, int tlen,
 
        maskeddb = padded_from + SHA_DIGEST_LENGTH;
 
-       MGF1(seed, SHA_DIGEST_LENGTH, maskeddb, dblen);
+       if (MGF1(seed, SHA_DIGEST_LENGTH, maskeddb, dblen))
+               return -1;
        for (i = 0; i < SHA_DIGEST_LENGTH; i++)
                seed[i] ^= padded_from[i];
   
-       MGF1(db, dblen, seed, SHA_DIGEST_LENGTH);
+       if (MGF1(db, dblen, seed, SHA_DIGEST_LENGTH))
+               return -1;
        for (i = 0; i < dblen; i++)
                db[i] ^= maskeddb[i];
 
@@ -188,6 +192,8 @@ int PKCS1_MGF1(unsigned char *mask, long len,
 
        EVP_MD_CTX_init(&c);
        mdlen = EVP_MD_size(dgst);
+       if (mdlen < 0)
+               return -1;
        for (i = 0; outlen < len; i++)
                {
                cnt[0] = (unsigned char)((i >> 24) & 255);
@@ -213,7 +219,8 @@ int PKCS1_MGF1(unsigned char *mask, long len,
        return 0;
        }
 
-int MGF1(unsigned char *mask, long len, const unsigned char *seed, long seedlen)
+static int MGF1(unsigned char *mask, long len, const unsigned char *seed,
+                long seedlen)
        {
        return PKCS1_MGF1(mask, len, seed, seedlen, EVP_sha1());
        }