Implement Async SSL_shutdown
[openssl.git] / ssl / ssl_lib.c
index 90de7472c70c63e1710a69814da0759ba7041be9..a43ec52736b127e0d309bec4eeb72f07dfa690cc 100644 (file)
@@ -190,10 +190,11 @@ struct ssl_async_args {
     SSL *s;
     void *buf;
     int num;
-    int type;
+    enum { READFUNC, WRITEFUNC,  OTHERFUNC} type;
     union {
-        int (*func1)(SSL *, void *, int);
-        int (*func2)(SSL *, const void *, int);
+        int (*func_read)(SSL *, void *, int);
+        int (*func_write)(SSL *, const void *, int);
+        int (*func_other)(SSL *);
     } f;
 };
 
@@ -745,6 +746,11 @@ SSL *SSL_new(SSL_CTX *ctx)
     return (NULL);
 }
 
+void SSL_up_ref(SSL *s)
+{
+    CRYPTO_add(&s->references, 1, CRYPTO_LOCK_SSL);
+}
+
 int SSL_CTX_set_session_id_context(SSL_CTX *ctx, const unsigned char *sid_ctx,
                                    unsigned int sid_ctx_len)
 {
@@ -872,18 +878,24 @@ int SSL_dane_enable(SSL *s, const char *basedomain)
         return 0;
     }
 
+    /*
+     * Default SNI name.  This rejects empty names, while set1_host below
+     * accepts them and disables host name checks.  To avoid side-effects with
+     * invalid input, set the SNI name first.
+     */
+    if (s->tlsext_hostname == NULL) {
+       if (!SSL_set_tlsext_host_name(s, basedomain)) {
+            SSLerr(SSL_F_SSL_DANE_ENABLE, SSL_R_ERROR_SETTING_TLSA_BASE_DOMAIN);
+           return -1;
+        }
+    }
+
     /* Primary RFC6125 reference identifier */
     if (!X509_VERIFY_PARAM_set1_host(s->param, basedomain, 0)) {
         SSLerr(SSL_F_SSL_DANE_ENABLE, SSL_R_ERROR_SETTING_TLSA_BASE_DOMAIN);
         return -1;
     }
 
-    /* Default SNI name */
-    if (s->tlsext_hostname == NULL) {
-       if (!SSL_set_tlsext_host_name(s, basedomain))
-           return -1;
-    }
-
     dane->mdpth = -1;
     dane->pdpth = -1;
     dane->dctx = &s->ctx->dane;
@@ -1458,10 +1470,15 @@ static int ssl_io_intern(void *vargs)
     s = args->s;
     buf = args->buf;
     num = args->num;
-    if (args->type == 1)
-        return args->f.func1(s, buf, num);
-    else
-        return args->f.func2(s, buf, num);
+    switch (args->type) {
+    case READFUNC:
+        return args->f.func_read(s, buf, num);
+    case WRITEFUNC:
+        return args->f.func_write(s, buf, num);
+    case OTHERFUNC:
+        return args->f.func_other(s);
+    }
+    return -1;
 }
 
 int SSL_read(SSL *s, void *buf, int num)
@@ -1482,8 +1499,8 @@ int SSL_read(SSL *s, void *buf, int num)
         args.s = s;
         args.buf = buf;
         args.num = num;
-        args.type = 1;
-        args.f.func1 = s->method->ssl_read;
+        args.type = READFUNC;
+        args.f.func_read = s->method->ssl_read;
 
         return ssl_start_async_job(s, &args, ssl_io_intern);
     } else {
@@ -1507,8 +1524,8 @@ int SSL_peek(SSL *s, void *buf, int num)
         args.s = s;
         args.buf = buf;
         args.num = num;
-        args.type = 1;
-        args.f.func1 = s->method->ssl_peek;
+        args.type = READFUNC;
+        args.f.func_read = s->method->ssl_peek;
 
         return ssl_start_async_job(s, &args, ssl_io_intern);
     } else {
@@ -1535,8 +1552,8 @@ int SSL_write(SSL *s, const void *buf, int num)
         args.s = s;
         args.buf = (void *)buf;
         args.num = num;
-        args.type = 2;
-        args.f.func2 = s->method->ssl_write;
+        args.type = WRITEFUNC;
+        args.f.func_write = s->method->ssl_write;
 
         return ssl_start_async_job(s, &args, ssl_io_intern);
     } else {
@@ -1558,10 +1575,19 @@ int SSL_shutdown(SSL *s)
         return -1;
     }
 
-    if (!SSL_in_init(s))
-        return (s->method->ssl_shutdown(s));
-    else
-        return (1);
+    if((s->mode & SSL_MODE_ASYNC) && ASYNC_get_current_job() == NULL) {
+        struct ssl_async_args args;
+
+        args.s = s;
+        args.type = OTHERFUNC;
+        args.f.func_other = s->method->ssl_shutdown;
+
+        return ssl_start_async_job(s, &args, ssl_io_intern);
+    } else {
+        return s->method->ssl_shutdown(s);
+    }
+
+    return s->method->ssl_shutdown(s);
 }
 
 int SSL_renegotiate(SSL *s)
@@ -2345,6 +2371,11 @@ SSL_CTX *SSL_CTX_new(const SSL_METHOD *meth)
     return (NULL);
 }
 
+void SSL_CTX_up_ref(SSL_CTX *ctx)
+{
+    CRYPTO_add(&ctx->references, 1, CRYPTO_LOCK_SSL_CTX);
+}
+
 void SSL_CTX_free(SSL_CTX *a)
 {
     int i;
@@ -2494,9 +2525,8 @@ void ssl_set_masks(SSL *s, const SSL_CIPHER *cipher)
     mask_a = 0;
 
 #ifdef CIPHER_DEBUG
-    fprintf(stderr,
-            "dht=%d re=%d rs=%d ds=%d dhr=%d dhd=%d\n",
-            dh_tmp, rsa_enc, rsa_sign, dsa_sign, dh_rsa, dh_dsa);
+    fprintf(stderr, "dht=%d re=%d rs=%d ds=%d\n",
+            dh_tmp, rsa_enc, rsa_sign, dsa_sign);
 #endif
 
 #ifndef OPENSSL_NO_GOST