RT3962: Check accept_count only if not unlimited
[openssl.git] / ssl / t1_lib.c
index 0420fe31b27536232c9a7382a7054a689e731822..47abf2b9f912112492152b9eeb542e742cbe83fe 100644 (file)
@@ -1111,7 +1111,7 @@ void ssl_set_client_disabled(SSL *s)
     /* with PSK there must be client callback set */
     if (!s->psk_client_callback) {
         s->s3->tmp.mask_a |= SSL_aPSK;
-        s->s3->tmp.mask_k |= SSL_kPSK;
+        s->s3->tmp.mask_k |= SSL_PSK;
     }
 #endif                         /* OPENSSL_NO_PSK */
 #ifndef OPENSSL_NO_SRP
@@ -1157,7 +1157,7 @@ unsigned char *ssl_add_clienthello_tlsext(SSL *s, unsigned char *buf,
 
             alg_k = c->algorithm_mkey;
             alg_a = c->algorithm_auth;
-            if ((alg_k & (SSL_kECDHE | SSL_kECDHr | SSL_kECDHe)
+            if ((alg_k & (SSL_kECDHE | SSL_kECDHr | SSL_kECDHe | SSL_kECDHEPSK)
                  || (alg_a & SSL_aECDSA))) {
                 using_ecc = 1;
                 break;
@@ -1940,19 +1940,23 @@ static int ssl_scan_clienthello_tlsext(SSL *s, unsigned char **p,
 
     s->srtp_profile = NULL;
 
-    if (data >= (d + n - 2))
+    if (data == d + n)
         goto ri_check;
+
+    if (data > (d + n - 2))
+        goto err;
+
     n2s(data, len);
 
     if (data > (d + n - len))
-        goto ri_check;
+        goto err;
 
     while (data <= (d + n - 4)) {
         n2s(data, type);
         n2s(data, size);
 
         if (data + size > (d + n))
-            goto ri_check;
+            goto err;
         if (s->tlsext_debug_cb)
             s->tlsext_debug_cb(s, 0, type, data, size, s->tlsext_debug_arg);
         if (type == TLSEXT_TYPE_renegotiate) {
@@ -1991,16 +1995,12 @@ static int ssl_scan_clienthello_tlsext(SSL *s, unsigned char **p,
             int servname_type;
             int dsize;
 
-            if (size < 2) {
-                *al = SSL_AD_DECODE_ERROR;
-                return 0;
-            }
+            if (size < 2)
+                goto err;
             n2s(data, dsize);
             size -= 2;
-            if (dsize > size) {
-                *al = SSL_AD_DECODE_ERROR;
-                return 0;
-            }
+            if (dsize > size)
+                goto err;
 
             sdata = data;
             while (dsize > 3) {
@@ -2008,18 +2008,16 @@ static int ssl_scan_clienthello_tlsext(SSL *s, unsigned char **p,
                 n2s(sdata, len);
                 dsize -= 3;
 
-                if (len > dsize) {
-                    *al = SSL_AD_DECODE_ERROR;
-                    return 0;
-                }
+                if (len > dsize)
+                    goto err;
+
                 if (s->servername_done == 0)
                     switch (servname_type) {
                     case TLSEXT_NAMETYPE_host_name:
                         if (!s->hit) {
-                            if (s->session->tlsext_hostname) {
-                                *al = SSL_AD_DECODE_ERROR;
-                                return 0;
-                            }
+                            if (s->session->tlsext_hostname)
+                                goto err;
+
                             if (len > TLSEXT_MAXLEN_host_name) {
                                 *al = TLS1_AD_UNRECOGNIZED_NAME;
                                 return 0;
@@ -2053,31 +2051,23 @@ static int ssl_scan_clienthello_tlsext(SSL *s, unsigned char **p,
 
                 dsize -= len;
             }
-            if (dsize != 0) {
-                *al = SSL_AD_DECODE_ERROR;
-                return 0;
-            }
+            if (dsize != 0)
+                goto err;
 
         }
 #ifndef OPENSSL_NO_SRP
         else if (type == TLSEXT_TYPE_srp) {
-            if (size == 0 || ((len = data[0])) != (size - 1)) {
-                *al = SSL_AD_DECODE_ERROR;
-                return 0;
-            }
-            if (s->srp_ctx.login != NULL) {
-                *al = SSL_AD_DECODE_ERROR;
-                return 0;
-            }
+            if (size == 0 || ((len = data[0])) != (size - 1))
+                goto err;
+            if (s->srp_ctx.login != NULL)
+                goto err;
             if ((s->srp_ctx.login = OPENSSL_malloc(len + 1)) == NULL)
                 return -1;
             memcpy(s->srp_ctx.login, &data[1], len);
             s->srp_ctx.login[len] = '\0';
 
-            if (strlen(s->srp_ctx.login) != len) {
-                *al = SSL_AD_DECODE_ERROR;
-                return 0;
-            }
+            if (strlen(s->srp_ctx.login) != len)
+                goto err;
         }
 #endif
 
@@ -2087,10 +2077,8 @@ static int ssl_scan_clienthello_tlsext(SSL *s, unsigned char **p,
             int ecpointformatlist_length = *(sdata++);
 
             if (ecpointformatlist_length != size - 1 ||
-                ecpointformatlist_length < 1) {
-                *al = TLS1_AD_DECODE_ERROR;
-                return 0;
-            }
+                ecpointformatlist_length < 1)
+                goto err;
             if (!s->hit) {
                 OPENSSL_free(s->session->tlsext_ecpointformatlist);
                 s->session->tlsext_ecpointformatlist = NULL;
@@ -2113,15 +2101,13 @@ static int ssl_scan_clienthello_tlsext(SSL *s, unsigned char **p,
             if (ellipticcurvelist_length != size - 2 ||
                 ellipticcurvelist_length < 1 ||
                 /* Each NamedCurve is 2 bytes. */
-                ellipticcurvelist_length & 1) {
-                *al = TLS1_AD_DECODE_ERROR;
-                return 0;
-            }
+                ellipticcurvelist_length & 1)
+                    goto err;
+
             if (!s->hit) {
-                if (s->session->tlsext_ellipticcurvelist) {
-                    *al = TLS1_AD_DECODE_ERROR;
-                    return 0;
-                }
+                if (s->session->tlsext_ellipticcurvelist)
+                    goto err;
+
                 s->session->tlsext_ellipticcurvelist_length = 0;
                 if ((s->session->tlsext_ellipticcurvelist =
                      OPENSSL_malloc(ellipticcurvelist_length)) == NULL) {
@@ -2145,26 +2131,18 @@ static int ssl_scan_clienthello_tlsext(SSL *s, unsigned char **p,
             }
         } else if (type == TLSEXT_TYPE_signature_algorithms) {
             int dsize;
-            if (s->s3->tmp.peer_sigalgs || size < 2) {
-                *al = SSL_AD_DECODE_ERROR;
-                return 0;
-            }
+            if (s->s3->tmp.peer_sigalgs || size < 2)
+                goto err;
             n2s(data, dsize);
             size -= 2;
-            if (dsize != size || dsize & 1 || !dsize) {
-                *al = SSL_AD_DECODE_ERROR;
-                return 0;
-            }
-            if (!tls1_save_sigalgs(s, data, dsize)) {
-                *al = SSL_AD_DECODE_ERROR;
-                return 0;
-            }
+            if (dsize != size || dsize & 1 || !dsize)
+                goto err;
+            if (!tls1_save_sigalgs(s, data, dsize))
+                goto err;
         } else if (type == TLSEXT_TYPE_status_request) {
 
-            if (size < 5) {
-                *al = SSL_AD_DECODE_ERROR;
-                return 0;
-            }
+            if (size < 5)
+                goto err;
 
             s->tlsext_status_type = *data++;
             size--;
@@ -2174,35 +2152,26 @@ static int ssl_scan_clienthello_tlsext(SSL *s, unsigned char **p,
                 /* Read in responder_id_list */
                 n2s(data, dsize);
                 size -= 2;
-                if (dsize > size) {
-                    *al = SSL_AD_DECODE_ERROR;
-                    return 0;
-                }
+                if (dsize > size)
+                    goto err;
                 while (dsize > 0) {
                     OCSP_RESPID *id;
                     int idsize;
-                    if (dsize < 4) {
-                        *al = SSL_AD_DECODE_ERROR;
-                        return 0;
-                    }
+                    if (dsize < 4)
+                        goto err;
                     n2s(data, idsize);
                     dsize -= 2 + idsize;
                     size -= 2 + idsize;
-                    if (dsize < 0) {
-                        *al = SSL_AD_DECODE_ERROR;
-                        return 0;
-                    }
+                    if (dsize < 0)
+                        goto err;
                     sdata = data;
                     data += idsize;
                     id = d2i_OCSP_RESPID(NULL, &sdata, idsize);
-                    if (!id) {
-                        *al = SSL_AD_DECODE_ERROR;
-                        return 0;
-                    }
+                    if (!id)
+                        goto err;
                     if (data != sdata) {
                         OCSP_RESPID_free(id);
-                        *al = SSL_AD_DECODE_ERROR;
-                        return 0;
+                        goto err;
                     }
                     if (!s->tlsext_ocsp_ids
                         && !(s->tlsext_ocsp_ids =
@@ -2219,26 +2188,20 @@ static int ssl_scan_clienthello_tlsext(SSL *s, unsigned char **p,
                 }
 
                 /* Read in request_extensions */
-                if (size < 2) {
-                    *al = SSL_AD_DECODE_ERROR;
-                    return 0;
-                }
+                if (size < 2)
+                    goto err;
                 n2s(data, dsize);
                 size -= 2;
-                if (dsize != size) {
-                    *al = SSL_AD_DECODE_ERROR;
-                    return 0;
-                }
+                if (dsize != size)
+                    goto err;
                 sdata = data;
                 if (dsize > 0) {
                     sk_X509_EXTENSION_pop_free(s->tlsext_ocsp_exts,
                                                X509_EXTENSION_free);
                     s->tlsext_ocsp_exts =
                         d2i_X509_EXTENSIONS(NULL, &sdata, dsize);
-                    if (!s->tlsext_ocsp_exts || (data + dsize != sdata)) {
-                        *al = SSL_AD_DECODE_ERROR;
-                        return 0;
-                    }
+                    if (!s->tlsext_ocsp_exts || (data + dsize != sdata))
+                        goto err;
                 }
             }
             /*
@@ -2329,6 +2292,10 @@ static int ssl_scan_clienthello_tlsext(SSL *s, unsigned char **p,
         data += size;
     }
 
+    /* Spurious data on the end */
+    if (data != d + n)
+        goto err;
+
     *p = data;
 
  ri_check:
@@ -2344,6 +2311,9 @@ static int ssl_scan_clienthello_tlsext(SSL *s, unsigned char **p,
     }
 
     return 1;
+err:
+    *al = SSL_AD_DECODE_ERROR;
+    return 0;
 }
 
 int ssl_parse_clienthello_tlsext(SSL *s, unsigned char **p, unsigned char *d,
@@ -3489,7 +3459,7 @@ int tls1_process_sigalgs(SSL *s)
     size_t i;
     const EVP_MD *md;
     const EVP_MD **pmd = s->s3->tmp.md;
-    int *pvalid = s->s3->tmp.valid_flags;
+    uint32_t *pvalid = s->s3->tmp.valid_flags;
     CERT *c = s->cert;
     TLS_SIGALGS *sigptr;
     if (!tls1_set_shared_sigalgs(s))
@@ -3769,12 +3739,27 @@ typedef struct {
     int sigalgs[MAX_SIGALGLEN];
 } sig_cb_st;
 
+static void get_sigorhash(int *psig, int *phash, const char *str)
+{
+    if (strcmp(str, "RSA") == 0) {
+        *psig = EVP_PKEY_RSA;
+    } else if (strcmp(str, "DSA") == 0) {
+        *psig = EVP_PKEY_DSA;
+    } else if (strcmp(str, "ECDSA") == 0) {
+        *psig = EVP_PKEY_EC;
+    } else {
+        *phash = OBJ_sn2nid(str);
+        if (*phash == NID_undef)
+            *phash = OBJ_ln2nid(str);
+    }
+}
+
 static int sig_cb(const char *elem, int len, void *arg)
 {
     sig_cb_st *sarg = arg;
     size_t i;
     char etmp[20], *p;
-    int sig_alg, hash_alg;
+    int sig_alg = NID_undef, hash_alg = NID_undef;
     if (elem == NULL)
         return 0;
     if (sarg->sigalgcnt == MAX_SIGALGLEN)
@@ -3791,19 +3776,10 @@ static int sig_cb(const char *elem, int len, void *arg)
     if (!*p)
         return 0;
 
-    if (strcmp(etmp, "RSA") == 0)
-        sig_alg = EVP_PKEY_RSA;
-    else if (strcmp(etmp, "DSA") == 0)
-        sig_alg = EVP_PKEY_DSA;
-    else if (strcmp(etmp, "ECDSA") == 0)
-        sig_alg = EVP_PKEY_EC;
-    else
-        return 0;
+    get_sigorhash(&sig_alg, &hash_alg, etmp);
+    get_sigorhash(&sig_alg, &hash_alg, p);
 
-    hash_alg = OBJ_sn2nid(p);
-    if (hash_alg == NID_undef)
-        hash_alg = OBJ_ln2nid(p);
-    if (hash_alg == NID_undef)
+    if (sig_alg == NID_undef || hash_alg == NID_undef)
         return 0;
 
     for (i = 0; i < sarg->sigalgcnt; i += 2) {
@@ -3920,7 +3896,7 @@ int tls1_check_chain(SSL *s, X509 *x, EVP_PKEY *pk, STACK_OF(X509) *chain,
     int check_flags = 0, strict_mode;
     CERT_PKEY *cpk = NULL;
     CERT *c = s->cert;
-    int *pvalid;
+    uint32_t *pvalid;
     unsigned int suiteb_flags = tls1_suiteb(s);
     /* idx == -1 means checking server chains */
     if (idx != -1) {
@@ -4189,7 +4165,7 @@ DH *ssl_get_auto_dh(SSL *s)
     int dh_secbits = 80;
     if (s->cert->dh_tmp_auto == 2)
         return DH_get_1024_160();
-    if (s->s3->tmp.new_cipher->algorithm_auth & SSL_aNULL) {
+    if (s->s3->tmp.new_cipher->algorithm_auth & (SSL_aNULL | SSL_aPSK)) {
         if (s->s3->tmp.new_cipher->strength_bits == 256)
             dh_secbits = 128;
         else