Fix dh dupctx refcount error
authorslontis <shane.lontis@oracle.com>
Thu, 2 Sep 2021 06:49:37 +0000 (16:49 +1000)
committerTomas Mraz <tomas@openssl.org>
Fri, 3 Sep 2021 10:31:59 +0000 (12:31 +0200)
Reviewed-by: Paul Dale <pauli@openssl.org>
Reviewed-by: Tomas Mraz <tomas@openssl.org>
(Merged from https://github.com/openssl/openssl/pull/16495)

providers/implementations/exchange/dh_exch.c
test/evp_test.c

index 1dffc8d1126c5b97ea387e5c527de7bf4f9786d2..ea05b3177e89b2cb2c3cc862efa6617c263b93c3 100644 (file)
@@ -238,7 +238,6 @@ static int dh_derive(void *vpdhctx, unsigned char *secret,
     return 0;
 }
 
-
 static void dh_freectx(void *vpdhctx)
 {
     PROV_DH_CTX *pdhctx = (PROV_DH_CTX *)vpdhctx;
@@ -271,12 +270,12 @@ static void *dh_dupctx(void *vpdhctx)
     dstctx->kdf_ukm = NULL;
     dstctx->kdf_cekalg = NULL;
 
-    if (dstctx->dh != NULL && !DH_up_ref(srcctx->dh))
+    if (srcctx->dh != NULL && !DH_up_ref(srcctx->dh))
         goto err;
     else
         dstctx->dh = srcctx->dh;
 
-    if (dstctx->dhpeer != NULL && !DH_up_ref(srcctx->dhpeer))
+    if (srcctx->dhpeer != NULL && !DH_up_ref(srcctx->dhpeer))
         goto err;
     else
         dstctx->dhpeer = srcctx->dhpeer;
index 075abc5ad93aa6e8a67889ffcf1846449c5f0963..eda8c827f901dac06feaac70e39e8a4a4bfc9d9b 100644 (file)
@@ -1848,11 +1848,17 @@ static int pderive_test_parse(EVP_TEST *t,
 
 static int pderive_test_run(EVP_TEST *t)
 {
+    EVP_PKEY_CTX *dctx = NULL;
     PKEY_DATA *expected = t->data;
     unsigned char *got = NULL;
     size_t got_len;
 
-    if (EVP_PKEY_derive(expected->ctx, NULL, &got_len) <= 0) {
+    if (!TEST_ptr(dctx = EVP_PKEY_CTX_dup(expected->ctx))) {
+        t->err = "DERIVE_ERROR";
+        goto err;
+    }
+
+    if (EVP_PKEY_derive(dctx, NULL, &got_len) <= 0) {
         t->err = "DERIVE_ERROR";
         goto err;
     }
@@ -1860,7 +1866,7 @@ static int pderive_test_run(EVP_TEST *t)
         t->err = "DERIVE_ERROR";
         goto err;
     }
-    if (EVP_PKEY_derive(expected->ctx, got, &got_len) <= 0) {
+    if (EVP_PKEY_derive(dctx, got, &got_len) <= 0) {
         t->err = "DERIVE_ERROR";
         goto err;
     }
@@ -1872,6 +1878,7 @@ static int pderive_test_run(EVP_TEST *t)
     t->err = NULL;
  err:
     OPENSSL_free(got);
+    EVP_PKEY_CTX_free(dctx);
     return 1;
 }