Check return from BN_sub
[openssl.git] / crypto / rsa / rsa_ossl.c
index ced11ad883c98e17b53e073e4da63c5010d5f0ea..c441905526c82b0e49c1945e9262498bbb0c6688 100644 (file)
@@ -68,7 +68,7 @@ static int rsa_ossl_public_encrypt(int flen, const unsigned char *from,
                                   unsigned char *to, RSA *rsa, int padding)
 {
     BIGNUM *f, *ret;
                                   unsigned char *to, RSA *rsa, int padding)
 {
     BIGNUM *f, *ret;
-    int i, j, k, num = 0, r = -1;
+    int i, num = 0, r = -1;
     unsigned char *buf = NULL;
     BN_CTX *ctx = NULL;
 
     unsigned char *buf = NULL;
     BN_CTX *ctx = NULL;
 
@@ -142,15 +142,10 @@ static int rsa_ossl_public_encrypt(int flen, const unsigned char *from,
         goto err;
 
     /*
         goto err;
 
     /*
-     * put in leading 0 bytes if the number is less than the length of the
-     * modulus
+     * BN_bn2binpad puts in leading 0 bytes if the number is less than
+     * the length of the modulus.
      */
      */
-    j = BN_num_bytes(ret);
-    i = BN_bn2bin(ret, &(to[num - j]));
-    for (k = 0; k < (num - i); k++)
-        to[k] = 0;
-
-    r = num;
+    r = BN_bn2binpad(ret, to, num);
  err:
     if (ctx != NULL)
         BN_CTX_end(ctx);
  err:
     if (ctx != NULL)
         BN_CTX_end(ctx);
@@ -239,7 +234,7 @@ static int rsa_ossl_private_encrypt(int flen, const unsigned char *from,
                                    unsigned char *to, RSA *rsa, int padding)
 {
     BIGNUM *f, *ret, *res;
                                    unsigned char *to, RSA *rsa, int padding)
 {
     BIGNUM *f, *ret, *res;
-    int i, j, k, num = 0, r = -1;
+    int i, num = 0, r = -1;
     unsigned char *buf = NULL;
     BN_CTX *ctx = NULL;
     int local_blinding = 0;
     unsigned char *buf = NULL;
     BN_CTX *ctx = NULL;
     int local_blinding = 0;
@@ -344,7 +339,8 @@ static int rsa_ossl_private_encrypt(int flen, const unsigned char *from,
             goto err;
 
     if (padding == RSA_X931_PADDING) {
             goto err;
 
     if (padding == RSA_X931_PADDING) {
-        BN_sub(f, rsa->n, ret);
+        if (!BN_sub(f, rsa->n, ret))
+            goto err;
         if (BN_cmp(ret, f) > 0)
             res = f;
         else
         if (BN_cmp(ret, f) > 0)
             res = f;
         else
@@ -354,15 +350,10 @@ static int rsa_ossl_private_encrypt(int flen, const unsigned char *from,
     }
 
     /*
     }
 
     /*
-     * put in leading 0 bytes if the number is less than the length of the
-     * modulus
+     * BN_bn2binpad puts in leading 0 bytes if the number is less than
+     * the length of the modulus.
      */
      */
-    j = BN_num_bytes(res);
-    i = BN_bn2bin(res, &(to[num - j]));
-    for (k = 0; k < (num - i); k++)
-        to[k] = 0;
-
-    r = num;
+    r = BN_bn2binpad(res, to, num);
  err:
     if (ctx != NULL)
         BN_CTX_end(ctx);
  err:
     if (ctx != NULL)
         BN_CTX_end(ctx);
@@ -376,7 +367,6 @@ static int rsa_ossl_private_decrypt(int flen, const unsigned char *from,
 {
     BIGNUM *f, *ret;
     int j, num = 0, r = -1;
 {
     BIGNUM *f, *ret;
     int j, num = 0, r = -1;
-    unsigned char *p;
     unsigned char *buf = NULL;
     BN_CTX *ctx = NULL;
     int local_blinding = 0;
     unsigned char *buf = NULL;
     BN_CTX *ctx = NULL;
     int local_blinding = 0;
@@ -472,8 +462,7 @@ static int rsa_ossl_private_decrypt(int flen, const unsigned char *from,
         if (!rsa_blinding_invert(blinding, ret, unblind, ctx))
             goto err;
 
         if (!rsa_blinding_invert(blinding, ret, unblind, ctx))
             goto err;
 
-    p = buf;
-    j = BN_bn2bin(ret, p);      /* j is only used with no-padding mode */
+    j = BN_bn2binpad(ret, buf, num);
 
     switch (padding) {
     case RSA_PKCS1_PADDING:
 
     switch (padding) {
     case RSA_PKCS1_PADDING:
@@ -486,7 +475,7 @@ static int rsa_ossl_private_decrypt(int flen, const unsigned char *from,
         r = RSA_padding_check_SSLv23(to, num, buf, j, num);
         break;
     case RSA_NO_PADDING:
         r = RSA_padding_check_SSLv23(to, num, buf, j, num);
         break;
     case RSA_NO_PADDING:
-        r = RSA_padding_check_none(to, num, buf, j, num);
+        memcpy(to, buf, (r = j));
         break;
     default:
         RSAerr(RSA_F_RSA_OSSL_PRIVATE_DECRYPT, RSA_R_UNKNOWN_PADDING_TYPE);
         break;
     default:
         RSAerr(RSA_F_RSA_OSSL_PRIVATE_DECRYPT, RSA_R_UNKNOWN_PADDING_TYPE);
@@ -509,7 +498,6 @@ static int rsa_ossl_public_decrypt(int flen, const unsigned char *from,
 {
     BIGNUM *f, *ret;
     int i, num = 0, r = -1;
 {
     BIGNUM *f, *ret;
     int i, num = 0, r = -1;
-    unsigned char *p;
     unsigned char *buf = NULL;
     BN_CTX *ctx = NULL;
 
     unsigned char *buf = NULL;
     BN_CTX *ctx = NULL;
 
@@ -574,8 +562,7 @@ static int rsa_ossl_public_decrypt(int flen, const unsigned char *from,
         if (!BN_sub(ret, rsa->n, ret))
             goto err;
 
         if (!BN_sub(ret, rsa->n, ret))
             goto err;
 
-    p = buf;
-    i = BN_bn2bin(ret, p);
+    i = BN_bn2binpad(ret, buf, num);
 
     switch (padding) {
     case RSA_PKCS1_PADDING:
 
     switch (padding) {
     case RSA_PKCS1_PADDING:
@@ -585,7 +572,7 @@ static int rsa_ossl_public_decrypt(int flen, const unsigned char *from,
         r = RSA_padding_check_X931(to, num, buf, i, num);
         break;
     case RSA_NO_PADDING:
         r = RSA_padding_check_X931(to, num, buf, i, num);
         break;
     case RSA_NO_PADDING:
-        r = RSA_padding_check_none(to, num, buf, i, num);
+        memcpy(to, buf, (r = i));
         break;
     default:
         RSAerr(RSA_F_RSA_OSSL_PUBLIC_DECRYPT, RSA_R_UNKNOWN_PADDING_TYPE);
         break;
     default:
         RSAerr(RSA_F_RSA_OSSL_PUBLIC_DECRYPT, RSA_R_UNKNOWN_PADDING_TYPE);
@@ -604,7 +591,7 @@ static int rsa_ossl_public_decrypt(int flen, const unsigned char *from,
 
 static int rsa_ossl_mod_exp(BIGNUM *r0, const BIGNUM *I, RSA *rsa, BN_CTX *ctx)
 {
 
 static int rsa_ossl_mod_exp(BIGNUM *r0, const BIGNUM *I, RSA *rsa, BN_CTX *ctx)
 {
-    BIGNUM *r1, *m1, *vrfy, *r2, *m[RSA_MAX_PRIME_NUM];
+    BIGNUM *r1, *m1, *vrfy, *r2, *m[RSA_MAX_PRIME_NUM - 2];
     int ret = 0, i, ex_primes = 0;
     RSA_PRIME_INFO *pinfo;
 
     int ret = 0, i, ex_primes = 0;
     RSA_PRIME_INFO *pinfo;
 
@@ -618,7 +605,8 @@ static int rsa_ossl_mod_exp(BIGNUM *r0, const BIGNUM *I, RSA *rsa, BN_CTX *ctx)
         goto err;
 
     if (rsa->version == RSA_ASN1_VERSION_MULTI
         goto err;
 
     if (rsa->version == RSA_ASN1_VERSION_MULTI
-        && (ex_primes = sk_RSA_PRIME_INFO_num(rsa->prime_infos)) <= 0)
+        && ((ex_primes = sk_RSA_PRIME_INFO_num(rsa->prime_infos)) <= 0
+             || ex_primes > RSA_MAX_PRIME_NUM - 2))
         goto err;
 
     {
         goto err;
 
     {