Skip to content

Commit

Permalink
Refine requestMatcher Validation Rules
Browse files Browse the repository at this point in the history
Closes gh-13850
  • Loading branch information
jzheaux committed Oct 12, 2023
1 parent 914ebd6 commit 5a6a1bf
Show file tree
Hide file tree
Showing 6 changed files with 232 additions and 22 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -26,6 +26,7 @@
import jakarta.servlet.DispatcherType;
import jakarta.servlet.ServletContext;
import jakarta.servlet.ServletRegistration;
import jakarta.servlet.http.HttpServletRequest;

import org.springframework.beans.factory.NoSuchBeanDefinitionException;
import org.springframework.context.ApplicationContext;
Expand Down Expand Up @@ -203,11 +204,30 @@ public C requestMatchers(HttpMethod method, String... patterns) {
if (!hasDispatcherServlet(registrations)) {
return requestMatchers(RequestMatchers.antMatchersAsArray(method, patterns));
}
if (registrations.size() > 1) {
String errorMessage = computeErrorMessage(registrations.values());
throw new IllegalArgumentException(errorMessage);
ServletRegistration dispatcherServlet = requireOneRootDispatcherServlet(registrations);
if (dispatcherServlet != null) {
if (registrations.size() == 1) {
return requestMatchers(createMvcMatchers(method, patterns).toArray(RequestMatcher[]::new));
}
List<RequestMatcher> matchers = new ArrayList<>();
for (String pattern : patterns) {
AntPathRequestMatcher ant = new AntPathRequestMatcher(pattern, (method != null) ? method.name() : null);
MvcRequestMatcher mvc = createMvcMatchers(method, pattern).get(0);
matchers.add(new DispatcherServletDelegatingRequestMatcher(ant, mvc, servletContext));
}
return requestMatchers(matchers.toArray(new RequestMatcher[0]));
}
return requestMatchers(createMvcMatchers(method, patterns).toArray(new RequestMatcher[0]));
dispatcherServlet = requireOnlyPathMappedDispatcherServlet(registrations);
if (dispatcherServlet != null) {
String mapping = dispatcherServlet.getMappings().iterator().next();
List<MvcRequestMatcher> matchers = createMvcMatchers(method, patterns);
for (MvcRequestMatcher matcher : matchers) {
matcher.setServletPath(mapping.substring(0, mapping.length() - 2));
}
return requestMatchers(matchers.toArray(new RequestMatcher[0]));
}
String errorMessage = computeErrorMessage(registrations.values());
throw new IllegalArgumentException(errorMessage);
}

private Map<String, ? extends ServletRegistration> mappableServletRegistrations(ServletContext servletContext) {
Expand All @@ -225,22 +245,66 @@ private boolean hasDispatcherServlet(Map<String, ? extends ServletRegistration>
if (registrations == null) {
return false;
}
Class<?> dispatcherServlet = ClassUtils.resolveClassName("org.springframework.web.servlet.DispatcherServlet",
null);
for (ServletRegistration registration : registrations.values()) {
try {
Class<?> clazz = Class.forName(registration.getClassName());
if (dispatcherServlet.isAssignableFrom(clazz)) {
return true;
}
}
catch (ClassNotFoundException ex) {
return false;
if (isDispatcherServlet(registration)) {
return true;
}
}
return false;
}

private ServletRegistration requireOneRootDispatcherServlet(
Map<String, ? extends ServletRegistration> registrations) {
ServletRegistration rootDispatcherServlet = null;
for (ServletRegistration registration : registrations.values()) {
if (!isDispatcherServlet(registration)) {
continue;
}
if (registration.getMappings().size() > 1) {
return null;
}
if (!"/".equals(registration.getMappings().iterator().next())) {
return null;
}
rootDispatcherServlet = registration;
}
return rootDispatcherServlet;
}

private ServletRegistration requireOnlyPathMappedDispatcherServlet(
Map<String, ? extends ServletRegistration> registrations) {
ServletRegistration pathDispatcherServlet = null;
for (ServletRegistration registration : registrations.values()) {
if (!isDispatcherServlet(registration)) {
return null;
}
if (registration.getMappings().size() > 1) {
return null;
}
String mapping = registration.getMappings().iterator().next();
if (!mapping.startsWith("/") || !mapping.endsWith("/*")) {
return null;
}
if (pathDispatcherServlet != null) {
return null;
}
pathDispatcherServlet = registration;
}
return pathDispatcherServlet;
}

private boolean isDispatcherServlet(ServletRegistration registration) {
Class<?> dispatcherServlet = ClassUtils.resolveClassName("org.springframework.web.servlet.DispatcherServlet",
null);
try {
Class<?> clazz = Class.forName(registration.getClassName());
return dispatcherServlet.isAssignableFrom(clazz);
}
catch (ClassNotFoundException ex) {
return false;
}
}

private String computeErrorMessage(Collection<? extends ServletRegistration> registrations) {
String template = "This method cannot decide whether these patterns are Spring MVC patterns or not. "
+ "If this endpoint is a Spring MVC endpoint, please use requestMatchers(MvcRequestMatcher); "
Expand Down Expand Up @@ -380,4 +444,55 @@ static List<RequestMatcher> regexMatchers(String... regexPatterns) {

}

static class DispatcherServletDelegatingRequestMatcher implements RequestMatcher {

private final AntPathRequestMatcher ant;

private final MvcRequestMatcher mvc;

private final ServletContext servletContext;

DispatcherServletDelegatingRequestMatcher(AntPathRequestMatcher ant, MvcRequestMatcher mvc,
ServletContext servletContext) {
this.ant = ant;
this.mvc = mvc;
this.servletContext = servletContext;
}

@Override
public boolean matches(HttpServletRequest request) {
String name = request.getHttpServletMapping().getServletName();
ServletRegistration registration = this.servletContext.getServletRegistration(name);
Assert.notNull(registration, "Failed to find servlet [" + name + "] in the servlet context");
if (isDispatcherServlet(registration)) {
return this.mvc.matches(request);
}
return this.ant.matches(request);
}

@Override
public MatchResult matcher(HttpServletRequest request) {
String name = request.getHttpServletMapping().getServletName();
ServletRegistration registration = this.servletContext.getServletRegistration(name);
Assert.notNull(registration, "Failed to find servlet [" + name + "] in the servlet context");
if (isDispatcherServlet(registration)) {
return this.mvc.matcher(request);
}
return this.ant.matcher(request);
}

private boolean isDispatcherServlet(ServletRegistration registration) {
Class<?> dispatcherServlet = ClassUtils
.resolveClassName("org.springframework.web.servlet.DispatcherServlet", null);
try {
Class<?> clazz = Class.forName(registration.getClassName());
return dispatcherServlet.isAssignableFrom(clazz);
}
catch (ClassNotFoundException ex) {
return false;
}
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ public ServletRegistration.Dynamic addServlet(@NonNull String servletName, Class
return this.registrations;
}

@Override
public ServletRegistration getServletRegistration(String servletName) {
return this.registrations.get(servletName);
}

private static class MockServletRegistration implements ServletRegistration.Dynamic {

private final String name;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,32 +14,32 @@
* limitations under the License.
*/

package org.springframework.security.config.annotation.web.configurers;
package org.springframework.security.config;

import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.MappingMatch;

import org.springframework.mock.web.MockHttpServletMapping;

final class TestMockHttpServletMappings {
public final class TestMockHttpServletMappings {

private TestMockHttpServletMappings() {

}

static MockHttpServletMapping extension(HttpServletRequest request, String extension) {
public static MockHttpServletMapping extension(HttpServletRequest request, String extension) {
String uri = request.getRequestURI();
String matchValue = uri.substring(0, uri.lastIndexOf(extension));
return new MockHttpServletMapping(matchValue, "*" + extension, "extension", MappingMatch.EXTENSION);
}

static MockHttpServletMapping path(HttpServletRequest request, String path) {
public static MockHttpServletMapping path(HttpServletRequest request, String path) {
String uri = request.getRequestURI();
String matchValue = uri.substring(path.length());
return new MockHttpServletMapping(matchValue, path + "/*", "path", MappingMatch.PATH);
}

static MockHttpServletMapping defaultMapping() {
public static MockHttpServletMapping defaultMapping() {
return new MockHttpServletMapping("", "/", "default", MappingMatch.DEFAULT);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -26,8 +26,11 @@
import org.springframework.beans.factory.NoSuchBeanDefinitionException;
import org.springframework.context.ApplicationContext;
import org.springframework.http.HttpMethod;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.security.config.MockServletContext;
import org.springframework.security.config.TestMockHttpServletMappings;
import org.springframework.security.config.annotation.ObjectPostProcessor;
import org.springframework.security.config.annotation.web.AbstractRequestMatcherRegistry.DispatcherServletDelegatingRequestMatcher;
import org.springframework.security.web.servlet.util.matcher.MvcRequestMatcher;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import org.springframework.security.web.util.matcher.DispatcherTypeRequestMatcher;
Expand All @@ -40,6 +43,9 @@
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.verifyNoMoreInteractions;

/**
* Tests for {@link AbstractRequestMatcherRegistry}.
Expand Down Expand Up @@ -159,6 +165,8 @@ public void requestMatchersWhenMvcPresentInClassPathAndMvcIntrospectorBeanNotAva
public void requestMatchersWhenNoDispatcherServletThenAntPathRequestMatcherType() {
MockServletContext servletContext = new MockServletContext();
given(this.context.getServletContext()).willReturn(servletContext);
servletContext.addServlet("servletOne", Servlet.class).addMapping("/one");
servletContext.addServlet("servletTwo", Servlet.class).addMapping("/two");
List<RequestMatcher> requestMatchers = this.matcherRegistry.requestMatchers("/**");
assertThat(requestMatchers).isNotEmpty();
assertThat(requestMatchers).hasSize(1);
Expand All @@ -170,7 +178,26 @@ public void requestMatchersWhenAmbiguousServletsThenException() {
MockServletContext servletContext = new MockServletContext();
given(this.context.getServletContext()).willReturn(servletContext);
servletContext.addServlet("dispatcherServlet", DispatcherServlet.class).addMapping("/");
servletContext.addServlet("servletTwo", Servlet.class).addMapping("/servlet/**");
servletContext.addServlet("servletTwo", DispatcherServlet.class).addMapping("/servlet/*");
assertThatExceptionOfType(IllegalArgumentException.class)
.isThrownBy(() -> this.matcherRegistry.requestMatchers("/**"));
}

@Test
public void requestMatchersWhenMultipleDispatcherServletMappingsThenException() {
MockServletContext servletContext = new MockServletContext();
given(this.context.getServletContext()).willReturn(servletContext);
servletContext.addServlet("dispatcherServlet", DispatcherServlet.class).addMapping("/", "/mvc/*");
assertThatExceptionOfType(IllegalArgumentException.class)
.isThrownBy(() -> this.matcherRegistry.requestMatchers("/**"));
}

@Test
public void requestMatchersWhenPathDispatcherServletAndOtherServletsThenException() {
MockServletContext servletContext = new MockServletContext();
given(this.context.getServletContext()).willReturn(servletContext);
servletContext.addServlet("dispatcherServlet", DispatcherServlet.class).addMapping("/mvc/*");
servletContext.addServlet("default", Servlet.class).addMapping("/");
assertThatExceptionOfType(IllegalArgumentException.class)
.isThrownBy(() -> this.matcherRegistry.requestMatchers("/**"));
}
Expand All @@ -187,6 +214,67 @@ public void requestMatchersWhenUnmappableServletsThenSkips() {
assertThat(requestMatchers.get(0)).isInstanceOf(MvcRequestMatcher.class);
}

@Test
public void requestMatchersWhenOnlyDispatcherServletThenAllows() {
MockServletContext servletContext = new MockServletContext();
given(this.context.getServletContext()).willReturn(servletContext);
servletContext.addServlet("dispatcherServlet", DispatcherServlet.class).addMapping("/mvc/*");
List<RequestMatcher> requestMatchers = this.matcherRegistry.requestMatchers("/**");
assertThat(requestMatchers).hasSize(1);
assertThat(requestMatchers.get(0)).isInstanceOf(MvcRequestMatcher.class);
}

@Test
public void requestMatchersWhenImplicitServletsThenAllows() {
mockMvcIntrospector(true);
MockServletContext servletContext = new MockServletContext();
given(this.context.getServletContext()).willReturn(servletContext);
servletContext.addServlet("defaultServlet", Servlet.class);
servletContext.addServlet("jspServlet", Servlet.class).addMapping("*.jsp", "*.jspx");
servletContext.addServlet("dispatcherServlet", DispatcherServlet.class).addMapping("/");
List<RequestMatcher> requestMatchers = this.matcherRegistry.requestMatchers("/**");
assertThat(requestMatchers).hasSize(1);
assertThat(requestMatchers.get(0)).isInstanceOf(DispatcherServletDelegatingRequestMatcher.class);
}

@Test
public void requestMatchersWhenPathBasedNonDispatcherServletThenAllows() {
MockServletContext servletContext = new MockServletContext();
given(this.context.getServletContext()).willReturn(servletContext);
servletContext.addServlet("path", Servlet.class).addMapping("/services/*");
servletContext.addServlet("default", DispatcherServlet.class).addMapping("/");
List<RequestMatcher> requestMatchers = this.matcherRegistry.requestMatchers("/services/*");
assertThat(requestMatchers).hasSize(1);
assertThat(requestMatchers.get(0)).isInstanceOf(DispatcherServletDelegatingRequestMatcher.class);
MockHttpServletRequest request = new MockHttpServletRequest("GET", "/services/endpoint");
request.setHttpServletMapping(TestMockHttpServletMappings.defaultMapping());
assertThat(requestMatchers.get(0).matcher(request).isMatch()).isTrue();
request.setHttpServletMapping(TestMockHttpServletMappings.path(request, "/services"));
request.setServletPath("/services");
request.setPathInfo("/endpoint");
assertThat(requestMatchers.get(0).matcher(request).isMatch()).isTrue();
}

@Test
public void matchesWhenDispatcherServletThenMvc() {
MockServletContext servletContext = new MockServletContext();
servletContext.addServlet("default", DispatcherServlet.class).addMapping("/");
servletContext.addServlet("path", Servlet.class).addMapping("/services/*");
MvcRequestMatcher mvc = mock(MvcRequestMatcher.class);
AntPathRequestMatcher ant = mock(AntPathRequestMatcher.class);
DispatcherServletDelegatingRequestMatcher requestMatcher = new DispatcherServletDelegatingRequestMatcher(ant,
mvc, servletContext);
MockHttpServletRequest request = new MockHttpServletRequest("GET", "/services/endpoint");
request.setHttpServletMapping(TestMockHttpServletMappings.defaultMapping());
assertThat(requestMatcher.matches(request)).isFalse();
verify(mvc).matches(request);
verifyNoInteractions(ant);
request.setHttpServletMapping(TestMockHttpServletMappings.path(request, "/services"));
assertThat(requestMatcher.matches(request)).isFalse();
verify(ant).matches(request);
verifyNoMoreInteractions(mvc);
}

private void mockMvcIntrospector(boolean isPresent) {
ApplicationContext context = this.matcherRegistry.getApplicationContext();
given(context.containsBean("mvcHandlerMappingIntrospector")).willReturn(isPresent);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.springframework.security.authorization.AuthorizationEventPublisher;
import org.springframework.security.authorization.AuthorizationManager;
import org.springframework.security.config.MockServletContext;
import org.springframework.security.config.TestMockHttpServletMappings;
import org.springframework.security.config.annotation.ObjectPostProcessor;
import org.springframework.security.config.annotation.web.AbstractRequestMatcherRegistry;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.junit.jupiter.api.Test;

import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.security.config.TestMockHttpServletMappings;

import static org.assertj.core.api.Assertions.assertThat;

Expand Down

0 comments on commit 5a6a1bf

Please sign in to comment.