add ssl_has_cert
[openssl.git] / ssl / ssl_lib.c
index 65e3ba1824dbc3b8c252982582a322ddf3fe3d70..c92875f2d9dfe8ad6bf356752fbd6a1c9c172b96 100644 (file)
@@ -589,49 +589,46 @@ SSL *SSL_new(SSL_CTX *ctx)
 
     SSL_CTX_up_ref(ctx);
     s->ctx = ctx;
-    s->tlsext_debug_cb = 0;
-    s->tlsext_debug_arg = NULL;
-    s->tlsext_ticket_expected = 0;
-    s->tlsext_status_type = ctx->tlsext_status_type;
-    s->tlsext_status_expected = 0;
-    s->tlsext_ocsp_ids = NULL;
-    s->tlsext_ocsp_exts = NULL;
-    s->tlsext_ocsp_resp = NULL;
-    s->tlsext_ocsp_resplen = 0;
+    s->ext.debug_cb = 0;
+    s->ext.debug_arg = NULL;
+    s->ext.ticket_expected = 0;
+    s->ext.status_type = ctx->ext.status_type;
+    s->ext.status_expected = 0;
+    s->ext.ocsp.ids = NULL;
+    s->ext.ocsp.exts = NULL;
+    s->ext.ocsp.resp = NULL;
+    s->ext.ocsp.resp_len = 0;
     SSL_CTX_up_ref(ctx);
-    s->initial_ctx = ctx;
+    s->session_ctx = ctx;
 #ifndef OPENSSL_NO_EC
-    if (ctx->tlsext_ecpointformatlist) {
-        s->tlsext_ecpointformatlist =
-            OPENSSL_memdup(ctx->tlsext_ecpointformatlist,
-                           ctx->tlsext_ecpointformatlist_length);
-        if (!s->tlsext_ecpointformatlist)
+    if (ctx->ext.ecpointformats) {
+        s->ext.ecpointformats =
+            OPENSSL_memdup(ctx->ext.ecpointformats,
+                           ctx->ext.ecpointformats_len);
+        if (!s->ext.ecpointformats)
             goto err;
-        s->tlsext_ecpointformatlist_length =
-            ctx->tlsext_ecpointformatlist_length;
-    }
-    if (ctx->tlsext_supportedgroupslist) {
-        s->tlsext_supportedgroupslist =
-            OPENSSL_memdup(ctx->tlsext_supportedgroupslist,
-                           ctx->tlsext_supportedgroupslist_length);
-        if (!s->tlsext_supportedgroupslist)
+        s->ext.ecpointformats_len =
+            ctx->ext.ecpointformats_len;
+    }
+    if (ctx->ext.supportedgroups) {
+        s->ext.supportedgroups =
+            OPENSSL_memdup(ctx->ext.supportedgroups,
+                           ctx->ext.supportedgroups_len);
+        if (!s->ext.supportedgroups)
             goto err;
-        s->tlsext_supportedgroupslist_length =
-            ctx->tlsext_supportedgroupslist_length;
+        s->ext.supportedgroups_len = ctx->ext.supportedgroups_len;
     }
 #endif
 #ifndef OPENSSL_NO_NEXTPROTONEG
-    s->next_proto_negotiated = NULL;
+    s->ext.npn = NULL;
 #endif
 
-    if (s->ctx->alpn_client_proto_list) {
-        s->alpn_client_proto_list =
-            OPENSSL_malloc(s->ctx->alpn_client_proto_list_len);
-        if (s->alpn_client_proto_list == NULL)
+    if (s->ctx->ext.alpn) {
+        s->ext.alpn = OPENSSL_malloc(s->ctx->ext.alpn_len);
+        if (s->ext.alpn == NULL)
             goto err;
-        memcpy(s->alpn_client_proto_list, s->ctx->alpn_client_proto_list,
-               s->ctx->alpn_client_proto_list_len);
-        s->alpn_client_proto_list_len = s->ctx->alpn_client_proto_list_len;
+        memcpy(s->ext.alpn, s->ctx->ext.alpn, s->ctx->ext.alpn_len);
+        s->ext.alpn_len = s->ctx->ext.alpn_len;
     }
 
     s->verified_chain = NULL;
@@ -838,7 +835,7 @@ int SSL_dane_enable(SSL *s, const char *basedomain)
      * accepts them and disables host name checks.  To avoid side-effects with
      * invalid input, set the SNI name first.
      */
-    if (s->tlsext_hostname == NULL) {
+    if (s->ext.hostname == NULL) {
         if (!SSL_set_tlsext_host_name(s, basedomain)) {
             SSLerr(SSL_F_SSL_DANE_ENABLE, SSL_R_ERROR_SETTING_TLSA_BASE_DOMAIN);
             return -1;
@@ -997,22 +994,22 @@ void SSL_free(SSL *s)
     ssl_cert_free(s->cert);
     /* Free up if allocated */
 
-    OPENSSL_free(s->tlsext_hostname);
-    SSL_CTX_free(s->initial_ctx);
+    OPENSSL_free(s->ext.hostname);
+    SSL_CTX_free(s->session_ctx);
 #ifndef OPENSSL_NO_EC
-    OPENSSL_free(s->tlsext_ecpointformatlist);
-    OPENSSL_free(s->tlsext_supportedgroupslist);
+    OPENSSL_free(s->ext.ecpointformats);
+    OPENSSL_free(s->ext.supportedgroups);
 #endif                          /* OPENSSL_NO_EC */
-    sk_X509_EXTENSION_pop_free(s->tlsext_ocsp_exts, X509_EXTENSION_free);
+    sk_X509_EXTENSION_pop_free(s->ext.ocsp.exts, X509_EXTENSION_free);
 #ifndef OPENSSL_NO_OCSP
-    sk_OCSP_RESPID_pop_free(s->tlsext_ocsp_ids, OCSP_RESPID_free);
+    sk_OCSP_RESPID_pop_free(s->ext.ocsp.ids, OCSP_RESPID_free);
 #endif
 #ifndef OPENSSL_NO_CT
     SCT_LIST_free(s->scts);
-    OPENSSL_free(s->tlsext_scts);
+    OPENSSL_free(s->ext.scts);
 #endif
-    OPENSSL_free(s->tlsext_ocsp_resp);
-    OPENSSL_free(s->alpn_client_proto_list);
+    OPENSSL_free(s->ext.ocsp.resp);
+    OPENSSL_free(s->ext.alpn);
 
     sk_X509_NAME_pop_free(s->client_CA, X509_NAME_free);
 
@@ -1028,7 +1025,7 @@ void SSL_free(SSL *s)
     ASYNC_WAIT_CTX_free(s->waitctx);
 
 #if !defined(OPENSSL_NO_NEXTPROTONEG)
-    OPENSSL_free(s->next_proto_negotiated);
+    OPENSSL_free(s->ext.npn);
 #endif
 
 #ifndef OPENSSL_NO_SRTP
@@ -1719,6 +1716,13 @@ int SSL_shutdown(SSL *s)
 
 int SSL_renegotiate(SSL *s)
 {
+    /*
+     * TODO(TLS1.3): Return an error for now. Perhaps we should do a KeyUpdate
+     * instead when we support that?
+     */
+    if (SSL_IS_TLS13(s))
+        return 0;
+
     if (s->renegotiate == 0)
         s->renegotiate = 1;
 
@@ -1729,6 +1733,13 @@ int SSL_renegotiate(SSL *s)
 
 int SSL_renegotiate_abbreviated(SSL *s)
 {
+    /*
+     * TODO(TLS1.3): Return an error for now. Perhaps we should do a KeyUpdate
+     * instead when we support that?
+     */
+    if (SSL_IS_TLS13(s))
+        return 0;
+
     if (s->renegotiate == 0)
         s->renegotiate = 1;
 
@@ -2168,15 +2179,15 @@ const char *SSL_get_servername(const SSL *s, const int type)
     if (type != TLSEXT_NAMETYPE_host_name)
         return NULL;
 
-    return s->session && !s->tlsext_hostname ?
-        s->session->tlsext_hostname : s->tlsext_hostname;
+    return s->session && !s->ext.hostname ?
+        s->session->ext.hostname : s->ext.hostname;
 }
 
 int SSL_get_servername_type(const SSL *s)
 {
     if (s->session
-        && (!s->tlsext_hostname ? s->session->
-            tlsext_hostname : s->tlsext_hostname))
+        && (!s->ext.hostname ? s->session->
+            ext.hostname : s->ext.hostname))
         return TLSEXT_NAMETYPE_host_name;
     return -1;
 }
@@ -2251,16 +2262,16 @@ int SSL_select_next_proto(unsigned char **out, unsigned char *outlen,
 void SSL_get0_next_proto_negotiated(const SSL *s, const unsigned char **data,
                                     unsigned *len)
 {
-    *data = s->next_proto_negotiated;
+    *data = s->ext.npn;
     if (!*data) {
         *len = 0;
     } else {
-        *len = (unsigned int)s->next_proto_negotiated_len;
+        *len = (unsigned int)s->ext.npn_len;
     }
 }
 
 /*
- * SSL_CTX_set_next_protos_advertised_cb sets a callback that is called when
+ * SSL_CTX_set_npn_advertised_cb sets a callback that is called when
  * a TLS server needs a list of supported protocols for Next Protocol
  * Negotiation. The returned list must be in wire format.  The list is
  * returned by setting |out| to point to it and |outlen| to its length. This
@@ -2269,15 +2280,12 @@ void SSL_get0_next_proto_negotiated(const SSL *s, const unsigned char **data,
  * wishes to advertise. Otherwise, no such extension will be included in the
  * ServerHello.
  */
-void SSL_CTX_set_next_protos_advertised_cb(SSL_CTX *ctx,
-                                           int (*cb) (SSL *ssl,
-                                                      const unsigned char
-                                                      **out,
-                                                      unsigned int *outlen,
-                                                      void *arg), void *arg)
+void SSL_CTX_set_npn_advertised_cb(SSL_CTX *ctx,
+                                   SSL_CTX_npn_advertised_cb_func cb,
+                                   void *arg)
 {
-    ctx->next_protos_advertised_cb = cb;
-    ctx->next_protos_advertised_cb_arg = arg;
+    ctx->ext.npn_advertised_cb = cb;
+    ctx->ext.npn_advertised_cb_arg = arg;
 }
 
 /*
@@ -2290,15 +2298,12 @@ void SSL_CTX_set_next_protos_advertised_cb(SSL_CTX *ctx,
  * select a protocol. It is fatal to the connection if this callback returns
  * a value other than SSL_TLSEXT_ERR_OK.
  */
-void SSL_CTX_set_next_proto_select_cb(SSL_CTX *ctx,
-                                      int (*cb) (SSL *s, unsigned char **out,
-                                                 unsigned char *outlen,
-                                                 const unsigned char *in,
-                                                 unsigned int inlen,
-                                                 void *arg), void *arg)
+void SSL_CTX_set_npn_select_cb(SSL_CTX *ctx,
+                               SSL_CTX_npn_select_cb_func cb,
+                               void *arg)
 {
-    ctx->next_proto_select_cb = cb;
-    ctx->next_proto_select_cb_arg = arg;
+    ctx->ext.npn_select_cb = cb;
+    ctx->ext.npn_select_cb_arg = arg;
 }
 #endif
 
@@ -2310,13 +2315,13 @@ void SSL_CTX_set_next_proto_select_cb(SSL_CTX *ctx,
 int SSL_CTX_set_alpn_protos(SSL_CTX *ctx, const unsigned char *protos,
                             unsigned int protos_len)
 {
-    OPENSSL_free(ctx->alpn_client_proto_list);
-    ctx->alpn_client_proto_list = OPENSSL_memdup(protos, protos_len);
-    if (ctx->alpn_client_proto_list == NULL) {
+    OPENSSL_free(ctx->ext.alpn);
+    ctx->ext.alpn = OPENSSL_memdup(protos, protos_len);
+    if (ctx->ext.alpn == NULL) {
         SSLerr(SSL_F_SSL_CTX_SET_ALPN_PROTOS, ERR_R_MALLOC_FAILURE);
         return 1;
     }
-    ctx->alpn_client_proto_list_len = protos_len;
+    ctx->ext.alpn_len = protos_len;
 
     return 0;
 }
@@ -2329,13 +2334,13 @@ int SSL_CTX_set_alpn_protos(SSL_CTX *ctx, const unsigned char *protos,
 int SSL_set_alpn_protos(SSL *ssl, const unsigned char *protos,
                         unsigned int protos_len)
 {
-    OPENSSL_free(ssl->alpn_client_proto_list);
-    ssl->alpn_client_proto_list = OPENSSL_memdup(protos, protos_len);
-    if (ssl->alpn_client_proto_list == NULL) {
+    OPENSSL_free(ssl->ext.alpn);
+    ssl->ext.alpn = OPENSSL_memdup(protos, protos_len);
+    if (ssl->ext.alpn == NULL) {
         SSLerr(SSL_F_SSL_SET_ALPN_PROTOS, ERR_R_MALLOC_FAILURE);
         return 1;
     }
-    ssl->alpn_client_proto_list_len = protos_len;
+    ssl->ext.alpn_len = protos_len;
 
     return 0;
 }
@@ -2346,15 +2351,11 @@ int SSL_set_alpn_protos(SSL *ssl, const unsigned char *protos,
  * from the client's list of offered protocols.
  */
 void SSL_CTX_set_alpn_select_cb(SSL_CTX *ctx,
-                                int (*cb) (SSL *ssl,
-                                           const unsigned char **out,
-                                           unsigned char *outlen,
-                                           const unsigned char *in,
-                                           unsigned int inlen,
-                                           void *arg), void *arg)
+                                SSL_CTX_alpn_select_cb_func cb,
+                                void *arg)
 {
-    ctx->alpn_select_cb = cb;
-    ctx->alpn_select_cb_arg = arg;
+    ctx->ext.alpn_select_cb = cb;
+    ctx->ext.alpn_select_cb_arg = arg;
 }
 
 /*
@@ -2390,13 +2391,21 @@ int SSL_export_keying_material(SSL *s, unsigned char *out, size_t olen,
 
 static unsigned long ssl_session_hash(const SSL_SESSION *a)
 {
+    const unsigned char *session_id = a->session_id;
     unsigned long l;
+    unsigned char tmp_storage[4];
+
+    if (a->session_id_length < sizeof(tmp_storage)) {
+        memset(tmp_storage, 0, sizeof(tmp_storage));
+        memcpy(tmp_storage, a->session_id, a->session_id_length);
+        session_id = tmp_storage;
+    }
 
     l = (unsigned long)
-        ((unsigned int)a->session_id[0]) |
-        ((unsigned int)a->session_id[1] << 8L) |
-        ((unsigned long)a->session_id[2] << 16L) |
-        ((unsigned long)a->session_id[3] << 24L);
+        ((unsigned long)session_id[0]) |
+        ((unsigned long)session_id[1] << 8L) |
+        ((unsigned long)session_id[2] << 16L) |
+        ((unsigned long)session_id[3] << 24L);
     return (l);
 }
 
@@ -2513,12 +2522,12 @@ SSL_CTX *SSL_CTX_new(const SSL_METHOD *meth)
     ret->split_send_fragment = SSL3_RT_MAX_PLAIN_LENGTH;
 
     /* Setup RFC5077 ticket keys */
-    if ((RAND_bytes(ret->tlsext_tick_key_name,
-                    sizeof(ret->tlsext_tick_key_name)) <= 0)
-        || (RAND_bytes(ret->tlsext_tick_hmac_key,
-                       sizeof(ret->tlsext_tick_hmac_key)) <= 0)
-        || (RAND_bytes(ret->tlsext_tick_aes_key,
-                       sizeof(ret->tlsext_tick_aes_key)) <= 0))
+    if ((RAND_bytes(ret->ext.tick_key_name,
+                    sizeof(ret->ext.tick_key_name)) <= 0)
+        || (RAND_bytes(ret->ext.tick_hmac_key,
+                       sizeof(ret->ext.tick_hmac_key)) <= 0)
+        || (RAND_bytes(ret->ext.tick_aes_key,
+                       sizeof(ret->ext.tick_aes_key)) <= 0))
         ret->options |= SSL_OP_NO_TICKET;
 
 #ifndef OPENSSL_NO_SRP
@@ -2556,7 +2565,7 @@ SSL_CTX *SSL_CTX_new(const SSL_METHOD *meth)
      */
     ret->options |= SSL_OP_NO_COMPRESSION;
 
-    ret->tlsext_status_type = TLSEXT_STATUSTYPE_nothing;
+    ret->ext.status_type = TLSEXT_STATUSTYPE_nothing;
 
     return ret;
  err:
@@ -2629,10 +2638,10 @@ void SSL_CTX_free(SSL_CTX *a)
 #endif
 
 #ifndef OPENSSL_NO_EC
-    OPENSSL_free(a->tlsext_ecpointformatlist);
-    OPENSSL_free(a->tlsext_supportedgroupslist);
+    OPENSSL_free(a->ext.ecpointformats);
+    OPENSSL_free(a->ext.supportedgroups);
 #endif
-    OPENSSL_free(a->alpn_client_proto_list);
+    OPENSSL_free(a->ext.alpn);
 
     CRYPTO_THREAD_lock_free(a->lock);
 
@@ -2711,16 +2720,12 @@ void SSL_set_cert_cb(SSL *s, int (*cb) (SSL *ssl, void *arg), void *arg)
 
 void ssl_set_masks(SSL *s)
 {
-#if !defined(OPENSSL_NO_EC) || !defined(OPENSSL_NO_GOST)
-    CERT_PKEY *cpk;
-#endif
     CERT *c = s->cert;
     uint32_t *pvalid = s->s3->tmp.valid_flags;
     int rsa_enc, rsa_sign, dh_tmp, dsa_sign;
     unsigned long mask_k, mask_a;
 #ifndef OPENSSL_NO_EC
     int have_ecc_cert, ecdsa_ok;
-    X509 *x = NULL;
 #endif
     if (c == NULL)
         return;
@@ -2731,8 +2736,8 @@ void ssl_set_masks(SSL *s)
     dh_tmp = 0;
 #endif
 
-    rsa_enc = pvalid[SSL_PKEY_RSA_ENC] & CERT_PKEY_VALID;
-    rsa_sign = pvalid[SSL_PKEY_RSA_SIGN] & CERT_PKEY_SIGN;
+    rsa_enc = pvalid[SSL_PKEY_RSA] & CERT_PKEY_VALID;
+    rsa_sign = pvalid[SSL_PKEY_RSA] & CERT_PKEY_SIGN;
     dsa_sign = pvalid[SSL_PKEY_DSA_SIGN] & CERT_PKEY_SIGN;
 #ifndef OPENSSL_NO_EC
     have_ecc_cert = pvalid[SSL_PKEY_ECC] & CERT_PKEY_VALID;
@@ -2746,18 +2751,15 @@ void ssl_set_masks(SSL *s)
 #endif
 
 #ifndef OPENSSL_NO_GOST
-    cpk = &(c->pkeys[SSL_PKEY_GOST12_512]);
-    if (cpk->x509 != NULL && cpk->privatekey != NULL) {
+    if (ssl_has_cert(s, SSL_PKEY_GOST12_512)) {
         mask_k |= SSL_kGOST;
         mask_a |= SSL_aGOST12;
     }
-    cpk = &(c->pkeys[SSL_PKEY_GOST12_256]);
-    if (cpk->x509 != NULL && cpk->privatekey != NULL) {
+    if (ssl_has_cert(s, SSL_PKEY_GOST12_256)) {
         mask_k |= SSL_kGOST;
         mask_a |= SSL_aGOST12;
     }
-    cpk = &(c->pkeys[SSL_PKEY_GOST01]);
-    if (cpk->x509 != NULL && cpk->privatekey != NULL) {
+    if (ssl_has_cert(s, SSL_PKEY_GOST01)) {
         mask_k |= SSL_kGOST;
         mask_a |= SSL_aGOST01;
     }
@@ -2786,9 +2788,7 @@ void ssl_set_masks(SSL *s)
 #ifndef OPENSSL_NO_EC
     if (have_ecc_cert) {
         uint32_t ex_kusage;
-        cpk = &c->pkeys[SSL_PKEY_ECC];
-        x = cpk->x509;
-        ex_kusage = X509_get_key_usage(x);
+        ex_kusage = X509_get_key_usage(c->pkeys[SSL_PKEY_ECC].x509);
         ecdsa_ok = ex_kusage & X509v3_KU_DIGITAL_SIGNATURE;
         if (!(pvalid[SSL_PKEY_ECC] & CERT_PKEY_SIGN))
             ecdsa_ok = 0;
@@ -2836,9 +2836,16 @@ int ssl_check_srvr_ecc_cert_and_alg(X509 *x, SSL *s)
 static int ssl_get_server_cert_index(const SSL *s)
 {
     int idx;
+
+    if (SSL_IS_TLS13(s)) {
+        if (s->s3->tmp.sigalg == NULL) {
+            SSLerr(SSL_F_SSL_GET_SERVER_CERT_INDEX, ERR_R_INTERNAL_ERROR);
+            return -1;
+        }
+        return s->s3->tmp.cert_idx;
+    }
+
     idx = ssl_cipher_get_cert_index(s->s3->tmp.new_cipher);
-    if (idx == SSL_PKEY_RSA_ENC && !s->cert->pkeys[SSL_PKEY_RSA_ENC].x509)
-        idx = SSL_PKEY_RSA_SIGN;
     if (idx == SSL_PKEY_GOST_EC) {
         if (s->cert->pkeys[SSL_PKEY_GOST12_512].x509)
             idx = SSL_PKEY_GOST12_512;
@@ -2884,15 +2891,12 @@ EVP_PKEY *ssl_get_sign_pkey(SSL *s, const SSL_CIPHER *cipher,
     alg_a = cipher->algorithm_auth;
     c = s->cert;
 
-    if ((alg_a & SSL_aDSS) && (c->pkeys[SSL_PKEY_DSA_SIGN].privatekey != NULL))
+    if (alg_a & SSL_aDSS && c->pkeys[SSL_PKEY_DSA_SIGN].privatekey != NULL)
         idx = SSL_PKEY_DSA_SIGN;
-    else if (alg_a & SSL_aRSA) {
-        if (c->pkeys[SSL_PKEY_RSA_SIGN].privatekey != NULL)
-            idx = SSL_PKEY_RSA_SIGN;
-        else if (c->pkeys[SSL_PKEY_RSA_ENC].privatekey != NULL)
-            idx = SSL_PKEY_RSA_ENC;
-    } else if ((alg_a & SSL_aECDSA) &&
-               (c->pkeys[SSL_PKEY_ECC].privatekey != NULL))
+    else if (alg_a & SSL_aRSA && c->pkeys[SSL_PKEY_RSA].privatekey != NULL)
+            idx = SSL_PKEY_RSA;
+    else if (alg_a & SSL_aECDSA &&
+             c->pkeys[SSL_PKEY_ECC].privatekey != NULL)
         idx = SSL_PKEY_ECC;
     if (idx == -1) {
         SSLerr(SSL_F_SSL_GET_SIGN_PKEY, ERR_R_INTERNAL_ERROR);
@@ -3094,7 +3098,7 @@ int SSL_do_handshake(SSL *s)
         return -1;
     }
 
-    s->method->ssl_renegotiate_check(s);
+    s->method->ssl_renegotiate_check(s, 0);
 
     if (SSL_in_init(s) || SSL_in_before(s)) {
         if ((s->mode & SSL_MODE_ASYNC) && ASYNC_get_current_job() == NULL) {
@@ -3476,7 +3480,7 @@ SSL_CTX *SSL_set_SSL_CTX(SSL *ssl, SSL_CTX *ctx)
     if (ssl->ctx == ctx)
         return ssl->ctx;
     if (ctx == NULL)
-        ctx = ssl->initial_ctx;
+        ctx = ssl->session_ctx;
     new_cert = ssl_cert_dup(ctx->cert);
     if (new_cert == NULL) {
         return NULL;
@@ -3728,46 +3732,22 @@ const char *SSL_get_psk_identity(const SSL *s)
     return (s->session->psk_identity);
 }
 
-void SSL_set_psk_client_callback(SSL *s,
-                                 unsigned int (*cb) (SSL *ssl,
-                                                     const char *hint,
-                                                     char *identity,
-                                                     unsigned int
-                                                     max_identity_len,
-                                                     unsigned char *psk,
-                                                     unsigned int max_psk_len))
+void SSL_set_psk_client_callback(SSL *s, SSL_psk_client_cb_func cb)
 {
     s->psk_client_callback = cb;
 }
 
-void SSL_CTX_set_psk_client_callback(SSL_CTX *ctx,
-                                     unsigned int (*cb) (SSL *ssl,
-                                                         const char *hint,
-                                                         char *identity,
-                                                         unsigned int
-                                                         max_identity_len,
-                                                         unsigned char *psk,
-                                                         unsigned int
-                                                         max_psk_len))
+void SSL_CTX_set_psk_client_callback(SSL_CTX *ctx, SSL_psk_client_cb_func cb)
 {
     ctx->psk_client_callback = cb;
 }
 
-void SSL_set_psk_server_callback(SSL *s,
-                                 unsigned int (*cb) (SSL *ssl,
-                                                     const char *identity,
-                                                     unsigned char *psk,
-                                                     unsigned int max_psk_len))
+void SSL_set_psk_server_callback(SSL *s, SSL_psk_server_cb_func cb)
 {
     s->psk_server_callback = cb;
 }
 
-void SSL_CTX_set_psk_server_callback(SSL_CTX *ctx,
-                                     unsigned int (*cb) (SSL *ssl,
-                                                         const char *identity,
-                                                         unsigned char *psk,
-                                                         unsigned int
-                                                         max_psk_len))
+void SSL_CTX_set_psk_server_callback(SSL_CTX *ctx, SSL_psk_server_cb_func cb)
 {
     ctx->psk_server_callback = cb;
 }
@@ -3809,8 +3789,8 @@ void SSL_set_not_resumable_session_callback(SSL *ssl,
 /*
  * Allocates new EVP_MD_CTX and sets pointer to it into given pointer
  * variable, freeing EVP_MD_CTX previously stored in that variable, if any.
- * If EVP_MD pointer is passed, initializes ctx with this md Returns newly
- * allocated ctx;
+ * If EVP_MD pointer is passed, initializes ctx with this md.
+ * Returns the newly allocated ctx;
  */
 
 EVP_MD_CTX *ssl_replace_hash(EVP_MD_CTX **hash, const EVP_MD *md)
@@ -4040,9 +4020,9 @@ static int ct_extract_tls_extension_scts(SSL *s)
 {
     int scts_extracted = 0;
 
-    if (s->tlsext_scts != NULL) {
-        const unsigned char *p = s->tlsext_scts;
-        STACK_OF(SCT) *scts = o2i_SCT_LIST(NULL, &p, s->tlsext_scts_len);
+    if (s->ext.scts != NULL) {
+        const unsigned char *p = s->ext.scts;
+        STACK_OF(SCT) *scts = o2i_SCT_LIST(NULL, &p, s->ext.scts_len);
 
         scts_extracted = ct_move_scts(&s->scts, scts, SCT_SOURCE_TLS_EXTENSION);
 
@@ -4070,11 +4050,11 @@ static int ct_extract_ocsp_response_scts(SSL *s)
     STACK_OF(SCT) *scts = NULL;
     int i;
 
-    if (s->tlsext_ocsp_resp == NULL || s->tlsext_ocsp_resplen == 0)
+    if (s->ext.ocsp.resp == NULL || s->ext.ocsp.resp_len == 0)
         goto err;
 
-    p = s->tlsext_ocsp_resp;
-    rsp = d2i_OCSP_RESPONSE(NULL, &p, (int)s->tlsext_ocsp_resplen);
+    p = s->ext.ocsp.resp;
+    rsp = d2i_OCSP_RESPONSE(NULL, &p, (int)s->ext.ocsp.resp_len);
     if (rsp == NULL)
         goto err;
 
@@ -4375,3 +4355,99 @@ const CTLOG_STORE *SSL_CTX_get0_ctlog_store(const SSL_CTX *ctx)
 }
 
 #endif
+
+void SSL_CTX_set_keylog_callback(SSL_CTX *ctx, SSL_CTX_keylog_cb_func cb)
+{
+    ctx->keylog_callback = cb;
+}
+
+SSL_CTX_keylog_cb_func SSL_CTX_get_keylog_callback(const SSL_CTX *ctx)
+{
+    return ctx->keylog_callback;
+}
+
+static int nss_keylog_int(const char *prefix,
+                          SSL *ssl,
+                          const uint8_t *parameter_1,
+                          size_t parameter_1_len,
+                          const uint8_t *parameter_2,
+                          size_t parameter_2_len)
+{
+    char *out = NULL;
+    char *cursor = NULL;
+    size_t out_len = 0;
+    size_t i;
+    size_t prefix_len;
+
+    if (ssl->ctx->keylog_callback == NULL) return 1;
+
+    /*
+     * Our output buffer will contain the following strings, rendered with
+     * space characters in between, terminated by a NULL character: first the
+     * prefix, then the first parameter, then the second parameter. The
+     * meaning of each parameter depends on the specific key material being
+     * logged. Note that the first and second parameters are encoded in
+     * hexadecimal, so we need a buffer that is twice their lengths.
+     */
+    prefix_len = strlen(prefix);
+    out_len = prefix_len + (2*parameter_1_len) + (2*parameter_2_len) + 3;
+    if ((out = cursor = OPENSSL_malloc(out_len)) == NULL) {
+        SSLerr(SSL_F_NSS_KEYLOG_INT, ERR_R_MALLOC_FAILURE);
+        return 0;
+    }
+
+    strcpy(cursor, prefix);
+    cursor += prefix_len;
+    *cursor++ = ' ';
+
+    for (i = 0; i < parameter_1_len; i++) {
+        sprintf(cursor, "%02x", parameter_1[i]);
+        cursor += 2;
+    }
+    *cursor++ = ' ';
+
+    for (i = 0; i < parameter_2_len; i++) {
+        sprintf(cursor, "%02x", parameter_2[i]);
+        cursor += 2;
+    }
+    *cursor = '\0';
+
+    ssl->ctx->keylog_callback(ssl, (const char *)out);
+    OPENSSL_free(out);
+    return 1;
+
+}
+
+int ssl_log_rsa_client_key_exchange(SSL *ssl,
+                                    const uint8_t *encrypted_premaster,
+                                    size_t encrypted_premaster_len,
+                                    const uint8_t *premaster,
+                                    size_t premaster_len)
+{
+    if (encrypted_premaster_len < 8) {
+        SSLerr(SSL_F_SSL_LOG_RSA_CLIENT_KEY_EXCHANGE, ERR_R_INTERNAL_ERROR);
+        return 0;
+    }
+
+    /* We only want the first 8 bytes of the encrypted premaster as a tag. */
+    return nss_keylog_int("RSA",
+                          ssl,
+                          encrypted_premaster,
+                          8,
+                          premaster,
+                          premaster_len);
+}
+
+int ssl_log_secret(SSL *ssl,
+                   const char *label,
+                   const uint8_t *secret,
+                   size_t secret_len)
+{
+    return nss_keylog_int(label,
+                          ssl,
+                          ssl->s3->client_random,
+                          SSL3_RANDOM_SIZE,
+                          secret,
+                          secret_len);
+}
+