Improve the performance of d2i_AutoPrivateKey and friends
[openssl.git] / crypto / asn1 / d2i_pr.c
index 720b7fd6c0508cc95b01130c49bbef4e62b6ceda..c49f22b3e090817abaf0a61f145242ec133a4c10 100644 (file)
@@ -22,6 +22,7 @@
 #include "crypto/asn1.h"
 #include "crypto/evp.h"
 #include "internal/asn1.h"
+#include "internal/sizes.h"
 
 static EVP_PKEY *
 d2i_PrivateKey_decoder(int keytype, EVP_PKEY **a, const unsigned char **pp,
@@ -32,8 +33,12 @@ d2i_PrivateKey_decoder(int keytype, EVP_PKEY **a, const unsigned char **pp,
     EVP_PKEY *pkey = NULL, *bak_a = NULL;
     EVP_PKEY **ppkey = &pkey;
     const char *key_name = NULL;
-    const char *input_structures[] = { "type-specific", "PrivateKeyInfo", NULL };
-    int i, ret;
+    char keytypebuf[OSSL_MAX_NAME_SIZE];
+    int ret;
+    const unsigned char *p = *pp;
+    const char *structure;
+    PKCS8_PRIV_KEY_INFO *p8info;
+    const ASN1_OBJECT *algoid;
 
     if (keytype != EVP_PKEY_NONE) {
         key_name = evp_pkey_type2name(keytype);
@@ -41,34 +46,42 @@ d2i_PrivateKey_decoder(int keytype, EVP_PKEY **a, const unsigned char **pp,
             return NULL;
     }
 
-    for (i = 0;  i < (int)OSSL_NELEM(input_structures); ++i) {
-        const unsigned char *p = *pp;
+    /* This is just a probe. It might fail, so we ignore errors */
+    ERR_set_mark();
+    p8info = d2i_PKCS8_PRIV_KEY_INFO(NULL, pp, len);
+    ERR_pop_to_mark();
+    if (p8info != NULL) {
+        if (key_name == NULL
+                && PKCS8_pkey_get0(&algoid, NULL, NULL, NULL, p8info)
+                && OBJ_obj2txt(keytypebuf, sizeof(keytypebuf), algoid, 0))
+            key_name = keytypebuf;
+        structure = "PrivateKeyInfo";
+        PKCS8_PRIV_KEY_INFO_free(p8info);
+    } else {
+        structure = "type-specific";
+    }
+    *pp = p;
 
-        if (a != NULL && (bak_a = *a) != NULL)
-            ppkey = a;
-        dctx = OSSL_DECODER_CTX_new_for_pkey(ppkey, "DER",
-                                             input_structures[i], key_name,
-                                             EVP_PKEY_KEYPAIR, libctx, propq);
+    if (a != NULL && (bak_a = *a) != NULL)
+        ppkey = a;
+    dctx = OSSL_DECODER_CTX_new_for_pkey(ppkey, "DER", structure, key_name,
+                                         EVP_PKEY_KEYPAIR, libctx, propq);
+    if (a != NULL)
+        *a = bak_a;
+    if (dctx == NULL)
+        goto err;
+
+    ret = OSSL_DECODER_from_data(dctx, pp, &len);
+    OSSL_DECODER_CTX_free(dctx);
+    if (ret
+        && *ppkey != NULL
+        && evp_keymgmt_util_has(*ppkey, OSSL_KEYMGMT_SELECT_PRIVATE_KEY)) {
         if (a != NULL)
-            *a = bak_a;
-        if (dctx == NULL)
-            continue;
-
-        ret = OSSL_DECODER_from_data(dctx, pp, &len);
-        OSSL_DECODER_CTX_free(dctx);
-        if (ret) {
-            if (*ppkey != NULL
-                && evp_keymgmt_util_has(*ppkey, OSSL_KEYMGMT_SELECT_PRIVATE_KEY)) {
-                if (a != NULL)
-                    *a = *ppkey;
-                return *ppkey;
-            }
-            *pp = p;
-            goto err;
-        }
+            *a = *ppkey;
+        return *ppkey;
     }
-    /* Fall through to error if all decodes failed */
-err:
+
+ err:
     if (ppkey != a)
         EVP_PKEY_free(*ppkey);
     return NULL;