+static int rsa_signature_init(void *vprsactx, void *vrsa, int operation)
+{
+ PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
+
+ if (prsactx == NULL || vrsa == NULL || !RSA_up_ref(vrsa))
+ return 0;
+
+ RSA_free(prsactx->rsa);
+ prsactx->rsa = vrsa;
+ prsactx->operation = operation;
+
+ /* Maximum for sign, auto for verify */
+ prsactx->saltlen = RSA_PSS_SALTLEN_AUTO;
+ prsactx->min_saltlen = -1;
+
+ switch (RSA_test_flags(prsactx->rsa, RSA_FLAG_TYPE_MASK)) {
+ case RSA_FLAG_TYPE_RSA:
+ prsactx->pad_mode = RSA_PKCS1_PADDING;
+ break;
+ case RSA_FLAG_TYPE_RSASSAPSS:
+ prsactx->pad_mode = RSA_PKCS1_PSS_PADDING;
+
+ {
+ const RSA_PSS_PARAMS_30 *pss =
+ rsa_get0_pss_params_30(prsactx->rsa);
+
+ if (!rsa_pss_params_30_is_unrestricted(pss)) {
+ int md_nid = rsa_pss_params_30_hashalg(pss);
+ int mgf1md_nid = rsa_pss_params_30_maskgenhashalg(pss);
+ int min_saltlen = rsa_pss_params_30_saltlen(pss);
+ const char *mdname, *mgf1mdname;
+
+ mdname = rsa_oaeppss_nid2name(md_nid);
+ mgf1mdname = rsa_oaeppss_nid2name(mgf1md_nid);
+ prsactx->min_saltlen = min_saltlen;
+
+ if (mdname == NULL) {
+ ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_DIGEST,
+ "PSS restrictions lack hash algorithm");
+ return 0;
+ }
+ if (mgf1mdname == NULL) {
+ ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_DIGEST,
+ "PSS restrictions lack MGF1 hash algorithm");
+ return 0;
+ }
+
+ strncpy(prsactx->mdname, mdname, sizeof(prsactx->mdname));
+ strncpy(prsactx->mgf1_mdname, mgf1mdname,
+ sizeof(prsactx->mgf1_mdname));
+ prsactx->saltlen = min_saltlen;
+
+ return rsa_setup_md(prsactx, mdname, prsactx->propq)
+ && rsa_setup_mgf1_md(prsactx, mgf1mdname, prsactx->propq);
+ }
+ }
+
+ break;
+ default:
+ ERR_raise(ERR_LIB_RSA, PROV_R_OPERATION_NOT_SUPPORTED_FOR_THIS_KEYTYPE);
+ return 0;
+ }
+
+ return 1;
+}
+