Resolve a TODO in ssl3_dispatch_alert
[openssl.git] / ssl / record / rec_layer_s3.c
index 301742dba590dfb74b9f841f4d769cc5aed6b40d..b4435bf0201b3458ad23760ad3a1ac9d90ff52af 100644 (file)
@@ -33,8 +33,6 @@ void RECORD_LAYER_clear(RECORD_LAYER *rl)
     rl->wpend_ret = 0;
     rl->wpend_buf = NULL;
 
-    RECORD_LAYER_reset_write_sequence(rl);
-
     if (rl->rrlmethod != NULL)
         rl->rrlmethod->free(rl->rrl); /* Ignore return value */
     if (rl->wrlmethod != NULL)
@@ -50,14 +48,6 @@ void RECORD_LAYER_clear(RECORD_LAYER *rl)
         DTLS_RECORD_LAYER_clear(rl);
 }
 
-void RECORD_LAYER_release(RECORD_LAYER *rl)
-{
-    /*
-     * TODO(RECLAYER): Need a way to release the write buffers in the record
-     * layer on demand
-     */
-}
-
 /* Checks if we have unprocessed read ahead data pending */
 int RECORD_LAYER_read_pending(const RECORD_LAYER *rl)
 {
@@ -76,9 +66,60 @@ int RECORD_LAYER_write_pending(const RECORD_LAYER *rl)
     return rl->wpend_tot > 0;
 }
 
-void RECORD_LAYER_reset_write_sequence(RECORD_LAYER *rl)
+static uint32_t ossl_get_max_early_data(SSL_CONNECTION *s)
+{
+    uint32_t max_early_data;
+    SSL_SESSION *sess = s->session;
+
+    /*
+     * If we are a client then we always use the max_early_data from the
+     * session/psksession. Otherwise we go with the lowest out of the max early
+     * data set in the session and the configured max_early_data.
+     */
+    if (!s->server && sess->ext.max_early_data == 0) {
+        if (!ossl_assert(s->psksession != NULL
+                         && s->psksession->ext.max_early_data > 0)) {
+            SSLfatal(s, SSL_AD_INTERNAL_ERROR, ERR_R_INTERNAL_ERROR);
+            return 0;
+        }
+        sess = s->psksession;
+    }
+
+    if (!s->server)
+        max_early_data = sess->ext.max_early_data;
+    else if (s->ext.early_data != SSL_EARLY_DATA_ACCEPTED)
+        max_early_data = s->recv_max_early_data;
+    else
+        max_early_data = s->recv_max_early_data < sess->ext.max_early_data
+                         ? s->recv_max_early_data : sess->ext.max_early_data;
+
+    return max_early_data;
+}
+
+static int ossl_early_data_count_ok(SSL_CONNECTION *s, size_t length,
+                                    size_t overhead, int send)
 {
-    memset(rl->write_sequence, 0, sizeof(rl->write_sequence));
+    uint32_t max_early_data;
+
+    max_early_data = ossl_get_max_early_data(s);
+
+    if (max_early_data == 0) {
+        SSLfatal(s, send ? SSL_AD_INTERNAL_ERROR : SSL_AD_UNEXPECTED_MESSAGE,
+                 SSL_R_TOO_MUCH_EARLY_DATA);
+        return 0;
+    }
+
+    /* If we are dealing with ciphertext we need to allow for the overhead */
+    max_early_data += overhead;
+
+    if (s->early_data_count + length > max_early_data) {
+        SSLfatal(s, send ? SSL_AD_INTERNAL_ERROR : SSL_AD_UNEXPECTED_MESSAGE,
+                 SSL_R_TOO_MUCH_EARLY_DATA);
+        return 0;
+    }
+    s->early_data_count += length;
+
+    return 1;
 }
 
 size_t ssl3_pending(const SSL *s)
@@ -274,7 +315,7 @@ int ssl3_write_bytes(SSL *ssl, int type, const void *buf_, size_t len,
     }
 
     /* If we have an alert to send, lets send it */
-    if (s->s3.alert_dispatch) {
+    if (s->s3.alert_dispatch > 0) {
         i = ssl->method->ssl_dispatch_alert(ssl);
         if (i <= 0) {
             /* SSLfatal() already called if appropriate */
@@ -637,7 +678,7 @@ int ssl3_read_bytes(SSL *ssl, int type, int *recvd_type, unsigned char *buf,
          * doing a handshake for the first time
          */
         if (SSL_in_init(ssl) && type == SSL3_RT_APPLICATION_DATA
-            && s->enc_read_ctx == NULL) {
+                && SSL_IS_FIRST_HANDSHAKE(s)) {
             SSLfatal(s, SSL_AD_UNEXPECTED_MESSAGE, SSL_R_APP_DATA_IN_HANDSHAKE);
             return -1;
         }
@@ -1220,8 +1261,8 @@ int ssl_set_new_record_layer(SSL_CONNECTION *s, int version,
         max_early_data = ossl_get_max_early_data(s);
 
         if (max_early_data != 0)
-            *set++ = OSSL_PARAM_construct_uint(OSSL_LIBSSL_RECORD_LAYER_PARAM_MAX_EARLY_DATA,
-                                               &max_early_data);
+            *set++ = OSSL_PARAM_construct_uint32(OSSL_LIBSSL_RECORD_LAYER_PARAM_MAX_EARLY_DATA,
+                                                 &max_early_data);
     }
 
     *set = OSSL_PARAM_construct_end();
@@ -1315,11 +1356,14 @@ int ssl_set_new_record_layer(SSL_CONNECTION *s, int version,
 
     /*
      * Free the old record layer if we have one except in the case of DTLS when
-     * writing. In that case the record layer is still referenced by buffered
-     * messages for potential retransmit. Only when those buffered messages get
-     * freed do we free the record layer object (see dtls1_hm_fragment_free)
+     * writing and there are still buffered sent messages in our queue. In that
+     * case the record layer is still referenced by those buffered messages for
+     * potential retransmit. Only when those buffered messages get freed do we
+     * free the record layer object (see dtls1_hm_fragment_free)
      */
-    if (!SSL_CONNECTION_IS_DTLS(s) || direction == OSSL_RECORD_DIRECTION_READ) {
+    if (!SSL_CONNECTION_IS_DTLS(s)
+            || direction == OSSL_RECORD_DIRECTION_READ
+            || pqueue_peek(s->d1->sent_messages) == NULL) {
         if (*thismethod != NULL && !(*thismethod)->free(*thisrl)) {
             SSLfatal(s, SSL_AD_INTERNAL_ERROR, ERR_R_INTERNAL_ERROR);
             return 0;