Make the EVP Key Exchange code provider aware
[openssl.git] / crypto / evp / evp_lib.c
index 615206bdd03e7cb1cd4c07ace465cb0414100d21..3e64a1f93eed305f1d7d5c0b12d961317bb09a46 100644 (file)
@@ -13,6 +13,7 @@
 #include <openssl/objects.h>
 #include <openssl/params.h>
 #include <openssl/core_names.h>
+#include <openssl/dh.h>
 #include "internal/evp_int.h"
 #include "internal/provider.h"
 #include "evp_locl.h"
@@ -726,3 +727,133 @@ int EVP_hex2ctrl(int (*cb)(void *ctx, int cmd, void *buf, size_t buflen),
     OPENSSL_free(bin);
     return rv;
 }
+
+#ifndef FIPS_MODE
+/*
+ * TODO(3.0): Temporarily unavailable in FIPS mode. This will need to be added
+ * in later.
+ */
+
+#define MAX_PARAMS 10
+typedef struct {
+    /* Number of the current param */
+    size_t curr;
+    struct {
+        /* Key for the current param */
+        const char *key;
+        /* Value for the current param */
+        const BIGNUM *bnparam;
+        /* Size of the buffer required for the BN */
+        size_t bufsz;
+    } params[MAX_PARAMS];
+    /* Running count of the total size required */
+    size_t totsz;
+    int ispublic;
+} PARAMS_TEMPLATE;
+
+static int push_param_bn(PARAMS_TEMPLATE *tmpl, const char *key,
+                         const BIGNUM *bn)
+{
+    int sz;
+
+    sz = BN_num_bytes(bn);
+    if (sz <= 0)
+        return 0;
+    tmpl->params[tmpl->curr].key = key;
+    tmpl->params[tmpl->curr].bnparam = bn;
+    tmpl->params[tmpl->curr++].bufsz = (size_t)sz;
+    tmpl->totsz += sizeof(OSSL_PARAM) + (size_t)sz;
+
+    return 1;
+}
+
+static OSSL_PARAM *param_template_to_param(PARAMS_TEMPLATE *tmpl, size_t *sz)
+{
+    size_t i;
+    void *buf;
+    OSSL_PARAM *param = NULL;
+    unsigned char *currbuf = NULL;
+
+    if (tmpl->totsz == 0)
+        return NULL;
+
+    /* Add some space for the end of OSSL_PARAM marker */
+    tmpl->totsz += sizeof(*param);
+
+    if (tmpl->ispublic)
+        buf = OPENSSL_zalloc(tmpl->totsz);
+    else
+        buf = OPENSSL_secure_zalloc(tmpl->totsz);
+    if (buf == NULL)
+        return NULL;
+    param = buf;
+
+    currbuf = (unsigned char *)buf + (sizeof(*param) * (tmpl->curr + 1));
+
+    for (i = 0; i < tmpl->curr; i++) {
+        if (!ossl_assert((currbuf - (unsigned char *)buf )
+                         + tmpl->params[i].bufsz <= tmpl->totsz))
+            goto err;
+        if (BN_bn2nativepad(tmpl->params[i].bnparam, currbuf,
+                            tmpl->params[i].bufsz) < 0)
+            goto err;
+        param[i] = OSSL_PARAM_construct_BN(tmpl->params[i].key, currbuf,
+                                           tmpl->params[i].bufsz);
+        currbuf += tmpl->params[i].bufsz;
+    }
+    param[i] = OSSL_PARAM_construct_end();
+
+    if (sz != NULL)
+        *sz = tmpl->totsz;
+    return param;
+
+ err:
+    if (tmpl->ispublic)
+        OPENSSL_free(param);
+    else
+        OPENSSL_clear_free(param, tmpl->totsz);
+    return NULL;
+}
+
+static OSSL_PARAM *evp_pkey_dh_to_param(EVP_PKEY *pkey, size_t *sz)
+{
+    DH *dh = pkey->pkey.dh;
+    PARAMS_TEMPLATE tmpl = {0};
+    const BIGNUM *p = DH_get0_p(dh), *g = DH_get0_g(dh), *q = DH_get0_q(dh);
+    const BIGNUM *pub_key = DH_get0_pub_key(dh);
+    const BIGNUM *priv_key = DH_get0_priv_key(dh);
+
+    if (p == NULL || g == NULL || pub_key == NULL)
+        return NULL;
+
+    if (!push_param_bn(&tmpl, OSSL_PKEY_PARAM_DH_P, p)
+            || !push_param_bn(&tmpl, OSSL_PKEY_PARAM_DH_G, g)
+            || !push_param_bn(&tmpl, OSSL_PKEY_PARAM_DH_PUB_KEY, pub_key))
+        return NULL;
+
+    if (q != NULL) {
+        if (!push_param_bn(&tmpl, OSSL_PKEY_PARAM_DH_Q, q))
+            return NULL;
+    }
+
+    if (priv_key != NULL) {
+        if (!push_param_bn(&tmpl, OSSL_PKEY_PARAM_DH_PRIV_KEY, priv_key))
+            return NULL;
+    } else {
+        tmpl.ispublic = 1;
+    }
+
+    return param_template_to_param(&tmpl, sz);
+}
+
+OSSL_PARAM *evp_pkey_to_param(EVP_PKEY *pkey, size_t *sz)
+{
+    switch (pkey->type) {
+    case EVP_PKEY_DH:
+        return evp_pkey_dh_to_param(pkey, sz);
+    default:
+        return NULL;
+    }
+}
+
+#endif /* FIPS_MODE */