diff --git a/.gitignore b/.gitignore index 311643aa..03ed5d58 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,3 @@ build/ .idea/ .boot-releases out/ -processor/ diff --git a/src/main/java/org/openrewrite/java/template/processor/RefasterTemplateProcessor.java b/src/main/java/org/openrewrite/java/template/processor/RefasterTemplateProcessor.java new file mode 100644 index 00000000..7d3739ab --- /dev/null +++ b/src/main/java/org/openrewrite/java/template/processor/RefasterTemplateProcessor.java @@ -0,0 +1,759 @@ +/* + * Copyright 2023 the original author or authors. + *

+ * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *

+ * https://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.openrewrite.java.template.processor; + +import com.sun.source.tree.*; +import com.sun.tools.javac.code.Symbol; +import com.sun.tools.javac.code.Type; +import com.sun.tools.javac.tree.JCTree; +import com.sun.tools.javac.tree.JCTree.JCCompilationUnit; +import com.sun.tools.javac.tree.TreeMaker; +import com.sun.tools.javac.tree.TreeScanner; +import com.sun.tools.javac.util.Context; +import org.jetbrains.annotations.Nullable; +import org.openrewrite.java.template.internal.FQNPretty; +import org.openrewrite.java.template.internal.ImportDetector; +import org.openrewrite.java.template.internal.JavacResolution; +import org.openrewrite.java.template.internal.UsedMethodDetector; + +import javax.annotation.processing.RoundEnvironment; +import javax.annotation.processing.SupportedAnnotationTypes; +import javax.lang.model.element.Element; +import javax.lang.model.element.Modifier; +import javax.lang.model.element.NestingKind; +import javax.lang.model.element.TypeElement; +import javax.tools.Diagnostic.Kind; +import javax.tools.JavaFileObject; +import java.io.BufferedWriter; +import java.io.IOException; +import java.io.Writer; +import java.util.*; +import java.util.function.Predicate; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static java.util.Collections.singletonList; +import static org.openrewrite.java.template.processor.RefasterTemplateProcessor.AFTER_TEMPLATE; +import static org.openrewrite.java.template.processor.RefasterTemplateProcessor.BEFORE_TEMPLATE; + +/** + * For steps to debug this annotation processor, see + * this blog post. + */ +@SupportedAnnotationTypes({BEFORE_TEMPLATE, AFTER_TEMPLATE}) +public class RefasterTemplateProcessor extends TypeAwareProcessor { + static final String BEFORE_TEMPLATE = "com.google.errorprone.refaster.annotation.BeforeTemplate"; + static final String AFTER_TEMPLATE = "com.google.errorprone.refaster.annotation.AfterTemplate"; + static Set UNSUPPORTED_ANNOTATIONS = Stream.of( + "com.google.errorprone.refaster.annotation.AlsoNegation", + "com.google.errorprone.refaster.annotation.AllowCodeBetweenLines", + "com.google.errorprone.refaster.annotation.Matches", + "com.google.errorprone.refaster.annotation.MayOptionallyUse", + "com.google.errorprone.refaster.annotation.NoAutoboxing", + "com.google.errorprone.refaster.annotation.NotMatches", + "com.google.errorprone.refaster.annotation.OfKind", + "com.google.errorprone.refaster.annotation.Placeholder", + "com.google.errorprone.refaster.annotation.Repeated", + "com.google.errorprone.refaster.annotation.UseImportPolicy", + "com.google.errorprone.annotations.DoNotCall" + ).collect(Collectors.toSet()); + + static ClassValue> LST_TYPE_MAP = new ClassValue>() { + @Override + protected List computeValue(Class type) { + if (JCTree.JCUnary.class.isAssignableFrom(type)) { + return singletonList("J.Unary"); + } else if (JCTree.JCBinary.class.isAssignableFrom(type)) { + return singletonList("J.Binary"); + } else if (JCTree.JCMethodInvocation.class.isAssignableFrom(type)) { + return singletonList("J.MethodInvocation"); + } else if (JCTree.JCFieldAccess.class.isAssignableFrom(type)) { + return Arrays.asList("J.FieldAccess", "J.Identifier"); + } else if (JCTree.JCExpression.class.isAssignableFrom(type)) { + // catch all for expressions + return singletonList("Expression"); + } else if (JCTree.JCStatement.class.isAssignableFrom(type)) { + // catch all for statements + return singletonList("Statement"); + } + throw new IllegalArgumentException(type.toString()); + } + }; + + @Override + public boolean process(Set annotations, RoundEnvironment roundEnv) { + for (Element element : roundEnv.getRootElements()) { + JCCompilationUnit jcCompilationUnit = toUnit(element); + if (jcCompilationUnit != null) { + maybeGenerateTemplateSources(jcCompilationUnit); + } + } + + return true; + } + + void maybeGenerateTemplateSources(JCCompilationUnit cu) { + Context context = javacProcessingEnv.getContext(); + + new TreeScanner() { + final Map> imports = new HashMap<>(); + final Map> staticImports = new HashMap<>(); + final Map recipes = new LinkedHashMap<>(); + + @Override + public void visitClassDef(JCTree.JCClassDecl classDecl) { + super.visitClassDef(classDecl); + + TemplateDescriptor descriptor = getTemplateDescriptor(classDecl, context, cu); + if (descriptor != null) { + + TreeMaker treeMaker = TreeMaker.instance(context).forToplevel(cu); + List membersWithoutConstructor = classDecl.getMembers().stream() + .filter(m -> !(m instanceof JCTree.JCMethodDecl) || !((JCTree.JCMethodDecl) m).name.contentEquals("")) + .collect(Collectors.toList()); + JCTree.JCClassDecl copy = treeMaker.ClassDef(classDecl.mods, classDecl.name, classDecl.typarams, classDecl.extending, classDecl.implementing, com.sun.tools.javac.util.List.from(membersWithoutConstructor)); + + processingEnv.getMessager().printMessage(Kind.NOTE, "Generating template for " + descriptor.classDecl.getSimpleName()); + + String templateName = classDecl.sym.fullname.toString().substring(classDecl.sym.packge().fullname.length() + 1); + String templateFqn = classDecl.sym.fullname.toString() + "Recipe"; + String templateCode = copy.toString().trim(); + String displayName = cu.docComments.getComment(classDecl) != null ? cu.docComments.getComment(classDecl).getText().trim() : "Refaster template `" + templateName + '`'; + if (displayName.endsWith(".")) { + displayName = displayName.substring(0, displayName.length() - 1); + } + + for (JCTree.JCMethodDecl template : descriptor.beforeTemplates) { + for (Symbol anImport : ImportDetector.imports(template)) { + if (anImport instanceof Symbol.ClassSymbol) { + imports.computeIfAbsent(template, k -> new TreeSet<>()) + .add(anImport.getQualifiedName().toString().replace('$', '.')); + } else if (anImport instanceof Symbol.VarSymbol || anImport instanceof Symbol.MethodSymbol) { + staticImports.computeIfAbsent(template, k -> new TreeSet<>()) + .add(anImport.owner.getQualifiedName().toString().replace('$', '.') + '.' + anImport.flatName().toString()); + } else { + throw new AssertionError(anImport.getClass()); + } + } + } + for (Symbol anImport : ImportDetector.imports(descriptor.afterTemplate)) { + if (anImport instanceof Symbol.ClassSymbol) { + imports.computeIfAbsent(descriptor.afterTemplate, k -> new TreeSet<>()) + .add(anImport.getQualifiedName().toString().replace('$', '.')); + } else if (anImport instanceof Symbol.VarSymbol || anImport instanceof Symbol.MethodSymbol) { + staticImports.computeIfAbsent(descriptor.afterTemplate, k -> new TreeSet<>()) + .add(anImport.owner.getQualifiedName().toString().replace('$', '.') + '.' + anImport.flatName().toString()); + } else { + throw new AssertionError(anImport.getClass()); + } + } + + for (Set imports : imports.values()) { + imports.removeIf(i -> "java.lang".equals(i.substring(0, i.lastIndexOf('.')))); + imports.remove(BEFORE_TEMPLATE); + imports.remove(AFTER_TEMPLATE); + } + + Map befores = new LinkedHashMap<>(); + for (JCTree.JCMethodDecl templ : descriptor.beforeTemplates) { + String name = templ.getName().toString(); + if (befores.containsKey(name)) { + String base = name; + for (int i = 0; ; i++) { + name = base + i; + if (!befores.containsKey(name)) { + break; + } + } + } + befores.put(name, templ); + } + String after = descriptor.afterTemplate.getName().toString(); + + StringBuilder recipe = new StringBuilder(); + String recipeName = templateFqn.substring(templateFqn.lastIndexOf('.') + 1); + String modifiers = classDecl.getModifiers().getFlags().stream().map(Modifier::toString).collect(Collectors.joining(" ")); + if (!modifiers.isEmpty()) { + modifiers += " "; + } + recipe.append("@NonNullApi\n"); + recipe.append(modifiers).append("class ").append(recipeName).append(" extends Recipe {\n"); + recipe.append("\n"); + recipe.append(" @Override\n"); + recipe.append(" public String getDisplayName() {\n"); + recipe.append(" return \"").append(escape(displayName)).append("\";\n"); + recipe.append(" }\n"); + recipe.append("\n"); + recipe.append(" @Override\n"); + recipe.append(" public String getDescription() {\n"); + recipe.append(" return \"Recipe created for the following Refaster template:\\n```java\\n").append(escape(templateCode)).append("\\n```\\n.\";\n"); + recipe.append(" }\n"); + recipe.append("\n"); + recipe.append(" @Override\n"); + recipe.append(" public TreeVisitor getVisitor() {\n"); + recipe.append(" JavaVisitor javaVisitor = new AbstractRefasterJavaVisitor() {\n"); + for (Map.Entry entry : befores.entrySet()) { + recipe.append(" final Supplier ") + .append(entry.getKey()) + .append(" = memoize(() -> Semantics.") + .append(statementType(entry.getValue())) + .append("(this, \"") + .append(entry.getKey()).append("\", ") + .append(toLambda(entry.getValue())) + .append(").build());\n"); + } + recipe.append(" final Supplier ") + .append(after) + .append(" = memoize(() -> Semantics.") + .append(statementType(descriptor.afterTemplate)) + .append("(this, \"") + .append(after) + .append("\", ") + .append(toLambda(descriptor.afterTemplate)) + .append(").build());\n"); + recipe.append("\n"); + + List lstTypes = LST_TYPE_MAP.get(getType(descriptor.beforeTemplates.get(0))); + String parameters = parameters(descriptor); + for (String lstType : lstTypes) { + String methodSuffix = lstType.startsWith("J.") ? lstType.substring(2) : lstType; + recipe.append(" @Override\n"); + recipe.append(" public J visit").append(methodSuffix).append("(").append(lstType).append(" elem, ExecutionContext ctx) {\n"); + if (lstType.equals("Statement")) { + recipe.append(" if (elem instanceof J.Block) {;\n"); + recipe.append(" // FIXME workaround\n"); + recipe.append(" return elem;\n"); + recipe.append(" }\n"); + } + + recipe.append(" JavaTemplate.Matcher matcher;\n"); + for (Map.Entry entry : befores.entrySet()) { + recipe.append(" if (" + "(matcher = matcher(").append(entry.getKey()).append(", getCursor())).find()").append(") {\n"); + com.sun.tools.javac.util.List jcVariableDecls = entry.getValue().getParameters(); + for (int i = 0; i < jcVariableDecls.size(); i++) { + JCTree.JCVariableDecl param = jcVariableDecls.get(i); + com.sun.tools.javac.util.List annotations = param.getModifiers().getAnnotations(); + for (JCTree.JCAnnotation jcAnnotation : annotations) { + String annotationType = jcAnnotation.attribute.type.tsym.getQualifiedName().toString(); + if (annotationType.equals("org.openrewrite.java.template.NotMatches")) { + String matcher = ((Type.ClassType) jcAnnotation.attribute.getValue().values.get(0).snd.getValue()).tsym.getQualifiedName().toString(); + recipe.append(" if (new ").append(matcher).append("().matches((Expression) matcher.parameter(").append(i).append("))) {\n"); + recipe.append(" return super.visit").append(methodSuffix).append("(elem, ctx);\n"); + recipe.append(" }\n"); + } else if (annotationType.equals("org.openrewrite.java.template.Matches")) { + String matcher = ((Type.ClassType) jcAnnotation.attribute.getValue().values.get(0).snd.getValue()).tsym.getQualifiedName().toString(); + recipe.append(" if (!new ").append(matcher).append("().matches((Expression) matcher.parameter(").append(i).append("))) {\n"); + recipe.append(" return super.visit").append(methodSuffix).append("(elem, ctx);\n"); + recipe.append(" }\n"); + } + } + } + Set beforeImports = imports.entrySet().stream() + .filter(e -> entry.getValue().equals(e.getKey())) + .map(Map.Entry::getValue) + .flatMap(Set::stream) + .collect(Collectors.toSet()); + Set afterImports = imports.entrySet().stream() + .filter(e -> descriptor.afterTemplate == e.getKey()) + .map(Map.Entry::getValue) + .flatMap(Set::stream) + .collect(Collectors.toSet()); + maybeRemoveImport(imports, beforeImports, afterImports, recipe); + beforeImports = staticImports.entrySet().stream() + .filter(e -> entry.getValue().equals(e.getKey())) + .map(Map.Entry::getValue) + .flatMap(Set::stream) + .collect(Collectors.toSet()); + afterImports = staticImports.entrySet().stream() + .filter(e -> descriptor.afterTemplate == e.getKey()) + .map(Map.Entry::getValue) + .flatMap(Set::stream) + .collect(Collectors.toSet()); + maybeRemoveImport(staticImports, beforeImports, afterImports, recipe); + if (parameters.isEmpty()) { + recipe.append(" return embed(apply(").append(after).append(", getCursor(), elem.getCoordinates().replace()), getCursor(), ctx);\n"); + } else { + recipe.append(" return embed(\n"); + recipe.append(" apply(").append(after).append(", getCursor(), elem.getCoordinates().replace(), ").append(parameters).append("),\n"); + recipe.append(" getCursor(),\n"); + recipe.append(" ctx\n"); + recipe.append(" );\n"); + } + recipe.append(" }\n"); + } + recipe.append(" return super.visit").append(methodSuffix).append("(elem, ctx);\n"); + recipe.append(" }\n"); + recipe.append("\n"); + } + recipe.append(" };\n"); + + String preconditions = generatePreconditions(descriptor.beforeTemplates, imports, 16); + if (preconditions == null) { + recipe.append(" return javaVisitor;\n"); + } else { + recipe.append(" return Preconditions.check(\n"); + recipe.append(" ").append(preconditions).append(",\n"); + recipe.append(" javaVisitor\n"); + recipe.append(" );\n"); + } + recipe.append(" }\n"); + recipe.append("}\n"); + recipes.put(recipeName, recipe.toString()); + } + + if (classDecl.sym != null && classDecl.sym.getNestingKind() == NestingKind.TOP_LEVEL && !recipes.isEmpty()) { + boolean outerClassRequired = descriptor == null; + try { + String inputOuterFQN = outerClassRequired ? classDecl.sym.fullname.toString() : descriptor.classDecl.sym.fullname.toString(); + String className = inputOuterFQN + (outerClassRequired ? "Recipes" : "Recipe"); + JavaFileObject builderFile = processingEnv.getFiler().createSourceFile(className); + try (Writer out = new BufferedWriter(builderFile.openWriter())) { + out.write("package " + classDecl.sym.packge().toString() + ";\n"); + out.write("\n"); + out.write("import org.openrewrite.ExecutionContext;\n"); + out.write("import org.openrewrite.Preconditions;\n"); + out.write("import org.openrewrite.Recipe;\n"); + out.write("import org.openrewrite.TreeVisitor;\n"); + out.write("import org.openrewrite.internal.lang.NonNullApi;\n"); + out.write("import org.openrewrite.java.JavaTemplate;\n"); + out.write("import org.openrewrite.java.JavaVisitor;\n"); + out.write("import org.openrewrite.java.search.*;\n"); + out.write("import org.openrewrite.java.template.Primitive;\n"); + out.write("import org.openrewrite.java.template.Semantics;\n"); + out.write("import org.openrewrite.java.template.function.*;\n"); + out.write("import org.openrewrite.java.template.internal.AbstractRefasterJavaVisitor;\n"); + out.write("import org.openrewrite.java.tree.*;\n"); + out.write("\n"); + out.write("import java.util.function.Supplier;\n"); + if (outerClassRequired) { + out.write("\n"); + out.write("import java.util.Arrays;\n"); + out.write("import java.util.List;\n"); + } + + out.write("\n"); + + if (!imports.isEmpty()) { + for (String anImport : imports.values().stream().flatMap(Set::stream).collect(Collectors.toSet())) { + out.write("import " + anImport + ";\n"); + } + out.write("\n"); + } + if (!staticImports.isEmpty()) { + for (String anImport : staticImports.values().stream().flatMap(Set::stream).collect(Collectors.toSet())) { + out.write("import static " + anImport + ";\n"); + } + out.write("\n"); + } + + if (outerClassRequired) { + String outerClassName = className.substring(className.lastIndexOf('.') + 1); + out.write("public final class " + outerClassName + " extends Recipe {\n"); + + String simpleInputOuterFQN = inputOuterFQN.substring(inputOuterFQN.lastIndexOf('.') + 1); + out.write("\n" + + " @Override\n" + + " public String getDisplayName() {\n" + + " return \"`" + simpleInputOuterFQN + "` Refaster recipes\";\n" + + " }\n" + + "\n" + + " @Override\n" + + " public String getDescription() {\n" + + " return \"Refaster template recipes for `" + inputOuterFQN + "`.\";\n" + + " }\n" + + "\n"); + String recipesAsList = recipes.keySet().stream() + .map(r -> " new " + r.substring(r.lastIndexOf('.') + 1) + "()") + .collect(Collectors.joining(",\n")); + out.write( + " @Override\n" + + " public List getRecipeList() {\n" + + " return Arrays.asList(\n" + + recipesAsList + '\n' + + " );\n" + + " }\n\n"); + + for (String r : recipes.values()) { + out.write(r.replaceAll("(?m)^(.+)$", " $1")); + out.write('\n'); + } + out.write("}\n"); + } else { + for (String r : recipes.values()) { + out.write(r); + out.write('\n'); + } + } + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + } + + private void maybeRemoveImport(Map> imports, Set beforeImports, Set afterImports, StringBuilder recipe) { + for (String anImport : imports.values().stream().flatMap(Set::stream).collect(Collectors.toSet())) { + if (anImport.startsWith("java.lang.")) { + continue; + } + //noinspection StatementWithEmptyBody + if (beforeImports.contains(anImport) && afterImports.contains(anImport)) { + // do nothing + } else if (beforeImports.contains(anImport)) { + recipe.append(" maybeRemoveImport(\"").append(anImport).append("\");\n"); + } + } + } + + /* Generate the minimal precondition that would allow to match each before template individually. */ + @SuppressWarnings("SameParameterValue") + @Nullable + private String generatePreconditions(List beforeTemplates, + Map> imports, + int indent) { + Map> preconditions = new LinkedHashMap<>(); + for (JCTree.JCMethodDecl beforeTemplate : beforeTemplates) { + Set usesVisitors = new LinkedHashSet<>(); + + Set localImports = imports.getOrDefault(beforeTemplate, Collections.emptySet()); + for (String anImport : localImports) { + usesVisitors.add("new UsesType<>(\"" + anImport + "\", true)"); + } + List usedMethods = UsedMethodDetector.usedMethods(beforeTemplate); + for (Symbol.MethodSymbol method : usedMethods) { + usesVisitors.add("new UsesMethod<>(\"" + method.owner.getQualifiedName().toString() + ' ' + method.name.toString() + "(..)\")"); + } + + preconditions.put(beforeTemplate, usesVisitors); + } + + if (preconditions.size() == 1) { + return joinPreconditions(preconditions.values().iterator().next(), "and", indent + 4); + } else if (preconditions.size() > 1) { + Set common = new LinkedHashSet<>(); + for (String dep : preconditions.values().iterator().next()) { + if (preconditions.values().stream().allMatch(v -> v.contains(dep))) { + common.add(dep); + } + } + common.forEach(dep -> preconditions.values().forEach(v -> v.remove(dep))); + preconditions.values().removeIf(Collection::isEmpty); + + if (common.isEmpty()) { + return joinPreconditions(preconditions.values().stream().map(v -> joinPreconditions(v, "and", indent + 4)).collect(Collectors.toList()), "or", indent + 4); + } else { + if (!preconditions.isEmpty()) { + String uniqueConditions = joinPreconditions(preconditions.values().stream().map(v -> joinPreconditions(v, "and", indent + 12)).collect(Collectors.toList()), "or", indent + 8); + common.add(uniqueConditions); + } + return joinPreconditions(common, "and", indent + 4); + } + } + return null; + } + + private String joinPreconditions(Collection preconditions, String op, int indent) { + if (preconditions.isEmpty()) { + return null; + } else if (preconditions.size() == 1) { + return preconditions.iterator().next(); + } + char[] indentChars = new char[indent]; + Arrays.fill(indentChars, ' '); + String indentStr = new String(indentChars); + return "Preconditions." + op + "(\n" + indentStr + String.join(",\n" + indentStr, preconditions) + "\n" + indentStr.substring(0, indent - 4) + ')'; + } + }.scan(cu); + } + + private String escape(String string) { + return string.replace("\"", "\\\"").replaceAll("\\R", "\\\\n"); + } + + private String parameters(TemplateDescriptor descriptor) { + List afterParams = new ArrayList<>(); + new TreeScanner() { + @Override + public void scan(JCTree jcTree) { + if (jcTree instanceof JCTree.JCIdent) { + JCTree.JCIdent jcIdent = (JCTree.JCIdent) jcTree; + if (jcIdent.sym instanceof Symbol.VarSymbol + && jcIdent.sym.owner instanceof Symbol.MethodSymbol + && ((Symbol.MethodSymbol) jcIdent.sym.owner).params.contains(jcIdent.sym)) { + afterParams.add(((Symbol.MethodSymbol) jcIdent.sym.owner).params.indexOf(jcIdent.sym)); + } + } + super.scan(jcTree); + } + }.scan(descriptor.afterTemplate.body); + + StringJoiner joiner = new StringJoiner(", "); + for (Integer param : afterParams) { + joiner.add("matcher.parameter(" + param + ")"); + } + return joiner.toString(); + } + + private Class getType(JCTree.JCMethodDecl method) { + JCTree.JCStatement statement = method.getBody().getStatements().get(0); + Class type = statement.getClass(); + if (statement instanceof JCTree.JCReturn) { + type = ((JCTree.JCReturn) statement).expr.getClass(); + } else if (statement instanceof JCTree.JCExpressionStatement) { + type = ((JCTree.JCExpressionStatement) statement).expr.getClass(); + } + return type; + } + + private String statementType(JCTree.JCMethodDecl method) { + // for now excluding assignment expressions and prefix and postfix -- and ++ + Set> expressionStatementTypes = Stream.of( + JCTree.JCMethodInvocation.class, + JCTree.JCNewClass.class).collect(Collectors.toSet()); + + Class type = getType(method); + if (expressionStatementTypes.contains(type)) { + if (type == JCTree.JCMethodInvocation.class + && method.getBody().getStatements().last() instanceof JCTree.JCExpressionStatement + && !(method.getReturnType().type instanceof Type.JCVoidType)) { + return "expression"; + } + if (method.restype.type instanceof Type.JCVoidType || !JCTree.JCExpression.class.isAssignableFrom(type)) { + return "statement"; + } + } + return "expression"; + } + + private String toLambda(JCTree.JCMethodDecl method) { + StringBuilder builder = new StringBuilder(); + + StringJoiner joiner = new StringJoiner(", ", "(", ")"); + for (JCTree.JCVariableDecl parameter : method.getParameters()) { + String paramType = parameter.getType().type.tsym.getQualifiedName().toString(); + + switch (paramType) { + case "boolean": + paramType = "@Primitive Boolean"; + break; + case "byte": + paramType = "@Primitive Byte"; + break; + case "char": + paramType = "@Primitive Character"; + break; + case "double": + paramType = "@Primitive Double"; + break; + case "float": + paramType = "@Primitive Float"; + break; + case "int": + paramType = "@Primitive Integer"; + break; + case "long": + paramType = "@Primitive Long"; + break; + case "short": + paramType = "@Primitive Short"; + break; + case "void": + paramType = "@Primitive Void"; + break; + } + + if (paramType.startsWith("java.lang.")) { + paramType = paramType.substring("java.lang.".length()); + } + joiner.add(paramType + " " + parameter.getName()); + } + builder.append(joiner); + builder.append(" -> "); + + JCTree.JCStatement statement = method.getBody().getStatements().get(0); + if (statement instanceof JCTree.JCReturn) { + builder.append(FQNPretty.toString(((JCTree.JCReturn) statement).getExpression())); + } else if (statement instanceof JCTree.JCThrow) { + String string = FQNPretty.toString(statement); + builder.append("{ ").append(string).append(" }"); + } else { + String string = FQNPretty.toString(statement); + builder.append(string); + } + return builder.toString(); + } + + @Nullable + private TemplateDescriptor getTemplateDescriptor(JCTree.JCClassDecl tree, Context context, JCCompilationUnit cu) { + TemplateDescriptor result = new TemplateDescriptor(tree); + for (JCTree member : tree.getMembers()) { + if (member instanceof JCTree.JCMethodDecl) { + JCTree.JCMethodDecl method = (JCTree.JCMethodDecl) member; + List annotations = getTemplateAnnotations(method, BEFORE_TEMPLATE::equals); + if (!annotations.isEmpty()) { + result.beforeTemplate(method); + } + annotations = getTemplateAnnotations(method, AFTER_TEMPLATE::equals); + if (!annotations.isEmpty()) { + result.afterTemplate(method); + } + } + } + return result.validate(context, cu); + } + + class TemplateDescriptor { + final JCTree.JCClassDecl classDecl; + final List beforeTemplates = new ArrayList<>(); + JCTree.JCMethodDecl afterTemplate; + + public TemplateDescriptor(JCTree.JCClassDecl classDecl) { + this.classDecl = classDecl; + } + + @Nullable + private TemplateDescriptor validate(Context context, JCCompilationUnit cu) { + if (beforeTemplates.isEmpty() || afterTemplate == null) { + return null; + } + + boolean valid = true; + for (JCTree member : classDecl.getMembers()) { + if (member instanceof JCTree.JCMethodDecl && !beforeTemplates.contains(member) && member != afterTemplate) { + for (JCTree.JCAnnotation annotation : getTemplateAnnotations(((JCTree.JCMethodDecl) member), UNSUPPORTED_ANNOTATIONS::contains)) { + processingEnv.getMessager().printMessage(Kind.NOTE, "The @" + annotation.annotationType + " is currently not supported", ((JCTree.JCMethodDecl) member).sym); + valid = false; + } + } + } + + // resolve so that we can inspect the template body + valid &= resolve(context, cu); + if (valid) { + for (JCTree.JCMethodDecl template : beforeTemplates) { + valid &= validateTemplateMethod(template); + } + valid &= validateTemplateMethod(afterTemplate); + } + return valid ? this : null; + } + + private boolean validateTemplateMethod(JCTree.JCMethodDecl template) { + boolean valid = true; + // TODO: support all Refaster method-level annotations + for (JCTree.JCAnnotation annotation : getTemplateAnnotations(template, UNSUPPORTED_ANNOTATIONS::contains)) { + processingEnv.getMessager().printMessage(Kind.NOTE, "The @" + annotation.annotationType + " is currently not supported", template.sym); + valid = false; + } + // TODO: support all Refaster parameter-level annotations + for (JCTree.JCVariableDecl parameter : template.getParameters()) { + for (JCTree.JCAnnotation annotation : getTemplateAnnotations(parameter, UNSUPPORTED_ANNOTATIONS::contains)) { + processingEnv.getMessager().printMessage(Kind.NOTE, "The @" + annotation.annotationType + " annotation is currently not supported", template.sym); + valid = false; + } + if (parameter.vartype instanceof ParameterizedTypeTree || parameter.vartype.type instanceof Type.TypeVar) { + processingEnv.getMessager().printMessage(Kind.NOTE, "Generics are currently not supported", template.sym); + valid = false; + } + } + if (template.restype instanceof ParameterizedTypeTree || template.restype.type instanceof Type.TypeVar) { + processingEnv.getMessager().printMessage(Kind.NOTE, "Generics are currently not supported", template.sym); + valid = false; + } + valid &= new TreeScanner() { + boolean valid = true; + + boolean validate(JCTree tree) { + scan(tree); + return valid; + } + + @Override + public void visitIdent(JCTree.JCIdent jcIdent) { + if (jcIdent.sym != null + && jcIdent.sym.packge().getQualifiedName().contentEquals("com.google.errorprone.refaster")) { + processingEnv.getMessager().printMessage(Kind.NOTE, jcIdent.type.tsym.getQualifiedName() + " is not supported", template.sym); + valid = false; + } + } + }.validate(template.getBody()); + return valid; + } + + public void beforeTemplate(JCTree.JCMethodDecl method) { + beforeTemplates.add(method); + } + + public void afterTemplate(JCTree.JCMethodDecl method) { + afterTemplate = method; + } + + private boolean resolve(Context context, JCCompilationUnit cu) { + try { + JavacResolution res = new JavacResolution(context); + beforeTemplates.replaceAll(key -> { + Map resolved = res.resolveAll(context, cu, singletonList(key)); + return (JCTree.JCMethodDecl) resolved.get(key); + }); + Map resolved = res.resolveAll(context, cu, singletonList(afterTemplate)); + afterTemplate = (JCTree.JCMethodDecl) resolved.get(afterTemplate); + } catch (Throwable t) { + processingEnv.getMessager().printMessage(Kind.WARNING, "Had trouble type attributing the template."); + return false; + } + return true; + } + + } + + private static List getTemplateAnnotations(MethodTree method, Predicate typePredicate) { + List result = new ArrayList<>(); + for (AnnotationTree annotation : method.getModifiers().getAnnotations()) { + Tree type = annotation.getAnnotationType(); + if (type.getKind() == Tree.Kind.IDENTIFIER && ((JCTree.JCIdent) type).sym != null + && typePredicate.test(((JCTree.JCIdent) type).sym.getQualifiedName().toString())) { + result.add((JCTree.JCAnnotation) annotation); + } else if (type.getKind() == Tree.Kind.IDENTIFIER && ((JCTree.JCAnnotation) annotation).attribute != null + && ((JCTree.JCAnnotation) annotation).attribute.type instanceof Type.ClassType + && ((JCTree.JCAnnotation) annotation).attribute.type.tsym != null + && typePredicate.test(((JCTree.JCAnnotation) annotation).attribute.type.tsym.getQualifiedName().toString())) { + result.add((JCTree.JCAnnotation) annotation); + } else if (type.getKind() == Tree.Kind.MEMBER_SELECT && type instanceof JCTree.JCFieldAccess + && ((JCTree.JCFieldAccess) type).sym != null + && typePredicate.test(((JCTree.JCFieldAccess) type).sym.getQualifiedName().toString())) { + result.add((JCTree.JCAnnotation) annotation); + } + } + return result; + } + + private static List getTemplateAnnotations(VariableTree parameter, Predicate typePredicate) { + List result = new ArrayList<>(); + for (AnnotationTree annotation : parameter.getModifiers().getAnnotations()) { + Tree type = annotation.getAnnotationType(); + if (type.getKind() == Tree.Kind.IDENTIFIER + && ((JCTree.JCIdent) type).sym != null + && typePredicate.test(((JCTree.JCIdent) type).sym.getQualifiedName().toString())) { + result.add((JCTree.JCAnnotation) annotation); + } else if (type.getKind() == Tree.Kind.MEMBER_SELECT && type instanceof JCTree.JCFieldAccess + && ((JCTree.JCFieldAccess) type).sym != null + && typePredicate.test(((JCTree.JCFieldAccess) type).sym.getQualifiedName().toString())) { + result.add((JCTree.JCAnnotation) annotation); + } + } + return result; + } +} diff --git a/src/main/java/org/openrewrite/java/template/processor/TemplateProcessor.java b/src/main/java/org/openrewrite/java/template/processor/TemplateProcessor.java new file mode 100644 index 00000000..94fa2aaa --- /dev/null +++ b/src/main/java/org/openrewrite/java/template/processor/TemplateProcessor.java @@ -0,0 +1,281 @@ +/* + * Copyright 2022 the original author or authors. + *

+ * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *

+ * https://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.openrewrite.java.template.processor; + +import com.sun.source.tree.Tree; +import com.sun.source.tree.VariableTree; +import com.sun.source.util.TreePathScanner; +import com.sun.tools.javac.code.Symbol; +import com.sun.tools.javac.code.Type; +import com.sun.tools.javac.tree.JCTree; +import com.sun.tools.javac.tree.JCTree.JCCompilationUnit; +import com.sun.tools.javac.tree.TreeScanner; +import com.sun.tools.javac.util.Context; +import org.openrewrite.java.template.internal.ClasspathJarNameDetector; +import org.openrewrite.java.template.internal.ImportDetector; +import org.openrewrite.java.template.internal.JavacResolution; + +import javax.annotation.processing.RoundEnvironment; +import javax.annotation.processing.SupportedAnnotationTypes; +import javax.lang.model.SourceVersion; +import javax.lang.model.element.Element; +import javax.lang.model.element.TypeElement; +import javax.tools.Diagnostic.Kind; +import javax.tools.JavaFileObject; +import java.io.*; +import java.util.*; +import java.util.concurrent.atomic.AtomicReference; + +import static java.util.Collections.*; + +/** + * For steps to debug this annotation processor, see + * this blog post. + */ +@SupportedAnnotationTypes("*") +public class TemplateProcessor extends TypeAwareProcessor { + private static final String PRIMITIVE_ANNOTATION = "org.openrewrite.java.template.Primitive"; + private static final Map PRIMITIVE_TYPE_MAP = new HashMap<>(); + + static { + PRIMITIVE_TYPE_MAP.put(Boolean.class.getName(), boolean.class.getName()); + PRIMITIVE_TYPE_MAP.put(Byte.class.getName(), byte.class.getName()); + PRIMITIVE_TYPE_MAP.put(Character.class.getName(), char.class.getName()); + PRIMITIVE_TYPE_MAP.put(Short.class.getName(), short.class.getName()); + PRIMITIVE_TYPE_MAP.put(Integer.class.getName(), int.class.getName()); + PRIMITIVE_TYPE_MAP.put(Long.class.getName(), long.class.getName()); + PRIMITIVE_TYPE_MAP.put(Float.class.getName(), float.class.getName()); + PRIMITIVE_TYPE_MAP.put(Double.class.getName(), double.class.getName()); + PRIMITIVE_TYPE_MAP.put(Void.class.getName(), void.class.getName()); + } + + private final String javaFileContent; + + public TemplateProcessor(String javaFileContent) { + this.javaFileContent = javaFileContent; + } + + public TemplateProcessor() { + this(null); + } + + @Override + public boolean process(Set annotations, RoundEnvironment roundEnv) { + for (Element element : roundEnv.getRootElements()) { + JCCompilationUnit jcCompilationUnit = toUnit(element); + if (jcCompilationUnit != null) { + maybeGenerateTemplateSources(jcCompilationUnit); + } + } + + return true; + } + + void maybeGenerateTemplateSources(JCCompilationUnit cu) { + Context context = javacProcessingEnv.getContext(); + JavacResolution res = new JavacResolution(context); + + new TreeScanner() { + @Override + public void visitApply(JCTree.JCMethodInvocation tree) { + JCTree.JCExpression jcSelect = tree.getMethodSelect(); + String name = jcSelect instanceof JCTree.JCFieldAccess ? + ((JCTree.JCFieldAccess) jcSelect).name.toString() : + ((JCTree.JCIdent) jcSelect).getName().toString(); + + if (("expression".equals(name) || "statement".equals(name)) && tree.getArguments().size() == 3) { + JCTree.JCMethodInvocation resolvedMethod; + Map resolved; + try { + resolved = res.resolveAll(context, cu, singletonList(tree)); + resolvedMethod = (JCTree.JCMethodInvocation) resolved.get(tree); + } catch (Throwable t) { + processingEnv.getMessager().printMessage(Kind.WARNING, "Had trouble type attributing the template."); + return; + } + + JCTree.JCExpression arg2 = tree.getArguments().get(2); + if (isOfClassType(resolvedMethod.type, "org.openrewrite.java.JavaTemplate.Builder") && + (arg2 instanceof JCTree.JCLambda || arg2 instanceof JCTree.JCTypeCast && ((JCTree.JCTypeCast) arg2).getExpression() instanceof JCTree.JCLambda)) { + + JCTree.JCLambda template = arg2 instanceof JCTree.JCLambda ? (JCTree.JCLambda) arg2 : (JCTree.JCLambda) ((JCTree.JCTypeCast) arg2).getExpression(); + + NavigableMap parameterPositions; + List parameters; + if (template.getParameters().isEmpty()) { + parameterPositions = emptyNavigableMap(); + parameters = emptyList(); + } else { + parameterPositions = new TreeMap<>(); + Map parameterResolution = res.resolveAll(context, cu, template.getParameters()); + parameters = new ArrayList<>(template.getParameters().size()); + for (VariableTree p : template.getParameters()) { + parameters.add((JCTree.JCVariableDecl) parameterResolution.get((JCTree) p)); + } + JCTree.JCLambda resolvedTemplate = (JCTree.JCLambda) parameterResolution.get(template); + + new TreeScanner() { + @Override + public void visitIdent(JCTree.JCIdent ident) { + for (JCTree.JCVariableDecl parameter : parameters) { + if (parameter.sym == ident.sym) { + parameterPositions.put(ident.getStartPosition(), parameter); + } + } + } + }.scan(resolvedTemplate.getBody()); + } + + try (InputStream inputStream = javaFileContent == null ? + cu.getSourceFile().openInputStream() : new ByteArrayInputStream(javaFileContent.getBytes())) { + //noinspection ResultOfMethodCallIgnored + inputStream.skip(template.getBody().getStartPosition()); + + byte[] templateSourceBytes = new byte[template.getBody().getEndPosition(cu.endPositions) - template.getBody().getStartPosition()]; + + //noinspection ResultOfMethodCallIgnored + inputStream.read(templateSourceBytes); + + String templateSource = new String(templateSourceBytes); + templateSource = templateSource.replace("\"", "\\\""); + + for (Map.Entry paramPos : parameterPositions.descendingMap().entrySet()) { + JCTree.JCVariableDecl param = paramPos.getValue(); + String type = param.type.toString(); + for (JCTree.JCAnnotation annotation : param.getModifiers().getAnnotations()) { + if (annotation.type.tsym.getQualifiedName().contentEquals(PRIMITIVE_ANNOTATION)) { + type = PRIMITIVE_TYPE_MAP.get(param.type.toString()); + // don't generate the annotation into the source code + param.mods.annotations = com.sun.tools.javac.util.List.filter(param.mods.annotations, annotation); + } + } + templateSource = templateSource.substring(0, paramPos.getKey() - template.getBody().getStartPosition()) + + "#{any(" + type + ")}" + + templateSource.substring((paramPos.getKey() - template.getBody().getStartPosition()) + + param.name.length()); + } + + JCTree.JCLiteral templateName = (JCTree.JCLiteral) tree.getArguments().get(1); + if (templateName.value == null) { + processingEnv.getMessager().printMessage(Kind.WARNING, "Can't compile a template with a null name."); + return; + } + + // this could be a visitor in the case that the visitor is in its own file or + // named inner class, or a recipe if the visitor is defined in an anonymous class + JCTree.JCClassDecl classDecl = cursor(cu, template).stream() + .filter(JCTree.JCClassDecl.class::isInstance) + .map(JCTree.JCClassDecl.class::cast) + .reduce((next, acc) -> next) + .orElseThrow(() -> new IllegalStateException("Expected to find an enclosing class")); + + String templateFqn; + + if (isOfClassType(classDecl.type, "org.openrewrite.java.JavaVisitor")) { + templateFqn = classDecl.sym.fullname.toString() + "_" + templateName.getValue().toString(); + } else { + JCTree.JCNewClass visitorClass = cursor(cu, template).stream() + .filter(JCTree.JCNewClass.class::isInstance) + .map(JCTree.JCNewClass.class::cast) + .reduce((next, acc) -> next) + .orElse(null); + + JCTree.JCNewClass resolvedVisitorClass = (JCTree.JCNewClass) resolved.get(visitorClass); + + if (resolvedVisitorClass != null && isOfClassType(resolvedVisitorClass.clazz.type, "org.openrewrite.java.JavaVisitor")) { + templateFqn = ((Symbol.ClassSymbol) resolvedVisitorClass.type.tsym).flatname.toString() + "_" + + templateName.getValue().toString(); + } else { + processingEnv.getMessager().printMessage(Kind.WARNING, "Can't compile a template outside of a visitor or recipe."); + return; + } + } + + JavaFileObject builderFile = processingEnv.getFiler().createSourceFile(templateFqn); + try (Writer out = new BufferedWriter(builderFile.openWriter())) { + out.write("package " + classDecl.sym.packge().toString() + ";\n"); + out.write("import org.openrewrite.java.*;\n"); + + + for (JCTree.JCVariableDecl parameter : parameters) { + if (parameter.type.tsym instanceof Symbol.ClassSymbol) { + String paramType = parameter.type.tsym.getQualifiedName().toString(); + if (!paramType.startsWith("java.lang")) { + out.write("import " + paramType + ";\n"); + } + } + } + + out.write("\n"); + out.write("public class " + templateFqn.substring(templateFqn.lastIndexOf('.') + 1) + " {\n"); + out.write(" public static JavaTemplate.Builder getTemplate(JavaVisitor visitor) {\n"); + out.write(" return JavaTemplate\n"); + out.write(" .builder(\"" + templateSource + "\")"); + + List imports = ImportDetector.imports(resolved.get(template)); + String classpath = ClasspathJarNameDetector.classpathFor(resolved.get(template), imports); + if (!classpath.isEmpty()) { + out.write("\n .javaParser(JavaParser.fromJavaVersion().classpath(" + + classpath + "))"); + } + + for (Symbol anImport : imports) { + if (anImport instanceof Symbol.ClassSymbol && !anImport.getQualifiedName().toString().startsWith("java.lang.")) { + out.write("\n .imports(\"" + ((Symbol.ClassSymbol) anImport).fullname.toString().replace('$', '.') + "\")"); + } else if (anImport instanceof Symbol.VarSymbol || anImport instanceof Symbol.MethodSymbol) { + out.write("\n .staticImports(\"" + anImport.owner.getQualifiedName().toString().replace('$', '.') + '.' + anImport.flatName().toString() + "\")"); + } + } + + out.write(";\n"); + out.write(" }\n"); + out.write("}\n"); + out.flush(); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + } + + super.visitApply(tree); + } + }.scan(cu); + } + + private boolean isOfClassType(Type type, String fqn) { + return type instanceof Type.ClassType && (((Symbol.ClassSymbol) type.tsym) + .fullname.contentEquals(fqn) || isOfClassType(((Type.ClassType) type).supertype_field, fqn)); + } + + private Stack cursor(JCCompilationUnit cu, Tree t) { + AtomicReference> matching = new AtomicReference<>(); + new TreePathScanner, Stack>() { + @Override + public Stack scan(Tree tree, Stack parent) { + Stack cursor = new Stack<>(); + cursor.addAll(parent); + cursor.push(tree); + if (tree == t) { + matching.set(cursor); + return cursor; + } + return super.scan(tree, cursor); + } + }.scan(cu, new Stack<>()); + return matching.get(); + } +} diff --git a/src/main/java/org/openrewrite/java/template/processor/TypeAwareProcessor.java b/src/main/java/org/openrewrite/java/template/processor/TypeAwareProcessor.java new file mode 100644 index 00000000..1ff11433 --- /dev/null +++ b/src/main/java/org/openrewrite/java/template/processor/TypeAwareProcessor.java @@ -0,0 +1,211 @@ +package org.openrewrite.java.template.processor; + +import com.sun.source.util.TreePath; +import com.sun.source.util.Trees; +import com.sun.tools.javac.processing.JavacProcessingEnvironment; +import com.sun.tools.javac.tree.JCTree; +import org.openrewrite.java.template.internal.Permit; +import org.openrewrite.java.template.internal.permit.Parent; +import sun.misc.Unsafe; + +import javax.annotation.processing.AbstractProcessor; +import javax.annotation.processing.ProcessingEnvironment; +import javax.lang.model.SourceVersion; +import javax.lang.model.element.Element; +import javax.tools.Diagnostic; +import java.lang.reflect.Field; +import java.lang.reflect.InvocationHandler; +import java.lang.reflect.Method; +import java.lang.reflect.Proxy; + +public abstract class TypeAwareProcessor extends AbstractProcessor { + protected ProcessingEnvironment processingEnv; + protected JavacProcessingEnvironment javacProcessingEnv; + protected Trees trees; + + /** + * We just return the latest version of whatever JDK we run on. Stupid? Yeah, but it's either that + * or warnings on all versions but 1. + */ + @Override + public SourceVersion getSupportedSourceVersion() { + return SourceVersion.latest(); + } + + @Override + public synchronized void init(ProcessingEnvironment processingEnv) { + super.init(processingEnv); + this.processingEnv = processingEnv; + this.javacProcessingEnv = getJavacProcessingEnvironment(processingEnv); + if (javacProcessingEnv == null) { + return; + } + trees = Trees.instance(javacProcessingEnv); + } + + protected JCTree.JCCompilationUnit toUnit(Element element) { + TreePath path = null; + if (trees != null) { + try { + path = trees.getPath(element); + } catch (NullPointerException ignore) { + // Happens if a package-info.java doesn't contain a package declaration. + // We can safely ignore those, since they do not need any processing + } + } + if (path == null) { + return null; + } + + return (JCTree.JCCompilationUnit) path.getCompilationUnit(); + } + + /** + * This class casts the given processing environment to a JavacProcessingEnvironment. In case of + * gradle incremental compilation, the delegate ProcessingEnvironment of the gradle wrapper is returned. + */ + public JavacProcessingEnvironment getJavacProcessingEnvironment(Object procEnv) { + addOpens(); + if (procEnv instanceof JavacProcessingEnvironment) { + return (JavacProcessingEnvironment) procEnv; + } + + // try to find a "delegate" field in the object, and use this to try to obtain a JavacProcessingEnvironment + for (Class procEnvClass = procEnv.getClass(); procEnvClass != null; procEnvClass = procEnvClass.getSuperclass()) { + Object delegate = tryGetDelegateField(procEnvClass, procEnv); + if (delegate == null) { + delegate = tryGetProxyDelegateToField(procEnv); + } + if (delegate == null) { + delegate = tryGetProcessingEnvField(procEnvClass, procEnv); + } + + if (delegate != null) { + return getJavacProcessingEnvironment(delegate); + } + // delegate field was not found, try on superclass + } + + processingEnv.getMessager().printMessage(Diagnostic.Kind.WARNING, "Can't get the delegate of the gradle " + + "IncrementalProcessingEnvironment. " + + "OpenRewrite's template processor won't work."); + return null; + } + + @SuppressWarnings({"DataFlowIssue", "JavaReflectionInvocation"}) + protected static void addOpens() { + Class cModule; + try { + cModule = Class.forName("java.lang.Module"); + } catch (ClassNotFoundException e) { + return; //jdk8-; this is not needed. + } + + Unsafe unsafe = getUnsafe(); + Object jdkCompilerModule = getJdkCompilerModule(); + Object ownModule = getOwnModule(); + String[] allPkgs = { + "com.sun.tools.javac.code", + "com.sun.tools.javac.comp", + "com.sun.tools.javac.file", + "com.sun.tools.javac.main", + "com.sun.tools.javac.model", + "com.sun.tools.javac.parser", + "com.sun.tools.javac.processing", + "com.sun.tools.javac.tree", + "com.sun.tools.javac.util", + "com.sun.tools.javac.jvm", + }; + + try { + Method m = cModule.getDeclaredMethod("implAddOpens", String.class, cModule); + long firstFieldOffset = getFirstFieldOffset(unsafe); + unsafe.putBooleanVolatile(m, firstFieldOffset, true); + for (String p : allPkgs) m.invoke(jdkCompilerModule, p, ownModule); + } catch (Exception ignore) { + } + } + + protected static long getFirstFieldOffset(Unsafe unsafe) { + try { + return unsafe.objectFieldOffset(Parent.class.getDeclaredField("first")); + } catch (NoSuchFieldException e) { + // can't happen. + throw new RuntimeException(e); + } catch (SecurityException e) { + // can't happen + throw new RuntimeException(e); + } + } + + protected static Unsafe getUnsafe() { + try { + Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe"); + theUnsafe.setAccessible(true); + return (Unsafe) theUnsafe.get(null); + } catch (Exception e) { + return null; + } + } + + protected static Object getOwnModule() { + try { + Method m = Permit.getMethod(Class.class, "getModule"); + return m.invoke(RefasterTemplateProcessor.class); + } catch (Exception e) { + return null; + } + } + + protected static Object getJdkCompilerModule() { + // call public api: ModuleLayer.boot().findModule("jdk.compiler").get(); + // but use reflection because we don't want this code to crash on jdk1.7 and below. + // In that case, none of this stuff was needed in the first place, so we just exit via + // the catch block and do nothing. + try { + Class cModuleLayer = Class.forName("java.lang.ModuleLayer"); + Method mBoot = cModuleLayer.getDeclaredMethod("boot"); + Object bootLayer = mBoot.invoke(null); + Class cOptional = Class.forName("java.util.Optional"); + Method mFindModule = cModuleLayer.getDeclaredMethod("findModule", String.class); + Object oCompilerO = mFindModule.invoke(bootLayer, "jdk.compiler"); + return cOptional.getDeclaredMethod("get").invoke(oCompilerO); + } catch (Exception e) { + return null; + } + } + + /** + * Gradle incremental processing + */ + protected Object tryGetDelegateField(Class delegateClass, Object instance) { + try { + return Permit.getField(delegateClass, "delegate").get(instance); + } catch (Exception e) { + return null; + } + } + + /** + * Kotlin incremental processing + */ + protected Object tryGetProcessingEnvField(Class delegateClass, Object instance) { + try { + return Permit.getField(delegateClass, "processingEnv").get(instance); + } catch (Exception e) { + return null; + } + } + + /** + * IntelliJ >= 2020.3 + */ + protected Object tryGetProxyDelegateToField(Object instance) { + try { + InvocationHandler handler = Proxy.getInvocationHandler(instance); + return Permit.getField(handler.getClass(), "val$delegateTo").get(handler); + } catch (Exception e) { + return null; + } + } +}