Add support for application defined signature algorithms for use with
[openssl.git] / ssl / t1_lib.c
index 85a5681f87a41df497151d545be6c9014acafc95..dcfecf4f5b79b077215bfd268d286bc5fc210c7f 100644 (file)
@@ -629,9 +629,29 @@ static unsigned char tls12_sigalgs[] = {
 #endif
 };
 
-int tls12_get_req_sig_algs(SSL *s, unsigned char *p)
+size_t tls12_get_sig_algs(SSL *s, unsigned char *p)
        {
-       size_t slen = sizeof(tls12_sigalgs);
+       TLS_SIGALGS *sptr = s->cert->conf_sigalgs;
+       size_t slen;
+
+       /* Use custom signature algorithms if any are set */
+
+       if (sptr)
+               {
+               slen = s->cert->conf_sigalgslen;
+               if (p)
+                       {
+                       size_t i;
+                       for (i = 0; i < slen; i++, sptr++)
+                               {
+                               *p++ = sptr->rhash;
+                               *p++ = sptr->rsign;
+                               }
+                       }
+               return slen * 2;
+               }
+               
+       slen = sizeof(tls12_sigalgs);
 #ifdef OPENSSL_FIPS
        /* If FIPS mode don't include MD5 which is last */
        if (FIPS_mode())
@@ -639,7 +659,7 @@ int tls12_get_req_sig_algs(SSL *s, unsigned char *p)
 #endif
        if (p)
                memcpy(p, tls12_sigalgs, slen);
-       return (int)slen;
+       return slen;
        }
 
 /* byte_compare is a compare function for qsort(3) that compares bytes. */
@@ -874,13 +894,15 @@ unsigned char *ssl_add_clienthello_tlsext(SSL *s, unsigned char *p, unsigned cha
 
        if (TLS1_get_client_version(s) >= TLS1_2_VERSION)
                {
-               if ((size_t)(limit - ret) < sizeof(tls12_sigalgs) + 6)
+               size_t salglen;
+               salglen = tls12_get_sig_algs(s, NULL);
+               if ((size_t)(limit - ret) < salglen + 6)
                        return NULL; 
                s2n(TLSEXT_TYPE_signature_algorithms,ret);
-               s2n(sizeof(tls12_sigalgs) + 2, ret);
-               s2n(sizeof(tls12_sigalgs), ret);
-               memcpy(ret, tls12_sigalgs, sizeof(tls12_sigalgs));
-               ret += sizeof(tls12_sigalgs);
+               s2n(salglen + 2, ret);
+               s2n(salglen, ret);
+               tls12_get_sig_algs(s, ret);
+               ret += salglen;
                }
 
 #ifdef TLSEXT_TYPE_opaque_prf_input
@@ -1781,6 +1803,8 @@ static int ssl_scan_clienthello_tlsext(SSL *s, unsigned char **p, unsigned char
                        if (!s->hit)
                                {
                                size_t i;
+                               if (s->s3->tlsext_authz_client_types != NULL)
+                                       OPENSSL_free(s->s3->tlsext_authz_client_types);
                                s->s3->tlsext_authz_client_types =
                                        OPENSSL_malloc(server_authz_dataformatlist_length);
                                if (!s->s3->tlsext_authz_client_types)
@@ -2857,14 +2881,14 @@ int tls1_process_sigalgs(SSL *s, const unsigned char *data, int dsize)
        c->pkeys[SSL_PKEY_RSA_ENC].digest = NULL;
        c->pkeys[SSL_PKEY_ECC].digest = NULL;
 
-       if (c->sigalgs)
-               OPENSSL_free(c->sigalgs);
-       c->sigalgs = OPENSSL_malloc((dsize/2) * sizeof(TLS_SIGALGS));
-       if (!c->sigalgs)
+       if (c->peer_sigalgs)
+               OPENSSL_free(c->peer_sigalgs);
+       c->peer_sigalgs = OPENSSL_malloc((dsize/2) * sizeof(TLS_SIGALGS));
+       if (!c->peer_sigalgs)
                return 0;
-       c->sigalgslen = dsize/2;
+       c->peer_sigalgslen = dsize/2;
 
-       for (i = 0, sigptr = c->sigalgs; i < dsize; i += 2, sigptr++)
+       for (i = 0, sigptr = c->peer_sigalgs; i < dsize; i += 2, sigptr++)
                {
                sigptr->rhash = data[i];
                sigptr->rsign = data[i + 1];
@@ -2938,14 +2962,14 @@ int SSL_get_sigalgs(SSL *s, int idx,
                        int *psign, int *phash, int *psignandhash,
                        unsigned char *rsig, unsigned char *rhash)
        {
-       if (s->cert->sigalgs == NULL)
+       if (s->cert->peer_sigalgs == NULL)
                return 0;
        if (idx >= 0)
                {
                TLS_SIGALGS *psig;
-               if (idx >= (int)s->cert->sigalgslen)
+               if (idx >= (int)s->cert->peer_sigalgslen)
                        return 0;
-               psig = s->cert->sigalgs + idx;
+               psig = s->cert->peer_sigalgs + idx;
                if (psign)
                        *psign = psig->sign_nid;
                if (phash)
@@ -2957,7 +2981,7 @@ int SSL_get_sigalgs(SSL *s, int idx,
                if (rhash)
                        *rhash = psig->rhash;
                }
-       return s->cert->sigalgslen;
+       return s->cert->peer_sigalgslen;
        }
        
 
@@ -3105,3 +3129,110 @@ tls1_heartbeat(SSL *s)
        return ret;
        }
 #endif
+
+#define MAX_SIGALGLEN  (TLSEXT_hash_num * TLSEXT_signature_num *2)
+
+typedef struct
+       {
+       size_t sigalgcnt;
+       int sigalgs[MAX_SIGALGLEN];
+       } sig_cb_st;
+
+static int sig_cb(const char *elem, int len, void *arg)
+       {
+       sig_cb_st *sarg = arg;
+       size_t i;
+       char etmp[20], *p;
+       int sig_alg, hash_alg;
+       if (sarg->sigalgcnt == MAX_SIGALGLEN)
+               return 0;
+       if (len > (int)(sizeof(etmp) - 1))
+               return 0;
+       memcpy(etmp, elem, len);
+       etmp[len] = 0;
+       p = strchr(etmp, '+');
+       if (!p)
+               return 0;
+       *p = 0;
+       p++;
+       if (!*p)
+               return 0;
+
+       if (!strcmp(etmp, "RSA"))
+               sig_alg = EVP_PKEY_RSA;
+       else if (!strcmp(etmp, "DSA"))
+               sig_alg = EVP_PKEY_DSA;
+       else if (!strcmp(etmp, "ECDSA"))
+               sig_alg = EVP_PKEY_EC;
+       else return 0;
+
+       hash_alg = OBJ_sn2nid(p);
+       if (hash_alg == NID_undef)
+               hash_alg = OBJ_ln2nid(p);
+       if (hash_alg == NID_undef)
+               return 0;
+
+       for (i = 0; i < sarg->sigalgcnt; i+=2)
+               {
+               if (sarg->sigalgs[i] == sig_alg
+                       && sarg->sigalgs[i + 1] == hash_alg)
+                       return 0;
+               }
+       sarg->sigalgs[sarg->sigalgcnt++] = hash_alg;
+       sarg->sigalgs[sarg->sigalgcnt++] = sig_alg;
+       return 1;
+       }
+
+/* Set suppored signature algorithms based on a colon separated list
+ * of the form sig+hash e.g. RSA+SHA512:DSA+SHA512 */
+int tls1_set_sigalgs_list(CERT *c, const char *str)
+       {
+       sig_cb_st sig;
+       sig.sigalgcnt = 0;
+       if (!CONF_parse_list(str, ':', 1, sig_cb, &sig))
+               return 0;
+       return tls1_set_sigalgs(c, sig.sigalgs, sig.sigalgcnt);
+       }
+
+int tls1_set_sigalgs(CERT *c, const int *salg, size_t salglen)
+       {
+       TLS_SIGALGS *sigalgs, *sptr;
+       int rhash, rsign;
+       size_t i;
+       if (salglen & 1)
+               return 0;
+       salglen /= 2;
+       sigalgs = OPENSSL_malloc(sizeof(TLS_SIGALGS) * salglen);
+       if (sigalgs == NULL)
+               return 0;
+       for (i = 0, sptr = sigalgs; i < salglen; i++, sptr++)
+               {
+               sptr->hash_nid = *salg++;
+               sptr->sign_nid = *salg++;
+               rhash = tls12_find_id(sptr->hash_nid, tls12_md,
+                                       sizeof(tls12_md)/sizeof(tls12_lookup));
+               rsign = tls12_find_id(sptr->sign_nid, tls12_sig,
+                               sizeof(tls12_sig)/sizeof(tls12_lookup));
+
+               if (rhash == -1 || rsign == -1)
+                       goto err;
+
+               if (!OBJ_find_sigid_by_algs(&sptr->signandhash_nid,
+                                               sptr->hash_nid,
+                                               sptr->sign_nid))
+                       sptr->signandhash_nid = NID_undef;
+               sptr->rhash = rhash;
+               sptr->rsign = rsign;
+               }
+
+       if (c->conf_sigalgs)
+               OPENSSL_free(c->conf_sigalgs);
+
+       c->conf_sigalgs = sigalgs;
+       c->conf_sigalgslen = salglen;
+       return 1;
+
+       err:
+       OPENSSL_free(sigalgs);
+       return 0;
+       }