From 00f4033b9b486e2bbb8368dd53f3aae0a727f377 Mon Sep 17 00:00:00 2001 From: Marcus Da Coregio Date: Fri, 27 Aug 2021 11:38:22 -0300 Subject: [PATCH] Update DefaultWebInvocationPrivilegeEvaluator to use current ServletContext Closes gh-10208 --- .../security/web/FilterInvocation.java | 26 +++++++++++++++++-- ...efaultWebInvocationPrivilegeEvaluator.java | 16 +++++++++--- .../security/web/FilterInvocationTests.java | 13 +++++++++- ...tWebInvocationPrivilegeEvaluatorTests.java | 18 +++++++++++++ 4 files changed, 67 insertions(+), 6 deletions(-) diff --git a/web/src/main/java/org/springframework/security/web/FilterInvocation.java b/web/src/main/java/org/springframework/security/web/FilterInvocation.java index af2c135b6c3..2b848fb50e9 100644 --- a/web/src/main/java/org/springframework/security/web/FilterInvocation.java +++ b/web/src/main/java/org/springframework/security/web/FilterInvocation.java @@ -1,5 +1,5 @@ /* - * Copyright 2004, 2005, 2006 Acegi Technology Pty Limited + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,6 +29,7 @@ import java.util.Map; import javax.servlet.FilterChain; +import javax.servlet.ServletContext; import javax.servlet.ServletRequest; import javax.servlet.ServletResponse; import javax.servlet.http.HttpServletRequest; @@ -78,10 +79,19 @@ public FilterInvocation(String servletPath, String method) { } public FilterInvocation(String contextPath, String servletPath, String method) { - this(contextPath, servletPath, null, null, method); + this(contextPath, servletPath, method, null); + } + + public FilterInvocation(String contextPath, String servletPath, String method, ServletContext servletContext) { + this(contextPath, servletPath, null, null, method, servletContext); } public FilterInvocation(String contextPath, String servletPath, String pathInfo, String query, String method) { + this(contextPath, servletPath, pathInfo, query, method, null); + } + + public FilterInvocation(String contextPath, String servletPath, String pathInfo, String query, String method, + ServletContext servletContext) { DummyRequest request = new DummyRequest(); contextPath = (contextPath != null) ? contextPath : "/cp"; request.setContextPath(contextPath); @@ -90,6 +100,7 @@ public FilterInvocation(String contextPath, String servletPath, String pathInfo, request.setPathInfo(pathInfo); request.setQueryString(query); request.setMethod(method); + request.setServletContext(servletContext); this.request = request; } @@ -160,6 +171,8 @@ static class DummyRequest extends HttpServletRequestWrapper { private String method; + private ServletContext servletContext; + private final HttpHeaders headers = new HttpHeaders(); private final Map parameters = new LinkedHashMap<>(); @@ -290,6 +303,15 @@ void setParameter(String name, String... values) { this.parameters.put(name, values); } + @Override + public ServletContext getServletContext() { + return this.servletContext; + } + + void setServletContext(ServletContext servletContext) { + this.servletContext = servletContext; + } + } static final class UnsupportedOperationExceptionInvocationHandler implements InvocationHandler { diff --git a/web/src/main/java/org/springframework/security/web/access/DefaultWebInvocationPrivilegeEvaluator.java b/web/src/main/java/org/springframework/security/web/access/DefaultWebInvocationPrivilegeEvaluator.java index 7030d29c464..0563636dd5b 100644 --- a/web/src/main/java/org/springframework/security/web/access/DefaultWebInvocationPrivilegeEvaluator.java +++ b/web/src/main/java/org/springframework/security/web/access/DefaultWebInvocationPrivilegeEvaluator.java @@ -1,5 +1,5 @@ /* - * Copyright 2004, 2005, 2006 Acegi Technology Pty Limited + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,6 +18,8 @@ import java.util.Collection; +import javax.servlet.ServletContext; + import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -28,6 +30,7 @@ import org.springframework.security.core.Authentication; import org.springframework.security.web.FilterInvocation; import org.springframework.util.Assert; +import org.springframework.web.context.ServletContextAware; /** * Allows users to determine whether they have privileges for a given web URI. @@ -36,12 +39,14 @@ * @author Luke Taylor * @since 3.0 */ -public class DefaultWebInvocationPrivilegeEvaluator implements WebInvocationPrivilegeEvaluator { +public class DefaultWebInvocationPrivilegeEvaluator implements WebInvocationPrivilegeEvaluator, ServletContextAware { protected static final Log logger = LogFactory.getLog(DefaultWebInvocationPrivilegeEvaluator.class); private final AbstractSecurityInterceptor securityInterceptor; + private ServletContext servletContext; + public DefaultWebInvocationPrivilegeEvaluator(AbstractSecurityInterceptor securityInterceptor) { Assert.notNull(securityInterceptor, "SecurityInterceptor cannot be null"); Assert.isTrue(FilterInvocation.class.equals(securityInterceptor.getSecureObjectClass()), @@ -82,7 +87,7 @@ public boolean isAllowed(String uri, Authentication authentication) { @Override public boolean isAllowed(String contextPath, String uri, String method, Authentication authentication) { Assert.notNull(uri, "uri parameter is required"); - FilterInvocation filterInvocation = new FilterInvocation(contextPath, uri, method); + FilterInvocation filterInvocation = new FilterInvocation(contextPath, uri, method, this.servletContext); Collection attributes = this.securityInterceptor.obtainSecurityMetadataSource() .getAttributes(filterInvocation); if (attributes == null) { @@ -101,4 +106,9 @@ public boolean isAllowed(String contextPath, String uri, String method, Authenti } } + @Override + public void setServletContext(ServletContext servletContext) { + this.servletContext = servletContext; + } + } diff --git a/web/src/test/java/org/springframework/security/web/FilterInvocationTests.java b/web/src/test/java/org/springframework/security/web/FilterInvocationTests.java index 5f1ceed22c3..dd71d927d4c 100644 --- a/web/src/test/java/org/springframework/security/web/FilterInvocationTests.java +++ b/web/src/test/java/org/springframework/security/web/FilterInvocationTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2004, 2005, 2006 Acegi Technology Pty Limited + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,6 +24,7 @@ import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.mock.web.MockServletContext; import org.springframework.security.web.FilterInvocation.DummyRequest; import org.springframework.security.web.util.UrlUtils; @@ -131,4 +132,14 @@ public void dummyRequestIsSupportedByUrlUtils() { UrlUtils.buildRequestUrl(request); } + @Test + public void constructorWhenServletContextProvidedThenSetServletContextInRequest() { + String contextPath = ""; + String servletPath = "/path"; + String method = ""; + MockServletContext mockServletContext = new MockServletContext(); + FilterInvocation filterInvocation = new FilterInvocation(contextPath, servletPath, method, mockServletContext); + assertThat(filterInvocation.getRequest().getServletContext()).isSameAs(mockServletContext); + } + } diff --git a/web/src/test/java/org/springframework/security/web/access/DefaultWebInvocationPrivilegeEvaluatorTests.java b/web/src/test/java/org/springframework/security/web/access/DefaultWebInvocationPrivilegeEvaluatorTests.java index 54dcc2b891d..414824fce3b 100644 --- a/web/src/test/java/org/springframework/security/web/access/DefaultWebInvocationPrivilegeEvaluatorTests.java +++ b/web/src/test/java/org/springframework/security/web/access/DefaultWebInvocationPrivilegeEvaluatorTests.java @@ -18,8 +18,10 @@ import org.junit.Before; import org.junit.Test; +import org.mockito.ArgumentCaptor; import org.springframework.context.ApplicationEventPublisher; +import org.springframework.mock.web.MockServletContext; import org.springframework.security.access.AccessDecisionManager; import org.springframework.security.access.AccessDeniedException; import org.springframework.security.access.intercept.RunAsManager; @@ -27,6 +29,7 @@ import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.web.FilterInvocation; import org.springframework.security.web.access.intercept.FilterInvocationSecurityMetadataSource; import org.springframework.security.web.access.intercept.FilterSecurityInterceptor; @@ -34,9 +37,11 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyList; import static org.mockito.ArgumentMatchers.anyObject; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.willThrow; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; /** * Tests @@ -106,4 +111,17 @@ public void deniesAccessIfAccessDecisionManagerDoes() { assertThat(wipe.isAllowed("/foo/index.jsp", token)).isFalse(); } + @Test + public void isAllowedWhenServletContextIsSetThenPassedFilterInvocationHasServletContext() { + Authentication token = new TestingAuthenticationToken("test", "Password", "MOCK_INDEX"); + MockServletContext servletContext = new MockServletContext(); + ArgumentCaptor filterInvocationArgumentCaptor = ArgumentCaptor + .forClass(FilterInvocation.class); + DefaultWebInvocationPrivilegeEvaluator wipe = new DefaultWebInvocationPrivilegeEvaluator(this.interceptor); + wipe.setServletContext(servletContext); + wipe.isAllowed("/foo/index.jsp", token); + verify(this.adm).decide(eq(token), filterInvocationArgumentCaptor.capture(), any()); + assertThat(filterInvocationArgumentCaptor.getValue().getRequest().getServletContext()).isNotNull(); + } + }