diff --git a/error-prone-contrib/src/main/java/tech/picnic/errorprone/bugpatterns/PrimitiveComparisonCheck.java b/error-prone-contrib/src/main/java/tech/picnic/errorprone/bugpatterns/PrimitiveComparisonCheck.java index 8cbb99f53a..3750631c4c 100644 --- a/error-prone-contrib/src/main/java/tech/picnic/errorprone/bugpatterns/PrimitiveComparisonCheck.java +++ b/error-prone-contrib/src/main/java/tech/picnic/errorprone/bugpatterns/PrimitiveComparisonCheck.java @@ -78,12 +78,12 @@ public Description matchMethodInvocation(MethodInvocationTree tree, VisitorState } return getPotentiallyBoxedReturnType(tree.getArguments().get(0)) - .flatMap(cmpType -> tryMakeMethodCallMorePrecise(tree, cmpType, isStatic, state)) + .flatMap(cmpType -> attemptMethodInvocationReplacement(tree, cmpType, isStatic, state)) .map(fix -> describeMatch(tree, fix)) .orElse(Description.NO_MATCH); } - private static Optional tryMakeMethodCallMorePrecise( + private static Optional attemptMethodInvocationReplacement( MethodInvocationTree tree, Type cmpType, boolean isStatic, VisitorState state) { return Optional.ofNullable(ASTHelpers.getSymbol(tree)) .map(methodSymbol -> methodSymbol.getSimpleName().toString()) @@ -93,26 +93,35 @@ private static Optional tryMakeMethodCallMorePrecise( .filter(not(actualMethodName::equals))) .map( preferredMethodName -> - mayPrefixWithTypeArguments(preferredMethodName, tree, cmpType, state)) + prefixTypeArgumentsIfRelevant(preferredMethodName, tree, cmpType, state)) .map(preferredMethodName -> suggestFix(tree, preferredMethodName, state)); } - private static String mayPrefixWithTypeArguments( + /** + * Prefixes the given method name with generic type parameters if it replaces a {@code + * Comparator#comparing{,Double,Long,Int}} method which also has generic type parameters. + * + *

Such type parameters are retained as they are likely required. + * + *

Note that any type parameter to {@code Comparator#thenComparing} is likely redundant, and in + * any case becomes obsolete once that method is replaced with {@code + * Comparator#thenComparing{Double,Long,Int}}. Conversion in the opposite direction does not + * require the introduction of a generic type parameter. + */ + private static String prefixTypeArgumentsIfRelevant( String preferredMethodName, MethodInvocationTree tree, Type cmpType, VisitorState state) { - int typeArgumentsCount = tree.getTypeArguments().size(); - boolean methodNameIsComparing = "comparing".equals(preferredMethodName); - - if (typeArgumentsCount == 0 || (typeArgumentsCount == 1 && !methodNameIsComparing)) { + if (tree.getTypeArguments().isEmpty() || preferredMethodName.startsWith("then")) { return preferredMethodName; } - String typeArgument = + String typeArguments = Stream.concat( Stream.of(Util.treeToString(tree.getTypeArguments().get(0), state)), - Stream.of(cmpType.tsym.getSimpleName()).filter(u -> methodNameIsComparing)) - .collect(joining(",")); + Stream.of(cmpType.tsym.getSimpleName()) + .filter(u -> "comparing".equals(preferredMethodName))) + .collect(joining(", ", "<", ">")); - return String.format("<%s>%s", typeArgument, preferredMethodName); + return typeArguments + preferredMethodName; } private static String getPreferredMethod(Type cmpType, boolean isStatic, VisitorState state) { diff --git a/error-prone-contrib/src/test/java/tech/picnic/errorprone/bugpatterns/PrimitiveComparisonCheckTest.java b/error-prone-contrib/src/test/java/tech/picnic/errorprone/bugpatterns/PrimitiveComparisonCheckTest.java index dfbd5ded15..90e9cf9727 100644 --- a/error-prone-contrib/src/test/java/tech/picnic/errorprone/bugpatterns/PrimitiveComparisonCheckTest.java +++ b/error-prone-contrib/src/test/java/tech/picnic/errorprone/bugpatterns/PrimitiveComparisonCheckTest.java @@ -434,21 +434,35 @@ void replacementWithPrimitiveVariants() { "import java.util.Comparator;", "", "interface A extends Comparable {", - " Comparator bCmp = Comparator.comparing(o -> (byte) 0);", - " Comparator cCmp = Comparator.comparing(o -> (char) 0);", - " Comparator sCmp = Comparator.comparing(o -> (short) 0);", - " Comparator iCmp = Comparator.comparing(o -> 0);", - " Comparator lCmp = Comparator.comparing(o -> 0L);", - " Comparator fCmp = Comparator.comparing(o -> 0.0f);", - " Comparator dCmp = Comparator.comparing(o -> 0.0);", + " Comparator bCmp = Comparator.comparing(o -> (byte) 0);", + " Comparator bCmp2 = Comparator.comparing(o -> (byte) 0);", + " Comparator cCmp = Comparator.comparing(o -> (char) 0);", + " Comparator cCmp2 = Comparator.comparing(o -> (char) 0);", + " Comparator sCmp = Comparator.comparing(o -> (short) 0);", + " Comparator sCmp2 = Comparator.comparing(o -> (short) 0);", + " Comparator iCmp = Comparator.comparing(o -> 0);", + " Comparator iCmp2 = Comparator.comparing(o -> 0);", + " Comparator lCmp = Comparator.comparing(o -> 0L);", + " Comparator lCmp2 = Comparator.comparing(o -> 0L);", + " Comparator fCmp = Comparator.comparing(o -> 0.0f);", + " Comparator fCmp2 = Comparator.comparing(o -> 0.0f);", + " Comparator dCmp = Comparator.comparing(o -> 0.0);", + " Comparator dCmp2 = Comparator.comparing(o -> 0.0);", "", " default void m() {", + " bCmp.thenComparing(o -> (byte) 0);", " bCmp.thenComparing(o -> (byte) 0);", + " cCmp.thenComparing(o -> (char) 0);", " cCmp.thenComparing(o -> (char) 0);", + " sCmp.thenComparing(o -> (short) 0);", " sCmp.thenComparing(o -> (short) 0);", + " iCmp.thenComparing(o -> 0);", " iCmp.thenComparing(o -> 0);", + " lCmp.thenComparing(o -> 0L);", " lCmp.thenComparing(o -> 0L);", + " fCmp.thenComparing(o -> 0.0f);", " fCmp.thenComparing(o -> 0.0f);", + " dCmp.thenComparing(o -> 0.0);", " dCmp.thenComparing(o -> 0.0);", " }", "}") @@ -457,22 +471,36 @@ void replacementWithPrimitiveVariants() { "import java.util.Comparator;", "", "interface A extends Comparable {", - " Comparator bCmp = Comparator.comparingInt(o -> (byte) 0);", - " Comparator cCmp = Comparator.comparingInt(o -> (char) 0);", - " Comparator sCmp = Comparator.comparingInt(o -> (short) 0);", - " Comparator iCmp = Comparator.comparingInt(o -> 0);", - " Comparator lCmp = Comparator.comparingLong(o -> 0L);", - " Comparator fCmp = Comparator.comparingDouble(o -> 0.0f);", - " Comparator dCmp = Comparator.comparingDouble(o -> 0.0);", + " Comparator bCmp = Comparator.comparingInt(o -> (byte) 0);", + " Comparator bCmp2 = Comparator.comparingInt(o -> (byte) 0);", + " Comparator cCmp = Comparator.comparingInt(o -> (char) 0);", + " Comparator cCmp2 = Comparator.comparingInt(o -> (char) 0);", + " Comparator sCmp = Comparator.comparingInt(o -> (short) 0);", + " Comparator sCmp2 = Comparator.comparingInt(o -> (short) 0);", + " Comparator iCmp = Comparator.comparingInt(o -> 0);", + " Comparator iCmp2 = Comparator.comparingInt(o -> 0);", + " Comparator lCmp = Comparator.comparingLong(o -> 0L);", + " Comparator lCmp2 = Comparator.comparingLong(o -> 0L);", + " Comparator fCmp = Comparator.comparingDouble(o -> 0.0f);", + " Comparator fCmp2 = Comparator.comparingDouble(o -> 0.0f);", + " Comparator dCmp = Comparator.comparingDouble(o -> 0.0);", + " Comparator dCmp2 = Comparator.comparingDouble(o -> 0.0);", "", " default void m() {", " bCmp.thenComparingInt(o -> (byte) 0);", + " bCmp.thenComparingInt(o -> (byte) 0);", + " cCmp.thenComparingInt(o -> (char) 0);", " cCmp.thenComparingInt(o -> (char) 0);", " sCmp.thenComparingInt(o -> (short) 0);", + " sCmp.thenComparingInt(o -> (short) 0);", + " iCmp.thenComparingInt(o -> 0);", " iCmp.thenComparingInt(o -> 0);", " lCmp.thenComparingLong(o -> 0L);", + " lCmp.thenComparingLong(o -> 0L);", + " fCmp.thenComparingDouble(o -> 0.0f);", " fCmp.thenComparingDouble(o -> 0.0f);", " dCmp.thenComparingDouble(o -> 0.0);", + " dCmp.thenComparingDouble(o -> 0.0);", " }", "}") .doTest(TestMode.TEXT_MATCH); @@ -486,13 +514,20 @@ void replacementWithBoxedVariants() { "import java.util.Comparator;", "", "interface A extends Comparable {", - " Comparator bCmp = Comparator.comparingInt(o -> Byte.valueOf((byte) 0));", - " Comparator cCmp = Comparator.comparingInt(o -> Character.valueOf((char) 0));", - " Comparator sCmp = Comparator.comparingInt(o -> Short.valueOf((short) 0));", - " Comparator iCmp = Comparator.comparingInt(o -> Integer.valueOf(0));", - " Comparator lCmp = Comparator.comparingLong(o -> Long.valueOf(0));", - " Comparator fCmp = Comparator.comparingDouble(o -> Float.valueOf(0));", - " Comparator dCmp = Comparator.comparingDouble(o -> Double.valueOf(0));", + " Comparator bCmp = Comparator.comparingInt(o -> Byte.valueOf((byte) 0));", + " Comparator bCmp2 = Comparator.comparingInt(o -> Byte.valueOf((byte) 0));", + " Comparator cCmp = Comparator.comparingInt(o -> Character.valueOf((char) 0));", + " Comparator cCmp2 = Comparator.comparingInt(o -> Character.valueOf((char) 0));", + " Comparator sCmp = Comparator.comparingInt(o -> Short.valueOf((short) 0));", + " Comparator sCmp2 = Comparator.comparingInt(o -> Short.valueOf((short) 0));", + " Comparator iCmp = Comparator.comparingInt(o -> Integer.valueOf(0));", + " Comparator iCmp2 = Comparator.comparingInt(o -> Integer.valueOf(0));", + " Comparator lCmp = Comparator.comparingLong(o -> Long.valueOf(0));", + " Comparator lCmp2 = Comparator.comparingLong(o -> Long.valueOf(0));", + " Comparator fCmp = Comparator.comparingDouble(o -> Float.valueOf(0));", + " Comparator fCmp2 = Comparator.comparingDouble(o -> Float.valueOf(0));", + " Comparator dCmp = Comparator.comparingDouble(o -> Double.valueOf(0));", + " Comparator dCmp2 = Comparator.comparingDouble(o -> Double.valueOf(0));", "", " default void m() {", " bCmp.thenComparingInt(o -> Byte.valueOf((byte) 0));", @@ -509,13 +544,20 @@ void replacementWithBoxedVariants() { "import java.util.Comparator;", "", "interface A extends Comparable {", - " Comparator bCmp = Comparator.comparing(o -> Byte.valueOf((byte) 0));", - " Comparator cCmp = Comparator.comparing(o -> Character.valueOf((char) 0));", - " Comparator sCmp = Comparator.comparing(o -> Short.valueOf((short) 0));", - " Comparator iCmp = Comparator.comparing(o -> Integer.valueOf(0));", - " Comparator lCmp = Comparator.comparing(o -> Long.valueOf(0));", - " Comparator fCmp = Comparator.comparing(o -> Float.valueOf(0));", - " Comparator dCmp = Comparator.comparing(o -> Double.valueOf(0));", + " Comparator bCmp = Comparator.comparing(o -> Byte.valueOf((byte) 0));", + " Comparator bCmp2 = Comparator.comparing(o -> Byte.valueOf((byte) 0));", + " Comparator cCmp = Comparator.comparing(o -> Character.valueOf((char) 0));", + " Comparator cCmp2 = Comparator.comparing(o -> Character.valueOf((char) 0));", + " Comparator sCmp = Comparator.comparing(o -> Short.valueOf((short) 0));", + " Comparator sCmp2 = Comparator.comparing(o -> Short.valueOf((short) 0));", + " Comparator iCmp = Comparator.comparing(o -> Integer.valueOf(0));", + " Comparator iCmp2 = Comparator.comparing(o -> Integer.valueOf(0));", + " Comparator lCmp = Comparator.comparing(o -> Long.valueOf(0));", + " Comparator lCmp2 = Comparator.comparing(o -> Long.valueOf(0));", + " Comparator fCmp = Comparator.comparing(o -> Float.valueOf(0));", + " Comparator fCmp2 = Comparator.comparing(o -> Float.valueOf(0));", + " Comparator dCmp = Comparator.comparing(o -> Double.valueOf(0));", + " Comparator dCmp2 = Comparator.comparing(o -> Double.valueOf(0));", "", " default void m() {", " bCmp.thenComparing(o -> Byte.valueOf((byte) 0));",