add OSSL_STACK_OF_X509_free() for commonly used pattern
[openssl.git] / crypto / ts / ts_rsp_verify.c
index 6798fc8263bc32b7783a3ac62bb1e6fb3b419c90..410f6882559db69e3448e9988f5e289a7debd9e3 100644 (file)
@@ -8,17 +8,18 @@
  */
 
 #include <stdio.h>
-#include "internal/cryptlib.h"
 #include <openssl/objects.h>
 #include <openssl/ts.h>
 #include <openssl/pkcs7.h>
-#include "ts_local.h"
+#include "internal/cryptlib.h"
+#include "internal/sizes.h"
 #include "crypto/ess.h"
+#include "ts_local.h"
 
 static int ts_verify_cert(X509_STORE *store, STACK_OF(X509) *untrusted,
                           X509 *signer, STACK_OF(X509) **chain);
-static int ts_check_signing_certs(PKCS7_SIGNER_INFO *si,
-                                  STACK_OF(X509) *chain);
+static int ts_check_signing_certs(const PKCS7_SIGNER_INFO *si,
+                                  const STACK_OF(X509) *chain);
 
 static int int_ts_RESP_verify_token(TS_VERIFY_CTX *ctx,
                                     PKCS7 *token, TS_TST_INFO *tst_info);
@@ -157,7 +158,7 @@ int TS_RESP_verify_signature(PKCS7 *token, STACK_OF(X509) *certs,
  err:
     BIO_free_all(p7bio);
     sk_X509_free(untrusted);
-    sk_X509_pop_free(chain, X509_free);
+    OSSL_STACK_OF_X509_free(chain);
     sk_X509_free(signers);
 
     return ret;
@@ -202,37 +203,38 @@ end:
     return ret;
 }
 
-static int ts_check_signing_certs(PKCS7_SIGNER_INFO *si,
-                                  STACK_OF(X509) *chain)
+static ESS_SIGNING_CERT *ossl_ess_get_signing_cert(const PKCS7_SIGNER_INFO *si)
 {
-    ESS_SIGNING_CERT *ss = ossl_ess_signing_cert_get(si);
-    ESS_SIGNING_CERT_V2 *ssv2 = ossl_ess_signing_cert_v2_get(si);
-    int i, j;
-    int ret = 0;
+    ASN1_TYPE *attr;
+    const unsigned char *p;
+
+    attr = PKCS7_get_signed_attribute(si, NID_id_smime_aa_signingCertificate);
+    if (attr == NULL)
+        return NULL;
+    p = attr->value.sequence->data;
+    return d2i_ESS_SIGNING_CERT(NULL, &p, attr->value.sequence->length);
+}
 
-    /*
-     * Check if first ESSCertIDs matches signer cert
-     * and each further ESSCertIDs matches any cert in the chain.
-     */
-    if (ss != NULL)
-        for (i = 0; i < sk_ESS_CERT_ID_num(ss->cert_ids); i++) {
-            j = ossl_ess_find_cid(chain, sk_ESS_CERT_ID_value(ss->cert_ids, i),
-                                  NULL);
-            if (j < 0 || (i == 0 && j != 0))
-                goto err;
-        }
-    if (ssv2 != NULL)
-        for (i = 0; i < sk_ESS_CERT_ID_V2_num(ssv2->cert_ids); i++) {
-            j = ossl_ess_find_cid(chain, NULL,
-                                  sk_ESS_CERT_ID_V2_value(ssv2->cert_ids, i));
-            if (j < 0 || (i == 0 && j != 0))
-                goto err;
-        }
-    ret = 1;
+static
+ESS_SIGNING_CERT_V2 *ossl_ess_get_signing_cert_v2(const PKCS7_SIGNER_INFO *si)
+{
+    ASN1_TYPE *attr;
+    const unsigned char *p;
+
+    attr = PKCS7_get_signed_attribute(si, NID_id_smime_aa_signingCertificateV2);
+    if (attr == NULL)
+        return NULL;
+    p = attr->value.sequence->data;
+    return d2i_ESS_SIGNING_CERT_V2(NULL, &p, attr->value.sequence->length);
+}
+
+static int ts_check_signing_certs(const PKCS7_SIGNER_INFO *si,
+                                  const STACK_OF(X509) *chain)
+{
+    ESS_SIGNING_CERT *ss = ossl_ess_get_signing_cert(si);
+    ESS_SIGNING_CERT_V2 *ssv2 = ossl_ess_get_signing_cert_v2(si);
+    int ret = OSSL_ESS_check_signing_certs(ss, ssv2, chain, 1) > 0;
 
- err:
-    if (!ret)
-        ERR_raise(ERR_LIB_TS, TS_R_ESS_SIGNING_CERTIFICATE_ERROR);
     ESS_SIGNING_CERT_free(ss);
     ESS_SIGNING_CERT_V2_free(ssv2);
     return ret;
@@ -397,7 +399,7 @@ static int ts_check_status_info(TS_RESP *response)
 
 static char *ts_get_status_text(STACK_OF(ASN1_UTF8STRING) *text)
 {
-    return sk_ASN1_UTF8STRING2text(text, "/", TS_MAX_STATUS_LENGTH);
+    return ossl_sk_ASN1_UTF8STRING2text(text, "/", TS_MAX_STATUS_LENGTH);
 }
 
 static int ts_check_policy(const ASN1_OBJECT *req_oid,
@@ -419,9 +421,10 @@ static int ts_compute_imprint(BIO *data, TS_TST_INFO *tst_info,
 {
     TS_MSG_IMPRINT *msg_imprint = tst_info->msg_imprint;
     X509_ALGOR *md_alg_resp = msg_imprint->hash_algo;
-    const EVP_MD *md;
+    EVP_MD *md = NULL;
     EVP_MD_CTX *md_ctx = NULL;
     unsigned char buffer[4096];
+    char name[OSSL_MAX_NAME_SIZE];
     int length;
 
     *md_alg = NULL;
@@ -429,11 +432,22 @@ static int ts_compute_imprint(BIO *data, TS_TST_INFO *tst_info,
 
     if ((*md_alg = X509_ALGOR_dup(md_alg_resp)) == NULL)
         goto err;
-    if ((md = EVP_get_digestbyobj((*md_alg)->algorithm)) == NULL) {
-        ERR_raise(ERR_LIB_TS, TS_R_UNSUPPORTED_MD_ALGORITHM);
+
+    OBJ_obj2txt(name, sizeof(name), md_alg_resp->algorithm, 0);
+
+    (void)ERR_set_mark();
+    md = EVP_MD_fetch(NULL, name, NULL);
+
+    if (md == NULL)
+        md = (EVP_MD *)EVP_get_digestbyname(name);
+
+    if (md == NULL) {
+        (void)ERR_clear_last_mark();
         goto err;
     }
-    length = EVP_MD_size(md);
+    (void)ERR_pop_to_mark();
+
+    length = EVP_MD_get_size(md);
     if (length < 0)
         goto err;
     *imprint_len = length;
@@ -449,6 +463,8 @@ static int ts_compute_imprint(BIO *data, TS_TST_INFO *tst_info,
     }
     if (!EVP_DigestInit(md_ctx, md))
         goto err;
+    EVP_MD_free(md);
+    md = NULL;
     while ((length = BIO_read(data, buffer, sizeof(buffer))) > 0) {
         if (!EVP_DigestUpdate(md_ctx, buffer, length))
             goto err;
@@ -460,7 +476,9 @@ static int ts_compute_imprint(BIO *data, TS_TST_INFO *tst_info,
     return 1;
  err:
     EVP_MD_CTX_free(md_ctx);
+    EVP_MD_free(md);
     X509_ALGOR_free(*md_alg);
+    *md_alg = NULL;
     OPENSSL_free(*imprint);
     *imprint_len = 0;
     *imprint = 0;