make EVP_PKEY opaque
[openssl.git] / crypto / rsa / rsa_ameth.c
index 071dbb8d680f4f01174a184417524888b75540ab..4433bf47ded393ebcc90436870a28afcc0a48809 100644 (file)
@@ -58,7 +58,7 @@
  */
 
 #include <stdio.h>
-#include "cryptlib.h"
+#include "internal/cryptlib.h"
 #include <openssl/asn1t.h>
 #include <openssl/x509.h>
 #include <openssl/rsa.h>
 #ifndef OPENSSL_NO_CMS
 # include <openssl/cms.h>
 #endif
-#include "asn1_locl.h"
+#include "internal/asn1_int.h"
+#include "internal/evp_int.h"
 
+#ifndef OPENSSL_NO_CMS
 static int rsa_cms_sign(CMS_SignerInfo *si);
 static int rsa_cms_verify(CMS_SignerInfo *si);
 static int rsa_cms_decrypt(CMS_RecipientInfo *ri);
 static int rsa_cms_encrypt(CMS_RecipientInfo *ri);
+#endif
 
 static int rsa_pub_encode(X509_PUBKEY *pk, const EVP_PKEY *pkey)
 {
@@ -93,9 +96,10 @@ static int rsa_pub_decode(EVP_PKEY *pkey, X509_PUBKEY *pubkey)
     const unsigned char *p;
     int pklen;
     RSA *rsa = NULL;
+
     if (!X509_PUBKEY_get0_param(NULL, &p, &pklen, NULL, pubkey))
         return 0;
-    if (!(rsa = d2i_RSAPublicKey(NULL, &p, pklen))) {
+    if ((rsa = d2i_RSAPublicKey(NULL, &p, pklen)) == NULL) {
         RSAerr(RSA_F_RSA_PUB_DECODE, ERR_R_RSA_LIB);
         return 0;
     }
@@ -115,7 +119,8 @@ static int old_rsa_priv_decode(EVP_PKEY *pkey,
                                const unsigned char **pder, int derlen)
 {
     RSA *rsa;
-    if (!(rsa = d2i_RSAPrivateKey(NULL, pder, derlen))) {
+
+    if ((rsa = d2i_RSAPrivateKey(NULL, pder, derlen)) == NULL) {
         RSAerr(RSA_F_OLD_RSA_PRIV_DECODE, ERR_R_RSA_LIB);
         return 0;
     }
@@ -206,7 +211,7 @@ static int do_rsa_print(BIO *bp, const RSA *x, int off, int priv)
         update_buflen(x->iqmp, &buf_len);
     }
 
-    m = (unsigned char *)OPENSSL_malloc(buf_len + 10);
+    m = OPENSSL_malloc(buf_len + 10);
     if (m == NULL) {
         RSAerr(RSA_F_DO_RSA_PRINT, ERR_R_MALLOC_FAILURE);
         goto err;
@@ -251,8 +256,7 @@ static int do_rsa_print(BIO *bp, const RSA *x, int off, int priv)
     }
     ret = 1;
  err:
-    if (m != NULL)
-        OPENSSL_free(m);
+    OPENSSL_free(m);
     return (ret);
 }
 
@@ -271,34 +275,23 @@ static int rsa_priv_print(BIO *bp, const EVP_PKEY *pkey, int indent,
 /* Given an MGF1 Algorithm ID decode to an Algorithm Identifier */
 static X509_ALGOR *rsa_mgf1_decode(X509_ALGOR *alg)
 {
-    const unsigned char *p;
-    int plen;
     if (alg == NULL)
         return NULL;
     if (OBJ_obj2nid(alg->algorithm) != NID_mgf1)
         return NULL;
-    if (alg->parameter->type != V_ASN1_SEQUENCE)
-        return NULL;
-
-    p = alg->parameter->value.sequence->data;
-    plen = alg->parameter->value.sequence->length;
-    return d2i_X509_ALGOR(NULL, &p, plen);
+    return ASN1_TYPE_unpack_sequence(ASN1_ITEM_rptr(X509_ALGOR),
+                                     alg->parameter);
 }
 
 static RSA_PSS_PARAMS *rsa_pss_decode(const X509_ALGOR *alg,
                                       X509_ALGOR **pmaskHash)
 {
-    const unsigned char *p;
-    int plen;
     RSA_PSS_PARAMS *pss;
 
     *pmaskHash = NULL;
 
-    if (!alg->parameter || alg->parameter->type != V_ASN1_SEQUENCE)
-        return NULL;
-    p = alg->parameter->value.sequence->data;
-    plen = alg->parameter->value.sequence->length;
-    pss = d2i_RSA_PSS_PARAMS(NULL, &p, plen);
+    pss = ASN1_TYPE_unpack_sequence(ASN1_ITEM_rptr(RSA_PSS_PARAMS),
+                                    alg->parameter);
 
     if (!pss)
         return NULL;
@@ -390,10 +383,8 @@ static int rsa_sig_print(BIO *bp, const X509_ALGOR *sigalg,
         X509_ALGOR *maskHash;
         pss = rsa_pss_decode(sigalg, &maskHash);
         rv = rsa_pss_param_print(bp, pss, maskHash, indent);
-        if (pss)
-            RSA_PSS_PARAMS_free(pss);
-        if (maskHash)
-            X509_ALGOR_free(maskHash);
+        RSA_PSS_PARAMS_free(pss);
+        X509_ALGOR_free(maskHash);
         if (!rv)
             return 0;
     } else if (!sig && BIO_puts(bp, "\n") <= 0)
@@ -459,7 +450,7 @@ static int rsa_md_to_algor(X509_ALGOR **palg, const EVP_MD *md)
     if (EVP_MD_type(md) == NID_sha1)
         return 1;
     *palg = X509_ALGOR_new();
-    if (!*palg)
+    if (*palg == NULL)
         return 0;
     X509_ALGOR_set_md(*palg, md);
     return 1;
@@ -479,14 +470,13 @@ static int rsa_md_to_mgf1(X509_ALGOR **palg, const EVP_MD *mgf1md)
     if (!ASN1_item_pack(algtmp, ASN1_ITEM_rptr(X509_ALGOR), &stmp))
          goto err;
     *palg = X509_ALGOR_new();
-    if (!*palg)
+    if (*palg == NULL)
         goto err;
     X509_ALGOR_set0(*palg, OBJ_nid2obj(NID_mgf1), V_ASN1_SEQUENCE, stmp);
     stmp = NULL;
  err:
     ASN1_STRING_free(stmp);
-    if (algtmp)
-        X509_ALGOR_free(algtmp);
+    X509_ALGOR_free(algtmp);
     if (*palg)
         return 1;
     return 0;
@@ -553,11 +543,11 @@ static ASN1_STRING *rsa_ctx_to_pss(EVP_PKEY_CTX *pkctx)
             saltlen--;
     }
     pss = RSA_PSS_PARAMS_new();
-    if (!pss)
+    if (pss == NULL)
         goto err;
     if (saltlen != 20) {
         pss->saltLength = ASN1_INTEGER_new();
-        if (!pss->saltLength)
+        if (pss->saltLength == NULL)
             goto err;
         if (!ASN1_INTEGER_set(pss->saltLength, saltlen))
             goto err;
@@ -571,8 +561,7 @@ static ASN1_STRING *rsa_ctx_to_pss(EVP_PKEY_CTX *pkctx)
          goto err;
     rv = 1;
  err:
-    if (pss)
-        RSA_PSS_PARAMS_free(pss);
+    RSA_PSS_PARAMS_free(pss);
     if (rv)
         return os;
     ASN1_STRING_free(os);
@@ -663,11 +652,11 @@ static int rsa_pss_to_ctx(EVP_MD_CTX *ctx, EVP_PKEY_CTX *pkctx,
 
  err:
     RSA_PSS_PARAMS_free(pss);
-    if (maskHash)
-        X509_ALGOR_free(maskHash);
+    X509_ALGOR_free(maskHash);
     return rv;
 }
 
+#ifndef OPENSSL_NO_CMS
 static int rsa_cms_verify(CMS_SignerInfo *si)
 {
     int nid, nid2;
@@ -686,6 +675,7 @@ static int rsa_cms_verify(CMS_SignerInfo *si)
     }
     return 0;
 }
+#endif
 
 /*
  * Customised RSA item verification routine. This is called when a signature
@@ -708,6 +698,7 @@ static int rsa_item_verify(EVP_MD_CTX *ctx, const ASN1_ITEM *it, void *asn,
     return -1;
 }
 
+#ifndef OPENSSL_NO_CMS
 static int rsa_cms_sign(CMS_SignerInfo *si)
 {
     int pad_mode = RSA_PKCS1_PADDING;
@@ -732,13 +723,14 @@ static int rsa_cms_sign(CMS_SignerInfo *si)
     X509_ALGOR_set0(alg, OBJ_nid2obj(NID_rsassaPss), V_ASN1_SEQUENCE, os);
     return 1;
 }
+#endif
 
 static int rsa_item_sign(EVP_MD_CTX *ctx, const ASN1_ITEM *it, void *asn,
                          X509_ALGOR *alg1, X509_ALGOR *alg2,
                          ASN1_BIT_STRING *sig)
 {
     int pad_mode;
-    EVP_PKEY_CTX *pkctx = ctx->pctx;
+    EVP_PKEY_CTX *pkctx = EVP_MD_CTX_pkey_ctx(ctx);
     if (EVP_PKEY_CTX_get_rsa_padding(pkctx, &pad_mode) <= 0)
         return 0;
     if (pad_mode == RSA_PKCS1_PADDING)
@@ -765,20 +757,16 @@ static int rsa_item_sign(EVP_MD_CTX *ctx, const ASN1_ITEM *it, void *asn,
     return 2;
 }
 
+#ifndef OPENSSL_NO_CMS
 static RSA_OAEP_PARAMS *rsa_oaep_decode(const X509_ALGOR *alg,
                                         X509_ALGOR **pmaskHash)
 {
-    const unsigned char *p;
-    int plen;
     RSA_OAEP_PARAMS *pss;
 
     *pmaskHash = NULL;
 
-    if (!alg->parameter || alg->parameter->type != V_ASN1_SEQUENCE)
-        return NULL;
-    p = alg->parameter->value.sequence->data;
-    plen = alg->parameter->value.sequence->length;
-    pss = d2i_RSA_OAEP_PARAMS(NULL, &p, plen);
+    pss = ASN1_TYPE_unpack_sequence(ASN1_ITEM_rptr(RSA_OAEP_PARAMS),
+                                    alg->parameter);
 
     if (!pss)
         return NULL;
@@ -856,8 +844,7 @@ static int rsa_cms_decrypt(CMS_RecipientInfo *ri)
 
  err:
     RSA_OAEP_PARAMS_free(oaep);
-    if (maskHash)
-        X509_ALGOR_free(maskHash);
+    X509_ALGOR_free(maskHash);
     return rv;
 }
 
@@ -890,7 +877,7 @@ static int rsa_cms_encrypt(CMS_RecipientInfo *ri)
     if (labellen < 0)
         goto err;
     oaep = RSA_OAEP_PARAMS_new();
-    if (!oaep)
+    if (oaep == NULL)
         goto err;
     if (!rsa_md_to_algor(&oaep->hashFunc, md))
         goto err;
@@ -899,9 +886,9 @@ static int rsa_cms_encrypt(CMS_RecipientInfo *ri)
     if (labellen > 0) {
         ASN1_OCTET_STRING *los = ASN1_OCTET_STRING_new();
         oaep->pSourceFunc = X509_ALGOR_new();
-        if (!oaep->pSourceFunc)
+        if (oaep->pSourceFunc == NULL)
             goto err;
-        if (!los)
+        if (los == NULL)
             goto err;
         if (!ASN1_OCTET_STRING_set(los, label, labellen)) {
             ASN1_OCTET_STRING_free(los);
@@ -917,11 +904,11 @@ static int rsa_cms_encrypt(CMS_RecipientInfo *ri)
     os = NULL;
     rv = 1;
  err:
-    if (oaep)
-        RSA_OAEP_PARAMS_free(oaep);
+    RSA_OAEP_PARAMS_free(oaep);
     ASN1_STRING_free(os);
     return rv;
 }
+#endif
 
 const EVP_PKEY_ASN1_METHOD rsa_asn1_meths[] = {
     {