`EVP_PKEY_CTX_dup` segmentation fault fix
[openssl.git] / crypto / evp / pmeth_lib.c
index 954166caae0168b1d4504d59e62931308d813567..caf10b2d5c968d2e3ab003696a50af5f9ecae4bc 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright 2006-2021 The OpenSSL Project Authors. All Rights Reserved.
+ * Copyright 2006-2022 The OpenSSL Project Authors. All Rights Reserved.
  *
  * Licensed under the Apache License 2.0 (the "License").  You may not use
  * this file except in compliance with the License.  You can obtain a copy
@@ -128,10 +128,8 @@ EVP_PKEY_METHOD *EVP_PKEY_meth_new(int id, int flags)
     EVP_PKEY_METHOD *pmeth;
 
     pmeth = OPENSSL_zalloc(sizeof(*pmeth));
-    if (pmeth == NULL) {
-        ERR_raise(ERR_LIB_EVP, ERR_R_MALLOC_FAILURE);
+    if (pmeth == NULL)
         return NULL;
-    }
 
     pmeth->pkey_id = id;
     pmeth->flags = flags | EVP_PKEY_FLAG_DYNAMIC;
@@ -265,7 +263,20 @@ static EVP_PKEY_CTX *int_ctx_new(OSSL_LIB_CTX *libctx,
      * fetching a provider implementation.
      */
     if (e == NULL && app_pmeth == NULL && keytype != NULL) {
-        keymgmt = EVP_KEYMGMT_fetch(libctx, keytype, propquery);
+        /*
+         * If |pkey| is given and is provided, we take a reference to its
+         * keymgmt.  Otherwise, we fetch one for the keytype we got. This
+         * is to ensure that operation init functions can access what they
+         * need through this single pointer.
+         */
+        if (pkey != NULL && pkey->keymgmt != NULL) {
+            if (!EVP_KEYMGMT_up_ref(pkey->keymgmt))
+                ERR_raise(ERR_LIB_EVP, EVP_R_INITIALIZATION_ERROR);
+            else
+                keymgmt = pkey->keymgmt;
+        } else {
+            keymgmt = EVP_KEYMGMT_fetch(libctx, keytype, propquery);
+        }
         if (keymgmt == NULL)
             return NULL;   /* EVP_KEYMGMT_fetch() recorded an error */
 
@@ -303,8 +314,6 @@ static EVP_PKEY_CTX *int_ctx_new(OSSL_LIB_CTX *libctx,
         ERR_raise(ERR_LIB_EVP, EVP_R_UNSUPPORTED_ALGORITHM);
     } else {
         ret = OPENSSL_zalloc(sizeof(*ret));
-        if (ret == NULL)
-            ERR_raise(ERR_LIB_EVP, ERR_R_MALLOC_FAILURE);
     }
 
 #if !defined(OPENSSL_NO_ENGINE) && !defined(FIPS_MODULE)
@@ -468,10 +477,8 @@ EVP_PKEY_CTX *EVP_PKEY_CTX_dup(const EVP_PKEY_CTX *pctx)
     }
 # endif
     rctx = OPENSSL_zalloc(sizeof(*rctx));
-    if (rctx == NULL) {
-        ERR_raise(ERR_LIB_EVP, ERR_R_MALLOC_FAILURE);
+    if (rctx == NULL)
         return NULL;
-    }
 
     if (pctx->pkey != NULL)
         EVP_PKEY_up_ref(pctx->pkey);
@@ -496,10 +503,14 @@ EVP_PKEY_CTX *EVP_PKEY_CTX_dup(const EVP_PKEY_CTX *pctx)
         if (pctx->op.kex.algctx != NULL) {
             if (!ossl_assert(pctx->op.kex.exchange != NULL))
                 goto err;
-            rctx->op.kex.algctx
-                = pctx->op.kex.exchange->dupctx(pctx->op.kex.algctx);
+
+            if (pctx->op.kex.exchange->dupctx != NULL)
+                rctx->op.kex.algctx
+                    = pctx->op.kex.exchange->dupctx(pctx->op.kex.algctx);
+
             if (rctx->op.kex.algctx == NULL) {
                 EVP_KEYEXCH_free(rctx->op.kex.exchange);
+                rctx->op.kex.exchange = NULL;
                 goto err;
             }
             return rctx;
@@ -513,10 +524,14 @@ EVP_PKEY_CTX *EVP_PKEY_CTX_dup(const EVP_PKEY_CTX *pctx)
         if (pctx->op.sig.algctx != NULL) {
             if (!ossl_assert(pctx->op.sig.signature != NULL))
                 goto err;
-            rctx->op.sig.algctx
-                = pctx->op.sig.signature->dupctx(pctx->op.sig.algctx);
+
+            if (pctx->op.sig.signature->dupctx != NULL)
+                rctx->op.sig.algctx
+                    = pctx->op.sig.signature->dupctx(pctx->op.sig.algctx);
+
             if (rctx->op.sig.algctx == NULL) {
                 EVP_SIGNATURE_free(rctx->op.sig.signature);
+                rctx->op.sig.signature = NULL;
                 goto err;
             }
             return rctx;
@@ -530,10 +545,14 @@ EVP_PKEY_CTX *EVP_PKEY_CTX_dup(const EVP_PKEY_CTX *pctx)
         if (pctx->op.ciph.algctx != NULL) {
             if (!ossl_assert(pctx->op.ciph.cipher != NULL))
                 goto err;
-            rctx->op.ciph.algctx
-                = pctx->op.ciph.cipher->dupctx(pctx->op.ciph.algctx);
+
+            if (pctx->op.ciph.cipher->dupctx != NULL)
+                rctx->op.ciph.algctx
+                    = pctx->op.ciph.cipher->dupctx(pctx->op.ciph.algctx);
+
             if (rctx->op.ciph.algctx == NULL) {
                 EVP_ASYM_CIPHER_free(rctx->op.ciph.cipher);
+                rctx->op.ciph.cipher = NULL;
                 goto err;
             }
             return rctx;
@@ -547,10 +566,14 @@ EVP_PKEY_CTX *EVP_PKEY_CTX_dup(const EVP_PKEY_CTX *pctx)
         if (pctx->op.encap.algctx != NULL) {
             if (!ossl_assert(pctx->op.encap.kem != NULL))
                 goto err;
-            rctx->op.encap.algctx
-                = pctx->op.encap.kem->dupctx(pctx->op.encap.algctx);
+
+            if (pctx->op.encap.kem->dupctx != NULL)
+                rctx->op.encap.algctx
+                    = pctx->op.encap.kem->dupctx(pctx->op.encap.algctx);
+
             if (rctx->op.encap.algctx == NULL) {
                 EVP_KEM_free(rctx->op.encap.kem);
+                rctx->op.encap.kem = NULL;
                 goto err;
             }
             return rctx;
@@ -597,13 +620,13 @@ int EVP_PKEY_meth_add0(const EVP_PKEY_METHOD *pmeth)
 {
     if (app_pkey_methods == NULL) {
         app_pkey_methods = sk_EVP_PKEY_METHOD_new(pmeth_cmp);
-        if (app_pkey_methods == NULL){
-            ERR_raise(ERR_LIB_EVP, ERR_R_MALLOC_FAILURE);
+        if (app_pkey_methods == NULL) {
+            ERR_raise(ERR_LIB_EVP, ERR_R_CRYPTO_LIB);
             return 0;
         }
     }
     if (!sk_EVP_PKEY_METHOD_push(app_pkey_methods, pmeth)) {
-        ERR_raise(ERR_LIB_EVP, ERR_R_MALLOC_FAILURE);
+        ERR_raise(ERR_LIB_EVP, ERR_R_CRYPTO_LIB);
         return 0;
     }
     sk_EVP_PKEY_METHOD_sort(app_pkey_methods);
@@ -849,7 +872,7 @@ int evp_pkey_ctx_set_params_strict(EVP_PKEY_CTX *ctx, OSSL_PARAM *params)
 
         for (p = params; p->key != NULL; p++) {
             /* Check the ctx actually understands this parameter */
-            if (OSSL_PARAM_locate_const(settable, p->key) == NULL )
+            if (OSSL_PARAM_locate_const(settable, p->key) == NULL)
                 return -2;
         }
     }
@@ -872,9 +895,9 @@ int evp_pkey_ctx_get_params_strict(EVP_PKEY_CTX *ctx, OSSL_PARAM *params)
         const OSSL_PARAM *gettable = EVP_PKEY_CTX_gettable_params(ctx);
         const OSSL_PARAM *p;
 
-        for (p = params; p->key != NULL; p++ ) {
+        for (p = params; p->key != NULL; p++) {
             /* Check the ctx actually understands this parameter */
-            if (OSSL_PARAM_locate_const(gettable, p->key) == NULL )
+            if (OSSL_PARAM_locate_const(gettable, p->key) == NULL)
                 return -2;
         }
     }
@@ -1468,17 +1491,13 @@ static int evp_pkey_ctx_store_cached_data(EVP_PKEY_CTX *ctx,
         evp_pkey_ctx_free_cached_data(ctx, cmd, name);
         if (name != NULL) {
             ctx->cached_parameters.dist_id_name = OPENSSL_strdup(name);
-            if (ctx->cached_parameters.dist_id_name == NULL) {
-                ERR_raise(ERR_LIB_EVP, ERR_R_MALLOC_FAILURE);
+            if (ctx->cached_parameters.dist_id_name == NULL)
                 return 0;
-            }
         }
         if (data_len > 0) {
             ctx->cached_parameters.dist_id = OPENSSL_memdup(data, data_len);
-            if (ctx->cached_parameters.dist_id == NULL) {
-                ERR_raise(ERR_LIB_EVP, ERR_R_MALLOC_FAILURE);
+            if (ctx->cached_parameters.dist_id == NULL)
                 return 0;
-            }
         }
         ctx->cached_parameters.dist_id_set = 1;
         ctx->cached_parameters.dist_id_len = data_len;