Skip to content

Commit

Permalink
Infer user-defined enum classes by checking if the class is a subtype…
Browse files Browse the repository at this point in the history
… of ``enum.Enum`` (#2277)

* Infer user-defined enum classes by checking if the class is a subtype of ``enum.Enum``.

Co-authored-by: Jacob Walls <jacobtylerwalls@gmail.com>
  • Loading branch information
mbyrnepr2 and jacobtylerwalls authored Sep 23, 2023
1 parent ea78827 commit c5352d5
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 17 deletions.
3 changes: 3 additions & 0 deletions ChangeLog
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,9 @@ Release date: TBA

Closes pylint-dev/pylint#8802

* Infer user-defined enum classes by checking if the class is a subtype of ``enum.Enum``.

Closes pylint-dev/pylint#8897

* Fix false positives for ``no-member`` and ``invalid-name`` when using the ``_name_``, ``_value_`` and ``_ignore_`` sunders in Enums.

Expand Down
18 changes: 1 addition & 17 deletions astroid/brain/brain_namedtuple_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,10 @@
AstroidTypeError,
AstroidValueError,
InferenceError,
MroError,
UseInferenceDefault,
)
from astroid.manager import AstroidManager

ENUM_BASE_NAMES = {
"Enum",
"IntEnum",
"enum.Enum",
"enum.IntEnum",
"IntFlag",
"enum.IntFlag",
}
ENUM_QNAME: Final[str] = "enum.Enum"
TYPING_NAMEDTUPLE_QUALIFIED: Final = {
"typing.NamedTuple",
Expand Down Expand Up @@ -653,14 +644,7 @@ def _get_namedtuple_fields(node: nodes.Call) -> str:

def _is_enum_subclass(cls: astroid.ClassDef) -> bool:
"""Return whether cls is a subclass of an Enum."""
try:
return any(
klass.name in ENUM_BASE_NAMES
and getattr(klass.root(), "name", None) == "enum"
for klass in cls.mro()
)
except MroError:
return False
return cls.is_subtype_of("enum.Enum")


def register(manager: AstroidManager) -> None:
Expand Down
36 changes: 36 additions & 0 deletions tests/brain/test_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,42 @@ def __init__(self, mass, radius):
assert mars[1].name == "MARS"
assert radius[1].name == "radius"

def test_local_enum_child_class_inference(self) -> None:
"""Originally reported in /~https://github.com/pylint-dev/pylint/issues/8897
Test that a user-defined enum class is inferred when it subclasses
another user-defined enum class.
"""
enum_class_node, enum_member_value_node = astroid.extract_node(
"""
import sys
from enum import Enum
if sys.version_info >= (3, 11):
from enum import StrEnum
else:
class StrEnum(str, Enum):
pass
class Color(StrEnum): #@
RED = "red"
Color.RED.value #@
"""
)
assert "RED" in enum_class_node.locals

enum_members = enum_class_node.locals["__members__"][0].items
assert len(enum_members) == 1
_, name = enum_members[0]
assert name.name == "RED"

inferred_enum_member_value_node = next(enum_member_value_node.infer())
assert inferred_enum_member_value_node.value == "red"

def test_enum_with_ignore(self) -> None:
"""Exclude ``_ignore_`` from the ``__members__`` container
Originally reported in /~https://github.com/pylint-dev/pylint/issues/9015
Expand Down
3 changes: 3 additions & 0 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4944,6 +4944,9 @@ def __class_getitem__(self, value):
"""
klass = extract_node(code)
context = InferenceContext()
# For this test, we want a fresh inference, rather than a cache hit on
# the inference done at brain time in _is_enum_subclass()
context.lookupname = "Fresh lookup!"
_ = klass.getitem(0, context=context)

assert next(iter(context.path))[0].name == "Parent"
Expand Down

0 comments on commit c5352d5

Please sign in to comment.