diff --git a/flake8_type_checking/checker.py b/flake8_type_checking/checker.py index 40c0670..9e678ec 100644 --- a/flake8_type_checking/checker.py +++ b/flake8_type_checking/checker.py @@ -21,7 +21,6 @@ ATTRS_DECORATORS, ATTRS_IMPORTS, BINOP_OPERAND_PROPERTY, - NAME_RE, TC001, TC002, TC003, @@ -96,7 +95,15 @@ def visit(self, node: ast.AST) -> None: self.visit(node.value) elif isinstance(node, ast.Subscript): self.visit(node.value) - if getattr(node.value, 'id', '') != 'Literal': + if getattr(node.value, 'id', '') == 'Annotated' and isinstance( + (elts_node := node.slice.value if py38 and isinstance(node.slice, Index) else node.slice), + (ast.Tuple, ast.List), + ): + if elts_node.elts: + # only visit the first element + self.visit(elts_node.elts[0]) + # TODO: We may want to visit the rest as a soft-runtime use + elif getattr(node.value, 'id', '') != 'Literal': self.visit(node.slice) elif isinstance(node, (ast.Tuple, ast.List)): for n in node.elts: @@ -302,7 +309,9 @@ def visit_annotation_name(self, node: ast.Name) -> None: def visit_annotation_string(self, node: ast.Constant) -> None: """Add all the names in the string to mapped names.""" - self.mapped_names.update(NAME_RE.findall(node.value)) + visitor = StringAnnotationVisitor() + visitor.parse_and_visit_string_annotation(node.value) + self.mapped_names.update(visitor.names) class SQLAlchemyMixin: @@ -423,10 +432,9 @@ def handle_sqlalchemy_annotation(self, node: ast.AST) -> None: if not annotation.endswith(']'): return - # if we ever do more sophisticated parsing of text annotations - # then we would want to strip the trailing `]` from inner, but - # with our simple parsing we don't care mapped_name, inner = annotation.split('[', 1) + # strip trailing `]` from inner + inner = inner[:-1] if mapped_name in self.mapped_aliases: # record a use for the name self.uses[mapped_name].append((node, self.current_scope)) @@ -448,7 +456,9 @@ def handle_sqlalchemy_annotation(self, node: ast.AST) -> None: # add all names contained in the inner part of the annotation # since this is not as strict as an actual runtime use, we don't # care if we record too much here - self.mapped_names.update(NAME_RE.findall(inner)) + visitor = StringAnnotationVisitor() + visitor.parse_and_visit_string_annotation(inner) + self.mapped_names.update(visitor.names) return # we only need to handle annotations like `Mapped[...]` @@ -817,6 +827,40 @@ def lookup(self, symbol_name: str, use: HasPosition | None = None, runtime_only: return parent.lookup(symbol_name, use, runtime_only) +class StringAnnotationVisitor(AnnotationVisitor): + """Visit a parsed string annotation and collect all the names.""" + + def __init__(self) -> None: + #: All the names referenced inside the annotation + self.names: set[str] = set() + + def parse_and_visit_string_annotation(self, annotation: str) -> None: + """Parse and visit the given string as an annotation expression.""" + try: + # in the future this simple approach may fail, because + # the quoted subexpression is only valid syntax in the context + # of the parent expression, in which case we would have to + # do something more clever here + module_node = ast.parse(f'_: _[{annotation}]') + except Exception: + # if we can't parse the annotation we should do nothing + return + + ann_assign_node = module_node.body[0] + assert isinstance(ann_assign_node, ast.AnnAssign) + annotation_node = ann_assign_node.annotation + assert isinstance(annotation_node, ast.Subscript) + self.visit(annotation_node.slice) + + def visit_annotation_name(self, node: ast.Name) -> None: + """Remember all the visited names.""" + self.names.add(node.id) + + def visit_annotation_string(self, node: ast.Constant) -> None: + """Parse and visit nested string annotations.""" + self.parse_and_visit_string_annotation(node.value) + + class ImportAnnotationVisitor(AnnotationVisitor): """Map all annotations on an AST node.""" @@ -870,10 +914,10 @@ def visit_annotation_string(self, node: ast.Constant) -> None: if getattr(node, BINOP_OPERAND_PROPERTY, False): self.invalid_binop_literals.append(node) else: + visitor = StringAnnotationVisitor() + visitor.parse_and_visit_string_annotation(node.value) (self.excess_wrapped_annotations if self.never_evaluates else self.wrapped_annotations).append( - WrappedAnnotation( - node.lineno, node.col_offset, node.value, set(NAME_RE.findall(node.value)), self.scope, self.type - ) + WrappedAnnotation(node.lineno, node.col_offset, node.value, visitor.names, self.scope, self.type) ) diff --git a/flake8_type_checking/constants.py b/flake8_type_checking/constants.py index 37e21a3..7cf3c11 100644 --- a/flake8_type_checking/constants.py +++ b/flake8_type_checking/constants.py @@ -1,5 +1,4 @@ import builtins -import re import sys import flake8 @@ -8,8 +7,6 @@ ANNOTATION_PROPERTY = '_flake8-type-checking__is_annotation' BINOP_OPERAND_PROPERTY = '_flake8-type-checking__is_binop_operand' -NAME_RE = re.compile(r'(?= (3, 11): + examples.extend([ + ('*Ts', {'Ts'}), + ]) + @pytest.mark.parametrize(('example', 'expected'), examples) def test_name_extraction(example, expected): - assert NAME_RE.findall(example) == expected + visitor = StringAnnotationVisitor() + visitor.parse_and_visit_string_annotation(example) + assert visitor.names == expected diff --git a/tests/test_tc001_to_tc003.py b/tests/test_tc001_to_tc003.py index 4b1f6b0..4a3fcf5 100644 --- a/tests/test_tc001_to_tc003.py +++ b/tests/test_tc001_to_tc003.py @@ -195,6 +195,26 @@ class Migration: '''), set(), ), + ( + textwrap.dedent(f''' + from typing import Annotated + + from {import_} import Depends + + x: Annotated[str, Depends] + '''), + set(), + ), + ( + textwrap.dedent(f''' + from typing import Annotated + + from {import_} import Depends + + x: Annotated[str, "Depends"] + '''), + set(), + ), ] return [