Skip to content

Commit

Permalink
Add SecurityContextHolderStrategy XML Configuration for Defaults
Browse files Browse the repository at this point in the history
Issue gh-11061
  • Loading branch information
jzheaux committed Jun 17, 2022
1 parent 2c09a30 commit 2a70707
Show file tree
Hide file tree
Showing 10 changed files with 162 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ final class AuthenticationConfigBuilder {

AuthenticationConfigBuilder(Element element, boolean forceAutoConfig, ParserContext pc,
SessionCreationPolicy sessionPolicy, BeanReference requestCache, BeanReference authenticationManager,
BeanReference authenticationFilterSecurityContextHolderStrategyRef,
BeanReference authenticationFilterSecurityContextRepositoryRef, BeanReference sessionStrategy,
BeanReference portMapper, BeanReference portResolver, BeanMetadataElement csrfLogoutHandler) {
this.httpElt = element;
Expand All @@ -247,23 +248,24 @@ final class AuthenticationConfigBuilder {
this.portMapper = portMapper;
this.portResolver = portResolver;
this.csrfLogoutHandler = csrfLogoutHandler;
createAnonymousFilter();
createAnonymousFilter(authenticationFilterSecurityContextHolderStrategyRef);
createRememberMeFilter(authenticationManager);
createBasicFilter(authenticationManager);
createBasicFilter(authenticationManager, authenticationFilterSecurityContextHolderStrategyRef);
createBearerTokenAuthenticationFilter(authenticationManager);
createFormLoginFilter(sessionStrategy, authenticationManager, authenticationFilterSecurityContextRepositoryRef);
createFormLoginFilter(sessionStrategy, authenticationManager,
authenticationFilterSecurityContextHolderStrategyRef, authenticationFilterSecurityContextRepositoryRef);
createOAuth2ClientFilters(sessionStrategy, requestCache, authenticationManager,
authenticationFilterSecurityContextRepositoryRef);
createOpenIDLoginFilter(sessionStrategy, authenticationManager,
authenticationFilterSecurityContextRepositoryRef);
createSaml2LoginFilter(authenticationManager, authenticationFilterSecurityContextRepositoryRef);
createX509Filter(authenticationManager);
createJeeFilter(authenticationManager);
createLogoutFilter();
createLogoutFilter(authenticationFilterSecurityContextHolderStrategyRef);
createSaml2LogoutFilter();
createLoginPageFilterIfNeeded();
createUserDetailsServiceFactory();
createExceptionTranslationFilter();
createExceptionTranslationFilter(authenticationFilterSecurityContextHolderStrategyRef);
}

void createRememberMeFilter(BeanReference authenticationManager) {
Expand Down Expand Up @@ -293,6 +295,7 @@ private void createRememberMeProvider(String key) {
}

void createFormLoginFilter(BeanReference sessionStrategy, BeanReference authManager,
BeanReference authenticationFilterSecurityContextHolderStrategyRef,
BeanReference authenticationFilterSecurityContextRepositoryRef) {
Element formLoginElt = DomUtils.getChildElementByTagName(this.httpElt, Elements.FORM_LOGIN);
RootBeanDefinition formFilter = null;
Expand All @@ -313,6 +316,8 @@ void createFormLoginFilter(BeanReference sessionStrategy, BeanReference authMana
formFilter.getPropertyValues().addPropertyValue("securityContextRepository",
authenticationFilterSecurityContextRepositoryRef);
}
formFilter.getPropertyValues().addPropertyValue("securityContextHolderStrategy",
authenticationFilterSecurityContextHolderStrategyRef);
// Id is required by login page filter
this.formFilterId = this.pc.getReaderContext().generateBeanName(formFilter);
this.pc.registerBeanComponent(new BeanComponentDefinition(formFilter, this.formFilterId));
Expand Down Expand Up @@ -564,7 +569,8 @@ private void injectRememberMeServicesRef(RootBeanDefinition bean, String remembe
}
}

void createBasicFilter(BeanReference authManager) {
void createBasicFilter(BeanReference authManager,
BeanReference authenticationFilterSecurityContextHolderStrategyRef) {
Element basicAuthElt = DomUtils.getChildElementByTagName(this.httpElt, Elements.BASIC_AUTH);
if (basicAuthElt == null && !this.autoConfig) {
// No basic auth, do nothing
Expand Down Expand Up @@ -592,6 +598,8 @@ void createBasicFilter(BeanReference authManager) {
}
filterBuilder.addConstructorArgValue(authManager);
filterBuilder.addConstructorArgValue(this.basicEntryPoint);
filterBuilder.addPropertyValue("securityContextHolderStrategy",
authenticationFilterSecurityContextHolderStrategyRef);
this.basicFilter = filterBuilder.getBeanDefinition();
}

Expand Down Expand Up @@ -739,15 +747,16 @@ void createLoginPageFilterIfNeeded() {
}
}

void createLogoutFilter() {
void createLogoutFilter(BeanReference authenticationFilterSecurityContextHolderStrategyRef) {
Element logoutElt = DomUtils.getChildElementByTagName(this.httpElt, Elements.LOGOUT);
if (logoutElt != null || this.autoConfig) {
String formLoginPage = this.formLoginPage;
if (formLoginPage == null) {
formLoginPage = DefaultLoginPageGeneratingFilter.DEFAULT_LOGIN_PAGE_URL;
}
LogoutBeanDefinitionParser logoutParser = new LogoutBeanDefinitionParser(formLoginPage,
this.rememberMeServicesId, this.csrfLogoutHandler);
this.rememberMeServicesId, this.csrfLogoutHandler,
authenticationFilterSecurityContextHolderStrategyRef);
this.logoutFilter = logoutParser.parse(logoutElt, this.pc);
this.logoutHandlers = logoutParser.getLogoutHandlers();
this.logoutSuccessHandler = logoutParser.getLogoutSuccessHandler();
Expand Down Expand Up @@ -803,7 +812,7 @@ List<BeanDefinition> getCsrfIgnoreRequestMatchers() {
return this.csrfIgnoreRequestMatchers;
}

void createAnonymousFilter() {
void createAnonymousFilter(BeanReference authenticationFilterSecurityContextHolderStrategyRef) {
Element anonymousElt = DomUtils.getChildElementByTagName(this.httpElt, Elements.ANONYMOUS);
if (anonymousElt != null && "false".equals(anonymousElt.getAttribute("enabled"))) {
return;
Expand Down Expand Up @@ -833,6 +842,8 @@ void createAnonymousFilter() {
this.anonymousFilter.getConstructorArgumentValues().addIndexedArgumentValue(1, username);
this.anonymousFilter.getConstructorArgumentValues().addIndexedArgumentValue(2,
AuthorityUtils.commaSeparatedStringToAuthorityList(grantedAuthority));
this.anonymousFilter.getPropertyValues().addPropertyValue("securityContextHolderStrategy",
authenticationFilterSecurityContextHolderStrategyRef);
this.anonymousFilter.setSource(source);
RootBeanDefinition anonymousProviderBean = new RootBeanDefinition(AnonymousAuthenticationProvider.class);
anonymousProviderBean.getConstructorArgumentValues().addIndexedArgumentValue(0, key);
Expand All @@ -847,14 +858,16 @@ private String createKey() {
return Long.toString(random.nextLong());
}

void createExceptionTranslationFilter() {
void createExceptionTranslationFilter(BeanReference authenticationFilterSecurityContextHolderStrategyRef) {
BeanDefinitionBuilder etfBuilder = BeanDefinitionBuilder.rootBeanDefinition(ExceptionTranslationFilter.class);
this.accessDeniedHandler = createAccessDeniedHandler(this.httpElt, this.pc);
etfBuilder.addPropertyValue("accessDeniedHandler", this.accessDeniedHandler);
Assert.state(this.requestCache != null, "No request cache found");
this.mainEntryPoint = selectEntryPoint();
etfBuilder.addConstructorArgValue(this.mainEntryPoint);
etfBuilder.addConstructorArgValue(this.requestCache);
etfBuilder.addPropertyValue("securityContextHolderStrategy",
authenticationFilterSecurityContextHolderStrategyRef);
this.etf = etfBuilder.getBeanDefinition();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.w3c.dom.Element;

import org.springframework.beans.BeanMetadataElement;
import org.springframework.beans.factory.FactoryBean;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.BeanReference;
import org.springframework.beans.factory.config.RuntimeBeanReference;
Expand All @@ -40,6 +41,8 @@
import org.springframework.security.access.vote.RoleVoter;
import org.springframework.security.config.Elements;
import org.springframework.security.config.http.GrantedAuthorityDefaultsParserUtils.AbstractGrantedAuthorityDefaultsBeanFactory;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.core.session.SessionRegistryImpl;
import org.springframework.security.web.access.AuthorizationManagerWebInvocationPrivilegeEvaluator;
import org.springframework.security.web.access.DefaultWebInvocationPrivilegeEvaluator;
Expand Down Expand Up @@ -106,6 +109,8 @@ class HttpConfigurationBuilder {

private static final String ATT_SESSION_AUTH_ERROR_URL = "session-authentication-error-url";

private static final String ATT_SECURITY_CONTEXT_HOLDER_STRATEGY = "security-context-holder-strategy-ref";

private static final String ATT_SECURITY_CONTEXT_REPOSITORY = "security-context-repository-ref";

private static final String ATT_SECURITY_CONTEXT_EXPLICIT_SAVE = "security-context-explicit-save";
Expand Down Expand Up @@ -156,6 +161,8 @@ class HttpConfigurationBuilder {

private BeanDefinition forceEagerSessionCreationFilter;

private BeanReference holderStrategyRef;

private BeanReference contextRepoRef;

private BeanReference sessionRegistryRef;
Expand Down Expand Up @@ -215,6 +222,7 @@ class HttpConfigurationBuilder {
String createSession = element.getAttribute(ATT_CREATE_SESSION);
this.sessionPolicy = !StringUtils.hasText(createSession) ? SessionCreationPolicy.IF_REQUIRED
: createPolicy(createSession);
createSecurityContextHolderStrategy();
createForceEagerSessionCreationFilter();
createDisableEncodeUrlFilter();
createCsrfFilter();
Expand Down Expand Up @@ -294,6 +302,10 @@ static String createPath(String path, boolean lowerCase) {
return lowerCase ? path.toLowerCase() : path;
}

BeanReference getSecurityContextHolderStrategyForAuthenticationFilters() {
return this.holderStrategyRef;
}

BeanReference getSecurityContextRepositoryForAuthenticationFilters() {
return (isExplicitSave()) ? this.contextRepoRef : null;
}
Expand Down Expand Up @@ -331,11 +343,23 @@ private void createSecurityContextPersistenceFilter() {
default:
scpf.addPropertyValue("forceEagerSessionCreation", Boolean.FALSE);
}
scpf.addPropertyValue("securityContextHolderStrategy", this.holderStrategyRef);
scpf.addConstructorArgValue(this.contextRepoRef);

this.securityContextPersistenceFilter = scpf.getBeanDefinition();
}

private void createSecurityContextHolderStrategy() {
String holderStrategyRef = this.httpElt.getAttribute(ATT_SECURITY_CONTEXT_HOLDER_STRATEGY);
if (!StringUtils.hasText(holderStrategyRef)) {
BeanDefinition holderStrategyBean = BeanDefinitionBuilder
.rootBeanDefinition(SecurityContextHolderStrategyFactory.class).getBeanDefinition();
holderStrategyRef = this.pc.getReaderContext().generateBeanName(holderStrategyBean);
this.pc.registerBeanComponent(new BeanComponentDefinition(holderStrategyBean, holderStrategyRef));
}
this.holderStrategyRef = new RuntimeBeanReference(holderStrategyRef);
}

private void createSecurityContextRepository() {
String repoRef = this.httpElt.getAttribute(ATT_SECURITY_CONTEXT_REPOSITORY);
if (!StringUtils.hasText(repoRef)) {
Expand All @@ -359,6 +383,7 @@ private void createSecurityContextRepository() {
contextRepo.addPropertyValue("disableUrlRewriting", Boolean.TRUE);
}
}
contextRepo.addPropertyValue("securityContextHolderStrategy", this.holderStrategyRef);
BeanDefinition repoBean = contextRepo.getBeanDefinition();
repoRef = this.pc.getReaderContext().generateBeanName(repoBean);
this.pc.registerBeanComponent(new BeanComponentDefinition(repoBean, repoRef));
Expand All @@ -374,6 +399,7 @@ private boolean isDisableUrlRewriting() {

private void createSecurityContextHolderFilter() {
BeanDefinitionBuilder filter = BeanDefinitionBuilder.rootBeanDefinition(SecurityContextHolderFilter.class);
filter.addPropertyValue("securityContextHolderStrategy", this.holderStrategyRef);
filter.addConstructorArgValue(this.contextRepoRef);
this.securityContextPersistenceFilter = filter.getBeanDefinition();
}
Expand Down Expand Up @@ -485,6 +511,7 @@ else if (StringUtils.hasText(sessionAuthStratRef)) {
if (StringUtils.hasText(errorUrl)) {
failureHandler.getPropertyValues().addPropertyValue("defaultFailureUrl", errorUrl);
}
sessionMgmtFilter.addPropertyValue("securityContextHolderStrategy", this.holderStrategyRef);
sessionMgmtFilter.addPropertyValue("authenticationFailureHandler", failureHandler);
sessionMgmtFilter.addConstructorArgValue(this.contextRepoRef);
if (!StringUtils.hasText(sessionAuthStratRef) && sessionFixationStrategy != null && !useChangeSessionId) {
Expand Down Expand Up @@ -744,6 +771,7 @@ private void createFilterSecurityInterceptor(BeanReference authManager) {
builder.addPropertyValue("observeOncePerRequest", Boolean.FALSE);
}
builder.addPropertyValue("securityMetadataSource", securityMds);
builder.addPropertyValue("securityContextHolderStrategy", this.holderStrategyRef);
BeanDefinition fsiBean = builder.getBeanDefinition();
String fsiId = this.pc.getReaderContext().generateBeanName(fsiBean);
this.pc.registerBeanComponent(new BeanComponentDefinition(fsiBean, fsiId));
Expand Down Expand Up @@ -883,4 +911,18 @@ public SecurityContextHolderAwareRequestFilter getBean() {

}

static class SecurityContextHolderStrategyFactory implements FactoryBean<SecurityContextHolderStrategy> {

@Override
public SecurityContextHolderStrategy getObject() throws Exception {
return SecurityContextHolder.getContextHolderStrategy();
}

@Override
public Class<?> getObjectType() {
return SecurityContextHolderStrategy.class;
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ private BeanReference createFilterChain(Element element, ParserContext pc) {
httpBldr.getSecurityContextRepositoryForAuthenticationFilters();
AuthenticationConfigBuilder authBldr = new AuthenticationConfigBuilder(element, forceAutoConfig, pc,
httpBldr.getSessionCreationPolicy(), httpBldr.getRequestCache(), authenticationManager,
httpBldr.getSecurityContextHolderStrategyForAuthenticationFilters(),
httpBldr.getSecurityContextRepositoryForAuthenticationFilters(), httpBldr.getSessionStrategy(),
portMapper, portResolver, httpBldr.getCsrfLogoutHandler());
httpBldr.setLogoutHandlers(authBldr.getLogoutHandlers());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import org.springframework.beans.BeanMetadataElement;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.BeanReference;
import org.springframework.beans.factory.config.RuntimeBeanReference;
import org.springframework.beans.factory.support.BeanDefinitionBuilder;
import org.springframework.beans.factory.support.ManagedList;
Expand Down Expand Up @@ -61,13 +62,17 @@ class LogoutBeanDefinitionParser implements BeanDefinitionParser {

private BeanMetadataElement logoutSuccessHandler;

LogoutBeanDefinitionParser(String loginPageUrl, String rememberMeServices, BeanMetadataElement csrfLogoutHandler) {
private BeanReference authenticationFilterSecurityContextHolderStrategyRef;

LogoutBeanDefinitionParser(String loginPageUrl, String rememberMeServices, BeanMetadataElement csrfLogoutHandler,
BeanReference authenticationFilterSecurityContextHolderStrategyRef) {
this.defaultLogoutUrl = loginPageUrl + "?logout";
this.rememberMeServices = rememberMeServices;
this.csrfEnabled = csrfLogoutHandler != null;
if (this.csrfEnabled) {
this.logoutHandlers.add(csrfLogoutHandler);
}
this.authenticationFilterSecurityContextHolderStrategyRef = authenticationFilterSecurityContextHolderStrategyRef;
}

@Override
Expand Down Expand Up @@ -123,6 +128,8 @@ public BeanDefinition parse(Element element, ParserContext pc) {
}
this.logoutHandlers.add(new RootBeanDefinition(LogoutSuccessEventPublishingLogoutHandler.class));
builder.addConstructorArgValue(this.logoutHandlers);
builder.addPropertyValue("securityContextHolderStrategy",
this.authenticationFilterSecurityContextHolderStrategyRef);
return builder.getBeanDefinition();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,9 @@ http.attlist &=
attribute auto-config {xsd:boolean}?
http.attlist &=
use-expressions?
http.attlist &=
## A reference to a SecurityContextHolderStrategy bean. This can be used to customize how the SecurityContextHolder is stored during a request
attribute security-context-holder-strategy-ref {xsd:token}?
http.attlist &=
## Controls the eagerness with which an HTTP session is created by Spring Security classes. If not set, defaults to "ifRequired". If "stateless" is used, this implies that the application guarantees that it will not create a session. This differs from the use of "never" which means that Spring Security will not create a session, but will make use of one if the application does.
attribute create-session {"ifRequired" | "always" | "never" | "stateless"}?
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1224,6 +1224,13 @@
</xs:documentation>
</xs:annotation>
</xs:attribute>
<xs:attribute name="security-context-holder-strategy-ref" type="xs:token">
<xs:annotation>
<xs:documentation>A reference to a SecurityContextHolderStrategy bean. This can be used to customize how the
SecurityContextHolder is stored during a request
</xs:documentation>
</xs:annotation>
</xs:attribute>
<xs:attribute name="create-session">
<xs:annotation>
<xs:documentation>Controls the eagerness with which an HTTP session is created by Spring Security classes.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.springframework.security.config.test.SpringTestContextExtension;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.web.FilterChainProxy;
import org.springframework.security.web.authentication.AuthenticationFailureHandler;
import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
Expand All @@ -45,6 +46,8 @@

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.verify;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post;
Expand Down Expand Up @@ -147,6 +150,14 @@ public void authenticateWhenCustomUsernameAndPasswordParametersThenSucceeds() th
.andExpect(redirectedUrl("/"));
}

@Test
public void authenticateWhenCustomSecurityContextHolderStrategyThenUses() throws Exception {
this.spring.configLocations(this.xml("WithCustomSecurityContextHolderStrategy")).autowire();
SecurityContextHolderStrategy strategy = this.spring.getContext().getBean(SecurityContextHolderStrategy.class);
this.mvc.perform(post("/login").with(csrf())).andExpect(redirectedUrl("/login?error"));
verify(strategy, atLeastOnce()).getContext();
}

/**
* SEC-2919 - DefaultLoginGeneratingFilter incorrectly used if login-url="/login"
*/
Expand Down
Loading

0 comments on commit 2a70707

Please sign in to comment.