Add KEM (Key encapsulation mechanism) support to providers
[openssl.git] / crypto / evp / kem.c
1 /*
2  * Copyright 2020 The OpenSSL Project Authors. All Rights Reserved.
3  *
4  * Licensed under the Apache License 2.0 (the "License").  You may not use
5  * this file except in compliance with the License.  You can obtain a copy
6  * in the file LICENSE in the source distribution or at
7  * https://www.openssl.org/source/license.html
8  */
9
10 #include <stdio.h>
11 #include <stdlib.h>
12 #include <openssl/objects.h>
13 #include <openssl/evp.h>
14 #include "internal/cryptlib.h"
15 #include "crypto/evp.h"
16 #include "internal/provider.h"
17 #include "evp_local.h"
18
19 static int evp_kem_init(EVP_PKEY_CTX *ctx, int operation)
20 {
21     int ret = 0;
22     EVP_KEM *kem = NULL;
23     EVP_KEYMGMT *tmp_keymgmt = NULL;
24     void *provkey = NULL;
25     const char *supported_kem = NULL;
26
27     if (ctx == NULL || ctx->keytype == NULL) {
28         ERR_raise(ERR_LIB_EVP, EVP_R_INITIALIZATION_ERROR);
29         return 0;
30     }
31
32     evp_pkey_ctx_free_old_ops(ctx);
33     ctx->operation = operation;
34
35     /*
36      * Ensure that the key is provided, either natively, or as a cached export.
37      */
38     tmp_keymgmt = ctx->keymgmt;
39     provkey = evp_pkey_export_to_provider(ctx->pkey, ctx->libctx,
40                                           &tmp_keymgmt, ctx->propquery);
41     if (provkey == NULL
42         || !EVP_KEYMGMT_up_ref(tmp_keymgmt)) {
43         ERR_raise(ERR_LIB_EVP, EVP_R_INITIALIZATION_ERROR);
44         goto err;
45     }
46     EVP_KEYMGMT_free(ctx->keymgmt);
47     ctx->keymgmt = tmp_keymgmt;
48
49     if (ctx->keymgmt->query_operation_name != NULL)
50         supported_kem = ctx->keymgmt->query_operation_name(OSSL_OP_KEM);
51
52     /*
53      * If we didn't get a supported kem, assume there is one with the
54      * same name as the key type.
55      */
56     if (supported_kem == NULL)
57         supported_kem = ctx->keytype;
58
59     kem = EVP_KEM_fetch(ctx->libctx, supported_kem, ctx->propquery);
60     if (kem == NULL
61         || (EVP_KEYMGMT_provider(ctx->keymgmt) != EVP_KEM_provider(kem))) {
62         ERR_raise(ERR_LIB_EVP, EVP_R_OPERATION_NOT_SUPPORTED_FOR_THIS_KEYTYPE);
63         ret = -2;
64         goto err;
65     }
66
67     ctx->op.encap.kem = kem;
68     ctx->op.encap.kemprovctx = kem->newctx(ossl_provider_ctx(kem->prov));
69     if (ctx->op.encap.kemprovctx == NULL) {
70         /* The provider key can stay in the cache */
71         ERR_raise(ERR_LIB_EVP, EVP_R_INITIALIZATION_ERROR);
72         goto err;
73     }
74
75     switch (operation) {
76     case EVP_PKEY_OP_ENCAPSULATE:
77         if (kem->encapsulate_init == NULL) {
78             ERR_raise(ERR_LIB_EVP, EVP_R_OPERATION_NOT_SUPPORTED_FOR_THIS_KEYTYPE);
79             ret = -2;
80             goto err;
81         }
82         ret = kem->encapsulate_init(ctx->op.encap.kemprovctx, provkey);
83         break;
84     case EVP_PKEY_OP_DECAPSULATE:
85         if (kem->decapsulate_init == NULL) {
86             ERR_raise(ERR_LIB_EVP, EVP_R_OPERATION_NOT_SUPPORTED_FOR_THIS_KEYTYPE);
87             ret = -2;
88             goto err;
89         }
90         ret = kem->decapsulate_init(ctx->op.encap.kemprovctx, provkey);
91         break;
92     default:
93         ERR_raise(ERR_LIB_EVP, EVP_R_INITIALIZATION_ERROR);
94         goto err;
95     }
96
97     if (ret > 0)
98         return 1;
99  err:
100     if (ret <= 0) {
101         evp_pkey_ctx_free_old_ops(ctx);
102         ctx->operation = EVP_PKEY_OP_UNDEFINED;
103     }
104     return ret;
105 }
106
107 int EVP_PKEY_encapsulate_init(EVP_PKEY_CTX *ctx)
108 {
109     return evp_kem_init(ctx, EVP_PKEY_OP_ENCAPSULATE);
110 }
111
112 int EVP_PKEY_encapsulate(EVP_PKEY_CTX *ctx,
113                          unsigned char *out, size_t *outlen,
114                          unsigned char *secret, size_t *secretlen)
115 {
116     if (ctx == NULL)
117         return 0;
118
119     if (ctx->operation != EVP_PKEY_OP_ENCAPSULATE) {
120         EVPerr(0, EVP_R_OPERATON_NOT_INITIALIZED);
121         return -1;
122     }
123
124     if (ctx->op.encap.kemprovctx == NULL) {
125         EVPerr(0, EVP_R_OPERATION_NOT_SUPPORTED_FOR_THIS_KEYTYPE);
126         return -2;
127     }
128
129     if (out != NULL && secret == NULL)
130         return 0;
131
132     return ctx->op.encap.kem->encapsulate(ctx->op.encap.kemprovctx,
133                                           out, outlen, secret, secretlen);
134 }
135
136 int EVP_PKEY_decapsulate_init(EVP_PKEY_CTX *ctx)
137 {
138     return evp_kem_init(ctx, EVP_PKEY_OP_DECAPSULATE);
139 }
140
141 int EVP_PKEY_decapsulate(EVP_PKEY_CTX *ctx,
142                          unsigned char *secret, size_t *secretlen,
143                          const unsigned char *in, size_t inlen)
144 {
145     if (ctx == NULL
146         || (in == NULL || inlen == 0)
147         || (secret == NULL && secretlen == NULL))
148         return 0;
149
150     if (ctx->operation != EVP_PKEY_OP_DECAPSULATE) {
151         EVPerr(0, EVP_R_OPERATON_NOT_INITIALIZED);
152         return -1;
153     }
154
155     if (ctx->op.encap.kemprovctx == NULL) {
156         EVPerr(0, EVP_R_OPERATION_NOT_SUPPORTED_FOR_THIS_KEYTYPE);
157         return -2;
158     }
159     return ctx->op.encap.kem->decapsulate(ctx->op.encap.kemprovctx,
160                                           secret, secretlen, in, inlen);
161 }
162
163 static EVP_KEM *evp_kem_new(OSSL_PROVIDER *prov)
164 {
165     EVP_KEM *kem = OPENSSL_zalloc(sizeof(EVP_KEM));
166
167     if (kem == NULL) {
168         ERR_raise(ERR_LIB_EVP, ERR_R_MALLOC_FAILURE);
169         return NULL;
170     }
171
172     kem->lock = CRYPTO_THREAD_lock_new();
173     if (kem->lock == NULL) {
174         ERR_raise(ERR_LIB_EVP, ERR_R_MALLOC_FAILURE);
175         OPENSSL_free(kem);
176         return NULL;
177     }
178     kem->prov = prov;
179     ossl_provider_up_ref(prov);
180     kem->refcnt = 1;
181
182     return kem;
183 }
184
185 static void *evp_kem_from_dispatch(int name_id, const OSSL_DISPATCH *fns,
186                                    OSSL_PROVIDER *prov)
187 {
188     EVP_KEM *kem = NULL;
189     int ctxfncnt = 0, encfncnt = 0, decfncnt = 0;
190     int gparamfncnt = 0, sparamfncnt = 0;
191
192     if ((kem = evp_kem_new(prov)) == NULL) {
193         ERR_raise(ERR_LIB_EVP, ERR_R_MALLOC_FAILURE);
194         goto err;
195     }
196
197     kem->name_id = name_id;
198
199     for (; fns->function_id != 0; fns++) {
200         switch (fns->function_id) {
201         case OSSL_FUNC_KEM_NEWCTX:
202             if (kem->newctx != NULL)
203                 break;
204             kem->newctx = OSSL_FUNC_kem_newctx(fns);
205             ctxfncnt++;
206             break;
207         case OSSL_FUNC_KEM_ENCAPSULATE_INIT:
208             if (kem->encapsulate_init != NULL)
209                 break;
210             kem->encapsulate_init = OSSL_FUNC_kem_encapsulate_init(fns);
211             encfncnt++;
212             break;
213         case OSSL_FUNC_KEM_ENCAPSULATE:
214             if (kem->encapsulate != NULL)
215                 break;
216             kem->encapsulate = OSSL_FUNC_kem_encapsulate(fns);
217             encfncnt++;
218             break;
219         case OSSL_FUNC_KEM_DECAPSULATE_INIT:
220             if (kem->decapsulate_init != NULL)
221                 break;
222             kem->decapsulate_init = OSSL_FUNC_kem_decapsulate_init(fns);
223             decfncnt++;
224             break;
225         case OSSL_FUNC_KEM_DECAPSULATE:
226             if (kem->decapsulate != NULL)
227                 break;
228             kem->decapsulate = OSSL_FUNC_kem_decapsulate(fns);
229             decfncnt++;
230             break;
231         case OSSL_FUNC_KEM_FREECTX:
232             if (kem->freectx != NULL)
233                 break;
234             kem->freectx = OSSL_FUNC_kem_freectx(fns);
235             ctxfncnt++;
236             break;
237         case OSSL_FUNC_KEM_DUPCTX:
238             if (kem->dupctx != NULL)
239                 break;
240             kem->dupctx = OSSL_FUNC_kem_dupctx(fns);
241             break;
242         case OSSL_FUNC_KEM_GET_CTX_PARAMS:
243             if (kem->get_ctx_params != NULL)
244                 break;
245             kem->get_ctx_params
246                 = OSSL_FUNC_kem_get_ctx_params(fns);
247             gparamfncnt++;
248             break;
249         case OSSL_FUNC_KEM_GETTABLE_CTX_PARAMS:
250             if (kem->gettable_ctx_params != NULL)
251                 break;
252             kem->gettable_ctx_params
253                 = OSSL_FUNC_kem_gettable_ctx_params(fns);
254             gparamfncnt++;
255             break;
256         case OSSL_FUNC_KEM_SET_CTX_PARAMS:
257             if (kem->set_ctx_params != NULL)
258                 break;
259             kem->set_ctx_params
260                 = OSSL_FUNC_kem_set_ctx_params(fns);
261             sparamfncnt++;
262             break;
263         case OSSL_FUNC_KEM_SETTABLE_CTX_PARAMS:
264             if (kem->settable_ctx_params != NULL)
265                 break;
266             kem->settable_ctx_params
267                 = OSSL_FUNC_kem_settable_ctx_params(fns);
268             sparamfncnt++;
269             break;
270         }
271     }
272     if (ctxfncnt != 2
273         || (encfncnt != 0 && encfncnt != 2)
274         || (decfncnt != 0 && decfncnt != 2)
275         || (encfncnt != 2 && decfncnt != 2)
276         || (gparamfncnt != 0 && gparamfncnt != 2)
277         || (sparamfncnt != 0 && sparamfncnt != 2)) {
278         /*
279          * In order to be a consistent set of functions we must have at least
280          * a set of context functions (newctx and freectx) as well as a pair of
281          * "kem" functions: (encapsulate_init, encapsulate) or
282          * (decapsulate_init, decapsulate). set_ctx_params and settable_ctx_params are
283          * optional, but if one of them is present then the other one must also
284          * be present. The same applies to get_ctx_params and
285          * gettable_ctx_params. The dupctx function is optional.
286          */
287         ERR_raise(ERR_LIB_EVP, EVP_R_INVALID_PROVIDER_FUNCTIONS);
288         goto err;
289     }
290
291     return kem;
292  err:
293     EVP_KEM_free(kem);
294     return NULL;
295 }
296
297 void EVP_KEM_free(EVP_KEM *kem)
298 {
299     if (kem != NULL) {
300         int i;
301
302         CRYPTO_DOWN_REF(&kem->refcnt, &i, kem->lock);
303         if (i > 0)
304             return;
305         ossl_provider_free(kem->prov);
306         CRYPTO_THREAD_lock_free(kem->lock);
307         OPENSSL_free(kem);
308     }
309 }
310
311 int EVP_KEM_up_ref(EVP_KEM *kem)
312 {
313     int ref = 0;
314
315     CRYPTO_UP_REF(&kem->refcnt, &ref, kem->lock);
316     return 1;
317 }
318
319 OSSL_PROVIDER *EVP_KEM_provider(const EVP_KEM *kem)
320 {
321     return kem->prov;
322 }
323
324 EVP_KEM *EVP_KEM_fetch(OPENSSL_CTX *ctx, const char *algorithm,
325                        const char *properties)
326 {
327     return evp_generic_fetch(ctx, OSSL_OP_KEM, algorithm, properties,
328                              evp_kem_from_dispatch,
329                              (int (*)(void *))EVP_KEM_up_ref,
330                              (void (*)(void *))EVP_KEM_free);
331 }
332
333 int EVP_KEM_is_a(const EVP_KEM *kem, const char *name)
334 {
335     return evp_is_a(kem->prov, kem->name_id, NULL, name);
336 }
337
338 int EVP_KEM_number(const EVP_KEM *kem)
339 {
340     return kem->name_id;
341 }
342
343 void EVP_KEM_do_all_provided(OPENSSL_CTX *libctx,
344                              void (*fn)(EVP_KEM *kem, void *arg),
345                              void *arg)
346 {
347     evp_generic_do_all(libctx, OSSL_OP_KEM, (void (*)(void *, void *))fn, arg,
348                        evp_kem_from_dispatch,
349                        (void (*)(void *))EVP_KEM_free);
350 }
351
352
353 void EVP_KEM_names_do_all(const EVP_KEM *kem,
354                           void (*fn)(const char *name, void *data),
355                           void *data)
356 {
357     if (kem->prov != NULL)
358         evp_names_do_all(kem->prov, kem->name_id, fn, data);
359 }