Updates following review feedback
[openssl.git] / ssl / ssl_lib.c
index 58873456c8134693e1f0eb7a89138f7e7cea50b8..cb5e0cfbc9bce8b67e8013974fff44214b5de339 100644 (file)
@@ -471,6 +471,8 @@ int SSL_clear(SSL *s)
     clear_ciphers(s);
     s->first_packet = 0;
 
+    s->key_update = SSL_KEY_UPDATE_NONE;
+
     /* Reset DANE verification result state */
     s->dane.mdpth = -1;
     s->dane.pdpth = -1;
@@ -599,7 +601,7 @@ SSL *SSL_new(SSL_CTX *ctx)
     s->ext.ocsp.resp = NULL;
     s->ext.ocsp.resp_len = 0;
     SSL_CTX_up_ref(ctx);
-    s->initial_ctx = ctx;
+    s->session_ctx = ctx;
 #ifndef OPENSSL_NO_EC
     if (ctx->ext.ecpointformats) {
         s->ext.ecpointformats =
@@ -639,6 +641,8 @@ SSL *SSL_new(SSL_CTX *ctx)
 
     s->method = ctx->method;
 
+    s->key_update = SSL_KEY_UPDATE_NONE;
+
     if (!s->method->ssl_new(s))
         goto err;
 
@@ -995,7 +999,7 @@ void SSL_free(SSL *s)
     /* Free up if allocated */
 
     OPENSSL_free(s->ext.hostname);
-    SSL_CTX_free(s->initial_ctx);
+    SSL_CTX_free(s->session_ctx);
 #ifndef OPENSSL_NO_EC
     OPENSSL_free(s->ext.ecpointformats);
     OPENSSL_free(s->ext.supportedgroups);
@@ -1714,8 +1718,46 @@ int SSL_shutdown(SSL *s)
     }
 }
 
+int SSL_key_update(SSL *s, SSL_KEY_UPDATE updatetype)
+{
+    /*
+     * TODO(TLS1.3): How will applications know whether TLSv1.3+ has been
+     * negotiated, and that it is appropriate to call SSL_key_update() instead
+     * of SSL_renegotiate().
+     */
+    if (!SSL_IS_TLS13(s)) {
+        SSLerr(SSL_F_SSL_KEY_UPDATE, SSL_R_WRONG_SSL_VERSION);
+        return 0;
+    }
+
+    if (updatetype != SSL_KEY_UPDATE_NOT_REQUESTED
+            && updatetype != SSL_KEY_UPDATE_REQUESTED) {
+        SSLerr(SSL_F_SSL_KEY_UPDATE, SSL_R_INVALID_KEY_UPDATE_TYPE);
+        return 0;
+    }
+
+    if (!SSL_is_init_finished(s)) {
+        SSLerr(SSL_F_SSL_KEY_UPDATE, SSL_R_STILL_IN_INIT);
+        return 0;
+    }
+
+    ossl_statem_set_in_init(s, 1);
+    s->key_update = updatetype;
+    return 1;
+}
+
+SSL_KEY_UPDATE SSL_get_key_update_type(SSL *s)
+{
+    return s->key_update;
+}
+
 int SSL_renegotiate(SSL *s)
 {
+    if (SSL_IS_TLS13(s)) {
+        SSLerr(SSL_F_SSL_RENEGOTIATE, SSL_R_WRONG_SSL_VERSION);
+        return 0;
+    }
+
     if (s->renegotiate == 0)
         s->renegotiate = 1;
 
@@ -1726,6 +1768,9 @@ int SSL_renegotiate(SSL *s)
 
 int SSL_renegotiate_abbreviated(SSL *s)
 {
+    if (SSL_IS_TLS13(s))
+        return 0;
+
     if (s->renegotiate == 0)
         s->renegotiate = 1;
 
@@ -2377,13 +2422,21 @@ int SSL_export_keying_material(SSL *s, unsigned char *out, size_t olen,
 
 static unsigned long ssl_session_hash(const SSL_SESSION *a)
 {
+    const unsigned char *session_id = a->session_id;
     unsigned long l;
+    unsigned char tmp_storage[4];
+
+    if (a->session_id_length < sizeof(tmp_storage)) {
+        memset(tmp_storage, 0, sizeof(tmp_storage));
+        memcpy(tmp_storage, a->session_id, a->session_id_length);
+        session_id = tmp_storage;
+    }
 
     l = (unsigned long)
-        ((unsigned int)a->session_id[0]) |
-        ((unsigned int)a->session_id[1] << 8L) |
-        ((unsigned long)a->session_id[2] << 16L) |
-        ((unsigned long)a->session_id[3] << 24L);
+        ((unsigned long)session_id[0]) |
+        ((unsigned long)session_id[1] << 8L) |
+        ((unsigned long)session_id[2] << 16L) |
+        ((unsigned long)session_id[3] << 24L);
     return (l);
 }
 
@@ -2698,16 +2751,12 @@ void SSL_set_cert_cb(SSL *s, int (*cb) (SSL *ssl, void *arg), void *arg)
 
 void ssl_set_masks(SSL *s)
 {
-#if !defined(OPENSSL_NO_EC) || !defined(OPENSSL_NO_GOST)
-    CERT_PKEY *cpk;
-#endif
     CERT *c = s->cert;
     uint32_t *pvalid = s->s3->tmp.valid_flags;
     int rsa_enc, rsa_sign, dh_tmp, dsa_sign;
     unsigned long mask_k, mask_a;
 #ifndef OPENSSL_NO_EC
     int have_ecc_cert, ecdsa_ok;
-    X509 *x = NULL;
 #endif
     if (c == NULL)
         return;
@@ -2718,8 +2767,8 @@ void ssl_set_masks(SSL *s)
     dh_tmp = 0;
 #endif
 
-    rsa_enc = pvalid[SSL_PKEY_RSA_ENC] & CERT_PKEY_VALID;
-    rsa_sign = pvalid[SSL_PKEY_RSA_SIGN] & CERT_PKEY_SIGN;
+    rsa_enc = pvalid[SSL_PKEY_RSA] & CERT_PKEY_VALID;
+    rsa_sign = pvalid[SSL_PKEY_RSA] & CERT_PKEY_SIGN;
     dsa_sign = pvalid[SSL_PKEY_DSA_SIGN] & CERT_PKEY_SIGN;
 #ifndef OPENSSL_NO_EC
     have_ecc_cert = pvalid[SSL_PKEY_ECC] & CERT_PKEY_VALID;
@@ -2733,18 +2782,15 @@ void ssl_set_masks(SSL *s)
 #endif
 
 #ifndef OPENSSL_NO_GOST
-    cpk = &(c->pkeys[SSL_PKEY_GOST12_512]);
-    if (cpk->x509 != NULL && cpk->privatekey != NULL) {
+    if (ssl_has_cert(s, SSL_PKEY_GOST12_512)) {
         mask_k |= SSL_kGOST;
         mask_a |= SSL_aGOST12;
     }
-    cpk = &(c->pkeys[SSL_PKEY_GOST12_256]);
-    if (cpk->x509 != NULL && cpk->privatekey != NULL) {
+    if (ssl_has_cert(s, SSL_PKEY_GOST12_256)) {
         mask_k |= SSL_kGOST;
         mask_a |= SSL_aGOST12;
     }
-    cpk = &(c->pkeys[SSL_PKEY_GOST01]);
-    if (cpk->x509 != NULL && cpk->privatekey != NULL) {
+    if (ssl_has_cert(s, SSL_PKEY_GOST01)) {
         mask_k |= SSL_kGOST;
         mask_a |= SSL_aGOST01;
     }
@@ -2773,9 +2819,7 @@ void ssl_set_masks(SSL *s)
 #ifndef OPENSSL_NO_EC
     if (have_ecc_cert) {
         uint32_t ex_kusage;
-        cpk = &c->pkeys[SSL_PKEY_ECC];
-        x = cpk->x509;
-        ex_kusage = X509_get_key_usage(x);
+        ex_kusage = X509_get_key_usage(c->pkeys[SSL_PKEY_ECC].x509);
         ecdsa_ok = ex_kusage & X509v3_KU_DIGITAL_SIGNATURE;
         if (!(pvalid[SSL_PKEY_ECC] & CERT_PKEY_SIGN))
             ecdsa_ok = 0;
@@ -2820,93 +2864,17 @@ int ssl_check_srvr_ecc_cert_and_alg(X509 *x, SSL *s)
 
 #endif
 
-static int ssl_get_server_cert_index(const SSL *s)
-{
-    int idx;
-    idx = ssl_cipher_get_cert_index(s->s3->tmp.new_cipher);
-    if (idx == SSL_PKEY_RSA_ENC && !s->cert->pkeys[SSL_PKEY_RSA_ENC].x509)
-        idx = SSL_PKEY_RSA_SIGN;
-    if (idx == SSL_PKEY_GOST_EC) {
-        if (s->cert->pkeys[SSL_PKEY_GOST12_512].x509)
-            idx = SSL_PKEY_GOST12_512;
-        else if (s->cert->pkeys[SSL_PKEY_GOST12_256].x509)
-            idx = SSL_PKEY_GOST12_256;
-        else if (s->cert->pkeys[SSL_PKEY_GOST01].x509)
-            idx = SSL_PKEY_GOST01;
-        else
-            idx = -1;
-    }
-    if (idx == -1)
-        SSLerr(SSL_F_SSL_GET_SERVER_CERT_INDEX, ERR_R_INTERNAL_ERROR);
-    return idx;
-}
-
-CERT_PKEY *ssl_get_server_send_pkey(SSL *s)
-{
-    CERT *c;
-    int i;
-
-    c = s->cert;
-    if (!s->s3 || !s->s3->tmp.new_cipher)
-        return NULL;
-    ssl_set_masks(s);
-
-    i = ssl_get_server_cert_index(s);
-
-    /* This may or may not be an error. */
-    if (i < 0)
-        return NULL;
-
-    /* May be NULL. */
-    return &c->pkeys[i];
-}
-
-EVP_PKEY *ssl_get_sign_pkey(SSL *s, const SSL_CIPHER *cipher,
-                            const EVP_MD **pmd)
-{
-    unsigned long alg_a;
-    CERT *c;
-    int idx = -1;
-
-    alg_a = cipher->algorithm_auth;
-    c = s->cert;
-
-    if ((alg_a & SSL_aDSS) && (c->pkeys[SSL_PKEY_DSA_SIGN].privatekey != NULL))
-        idx = SSL_PKEY_DSA_SIGN;
-    else if (alg_a & SSL_aRSA) {
-        if (c->pkeys[SSL_PKEY_RSA_SIGN].privatekey != NULL)
-            idx = SSL_PKEY_RSA_SIGN;
-        else if (c->pkeys[SSL_PKEY_RSA_ENC].privatekey != NULL)
-            idx = SSL_PKEY_RSA_ENC;
-    } else if ((alg_a & SSL_aECDSA) &&
-               (c->pkeys[SSL_PKEY_ECC].privatekey != NULL))
-        idx = SSL_PKEY_ECC;
-    if (idx == -1) {
-        SSLerr(SSL_F_SSL_GET_SIGN_PKEY, ERR_R_INTERNAL_ERROR);
-        return (NULL);
-    }
-    if (pmd)
-        *pmd = s->s3->tmp.md[idx];
-    return c->pkeys[idx].privatekey;
-}
-
 int ssl_get_server_cert_serverinfo(SSL *s, const unsigned char **serverinfo,
                                    size_t *serverinfo_length)
 {
-    CERT *c = NULL;
-    int i = 0;
+    CERT_PKEY *cpk = s->s3->tmp.cert;
     *serverinfo_length = 0;
 
-    c = s->cert;
-    i = ssl_get_server_cert_index(s);
-
-    if (i == -1)
-        return 0;
-    if (c->pkeys[i].serverinfo == NULL)
+    if (cpk == NULL || cpk->serverinfo == NULL)
         return 0;
 
-    *serverinfo = c->pkeys[i].serverinfo;
-    *serverinfo_length = c->pkeys[i].serverinfo_length;
+    *serverinfo = cpk->serverinfo;
+    *serverinfo_length = cpk->serverinfo_length;
     return 1;
 }
 
@@ -3081,7 +3049,7 @@ int SSL_do_handshake(SSL *s)
         return -1;
     }
 
-    s->method->ssl_renegotiate_check(s);
+    s->method->ssl_renegotiate_check(s, 0);
 
     if (SSL_in_init(s) || SSL_in_before(s)) {
         if ((s->mode & SSL_MODE_ASYNC) && ASYNC_get_current_job() == NULL) {
@@ -3463,7 +3431,7 @@ SSL_CTX *SSL_set_SSL_CTX(SSL *ssl, SSL_CTX *ctx)
     if (ssl->ctx == ctx)
         return ssl->ctx;
     if (ctx == NULL)
-        ctx = ssl->initial_ctx;
+        ctx = ssl->session_ctx;
     new_cert = ssl_cert_dup(ctx->cert);
     if (new_cert == NULL) {
         return NULL;
@@ -3772,8 +3740,8 @@ void SSL_set_not_resumable_session_callback(SSL *ssl,
 /*
  * Allocates new EVP_MD_CTX and sets pointer to it into given pointer
  * variable, freeing EVP_MD_CTX previously stored in that variable, if any.
- * If EVP_MD pointer is passed, initializes ctx with this md Returns newly
- * allocated ctx;
+ * If EVP_MD pointer is passed, initializes ctx with this md.
+ * Returns the newly allocated ctx;
  */
 
 EVP_MD_CTX *ssl_replace_hash(EVP_MD_CTX **hash, const EVP_MD *md)
@@ -4338,3 +4306,99 @@ const CTLOG_STORE *SSL_CTX_get0_ctlog_store(const SSL_CTX *ctx)
 }
 
 #endif
+
+void SSL_CTX_set_keylog_callback(SSL_CTX *ctx, SSL_CTX_keylog_cb_func cb)
+{
+    ctx->keylog_callback = cb;
+}
+
+SSL_CTX_keylog_cb_func SSL_CTX_get_keylog_callback(const SSL_CTX *ctx)
+{
+    return ctx->keylog_callback;
+}
+
+static int nss_keylog_int(const char *prefix,
+                          SSL *ssl,
+                          const uint8_t *parameter_1,
+                          size_t parameter_1_len,
+                          const uint8_t *parameter_2,
+                          size_t parameter_2_len)
+{
+    char *out = NULL;
+    char *cursor = NULL;
+    size_t out_len = 0;
+    size_t i;
+    size_t prefix_len;
+
+    if (ssl->ctx->keylog_callback == NULL) return 1;
+
+    /*
+     * Our output buffer will contain the following strings, rendered with
+     * space characters in between, terminated by a NULL character: first the
+     * prefix, then the first parameter, then the second parameter. The
+     * meaning of each parameter depends on the specific key material being
+     * logged. Note that the first and second parameters are encoded in
+     * hexadecimal, so we need a buffer that is twice their lengths.
+     */
+    prefix_len = strlen(prefix);
+    out_len = prefix_len + (2*parameter_1_len) + (2*parameter_2_len) + 3;
+    if ((out = cursor = OPENSSL_malloc(out_len)) == NULL) {
+        SSLerr(SSL_F_NSS_KEYLOG_INT, ERR_R_MALLOC_FAILURE);
+        return 0;
+    }
+
+    strcpy(cursor, prefix);
+    cursor += prefix_len;
+    *cursor++ = ' ';
+
+    for (i = 0; i < parameter_1_len; i++) {
+        sprintf(cursor, "%02x", parameter_1[i]);
+        cursor += 2;
+    }
+    *cursor++ = ' ';
+
+    for (i = 0; i < parameter_2_len; i++) {
+        sprintf(cursor, "%02x", parameter_2[i]);
+        cursor += 2;
+    }
+    *cursor = '\0';
+
+    ssl->ctx->keylog_callback(ssl, (const char *)out);
+    OPENSSL_free(out);
+    return 1;
+
+}
+
+int ssl_log_rsa_client_key_exchange(SSL *ssl,
+                                    const uint8_t *encrypted_premaster,
+                                    size_t encrypted_premaster_len,
+                                    const uint8_t *premaster,
+                                    size_t premaster_len)
+{
+    if (encrypted_premaster_len < 8) {
+        SSLerr(SSL_F_SSL_LOG_RSA_CLIENT_KEY_EXCHANGE, ERR_R_INTERNAL_ERROR);
+        return 0;
+    }
+
+    /* We only want the first 8 bytes of the encrypted premaster as a tag. */
+    return nss_keylog_int("RSA",
+                          ssl,
+                          encrypted_premaster,
+                          8,
+                          premaster,
+                          premaster_len);
+}
+
+int ssl_log_secret(SSL *ssl,
+                   const char *label,
+                   const uint8_t *secret,
+                   size_t secret_len)
+{
+    return nss_keylog_int(label,
+                          ssl,
+                          ssl->s3->client_random,
+                          SSL3_RANDOM_SIZE,
+                          secret,
+                          secret_len);
+}
+