diff --git a/src/fixit/tests/rule.py b/src/fixit/tests/rule.py index 5da83531..61ede2c6 100644 --- a/src/fixit/tests/rule.py +++ b/src/fixit/tests/rule.py @@ -73,6 +73,13 @@ def visit_Module(self, node: cst.Module) -> bool: def visit_ClassDef(self, node: cst.ClassDef) -> bool: self.report(node, "class def") + for d in node.decorators: + self.report(d, "class decorator") + return False + + def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: + if node.name.value == "problem": + self.report(node, "problem function") return False def visit_Pass(self, node: cst.Pass) -> bool: @@ -280,6 +287,119 @@ class Foo(object): "class def", (5, 0), ), + ( + # before function decorators + """ + import sys + + # lint-fixme: ExerciseReport + @contextmanager + def problem(): + yield True + """, + None, + None, + ), + ( + # after function decorators + """ + import sys + + @contextmanager + # lint-fixme: ExerciseReport + def problem(): + yield True + """, + None, + None, + ), + ( + # before class decorators + """ + import dataclasses + + # lint-fixme: ExerciseReport + @dataclasses.dataclass + class C: + value = 1 + """, + None, + None, + ), + ( + # after class decorators + """ + import dataclasses + + @dataclasses.dataclass + # lint-fixme: ExerciseReport + class C: + value = 1 + """, + None, + None, + ), + ( + # above comprehension + """ + # lint-fixme: ExerciseReport + [... for _ in range(1)] + """, + None, + None, + ), + ( + # inside comprehension + """ + [ + # lint-fixme: ExerciseReport + ... for _ in range(1) + ] + """, + None, + None, + ), + ( + # after comprehension + """ + [... for _ in range(1)] # lint-fixme: ExerciseReport + """, + None, + None, + ), + ( + # trailing inline comprehension + """ + [ + ... for _ in range(1) # lint-fixme: ExerciseReport + ] + """, + None, + None, + ), + ( + # before list element + """ + [ + # lint-fixme: ExerciseReport + ..., + None, + ] + """, + None, + None, + ), + ( + # trailing list element + """ + [ + ..., # lint-fixme: ExerciseReport + None, + ] + """, + None, + None, + ), ): idx += 1 content = dedent(code).encode("utf-8")