Add new ctrl to retrieve client certificate types, print out
[openssl.git] / ssl / s3_lib.c
index 2c6e1ad..457a5c7 100644 (file)
@@ -3089,6 +3089,8 @@ static char * MS_CALLBACK srp_password_from_info_cb(SSL *s, void *arg)
        }
 #endif
 
+static int ssl3_set_req_cert_type(CERT *c, const unsigned char *p, size_t len);
+
 long ssl3_ctrl(SSL *s, int cmd, long larg, void *parg)
        {
        int ret=0;
@@ -3426,6 +3428,27 @@ long ssl3_ctrl(SSL *s, int cmd, long larg, void *parg)
        case SSL_CTRL_SET_CLIENT_SIGALGS_LIST:
                return tls1_set_sigalgs_list(s->cert, parg, 1);
 
+       case SSL_CTRL_GET_CLIENT_CERT_TYPES:
+               {
+               const unsigned char **pctype = parg;
+               if (s->server || !s->s3->tmp.cert_req)
+                       return 0;
+               if (s->cert->ctypes)
+                       {
+                       if (pctype)
+                               *pctype = s->cert->ctypes;
+                       return (int)s->cert->ctype_num;
+                       }
+               if (pctype)
+                       *pctype = (unsigned char *)s->s3->tmp.ctype;
+               return s->s3->tmp.ctype_num;
+               }
+
+       case SSL_CTRL_SET_CLIENT_CERT_TYPES:
+               if (!s->server)
+                       return 0;
+               return ssl3_set_req_cert_type(s->cert, parg, larg);
+
        default:
                break;
                }
@@ -3720,6 +3743,9 @@ long ssl3_ctx_ctrl(SSL_CTX *ctx, int cmd, long larg, void *parg)
        case SSL_CTRL_SET_CLIENT_SIGALGS_LIST:
                return tls1_set_sigalgs_list(ctx->cert, parg, 1);
 
+       case SSL_CTRL_SET_CLIENT_CERT_TYPES:
+               return ssl3_set_req_cert_type(ctx->cert, parg, larg);
+
        case SSL_CTRL_SET_TLSEXT_AUTHZ_SERVER_AUDIT_PROOF_CB_ARG:
                ctx->tlsext_authz_server_audit_proof_cb_arg = parg;
                break;
@@ -4014,8 +4040,61 @@ SSL_CIPHER *ssl3_choose_cipher(SSL *s, STACK_OF(SSL_CIPHER) *clnt,
 int ssl3_get_req_cert_type(SSL *s, unsigned char *p)
        {
        int ret=0;
+       const unsigned char *sig;
+       size_t siglen;
+       int have_rsa_sign = 0, have_dsa_sign = 0, have_ecdsa_sign = 0;
+       int nostrict = 1;
        unsigned long alg_k;
 
+       /* If we have custom certificate types set, use them */
+       if (s->cert->ctypes)
+               {
+               memcpy(p, s->cert->ctypes, s->cert->ctype_num);
+               return (int)s->cert->ctype_num;
+               }
+       /* Else see if we have any signature algorithms configured */
+       if (s->cert->client_sigalgs)
+               {
+               sig = s->cert->client_sigalgs;
+               siglen = s->cert->client_sigalgslen;
+               }
+       else
+               {
+               sig = s->cert->conf_sigalgs;
+               siglen = s->cert->conf_sigalgslen;
+               }
+       /* If we have sigalgs work out if we can sign with RSA, DSA, ECDSA */
+       if (sig)
+               {
+               size_t i;
+               if (s->cert->cert_flags & SSL_CERT_FLAG_TLS_STRICT)
+                       nostrict = 0;
+               for (i = 0; i < siglen; i+=2, sig+=2)
+                       {
+                       switch(sig[1])
+                               {
+                       case TLSEXT_signature_rsa:
+                               have_rsa_sign = 1;
+                               break;
+
+                       case TLSEXT_signature_dsa:
+                               have_dsa_sign = 1;
+                               break;
+
+                       case TLSEXT_signature_ecdsa:
+                               have_ecdsa_sign = 1;
+                               break;
+                               }
+                       }
+               }
+       /* Otherwise allow anything */
+       else
+               {
+               have_rsa_sign = 1;
+               have_dsa_sign = 1;
+               have_ecdsa_sign = 1;
+               }
+
        alg_k = s->s3->tmp.new_cipher->algorithm_mkey;
 
 #ifndef OPENSSL_NO_GOST
@@ -4034,10 +4113,15 @@ int ssl3_get_req_cert_type(SSL *s, unsigned char *p)
        if (alg_k & (SSL_kDHr|SSL_kEDH))
                {
 #  ifndef OPENSSL_NO_RSA
-               p[ret++]=SSL3_CT_RSA_FIXED_DH;
+               /* Since this refers to a certificate signed with an RSA
+                * algorithm, only check for rsa signing in strict mode.
+                */
+               if (nostrict || have_rsa_sign)
+                       p[ret++]=SSL3_CT_RSA_FIXED_DH;
 #  endif
 #  ifndef OPENSSL_NO_DSA
-               p[ret++]=SSL3_CT_DSS_FIXED_DH;
+               if (nostrict || have_dsa_sign)
+                       p[ret++]=SSL3_CT_DSS_FIXED_DH;
 #  endif
                }
        if ((s->version == SSL3_VERSION) &&
@@ -4052,16 +4136,20 @@ int ssl3_get_req_cert_type(SSL *s, unsigned char *p)
                }
 #endif /* !OPENSSL_NO_DH */
 #ifndef OPENSSL_NO_RSA
-       p[ret++]=SSL3_CT_RSA_SIGN;
+       if (have_rsa_sign)
+               p[ret++]=SSL3_CT_RSA_SIGN;
 #endif
 #ifndef OPENSSL_NO_DSA
-       p[ret++]=SSL3_CT_DSS_SIGN;
+       if (have_dsa_sign)
+               p[ret++]=SSL3_CT_DSS_SIGN;
 #endif
 #ifndef OPENSSL_NO_ECDH
        if ((alg_k & (SSL_kECDHr|SSL_kECDHe)) && (s->version >= TLS1_VERSION))
                {
-               p[ret++]=TLS_CT_RSA_FIXED_ECDH;
-               p[ret++]=TLS_CT_ECDSA_FIXED_ECDH;
+               if (nostrict || have_rsa_sign)
+                       p[ret++]=TLS_CT_RSA_FIXED_ECDH;
+               if (nostrict || have_ecdsa_sign)
+                       p[ret++]=TLS_CT_ECDSA_FIXED_ECDH;
                }
 #endif
 
@@ -4071,12 +4159,32 @@ int ssl3_get_req_cert_type(SSL *s, unsigned char *p)
         */
        if (s->version >= TLS1_VERSION)
                {
-               p[ret++]=TLS_CT_ECDSA_SIGN;
+               if (have_ecdsa_sign)
+                       p[ret++]=TLS_CT_ECDSA_SIGN;
                }
 #endif 
        return(ret);
        }
 
+static int ssl3_set_req_cert_type(CERT *c, const unsigned char *p, size_t len)
+       {
+       if (c->ctypes)
+               {
+               OPENSSL_free(c->ctypes);
+               c->ctypes = NULL;
+               }
+       if (!p || !len)
+               return 1;
+       if (len > 0xff)
+               return 0;
+       c->ctypes = OPENSSL_malloc(len);
+       if (!c->ctypes)
+               return 0;
+       memcpy(c->ctypes, p, len);
+       c->ctype_num = len;
+       return 1;
+       }
+
 int ssl3_shutdown(SSL *s)
        {
        int ret;