Tolerate fragmentation and interleaving in the SSL 3/TLS record layer.
[openssl.git] / ssl / s3_both.c
index 9b6766322e2151e5745dbb0c277471df833ea5ec..6236b74572714c9d581c4dc812f59859e64ca8ff 100644 (file)
@@ -123,7 +123,7 @@ int ssl3_get_finished(SSL *s, int a, int b)
 
        if (!ok) return((int)n);
 
-       /* If this occurs if we has missed a message */
+       /* If this occurs, we have missed a message */
        if (!s->s3->change_cipher_spec)
                {
                al=SSL_AD_UNEXPECTED_MESSAGE;
@@ -283,16 +283,22 @@ long ssl3_get_message(SSL *s, int st1, int stn, int mt, long max, int *ok)
 
        p=(unsigned char *)s->init_buf->data;
 
-       if (s->state == st1)
+       if (s->state == st1) /* s->init_num < 4 */
                {
-               i=ssl3_read_bytes(s,SSL3_RT_HANDSHAKE,&p[s->init_num],
-                                 4-s->init_num);
-               if (i < (4-s->init_num))
+               while (s->init_num < 4)
                        {
-                       *ok=0;
-                       return(ssl3_part_read(s,i));
+                       i=ssl3_read_bytes(s,SSL3_RT_HANDSHAKE,&p[s->init_num],
+                               4-s->init_num);
+                       if (i <= 0)
+                               {
+                               s->rwstate=SSL_READING;
+                               *ok = 0;
+                               return i;
+                               }
+                       s->init_num+=i;
                        }
 
+/* XXX server may always send Hello Request */
                if ((mt >= 0) && (*p != mt))
                        {
                        al=SSL_AD_UNEXPECTED_MESSAGE;
@@ -334,17 +340,20 @@ long ssl3_get_message(SSL *s, int st1, int stn, int mt, long max, int *ok)
        /* next state (stn) */
        p=(unsigned char *)s->init_buf->data;
        n=s->s3->tmp.message_size;
-       if (n > 0)
+       while (n > 0)
                {
                i=ssl3_read_bytes(s,SSL3_RT_HANDSHAKE,&p[s->init_num],n);
-               if (i != (int)n)
+               if (i <= 0)
                        {
-                       *ok=0;
-                       return(ssl3_part_read(s,i));
+                       s->rwstate=SSL_READING;
+                       *ok = 0;
+                       return i;
                        }
+               s->init_num += i;
+               n -= i;
                }
        *ok=1;
-       return(n);
+       return s->init_num;
 f_err:
        ssl3_send_alert(s,SSL3_AL_FATAL,al);
 err:
@@ -465,7 +474,7 @@ int ssl3_setup_buffers(SSL *s)
                        extra=SSL3_RT_MAX_EXTRA;
                else
                        extra=0;
-               if ((p=(unsigned char *)Malloc(SSL3_RT_MAX_PACKET_SIZE+extra))
+               if ((p=Malloc(SSL3_RT_MAX_PACKET_SIZE+extra))
                        == NULL)
                        goto err;
                s->s3->rbuf.buf=p;
@@ -473,7 +482,7 @@ int ssl3_setup_buffers(SSL *s)
 
        if (s->s3->wbuf.buf == NULL)
                {
-               if ((p=(unsigned char *)Malloc(SSL3_RT_MAX_PACKET_SIZE))
+               if ((p=Malloc(SSL3_RT_MAX_PACKET_SIZE))
                        == NULL)
                        goto err;
                s->s3->wbuf.buf=p;