EVP: Fix method to determine if a PKEY is legacy or not
[openssl.git] / crypto / evp / pmeth_lib.c
index 6bbe025bfc8ff7a4e1520a2a8084ca64b2aca84a..8b49baf6abdba93e2e0eb56ccaf7a728b797dc2c 100644 (file)
@@ -15,6 +15,7 @@
 #include <openssl/x509v3.h>
 #include <openssl/core_names.h>
 #include <openssl/dh.h>
+#include <openssl/rsa.h>
 #include "internal/cryptlib.h"
 #include "crypto/asn1.h"
 #include "crypto/evp.h"
@@ -126,11 +127,28 @@ static EVP_PKEY_CTX *int_ctx_new(OPENSSL_CTX *libctx,
     if (pkey == NULL && e == NULL && id == -1)
         goto common;
 
+    /*
+     * If the key doesn't contain anything legacy, then it must be provided,
+     * so we extract the necessary information and use that.
+     */
+    if (pkey != NULL && pkey->ameth == NULL) {
+        /* If we have an engine, something went wrong somewhere... */
+        if (!ossl_assert(e == NULL))
+            return NULL;
+        name = evp_first_name(pkey->pkeys[0].keymgmt->prov,
+                              pkey->pkeys[0].keymgmt->name_id);
+        /*
+         * TODO: I wonder if the EVP_PKEY should have the name and propquery
+         * that were used when building it....  /RL
+         */
+        goto common;
+    }
+
     /* TODO(3.0) Legacy code should be removed when all is provider based */
     /* BEGIN legacy */
     if (id == -1) {
         if (pkey == NULL)
-            return 0;
+            return NULL;
         id = pkey->type;
     }
 
@@ -701,6 +719,33 @@ static int legacy_ctrl_to_param(EVP_PKEY_CTX *ctx, int keytype, int optype,
         return EVP_PKEY_CTX_set_signature_md(ctx, p2);
     case EVP_PKEY_CTRL_GET_MD:
         return EVP_PKEY_CTX_get_signature_md(ctx, p2);
+    case EVP_PKEY_CTRL_RSA_PADDING:
+        return EVP_PKEY_CTX_set_rsa_padding(ctx, p1);
+    case EVP_PKEY_CTRL_GET_RSA_PADDING:
+        return EVP_PKEY_CTX_get_rsa_padding(ctx, p2);
+    case EVP_PKEY_CTRL_RSA_OAEP_MD:
+        return EVP_PKEY_CTX_set_rsa_oaep_md(ctx, p2);
+    case EVP_PKEY_CTRL_GET_RSA_OAEP_MD:
+        return EVP_PKEY_CTX_get_rsa_oaep_md(ctx, p2);
+    case EVP_PKEY_CTRL_RSA_MGF1_MD:
+        return EVP_PKEY_CTX_set_rsa_oaep_md(ctx, p2);
+    case EVP_PKEY_CTRL_GET_RSA_MGF1_MD:
+        return EVP_PKEY_CTX_get_rsa_oaep_md(ctx, p2);
+    case EVP_PKEY_CTRL_RSA_OAEP_LABEL:
+        return EVP_PKEY_CTX_set0_rsa_oaep_label(ctx, p2, p1);
+    case EVP_PKEY_CTRL_GET_RSA_OAEP_LABEL:
+        return EVP_PKEY_CTX_get0_rsa_oaep_label(ctx, (unsigned char **)p2);
+    case EVP_PKEY_CTRL_PKCS7_ENCRYPT:
+    case EVP_PKEY_CTRL_PKCS7_DECRYPT:
+#ifndef OPENSSL_NO_CMS
+    case EVP_PKEY_CTRL_CMS_DECRYPT:
+    case EVP_PKEY_CTRL_CMS_ENCRYPT:
+#endif
+        if (ctx->pmeth->pkey_id != EVP_PKEY_RSA_PSS)
+            return 1;
+        ERR_raise(ERR_LIB_EVP,
+                  EVP_R_OPERATION_NOT_SUPPORTED_FOR_THIS_KEYTYPE);
+        return -2;
     }
     return 0;
 }
@@ -784,6 +829,52 @@ static int legacy_ctrl_str_to_param(EVP_PKEY_CTX *ctx, const char *name,
         return ret;
     }
 
+    if (strcmp(name, "rsa_padding_mode") == 0) {
+        int pm;
+
+        if (strcmp(value, "pkcs1") == 0) {
+            pm = RSA_PKCS1_PADDING;
+        } else if (strcmp(value, "sslv23") == 0) {
+            pm = RSA_SSLV23_PADDING;
+        } else if (strcmp(value, "none") == 0) {
+            pm = RSA_NO_PADDING;
+        } else if (strcmp(value, "oeap") == 0) {
+            pm = RSA_PKCS1_OAEP_PADDING;
+        } else if (strcmp(value, "oaep") == 0) {
+            pm = RSA_PKCS1_OAEP_PADDING;
+        } else if (strcmp(value, "x931") == 0) {
+            pm = RSA_X931_PADDING;
+        } else if (strcmp(value, "pss") == 0) {
+            pm = RSA_PKCS1_PSS_PADDING;
+        } else {
+            ERR_raise(ERR_LIB_RSA, RSA_R_UNKNOWN_PADDING_TYPE);
+            return -2;
+        }
+        return EVP_PKEY_CTX_set_rsa_padding(ctx, pm);
+    }
+
+    if (strcmp(name, "rsa_mgf1_md") == 0)
+        return EVP_PKEY_CTX_set_rsa_mgf1_md_name(ctx, value, NULL);
+
+    if (strcmp(name, "rsa_oaep_md") == 0)
+        return EVP_PKEY_CTX_set_rsa_oaep_md_name(ctx, value, NULL);
+
+    if (strcmp(name, "rsa_oaep_label") == 0) {
+        unsigned char *lab;
+        long lablen;
+        int ret;
+
+        lab = OPENSSL_hexstr2buf(value, &lablen);
+        if (lab == NULL)
+            return 0;
+        ret = EVP_PKEY_CTX_set0_rsa_oaep_label(ctx, lab, lablen);
+        if (ret <= 0)
+            OPENSSL_free(lab);
+        return ret;
+    }
+
+
+
     return 0;
 }
 
@@ -1037,6 +1128,21 @@ void EVP_PKEY_meth_set_ctrl(EVP_PKEY_METHOD *pmeth,
     pmeth->ctrl_str = ctrl_str;
 }
 
+void EVP_PKEY_meth_set_digestsign(EVP_PKEY_METHOD *pmeth,
+    int (*digestsign) (EVP_MD_CTX *ctx, unsigned char *sig, size_t *siglen,
+                       const unsigned char *tbs, size_t tbslen))
+{
+    pmeth->digestsign = digestsign;
+}
+
+void EVP_PKEY_meth_set_digestverify(EVP_PKEY_METHOD *pmeth,
+    int (*digestverify) (EVP_MD_CTX *ctx, const unsigned char *sig,
+                         size_t siglen, const unsigned char *tbs,
+                         size_t tbslen))
+{
+    pmeth->digestverify = digestverify;
+}
+
 void EVP_PKEY_meth_set_check(EVP_PKEY_METHOD *pmeth,
                              int (*check) (EVP_PKEY *pkey))
 {
@@ -1229,6 +1335,23 @@ void EVP_PKEY_meth_get_ctrl(const EVP_PKEY_METHOD *pmeth,
         *pctrl_str = pmeth->ctrl_str;
 }
 
+void EVP_PKEY_meth_get_digestsign(EVP_PKEY_METHOD *pmeth,
+    int (**digestsign) (EVP_MD_CTX *ctx, unsigned char *sig, size_t *siglen,
+                        const unsigned char *tbs, size_t tbslen))
+{
+    if (digestsign)
+        *digestsign = pmeth->digestsign;
+}
+
+void EVP_PKEY_meth_get_digestverify(EVP_PKEY_METHOD *pmeth,
+    int (**digestverify) (EVP_MD_CTX *ctx, const unsigned char *sig,
+                          size_t siglen, const unsigned char *tbs,
+                          size_t tbslen))
+{
+    if (digestverify)
+        *digestverify = pmeth->digestverify;
+}
+
 void EVP_PKEY_meth_get_check(const EVP_PKEY_METHOD *pmeth,
                              int (**pcheck) (EVP_PKEY *pkey))
 {