corrected fix to PR#2711 and also cover mime_param_cmp
[openssl.git] / crypto / asn1 / d2i_pr.c
index b4e47d48197f5080e9270d3cc5b8a702fdc196d5..28289447772c061736055ee88472c7124e01bf95 100644 (file)
 #include <openssl/bn.h>
 #include <openssl/evp.h>
 #include <openssl/objects.h>
+#ifndef OPENSSL_NO_ENGINE
+#include <openssl/engine.h>
+#endif
+#include <openssl/x509.h>
 #include <openssl/asn1.h>
 #include "asn1_locl.h"
 
@@ -77,26 +81,43 @@ EVP_PKEY *d2i_PrivateKey(int type, EVP_PKEY **a, const unsigned char **pp,
                        return(NULL);
                        }
                }
-       else    ret= *a;
-
-       ret->save_type=type;
-       ret->type=EVP_PKEY_type(type);
-       ret->ameth = EVP_PKEY_asn1_find(type);
-       if (ret->ameth)
+       else
                {
-               if (!ret->ameth->old_priv_decode ||
-                       !ret->ameth->old_priv_decode(ret, pp, length))
+               ret= *a;
+#ifndef OPENSSL_NO_ENGINE
+               if (ret->engine)
                        {
-                       ASN1err(ASN1_F_D2I_PRIVATEKEY,ERR_R_ASN1_LIB);
-                       goto err;
+                       ENGINE_finish(ret->engine);
+                       ret->engine = NULL;
                        }
+#endif
                }
-       else
+
+       if (!EVP_PKEY_set_type(ret, type))
                {
                ASN1err(ASN1_F_D2I_PRIVATEKEY,ASN1_R_UNKNOWN_PUBLIC_KEY_TYPE);
                goto err;
-               /* break; */
                }
+
+       if (!ret->ameth->old_priv_decode ||
+                       !ret->ameth->old_priv_decode(ret, pp, length))
+               {
+               if (ret->ameth->priv_decode) 
+                       {
+                       PKCS8_PRIV_KEY_INFO *p8=NULL;
+                       p8=d2i_PKCS8_PRIV_KEY_INFO(NULL,pp,length);
+                       if (!p8) goto err;
+                       EVP_PKEY_free(ret);
+                       ret = EVP_PKCS82PKEY(p8);
+                       PKCS8_PRIV_KEY_INFO_free(p8);
+
+                       } 
+               else 
+                       {
+                       ASN1err(ASN1_F_D2I_PRIVATEKEY,ERR_R_ASN1_LIB);
+                       goto err;
+                       }
+               }       
        if (a != NULL) (*a)=ret;
        return(ret);
 err:
@@ -117,8 +138,7 @@ EVP_PKEY *d2i_AutoPrivateKey(EVP_PKEY **a, const unsigned char **pp,
         * by analyzing it we can determine the passed structure: this
         * assumes the input is surrounded by an ASN1 SEQUENCE.
         */
-       inkey = d2i_ASN1_SET_OF_ASN1_TYPE(NULL, &p, length, d2i_ASN1_TYPE, 
-                       ASN1_TYPE_free, V_ASN1_SEQUENCE, V_ASN1_UNIVERSAL);
+       inkey = d2i_ASN1_SEQUENCE_ANY(NULL, &p, length);
        /* Since we only need to discern "traditional format" RSA and DSA
         * keys we can just count the elements.
          */
@@ -126,6 +146,24 @@ EVP_PKEY *d2i_AutoPrivateKey(EVP_PKEY **a, const unsigned char **pp,
                keytype = EVP_PKEY_DSA;
        else if (sk_ASN1_TYPE_num(inkey) == 4)
                keytype = EVP_PKEY_EC;
+       else if (sk_ASN1_TYPE_num(inkey) == 3)  
+               { /* This seems to be PKCS8, not traditional format */
+                       PKCS8_PRIV_KEY_INFO *p8 = d2i_PKCS8_PRIV_KEY_INFO(NULL,pp,length);
+                       EVP_PKEY *ret;
+
+                       sk_ASN1_TYPE_pop_free(inkey, ASN1_TYPE_free);
+                       if (!p8) 
+                               {
+                               ASN1err(ASN1_F_D2I_AUTOPRIVATEKEY,ASN1_R_UNSUPPORTED_PUBLIC_KEY_TYPE);
+                               return NULL;
+                               }
+                       ret = EVP_PKCS82PKEY(p8);
+                       PKCS8_PRIV_KEY_INFO_free(p8);
+                       if (a) {
+                               *a = ret;
+                       }       
+                       return ret;
+               }
        else keytype = EVP_PKEY_RSA;
        sk_ASN1_TYPE_pop_free(inkey, ASN1_TYPE_free);
        return d2i_PrivateKey(keytype, a, pp, length);