Skip to content

Commit

Permalink
Defer CsrfFilter Session Access
Browse files Browse the repository at this point in the history
Closes gh-11456
  • Loading branch information
rwinch committed Aug 16, 2022
2 parents 002a770 + 5b64526 commit c1a6cea
Show file tree
Hide file tree
Showing 11 changed files with 185 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ public final class CsrfConfigurer<H extends HttpSecurityBuilder<H>>

private SessionAuthenticationStrategy sessionAuthenticationStrategy;

private String csrfRequestAttributeName;

private final ApplicationContext context;

/**
Expand Down Expand Up @@ -124,6 +126,16 @@ public CsrfConfigurer<H> requireCsrfProtectionMatcher(RequestMatcher requireCsrf
return this;
}

/**
* Sets the {@link CsrfFilter#setCsrfRequestAttributeName(String)}
* @param csrfRequestAttributeName the attribute name to set the CsrfToken on.
* @return the {@link CsrfConfigurer} for further customizations.
*/
public CsrfConfigurer<H> csrfRequestAttributeName(String csrfRequestAttributeName) {
this.csrfRequestAttributeName = csrfRequestAttributeName;
return this;
}

/**
* <p>
* Allows specifying {@link HttpServletRequest} that should not use CSRF Protection
Expand Down Expand Up @@ -202,6 +214,9 @@ public CsrfConfigurer<H> sessionAuthenticationStrategy(
@Override
public void configure(H http) {
CsrfFilter filter = new CsrfFilter(this.csrfTokenRepository);
if (this.csrfRequestAttributeName != null) {
filter.setCsrfRequestAttributeName(this.csrfRequestAttributeName);
}
RequestMatcher requireCsrfProtectionMatcher = getRequireCsrfProtectionMatcher();
if (requireCsrfProtectionMatcher != null) {
filter.setRequireCsrfProtectionMatcher(requireCsrfProtectionMatcher);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,14 @@ public class CsrfBeanDefinitionParser implements BeanDefinitionParser {

private static final String DISPATCHER_SERVLET_CLASS_NAME = "org.springframework.web.servlet.DispatcherServlet";

private static final String ATT_REQUEST_ATTRIBUTE_NAME = "request-attribute-name";

private static final String ATT_MATCHER = "request-matcher-ref";

private static final String ATT_REPOSITORY = "token-repository-ref";

private String requestAttributeName;

private String csrfRepositoryRef;

private BeanDefinition csrfFilter;
Expand All @@ -94,6 +98,7 @@ public BeanDefinition parse(Element element, ParserContext pc) {
}
if (element != null) {
this.csrfRepositoryRef = element.getAttribute(ATT_REPOSITORY);
this.requestAttributeName = element.getAttribute(ATT_REQUEST_ATTRIBUTE_NAME);
this.requestMatcherRef = element.getAttribute(ATT_MATCHER);
}
if (!StringUtils.hasText(this.csrfRepositoryRef)) {
Expand All @@ -110,6 +115,9 @@ public BeanDefinition parse(Element element, ParserContext pc) {
if (StringUtils.hasText(this.requestMatcherRef)) {
builder.addPropertyReference("requireCsrfProtectionMatcher", this.requestMatcherRef);
}
if (StringUtils.hasText(this.requestAttributeName)) {
builder.addPropertyValue("csrfRequestAttributeName", this.requestAttributeName);
}
this.csrfFilter = builder.getBeanDefinition();
return this.csrfFilter;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1136,6 +1136,9 @@ csrf =
csrf-options.attlist &=
## Specifies if csrf protection should be disabled. Default false (i.e. CSRF protection is enabled).
attribute disabled {xsd:boolean}?
csrf-options.attlist &=
## The request attribute name the CsrfToken is set on. Default is to set to CsrfToken.parameterName
attribute request-attribute-name { xsd:token }?
csrf-options.attlist &=
## The RequestMatcher instance to be used to determine if CSRF should be applied. Default is any HTTP method except "GET", "TRACE", "HEAD", "OPTIONS"
attribute request-matcher-ref { xsd:token }?
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3217,6 +3217,13 @@
</xs:documentation>
</xs:annotation>
</xs:attribute>
<xs:attribute name="request-attribute-name" type="xs:token">
<xs:annotation>
<xs:documentation>The request attribute name the CsrfToken is set on. Default is to set to
CsrfToken.parameterName
</xs:documentation>
</xs:annotation>
</xs:attribute>
<xs:attribute name="request-matcher-ref" type="xs:token">
<xs:annotation>
<xs:documentation>The RequestMatcher instance to be used to determine if CSRF should be applied. Default is
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,15 @@ public void postWhenUsingCsrfAndCustomAccessDeniedHandlerThenTheHandlerIsAppropr
// @formatter:on
}

@Test
public void getWhenUsingCsrfAndCustomRequestAttributeThenSetUsingCsrfAttrName() throws Exception {
this.spring.configLocations(this.xml("WithRequestAttrName")).autowire();
// @formatter:off
MvcResult result = this.mvc.perform(get("/ok")).andReturn();
assertThat(result.getRequest().getAttribute("csrf-attribute-name")).isInstanceOf(CsrfToken.class);
// @formatter:on
}

@Test
public void postWhenHasCsrfTokenButSessionExpiresThenRequestIsCancelledAfterSuccessfulAuthentication()
throws Exception {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
<?xml version="1.0" encoding="UTF-8"?>
<!--
~ Copyright 2002-2018 the original author or authors.
~
~ Licensed under the Apache License, Version 2.0 (the "License");
~ you may not use this file except in compliance with the License.
~ You may obtain a copy of the License at
~
~ https://www.apache.org/licenses/LICENSE-2.0
~
~ Unless required by applicable law or agreed to in writing, software
~ distributed under the License is distributed on an "AS IS" BASIS,
~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
~ See the License for the specific language governing permissions and
~ limitations under the License.
-->

<b:beans xmlns:b="http://www.springframework.org/schema/beans"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xmlns="http://www.springframework.org/schema/security"
xsi:schemaLocation="http://www.springframework.org/schema/security https://www.springframework.org/schema/security/spring-security.xsd
http://www.springframework.org/schema/beans https://www.springframework.org/schema/beans/spring-beans.xsd">

<http auto-config="true">
<csrf request-attribute-name="csrf-attribute-name"/>
</http>

<b:import resource="CsrfConfigTests-shared-userservice.xml"/>
</b:beans>
4 changes: 4 additions & 0 deletions docs/modules/ROOT/pages/servlet/appendix/namespace/http.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -775,6 +775,10 @@ It is highly recommended to leave CSRF protection enabled.
The CsrfTokenRepository to use.
The default is `HttpSessionCsrfTokenRepository`.

[[nsa-csrf-request-attribute-name]]
* **request-attribute-name**
Optional attribute that specifies the request attribute name to set the `CsrfToken` on.
The default is `CsrfToken.parameterName`.

[[nsa-csrf-request-matcher-ref]]
* **request-matcher-ref**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ public final class CsrfFilter extends OncePerRequestFilter {

private AccessDeniedHandler accessDeniedHandler = new AccessDeniedHandlerImpl();

private String csrfRequestAttributeName;

public CsrfFilter(CsrfTokenRepository csrfTokenRepository) {
Assert.notNull(csrfTokenRepository, "csrfTokenRepository cannot be null");
this.tokenRepository = csrfTokenRepository;
Expand All @@ -108,7 +110,9 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse
this.tokenRepository.saveToken(csrfToken, request, response);
}
request.setAttribute(CsrfToken.class.getName(), csrfToken);
request.setAttribute(csrfToken.getParameterName(), csrfToken);
String csrfAttrName = (this.csrfRequestAttributeName != null) ? this.csrfRequestAttributeName
: csrfToken.getParameterName();
request.setAttribute(csrfAttrName, csrfToken);
if (!this.requireCsrfProtectionMatcher.matches(request)) {
if (this.logger.isTraceEnabled()) {
this.logger.trace("Did not protect against CSRF since request did not match "
Expand Down Expand Up @@ -167,6 +171,18 @@ public void setAccessDeniedHandler(AccessDeniedHandler accessDeniedHandler) {
this.accessDeniedHandler = accessDeniedHandler;
}

/**
* The {@link CsrfToken} is available as a request attribute named
* {@code CsrfToken.class.getName()}. By default, an additional request attribute that
* is the same as {@link CsrfToken#getParameterName()} is set. This attribute allows
* overriding the additional attribute.
* @param csrfRequestAttributeName the name of an additional request attribute with
* the value of the CsrfToken. Default is {@link CsrfToken#getParameterName()}
*/
public void setCsrfRequestAttributeName(String csrfRequestAttributeName) {
this.csrfRequestAttributeName = csrfRequestAttributeName;
}

/**
* Constant time comparison to prevent against timing attacks.
* @param expected
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ public final class LazyCsrfTokenRepository implements CsrfTokenRepository {

private final CsrfTokenRepository delegate;

private boolean deferLoadToken;

/**
* Creates a new instance
* @param delegate the {@link CsrfTokenRepository} to use. Cannot be null
Expand All @@ -48,6 +50,15 @@ public LazyCsrfTokenRepository(CsrfTokenRepository delegate) {
this.delegate = delegate;
}

/**
* Determines if {@link #loadToken(HttpServletRequest)} should be lazily loaded.
* @param deferLoadToken true if should lazily load
* {@link #loadToken(HttpServletRequest)}. Default false.
*/
public void setDeferLoadToken(boolean deferLoadToken) {
this.deferLoadToken = deferLoadToken;
}

/**
* Generates a new token
* @param request the {@link HttpServletRequest} to use. The
Expand Down Expand Up @@ -77,6 +88,9 @@ public void saveToken(CsrfToken token, HttpServletRequest request, HttpServletRe
*/
@Override
public CsrfToken loadToken(HttpServletRequest request) {
if (this.deferLoadToken) {
return new LazyLoadCsrfToken(request, this.delegate);
}
return this.delegate.loadToken(request);
}

Expand All @@ -92,6 +106,55 @@ private HttpServletResponse getResponse(HttpServletRequest request) {
return response;
}

private final class LazyLoadCsrfToken implements CsrfToken {

private final HttpServletRequest request;

private final CsrfTokenRepository tokenRepository;

private CsrfToken token;

private LazyLoadCsrfToken(HttpServletRequest request, CsrfTokenRepository tokenRepository) {
this.request = request;
this.tokenRepository = tokenRepository;
}

private CsrfToken getDelegate() {
if (this.token != null) {
return this.token;
}
// load from the delegate repository
this.token = LazyCsrfTokenRepository.this.delegate.loadToken(this.request);
if (this.token == null) {
// return a generated token that is lazily saved since
// LazyCsrfTokenRepository#loadToken always returns a value
this.token = generateToken(this.request);
}
return this.token;
}

@Override
public String getHeaderName() {
return getDelegate().getHeaderName();
}

@Override
public String getParameterName() {
return getDelegate().getParameterName();
}

@Override
public String getToken() {
return getDelegate().getToken();
}

@Override
public String toString() {
return "LazyLoadCsrfToken{" + "token=" + this.token + '}';
}

}

private static final class SaveOnAccessCsrfToken implements CsrfToken {

private transient CsrfTokenRepository tokenRepository;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.verifyZeroInteractions;

/**
Expand Down Expand Up @@ -344,6 +345,23 @@ public void setAccessDeniedHandlerNull() {
assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setAccessDeniedHandler(null));
}

// This ensures that the HttpSession on get requests unless the CsrfToken is used
@Test
public void doFilterWhenCsrfRequestAttributeNameThenNoCsrfTokenMethodInvokedOnGet()
throws ServletException, IOException {
CsrfFilter filter = createCsrfFilter(this.tokenRepository);
String csrfAttrName = "_csrf";
filter.setCsrfRequestAttributeName(csrfAttrName);
CsrfToken expectedCsrfToken = mock(CsrfToken.class);
given(this.tokenRepository.loadToken(this.request)).willReturn(expectedCsrfToken);

filter.doFilter(this.request, this.response, this.filterChain);

verifyNoInteractions(expectedCsrfToken);
CsrfToken tokenFromRequest = (CsrfToken) this.request.getAttribute(csrfAttrName);
assertThat(tokenFromRequest).isEqualTo(expectedCsrfToken);
}

private static CsrfTokenAssert assertToken(Object token) {
return new CsrfTokenAssert((CsrfToken) token);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.verifyZeroInteractions;

/**
Expand Down Expand Up @@ -98,4 +99,15 @@ public void loadTokenDelegates() {
verify(this.delegate).loadToken(this.request);
}

@Test
public void loadTokenWhenDeferLoadToken() {
given(this.delegate.loadToken(this.request)).willReturn(this.token);
this.repository.setDeferLoadToken(true);
CsrfToken loadToken = this.repository.loadToken(this.request);
verifyNoInteractions(this.delegate);
assertThat(loadToken.getToken()).isEqualTo(this.token.getToken());
assertThat(loadToken.getHeaderName()).isEqualTo(this.token.getHeaderName());
assertThat(loadToken.getParameterName()).isEqualTo(this.token.getParameterName());
}

}

0 comments on commit c1a6cea

Please sign in to comment.