diff --git a/flake8_builtins.py b/flake8_builtins.py index e95d1db..2afc5d7 100644 --- a/flake8_builtins.py +++ b/flake8_builtins.py @@ -140,7 +140,9 @@ def run(self): def check_assignment(self, statement): msg = self.assign_msg - if type(statement.__flake8_builtins_parent) is ast.ClassDef: + if isinstance(statement.__flake8_builtins_parent, ast.ClassDef): + if "TypedDict" in {a.id for a in statement.__flake8_builtins_parent.bases if isinstance(a, ast.Name)}: + return msg = self.class_attribute_msg if isinstance(statement, ast.Assign): diff --git a/run_tests.py b/run_tests.py index 26064e0..26312a0 100644 --- a/run_tests.py +++ b/run_tests.py @@ -529,3 +529,77 @@ def test_module_name_ignore_module(): def test_module_name_not_builtin(): source = '' check_code(source, filename='log_config') + + + +@pytest.mark.skipif( + sys.version_info < (3, 9), + reason='Typed dicts are introduced in Python 3.9', +) +def test_typed_dicts_dont_shadow_builtins(): + source = """ + from typing import Any, TypedDict + + + class PaginatedResponse(TypedDict): + count: int + next: str + prev: str + data: list[Any] + """ + check_code(source) + + +@pytest.mark.xfail(reason="N-deep inheritence not supported") +@pytest.mark.skipif( + sys.version_info < (3, 9), + reason='Typed dicts are introduced in Python 3.9', +) +def test_inherited_typed_dicts_dont_shadow_builtins(): + source = """ + from typing import Any, TypedDict + + + class ApiResponseDict(TypedDict): + status: int + okay: bool + + + class PaginatedResponse(ApiResponseDict): + count: int + next: str + prev: str + data: list[Any] + """ + check_code(source) + + +@pytest.mark.xfail(reason="N-deep inheritence not supported") +@pytest.mark.skipif( + sys.version_info < (3, 9), + reason='Typed dicts are introduced in Python 3.9', +) +def test_n_deep_inherited_typed_dicts_dont_shadow_builtins(): + source = """ + from typing import Any, TypedDict + + + class ApiResponseDict(TypedDict): + status: int + okay: bool + + + class Something(ApiResponseDict): + okay: int + + + class SomethingElse(Something): + is_an_oof: bool + + class PaginatedResponse(SomethingElse): + count: int + next: str + prev: str + data: list[Any] + """ + check_code(source)