Add `for_comp` flag when retrieving certs for compression
[openssl.git] / ssl / statem / statem_lib.c
index 9712fb37355e4b219293ca81501269d1594bc66e..b37633e37f05e2267d9a4825a76341c33d33f8d9 100644 (file)
@@ -906,25 +906,30 @@ CON_FUNC_RETURN tls_construct_change_cipher_spec(SSL_CONNECTION *s, WPACKET *pkt
 
 /* Add a certificate to the WPACKET */
 static int ssl_add_cert_to_wpacket(SSL_CONNECTION *s, WPACKET *pkt,
-                                   X509 *x, int chain)
+                                   X509 *x, int chain, int for_comp)
 {
     int len;
     unsigned char *outbytes;
+    int context = SSL_EXT_TLS1_3_CERTIFICATE;
+
+    if (for_comp)
+        context |= SSL_EXT_TLS1_3_CERTIFICATE_COMPRESSION;
 
     len = i2d_X509(x, NULL);
     if (len < 0) {
-        SSLfatal(s, SSL_AD_INTERNAL_ERROR, ERR_R_BUF_LIB);
+        if (!for_comp)
+            SSLfatal(s, SSL_AD_INTERNAL_ERROR, ERR_R_BUF_LIB);
         return 0;
     }
     if (!WPACKET_sub_allocate_bytes_u24(pkt, len, &outbytes)
             || i2d_X509(x, &outbytes) != len) {
-        SSLfatal(s, SSL_AD_INTERNAL_ERROR, ERR_R_INTERNAL_ERROR);
+        if (!for_comp)
+            SSLfatal(s, SSL_AD_INTERNAL_ERROR, ERR_R_INTERNAL_ERROR);
         return 0;
     }
 
-    if (SSL_CONNECTION_IS_TLS13(s)
-            && !tls_construct_extensions(s, pkt, SSL_EXT_TLS1_3_CERTIFICATE, x,
-                                         chain)) {
+    if ((SSL_CONNECTION_IS_TLS13(s) || for_comp)
+            && !tls_construct_extensions(s, pkt, context, x, chain)) {
         /* SSLfatal() already called */
         return 0;
     }
@@ -933,7 +938,7 @@ static int ssl_add_cert_to_wpacket(SSL_CONNECTION *s, WPACKET *pkt,
 }
 
 /* Add certificate chain to provided WPACKET */
-static int ssl_add_cert_chain(SSL_CONNECTION *s, WPACKET *pkt, CERT_PKEY *cpk)
+static int ssl_add_cert_chain(SSL_CONNECTION *s, WPACKET *pkt, CERT_PKEY *cpk, int for_comp)
 {
     int i, chain_count;
     X509 *x;
@@ -967,12 +972,14 @@ static int ssl_add_cert_chain(SSL_CONNECTION *s, WPACKET *pkt, CERT_PKEY *cpk)
                                                        sctx->propq);
 
         if (xs_ctx == NULL) {
-            SSLfatal(s, SSL_AD_INTERNAL_ERROR, ERR_R_X509_LIB);
+            if (!for_comp)
+                SSLfatal(s, SSL_AD_INTERNAL_ERROR, ERR_R_X509_LIB);
             return 0;
         }
         if (!X509_STORE_CTX_init(xs_ctx, chain_store, x, NULL)) {
             X509_STORE_CTX_free(xs_ctx);
-            SSLfatal(s, SSL_AD_INTERNAL_ERROR, ERR_R_X509_LIB);
+            if (!for_comp)
+                SSLfatal(s, SSL_AD_INTERNAL_ERROR, ERR_R_X509_LIB);
             return 0;
         }
         /*
@@ -994,14 +1001,15 @@ static int ssl_add_cert_chain(SSL_CONNECTION *s, WPACKET *pkt, CERT_PKEY *cpk)
             ERR_raise(ERR_LIB_SSL, SSL_R_CA_MD_TOO_WEAK);
 #endif
             X509_STORE_CTX_free(xs_ctx);
-            SSLfatal(s, SSL_AD_INTERNAL_ERROR, i);
+            if (!for_comp)
+                SSLfatal(s, SSL_AD_INTERNAL_ERROR, i);
             return 0;
         }
         chain_count = sk_X509_num(chain);
         for (i = 0; i < chain_count; i++) {
             x = sk_X509_value(chain, i);
 
-            if (!ssl_add_cert_to_wpacket(s, pkt, x, i)) {
+            if (!ssl_add_cert_to_wpacket(s, pkt, x, i, for_comp)) {
                 /* SSLfatal() already called */
                 X509_STORE_CTX_free(xs_ctx);
                 return 0;
@@ -1011,16 +1019,17 @@ static int ssl_add_cert_chain(SSL_CONNECTION *s, WPACKET *pkt, CERT_PKEY *cpk)
     } else {
         i = ssl_security_cert_chain(s, extra_certs, x, 0);
         if (i != 1) {
-            SSLfatal(s, SSL_AD_INTERNAL_ERROR, i);
+            if (!for_comp)
+                SSLfatal(s, SSL_AD_INTERNAL_ERROR, i);
             return 0;
         }
-        if (!ssl_add_cert_to_wpacket(s, pkt, x, 0)) {
+        if (!ssl_add_cert_to_wpacket(s, pkt, x, 0, for_comp)) {
             /* SSLfatal() already called */
             return 0;
         }
         for (i = 0; i < sk_X509_num(extra_certs); i++) {
             x = sk_X509_value(extra_certs, i);
-            if (!ssl_add_cert_to_wpacket(s, pkt, x, i + 1)) {
+            if (!ssl_add_cert_to_wpacket(s, pkt, x, i + 1, for_comp)) {
                 /* SSLfatal() already called */
                 return 0;
             }
@@ -1030,18 +1039,20 @@ static int ssl_add_cert_chain(SSL_CONNECTION *s, WPACKET *pkt, CERT_PKEY *cpk)
 }
 
 unsigned long ssl3_output_cert_chain(SSL_CONNECTION *s, WPACKET *pkt,
-                                     CERT_PKEY *cpk)
+                                     CERT_PKEY *cpk, int for_comp)
 {
     if (!WPACKET_start_sub_packet_u24(pkt)) {
-        SSLfatal(s, SSL_AD_INTERNAL_ERROR, ERR_R_INTERNAL_ERROR);
+        if (!for_comp)
+            SSLfatal(s, SSL_AD_INTERNAL_ERROR, ERR_R_INTERNAL_ERROR);
         return 0;
     }
 
-    if (!ssl_add_cert_chain(s, pkt, cpk))
+    if (!ssl_add_cert_chain(s, pkt, cpk, for_comp))
         return 0;
 
     if (!WPACKET_close(pkt)) {
-        SSLfatal(s, SSL_AD_INTERNAL_ERROR, ERR_R_INTERNAL_ERROR);
+        if (!for_comp)
+            SSLfatal(s, SSL_AD_INTERNAL_ERROR, ERR_R_INTERNAL_ERROR);
         return 0;
     }