Tighten sanity checks when calling early data functions
[openssl.git] / ssl / ssl_lib.c
index d3dddafad7bf6269dad45eaf5acecb712ca68546..b675c2eeadca335093234bbfa76509ae79c45082 100644 (file)
@@ -1545,6 +1545,14 @@ int ssl_read_internal(SSL *s, void *buf, size_t num, size_t *readbytes)
         return 0;
     }
 
+    if (s->early_data_state != SSL_EARLY_DATA_NONE
+            && s->early_data_state != SSL_EARLY_DATA_FINISHED_WRITING
+            && s->early_data_state != SSL_EARLY_DATA_FINISHED_READING
+            && s->early_data_state != SSL_EARLY_DATA_READING) {
+        SSLerr(SSL_F_SSL_READ_INTERNAL, ERR_R_SHOULD_NOT_HAVE_BEEN_CALLED);
+        return 0;
+    }
+
     if ((s->mode & SSL_MODE_ASYNC) && ASYNC_get_current_job() == NULL) {
         struct ssl_async_args args;
         int ret;
@@ -1594,6 +1602,76 @@ int SSL_read_ex(SSL *s, void *buf, size_t num, size_t *readbytes)
     return ret;
 }
 
+int SSL_read_early(SSL *s, void *buf, size_t num, size_t *readbytes)
+{
+    int ret;
+
+    if (!s->server) {
+        SSLerr(SSL_F_SSL_READ_EARLY, ERR_R_SHOULD_NOT_HAVE_BEEN_CALLED);
+        return SSL_READ_EARLY_ERROR;
+    }
+
+    switch (s->early_data_state) {
+    case SSL_EARLY_DATA_NONE:
+        if (!SSL_in_before(s)) {
+            SSLerr(SSL_F_SSL_READ_EARLY, ERR_R_SHOULD_NOT_HAVE_BEEN_CALLED);
+            return SSL_READ_EARLY_ERROR;
+        }
+        /* fall through */
+
+    case SSL_EARLY_DATA_ACCEPT_RETRY:
+        s->early_data_state = SSL_EARLY_DATA_ACCEPTING;
+        ret = SSL_accept(s);
+        if (ret <= 0) {
+            /* NBIO or error */
+            s->early_data_state = SSL_EARLY_DATA_ACCEPT_RETRY;
+            return SSL_READ_EARLY_ERROR;
+        }
+        /* fall through */
+
+    case SSL_EARLY_DATA_READ_RETRY:
+        if (s->ext.early_data == SSL_EARLY_DATA_ACCEPTED) {
+            s->early_data_state = SSL_EARLY_DATA_READING;
+            ret = SSL_read_ex(s, buf, num, readbytes);
+            /*
+             * Record layer will call ssl_end_of_early_data_seen() if we see
+             * that alert - which updates the early_data_state to
+             * SSL_EARLY_DATA_FINISHED_READING
+             */
+            if (ret > 0 || (ret <= 0 && s->early_data_state
+                                        != SSL_EARLY_DATA_FINISHED_READING)) {
+                s->early_data_state = SSL_EARLY_DATA_READ_RETRY;
+                return ret > 0 ? SSL_READ_EARLY_SUCCESS : SSL_READ_EARLY_ERROR;
+            }
+        } else {
+            s->early_data_state = SSL_EARLY_DATA_FINISHED_READING;
+        }
+        *readbytes = 0;
+        ossl_statem_set_in_init(s, 1);
+        return SSL_READ_EARLY_FINISH;
+
+    default:
+        SSLerr(SSL_F_SSL_READ_EARLY, ERR_R_SHOULD_NOT_HAVE_BEEN_CALLED);
+        return SSL_READ_EARLY_ERROR;
+    }
+}
+
+int ssl_end_of_early_data_seen(SSL *s)
+{
+    if (s->early_data_state == SSL_EARLY_DATA_READING) {
+        s->early_data_state = SSL_EARLY_DATA_FINISHED_READING;
+        ossl_statem_finish_early_data(s);
+        return 1;
+    }
+
+    return 0;
+}
+
+int SSL_get_early_data_status(const SSL *s)
+{
+    return s->ext.early_data;
+}
+
 static int ssl_peek_internal(SSL *s, void *buf, size_t num, size_t *readbytes)
 {
     if (s->handshake_func == NULL) {
@@ -1667,6 +1745,14 @@ int ssl_write_internal(SSL *s, const void *buf, size_t num, size_t *written)
         return -1;
     }
 
+    if (s->early_data_state != SSL_EARLY_DATA_NONE
+            && s->early_data_state != SSL_EARLY_DATA_FINISHED_WRITING
+            && s->early_data_state != SSL_EARLY_DATA_FINISHED_READING
+            && s->early_data_state != SSL_EARLY_DATA_WRITING) {
+        SSLerr(SSL_F_SSL_WRITE_INTERNAL, ERR_R_SHOULD_NOT_HAVE_BEEN_CALLED);
+        return 0;
+    }
+
     if ((s->mode & SSL_MODE_ASYNC) && ASYNC_get_current_job() == NULL) {
         int ret;
         struct ssl_async_args args;
@@ -1716,6 +1802,73 @@ int SSL_write_ex(SSL *s, const void *buf, size_t num, size_t *written)
     return ret;
 }
 
+int SSL_write_early(SSL *s, const void *buf, size_t num, size_t *written)
+{
+    int ret;
+
+    if (s->server) {
+        SSLerr(SSL_F_SSL_WRITE_EARLY, ERR_R_SHOULD_NOT_HAVE_BEEN_CALLED);
+        return 0;
+    }
+
+    switch (s->early_data_state) {
+    case SSL_EARLY_DATA_NONE:
+        if (!SSL_in_before(s)
+                || s->session == NULL
+                || s->session->ext.max_early_data == 0) {
+            SSLerr(SSL_F_SSL_WRITE_EARLY, ERR_R_SHOULD_NOT_HAVE_BEEN_CALLED);
+            return 0;
+        }
+        /* fall through */
+
+    case SSL_EARLY_DATA_CONNECT_RETRY:
+        s->early_data_state = SSL_EARLY_DATA_CONNECTING;
+        ret = SSL_connect(s);
+        if (ret <= 0) {
+            /* NBIO or error */
+            s->early_data_state = SSL_EARLY_DATA_CONNECT_RETRY;
+            return 0;
+        }
+        /* fall through */
+
+    case SSL_EARLY_DATA_WRITE_RETRY:
+        s->early_data_state = SSL_EARLY_DATA_WRITING;
+        ret = SSL_write_ex(s, buf, num, written);
+        s->early_data_state = SSL_EARLY_DATA_WRITE_RETRY;
+        return ret;
+
+    default:
+        SSLerr(SSL_F_SSL_WRITE_EARLY, ERR_R_SHOULD_NOT_HAVE_BEEN_CALLED);
+        return 0;
+    }
+}
+
+int SSL_write_early_finish(SSL *s)
+{
+    int ret;
+
+    if (s->early_data_state != SSL_EARLY_DATA_WRITE_RETRY) {
+        SSLerr(SSL_F_SSL_WRITE_EARLY_FINISH, ERR_R_SHOULD_NOT_HAVE_BEEN_CALLED);
+        return 0;
+    }
+
+    s->early_data_state = SSL_EARLY_DATA_WRITING;
+    ret = ssl3_send_alert(s, SSL3_AL_WARNING, SSL_AD_END_OF_EARLY_DATA);
+    if (ret <= 0) {
+        s->early_data_state = SSL_EARLY_DATA_WRITE_RETRY;
+        return 0;
+    }
+    s->early_data_state = SSL_EARLY_DATA_FINISHED_WRITING;
+    /*
+     * We set the enc_write_ctx back to NULL because we may end up writing
+     * in cleartext again if we get a HelloRetryRequest from the server.
+     */
+    EVP_CIPHER_CTX_free(s->enc_write_ctx);
+    s->enc_write_ctx = NULL;
+    ossl_statem_set_in_init(s, 1);
+    return 1;
+}
+
 int SSL_shutdown(SSL *s)
 {
     /*
@@ -2623,6 +2776,12 @@ SSL_CTX *SSL_CTX_new(const SSL_METHOD *meth)
 
     ret->ext.status_type = TLSEXT_STATUSTYPE_nothing;
 
+    /*
+     * Default max early data is a fully loaded single record. Could be split
+     * across multiple records in practice
+     */
+    ret->max_early_data = SSL3_RT_MAX_PLAIN_LENGTH;
+
     return ret;
  err:
     SSLerr(SSL_F_SSL_CTX_NEW, ERR_R_MALLOC_FAILURE);
@@ -3073,6 +3232,10 @@ int SSL_do_handshake(SSL *s)
         return -1;
     }
 
+    if (s->early_data_state == SSL_EARLY_DATA_WRITE_RETRY
+            || s->early_data_state == SSL_EARLY_DATA_CONNECT_RETRY)
+        return -1;
+
     s->method->ssl_renegotiate_check(s, 0);
 
     if (SSL_in_init(s) || SSL_in_before(s)) {
@@ -4666,7 +4829,7 @@ int SSL_CTX_set_max_early_data(SSL_CTX *ctx, uint32_t max_early_data)
     return 1;
 }
 
-uint32_t SSL_CTX_get_max_early_data(SSL_CTX *ctx)
+uint32_t SSL_CTX_get_max_early_data(const SSL_CTX *ctx)
 {
     return ctx->max_early_data;
 }
@@ -4678,7 +4841,7 @@ int SSL_set_max_early_data(SSL *s, uint32_t max_early_data)
     return 1;
 }
 
-uint32_t SSL_get_max_early_data(SSL_CTX *s)
+uint32_t SSL_get_max_early_data(const SSL_CTX *s)
 {
     return s->max_early_data;
 }