Flag RSA secret BNs as consttime on keygen and checks
[openssl.git] / crypto / rsa / rsa_sp800_56b_check.c
index c4c0b6a95b70b9ff9e384f7089ad1a2bf0ae3567..9840d082851954539430e973799edc4ec54912ab 100644 (file)
@@ -1,8 +1,8 @@
 /*
- * Copyright 2018-2019 The OpenSSL Project Authors. All Rights Reserved.
+ * Copyright 2018-2020 The OpenSSL Project Authors. All Rights Reserved.
  * Copyright (c) 2018-2019, Oracle and/or its affiliates.  All rights reserved.
  *
- * Licensed under the OpenSSL license (the "License").  You may not use
+ * Licensed under the Apache License 2.0 (the "License").  You may not use
  * this file except in compliance with the License.  You can obtain a copy
  * in the file LICENSE in the source distribution or at
  * https://www.openssl.org/source/license.html
@@ -37,7 +37,15 @@ int rsa_check_crt_components(const RSA *rsa, BN_CTX *ctx)
     r = BN_CTX_get(ctx);
     p1 = BN_CTX_get(ctx);
     q1 = BN_CTX_get(ctx);
-    ret = (q1 != NULL)
+    if (q1 != NULL) {
+        BN_set_flags(r, BN_FLG_CONSTTIME);
+        BN_set_flags(p1, BN_FLG_CONSTTIME);
+        BN_set_flags(q1, BN_FLG_CONSTTIME);
+        ret = 1;
+    } else {
+        ret = 0;
+    }
+    ret = ret
           /* p1 = p -1 */
           && (BN_copy(p1, rsa->p) != NULL)
           && BN_sub_word(p1, 1)
@@ -62,6 +70,7 @@ int rsa_check_crt_components(const RSA *rsa, BN_CTX *ctx)
           /* (f) 1 = (qInv . q) mod p */
           && BN_mod_mul(r, rsa->iqmp, rsa->q, rsa->p, ctx)
           && BN_is_one(r);
+    BN_clear(r);
     BN_clear(p1);
     BN_clear(q1);
     BN_CTX_end(ctx);
@@ -101,7 +110,7 @@ int rsa_check_prime_factor_range(const BIGNUM *p, int nbits, BN_CTX *ctx)
     if (shift >= 0) {
         /*
          * We don't have all the bits. bn_inv_sqrt_2 contains a rounded up
-         * value, so there is a very low probabilty that we'll reject a valid
+         * value, so there is a very low probability that we'll reject a valid
          * value.
          */
         if (!BN_lshift(low, low, shift))
@@ -138,7 +147,14 @@ int rsa_check_prime_factor(BIGNUM *p, BIGNUM *e, int nbits, BN_CTX *ctx)
     BN_CTX_start(ctx);
     p1 = BN_CTX_get(ctx);
     gcd = BN_CTX_get(ctx);
-    ret = (gcd != NULL)
+    if (gcd != NULL) {
+        BN_set_flags(p1, BN_FLG_CONSTTIME);
+        BN_set_flags(gcd, BN_FLG_CONSTTIME);
+        ret = 1;
+    } else {
+        ret = 0;
+    }
+    ret = ret
           /* (Step 5d) GCD(p-1, e) = 1 */
           && (BN_copy(p1, p) != NULL)
           && BN_sub_word(p1, 1)
@@ -172,7 +188,18 @@ int rsa_check_private_exponent(const RSA *rsa, int nbits, BN_CTX *ctx)
     lcm = BN_CTX_get(ctx);
     p1q1 = BN_CTX_get(ctx);
     gcd = BN_CTX_get(ctx);
-    ret = (gcd != NULL
+    if (gcd != NULL) {
+        BN_set_flags(r, BN_FLG_CONSTTIME);
+        BN_set_flags(p1, BN_FLG_CONSTTIME);
+        BN_set_flags(q1, BN_FLG_CONSTTIME);
+        BN_set_flags(lcm, BN_FLG_CONSTTIME);
+        BN_set_flags(p1q1, BN_FLG_CONSTTIME);
+        BN_set_flags(gcd, BN_FLG_CONSTTIME);
+        ret = 1;
+    } else {
+        ret = 0;
+    }
+    ret = (ret
           /* LCM(p - 1, q - 1) */
           && (rsa_get_lcm(ctx, rsa->p, rsa->q, lcm, gcd, p1, q1, p1q1) == 1)
           /* (Step 6a) d < LCM(p - 1, q - 1) */
@@ -181,6 +208,7 @@ int rsa_check_private_exponent(const RSA *rsa, int nbits, BN_CTX *ctx)
           && BN_mod_mul(r, rsa->e, rsa->d, lcm, ctx)
           && BN_is_one(r));
 
+    BN_clear(r);
     BN_clear(p1);
     BN_clear(q1);
     BN_clear(lcm);
@@ -189,12 +217,30 @@ int rsa_check_private_exponent(const RSA *rsa, int nbits, BN_CTX *ctx)
     return ret;
 }
 
+#ifndef FIPS_MODULE
+static int bn_is_three(const BIGNUM *bn)
+{
+    BIGNUM *num = BN_dup(bn);
+    int ret = (num != NULL && BN_sub_word(num, 3) && BN_is_zero(num));
+
+    BN_free(num);
+    return ret;
+}
+#endif /* FIPS_MODULE */
+
 /* Check exponent is odd, and has a bitlen ranging from [17..256] */
 int rsa_check_public_exponent(const BIGNUM *e)
 {
-    int bitlen = BN_num_bits(e);
+    int bitlen;
+
+    /* For legacy purposes RSA_3 is allowed in non fips mode */
+#ifndef FIPS_MODULE
+    if (bn_is_three(e))
+        return 1;
+#endif /* FIPS_MODULE */
 
-    return (BN_is_odd(e) &&  bitlen > 16 && bitlen < 257);
+    bitlen = BN_num_bits(e);
+    return (BN_is_odd(e) && bitlen > 16 && bitlen < 257);
 }
 
 /*
@@ -218,7 +264,12 @@ int rsa_check_pminusq_diff(BIGNUM *diff, const BIGNUM *p, const BIGNUM *q,
     return (BN_num_bits(diff) > bitlen);
 }
 
-/* return LCM(p-1, q-1) */
+/*
+ * return LCM(p-1, q-1)
+ *
+ * Caller should ensure that lcm, gcd, p1, q1, p1q1 are flagged with
+ * BN_FLG_CONSTTIME.
+ */
 int rsa_get_lcm(BN_CTX *ctx, const BIGNUM *p, const BIGNUM *q,
                 BIGNUM *lcm, BIGNUM *gcd, BIGNUM *p1, BIGNUM *q1,
                 BIGNUM *p1q1)
@@ -237,13 +288,17 @@ int rsa_get_lcm(BN_CTX *ctx, const BIGNUM *p, const BIGNUM *q,
  */
 int rsa_sp800_56b_check_public(const RSA *rsa)
 {
-    int ret = 0, nbits, status;
+    int ret = 0, status;
+#ifdef FIPS_MODULE
+    int nbits;
+#endif
     BN_CTX *ctx = NULL;
     BIGNUM *gcd = NULL;
 
     if (rsa->n == NULL || rsa->e == NULL)
         return 0;
 
+#ifdef FIPS_MODULE
     /*
      * (Step a): modulus must be 2048 or 3072 (caveat from SP800-56Br1)
      * NOTE: changed to allow keys >= 2048
@@ -253,11 +308,11 @@ int rsa_sp800_56b_check_public(const RSA *rsa)
         RSAerr(RSA_F_RSA_SP800_56B_CHECK_PUBLIC, RSA_R_INVALID_KEY_LENGTH);
         return 0;
     }
+#endif
     if (!BN_is_odd(rsa->n)) {
         RSAerr(RSA_F_RSA_SP800_56B_CHECK_PUBLIC, RSA_R_INVALID_MODULUS);
         return 0;
     }
-
     /* (Steps b-c): 2^16 < e < 2^256, n and e must be odd */
     if (!rsa_check_public_exponent(rsa->e)) {
         RSAerr(RSA_F_RSA_SP800_56B_CHECK_PUBLIC,
@@ -265,7 +320,7 @@ int rsa_sp800_56b_check_public(const RSA *rsa)
         return 0;
     }
 
-    ctx = BN_CTX_new();
+    ctx = BN_CTX_new_ex(rsa->libctx);
     gcd = BN_new();
     if (ctx == NULL || gcd == NULL)
         goto err;
@@ -354,7 +409,7 @@ int rsa_sp800_56b_check_keypair(const RSA *rsa, const BIGNUM *efixed,
         return 0;
     }
 
-    ctx = BN_CTX_new();
+    ctx = BN_CTX_new_ex(rsa->libctx);
     if (ctx == NULL)
         return 0;