Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate QAST #670

Merged
merged 2 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 0 additions & 17 deletions src/exo/API.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
# Moved to new file
from .proc_eqv import decl_new_proc, derive_proc, assert_eqv_proc, check_eqv_proc
from .pyparser import get_ast_from_python, Parser, get_src_locals
from .reflection import LoopIR_to_QAST
from .typecheck import TypeChecker

from . import API_cursors as C
Expand Down Expand Up @@ -312,22 +311,6 @@ def find_alloc_or_arg(self, pattern):
def find_all(self, pattern):
return self.find(pattern, many=True)

def get_ast(self, pattern=None):
if pattern is None:
return LoopIR_to_QAST(self._loopir_proc).result()

# do pattern matching
match = match_pattern(self._root(), pattern, call_depth=1)

# convert matched sub-trees to QAST
assert isinstance(match, list)
if not match:
return None

return [
LoopIR_to_QAST(node._node).result() for block in match for node in block
]

# ---------------------------------------------- #
# execution / compilation operations
# ---------------------------------------------- #
Expand Down
2 changes: 0 additions & 2 deletions src/exo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from .parse_fragment import ParseFragmentError
from .configs import Config
from .memory import Memory, DRAM
from . import query_asts as QAST

from . import stdlib

Expand All @@ -27,7 +26,6 @@
"Config",
"Memory",
"DRAM",
"QAST",
"SchedulingError",
"ParseFragmentError",
#
Expand Down
162 changes: 19 additions & 143 deletions src/exo/platforms/gemmini.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,28 @@
from __future__ import annotations

from exo import proc, instr, DRAM, config, QAST
from exo import proc, instr, DRAM, config
from exo.libs.memories import GEMM_SCRATCH, GEMM_ACCUM
from exo.stdlib.scheduling import *


def lift_config(p, config_str):
config = p.find(config_str)
while True:
try:
try:
while True:
try:
p = reorder_stmts(p, p.forward(config).expand(1, 0))
except:
raise Exception("Reordered to the top of the scope")
except:
p = fission(p, p.forward(config).after(), unsafe_disable_checks=True)
p = remove_loop(p, p.forward(config).parent())
except:
break
return p


def set_prec_mem(p, bufname, precision, memory):
p = set_memory(p, bufname, memory)
p = set_precision(p, bufname, precision)
Expand Down Expand Up @@ -219,148 +237,6 @@ def tile_outer_loops(gemmini):
return gemmini


class QAST_Do:
def __init__(self, proc):
self.proc = proc

# [ self.do_fnarg(a) for a in self.proc.args ]
[self.do_e(p) for p in self.proc.assertions]
self.do_stmts(self.proc.body)

def do_stmts(self, stmts):
[self.do_s(b) for b in stmts]

def do_s(self, s):
if type(s) is QAST.Assign or type(s) is QAST.Reduce:
[self.do_e(e) for e in s.idx]
self.do_e(s.rhs)
elif type(s) is QAST.WriteConfig:
self.do_e(s.rhs)
elif type(s) is QAST.For:
self.do_e(s.lo)
self.do_e(s.hi)
self.do_stmts(s.body)
elif type(s) is QAST.If:
self.do_e(s.cond)
self.do_stmts(s.body)
if len(s.orelse) > 0:
self.do_stmts(s.orelse)
elif type(s) is QAST.Pass:
pass
elif type(s) is QAST.Alloc:
pass
elif type(s) is QAST.Call:
[self.do_e(e) for e in s.args]
elif type(s) is QAST.WindowStmt:
self.do_e(s.rhs)
else:
assert False, "bad case"

def do_w_access(self, w):
if type(w) is QAST.Interval:
self.do_e(w.lo)
self.do_e(w.hi)
elif type(w) is QAST.Point:
self.do_e(w.pt)

def do_e(self, e):
if type(e) is QAST.Read:
[self.do_e(ei) for ei in e.idx]
elif type(e) is QAST.Const:
pass
elif type(e) is QAST.USub:
self.do_e(e.arg)
elif type(e) is QAST.BinOp:
self.do_e(e.lhs)
self.do_e(e.rhs)
elif type(e) is QAST.BuiltIn:
[self.do_e(ei) for ei in e.args]
elif type(e) is QAST.WindowExpr:
[self.do_w_access(w) for w in e.idx]
elif type(e) is QAST.StrideExpr:
pass
elif type(e) is QAST.ReadConfig:
pass
else:
assert False, "bad case"


class CanFissionLoop(QAST_Do):
def __init__(self, proc, stmt):
self.stmt = stmt
self.result = False
super().__init__(proc)

def result(self):
return self.result

def do_s(self, s):
if type(s) is QAST.For:
assert len(s.body) > 0
if s.body[0] == self.stmt:
self.result = True

super().do_s(s)


class CanFissionIf(QAST_Do):
def __init__(self, proc, stmt):
self.stmt = stmt
self.result = None
super().__init__(proc)

def result(self):
return self.result

def do_s(self, s):
if type(s) is QAST.If:
assert len(s.body) > 0
if s.body[0] == self.stmt:
self.result = str(s)
elif len(s.orelse) > 0 and s.orelse[0] == self.stmt:
self.result = str(s)

super().do_s(s)


class CanReorder(QAST_Do):
def __init__(self, proc, stmt):
self.stmt = stmt
self.result = None
super().__init__(proc)

def result(self):
return self.result

def do_stmts(self, stmts):
prev = None
for b in stmts:
if b == self.stmt and prev is not None:
self.result = str(prev)
else:
self.do_s(b)
prev = b


def lift_config(conv, string, nth=0):
stmt = conv.get_ast(string)
stmt = stmt[nth] # Get the match

while True:
proc = conv.get_ast()
fission_loop = CanFissionLoop(proc, stmt).result
reorder = CanReorder(proc, stmt).result
if fission_loop:
conv = old_fission_after(conv, string)
elif reorder is not None:
conv = reorder_stmts(conv, conv.find(string).expand(1, 0))
# conv = conv.reorder_before(string)
else:
break

return conv


def inline_vector(conv):
conv = call_eqv(conv, "ld_acc_i32_vector(_)", ld_acc_i32_vector_v2)
conv = inline(conv, "ld_acc_i32_vector_v2(_)")
Expand Down
Loading
Loading