Replace TLS_SIGALGS with SIGALG_LOOKUP
authorDr. Stephen Henson <steve@openssl.org>
Thu, 26 Jan 2017 15:24:35 +0000 (15:24 +0000)
committerDr. Stephen Henson <steve@openssl.org>
Mon, 30 Jan 2017 13:00:17 +0000 (13:00 +0000)
Since every supported signature algorithm is now an entry in the
SIGALG_LOOKUP table we can replace shared signature algortihms with
pointers to constant table entries.

Reviewed-by: Richard Levitte <levitte@openssl.org>
Reviewed-by: Matt Caswell <matt@openssl.org>
(Merged from https://github.com/openssl/openssl/pull/2301)

ssl/ssl_locl.h
ssl/t1_lib.c

index df24b294ad6a656c3dc6ce40f880046fd76521d8..13be4f379592bde4e1470a3575d18476ad37cf9c 100644 (file)
@@ -1509,6 +1509,25 @@ typedef struct {
     size_t meths_count;
 } custom_ext_methods;
 
+/*
+ * Structure containing table entry of values associated with the signature
+ * algorithms (signature scheme) extension
+*/
+typedef struct sigalg_lookup_st {
+    /* TLS 1.3 signature scheme name */
+    const char *name;
+    /* Raw value used in extension */
+    uint16_t sigalg;
+    /* NID of hash algorithm */
+    int hash;
+    /* NID of signature algorithm */
+    int sig;
+    /* Combined hash and signature NID, if any */
+    int sigandhash;
+    /* Required public key curve (ECDSA only) */
+    int curve;
+} SIGALG_LOOKUP;
+
 typedef struct cert_st {
     /* Current active set */
     /*
@@ -1554,7 +1573,7 @@ typedef struct cert_st {
      * Signature algorithms shared by client and server: cached because these
      * are used most often.
      */
-    TLS_SIGALGS *shared_sigalgs;
+    const SIGALG_LOOKUP **shared_sigalgs;
     size_t shared_sigalgslen;
     /*
      * Certificate setup callback: if set is called whenever a certificate
@@ -1588,18 +1607,6 @@ typedef struct cert_st {
     CRYPTO_RWLOCK *lock;
 } CERT;
 
-/* Structure containing decoded values of signature algorithms extension */
-struct tls_sigalgs_st {
-    /* NID of hash algorithm */
-    int hash_nid;
-    /* NID of signature algorithm */
-    int sign_nid;
-    /* Combined hash and signature NID */
-    int signandhash_nid;
-    /* Raw value used in extension */
-    uint16_t rsigalg;
-};
-
 # define FP_ICC  (int (*)(const void *,const void *))
 
 /*
index fc10dc1b71a9d783e0215bba8df55832e30b5aec..dbd0fb6cc6c556da974ada1c0648ceecd0a2a961 100644 (file)
@@ -607,7 +607,7 @@ static int tls1_check_cert_param(SSL *s, X509 *x, int set_ee_md)
         else
             return 0;           /* Should never happen */
         for (i = 0; i < c->shared_sigalgslen; i++)
-            if (check_md == c->shared_sigalgs[i].signandhash_nid)
+            if (check_md == c->shared_sigalgs[i]->sigandhash)
                 break;
         if (i == c->shared_sigalgslen)
             return 0;
@@ -705,15 +705,6 @@ static const uint16_t suiteb_sigalgs[] = {
 };
 #endif
 
-typedef struct sigalg_lookup_st {
-    const char *name;
-    uint16_t sigalg;
-    int hash;
-    int sig;
-    int sigandhash;
-    int curve;
-} SIGALG_LOOKUP;
-
 static const SIGALG_LOOKUP sigalg_lookup_tbl[] = {
 #ifndef OPENSSL_NO_EC
     {"ecdsa_secp256r1_sha256", TLSEXT_SIGALG_ecdsa_secp256r1_sha256,
@@ -761,32 +752,32 @@ static const SIGALG_LOOKUP sigalg_lookup_tbl[] = {
 #endif
 };
 
-static int tls_sigalg_get_hash(uint16_t sigalg)
+/* Lookup TLS signature algorithm */
+static const SIGALG_LOOKUP *tls1_lookup_sigalg(uint16_t sigalg)
 {
     size_t i;
-    const SIGALG_LOOKUP *curr;
+    const SIGALG_LOOKUP *s;
 
-    for (i = 0, curr = sigalg_lookup_tbl; i < OSSL_NELEM(sigalg_lookup_tbl);
-         i++, curr++) {
-        if (curr->sigalg == sigalg)
-            return curr->hash;
+    for (i = 0, s = sigalg_lookup_tbl; i < OSSL_NELEM(sigalg_lookup_tbl);
+         i++, s++) {
+        if (s->sigalg == sigalg)
+            return s;
     }
+    return NULL;
+}
 
-    return 0;
+static int tls_sigalg_get_hash(uint16_t sigalg)
+{
+    const SIGALG_LOOKUP *r = tls1_lookup_sigalg(sigalg);
+
+    return r != NULL ? r->hash : 0;
 }
 
 static int tls_sigalg_get_sig(uint16_t sigalg)
 {
-    size_t i;
-    const SIGALG_LOOKUP *curr;
-
-    for (i = 0, curr = sigalg_lookup_tbl; i < OSSL_NELEM(sigalg_lookup_tbl);
-         i++, curr++) {
-        if (curr->sigalg == sigalg)
-            return curr->sig;
-    }
+    const SIGALG_LOOKUP *r = tls1_lookup_sigalg(sigalg);
 
-    return 0;
+    return r != NULL ? r->sig : 0;
 }
 
 size_t tls12_get_psigalgs(SSL *s, int sent, const uint16_t **psigs)
@@ -840,6 +831,7 @@ int tls12_check_peer_sigalg(SSL *s, unsigned int sig, EVP_PKEY *pkey)
     size_t sent_sigslen, i;
     int pkeyid = EVP_PKEY_id(pkey);
     int peer_sigtype;
+
     /* Should never happen */
     if (pkeyid == -1)
         return -1;
@@ -1286,7 +1278,6 @@ int tls12_get_sigandhash(SSL *s, WPACKET *pkt, const EVP_PKEY *pk,
 {
     int md_id, sig_id;
     size_t i;
-    const TLS_SIGALGS *curr;
 
     if (md == NULL)
         return 0;
@@ -1298,17 +1289,18 @@ int tls12_get_sigandhash(SSL *s, WPACKET *pkt, const EVP_PKEY *pk,
     if (SSL_IS_TLS13(s) && sig_id == EVP_PKEY_RSA)
         sig_id = EVP_PKEY_RSA_PSS;
 
-    for (i = 0, curr = s->cert->shared_sigalgs; i < s->cert->shared_sigalgslen;
-         i++, curr++) {
+    for (i = 0; i < s->cert->shared_sigalgslen; i++) {
+        const SIGALG_LOOKUP *curr = s->cert->shared_sigalgs[i];
+
         /*
          * Look for matching key and hash. If key type is RSA also match PSS
          * signature type.
          */
-        if (curr->hash_nid == md_nid && (curr->sign_nid == sig_id
-            || (sig_id == EVP_PKEY_RSA && curr->sign_nid == EVP_PKEY_RSA_PSS))){
-            if (!WPACKET_put_bytes_u16(pkt, curr->rsigalg))
+        if (curr->hash == md_id && (curr->sig == sig_id
+            || (sig_id == EVP_PKEY_RSA && curr->sig == EVP_PKEY_RSA_PSS))){
+            if (!WPACKET_put_bytes_u16(pkt, curr->sigalg))
                 return 0;
-            *ispss = curr->sign_nid == EVP_PKEY_RSA_PSS;
+            *ispss = curr->sig == EVP_PKEY_RSA_PSS;
             return 1;
         }
     }
@@ -1393,30 +1385,6 @@ static int tls12_get_pkey_idx(int sig_nid)
     return -1;
 }
 
-/* Convert TLS 1.2 signature algorithm extension values into NIDs */
-static void tls1_lookup_sigalg(int *phash_nid, int *psign_nid,
-                               int *psignhash_nid, uint16_t data)
-{
-    int sign_nid = NID_undef, hash_nid = NID_undef;
-    if (!phash_nid && !psign_nid && !psignhash_nid)
-        return;
-    if (phash_nid || psignhash_nid) {
-        hash_nid = tls_sigalg_get_hash(data);
-        if (phash_nid)
-            *phash_nid = hash_nid;
-    }
-    if (psign_nid || psignhash_nid) {
-        sign_nid = tls_sigalg_get_sig(data);
-        if (psign_nid)
-            *psign_nid = sign_nid;
-    }
-    if (psignhash_nid) {
-        if (sign_nid == NID_undef || hash_nid == NID_undef
-            || OBJ_find_sigid_by_algs(psignhash_nid, hash_nid, sign_nid) <= 0)
-            *psignhash_nid = NID_undef;
-    }
-}
-
 /* Check to see if a signature algorithm is allowed */
 static int tls12_sigalg_allowed(SSL *s, int op, unsigned int ptmp)
 {
@@ -1500,7 +1468,7 @@ int tls12_copy_sigalgs(SSL *s, WPACKET *pkt,
 }
 
 /* Given preference and allowed sigalgs set shared sigalgs */
-static size_t tls12_shared_sigalgs(SSL *s, TLS_SIGALGS *shsig,
+static size_t tls12_shared_sigalgs(SSL *s, const SIGALG_LOOKUP **shsig,
                                    const uint16_t *pref, size_t preflen,
                                    const uint16_t *allow, size_t allowlen)
 {
@@ -1514,10 +1482,7 @@ static size_t tls12_shared_sigalgs(SSL *s, TLS_SIGALGS *shsig,
             if (*ptmp == *atmp) {
                 nmatch++;
                 if (shsig) {
-                    shsig->rsigalg = *ptmp;
-                    tls1_lookup_sigalg(&shsig->hash_nid,
-                                       &shsig->sign_nid,
-                                       &shsig->signandhash_nid, *ptmp);
+                    *shsig = tls1_lookup_sigalg(*ptmp);
                     shsig++;
                 }
                 break;
@@ -1533,7 +1498,7 @@ static int tls1_set_shared_sigalgs(SSL *s)
     const uint16_t *pref, *allow, *conf;
     size_t preflen, allowlen, conflen;
     size_t nmatch;
-    TLS_SIGALGS *salgs = NULL;
+    const SIGALG_LOOKUP **salgs = NULL;
     CERT *c = s->cert;
     unsigned int is_suiteb = tls1_suiteb(s);
 
@@ -1562,7 +1527,7 @@ static int tls1_set_shared_sigalgs(SSL *s)
     }
     nmatch = tls12_shared_sigalgs(s, NULL, pref, preflen, allow, allowlen);
     if (nmatch) {
-        salgs = OPENSSL_malloc(nmatch * sizeof(TLS_SIGALGS));
+        salgs = OPENSSL_malloc(nmatch * sizeof(*salgs));
         if (salgs == NULL)
             return 0;
         nmatch = tls12_shared_sigalgs(s, salgs, pref, preflen, allow, allowlen);
@@ -1620,18 +1585,19 @@ int tls1_process_sigalgs(SSL *s)
     const EVP_MD **pmd = s->s3->tmp.md;
     uint32_t *pvalid = s->s3->tmp.valid_flags;
     CERT *c = s->cert;
-    TLS_SIGALGS *sigptr;
+
     if (!tls1_set_shared_sigalgs(s))
         return 0;
 
-    for (i = 0, sigptr = c->shared_sigalgs;
-         i < c->shared_sigalgslen; i++, sigptr++) {
+    for (i = 0; i < c->shared_sigalgslen; i++) {
+        const SIGALG_LOOKUP *sigptr = c->shared_sigalgs[i];
+
         /* Ignore PKCS1 based sig algs in TLSv1.3 */
-        if (SSL_IS_TLS13(s) && sigptr->sign_nid == EVP_PKEY_RSA)
+        if (SSL_IS_TLS13(s) && sigptr->sig == EVP_PKEY_RSA)
             continue;
-        idx = tls12_get_pkey_idx(sigptr->sign_nid);
+        idx = tls12_get_pkey_idx(sigptr->sig);
         if (idx > 0 && pmd[idx] == NULL) {
-            md = tls12_get_hash(sigptr->hash_nid);
+            md = tls12_get_hash(sigptr->hash);
             pmd[idx] = md;
             pvalid[idx] = CERT_PKEY_EXPLICIT_SIGN;
             if (idx == SSL_PKEY_RSA_SIGN) {
@@ -1688,14 +1654,22 @@ int SSL_get_sigalgs(SSL *s, int idx,
     if (psig == NULL || numsigalgs > INT_MAX)
         return 0;
     if (idx >= 0) {
+        const SIGALG_LOOKUP *lu;
+
         if (idx >= (int)numsigalgs)
             return 0;
         psig += idx;
-        if (rhash)
+        if (rhash != NULL)
             *rhash = (unsigned char)((*psig >> 8) & 0xff);
-        if (rsig)
+        if (rsig != NULL)
             *rsig = (unsigned char)(*psig & 0xff);
-        tls1_lookup_sigalg(phash, psign, psignhash, *psig);
+        lu = tls1_lookup_sigalg(*psig);
+        if (psign != NULL)
+            *psign = lu != NULL ? lu->sig : NID_undef;
+        if (phash != NULL)
+            *phash = lu != NULL ? lu->hash : NID_undef;
+        if (psignhash != NULL)
+            *psignhash = lu != NULL ? lu->sigandhash : NID_undef;
     }
     return (int)numsigalgs;
 }
@@ -1704,21 +1678,22 @@ int SSL_get_shared_sigalgs(SSL *s, int idx,
                            int *psign, int *phash, int *psignhash,
                            unsigned char *rsig, unsigned char *rhash)
 {
-    TLS_SIGALGS *shsigalgs = s->cert->shared_sigalgs;
-    if (!shsigalgs || idx >= (int)s->cert->shared_sigalgslen
-            || s->cert->shared_sigalgslen > INT_MAX)
+    const SIGALG_LOOKUP *shsigalgs;
+    if (s->cert->shared_sigalgs == NULL
+        || idx >= (int)s->cert->shared_sigalgslen
+        || s->cert->shared_sigalgslen > INT_MAX)
         return 0;
-    shsigalgs += idx;
-    if (phash)
-        *phash = shsigalgs->hash_nid;
-    if (psign)
-        *psign = shsigalgs->sign_nid;
-    if (psignhash)
-        *psignhash = shsigalgs->signandhash_nid;
-    if (rsig)
-        *rsig = (unsigned char)(shsigalgs->rsigalg & 0xff);
-    if (rhash)
-        *rhash = (unsigned char)((shsigalgs->rsigalg >> 8) & 0xff);
+    shsigalgs = s->cert->shared_sigalgs[idx];
+    if (phash != NULL)
+        *phash = shsigalgs->hash;
+    if (psign != NULL)
+        *psign = shsigalgs->sig;
+    if (psignhash != NULL)
+        *psignhash = shsigalgs->sigandhash;
+    if (rsig != NULL)
+        *rsig = (unsigned char)(shsigalgs->sigalg & 0xff);
+    if (rhash != NULL)
+        *rhash = (unsigned char)((shsigalgs->sigalg >> 8) & 0xff);
     return (int)s->cert->shared_sigalgslen;
 }
 
@@ -1864,7 +1839,7 @@ static int tls1_check_sig_alg(CERT *c, X509 *x, int default_nid)
     if (default_nid)
         return sig_nid == default_nid ? 1 : 0;
     for (i = 0; i < c->shared_sigalgslen; i++)
-        if (sig_nid == c->shared_sigalgs[i].signandhash_nid)
+        if (sig_nid == c->shared_sigalgs[i]->sigandhash)
             return 1;
     return 0;
 }