QUIC: Refine SSL_shutdown and begin to implement SSL_shutdown_ex
[openssl.git] / ssl / ssl_lib.c
index 9db3c28073095e3f1701df867e46261c12c3d124..b927e283fee8357b3900b2da478e19727c9a59a3 100644 (file)
@@ -553,17 +553,18 @@ static int clear_record_layer(SSL_CONNECTION *s)
                                    SSL_CONNECTION_IS_DTLS(s) ? DTLS_ANY_VERSION
                                                              : TLS_ANY_VERSION,
                                    OSSL_RECORD_DIRECTION_READ,
-                                   OSSL_RECORD_PROTECTION_LEVEL_NONE,
+                                   OSSL_RECORD_PROTECTION_LEVEL_NONE, NULL, 0,
                                    NULL, 0, NULL, 0, NULL,  0, NULL, 0,
-                                   NID_undef, NULL, NULL);
+                                   NID_undef, NULL, NULL, NULL);
 
     ret &= ssl_set_new_record_layer(s,
                                     SSL_CONNECTION_IS_DTLS(s) ? DTLS_ANY_VERSION
                                                               : TLS_ANY_VERSION,
                                     OSSL_RECORD_DIRECTION_WRITE,
-                                    OSSL_RECORD_PROTECTION_LEVEL_NONE,
+                                    OSSL_RECORD_PROTECTION_LEVEL_NONE, NULL, 0,
                                     NULL, 0, NULL, 0, NULL,  0, NULL, 0,
-                                    NID_undef, NULL, NULL);
+                                    NID_undef, NULL, NULL, NULL);
+
     /* SSLfatal already called in the event of failure */
     return ret;
 }
@@ -644,9 +645,9 @@ int ossl_ssl_connection_reset(SSL *s)
      * Check to see if we were changed into a different method, if so, revert
      * back.
      */
-    if (s->method != SSL_CONNECTION_GET_CTX(sc)->method) {
+    if (s->method != s->defltmeth) {
         s->method->ssl_deinit(s);
-        s->method = SSL_CONNECTION_GET_CTX(sc)->method;
+        s->method = s->defltmeth;
         if (!s->method->ssl_init(s))
             return 0;
     } else {
@@ -702,7 +703,7 @@ SSL *SSL_new(SSL_CTX *ctx)
     return ctx->method->ssl_new(ctx);
 }
 
-int ossl_ssl_init(SSL *ssl, SSL_CTX *ctx, int type)
+int ossl_ssl_init(SSL *ssl, SSL_CTX *ctx, const SSL_METHOD *method, int type)
 {
     ssl->type = type;
 
@@ -714,7 +715,7 @@ int ossl_ssl_init(SSL *ssl, SSL_CTX *ctx, int type)
     SSL_CTX_up_ref(ctx);
     ssl->ctx = ctx;
 
-    ssl->method = ctx->method;
+    ssl->defltmeth = ssl->method = method;
 
     if (!CRYPTO_new_ex_data(CRYPTO_EX_INDEX_SSL, ssl, &ssl->ex_data))
         return 0;
@@ -722,7 +723,7 @@ int ossl_ssl_init(SSL *ssl, SSL_CTX *ctx, int type)
     return 1;
 }
 
-SSL *ossl_ssl_connection_new(SSL_CTX *ctx)
+SSL *ossl_ssl_connection_new_int(SSL_CTX *ctx, const SSL_METHOD *method)
 {
     SSL_CONNECTION *s;
     SSL *ssl;
@@ -732,17 +733,12 @@ SSL *ossl_ssl_connection_new(SSL_CTX *ctx)
         return NULL;
 
     ssl = &s->ssl;
-    if (!ossl_ssl_init(ssl, ctx, SSL_TYPE_SSL_CONNECTION)) {
+    if (!ossl_ssl_init(ssl, ctx, method, SSL_TYPE_SSL_CONNECTION)) {
         OPENSSL_free(s);
         s = NULL;
         goto sslerr;
     }
 
-#ifndef OPENSSL_NO_QUIC
-    /* set the parent (user visible) ssl to self */
-    s->user_ssl = ssl;
-#endif
-
     RECORD_LAYER_init(&s->rlayer, s);
 
     s->options = ctx->options;
@@ -860,12 +856,12 @@ SSL *ossl_ssl_connection_new(SSL_CTX *ctx)
     s->allow_early_data_cb = ctx->allow_early_data_cb;
     s->allow_early_data_cb_data = ctx->allow_early_data_cb_data;
 
-    if (!ssl->method->ssl_init(ssl))
+    if (!method->ssl_init(ssl))
         goto sslerr;
 
-    s->server = (ctx->method->ssl_accept == ssl_undefined_function) ? 0 : 1;
+    s->server = (method->ssl_accept == ssl_undefined_function) ? 0 : 1;
 
-    if (!SSL_clear(ssl))
+    if (!method->ssl_reset(ssl))
         goto sslerr;
 
 #ifndef OPENSSL_NO_PSK
@@ -904,6 +900,11 @@ SSL *ossl_ssl_connection_new(SSL_CTX *ctx)
     return NULL;
 }
 
+SSL *ossl_ssl_connection_new(SSL_CTX *ctx)
+{
+    return ossl_ssl_connection_new_int(ctx, ctx->method);
+}
+
 int SSL_is_dtls(const SSL *s)
 {
     SSL_CONNECTION *sc = SSL_CONNECTION_FROM_SSL(s);
@@ -2637,6 +2638,12 @@ int SSL_shutdown(SSL *s)
      * (see ssl3_shutdown).
      */
     SSL_CONNECTION *sc = SSL_CONNECTION_FROM_SSL(s);
+#ifndef OPENSSL_NO_QUIC
+    QUIC_CONNECTION *qc = QUIC_CONNECTION_FROM_SSL(s);
+
+    if (qc != NULL)
+        return ossl_quic_conn_shutdown(qc, 0, NULL, 0);
+#endif
 
     if (sc == NULL)
         return -1;
@@ -2887,13 +2894,13 @@ long SSL_ctrl(SSL *s, int cmd, long larg, void *parg)
             return 0;
     case SSL_CTRL_SET_MIN_PROTO_VERSION:
         return ssl_check_allowed_versions(larg, sc->max_proto_version)
-               && ssl_set_version_bound(s->ctx->method->version, (int)larg,
+               && ssl_set_version_bound(s->defltmeth->version, (int)larg,
                                         &sc->min_proto_version);
     case SSL_CTRL_GET_MIN_PROTO_VERSION:
         return sc->min_proto_version;
     case SSL_CTRL_SET_MAX_PROTO_VERSION:
         return ssl_check_allowed_versions(sc->min_proto_version, larg)
-               && ssl_set_version_bound(s->ctx->method->version, (int)larg,
+               && ssl_set_version_bound(s->defltmeth->version, (int)larg,
                                         &sc->max_proto_version);
     case SSL_CTRL_GET_MAX_PROTO_VERSION:
         return sc->max_proto_version;
@@ -4411,6 +4418,14 @@ int SSL_get_error(const SSL *s, int i)
     if (i > 0)
         return SSL_ERROR_NONE;
 
+#ifndef OPENSSL_NO_QUIC
+    if (qc != NULL) {
+        reason = ossl_quic_get_error(qc, i);
+        if (reason != SSL_ERROR_NONE)
+            return reason;
+    }
+#endif
+
     if (sc == NULL)
         return SSL_ERROR_SSL;
 
@@ -4425,14 +4440,6 @@ int SSL_get_error(const SSL *s, int i)
             return SSL_ERROR_SSL;
     }
 
-#ifndef OPENSSL_NO_QUIC
-    if (qc != NULL) {
-        reason = ossl_quic_get_error(qc, i);
-        if (reason != SSL_ERROR_NONE)
-            return reason;
-    }
-#endif
-
 #ifndef OPENSSL_NO_QUIC
     if (qc == NULL)
 #endif
@@ -7105,7 +7112,7 @@ int SSL_get_wpoll_descriptor(SSL *s, BIO_POLL_DESCRIPTOR *desc)
 #endif
 }
 
-int SSL_want_net_read(SSL *s)
+int SSL_net_read_desired(SSL *s)
 {
 #ifndef OPENSSL_NO_QUIC
     QUIC_CONNECTION *qc = QUIC_CONNECTION_FROM_SSL(s);
@@ -7113,13 +7120,13 @@ int SSL_want_net_read(SSL *s)
     if (qc == NULL)
         return 0;
 
-    return ossl_quic_get_want_net_read(qc);
+    return ossl_quic_get_net_read_desired(qc);
 #else
     return 0;
 #endif
 }
 
-int SSL_want_net_write(SSL *s)
+int SSL_net_write_desired(SSL *s)
 {
 #ifndef OPENSSL_NO_QUIC
     QUIC_CONNECTION *qc = QUIC_CONNECTION_FROM_SSL(s);
@@ -7127,7 +7134,7 @@ int SSL_want_net_write(SSL *s)
     if (qc == NULL)
         return 0;
 
-    return ossl_quic_get_want_net_write(qc);
+    return ossl_quic_get_net_write_desired(qc);
 #else
     return 0;
 #endif
@@ -7167,10 +7174,26 @@ int SSL_set_initial_peer_addr(SSL *s, const BIO_ADDR *peer_addr)
     QUIC_CONNECTION *qc = QUIC_CONNECTION_FROM_SSL(s);
 
     if (qc == NULL)
-        return -1;
+        return 0;
 
     return ossl_quic_conn_set_initial_peer_addr(qc, peer_addr);
 #else
-    return -1;
+    return 0;
+#endif
+}
+
+int SSL_shutdown_ex(SSL *ssl, uint64_t flags,
+                    const SSL_SHUTDOWN_EX_ARGS *args,
+                    size_t args_len)
+{
+#ifndef OPENSSL_NO_QUIC
+    QUIC_CONNECTION *qc = QUIC_CONNECTION_FROM_SSL(ssl);
+
+    if (qc == NULL)
+        return SSL_shutdown(ssl);
+
+    return ossl_quic_conn_shutdown(qc, flags, args, args_len);
+#else
+    return SSL_shutdown(ssl);
 #endif
 }