Skip to content

Commit

Permalink
Apply SecurityContextHolderFilter to all dispatcher types
Browse files Browse the repository at this point in the history
Closes gh-11962
  • Loading branch information
marcusdacoregio committed Dec 12, 2022
1 parent 88d50a5 commit 99d6d21
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,16 @@

import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.util.Assert;
import org.springframework.web.filter.OncePerRequestFilter;
import org.springframework.web.filter.GenericFilterBean;

/**
* A {@link javax.servlet.Filter} that uses the {@link SecurityContextRepository} to
Expand All @@ -40,17 +42,18 @@
* mechanisms to choose individually if authentication should be persisted.
*
* @author Rob Winch
* @author Marcus da Coregio
* @since 5.7
*/
public class SecurityContextHolderFilter extends OncePerRequestFilter {
public class SecurityContextHolderFilter extends GenericFilterBean {

private static final String FILTER_APPLIED = SecurityContextHolderFilter.class.getName() + ".APPLIED";

private final SecurityContextRepository securityContextRepository;

private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
.getContextHolderStrategy();

private boolean shouldNotFilterErrorDispatch;

/**
* Creates a new instance.
* @param securityContextRepository the repository to use. Cannot be null.
Expand All @@ -61,23 +64,29 @@ public SecurityContextHolderFilter(SecurityContextRepository securityContextRepo
}

@Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
throws IOException, ServletException {
doFilter((HttpServletRequest) request, (HttpServletResponse) response, chain);
}

private void doFilter(HttpServletRequest request, HttpServletResponse response, FilterChain chain)
throws ServletException, IOException {
if (request.getAttribute(FILTER_APPLIED) != null) {
chain.doFilter(request, response);
return;
}
request.setAttribute(FILTER_APPLIED, Boolean.TRUE);
Supplier<SecurityContext> deferredContext = this.securityContextRepository.loadDeferredContext(request);
try {
this.securityContextHolderStrategy.setDeferredContext(deferredContext);
filterChain.doFilter(request, response);
chain.doFilter(request, response);
}
finally {
this.securityContextHolderStrategy.clearContext();
request.removeAttribute(FILTER_APPLIED);
}
}

@Override
protected boolean shouldNotFilterErrorDispatch() {
return this.shouldNotFilterErrorDispatch;
}

/**
* Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use
* the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}.
Expand All @@ -89,13 +98,4 @@ public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy secur
this.securityContextHolderStrategy = securityContextHolderStrategy;
}

/**
* Disables {@link SecurityContextHolderFilter} for error dispatch.
* @param shouldNotFilterErrorDispatch if the Filter should be disabled for error
* dispatch. Default is false.
*/
public void setShouldNotFilterErrorDispatch(boolean shouldNotFilterErrorDispatch) {
this.shouldNotFilterErrorDispatch = shouldNotFilterErrorDispatch;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import java.util.function.Supplier;

import javax.servlet.DispatcherType;
import javax.servlet.FilterChain;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
Expand All @@ -26,11 +27,15 @@
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.EnumSource;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.InOrder;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;

import org.springframework.mock.web.MockFilterChain;
import org.springframework.security.authentication.TestAuthentication;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContext;
Expand All @@ -40,11 +45,17 @@

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.inOrder;
import static org.mockito.Mockito.lenient;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;

@ExtendWith(MockitoExtension.class)
class SecurityContextHolderFilterTests {

private static final String FILTER_APPLIED = "org.springframework.security.web.context.SecurityContextHolderFilter.APPLIED";

@Mock
private SecurityContextRepository repository;

Expand Down Expand Up @@ -105,14 +116,38 @@ void doFilterThenSetsAndClearsSecurityContextHolderStrategy() throws Exception {
}

@Test
void shouldNotFilterErrorDispatchWhenDefault() {
assertThat(this.filter.shouldNotFilterErrorDispatch()).isFalse();
void doFilterWhenFilterAppliedThenDoNothing() throws Exception {
given(this.request.getAttribute(FILTER_APPLIED)).willReturn(true);
this.filter.doFilter(this.request, this.response, new MockFilterChain());
verify(this.request, times(1)).getAttribute(FILTER_APPLIED);
verifyNoInteractions(this.repository, this.response);
}

@Test
void shouldNotFilterErrorDispatchWhenOverridden() {
this.filter.setShouldNotFilterErrorDispatch(true);
assertThat(this.filter.shouldNotFilterErrorDispatch()).isTrue();
void doFilterWhenNotAppliedThenSetsAndRemovesAttribute() throws Exception {
given(this.repository.loadDeferredContext(this.requestArg.capture())).willReturn(
new SupplierDeferredSecurityContext(SecurityContextHolder::createEmptyContext, this.strategy));

this.filter.doFilter(this.request, this.response, new MockFilterChain());

InOrder inOrder = inOrder(this.request, this.repository);
inOrder.verify(this.request).setAttribute(FILTER_APPLIED, true);
inOrder.verify(this.repository).loadDeferredContext(this.request);
inOrder.verify(this.request).removeAttribute(FILTER_APPLIED);
}

@ParameterizedTest
@EnumSource(DispatcherType.class)
void doFilterWhenAnyDispatcherTypeThenFilter(DispatcherType dispatcherType) throws Exception {
lenient().when(this.request.getDispatcherType()).thenReturn(dispatcherType);
Authentication authentication = TestAuthentication.authenticatedUser();
SecurityContext expectedContext = new SecurityContextImpl(authentication);
given(this.repository.loadDeferredContext(this.requestArg.capture()))
.willReturn(new SupplierDeferredSecurityContext(() -> expectedContext, this.strategy));
FilterChain filterChain = (request, response) -> assertThat(SecurityContextHolder.getContext())
.isEqualTo(expectedContext);

this.filter.doFilter(this.request, this.response, filterChain);
}

}

0 comments on commit 99d6d21

Please sign in to comment.