Support calling EVP_DigestUpdate instead of EVP_Digest[Sign|Verify]Update
[openssl.git] / crypto / evp / digest.c
index 4b9395d58bf6ae37c1652f969c84900bd2c8aac5..5ff43fdd6433c0d1ed04d9dca455ac3d331d8e90 100644 (file)
@@ -24,8 +24,19 @@ int EVP_MD_CTX_reset(EVP_MD_CTX *ctx)
     if (ctx == NULL)
         return 1;
 
-    if (ctx->digest == NULL || ctx->digest->prov == NULL)
-        goto legacy;
+#ifndef FIPS_MODE
+    /* TODO(3.0): Temporarily no support for EVP_DigestSign* in FIPS module */
+    /*
+     * pctx should be freed by the user of EVP_MD_CTX
+     * if EVP_MD_CTX_FLAG_KEEP_PKEY_CTX is set
+     */
+    if (!EVP_MD_CTX_test_flags(ctx, EVP_MD_CTX_FLAG_KEEP_PKEY_CTX))
+        EVP_PKEY_CTX_free(ctx->pctx);
+#endif
+
+    EVP_MD_free(ctx->fetched_digest);
+    ctx->fetched_digest = NULL;
+    ctx->reqdigest = NULL;
 
     if (ctx->provctx != NULL) {
         if (ctx->digest->freectx != NULL)
@@ -34,13 +45,7 @@ int EVP_MD_CTX_reset(EVP_MD_CTX *ctx)
         EVP_MD_CTX_set_flags(ctx, EVP_MD_CTX_FLAG_CLEANED);
     }
 
-    if (ctx->pctx != NULL)
-        goto legacy;
-
-    return 1;
-
     /* TODO(3.0): Remove legacy code below */
- legacy:
 
     /*
      * Don't assume ctx->md_data was cleaned in EVP_Digest_Final, because
@@ -53,19 +58,13 @@ int EVP_MD_CTX_reset(EVP_MD_CTX *ctx)
         && !EVP_MD_CTX_test_flags(ctx, EVP_MD_CTX_FLAG_REUSE)) {
         OPENSSL_clear_free(ctx->md_data, ctx->digest->ctx_size);
     }
-    /*
-     * pctx should be freed by the user of EVP_MD_CTX
-     * if EVP_MD_CTX_FLAG_KEEP_PKEY_CTX is set
-     */
-#ifndef FIPS_MODE
-    /* TODO(3.0): Temporarily no support for EVP_DigestSign* in FIPS module */
-    if (!EVP_MD_CTX_test_flags(ctx, EVP_MD_CTX_FLAG_KEEP_PKEY_CTX))
-        EVP_PKEY_CTX_free(ctx->pctx);
 
-# ifndef OPENSSL_NO_ENGINE
+#if !defined(FIPS_MODE) && !defined(OPENSSL_NO_ENGINE)
     ENGINE_finish(ctx->engine);
-# endif
 #endif
+
+    /* TODO(3.0): End of legacy code */
+
     OPENSSL_cleanse(ctx, sizeof(*ctx));
 
     return 1;
@@ -83,11 +82,6 @@ void EVP_MD_CTX_free(EVP_MD_CTX *ctx)
 
     EVP_MD_CTX_reset(ctx);
 
-    EVP_MD_free(ctx->fetched_digest);
-    ctx->fetched_digest = NULL;
-    ctx->digest = NULL;
-    ctx->reqdigest = NULL;
-
     OPENSSL_free(ctx);
     return;
 }
@@ -106,6 +100,16 @@ int EVP_DigestInit_ex(EVP_MD_CTX *ctx, const EVP_MD *type, ENGINE *impl)
 
     EVP_MD_CTX_clear_flags(ctx, EVP_MD_CTX_FLAG_CLEANED);
 
+    if (ctx->provctx != NULL) {
+        if (!ossl_assert(ctx->digest != NULL)) {
+            EVPerr(EVP_F_EVP_DIGESTINIT_EX, EVP_R_INITIALIZATION_ERROR);
+            return 0;
+        }
+        if (ctx->digest->freectx != NULL)
+            ctx->digest->freectx(ctx->provctx);
+        ctx->provctx = NULL;
+    }
+
     if (type != NULL)
         ctx->reqdigest = type;
 
@@ -136,15 +140,14 @@ int EVP_DigestInit_ex(EVP_MD_CTX *ctx, const EVP_MD *type, ENGINE *impl)
 #endif
 
     /*
-     * If there are engines involved or if we're being used as part of
-     * EVP_DigestSignInit then we should use legacy handling for now.
+     * If there are engines involved or EVP_MD_CTX_FLAG_NO_INIT is set then we
+     * should use legacy handling for now.
      */
     if (ctx->engine != NULL
             || impl != NULL
 #if !defined(OPENSSL_NO_ENGINE) && !defined(FIPS_MODE)
             || tmpimpl != NULL
 #endif
-            || ctx->pctx != NULL
             || (ctx->flags & EVP_MD_CTX_FLAG_NO_INIT) != 0) {
         if (ctx->digest == ctx->fetched_digest)
             ctx->digest = NULL;
@@ -282,6 +285,24 @@ int EVP_DigestUpdate(EVP_MD_CTX *ctx, const void *data, size_t count)
     if (count == 0)
         return 1;
 
+    if (ctx->pctx != NULL
+            && EVP_PKEY_CTX_IS_SIGNATURE_OP(ctx->pctx)
+            && ctx->pctx->op.sig.sigprovctx != NULL) {
+        /*
+         * Prior to OpenSSL 3.0 EVP_DigestSignUpdate() and
+         * EVP_DigestVerifyUpdate() were just macros for EVP_DigestUpdate().
+         * Some code calls EVP_DigestUpdate() directly even when initialised
+         * with EVP_DigestSignInit_ex() or EVP_DigestVerifyInit_ex(), so we
+         * detect that and redirect to the correct EVP_Digest*Update() function
+         */
+        if (ctx->pctx->operation == EVP_PKEY_OP_SIGNCTX)
+            return EVP_DigestSignUpdate(ctx, data, count);
+        if (ctx->pctx->operation == EVP_PKEY_OP_VERIFYCTX)
+            return EVP_DigestVerifyUpdate(ctx, data, count);
+        EVPerr(EVP_F_EVP_DIGESTUPDATE, EVP_R_UPDATE_ERROR);
+        return 0;
+    }
+
     if (ctx->digest == NULL || ctx->digest->prov == NULL)
         goto legacy;
 
@@ -331,7 +352,6 @@ int EVP_DigestFinal_ex(EVP_MD_CTX *ctx, unsigned char *md, unsigned int *isize)
         }
     }
 
-    EVP_MD_CTX_reset(ctx);
     return ret;
 
     /* TODO(3.0): Remove legacy code below */
@@ -537,8 +557,18 @@ const OSSL_PARAM *EVP_MD_gettable_params(const EVP_MD *digest)
 
 int EVP_MD_CTX_set_params(EVP_MD_CTX *ctx, const OSSL_PARAM params[])
 {
+    EVP_PKEY_CTX *pctx = ctx->pctx;
+
     if (ctx->digest != NULL && ctx->digest->set_ctx_params != NULL)
         return ctx->digest->set_ctx_params(ctx->provctx, params);
+
+    if (pctx != NULL
+            && (pctx->operation == EVP_PKEY_OP_VERIFYCTX
+                || pctx->operation == EVP_PKEY_OP_SIGNCTX)
+            && pctx->op.sig.sigprovctx != NULL
+            && pctx->op.sig.signature->set_ctx_md_params != NULL)
+        return pctx->op.sig.signature->set_ctx_md_params(pctx->op.sig.sigprovctx,
+                                                         params);
     return 0;
 }
 
@@ -551,18 +581,40 @@ const OSSL_PARAM *EVP_MD_settable_ctx_params(const EVP_MD *md)
 
 const OSSL_PARAM *EVP_MD_CTX_settable_params(EVP_MD_CTX *ctx)
 {
+    EVP_PKEY_CTX *pctx;
+
     if (ctx != NULL
             && ctx->digest != NULL
             && ctx->digest->settable_ctx_params != NULL)
         return ctx->digest->settable_ctx_params();
 
+    pctx = ctx->pctx;
+    if (pctx != NULL
+            && (pctx->operation == EVP_PKEY_OP_VERIFYCTX
+                || pctx->operation == EVP_PKEY_OP_SIGNCTX)
+            && pctx->op.sig.sigprovctx != NULL
+            && pctx->op.sig.signature->settable_ctx_md_params != NULL)
+        return pctx->op.sig.signature->settable_ctx_md_params(
+                   pctx->op.sig.sigprovctx);
+
     return NULL;
 }
 
 int EVP_MD_CTX_get_params(EVP_MD_CTX *ctx, OSSL_PARAM params[])
 {
+    EVP_PKEY_CTX *pctx = ctx->pctx;
+
     if (ctx->digest != NULL && ctx->digest->get_params != NULL)
         return ctx->digest->get_ctx_params(ctx->provctx, params);
+
+    if (pctx != NULL
+            && (pctx->operation == EVP_PKEY_OP_VERIFYCTX
+                || pctx->operation == EVP_PKEY_OP_SIGNCTX)
+            && pctx->op.sig.sigprovctx != NULL
+            && pctx->op.sig.signature->get_ctx_md_params != NULL)
+        return pctx->op.sig.signature->get_ctx_md_params(pctx->op.sig.sigprovctx,
+                                                         params);
+
     return 0;
 }
 
@@ -575,11 +627,22 @@ const OSSL_PARAM *EVP_MD_gettable_ctx_params(const EVP_MD *md)
 
 const OSSL_PARAM *EVP_MD_CTX_gettable_params(EVP_MD_CTX *ctx)
 {
+    EVP_PKEY_CTX *pctx;
+
     if (ctx != NULL
             && ctx->digest != NULL
             && ctx->digest->gettable_ctx_params != NULL)
         return ctx->digest->gettable_ctx_params();
 
+    pctx = ctx->pctx;
+    if (pctx != NULL
+            && (pctx->operation == EVP_PKEY_OP_VERIFYCTX
+                || pctx->operation == EVP_PKEY_OP_SIGNCTX)
+            && pctx->op.sig.sigprovctx != NULL
+            && pctx->op.sig.signature->gettable_ctx_md_params != NULL)
+        return pctx->op.sig.signature->gettable_ctx_md_params(
+                    pctx->op.sig.sigprovctx);
+
     return NULL;
 }
 
@@ -596,7 +659,10 @@ int EVP_MD_CTX_ctrl(EVP_MD_CTX *ctx, int cmd, int p1, void *p2)
         return 0;
     }
 
-    if (ctx->digest->prov == NULL)
+    if (ctx->digest->prov == NULL
+        && (ctx->pctx == NULL
+            || (ctx->pctx->operation != EVP_PKEY_OP_VERIFYCTX
+                && ctx->pctx->operation != EVP_PKEY_OP_SIGNCTX)))
         goto legacy;
 
     switch (cmd) {
@@ -614,10 +680,10 @@ int EVP_MD_CTX_ctrl(EVP_MD_CTX *ctx, int cmd, int p1, void *p2)
     }
 
     if (set_params)
-        ret = evp_do_md_ctx_setparams(ctx->digest, ctx->provctx, params);
+        ret = EVP_MD_CTX_set_params(ctx, params);
     else
-        ret = evp_do_md_ctx_getparams(ctx->digest, ctx->provctx, params);
-    return ret;
+        ret = EVP_MD_CTX_get_params(ctx, params);
+    goto conclude;
 
 
 /* TODO(3.0): Remove legacy code below */
@@ -628,6 +694,7 @@ int EVP_MD_CTX_ctrl(EVP_MD_CTX *ctx, int cmd, int p1, void *p2)
     }
 
     ret = ctx->digest->md_ctrl(ctx, cmd, p1, p2);
+ conclude:
     if (ret <= 0)
         return 0;
     return ret;