SSL test framework: port NPN and ALPN tests
[openssl.git] / test / handshake_helper.c
index 8a8dab02bba5d86d778349555e1e009ef21d9674..77852ad586a642f272029450e1cbd6e85bfc0b0f 100644 (file)
 
 #include "handshake_helper.h"
 
+HANDSHAKE_RESULT *HANDSHAKE_RESULT_new()
+{
+    HANDSHAKE_RESULT *ret;
+    ret = OPENSSL_zalloc(sizeof(*ret));
+    OPENSSL_assert(ret != NULL);
+    return ret;
+}
+
+void HANDSHAKE_RESULT_free(HANDSHAKE_RESULT *result)
+{
+    OPENSSL_free(result->client_npn_negotiated);
+    OPENSSL_free(result->server_npn_negotiated);
+    OPENSSL_free(result->client_alpn_negotiated);
+    OPENSSL_free(result->server_alpn_negotiated);
+    OPENSSL_free(result);
+}
+
 /*
  * Since there appears to be no way to extract the sent/received alert
  * from the SSL object directly, we use the info callback and stash
@@ -27,6 +44,22 @@ typedef struct handshake_ex_data {
     ssl_servername_t servername;
 } HANDSHAKE_EX_DATA;
 
+typedef struct ctx_data {
+    unsigned char *npn_protocols;
+    size_t npn_protocols_len;
+    unsigned char *alpn_protocols;
+    size_t alpn_protocols_len;
+} CTX_DATA;
+
+/* |ctx_data| itself is stack-allocated. */
+static void ctx_data_free_data(CTX_DATA *ctx_data)
+{
+    OPENSSL_free(ctx_data->npn_protocols);
+    ctx_data->npn_protocols = NULL;
+    OPENSSL_free(ctx_data->alpn_protocols);
+    ctx_data->alpn_protocols = NULL;
+}
+
 static int ex_data_idx;
 
 static void info_cb(const SSL *s, int where, int ret)
@@ -42,8 +75,7 @@ static void info_cb(const SSL *s, int where, int ret)
     }
 }
 
-/*
- * Select the appropriate server CTX.
+/* Select the appropriate server CTX.
  * Returns SSL_TLSEXT_ERR_OK if a match was found.
  * If |ignore| is 1, returns SSL_TLSEXT_ERR_NOACK on mismatch.
  * Otherwise, returns SSL_TLSEXT_ERR_ALERT_FATAL on mismatch.
@@ -115,13 +147,13 @@ static int verify_accept_cb(X509_STORE_CTX *ctx, void *arg) {
     return 1;
 }
 
-static int broken_session_ticket_cb(SSL* s, unsigned char* key_name, unsigned char *iv,
+static int broken_session_ticket_cb(SSL *s, unsigned char *key_name, unsigned char *iv,
                                     EVP_CIPHER_CTX *ctx, HMAC_CTX *hctx, int enc)
 {
     return 0;
 }
 
-static int do_not_call_session_ticket_cb(SSL* s, unsigned char* key_name,
+static int do_not_call_session_ticket_cb(SSL *s, unsigned char *key_name,
                                          unsigned char *iv,
                                          EVP_CIPHER_CTX *ctx,
                                          HMAC_CTX *hctx, int enc)
@@ -132,13 +164,114 @@ static int do_not_call_session_ticket_cb(SSL* s, unsigned char* key_name,
     return 0;
 }
 
+/* Parse the comma-separated list into TLS format. */
+static void parse_protos(const char *protos, unsigned char **out, size_t *outlen)
+{
+    size_t len, i, prefix;
+
+    len = strlen(protos);
+
+    /* Should never have reuse. */
+    OPENSSL_assert(*out == NULL);
+
+    /* Test values are small, so we omit length limit checks. */
+    *out = OPENSSL_malloc(len + 1);
+    OPENSSL_assert(*out != NULL);
+    *outlen = len + 1;
+
+    /*
+     * foo => '3', 'f', 'o', 'o'
+     * foo,bar => '3', 'f', 'o', 'o', '3', 'b', 'a', 'r'
+     */
+    memcpy(*out + 1, protos, len);
+
+    prefix = 0;
+    i = prefix + 1;
+    while (i <= len) {
+        if ((*out)[i] == ',') {
+            OPENSSL_assert(i - 1 - prefix > 0);
+            (*out)[prefix] = i - 1 - prefix;
+            prefix = i;
+        }
+        i++;
+    }
+    OPENSSL_assert(len - prefix > 0);
+    (*out)[prefix] = len - prefix;
+}
+
+/*
+ * The client SHOULD select the first protocol advertised by the server that it
+ * also supports.  In the event that the client doesn't support any of server's
+ * protocols, or the server doesn't advertise any, it SHOULD select the first
+ * protocol that it supports.
+ */
+static int client_npn_cb(SSL *s, unsigned char **out, unsigned char *outlen,
+                         const unsigned char *in, unsigned int inlen,
+                         void *arg)
+{
+    CTX_DATA *ctx_data = (CTX_DATA*)(arg);
+    int ret;
+
+    ret = SSL_select_next_proto(out, outlen, in, inlen,
+                                ctx_data->npn_protocols,
+                                ctx_data->npn_protocols_len);
+    /* Accept both OPENSSL_NPN_NEGOTIATED and OPENSSL_NPN_NO_OVERLAP. */
+    OPENSSL_assert(ret == OPENSSL_NPN_NEGOTIATED
+                   || ret == OPENSSL_NPN_NO_OVERLAP);
+    return SSL_TLSEXT_ERR_OK;
+}
+
+static int server_npn_cb(SSL *s, const unsigned char **data,
+                         unsigned int *len, void *arg)
+{
+    CTX_DATA *ctx_data = (CTX_DATA*)(arg);
+    *data = ctx_data->npn_protocols;
+    *len = ctx_data->npn_protocols_len;
+    return SSL_TLSEXT_ERR_OK;
+}
+
+/*
+ * The server SHOULD select the most highly preferred protocol that it supports
+ * and that is also advertised by the client.  In the event that the server
+ * supports no protocols that the client advertises, then the server SHALL
+ * respond with a fatal "no_application_protocol" alert.
+ */
+static int server_alpn_cb(SSL *s, const unsigned char **out,
+                          unsigned char *outlen, const unsigned char *in,
+                          unsigned int inlen, void *arg)
+{
+    CTX_DATA *ctx_data = (CTX_DATA*)(arg);
+    int ret;
+
+    /* SSL_select_next_proto isn't const-correct... */
+    unsigned char *tmp_out;
+
+    /*
+     * The result points either to |in| or to |ctx_data->alpn_protocols|.
+     * The callback is allowed to point to |in| or to a long-lived buffer,
+     * so we can return directly without storing a copy.
+     */
+    ret = SSL_select_next_proto(&tmp_out, outlen,
+                                ctx_data->alpn_protocols,
+                                ctx_data->alpn_protocols_len, in, inlen);
+
+    *out = tmp_out;
+    /* Unlike NPN, we don't tolerate a mismatch. */
+    return ret == OPENSSL_NPN_NEGOTIATED ? SSL_TLSEXT_ERR_OK
+        : SSL_TLSEXT_ERR_NOACK;
+}
+
+
 /*
  * Configure callbacks and other properties that can't be set directly
  * in the server/client CONF.
  */
 static void configure_handshake_ctx(SSL_CTX *server_ctx, SSL_CTX *server2_ctx,
                                     SSL_CTX *client_ctx,
-                                    const SSL_TEST_CTX *test_ctx)
+                                    const SSL_TEST_CTX *test_ctx,
+                                    CTX_DATA *server_ctx_data,
+                                    CTX_DATA *server2_ctx_data,
+                                    CTX_DATA *client_ctx_data)
 {
     switch (test_ctx->client_verify_callback) {
     case SSL_TEST_VERIFY_ACCEPT_ALL:
@@ -179,12 +312,55 @@ static void configure_handshake_ctx(SSL_CTX *server_ctx, SSL_CTX *server2_ctx,
     if (test_ctx->session_ticket_expected == SSL_TEST_SESSION_TICKET_BROKEN) {
         SSL_CTX_set_tlsext_ticket_key_cb(server_ctx, broken_session_ticket_cb);
     }
+
+    if (test_ctx->server_npn_protocols != NULL) {
+        parse_protos(test_ctx->server_npn_protocols,
+                     &server_ctx_data->npn_protocols,
+                     &server_ctx_data->npn_protocols_len);
+        SSL_CTX_set_next_protos_advertised_cb(server_ctx, server_npn_cb,
+                                              server_ctx_data);
+    }
+    if (test_ctx->server2_npn_protocols != NULL) {
+        parse_protos(test_ctx->server2_npn_protocols,
+                     &server2_ctx_data->npn_protocols,
+                     &server2_ctx_data->npn_protocols_len);
+        OPENSSL_assert(server2_ctx != NULL);
+        SSL_CTX_set_next_protos_advertised_cb(server2_ctx, server_npn_cb,
+                                              server2_ctx_data);
+    }
+    if (test_ctx->client_npn_protocols != NULL) {
+        parse_protos(test_ctx->client_npn_protocols,
+                     &client_ctx_data->npn_protocols,
+                     &client_ctx_data->npn_protocols_len);
+        SSL_CTX_set_next_proto_select_cb(client_ctx, client_npn_cb,
+                                         client_ctx_data);
+    }
+    if (test_ctx->server_alpn_protocols != NULL) {
+        parse_protos(test_ctx->server_alpn_protocols,
+                     &server_ctx_data->alpn_protocols,
+                     &server_ctx_data->alpn_protocols_len);
+        SSL_CTX_set_alpn_select_cb(server_ctx, server_alpn_cb, server_ctx_data);
+    }
+    if (test_ctx->server2_alpn_protocols != NULL) {
+        OPENSSL_assert(server2_ctx != NULL);
+        parse_protos(test_ctx->server2_alpn_protocols,
+                     &server2_ctx_data->alpn_protocols,
+                     &server2_ctx_data->alpn_protocols_len);
+        SSL_CTX_set_alpn_select_cb(server2_ctx, server_alpn_cb, server2_ctx_data);
+    }
+    if (test_ctx->client_alpn_protocols != NULL) {
+        unsigned char *alpn_protos = NULL;
+        size_t alpn_protos_len;
+        parse_protos(test_ctx->client_alpn_protocols,
+                     &alpn_protos, &alpn_protos_len);
+        /* Reversed return value convention... */
+        OPENSSL_assert(SSL_CTX_set_alpn_protos(client_ctx, alpn_protos,
+                                               alpn_protos_len) == 0);
+        OPENSSL_free(alpn_protos);
+    }
 }
 
-/*
- * Configure callbacks and other properties that can't be set directly
- * in the server/client CONF.
- */
+/* Configure per-SSL callbacks and other properties. */
 static void configure_handshake_ssl(SSL *server, SSL *client,
                                     const SSL_TEST_CTX *test_ctx)
 {
@@ -293,21 +469,45 @@ static handshake_status_t handshake_status(peer_status_t last_status,
     return INTERNAL_ERROR;
 }
 
-HANDSHAKE_RESULT do_handshake(SSL_CTX *server_ctx, SSL_CTX *server2_ctx,
-                              SSL_CTX *client_ctx, const SSL_TEST_CTX *test_ctx)
+/* Convert unsigned char buf's that shouldn't contain any NUL-bytes to char. */
+static char *dup_str(const unsigned char *in, size_t len)
+{
+    char *ret;
+
+    if(len == 0)
+        return NULL;
+
+    /* Assert that the string does not contain NUL-bytes. */
+    OPENSSL_assert(OPENSSL_strnlen((const char*)(in), len) == len);
+    ret = OPENSSL_strndup((const char*)(in), len);
+    OPENSSL_assert(ret != NULL);
+    return ret;
+}
+
+HANDSHAKE_RESULT *do_handshake(SSL_CTX *server_ctx, SSL_CTX *server2_ctx,
+                               SSL_CTX *client_ctx, const SSL_TEST_CTX *test_ctx)
 {
     SSL *server, *client;
     BIO *client_to_server, *server_to_client;
     HANDSHAKE_EX_DATA server_ex_data, client_ex_data;
-    HANDSHAKE_RESULT ret;
+    CTX_DATA client_ctx_data, server_ctx_data, server2_ctx_data;
+    HANDSHAKE_RESULT *ret = HANDSHAKE_RESULT_new();
     int client_turn = 1;
     peer_status_t client_status = PEER_RETRY, server_status = PEER_RETRY;
     handshake_status_t status = HANDSHAKE_RETRY;
     unsigned char* tick = NULL;
-    size_t len = 0;
+    size_t tick_len = 0;
     SSL_SESSION* sess = NULL;
+    const unsigned char *proto = NULL;
+    /* API dictates unsigned int rather than size_t. */
+    unsigned int proto_len = 0;
 
-    configure_handshake_ctx(server_ctx, server2_ctx, client_ctx, test_ctx);
+    memset(&server_ctx_data, 0, sizeof(server_ctx_data));
+    memset(&server2_ctx_data, 0, sizeof(server2_ctx_data));
+    memset(&client_ctx_data, 0, sizeof(client_ctx_data));
+
+    configure_handshake_ctx(server_ctx, server2_ctx, client_ctx, test_ctx,
+                            &server_ctx_data, &server2_ctx_data, &client_ctx_data);
 
     server = SSL_new(server_ctx);
     client = SSL_new(client_ctx);
@@ -317,8 +517,8 @@ HANDSHAKE_RESULT do_handshake(SSL_CTX *server_ctx, SSL_CTX *server2_ctx,
 
     memset(&server_ex_data, 0, sizeof(server_ex_data));
     memset(&client_ex_data, 0, sizeof(client_ex_data));
-    memset(&ret, 0, sizeof(ret));
-    ret.result = SSL_TEST_INTERNAL_ERROR;
+
+    ret->result = SSL_TEST_INTERNAL_ERROR;
 
     client_to_server = BIO_new(BIO_s_mem());
     server_to_client = BIO_new(BIO_s_mem());
@@ -370,16 +570,16 @@ HANDSHAKE_RESULT do_handshake(SSL_CTX *server_ctx, SSL_CTX *server2_ctx,
 
         switch (status) {
         case HANDSHAKE_SUCCESS:
-            ret.result = SSL_TEST_SUCCESS;
+            ret->result = SSL_TEST_SUCCESS;
             goto err;
         case CLIENT_ERROR:
-            ret.result = SSL_TEST_CLIENT_FAIL;
+            ret->result = SSL_TEST_CLIENT_FAIL;
             goto err;
         case SERVER_ERROR:
-            ret.result = SSL_TEST_SERVER_FAIL;
+            ret->result = SSL_TEST_SERVER_FAIL;
             goto err;
         case INTERNAL_ERROR:
-            ret.result = SSL_TEST_INTERNAL_ERROR;
+            ret->result = SSL_TEST_INTERNAL_ERROR;
             goto err;
         case HANDSHAKE_RETRY:
             /* Continue. */
@@ -388,21 +588,36 @@ HANDSHAKE_RESULT do_handshake(SSL_CTX *server_ctx, SSL_CTX *server2_ctx,
         }
     }
  err:
-    ret.server_alert_sent = server_ex_data.alert_sent;
-    ret.server_alert_received = client_ex_data.alert_received;
-    ret.client_alert_sent = client_ex_data.alert_sent;
-    ret.client_alert_received = server_ex_data.alert_received;
-    ret.server_protocol = SSL_version(server);
-    ret.client_protocol = SSL_version(client);
-    ret.servername = server_ex_data.servername;
+    ret->server_alert_sent = server_ex_data.alert_sent;
+    ret->server_alert_received = client_ex_data.alert_received;
+    ret->client_alert_sent = client_ex_data.alert_sent;
+    ret->client_alert_received = server_ex_data.alert_received;
+    ret->server_protocol = SSL_version(server);
+    ret->client_protocol = SSL_version(client);
+    ret->servername = server_ex_data.servername;
     if ((sess = SSL_get0_session(client)) != NULL)
-        SSL_SESSION_get0_ticket(sess, &tick, &len);
-    if (tick == NULL || len == 0)
-        ret.session_ticket = SSL_TEST_SESSION_TICKET_NO;
+        SSL_SESSION_get0_ticket(sess, &tick, &tick_len);
+    if (tick == NULL || tick_len == 0)
+        ret->session_ticket = SSL_TEST_SESSION_TICKET_NO;
     else
-        ret.session_ticket = SSL_TEST_SESSION_TICKET_YES;
-    ret.session_ticket_do_not_call = server_ex_data.session_ticket_do_not_call;
+        ret->session_ticket = SSL_TEST_SESSION_TICKET_YES;
+    ret->session_ticket_do_not_call = server_ex_data.session_ticket_do_not_call;
+
+    SSL_get0_next_proto_negotiated(client, &proto, &proto_len);
+    ret->client_npn_negotiated = dup_str(proto, proto_len);
+
+    SSL_get0_next_proto_negotiated(server, &proto, &proto_len);
+    ret->server_npn_negotiated = dup_str(proto, proto_len);
+
+    SSL_get0_alpn_selected(client, &proto, &proto_len);
+    ret->client_alpn_negotiated = dup_str(proto, proto_len);
+
+    SSL_get0_alpn_selected(server, &proto, &proto_len);
+    ret->server_alpn_negotiated = dup_str(proto, proto_len);
 
+    ctx_data_free_data(&server_ctx_data);
+    ctx_data_free_data(&server2_ctx_data);
+    ctx_data_free_data(&client_ctx_data);
     SSL_free(server);
     SSL_free(client);
     return ret;