Extended PSK server support.
[openssl.git] / ssl / s3_srvr.c
index 3e7f44ada9a19ee65800db48c687541e7980ddac..caf45d1afbf526668d4116541424d3f1f38a1bb8 100644 (file)
@@ -387,19 +387,15 @@ int ssl3_accept(SSL *s)
             ret = ssl3_send_server_hello(s);
             if (ret <= 0)
                 goto end;
-#ifndef OPENSSL_NO_TLSEXT
+
             if (s->hit) {
                 if (s->tlsext_ticket_expected)
                     s->state = SSL3_ST_SW_SESSION_TICKET_A;
                 else
                     s->state = SSL3_ST_SW_CHANGE_A;
-            }
-#else
-            if (s->hit)
-                s->state = SSL3_ST_SW_CHANGE_A;
-#endif
-            else
+            } else {
                 s->state = SSL3_ST_SW_CERT_A;
+            }
             s->init_num = 0;
             break;
 
@@ -407,14 +403,12 @@ int ssl3_accept(SSL *s)
         case SSL3_ST_SW_CERT_B:
             /* Check if it is anon DH or anon ECDH, */
             /* normal PSK or SRP */
-            if (!
-                (s->s3->tmp.
-                 new_cipher->algorithm_auth & (SSL_aNULL | SSL_aSRP))
-&& !(s->s3->tmp.new_cipher->algorithm_mkey & SSL_kPSK)) {
+            if (!(s->s3->tmp.new_cipher->algorithm_auth &
+                 (SSL_aNULL | SSL_aSRP | SSL_aPSK))) {
                 ret = ssl3_send_server_certificate(s);
                 if (ret <= 0)
                     goto end;
-#ifndef OPENSSL_NO_TLSEXT
+
                 if (s->tlsext_status_expected)
                     s->state = SSL3_ST_SW_CERT_STATUS_A;
                 else
@@ -423,12 +417,6 @@ int ssl3_accept(SSL *s)
                 skip = 1;
                 s->state = SSL3_ST_SW_KEY_EXCH_A;
             }
-#else
-            } else
-                skip = 1;
-
-            s->state = SSL3_ST_SW_KEY_EXCH_A;
-#endif
             s->init_num = 0;
             break;
 
@@ -456,7 +444,10 @@ int ssl3_accept(SSL *s)
                  * provided
                  */
 #ifndef OPENSSL_NO_PSK
-                || ((alg_k & SSL_kPSK) && s->ctx->psk_identity_hint)
+                /* Only send SKE if we have identity hint for plain PSK */
+                || ((alg_k & (SSL_kPSK | SSL_kRSAPSK)) && s->ctx->psk_identity_hint)
+                /* For other PSK always send SKE */
+                || (alg_k & (SSL_PSK & (SSL_kDHEPSK | SSL_kECDHEPSK)))
 #endif
 #ifndef OPENSSL_NO_SRP
                 /* SRP: send ServerKeyExchange */
@@ -517,11 +508,9 @@ int ssl3_accept(SSL *s)
                 skip = 1;
                 s->s3->tmp.cert_request = 0;
                 s->state = SSL3_ST_SW_SRVR_DONE_A;
-                if (s->s3->handshake_buffer) {
-                    if (!ssl3_digest_cached_records(s)) {
-                        s->state = SSL_ST_ERR;
-                        return -1;
-                    }
+                if (!ssl3_digest_cached_records(s, 0)) {
+                    s->state = SSL_ST_ERR;
+                    return -1;
                 }
             } else {
                 s->s3->tmp.cert_request = 1;
@@ -587,7 +576,7 @@ int ssl3_accept(SSL *s)
                  * not sent. Also for GOST ciphersuites when the client uses
                  * its key from the certificate for key exchange.
                  */
-#if defined(OPENSSL_NO_TLSEXT) || defined(OPENSSL_NO_NEXTPROTONEG)
+#if defined(OPENSSL_NO_NEXTPROTONEG)
                 s->state = SSL3_ST_SR_FINISHED_A;
 #else
                 if (s->s3->next_proto_neg_seen)
@@ -608,14 +597,11 @@ int ssl3_accept(SSL *s)
                 }
                 /*
                  * For sigalgs freeze the handshake buffer. If we support
-                 * extms we've done this already.
+                 * extms we've done this already so this is a no-op
                  */
-                if (!(s->s3->flags & SSL_SESS_FLAG_EXTMS)) {
-                    s->s3->flags |= TLS1_FLAGS_KEEP_HANDSHAKE;
-                    if (!ssl3_digest_cached_records(s)) {
-                        s->state = SSL_ST_ERR;
-                        return -1;
-                    }
+                if (!ssl3_digest_cached_records(s, 1)) {
+                    s->state = SSL_ST_ERR;
+                    return -1;
                 }
             } else {
                 int offset = 0;
@@ -630,11 +616,9 @@ int ssl3_accept(SSL *s)
                  * CertificateVerify should be generalized. But it is next
                  * step
                  */
-                if (s->s3->handshake_buffer) {
-                    if (!ssl3_digest_cached_records(s)) {
-                        s->state = SSL_ST_ERR;
-                        return -1;
-                    }
+                if (!ssl3_digest_cached_records(s, 0)) {
+                    s->state = SSL_ST_ERR;
+                    return -1;
                 }
                 for (dgst_num = 0; dgst_num < SSL_MAX_DIGEST; dgst_num++)
                     if (s->s3->handshake_dgst[dgst_num]) {
@@ -666,7 +650,7 @@ int ssl3_accept(SSL *s)
             if (ret <= 0)
                 goto end;
 
-#if defined(OPENSSL_NO_TLSEXT) || defined(OPENSSL_NO_NEXTPROTONEG)
+#if defined(OPENSSL_NO_NEXTPROTONEG)
             s->state = SSL3_ST_SR_FINISHED_A;
 #else
             if (s->s3->next_proto_neg_seen)
@@ -677,7 +661,7 @@ int ssl3_accept(SSL *s)
             s->init_num = 0;
             break;
 
-#if !defined(OPENSSL_NO_TLSEXT) && !defined(OPENSSL_NO_NEXTPROTONEG)
+#if !defined(OPENSSL_NO_NEXTPROTONEG)
         case SSL3_ST_SR_NEXT_PROTO_A:
         case SSL3_ST_SR_NEXT_PROTO_B:
             /*
@@ -718,16 +702,13 @@ int ssl3_accept(SSL *s)
                 goto end;
             if (s->hit)
                 s->state = SSL_ST_OK;
-#ifndef OPENSSL_NO_TLSEXT
             else if (s->tlsext_ticket_expected)
                 s->state = SSL3_ST_SW_SESSION_TICKET_A;
-#endif
             else
                 s->state = SSL3_ST_SW_CHANGE_A;
             s->init_num = 0;
             break;
 
-#ifndef OPENSSL_NO_TLSEXT
         case SSL3_ST_SW_SESSION_TICKET_A:
         case SSL3_ST_SW_SESSION_TICKET_B:
             ret = ssl3_send_newsession_ticket(s);
@@ -746,8 +727,6 @@ int ssl3_accept(SSL *s)
             s->init_num = 0;
             break;
 
-#endif
-
         case SSL3_ST_SW_CHANGE_A:
         case SSL3_ST_SW_CHANGE_B:
 
@@ -790,7 +769,7 @@ int ssl3_accept(SSL *s)
                 goto end;
             s->state = SSL3_ST_SW_FLUSH;
             if (s->hit) {
-#if defined(OPENSSL_NO_TLSEXT) || defined(OPENSSL_NO_NEXTPROTONEG)
+#if defined(OPENSSL_NO_NEXTPROTONEG)
                 s->s3->tmp.next_state = SSL3_ST_SR_FINISHED_A;
 #else
                 if (s->s3->next_proto_neg_seen) {
@@ -883,7 +862,7 @@ int ssl3_send_hello_request(SSL *s)
 
 int ssl3_get_client_hello(SSL *s)
 {
-    int i, j, ok, al = SSL_AD_INTERNAL_ERROR, ret = -1;
+    int i, complen, j, ok, al = SSL_AD_INTERNAL_ERROR, ret = -1;
     unsigned int cookie_len;
     long n;
     unsigned long id;
@@ -921,13 +900,7 @@ int ssl3_get_client_hello(SSL *s)
     d = p = (unsigned char *)s->init_msg;
 
     /* First lets get s->client_version set correctly */
-    if (!s->read_hash && !s->enc_read_ctx
-            && RECORD_LAYER_is_sslv2_record(&s->rlayer)) {
-        if (n < MIN_SSL2_RECORD_LEN) {
-            SSLerr(SSL_F_SSL3_GET_CLIENT_HELLO, SSL_R_RECORD_LENGTH_MISMATCH);
-            al = SSL_AD_DECODE_ERROR;
-            goto f_err;
-        }
+    if (RECORD_LAYER_is_sslv2_record(&s->rlayer)) {
         /*-
          * An SSLv3/TLSv1 backwards-compatible CLIENT-HELLO in an SSLv2
          * header is sent directly on the wire, not wrapped as a TLS
@@ -986,11 +959,10 @@ int ssl3_get_client_hello(SSL *s)
     /* Do SSL/TLS version negotiation if applicable */
     if (!SSL_IS_DTLS(s)) {
         if (s->version != TLS_ANY_VERSION) {
-            if (s->client_version >= s->version
-                && (((s->client_version >> 8) & 0xff) == SSL3_VERSION_MAJOR)) {
+            if (s->client_version >= s->version) {
                 protverr = 0;
             }
-        } else if (((s->client_version >> 8) & 0xff) == SSL3_VERSION_MAJOR) {
+        } else if (s->client_version >= SSL3_VERSION) {
             switch(s->client_version) {
             default:
             case TLS1_2_VERSION:
@@ -1018,17 +990,20 @@ int ssl3_get_client_hello(SSL *s)
                 }
                 /* Deliberately fall through */
             case SSL3_VERSION:
+#ifndef OPENSSL_NO_SSL3
                 if(!(s->options & SSL_OP_NO_SSLv3)) {
                     s->version = SSL3_VERSION;
                     s->method = SSLv3_server_method();
                     protverr = 0;
                     break;
                 }
+#else
+                break;
+#endif
             }
         }
-    } else if (((s->client_version >> 8) & 0xff) == DTLS1_VERSION_MAJOR &&
-                (s->client_version <= s->version
-                || s->method->version == DTLS_ANY_VERSION)) {
+    } else if (s->client_version <= s->version
+                || s->method->version == DTLS_ANY_VERSION) {
         /*
          * For DTLS we just check versions are potentially compatible. Version
          * negotiation comes later.
@@ -1096,8 +1071,8 @@ int ssl3_get_client_hello(SSL *s)
         /* Set p to end of packet to ensure we don't look for extensions */
         p = d + n;
 
-        /* No compression, so set i to 0 */
-        i = 0;
+        /* No compression, so set complen to 0 */
+        complen = 0;
     } else {
         /* If we get here we've got SSLv3+ in an SSLv3+ record */
 
@@ -1341,8 +1316,8 @@ int ssl3_get_client_hello(SSL *s)
         }
 
         /* compression */
-        i = *(p++);
-        if ((p + i) > (d + n)) {
+        complen = *(p++);
+        if ((p + complen) > (d + n)) {
             /* not enough data */
             al = SSL_AD_DECODE_ERROR;
             SSLerr(SSL_F_SSL3_GET_CLIENT_HELLO, SSL_R_LENGTH_MISMATCH);
@@ -1351,13 +1326,13 @@ int ssl3_get_client_hello(SSL *s)
 #ifndef OPENSSL_NO_COMP
         q = p;
 #endif
-        for (j = 0; j < i; j++) {
+        for (j = 0; j < complen; j++) {
             if (p[j] == 0)
                 break;
         }
 
-        p += i;
-        if (j >= i) {
+        p += complen;
+        if (j >= complen) {
             /* no compress */
             al = SSL_AD_DECODE_ERROR;
             SSLerr(SSL_F_SSL3_GET_CLIENT_HELLO, SSL_R_NO_COMPRESSION_SPECIFIED);
@@ -1365,7 +1340,6 @@ int ssl3_get_client_hello(SSL *s)
         }
     }
 
-#ifndef OPENSSL_NO_TLSEXT
     /* TLS extensions */
     if (s->version >= SSL3_VERSION) {
         if (!ssl_parse_clienthello_tlsext(s, &p, d, n)) {
@@ -1422,11 +1396,10 @@ int ssl3_get_client_hello(SSL *s)
             s->cipher_list_by_id = sk_SSL_CIPHER_dup(s->session->ciphers);
         }
     }
-#endif
 
     /*
      * Worst case, we will use the NULL compression, but if we have other
-     * options, we will now look for them.  We have i-1 compression
+     * options, we will now look for them.  We have complen-1 compression
      * algorithms from the client, starting at q.
      */
     s->s3->tmp.new_compression = NULL;
@@ -1455,11 +1428,11 @@ int ssl3_get_client_hello(SSL *s)
             goto f_err;
         }
         /* Look for resumed method in compression list */
-        for (m = 0; m < i; m++) {
+        for (m = 0; m < complen; m++) {
             if (q[m] == comp_id)
                 break;
         }
-        if (m >= i) {
+        if (m >= complen) {
             al = SSL_AD_ILLEGAL_PARAMETER;
             SSLerr(SSL_F_SSL3_GET_CLIENT_HELLO,
                    SSL_R_REQUIRED_COMPRESSSION_ALGORITHM_MISSING);
@@ -1475,7 +1448,7 @@ int ssl3_get_client_hello(SSL *s)
         for (m = 0; m < nn; m++) {
             comp = sk_SSL_COMP_value(s->ctx->comp_methods, m);
             v = comp->id;
-            for (o = 0; o < i; o++) {
+            for (o = 0; o < complen; o++) {
                 if (v == q[o]) {
                     done = 1;
                     break;
@@ -1559,7 +1532,7 @@ int ssl3_get_client_hello(SSL *s)
     }
 
     if (!SSL_USE_SIGALGS(s) || !(s->verify_mode & SSL_VERIFY_PEER)) {
-        if (!ssl3_digest_cached_records(s))
+        if (!ssl3_digest_cached_records(s, 0))
             goto f_err;
     }
 
@@ -1606,13 +1579,13 @@ int ssl3_send_server_hello(SSL *s)
 
     if (s->state == SSL3_ST_SW_SRVR_HELLO_A) {
         buf = (unsigned char *)s->init_buf->data;
-#ifdef OPENSSL_NO_TLSEXT
+
         p = s->s3->server_random;
         if (ssl_fill_hello_random(s, 1, p, SSL3_RANDOM_SIZE) <= 0) {
             s->state = SSL_ST_ERR;
             return -1;
         }
-#endif
+
         /* Do the message type and length last */
         d = p = ssl_handshake_start(s);
 
@@ -1667,7 +1640,7 @@ int ssl3_send_server_hello(SSL *s)
         else
             *(p++) = s->s3->tmp.new_compression->id;
 #endif
-#ifndef OPENSSL_NO_TLSEXT
+
         if (ssl_prepare_serverhello_tlsext(s) <= 0) {
             SSLerr(SSL_F_SSL3_SEND_SERVER_HELLO, SSL_R_SERVERHELLO_TLSEXT);
             s->state = SSL_ST_ERR;
@@ -1681,7 +1654,7 @@ int ssl3_send_server_hello(SSL *s)
             s->state = SSL_ST_ERR;
             return -1;
         }
-#endif
+
         /* do the header */
         l = (p - d);
         if (!ssl_set_handshake_header(s, SSL3_MT_SERVER_HELLO, l)) {
@@ -1750,6 +1723,19 @@ int ssl3_send_server_key_exchange(SSL *s)
 
         r[0] = r[1] = r[2] = r[3] = NULL;
         n = 0;
+#ifndef OPENSSL_NO_PSK
+        if (type & SSL_PSK) {
+            /*
+             * reserve size for record length and PSK identity hint
+             */
+            n += 2;
+            if (s->ctx->psk_identity_hint)
+                n += strlen(s->ctx->psk_identity_hint);
+        }
+        /* Plain PSK or RSAPSK nothing to do */
+        if (type & (SSL_kPSK | SSL_kRSAPSK)) {
+        } else
+#endif                          /* !OPENSSL_NO_PSK */
 #ifndef OPENSSL_NO_RSA
         if (type & SSL_kRSA) {
             rsa = cert->rsa_tmp;
@@ -1780,7 +1766,7 @@ int ssl3_send_server_key_exchange(SSL *s)
         } else
 #endif
 #ifndef OPENSSL_NO_DH
-        if (type & SSL_kDHE) {
+        if (type & (SSL_kDHE | SSL_kDHEPSK)) {
             if (s->cert->dh_tmp_auto) {
                 dhp = ssl_get_auto_dh(s);
                 if (dhp == NULL) {
@@ -1845,7 +1831,7 @@ int ssl3_send_server_key_exchange(SSL *s)
         } else
 #endif
 #ifndef OPENSSL_NO_EC
-        if (type & SSL_kECDHE) {
+        if (type & (SSL_kECDHE | SSL_kECDHEPSK)) {
             const EC_GROUP *group;
 
             ecdhp = cert->ecdh_tmp;
@@ -1961,7 +1947,7 @@ int ssl3_send_server_key_exchange(SSL *s)
              * additional bytes to encode the entire ServerECDHParams
              * structure.
              */
-            n = 4 + encodedlen;
+            n += 4 + encodedlen;
 
             /*
              * We'll generate the serverKeyExchange message explicitly so we
@@ -1973,14 +1959,6 @@ int ssl3_send_server_key_exchange(SSL *s)
             r[3] = NULL;
         } else
 #endif                          /* !OPENSSL_NO_EC */
-#ifndef OPENSSL_NO_PSK
-        if (type & SSL_kPSK) {
-            /*
-             * reserve size for record length and PSK identity hint
-             */
-            n += 2 + strlen(s->ctx->psk_identity_hint);
-        } else
-#endif                          /* !OPENSSL_NO_PSK */
 #ifndef OPENSSL_NO_SRP
         if (type & SSL_kSRP) {
             if ((s->srp_ctx.N == NULL) ||
@@ -2012,8 +1990,8 @@ int ssl3_send_server_key_exchange(SSL *s)
                 n += 2 + nr[i];
         }
 
-        if (!(s->s3->tmp.new_cipher->algorithm_auth & (SSL_aNULL | SSL_aSRP))
-            && !(s->s3->tmp.new_cipher->algorithm_mkey & SSL_kPSK)) {
+        if (!(s->s3->tmp.new_cipher->algorithm_auth & (SSL_aNULL|SSL_aSRP))
+            && !(s->s3->tmp.new_cipher->algorithm_mkey & SSL_PSK)) {
             if ((pkey = ssl_get_sign_pkey(s, s->s3->tmp.new_cipher, &md))
                 == NULL) {
                 al = SSL_AD_DECODE_ERROR;
@@ -2031,6 +2009,20 @@ int ssl3_send_server_key_exchange(SSL *s)
         }
         d = p = ssl_handshake_start(s);
 
+#ifndef OPENSSL_NO_PSK
+        if (type & SSL_PSK) {
+            /* copy PSK identity hint */
+            if (s->ctx->psk_identity_hint) {
+                s2n(strlen(s->ctx->psk_identity_hint), p);
+                strncpy((char *)p, s->ctx->psk_identity_hint,
+                        strlen(s->ctx->psk_identity_hint));
+                p += strlen(s->ctx->psk_identity_hint);
+            } else {
+                s2n(0, p);
+            }
+        }
+#endif
+
         for (i = 0; i < 4 && r[i] != NULL; i++) {
 #ifndef OPENSSL_NO_SRP
             if ((i == 2) && (type & SSL_kSRP)) {
@@ -2044,7 +2036,7 @@ int ssl3_send_server_key_exchange(SSL *s)
         }
 
 #ifndef OPENSSL_NO_EC
-        if (type & SSL_kECDHE) {
+        if (type & (SSL_kECDHE | SSL_kECDHEPSK)) {
             /*
              * XXX: For now, we only support named (not generic) curves. In
              * this situation, the serverKeyExchange message has: [1 byte
@@ -2066,16 +2058,6 @@ int ssl3_send_server_key_exchange(SSL *s)
         }
 #endif
 
-#ifndef OPENSSL_NO_PSK
-        if (type & SSL_kPSK) {
-            /* copy PSK identity hint */
-            s2n(strlen(s->ctx->psk_identity_hint), p);
-            strncpy((char *)p, s->ctx->psk_identity_hint,
-                    strlen(s->ctx->psk_identity_hint));
-            p += strlen(s->ctx->psk_identity_hint);
-        }
-#endif
-
         /* not anonymous */
         if (pkey != NULL) {
             /*
@@ -2259,7 +2241,6 @@ int ssl3_get_client_key_exchange(SSL *s)
     BIGNUM *pub = NULL;
     DH *dh_srvr, *dh_clnt = NULL;
 #endif
-
 #ifndef OPENSSL_NO_EC
     EC_KEY *srvr_ecdh = NULL;
     EVP_PKEY *clnt_pub_pkey = NULL;
@@ -2278,8 +2259,94 @@ int ssl3_get_client_key_exchange(SSL *s)
 
     alg_k = s->s3->tmp.new_cipher->algorithm_mkey;
 
+#ifndef OPENSSL_NO_PSK
+    /* For PSK parse and retrieve identity, obtain PSK key */
+    if (alg_k & SSL_PSK) {
+        unsigned char psk[PSK_MAX_PSK_LEN];
+        size_t psklen;
+        if (n < 2) {
+            al = SSL_AD_DECODE_ERROR;
+            SSLerr(SSL_F_SSL3_GET_CLIENT_KEY_EXCHANGE, SSL_R_LENGTH_MISMATCH);
+            goto f_err;
+        }
+        n2s(p, i);
+        if (i + 2 > n) {
+            al = SSL_AD_DECODE_ERROR;
+            SSLerr(SSL_F_SSL3_GET_CLIENT_KEY_EXCHANGE, SSL_R_LENGTH_MISMATCH);
+            goto f_err;
+        }
+        if (i > PSK_MAX_IDENTITY_LEN) {
+            al = SSL_AD_DECODE_ERROR;
+            SSLerr(SSL_F_SSL3_GET_CLIENT_KEY_EXCHANGE,
+                   SSL_R_DATA_LENGTH_TOO_LONG);
+            goto f_err;
+        }
+        if (s->psk_server_callback == NULL) {
+            al = SSL_AD_INTERNAL_ERROR;
+            SSLerr(SSL_F_SSL3_GET_CLIENT_KEY_EXCHANGE,
+                   SSL_R_PSK_NO_SERVER_CB);
+            goto f_err;
+        }
+
+        OPENSSL_free(s->session->psk_identity);
+        s->session->psk_identity = BUF_strndup((char *)p, i);
+
+        if (s->session->psk_identity == NULL) {
+            al = SSL_AD_INTERNAL_ERROR;
+            SSLerr(SSL_F_SSL3_GET_CLIENT_KEY_EXCHANGE,
+                   ERR_R_MALLOC_FAILURE);
+            goto f_err;
+        }
+
+        psklen = s->psk_server_callback(s, s->session->psk_identity,
+                                         psk, sizeof(psk));
+
+        if (psklen > PSK_MAX_PSK_LEN) {
+            al = SSL_AD_INTERNAL_ERROR;
+            SSLerr(SSL_F_SSL3_GET_CLIENT_KEY_EXCHANGE, ERR_R_INTERNAL_ERROR);
+            goto f_err;
+        } else if (psklen == 0) {
+            /*
+             * PSK related to the given identity not found
+             */
+            SSLerr(SSL_F_SSL3_GET_CLIENT_KEY_EXCHANGE,
+                   SSL_R_PSK_IDENTITY_NOT_FOUND);
+            al = SSL_AD_UNKNOWN_PSK_IDENTITY;
+            goto f_err;
+        }
+
+        OPENSSL_free(s->s3->tmp.psk);
+        s->s3->tmp.psk = BUF_memdup(psk, psklen);
+        OPENSSL_cleanse(psk, psklen);
+
+        if (s->s3->tmp.psk == NULL) {
+            al = SSL_AD_INTERNAL_ERROR;
+            SSLerr(SSL_F_SSL3_GET_CLIENT_KEY_EXCHANGE, ERR_R_MALLOC_FAILURE);
+            goto f_err;
+        }
+
+        s->s3->tmp.psklen = psklen;
+
+        n -= i + 2;
+        p += i;
+    }
+    if (alg_k & SSL_kPSK) {
+        /* Identity extracted earlier: should be nothing left */
+        if (n != 0) {
+            al = SSL_AD_HANDSHAKE_FAILURE;
+            SSLerr(SSL_F_SSL3_GET_CLIENT_KEY_EXCHANGE, SSL_R_LENGTH_MISMATCH);
+            goto f_err;
+        }
+        /* PSK handled by ssl_generate_master_secret */
+        if (!ssl_generate_master_secret(s, NULL, 0, 0)) {
+            al = SSL_AD_INTERNAL_ERROR;
+            SSLerr(SSL_F_SSL3_GET_CLIENT_KEY_EXCHANGE, ERR_R_INTERNAL_ERROR);
+            goto f_err;
+        }
+    } else
+#endif
 #ifndef OPENSSL_NO_RSA
-    if (alg_k & SSL_kRSA) {
+    if (alg_k & (SSL_kRSA | SSL_kRSAPSK)) {
         unsigned char rand_premaster_secret[SSL_MAX_MASTER_KEY_LENGTH];
         int decrypt_len;
         unsigned char decrypt_good, version_good;
@@ -2410,15 +2477,7 @@ int ssl3_get_client_key_exchange(SSL *s)
                                           rand_premaster_secret[j]);
         }
 
-        s->session->master_key_length =
-            s->method->ssl3_enc->generate_master_secret(s,
-                                                        s->
-                                                        session->master_key,
-                                                        p,
-                                                        sizeof
-                                                        (rand_premaster_secret));
-        OPENSSL_cleanse(p, sizeof(rand_premaster_secret));
-        if (s->session->master_key_length < 0) {
+        if (!ssl_generate_master_secret(s, p, sizeof(rand_premaster_secret), 0)) {
             al = SSL_AD_INTERNAL_ERROR;
             SSLerr(SSL_F_SSL3_GET_CLIENT_KEY_EXCHANGE, ERR_R_INTERNAL_ERROR);
             goto f_err;
@@ -2426,13 +2485,13 @@ int ssl3_get_client_key_exchange(SSL *s)
     } else
 #endif
 #ifndef OPENSSL_NO_DH
-    if (alg_k & (SSL_kDHE | SSL_kDHr | SSL_kDHd)) {
+    if (alg_k & (SSL_kDHE | SSL_kDHr | SSL_kDHd | SSL_kDHEPSK)) {
         int idx = -1;
         EVP_PKEY *skey = NULL;
         if (n > 1) {
             n2s(p, i);
         } else {
-            if (alg_k & SSL_kDHE) {
+            if (alg_k & (SSL_kDHE | SSL_kDHEPSK)) {
                 al = SSL_AD_HANDSHAKE_FAILURE;
                 SSLerr(SSL_F_SSL3_GET_CLIENT_KEY_EXCHANGE,
                        SSL_R_DH_PUBLIC_VALUE_LENGTH_IS_WRONG);
@@ -2509,13 +2568,7 @@ int ssl3_get_client_key_exchange(SSL *s)
         else
             BN_clear_free(pub);
         pub = NULL;
-        s->session->master_key_length =
-            s->method->ssl3_enc->generate_master_secret(s,
-                                                        s->
-                                                        session->master_key,
-                                                        p, i);
-        OPENSSL_cleanse(p, i);
-        if (s->session->master_key_length < 0) {
+        if (!ssl_generate_master_secret(s, p, i, 0)) {
             al = SSL_AD_INTERNAL_ERROR;
             SSLerr(SSL_F_SSL3_GET_CLIENT_KEY_EXCHANGE, ERR_R_INTERNAL_ERROR);
             goto f_err;
@@ -2526,7 +2579,7 @@ int ssl3_get_client_key_exchange(SSL *s)
 #endif
 
 #ifndef OPENSSL_NO_EC
-    if (alg_k & (SSL_kECDHE | SSL_kECDHr | SSL_kECDHe)) {
+    if (alg_k & (SSL_kECDHE | SSL_kECDHr | SSL_kECDHe | SSL_kECDHEPSK)) {
         int ret = 1;
         int field_size = 0;
         const EC_KEY *tkey;
@@ -2569,7 +2622,7 @@ int ssl3_get_client_key_exchange(SSL *s)
         if (n == 0L) {
             /* Client Publickey was in Client Certificate */
 
-            if (alg_k & SSL_kECDHE) {
+            if (alg_k & (SSL_kECDHE | SSL_kECDHEPSK)) {
                 al = SSL_AD_HANDSHAKE_FAILURE;
                 SSLerr(SSL_F_SSL3_GET_CLIENT_KEY_EXCHANGE,
                        SSL_R_MISSING_TMP_ECDH_KEY);
@@ -2647,15 +2700,7 @@ int ssl3_get_client_key_exchange(SSL *s)
         EC_KEY_free(s->s3->tmp.ecdh);
         s->s3->tmp.ecdh = NULL;
 
-        /* Compute the master secret */
-        s->session->master_key_length =
-            s->method->ssl3_enc->generate_master_secret(s,
-                                                        s->
-                                                        session->master_key,
-                                                        p, i);
-
-        OPENSSL_cleanse(p, i);
-        if (s->session->master_key_length < 0) {
+        if (!ssl_generate_master_secret(s, p, i, 0)) {
             al = SSL_AD_INTERNAL_ERROR;
             SSLerr(SSL_F_SSL3_GET_CLIENT_KEY_EXCHANGE, ERR_R_INTERNAL_ERROR);
             goto f_err;
@@ -2663,97 +2708,6 @@ int ssl3_get_client_key_exchange(SSL *s)
         return (ret);
     } else
 #endif
-#ifndef OPENSSL_NO_PSK
-    if (alg_k & SSL_kPSK) {
-        unsigned char *t = NULL;
-        unsigned char psk_or_pre_ms[PSK_MAX_PSK_LEN * 2 + 4];
-        unsigned int pre_ms_len = 0, psk_len = 0;
-        int psk_err = 1;
-        char tmp_id[PSK_MAX_IDENTITY_LEN + 1];
-
-        al = SSL_AD_HANDSHAKE_FAILURE;
-
-        n2s(p, i);
-        if (n != i + 2) {
-            SSLerr(SSL_F_SSL3_GET_CLIENT_KEY_EXCHANGE, SSL_R_LENGTH_MISMATCH);
-            goto psk_err;
-        }
-        if (i > PSK_MAX_IDENTITY_LEN) {
-            SSLerr(SSL_F_SSL3_GET_CLIENT_KEY_EXCHANGE,
-                   SSL_R_DATA_LENGTH_TOO_LONG);
-            goto psk_err;
-        }
-        if (s->psk_server_callback == NULL) {
-            SSLerr(SSL_F_SSL3_GET_CLIENT_KEY_EXCHANGE,
-                   SSL_R_PSK_NO_SERVER_CB);
-            goto psk_err;
-        }
-
-        /*
-         * Create guaranteed NULL-terminated identity string for the callback
-         */
-        memcpy(tmp_id, p, i);
-        memset(tmp_id + i, 0, PSK_MAX_IDENTITY_LEN + 1 - i);
-        psk_len = s->psk_server_callback(s, tmp_id,
-                                         psk_or_pre_ms,
-                                         sizeof(psk_or_pre_ms));
-        OPENSSL_cleanse(tmp_id, PSK_MAX_IDENTITY_LEN + 1);
-
-        if (psk_len > PSK_MAX_PSK_LEN) {
-            SSLerr(SSL_F_SSL3_GET_CLIENT_KEY_EXCHANGE, ERR_R_INTERNAL_ERROR);
-            goto psk_err;
-        } else if (psk_len == 0) {
-            /*
-             * PSK related to the given identity not found
-             */
-            SSLerr(SSL_F_SSL3_GET_CLIENT_KEY_EXCHANGE,
-                   SSL_R_PSK_IDENTITY_NOT_FOUND);
-            al = SSL_AD_UNKNOWN_PSK_IDENTITY;
-            goto psk_err;
-        }
-
-        /* create PSK pre_master_secret */
-        pre_ms_len = 2 + psk_len + 2 + psk_len;
-        t = psk_or_pre_ms;
-        memmove(psk_or_pre_ms + psk_len + 4, psk_or_pre_ms, psk_len);
-        s2n(psk_len, t);
-        memset(t, 0, psk_len);
-        t += psk_len;
-        s2n(psk_len, t);
-
-        OPENSSL_free(s->session->psk_identity);
-        s->session->psk_identity = BUF_strdup((char *)p);
-        if (s->session->psk_identity == NULL) {
-            SSLerr(SSL_F_SSL3_GET_CLIENT_KEY_EXCHANGE, ERR_R_MALLOC_FAILURE);
-            goto psk_err;
-        }
-
-        OPENSSL_free(s->session->psk_identity_hint);
-        s->session->psk_identity_hint = BUF_strdup(s->ctx->psk_identity_hint);
-        if (s->ctx->psk_identity_hint != NULL &&
-            s->session->psk_identity_hint == NULL) {
-            SSLerr(SSL_F_SSL3_GET_CLIENT_KEY_EXCHANGE, ERR_R_MALLOC_FAILURE);
-            goto psk_err;
-        }
-
-        s->session->master_key_length =
-            s->method->ssl3_enc->generate_master_secret(s,
-                                                        s->
-                                                        session->master_key,
-                                                        psk_or_pre_ms,
-                                                        pre_ms_len);
-        if (s->session->master_key_length < 0) {
-            al = SSL_AD_INTERNAL_ERROR;
-            SSLerr(SSL_F_SSL3_GET_CLIENT_KEY_EXCHANGE, ERR_R_INTERNAL_ERROR);
-            goto psk_err;
-        }
-        psk_err = 0;
- psk_err:
-        OPENSSL_cleanse(psk_or_pre_ms, sizeof(psk_or_pre_ms));
-        if (psk_err != 0)
-            goto f_err;
-    } else
-#endif
 #ifndef OPENSSL_NO_SRP
     if (alg_k & SSL_kSRP) {
         int param_len;
@@ -2784,9 +2738,7 @@ int ssl3_get_client_key_exchange(SSL *s)
             goto err;
         }
 
-        if ((s->session->master_key_length =
-             SRP_generate_server_master_secret(s,
-                                               s->session->master_key)) < 0) {
+        if (!srp_generate_server_master_secret(s)) {
             SSLerr(SSL_F_SSL3_GET_CLIENT_KEY_EXCHANGE, ERR_R_INTERNAL_ERROR);
             goto err;
         }
@@ -2842,12 +2794,8 @@ int ssl3_get_client_key_exchange(SSL *s)
             goto gerr;
         }
         /* Generate master secret */
-        s->session->master_key_length =
-            s->method->ssl3_enc->generate_master_secret(s,
-                                                        s->
-                                                        session->master_key,
-                                                        premaster_secret, 32);
-        if (s->session->master_key_length < 0) {
+        if (!ssl_generate_master_secret(s, premaster_secret,
+                                        sizeof(premaster_secret), 0)) {
             al = SSL_AD_INTERNAL_ERROR;
             SSLerr(SSL_F_SSL3_GET_CLIENT_KEY_EXCHANGE, ERR_R_INTERNAL_ERROR);
             goto f_err;
@@ -2881,6 +2829,10 @@ int ssl3_get_client_key_exchange(SSL *s)
     EC_POINT_free(clnt_ecpoint);
     EC_KEY_free(srvr_ecdh);
     BN_CTX_free(bn_ctx);
+#endif
+#ifndef OPENSSL_NO_PSK
+    OPENSSL_clear_free(s->s3->tmp.psk, s->s3->tmp.psklen);
+    s->s3->tmp.psk = NULL;
 #endif
     s->state = SSL_ST_ERR;
     return (-1);
@@ -3076,7 +3028,6 @@ int ssl3_get_cert_verify(SSL *s)
  end:
     BIO_free(s->s3->handshake_buffer);
     s->s3->handshake_buffer = NULL;
-    s->s3->flags &= ~TLS1_FLAGS_KEEP_HANDSHAKE;
     EVP_MD_CTX_cleanup(&mctx);
     EVP_PKEY_free(pkey);
     return (ret);
@@ -3184,7 +3135,7 @@ int ssl3_get_client_certificate(SSL *s)
             goto f_err;
         }
         /* No client certificate so digest cached records */
-        if (s->s3->handshake_buffer && !ssl3_digest_cached_records(s)) {
+        if (s->s3->handshake_buffer && !ssl3_digest_cached_records(s, 0)) {
             al = SSL_AD_INTERNAL_ERROR;
             goto f_err;
         }
@@ -3216,19 +3167,8 @@ int ssl3_get_client_certificate(SSL *s)
     s->session->peer = sk_X509_shift(sk);
     s->session->verify_result = s->verify_result;
 
-    /*
-     * With the current implementation, sess_cert will always be NULL when we
-     * arrive here.
-     */
-    if (s->session->sess_cert == NULL) {
-        s->session->sess_cert = ssl_sess_cert_new();
-        if (s->session->sess_cert == NULL) {
-            SSLerr(SSL_F_SSL3_GET_CLIENT_CERTIFICATE, ERR_R_MALLOC_FAILURE);
-            goto done;
-        }
-    }
-    sk_X509_pop_free(s->session->sess_cert->cert_chain, X509_free);
-    s->session->sess_cert->cert_chain = sk;
+    sk_X509_pop_free(s->session->peer_chain, X509_free);
+    s->session->peer_chain = sk;
     /*
      * Inconsistency alert: cert_chain does *not* include the peer's own
      * certificate, while we do include it in s3_clnt.c
@@ -3270,7 +3210,6 @@ int ssl3_send_server_certificate(SSL *s)
     return ssl_do_write(s);
 }
 
-#ifndef OPENSSL_NO_TLSEXT
 /* send a new session ticket (not necessarily for a new session) */
 int ssl3_send_newsession_ticket(SSL *s)
 {
@@ -3458,7 +3397,7 @@ int ssl3_send_cert_status(SSL *s)
     return (ssl3_do_write(s, SSL3_RT_HANDSHAKE));
 }
 
-# ifndef OPENSSL_NO_NEXTPROTONEG
+#ifndef OPENSSL_NO_NEXTPROTONEG
 /*
  * ssl3_get_next_proto reads a Next Protocol Negotiation handshake message.
  * It sets the next_proto member in s if found
@@ -3537,8 +3476,6 @@ int ssl3_get_next_proto(SSL *s)
 
     return 1;
 }
-# endif
-
 #endif
 
 #define SSLV2_CIPHER_LEN    3
@@ -3576,13 +3513,13 @@ STACK_OF(SSL_CIPHER) *ssl_bytes_to_cipher_list(SSL *s, unsigned char *p,
         sk_SSL_CIPHER_zero(sk);
     }
 
-    OPENSSL_free(s->cert->ciphers_raw);
-    s->cert->ciphers_raw = BUF_memdup(p, num);
-    if (s->cert->ciphers_raw == NULL) {
+    OPENSSL_free(s->s3->tmp.ciphers_raw);
+    s->s3->tmp.ciphers_raw = BUF_memdup(p, num);
+    if (s->s3->tmp.ciphers_raw == NULL) {
         SSLerr(SSL_F_SSL_BYTES_TO_CIPHER_LIST, ERR_R_MALLOC_FAILURE);
         goto err;
     }
-    s->cert->ciphers_rawlen = (size_t)num;
+    s->s3->tmp.ciphers_rawlen = (size_t)num;
 
     for (i = 0; i < num; i += n) {
         /* Check for TLS_EMPTY_RENEGOTIATION_INFO_SCSV */