PROV & SIGNATURE: Adapt the RSA signature code for PSS-parameters
[openssl.git] / providers / implementations / signature / rsa.c
index 27e35be..4dc3a89 100644 (file)
@@ -157,9 +157,6 @@ static int rsa_get_md_nid(const EVP_MD *md)
         }
     }
 
-    if (mdnid == NID_undef)
-        ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_DIGEST);
-
  end:
     return mdnid;
 }
@@ -181,47 +178,45 @@ static int rsa_check_padding(int mdnid, int padding)
     return 1;
 }
 
+static int rsa_check_parameters(EVP_MD *md, PROV_RSA_CTX *prsactx)
+{
+    if (prsactx->pad_mode == RSA_PKCS1_PSS_PADDING) {
+        int max_saltlen;
+
+        /* See if minimum salt length exceeds maximum possible */
+        max_saltlen = RSA_size(prsactx->rsa) - EVP_MD_size(md);
+        if ((RSA_bits(prsactx->rsa) & 0x7) == 1)
+            max_saltlen--;
+        if (prsactx->min_saltlen > max_saltlen) {
+            ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_SALT_LENGTH);
+            return 0;
+        }
+    }
+    return 1;
+}
+
 static void *rsa_newctx(void *provctx, const char *propq)
 {
-    PROV_RSA_CTX *prsactx = OPENSSL_zalloc(sizeof(PROV_RSA_CTX));
+    PROV_RSA_CTX *prsactx = NULL;
+    char *propq_copy = NULL;
 
-    if (prsactx == NULL)
+    if ((prsactx = OPENSSL_zalloc(sizeof(PROV_RSA_CTX))) == NULL
+        || (propq != NULL
+            && (propq_copy = OPENSSL_strdup(propq)) == NULL)) {
+        OPENSSL_free(prsactx);
+        ERR_raise(ERR_LIB_PROV, ERR_R_MALLOC_FAILURE);
         return NULL;
+    }
 
     prsactx->libctx = PROV_LIBRARY_CONTEXT_OF(provctx);
     prsactx->flag_allow_md = 1;
-    if (propq != NULL && (prsactx->propq = OPENSSL_strdup(propq)) == NULL) {
-        OPENSSL_free(prsactx);
-        prsactx = NULL;
-        ERR_raise(ERR_LIB_PROV, ERR_R_MALLOC_FAILURE);
-    }
+    prsactx->propq = propq_copy;
     return prsactx;
 }
 
 /* True if PSS parameters are restricted */
 #define rsa_pss_restricted(prsactx) (prsactx->min_saltlen != -1)
 
-static int rsa_signature_init(void *vprsactx, void *vrsa, int operation)
-{
-    PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
-
-    if (prsactx == NULL || vrsa == NULL || !RSA_up_ref(vrsa))
-        return 0;
-
-    RSA_free(prsactx->rsa);
-    prsactx->rsa = vrsa;
-    prsactx->operation = operation;
-    if (RSA_get0_pss_params(prsactx->rsa) != NULL)
-        prsactx->pad_mode = RSA_PKCS1_PSS_PADDING;
-    else
-        prsactx->pad_mode = RSA_PKCS1_PADDING;
-    /* Maximum for sign, auto for verify */
-    prsactx->saltlen = RSA_PSS_SALTLEN_AUTO;
-    prsactx->min_saltlen = -1;
-
-    return 1;
-}
-
 static int rsa_setup_md(PROV_RSA_CTX *ctx, const char *mdname,
                         const char *mdprops)
 {
@@ -235,7 +230,14 @@ static int rsa_setup_md(PROV_RSA_CTX *ctx, const char *mdname,
 
         if (md == NULL
             || md_nid == NID_undef
-            || !rsa_check_padding(md_nid, ctx->pad_mode)) {
+            || !rsa_check_padding(md_nid, ctx->pad_mode)
+            || !rsa_check_parameters(md, ctx)) {
+            if (md == NULL)
+                ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_DIGEST,
+                               "%s could not be fetched", mdname);
+            if (md_nid == NID_undef)
+                ERR_raise_data(ERR_LIB_PROV, PROV_R_DIGEST_NOT_ALLOWED,
+                               "digest=%s", mdname);
             EVP_MD_free(md);
             return 0;
         }
@@ -277,13 +279,82 @@ static int rsa_setup_mgf1_md(PROV_RSA_CTX *ctx, const char *mdname,
     if (ctx->mgf1_mdname[0] != '\0')
         EVP_MD_free(ctx->mgf1_md);
 
-    if ((ctx->mgf1_md = EVP_MD_fetch(ctx->libctx, mdname, mdprops)) == NULL)
+    if ((ctx->mgf1_md = EVP_MD_fetch(ctx->libctx, mdname, mdprops)) == NULL) {
+        ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_DIGEST,
+                       "%s could not be fetched", mdname);
         return 0;
+    }
     OPENSSL_strlcpy(ctx->mgf1_mdname, mdname, sizeof(ctx->mgf1_mdname));
 
     return 1;
 }
 
+static int rsa_signature_init(void *vprsactx, void *vrsa, int operation)
+{
+    PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
+
+    if (prsactx == NULL || vrsa == NULL || !RSA_up_ref(vrsa))
+        return 0;
+
+    RSA_free(prsactx->rsa);
+    prsactx->rsa = vrsa;
+    prsactx->operation = operation;
+
+    /* Maximum for sign, auto for verify */
+    prsactx->saltlen = RSA_PSS_SALTLEN_AUTO;
+    prsactx->min_saltlen = -1;
+
+    switch (RSA_test_flags(prsactx->rsa, RSA_FLAG_TYPE_MASK)) {
+    case RSA_FLAG_TYPE_RSA:
+        prsactx->pad_mode = RSA_PKCS1_PADDING;
+        break;
+    case RSA_FLAG_TYPE_RSASSAPSS:
+        prsactx->pad_mode = RSA_PKCS1_PSS_PADDING;
+
+        {
+            const RSA_PSS_PARAMS_30 *pss =
+                rsa_get0_pss_params_30(prsactx->rsa);
+
+            if (!rsa_pss_params_30_is_unrestricted(pss)) {
+                int md_nid = rsa_pss_params_30_hashalg(pss);
+                int mgf1md_nid = rsa_pss_params_30_maskgenhashalg(pss);
+                int min_saltlen = rsa_pss_params_30_saltlen(pss);
+                const char *mdname, *mgf1mdname;
+
+                mdname = rsa_oaeppss_nid2name(md_nid);
+                mgf1mdname = rsa_oaeppss_nid2name(mgf1md_nid);
+                prsactx->min_saltlen = min_saltlen;
+
+                if (mdname == NULL) {
+                    ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_DIGEST,
+                                   "PSS restrictions lack hash algorithm");
+                    return 0;
+                }
+                if (mgf1mdname == NULL) {
+                    ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_DIGEST,
+                                   "PSS restrictions lack MGF1 hash algorithm");
+                    return 0;
+                }
+
+                strncpy(prsactx->mdname, mdname, sizeof(prsactx->mdname));
+                strncpy(prsactx->mgf1_mdname, mgf1mdname,
+                        sizeof(prsactx->mgf1_mdname));
+                prsactx->saltlen = min_saltlen;
+
+                return rsa_setup_md(prsactx, mdname, prsactx->propq)
+                    && rsa_setup_mgf1_md(prsactx, mgf1mdname, prsactx->propq);
+            }
+        }
+
+        break;
+    default:
+        ERR_raise(ERR_LIB_RSA, PROV_R_OPERATION_NOT_SUPPORTED_FOR_THIS_KEYTYPE);
+        return 0;
+    }
+
+    return 1;
+}
+
 static int setup_tbuf(PROV_RSA_CTX *ctx)
 {
     if (ctx->tbuf != NULL)
@@ -303,7 +374,8 @@ static void clean_tbuf(PROV_RSA_CTX *ctx)
 
 static void free_tbuf(PROV_RSA_CTX *ctx)
 {
-    OPENSSL_clear_free(ctx->tbuf, RSA_size(ctx->rsa));
+    clean_tbuf(ctx);
+    OPENSSL_free(ctx->tbuf);
     ctx->tbuf = NULL;
 }
 
@@ -325,8 +397,11 @@ static int rsa_sign(void *vprsactx, unsigned char *sig, size_t *siglen,
         return 1;
     }
 
-    if (sigsize < (size_t)rsasize)
+    if (sigsize < rsasize) {
+        ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_SIGNATURE_SIZE,
+                       "is %zu, should be at least %zu", sigsize, rsasize);
         return 0;
+    }
 
     if (mdsize != 0) {
         if (tbslen != mdsize) {
@@ -357,7 +432,9 @@ static int rsa_sign(void *vprsactx, unsigned char *sig, size_t *siglen,
         switch (prsactx->pad_mode) {
         case RSA_X931_PADDING:
             if ((size_t)RSA_size(prsactx->rsa) < tbslen + 1) {
-                ERR_raise(ERR_LIB_PROV, PROV_R_KEY_SIZE_TOO_SMALL);
+                ERR_raise_data(ERR_LIB_PROV, PROV_R_KEY_SIZE_TOO_SMALL,
+                               "RSA key size = %d, expected minimum = %d",
+                               RSA_size(prsactx->rsa), tbslen + 1);
                 return 0;
             }
             if (!setup_tbuf(prsactx)) {
@@ -391,14 +468,24 @@ static int rsa_sign(void *vprsactx, unsigned char *sig, size_t *siglen,
                 switch (prsactx->saltlen) {
                 case RSA_PSS_SALTLEN_DIGEST:
                     if (prsactx->min_saltlen > EVP_MD_size(prsactx->md)) {
-                        ERR_raise(ERR_LIB_PROV, PROV_R_PSS_SALTLEN_TOO_SMALL);
+                        ERR_raise_data(ERR_LIB_PROV,
+                                       PROV_R_PSS_SALTLEN_TOO_SMALL,
+                                       "minimum salt length set to %d, "
+                                       "but the digest only gives %d",
+                                       prsactx->min_saltlen,
+                                       EVP_MD_size(prsactx->md));
                         return 0;
                     }
                     /* FALLTHRU */
                 default:
                     if (prsactx->saltlen >= 0
                         && prsactx->saltlen < prsactx->min_saltlen) {
-                        ERR_raise(ERR_LIB_PROV, PROV_R_PSS_SALTLEN_TOO_SMALL);
+                        ERR_raise_data(ERR_LIB_PROV,
+                                       PROV_R_PSS_SALTLEN_TOO_SMALL,
+                                       "minimum salt length set to %d, but the"
+                                       "actual salt length is only set to %d",
+                                       prsactx->min_saltlen,
+                                       prsactx->saltlen);
                         return 0;
                     }
                     break;
@@ -485,7 +572,9 @@ static int rsa_verify_recover(void *vprsactx,
 
             *routlen = ret;
             if (routsize < (size_t)ret) {
-                ERR_raise(ERR_LIB_PROV, PROV_R_OUTPUT_BUFFER_TOO_SMALL);
+                ERR_raise_data(ERR_LIB_PROV, PROV_R_OUTPUT_BUFFER_TOO_SMALL,
+                               "buffer size is %d, should be %d",
+                               routsize, ret);
                 return 0;
             }
             memcpy(rout, prsactx->tbuf, ret);
@@ -610,12 +699,14 @@ static int rsa_digest_signverify_init(void *vprsactx, const char *mdname,
 
     prsactx->flag_allow_md = 0;
     if (!rsa_signature_init(vprsactx, vrsa, operation)
-        || !rsa_setup_md(prsactx, mdname, NULL))
+        || !rsa_setup_md(prsactx, mdname, NULL)) /* TODO RL */
         return 0;
 
     prsactx->mdctx = EVP_MD_CTX_new();
-    if (prsactx->mdctx == NULL)
+    if (prsactx->mdctx == NULL) {
+        ERR_raise(ERR_LIB_PROV, ERR_R_MALLOC_FAILURE);
         goto error;
+    }
 
     if (!EVP_DigestInit_ex(prsactx->mdctx, prsactx->md, NULL))
         goto error;
@@ -643,9 +734,9 @@ static int rsa_digest_signverify_update(void *vprsactx,
 }
 
 static int rsa_digest_sign_init(void *vprsactx, const char *mdname,
-                                const char *props, void *vrsa)
+                                void *vrsa)
 {
-    return rsa_digest_signverify_init(vprsactx, mdname, props, vrsa,
+    return rsa_digest_signverify_init(vprsactx, mdname, vrsa,
                                       EVP_PKEY_OP_SIGN);
 }
 
@@ -678,9 +769,9 @@ static int rsa_digest_sign_final(void *vprsactx, unsigned char *sig,
 }
 
 static int rsa_digest_verify_init(void *vprsactx, const char *mdname,
-                                  const char *props, void *vrsa)
+                                  void *vrsa)
 {
-    return rsa_digest_signverify_init(vprsactx, mdname, props, vrsa,
+    return rsa_digest_signverify_init(vprsactx, mdname, vrsa,
                                       EVP_PKEY_OP_VERIFY);
 }
 
@@ -729,8 +820,10 @@ static void *rsa_dupctx(void *vprsactx)
     PROV_RSA_CTX *dstctx;
 
     dstctx = OPENSSL_zalloc(sizeof(*srcctx));
-    if (dstctx == NULL)
+    if (dstctx == NULL) {
+        ERR_raise(ERR_LIB_PROV, ERR_R_MALLOC_FAILURE);
         return NULL;
+    }
 
     *dstctx = *srcctx;
     dstctx->rsa = NULL;
@@ -888,7 +981,6 @@ static int rsa_set_ctx_params(void *vprsactx, const OSSL_PARAM params[])
                                              &pmdprops, sizeof(mdprops)))
             return 0;
 
-        /* TODO(3.0) PSS check needs more work */
         if (rsa_pss_restricted(prsactx)) {
             /* TODO(3.0) figure out what to do for prsactx->md == NULL */
             if (prsactx->md == NULL || EVP_MD_is_a(prsactx->md, mdname))
@@ -948,9 +1040,6 @@ static int rsa_set_ctx_params(void *vprsactx, const OSSL_PARAM params[])
             }
             if (prsactx->md == NULL
                 && !rsa_setup_md(prsactx, OSSL_DIGEST_NAME_SHA1, NULL)) {
-                ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_DIGEST,
-                               "%s could not be fetched",
-                               OSSL_DIGEST_NAME_SHA1);
                 return 0;
             }
             break;
@@ -966,7 +1055,8 @@ static int rsa_set_ctx_params(void *vprsactx, const OSSL_PARAM params[])
         case RSA_X931_PADDING:
             err_extra_text = "X.931 padding not allowed with RSA-PSS";
         cont:
-            if (RSA_get0_pss_params(prsactx->rsa) == NULL)
+            if (RSA_test_flags(prsactx->rsa,
+                               RSA_FLAG_TYPE_MASK) == RSA_FLAG_TYPE_RSA)
                 break;
             /* FALLTHRU */
         default: