Fix some client side transition logic
[openssl.git] / ssl / statem / statem_clnt.c
index c68eecf24fea8d2bfe2fc0e825a4d9b92675df50..c74ee82c1a33b0d28c929aa7e497e4ee265077c7 100644 (file)
 #endif
 
 static inline int cert_req_allowed(SSL *s);
 #endif
 
 static inline int cert_req_allowed(SSL *s);
-static inline int key_exchange_skip_allowed(SSL *s);
+static int key_exchange_expected(SSL *s);
 static int ssl_set_version(SSL *s);
 static int ca_dn_cmp(const X509_NAME *const *a, const X509_NAME *const *b);
 static int ssl_cipher_list_to_bytes(SSL *s, STACK_OF(SSL_CIPHER) *sk,
 static int ssl_set_version(SSL *s);
 static int ca_dn_cmp(const X509_NAME *const *a, const X509_NAME *const *b);
 static int ssl_cipher_list_to_bytes(SSL *s, STACK_OF(SSL_CIPHER) *sk,
@@ -190,25 +190,51 @@ static inline int cert_req_allowed(SSL *s)
 }
 
 /*
 }
 
 /*
- * Are we allowed to skip the ServerKeyExchange message?
+ * Should we expect the ServerKeyExchange message or not?
  *
  *  Return values are:
  *  1: Yes
  *  0: No
  *
  *  Return values are:
  *  1: Yes
  *  0: No
+ * -1: Error
  */
  */
-static inline int key_exchange_skip_allowed(SSL *s)
+static int key_exchange_expected(SSL *s)
 {
     long alg_k = s->s3->tmp.new_cipher->algorithm_mkey;
 
     /*
      * Can't skip server key exchange if this is an ephemeral
 {
     long alg_k = s->s3->tmp.new_cipher->algorithm_mkey;
 
     /*
      * Can't skip server key exchange if this is an ephemeral
-     * ciphersuite.
+     * ciphersuite or for SRP
      */
      */
-    if (alg_k & (SSL_kDHE | SSL_kECDHE)) {
-        return 0;
+    if (alg_k & (SSL_kDHE | SSL_kECDHE | SSL_kDHEPSK | SSL_kECDHEPSK
+                 | SSL_kSRP)) {
+        return 1;
     }
 
     }
 
-    return 1;
+    /*
+     * Export ciphersuites may have temporary RSA keys if the public key in the
+     * server certificate is longer than the maximum export strength
+     */
+    if ((alg_k & SSL_kRSA) && SSL_C_IS_EXPORT(s->s3->tmp.new_cipher)) {
+        EVP_PKEY *pkey;
+
+        pkey = X509_get_pubkey(s->session->peer);
+        if (pkey == NULL)
+            return -1;
+
+        /*
+         * If the public key in the certificate is shorter than or equal to the
+         * maximum export strength then a temporary RSA key is not allowed
+         */
+        if (EVP_PKEY_bits(pkey)
+                <= SSL_C_EXPORT_PKEYLENGTH(s->s3->tmp.new_cipher))
+            return 0;
+
+        EVP_PKEY_free(pkey);
+
+        return 1;
+    }
+
+    return 0;
 }
 
 /*
 }
 
 /*
@@ -224,6 +250,7 @@ static inline int key_exchange_skip_allowed(SSL *s)
 int client_read_transition(SSL *s, int mt)
 {
     STATEM *st = &s->statem;
 int client_read_transition(SSL *s, int mt)
 {
     STATEM *st = &s->statem;
+    int ske_expected;
 
     switch(st->hand_state) {
     case TLS_ST_CW_CLNT_HELLO:
 
     switch(st->hand_state) {
     case TLS_ST_CW_CLNT_HELLO:
@@ -262,18 +289,24 @@ int client_read_transition(SSL *s, int mt)
                     return 1;
                 }
             } else {
                     return 1;
                 }
             } else {
-                if (mt == SSL3_MT_SERVER_KEY_EXCHANGE) {
-                    st->hand_state = TLS_ST_CR_KEY_EXCH;
-                    return 1;
-                } else if (key_exchange_skip_allowed(s)) {
-                    if (mt == SSL3_MT_CERTIFICATE_REQUEST
+                ske_expected = key_exchange_expected(s);
+                if (ske_expected < 0)
+                    return 0;
+                /* SKE is optional for some PSK ciphersuites */
+                if (ske_expected
+                        || ((s->s3->tmp.new_cipher->algorithm_mkey & SSL_PSK)
+                            && mt == SSL3_MT_SERVER_KEY_EXCHANGE)) {
+                    if (mt == SSL3_MT_SERVER_KEY_EXCHANGE) {
+                        st->hand_state = TLS_ST_CR_KEY_EXCH;
+                        return 1;
+                    }
+                } else if (mt == SSL3_MT_CERTIFICATE_REQUEST
                             && cert_req_allowed(s)) {
                         st->hand_state = TLS_ST_CR_CERT_REQ;
                         return 1;
                             && cert_req_allowed(s)) {
                         st->hand_state = TLS_ST_CR_CERT_REQ;
                         return 1;
-                    } else if (mt == SSL3_MT_SERVER_DONE) {
+                } else if (mt == SSL3_MT_SERVER_DONE) {
                         st->hand_state = TLS_ST_CR_SRVR_DONE;
                         return 1;
                         st->hand_state = TLS_ST_CR_SRVR_DONE;
                         return 1;
-                    }
                 }
             }
         }
                 }
             }
         }
@@ -285,46 +318,35 @@ int client_read_transition(SSL *s, int mt)
                 st->hand_state = TLS_ST_CR_CERT_STATUS;
                 return 1;
             }
                 st->hand_state = TLS_ST_CR_CERT_STATUS;
                 return 1;
             }
-        } else {
+            return 0;
+        }
+        /* Fall through */
+
+    case TLS_ST_CR_CERT_STATUS:
+        ske_expected = key_exchange_expected(s);
+        if (ske_expected < 0)
+            return 0;
+        /* SKE is optional for some PSK ciphersuites */
+        if (ske_expected
+                || ((s->s3->tmp.new_cipher->algorithm_mkey & SSL_PSK)
+                    && mt == SSL3_MT_SERVER_KEY_EXCHANGE)) {
             if (mt == SSL3_MT_SERVER_KEY_EXCHANGE) {
                 st->hand_state = TLS_ST_CR_KEY_EXCH;
                 return 1;
             if (mt == SSL3_MT_SERVER_KEY_EXCHANGE) {
                 st->hand_state = TLS_ST_CR_KEY_EXCH;
                 return 1;
-            } else if (key_exchange_skip_allowed(s)) {
-                if (mt == SSL3_MT_CERTIFICATE_REQUEST && cert_req_allowed(s)) {
-                    st->hand_state = TLS_ST_CR_CERT_REQ;
-                    return 1;
-                } else if (mt == SSL3_MT_SERVER_DONE) {
-                    st->hand_state = TLS_ST_CR_SRVR_DONE;
-                    return 1;
-                }
             }
             }
+            return 0;
         }
         }
-        break;
+        /* Fall through */
 
 
-    case TLS_ST_CR_CERT_STATUS:
-        if (mt == SSL3_MT_SERVER_KEY_EXCHANGE) {
-            st->hand_state = TLS_ST_CR_KEY_EXCH;
-            return 1;
-        } else if (key_exchange_skip_allowed(s)) {
-            if (mt == SSL3_MT_CERTIFICATE_REQUEST && cert_req_allowed(s)) {
+    case TLS_ST_CR_KEY_EXCH:
+        if (mt == SSL3_MT_CERTIFICATE_REQUEST) {
+            if (cert_req_allowed(s)) {
                 st->hand_state = TLS_ST_CR_CERT_REQ;
                 return 1;
                 st->hand_state = TLS_ST_CR_CERT_REQ;
                 return 1;
-            } else if (mt == SSL3_MT_SERVER_DONE) {
-                st->hand_state = TLS_ST_CR_SRVR_DONE;
-                return 1;
             }
             }
+            return 0;
         }
         }
-        break;
-
-    case TLS_ST_CR_KEY_EXCH:
-        if (mt == SSL3_MT_CERTIFICATE_REQUEST && cert_req_allowed(s)) {
-            st->hand_state = TLS_ST_CR_CERT_REQ;
-            return 1;
-        } else if (mt == SSL3_MT_SERVER_DONE) {
-            st->hand_state = TLS_ST_CR_SRVR_DONE;
-            return 1;
-        }
-        break;
+        /* Fall through */
 
     case TLS_ST_CR_CERT_REQ:
         if (mt == SSL3_MT_SERVER_DONE) {
 
     case TLS_ST_CR_CERT_REQ:
         if (mt == SSL3_MT_SERVER_DONE) {
@@ -1721,12 +1743,6 @@ enum MSG_PROCESS_RETURN tls_process_key_exchange(SSL *s, PACKET *pkt)
             goto err;
         }
 
             goto err;
         }
 
-        if (EVP_PKEY_bits(pkey) <= SSL_C_EXPORT_PKEYLENGTH(s->s3->tmp.new_cipher)) {
-            al = SSL_AD_UNEXPECTED_MESSAGE;
-            SSLerr(SSL_F_TLS_PROCESS_KEY_EXCHANGE, SSL_R_UNEXPECTED_MESSAGE);
-            goto f_err;
-        }
-
         s->s3->peer_rsa_tmp = rsa;
         rsa = NULL;
     }
         s->s3->peer_rsa_tmp = rsa;
         rsa = NULL;
     }
@@ -2312,6 +2328,16 @@ enum MSG_PROCESS_RETURN tls_process_server_done(SSL *s, PACKET *pkt)
     }
 #endif
 
     }
 #endif
 
+    /*
+     * at this point we check that we have the required stuff from
+     * the server
+     */
+    if (!ssl3_check_cert_and_algorithm(s)) {
+        ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_HANDSHAKE_FAILURE);
+        statem_set_error(s);
+        return MSG_PROCESS_ERROR;
+    }
+
 #ifndef OPENSSL_NO_SCTP
     /* Only applies to renegotiation */
     if (SSL_IS_DTLS(s) && BIO_dgram_is_sctp(SSL_get_wbio(s))
 #ifndef OPENSSL_NO_SCTP
     /* Only applies to renegotiation */
     if (SSL_IS_DTLS(s) && BIO_dgram_is_sctp(SSL_get_wbio(s))