Convert dtls_write_records to use standard record layer functions
[openssl.git] / ssl / record / rec_layer_s3.c
index 0318b07a9fb63dd3fdd37234cfecb5fef9ab2066..04f130bc2eba140185c0c5205945f124bda295bd 100644 (file)
@@ -1062,6 +1062,7 @@ static const OSSL_DISPATCH rlayer_dispatch[] = {
 };
 
 static const OSSL_RECORD_METHOD *ssl_select_next_record_layer(SSL_CONNECTION *s,
+                                                              int direction,
                                                               int level)
 {
 
@@ -1081,7 +1082,8 @@ static const OSSL_RECORD_METHOD *ssl_select_next_record_layer(SSL_CONNECTION *s,
 #endif
 
     /* Default to the current OSSL_RECORD_METHOD */
-    return s->rlayer.rrlmethod;
+    return direction == OSSL_RECORD_DIRECTION_READ ? s->rlayer.rrlmethod
+                                                   : s->rlayer.wrlmethod;
 }
 
 static int ssl_post_record_layer_select(SSL_CONNECTION *s, int direction)
@@ -1133,11 +1135,14 @@ int ssl_set_new_record_layer(SSL_CONNECTION *s, int version,
     SSL_CTX *sctx = SSL_CONNECTION_GET_CTX(s);
     const OSSL_RECORD_METHOD *meth;
     int use_etm, stream_mac = 0, tlstree = 0;
-    unsigned int maxfrag = SSL3_RT_MAX_PLAIN_LENGTH;
+    unsigned int maxfrag = (direction == OSSL_RECORD_DIRECTION_WRITE)
+                           ? ssl_get_max_send_fragment(s)
+                           : SSL3_RT_MAX_PLAIN_LENGTH;
     int use_early_data = 0;
     uint32_t max_early_data;
+    COMP_METHOD *compm = (comp == NULL) ? NULL : comp->method;
 
-    meth = ssl_select_next_record_layer(s, level);
+    meth = ssl_select_next_record_layer(s, direction, level);
 
     if (direction == OSSL_RECORD_DIRECTION_READ) {
         thismethod = &s->rlayer.rrlmethod;
@@ -1202,9 +1207,16 @@ int ssl_set_new_record_layer(SSL_CONNECTION *s, int version,
         *set++ = OSSL_PARAM_construct_int(OSSL_LIBSSL_RECORD_LAYER_PARAM_TLSTREE,
                                           &tlstree);
 
-    if (s->session != NULL && USE_MAX_FRAGMENT_LENGTH_EXT(s->session))
+    /*
+     * We only need to do this for the read side. The write side should already
+     * have the correct value due to the ssl_get_max_send_fragment() call above
+     */
+    if (direction == OSSL_RECORD_DIRECTION_READ
+            && s->session != NULL
+            && USE_MAX_FRAGMENT_LENGTH_EXT(s->session))
         maxfrag = GET_MAX_FRAGMENT_LENGTH(s->session);
 
+
     if (maxfrag != SSL3_RT_MAX_PLAIN_LENGTH)
         *set++ = OSSL_PARAM_construct_uint(OSSL_LIBSSL_RECORD_LAYER_PARAM_MAX_FRAG_LEN,
                                            &maxfrag);
@@ -1214,7 +1226,6 @@ int ssl_set_new_record_layer(SSL_CONNECTION *s, int version,
      * using the early keys. A server also needs to worry about rejected early
      * data that might arrive when the handshake keys are in force.
      */
-    /* TODO(RECLAYER): Check this when doing the "write" record layer */
     if (s->server && direction == OSSL_RECORD_DIRECTION_READ) {
         use_early_data = (level == OSSL_RECORD_PROTECTION_LEVEL_EARLY
                           || level == OSSL_RECORD_PROTECTION_LEVEL_HANDSHAKE);
@@ -1256,6 +1267,10 @@ int ssl_set_new_record_layer(SSL_CONNECTION *s, int version,
                 return 0;
             }
             s->rlayer.rrlnext = next;
+        } else {
+            if (SSL_CONNECTION_IS_DTLS(s)
+                    && level != OSSL_RECORD_PROTECTION_LEVEL_NONE)
+                epoch =  DTLS_RECORD_LAYER_get_w_epoch(&s->rlayer) + 1; /* new epoch */
         }
 
         /*
@@ -1282,7 +1297,7 @@ int ssl_set_new_record_layer(SSL_CONNECTION *s, int version,
                                        s->server, direction, level, epoch,
                                        key, keylen, iv, ivlen, mackey,
                                        mackeylen, ciph, taglen, mactype, md,
-                                       comp, prev, thisbio, next, NULL, NULL,
+                                       compm, prev, thisbio, next, NULL, NULL,
                                        settings, options, rlayer_dispatch_tmp,
                                        s, &newrl);
         BIO_free(prev);
@@ -1314,9 +1329,17 @@ int ssl_set_new_record_layer(SSL_CONNECTION *s, int version,
         break;
     }
 
-    if (*thismethod != NULL && !(*thismethod)->free(*thisrl)) {
-        SSLfatal(s, SSL_AD_INTERNAL_ERROR, ERR_R_INTERNAL_ERROR);
-        return 0;
+    /*
+     * Free the old record layer if we have one except in the case of DTLS when
+     * writing. In that case the record layer is still referenced by buffered
+     * messages for potential retransmit. Only when those buffered messages get
+     * freed do we free the record layer object (see dtls1_hm_fragment_free)
+     */
+    if (!SSL_CONNECTION_IS_DTLS(s) || direction == OSSL_RECORD_DIRECTION_READ) {
+        if (*thismethod != NULL && !(*thismethod)->free(*thisrl)) {
+            SSLfatal(s, SSL_AD_INTERNAL_ERROR, ERR_R_INTERNAL_ERROR);
+            return 0;
+        }
     }
 
     *thisrl = newrl;