Skip to content

Commit

Permalink
Add SecurityContextHolderFilter
Browse files Browse the repository at this point in the history
Closes gh-9635
  • Loading branch information
rwinch committed Mar 12, 2022
1 parent f9619ce commit 972039e
Show file tree
Hide file tree
Showing 21 changed files with 571 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.springframework.security.web.authentication.ui.DefaultLogoutPageGeneratingFilter;
import org.springframework.security.web.authentication.www.BasicAuthenticationFilter;
import org.springframework.security.web.authentication.www.DigestAuthenticationFilter;
import org.springframework.security.web.context.SecurityContextHolderFilter;
import org.springframework.security.web.context.SecurityContextPersistenceFilter;
import org.springframework.security.web.context.request.async.WebAsyncManagerIntegrationFilter;
import org.springframework.security.web.csrf.CsrfFilter;
Expand Down Expand Up @@ -70,6 +71,7 @@ final class FilterOrderRegistration {
put(ChannelProcessingFilter.class, order.next());
order.next(); // gh-8105
put(WebAsyncManagerIntegrationFilter.class, order.next());
put(SecurityContextHolderFilter.class, order.next());
put(SecurityContextPersistenceFilter.class, order.next());
put(HeaderWriterFilter.class, order.next());
put(CorsFilter.class, order.next());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.springframework.security.web.authentication.SimpleUrlAuthenticationFailureHandler;
import org.springframework.security.web.authentication.WebAuthenticationDetailsSource;
import org.springframework.security.web.authentication.session.SessionAuthenticationStrategy;
import org.springframework.security.web.context.SecurityContextRepository;
import org.springframework.security.web.savedrequest.RequestCache;
import org.springframework.security.web.util.matcher.AndRequestMatcher;
import org.springframework.security.web.util.matcher.MediaTypeRequestMatcher;
Expand Down Expand Up @@ -144,6 +145,11 @@ public T loginProcessingUrl(String loginProcessingUrl) {
return getSelf();
}

public T securityContextRepository(SecurityContextRepository securityContextRepository) {
this.authFilter.setSecurityContextRepository(securityContextRepository);
return getSelf();
}

/**
* Create the {@link RequestMatcher} given a loginProcessingUrl
* @param loginProcessingUrl creates the {@link RequestMatcher} based upon the
Expand Down Expand Up @@ -285,6 +291,12 @@ public void configure(B http) throws Exception {
if (rememberMeServices != null) {
this.authFilter.setRememberMeServices(rememberMeServices);
}
SecurityContextConfigurer securityContextConfigurer = http.getConfigurer(SecurityContextConfigurer.class);
if (securityContextConfigurer != null && securityContextConfigurer.isRequireExplicitSave()) {
SecurityContextRepository securityContextRepository = securityContextConfigurer
.getSecurityContextRepository();
this.authFilter.setSecurityContextRepository(securityContextRepository);
}
F filter = postProcess(this.authFilter);
http.addFilter(filter);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.web.context.HttpSessionSecurityContextRepository;
import org.springframework.security.web.context.SecurityContextHolderFilter;
import org.springframework.security.web.context.SecurityContextPersistenceFilter;
import org.springframework.security.web.context.SecurityContextRepository;

Expand Down Expand Up @@ -62,6 +63,8 @@
public final class SecurityContextConfigurer<H extends HttpSecurityBuilder<H>>
extends AbstractHttpConfigurer<SecurityContextConfigurer<H>, H> {

private boolean requireExplicitSave;

/**
* Creates a new instance
* @see HttpSecurity#securityContext()
Expand All @@ -79,23 +82,45 @@ public SecurityContextConfigurer<H> securityContextRepository(SecurityContextRep
return this;
}

public SecurityContextConfigurer<H> requireExplicitSave(boolean requireExplicitSave) {
this.requireExplicitSave = requireExplicitSave;
return this;
}

boolean isRequireExplicitSave() {
return this.requireExplicitSave;
}

SecurityContextRepository getSecurityContextRepository() {
SecurityContextRepository securityContextRepository = getBuilder()
.getSharedObject(SecurityContextRepository.class);
if (securityContextRepository == null) {
securityContextRepository = new HttpSessionSecurityContextRepository();
}
return securityContextRepository;
}

@Override
@SuppressWarnings("unchecked")
public void configure(H http) {
SecurityContextRepository securityContextRepository = http.getSharedObject(SecurityContextRepository.class);
if (securityContextRepository == null) {
securityContextRepository = new HttpSessionSecurityContextRepository();
SecurityContextRepository securityContextRepository = getSecurityContextRepository();
if (this.requireExplicitSave) {
SecurityContextHolderFilter securityContextHolderFilter = postProcess(
new SecurityContextHolderFilter(securityContextRepository));
http.addFilter(securityContextHolderFilter);
}
SecurityContextPersistenceFilter securityContextFilter = new SecurityContextPersistenceFilter(
securityContextRepository);
SessionManagementConfigurer<?> sessionManagement = http.getConfigurer(SessionManagementConfigurer.class);
SessionCreationPolicy sessionCreationPolicy = (sessionManagement != null)
? sessionManagement.getSessionCreationPolicy() : null;
if (SessionCreationPolicy.ALWAYS == sessionCreationPolicy) {
securityContextFilter.setForceEagerSessionCreation(true);
else {
SecurityContextPersistenceFilter securityContextFilter = new SecurityContextPersistenceFilter(
securityContextRepository);
SessionManagementConfigurer<?> sessionManagement = http.getConfigurer(SessionManagementConfigurer.class);
SessionCreationPolicy sessionCreationPolicy = (sessionManagement != null)
? sessionManagement.getSessionCreationPolicy() : null;
if (SessionCreationPolicy.ALWAYS == sessionCreationPolicy) {
securityContextFilter.setForceEagerSessionCreation(true);
}
securityContextFilter = postProcess(securityContextFilter);
http.addFilter(securityContextFilter);
}
securityContextFilter = postProcess(securityContextFilter);
http.addFilter(securityContextFilter);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,8 @@ final class AuthenticationConfigBuilder {

AuthenticationConfigBuilder(Element element, boolean forceAutoConfig, ParserContext pc,
SessionCreationPolicy sessionPolicy, BeanReference requestCache, BeanReference authenticationManager,
BeanReference sessionStrategy, BeanReference portMapper, BeanReference portResolver,
BeanMetadataElement csrfLogoutHandler) {
BeanReference authenticationFilterSecurityContextRepositoryRef, BeanReference sessionStrategy,
BeanReference portMapper, BeanReference portResolver, BeanMetadataElement csrfLogoutHandler) {
this.httpElt = element;
this.pc = pc;
this.requestCache = requestCache;
Expand All @@ -231,9 +231,10 @@ final class AuthenticationConfigBuilder {
createRememberMeFilter(authenticationManager);
createBasicFilter(authenticationManager);
createBearerTokenAuthenticationFilter(authenticationManager);
createFormLoginFilter(sessionStrategy, authenticationManager);
createOAuth2ClientFilters(sessionStrategy, requestCache, authenticationManager);
createSaml2LoginFilter(authenticationManager);
createFormLoginFilter(sessionStrategy, authenticationManager, authenticationFilterSecurityContextRepositoryRef);
createOAuth2ClientFilters(sessionStrategy, requestCache, authenticationManager,
authenticationFilterSecurityContextRepositoryRef);
createSaml2LoginFilter(authenticationManager, authenticationFilterSecurityContextRepositoryRef);
createX509Filter(authenticationManager);
createJeeFilter(authenticationManager);
createLogoutFilter();
Expand Down Expand Up @@ -269,7 +270,8 @@ private void createRememberMeProvider(String key) {
this.rememberMeProviderRef = new RuntimeBeanReference(id);
}

void createFormLoginFilter(BeanReference sessionStrategy, BeanReference authManager) {
void createFormLoginFilter(BeanReference sessionStrategy, BeanReference authManager,
BeanReference authenticationFilterSecurityContextRepositoryRef) {
Element formLoginElt = DomUtils.getChildElementByTagName(this.httpElt, Elements.FORM_LOGIN);
RootBeanDefinition formFilter = null;
if (formLoginElt != null || this.autoConfig) {
Expand All @@ -285,6 +287,10 @@ void createFormLoginFilter(BeanReference sessionStrategy, BeanReference authMana
if (formFilter != null) {
formFilter.getPropertyValues().addPropertyValue("allowSessionCreation", this.allowSessionCreation);
formFilter.getPropertyValues().addPropertyValue("authenticationManager", authManager);
if (authenticationFilterSecurityContextRepositoryRef != null) {
formFilter.getPropertyValues().addPropertyValue("securityContextRepository",
authenticationFilterSecurityContextRepositoryRef);
}
// Id is required by login page filter
this.formFilterId = this.pc.getReaderContext().generateBeanName(formFilter);
this.pc.registerBeanComponent(new BeanComponentDefinition(formFilter, this.formFilterId));
Expand All @@ -293,13 +299,15 @@ void createFormLoginFilter(BeanReference sessionStrategy, BeanReference authMana
}

void createOAuth2ClientFilters(BeanReference sessionStrategy, BeanReference requestCache,
BeanReference authenticationManager) {
createOAuth2LoginFilter(sessionStrategy, authenticationManager);
createOAuth2ClientFilter(requestCache, authenticationManager);
BeanReference authenticationManager, BeanReference authenticationFilterSecurityContextRepositoryRef) {
createOAuth2LoginFilter(sessionStrategy, authenticationManager,
authenticationFilterSecurityContextRepositoryRef);
createOAuth2ClientFilter(requestCache, authenticationManager, authenticationFilterSecurityContextRepositoryRef);
registerOAuth2ClientPostProcessors();
}

void createOAuth2LoginFilter(BeanReference sessionStrategy, BeanReference authManager) {
void createOAuth2LoginFilter(BeanReference sessionStrategy, BeanReference authManager,
BeanReference authenticationFilterSecurityContextRepositoryRef) {
Element oauth2LoginElt = DomUtils.getChildElementByTagName(this.httpElt, Elements.OAUTH2_LOGIN);
if (oauth2LoginElt == null) {
return;
Expand All @@ -311,6 +319,10 @@ void createOAuth2LoginFilter(BeanReference sessionStrategy, BeanReference authMa
BeanDefinition defaultAuthorizedClientRepository = parser.getDefaultAuthorizedClientRepository();
registerDefaultAuthorizedClientRepositoryIfNecessary(defaultAuthorizedClientRepository);
oauth2LoginFilterBean.getPropertyValues().addPropertyValue("authenticationManager", authManager);
if (authenticationFilterSecurityContextRepositoryRef != null) {
oauth2LoginFilterBean.getPropertyValues().addPropertyValue("securityContextRepository",
authenticationFilterSecurityContextRepositoryRef);
}

// retrieve the other bean result
BeanDefinition oauth2LoginAuthProvider = parser.getOAuth2LoginAuthenticationProvider();
Expand Down Expand Up @@ -340,14 +352,15 @@ void createOAuth2LoginFilter(BeanReference sessionStrategy, BeanReference authMa
this.oauth2LoginOidcAuthenticationProviderRef = new RuntimeBeanReference(oauth2LoginOidcAuthProviderId);
}

void createOAuth2ClientFilter(BeanReference requestCache, BeanReference authenticationManager) {
void createOAuth2ClientFilter(BeanReference requestCache, BeanReference authenticationManager,
BeanReference authenticationFilterSecurityContextRepositoryRef) {
Element oauth2ClientElt = DomUtils.getChildElementByTagName(this.httpElt, Elements.OAUTH2_CLIENT);
if (oauth2ClientElt == null) {
return;
}
this.oauth2ClientEnabled = true;
OAuth2ClientBeanDefinitionParser parser = new OAuth2ClientBeanDefinitionParser(requestCache,
authenticationManager);
authenticationManager, authenticationFilterSecurityContextRepositoryRef);
parser.parse(oauth2ClientElt, this.pc);
BeanDefinition defaultAuthorizedClientRepository = parser.getDefaultAuthorizedClientRepository();
registerDefaultAuthorizedClientRepositoryIfNecessary(defaultAuthorizedClientRepository);
Expand Down Expand Up @@ -392,14 +405,16 @@ private void registerOAuth2ClientPostProcessors() {
}
}

private void createSaml2LoginFilter(BeanReference authenticationManager) {
private void createSaml2LoginFilter(BeanReference authenticationManager,
BeanReference authenticationFilterSecurityContextRepositoryRef) {
Element saml2LoginElt = DomUtils.getChildElementByTagName(this.httpElt, Elements.SAML2_LOGIN);
if (saml2LoginElt == null) {
return;
}
Saml2LoginBeanDefinitionParser parser = new Saml2LoginBeanDefinitionParser(this.csrfIgnoreRequestMatchers,
this.portMapper, this.portResolver, this.requestCache, this.allowSessionCreation, authenticationManager,
this.authenticationProviders, this.defaultEntryPointMappings);
authenticationFilterSecurityContextRepositoryRef, this.authenticationProviders,
this.defaultEntryPointMappings);
BeanDefinition saml2WebSsoAuthenticationFilter = parser.parse(saml2LoginElt, this.pc);
this.saml2AuthorizationRequestFilter = parser.getSaml2WebSsoAuthenticationRequestFilter();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
import org.springframework.security.web.authentication.session.SessionFixationProtectionStrategy;
import org.springframework.security.web.context.HttpSessionSecurityContextRepository;
import org.springframework.security.web.context.NullSecurityContextRepository;
import org.springframework.security.web.context.SecurityContextHolderFilter;
import org.springframework.security.web.context.SecurityContextPersistenceFilter;
import org.springframework.security.web.context.request.async.WebAsyncManagerIntegrationFilter;
import org.springframework.security.web.jaasapi.JaasApiIntegrationFilter;
Expand Down Expand Up @@ -104,6 +105,8 @@ class HttpConfigurationBuilder {

private static final String ATT_SECURITY_CONTEXT_REPOSITORY = "security-context-repository-ref";

private static final String ATT_SECURITY_CONTEXT_EXPLICIT_SAVE = "security-context-explicit-save";

private static final String ATT_INVALID_SESSION_STRATEGY_REF = "invalid-session-strategy-ref";

private static final String ATT_DISABLE_URL_REWRITING = "disable-url-rewriting";
Expand Down Expand Up @@ -202,8 +205,7 @@ class HttpConfigurationBuilder {
this.sessionPolicy = !StringUtils.hasText(createSession) ? SessionCreationPolicy.IF_REQUIRED
: createPolicy(createSession);
createCsrfFilter();
createSecurityContextRepository();
createSecurityContextPersistenceFilter();
createSecurityPersistence();
createSessionManagementFilters();
createWebAsyncManagerFilter();
createRequestCacheFilter();
Expand Down Expand Up @@ -279,9 +281,27 @@ static String createPath(String path, boolean lowerCase) {
return lowerCase ? path.toLowerCase() : path;
}

BeanReference getSecurityContextRepositoryForAuthenticationFilters() {
return (isExplicitSave()) ? this.contextRepoRef : null;
}

private void createSecurityPersistence() {
createSecurityContextRepository();
if (isExplicitSave()) {
createSecurityContextHolderFilter();
}
else {
createSecurityContextPersistenceFilter();
}
}

private boolean isExplicitSave() {
String explicitSaveAttr = this.httpElt.getAttribute(ATT_SECURITY_CONTEXT_EXPLICIT_SAVE);
return Boolean.parseBoolean(explicitSaveAttr);
}

private void createSecurityContextPersistenceFilter() {
BeanDefinitionBuilder scpf = BeanDefinitionBuilder.rootBeanDefinition(SecurityContextPersistenceFilter.class);
String disableUrlRewriting = this.httpElt.getAttribute(ATT_DISABLE_URL_REWRITING);
switch (this.sessionPolicy) {
case ALWAYS:
scpf.addPropertyValue("forceEagerSessionCreation", Boolean.TRUE);
Expand Down Expand Up @@ -332,6 +352,12 @@ private void createSecurityContextRepository() {
this.contextRepoRef = new RuntimeBeanReference(repoRef);
}

private void createSecurityContextHolderFilter() {
BeanDefinitionBuilder filter = BeanDefinitionBuilder.rootBeanDefinition(SecurityContextHolderFilter.class);
filter.addConstructorArgValue(this.contextRepoRef);
this.securityContextPersistenceFilter = filter.getBeanDefinition();
}

private void createSessionManagementFilters() {
Element sessionMgmtElt = DomUtils.getChildElementByTagName(this.httpElt, Elements.SESSION_MANAGEMENT);
Element sessionCtrlElt = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,11 @@ private BeanReference createFilterChain(Element element, ParserContext pc) {
boolean forceAutoConfig = isDefaultHttpConfig(element);
HttpConfigurationBuilder httpBldr = new HttpConfigurationBuilder(element, forceAutoConfig, pc, portMapper,
portResolver, authenticationManager);
httpBldr.getSecurityContextRepositoryForAuthenticationFilters();
AuthenticationConfigBuilder authBldr = new AuthenticationConfigBuilder(element, forceAutoConfig, pc,
httpBldr.getSessionCreationPolicy(), httpBldr.getRequestCache(), authenticationManager,
httpBldr.getSessionStrategy(), portMapper, portResolver, httpBldr.getCsrfLogoutHandler());
httpBldr.getSecurityContextRepositoryForAuthenticationFilters(), httpBldr.getSessionStrategy(),
portMapper, portResolver, httpBldr.getCsrfLogoutHandler());
httpBldr.setLogoutHandlers(authBldr.getLogoutHandlers());
httpBldr.setEntryPoint(authBldr.getEntryPointBean());
httpBldr.setAccessDeniedHandler(authBldr.getAccessDeniedHandlerBean());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ final class OAuth2ClientBeanDefinitionParser implements BeanDefinitionParser {

private final BeanReference authenticationManager;

private final BeanReference authenticationFilterSecurityContextRepositoryRef;

private BeanDefinition defaultAuthorizedClientRepository;

private BeanDefinition authorizationRequestRedirectFilter;
Expand All @@ -58,9 +60,11 @@ final class OAuth2ClientBeanDefinitionParser implements BeanDefinitionParser {

private BeanDefinition authorizationCodeAuthenticationProvider;

OAuth2ClientBeanDefinitionParser(BeanReference requestCache, BeanReference authenticationManager) {
OAuth2ClientBeanDefinitionParser(BeanReference requestCache, BeanReference authenticationManager,
BeanReference authenticationFilterSecurityContextRepositoryRef) {
this.requestCache = requestCache;
this.authenticationManager = authenticationManager;
this.authenticationFilterSecurityContextRepositoryRef = authenticationFilterSecurityContextRepositoryRef;
}

@Override
Expand Down Expand Up @@ -92,11 +96,16 @@ public BeanDefinition parse(Element element, ParserContext parserContext) {
this.authorizationRequestRedirectFilter = authorizationRequestRedirectFilterBuilder
.addPropertyValue("authorizationRequestRepository", authorizationRequestRepository)
.addPropertyValue("requestCache", this.requestCache).getBeanDefinition();
this.authorizationCodeGrantFilter = BeanDefinitionBuilder
BeanDefinitionBuilder authorizationCodeGrantFilterBldr = BeanDefinitionBuilder
.rootBeanDefinition(OAuth2AuthorizationCodeGrantFilter.class)
.addConstructorArgValue(clientRegistrationRepository).addConstructorArgValue(authorizedClientRepository)
.addConstructorArgValue(this.authenticationManager)
.addPropertyValue("authorizationRequestRepository", authorizationRequestRepository).getBeanDefinition();
.addPropertyValue("authorizationRequestRepository", authorizationRequestRepository);
if (this.authenticationFilterSecurityContextRepositoryRef != null) {
authorizationCodeGrantFilterBldr.addPropertyValue("securityContextRepository",
this.authenticationFilterSecurityContextRepositoryRef);
}
this.authorizationCodeGrantFilter = authorizationCodeGrantFilterBldr.getBeanDefinition();

BeanMetadataElement accessTokenResponseClient = getAccessTokenResponseClient(authorizationCodeGrantElt);
this.authorizationCodeAuthenticationProvider = BeanDefinitionBuilder
Expand Down
Loading

0 comments on commit 972039e

Please sign in to comment.