Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: typing.Annotated should be special-cased like typing.Literal #193

Merged
merged 4 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 54 additions & 10 deletions flake8_type_checking/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
ATTRS_DECORATORS,
ATTRS_IMPORTS,
BINOP_OPERAND_PROPERTY,
NAME_RE,
TC001,
TC002,
TC003,
Expand Down Expand Up @@ -96,7 +95,15 @@ def visit(self, node: ast.AST) -> None:
self.visit(node.value)
elif isinstance(node, ast.Subscript):
self.visit(node.value)
if getattr(node.value, 'id', '') != 'Literal':
if getattr(node.value, 'id', '') == 'Annotated' and isinstance(
(elts_node := node.slice.value if py38 and isinstance(node.slice, Index) else node.slice),
(ast.Tuple, ast.List),
):
if elts_node.elts:
# only visit the first element
self.visit(elts_node.elts[0])
# TODO: We may want to visit the rest as a soft-runtime use
elif getattr(node.value, 'id', '') != 'Literal':
self.visit(node.slice)
elif isinstance(node, (ast.Tuple, ast.List)):
for n in node.elts:
Expand Down Expand Up @@ -302,7 +309,9 @@ def visit_annotation_name(self, node: ast.Name) -> None:

def visit_annotation_string(self, node: ast.Constant) -> None:
"""Add all the names in the string to mapped names."""
self.mapped_names.update(NAME_RE.findall(node.value))
visitor = StringAnnotationVisitor()
visitor.parse_and_visit_string_annotation(node.value)
self.mapped_names.update(visitor.names)


class SQLAlchemyMixin:
Expand Down Expand Up @@ -423,10 +432,9 @@ def handle_sqlalchemy_annotation(self, node: ast.AST) -> None:
if not annotation.endswith(']'):
return

# if we ever do more sophisticated parsing of text annotations
# then we would want to strip the trailing `]` from inner, but
# with our simple parsing we don't care
mapped_name, inner = annotation.split('[', 1)
# strip trailing `]` from inner
inner = inner[:-1]
if mapped_name in self.mapped_aliases:
# record a use for the name
self.uses[mapped_name].append((node, self.current_scope))
Expand All @@ -448,7 +456,9 @@ def handle_sqlalchemy_annotation(self, node: ast.AST) -> None:
# add all names contained in the inner part of the annotation
# since this is not as strict as an actual runtime use, we don't
# care if we record too much here
self.mapped_names.update(NAME_RE.findall(inner))
visitor = StringAnnotationVisitor()
visitor.parse_and_visit_string_annotation(inner)
self.mapped_names.update(visitor.names)
return

# we only need to handle annotations like `Mapped[...]`
Expand Down Expand Up @@ -817,6 +827,40 @@ def lookup(self, symbol_name: str, use: HasPosition | None = None, runtime_only:
return parent.lookup(symbol_name, use, runtime_only)


class StringAnnotationVisitor(AnnotationVisitor):
"""Visit a parsed string annotation and collect all the names."""

def __init__(self) -> None:
#: All the names referenced inside the annotation
self.names: set[str] = set()

def parse_and_visit_string_annotation(self, annotation: str) -> None:
"""Parse and visit the given string as an annotation expression."""
try:
# in the future this simple approach may fail, because
# the quoted subexpression is only valid syntax in the context
# of the parent expression, in which case we would have to
# do something more clever here
module_node = ast.parse(f'_: _[{annotation}]')
except Exception:
# if we can't parse the annotation we should do nothing
return

ann_assign_node = module_node.body[0]
assert isinstance(ann_assign_node, ast.AnnAssign)
annotation_node = ann_assign_node.annotation
assert isinstance(annotation_node, ast.Subscript)
self.visit(annotation_node.slice)

def visit_annotation_name(self, node: ast.Name) -> None:
"""Remember all the visited names."""
self.names.add(node.id)

def visit_annotation_string(self, node: ast.Constant) -> None:
"""Parse and visit nested string annotations."""
self.parse_and_visit_string_annotation(node.value)


class ImportAnnotationVisitor(AnnotationVisitor):
"""Map all annotations on an AST node."""

Expand Down Expand Up @@ -870,10 +914,10 @@ def visit_annotation_string(self, node: ast.Constant) -> None:
if getattr(node, BINOP_OPERAND_PROPERTY, False):
self.invalid_binop_literals.append(node)
else:
visitor = StringAnnotationVisitor()
visitor.parse_and_visit_string_annotation(node.value)
(self.excess_wrapped_annotations if self.never_evaluates else self.wrapped_annotations).append(
WrappedAnnotation(
node.lineno, node.col_offset, node.value, set(NAME_RE.findall(node.value)), self.scope, self.type
)
WrappedAnnotation(node.lineno, node.col_offset, node.value, visitor.names, self.scope, self.type)
)


Expand Down
3 changes: 0 additions & 3 deletions flake8_type_checking/constants.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import builtins
import re
import sys

import flake8
Expand All @@ -8,8 +7,6 @@
ANNOTATION_PROPERTY = '_flake8-type-checking__is_annotation'
BINOP_OPERAND_PROPERTY = '_flake8-type-checking__is_binop_operand'

NAME_RE = re.compile(r'(?<![\'".])\b[A-Za-z_]\w*(?![\'"])')

ATTRS_DECORATORS = [
'attrs.define',
'attrs.frozen',
Expand Down
46 changes: 26 additions & 20 deletions tests/test_name_extraction.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,38 @@
import sys

import pytest

from flake8_type_checking.constants import NAME_RE
from flake8_type_checking.checker import StringAnnotationVisitor

examples = [
('', []),
('int', ['int']),
('dict[str, int]', ['dict', 'str', 'int']),
('', set()),
('invalid_syntax]', set()),
('int', {'int'}),
('dict[str, int]', {'dict', 'str', 'int'}),
# make sure literals don't add names for their contents
('Literal["a"]', ['Literal']),
("Literal['a']", ['Literal']),
('Literal[0]', ['Literal']),
('Literal[1.0]', ['Literal']),
# booleans are a special case and difficult to reject using a RegEx
# for now it seems harmless to include them in the names, but if
# we do something more sophisticated with the names we may want to
# explicitly remove True/False from the result set
('Literal[True]', ['Literal', 'True']),
# try some potentially upcoming syntax
('*Ts | _T & S', ['Ts', '_T', 'S']),
# even when it's formatted badly
('*Ts|_T&P', ['Ts', '_T', 'P']),
('Union[Dict[str, Any], Literal["Foo", "Bar"], _T]', ['Union', 'Dict', 'str', 'Any', 'Literal', '_T']),
('Literal["a"]', {'Literal'}),
("Literal['a']", {'Literal'}),
('Literal[0]', {'Literal'}),
('Literal[1.0]', {'Literal'}),
('Literal[True]', {'Literal'}),
('T | S', {'T', 'S'}),
('Union[Dict[str, Any], Literal["Foo", "Bar"], _T]', {'Union', 'Dict', 'str', 'Any', 'Literal', '_T'}),
# for attribute access only everything up to the first dot should count
# this matches the behavior of add_annotation
('datetime.date | os.path.sep', ['datetime', 'os']),
('datetime.date | os.path.sep', {'datetime', 'os'}),
('Nested["str"]', {'Nested', 'str'}),
('Annotated[str, validator]', {'Annotated', 'str'}),
('Annotated[str, "bool"]', {'Annotated', 'str'}),
]

if sys.version_info >= (3, 11):
examples.extend([
('*Ts', {'Ts'}),
])


@pytest.mark.parametrize(('example', 'expected'), examples)
def test_name_extraction(example, expected):
assert NAME_RE.findall(example) == expected
visitor = StringAnnotationVisitor()
visitor.parse_and_visit_string_annotation(example)
assert visitor.names == expected
20 changes: 20 additions & 0 deletions tests/test_tc001_to_tc003.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,26 @@ class Migration:
'''),
set(),
),
(
textwrap.dedent(f'''
from typing import Annotated

from {import_} import Depends

x: Annotated[str, Depends]
'''),
set(),
),
(
textwrap.dedent(f'''
from typing import Annotated

from {import_} import Depends

x: Annotated[str, "Depends"]
'''),
set(),
),
]

return [
Expand Down
Loading