Fix potential double free in rsa_keygen pairwise test.
[openssl.git] / crypto / rsa / rsa_gen.c
index 3d5a32a0a1224b80c1e8b0c5b8ef9e4ece4758d1..1cdc8d91e8823a88a9ee550f86acc420fd920227 100644 (file)
@@ -66,20 +66,14 @@ int RSA_generate_multi_prime_key(RSA *rsa, int bits, int primes,
         else
             return 0;
     }
-#endif /* FIPS_MODULE */
-    return rsa_keygen(NULL, rsa, bits, primes, e_value, cb, 0);
+#endif /* FIPS_MODUKE */
+    return rsa_keygen(rsa->libctx, rsa, bits, primes, e_value, cb, 0);
 }
 
-static int rsa_keygen(OPENSSL_CTX *libctx, RSA *rsa, int bits, int primes,
-                      BIGNUM *e_value, BN_GENCB *cb, int pairwise_test)
+#ifndef FIPS_MODULE
+static int rsa_multiprime_keygen(RSA *rsa, int bits, int primes,
+                                 BIGNUM *e_value, BN_GENCB *cb)
 {
-    int ok = -1;
-#ifdef FIPS_MODULE
-    if (primes != 2)
-        return 0;
-    ok = rsa_sp800_56b_generate_key(rsa, bits, e_value, cb);
-    pairwise_test = 1; /* FIPS MODE needs to always run the pairwise test */
-#else
     BIGNUM *r0 = NULL, *r1 = NULL, *r2 = NULL, *tmp, *prime;
     int n = 0, bitsr[RSA_MAX_PRIME_NUM], bitse = 0;
     int i = 0, quo = 0, rmd = 0, adj = 0, retries = 0;
@@ -88,6 +82,7 @@ static int rsa_keygen(OPENSSL_CTX *libctx, RSA *rsa, int bits, int primes,
     BN_CTX *ctx = NULL;
     BN_ULONG bitst = 0;
     unsigned long error = 0;
+    int ok = -1;
 
     if (bits < RSA_MIN_MODULUS_BITS) {
         ok = 0;             /* we set our own err */
@@ -95,6 +90,12 @@ static int rsa_keygen(OPENSSL_CTX *libctx, RSA *rsa, int bits, int primes,
         goto err;
     }
 
+    /* A bad value for e can cause infinite loops */
+    if (e_value != NULL && !rsa_check_public_exponent(e_value)) {
+        RSAerr(0, RSA_R_PUB_EXPONENT_OUT_OF_RANGE);
+        return 0;
+    }
+
     if (primes < RSA_DEFAULT_PRIME_NUM || primes > rsa_multip_cap(bits)) {
         ok = 0;             /* we set our own err */
         RSAerr(0, RSA_R_KEY_PRIME_NUM_INVALID);
@@ -125,18 +126,24 @@ static int rsa_keygen(OPENSSL_CTX *libctx, RSA *rsa, int bits, int primes,
         goto err;
     if (!rsa->d && ((rsa->d = BN_secure_new()) == NULL))
         goto err;
+    BN_set_flags(rsa->d, BN_FLG_CONSTTIME);
     if (!rsa->e && ((rsa->e = BN_new()) == NULL))
         goto err;
     if (!rsa->p && ((rsa->p = BN_secure_new()) == NULL))
         goto err;
+    BN_set_flags(rsa->p, BN_FLG_CONSTTIME);
     if (!rsa->q && ((rsa->q = BN_secure_new()) == NULL))
         goto err;
+    BN_set_flags(rsa->q, BN_FLG_CONSTTIME);
     if (!rsa->dmp1 && ((rsa->dmp1 = BN_secure_new()) == NULL))
         goto err;
+    BN_set_flags(rsa->dmp1, BN_FLG_CONSTTIME);
     if (!rsa->dmq1 && ((rsa->dmq1 = BN_secure_new()) == NULL))
         goto err;
+    BN_set_flags(rsa->dmq1, BN_FLG_CONSTTIME);
     if (!rsa->iqmp && ((rsa->iqmp = BN_secure_new()) == NULL))
         goto err;
+    BN_set_flags(rsa->iqmp, BN_FLG_CONSTTIME);
 
     /* initialize multi-prime components */
     if (primes > RSA_DEFAULT_PRIME_NUM) {
@@ -407,8 +414,29 @@ static int rsa_keygen(OPENSSL_CTX *libctx, RSA *rsa, int bits, int primes,
     }
     BN_CTX_end(ctx);
     BN_CTX_free(ctx);
+    return ok;
+}
 #endif /* FIPS_MODULE */
 
+static int rsa_keygen(OPENSSL_CTX *libctx, RSA *rsa, int bits, int primes,
+                      BIGNUM *e_value, BN_GENCB *cb, int pairwise_test)
+{
+    int ok = 0;
+
+    /*
+     * Only multi-prime keys or insecure keys with a small key length will use
+     * the older rsa_multiprime_keygen().
+     */
+    if (primes == 2 && bits >= 2048)
+        ok = rsa_sp800_56b_generate_key(rsa, bits, e_value, cb);
+#ifndef FIPS_MODULE
+    else
+        ok = rsa_multiprime_keygen(rsa, bits, primes, e_value, cb);
+#endif /* FIPS_MODULE */
+
+#ifdef FIPS_MODULE
+    pairwise_test = 1; /* FIPS MODE needs to always run the pairwise test */
+#endif
     if (pairwise_test && ok > 0) {
         OSSL_CALLBACK *stcb = NULL;
         void *stcbarg = NULL;
@@ -423,6 +451,12 @@ static int rsa_keygen(OPENSSL_CTX *libctx, RSA *rsa, int bits, int primes,
             BN_clear_free(rsa->dmp1);
             BN_clear_free(rsa->dmq1);
             BN_clear_free(rsa->iqmp);
+            rsa->d = NULL;
+            rsa->p = NULL;
+            rsa->q = NULL;
+            rsa->dmp1 = NULL;
+            rsa->dmq1 = NULL;
+            rsa->iqmp = NULL;
         }
     }
     return ok;
@@ -463,7 +497,7 @@ static int rsa_keygen_pairwise_test(RSA *rsa, OSSL_CALLBACK *cb, void *cbarg)
     if (ciphertxt_len <= 0)
         goto err;
     if (ciphertxt_len == plaintxt_len
-        && memcmp(decoded, plaintxt, plaintxt_len) == 0)
+        && memcmp(ciphertxt, plaintxt, plaintxt_len) == 0)
         goto err;
 
     OSSL_SELF_TEST_oncorrupt_byte(st, ciphertxt);