RSA: Better synchronisation between ASN1 PSS params and RSA_PSS_PARAMS_30
[openssl.git] / crypto / rsa / rsa_ameth.c
index 22c06a2139e0e8a939e7fbd5f2fca61641e64455..f5911ad233677391430c5abc16ca5edd12666924 100644 (file)
@@ -34,6 +34,7 @@ static int rsa_cms_encrypt(CMS_RecipientInfo *ri);
 #endif
 
 static RSA_PSS_PARAMS *rsa_pss_decode(const X509_ALGOR *alg);
+static int rsa_sync_to_pss_params_30(RSA *rsa);
 
 /* Set any parameters associated with pkey */
 static int rsa_param_encode(const EVP_PKEY *pkey,
@@ -78,6 +79,8 @@ static int rsa_param_decode(RSA *rsa, const X509_ALGOR *alg)
     rsa->pss = rsa_pss_decode(alg);
     if (rsa->pss == NULL)
         return 0;
+    if (!rsa_sync_to_pss_params_30(rsa))
+        return 0;
     return 1;
 }
 
@@ -118,6 +121,20 @@ static int rsa_pub_decode(EVP_PKEY *pkey, const X509_PUBKEY *pubkey)
         RSA_free(rsa);
         return 0;
     }
+
+    RSA_clear_flags(rsa, RSA_FLAG_TYPE_MASK);
+    switch (pkey->ameth->pkey_id) {
+    case EVP_PKEY_RSA:
+        RSA_set_flags(rsa, RSA_FLAG_TYPE_RSA);
+        break;
+    case EVP_PKEY_RSA_PSS:
+        RSA_set_flags(rsa, RSA_FLAG_TYPE_RSASSAPSS);
+        break;
+    default:
+        /* Leave the type bits zero */
+        break;
+    }
+
     if (!EVP_PKEY_assign(pkey, pkey->ameth->pkey_id, rsa)) {
         RSA_free(rsa);
         return 0;
@@ -729,9 +746,34 @@ static int rsa_pss_to_ctx(EVP_MD_CTX *ctx, EVP_PKEY_CTX *pkctx,
     return rv;
 }
 
-int rsa_pss_get_param(const RSA_PSS_PARAMS *pss, const EVP_MD **pmd,
-                      const EVP_MD **pmgf1md, int *psaltlen)
+static int rsa_pss_verify_param(const EVP_MD **pmd, const EVP_MD **pmgf1md,
+                                int *psaltlen, int *ptrailerField)
 {
+    if (psaltlen != NULL && *psaltlen < 0) {
+        ERR_raise(ERR_LIB_RSA, RSA_R_INVALID_SALT_LENGTH);
+        return 0;
+    }
+    /*
+     * low-level routines support only trailer field 0xbc (value 1) and
+     * PKCS#1 says we should reject any other value anyway.
+     */
+    if (ptrailerField != NULL && *ptrailerField != 1) {
+        ERR_raise(ERR_LIB_RSA, RSA_R_INVALID_TRAILER);
+        return 0;
+    }
+    return 1;
+}
+
+static int rsa_pss_get_param_unverified(const RSA_PSS_PARAMS *pss,
+                                        const EVP_MD **pmd,
+                                        const EVP_MD **pmgf1md,
+                                        int *psaltlen, int *ptrailerField)
+{
+    RSA_PSS_PARAMS_30 pss_params;
+
+    /* Get the defaults from the ONE place */
+    (void)rsa_pss_params_30_set_defaults(&pss_params);
+
     if (pss == NULL)
         return 0;
     *pmd = rsa_algor_to_md(pss->hashAlgorithm);
@@ -740,25 +782,65 @@ int rsa_pss_get_param(const RSA_PSS_PARAMS *pss, const EVP_MD **pmd,
     *pmgf1md = rsa_algor_to_md(pss->maskHash);
     if (*pmgf1md == NULL)
         return 0;
-    if (pss->saltLength) {
+    if (pss->saltLength)
         *psaltlen = ASN1_INTEGER_get(pss->saltLength);
-        if (*psaltlen < 0) {
-            RSAerr(RSA_F_RSA_PSS_GET_PARAM, RSA_R_INVALID_SALT_LENGTH);
-            return 0;
-        }
-    } else {
-        *psaltlen = 20;
-    }
+    else
+        *psaltlen = rsa_pss_params_30_saltlen(&pss_params);
+    if (pss->trailerField)
+        *ptrailerField = ASN1_INTEGER_get(pss->trailerField);
+    else
+        *ptrailerField = rsa_pss_params_30_trailerfield(&pss_params);;
+
+    return 1;
+}
 
+int rsa_pss_get_param(const RSA_PSS_PARAMS *pss, const EVP_MD **pmd,
+                      const EVP_MD **pmgf1md, int *psaltlen)
+{
     /*
-     * low-level routines support only trailer field 0xbc (value 1) and
-     * PKCS#1 says we should reject any other value anyway.
+     * Callers do not care about the trailer field, and yet, we must
+     * pass it from get_param to verify_param, since the latter checks
+     * its value.
+     *
+     * When callers start caring, it's a simple thing to add another
+     * argument to this function.
      */
-    if (pss->trailerField && ASN1_INTEGER_get(pss->trailerField) != 1) {
-        RSAerr(RSA_F_RSA_PSS_GET_PARAM, RSA_R_INVALID_TRAILER);
-        return 0;
-    }
+    int trailerField = 0;
+
+    return rsa_pss_get_param_unverified(pss, pmd, pmgf1md, psaltlen,
+                                        &trailerField)
+        && rsa_pss_verify_param(pmd, pmgf1md, psaltlen, &trailerField);
+}
+
+static int rsa_sync_to_pss_params_30(RSA *rsa)
+{
+    if (rsa != NULL && rsa->pss != NULL) {
+        const EVP_MD *md = NULL, *mgf1md = NULL;
+        int md_nid, mgf1md_nid, saltlen, trailerField;
+        RSA_PSS_PARAMS_30 pss_params;
 
+        /*
+         * We don't care about the validity of the fields here, we just
+         * want to synchronise values.  Verifying here makes it impossible
+         * to even read a key with invalid values, making it hard to test
+         * a bad situation.
+         *
+         * Other routines use rsa_pss_get_param(), so the values will be
+         * checked, eventually.
+         */
+        if (!rsa_pss_get_param_unverified(rsa->pss, &md, &mgf1md,
+                                          &saltlen, &trailerField))
+            return 0;
+        md_nid = EVP_MD_type(md);
+        mgf1md_nid = EVP_MD_type(mgf1md);
+        if (!rsa_pss_params_30_set_defaults(&pss_params)
+            || !rsa_pss_params_30_set_hashalg(&pss_params, md_nid)
+            || !rsa_pss_params_30_set_maskgenhashalg(&pss_params, mgf1md_nid)
+            || !rsa_pss_params_30_set_saltlen(&pss_params, saltlen)
+            || !rsa_pss_params_30_set_trailerfield(&pss_params, trailerField))
+            return 0;
+        rsa->pss_params = pss_params;
+    }
     return 1;
 }