Skip to content

Commit

Permalink
More complete typing
Browse files Browse the repository at this point in the history
  • Loading branch information
msaelices committed May 18, 2024
1 parent 9b27c18 commit d6d9b2f
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
6 changes: 3 additions & 3 deletions py2mojo/converters/assignment.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import ast
from functools import partial
from typing import Iterable
from typing import Callable, Iterable

from tokenize_rt import Token
from tokenize_rt import Offset, Token

from ..helpers import ast_to_offset, get_annotation_type, find_token, find_token_by_name, get_mojo_type
from ..rules import RuleSet
Expand All @@ -20,7 +20,7 @@ def _replace_assignment(tokens: list[Token], i: int, rules: RuleSet, new_type: s
tokens.insert(type_idx, Token(name='NAME', src=new_type))


def convert_assignment(node: ast.AnnAssign, rules: RuleSet) -> Iterable:
def convert_assignment(node: ast.AnnAssign, rules: RuleSet) -> Iterable[tuple[Offset, Callable]]:
"""Convert an assignment to a mojo assignment."""
curr_type = get_annotation_type(node.annotation)
new_type = get_mojo_type(curr_type, rules)
Expand Down
9 changes: 5 additions & 4 deletions py2mojo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,22 @@

import argparse
import ast
from collections.abc import Iterable
import os
import sys
import tokenize
from collections import defaultdict
from typing import Callable, Sequence
from typing import Callable, Sequence, TypeAlias

from tokenize_rt import Token, reversed_enumerate, src_to_tokens, tokens_to_src
from tokenize_rt import Offset, reversed_enumerate, src_to_tokens, tokens_to_src

from .converters import convert_assignment, convert_functiondef, convert_classdef
from .exceptions import ParseException
from .helpers import display_error, fixup_dedent_tokens
from .rules import get_rules, RuleSet


TokenFunc = Callable[[list[Token], int], None]
TokenFunc: TypeAlias = Callable[[ast.AST, RuleSet], Iterable[tuple[Offset, Callable]]]


def get_converters(klass: type) -> list[TokenFunc]:
Expand All @@ -33,7 +34,7 @@ def get_converters(klass: type) -> list[TokenFunc]:
}.get(klass, [])


def visit(tree: ast.Module, rules: RuleSet) -> list[TokenFunc]:
def visit(tree: ast.Module, rules: RuleSet) -> dict[Offset, TokenFunc]:
nodes = [tree]
ret = defaultdict(list)
while nodes:
Expand Down

0 comments on commit d6d9b2f

Please sign in to comment.