From c32fca971424e11559210f390fe52494b27cc2c0 Mon Sep 17 00:00:00 2001 From: Knut Wannheden Date: Fri, 12 May 2023 22:02:06 +0200 Subject: [PATCH] Better import handling Issue: #5 --- .../template/RefasterTemplateProcessor.java | 8 ++++++- .../template/internal/ImportDetector.java | 21 +++++++++++++++---- 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/src/main/java/org/openrewrite/java/template/RefasterTemplateProcessor.java b/src/main/java/org/openrewrite/java/template/RefasterTemplateProcessor.java index c9c9739a..b7a35c7d 100644 --- a/src/main/java/org/openrewrite/java/template/RefasterTemplateProcessor.java +++ b/src/main/java/org/openrewrite/java/template/RefasterTemplateProcessor.java @@ -327,9 +327,12 @@ public void visitClassDef(JCTree.JCClassDecl classDecl) { } private static String lambdaCastType(Class type, JCTree.JCMethodDecl method) { + if (type == JCTree.JCMethodInvocation.class && method.getBody().getStatements().last() instanceof JCTree.JCExpressionStatement) { + return ""; + } int paramCount = method.params.size(); boolean asFunction = !(method.restype.type instanceof Type.JCVoidType) && JCTree.JCExpression.class.isAssignableFrom(type); - StringJoiner joiner = new StringJoiner(", ", "<", ">"); + StringJoiner joiner = new StringJoiner(", ", "<", ">").setEmptyValue(""); for (int i = 0; i < (asFunction ? paramCount + 1 : paramCount); i++) { joiner.add("?"); } @@ -388,6 +391,9 @@ private String toLambda(JCTree.JCMethodDecl method) { JCTree.JCStatement statement = method.getBody().getStatements().get(0); if (statement instanceof JCTree.JCReturn) { builder.append(((JCTree.JCReturn) statement).getExpression().toString()); + } else if (statement instanceof JCTree.JCThrow) { + String string = statement.toString(); + builder.append("{ ").append(string).append(" }"); } else { String string = statement.toString(); builder.append(string, 0, string.length() - 1); diff --git a/src/main/java/org/openrewrite/java/template/internal/ImportDetector.java b/src/main/java/org/openrewrite/java/template/internal/ImportDetector.java index eda81b45..5d8fea1d 100644 --- a/src/main/java/org/openrewrite/java/template/internal/ImportDetector.java +++ b/src/main/java/org/openrewrite/java/template/internal/ImportDetector.java @@ -16,6 +16,7 @@ package org.openrewrite.java.template.internal; import com.sun.tools.javac.code.Symbol; +import com.sun.tools.javac.code.Type; import com.sun.tools.javac.tree.JCTree; import com.sun.tools.javac.tree.JCTree.JCFieldAccess; import com.sun.tools.javac.tree.JCTree.JCIdent; @@ -23,7 +24,9 @@ import javax.lang.model.element.ElementKind; import java.util.ArrayList; +import java.util.LinkedHashSet; import java.util.List; +import java.util.Set; public class ImportDetector { @@ -34,18 +37,19 @@ public class ImportDetector { * @return The list of imports to add. */ public static List imports(JCTree input) { - List imports = new ArrayList<>(); + Set imports = new LinkedHashSet<>(); new TreeScanner() { @Override public void scan(JCTree tree) { JCTree maybeFieldAccess = tree; if (maybeFieldAccess instanceof JCFieldAccess && - Character.isUpperCase(((JCFieldAccess) maybeFieldAccess).getIdentifier().toString().charAt(0))) { + ((JCFieldAccess) maybeFieldAccess).sym instanceof Symbol.ClassSymbol && + Character.isUpperCase(((JCFieldAccess) maybeFieldAccess).getIdentifier().toString().charAt(0))) { while (maybeFieldAccess instanceof JCFieldAccess) { maybeFieldAccess = ((JCFieldAccess) maybeFieldAccess).getExpression(); if (maybeFieldAccess instanceof JCIdent && - Character.isUpperCase(((JCIdent) maybeFieldAccess).getName().toString().charAt(0))) { + Character.isUpperCase(((JCIdent) maybeFieldAccess).getName().toString().charAt(0))) { // this might be a fully qualified type name, so we don't want to add an import for it // and returning will skip the nested identifier which represents just the class simple name return; @@ -64,12 +68,21 @@ public void scan(JCTree tree) { } else if (((JCIdent) tree).sym.getKind() == ElementKind.METHOD) { imports.add(((JCIdent) tree).sym); } + } else if (tree instanceof JCFieldAccess && ((JCFieldAccess) tree).sym instanceof Symbol.VarSymbol + && ((JCFieldAccess) tree).selected instanceof JCIdent + && ((JCIdent) ((JCFieldAccess) tree).selected).sym instanceof Symbol.ClassSymbol) { + imports.add(((JCIdent) ((JCFieldAccess) tree).selected).sym); + } else if (tree instanceof JCFieldAccess && ((JCFieldAccess) tree).sym instanceof Symbol.ClassSymbol + && ((JCFieldAccess) tree).selected instanceof JCIdent + && ((JCIdent) ((JCFieldAccess) tree).selected).sym instanceof Symbol.ClassSymbol + && !(((JCIdent) ((JCFieldAccess) tree).selected).sym.type instanceof Type.ErrorType)) { + imports.add(((JCIdent) ((JCFieldAccess) tree).selected).sym); } super.scan(tree); } }.scan(input); - return imports; + return new ArrayList<>(imports); } }