diff --git a/src/main/java/org/openrewrite/java/migrate/ReferenceCloneMethod.java b/src/main/java/org/openrewrite/java/migrate/ReferenceCloneMethod.java index f518243f42..5fa45dd207 100644 --- a/src/main/java/org/openrewrite/java/migrate/ReferenceCloneMethod.java +++ b/src/main/java/org/openrewrite/java/migrate/ReferenceCloneMethod.java @@ -24,6 +24,7 @@ import org.openrewrite.java.JavaTemplate; import org.openrewrite.java.JavaVisitor; import org.openrewrite.java.MethodMatcher; +import org.openrewrite.java.ShortenFullyQualifiedTypeReferences; import org.openrewrite.java.search.UsesMethod; import org.openrewrite.java.tree.J; import org.openrewrite.java.tree.TypeUtils; @@ -51,12 +52,11 @@ public TreeVisitor getVisitor() { return Preconditions.check( new UsesMethod<>(REFERENCE_CLONE), new JavaVisitor() { - private static final String REFERENCE_CLONE_REPLACED = "REFERENCE_CLONE_REPLACED"; @Override - public J visitTypeCast(J.TypeCast typeCast, ExecutionContext executionContext) { - J j = super.visitTypeCast(typeCast, executionContext); + public J visitTypeCast(J.TypeCast typeCast, ExecutionContext ctx) { + J j = super.visitTypeCast(typeCast, ctx); if (Boolean.TRUE.equals(getCursor().pollNearestMessage(REFERENCE_CLONE_REPLACED)) && j instanceof J.TypeCast) { J.TypeCast tc = (J.TypeCast) j; @@ -72,17 +72,14 @@ public J visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) super.visitMethodInvocation(method, ctx); if (REFERENCE_CLONE.matches(method) && method.getSelect() instanceof J.Identifier) { J.Identifier methodRef = (J.Identifier) method.getSelect(); - String className = methodRef.getType().toString() - .replace("java.lang.ref.", "") - .replace("java.lang.", ""); - String template = "new " + className + "(" + methodRef.getSimpleName() + ", new ReferenceQueue<>())"; + String template = "new " + methodRef.getType().toString() + "(" + methodRef.getSimpleName() + ", new ReferenceQueue<>())"; getCursor().putMessageOnFirstEnclosing(J.TypeCast.class, REFERENCE_CLONE_REPLACED, true); - return JavaTemplate.builder(template) + J replacement = JavaTemplate.builder(template) .contextSensitive() - .imports( - methodRef.getType().toString(), - "java.lang.ref.ReferenceQueue") + .imports("java.lang.ref.ReferenceQueue") .build().apply(getCursor(), method.getCoordinates().replace()); + doAfterVisit(ShortenFullyQualifiedTypeReferences.modifyOnly(replacement)); + return replacement; } return method; }