Add evp_test fixes.
[openssl.git] / crypto / sm2 / sm2_crypt.c
index c09e4c001b8963e486157cc4329fb96ee2885109..0ae67fb22b68345d632557648578cd8ade5c7d66 100644 (file)
@@ -138,6 +138,9 @@ int sm2_encrypt(const EC_KEY *key,
     uint8_t *C3 = NULL;
     size_t field_size;
     const int C3_size = EVP_MD_size(digest);
+    EVP_MD *fetched_digest = NULL;
+    OPENSSL_CTX *libctx = ec_key_get_libctx(key);
+    const char *propq = ec_key_get0_propq(key);
 
     /* NULL these before any "goto done" */
     ctext_struct.C2 = NULL;
@@ -156,7 +159,7 @@ int sm2_encrypt(const EC_KEY *key,
 
     kG = EC_POINT_new(group);
     kP = EC_POINT_new(group);
-    ctx = BN_CTX_new();
+    ctx = BN_CTX_new_ex(libctx);
     if (kG == NULL || kP == NULL || ctx == NULL) {
         SM2err(SM2_F_SM2_ENCRYPT, ERR_R_MALLOC_FAILURE);
         goto done;
@@ -211,7 +214,7 @@ int sm2_encrypt(const EC_KEY *key,
 
     /* X9.63 with no salt happens to match the KDF used in SM2 */
     if (!ecdh_KDF_X9_63(msg_mask, msg_len, x2y2, 2 * field_size, NULL, 0,
-                        digest)) {
+                        digest, libctx, propq)) {
         SM2err(SM2_F_SM2_ENCRYPT, ERR_R_EVP_LIB);
         goto done;
     }
@@ -219,7 +222,12 @@ int sm2_encrypt(const EC_KEY *key,
     for (i = 0; i != msg_len; ++i)
         msg_mask[i] ^= msg[i];
 
-    if (EVP_DigestInit(hash, digest) == 0
+    fetched_digest = EVP_MD_fetch(libctx, EVP_MD_name(digest), propq);
+    if (fetched_digest == NULL) {
+        SM2err(SM2_F_SM2_ENCRYPT, ERR_R_INTERNAL_ERROR);
+        goto done;
+    }
+    if (EVP_DigestInit(hash, fetched_digest) == 0
             || EVP_DigestUpdate(hash, x2y2, field_size) == 0
             || EVP_DigestUpdate(hash, msg, msg_len) == 0
             || EVP_DigestUpdate(hash, x2y2 + field_size, field_size) == 0
@@ -254,6 +262,7 @@ int sm2_encrypt(const EC_KEY *key,
     rc = 1;
 
  done:
+    EVP_MD_free(fetched_digest);
     ASN1_OCTET_STRING_free(ctext_struct.C2);
     ASN1_OCTET_STRING_free(ctext_struct.C3);
     OPENSSL_free(msg_mask);
@@ -288,6 +297,8 @@ int sm2_decrypt(const EC_KEY *key,
     const uint8_t *C3 = NULL;
     int msg_len = 0;
     EVP_MD_CTX *hash = NULL;
+    OPENSSL_CTX *libctx = ec_key_get_libctx(key);
+    const char *propq = ec_key_get0_propq(key);
 
     if (field_size == 0 || hash_size <= 0)
        goto done;
@@ -310,7 +321,7 @@ int sm2_decrypt(const EC_KEY *key,
     C3 = sm2_ctext->C3->data;
     msg_len = sm2_ctext->C2->length;
 
-    ctx = BN_CTX_new();
+    ctx = BN_CTX_new_ex(libctx);
     if (ctx == NULL) {
         SM2err(SM2_F_SM2_DECRYPT, ERR_R_MALLOC_FAILURE);
         goto done;
@@ -352,7 +363,7 @@ int sm2_decrypt(const EC_KEY *key,
     if (BN_bn2binpad(x2, x2y2, field_size) < 0
             || BN_bn2binpad(y2, x2y2 + field_size, field_size) < 0
             || !ecdh_KDF_X9_63(msg_mask, msg_len, x2y2, 2 * field_size, NULL, 0,
-                               digest)) {
+                               digest, libctx, propq)) {
         SM2err(SM2_F_SM2_DECRYPT, ERR_R_INTERNAL_ERROR);
         goto done;
     }