Fix uninitialized read in sigalg parsing code
[openssl.git] / ssl / t1_lib.c
index 7f896d58d31fbd6128aebf04a78163a7e56d7880..7109741a7dca1b4aa1682b330d77a9b7e667e218 100644 (file)
@@ -1131,7 +1131,8 @@ int tls1_set_server_sigalgs(SSL *s)
      * If peer sent no signature algorithms check to see if we support
      * the default algorithm for each certificate type
      */
-    if (s->s3->tmp.peer_sigalgs == NULL) {
+    if (s->s3->tmp.peer_cert_sigalgs == NULL
+            && s->s3->tmp.peer_sigalgs == NULL) {
         const uint16_t *sent_sigs;
         size_t sent_sigslen = tls12_get_psigalgs(s, 1, &sent_sigs);
 
@@ -1596,7 +1597,7 @@ int tls1_save_u16(PACKET *pkt, uint16_t **pdest, size_t *pdestlen)
     return 1;
 }
 
-int tls1_save_sigalgs(SSL *s, PACKET *pkt)
+int tls1_save_sigalgs(SSL *s, PACKET *pkt, int cert)
 {
     /* Extension ignored for inappropriate versions */
     if (!SSL_USE_SIGALGS(s))
@@ -1605,10 +1606,13 @@ int tls1_save_sigalgs(SSL *s, PACKET *pkt)
     if (s->cert == NULL)
         return 0;
 
-    return tls1_save_u16(pkt, &s->s3->tmp.peer_sigalgs,
-                         &s->s3->tmp.peer_sigalgslen);
+    if (cert)
+        return tls1_save_u16(pkt, &s->s3->tmp.peer_cert_sigalgs,
+                             &s->s3->tmp.peer_cert_sigalgslen);
+    else
+        return tls1_save_u16(pkt, &s->s3->tmp.peer_sigalgs,
+                             &s->s3->tmp.peer_sigalgslen);
 
-    return 1;
 }
 
 /* Set preferred digest for each key type */
@@ -1697,7 +1701,8 @@ int SSL_get_shared_sigalgs(SSL *s, int idx,
 
 typedef struct {
     size_t sigalgcnt;
-    int sigalgs[TLS_MAX_SIGALGCNT];
+    /* TLSEXT_SIGALG_XXX values */
+    uint16_t sigalgs[TLS_MAX_SIGALGCNT];
 } sig_cb_st;
 
 static void get_sigorhash(int *psig, int *phash, const char *str)
@@ -1723,6 +1728,7 @@ static int sig_cb(const char *elem, int len, void *arg)
 {
     sig_cb_st *sarg = arg;
     size_t i;
+    const SIGALG_LOOKUP *s;
     char etmp[TLS_MAX_SIGSTRING_LEN], *p;
     int sig_alg = NID_undef, hash_alg = NID_undef;
     if (elem == NULL)
@@ -1734,18 +1740,25 @@ static int sig_cb(const char *elem, int len, void *arg)
     memcpy(etmp, elem, len);
     etmp[len] = 0;
     p = strchr(etmp, '+');
-    /* See if we have a match for TLS 1.3 names */
+    /*
+     * We only allow SignatureSchemes listed in the sigalg_lookup_tbl;
+     * if there's no '+' in the provided name, look for the new-style combined
+     * name.  If not, match both sig+hash to find the needed SIGALG_LOOKUP.
+     * Just sig+hash is not unique since TLS 1.3 adds rsa_pss_pss_* and
+     * rsa_pss_rsae_* that differ only by public key OID; in such cases
+     * we will pick the _rsae_ variant, by virtue of them appearing earlier
+     * in the table.
+     */
     if (p == NULL) {
-        const SIGALG_LOOKUP *s;
-
         for (i = 0, s = sigalg_lookup_tbl; i < OSSL_NELEM(sigalg_lookup_tbl);
              i++, s++) {
             if (s->name != NULL && strcmp(etmp, s->name) == 0) {
-                sig_alg = s->sig;
-                hash_alg = s->hash;
+                sarg->sigalgs[sarg->sigalgcnt++] = s->sigalg;
                 break;
             }
         }
+        if (i == OSSL_NELEM(sigalg_lookup_tbl))
+            return 0;
     } else {
         *p = 0;
         p++;
@@ -1753,17 +1766,26 @@ static int sig_cb(const char *elem, int len, void *arg)
             return 0;
         get_sigorhash(&sig_alg, &hash_alg, etmp);
         get_sigorhash(&sig_alg, &hash_alg, p);
+        if (sig_alg == NID_undef || hash_alg == NID_undef)
+            return 0;
+        for (i = 0, s = sigalg_lookup_tbl; i < OSSL_NELEM(sigalg_lookup_tbl);
+             i++, s++) {
+            if (s->hash == hash_alg && s->sig == sig_alg) {
+                sarg->sigalgs[sarg->sigalgcnt++] = s->sigalg;
+                break;
+            }
+        }
+        if (i == OSSL_NELEM(sigalg_lookup_tbl))
+            return 0;
     }
 
-    if (sig_alg == NID_undef || (p != NULL && hash_alg == NID_undef))
-        return 0;
-
-    for (i = 0; i < sarg->sigalgcnt; i += 2) {
-        if (sarg->sigalgs[i] == sig_alg && sarg->sigalgs[i + 1] == hash_alg)
+    /* Reject duplicates */
+    for (i = 0; i < sarg->sigalgcnt - 1; i++) {
+        if (sarg->sigalgs[i] == sarg->sigalgs[sarg->sigalgcnt - 1]) {
+            sarg->sigalgcnt--;
             return 0;
+        }
     }
-    sarg->sigalgs[sarg->sigalgcnt++] = hash_alg;
-    sarg->sigalgs[sarg->sigalgcnt++] = sig_alg;
     return 1;
 }
 
@@ -1779,7 +1801,30 @@ int tls1_set_sigalgs_list(CERT *c, const char *str, int client)
         return 0;
     if (c == NULL)
         return 1;
-    return tls1_set_sigalgs(c, sig.sigalgs, sig.sigalgcnt, client);
+    return tls1_set_raw_sigalgs(c, sig.sigalgs, sig.sigalgcnt, client);
+}
+
+int tls1_set_raw_sigalgs(CERT *c, const uint16_t *psigs, size_t salglen,
+                     int client)
+{
+    uint16_t *sigalgs;
+
+    sigalgs = OPENSSL_malloc(salglen * sizeof(*sigalgs));
+    if (sigalgs == NULL)
+        return 0;
+    memcpy(sigalgs, psigs, salglen * sizeof(*sigalgs));
+
+    if (client) {
+        OPENSSL_free(c->client_sigalgs);
+        c->client_sigalgs = sigalgs;
+        c->client_sigalgslen = salglen;
+    } else {
+        OPENSSL_free(c->conf_sigalgs);
+        c->conf_sigalgs = sigalgs;
+        c->conf_sigalgslen = salglen;
+    }
+
+    return 1;
 }
 
 int tls1_set_sigalgs(CERT *c, const int *psig_nids, size_t salglen, int client)
@@ -1933,10 +1978,11 @@ int tls1_check_chain(SSL *s, X509 *x, EVP_PKEY *pk, STACK_OF(X509) *chain,
     if (TLS1_get_version(s) >= TLS1_2_VERSION && strict_mode) {
         int default_nid;
         int rsign = 0;
-        if (s->s3->tmp.peer_sigalgs)
+        if (s->s3->tmp.peer_cert_sigalgs != NULL
+                || s->s3->tmp.peer_sigalgs != NULL) {
             default_nid = 0;
         /* If no sigalgs extension use defaults from RFC5246 */
-        else {
+        else {
             switch (idx) {
             case SSL_PKEY_RSA:
                 rsign = EVP_PKEY_RSA;
@@ -2271,13 +2317,48 @@ static int tls12_get_cert_sigalg_idx(const SSL *s, const SIGALG_LOOKUP *lu)
     if (clu == NULL || !(clu->amask & s->s3->tmp.new_cipher->algorithm_auth))
         return -1;
 
-    /* If PSS and we have no PSS cert use RSA */
-    if (sig_idx == SSL_PKEY_RSA_PSS_SIGN && !ssl_has_cert(s, sig_idx))
-        sig_idx = SSL_PKEY_RSA;
-
     return s->s3->tmp.valid_flags[sig_idx] & CERT_PKEY_VALID ? sig_idx : -1;
 }
 
+/*
+ * Returns true if |s| has a usable certificate configured for use
+ * with signature scheme |sig|.
+ * "Usable" includes a check for presence as well as applying
+ * the signature_algorithm_cert restrictions sent by the peer (if any).
+ * Returns false if no usable certificate is found.
+ */
+static int has_usable_cert(SSL *s, const SIGALG_LOOKUP *sig, int idx)
+{
+    const SIGALG_LOOKUP *lu;
+    int mdnid, pknid;
+    size_t i;
+
+    /* TLS 1.2 callers can override lu->sig_idx, but not TLS 1.3 callers. */
+    if (idx == -1)
+        idx = sig->sig_idx;
+    if (!ssl_has_cert(s, idx))
+        return 0;
+    if (s->s3->tmp.peer_cert_sigalgs != NULL) {
+        for (i = 0; i < s->s3->tmp.peer_cert_sigalgslen; i++) {
+            lu = tls1_lookup_sigalg(s->s3->tmp.peer_cert_sigalgs[i]);
+            if (lu == NULL
+                || !X509_get_signature_info(s->cert->pkeys[idx].x509, &mdnid,
+                                            &pknid, NULL, NULL))
+                continue;
+            /*
+             * TODO this does not differentiate between the
+             * rsa_pss_pss_* and rsa_pss_rsae_* schemes since we do not
+             * have a chain here that lets us look at the key OID in the
+             * signing certificate.
+             */
+            if (mdnid == lu->hash && pknid == lu->sig)
+                return 1;
+        }
+        return 0;
+    }
+    return 1;
+}
+
 /*
  * Choose an appropriate signature algorithm based on available certificates
  * Sets chosen certificate and signature algorithm.
@@ -2314,14 +2395,9 @@ int tls_choose_sigalg(SSL *s, int fatalerrs)
                 || lu->sig == EVP_PKEY_DSA
                 || lu->sig == EVP_PKEY_RSA)
                 continue;
-            if (!tls1_lookup_md(lu, NULL))
+            /* Check that we have a cert, and signature_algorithms_cert */
+            if (!tls1_lookup_md(lu, NULL) || !has_usable_cert(s, lu, -1))
                 continue;
-            if (!ssl_has_cert(s, lu->sig_idx)) {
-                if (lu->sig_idx != SSL_PKEY_RSA_PSS_SIGN
-                        || !ssl_has_cert(s, SSL_PKEY_RSA))
-                    continue;
-                sig_idx = SSL_PKEY_RSA;
-            }
             if (lu->sig == EVP_PKEY_EC) {
 #ifndef OPENSSL_NO_EC
                 if (curve == -1) {
@@ -2340,21 +2416,8 @@ int tls_choose_sigalg(SSL *s, int fatalerrs)
             } else if (lu->sig == EVP_PKEY_RSA_PSS) {
                 /* validate that key is large enough for the signature algorithm */
                 EVP_PKEY *pkey;
-                int pkey_id;
 
-                if (sig_idx == -1)
-                    pkey = s->cert->pkeys[lu->sig_idx].privatekey;
-                else
-                    pkey = s->cert->pkeys[sig_idx].privatekey;
-                pkey_id = EVP_PKEY_id(pkey);
-                if (pkey_id != EVP_PKEY_RSA_PSS
-                    && pkey_id != EVP_PKEY_RSA)
-                    continue;
-                /*
-                 * The pkey type is EVP_PKEY_RSA_PSS or EVP_PKEY_RSA
-                 * EVP_PKEY_get0_RSA returns NULL if the type is not EVP_PKEY_RSA
-                 * so use EVP_PKEY_get0 instead
-                 */
+                pkey = s->cert->pkeys[lu->sig_idx].privatekey;
                 if (!rsa_pss_check_min_key_size(EVP_PKEY_get0(pkey), lu))
                     continue;
             }
@@ -2375,8 +2438,8 @@ int tls_choose_sigalg(SSL *s, int fatalerrs)
                 return 1;
 
         if (SSL_USE_SIGALGS(s)) {
+            size_t i;
             if (s->s3->tmp.peer_sigalgs != NULL) {
-                size_t i;
 #ifndef OPENSSL_NO_EC
                 int curve;
 
@@ -2403,21 +2466,16 @@ int tls_choose_sigalg(SSL *s, int fatalerrs)
                         int cc_idx = s->cert->key - s->cert->pkeys;
 
                         sig_idx = lu->sig_idx;
-                        if (cc_idx != sig_idx) {
-                            if (sig_idx != SSL_PKEY_RSA_PSS_SIGN
-                                || cc_idx != SSL_PKEY_RSA)
-                                continue;
-                            sig_idx = SSL_PKEY_RSA;
-                        }
+                        if (cc_idx != sig_idx)
+                            continue;
                     }
+                    /* Check that we have a cert, and sig_algs_cert */
+                    if (!has_usable_cert(s, lu, sig_idx))
+                        continue;
                     if (lu->sig == EVP_PKEY_RSA_PSS) {
                         /* validate that key is large enough for the signature algorithm */
                         EVP_PKEY *pkey = s->cert->pkeys[sig_idx].privatekey;
-                        int pkey_id = EVP_PKEY_id(pkey);
 
-                        if (pkey_id != EVP_PKEY_RSA_PSS
-                            && pkey_id != EVP_PKEY_RSA)
-                            continue;
                         if (!rsa_pss_check_min_key_size(EVP_PKEY_get0(pkey), lu))
                             continue;
                     }
@@ -2438,7 +2496,7 @@ int tls_choose_sigalg(SSL *s, int fatalerrs)
                  * If we have no sigalg use defaults
                  */
                 const uint16_t *sent_sigs;
-                size_t sent_sigslen, i;
+                size_t sent_sigslen;
 
                 if ((lu = tls1_get_legacy_sigalg(s, -1)) == NULL) {
                     if (!fatalerrs)
@@ -2451,7 +2509,8 @@ int tls_choose_sigalg(SSL *s, int fatalerrs)
                 /* Check signature matches a type we sent */
                 sent_sigslen = tls12_get_psigalgs(s, 1, &sent_sigs);
                 for (i = 0; i < sent_sigslen; i++, sent_sigs++) {
-                    if (lu->sigalg == *sent_sigs)
+                    if (lu->sigalg == *sent_sigs
+                            && has_usable_cert(s, lu, lu->sig_idx))
                         break;
                 }
                 if (i == sent_sigslen) {