diff --git a/src/main/java/org/openrewrite/java/spring/data/MigrateAuditorAwareToOptional.java b/src/main/java/org/openrewrite/java/spring/data/MigrateAuditorAwareToOptional.java index b78e67b5e..fad42dee7 100644 --- a/src/main/java/org/openrewrite/java/spring/data/MigrateAuditorAwareToOptional.java +++ b/src/main/java/org/openrewrite/java/spring/data/MigrateAuditorAwareToOptional.java @@ -15,6 +15,7 @@ */ package org.openrewrite.java.spring.data; +import lombok.RequiredArgsConstructor; import org.jspecify.annotations.Nullable; import org.openrewrite.*; import org.openrewrite.java.*; @@ -41,18 +42,14 @@ public String getDescription() { @Override public TreeVisitor getVisitor() { - return Preconditions.check(new UsesType<>("org.springframework.data.domain.AuditorAware", true), new TreeVisitor() { - + ImplementationVisitor implementationVisitor = new ImplementationVisitor(); + FunctionalVisitor functionalVisitor = new FunctionalVisitor(implementationVisitor); + return Preconditions.check(new UsesType<>("org.springframework.data.domain.AuditorAware", true), new JavaIsoVisitor() { @Override - public @Nullable Tree visit(@Nullable Tree tree, ExecutionContext ctx, Cursor parent) { - if (!(tree instanceof SourceFile)) { - return tree; - } - - ImplementationVisitor implementationVisitor = new ImplementationVisitor(); + public @Nullable J visit(@Nullable Tree tree, ExecutionContext ctx) { tree = implementationVisitor.visit(tree, ctx); - tree = new FunctionalVisitor(implementationVisitor).visit(tree, ctx); - return tree; + tree = functionalVisitor.visit(tree, ctx); + return (J) tree; } }); } @@ -92,13 +89,10 @@ public J.Return visitReturn(J.Return return_, ExecutionContext ctx) { } } + @RequiredArgsConstructor private static class FunctionalVisitor extends JavaIsoVisitor { private final JavaIsoVisitor implementationVisitor; - public FunctionalVisitor(JavaIsoVisitor implementationVisitor) { - this.implementationVisitor = implementationVisitor; - } - @Override public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration method, ExecutionContext ctx) { if (!isAuditorAware.matches(method.getReturnTypeExpression()) || method.getBody() == null || method.getBody().getStatements().size() != 1) { @@ -108,14 +102,11 @@ public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration method, Ex if (!(statement instanceof J.Return)) { return method; } - return super.visitMethodDeclaration(method, ctx); } - @Override public J.Return visitReturn(J.Return return_, ExecutionContext ctx) { - Expression expression = return_.getExpression(); if (expression instanceof J.MemberReference) { J.MemberReference memberReference = (J.MemberReference) expression; @@ -123,7 +114,6 @@ public J.Return visitReturn(J.Return return_, ExecutionContext ctx) { if (methodType == null || isOptional.matches(methodType.getReturnType())) { return return_; } - expression = (Expression) new MemberReferenceToMethodInvocation().visitNonNull(memberReference, ctx, new Cursor(getCursor(), expression).getParent()); } if (expression instanceof J.Lambda) { diff --git a/src/test/java/org/openrewrite/java/spring/data/MigrateAuditorAwareToOptionalTest.java b/src/test/java/org/openrewrite/java/spring/data/MigrateAuditorAwareToOptionalTest.java index f3f5e1e8f..d81161802 100644 --- a/src/test/java/org/openrewrite/java/spring/data/MigrateAuditorAwareToOptionalTest.java +++ b/src/test/java/org/openrewrite/java/spring/data/MigrateAuditorAwareToOptionalTest.java @@ -131,9 +131,6 @@ public AuditorAware auditorAware() { @Test void rewriteInterfaceInstantiation() { - //TODO Question for TIM: how to get rid of the types? I have the imports. - //- public Optional getCurrentAuditor() { - //+ public java.util.Optional getCurrentAuditor() { rewriteRun( //language=java java(