Convert libssl writing for size_t
[openssl.git] / ssl / ssl_lib.c
index 30b1d6b860bc16b1bb7fa6f7bd8d7f10da2c9b07..a26e352b29baf7b2bba45cac8a8c405760dfdaa8 100644 (file)
@@ -85,7 +85,7 @@ struct ssl_async_args {
     enum { READFUNC, WRITEFUNC, OTHERFUNC } type;
     union {
         int (*func_read) (SSL *, void *, size_t, size_t *);
-        int (*func_write) (SSL *, const void *, int);
+        int (*func_write) (SSL *, const void *, size_t, size_t *);
         int (*func_other) (SSL *);
     } f;
 };
@@ -1517,9 +1517,9 @@ static int ssl_io_intern(void *vargs)
     num = args->num;
     switch (args->type) {
     case READFUNC:
-        return args->f.func_read(s, buf, num, &s->asyncread);
+        return args->f.func_read(s, buf, num, &s->asyncrw);
     case WRITEFUNC:
-        return args->f.func_write(s, buf, num);
+        return args->f.func_write(s, buf, num, &s->asyncrw);
     case OTHERFUNC:
         return args->f.func_other(s);
     }
@@ -1571,7 +1571,7 @@ int SSL_read_ex(SSL *s, void *buf, size_t num, size_t *read)
         args.f.func_read = s->method->ssl_read;
 
         ret = ssl_start_async_job(s, &args, ssl_io_intern);
-        *read = s->asyncread;
+        *read = s->asyncrw;
         return ret;
     } else {
         return s->method->ssl_read(s, buf, num, read);
@@ -1621,7 +1621,7 @@ int SSL_peek_ex(SSL *s, void *buf, size_t num, size_t *read)
         args.f.func_read = s->method->ssl_peek;
 
         ret = ssl_start_async_job(s, &args, ssl_io_intern);
-        *read = s->asyncread;
+        *read = s->asyncrw;
         return ret;
     } else {
         return s->method->ssl_peek(s, buf, num, read);
@@ -1629,19 +1629,42 @@ int SSL_peek_ex(SSL *s, void *buf, size_t num, size_t *read)
 }
 
 int SSL_write(SSL *s, const void *buf, int num)
+{
+    int ret;
+    size_t written;
+
+    if (num < 0) {
+        SSLerr(SSL_F_SSL_WRITE, SSL_R_BAD_LENGTH);
+        return -1;
+    }
+
+    ret = SSL_write_ex(s, buf, (size_t)num, &written);
+
+    /*
+     * The cast is safe here because ret should be <= INT_MAX because num is
+     * <= INT_MAX
+     */
+    if (ret > 0)
+        ret = (int)written;
+
+    return ret;
+}
+
+int SSL_write_ex(SSL *s, const void *buf, size_t num, size_t *written)
 {
     if (s->handshake_func == NULL) {
-        SSLerr(SSL_F_SSL_WRITE, SSL_R_UNINITIALIZED);
+        SSLerr(SSL_F_SSL_WRITE_EX, SSL_R_UNINITIALIZED);
         return -1;
     }
 
     if (s->shutdown & SSL_SENT_SHUTDOWN) {
         s->rwstate = SSL_NOTHING;
-        SSLerr(SSL_F_SSL_WRITE, SSL_R_PROTOCOL_IS_SHUTDOWN);
+        SSLerr(SSL_F_SSL_WRITE_EX, SSL_R_PROTOCOL_IS_SHUTDOWN);
         return (-1);
     }
 
     if ((s->mode & SSL_MODE_ASYNC) && ASYNC_get_current_job() == NULL) {
+        int ret;
         struct ssl_async_args args;
 
         args.s = s;
@@ -1650,9 +1673,11 @@ int SSL_write(SSL *s, const void *buf, int num)
         args.type = WRITEFUNC;
         args.f.func_write = s->method->ssl_write;
 
-        return ssl_start_async_job(s, &args, ssl_io_intern);
+        ret = ssl_start_async_job(s, &args, ssl_io_intern);
+        *written = s->asyncrw;
+        return ret;
     } else {
-        return s->method->ssl_write(s, buf, num);
+        return s->method->ssl_write(s, buf, num, written);
     }
 }
 
@@ -1751,7 +1776,7 @@ long SSL_ctrl(SSL *s, int cmd, long larg, void *parg)
             s->split_send_fragment = s->max_send_fragment;
         return 1;
     case SSL_CTRL_SET_SPLIT_SEND_FRAGMENT:
-        if ((unsigned int)larg > s->max_send_fragment || larg == 0)
+        if ((size_t)larg > s->max_send_fragment || larg == 0)
             return 0;
         s->split_send_fragment = larg;
         return 1;
@@ -1905,7 +1930,7 @@ long SSL_CTX_ctrl(SSL_CTX *ctx, int cmd, long larg, void *parg)
             ctx->split_send_fragment = ctx->max_send_fragment;
         return 1;
     case SSL_CTRL_SET_SPLIT_SEND_FRAGMENT:
-        if ((unsigned int)larg > ctx->max_send_fragment || larg == 0)
+        if ((size_t)larg > ctx->max_send_fragment || larg == 0)
             return 0;
         ctx->split_send_fragment = larg;
         return 1;