If we're going to return errors (no matter how stupid), then we should
[openssl.git] / crypto / rsa / rsa_oaep.c
index 3652677..70bacf8 100644 (file)
@@ -28,7 +28,7 @@
 #include <openssl/rand.h>
 #include <openssl/sha.h>
 
-int MGF1(unsigned char *mask, long len,
+static int MGF1(unsigned char *mask, long len,
        const unsigned char *seed, long seedlen);
 
 int RSA_padding_add_PKCS1_OAEP(unsigned char *to, int tlen,
@@ -76,11 +76,13 @@ int RSA_padding_add_PKCS1_OAEP(unsigned char *to, int tlen,
           20);
 #endif
 
-       MGF1(dbmask, emlen - SHA_DIGEST_LENGTH, seed, SHA_DIGEST_LENGTH);
+       if (MGF1(dbmask, emlen - SHA_DIGEST_LENGTH, seed, SHA_DIGEST_LENGTH) < 0)
+               return 0;
        for (i = 0; i < emlen - SHA_DIGEST_LENGTH; i++)
                db[i] ^= dbmask[i];
 
-       MGF1(seedmask, SHA_DIGEST_LENGTH, db, emlen - SHA_DIGEST_LENGTH);
+       if (MGF1(seedmask, SHA_DIGEST_LENGTH, db, emlen - SHA_DIGEST_LENGTH) < 0)
+               return 0;
        for (i = 0; i < SHA_DIGEST_LENGTH; i++)
                seed[i] ^= seedmask[i];
 
@@ -133,11 +135,13 @@ int RSA_padding_check_PKCS1_OAEP(unsigned char *to, int tlen,
 
        maskeddb = padded_from + SHA_DIGEST_LENGTH;
 
-       MGF1(seed, SHA_DIGEST_LENGTH, maskeddb, dblen);
+       if (MGF1(seed, SHA_DIGEST_LENGTH, maskeddb, dblen))
+               return -1;
        for (i = 0; i < SHA_DIGEST_LENGTH; i++)
                seed[i] ^= padded_from[i];
   
-       MGF1(db, dblen, seed, SHA_DIGEST_LENGTH);
+       if (MGF1(db, dblen, seed, SHA_DIGEST_LENGTH))
+               return -1;
        for (i = 0; i < dblen; i++)
                db[i] ^= maskeddb[i];
 
@@ -188,6 +192,8 @@ int PKCS1_MGF1(unsigned char *mask, long len,
 
        EVP_MD_CTX_init(&c);
        mdlen = EVP_MD_size(dgst);
+       if (mdlen < 0)
+               return -1;
        for (i = 0; outlen < len; i++)
                {
                cnt[0] = (unsigned char)((i >> 24) & 255);
@@ -213,7 +219,8 @@ int PKCS1_MGF1(unsigned char *mask, long len,
        return 0;
        }
 
-int MGF1(unsigned char *mask, long len, const unsigned char *seed, long seedlen)
+static int MGF1(unsigned char *mask, long len, const unsigned char *seed,
+                long seedlen)
        {
        return PKCS1_MGF1(mask, len, seed, seedlen, EVP_sha1());
        }