Rename start_async_job to ssl_start_async_job
[openssl.git] / ssl / ssl_lib.c
index ec852569b1994bd2909c53312f1d7d47a3e61920..44374b47e9a7421b2e2c280138aa42deb575db39 100644 (file)
 #ifndef OPENSSL_NO_ENGINE
 # include <openssl/engine.h>
 #endif
+#include <openssl/async.h>
 
 const char SSL_version_str[] = OPENSSL_VERSION_TEXT;
 
@@ -186,6 +187,17 @@ SSL3_ENC_METHOD ssl3_undef_enc_method = {
              int use_context))ssl_undefined_function,
 };
 
+struct ssl_async_args {
+    SSL *s;
+    void *buf;
+    int num;
+    int type;
+    union {
+        int (*func1)(SSL *, void *, int);
+        int (*func2)(SSL *, const void *, int);
+    } f;
+};
+
 static void clear_ciphers(SSL *s)
 {
     /* clear the current cipher */
@@ -311,7 +323,7 @@ SSL *SSL_new(SSL_CTX *ctx)
     s->generate_session_id = ctx->generate_session_id;
 
     s->param = X509_VERIFY_PARAM_new();
-    if (!s->param)
+    if (s->param == NULL)
         goto err;
     X509_VERIFY_PARAM_inherit(s->param, ctx->param);
     s->quiet_shutdown = ctx->quiet_shutdown;
@@ -366,6 +378,9 @@ SSL *SSL_new(SSL_CTX *ctx)
 
     s->verify_result = X509_V_OK;
 
+    s->default_passwd_callback = ctx->default_passwd_callback;
+    s->default_passwd_callback_userdata = ctx->default_passwd_callback_userdata;
+
     s->method = ctx->method;
 
     if (!s->method->ssl_new(s))
@@ -383,6 +398,8 @@ SSL *SSL_new(SSL_CTX *ctx)
     s->psk_server_callback = ctx->psk_server_callback;
 #endif
 
+    s->job = NULL;
+
     return (s);
  err:
     SSL_free(s);
@@ -911,22 +928,40 @@ int SSL_check_private_key(const SSL *ssl)
                                    ssl->cert->key->privatekey));
 }
 
+int SSL_waiting_for_async(SSL *s)
+{
+    if(s->job)
+        return 1;
+
+    return 0;
+}
+
+int SSL_get_async_wait_fd(SSL *s)
+{
+    if (!s->job)
+        return -1;
+
+    return ASYNC_get_wait_fd(s->job);
+}
+
 int SSL_accept(SSL *s)
 {
-    if (s->handshake_func == 0)
+    if (s->handshake_func == 0) {
         /* Not properly initialized yet */
         SSL_set_accept_state(s);
+    }
 
-    return (s->method->ssl_accept(s));
+    return SSL_do_handshake(s);
 }
 
 int SSL_connect(SSL *s)
 {
-    if (s->handshake_func == 0)
+    if (s->handshake_func == 0) {
         /* Not properly initialized yet */
         SSL_set_connect_state(s);
+    }
 
-    return (s->method->ssl_connect(s));
+    return SSL_do_handshake(s);
 }
 
 long SSL_get_default_timeout(const SSL *s)
@@ -934,6 +969,46 @@ long SSL_get_default_timeout(const SSL *s)
     return (s->method->get_timeout());
 }
 
+static int ssl_start_async_job(SSL *s, struct ssl_async_args *args,
+                          int (*func)(void *)) {
+    int ret;
+    switch(ASYNC_start_job(&s->job, &ret, func, args,
+        sizeof(struct ssl_async_args))) {
+    case ASYNC_ERR:
+        s->rwstate = SSL_NOTHING;
+        SSLerr(SSL_F_SSL_START_ASYNC_JOB, SSL_R_FAILED_TO_INIT_ASYNC);
+        return -1;
+    case ASYNC_PAUSE:
+        s->rwstate = SSL_ASYNC_PAUSED;
+        return -1;
+    case ASYNC_FINISH:
+        s->job = NULL;
+        return ret;
+    default:
+        s->rwstate = SSL_NOTHING;
+        SSLerr(SSL_F_SSL_START_ASYNC_JOB, ERR_R_INTERNAL_ERROR);
+        /* Shouldn't happen */
+        return -1;
+    }
+}
+
+static int ssl_io_intern(void *vargs)
+{
+    struct ssl_async_args *args;
+    SSL *s;
+    void *buf;
+    int num;
+
+    args = (struct ssl_async_args *)vargs;
+    s = args->s;
+    buf = args->buf;
+    num = args->num;
+    if (args->type == 1)
+        return args->f.func1(s, buf, num);
+    else
+        return args->f.func2(s, buf, num);
+}
+
 int SSL_read(SSL *s, void *buf, int num)
 {
     if (s->handshake_func == 0) {
@@ -945,7 +1020,20 @@ int SSL_read(SSL *s, void *buf, int num)
         s->rwstate = SSL_NOTHING;
         return (0);
     }
-    return (s->method->ssl_read(s, buf, num));
+
+    if((s->mode & SSL_MODE_ASYNC) && ASYNC_get_current_job() == NULL) {
+        struct ssl_async_args args;
+
+        args.s = s;
+        args.buf = buf;
+        args.num = num;
+        args.type = 1;
+        args.f.func1 = s->method->ssl_read;
+
+        return ssl_start_async_job(s, &args, ssl_io_intern);
+    } else {
+        return s->method->ssl_read(s, buf, num);
+    }
 }
 
 int SSL_peek(SSL *s, void *buf, int num)
@@ -958,7 +1046,19 @@ int SSL_peek(SSL *s, void *buf, int num)
     if (s->shutdown & SSL_RECEIVED_SHUTDOWN) {
         return (0);
     }
-    return (s->method->ssl_peek(s, buf, num));
+    if((s->mode & SSL_MODE_ASYNC) && ASYNC_get_current_job() == NULL) {
+        struct ssl_async_args args;
+
+        args.s = s;
+        args.buf = buf;
+        args.num = num;
+        args.type = 1;
+        args.f.func1 = s->method->ssl_peek;
+
+        return ssl_start_async_job(s, &args, ssl_io_intern);
+    } else {
+        return s->method->ssl_peek(s, buf, num);
+    }
 }
 
 int SSL_write(SSL *s, const void *buf, int num)
@@ -973,7 +1073,20 @@ int SSL_write(SSL *s, const void *buf, int num)
         SSLerr(SSL_F_SSL_WRITE, SSL_R_PROTOCOL_IS_SHUTDOWN);
         return (-1);
     }
-    return (s->method->ssl_write(s, buf, num));
+
+    if((s->mode & SSL_MODE_ASYNC) && ASYNC_get_current_job() == NULL) {
+        struct ssl_async_args args;
+
+        args.s = s;
+        args.buf = (void *)buf;
+        args.num = num;
+        args.type = 2;
+        args.f.func2 = s->method->ssl_write;
+
+        return ssl_start_async_job(s, &args, ssl_io_intern);
+    } else {
+        return s->method->ssl_write(s, buf, num);
+    }
 }
 
 int SSL_shutdown(SSL *s)
@@ -1547,7 +1660,7 @@ int SSL_CTX_set_alpn_protos(SSL_CTX *ctx, const unsigned char *protos,
 {
     OPENSSL_free(ctx->alpn_client_proto_list);
     ctx->alpn_client_proto_list = OPENSSL_malloc(protos_len);
-    if (!ctx->alpn_client_proto_list)
+    if (ctx->alpn_client_proto_list == NULL)
         return 1;
     memcpy(ctx->alpn_client_proto_list, protos, protos_len);
     ctx->alpn_client_proto_list_len = protos_len;
@@ -1565,7 +1678,7 @@ int SSL_set_alpn_protos(SSL *ssl, const unsigned char *protos,
 {
     OPENSSL_free(ssl->alpn_client_proto_list);
     ssl->alpn_client_proto_list = OPENSSL_malloc(protos_len);
-    if (!ssl->alpn_client_proto_list)
+    if (ssl->alpn_client_proto_list == NULL)
         return 1;
     memcpy(ssl->alpn_client_proto_list, protos, protos_len);
     ssl->alpn_client_proto_list_len = protos_len;
@@ -1708,7 +1821,7 @@ SSL_CTX *SSL_CTX_new(const SSL_METHOD *meth)
     }
 
     ret->param = X509_VERIFY_PARAM_new();
-    if (!ret->param)
+    if (ret->param == NULL)
         goto err;
 
     if ((ret->md5 = EVP_get_digestbyname("ssl3-md5")) == NULL) {
@@ -1846,6 +1959,16 @@ void SSL_CTX_set_default_passwd_cb_userdata(SSL_CTX *ctx, void *u)
     ctx->default_passwd_callback_userdata = u;
 }
 
+void SSL_set_default_passwd_cb(SSL *s, pem_password_cb *cb)
+{
+    s->default_passwd_callback = cb;
+}
+
+void SSL_set_default_passwd_cb_userdata(SSL *s, void *u)
+{
+    s->default_passwd_callback_userdata = u;
+}
+
 void SSL_CTX_set_cert_verify_callback(SSL_CTX *ctx,
                                       int (*cb) (X509_STORE_CTX *, void *),
                                       void *arg)
@@ -2360,6 +2483,9 @@ int SSL_get_error(const SSL *s, int i)
     if ((i < 0) && SSL_want_x509_lookup(s)) {
         return (SSL_ERROR_WANT_X509_LOOKUP);
     }
+    if ((i < 0) && SSL_want_async(s)) {
+        return SSL_ERROR_WANT_ASYNC;
+    }
 
     if (i == 0) {
         if ((s->shutdown & SSL_RECEIVED_SHUTDOWN) &&
@@ -2369,21 +2495,40 @@ int SSL_get_error(const SSL *s, int i)
     return (SSL_ERROR_SYSCALL);
 }
 
+static int ssl_do_handshake_intern(void *vargs)
+{
+    struct ssl_async_args *args;
+    SSL *s;
+
+    args = (struct ssl_async_args *)vargs;
+    s = args->s;
+
+    return s->handshake_func(s);
+}
+
 int SSL_do_handshake(SSL *s)
 {
     int ret = 1;
 
     if (s->handshake_func == NULL) {
         SSLerr(SSL_F_SSL_DO_HANDSHAKE, SSL_R_CONNECTION_TYPE_NOT_SET);
-        return (-1);
+        return -1;
     }
 
     s->method->ssl_renegotiate_check(s);
 
     if (SSL_in_init(s) || SSL_in_before(s)) {
-        ret = s->handshake_func(s);
+        if((s->mode & SSL_MODE_ASYNC) && ASYNC_get_current_job() == NULL) {
+            struct ssl_async_args args;
+
+            args.s = s;
+
+            ret = ssl_start_async_job(s, &args, ssl_do_handshake_intern);
+        } else {
+            ret = s->handshake_func(s);
+        }
     }
-    return (ret);
+    return ret;
 }
 
 void SSL_set_accept_state(SSL *s)
@@ -2419,8 +2564,6 @@ int ssl_undefined_void_function(void)
 
 int ssl_undefined_const_function(const SSL *s)
 {
-    SSLerr(SSL_F_SSL_UNDEFINED_CONST_FUNCTION,
-           ERR_R_SHOULD_NOT_HAVE_BEEN_CALLED);
     return (0);
 }
 
@@ -2535,6 +2678,9 @@ SSL *SSL_dup(SSL *s)
                                  * ret->init_off */
     ret->hit = s->hit;
 
+    ret->default_passwd_callback = s->default_passwd_callback;
+    ret->default_passwd_callback_userdata = s->default_passwd_callback_userdata;
+
     X509_VERIFY_PARAM_inherit(ret->param, s->param);
 
     /* dup the cipher_list and cipher_list_by_id stacks */
@@ -3149,8 +3295,11 @@ EVP_MD_CTX *ssl_replace_hash(EVP_MD_CTX **hash, const EVP_MD *md)
 {
     ssl_clear_hash_ctx(hash);
     *hash = EVP_MD_CTX_create();
-    if (md)
-        EVP_DigestInit_ex(*hash, md, NULL);
+    if (*hash == NULL || (md && EVP_DigestInit_ex(*hash, md, NULL) <= 0)) {
+        EVP_MD_CTX_destroy(*hash);
+        *hash = NULL;
+        return NULL;
+    }
     return *hash;
 }