add separate PSS decode function, rename PSS parameters to RSA_PSS_PARAMS
[openssl.git] / crypto / rsa / rsa_ameth.c
index 649291ef7ec8fea520a31dcc3df8cfc3776d8fb9..e25240d3f78119e8cbc07eb8497e740dd346934e 100644 (file)
@@ -265,14 +265,48 @@ static int rsa_priv_print(BIO *bp, const EVP_PKEY *pkey, int indent,
        return do_rsa_print(bp, pkey->pkey.rsa, indent, 1);
        }
 
-static int rsa_pss_param_print(BIO *bp, RSASSA_PSS_PARAMS *pss, int indent)
+static RSA_PSS_PARAMS *rsa_pss_decode(const X509_ALGOR *alg,
+                                       X509_ALGOR **pmaskHash)
+       {
+       const unsigned char *p;
+       int plen;
+       RSA_PSS_PARAMS *pss;
+
+       *pmaskHash = NULL;
+
+       if (!alg->parameter || alg->parameter->type != V_ASN1_SEQUENCE)
+               return NULL;
+       p = alg->parameter->value.sequence->data;
+       plen = alg->parameter->value.sequence->length;
+       pss = d2i_RSA_PSS_PARAMS(NULL, &p, plen);
+
+       if (!pss)
+               return NULL;
+       
+       if (pss->maskGenAlgorithm)
+               {
+               ASN1_TYPE *param = pss->maskGenAlgorithm->parameter;
+               if (OBJ_obj2nid(pss->maskGenAlgorithm->algorithm) == NID_mgf1
+                       && param->type == V_ASN1_SEQUENCE)
+                       {
+                       p = param->value.sequence->data;
+                       plen = param->value.sequence->length;
+                       *pmaskHash = d2i_X509_ALGOR(NULL, &p, plen);
+                       }
+               }
+
+       return pss;
+       }
+
+static int rsa_pss_param_print(BIO *bp, RSA_PSS_PARAMS *pss, 
+                               X509_ALGOR *maskHash, int indent)
        {
        int rv = 0;
-       X509_ALGOR *maskHash = NULL;
        if (!pss)
                {
                if (BIO_puts(bp, " (INVALID PSS PARAMETERS)\n") <= 0)
                        return 0;
+               return 1;
                }
        if (BIO_puts(bp, "\n") <= 0)
                goto err;
@@ -299,18 +333,16 @@ static int rsa_pss_param_print(BIO *bp, RSASSA_PSS_PARAMS *pss, int indent)
                        goto err;
        if (pss->maskGenAlgorithm)
                {
-               ASN1_TYPE *param = pss->maskGenAlgorithm->parameter;
-               if (param->type == V_ASN1_SEQUENCE)
-                       {
-                       const unsigned char *p = param->value.sequence->data;
-                       int plen = param->value.sequence->length;
-                       maskHash = d2i_X509_ALGOR(NULL, &p, plen);
-                       }
                if (i2a_ASN1_OBJECT(bp, pss->maskGenAlgorithm->algorithm) <= 0)
                        goto err;
                if (BIO_puts(bp, " with ") <= 0)
                        goto err;
-               if (i2a_ASN1_OBJECT(bp, maskHash->algorithm) <= 0)
+               if (maskHash)
+                       {
+                       if (i2a_ASN1_OBJECT(bp, maskHash->algorithm) <= 0)
+                       goto err;
+                       }
+               else if (BIO_puts(bp, "INVALID") <= 0)
                        goto err;
                }
        else if (BIO_puts(bp, "mgf1 with sha1 (default)") <= 0)
@@ -346,9 +378,6 @@ static int rsa_pss_param_print(BIO *bp, RSASSA_PSS_PARAMS *pss, int indent)
        rv = 1;
 
        err:
-       if (maskHash)
-               X509_ALGOR_free(maskHash);
-       RSASSA_PSS_PARAMS_free(pss);
        return rv;
 
        }
@@ -359,15 +388,16 @@ static int rsa_sig_print(BIO *bp, const X509_ALGOR *sigalg,
        {
        if (OBJ_obj2nid(sigalg->algorithm) == NID_rsassaPss)
                {
-               RSASSA_PSS_PARAMS *pss = NULL;
-               ASN1_TYPE *param = sigalg->parameter;
-               if (param && param->type == V_ASN1_SEQUENCE)
-                       {
-                       const unsigned char *p = param->value.sequence->data;
-                       int plen = param->value.sequence->length;
-                       pss = d2i_RSASSA_PSS_PARAMS(NULL, &p, plen);
-                       }
-               if (!rsa_pss_param_print(bp, pss, indent))
+               int rv;
+               RSA_PSS_PARAMS *pss;
+               X509_ALGOR *maskHash;
+               pss = rsa_pss_decode(sigalg, &maskHash);
+               rv = rsa_pss_param_print(bp, pss, maskHash, indent);
+               if (pss)
+                       RSA_PSS_PARAMS_free(pss);
+               if (maskHash)
+                       X509_ALGOR_free(maskHash);
+               if (!rv)
                        return 0;
                }