Remove unnecessary loop in pkey_rsa_decrypt.
[openssl.git] / crypto / rsa / rsa_pmeth.c
index 55f1f28d38fd53baa2837735aaf77a10552999c2..4ba713910c2f4141970f6e4f82583200c65099c9 100644 (file)
@@ -50,6 +50,7 @@ typedef struct {
 static int pkey_rsa_init(EVP_PKEY_CTX *ctx)
 {
     RSA_PKEY_CTX *rctx = OPENSSL_zalloc(sizeof(*rctx));
+
     if (rctx == NULL)
         return 0;
     rctx->nbits = 1024;
@@ -57,7 +58,8 @@ static int pkey_rsa_init(EVP_PKEY_CTX *ctx)
         rctx->pad_mode = RSA_PKCS1_PSS_PADDING;
     else
         rctx->pad_mode = RSA_PKCS1_PADDING;
-    rctx->saltlen = -2;
+    /* Maximum for sign, auto for verify */
+    rctx->saltlen = RSA_PSS_SALTLEN_AUTO;
     rctx->min_saltlen = -1;
     ctx->data = rctx;
     ctx->keygen_info = rctx->gentmp;
@@ -314,19 +316,14 @@ static int pkey_rsa_decrypt(EVP_PKEY_CTX *ctx,
     RSA_PKEY_CTX *rctx = ctx->data;
 
     if (rctx->pad_mode == RSA_PKCS1_OAEP_PADDING) {
-        int i;
         if (!setup_tbuf(rctx, ctx))
             return -1;
         ret = RSA_private_decrypt(inlen, in, rctx->tbuf,
                                   ctx->pkey->pkey.rsa, RSA_NO_PADDING);
         if (ret <= 0)
             return ret;
-        for (i = 0; i < ret; i++) {
-            if (rctx->tbuf[i])
-                break;
-        }
-        ret = RSA_padding_check_PKCS1_OAEP_mgf1(out, ret, rctx->tbuf + i,
-                                                ret - i, ret,
+        ret = RSA_padding_check_PKCS1_OAEP_mgf1(out, ret, rctx->tbuf,
+                                                ret, ret,
                                                 rctx->oaep_label,
                                                 rctx->oaep_labellen,
                                                 rctx->md, rctx->mgf1md);
@@ -429,11 +426,20 @@ static int pkey_rsa_ctrl(EVP_PKEY_CTX *ctx, int type, int p1, void *p2)
         if (type == EVP_PKEY_CTRL_GET_RSA_PSS_SALTLEN) {
             *(int *)p2 = rctx->saltlen;
         } else {
-            if (p1 < -2)
+            if (p1 < RSA_PSS_SALTLEN_MAX)
                 return -2;
-            if (rsa_pss_restricted(rctx) && p1 < rctx->min_saltlen) {
-                RSAerr(RSA_F_PKEY_RSA_CTRL, RSA_R_PSS_SALTLEN_TOO_SMALL);
-                return 0;
+            if (rsa_pss_restricted(rctx)) {
+                if (p1 == RSA_PSS_SALTLEN_AUTO
+                    && ctx->operation == EVP_PKEY_OP_VERIFY) {
+                    RSAerr(RSA_F_PKEY_RSA_CTRL, RSA_R_INVALID_PSS_SALTLEN);
+                    return -2;
+                }
+                if ((p1 == RSA_PSS_SALTLEN_DIGEST
+                     && rctx->min_saltlen > EVP_MD_size(rctx->md))
+                    || (p1 >= 0 && p1 < rctx->min_saltlen)) {
+                    RSAerr(RSA_F_PKEY_RSA_CTRL, RSA_R_PSS_SALTLEN_TOO_SMALL);
+                    return 0;
+                }
             }
             rctx->saltlen = p1;
         }
@@ -542,9 +548,10 @@ static int pkey_rsa_ctrl(EVP_PKEY_CTX *ctx, int type, int p1, void *p2)
 #ifndef OPENSSL_NO_CMS
     case EVP_PKEY_CTRL_CMS_DECRYPT:
     case EVP_PKEY_CTRL_CMS_ENCRYPT:
+#endif
     if (!pkey_ctx_is_pss(ctx))
         return 1;
-#endif
+    /* fall through */
     case EVP_PKEY_CTRL_PEER_KEY:
         RSAerr(RSA_F_PKEY_RSA_CTRL,
                RSA_R_OPERATION_NOT_SUPPORTED_FOR_THIS_KEYTYPE);
@@ -588,7 +595,14 @@ static int pkey_rsa_ctrl_str(EVP_PKEY_CTX *ctx,
 
     if (strcmp(type, "rsa_pss_saltlen") == 0) {
         int saltlen;
-        saltlen = atoi(value);
+        if (!strcmp(value, "digest"))
+            saltlen = RSA_PSS_SALTLEN_DIGEST;
+        else if (!strcmp(value, "max"))
+            saltlen = RSA_PSS_SALTLEN_MAX;
+        else if (!strcmp(value, "auto"))
+            saltlen = RSA_PSS_SALTLEN_AUTO;
+        else
+            saltlen = atoi(value);
         return EVP_PKEY_CTX_set_rsa_pss_saltlen(ctx, saltlen);
     }
 
@@ -625,8 +639,8 @@ static int pkey_rsa_ctrl_str(EVP_PKEY_CTX *ctx,
                                    EVP_PKEY_CTRL_MD, value);
 
         if (strcmp(type, "rsa_pss_keygen_saltlen") == 0) {
-            int saltlen;
-            saltlen = atoi(value);
+            int saltlen = atoi(value);
+
             return EVP_PKEY_CTX_set_rsa_pss_keygen_saltlen(ctx, saltlen);
         }
     }
@@ -751,7 +765,7 @@ static int pkey_pss_init(EVP_PKEY_CTX *ctx)
     RSA_PKEY_CTX *rctx = ctx->data;
     const EVP_MD *md;
     const EVP_MD *mgf1md;
-    int min_saltlen;
+    int min_saltlen, max_saltlen;
 
     /* Should never happen */
     if (!pkey_ctx_is_pss(ctx))
@@ -764,6 +778,15 @@ static int pkey_pss_init(EVP_PKEY_CTX *ctx)
     if (!rsa_pss_get_param(rsa->pss, &md, &mgf1md, &min_saltlen))
         return 0;
 
+    /* See if minumum salt length exceeds maximum possible */
+    max_saltlen = RSA_size(rsa) - EVP_MD_size(md);
+    if ((RSA_bits(rsa) & 0x7) == 1)
+        max_saltlen--;
+    if (min_saltlen > max_saltlen) {
+        RSAerr(RSA_F_PKEY_PSS_INIT, RSA_R_INVALID_SALT_LENGTH);
+        return 0;
+    }
+
     rctx->min_saltlen = min_saltlen;
 
     /*