diff --git a/error-prone-contrib/src/main/java/tech/picnic/errorprone/bugpatterns/RequestParamType.java b/error-prone-contrib/src/main/java/tech/picnic/errorprone/bugpatterns/RequestParamType.java index eb076e38673..eada0abe492 100644 --- a/error-prone-contrib/src/main/java/tech/picnic/errorprone/bugpatterns/RequestParamType.java +++ b/error-prone-contrib/src/main/java/tech/picnic/errorprone/bugpatterns/RequestParamType.java @@ -1,5 +1,6 @@ package tech.picnic.errorprone.bugpatterns; +import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.errorprone.BugPattern.LinkType.CUSTOM; import static com.google.errorprone.BugPattern.SeverityLevel.ERROR; import static com.google.errorprone.BugPattern.StandardTags.LIKELY_ERROR; @@ -9,17 +10,21 @@ import static com.google.errorprone.matchers.Matchers.anyOf; import static com.google.errorprone.matchers.Matchers.isSubtypeOf; import static com.google.errorprone.matchers.Matchers.isType; +import static com.google.errorprone.matchers.Matchers.not; import static tech.picnic.errorprone.bugpatterns.util.Documentation.BUG_PATTERNS_BASE_URL; import com.google.auto.service.AutoService; import com.google.common.collect.ImmutableCollection; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.errorprone.BugPattern; +import com.google.errorprone.ErrorProneFlags; import com.google.errorprone.VisitorState; import com.google.errorprone.bugpatterns.BugChecker; import com.google.errorprone.bugpatterns.BugChecker.VariableTreeMatcher; import com.google.errorprone.matchers.Description; import com.google.errorprone.matchers.Matcher; +import com.sun.source.tree.Tree; import com.sun.source.tree.VariableTree; /** A {@link BugChecker} that flags {@code @RequestParam} parameters with an unsupported type. */ @@ -32,18 +37,58 @@ tags = LIKELY_ERROR) public final class RequestParamType extends BugChecker implements VariableTreeMatcher { private static final long serialVersionUID = 1L; - private static final Matcher HAS_UNSUPPORTED_REQUEST_PARAM = - allOf( - annotations(AT_LEAST_ONE, isType("org.springframework.web.bind.annotation.RequestParam")), - anyOf(isSubtypeOf(ImmutableCollection.class), isSubtypeOf(ImmutableMap.class))); + private static final String FLAG_PREFIX = "RequestParamType:"; + private static final String INCLUDED_CLASS_FLAG = FLAG_PREFIX + "Includes"; - /** Instantiates a new {@link RequestParamType} instance. */ - public RequestParamType() {} + private final Matcher hasUnsupportedRequestParams; + + /** Instantiates a default {@link RequestParamType} instance. */ + public RequestParamType() { + this(ErrorProneFlags.empty()); + } + + /** + * Instantiates a customized {@link RequestParamType} instance. + * + * @param flags Any provided command line flags. + */ + public RequestParamType(ErrorProneFlags flags) { + hasUnsupportedRequestParams = createVariableTreeMatcher(flags); + } @Override public Description matchVariable(VariableTree tree, VisitorState state) { - return HAS_UNSUPPORTED_REQUEST_PARAM.matches(tree, state) + return hasUnsupportedRequestParams.matches(tree, state) ? describeMatch(tree) : Description.NO_MATCH; } + + private static Matcher createVariableTreeMatcher(ErrorProneFlags flags) { + return allOf( + annotations(AT_LEAST_ONE, isType("org.springframework.web.bind.annotation.RequestParam")), + anyOf(isSubtypeOf(ImmutableCollection.class), isSubtypeOf(ImmutableMap.class)), + not(anyOf(getSupportedClasses(includedClassNames(flags))))); + } + + private static ImmutableList includedClassNames(ErrorProneFlags flags) { + return flags.getList(INCLUDED_CLASS_FLAG).map(ImmutableList::copyOf).orElse(ImmutableList.of()); + } + + private static ImmutableList> getSupportedClasses( + ImmutableList inclusions) { + return inclusions.stream() + .filter(inclusion -> !inclusion.isEmpty()) + .map(String::trim) + .map(inclusion -> isSubtypeOf(createClass(inclusion))) + .collect(toImmutableList()); + } + + private static Class createClass(String className) { + try { + return Class.forName(className); + } catch (ClassNotFoundException e) { + throw new IllegalArgumentException( + String.format("Invalid class name '%s' in `%s`", className, INCLUDED_CLASS_FLAG), e); + } + } } diff --git a/error-prone-contrib/src/test/java/tech/picnic/errorprone/bugpatterns/RequestParamTypeTest.java b/error-prone-contrib/src/test/java/tech/picnic/errorprone/bugpatterns/RequestParamTypeTest.java index 45d58534c67..cb268714985 100644 --- a/error-prone-contrib/src/test/java/tech/picnic/errorprone/bugpatterns/RequestParamTypeTest.java +++ b/error-prone-contrib/src/test/java/tech/picnic/errorprone/bugpatterns/RequestParamTypeTest.java @@ -1,11 +1,22 @@ package tech.picnic.errorprone.bugpatterns; +import com.google.common.collect.ImmutableList; import com.google.errorprone.CompilationTestHelper; import org.junit.jupiter.api.Test; final class RequestParamTypeTest { private final CompilationTestHelper compilationTestHelper = CompilationTestHelper.newInstance(RequestParamType.class, getClass()); + private final CompilationTestHelper restrictedCompilationTestHelper = + CompilationTestHelper.newInstance(RequestParamType.class, getClass()) + .setArgs( + ImmutableList.of( + "-XepOpt:RequestParamType:Includes=com.google.common.collect.ImmutableCollection")); + private final CompilationTestHelper restrictedWithSubTypeCompilationTestHelper = + CompilationTestHelper.newInstance(RequestParamType.class, getClass()) + .setArgs( + ImmutableList.of( + "-XepOpt:RequestParamType:Includes=com.google.common.collect.ImmutableSet")); @Test void identification() { @@ -63,4 +74,97 @@ void identification() { "}") .doTest(); } + + @Test + void identificationOfIncludedClass() { + restrictedCompilationTestHelper + .addSourceLines( + "A.java", + "import com.google.common.collect.ImmutableBiMap;", + "import com.google.common.collect.ImmutableList;", + "import com.google.common.collect.ImmutableMap;", + "import com.google.common.collect.ImmutableSet;", + "import java.util.List;", + "import java.util.Map;", + "import java.util.Set;", + "import org.jspecify.annotations.Nullable;", + "import org.springframework.web.bind.annotation.DeleteMapping;", + "import org.springframework.web.bind.annotation.GetMapping;", + "import org.springframework.web.bind.annotation.PostMapping;", + "import org.springframework.web.bind.annotation.PutMapping;", + "import org.springframework.web.bind.annotation.RequestBody;", + "import org.springframework.web.bind.annotation.RequestParam;", + "", + "interface A {", + " @PostMapping", + " A properRequestParam(@RequestBody String body);", + "", + " @GetMapping", + " A properRequestParam(@RequestParam int param);", + "", + " @GetMapping", + " A properRequestParam(@RequestParam List param);", + "", + " @PostMapping", + " A properRequestParam(@RequestBody String body, @RequestParam Set param);", + "", + " @PutMapping", + " A properRequestParam(@RequestBody String body, @RequestParam Map param);", + "", + " @GetMapping", + " // BUG: Diagnostic contains:", + " A get(@RequestParam ImmutableBiMap param);", + "", + " @PostMapping", + " A post(@Nullable @RequestParam ImmutableList param);", + "", + " @PutMapping", + " A put(@RequestBody String body, @RequestParam ImmutableSet param);", + "", + " @DeleteMapping", + " // BUG: Diagnostic contains:", + " A delete(@RequestBody String body, @RequestParam ImmutableMap param);", + "", + " void negative(ImmutableSet set, ImmutableMap map);", + "}") + .doTest(); + } + + @Test + void identificationOfIncludedSubClass() { + restrictedWithSubTypeCompilationTestHelper + .addSourceLines( + "A.java", + "import com.google.common.collect.ImmutableBiMap;", + "import com.google.common.collect.ImmutableList;", + "import com.google.common.collect.ImmutableMap;", + "import com.google.common.collect.ImmutableSet;", + "import org.jspecify.annotations.Nullable;", + "import org.springframework.web.bind.annotation.DeleteMapping;", + "import org.springframework.web.bind.annotation.GetMapping;", + "import org.springframework.web.bind.annotation.PostMapping;", + "import org.springframework.web.bind.annotation.PutMapping;", + "import org.springframework.web.bind.annotation.RequestBody;", + "import org.springframework.web.bind.annotation.RequestParam;", + "", + "interface A {", + " @GetMapping", + " // BUG: Diagnostic contains:", + " A get(@RequestParam ImmutableBiMap param);", + "", + " @PostMapping", + " // BUG: Diagnostic contains:", + " A post(@Nullable @RequestParam ImmutableList param);", + "", + " @PutMapping", + " A put(@RequestBody String body, @RequestParam ImmutableSet param);", + "", + " @DeleteMapping", + " // BUG: Diagnostic contains:", + " A delete(@RequestBody String body, @RequestParam ImmutableMap param);", + "", + " void negative(ImmutableSet set, ImmutableMap map);", + "}") + .doTest(); + } }