Cache maskHash parameter
authorDr. Stephen Henson <steve@openssl.org>
Thu, 24 Nov 2016 18:51:54 +0000 (18:51 +0000)
committerDr. Stephen Henson <steve@openssl.org>
Sun, 8 Jan 2017 01:42:46 +0000 (01:42 +0000)
Store hash algorithm used for MGF1 masks in PSS and OAEP modes in PSS and
OAEP parameter structure: this avoids the need to decode part of the ASN.1
structure every time it is used.

Reviewed-by: Rich Salz <rsalz@openssl.org>
Reviewed-by: Matt Caswell <matt@openssl.org>
(Merged from https://github.com/openssl/openssl/pull/2177)

crypto/rsa/rsa_ameth.c
crypto/rsa/rsa_asn1.c
include/openssl/rsa.h

index 5694140..6a7a088 100644 (file)
@@ -188,37 +188,36 @@ static int rsa_priv_print(BIO *bp, const EVP_PKEY *pkey, int indent,
     return do_rsa_print(bp, pkey->pkey.rsa, indent, 1);
 }
 
-/* Given an MGF1 Algorithm ID decode to an Algorithm Identifier */
 static X509_ALGOR *rsa_mgf1_decode(X509_ALGOR *alg)
 {
-    if (alg == NULL)
-        return NULL;
     if (OBJ_obj2nid(alg->algorithm) != NID_mgf1)
         return NULL;
     return ASN1_TYPE_unpack_sequence(ASN1_ITEM_rptr(X509_ALGOR),
                                      alg->parameter);
 }
 
-static RSA_PSS_PARAMS *rsa_pss_decode(const X509_ALGOR *alg,
-                                      X509_ALGOR **pmaskHash)
+static RSA_PSS_PARAMS *rsa_pss_decode(const X509_ALGOR *alg)
 {
     RSA_PSS_PARAMS *pss;
 
-    *pmaskHash = NULL;
-
     pss = ASN1_TYPE_unpack_sequence(ASN1_ITEM_rptr(RSA_PSS_PARAMS),
                                     alg->parameter);
 
-    if (!pss)
+    if (pss == NULL)
         return NULL;
 
-    *pmaskHash = rsa_mgf1_decode(pss->maskGenAlgorithm);
+    if (pss->maskGenAlgorithm != NULL) {
+        pss->maskHash = rsa_mgf1_decode(pss->maskGenAlgorithm);
+        if (pss->maskHash == NULL) {
+            RSA_PSS_PARAMS_free(pss);
+            return NULL;
+        }
+    }
 
     return pss;
 }
 
-static int rsa_pss_param_print(BIO *bp, RSA_PSS_PARAMS *pss,
-                               X509_ALGOR *maskHash, int indent)
+static int rsa_pss_param_print(BIO *bp, RSA_PSS_PARAMS *pss, int indent)
 {
     int rv = 0;
     if (!pss) {
@@ -252,8 +251,8 @@ static int rsa_pss_param_print(BIO *bp, RSA_PSS_PARAMS *pss,
             goto err;
         if (BIO_puts(bp, " with ") <= 0)
             goto err;
-        if (maskHash) {
-            if (i2a_ASN1_OBJECT(bp, maskHash->algorithm) <= 0)
+        if (pss->maskHash) {
+            if (i2a_ASN1_OBJECT(bp, pss->maskHash->algorithm) <= 0)
                 goto err;
         } else if (BIO_puts(bp, "INVALID") <= 0)
             goto err;
@@ -296,11 +295,9 @@ static int rsa_sig_print(BIO *bp, const X509_ALGOR *sigalg,
     if (OBJ_obj2nid(sigalg->algorithm) == NID_rsassaPss) {
         int rv;
         RSA_PSS_PARAMS *pss;
-        X509_ALGOR *maskHash;
-        pss = rsa_pss_decode(sigalg, &maskHash);
-        rv = rsa_pss_param_print(bp, pss, maskHash, indent);
+        pss = rsa_pss_decode(sigalg);
+        rv = rsa_pss_param_print(bp, pss, indent);
         RSA_PSS_PARAMS_free(pss);
-        X509_ALGOR_free(maskHash);
         if (!rv)
             return 0;
     } else if (!sig && BIO_puts(bp, "\n") <= 0)
@@ -410,29 +407,6 @@ static const EVP_MD *rsa_algor_to_md(X509_ALGOR *alg)
     return md;
 }
 
-/* convert MGF1 algorithm ID to EVP_MD, default SHA1 */
-static const EVP_MD *rsa_mgf1_to_md(X509_ALGOR *alg, X509_ALGOR *maskHash)
-{
-    const EVP_MD *md;
-    if (!alg)
-        return EVP_sha1();
-    /* Check mask and lookup mask hash algorithm */
-    if (OBJ_obj2nid(alg->algorithm) != NID_mgf1) {
-        RSAerr(RSA_F_RSA_MGF1_TO_MD, RSA_R_UNSUPPORTED_MASK_ALGORITHM);
-        return NULL;
-    }
-    if (!maskHash) {
-        RSAerr(RSA_F_RSA_MGF1_TO_MD, RSA_R_UNSUPPORTED_MASK_PARAMETER);
-        return NULL;
-    }
-    md = EVP_get_digestbyobj(maskHash->algorithm);
-    if (md == NULL) {
-        RSAerr(RSA_F_RSA_MGF1_TO_MD, RSA_R_UNKNOWN_MASK_DIGEST);
-        return NULL;
-    }
-    return md;
-}
-
 /*
  * Convert EVP_PKEY_CTX is PSS mode into corresponding algorithm parameter,
  * suitable for setting an AlgorithmIdentifier.
@@ -497,20 +471,19 @@ static int rsa_pss_to_ctx(EVP_MD_CTX *ctx, EVP_PKEY_CTX *pkctx,
     int saltlen;
     const EVP_MD *mgf1md = NULL, *md = NULL;
     RSA_PSS_PARAMS *pss;
-    X509_ALGOR *maskHash;
     /* Sanity check: make sure it is PSS */
     if (OBJ_obj2nid(sigalg->algorithm) != NID_rsassaPss) {
         RSAerr(RSA_F_RSA_PSS_TO_CTX, RSA_R_UNSUPPORTED_SIGNATURE_TYPE);
         return -1;
     }
     /* Decode PSS parameters */
-    pss = rsa_pss_decode(sigalg, &maskHash);
+    pss = rsa_pss_decode(sigalg);
 
     if (pss == NULL) {
         RSAerr(RSA_F_RSA_PSS_TO_CTX, RSA_R_INVALID_PSS_PARAMETERS);
         goto err;
     }
-    mgf1md = rsa_mgf1_to_md(pss->maskGenAlgorithm, maskHash);
+    mgf1md = rsa_algor_to_md(pss->maskHash);
     if (!mgf1md)
         goto err;
     md = rsa_algor_to_md(pss->hashAlgorithm);
@@ -568,7 +541,6 @@ static int rsa_pss_to_ctx(EVP_MD_CTX *ctx, EVP_PKEY_CTX *pkctx,
 
  err:
     RSA_PSS_PARAMS_free(pss);
-    X509_ALGOR_free(maskHash);
     return rv;
 }
 
@@ -674,22 +646,24 @@ static int rsa_item_sign(EVP_MD_CTX *ctx, const ASN1_ITEM *it, void *asn,
 }
 
 #ifndef OPENSSL_NO_CMS
-static RSA_OAEP_PARAMS *rsa_oaep_decode(const X509_ALGOR *alg,
-                                        X509_ALGOR **pmaskHash)
+static RSA_OAEP_PARAMS *rsa_oaep_decode(const X509_ALGOR *alg)
 {
-    RSA_OAEP_PARAMS *pss;
-
-    *pmaskHash = NULL;
+    RSA_OAEP_PARAMS *oaep;
 
-    pss = ASN1_TYPE_unpack_sequence(ASN1_ITEM_rptr(RSA_OAEP_PARAMS),
+    oaep = ASN1_TYPE_unpack_sequence(ASN1_ITEM_rptr(RSA_OAEP_PARAMS),
                                     alg->parameter);
 
-    if (!pss)
+    if (oaep == NULL)
         return NULL;
 
-    *pmaskHash = rsa_mgf1_decode(pss->maskGenFunc);
-
-    return pss;
+    if (oaep->maskGenFunc != NULL) {
+        oaep->maskHash = rsa_mgf1_decode(oaep->maskGenFunc);
+        if (oaep->maskHash == NULL) {
+            RSA_OAEP_PARAMS_free(oaep);
+            return NULL;
+        }
+    }
+    return oaep;
 }
 
 static int rsa_cms_decrypt(CMS_RecipientInfo *ri)
@@ -702,7 +676,6 @@ static int rsa_cms_decrypt(CMS_RecipientInfo *ri)
     int labellen = 0;
     const EVP_MD *mgf1md = NULL, *md = NULL;
     RSA_OAEP_PARAMS *oaep;
-    X509_ALGOR *maskHash;
     pkctx = CMS_RecipientInfo_get0_pkey_ctx(ri);
     if (!pkctx)
         return 0;
@@ -716,14 +689,14 @@ static int rsa_cms_decrypt(CMS_RecipientInfo *ri)
         return -1;
     }
     /* Decode OAEP parameters */
-    oaep = rsa_oaep_decode(cmsalg, &maskHash);
+    oaep = rsa_oaep_decode(cmsalg);
 
     if (oaep == NULL) {
         RSAerr(RSA_F_RSA_CMS_DECRYPT, RSA_R_INVALID_OAEP_PARAMETERS);
         goto err;
     }
 
-    mgf1md = rsa_mgf1_to_md(oaep->maskGenFunc, maskHash);
+    mgf1md = rsa_algor_to_md(oaep->maskHash);
     if (!mgf1md)
         goto err;
     md = rsa_algor_to_md(oaep->hashFunc);
@@ -760,7 +733,6 @@ static int rsa_cms_decrypt(CMS_RecipientInfo *ri)
 
  err:
     RSA_OAEP_PARAMS_free(oaep);
-    X509_ALGOR_free(maskHash);
     return rv;
 }
 
index 20f8ebf..626a479 100644 (file)
@@ -49,20 +49,42 @@ ASN1_SEQUENCE_cb(RSAPublicKey, rsa_cb) = {
         ASN1_SIMPLE(RSA, e, BIGNUM),
 } ASN1_SEQUENCE_END_cb(RSA, RSAPublicKey)
 
-ASN1_SEQUENCE(RSA_PSS_PARAMS) = {
+/* Free up maskHash */
+static int rsa_pss_cb(int operation, ASN1_VALUE **pval, const ASN1_ITEM *it,
+                      void *exarg)
+{
+    if (operation == ASN1_OP_FREE_PRE) {
+        RSA_PSS_PARAMS *pss = (RSA_PSS_PARAMS *)*pval;
+        X509_ALGOR_free(pss->maskHash);
+    }
+    return 1;
+}
+
+ASN1_SEQUENCE_cb(RSA_PSS_PARAMS, rsa_pss_cb) = {
         ASN1_EXP_OPT(RSA_PSS_PARAMS, hashAlgorithm, X509_ALGOR,0),
         ASN1_EXP_OPT(RSA_PSS_PARAMS, maskGenAlgorithm, X509_ALGOR,1),
         ASN1_EXP_OPT(RSA_PSS_PARAMS, saltLength, ASN1_INTEGER,2),
         ASN1_EXP_OPT(RSA_PSS_PARAMS, trailerField, ASN1_INTEGER,3)
-} ASN1_SEQUENCE_END(RSA_PSS_PARAMS)
+} ASN1_SEQUENCE_END_cb(RSA_PSS_PARAMS, RSA_PSS_PARAMS)
 
 IMPLEMENT_ASN1_FUNCTIONS(RSA_PSS_PARAMS)
 
-ASN1_SEQUENCE(RSA_OAEP_PARAMS) = {
+/* Free up maskHash */
+static int rsa_oaep_cb(int operation, ASN1_VALUE **pval, const ASN1_ITEM *it,
+                       void *exarg)
+{
+    if (operation == ASN1_OP_FREE_PRE) {
+        RSA_OAEP_PARAMS *oaep = (RSA_OAEP_PARAMS *)*pval;
+        X509_ALGOR_free(oaep->maskHash);
+    }
+    return 1;
+}
+
+ASN1_SEQUENCE_cb(RSA_OAEP_PARAMS, rsa_oaep_cb) = {
         ASN1_EXP_OPT(RSA_OAEP_PARAMS, hashFunc, X509_ALGOR, 0),
         ASN1_EXP_OPT(RSA_OAEP_PARAMS, maskGenFunc, X509_ALGOR, 1),
         ASN1_EXP_OPT(RSA_OAEP_PARAMS, pSourceFunc, X509_ALGOR, 2),
-} ASN1_SEQUENCE_END(RSA_OAEP_PARAMS)
+} ASN1_SEQUENCE_END_cb(RSA_OAEP_PARAMS, RSA_OAEP_PARAMS)
 
 IMPLEMENT_ASN1_FUNCTIONS(RSA_OAEP_PARAMS)
 
index d97d6e0..5d4ab4e 100644 (file)
@@ -239,6 +239,8 @@ typedef struct rsa_pss_params_st {
     X509_ALGOR *maskGenAlgorithm;
     ASN1_INTEGER *saltLength;
     ASN1_INTEGER *trailerField;
+    /* Decoded hash algorithm from maskGenAlgorithm */
+    X509_ALGOR *maskHash;
 } RSA_PSS_PARAMS;
 
 DECLARE_ASN1_FUNCTIONS(RSA_PSS_PARAMS)
@@ -247,6 +249,8 @@ typedef struct rsa_oaep_params_st {
     X509_ALGOR *hashFunc;
     X509_ALGOR *maskGenFunc;
     X509_ALGOR *pSourceFunc;
+    /* Decoded hash algorithm from maskGenFunc */
+    X509_ALGOR *maskHash;
 } RSA_OAEP_PARAMS;
 
 DECLARE_ASN1_FUNCTIONS(RSA_OAEP_PARAMS)