Extended PSK client support.
[openssl.git] / ssl / s3_clnt.c
index 04af8514d8b327b38e1508fe9e8b011548c65897..d5bcf5428032f2791bd0198739f300be0e1e4c1a 100644 (file)
@@ -331,10 +331,8 @@ int ssl3_connect(SSL *s)
 
             /* Check if it is anon DH/ECDH, SRP auth */
             /* or PSK */
-            if (!
-                (s->s3->tmp.
-                 new_cipher->algorithm_auth & (SSL_aNULL | SSL_aSRP))
-                    && !(s->s3->tmp.new_cipher->algorithm_mkey & SSL_kPSK)) {
+            if (!(s->s3->tmp.new_cipher->algorithm_auth &
+                    (SSL_aNULL | SSL_aSRP | SSL_aPSK))) {
                 ret = ssl3_get_server_certificate(s);
                 if (ret <= 0)
                     goto end;
@@ -1414,7 +1412,7 @@ int ssl3_get_key_exchange(SSL *s)
          * Can't skip server key exchange if this is an ephemeral
          * ciphersuite.
          */
-        if (alg_k & (SSL_kDHE | SSL_kECDHE)) {
+        if (alg_k & (SSL_kDHE | SSL_kECDHE | SSL_kDHEPSK | SSL_kECDHEPSK)) {
             SSLerr(SSL_F_SSL3_GET_KEY_EXCHANGE, SSL_R_UNEXPECTED_MESSAGE);
             al = SSL_AD_UNEXPECTED_MESSAGE;
             goto f_err;
@@ -1447,8 +1445,8 @@ int ssl3_get_key_exchange(SSL *s)
     al = SSL_AD_DECODE_ERROR;
 
 #ifndef OPENSSL_NO_PSK
-    if (alg_k & SSL_kPSK) {
-        char tmp_id_hint[PSK_MAX_IDENTITY_LEN + 1];
+    /* PSK ciphersuites are preceded by an identity hint */
+    if (alg_k & SSL_PSK) {
 
         param_len = 2;
         if (param_len > n) {
@@ -1475,23 +1473,24 @@ int ssl3_get_key_exchange(SSL *s)
         }
         param_len += i;
 
-        /*
-         * If received PSK identity hint contains NULL characters, the hint
-         * is truncated from the first NULL. p may not be ending with NULL,
-         * so create a NULL-terminated string.
-         */
-        memcpy(tmp_id_hint, p, i);
-        memset(tmp_id_hint + i, 0, PSK_MAX_IDENTITY_LEN + 1 - i);
         OPENSSL_free(s->session->psk_identity_hint);
-        s->session->psk_identity_hint = BUF_strdup(tmp_id_hint);
-        if (s->session->psk_identity_hint == NULL) {
-            al = SSL_AD_HANDSHAKE_FAILURE;
-            SSLerr(SSL_F_SSL3_GET_KEY_EXCHANGE, ERR_R_MALLOC_FAILURE);
-            goto f_err;
+        if (i != 0) {
+            s->session->psk_identity_hint = BUF_strndup((char *)p, i);
+            if (s->session->psk_identity_hint == NULL) {
+                al = SSL_AD_HANDSHAKE_FAILURE;
+                SSLerr(SSL_F_SSL3_GET_KEY_EXCHANGE, ERR_R_MALLOC_FAILURE);
+                goto f_err;
+            }
+        } else {
+            s->session->psk_identity_hint = NULL;
         }
 
         p += i;
         n -= param_len;
+    }
+
+    /* Nothing else to do for plain PSK or RSAPSK */
+    if (alg_k & (SSL_kPSK | SSL_kRSAPSK)) {
     } else
 #endif                          /* !OPENSSL_NO_PSK */
 #ifndef OPENSSL_NO_SRP
@@ -1661,7 +1660,7 @@ int ssl3_get_key_exchange(SSL *s)
     if (0) ;
 #endif
 #ifndef OPENSSL_NO_DH
-    else if (alg_k & SSL_kDHE) {
+    else if (alg_k & (SSL_kDHE | SSL_kDHEPSK)) {
         if ((dh = DH_new()) == NULL) {
             SSLerr(SSL_F_SSL3_GET_KEY_EXCHANGE, ERR_R_DH_LIB);
             goto err;
@@ -1742,7 +1741,7 @@ int ssl3_get_key_exchange(SSL *s)
 #endif                          /* !OPENSSL_NO_DH */
 
 #ifndef OPENSSL_NO_EC
-    else if (alg_k & SSL_kECDHE) {
+    else if (alg_k & (SSL_kECDHE | SSL_kECDHEPSK)) {
         EC_GROUP *ngroup;
         const EC_GROUP *group;
 
@@ -1945,8 +1944,8 @@ int ssl3_get_key_exchange(SSL *s)
             }
         }
     } else {
-        /* aNULL, aSRP or kPSK do not need public keys */
-        if (!(alg_a & (SSL_aNULL | SSL_aSRP)) && !(alg_k & SSL_kPSK)) {
+        /* aNULL, aSRP or PSK do not need public keys */
+        if (!(alg_a & (SSL_aNULL | SSL_aSRP)) && !(alg_k & SSL_PSK)) {
             /* Might be wrong key type, check it */
             if (ssl3_check_cert_and_algorithm(s))
                 /* Otherwise this shouldn't happen */
@@ -2329,6 +2328,9 @@ int ssl3_send_client_key_exchange(SSL *s)
 {
     unsigned char *p;
     int n;
+#ifndef OPENSSL_NO_PSK
+    size_t pskhdrlen = 0;
+#endif
     unsigned long alg_k;
 #ifndef OPENSSL_NO_RSA
     unsigned char *q;
@@ -2344,17 +2346,89 @@ int ssl3_send_client_key_exchange(SSL *s)
 #endif
     unsigned char *pms = NULL;
     size_t pmslen = 0;
+    alg_k = s->s3->tmp.new_cipher->algorithm_mkey;
 
     if (s->state == SSL3_ST_CW_KEY_EXCH_A) {
         p = ssl_handshake_start(s);
 
-        alg_k = s->s3->tmp.new_cipher->algorithm_mkey;
+
+#ifndef OPENSSL_NO_PSK
+        if (alg_k & SSL_PSK) {
+            int psk_err = 1;
+            /*
+             * The callback needs PSK_MAX_IDENTITY_LEN + 1 bytes to return a
+             * \0-terminated identity. The last byte is for us for simulating
+             * strnlen.
+             */
+            char identity[PSK_MAX_IDENTITY_LEN + 1];
+            size_t identitylen;
+            unsigned char psk[PSK_MAX_PSK_LEN];
+            size_t psklen;
+
+            if (s->psk_client_callback == NULL) {
+                SSLerr(SSL_F_SSL3_SEND_CLIENT_KEY_EXCHANGE,
+                       SSL_R_PSK_NO_CLIENT_CB);
+                goto err;
+            }
+
+            memset(identity, 0, sizeof(identity));
+
+            psklen = s->psk_client_callback(s, s->session->psk_identity_hint,
+                                            identity, sizeof(identity) - 1,
+                                            psk, sizeof(psk));
+
+            if (psklen > PSK_MAX_PSK_LEN) {
+                SSLerr(SSL_F_SSL3_SEND_CLIENT_KEY_EXCHANGE,
+                       ERR_R_INTERNAL_ERROR);
+                goto psk_err;
+            } else if (psklen == 0) {
+                SSLerr(SSL_F_SSL3_SEND_CLIENT_KEY_EXCHANGE,
+                       SSL_R_PSK_IDENTITY_NOT_FOUND);
+                goto psk_err;
+            }
+
+            OPENSSL_free(s->s3->tmp.psk);
+            s->s3->tmp.psk = BUF_memdup(psk, psklen);
+            OPENSSL_cleanse(psk, psklen);
+
+            if (s->s3->tmp.psk == NULL)
+                goto memerr;
+
+            s->s3->tmp.psklen = psklen;
+
+            identitylen = strlen(identity);
+            if (identitylen > PSK_MAX_IDENTITY_LEN) {
+                SSLerr(SSL_F_SSL3_SEND_CLIENT_KEY_EXCHANGE,
+                       ERR_R_INTERNAL_ERROR);
+                goto psk_err;
+            }
+            OPENSSL_free(s->session->psk_identity);
+            s->session->psk_identity = BUF_strdup(identity);
+            if (s->session->psk_identity == NULL)
+                goto memerr;
+
+            s2n(identitylen, p);
+            memcpy(p, identity, identitylen);
+            pskhdrlen = 2 + identitylen;
+            p += identitylen;
+            psk_err = 0;
+ psk_err:
+            OPENSSL_cleanse(identity, sizeof(identity));
+            if (psk_err != 0) {
+                ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_HANDSHAKE_FAILURE);
+                goto err;
+            }
+        }
+        if (alg_k & SSL_kPSK) {
+            n = 0;
+        } else
+#endif
 
         /* Fool emacs indentation */
         if (0) {
         }
 #ifndef OPENSSL_NO_RSA
-        else if (alg_k & SSL_kRSA) {
+        else if (alg_k & (SSL_kRSA | SSL_kRSAPSK)) {
             RSA *rsa;
             pmslen = SSL_MAX_MASTER_KEY_LENGTH;
             pms = OPENSSL_malloc(pmslen);
@@ -2414,7 +2488,7 @@ int ssl3_send_client_key_exchange(SSL *s)
         }
 #endif
 #ifndef OPENSSL_NO_DH
-        else if (alg_k & (SSL_kDHE | SSL_kDHr | SSL_kDHd)) {
+        else if (alg_k & (SSL_kDHE | SSL_kDHr | SSL_kDHd | SSL_kDHEPSK)) {
             DH *dh_srvr, *dh_clnt;
             if (s->s3->peer_dh_tmp != NULL)
                 dh_srvr = s->s3->peer_dh_tmp;
@@ -2493,7 +2567,7 @@ int ssl3_send_client_key_exchange(SSL *s)
 #endif
 
 #ifndef OPENSSL_NO_EC
-        else if (alg_k & (SSL_kECDHE | SSL_kECDHr | SSL_kECDHe)) {
+        else if (alg_k & (SSL_kECDHE | SSL_kECDHr | SSL_kECDHe | SSL_kECDHEPSK)) {
             const EC_GROUP *srvr_group = NULL;
             EC_KEY *tkey;
             int ecdh_clnt_cert = 0;
@@ -2780,82 +2854,6 @@ int ssl3_send_client_key_exchange(SSL *s)
                 goto err;
             }
         }
-#endif
-#ifndef OPENSSL_NO_PSK
-        else if (alg_k & SSL_kPSK) {
-            /*
-             * The callback needs PSK_MAX_IDENTITY_LEN + 1 bytes to return a
-             * \0-terminated identity. The last byte is for us for simulating
-             * strnlen.
-             */
-            char identity[PSK_MAX_IDENTITY_LEN + 2];
-            size_t identity_len;
-            unsigned char *t = NULL;
-            unsigned int psk_len = 0;
-            int psk_err = 1;
-
-            n = 0;
-            if (s->psk_client_callback == NULL) {
-                SSLerr(SSL_F_SSL3_SEND_CLIENT_KEY_EXCHANGE,
-                       SSL_R_PSK_NO_CLIENT_CB);
-                goto err;
-            }
-
-            memset(identity, 0, sizeof(identity));
-            /* Allocate maximum size buffer */
-            pmslen = PSK_MAX_PSK_LEN * 2 + 4;
-            pms = OPENSSL_malloc(pmslen);
-            if (!pms)
-                goto memerr;
-
-            psk_len = s->psk_client_callback(s, s->session->psk_identity_hint,
-                                             identity, sizeof(identity) - 1,
-                                             pms, pmslen);
-            if (psk_len > PSK_MAX_PSK_LEN) {
-                SSLerr(SSL_F_SSL3_SEND_CLIENT_KEY_EXCHANGE,
-                       ERR_R_INTERNAL_ERROR);
-                goto psk_err;
-            } else if (psk_len == 0) {
-                SSLerr(SSL_F_SSL3_SEND_CLIENT_KEY_EXCHANGE,
-                       SSL_R_PSK_IDENTITY_NOT_FOUND);
-                goto psk_err;
-            }
-            /* Change pmslen to real length */
-            pmslen = 2 + psk_len + 2 + psk_len;
-            identity[PSK_MAX_IDENTITY_LEN + 1] = '\0';
-            identity_len = strlen(identity);
-            if (identity_len > PSK_MAX_IDENTITY_LEN) {
-                SSLerr(SSL_F_SSL3_SEND_CLIENT_KEY_EXCHANGE,
-                       ERR_R_INTERNAL_ERROR);
-                goto psk_err;
-            }
-            /* create PSK pre_master_secret */
-            t = pms;
-            memmove(pms + psk_len + 4, pms, psk_len);
-            s2n(psk_len, t);
-            memset(t, 0, psk_len);
-            t += psk_len;
-            s2n(psk_len, t);
-
-            OPENSSL_free(s->session->psk_identity);
-            s->session->psk_identity = BUF_strdup(identity);
-            if (s->session->psk_identity == NULL) {
-                SSLerr(SSL_F_SSL3_SEND_CLIENT_KEY_EXCHANGE,
-                       ERR_R_MALLOC_FAILURE);
-                goto psk_err;
-            }
-
-            s2n(identity_len, p);
-            memcpy(p, identity, identity_len);
-            n = 2 + identity_len;
-            psk_err = 0;
- psk_err:
-            OPENSSL_cleanse(identity, sizeof(identity));
-            if (psk_err != 0) {
-                ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_HANDSHAKE_FAILURE);
-                goto err;
-            }
-        }
 #endif
         else {
             ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_HANDSHAKE_FAILURE);
@@ -2863,6 +2861,10 @@ int ssl3_send_client_key_exchange(SSL *s)
             goto err;
         }
 
+#ifndef OPENSSL_NO_PSK
+        n += pskhdrlen;
+#endif
+
         if (!ssl_set_handshake_header(s, SSL3_MT_CLIENT_KEY_EXCHANGE, n)) {
             ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_HANDSHAKE_FAILURE);
             SSLerr(SSL_F_SSL3_SEND_CLIENT_KEY_EXCHANGE, ERR_R_INTERNAL_ERROR);
@@ -2876,7 +2878,7 @@ int ssl3_send_client_key_exchange(SSL *s)
     n = ssl_do_write(s);
 #ifndef OPENSSL_NO_SRP
     /* Check for SRP */
-    if (s->s3->tmp.new_cipher->algorithm_mkey & SSL_kSRP) {
+    if (alg_k & SSL_kSRP) {
         /*
          * If everything written generate master key: no need to save PMS as
          * srp_generate_client_master_secret generates it internally.
@@ -2900,7 +2902,7 @@ int ssl3_send_client_key_exchange(SSL *s)
             pms = s->s3->tmp.pms;
             pmslen = s->s3->tmp.pmslen;
         }
-        if (pms == NULL) {
+        if (pms == NULL && !(alg_k & SSL_kPSK)) {
             ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_INTERNAL_ERROR);
             SSLerr(SSL_F_SSL3_SEND_CLIENT_KEY_EXCHANGE, ERR_R_MALLOC_FAILURE);
             goto err;
@@ -2923,6 +2925,10 @@ int ssl3_send_client_key_exchange(SSL *s)
     OPENSSL_free(encodedPoint);
     EC_KEY_free(clnt_ecdh);
     EVP_PKEY_free(srvr_pub_pkey);
+#endif
+#ifndef OPENSSL_NO_PSK
+    OPENSSL_clear_free(s->s3->tmp.psk, s->s3->tmp.psklen);
+    s->s3->tmp.psk = NULL;
 #endif
     s->state = SSL_ST_ERR;
     return (-1);
@@ -3261,7 +3267,7 @@ int ssl3_check_cert_and_algorithm(SSL *s)
     }
 #endif
 #ifndef OPENSSL_NO_RSA
-    if (alg_k & SSL_kRSA) {
+    if (alg_k & (SSL_kRSA | SSL_kRSAPSK)) {
         if (!SSL_C_IS_EXPORT(s->s3->tmp.new_cipher) &&
             !has_bits(i, EVP_PK_RSA | EVP_PKT_ENC)) {
             SSLerr(SSL_F_SSL3_CHECK_CERT_AND_ALGORITHM,