Provide functions to write early data
authorMatt Caswell <matt@openssl.org>
Tue, 21 Feb 2017 09:22:22 +0000 (09:22 +0000)
committerMatt Caswell <matt@openssl.org>
Thu, 2 Mar 2017 17:44:14 +0000 (17:44 +0000)
We provide SSL_write_early() which *must* be called first on a connection
(prior to any other IO function including SSL_connect()/SSL_do_handshake()).
Also SSL_write_early_finish() which signals the end of early data.

Reviewed-by: Rich Salz <rsalz@openssl.org>
(Merged from https://github.com/openssl/openssl/pull/2737)

12 files changed:
include/openssl/ssl.h
include/openssl/tls1.h
ssl/record/rec_layer_s3.c
ssl/record/ssl3_record_tls13.c
ssl/s3_msg.c
ssl/ssl_err.c
ssl/ssl_lib.c
ssl/ssl_locl.h
ssl/statem/statem.c
ssl/statem/statem_clnt.c
ssl/tls13_enc.c
util/libssl.num

index 7f84e0a265c6bd0c1eac869d88539894bec0283e..7d3ac4e2533f69dcc7dea355723690e67cc5e8e7 100644 (file)
@@ -1025,6 +1025,7 @@ DECLARE_PEM_rw(SSL_SESSION, SSL_SESSION)
 # define SSL_AD_INTERNAL_ERROR           TLS1_AD_INTERNAL_ERROR
 # define SSL_AD_USER_CANCELLED           TLS1_AD_USER_CANCELLED
 # define SSL_AD_NO_RENEGOTIATION         TLS1_AD_NO_RENEGOTIATION
+# define SSL_AD_END_OF_EARLY_DATA        TLS13_AD_END_OF_EARLY_DATA
 # define SSL_AD_MISSING_EXTENSION        TLS13_AD_MISSING_EXTENSION
 # define SSL_AD_UNSUPPORTED_EXTENSION    TLS1_AD_UNSUPPORTED_EXTENSION
 # define SSL_AD_CERTIFICATE_UNOBTAINABLE TLS1_AD_CERTIFICATE_UNOBTAINABLE
@@ -1614,6 +1615,9 @@ __owur int SSL_peek(SSL *ssl, void *buf, int num);
 __owur int SSL_peek_ex(SSL *ssl, void *buf, size_t num, size_t *readbytes);
 __owur int SSL_write(SSL *ssl, const void *buf, int num);
 __owur int SSL_write_ex(SSL *s, const void *buf, size_t num, size_t *written);
+__owur int SSL_write_early(SSL *s, const void *buf, size_t num,
+                           size_t *written);
+__owur int SSL_write_early_finish(SSL *s);
 long SSL_ctrl(SSL *ssl, int cmd, long larg, void *parg);
 long SSL_callback_ctrl(SSL *, int, void (*)(void));
 long SSL_CTX_ctrl(SSL_CTX *ctx, int cmd, long larg, void *parg);
@@ -2290,6 +2294,8 @@ int ERR_load_SSL_strings(void);
 # define SSL_F_SSL_VALIDATE_CT                            400
 # define SSL_F_SSL_VERIFY_CERT_CHAIN                      207
 # define SSL_F_SSL_WRITE                                  208
+# define SSL_F_SSL_WRITE_EARLY                            526
+# define SSL_F_SSL_WRITE_EARLY_FINISH                     527
 # define SSL_F_SSL_WRITE_EX                               433
 # define SSL_F_SSL_WRITE_INTERNAL                         524
 # define SSL_F_STATE_MACHINE                              353
@@ -2382,7 +2388,7 @@ int ERR_load_SSL_strings(void);
 # define SSL_F_TLS_PARSE_CTOS_PSK                         505
 # define SSL_F_TLS_PARSE_CTOS_RENEGOTIATE                 464
 # define SSL_F_TLS_PARSE_CTOS_USE_SRTP                    465
-# define SSL_F_TLS_PARSE_STOC_EARLY_DATA_INFO             520
+# define SSL_F_TLS_PARSE_STOC_EARLY_DATA_INFO             528
 # define SSL_F_TLS_PARSE_STOC_KEY_SHARE                   445
 # define SSL_F_TLS_PARSE_STOC_PSK                         502
 # define SSL_F_TLS_PARSE_STOC_RENEGOTIATE                 448
index 4fb5a560ad87f20035372e13a4262080fdd16049..63e9ee35cb09a4b44ca5e40d770a084dda5884e1 100644 (file)
@@ -104,6 +104,7 @@ extern "C" {
 # define TLS1_AD_USER_CANCELLED          90
 # define TLS1_AD_NO_RENEGOTIATION        100
 /* TLSv1.3 alerts */
+# define TLS13_AD_END_OF_EARLY_DATA      1
 # define TLS13_AD_MISSING_EXTENSION      109 /* fatal */
 /* codes 110-114 are from RFC3546 */
 # define TLS1_AD_UNSUPPORTED_EXTENSION   110
index 5aea4b31bdd6262327432c361a8ade2faef3110f..d7b98e055b06d288396777e0e99a683ba0299c8f 100644 (file)
@@ -745,7 +745,7 @@ int do_ssl3_write(SSL *s, int type, const unsigned char *buf,
     }
 
     /* Explicit IV length, block ciphers appropriate version flag */
-    if (s->enc_write_ctx && SSL_USE_EXPLICIT_IV(s)) {
+    if (s->enc_write_ctx && SSL_USE_EXPLICIT_IV(s) && !SSL_TREAT_AS_TLS13(s)) {
         int mode = EVP_CIPHER_CTX_mode(s->enc_write_ctx);
         if (mode == EVP_CIPH_CBC_MODE) {
             /* TODO(size_t): Convert me */
@@ -764,7 +764,7 @@ int do_ssl3_write(SSL *s, int type, const unsigned char *buf,
     /* Clear our SSL3_RECORD structures */
     memset(wr, 0, sizeof wr);
     for (j = 0; j < numpipes; j++) {
-        unsigned int version = SSL_IS_TLS13(s) ? TLS1_VERSION : s->version;
+        unsigned int version = SSL_TREAT_AS_TLS13(s) ? TLS1_VERSION : s->version;
         unsigned char *compressdata = NULL;
         size_t maxcomplen;
         unsigned int rectype;
@@ -777,7 +777,7 @@ int do_ssl3_write(SSL *s, int type, const unsigned char *buf,
          * In TLSv1.3, once encrypting, we always use application data for the
          * record type
          */
-        if (SSL_IS_TLS13(s) && s->enc_write_ctx != NULL)
+        if (SSL_TREAT_AS_TLS13(s) && s->enc_write_ctx != NULL)
             rectype = SSL3_RT_APPLICATION_DATA;
         else
             rectype = type;
@@ -835,7 +835,7 @@ int do_ssl3_write(SSL *s, int type, const unsigned char *buf,
             SSL3_RECORD_reset_input(&wr[j]);
         }
 
-        if (SSL_IS_TLS13(s) && s->enc_write_ctx != NULL) {
+        if (SSL_TREAT_AS_TLS13(s) && s->enc_write_ctx != NULL) {
             if (!WPACKET_put_bytes_u8(thispkt, type)) {
                 SSLerr(SSL_F_DO_SSL3_WRITE, ERR_R_INTERNAL_ERROR);
                 goto err;
@@ -887,8 +887,17 @@ int do_ssl3_write(SSL *s, int type, const unsigned char *buf,
         SSL3_RECORD_set_length(thiswr, len);
     }
 
-    if (s->method->ssl3_enc->enc(s, wr, numpipes, 1) < 1)
-        goto err;
+    if (s->early_data_state == SSL_EARLY_DATA_WRITING) {
+        /*
+         * We haven't actually negotiated the version yet, but we're trying to
+         * send early data - so we need to use the the tls13enc function.
+         */
+        if (tls13_enc(s, wr, numpipes, 1) < 1)
+            goto err;
+    } else {
+        if (s->method->ssl3_enc->enc(s, wr, numpipes, 1) < 1)
+            goto err;
+    }
 
     for (j = 0; j < numpipes; j++) {
         size_t origlen;
index d96a042ff92c41f803d252d7759a06656691f6bb..87041df2c75a76ed90483a158bea27919a6a18f0 100644 (file)
@@ -56,14 +56,18 @@ int tls13_enc(SSL *s, SSL3_RECORD *recs, size_t n_recs, int send)
 
     ivlen = EVP_CIPHER_CTX_iv_length(ctx);
 
-    /*
-     * To get here we must have selected a ciphersuite - otherwise ctx would
-     * be NULL
-     */
-    assert(s->s3->tmp.new_cipher != NULL);
-    if (s->s3->tmp.new_cipher == NULL)
-        return -1;
-    alg_enc = s->s3->tmp.new_cipher->algorithm_enc;
+    if (s->early_data_state == SSL_EARLY_DATA_WRITING) {
+        alg_enc = s->session->cipher->algorithm_enc;
+    } else {
+        /*
+         * To get here we must have selected a ciphersuite - otherwise ctx would
+         * be NULL
+         */
+        assert(s->s3->tmp.new_cipher != NULL);
+        if (s->s3->tmp.new_cipher == NULL)
+            return -1;
+        alg_enc = s->s3->tmp.new_cipher->algorithm_enc;
+    }
 
     if (alg_enc & SSL_AESCCM) {
         if (alg_enc & (SSL_AES128CCM8 | SSL_AES256CCM8))
index 743a02b8d17a31d0bfe5ec269903d6866eb896bc..7af2f99e05a12c34642420f6657296eca97ae2d9 100644 (file)
@@ -63,7 +63,10 @@ int ssl3_do_change_cipher_spec(SSL *s)
 int ssl3_send_alert(SSL *s, int level, int desc)
 {
     /* Map tls/ssl alert value to correct one */
-    desc = s->method->ssl3_enc->alert_value(desc);
+    if (SSL_TREAT_AS_TLS13(s))
+        desc = tls13_alert_code(desc);
+    else
+        desc = s->method->ssl3_enc->alert_value(desc);
     if (s->version == SSL3_VERSION && desc == SSL_AD_PROTOCOL_VERSION)
         desc = SSL_AD_HANDSHAKE_FAILURE; /* SSL 3.0 does not have
                                           * protocol_version alerts */
index 69d9adc9d744c235a780b7df81c78c511df39ffd..41d4a69dd8cd8523a16c3c7cc5eca5e2a2b14f30 100644 (file)
@@ -253,6 +253,8 @@ static ERR_STRING_DATA SSL_str_functs[] = {
     {ERR_FUNC(SSL_F_SSL_VALIDATE_CT), "ssl_validate_ct"},
     {ERR_FUNC(SSL_F_SSL_VERIFY_CERT_CHAIN), "ssl_verify_cert_chain"},
     {ERR_FUNC(SSL_F_SSL_WRITE), "SSL_write"},
+    {ERR_FUNC(SSL_F_SSL_WRITE_EARLY), "SSL_write_early"},
+    {ERR_FUNC(SSL_F_SSL_WRITE_EARLY_FINISH), "SSL_write_early_finish"},
     {ERR_FUNC(SSL_F_SSL_WRITE_EX), "SSL_write_ex"},
     {ERR_FUNC(SSL_F_SSL_WRITE_INTERNAL), "ssl_write_internal"},
     {ERR_FUNC(SSL_F_STATE_MACHINE), "state_machine"},
index d3dddafad7bf6269dad45eaf5acecb712ca68546..8e28786447c86feeb0e1dd0c432f54969a33ddc5 100644 (file)
@@ -1667,6 +1667,10 @@ int ssl_write_internal(SSL *s, const void *buf, size_t num, size_t *written)
         return -1;
     }
 
+    if (s->early_data_state == SSL_EARLY_DATA_WRITE_RETRY
+            || s->early_data_state == SSL_EARLY_DATA_CONNECT_RETRY)
+        return 0;
+
     if ((s->mode & SSL_MODE_ASYNC) && ASYNC_get_current_job() == NULL) {
         int ret;
         struct ssl_async_args args;
@@ -1716,6 +1720,76 @@ int SSL_write_ex(SSL *s, const void *buf, size_t num, size_t *written)
     return ret;
 }
 
+int SSL_write_early(SSL *s, const void *buf, size_t num, size_t *written)
+{
+    int ret;
+
+    if (s->server) {
+        SSLerr(SSL_F_SSL_WRITE_EARLY, ERR_R_SHOULD_NOT_HAVE_BEEN_CALLED);
+        return 0;
+    }
+
+    /*
+     * TODO(TLS1.3): Somehow we need to check that we're not sending too much
+     * data
+     */
+
+    switch (s->early_data_state) {
+    case SSL_EARLY_DATA_NONE:
+        if (!SSL_in_before(s)) {
+            SSLerr(SSL_F_SSL_WRITE_EARLY, ERR_R_SHOULD_NOT_HAVE_BEEN_CALLED);
+            return 0;
+        }
+        /* fall through */
+
+    case SSL_EARLY_DATA_CONNECT_RETRY:
+        s->early_data_state = SSL_EARLY_DATA_CONNECTING;
+        ret = SSL_connect(s);
+        if (ret <= 0) {
+            /* NBIO or error */
+            s->early_data_state = SSL_EARLY_DATA_CONNECT_RETRY;
+            return 0;
+        }
+        /* fall through */
+
+    case SSL_EARLY_DATA_WRITE_RETRY:
+        s->early_data_state = SSL_EARLY_DATA_WRITING;
+        ret = SSL_write_ex(s, buf, num, written);
+        s->early_data_state = SSL_EARLY_DATA_WRITE_RETRY;
+        return ret;
+
+    default:
+        SSLerr(SSL_F_SSL_WRITE_EARLY, ERR_R_SHOULD_NOT_HAVE_BEEN_CALLED);
+        return 0;
+    }
+}
+
+int SSL_write_early_finish(SSL *s)
+{
+    int ret;
+
+    if (s->early_data_state != SSL_EARLY_DATA_WRITE_RETRY) {
+        SSLerr(SSL_F_SSL_WRITE_EARLY_FINISH, ERR_R_SHOULD_NOT_HAVE_BEEN_CALLED);
+        return 0;
+    }
+
+    s->early_data_state = SSL_EARLY_DATA_WRITING;
+    ret = ssl3_send_alert(s, SSL3_AL_WARNING, SSL_AD_END_OF_EARLY_DATA);
+    if (ret <= 0) {
+        s->early_data_state = SSL_EARLY_DATA_WRITE_RETRY;
+        return 0;
+    }
+    s->early_data_state = SSL_EARLY_DATA_FINISHED_WRITING;
+    /*
+     * We set the enc_write_ctx back to NULL because we may end up writing
+     * in cleartext again if we get a HelloRetryRequest from the server.
+     */
+    EVP_CIPHER_CTX_free(s->enc_write_ctx);
+    s->enc_write_ctx = NULL;
+    ossl_statem_set_in_init(s, 1);
+    return 1;
+}
+
 int SSL_shutdown(SSL *s)
 {
     /*
@@ -3073,6 +3147,10 @@ int SSL_do_handshake(SSL *s)
         return -1;
     }
 
+    if (s->early_data_state == SSL_EARLY_DATA_WRITE_RETRY
+            || s->early_data_state == SSL_EARLY_DATA_CONNECT_RETRY)
+        return -1;
+
     s->method->ssl_renegotiate_check(s, 0);
 
     if (SSL_in_init(s) || SSL_in_before(s)) {
index a7fa0b52c28022452d95729108de466d2ed77e6c..86603a07c5168e94cdb1d669dd8ef151b1eac099 100644 (file)
                           && (s)->method->version >= TLS1_3_VERSION \
                           && (s)->method->version != TLS_ANY_VERSION)
 
+# define SSL_TREAT_AS_TLS13(s) \
+    (SSL_IS_TLS13(s) || (s)->early_data_state == SSL_EARLY_DATA_WRITING)
+
 # define SSL_IS_FIRST_HANDSHAKE(S) ((s)->s3->tmp.finish_md_len == 0)
 
 /* See if we need explicit IV */
@@ -609,6 +612,15 @@ typedef struct srp_ctx_st {
 
 # endif
 
+typedef enum {
+    SSL_EARLY_DATA_NONE = 0,
+    SSL_EARLY_DATA_CONNECT_RETRY,
+    SSL_EARLY_DATA_CONNECTING,
+    SSL_EARLY_DATA_WRITE_RETRY,
+    SSL_EARLY_DATA_WRITING,
+    SSL_EARLY_DATA_FINISHED_WRITING
+} SSL_EARLY_DATA_STATE;
+
 #define MAX_COMPRESSIONS_SIZE   255
 
 struct ssl_comp_st {
@@ -976,6 +988,7 @@ struct ssl_st {
     int shutdown;
     /* where we are */
     OSSL_STATEM statem;
+    SSL_EARLY_DATA_STATE early_data_state;
     BUF_MEM *init_buf;          /* buffer used during init */
     void *init_msg;             /* pointer to handshake message body, set by
                                  * ssl3_get_message() */
index a1c5a2152297ca7f22d63561e68035b2373aa2d5..10d794ede7fdb35dce999690e6f597e570cb73a4 100644 (file)
@@ -313,7 +313,9 @@ static int state_machine(SSL *s, int server)
                 goto end;
             }
 
-        if (SSL_IS_FIRST_HANDSHAKE(s) || s->renegotiate) {
+        if ((SSL_IS_FIRST_HANDSHAKE(s)
+                    && s->early_data_state != SSL_EARLY_DATA_FINISHED_WRITING)
+                || s->renegotiate) {
             if (!tls_setup_handshake(s)) {
                 ossl_statem_set_error(s);
                 goto end;
index abddc0ace375f7da3b69c3cc50785b28937d1ad4..6507fc7d598b7f688b4f4eb2bf970177c3c13507 100644 (file)
@@ -196,6 +196,11 @@ static int ossl_statem_client13_read_transition(SSL *s, int mt)
         break;
 
     case TLS_ST_OK:
+        if (s->early_data_state == SSL_EARLY_DATA_FINISHED_WRITING
+                && mt == SSL3_MT_SERVER_HELLO) {
+            st->hand_state = TLS_ST_CR_SRVR_HELLO;
+            return 1;
+        }
         if (mt == SSL3_MT_NEWSESSION_TICKET) {
             st->hand_state = TLS_ST_CR_SESSION_TICKET;
             return 1;
@@ -382,7 +387,21 @@ int ossl_statem_client_read_transition(SSL *s, int mt)
         break;
 
     case TLS_ST_OK:
-        if (mt == SSL3_MT_HELLO_REQUEST) {
+        if (s->early_data_state == SSL_EARLY_DATA_FINISHED_WRITING) {
+            /*
+             * We've not actually selected TLSv1.3 yet, but we have sent early
+             * data. The only thing allowed now is a ServerHello or a
+             * HelloRetryRequest.
+             */
+            if (mt == SSL3_MT_SERVER_HELLO) {
+                st->hand_state = TLS_ST_CR_SRVR_HELLO;
+                return 1;
+            }
+            if (mt == SSL3_MT_HELLO_RETRY_REQUEST) {
+                st->hand_state = TLS_ST_CR_HELLO_RETRY_REQUEST;
+                return 1;
+            }
+        } else if (mt == SSL3_MT_HELLO_REQUEST) {
             st->hand_state = TLS_ST_CR_HELLO_REQ;
             return 1;
         }
@@ -485,6 +504,13 @@ WRITE_TRAN ossl_statem_client_write_transition(SSL *s)
         return WRITE_TRAN_ERROR;
 
     case TLS_ST_OK:
+        if (s->early_data_state == SSL_EARLY_DATA_FINISHED_WRITING) {
+            /*
+             * We are assuming this is a TLSv1.3 connection, although we haven't
+             * actually selected a version yet.
+             */
+            return WRITE_TRAN_FINISHED;
+        }
         if (!s->renegotiate) {
             /*
              * We haven't requested a renegotiation ourselves so we must have
@@ -498,6 +524,15 @@ WRITE_TRAN ossl_statem_client_write_transition(SSL *s)
         return WRITE_TRAN_CONTINUE;
 
     case TLS_ST_CW_CLNT_HELLO:
+        if (s->early_data_state == SSL_EARLY_DATA_CONNECTING) {
+            /*
+             * We are assuming this is a TLSv1.3 connection, although we haven't
+             * actually selected a version yet.
+             */
+            st->hand_state = TLS_ST_OK;
+            ossl_statem_set_in_init(s, 0);
+            return WRITE_TRAN_CONTINUE;
+        }
         /*
          * No transition at the end of writing because we don't know what
          * we will be sent
@@ -999,9 +1034,6 @@ int tls_construct_client_hello(SSL *s, WPACKET *pkt)
     }
     /* else use the pre-loaded session */
 
-    /* This is a real handshake so make sure we clean it up at the end */
-    s->statem.cleanuphand = 1;
-
     p = s->s3->client_random;
 
     /*
@@ -1185,6 +1217,12 @@ MSG_PROCESS_RETURN tls_process_server_hello(SSL *s, PACKET *pkt)
     SSL_COMP *comp;
 #endif
 
+    /*
+     * This is a real handshake so make sure we clean it up at the end. We set
+     * this here so that we are after any early_data
+     */
+    s->statem.cleanuphand = 1;
+
     if (!PACKET_get_net_2(pkt, &sversion)) {
         al = SSL_AD_DECODE_ERROR;
         SSLerr(SSL_F_TLS_PROCESS_SERVER_HELLO, SSL_R_LENGTH_MISMATCH);
index 6faa4e0838777dcfc146b9981881c31544eba365..7f786e150e61ffc645752a1719a85aa5e709d16c 100644 (file)
@@ -500,7 +500,7 @@ int tls13_update_key(SSL *s, int send)
 
 int tls13_alert_code(int code)
 {
-    if (code == SSL_AD_MISSING_EXTENSION)
+    if (code == SSL_AD_MISSING_EXTENSION || code == SSL_AD_END_OF_EARLY_DATA)
         return code;
 
     return tls1_alert_code(code);
index d3f7a4f596608c2bb77148dca200fe6494347a6a..52ae727b373c0a5590012419045b8feb143c7e63 100644 (file)
@@ -428,3 +428,5 @@ SSL_set_max_early_data                  428 1_1_1   EXIST::FUNCTION:
 SSL_CTX_set_max_early_data              429    1_1_1   EXIST::FUNCTION:
 SSL_get_max_early_data                  430    1_1_1   EXIST::FUNCTION:
 SSL_CTX_get_max_early_data              431    1_1_1   EXIST::FUNCTION:
+SSL_write_early                         432    1_1_1   EXIST::FUNCTION:
+SSL_write_early_finish                  433    1_1_1   EXIST::FUNCTION: