DECODER: Allow precise result type for OSSL_DECODER_CTX_new_by_EVP_PKEY()
[openssl.git] / crypto / encode_decode / decoder_pkey.c
index 64ea4e2c3f204f81073437e62bb0dd2f16bb4c98..75c491f4acb65434e6ba9a2db0015443b44a72a3 100644 (file)
@@ -62,9 +62,8 @@ static int decoder_construct_EVP_PKEY(OSSL_DECODER_INSTANCE *decoder_inst,
                                       void *construct_data)
 {
     struct decoder_EVP_PKEY_data_st *data = construct_data;
-    OSSL_DECODER *decoder =
-        OSSL_DECODER_INSTANCE_decoder(decoder_inst);
-    void *decoderctx = OSSL_DECODER_INSTANCE_decoder_ctx(decoder_inst);
+    OSSL_DECODER *decoder = OSSL_DECODER_INSTANCE_get_decoder(decoder_inst);
+    void *decoderctx = OSSL_DECODER_INSTANCE_get_decoder_ctx(decoder_inst);
     size_t i, end_i;
     /*
      * |object_ref| points to a provider reference to an object, its exact
@@ -187,10 +186,9 @@ static void decoder_clean_EVP_PKEY_construct_arg(void *construct_data)
     }
 }
 
-DEFINE_STACK_OF_CSTRING()
-
 struct collected_data_st {
     struct decoder_EVP_PKEY_data_st *process_data;
+    const char *keytype;
     STACK_OF(OPENSSL_CSTRING) *names;
     OSSL_DECODER_CTX *ctx;
 
@@ -201,6 +199,8 @@ static void collect_keymgmt(EVP_KEYMGMT *keymgmt, void *arg)
 {
     struct collected_data_st *data = arg;
 
+    if (data->keytype != NULL && !EVP_KEYMGMT_is_a(keymgmt, data->keytype))
+        return;
     if (data->error_occured)
         return;
 
@@ -209,7 +209,7 @@ static void collect_keymgmt(EVP_KEYMGMT *keymgmt, void *arg)
     if (!EVP_KEYMGMT_up_ref(keymgmt) /* ref++ */)
         return;
     if (sk_EVP_KEYMGMT_push(data->process_data->keymgmts, keymgmt) <= 0) {
-        EVP_KEYMGMT_free(keymgmt); /* ref-- */
+        EVP_KEYMGMT_free(keymgmt);   /* ref-- */
         return;
     }
 
@@ -256,7 +256,7 @@ static void collect_decoder(OSSL_DECODER *decoder, void *arg)
 }
 
 int ossl_decoder_ctx_setup_for_EVP_PKEY(OSSL_DECODER_CTX *ctx,
-                                        EVP_PKEY **pkey,
+                                        EVP_PKEY **pkey, const char *keytype,
                                         OPENSSL_CTX *libctx,
                                         const char *propquery)
 {
@@ -267,14 +267,14 @@ int ossl_decoder_ctx_setup_for_EVP_PKEY(OSSL_DECODER_CTX *ctx,
     if ((data = OPENSSL_zalloc(sizeof(*data))) == NULL
         || (data->process_data =
             OPENSSL_zalloc(sizeof(*data->process_data))) == NULL
-        || (data->process_data->keymgmts
-            = sk_EVP_KEYMGMT_new_null()) == NULL
+        || (data->process_data->keymgmts = sk_EVP_KEYMGMT_new_null()) == NULL
         || (data->names = sk_OPENSSL_CSTRING_new_null()) == NULL) {
         ERR_raise(ERR_LIB_OSSL_DECODER, ERR_R_MALLOC_FAILURE);
         goto err;
     }
     data->process_data->object = (void **)pkey;
     data->ctx = ctx;
+    data->keytype = keytype;
 
     /* First, find all keymgmts to form goals */
     EVP_KEYMGMT_do_all_provided(libctx, collect_keymgmt, data);
@@ -303,17 +303,16 @@ int ossl_decoder_ctx_setup_for_EVP_PKEY(OSSL_DECODER_CTX *ctx,
     if (data->error_occured)
         goto err;
 
-    /* If we found no decoders to match the keymgmts, we err */
-    if (OSSL_DECODER_CTX_num_decoders(ctx) == 0)
-        goto err;
+    if (OSSL_DECODER_CTX_get_num_decoders(ctx) != 0) {
+        if (!OSSL_DECODER_CTX_set_construct(ctx, decoder_construct_EVP_PKEY)
+            || !OSSL_DECODER_CTX_set_construct_data(ctx, data->process_data)
+            || !OSSL_DECODER_CTX_set_cleanup(ctx,
+                                             decoder_clean_EVP_PKEY_construct_arg))
+            goto err;
 
-    if (!OSSL_DECODER_CTX_set_construct(ctx, decoder_construct_EVP_PKEY)
-        || !OSSL_DECODER_CTX_set_construct_data(ctx, data->process_data)
-        || !OSSL_DECODER_CTX_set_cleanup(ctx,
-                                         decoder_clean_EVP_PKEY_construct_arg))
-        goto err;
+        data->process_data = NULL; /* Avoid it being freed */
+    }
 
-    data->process_data = NULL;
     ok = 1;
  err:
     if (data != NULL) {
@@ -324,10 +323,10 @@ int ossl_decoder_ctx_setup_for_EVP_PKEY(OSSL_DECODER_CTX *ctx,
     return ok;
 }
 
-OSSL_DECODER_CTX *OSSL_DECODER_CTX_new_by_EVP_PKEY(EVP_PKEY **pkey,
-                                                   const char *input_type,
-                                                   OPENSSL_CTX *libctx,
-                                                   const char *propquery)
+OSSL_DECODER_CTX *
+OSSL_DECODER_CTX_new_by_EVP_PKEY(EVP_PKEY **pkey,
+                                 const char *input_type, const char *keytype,
+                                 OPENSSL_CTX *libctx, const char *propquery)
 {
     OSSL_DECODER_CTX *ctx = NULL;
 
@@ -336,7 +335,8 @@ OSSL_DECODER_CTX *OSSL_DECODER_CTX_new_by_EVP_PKEY(EVP_PKEY **pkey,
         return NULL;
     }
     if (OSSL_DECODER_CTX_set_input_type(ctx, input_type)
-        && ossl_decoder_ctx_setup_for_EVP_PKEY(ctx, pkey, libctx, propquery)
+        && ossl_decoder_ctx_setup_for_EVP_PKEY(ctx, pkey, keytype,
+                                               libctx, propquery)
         && OSSL_DECODER_CTX_add_extra(ctx, libctx, propquery))
         return ctx;