Add the ability to mutate TLS handshake messages before they are written
[openssl.git] / ssl / statem / statem_lib.c
index 88f3b94f2e2b4d23e314b5779a329b19248a8e46..ebedbeefbb2c4f1f563fbaeb52b25ec5278e561a 100644 (file)
@@ -36,6 +36,23 @@ const unsigned char hrrrandom[] = {
     0x07, 0x9e, 0x09, 0xe2, 0xc8, 0xa8, 0x33, 0x9c
 };
 
     0x07, 0x9e, 0x09, 0xe2, 0xc8, 0xa8, 0x33, 0x9c
 };
 
+int ossl_statem_set_mutator(SSL *s,
+                            ossl_statem_mutate_handshake_cb mutate_handshake_cb,
+                            ossl_statem_finish_mutate_handshake_cb finish_mutate_handshake_cb,
+                            void *mutatearg)
+{
+    SSL_CONNECTION *sc = SSL_CONNECTION_FROM_SSL(s);
+
+    if (sc == NULL)
+        return 0;
+
+    sc->statem.mutate_handshake_cb = mutate_handshake_cb;
+    sc->statem.mutatearg = mutatearg;
+    sc->statem.finish_mutate_handshake_cb = finish_mutate_handshake_cb;
+
+    return 1;
+}
+
 /*
  * send s->init_buf in records of type 'type' (SSL3_RT_HANDSHAKE or
  * SSL3_RT_CHANGE_CIPHER_SPEC)
 /*
  * send s->init_buf in records of type 'type' (SSL3_RT_HANDSHAKE or
  * SSL3_RT_CHANGE_CIPHER_SPEC)
@@ -46,6 +63,32 @@ int ssl3_do_write(SSL_CONNECTION *s, int type)
     size_t written = 0;
     SSL *ssl = SSL_CONNECTION_GET_SSL(s);
 
     size_t written = 0;
     SSL *ssl = SSL_CONNECTION_GET_SSL(s);
 
+    /*
+     * If we're running the test suite then we may need to mutate the message
+     * we've been asked to write. Does not happen in normal operation.
+     */
+    if (s->statem.mutate_handshake_cb != NULL
+            && !s->statem.write_in_progress
+            && type == SSL3_RT_HANDSHAKE
+            && s->init_num >= SSL3_HM_HEADER_LENGTH) {
+        unsigned char *msg;
+        size_t msglen;
+
+        if (!s->statem.mutate_handshake_cb((unsigned char *)s->init_buf->data,
+                                           s->init_num,
+                                           &msg, &msglen,
+                                           s->statem.mutatearg))
+            return -1;
+        if (msglen < SSL3_HM_HEADER_LENGTH
+                || !BUF_MEM_grow(s->init_buf, msglen))
+            return -1;
+        memcpy(s->init_buf->data, msg, msglen);
+        s->init_num = msglen;
+        s->init_msg = s->init_buf->data + SSL3_HM_HEADER_LENGTH;
+        s->statem.finish_mutate_handshake_cb(s->statem.mutatearg);
+        s->statem.write_in_progress = 1;
+    }
+
     ret = ssl3_write_bytes(ssl, type, &s->init_buf->data[s->init_off],
                            s->init_num, &written);
     if (ret < 0)
     ret = ssl3_write_bytes(ssl, type, &s->init_buf->data[s->init_off],
                            s->init_num, &written);
     if (ret < 0)
@@ -65,6 +108,7 @@ int ssl3_do_write(SSL_CONNECTION *s, int type)
                                  written))
                 return -1;
     if (written == s->init_num) {
                                  written))
                 return -1;
     if (written == s->init_num) {
+        s->statem.write_in_progress = 0;
         if (s->msg_callback)
             s->msg_callback(1, s->version, type, s->init_buf->data,
                             (size_t)(s->init_off + s->init_num), ssl,
         if (s->msg_callback)
             s->msg_callback(1, s->version, type, s->init_buf->data,
                             (size_t)(s->init_off + s->init_num), ssl,