For TLS 1.3 retrieve previously set certificate index
[openssl.git] / ssl / ssl_lib.c
index ddc2ff78e7d0209230c074a3402e05fcc5f38d34..e4eec4a949a4ca72b6594c0acb606721912ee73b 100644 (file)
@@ -599,7 +599,7 @@ SSL *SSL_new(SSL_CTX *ctx)
     s->ext.ocsp.resp = NULL;
     s->ext.ocsp.resp_len = 0;
     SSL_CTX_up_ref(ctx);
-    s->initial_ctx = ctx;
+    s->session_ctx = ctx;
 #ifndef OPENSSL_NO_EC
     if (ctx->ext.ecpointformats) {
         s->ext.ecpointformats =
@@ -995,7 +995,7 @@ void SSL_free(SSL *s)
     /* Free up if allocated */
 
     OPENSSL_free(s->ext.hostname);
-    SSL_CTX_free(s->initial_ctx);
+    SSL_CTX_free(s->session_ctx);
 #ifndef OPENSSL_NO_EC
     OPENSSL_free(s->ext.ecpointformats);
     OPENSSL_free(s->ext.supportedgroups);
@@ -1716,6 +1716,13 @@ int SSL_shutdown(SSL *s)
 
 int SSL_renegotiate(SSL *s)
 {
+    /*
+     * TODO(TLS1.3): Return an error for now. Perhaps we should do a KeyUpdate
+     * instead when we support that?
+     */
+    if (SSL_IS_TLS13(s))
+        return 0;
+
     if (s->renegotiate == 0)
         s->renegotiate = 1;
 
@@ -1726,6 +1733,13 @@ int SSL_renegotiate(SSL *s)
 
 int SSL_renegotiate_abbreviated(SSL *s)
 {
+    /*
+     * TODO(TLS1.3): Return an error for now. Perhaps we should do a KeyUpdate
+     * instead when we support that?
+     */
+    if (SSL_IS_TLS13(s))
+        return 0;
+
     if (s->renegotiate == 0)
         s->renegotiate = 1;
 
@@ -2267,10 +2281,7 @@ void SSL_get0_next_proto_negotiated(const SSL *s, const unsigned char **data,
  * ServerHello.
  */
 void SSL_CTX_set_npn_advertised_cb(SSL_CTX *ctx,
-                                   int (*cb) (SSL *ssl,
-                                              const unsigned char **out,
-                                              unsigned int *outlen,
-                                              void *arg),
+                                   SSL_CTX_npn_advertised_cb_func cb,
                                    void *arg)
 {
     ctx->ext.npn_advertised_cb = cb;
@@ -2288,11 +2299,7 @@ void SSL_CTX_set_npn_advertised_cb(SSL_CTX *ctx,
  * a value other than SSL_TLSEXT_ERR_OK.
  */
 void SSL_CTX_set_npn_select_cb(SSL_CTX *ctx,
-                               int (*cb) (SSL *s, unsigned char **out,
-                                          unsigned char *outlen,
-                                          const unsigned char *in,
-                                          unsigned int inlen,
-                                          void *arg),
+                               SSL_CTX_npn_select_cb_func cb,
                                void *arg)
 {
     ctx->ext.npn_select_cb = cb;
@@ -2344,12 +2351,8 @@ int SSL_set_alpn_protos(SSL *ssl, const unsigned char *protos,
  * from the client's list of offered protocols.
  */
 void SSL_CTX_set_alpn_select_cb(SSL_CTX *ctx,
-                                int (*cb) (SSL *ssl,
-                                           const unsigned char **out,
-                                           unsigned char *outlen,
-                                           const unsigned char *in,
-                                           unsigned int inlen,
-                                           void *arg), void *arg)
+                                SSL_CTX_alpn_select_cb_func cb,
+                                void *arg)
 {
     ctx->ext.alpn_select_cb = cb;
     ctx->ext.alpn_select_cb_arg = arg;
@@ -2834,6 +2837,15 @@ int ssl_check_srvr_ecc_cert_and_alg(X509 *x, SSL *s)
 static int ssl_get_server_cert_index(const SSL *s)
 {
     int idx;
+
+    if (SSL_IS_TLS13(s)) {
+        if (s->s3->tmp.sigalg == NULL) {
+            SSLerr(SSL_F_SSL_GET_SERVER_CERT_INDEX, ERR_R_INTERNAL_ERROR);
+            return -1;
+        }
+        return s->s3->tmp.cert_idx;
+    }
+
     idx = ssl_cipher_get_cert_index(s->s3->tmp.new_cipher);
     if (idx == SSL_PKEY_RSA_ENC && !s->cert->pkeys[SSL_PKEY_RSA_ENC].x509)
         idx = SSL_PKEY_RSA_SIGN;
@@ -3092,7 +3104,7 @@ int SSL_do_handshake(SSL *s)
         return -1;
     }
 
-    s->method->ssl_renegotiate_check(s);
+    s->method->ssl_renegotiate_check(s, 0);
 
     if (SSL_in_init(s) || SSL_in_before(s)) {
         if ((s->mode & SSL_MODE_ASYNC) && ASYNC_get_current_job() == NULL) {
@@ -3474,7 +3486,7 @@ SSL_CTX *SSL_set_SSL_CTX(SSL *ssl, SSL_CTX *ctx)
     if (ssl->ctx == ctx)
         return ssl->ctx;
     if (ctx == NULL)
-        ctx = ssl->initial_ctx;
+        ctx = ssl->session_ctx;
     new_cert = ssl_cert_dup(ctx->cert);
     if (new_cert == NULL) {
         return NULL;
@@ -3726,46 +3738,22 @@ const char *SSL_get_psk_identity(const SSL *s)
     return (s->session->psk_identity);
 }
 
-void SSL_set_psk_client_callback(SSL *s,
-                                 unsigned int (*cb) (SSL *ssl,
-                                                     const char *hint,
-                                                     char *identity,
-                                                     unsigned int
-                                                     max_identity_len,
-                                                     unsigned char *psk,
-                                                     unsigned int max_psk_len))
+void SSL_set_psk_client_callback(SSL *s, SSL_psk_client_cb_func cb)
 {
     s->psk_client_callback = cb;
 }
 
-void SSL_CTX_set_psk_client_callback(SSL_CTX *ctx,
-                                     unsigned int (*cb) (SSL *ssl,
-                                                         const char *hint,
-                                                         char *identity,
-                                                         unsigned int
-                                                         max_identity_len,
-                                                         unsigned char *psk,
-                                                         unsigned int
-                                                         max_psk_len))
+void SSL_CTX_set_psk_client_callback(SSL_CTX *ctx, SSL_psk_client_cb_func cb)
 {
     ctx->psk_client_callback = cb;
 }
 
-void SSL_set_psk_server_callback(SSL *s,
-                                 unsigned int (*cb) (SSL *ssl,
-                                                     const char *identity,
-                                                     unsigned char *psk,
-                                                     unsigned int max_psk_len))
+void SSL_set_psk_server_callback(SSL *s, SSL_psk_server_cb_func cb)
 {
     s->psk_server_callback = cb;
 }
 
-void SSL_CTX_set_psk_server_callback(SSL_CTX *ctx,
-                                     unsigned int (*cb) (SSL *ssl,
-                                                         const char *identity,
-                                                         unsigned char *psk,
-                                                         unsigned int
-                                                         max_psk_len))
+void SSL_CTX_set_psk_server_callback(SSL_CTX *ctx, SSL_psk_server_cb_func cb)
 {
     ctx->psk_server_callback = cb;
 }
@@ -3807,8 +3795,8 @@ void SSL_set_not_resumable_session_callback(SSL *ssl,
 /*
  * Allocates new EVP_MD_CTX and sets pointer to it into given pointer
  * variable, freeing EVP_MD_CTX previously stored in that variable, if any.
- * If EVP_MD pointer is passed, initializes ctx with this md Returns newly
- * allocated ctx;
+ * If EVP_MD pointer is passed, initializes ctx with this md.
+ * Returns the newly allocated ctx;
  */
 
 EVP_MD_CTX *ssl_replace_hash(EVP_MD_CTX **hash, const EVP_MD *md)
@@ -4373,3 +4361,99 @@ const CTLOG_STORE *SSL_CTX_get0_ctlog_store(const SSL_CTX *ctx)
 }
 
 #endif
+
+void SSL_CTX_set_keylog_callback(SSL_CTX *ctx, SSL_CTX_keylog_cb_func cb)
+{
+    ctx->keylog_callback = cb;
+}
+
+SSL_CTX_keylog_cb_func SSL_CTX_get_keylog_callback(const SSL_CTX *ctx)
+{
+    return ctx->keylog_callback;
+}
+
+static int nss_keylog_int(const char *prefix,
+                          SSL *ssl,
+                          const uint8_t *parameter_1,
+                          size_t parameter_1_len,
+                          const uint8_t *parameter_2,
+                          size_t parameter_2_len)
+{
+    char *out = NULL;
+    char *cursor = NULL;
+    size_t out_len = 0;
+    size_t i;
+    size_t prefix_len;
+
+    if (ssl->ctx->keylog_callback == NULL) return 1;
+
+    /*
+     * Our output buffer will contain the following strings, rendered with
+     * space characters in between, terminated by a NULL character: first the
+     * prefix, then the first parameter, then the second parameter. The
+     * meaning of each parameter depends on the specific key material being
+     * logged. Note that the first and second parameters are encoded in
+     * hexadecimal, so we need a buffer that is twice their lengths.
+     */
+    prefix_len = strlen(prefix);
+    out_len = prefix_len + (2*parameter_1_len) + (2*parameter_2_len) + 3;
+    if ((out = cursor = OPENSSL_malloc(out_len)) == NULL) {
+        SSLerr(SSL_F_NSS_KEYLOG_INT, ERR_R_MALLOC_FAILURE);
+        return 0;
+    }
+
+    strcpy(cursor, prefix);
+    cursor += prefix_len;
+    *cursor++ = ' ';
+
+    for (i = 0; i < parameter_1_len; i++) {
+        sprintf(cursor, "%02x", parameter_1[i]);
+        cursor += 2;
+    }
+    *cursor++ = ' ';
+
+    for (i = 0; i < parameter_2_len; i++) {
+        sprintf(cursor, "%02x", parameter_2[i]);
+        cursor += 2;
+    }
+    *cursor = '\0';
+
+    ssl->ctx->keylog_callback(ssl, (const char *)out);
+    OPENSSL_free(out);
+    return 1;
+
+}
+
+int ssl_log_rsa_client_key_exchange(SSL *ssl,
+                                    const uint8_t *encrypted_premaster,
+                                    size_t encrypted_premaster_len,
+                                    const uint8_t *premaster,
+                                    size_t premaster_len)
+{
+    if (encrypted_premaster_len < 8) {
+        SSLerr(SSL_F_SSL_LOG_RSA_CLIENT_KEY_EXCHANGE, ERR_R_INTERNAL_ERROR);
+        return 0;
+    }
+
+    /* We only want the first 8 bytes of the encrypted premaster as a tag. */
+    return nss_keylog_int("RSA",
+                          ssl,
+                          encrypted_premaster,
+                          8,
+                          premaster,
+                          premaster_len);
+}
+
+int ssl_log_secret(SSL *ssl,
+                   const char *label,
+                   const uint8_t *secret,
+                   size_t secret_len)
+{
+    return nss_keylog_int(label,
+                          ssl,
+                          ssl->s3->client_random,
+                          SSL3_RANDOM_SIZE,
+                          secret,
+                          secret_len);
+}
+