Add more error state transitions
[openssl.git] / ssl / s3_srvr.c
index 7a399673b1d6b4f1c4c1fc0cd11fb7c4f75c25bf..e6884f3fb1554f7d6b88af40c89b1332a8a7c3f5 100644 (file)
@@ -262,6 +262,7 @@ int ssl3_accept(SSL *s)
 
             if ((s->version >> 8) != 3) {
                 SSLerr(SSL_F_SSL3_ACCEPT, ERR_R_INTERNAL_ERROR);
+                s->state = SSL_ST_ERR;
                 return -1;
             }
 
@@ -275,11 +276,13 @@ int ssl3_accept(SSL *s)
             if (s->init_buf == NULL) {
                 if ((buf = BUF_MEM_new()) == NULL) {
                     ret = -1;
+                    s->state = SSL_ST_ERR;
                     goto end;
                 }
                 if (!BUF_MEM_grow(buf, SSL3_RT_MAX_PLAIN_LENGTH)) {
                     BUF_MEM_free(buf);
                     ret = -1;
+                    s->state = SSL_ST_ERR;
                     goto end;
                 }
                 s->init_buf = buf;
@@ -287,6 +290,7 @@ int ssl3_accept(SSL *s)
 
             if (!ssl3_setup_buffers(s)) {
                 ret = -1;
+                s->state = SSL_ST_ERR;
                 goto end;
             }
 
@@ -305,6 +309,7 @@ int ssl3_accept(SSL *s)
                  */
                 if (!ssl_init_wbio_buffer(s, 1)) {
                     ret = -1;
+                    s->state = SSL_ST_ERR;
                     goto end;
                 }
 
@@ -322,6 +327,7 @@ int ssl3_accept(SSL *s)
                        SSL_R_UNSAFE_LEGACY_RENEGOTIATION_DISABLED);
                 ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_HANDSHAKE_FAILURE);
                 ret = -1;
+                s->state = SSL_ST_ERR;
                 goto end;
             } else {
                 /*
@@ -380,6 +386,7 @@ int ssl3_accept(SSL *s)
                         SSLerr(SSL_F_SSL3_ACCEPT, SSL_R_CLIENTHELLO_TLSEXT);
                     ret = SSL_TLSEXT_ERR_ALERT_FATAL;
                     ret = -1;
+                    s->state = SSL_ST_ERR;
                     goto end;
                 }
             }
@@ -530,9 +537,12 @@ int ssl3_accept(SSL *s)
                 skip = 1;
                 s->s3->tmp.cert_request = 0;
                 s->state = SSL3_ST_SW_SRVR_DONE_A;
-                if (s->s3->handshake_buffer)
-                    if (!ssl3_digest_cached_records(s))
+                if (s->s3->handshake_buffer) {
+                    if (!ssl3_digest_cached_records(s)) {
+                        s->state = SSL_ST_ERR;
                         return -1;
+                    }
+                }
             } else {
                 s->s3->tmp.cert_request = 1;
                 ret = ssl3_send_certificate_request(s);
@@ -613,6 +623,7 @@ int ssl3_accept(SSL *s)
                     break;
                 if (!s->s3->handshake_buffer) {
                     SSLerr(SSL_F_SSL3_ACCEPT, ERR_R_INTERNAL_ERROR);
+                    s->state = SSL_ST_ERR;
                     return -1;
                 }
                 /*
@@ -621,8 +632,10 @@ int ssl3_accept(SSL *s)
                  */
                 if (!(s->s3->flags & SSL_SESS_FLAG_EXTMS)) {
                     s->s3->flags |= TLS1_FLAGS_KEEP_HANDSHAKE;
-                    if (!ssl3_digest_cached_records(s))
+                    if (!ssl3_digest_cached_records(s)) {
+                        s->state = SSL_ST_ERR;
                         return -1;
+                    }
                 }
             } else {
                 int offset = 0;
@@ -637,9 +650,12 @@ int ssl3_accept(SSL *s)
                  * CertificateVerify should be generalized. But it is next
                  * step
                  */
-                if (s->s3->handshake_buffer)
-                    if (!ssl3_digest_cached_records(s))
+                if (s->s3->handshake_buffer) {
+                    if (!ssl3_digest_cached_records(s)) {
+                        s->state = SSL_ST_ERR;
                         return -1;
+                    }
+                }
                 for (dgst_num = 0; dgst_num < SSL_MAX_DIGEST; dgst_num++)
                     if (s->s3->handshake_dgst[dgst_num]) {
                         int dgst_size;
@@ -655,6 +671,7 @@ int ssl3_accept(SSL *s)
                         dgst_size =
                             EVP_MD_CTX_size(s->s3->handshake_dgst[dgst_num]);
                         if (dgst_size < 0) {
+                            s->state = SSL_ST_ERR;
                             ret = -1;
                             goto end;
                         }
@@ -769,6 +786,7 @@ int ssl3_accept(SSL *s)
             s->session->cipher = s->s3->tmp.new_cipher;
             if (!s->method->ssl3_enc->setup_key_block(s)) {
                 ret = -1;
+                s->state = SSL_ST_ERR;
                 goto end;
             }
 
@@ -785,6 +803,7 @@ int ssl3_accept(SSL *s)
                                                           SSL3_CHANGE_CIPHER_SERVER_WRITE))
             {
                 ret = -1;
+                s->state = SSL_ST_ERR;
                 goto end;
             }
 
@@ -847,6 +866,7 @@ int ssl3_accept(SSL *s)
             goto end;
             /* break; */
 
+        case SSL_ST_ERR:
         default:
             SSLerr(SSL_F_SSL3_ACCEPT, SSL_R_UNKNOWN_STATE);
             ret = -1;
@@ -1444,8 +1464,10 @@ int ssl3_get_client_hello(SSL *s)
     if (0) {
  f_err:
         ssl3_send_alert(s, SSL3_AL_FATAL, al);
-    }
  err:
+        s->state = SSL_ST_ERR;
+    }
+
     sk_SSL_CIPHER_free(ciphers);
     return ret < 0 ? -1 : ret;
 }
@@ -1462,8 +1484,10 @@ int ssl3_send_server_hello(SSL *s)
         buf = (unsigned char *)s->init_buf->data;
 #ifdef OPENSSL_NO_TLSEXT
         p = s->s3->server_random;
-        if (ssl_fill_hello_random(s, 1, p, SSL3_RANDOM_SIZE) <= 0)
+        if (ssl_fill_hello_random(s, 1, p, SSL3_RANDOM_SIZE) <= 0) {
+            s->state = SSL_ST_ERR;
             return -1;
+        }
 #endif
         /* Do the message type and length last */
         d = p = ssl_handshake_start(s);
@@ -1499,6 +1523,7 @@ int ssl3_send_server_hello(SSL *s)
         sl = s->session->session_id_length;
         if (sl > (int)sizeof(s->session->session_id)) {
             SSLerr(SSL_F_SSL3_SEND_SERVER_HELLO, ERR_R_INTERNAL_ERROR);
+            s->state = SSL_ST_ERR;
             return -1;
         }
         *(p++) = sl;
@@ -1521,6 +1546,7 @@ int ssl3_send_server_hello(SSL *s)
 #ifndef OPENSSL_NO_TLSEXT
         if (ssl_prepare_serverhello_tlsext(s) <= 0) {
             SSLerr(SSL_F_SSL3_SEND_SERVER_HELLO, SSL_R_SERVERHELLO_TLSEXT);
+            s->state = SSL_ST_ERR;
             return -1;
         }
         if ((p =
@@ -1528,6 +1554,7 @@ int ssl3_send_server_hello(SSL *s)
                                         &al)) == NULL) {
             ssl3_send_alert(s, SSL3_AL_FATAL, al);
             SSLerr(SSL_F_SSL3_SEND_SERVER_HELLO, ERR_R_INTERNAL_ERROR);
+            s->state = SSL_ST_ERR;
             return -1;
         }
 #endif
@@ -2016,6 +2043,7 @@ int ssl3_send_server_key_exchange(SSL *s)
     BN_CTX_free(bn_ctx);
 #endif
     EVP_MD_CTX_cleanup(&md_ctx);
+    s->state = SSL_ST_ERR;
     return (-1);
 }
 
@@ -2090,6 +2118,7 @@ int ssl3_send_certificate_request(SSL *s)
     /* SSL3_ST_SW_CERT_REQ_B */
     return ssl_do_write(s);
  err:
+    s->state = SSL_ST_ERR;
     return (-1);
 }
 
@@ -2916,6 +2945,7 @@ int ssl3_get_client_key_exchange(SSL *s)
     EC_KEY_free(srvr_ecdh);
     BN_CTX_free(bn_ctx);
 #endif
+    s->state = SSL_ST_ERR;
     return (-1);
 }
 
@@ -3118,6 +3148,7 @@ int ssl3_get_cert_verify(SSL *s)
     if (0) {
  f_err:
         ssl3_send_alert(s, SSL3_AL_FATAL, al);
+        s->state = SSL_ST_ERR;
     }
  end:
     BIO_free(s->s3->handshake_buffer);
@@ -3286,6 +3317,7 @@ int ssl3_get_client_certificate(SSL *s)
  f_err:
     ssl3_send_alert(s, SSL3_AL_FATAL, al);
  done:
+    s->state = SSL_ST_ERR;
     X509_free(x);
     sk_X509_pop_free(sk, X509_free);
     return (ret);
@@ -3303,12 +3335,14 @@ int ssl3_send_server_certificate(SSL *s)
                 (s->s3->tmp.new_cipher->algorithm_mkey & SSL_kKRB5)) {
                 SSLerr(SSL_F_SSL3_SEND_SERVER_CERTIFICATE,
                        ERR_R_INTERNAL_ERROR);
+                s->state = SSL_ST_ERR;
                 return (0);
             }
         }
 
         if (!ssl3_output_cert_chain(s, cpk)) {
             SSLerr(SSL_F_SSL3_SEND_SERVER_CERTIFICATE, ERR_R_INTERNAL_ERROR);
+            s->state = SSL_ST_ERR;
             return (0);
         }
         s->state = SSL3_ST_SW_CERT_B;
@@ -3342,11 +3376,15 @@ int ssl3_send_newsession_ticket(SSL *s)
          * Some length values are 16 bits, so forget it if session is too
          * long
          */
-        if (slen_full == 0 || slen_full > 0xFF00)
+        if (slen_full == 0 || slen_full > 0xFF00) {
+            s->state = SSL_ST_ERR;
             return -1;
+        }
         senc = OPENSSL_malloc(slen_full);
-        if (!senc)
+        if (!senc) {
+            s->state = SSL_ST_ERR;
             return -1;
+        }
 
         EVP_CIPHER_CTX_init(&ctx);
         HMAC_CTX_init(&hctx);
@@ -3461,6 +3499,7 @@ int ssl3_send_newsession_ticket(SSL *s)
     OPENSSL_free(senc);
     EVP_CIPHER_CTX_cleanup(&ctx);
     HMAC_CTX_cleanup(&hctx);
+    s->state = SSL_ST_ERR;
     return -1;
 }
 
@@ -3474,8 +3513,10 @@ int ssl3_send_cert_status(SSL *s)
          * 1 (ocsp response type) + 3 (ocsp response length)
          * + (ocsp response)
          */
-        if (!BUF_MEM_grow(s->init_buf, 8 + s->tlsext_ocsp_resplen))
+        if (!BUF_MEM_grow(s->init_buf, 8 + s->tlsext_ocsp_resplen)) {
+            s->state = SSL_ST_ERR;
             return -1;
+        }
 
         p = (unsigned char *)s->init_buf->data;
 
@@ -3518,6 +3559,7 @@ int ssl3_get_next_proto(SSL *s)
     if (!s->s3->next_proto_neg_seen) {
         SSLerr(SSL_F_SSL3_GET_NEXT_PROTO,
                SSL_R_GOT_NEXT_PROTO_WITHOUT_EXTENSION);
+        s->state = SSL_ST_ERR;
         return -1;
     }
 
@@ -3537,11 +3579,14 @@ int ssl3_get_next_proto(SSL *s)
      */
     if (!s->s3->change_cipher_spec) {
         SSLerr(SSL_F_SSL3_GET_NEXT_PROTO, SSL_R_GOT_NEXT_PROTO_BEFORE_A_CCS);
+        s->state = SSL_ST_ERR;
         return -1;
     }
 
-    if (n < 2)
+    if (n < 2) {
+        s->state = SSL_ST_ERR;
         return 0;               /* The body must be > 1 bytes long */
+    }
 
     p = (unsigned char *)s->init_msg;
 
@@ -3553,15 +3598,20 @@ int ssl3_get_next_proto(SSL *s)
      *   uint8 padding[padding_len];
      */
     proto_len = p[0];
-    if (proto_len + 2 > s->init_num)
+    if (proto_len + 2 > s->init_num) {
+        s->state = SSL_ST_ERR;
         return 0;
+    }
     padding_len = p[proto_len + 1];
-    if (proto_len + padding_len + 2 != s->init_num)
+    if (proto_len + padding_len + 2 != s->init_num) {
+        s->state = SSL_ST_ERR;
         return 0;
+    }
 
     s->next_proto_negotiated = OPENSSL_malloc(proto_len);
     if (!s->next_proto_negotiated) {
         SSLerr(SSL_F_SSL3_GET_NEXT_PROTO, ERR_R_MALLOC_FAILURE);
+        s->state = SSL_ST_ERR;
         return 0;
     }
     memcpy(s->next_proto_negotiated, p + 1, proto_len);