EVP & TLS: Add necessary EC_KEY data extraction functions, and use them
[openssl.git] / ssl / t1_lib.c
index 76096401bed3f20100de610d80321c20394e9cf2..beadf28f111ebbdf018bde4ee1c76133657b2e6a 100644 (file)
@@ -22,6 +22,7 @@
 #include <openssl/dh.h>
 #include <openssl/bn.h>
 #include "internal/nelem.h"
+#include "internal/evp.h"
 #include "ssl_local.h"
 #include <openssl/ct.h>
 
@@ -583,7 +584,7 @@ static int tls1_check_pkey_comp(SSL *s, EVP_PKEY *pkey)
     size_t i;
 
     /* If not an EC key nothing to check */
-    if (EVP_PKEY_id(pkey) != EVP_PKEY_EC)
+    if (!EVP_PKEY_is_a(pkey, "EC"))
         return 1;
     ec = EVP_PKEY_get0_EC_KEY(pkey);
     grp = EC_KEY_get0_group(ec);
@@ -624,13 +625,11 @@ static int tls1_check_pkey_comp(SSL *s, EVP_PKEY *pkey)
 /* Return group id of a key */
 static uint16_t tls1_get_group_id(EVP_PKEY *pkey)
 {
-    EC_KEY *ec = EVP_PKEY_get0_EC_KEY(pkey);
-    const EC_GROUP *grp;
+    int curve_nid = evp_pkey_get_EC_KEY_curve_nid(pkey);
 
-    if (ec == NULL)
+    if (curve_nid == NID_undef)
         return 0;
-    grp = EC_KEY_get0_group(ec);
-    return tls1_nid2group_id(EC_GROUP_get_curve_name(grp));
+    return tls1_nid2group_id(curve_nid);
 }
 
 /*
@@ -645,7 +644,7 @@ static int tls1_check_cert_param(SSL *s, X509 *x, int check_ee_md)
     if (pkey == NULL)
         return 0;
     /* If not EC nothing to do */
-    if (EVP_PKEY_id(pkey) != EVP_PKEY_EC)
+    if (!EVP_PKEY_is_a(pkey, "EC"))
         return 1;
     /* Check compression */
     if (!tls1_check_pkey_comp(s, pkey))
@@ -1111,10 +1110,22 @@ int tls12_check_peer_sigalg(SSL *s, uint16_t sig, EVP_PKEY *pkey)
     const EVP_MD *md = NULL;
     char sigalgstr[2];
     size_t sent_sigslen, i, cidx;
-    int pkeyid = EVP_PKEY_id(pkey);
+    int pkeyid = -1;
     const SIGALG_LOOKUP *lu;
     int secbits = 0;
 
+    /*
+     * TODO(3.0) Remove this when we adapted this function for provider
+     * side keys.  We know that EVP_PKEY_get0() downgrades an EVP_PKEY
+     * to contain a legacy key.
+     *
+     * THIS IS TEMPORARY
+     */
+    EVP_PKEY_get0(pkey);
+    if (EVP_PKEY_id(pkey) == EVP_PKEY_NONE)
+        return 0;
+
+    pkeyid = EVP_PKEY_id(pkey);
     /* Should never happen */
     if (pkeyid == -1)
         return -1;
@@ -1163,8 +1174,7 @@ int tls12_check_peer_sigalg(SSL *s, uint16_t sig, EVP_PKEY *pkey)
 
         /* For TLS 1.3 or Suite B check curve matches signature algorithm */
         if (SSL_IS_TLS13(s) || tls1_suiteb(s)) {
-            EC_KEY *ec = EVP_PKEY_get0_EC_KEY(pkey);
-            int curve = EC_GROUP_get_curve_name(EC_KEY_get0_group(ec));
+            int curve = evp_pkey_get_EC_KEY_curve_nid(pkey);
 
             if (lu->curve != NID_undef && curve != lu->curve) {
                 SSLfatal(s, SSL_AD_ILLEGAL_PARAMETER,
@@ -1521,21 +1531,29 @@ SSL_TICKET_STATUS tls_decrypt_ticket(SSL *s, const unsigned char *etick,
         if (rv == 2)
             renew_ticket = 1;
     } else {
+        EVP_CIPHER *aes256cbc = NULL;
+
         /* Check key name matches */
         if (memcmp(etick, tctx->ext.tick_key_name,
                    TLSEXT_KEYNAME_LENGTH) != 0) {
             ret = SSL_TICKET_NO_DECRYPT;
             goto end;
         }
-        if (ssl_hmac_init(hctx, tctx->ext.secure->tick_hmac_key,
-                          sizeof(tctx->ext.secure->tick_hmac_key),
-                          "SHA256") <= 0
-            || EVP_DecryptInit_ex(ctx, EVP_aes_256_cbc(), NULL,
+
+        aes256cbc = EVP_CIPHER_fetch(s->ctx->libctx, "AES-256-CBC",
+                                     s->ctx->propq);
+        if (aes256cbc == NULL
+            || ssl_hmac_init(hctx, tctx->ext.secure->tick_hmac_key,
+                             sizeof(tctx->ext.secure->tick_hmac_key),
+                             "SHA256") <= 0
+            || EVP_DecryptInit_ex(ctx, aes256cbc, NULL,
                                   tctx->ext.secure->tick_aes_key,
                                   etick + TLSEXT_KEYNAME_LENGTH) <= 0) {
+            EVP_CIPHER_free(aes256cbc);
             ret = SSL_TICKET_FATAL_ERR_OTHER;
             goto end;
         }
+        EVP_CIPHER_free(aes256cbc);
         if (SSL_IS_TLS13(s))
             renew_ticket = 1;
     }
@@ -2441,17 +2459,14 @@ int tls1_check_chain(SSL *s, X509 *x, EVP_PKEY *pk, STACK_OF(X509) *chain,
     if (!s->server && strict_mode) {
         STACK_OF(X509_NAME) *ca_dn;
         int check_type = 0;
-        switch (EVP_PKEY_id(pk)) {
-        case EVP_PKEY_RSA:
+
+        if (EVP_PKEY_is_a(pk, "RSA"))
             check_type = TLS_CT_RSA_SIGN;
-            break;
-        case EVP_PKEY_DSA:
+        else if (EVP_PKEY_is_a(pk, "DSA"))
             check_type = TLS_CT_DSS_SIGN;
-            break;
-        case EVP_PKEY_EC:
+        else if (EVP_PKEY_is_a(pk, "EC"))
             check_type = TLS_CT_ECDSA_SIGN;
-            break;
-        }
+
         if (check_type) {
             const uint8_t *ctypes = s->s3.tmp.ctype;
             size_t j;
@@ -2812,10 +2827,8 @@ static const SIGALG_LOOKUP *find_sig_alg(SSL *s, X509 *x, EVP_PKEY *pkey)
 
         if (lu->sig == EVP_PKEY_EC) {
 #ifndef OPENSSL_NO_EC
-            if (curve == -1) {
-                EC_KEY *ec = EVP_PKEY_get0_EC_KEY(tmppkey);
-                curve = EC_GROUP_get_curve_name(EC_KEY_get0_group(ec));
-            }
+            if (curve == -1)
+                curve = evp_pkey_get_EC_KEY_curve_nid(tmppkey);
             if (lu->curve != NID_undef && curve != lu->curve)
                 continue;
 #else
@@ -2874,15 +2887,13 @@ int tls_choose_sigalg(SSL *s, int fatalerrs)
             size_t i;
             if (s->s3.tmp.peer_sigalgs != NULL) {
 #ifndef OPENSSL_NO_EC
-                int curve;
+                int curve = -1;
 
                 /* For Suite B need to match signature algorithm to curve */
-                if (tls1_suiteb(s)) {
-                    EC_KEY *ec = EVP_PKEY_get0_EC_KEY(s->cert->pkeys[SSL_PKEY_ECC].privatekey);
-                    curve = EC_GROUP_get_curve_name(EC_KEY_get0_group(ec));
-                } else {
-                    curve = -1;
-                }
+                if (tls1_suiteb(s))
+                    curve =
+                        evp_pkey_get_EC_KEY_curve_nid(s->cert->pkeys[SSL_PKEY_ECC]
+                                                      .privatekey);
 #endif
 
                 /*
@@ -2956,7 +2967,7 @@ int tls_choose_sigalg(SSL *s, int fatalerrs)
                     if (!fatalerrs)
                         return 1;
                     SSLfatal(s, SSL_AD_INTERNAL_ERROR, SSL_F_TLS_CHOOSE_SIGALG,
-                             ERR_R_INTERNAL_ERROR);
+                             SSL_R_NO_SUITABLE_SIGNATURE_ALGORITHM);
                     return 0;
                 }
 
@@ -2981,7 +2992,7 @@ int tls_choose_sigalg(SSL *s, int fatalerrs)
                 if (!fatalerrs)
                     return 1;
                 SSLfatal(s, SSL_AD_INTERNAL_ERROR, SSL_F_TLS_CHOOSE_SIGALG,
-                         ERR_R_INTERNAL_ERROR);
+                         SSL_R_NO_SUITABLE_SIGNATURE_ALGORITHM);
                 return 0;
             }
         }