Move digests to providers
[openssl.git] / crypto / evp / digest.c
index a1f0154a7fc07a765accf450fa1a679f8c7670ff..89cd5c1d006759a05e52a9e7e60d368c5d1a42cc 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright 1995-2018 The OpenSSL Project Authors. All Rights Reserved.
+ * Copyright 1995-2019 The OpenSSL Project Authors. All Rights Reserved.
  *
  * Licensed under the Apache License 2.0 (the "License").  You may not use
  * this file except in compliance with the License.  You can obtain a copy
@@ -8,10 +8,12 @@
  */
 
 #include <stdio.h>
-#include "internal/cryptlib.h"
 #include <openssl/objects.h>
 #include <openssl/evp.h>
 #include <openssl/engine.h>
+#include <openssl/params.h>
+#include <openssl/core_names.h>
+#include "internal/cryptlib.h"
 #include "internal/evp_int.h"
 #include "internal/provider.h"
 #include "evp_locl.h"
@@ -149,16 +151,6 @@ int EVP_DigestInit_ex(EVP_MD_CTX *ctx, const EVP_MD *type, ENGINE *impl)
         goto legacy;
     }
 
-    if (type->prov == NULL) {
-        switch(type->type) {
-        case NID_sha256:
-        case NID_md2:
-            break;
-        default:
-            goto legacy;
-        }
-    }
-
     if (ctx->digest != NULL && ctx->digest->ctx_size > 0) {
         OPENSSL_clear_free(ctx->md_data, ctx->digest->ctx_size);
         ctx->md_data = NULL;
@@ -184,6 +176,11 @@ int EVP_DigestInit_ex(EVP_MD_CTX *ctx, const EVP_MD *type, ENGINE *impl)
 #endif
     }
 
+    if (ctx->provctx != NULL && ctx->digest != NULL && ctx->digest != type) {
+        if (ctx->digest->freectx != NULL)
+            ctx->digest->freectx(ctx->provctx);
+        ctx->provctx = NULL;
+    }
     ctx->digest = type;
     if (ctx->provctx == NULL) {
         ctx->provctx = ctx->digest->newctx(ossl_provider_ctx(type->prov));
@@ -334,7 +331,6 @@ int EVP_DigestFinal_ex(EVP_MD_CTX *ctx, unsigned char *md, unsigned int *isize)
     }
 
     EVP_MD_CTX_reset(ctx);
-
     return ret;
 
     /* TODO(3.0): Remove legacy code below */
@@ -354,12 +350,31 @@ int EVP_DigestFinal_ex(EVP_MD_CTX *ctx, unsigned char *md, unsigned int *isize)
 int EVP_DigestFinalXOF(EVP_MD_CTX *ctx, unsigned char *md, size_t size)
 {
     int ret = 0;
+    OSSL_PARAM params[2];
+    size_t i = 0;
+
+    if (ctx->digest == NULL || ctx->digest->prov == NULL)
+        goto legacy;
 
+    if (ctx->digest->dfinal == NULL) {
+        EVPerr(EVP_F_EVP_DIGESTFINALXOF, EVP_R_FINAL_ERROR);
+        return 0;
+    }
+
+    params[i++] = OSSL_PARAM_construct_size_t(OSSL_DIGEST_PARAM_XOFLEN,
+                                              &size, NULL);
+    params[i++] = OSSL_PARAM_construct_end();
+
+    if (EVP_MD_CTX_set_params(ctx, params) > 0)
+        ret = ctx->digest->dfinal(ctx->provctx, md, &size, size);
+    EVP_MD_CTX_reset(ctx);
+    return ret;
+
+legacy:
     if (ctx->digest->flags & EVP_MD_FLAG_XOF
         && size <= INT_MAX
         && ctx->digest->md_ctrl(ctx, EVP_MD_CTRL_XOF_LEN, (int)size, NULL)) {
         ret = ctx->digest->final(ctx, md);
-
         if (ctx->digest->cleanup != NULL) {
             ctx->digest->cleanup(ctx);
             EVP_MD_CTX_set_flags(ctx, EVP_MD_CTX_FLAG_CLEANED);
@@ -506,16 +521,56 @@ int EVP_Digest(const void *data, size_t count,
     return ret;
 }
 
+int EVP_MD_CTX_set_params(EVP_MD_CTX *ctx, const OSSL_PARAM params[])
+{
+    if (ctx->digest != NULL && ctx->digest->set_params != NULL)
+        return ctx->digest->set_params(ctx->provctx, params);
+    return 0;
+}
+
+int EVP_MD_CTX_get_params(EVP_MD_CTX *ctx, const OSSL_PARAM params[])
+{
+    if (ctx->digest != NULL && ctx->digest->get_params != NULL)
+        return ctx->digest->get_params(ctx->provctx, params);
+    return 0;
+}
+
+#if !OPENSSL_API_3
 int EVP_MD_CTX_ctrl(EVP_MD_CTX *ctx, int cmd, int p1, void *p2)
 {
-    if (ctx->digest && ctx->digest->md_ctrl) {
-        int ret = ctx->digest->md_ctrl(ctx, cmd, p1, p2);
-        if (ret <= 0)
-            return 0;
-        return 1;
+    if (ctx->digest != NULL) {
+        OSSL_PARAM params[2];
+        size_t i, sz, n = 0;
+
+        switch (cmd) {
+        case EVP_MD_CTRL_XOF_LEN:
+            if (ctx->digest->set_params == NULL)
+                break;
+            i = (size_t)p1;
+            params[n++] = OSSL_PARAM_construct_size_t(
+                              OSSL_DIGEST_PARAM_XOFLEN, &i, &sz);
+            params[n++] = OSSL_PARAM_construct_end();
+            return ctx->digest->set_params(ctx->provctx, params) > 0;
+        case EVP_MD_CTRL_MICALG:
+            if (ctx->digest->get_params == NULL)
+                break;
+            params[n++] = OSSL_PARAM_construct_utf8_string(
+                              OSSL_DIGEST_PARAM_MICALG, p2, p1 ? p1 : 9999,
+                              &sz);
+            params[n++] = OSSL_PARAM_construct_end();
+            return ctx->digest->get_params(ctx->provctx, params);
+        }
+        /* legacy code */
+        if (ctx->digest->md_ctrl != NULL) {
+            int ret = ctx->digest->md_ctrl(ctx, cmd, p1, p2);
+            if (ret <= 0)
+                return 0;
+            return 1;
+        }
     }
     return 0;
 }
+#endif
 
 static void *evp_md_from_dispatch(const OSSL_DISPATCH *fns,
                                   OSSL_PROVIDER *prov)
@@ -530,55 +585,59 @@ static void *evp_md_from_dispatch(const OSSL_DISPATCH *fns,
     for (; fns->function_id != 0; fns++) {
         switch (fns->function_id) {
         case OSSL_FUNC_DIGEST_NEWCTX:
-            if (md->newctx != NULL)
-                break;
-            md->newctx = OSSL_get_OP_digest_newctx(fns);
-            fncnt++;
+            if (md->newctx == NULL) {
+                md->newctx = OSSL_get_OP_digest_newctx(fns);
+                fncnt++;
+            }
             break;
         case OSSL_FUNC_DIGEST_INIT:
-            if (md->dinit != NULL)
-                break;
-            md->dinit = OSSL_get_OP_digest_init(fns);
-            fncnt++;
+            if (md->dinit == NULL) {
+                md->dinit = OSSL_get_OP_digest_init(fns);
+                fncnt++;
+            }
             break;
         case OSSL_FUNC_DIGEST_UPDATE:
-            if (md->dupdate != NULL)
-                break;
-            md->dupdate = OSSL_get_OP_digest_update(fns);
-            fncnt++;
+            if (md->dupdate == NULL) {
+                md->dupdate = OSSL_get_OP_digest_update(fns);
+                fncnt++;
+            }
             break;
         case OSSL_FUNC_DIGEST_FINAL:
-            if (md->dfinal != NULL)
-                break;
-            md->dfinal = OSSL_get_OP_digest_final(fns);
-            fncnt++;
+            if (md->dfinal == NULL) {
+                md->dfinal = OSSL_get_OP_digest_final(fns);
+                fncnt++;
+            }
             break;
         case OSSL_FUNC_DIGEST_DIGEST:
-            if (md->digest != NULL)
-                break;
-            md->digest = OSSL_get_OP_digest_digest(fns);
+            if (md->digest == NULL)
+                md->digest = OSSL_get_OP_digest_digest(fns);
             /* We don't increment fnct for this as it is stand alone */
             break;
         case OSSL_FUNC_DIGEST_FREECTX:
-            if (md->freectx != NULL)
-                break;
-            md->freectx = OSSL_get_OP_digest_freectx(fns);
-            fncnt++;
+            if (md->freectx == NULL) {
+                md->freectx = OSSL_get_OP_digest_freectx(fns);
+                fncnt++;
+            }
             break;
         case OSSL_FUNC_DIGEST_DUPCTX:
-            if (md->dupctx != NULL)
-                break;
-            md->dupctx = OSSL_get_OP_digest_dupctx(fns);
+            if (md->dupctx == NULL)
+                md->dupctx = OSSL_get_OP_digest_dupctx(fns);
             break;
         case OSSL_FUNC_DIGEST_SIZE:
-            if (md->size != NULL)
-                break;
-            md->size = OSSL_get_OP_digest_size(fns);
+            if (md->size == NULL)
+                md->size = OSSL_get_OP_digest_size(fns);
             break;
         case OSSL_FUNC_DIGEST_BLOCK_SIZE:
-            if (md->dblock_size != NULL)
-                break;
-            md->dblock_size = OSSL_get_OP_digest_block_size(fns);
+            if (md->dblock_size == NULL)
+                md->dblock_size = OSSL_get_OP_digest_block_size(fns);
+            break;
+        case OSSL_FUNC_DIGEST_SET_PARAMS:
+            if (md->set_params == NULL)
+                md->set_params = OSSL_get_OP_digest_set_params(fns);
+            break;
+        case OSSL_FUNC_DIGEST_GET_PARAMS:
+            if (md->get_params == NULL)
+                md->get_params = OSSL_get_OP_digest_get_params(fns);
             break;
         }
     }