diff --git a/py2mojo/converters/assignment.py b/py2mojo/converters/assignment.py index 1aacb86..a319042 100644 --- a/py2mojo/converters/assignment.py +++ b/py2mojo/converters/assignment.py @@ -1,8 +1,8 @@ import ast from functools import partial -from typing import Iterable +from typing import Callable, Iterable -from tokenize_rt import Token +from tokenize_rt import Offset, Token from ..helpers import ast_to_offset, get_annotation_type, find_token, find_token_by_name, get_mojo_type from ..rules import RuleSet @@ -20,7 +20,7 @@ def _replace_assignment(tokens: list[Token], i: int, rules: RuleSet, new_type: s tokens.insert(type_idx, Token(name='NAME', src=new_type)) -def convert_assignment(node: ast.AnnAssign, rules: RuleSet) -> Iterable: +def convert_assignment(node: ast.AnnAssign, rules: RuleSet) -> Iterable[tuple[Offset, Callable]]: """Convert an assignment to a mojo assignment.""" curr_type = get_annotation_type(node.annotation) new_type = get_mojo_type(curr_type, rules) diff --git a/py2mojo/main.py b/py2mojo/main.py index 1647685..cbf561e 100644 --- a/py2mojo/main.py +++ b/py2mojo/main.py @@ -2,13 +2,14 @@ import argparse import ast +from collections.abc import Iterable import os import sys import tokenize from collections import defaultdict -from typing import Callable, Sequence +from typing import Callable, Sequence, TypeAlias -from tokenize_rt import Token, reversed_enumerate, src_to_tokens, tokens_to_src +from tokenize_rt import Offset, reversed_enumerate, src_to_tokens, tokens_to_src from .converters import convert_assignment, convert_functiondef, convert_classdef from .exceptions import ParseException @@ -16,7 +17,7 @@ from .rules import get_rules, RuleSet -TokenFunc = Callable[[list[Token], int], None] +TokenFunc: TypeAlias = Callable[[ast.AST, RuleSet], Iterable[tuple[Offset, Callable]]] def get_converters(klass: type) -> list[TokenFunc]: @@ -33,7 +34,7 @@ def get_converters(klass: type) -> list[TokenFunc]: }.get(klass, []) -def visit(tree: ast.Module, rules: RuleSet) -> list[TokenFunc]: +def visit(tree: ast.Module, rules: RuleSet) -> dict[Offset, TokenFunc]: nodes = [tree] ret = defaultdict(list) while nodes: