Skip to content

Commit

Permalink
Fix SpringRequestMapping to be able to handle it when multiple paths …
Browse files Browse the repository at this point in the history
…are mapped to the same method.
  • Loading branch information
sambsnyd committed Oct 8, 2024
1 parent 6e703ee commit 0d5e2d8
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,8 @@
import org.openrewrite.trait.SimpleTraitMatcher;
import org.openrewrite.trait.Trait;

import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static java.util.stream.Collectors.toList;
Expand All @@ -52,23 +51,34 @@ public String getHttpMethod() {
}

public String getPath() {
String path =
cursor.getPathAsStream()
.filter(J.ClassDeclaration.class::isInstance)
.map(classDecl -> ((J.ClassDeclaration) classDecl).getAllAnnotations().stream()
.filter(SpringRequestMapping::hasRequestMapping)
.findAny()
.flatMap(classMapping -> new Annotated(new Cursor(null, classMapping))
.getDefaultAttribute(null)
.map(Literal::getString))
.orElse(null))
.filter(Objects::nonNull)
.collect(Collectors.joining("/")) +
new Annotated(cursor)
.getDefaultAttribute(null)
.map(Literal::getString)
.orElse("");
return path.replace("//", "/");
List<String> pathPrefixes = cursor.getPathAsStream()
.filter(J.ClassDeclaration.class::isInstance)
.map(J.ClassDeclaration.class::cast)
.flatMap(classDecl -> classDecl.getLeadingAnnotations().stream()
.filter(SpringRequestMapping::hasRequestMapping)
.findAny()
.flatMap(classMapping -> new Annotated(new Cursor(null, classMapping))
.getDefaultAttribute(null)
.map(lit -> lit.getStrings().stream()))
.orElse(Stream.of("")))
.collect(toList());
List<String> pathEndings = new Annotated(cursor)
.getDefaultAttribute(null)
.map(Literal::getStrings)
.orElse(Collections.emptyList());

StringBuilder result = new StringBuilder();
for (int j = 0; j < pathPrefixes.size(); j++) {
for (int i = 0; i < pathEndings.size(); i++) {
String pathEnding = pathEndings.get(i);
String prefix = pathPrefixes.get(j);
result.append(prefix).append(pathEnding);
if(i < pathEndings.size() - 1 || j < pathPrefixes.size() - 1) {
result.append(", ");
}
}
}
return result.toString().replace("//", "/");
}

public static class Matcher extends SimpleTraitMatcher<SpringRequestMapping> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,12 @@ void migrateHttpComponentsClientHttpRequestFactoryReadTimeout() {
import org.springframework.boot.web.client.RestTemplateBuilder;
import org.springframework.http.client.HttpComponentsClientHttpRequestFactory;
import org.springframework.web.client.RestTemplate;
class RestContextInitializer {
RestTemplate getRestTemplate() throws Exception {
Registry<ConnectionSocketFactory> socketFactoryRegistry = RegistryBuilder.<ConnectionSocketFactory>create().build();
PoolingHttpClientConnectionManager poolingConnectionManager = new PoolingHttpClientConnectionManager(socketFactoryRegistry);
return new RestTemplateBuilder()
.requestFactory(() -> {
HttpComponentsClientHttpRequestFactory clientHttpRequestFactory = new HttpComponentsClientHttpRequestFactory();
Expand All @@ -77,15 +77,15 @@ RestTemplate getRestTemplate() throws Exception {
import org.springframework.boot.web.client.RestTemplateBuilder;
import org.springframework.http.client.HttpComponentsClientHttpRequestFactory;
import org.springframework.web.client.RestTemplate;
import java.util.concurrent.TimeUnit;
class RestContextInitializer {
RestTemplate getRestTemplate() throws Exception {
Registry<ConnectionSocketFactory> socketFactoryRegistry = RegistryBuilder.<ConnectionSocketFactory>create().build();
PoolingHttpClientConnectionManager poolingConnectionManager = new PoolingHttpClientConnectionManager(socketFactoryRegistry);
poolingConnectionManager.setDefaultSocketConfig(SocketConfig.custom().setSoTimeout(30000, TimeUnit.MILLISECONDS).build());
return new RestTemplateBuilder()
.requestFactory(() -> {
HttpComponentsClientHttpRequestFactory clientHttpRequestFactory = new HttpComponentsClientHttpRequestFactory();
Expand Down Expand Up @@ -114,15 +114,15 @@ void migratePoolingHttpClientConnectionManagerBuilderToVariable() {
import org.springframework.boot.web.client.RestTemplateBuilder;
import org.springframework.http.client.HttpComponentsClientHttpRequestFactory;
import org.springframework.web.client.RestTemplate;
import javax.net.ssl.SSLContext;
class RestContextInitializer {
RestTemplate getRestTemplate() throws Exception {
SSLContext sslContext = SSLContexts.custom().loadTrustMaterial(null, (cert, authType) -> true).build();
SSLConnectionSocketFactory socketFactoryRegistry = new SSLConnectionSocketFactory(sslContext,NoopHostnameVerifier.INSTANCE);
PoolingHttpClientConnectionManager poolingConnectionManager = PoolingHttpClientConnectionManagerBuilder.create().setSSLSocketFactory(socketFactoryRegistry).build();
return new RestTemplateBuilder()
.requestFactory(() -> {
HttpComponentsClientHttpRequestFactory clientHttpRequestFactory = new HttpComponentsClientHttpRequestFactory();
Expand All @@ -144,18 +144,18 @@ RestTemplate getRestTemplate() throws Exception {
import org.springframework.boot.web.client.RestTemplateBuilder;
import org.springframework.http.client.HttpComponentsClientHttpRequestFactory;
import org.springframework.web.client.RestTemplate;
import javax.net.ssl.SSLContext;
import java.util.concurrent.TimeUnit;
class RestContextInitializer {
RestTemplate getRestTemplate() throws Exception {
SSLContext sslContext = SSLContexts.custom().loadTrustMaterial(null, (cert, authType) -> true).build();
SSLConnectionSocketFactory socketFactoryRegistry = new SSLConnectionSocketFactory(sslContext,NoopHostnameVerifier.INSTANCE);
PoolingHttpClientConnectionManager poolingConnectionManager = PoolingHttpClientConnectionManagerBuilder.create().setSSLSocketFactory(socketFactoryRegistry).build();
poolingConnectionManager.setDefaultSocketConfig(SocketConfig.custom().setSoTimeout(30000, TimeUnit.MILLISECONDS).build());
return new RestTemplateBuilder()
.requestFactory(() -> {
HttpComponentsClientHttpRequestFactory clientHttpRequestFactory = new HttpComponentsClientHttpRequestFactory();
Expand Down Expand Up @@ -183,15 +183,15 @@ void doNotMigrateWhenUsingPoolingHttpClientConnectionManagerBuilderInline() {
import org.springframework.boot.web.client.RestTemplateBuilder;
import org.springframework.http.client.HttpComponentsClientHttpRequestFactory;
import org.springframework.web.client.RestTemplate;
import javax.net.ssl.SSLContext;
class RestContextInitializer {
RestTemplate getRestTemplate() throws Exception {
SSLContext sslContext = SSLContexts.custom().loadTrustMaterial(null, (cert, authType) -> true).build();
SSLConnectionSocketFactory socketFactoryRegistry = new SSLConnectionSocketFactory(sslContext,NoopHostnameVerifier.INSTANCE);
return PoolingHttpClientConnectionManagerBuilder.create().setSSLSocketFactory(socketFactoryRegistry).build();
return new RestTemplateBuilder()
.requestFactory(() -> {
HttpComponentsClientHttpRequestFactory clientHttpRequestFactory = new HttpComponentsClientHttpRequestFactory();
Expand All @@ -211,15 +211,15 @@ RestTemplate getRestTemplate() throws Exception {
import org.springframework.boot.web.client.RestTemplateBuilder;
import org.springframework.http.client.HttpComponentsClientHttpRequestFactory;
import org.springframework.web.client.RestTemplate;
import javax.net.ssl.SSLContext;
class RestContextInitializer {
RestTemplate getRestTemplate() throws Exception {
SSLContext sslContext = SSLContexts.custom().loadTrustMaterial(null, (cert, authType) -> true).build();
SSLConnectionSocketFactory socketFactoryRegistry = new SSLConnectionSocketFactory(sslContext,NoopHostnameVerifier.INSTANCE);
return PoolingHttpClientConnectionManagerBuilder.create().setSSLSocketFactory(socketFactoryRegistry).build();
return new RestTemplateBuilder()
.requestFactory(() -> {
HttpComponentsClientHttpRequestFactory clientHttpRequestFactory = new HttpComponentsClientHttpRequestFactory();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,52 @@ class FindApiEndpointsTest implements RewriteTest {
@Override
public void defaults(RecipeSpec spec) {
spec.recipe(new FindApiEndpoints())
.parser(JavaParser.fromJavaVersion().classpath("spring-web"));
.parser(JavaParser.fromJavaVersion().classpath("spring-web", "spring-context"));
}

@Test
@DocumentExample
void webClient() {
void withinController() {
rewriteRun(
//language=java
java(
"""
import org.springframework.stereotype.Controller;
import org.springframework.web.bind.annotation.*;
@Controller
class PersonController {
@GetMapping("/count")
int count() {
return 42;
}
}
""",
"""
import org.springframework.stereotype.Controller;
import org.springframework.web.bind.annotation.*;
@Controller
class PersonController {
/*~~(GET /count)~~>*/@GetMapping("/count")
int count() {
return 42;
}
}
"""
)
);
}

@Test
@DocumentExample
void webClient() {
rewriteRun(
//language=java
java(
"""
import org.springframework.web.bind.annotation.*;
@RequestMapping("/person")
class PersonController {
@GetMapping("/count")
Expand All @@ -52,7 +86,7 @@ int count() {
""",
"""
import org.springframework.web.bind.annotation.*;
@RequestMapping("/person")
class PersonController {
/*~~(GET /person/count)~~>*/@GetMapping("/count")
Expand All @@ -64,4 +98,35 @@ int count() {
)
);
}

@Test
void multiplePathsOneMethod() {
rewriteRun(
//language=java
java(
"""
import org.springframework.web.bind.annotation.*;
@RequestMapping({"/person", "/people"})
class PersonController {
@GetMapping({"/count", "/length"})
int count() {
return 42;
}
}
""",
"""
import org.springframework.web.bind.annotation.*;
@RequestMapping({"/person", "/people"})
class PersonController {
/*~~(GET /person/count, /person/length, /people/count, /people/length)~~>*/@GetMapping({"/count", "/length"})
int count() {
return 42;
}
}
"""
)
);
}
}

0 comments on commit 0d5e2d8

Please sign in to comment.