SSL object refactoring using SSL_CONNECTION object
[openssl.git] / ssl / ssl_sess.c
index 0a26c1345ca5c36ee21f74764c1654622242e157..0a9f025fbb19c80b941a21a7774559326c1e70d4 100644 (file)
@@ -88,13 +88,19 @@ void ssl_session_calculate_timeout(SSL_SESSION *ss)
 SSL_SESSION *SSL_get_session(const SSL *ssl)
 /* aka SSL_get0_session; gets 0 objects, just returns a copy of the pointer */
 {
-    return ssl->session;
+    const SSL_CONNECTION *sc = SSL_CONNECTION_FROM_SSL(ssl);
+
+    if (sc == NULL)
+        return NULL;
+
+    return sc->session;
 }
 
 SSL_SESSION *SSL_get1_session(SSL *ssl)
 /* variant of SSL_get_session: caller really gets something */
 {
     SSL_SESSION *sess;
+
     /*
      * Need to lock this all up rather than just use CRYPTO_add so that
      * somebody doesn't free ssl->session between when we check it's non-null
@@ -102,8 +108,8 @@ SSL_SESSION *SSL_get1_session(SSL *ssl)
      */
     if (!CRYPTO_THREAD_read_lock(ssl->lock))
         return NULL;
-    sess = ssl->session;
-    if (sess)
+    sess = SSL_get_session(ssl);
+    if (sess != NULL)
         SSL_SESSION_up_ref(sess);
     CRYPTO_THREAD_unlock(ssl->lock);
     return sess;
@@ -335,10 +341,11 @@ static int def_generate_session_id(SSL *ssl, unsigned char *id,
     return 0;
 }
 
-int ssl_generate_session_id(SSL *s, SSL_SESSION *ss)
+int ssl_generate_session_id(SSL_CONNECTION *s, SSL_SESSION *ss)
 {
     unsigned int tmp;
     GEN_SESSION_CB cb = def_generate_session_id;
+    SSL *ssl = SSL_CONNECTION_GET_SSL(s);
 
     switch (s->version) {
     case SSL3_VERSION:
@@ -377,10 +384,10 @@ int ssl_generate_session_id(SSL *s, SSL_SESSION *ss)
     }
 
     /* Choose which callback will set the session ID */
-    if (!CRYPTO_THREAD_read_lock(s->lock))
+    if (!CRYPTO_THREAD_read_lock(SSL_CONNECTION_GET_SSL(s)->lock))
         return 0;
     if (!CRYPTO_THREAD_read_lock(s->session_ctx->lock)) {
-        CRYPTO_THREAD_unlock(s->lock);
+        CRYPTO_THREAD_unlock(ssl->lock);
         SSLfatal(s, SSL_AD_INTERNAL_ERROR,
                  SSL_R_SESSION_ID_CONTEXT_UNINITIALIZED);
         return 0;
@@ -390,11 +397,11 @@ int ssl_generate_session_id(SSL *s, SSL_SESSION *ss)
     else if (s->session_ctx->generate_session_id)
         cb = s->session_ctx->generate_session_id;
     CRYPTO_THREAD_unlock(s->session_ctx->lock);
-    CRYPTO_THREAD_unlock(s->lock);
+    CRYPTO_THREAD_unlock(ssl->lock);
     /* Choose a session ID */
     memset(ss->session_id, 0, ss->session_id_length);
     tmp = (int)ss->session_id_length;
-    if (!cb(s, ss->session_id, &tmp)) {
+    if (!cb(ssl, ss->session_id, &tmp)) {
         /* The callback failed */
         SSLfatal(s, SSL_AD_INTERNAL_ERROR,
                  SSL_R_SSL_SESSION_ID_CALLBACK_FAILED);
@@ -412,7 +419,7 @@ int ssl_generate_session_id(SSL *s, SSL_SESSION *ss)
     }
     ss->session_id_length = tmp;
     /* Finally, check for a conflict */
-    if (SSL_has_matching_session_id(s, ss->session_id,
+    if (SSL_has_matching_session_id(ssl, ss->session_id,
                                     (unsigned int)ss->session_id_length)) {
         SSLfatal(s, SSL_AD_INTERNAL_ERROR, SSL_R_SSL_SESSION_ID_CONFLICT);
         return 0;
@@ -421,7 +428,7 @@ int ssl_generate_session_id(SSL *s, SSL_SESSION *ss)
     return 1;
 }
 
-int ssl_get_new_session(SSL *s, int session)
+int ssl_get_new_session(SSL_CONNECTION *s, int session)
 {
     /* This gets used by clients and servers. */
 
@@ -434,7 +441,7 @@ int ssl_get_new_session(SSL *s, int session)
 
     /* If the context has a default timeout, use it */
     if (s->session_ctx->session_timeout == 0)
-        ss->timeout = SSL_get_default_timeout(s);
+        ss->timeout = SSL_get_default_timeout(SSL_CONNECTION_GET_SSL(s));
     else
         ss->timeout = s->session_ctx->session_timeout;
     ssl_session_calculate_timeout(ss);
@@ -443,7 +450,7 @@ int ssl_get_new_session(SSL *s, int session)
     s->session = NULL;
 
     if (session) {
-        if (SSL_IS_TLS13(s)) {
+        if (SSL_CONNECTION_IS_TLS13(s)) {
             /*
              * We generate the session id while constructing the
              * NewSessionTicket in TLSv1.3.
@@ -477,7 +484,8 @@ int ssl_get_new_session(SSL *s, int session)
     return 1;
 }
 
-SSL_SESSION *lookup_sess_in_cache(SSL *s, const unsigned char *sess_id,
+SSL_SESSION *lookup_sess_in_cache(SSL_CONNECTION *s,
+                                  const unsigned char *sess_id,
                                   size_t sess_id_len)
 {
     SSL_SESSION *ret = NULL;
@@ -508,7 +516,8 @@ SSL_SESSION *lookup_sess_in_cache(SSL *s, const unsigned char *sess_id,
     if (ret == NULL && s->session_ctx->get_session_cb != NULL) {
         int copy = 1;
 
-        ret = s->session_ctx->get_session_cb(s, sess_id, sess_id_len, &copy);
+        ret = s->session_ctx->get_session_cb(SSL_CONNECTION_GET_SSL(s),
+                                             sess_id, sess_id_len, &copy);
 
         if (ret != NULL) {
             ssl_tsan_counter(s->session_ctx,
@@ -560,7 +569,7 @@ SSL_SESSION *lookup_sess_in_cache(SSL *s, const unsigned char *sess_id,
  *   - Both for new and resumed sessions, s->ext.ticket_expected is set to 1
  *     if the server should issue a new session ticket (to 0 otherwise).
  */
-int ssl_get_prev_session(SSL *s, CLIENTHELLO_MSG *hello)
+int ssl_get_prev_session(SSL_CONNECTION *s, CLIENTHELLO_MSG *hello)
 {
     /* This is used only by servers. */
 
@@ -569,7 +578,7 @@ int ssl_get_prev_session(SSL *s, CLIENTHELLO_MSG *hello)
     int try_session_cache = 0;
     SSL_TICKET_STATUS r;
 
-    if (SSL_IS_TLS13(s)) {
+    if (SSL_CONNECTION_IS_TLS13(s)) {
         /*
          * By default we will send a new ticket. This can be overridden in the
          * ticket processing.
@@ -664,7 +673,7 @@ int ssl_get_prev_session(SSL *s, CLIENTHELLO_MSG *hello)
         goto err;
     }
 
-    if (!SSL_IS_TLS13(s)) {
+    if (!SSL_CONNECTION_IS_TLS13(s)) {
         /* We already did this for TLS1.3 */
         SSL_SESSION_free(s->session);
         s->session = ret;
@@ -678,7 +687,7 @@ int ssl_get_prev_session(SSL *s, CLIENTHELLO_MSG *hello)
     if (ret != NULL) {
         SSL_SESSION_free(ret);
         /* In TLSv1.3 s->session was already set to ret, so we NULL it out */
-        if (SSL_IS_TLS13(s))
+        if (SSL_CONNECTION_IS_TLS13(s))
             s->session = NULL;
 
         if (!try_session_cache) {
@@ -859,7 +868,12 @@ int SSL_SESSION_up_ref(SSL_SESSION *ss)
 
 int SSL_set_session(SSL *s, SSL_SESSION *session)
 {
-    ssl_clear_bad_session(s);
+    SSL_CONNECTION *sc = SSL_CONNECTION_FROM_SSL(s);
+
+    if (sc == NULL)
+        return 0;
+
+    ssl_clear_bad_session(sc);
     if (s->ctx->method != s->method) {
         if (!SSL_set_ssl_method(s, s->ctx->method))
             return 0;
@@ -867,10 +881,10 @@ int SSL_set_session(SSL *s, SSL_SESSION *session)
 
     if (session != NULL) {
         SSL_SESSION_up_ref(session);
-        s->verify_result = session->verify_result;
+        sc->verify_result = session->verify_result;
     }
-    SSL_SESSION_free(s->session);
-    s->session = session;
+    SSL_SESSION_free(sc->session);
+    sc->session = session;
 
     return 1;
 }
@@ -1088,42 +1102,53 @@ int SSL_set_session_secret_cb(SSL *s,
                               tls_session_secret_cb_fn tls_session_secret_cb,
                               void *arg)
 {
-    if (s == NULL)
+    SSL_CONNECTION *sc = SSL_CONNECTION_FROM_SSL(s);
+
+    if (sc == NULL)
         return 0;
-    s->ext.session_secret_cb = tls_session_secret_cb;
-    s->ext.session_secret_cb_arg = arg;
+
+    sc->ext.session_secret_cb = tls_session_secret_cb;
+    sc->ext.session_secret_cb_arg = arg;
     return 1;
 }
 
 int SSL_set_session_ticket_ext_cb(SSL *s, tls_session_ticket_ext_cb_fn cb,
                                   void *arg)
 {
-    if (s == NULL)
+    SSL_CONNECTION *sc = SSL_CONNECTION_FROM_SSL(s);
+
+    if (sc == NULL)
         return 0;
-    s->ext.session_ticket_cb = cb;
-    s->ext.session_ticket_cb_arg = arg;
+
+    sc->ext.session_ticket_cb = cb;
+    sc->ext.session_ticket_cb_arg = arg;
     return 1;
 }
 
 int SSL_set_session_ticket_ext(SSL *s, void *ext_data, int ext_len)
 {
-    if (s->version >= TLS1_VERSION) {
-        OPENSSL_free(s->ext.session_ticket);
-        s->ext.session_ticket = NULL;
-        s->ext.session_ticket =
+    SSL_CONNECTION *sc = SSL_CONNECTION_FROM_SSL(s);
+
+    if (sc == NULL)
+        return 0;
+
+    if (sc->version >= TLS1_VERSION) {
+        OPENSSL_free(sc->ext.session_ticket);
+        sc->ext.session_ticket = NULL;
+        sc->ext.session_ticket =
             OPENSSL_malloc(sizeof(TLS_SESSION_TICKET_EXT) + ext_len);
-        if (s->ext.session_ticket == NULL) {
+        if (sc->ext.session_ticket == NULL) {
             ERR_raise(ERR_LIB_SSL, ERR_R_MALLOC_FAILURE);
             return 0;
         }
 
         if (ext_data != NULL) {
-            s->ext.session_ticket->length = ext_len;
-            s->ext.session_ticket->data = s->ext.session_ticket + 1;
-            memcpy(s->ext.session_ticket->data, ext_data, ext_len);
+            sc->ext.session_ticket->length = ext_len;
+            sc->ext.session_ticket->data = sc->ext.session_ticket + 1;
+            memcpy(sc->ext.session_ticket->data, ext_data, ext_len);
         } else {
-            s->ext.session_ticket->length = 0;
-            s->ext.session_ticket->data = NULL;
+            sc->ext.session_ticket->length = 0;
+            sc->ext.session_ticket->data = NULL;
         }
 
         return 1;
@@ -1180,11 +1205,12 @@ void SSL_CTX_flush_sessions(SSL_CTX *s, long t)
     sk_SSL_SESSION_pop_free(sk, SSL_SESSION_free);
 }
 
-int ssl_clear_bad_session(SSL *s)
+int ssl_clear_bad_session(SSL_CONNECTION *s)
 {
     if ((s->session != NULL) &&
         !(s->shutdown & SSL_SENT_SHUTDOWN) &&
-        !(SSL_in_init(s) || SSL_in_before(s))) {
+        !(SSL_in_init(SSL_CONNECTION_GET_SSL(s))
+          || SSL_in_before(SSL_CONNECTION_GET_SSL(s)))) {
         SSL_CTX_remove_session(s->session_ctx, s->session);
         return 1;
     } else
@@ -1293,7 +1319,7 @@ void (*SSL_CTX_sess_get_remove_cb(SSL_CTX *ctx)) (SSL_CTX *ctx,
 }
 
 void SSL_CTX_sess_set_get_cb(SSL_CTX *ctx,
-                             SSL_SESSION *(*cb) (struct ssl_st *ssl,
+                             SSL_SESSION *(*cb) (SSL *ssl,
                                                  const unsigned char *data,
                                                  int len, int *copy))
 {