fix memory allocation and reference counting issues
[openssl.git] / crypto / rsa / rsa_lib.c
index 70eaa59a8b386df0a1d96bfbd0754e0b86f24e65..1601e92ddb006a9dc02a5def4671b28bb94e8c25 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright 1995-2021 The OpenSSL Project Authors. All Rights Reserved.
+ * Copyright 1995-2022 The OpenSSL Project Authors. All Rights Reserved.
  *
  * Licensed under the Apache License 2.0 (the "License").  You may not use
  * this file except in compliance with the License.  You can obtain a copy
@@ -15,7 +15,9 @@
 
 #include <openssl/crypto.h>
 #include <openssl/core_names.h>
-#include <openssl/engine.h>
+#ifndef FIPS_MODULE
+# include <openssl/engine.h>
+#endif
 #include <openssl/evp.h>
 #include <openssl/param_build.h>
 #include "internal/cryptlib.h"
@@ -74,15 +76,18 @@ static RSA *rsa_new_intern(ENGINE *engine, OSSL_LIB_CTX *libctx)
 {
     RSA *ret = OPENSSL_zalloc(sizeof(*ret));
 
-    if (ret == NULL) {
-        ERR_raise(ERR_LIB_RSA, ERR_R_MALLOC_FAILURE);
+    if (ret == NULL)
         return NULL;
-    }
 
-    ret->references = 1;
     ret->lock = CRYPTO_THREAD_lock_new();
     if (ret->lock == NULL) {
-        ERR_raise(ERR_LIB_RSA, ERR_R_MALLOC_FAILURE);
+        ERR_raise(ERR_LIB_RSA, ERR_R_CRYPTO_LIB);
+        OPENSSL_free(ret);
+        return NULL;
+    }
+
+    if (!CRYPTO_NEW_REF(&ret->references, 1)) {
+        CRYPTO_THREAD_lock_free(ret->lock);
         OPENSSL_free(ret);
         return NULL;
     }
@@ -135,7 +140,7 @@ void RSA_free(RSA *r)
     if (r == NULL)
         return;
 
-    CRYPTO_DOWN_REF(&r->references, &i, r->lock);
+    CRYPTO_DOWN_REF(&r->references, &i);
     REF_PRINT_COUNT("RSA", r);
     if (i > 0)
         return;
@@ -152,6 +157,7 @@ void RSA_free(RSA *r)
 #endif
 
     CRYPTO_THREAD_lock_free(r->lock);
+    CRYPTO_FREE_REF(&r->references);
 
     BN_free(r->n);
     BN_free(r->e);
@@ -179,7 +185,7 @@ int RSA_up_ref(RSA *r)
 {
     int i;
 
-    if (CRYPTO_UP_REF(&r->references, &i, r->lock) <= 0)
+    if (CRYPTO_UP_REF(&r->references, &i) <= 0)
         return 0;
 
     REF_PRINT_COUNT("RSA", r);
@@ -785,10 +791,8 @@ int ossl_rsa_set0_all_params(RSA *r, const STACK_OF(BIGNUM) *primes,
                 goto err;
 
             /* Using ossl_rsa_multip_info_new() is wasteful, so allocate directly */
-            if ((pinfo = OPENSSL_zalloc(sizeof(*pinfo))) == NULL) {
-                ERR_raise(ERR_LIB_RSA, ERR_R_MALLOC_FAILURE);
+            if ((pinfo = OPENSSL_zalloc(sizeof(*pinfo))) == NULL)
                 goto err;
-            }
 
             pinfo->r = prime;
             pinfo->d = exp;
@@ -1082,6 +1086,7 @@ int EVP_PKEY_CTX_get_rsa_mgf1_md(EVP_PKEY_CTX *ctx, const EVP_MD **md)
 int EVP_PKEY_CTX_set0_rsa_oaep_label(EVP_PKEY_CTX *ctx, void *label, int llen)
 {
     OSSL_PARAM rsa_params[2], *p = rsa_params;
+    int ret;
 
     if (ctx == NULL || !EVP_PKEY_CTX_IS_ASYM_CIPHER_OP(ctx)) {
         ERR_raise(ERR_LIB_EVP, EVP_R_COMMAND_NOT_SUPPORTED);
@@ -1098,10 +1103,11 @@ int EVP_PKEY_CTX_set0_rsa_oaep_label(EVP_PKEY_CTX *ctx, void *label, int llen)
                                              (void *)label, (size_t)llen);
     *p++ = OSSL_PARAM_construct_end();
 
-    if (!evp_pkey_ctx_set_params_strict(ctx, rsa_params))
-        return 0;
+    ret = evp_pkey_ctx_set_params_strict(ctx, rsa_params);
+    if (ret <= 0)
+        return ret;
 
-    /* Ownership is supposed to be transfered to the callee. */
+    /* Ownership is supposed to be transferred to the callee. */
     OPENSSL_free(label);
     return 1;
 }
@@ -1242,8 +1248,11 @@ int EVP_PKEY_CTX_set1_rsa_keygen_pubexp(EVP_PKEY_CTX *ctx, BIGNUM *pubexp)
      * When we're dealing with a provider, there's no need to duplicate
      * pubexp, as it gets copied when transforming to an OSSL_PARAM anyway.
      */
-    if (evp_pkey_ctx_is_legacy(ctx))
+    if (evp_pkey_ctx_is_legacy(ctx)) {
         pubexp = BN_dup(pubexp);
+        if (pubexp == NULL)
+            return 0;
+    }
     ret = EVP_PKEY_CTX_ctrl(ctx, EVP_PKEY_RSA, EVP_PKEY_OP_KEYGEN,
                             EVP_PKEY_CTRL_RSA_KEYGEN_PUBEXP, 0, pubexp);
     if (evp_pkey_ctx_is_legacy(ctx) && ret <= 0)