Implement OSSL_PROVIDER_get0_provider_ctx()
[openssl.git] / crypto / provider_core.c
index 662576cd7b1dfff13e44832764438742a881ad5e..f7af51a297421eff537f5fe65a495499a7000df0 100644 (file)
@@ -42,7 +42,8 @@ struct provider_store_st;        /* Forward declaration */
 struct ossl_provider_st {
     /* Flag bits */
     unsigned int flag_initialized:1;
-    unsigned int flag_fallback:1;
+    unsigned int flag_fallback:1; /* Can be used as fallback */
+    unsigned int flag_activated_as_fallback:1;
 
     /* OpenSSL library side data */
     CRYPTO_REF_COUNT refcnt;
@@ -71,6 +72,13 @@ struct ossl_provider_st {
     OSSL_provider_get_params_fn *get_params;
     OSSL_provider_query_operation_fn *query_operation;
 
+    /*
+     * Cache of bit to indicate of query_operation() has been called on
+     * a specific operation or not.
+     */
+    unsigned char *operation_bits;
+    size_t operation_bits_sz;
+
     /* Provider side data */
     void *provctx;
 };
@@ -97,6 +105,24 @@ struct provider_store_st {
     unsigned int use_fallbacks:1;
 };
 
+/*
+ * provider_deactivate_free() is a wrapper around ossl_provider_free()
+ * that also makes sure that activated fallback providers are deactivated.
+ * This is simply done by freeing them an extra time, to compensate for the
+ * refcount that provider_activate_fallbacks() gives them.
+ * Since this is only called when the provider store is being emptied, we
+ * don't need to care about any lock.
+ */
+static void provider_deactivate_free(OSSL_PROVIDER *prov)
+{
+    int extra_free = (prov->flag_initialized
+                      && prov->flag_activated_as_fallback);
+
+    if (extra_free)
+        ossl_provider_free(prov);
+    ossl_provider_free(prov);
+}
+
 static void provider_store_free(void *vstore)
 {
     struct provider_store_st *store = vstore;
@@ -104,7 +130,7 @@ static void provider_store_free(void *vstore)
     if (store == NULL)
         return;
     OPENSSL_free(store->default_path);
-    sk_OSSL_PROVIDER_pop_free(store->providers, ossl_provider_free);
+    sk_OSSL_PROVIDER_pop_free(store->providers, provider_deactivate_free);
     CRYPTO_THREAD_lock_free(store->lock);
     OPENSSL_free(store);
 }
@@ -317,6 +343,9 @@ void ossl_provider_free(OSSL_PROVIDER *prov)
             }
 # endif
 #endif
+            OPENSSL_free(prov->operation_bits);
+            prov->operation_bits = NULL;
+            prov->operation_bits_sz = 0;
             prov->flag_initialized = 0;
         }
 
@@ -644,13 +673,22 @@ static void provider_activate_fallbacks(struct provider_store_st *store)
             OSSL_PROVIDER *prov = sk_OSSL_PROVIDER_value(store->providers, i);
 
             /*
-             * Note that we don't care if the activation succeeds or not.
-             * If it doesn't succeed, then any attempt to use any of the
-             * fallback providers will fail anyway.
+             * Activated fallback providers get an extra refcount, to
+             * simulate a regular load.
+             * Note that we don't care if the activation succeeds or not,
+             * other than to maintain a correct refcount.  If the activation
+             * doesn't succeed, then any future attempt to use the fallback
+             * provider will fail anyway.
              */
             if (prov->flag_fallback) {
-                activated_fallback_count++;
-                provider_activate(prov);
+                if (ossl_provider_up_ref(prov)) {
+                    if (!provider_activate(prov)) {
+                        ossl_provider_free(prov);
+                    } else {
+                        prov->flag_activated_as_fallback = 1;
+                        activated_fallback_count++;
+                    }
+                }
             }
         }
 
@@ -749,6 +787,14 @@ const char *ossl_provider_module_path(const OSSL_PROVIDER *prov)
 #endif
 }
 
+void *ossl_provider_prov_ctx(const OSSL_PROVIDER *prov)
+{
+    if (prov != NULL)
+        return prov->provctx;
+
+    return NULL;
+}
+
 OPENSSL_CTX *ossl_provider_library_context(const OSSL_PROVIDER *prov)
 {
     /* TODO(3.0) just: return prov->libctx; */
@@ -782,6 +828,42 @@ const OSSL_ALGORITHM *ossl_provider_query_operation(const OSSL_PROVIDER *prov,
     return prov->query_operation(prov->provctx, operation_id, no_cache);
 }
 
+int ossl_provider_set_operation_bit(OSSL_PROVIDER *provider, size_t bitnum)
+{
+    size_t byte = bitnum / 8;
+    unsigned char bit = (1 << (bitnum % 8)) & 0xFF;
+
+    if (provider->operation_bits_sz <= byte) {
+        provider->operation_bits = OPENSSL_realloc(provider->operation_bits,
+                                                   byte + 1);
+        if (provider->operation_bits == NULL) {
+            ERR_raise(ERR_LIB_CRYPTO, ERR_R_MALLOC_FAILURE);
+            return 0;
+        }
+        memset(provider->operation_bits + provider->operation_bits_sz,
+               '\0', byte + 1 - provider->operation_bits_sz);
+    }
+    provider->operation_bits[byte] |= bit;
+    return 1;
+}
+
+int ossl_provider_test_operation_bit(OSSL_PROVIDER *provider, size_t bitnum,
+                                     int *result)
+{
+    size_t byte = bitnum / 8;
+    unsigned char bit = (1 << (bitnum % 8)) & 0xFF;
+
+    if (!ossl_assert(result != NULL)) {
+        ERR_raise(ERR_LIB_CRYPTO, ERR_R_PASSED_NULL_PARAMETER);
+        return 0;
+    }
+
+    *result = 0;
+    if (provider->operation_bits_sz > byte)
+        *result = ((provider->operation_bits[byte] & bit) != 0);
+    return 1;
+}
+
 /*-
  * Core functions for the provider
  * ===============================
@@ -795,8 +877,13 @@ const OSSL_ALGORITHM *ossl_provider_query_operation(const OSSL_PROVIDER *prov,
  * never knows.
  */
 static const OSSL_PARAM param_types[] = {
-    OSSL_PARAM_DEFN("openssl-version", OSSL_PARAM_UTF8_PTR, NULL, 0),
-    OSSL_PARAM_DEFN("provider-name", OSSL_PARAM_UTF8_PTR, NULL, 0),
+    OSSL_PARAM_DEFN(OSSL_PROV_PARAM_CORE_VERSION, OSSL_PARAM_UTF8_PTR, NULL, 0),
+    OSSL_PARAM_DEFN(OSSL_PROV_PARAM_CORE_PROV_NAME, OSSL_PARAM_UTF8_PTR,
+                    NULL, 0),
+#ifndef FIPS_MODULE
+    OSSL_PARAM_DEFN(OSSL_PROV_PARAM_CORE_MODULE_FILENAME, OSSL_PARAM_UTF8_PTR,
+                    NULL, 0),
+#endif
     OSSL_PARAM_END
 };
 
@@ -833,13 +920,14 @@ static int core_get_params(const OSSL_CORE_HANDLE *handle, OSSL_PARAM params[])
      */
     OSSL_PROVIDER *prov = (OSSL_PROVIDER *)handle;
 
-    if ((p = OSSL_PARAM_locate(params, "openssl-version")) != NULL)
+    if ((p = OSSL_PARAM_locate(params, OSSL_PROV_PARAM_CORE_VERSION)) != NULL)
         OSSL_PARAM_set_utf8_ptr(p, OPENSSL_VERSION_STR);
-    if ((p = OSSL_PARAM_locate(params, "provider-name")) != NULL)
+    if ((p = OSSL_PARAM_locate(params, OSSL_PROV_PARAM_CORE_PROV_NAME)) != NULL)
         OSSL_PARAM_set_utf8_ptr(p, prov->name);
 
 #ifndef FIPS_MODULE
-    if ((p = OSSL_PARAM_locate(params, OSSL_PROV_PARAM_MODULE_FILENAME)) != NULL)
+    if ((p = OSSL_PARAM_locate(params,
+                               OSSL_PROV_PARAM_CORE_MODULE_FILENAME)) != NULL)
         OSSL_PARAM_set_utf8_ptr(p, ossl_provider_module_path(prov));
 #endif