bn: Properly error out if aliasing return value with modulus
authorTomas Mraz <tomas@openssl.org>
Wed, 18 Oct 2023 13:50:30 +0000 (15:50 +0200)
committerHugo Landau <hlandau@openssl.org>
Thu, 26 Oct 2023 14:25:47 +0000 (15:25 +0100)
Test case amended from code initially written by Bernd Edlinger.

Fixes #21110

Reviewed-by: Dmitry Belyavskiy <beldmit@gmail.com>
Reviewed-by: Paul Dale <pauli@openssl.org>
Reviewed-by: Hugo Landau <hlandau@openssl.org>
(Merged from https://github.com/openssl/openssl/pull/22421)

crypto/bn/bn_exp.c
crypto/bn/bn_mod.c
doc/man3/BN_add.pod
doc/man3/BN_mod_inverse.pod
test/bntest.c

index cb6d19229fe6f9e01a924d6f0d2cb93b199f6b8d..b876edbfac36e3e71a51d00465e7ddf322713351 100644 (file)
@@ -242,6 +242,14 @@ int BN_mod_exp_recp(BIGNUM *r, const BIGNUM *a, const BIGNUM *p,
     wstart = bits - 1;          /* The top bit of the window */
     wend = 0;                   /* The bottom bit of the window */
 
+    if (r == p) {
+        BIGNUM *p_dup = BN_CTX_get(ctx);
+
+        if (p_dup == NULL || BN_copy(p_dup, p) == NULL)
+            goto err;
+        p = p_dup;
+    }
+
     if (!BN_one(r))
         goto err;
 
@@ -1317,6 +1325,11 @@ int BN_mod_exp_simple(BIGNUM *r, const BIGNUM *a, const BIGNUM *p,
         return 0;
     }
 
+    if (r == m) {
+        ERR_raise(ERR_LIB_BN, ERR_R_PASSED_INVALID_ARGUMENT);
+        return 0;
+    }
+
     bits = BN_num_bits(p);
     if (bits == 0) {
         /* x**0 mod 1, or x**0 mod -1 is still zero. */
@@ -1361,6 +1374,14 @@ int BN_mod_exp_simple(BIGNUM *r, const BIGNUM *a, const BIGNUM *p,
     wstart = bits - 1;          /* The top bit of the window */
     wend = 0;                   /* The bottom bit of the window */
 
+    if (r == p) {
+        BIGNUM *p_dup = BN_CTX_get(ctx);
+
+        if (p_dup == NULL || BN_copy(p_dup, p) == NULL)
+            goto err;
+        p = p_dup;
+    }
+
     if (!BN_one(r))
         goto err;
 
index 982e0e992c00b2d434781b9572b28690940f2308..d7c2f4bd5bfa9e54494d3b068ffa771e38002c09 100644 (file)
@@ -17,6 +17,11 @@ int BN_nnmod(BIGNUM *r, const BIGNUM *m, const BIGNUM *d, BN_CTX *ctx)
      * always holds)
      */
 
+    if (r == d) {
+        ERR_raise(ERR_LIB_BN, ERR_R_PASSED_INVALID_ARGUMENT);
+        return 0;
+    }
+
     if (!(BN_mod(r, m, d, ctx)))
         return 0;
     if (!r->neg)
@@ -184,6 +189,11 @@ int bn_mod_sub_fixed_top(BIGNUM *r, const BIGNUM *a, const BIGNUM *b,
 int BN_mod_sub_quick(BIGNUM *r, const BIGNUM *a, const BIGNUM *b,
                      const BIGNUM *m)
 {
+    if (r == m) {
+        ERR_raise(ERR_LIB_BN, ERR_R_PASSED_INVALID_ARGUMENT);
+        return 0;
+    }
+
     if (!BN_sub(r, a, b))
         return 0;
     if (r->neg)
index 9561d554318f1894eb7008773c5b3cc15ffebf9a..35cfdd1495fd125aae0ff8d5d784af7333004904 100644 (file)
@@ -114,6 +114,11 @@ temporary variables; see L<BN_CTX_new(3)>.
 Unless noted otherwise, the result B<BIGNUM> must be different from
 the arguments.
 
+=head1 NOTES
+
+For modular operations such as BN_nnmod() or BN_mod_exp() it is an error
+to use the same B<BIGNUM> object for the modulus as for the output.
+
 =head1 RETURN VALUES
 
 The BN_mod_sqrt() returns the result (possibly incorrect if I<p> is
index 5dbb5c3cc2d602aa3a28605091d148ad7c82ee82..f88e0e63fafa446b276dc671ea5d59af9ac2554a 100644 (file)
@@ -18,7 +18,11 @@ places the result in B<r> (C<(a*r)%n==1>). If B<r> is NULL,
 a new B<BIGNUM> is created.
 
 B<ctx> is a previously allocated B<BN_CTX> used for temporary
-variables. B<r> may be the same B<BIGNUM> as B<a> or B<n>.
+variables. B<r> may be the same B<BIGNUM> as B<a>.
+
+=head1 NOTES
+
+It is an error to use the same B<BIGNUM> as B<n>.
 
 =head1 RETURN VALUES
 
index 9c0633d7f16c5ddfed8e3a4a17077ce1f0fce595..2ffff10ef1c29f20c365df3bbd3a5ec86eea0939 100644 (file)
@@ -3165,6 +3165,108 @@ err:
     return res;
 }
 
+static int test_mod_inverse(void)
+{
+    int res = 0;
+    char *str = NULL;
+    BIGNUM *a = NULL;
+    BIGNUM *b = NULL;
+    BIGNUM *r = NULL;
+
+    if (!TEST_true(BN_dec2bn(&a, "5193817943")))
+        goto err;
+    if (!TEST_true(BN_dec2bn(&b, "3259122431")))
+        goto err;
+    if (!TEST_ptr(r = BN_new()))
+        goto err;
+    if (!TEST_ptr_eq(BN_mod_inverse(r, a, b, ctx), r))
+        goto err;
+    if (!TEST_ptr_ne(str = BN_bn2dec(r), NULL))
+        goto err;
+    if (!TEST_int_eq(strcmp(str, "2609653924"), 0))
+        goto err;
+
+    /* Note that this aliases the result with the modulus. */
+    if (!TEST_ptr_null(BN_mod_inverse(b, a, b, ctx)))
+        goto err;
+
+    res = 1;
+
+err:
+    BN_free(a);
+    BN_free(b);
+    BN_free(r);
+    OPENSSL_free(str);
+    return res;
+}
+
+static int test_mod_exp_alias(int idx)
+{
+    int res = 0;
+    char *str = NULL;
+    BIGNUM *a = NULL;
+    BIGNUM *b = NULL;
+    BIGNUM *c = NULL;
+    BIGNUM *r = NULL;
+
+    if (!TEST_true(BN_dec2bn(&a, "15")))
+        goto err;
+    if (!TEST_true(BN_dec2bn(&b, "10")))
+        goto err;
+    if (!TEST_true(BN_dec2bn(&c, "39")))
+        goto err;
+    if (!TEST_ptr(r = BN_new()))
+        goto err;
+
+    if (!TEST_int_eq((idx == 0 ? BN_mod_exp_simple
+                               : BN_mod_exp_recp)(r, a, b, c, ctx), 1))
+        goto err;
+    if (!TEST_ptr_ne(str = BN_bn2dec(r), NULL))
+        goto err;
+    if (!TEST_str_eq(str, "36"))
+        goto err;
+
+    OPENSSL_free(str);
+    str = NULL;
+
+    BN_copy(r, b);
+
+    /* Aliasing with exponent must work. */
+    if (!TEST_int_eq((idx == 0 ? BN_mod_exp_simple
+                               : BN_mod_exp_recp)(r, a, r, c, ctx), 1))
+        goto err;
+    if (!TEST_ptr_ne(str = BN_bn2dec(r), NULL))
+        goto err;
+    if (!TEST_str_eq(str, "36"))
+        goto err;
+
+    OPENSSL_free(str);
+    str = NULL;
+
+    /* Aliasing with modulus should return failure for the simple call. */
+    if (idx == 0) {
+        if (!TEST_int_eq(BN_mod_exp_simple(c, a, b, c, ctx), 0))
+            goto err;
+    } else {
+        if (!TEST_int_eq(BN_mod_exp_recp(c, a, b, c, ctx), 1))
+            goto err;
+        if (!TEST_ptr_ne(str = BN_bn2dec(c), NULL))
+            goto err;
+        if (!TEST_str_eq(str, "36"))
+            goto err;
+    }
+
+    res = 1;
+
+err:
+    BN_free(a);
+    BN_free(b);
+    BN_free(c);
+    BN_free(r);
+    OPENSSL_free(str);
+    return res;
+}
+
 static int file_test_run(STANZA *s)
 {
     static const FILETEST filetests[] = {
@@ -3274,6 +3376,8 @@ int setup_tests(void)
         ADD_ALL_TESTS(test_signed_mod_replace_ab, OSSL_NELEM(signed_mod_tests));
         ADD_ALL_TESTS(test_signed_mod_replace_ba, OSSL_NELEM(signed_mod_tests));
         ADD_TEST(test_mod);
+        ADD_TEST(test_mod_inverse);
+        ADD_ALL_TESTS(test_mod_exp_alias, 2);
         ADD_TEST(test_modexp_mont5);
         ADD_TEST(test_kronecker);
         ADD_TEST(test_rand);