Return an error from BN_mod_inverse if n is 1 (or -1)
authorMatt Caswell <matt@openssl.org>
Fri, 27 Apr 2018 16:36:11 +0000 (17:36 +0100)
committerMatt Caswell <matt@openssl.org>
Thu, 3 May 2018 09:14:12 +0000 (10:14 +0100)
Calculating BN_mod_inverse where n is 1 (or -1) doesn't make sense. We
should return an error in that case. Instead we were returning a valid
result with value 0.

Fixes #6004

Reviewed-by: Rich Salz <rsalz@openssl.org>
(Merged from https://github.com/openssl/openssl/pull/6119)

crypto/bn/bn_gcd.c
crypto/bn/bn_mont.c

index 22f8093..6d8c565 100644 (file)
@@ -140,7 +140,14 @@ BIGNUM *int_bn_mod_inverse(BIGNUM *in,
     BIGNUM *ret = NULL;
     int sign;
 
     BIGNUM *ret = NULL;
     int sign;
 
-    if (pnoinv)
+    /* This is invalid input so we don't worry about constant time here */
+    if (BN_abs_is_word(n, 1) || BN_is_zero(n)) {
+        if (pnoinv != NULL)
+            *pnoinv = 1;
+        return NULL;
+    }
+
+    if (pnoinv != NULL)
         *pnoinv = 0;
 
     if ((BN_get_flags(a, BN_FLG_CONSTTIME) != 0)
         *pnoinv = 0;
 
     if ((BN_get_flags(a, BN_FLG_CONSTTIME) != 0)
index b85a893..5e068c4 100644 (file)
@@ -281,7 +281,9 @@ int BN_MONT_CTX_set(BN_MONT_CTX *mont, const BIGNUM *mod, BN_CTX *ctx)
         if ((buf[1] = mod->top > 1 ? mod->d[1] : 0))
             tmod.top = 2;
 
         if ((buf[1] = mod->top > 1 ? mod->d[1] : 0))
             tmod.top = 2;
 
-        if ((BN_mod_inverse(Ri, R, &tmod, ctx)) == NULL)
+        if (BN_is_one(&tmod))
+            BN_zero(Ri);
+        else if ((BN_mod_inverse(Ri, R, &tmod, ctx)) == NULL)
             goto err;
         if (!BN_lshift(Ri, Ri, 2 * BN_BITS2))
             goto err;           /* R*Ri */
             goto err;
         if (!BN_lshift(Ri, Ri, 2 * BN_BITS2))
             goto err;           /* R*Ri */
@@ -314,7 +316,9 @@ int BN_MONT_CTX_set(BN_MONT_CTX *mont, const BIGNUM *mod, BN_CTX *ctx)
         buf[1] = 0;
         tmod.top = buf[0] != 0 ? 1 : 0;
         /* Ri = R^-1 mod N */
         buf[1] = 0;
         tmod.top = buf[0] != 0 ? 1 : 0;
         /* Ri = R^-1 mod N */
-        if ((BN_mod_inverse(Ri, R, &tmod, ctx)) == NULL)
+        if (BN_is_one(&tmod))
+            BN_zero(Ri);
+        else if ((BN_mod_inverse(Ri, R, &tmod, ctx)) == NULL)
             goto err;
         if (!BN_lshift(Ri, Ri, BN_BITS2))
             goto err;           /* R*Ri */
             goto err;
         if (!BN_lshift(Ri, Ri, BN_BITS2))
             goto err;           /* R*Ri */