Updated `rsa_has()` for correct validation
[openssl.git] / providers / implementations / keymgmt / rsa_kmgmt.c
index a075c54487ce2bd70ea5e6da8edb09862dfba6f2..7e67316deb300045a24299f7ccf9a9c685ea37a8 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright 2019-2021 The OpenSSL Project Authors. All Rights Reserved.
+ * Copyright 2019-2022 The OpenSSL Project Authors. All Rights Reserved.
  *
  * Licensed under the Apache License 2.0 (the "License").  You may not use
  * this file except in compliance with the License.  You can obtain a copy
@@ -122,13 +122,11 @@ static int rsa_has(const void *keydata, int selection)
     if ((selection & RSA_POSSIBLE_SELECTIONS) == 0)
         return 1; /* the selection is not missing */
 
-    if ((selection & OSSL_KEYMGMT_SELECT_OTHER_PARAMETERS) != 0)
-        /* This will change with OAEP */
-        ok = ok && (RSA_test_flags(rsa, RSA_FLAG_TYPE_RSASSAPSS) != 0);
+    /* OSSL_KEYMGMT_SELECT_OTHER_PARAMETERS are always available even if empty */
     if ((selection & OSSL_KEYMGMT_SELECT_KEYPAIR) != 0)
-        ok = ok && (RSA_get0_e(rsa) != NULL);
-    if ((selection & OSSL_KEYMGMT_SELECT_PUBLIC_KEY) != 0)
         ok = ok && (RSA_get0_n(rsa) != NULL);
+    if ((selection & OSSL_KEYMGMT_SELECT_PUBLIC_KEY) != 0)
+        ok = ok && (RSA_get0_e(rsa) != NULL);
     if ((selection & OSSL_KEYMGMT_SELECT_PRIVATE_KEY) != 0)
         ok = ok && (RSA_get0_d(rsa) != NULL);
     return ok;
@@ -145,10 +143,30 @@ static int rsa_match(const void *keydata1, const void *keydata2, int selection)
 
     /* There is always an |e| */
     ok = ok && BN_cmp(RSA_get0_e(rsa1), RSA_get0_e(rsa2)) == 0;
-    if ((selection & OSSL_KEYMGMT_SELECT_PUBLIC_KEY) != 0)
-        ok = ok && BN_cmp(RSA_get0_n(rsa1), RSA_get0_n(rsa2)) == 0;
-    if ((selection & OSSL_KEYMGMT_SELECT_PRIVATE_KEY) != 0)
-        ok = ok && BN_cmp(RSA_get0_d(rsa1), RSA_get0_d(rsa2)) == 0;
+    if ((selection & OSSL_KEYMGMT_SELECT_KEYPAIR) != 0) {
+        int key_checked = 0;
+
+        if ((selection & OSSL_KEYMGMT_SELECT_PUBLIC_KEY) != 0) {
+            const BIGNUM *pa = RSA_get0_n(rsa1);
+            const BIGNUM *pb = RSA_get0_n(rsa2);
+
+            if (pa != NULL && pb != NULL) {
+                ok = ok && BN_cmp(pa, pb) == 0;
+                key_checked = 1;
+            }
+        }
+        if (!key_checked
+            && (selection & OSSL_KEYMGMT_SELECT_PRIVATE_KEY) != 0) {
+            const BIGNUM *pa = RSA_get0_d(rsa1);
+            const BIGNUM *pb = RSA_get0_d(rsa2);
+
+            if (pa != NULL && pb != NULL) {
+                ok = ok && BN_cmp(pa, pb) == 0;
+                key_checked = 1;
+            }
+        }
+        ok = ok && key_checked;
+    }
     return ok;
 }
 
@@ -172,8 +190,12 @@ static int rsa_import(void *keydata, int selection, const OSSL_PARAM params[])
                                        &pss_defaults_set,
                                        params, rsa_type,
                                        ossl_rsa_get0_libctx(rsa));
-    if ((selection & OSSL_KEYMGMT_SELECT_KEYPAIR) != 0)
-        ok = ok && ossl_rsa_fromdata(rsa, params);
+    if ((selection & OSSL_KEYMGMT_SELECT_KEYPAIR) != 0) {
+        int include_private =
+            selection & OSSL_KEYMGMT_SELECT_PRIVATE_KEY ? 1 : 0;
+
+        ok = ok && ossl_rsa_fromdata(rsa, params, include_private);
+    }
 
     return ok;
 }
@@ -200,12 +222,17 @@ static int rsa_export(void *keydata, int selection,
     if ((selection & OSSL_KEYMGMT_SELECT_OTHER_PARAMETERS) != 0)
         ok = ok && (ossl_rsa_pss_params_30_is_unrestricted(pss_params)
                     || ossl_rsa_pss_params_30_todata(pss_params, tmpl, NULL));
-    if ((selection & OSSL_KEYMGMT_SELECT_KEYPAIR) != 0)
-        ok = ok && ossl_rsa_todata(rsa, tmpl, NULL);
+    if ((selection & OSSL_KEYMGMT_SELECT_KEYPAIR) != 0) {
+        int include_private =
+            selection & OSSL_KEYMGMT_SELECT_PRIVATE_KEY ? 1 : 0;
 
-    if (!ok
-        || (params = OSSL_PARAM_BLD_to_param(tmpl)) == NULL)
+        ok = ok && ossl_rsa_todata(rsa, tmpl, NULL, include_private);
+    }
+
+    if (!ok || (params = OSSL_PARAM_BLD_to_param(tmpl)) == NULL) {
+        ok = 0;
         goto err;
+    }
 
     ok = param_callback(params, cbarg);
     OSSL_PARAM_free(params);
@@ -345,7 +372,7 @@ static int rsa_get_params(void *key, OSSL_PARAM params[])
     }
     return (rsa_type != RSA_FLAG_TYPE_RSASSAPSS
             || ossl_rsa_pss_params_30_todata(pss_params, NULL, params))
-        && ossl_rsa_todata(rsa, NULL, params);
+        && ossl_rsa_todata(rsa, NULL, params, 1);
 }
 
 static const OSSL_PARAM rsa_params[] = {
@@ -436,19 +463,24 @@ static void *gen_init(void *provctx, int selection, int rsa_type,
         gctx->libctx = libctx;
         if ((gctx->pub_exp = BN_new()) == NULL
             || !BN_set_word(gctx->pub_exp, RSA_F4)) {
-            BN_free(gctx->pub_exp);
-            OPENSSL_free(gctx);
-            return NULL;
+            goto err;
         }
         gctx->nbits = 2048;
         gctx->primes = RSA_DEFAULT_PRIME_NUM;
         gctx->rsa_type = rsa_type;
+    } else {
+        goto err;
     }
-    if (!rsa_gen_set_params(gctx, params)) {
-        OPENSSL_free(gctx);
-        return NULL;
-    }
+
+    if (!rsa_gen_set_params(gctx, params))
+        goto err;
     return gctx;
+
+err:
+    if (gctx != NULL)
+        BN_free(gctx->pub_exp);
+    OPENSSL_free(gctx);
+    return NULL;
 }
 
 static void *rsa_gen_init(void *provctx, int selection,