Only inherit the session ID context in SSL_set_SSL_CTX if the existing
[openssl.git] / ssl / ssl_lib.c
index d42f50bf60217e73807f52a8d519662d96888442..f9f91e666c09acc30c69ee774b252a9eda07ad59 100644 (file)
@@ -3191,24 +3191,31 @@ SSL_CTX *SSL_set_SSL_CTX(SSL *ssl, SSL_CTX* ctx)
                        }
                ssl_cert_free(ocert);
                }
-       CRYPTO_add(&ctx->references,1,CRYPTO_LOCK_SSL_CTX);
-       if (ssl->ctx != NULL)
-               SSL_CTX_free(ssl->ctx); /* decrement reference count */
-       ssl->ctx = ctx;
 
-       /*
-        * Inherit the session ID context as it is typically set from the
-        * parent SSL_CTX, and can vary with the CTX.
-        * Note that per-SSL SSL_set_session_id_context() will not persist
-        * if called before SSL_set_SSL_CTX.
-        */
-       ssl->sid_ctx_length = ctx->sid_ctx_length;
        /*
         * Program invariant: |sid_ctx| has fixed size (SSL_MAX_SID_CTX_LENGTH),
         * so setter APIs must prevent invalid lengths from entering the system.
         */
-       OPENSSL_assert(ssl->sid_ctx_length <= sizeof ssl->sid_ctx);
-       memcpy(&ssl->sid_ctx, &ctx->sid_ctx, sizeof(ssl->sid_ctx));
+       OPENSSL_assert(ssl->sid_ctx_length <= sizeof(ssl->sid_ctx));
+
+       /*
+        * If the session ID context matches that of the parent SSL_CTX,
+        * inherit it from the new SSL_CTX as well. If however the context does
+        * not match (i.e., it was set per-ssl with SSL_set_session_id_context),
+        * leave it unchanged.
+        */
+       if ((ssl->ctx != NULL) &&
+               (ssl->sid_ctx_length == ssl->ctx->sid_ctx_length) &&
+               (memcmp(ssl->sid_ctx, ssl->ctx->sid_ctx, ssl->sid_ctx_length) == 0))
+               {
+               ssl->sid_ctx_length = ctx->sid_ctx_length;
+               memcpy(&ssl->sid_ctx, &ctx->sid_ctx, sizeof(ssl->sid_ctx));
+               }
+
+       CRYPTO_add(&ctx->references,1,CRYPTO_LOCK_SSL_CTX);
+       if (ssl->ctx != NULL)
+               SSL_CTX_free(ssl->ctx); /* decrement reference count */
+       ssl->ctx = ctx;
 
        return(ssl->ctx);
        }