Ensure RSA PSS correctly returns the right default digest
[openssl.git] / crypto / rsa / rsa_ameth.c
index de9e3c10776656510e73d5ec5037315006624709..bf56039b468d4dcfd05c282cefc00b9c5e9581c6 100644 (file)
@@ -1,7 +1,7 @@
 /*
 /*
- * Copyright 2006-2017 The OpenSSL Project Authors. All Rights Reserved.
+ * Copyright 2006-2018 The OpenSSL Project Authors. All Rights Reserved.
  *
  *
- * Licensed under the OpenSSL license (the "License").  You may not use
+ * Licensed under the Apache License 2.0 (the "License").  You may not use
  * this file except in compliance with the License.  You can obtain a copy
  * in the file LICENSE in the source distribution or at
  * https://www.openssl.org/source/license.html
  * this file except in compliance with the License.  You can obtain a copy
  * in the file LICENSE in the source distribution or at
  * https://www.openssl.org/source/license.html
@@ -34,7 +34,7 @@ static int rsa_param_encode(const EVP_PKEY *pkey,
 
     *pstr = NULL;
     /* If RSA it's just NULL type */
 
     *pstr = NULL;
     /* If RSA it's just NULL type */
-    if (pkey->ameth->pkey_id == EVP_PKEY_RSA) {
+    if (pkey->ameth->pkey_id != EVP_PKEY_RSA_PSS) {
         *pstrtype = V_ASN1_NULL;
         return 1;
     }
         *pstrtype = V_ASN1_NULL;
         return 1;
     }
@@ -58,7 +58,7 @@ static int rsa_param_decode(RSA *rsa, const X509_ALGOR *alg)
     int algptype;
 
     X509_ALGOR_get0(&algoid, &algptype, &algp, alg);
     int algptype;
 
     X509_ALGOR_get0(&algoid, &algptype, &algp, alg);
-    if (OBJ_obj2nid(algoid) == EVP_PKEY_RSA)
+    if (OBJ_obj2nid(algoid) != EVP_PKEY_RSA_PSS)
         return 1;
     if (algptype == V_ASN1_UNDEF)
         return 1;
         return 1;
     if (algptype == V_ASN1_UNDEF)
         return 1;
@@ -109,7 +109,10 @@ static int rsa_pub_decode(EVP_PKEY *pkey, X509_PUBKEY *pubkey)
         RSA_free(rsa);
         return 0;
     }
         RSA_free(rsa);
         return 0;
     }
-    EVP_PKEY_assign(pkey, pkey->ameth->pkey_id, rsa);
+    if (!EVP_PKEY_assign(pkey, pkey->ameth->pkey_id, rsa)) {
+        RSA_free(rsa);
+        return 0;
+    }
     return 1;
 }
 
     return 1;
 }
 
@@ -444,7 +447,7 @@ static int rsa_sig_print(BIO *bp, const X509_ALGOR *sigalg,
         RSA_PSS_PARAMS_free(pss);
         if (!rv)
             return 0;
         RSA_PSS_PARAMS_free(pss);
         if (!rv)
             return 0;
-    } else if (!sig && BIO_puts(bp, "\n") <= 0) {
+    } else if (BIO_puts(bp, "\n") <= 0) {
         return 0;
     }
     if (sig)
         return 0;
     }
     if (sig)
@@ -455,6 +458,9 @@ static int rsa_sig_print(BIO *bp, const X509_ALGOR *sigalg,
 static int rsa_pkey_ctrl(EVP_PKEY *pkey, int op, long arg1, void *arg2)
 {
     X509_ALGOR *alg = NULL;
 static int rsa_pkey_ctrl(EVP_PKEY *pkey, int op, long arg1, void *arg2)
 {
     X509_ALGOR *alg = NULL;
+    const EVP_MD *md;
+    const EVP_MD *mgf1md;
+    int min_saltlen;
 
     switch (op) {
 
 
     switch (op) {
 
@@ -494,6 +500,16 @@ static int rsa_pkey_ctrl(EVP_PKEY *pkey, int op, long arg1, void *arg2)
 #endif
 
     case ASN1_PKEY_CTRL_DEFAULT_MD_NID:
 #endif
 
     case ASN1_PKEY_CTRL_DEFAULT_MD_NID:
+        if (pkey->pkey.rsa->pss != NULL) {
+            if (!rsa_pss_get_param(pkey->pkey.rsa->pss, &md, &mgf1md,
+                                   &min_saltlen)) {
+                RSAerr(0, ERR_R_INTERNAL_ERROR);
+                return 0;
+            }
+            *(int *)arg2 = EVP_MD_type(md);
+            /* Return of 2 indicates this MD is mandatory */
+            return 2;
+        }
         *(int *)arg2 = NID_sha256;
         return 1;
 
         *(int *)arg2 = NID_sha256;
         return 1;
 
@@ -580,10 +596,12 @@ static RSA_PSS_PARAMS *rsa_ctx_to_pss(EVP_PKEY_CTX *pkctx)
         return NULL;
     if (saltlen == -1) {
         saltlen = EVP_MD_size(sigmd);
         return NULL;
     if (saltlen == -1) {
         saltlen = EVP_MD_size(sigmd);
-    } else if (saltlen == -2) {
+    } else if (saltlen == -2 || saltlen == -3) {
         saltlen = EVP_PKEY_size(pk) - EVP_MD_size(sigmd) - 2;
         if ((EVP_PKEY_bits(pk) & 0x7) == 1)
             saltlen--;
         saltlen = EVP_PKEY_size(pk) - EVP_MD_size(sigmd) - 2;
         if ((EVP_PKEY_bits(pk) & 0x7) == 1)
             saltlen--;
+        if (saltlen < 0)
+            return NULL;
     }
 
     return rsa_pss_params_create(sigmd, mgf1md, saltlen);
     }
 
     return rsa_pss_params_create(sigmd, mgf1md, saltlen);