Add support for the TLS 1.3 signature_algorithms_cert extension
[openssl.git] / ssl / t1_lib.c
index 00ac5133b366244c5ce2a886a649178683e67aad..d4c9086e5af1357a2fb3cfb1b65c1816bf97da1b 100644 (file)
@@ -1131,7 +1131,8 @@ int tls1_set_server_sigalgs(SSL *s)
      * If peer sent no signature algorithms check to see if we support
      * the default algorithm for each certificate type
      */
-    if (s->s3->tmp.peer_sigalgs == NULL) {
+    if (s->s3->tmp.peer_cert_sigalgs == NULL
+            && s->s3->tmp.peer_sigalgs == NULL) {
         const uint16_t *sent_sigs;
         size_t sent_sigslen = tls12_get_psigalgs(s, 1, &sent_sigs);
 
@@ -1596,7 +1597,7 @@ int tls1_save_u16(PACKET *pkt, uint16_t **pdest, size_t *pdestlen)
     return 1;
 }
 
-int tls1_save_sigalgs(SSL *s, PACKET *pkt)
+int tls1_save_sigalgs(SSL *s, PACKET *pkt, int cert)
 {
     /* Extension ignored for inappropriate versions */
     if (!SSL_USE_SIGALGS(s))
@@ -1605,10 +1606,13 @@ int tls1_save_sigalgs(SSL *s, PACKET *pkt)
     if (s->cert == NULL)
         return 0;
 
-    return tls1_save_u16(pkt, &s->s3->tmp.peer_sigalgs,
-                         &s->s3->tmp.peer_sigalgslen);
+    if (cert)
+        return tls1_save_u16(pkt, &s->s3->tmp.peer_cert_sigalgs,
+                             &s->s3->tmp.peer_cert_sigalgslen);
+    else
+        return tls1_save_u16(pkt, &s->s3->tmp.peer_sigalgs,
+                             &s->s3->tmp.peer_sigalgslen);
 
-    return 1;
 }
 
 /* Set preferred digest for each key type */
@@ -1974,10 +1978,11 @@ int tls1_check_chain(SSL *s, X509 *x, EVP_PKEY *pk, STACK_OF(X509) *chain,
     if (TLS1_get_version(s) >= TLS1_2_VERSION && strict_mode) {
         int default_nid;
         int rsign = 0;
-        if (s->s3->tmp.peer_sigalgs)
+        if (s->s3->tmp.peer_cert_sigalgs != NULL
+                || s->s3->tmp.peer_sigalgs != NULL) {
             default_nid = 0;
         /* If no sigalgs extension use defaults from RFC5246 */
-        else {
+        else {
             switch (idx) {
             case SSL_PKEY_RSA:
                 rsign = EVP_PKEY_RSA;
@@ -2312,13 +2317,48 @@ static int tls12_get_cert_sigalg_idx(const SSL *s, const SIGALG_LOOKUP *lu)
     if (clu == NULL || !(clu->amask & s->s3->tmp.new_cipher->algorithm_auth))
         return -1;
 
-    /* If PSS and we have no PSS cert use RSA */
-    if (sig_idx == SSL_PKEY_RSA_PSS_SIGN && !ssl_has_cert(s, sig_idx))
-        sig_idx = SSL_PKEY_RSA;
-
     return s->s3->tmp.valid_flags[sig_idx] & CERT_PKEY_VALID ? sig_idx : -1;
 }
 
+/*
+ * Returns true if |s| has a usable certificate configured for use
+ * with signature scheme |sig|.
+ * "Usable" includes a check for presence as well as applying
+ * the signature_algorithm_cert restrictions sent by the peer (if any).
+ * Returns false if no usable certificate is found.
+ */
+static int has_usable_cert(SSL *s, const SIGALG_LOOKUP *sig, int idx)
+{
+    const SIGALG_LOOKUP *lu;
+    int mdnid, pknid;
+    size_t i;
+
+    /* TLS 1.2 callers can override lu->sig_idx, but not TLS 1.3 callers. */
+    if (idx == -1)
+        idx = sig->sig_idx;
+    if (!ssl_has_cert(s, idx))
+        return 0;
+    if (s->s3->tmp.peer_cert_sigalgs != NULL) {
+        for (i = 0; i < s->s3->tmp.peer_cert_sigalgslen; i++) {
+            lu = tls1_lookup_sigalg(s->s3->tmp.peer_cert_sigalgs[i]);
+            if (lu == NULL
+                || !X509_get_signature_info(s->cert->pkeys[idx].x509, &mdnid,
+                                            &pknid, NULL, NULL))
+                continue;
+            /*
+             * TODO this does not differentiate between the
+             * rsa_pss_pss_* and rsa_pss_rsae_* schemes since we do not
+             * have a chain here that lets us look at the key OID in the
+             * signing certificate.
+             */
+            if (mdnid == lu->hash && pknid == lu->sig)
+                return 1;
+        }
+        return 0;
+    }
+    return 1;
+}
+
 /*
  * Choose an appropriate signature algorithm based on available certificates
  * Sets chosen certificate and signature algorithm.
@@ -2355,14 +2395,9 @@ int tls_choose_sigalg(SSL *s, int fatalerrs)
                 || lu->sig == EVP_PKEY_DSA
                 || lu->sig == EVP_PKEY_RSA)
                 continue;
-            if (!tls1_lookup_md(lu, NULL))
+            /* Check that we have a cert, and signature_algorithms_cert */
+            if (!tls1_lookup_md(lu, NULL) || !has_usable_cert(s, lu, -1))
                 continue;
-            if (!ssl_has_cert(s, lu->sig_idx)) {
-                if (lu->sig_idx != SSL_PKEY_RSA_PSS_SIGN
-                        || !ssl_has_cert(s, SSL_PKEY_RSA))
-                    continue;
-                sig_idx = SSL_PKEY_RSA;
-            }
             if (lu->sig == EVP_PKEY_EC) {
 #ifndef OPENSSL_NO_EC
                 if (curve == -1) {
@@ -2381,21 +2416,8 @@ int tls_choose_sigalg(SSL *s, int fatalerrs)
             } else if (lu->sig == EVP_PKEY_RSA_PSS) {
                 /* validate that key is large enough for the signature algorithm */
                 EVP_PKEY *pkey;
-                int pkey_id;
 
-                if (sig_idx == -1)
-                    pkey = s->cert->pkeys[lu->sig_idx].privatekey;
-                else
-                    pkey = s->cert->pkeys[sig_idx].privatekey;
-                pkey_id = EVP_PKEY_id(pkey);
-                if (pkey_id != EVP_PKEY_RSA_PSS
-                    && pkey_id != EVP_PKEY_RSA)
-                    continue;
-                /*
-                 * The pkey type is EVP_PKEY_RSA_PSS or EVP_PKEY_RSA
-                 * EVP_PKEY_get0_RSA returns NULL if the type is not EVP_PKEY_RSA
-                 * so use EVP_PKEY_get0 instead
-                 */
+                pkey = s->cert->pkeys[lu->sig_idx].privatekey;
                 if (!rsa_pss_check_min_key_size(EVP_PKEY_get0(pkey), lu))
                     continue;
             }
@@ -2416,8 +2438,8 @@ int tls_choose_sigalg(SSL *s, int fatalerrs)
                 return 1;
 
         if (SSL_USE_SIGALGS(s)) {
+            size_t i;
             if (s->s3->tmp.peer_sigalgs != NULL) {
-                size_t i;
 #ifndef OPENSSL_NO_EC
                 int curve;
 
@@ -2444,21 +2466,16 @@ int tls_choose_sigalg(SSL *s, int fatalerrs)
                         int cc_idx = s->cert->key - s->cert->pkeys;
 
                         sig_idx = lu->sig_idx;
-                        if (cc_idx != sig_idx) {
-                            if (sig_idx != SSL_PKEY_RSA_PSS_SIGN
-                                || cc_idx != SSL_PKEY_RSA)
-                                continue;
-                            sig_idx = SSL_PKEY_RSA;
-                        }
+                        if (cc_idx != sig_idx)
+                            continue;
                     }
+                    /* Check that we have a cert, and sig_algs_cert */
+                    if (!has_usable_cert(s, lu, sig_idx))
+                        continue;
                     if (lu->sig == EVP_PKEY_RSA_PSS) {
                         /* validate that key is large enough for the signature algorithm */
                         EVP_PKEY *pkey = s->cert->pkeys[sig_idx].privatekey;
-                        int pkey_id = EVP_PKEY_id(pkey);
 
-                        if (pkey_id != EVP_PKEY_RSA_PSS
-                            && pkey_id != EVP_PKEY_RSA)
-                            continue;
                         if (!rsa_pss_check_min_key_size(EVP_PKEY_get0(pkey), lu))
                             continue;
                     }
@@ -2479,7 +2496,7 @@ int tls_choose_sigalg(SSL *s, int fatalerrs)
                  * If we have no sigalg use defaults
                  */
                 const uint16_t *sent_sigs;
-                size_t sent_sigslen, i;
+                size_t sent_sigslen;
 
                 if ((lu = tls1_get_legacy_sigalg(s, -1)) == NULL) {
                     if (!fatalerrs)
@@ -2492,7 +2509,8 @@ int tls_choose_sigalg(SSL *s, int fatalerrs)
                 /* Check signature matches a type we sent */
                 sent_sigslen = tls12_get_psigalgs(s, 1, &sent_sigs);
                 for (i = 0; i < sent_sigslen; i++, sent_sigs++) {
-                    if (lu->sigalg == *sent_sigs)
+                    if (lu->sigalg == *sent_sigs
+                            && has_usable_cert(s, lu, lu->sig_idx))
                         break;
                 }
                 if (i == sent_sigslen) {