Skip to content

Commit

Permalink
Infer user-defined enum classes by checking if the class is a subtype…
Browse files Browse the repository at this point in the history
… of ``enum.Enum`` (#2277) (#2298)

(cherry picked from commit c5352d5)
Co-authored-by: Jacob Walls <[email protected]>

Co-authored-by: Mark Byrne <[email protected]>
  • Loading branch information
jacobtylerwalls and mbyrnepr2 authored Sep 23, 2023
1 parent 0e0dd9c commit 1f62beb
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 17 deletions.
4 changes: 4 additions & 0 deletions ChangeLog
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ Release date: TBA

Closes pylint-dev/pylint#8802

* Infer user-defined enum classes by checking if the class is a subtype of ``enum.Enum``.

Closes pylint-dev/pylint#8897

* Fix inference of functions with ``@functools.lru_cache`` decorators without
parentheses.

Expand Down
18 changes: 1 addition & 17 deletions astroid/brain/brain_namedtuple_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
AstroidTypeError,
AstroidValueError,
InferenceError,
MroError,
UseInferenceDefault,
)
from astroid.manager import AstroidManager
Expand All @@ -31,14 +30,6 @@
from typing_extensions import Final


ENUM_BASE_NAMES = {
"Enum",
"IntEnum",
"enum.Enum",
"enum.IntEnum",
"IntFlag",
"enum.IntFlag",
}
ENUM_QNAME: Final[str] = "enum.Enum"
TYPING_NAMEDTUPLE_QUALIFIED: Final = {
"typing.NamedTuple",
Expand Down Expand Up @@ -606,14 +597,7 @@ def _get_namedtuple_fields(node: nodes.Call) -> str:

def _is_enum_subclass(cls: astroid.ClassDef) -> bool:
"""Return whether cls is a subclass of an Enum."""
try:
return any(
klass.name in ENUM_BASE_NAMES
and getattr(klass.root(), "name", None) == "enum"
for klass in cls.mro()
)
except MroError:
return False
return cls.is_subtype_of("enum.Enum")


AstroidManager().register_transform(
Expand Down
36 changes: 36 additions & 0 deletions tests/brain/test_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,3 +493,39 @@ def pear(self):
for node in (attribute_nodes[1], name_nodes[1]):
with pytest.raises(InferenceError):
node.inferred()

def test_local_enum_child_class_inference(self) -> None:
"""Originally reported in https://github.com/pylint-dev/pylint/issues/8897
Test that a user-defined enum class is inferred when it subclasses
another user-defined enum class.
"""
enum_class_node, enum_member_value_node = astroid.extract_node(
"""
import sys
from enum import Enum
if sys.version_info >= (3, 11):
from enum import StrEnum
else:
class StrEnum(str, Enum):
pass
class Color(StrEnum): #@
RED = "red"
Color.RED.value #@
"""
)
assert "RED" in enum_class_node.locals

enum_members = enum_class_node.locals["__members__"][0].items
assert len(enum_members) == 1
_, name = enum_members[0]
assert name.name == "RED"

inferred_enum_member_value_node = next(enum_member_value_node.infer())
assert inferred_enum_member_value_node.value == "red"

0 comments on commit 1f62beb

Please sign in to comment.