TLSv13: add kTLS support
[openssl.git] / ssl / tls13_enc.c
index 1775152eeb8af9a2a216ebbeaaea571eb67b0a62..ba385f6ea2e5d7b64c0064247ff7b69e72e8469f 100644 (file)
@@ -9,6 +9,8 @@
 
 #include <stdlib.h>
 #include "ssl_local.h"
 
 #include <stdlib.h>
 #include "ssl_local.h"
+#include "internal/ktls.h"
+#include "record/record_local.h"
 #include "internal/cryptlib.h"
 #include <openssl/evp.h>
 #include <openssl/kdf.h>
 #include "internal/cryptlib.h"
 #include <openssl/evp.h>
 #include <openssl/kdf.h>
@@ -409,9 +411,9 @@ static int derive_secret_key_and_iv(SSL *s, int sending, const EVP_MD *md,
                                     const unsigned char *hash,
                                     const unsigned char *label,
                                     size_t labellen, unsigned char *secret,
                                     const unsigned char *hash,
                                     const unsigned char *label,
                                     size_t labellen, unsigned char *secret,
-                                    unsigned char *iv, EVP_CIPHER_CTX *ciph_ctx)
+                                    unsigned char *key, unsigned char *iv,
+                                    EVP_CIPHER_CTX *ciph_ctx)
 {
 {
-    unsigned char key[EVP_MAX_KEY_LENGTH];
     size_t ivlen, keylen, taglen;
     int hashleni = EVP_MD_size(md);
     size_t hashlen;
     size_t ivlen, keylen, taglen;
     int hashleni = EVP_MD_size(md);
     size_t hashlen;
@@ -420,14 +422,14 @@ static int derive_secret_key_and_iv(SSL *s, int sending, const EVP_MD *md,
     if (!ossl_assert(hashleni >= 0)) {
         SSLfatal(s, SSL_AD_INTERNAL_ERROR, SSL_F_DERIVE_SECRET_KEY_AND_IV,
                  ERR_R_EVP_LIB);
     if (!ossl_assert(hashleni >= 0)) {
         SSLfatal(s, SSL_AD_INTERNAL_ERROR, SSL_F_DERIVE_SECRET_KEY_AND_IV,
                  ERR_R_EVP_LIB);
-        goto err;
+        return 0;
     }
     hashlen = (size_t)hashleni;
 
     if (!tls13_hkdf_expand(s, md, insecret, label, labellen, hash, hashlen,
                            secret, hashlen, 1)) {
         /* SSLfatal() already called */
     }
     hashlen = (size_t)hashleni;
 
     if (!tls13_hkdf_expand(s, md, insecret, label, labellen, hash, hashlen,
                            secret, hashlen, 1)) {
         /* SSLfatal() already called */
-        goto err;
+        return 0;
     }
 
     /* TODO(size_t): convert me */
     }
 
     /* TODO(size_t): convert me */
@@ -447,7 +449,7 @@ static int derive_secret_key_and_iv(SSL *s, int sending, const EVP_MD *md,
         } else {
             SSLfatal(s, SSL_AD_INTERNAL_ERROR, SSL_F_DERIVE_SECRET_KEY_AND_IV,
                      ERR_R_EVP_LIB);
         } else {
             SSLfatal(s, SSL_AD_INTERNAL_ERROR, SSL_F_DERIVE_SECRET_KEY_AND_IV,
                      ERR_R_EVP_LIB);
-            goto err;
+            return 0;
         }
         if (algenc & (SSL_AES128CCM8 | SSL_AES256CCM8))
             taglen = EVP_CCM8_TLS_TAG_LEN;
         }
         if (algenc & (SSL_AES128CCM8 | SSL_AES256CCM8))
             taglen = EVP_CCM8_TLS_TAG_LEN;
@@ -461,7 +463,7 @@ static int derive_secret_key_and_iv(SSL *s, int sending, const EVP_MD *md,
     if (!tls13_derive_key(s, md, secret, key, keylen)
             || !tls13_derive_iv(s, md, secret, iv, ivlen)) {
         /* SSLfatal() already called */
     if (!tls13_derive_key(s, md, secret, key, keylen)
             || !tls13_derive_iv(s, md, secret, iv, ivlen)) {
         /* SSLfatal() already called */
-        goto err;
+        return 0;
     }
 
     if (EVP_CipherInit_ex(ciph_ctx, ciph, NULL, NULL, NULL, sending) <= 0
     }
 
     if (EVP_CipherInit_ex(ciph_ctx, ciph, NULL, NULL, NULL, sending) <= 0
@@ -471,13 +473,10 @@ static int derive_secret_key_and_iv(SSL *s, int sending, const EVP_MD *md,
         || EVP_CipherInit_ex(ciph_ctx, NULL, NULL, key, NULL, -1) <= 0) {
         SSLfatal(s, SSL_AD_INTERNAL_ERROR, SSL_F_DERIVE_SECRET_KEY_AND_IV,
                  ERR_R_EVP_LIB);
         || EVP_CipherInit_ex(ciph_ctx, NULL, NULL, key, NULL, -1) <= 0) {
         SSLfatal(s, SSL_AD_INTERNAL_ERROR, SSL_F_DERIVE_SECRET_KEY_AND_IV,
                  ERR_R_EVP_LIB);
-        goto err;
+        return 0;
     }
 
     return 1;
     }
 
     return 1;
- err:
-    OPENSSL_cleanse(key, sizeof(key));
-    return 0;
 }
 
 int tls13_change_cipher_state(SSL *s, int which)
 }
 
 int tls13_change_cipher_state(SSL *s, int which)
@@ -502,6 +501,7 @@ int tls13_change_cipher_state(SSL *s, int which)
     static const unsigned char early_exporter_master_secret[] = "e exp master";
 #endif
     unsigned char *iv;
     static const unsigned char early_exporter_master_secret[] = "e exp master";
 #endif
     unsigned char *iv;
+    unsigned char key[EVP_MAX_KEY_LENGTH];
     unsigned char secret[EVP_MAX_MD_SIZE];
     unsigned char hashval[EVP_MAX_MD_SIZE];
     unsigned char *hash = hashval;
     unsigned char secret[EVP_MAX_MD_SIZE];
     unsigned char hashval[EVP_MAX_MD_SIZE];
     unsigned char *hash = hashval;
@@ -515,6 +515,12 @@ int tls13_change_cipher_state(SSL *s, int which)
     int ret = 0;
     const EVP_MD *md = NULL;
     const EVP_CIPHER *cipher = NULL;
     int ret = 0;
     const EVP_MD *md = NULL;
     const EVP_CIPHER *cipher = NULL;
+#if !defined(OPENSSL_NO_KTLS) && defined(OPENSSL_KTLS_TLS13)
+# ifndef __FreeBSD__
+    struct tls_crypto_info_all crypto_info;
+    BIO *bio;
+# endif
+#endif
 
     if (which & SSL3_CC_READ) {
         if (s->enc_read_ctx != NULL) {
 
     if (which & SSL3_CC_READ) {
         if (s->enc_read_ctx != NULL) {
@@ -729,9 +735,13 @@ int tls13_change_cipher_state(SSL *s, int which)
         }
     }
 
         }
     }
 
+    /* check whether cipher is known */
+    if(!ossl_assert(cipher != NULL))
+        goto err;
+
     if (!derive_secret_key_and_iv(s, which & SSL3_CC_WRITE, md, cipher,
     if (!derive_secret_key_and_iv(s, which & SSL3_CC_WRITE, md, cipher,
-                                  insecret, hash, label, labellen, secret, iv,
-                                  ciph_ctx)) {
+                                  insecret, hash, label, labellen, secret, key,
+                                  iv, ciph_ctx)) {
         /* SSLfatal() already called */
         goto err;
     }
         /* SSLfatal() already called */
         goto err;
     }
@@ -772,12 +782,57 @@ int tls13_change_cipher_state(SSL *s, int which)
         s->statem.enc_write_state = ENC_WRITE_STATE_WRITE_PLAIN_ALERTS;
     else
         s->statem.enc_write_state = ENC_WRITE_STATE_VALID;
         s->statem.enc_write_state = ENC_WRITE_STATE_WRITE_PLAIN_ALERTS;
     else
         s->statem.enc_write_state = ENC_WRITE_STATE_VALID;
+#ifndef OPENSSL_NO_KTLS
+# if defined(OPENSSL_KTLS_TLS13)
+#  ifndef __FreeBSD__
+    if (!(which & SSL3_CC_WRITE) || !(which & SSL3_CC_APPLICATION)
+        || ((which & SSL3_CC_WRITE) && (s->mode & SSL_MODE_NO_KTLS_TX)))
+        goto skip_ktls;
+
+    /* ktls supports only the maximum fragment size */
+    if (ssl_get_max_send_fragment(s) != SSL3_RT_MAX_PLAIN_LENGTH)
+        goto skip_ktls;
+
+    /* ktls does not support record padding */
+    if (s->record_padding_cb != NULL)
+        goto skip_ktls;
+
+    /* check that cipher is supported */
+    if (!ktls_check_supported_cipher(cipher, ciph_ctx))
+        goto skip_ktls;
+
+    bio = s->wbio;
+
+    if (!ossl_assert(bio != NULL)) {
+        SSLfatal(s, SSL_AD_INTERNAL_ERROR, SSL_F_TLS13_CHANGE_CIPHER_STATE,
+                 ERR_R_INTERNAL_ERROR);
+        goto err;
+    }
+
+    /* All future data will get encrypted by ktls. Flush the BIO or skip ktls */
+    if (BIO_flush(bio) <= 0)
+        goto skip_ktls;
+
+    /* configure kernel crypto structure */
+    if (!ktls_configure_crypto(cipher, s->version, ciph_ctx, 
+                               RECORD_LAYER_get_write_sequence(&s->rlayer),
+                               &crypto_info, NULL, iv, key))
+        goto skip_ktls;
+
+    /* ktls works with user provided buffers directly */
+    if (BIO_set_ktls(bio, &crypto_info, which & SSL3_CC_WRITE))
+        ssl3_release_write_buffer(s);
+#  endif
+skip_ktls:
+# endif
+#endif
     ret = 1;
  err:
     if ((which & SSL3_CC_EARLY) != 0) {
         /* We up-refed this so now we need to down ref */
         ssl_evp_cipher_free(cipher);
     }
     ret = 1;
  err:
     if ((which & SSL3_CC_EARLY) != 0) {
         /* We up-refed this so now we need to down ref */
         ssl_evp_cipher_free(cipher);
     }
+    OPENSSL_cleanse(key, sizeof(key));
     OPENSSL_cleanse(secret, sizeof(secret));
     return ret;
 }
     OPENSSL_cleanse(secret, sizeof(secret));
     return ret;
 }
@@ -791,6 +846,7 @@ int tls13_update_key(SSL *s, int sending)
 #endif
     const EVP_MD *md = ssl_handshake_md(s);
     size_t hashlen = EVP_MD_size(md);
 #endif
     const EVP_MD *md = ssl_handshake_md(s);
     size_t hashlen = EVP_MD_size(md);
+    unsigned char key[EVP_MAX_KEY_LENGTH];
     unsigned char *insecret, *iv;
     unsigned char secret[EVP_MAX_MD_SIZE];
     EVP_CIPHER_CTX *ciph_ctx;
     unsigned char *insecret, *iv;
     unsigned char secret[EVP_MAX_MD_SIZE];
     EVP_CIPHER_CTX *ciph_ctx;
@@ -815,8 +871,8 @@ int tls13_update_key(SSL *s, int sending)
     if (!derive_secret_key_and_iv(s, sending, ssl_handshake_md(s),
                                   s->s3.tmp.new_sym_enc, insecret, NULL,
                                   application_traffic,
     if (!derive_secret_key_and_iv(s, sending, ssl_handshake_md(s),
                                   s->s3.tmp.new_sym_enc, insecret, NULL,
                                   application_traffic,
-                                  sizeof(application_traffic) - 1, secret, iv,
-                                  ciph_ctx)) {
+                                  sizeof(application_traffic) - 1, secret, key,
+                                  iv, ciph_ctx)) {
         /* SSLfatal() already called */
         goto err;
     }
         /* SSLfatal() already called */
         goto err;
     }
@@ -826,6 +882,7 @@ int tls13_update_key(SSL *s, int sending)
     s->statem.enc_write_state = ENC_WRITE_STATE_VALID;
     ret = 1;
  err:
     s->statem.enc_write_state = ENC_WRITE_STATE_VALID;
     ret = 1;
  err:
+    OPENSSL_cleanse(key, sizeof(key));
     OPENSSL_cleanse(secret, sizeof(secret));
     return ret;
 }
     OPENSSL_cleanse(secret, sizeof(secret));
     return ret;
 }