From 4ca54683ae9166f9f59f18111e111d26b2a8e901 Mon Sep 17 00:00:00 2001 From: Josh Cummings Date: Thu, 16 Nov 2023 15:45:44 -0700 Subject: [PATCH] Defer requestMatchers Validation to Runtime Closes gh-13794 --- .../web/AbstractRequestMatcherRegistry.java | 87 +++++++++++++++---- .../AbstractRequestMatcherRegistryTests.java | 23 +++++ 2 files changed, 95 insertions(+), 15 deletions(-) diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistry.java b/config/src/main/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistry.java index f2618aaa1e2..92c3d812936 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistry.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistry.java @@ -22,6 +22,8 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; import javax.servlet.DispatcherType; import javax.servlet.ServletContext; @@ -42,6 +44,7 @@ import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.util.Assert; import org.springframework.util.ClassUtils; +import org.springframework.util.function.SingletonSupplier; import org.springframework.web.context.WebApplicationContext; import org.springframework.web.servlet.handler.HandlerMappingIntrospector; @@ -315,34 +318,51 @@ public C requestMatchers(HttpMethod method, String... patterns) { if (servletContext == null) { return requestMatchers(RequestMatchers.antMatchersAsArray(method, patterns)); } + boolean isProgrammaticApiAvailable = isProgrammaticApiAvailable(servletContext); + List matchers = new ArrayList<>(); + for (String pattern : patterns) { + AntPathRequestMatcher ant = new AntPathRequestMatcher(pattern, (method != null) ? method.name() : null); + MvcRequestMatcher mvc = createMvcMatchers(method, pattern).get(0); + if (isProgrammaticApiAvailable) { + matchers.add(resolve(ant, mvc, servletContext)); + } + else { + matchers.add(new DeferredRequestMatcher(() -> resolve(ant, mvc, servletContext), mvc, ant)); + } + } + return requestMatchers(matchers.toArray(new RequestMatcher[0])); + } + + private static boolean isProgrammaticApiAvailable(ServletContext servletContext) { + try { + servletContext.getServletRegistrations(); + return true; + } + catch (UnsupportedOperationException ex) { + return false; + } + } + + private RequestMatcher resolve(AntPathRequestMatcher ant, MvcRequestMatcher mvc, ServletContext servletContext) { Map registrations = mappableServletRegistrations(servletContext); if (registrations.isEmpty()) { - return requestMatchers(RequestMatchers.antMatchersAsArray(method, patterns)); + return ant; } if (!hasDispatcherServlet(registrations)) { - return requestMatchers(RequestMatchers.antMatchersAsArray(method, patterns)); + return ant; } ServletRegistration dispatcherServlet = requireOneRootDispatcherServlet(registrations); if (dispatcherServlet != null) { if (registrations.size() == 1) { - return requestMatchers(createMvcMatchers(method, patterns).toArray(new RequestMatcher[0])); + return mvc; } - List matchers = new ArrayList<>(); - for (String pattern : patterns) { - AntPathRequestMatcher ant = new AntPathRequestMatcher(pattern, (method != null) ? method.name() : null); - MvcRequestMatcher mvc = createMvcMatchers(method, pattern).get(0); - matchers.add(new DispatcherServletDelegatingRequestMatcher(ant, mvc, servletContext)); - } - return requestMatchers(matchers.toArray(new RequestMatcher[0])); + return new DispatcherServletDelegatingRequestMatcher(ant, mvc, servletContext); } dispatcherServlet = requireOnlyPathMappedDispatcherServlet(registrations); if (dispatcherServlet != null) { String mapping = dispatcherServlet.getMappings().iterator().next(); - List matchers = createMvcMatchers(method, patterns); - for (MvcRequestMatcher matcher : matchers) { - matcher.setServletPath(mapping.substring(0, mapping.length() - 2)); - } - return requestMatchers(matchers.toArray(new RequestMatcher[0])); + mvc.setServletPath(mapping.substring(0, mapping.length() - 2)); + return mvc; } String errorMessage = computeErrorMessage(registrations.values()); throw new IllegalArgumentException(errorMessage); @@ -562,6 +582,38 @@ static List regexMatchers(String... regexPatterns) { } + static class DeferredRequestMatcher implements RequestMatcher { + + final Supplier requestMatcher; + + final AtomicReference description = new AtomicReference<>(); + + DeferredRequestMatcher(Supplier resolver, RequestMatcher... candidates) { + this.requestMatcher = SingletonSupplier.of(() -> { + RequestMatcher matcher = resolver.get(); + this.description.set(matcher.toString()); + return matcher; + }); + this.description.set("Deferred " + candidates); + } + + @Override + public boolean matches(HttpServletRequest request) { + return this.requestMatcher.get().matches(request); + } + + @Override + public MatchResult matcher(HttpServletRequest request) { + return this.requestMatcher.get().matcher(request); + } + + @Override + public String toString() { + return this.description.get(); + } + + } + static class DispatcherServletDelegatingRequestMatcher implements RequestMatcher { private final AntPathRequestMatcher ant; @@ -611,6 +663,11 @@ private boolean isDispatcherServlet(ServletRegistration registration) { } } + @Override + public String toString() { + return "DispatcherServletDelegating [" + "ant = " + this.ant + ", mvc = " + this.mvc + "]"; + } + } } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistryTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistryTests.java index fb06997c4f0..53e21c1a3c3 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistryTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistryTests.java @@ -18,6 +18,7 @@ import java.lang.reflect.Field; import java.lang.reflect.Modifier; +import java.util.ArrayList; import java.util.List; import javax.servlet.DispatcherType; @@ -202,6 +203,7 @@ public void requestMatchersWhenMvcPresentInClassPathAndMvcIntrospectorBeanNotAva @Test public void requestMatchersWhenNoDispatcherServletThenAntPathRequestMatcherType() { + mockMvcIntrospector(true); MockServletContext servletContext = new MockServletContext(); given(this.context.getServletContext()).willReturn(servletContext); servletContext.addServlet("servletOne", Servlet.class).addMapping("/one"); @@ -220,6 +222,7 @@ public void requestMatchersWhenNoDispatcherServletThenAntPathRequestMatcherType( @Test public void requestMatchersWhenAmbiguousServletsThenException() { + mockMvcIntrospector(true); MockServletContext servletContext = new MockServletContext(); given(this.context.getServletContext()).willReturn(servletContext); servletContext.addServlet("dispatcherServlet", DispatcherServlet.class).addMapping("/"); @@ -230,6 +233,7 @@ public void requestMatchersWhenAmbiguousServletsThenException() { @Test public void requestMatchersWhenMultipleDispatcherServletMappingsThenException() { + mockMvcIntrospector(true); MockServletContext servletContext = new MockServletContext(); given(this.context.getServletContext()).willReturn(servletContext); servletContext.addServlet("dispatcherServlet", DispatcherServlet.class).addMapping("/", "/mvc/*"); @@ -239,6 +243,7 @@ public void requestMatchersWhenMultipleDispatcherServletMappingsThenException() @Test public void requestMatchersWhenPathDispatcherServletAndOtherServletsThenException() { + mockMvcIntrospector(true); MockServletContext servletContext = new MockServletContext(); given(this.context.getServletContext()).willReturn(servletContext); servletContext.addServlet("dispatcherServlet", DispatcherServlet.class).addMapping("/mvc/*"); @@ -366,11 +371,29 @@ public List mvcMatchers(HttpMethod method, String... mvcPatterns return null; } + @Override + public List requestMatchers(RequestMatcher... requestMatchers) { + return unwrap(super.requestMatchers(requestMatchers)); + } + @Override protected List chainRequestMatchers(List requestMatchers) { return requestMatchers; } + private static List unwrap(List wrappedMatchers) { + List requestMatchers = new ArrayList<>(); + for (RequestMatcher requestMatcher : wrappedMatchers) { + if (requestMatcher instanceof AbstractRequestMatcherRegistry.DeferredRequestMatcher) { + requestMatchers.add(((DeferredRequestMatcher) requestMatcher).requestMatcher.get()); + } + else { + requestMatchers.add(requestMatcher); + } + } + return requestMatchers; + } + } }