Extended PSK client support.
authorDr. Stephen Henson <steve@openssl.org>
Sun, 28 Jun 2015 16:15:10 +0000 (17:15 +0100)
committerDr. Stephen Henson <steve@openssl.org>
Thu, 30 Jul 2015 13:43:35 +0000 (14:43 +0100)
Add support for RSAPSK, DHEPSK and ECDHEPSK client side.

Update various checks to ensure certificate and server key exchange messages
are only expected when required.

Update message handling. PSK server key exchange parsing now expects an
identity hint prefix for all PSK server key exchange messages. PSK
client key exchange message requests PSK identity and key for all PSK
key exchange ciphersuites and includes identity in message.

Update flags for RSA, DH and ECDH so they are also used in PSK.

Reviewed-by: Matt Caswell <matt@openssl.org>
ssl/s3_clnt.c

index 04af851..d5bcf54 100644 (file)
@@ -331,10 +331,8 @@ int ssl3_connect(SSL *s)
 
             /* Check if it is anon DH/ECDH, SRP auth */
             /* or PSK */
-            if (!
-                (s->s3->tmp.
-                 new_cipher->algorithm_auth & (SSL_aNULL | SSL_aSRP))
-                    && !(s->s3->tmp.new_cipher->algorithm_mkey & SSL_kPSK)) {
+            if (!(s->s3->tmp.new_cipher->algorithm_auth &
+                    (SSL_aNULL | SSL_aSRP | SSL_aPSK))) {
                 ret = ssl3_get_server_certificate(s);
                 if (ret <= 0)
                     goto end;
@@ -1414,7 +1412,7 @@ int ssl3_get_key_exchange(SSL *s)
          * Can't skip server key exchange if this is an ephemeral
          * ciphersuite.
          */
-        if (alg_k & (SSL_kDHE | SSL_kECDHE)) {
+        if (alg_k & (SSL_kDHE | SSL_kECDHE | SSL_kDHEPSK | SSL_kECDHEPSK)) {
             SSLerr(SSL_F_SSL3_GET_KEY_EXCHANGE, SSL_R_UNEXPECTED_MESSAGE);
             al = SSL_AD_UNEXPECTED_MESSAGE;
             goto f_err;
@@ -1447,8 +1445,8 @@ int ssl3_get_key_exchange(SSL *s)
     al = SSL_AD_DECODE_ERROR;
 
 #ifndef OPENSSL_NO_PSK
-    if (alg_k & SSL_kPSK) {
-        char tmp_id_hint[PSK_MAX_IDENTITY_LEN + 1];
+    /* PSK ciphersuites are preceded by an identity hint */
+    if (alg_k & SSL_PSK) {
 
         param_len = 2;
         if (param_len > n) {
@@ -1475,23 +1473,24 @@ int ssl3_get_key_exchange(SSL *s)
         }
         param_len += i;
 
-        /*
-         * If received PSK identity hint contains NULL characters, the hint
-         * is truncated from the first NULL. p may not be ending with NULL,
-         * so create a NULL-terminated string.
-         */
-        memcpy(tmp_id_hint, p, i);
-        memset(tmp_id_hint + i, 0, PSK_MAX_IDENTITY_LEN + 1 - i);
         OPENSSL_free(s->session->psk_identity_hint);
-        s->session->psk_identity_hint = BUF_strdup(tmp_id_hint);
-        if (s->session->psk_identity_hint == NULL) {
-            al = SSL_AD_HANDSHAKE_FAILURE;
-            SSLerr(SSL_F_SSL3_GET_KEY_EXCHANGE, ERR_R_MALLOC_FAILURE);
-            goto f_err;
+        if (i != 0) {
+            s->session->psk_identity_hint = BUF_strndup((char *)p, i);
+            if (s->session->psk_identity_hint == NULL) {
+                al = SSL_AD_HANDSHAKE_FAILURE;
+                SSLerr(SSL_F_SSL3_GET_KEY_EXCHANGE, ERR_R_MALLOC_FAILURE);
+                goto f_err;
+            }
+        } else {
+            s->session->psk_identity_hint = NULL;
         }
 
         p += i;
         n -= param_len;
+    }
+
+    /* Nothing else to do for plain PSK or RSAPSK */
+    if (alg_k & (SSL_kPSK | SSL_kRSAPSK)) {
     } else
 #endif                          /* !OPENSSL_NO_PSK */
 #ifndef OPENSSL_NO_SRP
@@ -1661,7 +1660,7 @@ int ssl3_get_key_exchange(SSL *s)
     if (0) ;
 #endif
 #ifndef OPENSSL_NO_DH
-    else if (alg_k & SSL_kDHE) {
+    else if (alg_k & (SSL_kDHE | SSL_kDHEPSK)) {
         if ((dh = DH_new()) == NULL) {
             SSLerr(SSL_F_SSL3_GET_KEY_EXCHANGE, ERR_R_DH_LIB);
             goto err;
@@ -1742,7 +1741,7 @@ int ssl3_get_key_exchange(SSL *s)
 #endif                          /* !OPENSSL_NO_DH */
 
 #ifndef OPENSSL_NO_EC
-    else if (alg_k & SSL_kECDHE) {
+    else if (alg_k & (SSL_kECDHE | SSL_kECDHEPSK)) {
         EC_GROUP *ngroup;
         const EC_GROUP *group;
 
@@ -1945,8 +1944,8 @@ int ssl3_get_key_exchange(SSL *s)
             }
         }
     } else {
-        /* aNULL, aSRP or kPSK do not need public keys */
-        if (!(alg_a & (SSL_aNULL | SSL_aSRP)) && !(alg_k & SSL_kPSK)) {
+        /* aNULL, aSRP or PSK do not need public keys */
+        if (!(alg_a & (SSL_aNULL | SSL_aSRP)) && !(alg_k & SSL_PSK)) {
             /* Might be wrong key type, check it */
             if (ssl3_check_cert_and_algorithm(s))
                 /* Otherwise this shouldn't happen */
@@ -2329,6 +2328,9 @@ int ssl3_send_client_key_exchange(SSL *s)
 {
     unsigned char *p;
     int n;
+#ifndef OPENSSL_NO_PSK
+    size_t pskhdrlen = 0;
+#endif
     unsigned long alg_k;
 #ifndef OPENSSL_NO_RSA
     unsigned char *q;
@@ -2344,17 +2346,89 @@ int ssl3_send_client_key_exchange(SSL *s)
 #endif
     unsigned char *pms = NULL;
     size_t pmslen = 0;
+    alg_k = s->s3->tmp.new_cipher->algorithm_mkey;
 
     if (s->state == SSL3_ST_CW_KEY_EXCH_A) {
         p = ssl_handshake_start(s);
 
-        alg_k = s->s3->tmp.new_cipher->algorithm_mkey;
+
+#ifndef OPENSSL_NO_PSK
+        if (alg_k & SSL_PSK) {
+            int psk_err = 1;
+            /*
+             * The callback needs PSK_MAX_IDENTITY_LEN + 1 bytes to return a
+             * \0-terminated identity. The last byte is for us for simulating
+             * strnlen.
+             */
+            char identity[PSK_MAX_IDENTITY_LEN + 1];
+            size_t identitylen;
+            unsigned char psk[PSK_MAX_PSK_LEN];
+            size_t psklen;
+
+            if (s->psk_client_callback == NULL) {
+                SSLerr(SSL_F_SSL3_SEND_CLIENT_KEY_EXCHANGE,
+                       SSL_R_PSK_NO_CLIENT_CB);
+                goto err;
+            }
+
+            memset(identity, 0, sizeof(identity));
+
+            psklen = s->psk_client_callback(s, s->session->psk_identity_hint,
+                                            identity, sizeof(identity) - 1,
+                                            psk, sizeof(psk));
+
+            if (psklen > PSK_MAX_PSK_LEN) {
+                SSLerr(SSL_F_SSL3_SEND_CLIENT_KEY_EXCHANGE,
+                       ERR_R_INTERNAL_ERROR);
+                goto psk_err;
+            } else if (psklen == 0) {
+                SSLerr(SSL_F_SSL3_SEND_CLIENT_KEY_EXCHANGE,
+                       SSL_R_PSK_IDENTITY_NOT_FOUND);
+                goto psk_err;
+            }
+
+            OPENSSL_free(s->s3->tmp.psk);
+            s->s3->tmp.psk = BUF_memdup(psk, psklen);
+            OPENSSL_cleanse(psk, psklen);
+
+            if (s->s3->tmp.psk == NULL)
+                goto memerr;
+
+            s->s3->tmp.psklen = psklen;
+
+            identitylen = strlen(identity);
+            if (identitylen > PSK_MAX_IDENTITY_LEN) {
+                SSLerr(SSL_F_SSL3_SEND_CLIENT_KEY_EXCHANGE,
+                       ERR_R_INTERNAL_ERROR);
+                goto psk_err;
+            }
+            OPENSSL_free(s->session->psk_identity);
+            s->session->psk_identity = BUF_strdup(identity);
+            if (s->session->psk_identity == NULL)
+                goto memerr;
+
+            s2n(identitylen, p);
+            memcpy(p, identity, identitylen);
+            pskhdrlen = 2 + identitylen;
+            p += identitylen;
+            psk_err = 0;
+ psk_err:
+            OPENSSL_cleanse(identity, sizeof(identity));
+            if (psk_err != 0) {
+                ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_HANDSHAKE_FAILURE);
+                goto err;
+            }
+        }
+        if (alg_k & SSL_kPSK) {
+            n = 0;
+        } else
+#endif
 
         /* Fool emacs indentation */
         if (0) {
         }
 #ifndef OPENSSL_NO_RSA
-        else if (alg_k & SSL_kRSA) {
+        else if (alg_k & (SSL_kRSA | SSL_kRSAPSK)) {
             RSA *rsa;
             pmslen = SSL_MAX_MASTER_KEY_LENGTH;
             pms = OPENSSL_malloc(pmslen);
@@ -2414,7 +2488,7 @@ int ssl3_send_client_key_exchange(SSL *s)
         }
 #endif
 #ifndef OPENSSL_NO_DH
-        else if (alg_k & (SSL_kDHE | SSL_kDHr | SSL_kDHd)) {
+        else if (alg_k & (SSL_kDHE | SSL_kDHr | SSL_kDHd | SSL_kDHEPSK)) {
             DH *dh_srvr, *dh_clnt;
             if (s->s3->peer_dh_tmp != NULL)
                 dh_srvr = s->s3->peer_dh_tmp;
@@ -2493,7 +2567,7 @@ int ssl3_send_client_key_exchange(SSL *s)
 #endif
 
 #ifndef OPENSSL_NO_EC
-        else if (alg_k & (SSL_kECDHE | SSL_kECDHr | SSL_kECDHe)) {
+        else if (alg_k & (SSL_kECDHE | SSL_kECDHr | SSL_kECDHe | SSL_kECDHEPSK)) {
             const EC_GROUP *srvr_group = NULL;
             EC_KEY *tkey;
             int ecdh_clnt_cert = 0;
@@ -2780,82 +2854,6 @@ int ssl3_send_client_key_exchange(SSL *s)
                 goto err;
             }
         }
-#endif
-#ifndef OPENSSL_NO_PSK
-        else if (alg_k & SSL_kPSK) {
-            /*
-             * The callback needs PSK_MAX_IDENTITY_LEN + 1 bytes to return a
-             * \0-terminated identity. The last byte is for us for simulating
-             * strnlen.
-             */
-            char identity[PSK_MAX_IDENTITY_LEN + 2];
-            size_t identity_len;
-            unsigned char *t = NULL;
-            unsigned int psk_len = 0;
-            int psk_err = 1;
-
-            n = 0;
-            if (s->psk_client_callback == NULL) {
-                SSLerr(SSL_F_SSL3_SEND_CLIENT_KEY_EXCHANGE,
-                       SSL_R_PSK_NO_CLIENT_CB);
-                goto err;
-            }
-
-            memset(identity, 0, sizeof(identity));
-            /* Allocate maximum size buffer */
-            pmslen = PSK_MAX_PSK_LEN * 2 + 4;
-            pms = OPENSSL_malloc(pmslen);
-            if (!pms)
-                goto memerr;
-
-            psk_len = s->psk_client_callback(s, s->session->psk_identity_hint,
-                                             identity, sizeof(identity) - 1,
-                                             pms, pmslen);
-            if (psk_len > PSK_MAX_PSK_LEN) {
-                SSLerr(SSL_F_SSL3_SEND_CLIENT_KEY_EXCHANGE,
-                       ERR_R_INTERNAL_ERROR);
-                goto psk_err;
-            } else if (psk_len == 0) {
-                SSLerr(SSL_F_SSL3_SEND_CLIENT_KEY_EXCHANGE,
-                       SSL_R_PSK_IDENTITY_NOT_FOUND);
-                goto psk_err;
-            }
-            /* Change pmslen to real length */
-            pmslen = 2 + psk_len + 2 + psk_len;
-            identity[PSK_MAX_IDENTITY_LEN + 1] = '\0';
-            identity_len = strlen(identity);
-            if (identity_len > PSK_MAX_IDENTITY_LEN) {
-                SSLerr(SSL_F_SSL3_SEND_CLIENT_KEY_EXCHANGE,
-                       ERR_R_INTERNAL_ERROR);
-                goto psk_err;
-            }
-            /* create PSK pre_master_secret */
-            t = pms;
-            memmove(pms + psk_len + 4, pms, psk_len);
-            s2n(psk_len, t);
-            memset(t, 0, psk_len);
-            t += psk_len;
-            s2n(psk_len, t);
-
-            OPENSSL_free(s->session->psk_identity);
-            s->session->psk_identity = BUF_strdup(identity);
-            if (s->session->psk_identity == NULL) {
-                SSLerr(SSL_F_SSL3_SEND_CLIENT_KEY_EXCHANGE,
-                       ERR_R_MALLOC_FAILURE);
-                goto psk_err;
-            }
-
-            s2n(identity_len, p);
-            memcpy(p, identity, identity_len);
-            n = 2 + identity_len;
-            psk_err = 0;
- psk_err:
-            OPENSSL_cleanse(identity, sizeof(identity));
-            if (psk_err != 0) {
-                ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_HANDSHAKE_FAILURE);
-                goto err;
-            }
-        }
 #endif
         else {
             ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_HANDSHAKE_FAILURE);
@@ -2863,6 +2861,10 @@ int ssl3_send_client_key_exchange(SSL *s)
             goto err;
         }
 
+#ifndef OPENSSL_NO_PSK
+        n += pskhdrlen;
+#endif
+
         if (!ssl_set_handshake_header(s, SSL3_MT_CLIENT_KEY_EXCHANGE, n)) {
             ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_HANDSHAKE_FAILURE);
             SSLerr(SSL_F_SSL3_SEND_CLIENT_KEY_EXCHANGE, ERR_R_INTERNAL_ERROR);
@@ -2876,7 +2878,7 @@ int ssl3_send_client_key_exchange(SSL *s)
     n = ssl_do_write(s);
 #ifndef OPENSSL_NO_SRP
     /* Check for SRP */
-    if (s->s3->tmp.new_cipher->algorithm_mkey & SSL_kSRP) {
+    if (alg_k & SSL_kSRP) {
         /*
          * If everything written generate master key: no need to save PMS as
          * srp_generate_client_master_secret generates it internally.
@@ -2900,7 +2902,7 @@ int ssl3_send_client_key_exchange(SSL *s)
             pms = s->s3->tmp.pms;
             pmslen = s->s3->tmp.pmslen;
         }
-        if (pms == NULL) {
+        if (pms == NULL && !(alg_k & SSL_kPSK)) {
             ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_INTERNAL_ERROR);
             SSLerr(SSL_F_SSL3_SEND_CLIENT_KEY_EXCHANGE, ERR_R_MALLOC_FAILURE);
             goto err;
@@ -2923,6 +2925,10 @@ int ssl3_send_client_key_exchange(SSL *s)
     OPENSSL_free(encodedPoint);
     EC_KEY_free(clnt_ecdh);
     EVP_PKEY_free(srvr_pub_pkey);
+#endif
+#ifndef OPENSSL_NO_PSK
+    OPENSSL_clear_free(s->s3->tmp.psk, s->s3->tmp.psklen);
+    s->s3->tmp.psk = NULL;
 #endif
     s->state = SSL_ST_ERR;
     return (-1);
@@ -3261,7 +3267,7 @@ int ssl3_check_cert_and_algorithm(SSL *s)
     }
 #endif
 #ifndef OPENSSL_NO_RSA
-    if (alg_k & SSL_kRSA) {
+    if (alg_k & (SSL_kRSA | SSL_kRSAPSK)) {
         if (!SSL_C_IS_EXPORT(s->s3->tmp.new_cipher) &&
             !has_bits(i, EVP_PK_RSA | EVP_PKT_ENC)) {
             SSLerr(SSL_F_SSL3_CHECK_CERT_AND_ALGORITHM,