diff --git a/ChangeLog b/ChangeLog index 1cfba953a2..c718af0773 100644 --- a/ChangeLog +++ b/ChangeLog @@ -20,6 +20,10 @@ Release date: TBA Closes pylint-dev/pylint#8802 +* Infer user-defined enum classes by checking if the class is a subtype of ``enum.Enum``. + + Closes pylint-dev/pylint#8897 + * Fix inference of functions with ``@functools.lru_cache`` decorators without parentheses. diff --git a/astroid/brain/brain_namedtuple_enum.py b/astroid/brain/brain_namedtuple_enum.py index 36b703610f..36ee653605 100644 --- a/astroid/brain/brain_namedtuple_enum.py +++ b/astroid/brain/brain_namedtuple_enum.py @@ -20,7 +20,6 @@ AstroidTypeError, AstroidValueError, InferenceError, - MroError, UseInferenceDefault, ) from astroid.manager import AstroidManager @@ -31,14 +30,6 @@ from typing_extensions import Final -ENUM_BASE_NAMES = { - "Enum", - "IntEnum", - "enum.Enum", - "enum.IntEnum", - "IntFlag", - "enum.IntFlag", -} ENUM_QNAME: Final[str] = "enum.Enum" TYPING_NAMEDTUPLE_QUALIFIED: Final = { "typing.NamedTuple", @@ -606,14 +597,7 @@ def _get_namedtuple_fields(node: nodes.Call) -> str: def _is_enum_subclass(cls: astroid.ClassDef) -> bool: """Return whether cls is a subclass of an Enum.""" - try: - return any( - klass.name in ENUM_BASE_NAMES - and getattr(klass.root(), "name", None) == "enum" - for klass in cls.mro() - ) - except MroError: - return False + return cls.is_subtype_of("enum.Enum") AstroidManager().register_transform( diff --git a/tests/brain/test_enum.py b/tests/brain/test_enum.py index 9d95d2ffbb..085d00c133 100644 --- a/tests/brain/test_enum.py +++ b/tests/brain/test_enum.py @@ -493,3 +493,39 @@ def pear(self): for node in (attribute_nodes[1], name_nodes[1]): with pytest.raises(InferenceError): node.inferred() + + def test_local_enum_child_class_inference(self) -> None: + """Originally reported in /~https://github.com/pylint-dev/pylint/issues/8897 + + Test that a user-defined enum class is inferred when it subclasses + another user-defined enum class. + """ + enum_class_node, enum_member_value_node = astroid.extract_node( + """ + import sys + + from enum import Enum + + if sys.version_info >= (3, 11): + from enum import StrEnum + else: + class StrEnum(str, Enum): + pass + + + class Color(StrEnum): #@ + RED = "red" + + + Color.RED.value #@ + """ + ) + assert "RED" in enum_class_node.locals + + enum_members = enum_class_node.locals["__members__"][0].items + assert len(enum_members) == 1 + _, name = enum_members[0] + assert name.name == "RED" + + inferred_enum_member_value_node = next(enum_member_value_node.infer()) + assert inferred_enum_member_value_node.value == "red"