Extended PSK server support.
[openssl.git] / ssl / s3_srvr.c
index cbe80eb8eb933f53e8cc212039c09d1b2cfb8bf0..caf45d1afbf526668d4116541424d3f1f38a1bb8 100644 (file)
@@ -403,10 +403,8 @@ int ssl3_accept(SSL *s)
         case SSL3_ST_SW_CERT_B:
             /* Check if it is anon DH or anon ECDH, */
             /* normal PSK or SRP */
-            if (!
-                (s->s3->tmp.
-                 new_cipher->algorithm_auth & (SSL_aNULL | SSL_aSRP))
-&& !(s->s3->tmp.new_cipher->algorithm_mkey & SSL_kPSK)) {
+            if (!(s->s3->tmp.new_cipher->algorithm_auth &
+                 (SSL_aNULL | SSL_aSRP | SSL_aPSK))) {
                 ret = ssl3_send_server_certificate(s);
                 if (ret <= 0)
                     goto end;
@@ -446,7 +444,10 @@ int ssl3_accept(SSL *s)
                  * provided
                  */
 #ifndef OPENSSL_NO_PSK
-                || ((alg_k & SSL_kPSK) && s->ctx->psk_identity_hint)
+                /* Only send SKE if we have identity hint for plain PSK */
+                || ((alg_k & (SSL_kPSK | SSL_kRSAPSK)) && s->ctx->psk_identity_hint)
+                /* For other PSK always send SKE */
+                || (alg_k & (SSL_PSK & (SSL_kDHEPSK | SSL_kECDHEPSK)))
 #endif
 #ifndef OPENSSL_NO_SRP
                 /* SRP: send ServerKeyExchange */
@@ -1722,6 +1723,19 @@ int ssl3_send_server_key_exchange(SSL *s)
 
         r[0] = r[1] = r[2] = r[3] = NULL;
         n = 0;
+#ifndef OPENSSL_NO_PSK
+        if (type & SSL_PSK) {
+            /*
+             * reserve size for record length and PSK identity hint
+             */
+            n += 2;
+            if (s->ctx->psk_identity_hint)
+                n += strlen(s->ctx->psk_identity_hint);
+        }
+        /* Plain PSK or RSAPSK nothing to do */
+        if (type & (SSL_kPSK | SSL_kRSAPSK)) {
+        } else
+#endif                          /* !OPENSSL_NO_PSK */
 #ifndef OPENSSL_NO_RSA
         if (type & SSL_kRSA) {
             rsa = cert->rsa_tmp;
@@ -1752,7 +1766,7 @@ int ssl3_send_server_key_exchange(SSL *s)
         } else
 #endif
 #ifndef OPENSSL_NO_DH
-        if (type & SSL_kDHE) {
+        if (type & (SSL_kDHE | SSL_kDHEPSK)) {
             if (s->cert->dh_tmp_auto) {
                 dhp = ssl_get_auto_dh(s);
                 if (dhp == NULL) {
@@ -1817,7 +1831,7 @@ int ssl3_send_server_key_exchange(SSL *s)
         } else
 #endif
 #ifndef OPENSSL_NO_EC
-        if (type & SSL_kECDHE) {
+        if (type & (SSL_kECDHE | SSL_kECDHEPSK)) {
             const EC_GROUP *group;
 
             ecdhp = cert->ecdh_tmp;
@@ -1933,7 +1947,7 @@ int ssl3_send_server_key_exchange(SSL *s)
              * additional bytes to encode the entire ServerECDHParams
              * structure.
              */
-            n = 4 + encodedlen;
+            n += 4 + encodedlen;
 
             /*
              * We'll generate the serverKeyExchange message explicitly so we
@@ -1945,14 +1959,6 @@ int ssl3_send_server_key_exchange(SSL *s)
             r[3] = NULL;
         } else
 #endif                          /* !OPENSSL_NO_EC */
-#ifndef OPENSSL_NO_PSK
-        if (type & SSL_kPSK) {
-            /*
-             * reserve size for record length and PSK identity hint
-             */
-            n += 2 + strlen(s->ctx->psk_identity_hint);
-        } else
-#endif                          /* !OPENSSL_NO_PSK */
 #ifndef OPENSSL_NO_SRP
         if (type & SSL_kSRP) {
             if ((s->srp_ctx.N == NULL) ||
@@ -1984,8 +1990,8 @@ int ssl3_send_server_key_exchange(SSL *s)
                 n += 2 + nr[i];
         }
 
-        if (!(s->s3->tmp.new_cipher->algorithm_auth & (SSL_aNULL | SSL_aSRP))
-            && !(s->s3->tmp.new_cipher->algorithm_mkey & SSL_kPSK)) {
+        if (!(s->s3->tmp.new_cipher->algorithm_auth & (SSL_aNULL|SSL_aSRP))
+            && !(s->s3->tmp.new_cipher->algorithm_mkey & SSL_PSK)) {
             if ((pkey = ssl_get_sign_pkey(s, s->s3->tmp.new_cipher, &md))
                 == NULL) {
                 al = SSL_AD_DECODE_ERROR;
@@ -2003,6 +2009,20 @@ int ssl3_send_server_key_exchange(SSL *s)
         }
         d = p = ssl_handshake_start(s);
 
+#ifndef OPENSSL_NO_PSK
+        if (type & SSL_PSK) {
+            /* copy PSK identity hint */
+            if (s->ctx->psk_identity_hint) {
+                s2n(strlen(s->ctx->psk_identity_hint), p);
+                strncpy((char *)p, s->ctx->psk_identity_hint,
+                        strlen(s->ctx->psk_identity_hint));
+                p += strlen(s->ctx->psk_identity_hint);
+            } else {
+                s2n(0, p);
+            }
+        }
+#endif
+
         for (i = 0; i < 4 && r[i] != NULL; i++) {
 #ifndef OPENSSL_NO_SRP
             if ((i == 2) && (type & SSL_kSRP)) {
@@ -2016,7 +2036,7 @@ int ssl3_send_server_key_exchange(SSL *s)
         }
 
 #ifndef OPENSSL_NO_EC
-        if (type & SSL_kECDHE) {
+        if (type & (SSL_kECDHE | SSL_kECDHEPSK)) {
             /*
              * XXX: For now, we only support named (not generic) curves. In
              * this situation, the serverKeyExchange message has: [1 byte
@@ -2038,16 +2058,6 @@ int ssl3_send_server_key_exchange(SSL *s)
         }
 #endif
 
-#ifndef OPENSSL_NO_PSK
-        if (type & SSL_kPSK) {
-            /* copy PSK identity hint */
-            s2n(strlen(s->ctx->psk_identity_hint), p);
-            strncpy((char *)p, s->ctx->psk_identity_hint,
-                    strlen(s->ctx->psk_identity_hint));
-            p += strlen(s->ctx->psk_identity_hint);
-        }
-#endif
-
         /* not anonymous */
         if (pkey != NULL) {
             /*
@@ -2249,8 +2259,94 @@ int ssl3_get_client_key_exchange(SSL *s)
 
     alg_k = s->s3->tmp.new_cipher->algorithm_mkey;
 
+#ifndef OPENSSL_NO_PSK
+    /* For PSK parse and retrieve identity, obtain PSK key */
+    if (alg_k & SSL_PSK) {
+        unsigned char psk[PSK_MAX_PSK_LEN];
+        size_t psklen;
+        if (n < 2) {
+            al = SSL_AD_DECODE_ERROR;
+            SSLerr(SSL_F_SSL3_GET_CLIENT_KEY_EXCHANGE, SSL_R_LENGTH_MISMATCH);
+            goto f_err;
+        }
+        n2s(p, i);
+        if (i + 2 > n) {
+            al = SSL_AD_DECODE_ERROR;
+            SSLerr(SSL_F_SSL3_GET_CLIENT_KEY_EXCHANGE, SSL_R_LENGTH_MISMATCH);
+            goto f_err;
+        }
+        if (i > PSK_MAX_IDENTITY_LEN) {
+            al = SSL_AD_DECODE_ERROR;
+            SSLerr(SSL_F_SSL3_GET_CLIENT_KEY_EXCHANGE,
+                   SSL_R_DATA_LENGTH_TOO_LONG);
+            goto f_err;
+        }
+        if (s->psk_server_callback == NULL) {
+            al = SSL_AD_INTERNAL_ERROR;
+            SSLerr(SSL_F_SSL3_GET_CLIENT_KEY_EXCHANGE,
+                   SSL_R_PSK_NO_SERVER_CB);
+            goto f_err;
+        }
+
+        OPENSSL_free(s->session->psk_identity);
+        s->session->psk_identity = BUF_strndup((char *)p, i);
+
+        if (s->session->psk_identity == NULL) {
+            al = SSL_AD_INTERNAL_ERROR;
+            SSLerr(SSL_F_SSL3_GET_CLIENT_KEY_EXCHANGE,
+                   ERR_R_MALLOC_FAILURE);
+            goto f_err;
+        }
+
+        psklen = s->psk_server_callback(s, s->session->psk_identity,
+                                         psk, sizeof(psk));
+
+        if (psklen > PSK_MAX_PSK_LEN) {
+            al = SSL_AD_INTERNAL_ERROR;
+            SSLerr(SSL_F_SSL3_GET_CLIENT_KEY_EXCHANGE, ERR_R_INTERNAL_ERROR);
+            goto f_err;
+        } else if (psklen == 0) {
+            /*
+             * PSK related to the given identity not found
+             */
+            SSLerr(SSL_F_SSL3_GET_CLIENT_KEY_EXCHANGE,
+                   SSL_R_PSK_IDENTITY_NOT_FOUND);
+            al = SSL_AD_UNKNOWN_PSK_IDENTITY;
+            goto f_err;
+        }
+
+        OPENSSL_free(s->s3->tmp.psk);
+        s->s3->tmp.psk = BUF_memdup(psk, psklen);
+        OPENSSL_cleanse(psk, psklen);
+
+        if (s->s3->tmp.psk == NULL) {
+            al = SSL_AD_INTERNAL_ERROR;
+            SSLerr(SSL_F_SSL3_GET_CLIENT_KEY_EXCHANGE, ERR_R_MALLOC_FAILURE);
+            goto f_err;
+        }
+
+        s->s3->tmp.psklen = psklen;
+
+        n -= i + 2;
+        p += i;
+    }
+    if (alg_k & SSL_kPSK) {
+        /* Identity extracted earlier: should be nothing left */
+        if (n != 0) {
+            al = SSL_AD_HANDSHAKE_FAILURE;
+            SSLerr(SSL_F_SSL3_GET_CLIENT_KEY_EXCHANGE, SSL_R_LENGTH_MISMATCH);
+            goto f_err;
+        }
+        /* PSK handled by ssl_generate_master_secret */
+        if (!ssl_generate_master_secret(s, NULL, 0, 0)) {
+            al = SSL_AD_INTERNAL_ERROR;
+            SSLerr(SSL_F_SSL3_GET_CLIENT_KEY_EXCHANGE, ERR_R_INTERNAL_ERROR);
+            goto f_err;
+        }
+    } else
+#endif
 #ifndef OPENSSL_NO_RSA
-    if (alg_k & SSL_kRSA) {
+    if (alg_k & (SSL_kRSA | SSL_kRSAPSK)) {
         unsigned char rand_premaster_secret[SSL_MAX_MASTER_KEY_LENGTH];
         int decrypt_len;
         unsigned char decrypt_good, version_good;
@@ -2389,13 +2485,13 @@ int ssl3_get_client_key_exchange(SSL *s)
     } else
 #endif
 #ifndef OPENSSL_NO_DH
-    if (alg_k & (SSL_kDHE | SSL_kDHr | SSL_kDHd)) {
+    if (alg_k & (SSL_kDHE | SSL_kDHr | SSL_kDHd | SSL_kDHEPSK)) {
         int idx = -1;
         EVP_PKEY *skey = NULL;
         if (n > 1) {
             n2s(p, i);
         } else {
-            if (alg_k & SSL_kDHE) {
+            if (alg_k & (SSL_kDHE | SSL_kDHEPSK)) {
                 al = SSL_AD_HANDSHAKE_FAILURE;
                 SSLerr(SSL_F_SSL3_GET_CLIENT_KEY_EXCHANGE,
                        SSL_R_DH_PUBLIC_VALUE_LENGTH_IS_WRONG);
@@ -2483,7 +2579,7 @@ int ssl3_get_client_key_exchange(SSL *s)
 #endif
 
 #ifndef OPENSSL_NO_EC
-    if (alg_k & (SSL_kECDHE | SSL_kECDHr | SSL_kECDHe)) {
+    if (alg_k & (SSL_kECDHE | SSL_kECDHr | SSL_kECDHe | SSL_kECDHEPSK)) {
         int ret = 1;
         int field_size = 0;
         const EC_KEY *tkey;
@@ -2526,7 +2622,7 @@ int ssl3_get_client_key_exchange(SSL *s)
         if (n == 0L) {
             /* Client Publickey was in Client Certificate */
 
-            if (alg_k & SSL_kECDHE) {
+            if (alg_k & (SSL_kECDHE | SSL_kECDHEPSK)) {
                 al = SSL_AD_HANDSHAKE_FAILURE;
                 SSLerr(SSL_F_SSL3_GET_CLIENT_KEY_EXCHANGE,
                        SSL_R_MISSING_TMP_ECDH_KEY);
@@ -2612,92 +2708,6 @@ int ssl3_get_client_key_exchange(SSL *s)
         return (ret);
     } else
 #endif
-#ifndef OPENSSL_NO_PSK
-    if (alg_k & SSL_kPSK) {
-        unsigned char *t = NULL;
-        unsigned char psk_or_pre_ms[PSK_MAX_PSK_LEN * 2 + 4];
-        unsigned int pre_ms_len = 0, psk_len = 0;
-        int psk_err = 1;
-        char tmp_id[PSK_MAX_IDENTITY_LEN + 1];
-
-        al = SSL_AD_HANDSHAKE_FAILURE;
-
-        n2s(p, i);
-        if (n != i + 2) {
-            SSLerr(SSL_F_SSL3_GET_CLIENT_KEY_EXCHANGE, SSL_R_LENGTH_MISMATCH);
-            goto psk_err;
-        }
-        if (i > PSK_MAX_IDENTITY_LEN) {
-            SSLerr(SSL_F_SSL3_GET_CLIENT_KEY_EXCHANGE,
-                   SSL_R_DATA_LENGTH_TOO_LONG);
-            goto psk_err;
-        }
-        if (s->psk_server_callback == NULL) {
-            SSLerr(SSL_F_SSL3_GET_CLIENT_KEY_EXCHANGE,
-                   SSL_R_PSK_NO_SERVER_CB);
-            goto psk_err;
-        }
-
-        /*
-         * Create guaranteed NULL-terminated identity string for the callback
-         */
-        memcpy(tmp_id, p, i);
-        memset(tmp_id + i, 0, PSK_MAX_IDENTITY_LEN + 1 - i);
-        psk_len = s->psk_server_callback(s, tmp_id,
-                                         psk_or_pre_ms,
-                                         sizeof(psk_or_pre_ms));
-        OPENSSL_cleanse(tmp_id, sizeof(tmp_id));
-
-        if (psk_len > PSK_MAX_PSK_LEN) {
-            SSLerr(SSL_F_SSL3_GET_CLIENT_KEY_EXCHANGE, ERR_R_INTERNAL_ERROR);
-            goto psk_err;
-        } else if (psk_len == 0) {
-            /*
-             * PSK related to the given identity not found
-             */
-            SSLerr(SSL_F_SSL3_GET_CLIENT_KEY_EXCHANGE,
-                   SSL_R_PSK_IDENTITY_NOT_FOUND);
-            al = SSL_AD_UNKNOWN_PSK_IDENTITY;
-            goto psk_err;
-        }
-
-        /* create PSK pre_master_secret */
-        pre_ms_len = 2 + psk_len + 2 + psk_len;
-        t = psk_or_pre_ms;
-        memmove(psk_or_pre_ms + psk_len + 4, psk_or_pre_ms, psk_len);
-        s2n(psk_len, t);
-        memset(t, 0, psk_len);
-        t += psk_len;
-        s2n(psk_len, t);
-
-        OPENSSL_free(s->session->psk_identity);
-        s->session->psk_identity = BUF_strdup((char *)p);
-        if (s->session->psk_identity == NULL) {
-            SSLerr(SSL_F_SSL3_GET_CLIENT_KEY_EXCHANGE, ERR_R_MALLOC_FAILURE);
-            goto psk_err;
-        }
-
-        OPENSSL_free(s->session->psk_identity_hint);
-        s->session->psk_identity_hint = BUF_strdup(s->ctx->psk_identity_hint);
-        if (s->ctx->psk_identity_hint != NULL &&
-            s->session->psk_identity_hint == NULL) {
-            SSLerr(SSL_F_SSL3_GET_CLIENT_KEY_EXCHANGE, ERR_R_MALLOC_FAILURE);
-            goto psk_err;
-        }
-
-        if (!ssl_generate_master_secret(s, psk_or_pre_ms, pre_ms_len, 0)) {
-            al = SSL_AD_INTERNAL_ERROR;
-            SSLerr(SSL_F_SSL3_GET_CLIENT_KEY_EXCHANGE, ERR_R_INTERNAL_ERROR);
-            goto f_err;
-        }
-        psk_err = 0;
- psk_err:
-        if (psk_err != 0) {
-            OPENSSL_cleanse(psk_or_pre_ms, sizeof(psk_or_pre_ms));
-            goto f_err;
-        }
-    } else
-#endif
 #ifndef OPENSSL_NO_SRP
     if (alg_k & SSL_kSRP) {
         int param_len;
@@ -2819,6 +2829,10 @@ int ssl3_get_client_key_exchange(SSL *s)
     EC_POINT_free(clnt_ecpoint);
     EC_KEY_free(srvr_ecdh);
     BN_CTX_free(bn_ctx);
+#endif
+#ifndef OPENSSL_NO_PSK
+    OPENSSL_clear_free(s->s3->tmp.psk, s->s3->tmp.psklen);
+    s->s3->tmp.psk = NULL;
 #endif
     s->state = SSL_ST_ERR;
     return (-1);