Extended PSK server support.
authorDr. Stephen Henson <steve@openssl.org>
Sun, 28 Jun 2015 16:23:13 +0000 (17:23 +0100)
committerDr. Stephen Henson <steve@openssl.org>
Thu, 30 Jul 2015 13:43:35 +0000 (14:43 +0100)
Add support for RSAPSK, DHEPSK and ECDHEPSK server side.

Update various checks to ensure certificate and server key exchange messages
are only sent when required.

Update message handling. PSK server key exchange parsing now include an
identity hint prefix for all PSK server key exchange messages. PSK
client key exchange message expects PSK identity and requests key for
all PSK key exchange ciphersuites.

Update flags for RSA, DH and ECDH so they are also used in PSK.

Reviewed-by: Matt Caswell <matt@openssl.org>
ssl/s3_srvr.c
ssl/ssl_locl.h

index cbe80eb..caf45d1 100644 (file)
@@ -403,10 +403,8 @@ int ssl3_accept(SSL *s)
         case SSL3_ST_SW_CERT_B:
             /* Check if it is anon DH or anon ECDH, */
             /* normal PSK or SRP */
-            if (!
-                (s->s3->tmp.
-                 new_cipher->algorithm_auth & (SSL_aNULL | SSL_aSRP))
-&& !(s->s3->tmp.new_cipher->algorithm_mkey & SSL_kPSK)) {
+            if (!(s->s3->tmp.new_cipher->algorithm_auth &
+                 (SSL_aNULL | SSL_aSRP | SSL_aPSK))) {
                 ret = ssl3_send_server_certificate(s);
                 if (ret <= 0)
                     goto end;
@@ -446,7 +444,10 @@ int ssl3_accept(SSL *s)
                  * provided
                  */
 #ifndef OPENSSL_NO_PSK
-                || ((alg_k & SSL_kPSK) && s->ctx->psk_identity_hint)
+                /* Only send SKE if we have identity hint for plain PSK */
+                || ((alg_k & (SSL_kPSK | SSL_kRSAPSK)) && s->ctx->psk_identity_hint)
+                /* For other PSK always send SKE */
+                || (alg_k & (SSL_PSK & (SSL_kDHEPSK | SSL_kECDHEPSK)))
 #endif
 #ifndef OPENSSL_NO_SRP
                 /* SRP: send ServerKeyExchange */
@@ -1722,6 +1723,19 @@ int ssl3_send_server_key_exchange(SSL *s)
 
         r[0] = r[1] = r[2] = r[3] = NULL;
         n = 0;
+#ifndef OPENSSL_NO_PSK
+        if (type & SSL_PSK) {
+            /*
+             * reserve size for record length and PSK identity hint
+             */
+            n += 2;
+            if (s->ctx->psk_identity_hint)
+                n += strlen(s->ctx->psk_identity_hint);
+        }
+        /* Plain PSK or RSAPSK nothing to do */
+        if (type & (SSL_kPSK | SSL_kRSAPSK)) {
+        } else
+#endif                          /* !OPENSSL_NO_PSK */
 #ifndef OPENSSL_NO_RSA
         if (type & SSL_kRSA) {
             rsa = cert->rsa_tmp;
@@ -1752,7 +1766,7 @@ int ssl3_send_server_key_exchange(SSL *s)
         } else
 #endif
 #ifndef OPENSSL_NO_DH
-        if (type & SSL_kDHE) {
+        if (type & (SSL_kDHE | SSL_kDHEPSK)) {
             if (s->cert->dh_tmp_auto) {
                 dhp = ssl_get_auto_dh(s);
                 if (dhp == NULL) {
@@ -1817,7 +1831,7 @@ int ssl3_send_server_key_exchange(SSL *s)
         } else
 #endif
 #ifndef OPENSSL_NO_EC
-        if (type & SSL_kECDHE) {
+        if (type & (SSL_kECDHE | SSL_kECDHEPSK)) {
             const EC_GROUP *group;
 
             ecdhp = cert->ecdh_tmp;
@@ -1933,7 +1947,7 @@ int ssl3_send_server_key_exchange(SSL *s)
              * additional bytes to encode the entire ServerECDHParams
              * structure.
              */
-            n = 4 + encodedlen;
+            n += 4 + encodedlen;
 
             /*
              * We'll generate the serverKeyExchange message explicitly so we
@@ -1945,14 +1959,6 @@ int ssl3_send_server_key_exchange(SSL *s)
             r[3] = NULL;
         } else
 #endif                          /* !OPENSSL_NO_EC */
-#ifndef OPENSSL_NO_PSK
-        if (type & SSL_kPSK) {
-            /*
-             * reserve size for record length and PSK identity hint
-             */
-            n += 2 + strlen(s->ctx->psk_identity_hint);
-        } else
-#endif                          /* !OPENSSL_NO_PSK */
 #ifndef OPENSSL_NO_SRP
         if (type & SSL_kSRP) {
             if ((s->srp_ctx.N == NULL) ||
@@ -1984,8 +1990,8 @@ int ssl3_send_server_key_exchange(SSL *s)
                 n += 2 + nr[i];
         }
 
-        if (!(s->s3->tmp.new_cipher->algorithm_auth & (SSL_aNULL | SSL_aSRP))
-            && !(s->s3->tmp.new_cipher->algorithm_mkey & SSL_kPSK)) {
+        if (!(s->s3->tmp.new_cipher->algorithm_auth & (SSL_aNULL|SSL_aSRP))
+            && !(s->s3->tmp.new_cipher->algorithm_mkey & SSL_PSK)) {
             if ((pkey = ssl_get_sign_pkey(s, s->s3->tmp.new_cipher, &md))
                 == NULL) {
                 al = SSL_AD_DECODE_ERROR;
@@ -2003,6 +2009,20 @@ int ssl3_send_server_key_exchange(SSL *s)
         }
         d = p = ssl_handshake_start(s);
 
+#ifndef OPENSSL_NO_PSK
+        if (type & SSL_PSK) {
+            /* copy PSK identity hint */
+            if (s->ctx->psk_identity_hint) {
+                s2n(strlen(s->ctx->psk_identity_hint), p);
+                strncpy((char *)p, s->ctx->psk_identity_hint,
+                        strlen(s->ctx->psk_identity_hint));
+                p += strlen(s->ctx->psk_identity_hint);
+            } else {
+                s2n(0, p);
+            }
+        }
+#endif
+
         for (i = 0; i < 4 && r[i] != NULL; i++) {
 #ifndef OPENSSL_NO_SRP
             if ((i == 2) && (type & SSL_kSRP)) {
@@ -2016,7 +2036,7 @@ int ssl3_send_server_key_exchange(SSL *s)
         }
 
 #ifndef OPENSSL_NO_EC
-        if (type & SSL_kECDHE) {
+        if (type & (SSL_kECDHE | SSL_kECDHEPSK)) {
             /*
              * XXX: For now, we only support named (not generic) curves. In
              * this situation, the serverKeyExchange message has: [1 byte
@@ -2038,16 +2058,6 @@ int ssl3_send_server_key_exchange(SSL *s)
         }
 #endif
 
-#ifndef OPENSSL_NO_PSK
-        if (type & SSL_kPSK) {
-            /* copy PSK identity hint */
-            s2n(strlen(s->ctx->psk_identity_hint), p);
-            strncpy((char *)p, s->ctx->psk_identity_hint,
-                    strlen(s->ctx->psk_identity_hint));
-            p += strlen(s->ctx->psk_identity_hint);
-        }
-#endif
-
         /* not anonymous */
         if (pkey != NULL) {
             /*
@@ -2249,8 +2259,94 @@ int ssl3_get_client_key_exchange(SSL *s)
 
     alg_k = s->s3->tmp.new_cipher->algorithm_mkey;
 
+#ifndef OPENSSL_NO_PSK
+    /* For PSK parse and retrieve identity, obtain PSK key */
+    if (alg_k & SSL_PSK) {
+        unsigned char psk[PSK_MAX_PSK_LEN];
+        size_t psklen;
+        if (n < 2) {
+            al = SSL_AD_DECODE_ERROR;
+            SSLerr(SSL_F_SSL3_GET_CLIENT_KEY_EXCHANGE, SSL_R_LENGTH_MISMATCH);
+            goto f_err;
+        }
+        n2s(p, i);
+        if (i + 2 > n) {
+            al = SSL_AD_DECODE_ERROR;
+            SSLerr(SSL_F_SSL3_GET_CLIENT_KEY_EXCHANGE, SSL_R_LENGTH_MISMATCH);
+            goto f_err;
+        }
+        if (i > PSK_MAX_IDENTITY_LEN) {
+            al = SSL_AD_DECODE_ERROR;
+            SSLerr(SSL_F_SSL3_GET_CLIENT_KEY_EXCHANGE,
+                   SSL_R_DATA_LENGTH_TOO_LONG);
+            goto f_err;
+        }
+        if (s->psk_server_callback == NULL) {
+            al = SSL_AD_INTERNAL_ERROR;
+            SSLerr(SSL_F_SSL3_GET_CLIENT_KEY_EXCHANGE,
+                   SSL_R_PSK_NO_SERVER_CB);
+            goto f_err;
+        }
+
+        OPENSSL_free(s->session->psk_identity);
+        s->session->psk_identity = BUF_strndup((char *)p, i);
+
+        if (s->session->psk_identity == NULL) {
+            al = SSL_AD_INTERNAL_ERROR;
+            SSLerr(SSL_F_SSL3_GET_CLIENT_KEY_EXCHANGE,
+                   ERR_R_MALLOC_FAILURE);
+            goto f_err;
+        }
+
+        psklen = s->psk_server_callback(s, s->session->psk_identity,
+                                         psk, sizeof(psk));
+
+        if (psklen > PSK_MAX_PSK_LEN) {
+            al = SSL_AD_INTERNAL_ERROR;
+            SSLerr(SSL_F_SSL3_GET_CLIENT_KEY_EXCHANGE, ERR_R_INTERNAL_ERROR);
+            goto f_err;
+        } else if (psklen == 0) {
+            /*
+             * PSK related to the given identity not found
+             */
+            SSLerr(SSL_F_SSL3_GET_CLIENT_KEY_EXCHANGE,
+                   SSL_R_PSK_IDENTITY_NOT_FOUND);
+            al = SSL_AD_UNKNOWN_PSK_IDENTITY;
+            goto f_err;
+        }
+
+        OPENSSL_free(s->s3->tmp.psk);
+        s->s3->tmp.psk = BUF_memdup(psk, psklen);
+        OPENSSL_cleanse(psk, psklen);
+
+        if (s->s3->tmp.psk == NULL) {
+            al = SSL_AD_INTERNAL_ERROR;
+            SSLerr(SSL_F_SSL3_GET_CLIENT_KEY_EXCHANGE, ERR_R_MALLOC_FAILURE);
+            goto f_err;
+        }
+
+        s->s3->tmp.psklen = psklen;
+
+        n -= i + 2;
+        p += i;
+    }
+    if (alg_k & SSL_kPSK) {
+        /* Identity extracted earlier: should be nothing left */
+        if (n != 0) {
+            al = SSL_AD_HANDSHAKE_FAILURE;
+            SSLerr(SSL_F_SSL3_GET_CLIENT_KEY_EXCHANGE, SSL_R_LENGTH_MISMATCH);
+            goto f_err;
+        }
+        /* PSK handled by ssl_generate_master_secret */
+        if (!ssl_generate_master_secret(s, NULL, 0, 0)) {
+            al = SSL_AD_INTERNAL_ERROR;
+            SSLerr(SSL_F_SSL3_GET_CLIENT_KEY_EXCHANGE, ERR_R_INTERNAL_ERROR);
+            goto f_err;
+        }
+    } else
+#endif
 #ifndef OPENSSL_NO_RSA
-    if (alg_k & SSL_kRSA) {
+    if (alg_k & (SSL_kRSA | SSL_kRSAPSK)) {
         unsigned char rand_premaster_secret[SSL_MAX_MASTER_KEY_LENGTH];
         int decrypt_len;
         unsigned char decrypt_good, version_good;
@@ -2389,13 +2485,13 @@ int ssl3_get_client_key_exchange(SSL *s)
     } else
 #endif
 #ifndef OPENSSL_NO_DH
-    if (alg_k & (SSL_kDHE | SSL_kDHr | SSL_kDHd)) {
+    if (alg_k & (SSL_kDHE | SSL_kDHr | SSL_kDHd | SSL_kDHEPSK)) {
         int idx = -1;
         EVP_PKEY *skey = NULL;
         if (n > 1) {
             n2s(p, i);
         } else {
-            if (alg_k & SSL_kDHE) {
+            if (alg_k & (SSL_kDHE | SSL_kDHEPSK)) {
                 al = SSL_AD_HANDSHAKE_FAILURE;
                 SSLerr(SSL_F_SSL3_GET_CLIENT_KEY_EXCHANGE,
                        SSL_R_DH_PUBLIC_VALUE_LENGTH_IS_WRONG);
@@ -2483,7 +2579,7 @@ int ssl3_get_client_key_exchange(SSL *s)
 #endif
 
 #ifndef OPENSSL_NO_EC
-    if (alg_k & (SSL_kECDHE | SSL_kECDHr | SSL_kECDHe)) {
+    if (alg_k & (SSL_kECDHE | SSL_kECDHr | SSL_kECDHe | SSL_kECDHEPSK)) {
         int ret = 1;
         int field_size = 0;
         const EC_KEY *tkey;
@@ -2526,7 +2622,7 @@ int ssl3_get_client_key_exchange(SSL *s)
         if (n == 0L) {
             /* Client Publickey was in Client Certificate */
 
-            if (alg_k & SSL_kECDHE) {
+            if (alg_k & (SSL_kECDHE | SSL_kECDHEPSK)) {
                 al = SSL_AD_HANDSHAKE_FAILURE;
                 SSLerr(SSL_F_SSL3_GET_CLIENT_KEY_EXCHANGE,
                        SSL_R_MISSING_TMP_ECDH_KEY);
@@ -2612,92 +2708,6 @@ int ssl3_get_client_key_exchange(SSL *s)
         return (ret);
     } else
 #endif
-#ifndef OPENSSL_NO_PSK
-    if (alg_k & SSL_kPSK) {
-        unsigned char *t = NULL;
-        unsigned char psk_or_pre_ms[PSK_MAX_PSK_LEN * 2 + 4];
-        unsigned int pre_ms_len = 0, psk_len = 0;
-        int psk_err = 1;
-        char tmp_id[PSK_MAX_IDENTITY_LEN + 1];
-
-        al = SSL_AD_HANDSHAKE_FAILURE;
-
-        n2s(p, i);
-        if (n != i + 2) {
-            SSLerr(SSL_F_SSL3_GET_CLIENT_KEY_EXCHANGE, SSL_R_LENGTH_MISMATCH);
-            goto psk_err;
-        }
-        if (i > PSK_MAX_IDENTITY_LEN) {
-            SSLerr(SSL_F_SSL3_GET_CLIENT_KEY_EXCHANGE,
-                   SSL_R_DATA_LENGTH_TOO_LONG);
-            goto psk_err;
-        }
-        if (s->psk_server_callback == NULL) {
-            SSLerr(SSL_F_SSL3_GET_CLIENT_KEY_EXCHANGE,
-                   SSL_R_PSK_NO_SERVER_CB);
-            goto psk_err;
-        }
-
-        /*
-         * Create guaranteed NULL-terminated identity string for the callback
-         */
-        memcpy(tmp_id, p, i);
-        memset(tmp_id + i, 0, PSK_MAX_IDENTITY_LEN + 1 - i);
-        psk_len = s->psk_server_callback(s, tmp_id,
-                                         psk_or_pre_ms,
-                                         sizeof(psk_or_pre_ms));
-        OPENSSL_cleanse(tmp_id, sizeof(tmp_id));
-
-        if (psk_len > PSK_MAX_PSK_LEN) {
-            SSLerr(SSL_F_SSL3_GET_CLIENT_KEY_EXCHANGE, ERR_R_INTERNAL_ERROR);
-            goto psk_err;
-        } else if (psk_len == 0) {
-            /*
-             * PSK related to the given identity not found
-             */
-            SSLerr(SSL_F_SSL3_GET_CLIENT_KEY_EXCHANGE,
-                   SSL_R_PSK_IDENTITY_NOT_FOUND);
-            al = SSL_AD_UNKNOWN_PSK_IDENTITY;
-            goto psk_err;
-        }
-
-        /* create PSK pre_master_secret */
-        pre_ms_len = 2 + psk_len + 2 + psk_len;
-        t = psk_or_pre_ms;
-        memmove(psk_or_pre_ms + psk_len + 4, psk_or_pre_ms, psk_len);
-        s2n(psk_len, t);
-        memset(t, 0, psk_len);
-        t += psk_len;
-        s2n(psk_len, t);
-
-        OPENSSL_free(s->session->psk_identity);
-        s->session->psk_identity = BUF_strdup((char *)p);
-        if (s->session->psk_identity == NULL) {
-            SSLerr(SSL_F_SSL3_GET_CLIENT_KEY_EXCHANGE, ERR_R_MALLOC_FAILURE);
-            goto psk_err;
-        }
-
-        OPENSSL_free(s->session->psk_identity_hint);
-        s->session->psk_identity_hint = BUF_strdup(s->ctx->psk_identity_hint);
-        if (s->ctx->psk_identity_hint != NULL &&
-            s->session->psk_identity_hint == NULL) {
-            SSLerr(SSL_F_SSL3_GET_CLIENT_KEY_EXCHANGE, ERR_R_MALLOC_FAILURE);
-            goto psk_err;
-        }
-
-        if (!ssl_generate_master_secret(s, psk_or_pre_ms, pre_ms_len, 0)) {
-            al = SSL_AD_INTERNAL_ERROR;
-            SSLerr(SSL_F_SSL3_GET_CLIENT_KEY_EXCHANGE, ERR_R_INTERNAL_ERROR);
-            goto f_err;
-        }
-        psk_err = 0;
- psk_err:
-        if (psk_err != 0) {
-            OPENSSL_cleanse(psk_or_pre_ms, sizeof(psk_or_pre_ms));
-            goto f_err;
-        }
-    } else
-#endif
 #ifndef OPENSSL_NO_SRP
     if (alg_k & SSL_kSRP) {
         int param_len;
@@ -2819,6 +2829,10 @@ int ssl3_get_client_key_exchange(SSL *s)
     EC_POINT_free(clnt_ecpoint);
     EC_KEY_free(srvr_ecdh);
     BN_CTX_free(bn_ctx);
+#endif
+#ifndef OPENSSL_NO_PSK
+    OPENSSL_clear_free(s->s3->tmp.psk, s->s3->tmp.psklen);
+    s->s3->tmp.psk = NULL;
 #endif
     s->state = SSL_ST_ERR;
     return (-1);
index db2341c..c75219b 100644 (file)
@@ -1277,9 +1277,11 @@ typedef struct ssl3_state_st {
         /* Temporary storage for premaster secret */
         unsigned char *pms;
         size_t pmslen;
+#ifndef OPENSSL_NO_PSK
         /* Temporary storage for PSK key */
         unsigned char *psk;
         size_t psklen;
+#endif
         /*
          * signature algorithms peer reports: e.g. supported signature
          * algorithms extension for server or as part of a certificate