Rename SSL_write_early() to SSL_write_early_data()
[openssl.git] / ssl / ssl_lib.c
index e3e7853d602997103ff1365d4ee8b8c38ef406f9..c3496e7b48e85ff57868f85179a54b6033a604e0 100644 (file)
@@ -105,6 +105,8 @@ static const struct {
     },
 };
 
+static int ssl_write_early_finish(SSL *s);
+
 static int dane_ctx_enable(struct dane_ctx_st *dctx)
 {
     const EVP_MD **mdevp;
@@ -1545,6 +1547,17 @@ int ssl_read_internal(SSL *s, void *buf, size_t num, size_t *readbytes)
         return 0;
     }
 
+    if (s->early_data_state == SSL_EARLY_DATA_CONNECT_RETRY
+                || s->early_data_state == SSL_EARLY_DATA_ACCEPT_RETRY) {
+        SSLerr(SSL_F_SSL_READ_INTERNAL, ERR_R_SHOULD_NOT_HAVE_BEEN_CALLED);
+        return 0;
+    }
+    /*
+     * If we are a client and haven't received the ServerHello etc then we
+     * better do that
+     */
+    ossl_statem_check_finish_init(s, 0);
+
     if ((s->mode & SSL_MODE_ASYNC) && ASYNC_get_current_job() == NULL) {
         struct ssl_async_args args;
         int ret;
@@ -1594,25 +1607,21 @@ int SSL_read_ex(SSL *s, void *buf, size_t num, size_t *readbytes)
     return ret;
 }
 
-int SSL_read_early(SSL *s, void *buf, size_t num, size_t *readbytes)
+int SSL_read_early_data(SSL *s, void *buf, size_t num, size_t *readbytes)
 {
     int ret;
 
     if (!s->server) {
-        SSLerr(SSL_F_SSL_READ_EARLY, ERR_R_SHOULD_NOT_HAVE_BEEN_CALLED);
-        return SSL_READ_EARLY_ERROR;
+        SSLerr(SSL_F_SSL_READ_EARLY_DATA, ERR_R_SHOULD_NOT_HAVE_BEEN_CALLED);
+        return SSL_READ_EARLY_DATA_ERROR;
     }
 
-    /*
-     * TODO(TLS1.3): Somehow we need to check that we're not receiving too much
-     * data
-     */
-
     switch (s->early_data_state) {
     case SSL_EARLY_DATA_NONE:
         if (!SSL_in_before(s)) {
-            SSLerr(SSL_F_SSL_READ_EARLY, ERR_R_SHOULD_NOT_HAVE_BEEN_CALLED);
-            return SSL_READ_EARLY_ERROR;
+            SSLerr(SSL_F_SSL_READ_EARLY_DATA,
+                   ERR_R_SHOULD_NOT_HAVE_BEEN_CALLED);
+            return SSL_READ_EARLY_DATA_ERROR;
         }
         /* fall through */
 
@@ -1622,7 +1631,7 @@ int SSL_read_early(SSL *s, void *buf, size_t num, size_t *readbytes)
         if (ret <= 0) {
             /* NBIO or error */
             s->early_data_state = SSL_EARLY_DATA_ACCEPT_RETRY;
-            return SSL_READ_EARLY_ERROR;
+            return SSL_READ_EARLY_DATA_ERROR;
         }
         /* fall through */
 
@@ -1638,31 +1647,38 @@ int SSL_read_early(SSL *s, void *buf, size_t num, size_t *readbytes)
             if (ret > 0 || (ret <= 0 && s->early_data_state
                                         != SSL_EARLY_DATA_FINISHED_READING)) {
                 s->early_data_state = SSL_EARLY_DATA_READ_RETRY;
-                return ret > 0 ? SSL_READ_EARLY_SUCCESS : SSL_READ_EARLY_ERROR;
+                return ret > 0 ? SSL_READ_EARLY_DATA_SUCCESS
+                               : SSL_READ_EARLY_DATA_ERROR;
             }
         } else {
             s->early_data_state = SSL_EARLY_DATA_FINISHED_READING;
         }
         *readbytes = 0;
-        ossl_statem_set_in_init(s, 1);
-        return SSL_READ_EARLY_FINISH;
+        return SSL_READ_EARLY_DATA_FINISH;
 
     default:
-        SSLerr(SSL_F_SSL_READ_EARLY, ERR_R_SHOULD_NOT_HAVE_BEEN_CALLED);
-        return SSL_READ_EARLY_ERROR;
+        SSLerr(SSL_F_SSL_READ_EARLY_DATA, ERR_R_SHOULD_NOT_HAVE_BEEN_CALLED);
+        return SSL_READ_EARLY_DATA_ERROR;
     }
 }
 
 int ssl_end_of_early_data_seen(SSL *s)
 {
-    if (s->early_data_state == SSL_EARLY_DATA_READING) {
+    if (s->early_data_state == SSL_EARLY_DATA_READING
+            || s->early_data_state == SSL_EARLY_DATA_READ_RETRY) {
         s->early_data_state = SSL_EARLY_DATA_FINISHED_READING;
+        ossl_statem_finish_early_data(s);
         return 1;
     }
 
     return 0;
 }
 
+int SSL_get_early_data_status(const SSL *s)
+{
+    return s->ext.early_data;
+}
+
 static int ssl_peek_internal(SSL *s, void *buf, size_t num, size_t *readbytes)
 {
     if (s->handshake_func == NULL) {
@@ -1736,9 +1752,20 @@ int ssl_write_internal(SSL *s, const void *buf, size_t num, size_t *written)
         return -1;
     }
 
-    if (s->early_data_state == SSL_EARLY_DATA_WRITE_RETRY
-            || s->early_data_state == SSL_EARLY_DATA_CONNECT_RETRY)
+    if (s->early_data_state == SSL_EARLY_DATA_WRITE_RETRY) {
+        /*
+         * We're still writing early data. We need to stop that so we can write
+         * normal data
+         */
+        if (!ssl_write_early_finish(s))
+            return 0;
+    } else if (s->early_data_state == SSL_EARLY_DATA_CONNECT_RETRY
+                || s->early_data_state == SSL_EARLY_DATA_ACCEPT_RETRY) {
+        SSLerr(SSL_F_SSL_WRITE_INTERNAL, ERR_R_SHOULD_NOT_HAVE_BEEN_CALLED);
         return 0;
+    }
+    /* If we are a client and haven't sent the Finished we better do that */
+    ossl_statem_check_finish_init(s, 1);
 
     if ((s->mode & SSL_MODE_ASYNC) && ASYNC_get_current_job() == NULL) {
         int ret;
@@ -1789,7 +1816,7 @@ int SSL_write_ex(SSL *s, const void *buf, size_t num, size_t *written)
     return ret;
 }
 
-int SSL_write_early(SSL *s, const void *buf, size_t num, size_t *written)
+int SSL_write_early_data(SSL *s, const void *buf, size_t num, size_t *written)
 {
     int ret;
 
@@ -1798,14 +1825,11 @@ int SSL_write_early(SSL *s, const void *buf, size_t num, size_t *written)
         return 0;
     }
 
-    /*
-     * TODO(TLS1.3): Somehow we need to check that we're not sending too much
-     * data
-     */
-
     switch (s->early_data_state) {
     case SSL_EARLY_DATA_NONE:
-        if (!SSL_in_before(s)) {
+        if (!SSL_in_before(s)
+                || s->session == NULL
+                || s->session->ext.max_early_data == 0) {
             SSLerr(SSL_F_SSL_WRITE_EARLY, ERR_R_SHOULD_NOT_HAVE_BEEN_CALLED);
             return 0;
         }
@@ -1833,7 +1857,7 @@ int SSL_write_early(SSL *s, const void *buf, size_t num, size_t *written)
     }
 }
 
-int SSL_write_early_finish(SSL *s)
+static int ssl_write_early_finish(SSL *s)
 {
     int ret;
 
@@ -2766,6 +2790,12 @@ SSL_CTX *SSL_CTX_new(const SSL_METHOD *meth)
 
     ret->ext.status_type = TLSEXT_STATUSTYPE_nothing;
 
+    /*
+     * Default max early data is a fully loaded single record. Could be split
+     * across multiple records in practice
+     */
+    ret->max_early_data = SSL3_RT_MAX_PLAIN_LENGTH;
+
     return ret;
  err:
     SSLerr(SSL_F_SSL_CTX_NEW, ERR_R_MALLOC_FAILURE);
@@ -3216,9 +3246,14 @@ int SSL_do_handshake(SSL *s)
         return -1;
     }
 
-    if (s->early_data_state == SSL_EARLY_DATA_WRITE_RETRY
-            || s->early_data_state == SSL_EARLY_DATA_CONNECT_RETRY)
-        return -1;
+    if (s->early_data_state == SSL_EARLY_DATA_WRITE_RETRY) {
+        int edfin;
+
+        edfin = ssl_write_early_finish(s);
+        if (edfin <= 0)
+            return edfin;
+    }
+    ossl_statem_check_finish_init(s, -1);
 
     s->method->ssl_renegotiate_check(s, 0);
 
@@ -4813,7 +4848,7 @@ int SSL_CTX_set_max_early_data(SSL_CTX *ctx, uint32_t max_early_data)
     return 1;
 }
 
-uint32_t SSL_CTX_get_max_early_data(SSL_CTX *ctx)
+uint32_t SSL_CTX_get_max_early_data(const SSL_CTX *ctx)
 {
     return ctx->max_early_data;
 }
@@ -4825,7 +4860,7 @@ int SSL_set_max_early_data(SSL *s, uint32_t max_early_data)
     return 1;
 }
 
-uint32_t SSL_get_max_early_data(SSL_CTX *s)
+uint32_t SSL_get_max_early_data(const SSL_CTX *s)
 {
     return s->max_early_data;
 }