Skip to content

Commit

Permalink
Deprecate stage_assn (#300)
Browse files Browse the repository at this point in the history
  • Loading branch information
SamirDroubi authored Jan 7, 2023
1 parent e4d51dc commit 91d8ae7
Show file tree
Hide file tree
Showing 11 changed files with 28 additions and 95 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ Take a look at `exo/examples` for scheduling examples.
| `.expand_dim(stmt, alloc_dim, indexing)` | Expands the dimension of the allocation statement `stmt` with dimension `alloc_dim` of indexing `indexing`. |
| `.bind_expr(new_name, expr)` | Binds the right hand side expression `expr` to a newly allocated buffer named `new_name` |
| `.stage_mem(win_expr, new_name, stmt_start, stmt_end=None)` | Stages the buffer `win_expr` to the new window expression `new_name` in statement block (`stmt_start` to `stmt_end`), and adds an initialization loop and a write-back loop. |
| `.stage_assn(new_name, stmt)` | Binds the left hand side expression of `stmt` to a newly allocated buffer named `new_name`. |
| `.rearrange_dim(alloc, dimensions)` | Takes an allocation statement and a list of integers to map the dimension. It rearranges the dimensions of `alloc` in `dimension` order. E.g., if `alloc` were `foo[N,M,K]` and the `dimension` were `[2,0,1]`, it would become `foo[K,N,M]` after this operation. |
| `.lift_alloc(alloc, n_lifts=1, keep_dims=False)` | Lifts the allocation statement `alloc` out of `n_lifts` number of scopes. If and For statements are the only statements in Exo which introduce a scope. When lifting the allocation out of a for loop, it will expand its dimension to the loop bound if `keep_dims` is True. |

Expand Down
5 changes: 3 additions & 2 deletions apps/x86/sgemm/sgemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def make_avx512_kernel(p):
# Vectorize columns
p = divide_loop(p, "j", VEC_W, ["jo", "ji"], perfect=True)
# Stage C for reduction
p = stage_assn(p, "C[_] += _", "C_reg")
p = stage_mem(p, "C[_] += _", f"C[i, {VEC_W} * jo + ji]", "C_reg")
p = set_memory(p, "C_reg", AVX512)
p = autolift_alloc(p, "C_reg: _", n_lifts=3, keep_dims=True)
p = autolift_alloc(p, "C_reg: _")
Expand Down Expand Up @@ -111,6 +111,7 @@ def stage_input(p, expr, new_buf):
p = autofission(p, p.find("mm512_set1_ps(_)").after())
# Clean up
p = simplify(p)
print(p)
return p

sgemm_kernel_avx512_Mx4[M] = make_avx512_kernel(basic_kernel_Mx4[M])
Expand Down Expand Up @@ -155,7 +156,7 @@ def make_right_panel_kernel(p=SGEMM_WINDOW):
def make_right_panel_kernel_opt(p=right_panel_kernel):
p = rename(p, "right_panel_kernel_opt")
#
p = stage_assn(p, "C[_] += _", "C_reg")
p = stage_mem(p, "C[_] += _", "C[i, j]", "C_reg")
p = divide_loop(p, "j", VEC_W, ["jo", "ji"], tail="cut")
p = bound_and_guard(p, "for ji in _: _ #1")
p = fission(p, p.find("for jo in _: _").after(), n_lifts=2)
Expand Down
14 changes: 7 additions & 7 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def rank_k_reduce_6x16(K: size, C: f32[6, 16] @ DRAM, A: f32[6, K] @ DRAM,
```

Next, please uncomment the code in the first block by deleting the multi-line string
markers (`"""`). Now, you will see that `stage_assn()` stages `C` to a buffer
markers (`"""`). Now, you will see that `stage_mem()` stages `C` to a buffer
called `C_reg`. `set_memory()` sets `C_reg`'s memory to AVX2 to use it as an AVX vector,
which is denoted by `@ AVX2`.

Expand All @@ -41,7 +41,7 @@ def rank_k_reduce_6x16_scheduled(K: size, C: f32[6, 16] @ DRAM,
for i in seq(0, 6):
for j in seq(0, 16):
for k in seq(0, K):
C_reg: R @ AVX2
C_reg: f32 @ AVX2
C_reg = C[i, j]
C_reg += A[i, k] * B[k, j]
C[i, j] = C_reg
Expand All @@ -59,7 +59,7 @@ def rank_k_reduce_6x16_scheduled(K: size, C: f32[6, 16] @ DRAM,
for i in seq(0, 6):
for jo in seq(0, 2):
for ji in seq(0, 8):
C_reg: R @ AVX2
C_reg: f32 @ AVX2
C_reg = C[i, 8 * jo + ji]
C_reg += A[i, k] * B[k, 8 * jo + ji]
C[i, 8 * jo + ji] = C_reg
Expand All @@ -75,7 +75,7 @@ Please uncomment the code in the third block. Please notice that
# Third block:
def rank_k_reduce_6x16_scheduled(K: size, C: f32[6, 16] @ DRAM,
A: f32[6, K] @ DRAM, B: f32[K, 16] @ DRAM):
C_reg: R[1 + K, 6, 2, 8] @ AVX2
C_reg: f32[1 + K, 6, 2, 8] @ AVX2
for k in seq(0, K):
for i in seq(0, 6):
for jo in seq(0, 2):
Expand All @@ -100,7 +100,7 @@ register `a_vec` by `bind_expr()` and `set_memory()`.
# Fourth block:
def rank_k_reduce_6x16_scheduled(K: size, C: f32[6, 16] @ DRAM,
A: f32[6, K] @ DRAM, B: f32[K, 16] @ DRAM):
C_reg: R[1 + K, 6, 2, 8] @ AVX2
C_reg: f32[1 + K, 6, 2, 8] @ AVX2
for k in seq(0, K):
for i in seq(0, 6):
for jo in seq(0, 2):
Expand Down Expand Up @@ -128,7 +128,7 @@ to `B`.
# Fifth block:
def rank_k_reduce_6x16_scheduled(K: size, C: f32[6, 16] @ DRAM,
A: f32[6, K] @ DRAM, B: f32[K, 16] @ DRAM):
C_reg: R[1 + K, 6, 2, 8] @ AVX2
C_reg: f32[1 + K, 6, 2, 8] @ AVX2
for k in seq(0, K):
for i in seq(0, 6):
for jo in seq(0, 2):
Expand Down Expand Up @@ -164,7 +164,7 @@ statement with the call to AVX2 instruction procedures to get the final schedule
# Sixth block:
def rank_k_reduce_6x16_scheduled(K: size, C: f32[6, 16] @ DRAM,
A: f32[6, K] @ DRAM, B: f32[K, 16] @ DRAM):
C_reg: R[1 + K, 6, 2, 8] @ AVX2
C_reg: f32[1 + K, 6, 2, 8] @ AVX2
for k in seq(0, K):
for i in seq(0, 6):
for jo in seq(0, 2):
Expand Down
2 changes: 1 addition & 1 deletion examples/x86_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def rank_k_reduce_6x16(
# First block
"""
avx = rename(rank_k_reduce_6x16, "rank_k_reduce_6x16_scheduled")
avx = stage_assn(avx, 'C[_] += _', 'C_reg')
avx = stage_mem(avx, 'C[_] += _', 'C[i, j]', 'C_reg')
avx = set_memory(avx, 'C_reg', AVX2)
print("First block:")
print(avx)
Expand Down
13 changes: 0 additions & 13 deletions src/exo/API_scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1961,16 +1961,3 @@ def bound_and_guard(proc, loop):
proc_c = ic.Cursor.root(proc)

return Schedules.DoBoundAndGuard(proc_c, stmt).result()


@sched_op([AssignOrReduceCursorA, NameA])
def stage_assn(proc, stmt_cursor, buf_name):
"""
DEPRECATED
This operation is deprecated, and should be replaced by
calls to `stage_mem` or something similar.
"""
stmt = stmt_cursor._impl
proc_c = ic.Cursor.root(proc)

return Schedules.DoStageAssn(proc_c, buf_name, stmt).result()
42 changes: 0 additions & 42 deletions src/exo/LoopIR_scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1390,47 +1390,6 @@ def map_e(self, e):
return super().map_e(e)


class _DoStageAssn(Cursor_Rewrite):
def __init__(self, proc_cursor, new_name, assn_cursor):
self.assn = assn_cursor._node()
assert isinstance(self.assn, (LoopIR.Assign, LoopIR.Reduce))
self.new_name = Sym(new_name)

super().__init__(proc_cursor)

# repair effects...
self.proc = InferEffects(self.proc).result()

def map_s(self, sc):
s = sc._node()
tmp = self.new_name
if s is self.assn and isinstance(s, LoopIR.Assign):
rdtmp = LoopIR.Read(tmp, [], s.type, s.srcinfo)
return [
# tmp : R
LoopIR.Alloc(tmp, T.R, None, None, s.srcinfo),
# tmp = rhs
LoopIR.Assign(tmp, s.type, None, [], s.rhs, None, s.srcinfo),
# lhs = tmp
LoopIR.Assign(s.name, s.type, None, s.idx, rdtmp, None, s.srcinfo),
]
elif s is self.assn and isinstance(s, LoopIR.Reduce):
rdbuf = LoopIR.Read(s.name, s.idx, s.type, s.srcinfo)
rdtmp = LoopIR.Read(tmp, [], s.type, s.srcinfo)
return [
# tmp : R
LoopIR.Alloc(tmp, T.R, None, None, s.srcinfo),
# tmp = lhs
LoopIR.Assign(tmp, s.type, None, [], rdbuf, None, s.srcinfo),
# tmp += rhs
LoopIR.Reduce(tmp, s.type, None, [], s.rhs, None, s.srcinfo),
# lhs = tmp
LoopIR.Assign(s.name, s.type, None, s.idx, rdtmp, None, s.srcinfo),
]

return super().map_s(sc)


# Lift if no variable dependency
class _DoLiftScope(Cursor_Rewrite):
def __init__(self, proc_cursor, if_cursor):
Expand Down Expand Up @@ -3835,7 +3794,6 @@ class Schedules:
DoCallSwap = _CallSwap
DoBindExpr = _BindExpr
DoBindConfig = _BindConfig
DoStageAssn = _DoStageAssn
DoLiftAlloc = _LiftAlloc
DoFissionLoops = _FissionLoops
DoExtractMethod = _DoExtractMethod
Expand Down
1 change: 0 additions & 1 deletion src/exo/stdlib/scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@
# deprecated scheduling operations
add_unsafe_guard,
bound_and_guard,
stage_assn,
#
# to be replaced by stdlib compositions eventually
autofission,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
def simple_math_neon_sched(n: size, x: R[n] @ DRAM, y: R[n] @ DRAM):
for io in seq(0, n / 4):
xyy: R[4] @ Neon4f
xVec: R[4] @ Neon4f
neon_vld_4xf32(xVec[0:4], x[4 * io + 0:4 * io + 4])
neon_vld_4xf32(xVec[0:4], x[4 * io:4 + 4 * io])
yVec: R[4] @ Neon4f
neon_vld_4xf32(yVec[0:4], y[4 * io + 0:4 * io + 4])
neon_vld_4xf32(yVec[0:4], y[4 * io:4 + 4 * io])
xy: R[4] @ Neon4f
neon_vmul_4xf32(xy[0:4], xVec[0:4], yVec[0:4])
neon_vmul_4xf32(xyy[0:4], xy[0:4], yVec[0:4])
neon_vst_4xf32(x[4 * io + 0:4 * io + 4], xyy[0:4])
neon_vmul_4xf32(xVec[0:4], xy[0:4], yVec[0:4])
neon_vst_4xf32(x[4 * io:4 + 4 * io], xVec[0:4])
if n % 4 > 0:
for ii in seq(0, n % 4):
x[ii + n / 4 *
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
def simple_math_avx2_sched(n: size, x: R[n] @ DRAM, y: R[n] @ DRAM):
for io in seq(0, n / 8):
xyy: R[8] @ AVX2
xVec: R[8] @ AVX2
mm256_loadu_ps(xVec[0:8], x[8 * io + 0:8 * io + 8])
mm256_loadu_ps(xVec[0:8], x[8 * io:8 + 8 * io])
yVec: R[8] @ AVX2
mm256_loadu_ps(yVec[0:8], y[8 * io + 0:8 * io + 8])
mm256_loadu_ps(yVec[0:8], y[8 * io:8 + 8 * io])
xy: R[8] @ AVX2
mm256_mul_ps(xy, xVec, yVec)
mm256_mul_ps(xyy, xy, yVec)
mm256_storeu_ps(x[8 * io + 0:8 * io + 8], xyy[0:8])
mm256_mul_ps(xVec, xy, yVec)
mm256_storeu_ps(x[8 * io:8 + 8 * io], xVec[0:8])
if n % 8 > 0:
for ii in seq(0, n % 8):
x[ii + n / 8 *
Expand Down
11 changes: 3 additions & 8 deletions tests/test_neon.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,13 +147,8 @@ def simple_math_neon_sched(

def sched_neon(p=simple_math_neon_sched):
p = divide_loop(p, "i", 4, ["io", "ii"], tail="cut_and_guard")
p = stage_assn(p, "x[_] = _ #0", "xyy")
p = autolift_alloc(p, "xyy: _", keep_dims=True)
p = fission(p, p.find("xyy[_] = _").after())

p = bind_expr(p, "x[_]", "xVec")
p = autolift_alloc(p, "xVec: _", keep_dims=True)
p = fission(p, p.find("xVec[_] = _").after())
p = stage_mem(p, "for ii in _:_ #0", "x[4 * io : 4 * io + 4]", "xVec")

p = bind_expr(p, "y[_]", "yVec", cse=True)
p = autolift_alloc(p, "yVec: _", keep_dims=True)
Expand All @@ -166,11 +161,11 @@ def sched_neon(p=simple_math_neon_sched):
p = set_memory(p, "xVec", Neon4f)
p = set_memory(p, "yVec", Neon4f)
p = set_memory(p, "xy", Neon4f)
p = set_memory(p, "xyy", Neon4f)
p = replace(p, "for ii in _: _ #4", neon_vst_4xf32)
p = replace(p, "for i0 in _: _ #1", neon_vst_4xf32)
p = replace_all(p, neon_vld_4xf32)
p = replace_all(p, neon_vmul_4xf32)

p = simplify(p)
return p

simple_math_neon_sched = sched_neon()
Expand Down
16 changes: 6 additions & 10 deletions tests/test_x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,17 +100,12 @@ def simple_math_avx2_sched(

def sched_simple_math_avx2_sched(p=simple_math_avx2_sched):
p = old_split(p, "i", 8, ["io", "ii"], tail="cut_and_guard")
p = stage_assn(p, "x[_] = _ #0", "xyy")
p = autolift_alloc(p, "xyy: _", keep_dims=True)
p = set_memory(p, "xyy", AVX2)
p = old_fission_after(p, "xyy[_] = _")

p = replace_all(p, mm256_storeu_ps)

p = bind_expr(p, "x[_]", "xVec")
p = autolift_alloc(p, "xVec: _", keep_dims=True)
p = stage_mem(p, "for ii in _:_", "x[8 * io: 8 * io + 8]", "xVec")
p = set_memory(p, "xVec", AVX2)
p = old_fission_after(p, "xVec[_] = _")

p = replace(p, "for i0 in _:_ #0", mm256_loadu_ps)
p = replace(p, "for i0 in _:_ #0", mm256_storeu_ps)

p = bind_expr(p, "y[_]", "yVec", cse=True)
p = autolift_alloc(p, "yVec: _", keep_dims=True)
Expand All @@ -125,6 +120,7 @@ def sched_simple_math_avx2_sched(p=simple_math_avx2_sched):
p = old_fission_after(p, "xy[_] = _")

p = replace_all(p, mm256_mul_ps)
p = simplify(p)
return p

simple_math_avx2_sched = sched_simple_math_avx2_sched()
Expand Down Expand Up @@ -184,7 +180,7 @@ def sgemm_6x16(
def avx2_sgemm_6x16(sgemm_6x16):
avx = rename(sgemm_6x16, "rank_k_reduce_6x16_scheduled")
print(avx)
avx = stage_assn(avx, "C[_] += _", "C_reg")
avx = stage_mem(avx, "C[_] += _", "C[i, j]", "C_reg")
avx = set_memory(avx, "C_reg", AVX2)
avx = old_split(avx, "j", 8, ["jo", "ji"], perfect=True)
avx = reorder_loops(avx, "ji k")
Expand Down

0 comments on commit 91d8ae7

Please sign in to comment.