Move init of the WPACKET into write_state_machine()
[openssl.git] / ssl / statem / statem.c
index df3008575d19604ad0f3702f7ff489d3ac0d69dd..1ad421b07b942797a3a4f0f1b4cc4fd0b3978fd3 100644 (file)
@@ -445,6 +445,21 @@ static void init_read_state_machine(SSL *s)
     st->read_state = READ_STATE_HEADER;
 }
 
+static int grow_init_buf(SSL *s, size_t size) {
+
+    size_t msg_offset = (char *)s->init_msg - s->init_buf->data;
+
+    if (!BUF_MEM_grow_clean(s->init_buf, (int)size))
+        return 0;
+
+    if (size < msg_offset)
+        return 0;
+
+    s->init_msg = s->init_buf->data + msg_offset;
+
+    return 1;
+}
+
 /*
  * This function implements the sub-state machine when the message flow is in
  * MSG_FLOW_READING. The valid sub-states and transitions are:
@@ -542,6 +557,16 @@ static SUB_STATE_RETURN read_state_machine(SSL *s)
                 return SUB_STATE_ERROR;
             }
 
+            /* dtls_get_message already did this */
+            if (!SSL_IS_DTLS(s)
+                    && s->s3->tmp.message_size > 0
+                    && !grow_init_buf(s, s->s3->tmp.message_size
+                                         + SSL3_HM_HEADER_LENGTH)) {
+                ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_INTERNAL_ERROR);
+                SSLerr(SSL_F_READ_STATE_MACHINE, ERR_R_BUF_LIB);
+                return SUB_STATE_ERROR;
+            }
+
             st->read_state = READ_STATE_BODY;
             /* Fall through */
 
@@ -590,7 +615,9 @@ static SUB_STATE_RETURN read_state_machine(SSL *s)
         case READ_STATE_POST_PROCESS:
             st->read_state_work = post_process_message(s, st->read_state_work);
             switch (st->read_state_work) {
-            default:
+            case WORK_ERROR:
+            case WORK_MORE_A:
+            case WORK_MORE_B:
                 return SUB_STATE_ERROR;
 
             case WORK_FINISHED_CONTINUE:
@@ -681,8 +708,9 @@ static SUB_STATE_RETURN write_state_machine(SSL *s)
     WRITE_TRAN(*transition) (SSL *s);
     WORK_STATE(*pre_work) (SSL *s, WORK_STATE wst);
     WORK_STATE(*post_work) (SSL *s, WORK_STATE wst);
-    int (*construct_message) (SSL *s);
+    int (*construct_message) (SSL *s, WPACKET *pkt);
     void (*cb) (const SSL *ssl, int type, int val) = NULL;
+    WPACKET pkt;
 
     cb = get_callback(s);
 
@@ -718,14 +746,16 @@ static SUB_STATE_RETURN write_state_machine(SSL *s)
                 return SUB_STATE_FINISHED;
                 break;
 
-            default:
+            case WRITE_TRAN_ERROR:
                 return SUB_STATE_ERROR;
             }
             break;
 
         case WRITE_STATE_PRE_WORK:
             switch (st->write_state_work = pre_work(s, st->write_state_work)) {
-            default:
+            case WORK_ERROR:
+            case WORK_MORE_A:
+            case WORK_MORE_B:
                 return SUB_STATE_ERROR;
 
             case WORK_FINISHED_CONTINUE:
@@ -735,8 +765,13 @@ static SUB_STATE_RETURN write_state_machine(SSL *s)
             case WORK_FINISHED_STOP:
                 return SUB_STATE_END_HANDSHAKE;
             }
-            if (construct_message(s) == 0)
+            if (!WPACKET_init(&pkt, s->init_buf)
+                    || !construct_message(s, &pkt)
+                    || !WPACKET_finish(&pkt)) {
+                WPACKET_cleanup(&pkt);
+                ossl_statem_set_error(s);
                 return SUB_STATE_ERROR;
+            }
 
             /* Fall through */
 
@@ -754,7 +789,9 @@ static SUB_STATE_RETURN write_state_machine(SSL *s)
 
         case WRITE_STATE_POST_WORK:
             switch (st->write_state_work = post_work(s, st->write_state_work)) {
-            default:
+            case WORK_ERROR:
+            case WORK_MORE_A:
+            case WORK_MORE_B:
                 return SUB_STATE_ERROR;
 
             case WORK_FINISHED_CONTINUE: