diff --git a/src/exo/API.py b/src/exo/API.py index a180f8539..03d7e7292 100644 --- a/src/exo/API.py +++ b/src/exo/API.py @@ -23,7 +23,6 @@ # Moved to new file from .proc_eqv import decl_new_proc, derive_proc, assert_eqv_proc, check_eqv_proc from .pyparser import get_ast_from_python, Parser, get_src_locals -from .reflection import LoopIR_to_QAST from .typecheck import TypeChecker from . import API_cursors as C @@ -312,22 +311,6 @@ def find_alloc_or_arg(self, pattern): def find_all(self, pattern): return self.find(pattern, many=True) - def get_ast(self, pattern=None): - if pattern is None: - return LoopIR_to_QAST(self._loopir_proc).result() - - # do pattern matching - match = match_pattern(self._root(), pattern, call_depth=1) - - # convert matched sub-trees to QAST - assert isinstance(match, list) - if not match: - return None - - return [ - LoopIR_to_QAST(node._node).result() for block in match for node in block - ] - # ---------------------------------------------- # # execution / compilation operations # ---------------------------------------------- # diff --git a/src/exo/__init__.py b/src/exo/__init__.py index 998d56cb2..80cbbf15e 100644 --- a/src/exo/__init__.py +++ b/src/exo/__init__.py @@ -11,7 +11,6 @@ from .parse_fragment import ParseFragmentError from .configs import Config from .memory import Memory, DRAM -from . import query_asts as QAST from . import stdlib @@ -27,7 +26,6 @@ "Config", "Memory", "DRAM", - "QAST", "SchedulingError", "ParseFragmentError", # diff --git a/src/exo/platforms/gemmini.py b/src/exo/platforms/gemmini.py index bef9587ef..c91e4a460 100644 --- a/src/exo/platforms/gemmini.py +++ b/src/exo/platforms/gemmini.py @@ -1,10 +1,28 @@ from __future__ import annotations -from exo import proc, instr, DRAM, config, QAST +from exo import proc, instr, DRAM, config from exo.libs.memories import GEMM_SCRATCH, GEMM_ACCUM from exo.stdlib.scheduling import * +def lift_config(p, config_str): + config = p.find(config_str) + while True: + try: + try: + while True: + try: + p = reorder_stmts(p, p.forward(config).expand(1, 0)) + except: + raise Exception("Reordered to the top of the scope") + except: + p = fission(p, p.forward(config).after(), unsafe_disable_checks=True) + p = remove_loop(p, p.forward(config).parent()) + except: + break + return p + + def set_prec_mem(p, bufname, precision, memory): p = set_memory(p, bufname, memory) p = set_precision(p, bufname, precision) @@ -219,148 +237,6 @@ def tile_outer_loops(gemmini): return gemmini -class QAST_Do: - def __init__(self, proc): - self.proc = proc - - # [ self.do_fnarg(a) for a in self.proc.args ] - [self.do_e(p) for p in self.proc.assertions] - self.do_stmts(self.proc.body) - - def do_stmts(self, stmts): - [self.do_s(b) for b in stmts] - - def do_s(self, s): - if type(s) is QAST.Assign or type(s) is QAST.Reduce: - [self.do_e(e) for e in s.idx] - self.do_e(s.rhs) - elif type(s) is QAST.WriteConfig: - self.do_e(s.rhs) - elif type(s) is QAST.For: - self.do_e(s.lo) - self.do_e(s.hi) - self.do_stmts(s.body) - elif type(s) is QAST.If: - self.do_e(s.cond) - self.do_stmts(s.body) - if len(s.orelse) > 0: - self.do_stmts(s.orelse) - elif type(s) is QAST.Pass: - pass - elif type(s) is QAST.Alloc: - pass - elif type(s) is QAST.Call: - [self.do_e(e) for e in s.args] - elif type(s) is QAST.WindowStmt: - self.do_e(s.rhs) - else: - assert False, "bad case" - - def do_w_access(self, w): - if type(w) is QAST.Interval: - self.do_e(w.lo) - self.do_e(w.hi) - elif type(w) is QAST.Point: - self.do_e(w.pt) - - def do_e(self, e): - if type(e) is QAST.Read: - [self.do_e(ei) for ei in e.idx] - elif type(e) is QAST.Const: - pass - elif type(e) is QAST.USub: - self.do_e(e.arg) - elif type(e) is QAST.BinOp: - self.do_e(e.lhs) - self.do_e(e.rhs) - elif type(e) is QAST.BuiltIn: - [self.do_e(ei) for ei in e.args] - elif type(e) is QAST.WindowExpr: - [self.do_w_access(w) for w in e.idx] - elif type(e) is QAST.StrideExpr: - pass - elif type(e) is QAST.ReadConfig: - pass - else: - assert False, "bad case" - - -class CanFissionLoop(QAST_Do): - def __init__(self, proc, stmt): - self.stmt = stmt - self.result = False - super().__init__(proc) - - def result(self): - return self.result - - def do_s(self, s): - if type(s) is QAST.For: - assert len(s.body) > 0 - if s.body[0] == self.stmt: - self.result = True - - super().do_s(s) - - -class CanFissionIf(QAST_Do): - def __init__(self, proc, stmt): - self.stmt = stmt - self.result = None - super().__init__(proc) - - def result(self): - return self.result - - def do_s(self, s): - if type(s) is QAST.If: - assert len(s.body) > 0 - if s.body[0] == self.stmt: - self.result = str(s) - elif len(s.orelse) > 0 and s.orelse[0] == self.stmt: - self.result = str(s) - - super().do_s(s) - - -class CanReorder(QAST_Do): - def __init__(self, proc, stmt): - self.stmt = stmt - self.result = None - super().__init__(proc) - - def result(self): - return self.result - - def do_stmts(self, stmts): - prev = None - for b in stmts: - if b == self.stmt and prev is not None: - self.result = str(prev) - else: - self.do_s(b) - prev = b - - -def lift_config(conv, string, nth=0): - stmt = conv.get_ast(string) - stmt = stmt[nth] # Get the match - - while True: - proc = conv.get_ast() - fission_loop = CanFissionLoop(proc, stmt).result - reorder = CanReorder(proc, stmt).result - if fission_loop: - conv = old_fission_after(conv, string) - elif reorder is not None: - conv = reorder_stmts(conv, conv.find(string).expand(1, 0)) - # conv = conv.reorder_before(string) - else: - break - - return conv - - def inline_vector(conv): conv = call_eqv(conv, "ld_acc_i32_vector(_)", ld_acc_i32_vector_v2) conv = inline(conv, "ld_acc_i32_vector_v2(_)") diff --git a/src/exo/query_asts.py b/src/exo/query_asts.py deleted file mode 100644 index 1b69e2717..000000000 --- a/src/exo/query_asts.py +++ /dev/null @@ -1,316 +0,0 @@ -"""Query AST Classes - -This module contains a reflection of Exo's internal AST structures. - -They are organized into a class hierarchy of Python dataclasses as -follows. - -QueryAST - Proc ( name : str, args : list[FnArg], assertions : list[Expr], - body : list[Stmt], instruction : Optional[str] ) - FnArg ( name : str, type : Type, memory : Optional[Memory] ) - Stmt - Assign ( name : str, lhs_type : Type, idx : list[Expr], - rhs : Expr ) - Reduce ( name : str, lhs_type : Type, idx : list[Expr], - rhs : Expr ) - WriteConfig ( config : Config, field : str, - rhs : Expr ) - Pass () - If ( cond : Expr, body : list[Stmt], orelse : list[Stmt] ) - For ( name : str, lo : Expr, hi : Expr, - body : list[Stmt], is_par : bool ) - Alloc ( name : str, type : Type, memory : Optional[Memory] ) - Call ( proc : str, args : list[Expr] ) - WindowStmt( name : str, rhs : Expr ) - Expr - Read ( name : str, idx : list[Expr], type : Type ) - Const ( val : Any, type : Type ) - USub ( arg : Expr, type : Type ) - BinOp ( op : str, lhs : Expr, - rhs : Expr, type : Type ) - BuiltIn ( func : str, args : list[Expr], type : Type ) - WindowExpr( name : str, idx : list[WAccess], type : Type ) - StrideExpr( name : str, dim : int, type : Type ) - ReadConfig( config : Config, field : str, type : Type ) - WAccess - Interval( lo : Expr, hi : Expr ) - Point( pt : Expr ) - Type - R() - f16() - f32() - f64() - i8() - i32() - bool() - int() - index() - size() - stride() - tensor( hi : list[Expr], is_window : bool, type : Type ) - -""" - -from dataclasses import dataclass as _dataclass -from typing import Any as _Any -from typing import Optional as _Optional - -from .configs import Config as _Config -from .memory import Memory as _Memory - - -# ---------------------------------------------------------------------- -# -- base classes for all asts returned from the reflection interface -- - - -class QueryAST: - def __init__(self): - raise Exception("Should never try to instantiate QueryAST") - - -class Type(QueryAST): - def __init__(self): - raise Exception("Should never try to instantiate Type") - - -class Expr(QueryAST): - def __init__(self): - raise Exception("Should never try to instantiate Expr") - - -class WAccess(QueryAST): - def __init__(self): - raise Exception("Should never try to instantiate WAccess") - - -class Stmt(QueryAST): - def __init__(self): - raise Exception("Should never try to instantiate Stmt") - - -# ----------------------------------- -# -- QueryAST --> Type --> _______ -- - - -@_dataclass -class R(Type): - pass - - -@_dataclass -class f16(Type): - pass - - -@_dataclass -class f32(Type): - pass - - -@_dataclass -class f64(Type): - pass - - -@_dataclass -class i8(Type): - pass - - -@_dataclass -class i32(Type): - pass - - -@_dataclass -class bool(Type): - pass - - -@_dataclass -class int(Type): - pass - - -@_dataclass -class index(Type): - pass - - -@_dataclass -class size(Type): - pass - - -@_dataclass -class stride(Type): - pass - - -@_dataclass -class tensor(Type): - hi: list[Expr] - is_window: bool - type: Type - - -# -------------------------------------- -# -- QueryAST --> WAccess --> _______ -- - - -@_dataclass -class Interval(WAccess): - lo: Expr - hi: Expr - - -@_dataclass -class Point(WAccess): - pt: Expr - - -# ----------------------------------- -# -- QueryAST --> Expr --> _______ -- - - -@_dataclass -class Read(Expr): - name: str - idx: list[Expr] - type: Type - - -@_dataclass -class Const(Expr): - val: _Any - type: Type - - -@_dataclass -class USub(Expr): - arg: Expr - type: Type - - -@_dataclass -class BinOp(Expr): - op: str - lhs: Expr - rhs: Expr - type: Type - - -@_dataclass -class BuiltIn(Expr): - func: str - args: list[Expr] - type: Type - - -@_dataclass -class WindowExpr(Expr): - name: str - idx: list[WAccess] - type: Type - - -@_dataclass -class StrideExpr(Expr): - name: str - dim: int - type: Type - - -@_dataclass -class ReadConfig(Expr): - config: _Config - field: str - type: Type - - -# ----------------------------------- -# -- QueryAST --> Stmt --> _______ -- - - -@_dataclass -class Assign(Stmt): - name: str - lhs_type: Type - idx: list[Expr] - rhs: Expr - - -@_dataclass -class Reduce(Stmt): - name: str - lhs_type: Type - idx: list[Expr] - rhs: Expr - - -@_dataclass -class WriteConfig(Stmt): - config: _Config - field: str - rhs: Expr - - -@_dataclass -class Pass(Stmt): - pass - - -@_dataclass -class If(Stmt): - cond: Expr - body: list[Expr] - orelse: list[Expr] - - -@_dataclass -class For(Stmt): - name: str - lo: Expr - hi: Expr - body: list[Expr] - is_par: bool - - -@_dataclass -class Alloc(Stmt): - name: str - type: Type - memory: _Optional[_Memory] - - -@_dataclass -class Call(Stmt): - proc: str - args: list[Expr] - - -@_dataclass -class WindowStmt(Stmt): - name: str - rhs: Expr - - -# -------------------------- -# -- QueryAST --> _______ -- - - -@_dataclass -class FnArg(QueryAST): - name: str - type: Type - memory: _Optional[_Memory] - - -@_dataclass -class Proc(QueryAST): - name: str - args: list[FnArg] - assertions: list[Expr] - body: list[Stmt] - instruction: _Optional[str] diff --git a/src/exo/reflection.py b/src/exo/reflection.py deleted file mode 100644 index 7eda4da8d..000000000 --- a/src/exo/reflection.py +++ /dev/null @@ -1,261 +0,0 @@ -from . import query_asts as QAST -from .LoopIR import LoopIR, T -from .prelude import * - - -@extclass(QAST.Call) -def __str__(self): - return f"{self.proc}(_)" - - -del __str__ - - -@extclass(QAST.Alloc) -def __str__(self): - return f"{self.name} : _" - - -del __str__ - - -@extclass(QAST.WriteConfig) -def __str__(self): - return f"{self.config.name()}.{self.field} = {str(self.rhs)}" - - -del __str__ - - -@extclass(QAST.Assign) -@extclass(QAST.Reduce) -def __str__(self): - if len(self.idx) > 0: - return f"{self.name}[_] = {str(self.rhs)}" - else: - return f"{self.name} = {str(self.rhs)}" - - -del __str__ - - -@extclass(QAST.For) -def __str__(self): - return f"for {self.name} in par(0, {str(self.hi)}):_" - - -del __str__ - - -@extclass(QAST.If) -def __str__(self): - cond = str(self.cond) - return f"if {cond}:_" - - -del __str__ - - -@extclass(QAST.Read) -def __str__(self): - if len(self.idx) > 0: - return f"{self.name}[_]" - else: - return f"{self.name}" - - -del __str__ - - -@extclass(QAST.Const) -def __str__(self): - return f"{self.val}" - - -del __str__ - - -@extclass(QAST.USub) -def __str__(self): - return f"-{str(self.arg)}" - - -del __str__ - - -@extclass(QAST.BinOp) -def __str__(self): - lhs = str(self.lhs) - rhs = str(self.rhs) - return f"{lhs} {self.op} {rhs}" - - -del __str__ - - -# --------------------------------------------------------------------------- # -# --------------------------------------------------------------------------- # -# Conversion from LoopIR AST to QAST - - -class LoopIR_to_QAST: - def __init__(self, loopir_node): - self.loopir_node = loopir_node - - self.names = dict() - - if isinstance(loopir_node, LoopIR.proc): - self.qast = self.map_proc(loopir_node) - elif isinstance(loopir_node, list): - if len(loopir_node) == 0: - self.qast = [] - elif isinstance(loopir_node[0], LoopIR.stmt): - self.qast = self.map_stmts(loopir_node) - else: - assert False, f"cannot process list of {type(loopir_node[0])}" - elif isinstance(loopir_node, LoopIR.fnarg): - self.qast = self.map_fnarg(loopir_node) - elif isinstance(loopir_node, LoopIR.stmt): - self.qast = self.map_stmt(loopir_node) - elif isinstance(loopir_node, LoopIR.expr): - self.qast = self.map_expr(loopir_node) - elif isinstance(loopir_node, LoopIR.type): - self.qast = self.map_type(loopir_node) - - def result(self): - return self.qast - - def getname(self, sym): - return str(sym) - - def bindname(self, sym): - return str(sym) - - def map_proc(self, proc): - name = proc.name - return QAST.Proc( - name, - [self.map_fnarg(fa) for fa in proc.args], - [self.map_expr(p) for p in proc.preds], - self.map_stmts(proc.body), - proc.instr, - ) - - def map_fnarg(self, fa): - name = self.bindname(fa.name) - return QAST.FnArg(name, self.map_type(fa.type), fa.mem) - - def map_stmts(self, body): - return [self.map_stmt(s) for s in body] - - def map_stmt(self, s): - styp = type(s) - if styp is LoopIR.Assign or styp is LoopIR.Reduce: - qtyp = QAST.Assign if styp is LoopIR.Assign else QAST.Reduce - name = self.getname(s.name) - return qtyp( - name, - self.map_type(s.type), - [self.map_expr(i) for i in s.idx], - self.map_expr(s.rhs), - ) - elif styp is LoopIR.WriteConfig: - return QAST.WriteConfig(s.config, s.field, self.map_expr(s.rhs)) - elif styp is LoopIR.Pass: - return QAST.Pass() - elif styp is LoopIR.If: - return QAST.If( - self.map_expr(s.cond), self.map_stmts(s.body), self.map_stmts(s.orelse) - ) - elif styp is LoopIR.For: - name = self.bindname(s.iter) - return QAST.For( - name, - self.map_expr(s.lo), - self.map_expr(s.hi), - self.map_stmts(s.body), - False, - ) - elif styp is LoopIR.Alloc: - name = self.bindname(s.name) - return QAST.Alloc(name, self.map_type(s.type), s.mem) - elif styp is LoopIR.Call: - return QAST.Call(s.f.name, [self.map_expr(a) for a in s.args]) - elif styp is LoopIR.WindowStmt: - name = self.bindname(s.name) - return QAST.WindowStmt(name, self.map_expr(s.rhs)) - else: - assert False, f"bad case: {styp}" - - def map_expr(self, e): - etyp = type(e) - if etyp is LoopIR.Read: - return QAST.Read( - self.getname(e.name), - [self.map_expr(i) for i in e.idx], - self.map_type(e.type), - ) - elif etyp is LoopIR.Const: - return QAST.Const(e.val, self.map_type(e.type)) - elif etyp is LoopIR.USub: - return QAST.USub(self.map_expr(e.arg), self.map_type(e.type)) - elif etyp is LoopIR.BinOp: - return QAST.BinOp( - e.op, self.map_expr(e.lhs), self.map_expr(e.rhs), self.map_type(e.type) - ) - elif etyp is LoopIR.BuiltIn: - return QAST.BuiltIn( - e.f.name(), [self.map_expr(a) for a in e.args], self.map_type(e.type) - ) - elif etyp is LoopIR.WindowExpr: - name = self.getname(e.name) - - def map_w(w): - if isinstance(w, LoopIR.Interval): - return QAST.Interval(self.map_expr(w.lo), self.map_expr(w.hi)) - else: - return QAST.Point(self.map_expr(w.pt)) - - return QAST.WindowExpr( - name, [map_w(w) for w in e.idx], self.map_type(e.type) - ) - elif etyp is LoopIR.StrideExpr: - name = self.getname(e.name) - return QAST.StrideExpr(name, e.dim, self.map_type(e.type)) - elif etyp is LoopIR.ReadConfig: - return QAST.ReadConfig(e.config, e.field, self.map_type(e.type)) - else: - assert False, f"bad case: {etyp}" - - def map_type(self, typ): - if typ == T.R: - return QAST.R() - elif typ == T.f16: - return QAST.f16() - elif typ == T.f32: - return QAST.f32() - elif typ == T.f64: - return QAST.f64() - elif typ == T.i8: - return QAST.i8() - elif typ == T.i32: - return QAST.i32() - elif typ == T.bool: - return QAST.bool() - elif typ == T.int: - return QAST.int() - elif typ == T.index: - return QAST.index() - elif typ == T.size: - return QAST.size() - elif typ == T.stride: - return QAST.stride() - elif typ.is_tensor_or_window(): - as_tensor = typ.as_tensor if isinstance(typ, T.Window) else typ - return QAST.tensor( - [self.map_expr(e) for e in as_tensor.hi], - as_tensor.is_window, - self.map_type(as_tensor.type), - ) - else: - assert False, f"bad case: {type(typ)}" diff --git a/tests/test_reflection.py b/tests/test_reflection.py deleted file mode 100644 index d937a03d0..000000000 --- a/tests/test_reflection.py +++ /dev/null @@ -1,139 +0,0 @@ -from __future__ import annotations - -from exo import proc, DRAM, QAST -from exo.stdlib.scheduling import * - - -# ------- Reflection tests --------- - - -def new_sgemm(): - @proc - def sgemm_full( - N: size, - M: size, - K: size, - C: f32[N, M] @ DRAM, - A: f32[N, K] @ DRAM, - B: f32[K, M] @ DRAM, - ): - for i in seq(0, N): - for j in seq(0, M): - for k in seq(0, K): - C[i, j] += A[i, k] * B[k, j] - - return sgemm_full - - -def test_proc_name(): - sgemm = new_sgemm() - - proc = sgemm.get_ast() - assert isinstance(proc, QAST.Proc) - assert proc.name == "sgemm_full" - - sgemm = rename(sgemm, "sgemm") - - proc = sgemm.get_ast() - assert isinstance(proc, QAST.Proc) - assert proc.name == "sgemm" - - -def test_find_outer_loop(): - sgemm = new_sgemm() - - loops = sgemm.get_ast("for _ in _: _ #0") - assert isinstance(loops, list) and len(loops) == 1 - assert isinstance(loops[0], QAST.For) - - assert loops[0].name == "i" - i_body = loops[0].body - assert isinstance(i_body, list) and len(i_body) == 1 - assert isinstance(i_body[0], QAST.For) - - assert i_body[0].name == "j" - j_body = i_body[0].body - assert isinstance(j_body, list) and len(j_body) == 1 - assert isinstance(j_body[0], QAST.For) - - assert j_body[0].name == "k" - k_body = j_body[0].body - assert isinstance(k_body, list) and len(k_body) == 1 - - assert not isinstance(k_body[0], QAST.For) - - -def get_loop_nest_info(p, pattern): - loops = p.get_ast(pattern) - if loops is None: - return [] - assert isinstance(loops, list) and len(loops) > 0 - assert isinstance(loops[0], QAST.Stmt), "must call with ... #_ pattern" - - def recurse_loops(loops): - if len(loops) != 1: - return [] - stmt = loops[0] - assert isinstance(stmt, QAST.Stmt) - - if isinstance(stmt, QAST.For): - return [(stmt.name, stmt.hi)] + recurse_loops(stmt.body) - else: - return [] - - return recurse_loops(loops) - - -def test_get_outer_loop_info(): - sgemm = new_sgemm() - - info = get_loop_nest_info(sgemm, "for _ in _: _ #0") - - expect_info = [ - ("i", QAST.Read("N", [], QAST.size())), - ("j", QAST.Read("M", [], QAST.size())), - ("k", QAST.Read("K", [], QAST.size())), - ] - assert info == expect_info - - -def test_get_mid_loop_info(): - sgemm = new_sgemm() - - expect_info = [ - ("j", QAST.Read("M", [], QAST.size())), - ("k", QAST.Read("K", [], QAST.size())), - ] - - info = get_loop_nest_info(sgemm, "for j in _: _ #0") - assert info == expect_info - - info = get_loop_nest_info(sgemm, "for _ in _: _ #1") - assert info == expect_info - - -def test_get_bottom_loop_info(): - sgemm = new_sgemm() - - expect_info = [ - ("k", QAST.Read("K", [], QAST.size())), - ] - - info = get_loop_nest_info(sgemm, "for k in _: _ #0") - assert info == expect_info - - info = get_loop_nest_info(sgemm, "for _ in _: _ #2") - assert info == expect_info - - -def test_get_no_loop_info(): - sgemm = new_sgemm() - - info = get_loop_nest_info(sgemm, "for abc in _: _ #0") - assert info == [] - - -def test_show_effect(golden): - sgemm = new_sgemm() - - assert sgemm.show_effect("for j in _: _") == golden