Fix PKCS12_newpass() to work with PBES2.
authorslontis <shane.lontis@oracle.com>
Wed, 25 Jan 2023 01:25:33 +0000 (11:25 +1000)
committerPauli <pauli@openssl.org>
Tue, 14 Mar 2023 21:49:03 +0000 (08:49 +1100)
Fixes #19092

The code looks like it was written to work with PBES1.
As it had no tests, this would of then broken when PBES2
was introduced at a later point.

Also added libctx and propq support.

This affects the shroudedkeybag object.

Reviewed-by: Tomas Mraz <tomas@openssl.org>
Reviewed-by: Paul Dale <pauli@openssl.org>
(Merged from https://github.com/openssl/openssl/pull/20134)

crypto/pkcs12/p12_npas.c
test/pkcs12_api_test.c

index 62230bc6187ff57d3e2a6a7bd4bd0467150f6c32..90139100c6dbf96256ce1decabf7aeed24261c65 100644 (file)
 
 static int newpass_p12(PKCS12 *p12, const char *oldpass, const char *newpass);
 static int newpass_bags(STACK_OF(PKCS12_SAFEBAG) *bags, const char *oldpass,
-                        const char *newpass);
+                        const char *newpass,
+                        OSSL_LIB_CTX *libctx, const char *propq);
 static int newpass_bag(PKCS12_SAFEBAG *bag, const char *oldpass,
-                        const char *newpass);
+                        const char *newpass,
+                        OSSL_LIB_CTX *libctx, const char *propq);
 static int alg_get(const X509_ALGOR *alg, int *pnid, int *piter,
-                   int *psaltlen);
+                   int *psaltlen, int *cipherid);
 
 /*
  * Change the password on a PKCS#12 structure.
@@ -39,12 +41,12 @@ int PKCS12_newpass(PKCS12 *p12, const char *oldpass, const char *newpass)
     }
 
     /* Check the mac */
-
-    if (!PKCS12_verify_mac(p12, oldpass, -1)) {
-        ERR_raise(ERR_LIB_PKCS12, PKCS12_R_MAC_VERIFY_FAILURE);
-        return 0;
+    if (p12->mac != NULL) {
+        if (!PKCS12_verify_mac(p12, oldpass, -1)) {
+            ERR_raise(ERR_LIB_PKCS12, PKCS12_R_MAC_VERIFY_FAILURE);
+            return 0;
+        }
     }
-
     if (!newpass_p12(p12, oldpass, newpass)) {
         ERR_raise(ERR_LIB_PKCS12, PKCS12_R_PARSE_ERROR);
         return 0;
@@ -59,7 +61,7 @@ static int newpass_p12(PKCS12 *p12, const char *oldpass, const char *newpass)
 {
     STACK_OF(PKCS7) *asafes = NULL, *newsafes = NULL;
     STACK_OF(PKCS12_SAFEBAG) *bags = NULL;
-    int i, bagnid, pbe_nid = 0, pbe_iter = 0, pbe_saltlen = 0;
+    int i, bagnid, pbe_nid = 0, pbe_iter = 0, pbe_saltlen = 0, cipherid = NID_undef;
     PKCS7 *p7, *p7new;
     ASN1_OCTET_STRING *p12_data_tmp = NULL, *macoct = NULL;
     unsigned char mac[EVP_MAX_MD_SIZE];
@@ -72,27 +74,30 @@ static int newpass_p12(PKCS12 *p12, const char *oldpass, const char *newpass)
         goto err;
     for (i = 0; i < sk_PKCS7_num(asafes); i++) {
         p7 = sk_PKCS7_value(asafes, i);
+
         bagnid = OBJ_obj2nid(p7->type);
         if (bagnid == NID_pkcs7_data) {
             bags = PKCS12_unpack_p7data(p7);
         } else if (bagnid == NID_pkcs7_encrypted) {
             bags = PKCS12_unpack_p7encdata(p7, oldpass, -1);
             if (!alg_get(p7->d.encrypted->enc_data->algorithm,
-                         &pbe_nid, &pbe_iter, &pbe_saltlen))
+                         &pbe_nid, &pbe_iter, &pbe_saltlen, &cipherid))
                 goto err;
         } else {
             continue;
         }
         if (bags == NULL)
             goto err;
-        if (!newpass_bags(bags, oldpass, newpass))
+        if (!newpass_bags(bags, oldpass, newpass,
+                          p7->ctx.libctx, p7->ctx.propq))
             goto err;
         /* Repack bag in same form with new password */
         if (bagnid == NID_pkcs7_data)
             p7new = PKCS12_pack_p7data(bags);
         else
-            p7new = PKCS12_pack_p7encdata(pbe_nid, newpass, -1, NULL,
-                                          pbe_saltlen, pbe_iter, bags);
+            p7new = PKCS12_pack_p7encdata_ex(pbe_nid, newpass, -1, NULL,
+                                             pbe_saltlen, pbe_iter, bags,
+                                             p7->ctx.libctx, p7->ctx.propq);
         if (p7new == NULL || !sk_PKCS7_push(newsafes, p7new))
             goto err;
         sk_PKCS12_SAFEBAG_pop_free(bags, PKCS12_SAFEBAG_free);
@@ -107,11 +112,13 @@ static int newpass_p12(PKCS12 *p12, const char *oldpass, const char *newpass)
     if (!PKCS12_pack_authsafes(p12, newsafes))
         goto err;
 
-    if (!PKCS12_gen_mac(p12, newpass, -1, mac, &maclen))
-        goto err;
-    X509_SIG_getm(p12->mac->dinfo, NULL, &macoct);
-    if (!ASN1_OCTET_STRING_set(macoct, mac, maclen))
-        goto err;
+    if (p12->mac != NULL) {
+        if (!PKCS12_gen_mac(p12, newpass, -1, mac, &maclen))
+            goto err;
+        X509_SIG_getm(p12->mac->dinfo, NULL, &macoct);
+        if (!ASN1_OCTET_STRING_set(macoct, mac, maclen))
+            goto err;
+    }
 
     rv = 1;
 
@@ -130,11 +137,13 @@ err:
 }
 
 static int newpass_bags(STACK_OF(PKCS12_SAFEBAG) *bags, const char *oldpass,
-                        const char *newpass)
+                        const char *newpass,
+                        OSSL_LIB_CTX *libctx, const char *propq)
 {
     int i;
     for (i = 0; i < sk_PKCS12_SAFEBAG_num(bags); i++) {
-        if (!newpass_bag(sk_PKCS12_SAFEBAG_value(bags, i), oldpass, newpass))
+        if (!newpass_bag(sk_PKCS12_SAFEBAG_value(bags, i), oldpass, newpass,
+                         libctx, propq))
             return 0;
     }
     return 1;
@@ -143,26 +152,37 @@ static int newpass_bags(STACK_OF(PKCS12_SAFEBAG) *bags, const char *oldpass,
 /* Change password of safebag: only needs handle shrouded keybags */
 
 static int newpass_bag(PKCS12_SAFEBAG *bag, const char *oldpass,
-                       const char *newpass)
+                       const char *newpass,
+                       OSSL_LIB_CTX *libctx, const char *propq)
 {
+    EVP_CIPHER *cipher = NULL;
     PKCS8_PRIV_KEY_INFO *p8;
     X509_SIG *p8new;
-    int p8_nid, p8_saltlen, p8_iter;
+    int p8_nid, p8_saltlen, p8_iter, cipherid = 0;
     const X509_ALGOR *shalg;
 
     if (PKCS12_SAFEBAG_get_nid(bag) != NID_pkcs8ShroudedKeyBag)
         return 1;
 
-    if ((p8 = PKCS8_decrypt(bag->value.shkeybag, oldpass, -1)) == NULL)
+    if ((p8 = PKCS8_decrypt_ex(bag->value.shkeybag, oldpass, -1,
+                               libctx, propq)) == NULL)
         return 0;
     X509_SIG_get0(bag->value.shkeybag, &shalg, NULL);
-    if (!alg_get(shalg, &p8_nid, &p8_iter, &p8_saltlen)) {
+    if (!alg_get(shalg, &p8_nid, &p8_iter, &p8_saltlen, &cipherid)) {
         PKCS8_PRIV_KEY_INFO_free(p8);
         return 0;
     }
-    p8new = PKCS8_encrypt(p8_nid, NULL, newpass, -1, NULL, p8_saltlen,
-                          p8_iter, p8);
+    if (cipherid != NID_undef) {
+        cipher = EVP_CIPHER_fetch(libctx, OBJ_nid2sn(cipherid), propq);
+        if (cipher == NULL) {
+            PKCS8_PRIV_KEY_INFO_free(p8);
+            return 0;
+        }
+    }
+    p8new = PKCS8_encrypt_ex(p8_nid, cipher, newpass, -1, NULL, p8_saltlen,
+                             p8_iter, p8, libctx, propq);
     PKCS8_PRIV_KEY_INFO_free(p8);
+    EVP_CIPHER_free(cipher);
     if (p8new == NULL)
         return 0;
     X509_SIG_free(bag->value.shkeybag);
@@ -171,16 +191,69 @@ static int newpass_bag(PKCS12_SAFEBAG *bag, const char *oldpass,
 }
 
 static int alg_get(const X509_ALGOR *alg, int *pnid, int *piter,
-                   int *psaltlen)
+                   int *psaltlen, int *cipherid)
 {
-    PBEPARAM *pbe;
+    int ret = 0, pbenid, aparamtype;
+    int encnid, prfnid;
+    const ASN1_OBJECT *aoid;
+    const void *aparam;
+    PBEPARAM *pbe = NULL;
+    PBE2PARAM *pbe2 = NULL;
+    PBKDF2PARAM *kdf = NULL;
 
-    pbe = ASN1_TYPE_unpack_sequence(ASN1_ITEM_rptr(PBEPARAM), alg->parameter);
-    if (pbe == NULL)
-        return 0;
-    *pnid = OBJ_obj2nid(alg->algorithm);
-    *piter = ASN1_INTEGER_get(pbe->iter);
-    *psaltlen = pbe->salt->length;
-    PBEPARAM_free(pbe);
-    return 1;
+    X509_ALGOR_get0(&aoid, &aparamtype, &aparam, alg);
+    pbenid = OBJ_obj2nid(aoid);
+
+    switch (pbenid) {
+    case NID_pbes2:
+        if (aparamtype == V_ASN1_SEQUENCE)
+            pbe2 = ASN1_item_unpack(aparam, ASN1_ITEM_rptr(PBE2PARAM));
+        if (pbe2 == NULL)
+            goto done;
+
+        X509_ALGOR_get0(&aoid, &aparamtype, &aparam, pbe2->keyfunc);
+        pbenid = OBJ_obj2nid(aoid);
+        X509_ALGOR_get0(&aoid, NULL, NULL, pbe2->encryption);
+        encnid = OBJ_obj2nid(aoid);
+
+        if (aparamtype == V_ASN1_SEQUENCE)
+            kdf = ASN1_item_unpack(aparam, ASN1_ITEM_rptr(PBKDF2PARAM));
+        if (kdf == NULL)
+            goto done;
+
+        /* Only OCTET_STRING is supported */
+        if (kdf->salt->type != V_ASN1_OCTET_STRING)
+            goto done;
+
+        if (kdf->prf == NULL) {
+            prfnid = NID_hmacWithSHA1;
+        } else {
+            X509_ALGOR_get0(&aoid, NULL, NULL, kdf->prf);
+            prfnid = OBJ_obj2nid(aoid);
+        }
+        *psaltlen = kdf->salt->value.octet_string->length;
+        *piter = ASN1_INTEGER_get(kdf->iter);
+        *pnid = prfnid;
+        *cipherid = encnid;
+        break;
+    default:
+        pbe = ASN1_TYPE_unpack_sequence(ASN1_ITEM_rptr(PBEPARAM), alg->parameter);
+        if (pbe == NULL)
+            goto done;
+        *pnid = OBJ_obj2nid(alg->algorithm);
+        *piter = ASN1_INTEGER_get(pbe->iter);
+        *psaltlen = pbe->salt->length;
+        *cipherid = NID_undef;
+        ret = 1;
+        break;
+    }
+    ret = 1;
+done:
+    if (kdf != NULL)
+        PBKDF2PARAM_free(kdf);
+    if (pbe2 != NULL)
+        PBE2PARAM_free(pbe2);
+    if (pbe != NULL)
+        PBEPARAM_free(pbe);
+    return ret;
 }
index 71867844637cd79d9ccee7d3a5599965f04e0743..da023f364d08b2410c10900d23d6cf78df4594f4 100644 (file)
@@ -60,6 +60,46 @@ static const char *in_pass = "";
 static int has_key = 0;
 static int has_cert = 0;
 static int has_ca = 0;
+
+static int changepass(PKCS12 *p12, EVP_PKEY *key, X509 *cert, STACK_OF(X509) *ca)
+{
+    int ret = 0;
+    PKCS12 *p12new = NULL;
+    EVP_PKEY *key2 = NULL;
+    X509 *cert2 = NULL;
+    STACK_OF(X509) *ca2 = NULL;
+    BIO *bio = NULL;
+
+    if (!TEST_true(PKCS12_newpass(p12, in_pass, "NEWPASS")))
+        goto err;
+    if (!TEST_ptr(bio = BIO_new(BIO_s_mem())))
+        goto err;
+    if (!TEST_true(i2d_PKCS12_bio(bio, p12)))
+        goto err;
+    if (!TEST_ptr(p12new = PKCS12_init_ex(NID_pkcs7_data, testctx, "provider=default")))
+        goto err;
+    if (!TEST_ptr(d2i_PKCS12_bio(bio, &p12new)))
+        goto err;
+    if (!TEST_true(PKCS12_parse(p12new, "NEWPASS", &key2, &cert2, &ca2)))
+        goto err;
+    if (has_key) {
+        if (!TEST_ptr(key2) || !TEST_int_eq(EVP_PKEY_eq(key, key2), 1))
+            goto err;
+    }
+    if (has_cert) {
+        if (!TEST_ptr(cert2) || !TEST_int_eq(X509_cmp(cert, cert2), 0))
+            goto err;
+    }
+    ret = 1;
+err:
+    BIO_free(bio);
+    PKCS12_free(p12new);
+    EVP_PKEY_free(key2);
+    X509_free(cert2);
+    OSSL_STACK_OF_X509_free(ca2);
+    return ret;
+}
+
 static int pkcs12_parse_test(void)
 {
     int ret = 0;
@@ -82,8 +122,9 @@ static int pkcs12_parse_test(void)
             goto err;
         if ((has_ca && !TEST_ptr(ca)) || (!has_ca && !TEST_ptr_null(ca)))
             goto err;
+        if (has_key && !changepass(p12, key, cert, ca))
+            goto err;
     }
-
     ret = 1;
 err:
     PKCS12_free(p12);