diff --git a/src/main/java/org/openrewrite/java/spring/framework/MigrateWebMvcConfigurerAdapter.java b/src/main/java/org/openrewrite/java/spring/framework/MigrateWebMvcConfigurerAdapter.java index 5c4f4b2fc..479781793 100644 --- a/src/main/java/org/openrewrite/java/spring/framework/MigrateWebMvcConfigurerAdapter.java +++ b/src/main/java/org/openrewrite/java/spring/framework/MigrateWebMvcConfigurerAdapter.java @@ -30,6 +30,9 @@ import org.openrewrite.java.tree.TypeUtils; public class MigrateWebMvcConfigurerAdapter extends Recipe { + private static final String WEB_MVC_CONFIGURER_ADAPTER = "org.springframework.web.servlet.config.annotation.WebMvcConfigurerAdapter"; + private static final String WEB_MVC_CONFIGURER = "org.springframework.web.servlet.config.annotation.WebMvcConfigurer"; + @Override public String getDisplayName() { return "Replace `WebMvcConfigurerAdapter` with `WebMvcConfigurer`"; @@ -43,11 +46,13 @@ public String getDescription() { @Override public TreeVisitor getVisitor() { - return Preconditions.check(new UsesType<>("org.springframework.web.servlet.config.annotation.WebMvcConfigurerAdapter", false), new JavaIsoVisitor() { + return Preconditions.check(new UsesType<>(WEB_MVC_CONFIGURER_ADAPTER, false), new JavaIsoVisitor() { + private final JavaType WEB_MVC_CONFIGURER_TYPE = JavaType.buildType(WEB_MVC_CONFIGURER); + @Override public J.ClassDeclaration visitClassDeclaration(J.ClassDeclaration classDecl, ExecutionContext ctx) { J.ClassDeclaration cd = super.visitClassDeclaration(classDecl, ctx); - if (cd.getExtends() != null && TypeUtils.isOfClassType(cd.getExtends().getType(), "org.springframework.web.servlet.config.annotation.WebMvcConfigurerAdapter")) { + if (cd.getExtends() != null && TypeUtils.isOfClassType(cd.getExtends().getType(), WEB_MVC_CONFIGURER_ADAPTER)) { cd = cd.withExtends(null); updateCursor(cd); // This is an interesting one... WebMvcConfigurerAdapter implements WebMvcConfigurer @@ -59,18 +64,51 @@ public J.ClassDeclaration visitClassDeclaration(J.ClassDeclaration classDecl, Ex } cd = JavaTemplate.builder("WebMvcConfigurer") .contextSensitive() - .imports("org.springframework.web.servlet.config.annotation.WebMvcConfigurer") + .imports(WEB_MVC_CONFIGURER) .javaParser(JavaParser.fromJavaVersion() - .classpathFromResources(ctx, "spring-webmvc-5.*")) + .classpathFromResources(ctx, "spring-webmvc-5")) .build().apply(getCursor(), cd.getCoordinates().addImplementsClause()); updateCursor(cd); cd = (J.ClassDeclaration) new RemoveSuperStatementVisitor().visitNonNull(cd, ctx, getCursor().getParentOrThrow()); - maybeRemoveImport("org.springframework.web.servlet.config.annotation.WebMvcConfigurerAdapter"); - maybeAddImport("org.springframework.web.servlet.config.annotation.WebMvcConfigurer"); + maybeRemoveImport(WEB_MVC_CONFIGURER_ADAPTER); + maybeAddImport(WEB_MVC_CONFIGURER); } return cd; } + @Override + public J.NewClass visitNewClass(J.NewClass newClass, ExecutionContext ctx) { + if (newClass.getClazz() != null && TypeUtils.isOfClassType(newClass.getClazz().getType(), WEB_MVC_CONFIGURER_ADAPTER)) { + if (newClass.getClazz() instanceof J.Identifier) { + J.Identifier identifier = (J.Identifier) newClass.getClazz(); + newClass = newClass.withClazz(identifier + .withType(WEB_MVC_CONFIGURER_TYPE) + .withSimpleName(((JavaType.ShallowClass) WEB_MVC_CONFIGURER_TYPE).getClassName()) + ); + } + maybeRemoveImport(WEB_MVC_CONFIGURER_ADAPTER); + maybeAddImport(WEB_MVC_CONFIGURER); + } + return super.visitNewClass(newClass, ctx); + } + + @Override + public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration md, ExecutionContext ctx) { + if (md.getMethodType() != null && TypeUtils.isOfClassType(md.getType(), WEB_MVC_CONFIGURER_ADAPTER)) { + if (md.getReturnTypeExpression() instanceof J.Identifier) { + J.Identifier identifier = (J.Identifier) md.getReturnTypeExpression(); + md = md.withReturnTypeExpression(identifier + .withType(WEB_MVC_CONFIGURER_TYPE) + .withSimpleName(((JavaType.ShallowClass) WEB_MVC_CONFIGURER_TYPE).getClassName()) + ); + } + + maybeRemoveImport(WEB_MVC_CONFIGURER_ADAPTER); + maybeAddImport(WEB_MVC_CONFIGURER); + } + return super.visitMethodDeclaration(md, ctx); + } + class RemoveSuperStatementVisitor extends JavaIsoVisitor { final MethodMatcher wm = new MethodMatcher("org.springframework.web.servlet.config.annotation.WebMvcConfigurerAdapter *(..)"); diff --git a/src/testWithSpringBoot_2_1/java/org/openrewrite/java/spring/framework/MigrateWebMvcConfigurerAdapterTest.java b/src/testWithSpringBoot_2_1/java/org/openrewrite/java/spring/framework/MigrateWebMvcConfigurerAdapterTest.java index 42e9f3c76..92e90253c 100644 --- a/src/testWithSpringBoot_2_1/java/org/openrewrite/java/spring/framework/MigrateWebMvcConfigurerAdapterTest.java +++ b/src/testWithSpringBoot_2_1/java/org/openrewrite/java/spring/framework/MigrateWebMvcConfigurerAdapterTest.java @@ -28,8 +28,8 @@ class MigrateWebMvcConfigurerAdapterTest implements RewriteTest { @Override public void defaults(RecipeSpec spec) { - spec.parser(JavaParser.fromJavaVersion() - .classpathFromResources(new InMemoryExecutionContext(), "spring-webmvc-5.*", "spring-core-5.*", "spring-web-5.*")) + spec.parser(JavaParser.fromJavaVersion().classpathFromResources(new InMemoryExecutionContext(), + "spring-webmvc-5", "spring-core-5", "spring-web-5")) .recipe(new MigrateWebMvcConfigurerAdapter()); } @@ -39,30 +39,55 @@ void transformSimple() { rewriteRun( //language=java java( - """ - package a.b.c; - - import org.springframework.web.servlet.config.annotation.WebMvcConfigurerAdapter; - - public class CustomMvcConfigurer extends WebMvcConfigurerAdapter { - private final String someArg; - public CustomMvcConfigurer(String someArg) { - super(); - this.someArg = someArg; - } - } - """, """ - package a.b.c; - - import org.springframework.web.servlet.config.annotation.WebMvcConfigurer; - - public class CustomMvcConfigurer implements WebMvcConfigurer { - private final String someArg; - public CustomMvcConfigurer(String someArg) { - this.someArg = someArg; - } - } - """) + """ + import org.springframework.web.servlet.config.annotation.WebMvcConfigurerAdapter; + + public class CustomMvcConfigurer extends WebMvcConfigurerAdapter { + private final String someArg; + public CustomMvcConfigurer(String someArg) { + super(); + this.someArg = someArg; + } + } + """, """ + import org.springframework.web.servlet.config.annotation.WebMvcConfigurer; + + public class CustomMvcConfigurer implements WebMvcConfigurer { + private final String someArg; + public CustomMvcConfigurer(String someArg) { + this.someArg = someArg; + } + } + """) + ); + } + + @Test + void transformBean() { + // language=java + rewriteRun( + java( + """ + import org.springframework.web.servlet.config.annotation.WebMvcConfigurerAdapter; + + class WebConfig { + WebMvcConfigurerAdapter forwardToIndex() { + return new WebMvcConfigurerAdapter() { + }; + } + } + """, + """ + import org.springframework.web.servlet.config.annotation.WebMvcConfigurer; + + class WebConfig { + WebMvcConfigurer forwardToIndex() { + return new WebMvcConfigurer() { + }; + } + } + """ + ) ); } }