Skip to content

Commit

Permalink
[Pallas/Mosaic GPU] Add an abstraction to obtain a slice of dynamic s…
Browse files Browse the repository at this point in the history
…hared memory when using waprgroup semantics.

Explicitly make the assumption that `runtime_smem` starts at `0` in the Pallas
module context---which should be enforced by Mosaic GPU.

This is in preparation of changes implementing transform inference.

PiperOrigin-RevId: 731327428
  • Loading branch information
bchetioui authored and Google-ML-Automation committed Feb 28, 2025
1 parent 55263ce commit 6a14d39
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 15 deletions.
31 changes: 22 additions & 9 deletions jax/_src/pallas/mosaic_gpu/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ class ModuleContext:
program_ids: Sequence[ir.Value] | None
approx_math: bool
single_wg_lane_predicate: ir.Value
runtime_smem: ir.Value # ir.MemRefType
smem_requested_bytes: int
smem_used_bytes: int
runtime_barriers: MutableMapping[
mgpu.Barrier, MutableSequence[mgpu.BarrierRef]
Expand Down Expand Up @@ -279,25 +279,38 @@ def scratch_view(
and the second element is a sequence of memref views into the
runtime scratch buffer.
"""
smem_scratch_bytes = math.prod(ir.MemRefType(self.runtime_smem.type).shape)

smem_base = None
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
i8 = ir.IntegerType.get_signless(8)
i32 = ir.IntegerType.get_signless(32)
if self.thread_semantics == mgpu.ThreadSemantics.Lane:
smem_base = gpu_dialect.dynamic_shared_memory(
ir.MemRefType.get((mgpu_utils.DYNAMIC,), i8, memory_space=smem)
)
views = []
off = initial_used_bytes = self.smem_used_bytes
assert off % _SMEM_ALIGNMENT == 0
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
for s in structs:
scratch_ty = ir.MemRefType.get(
s.shape,
mgpu_utils.dtype_to_ir_type(s.dtype),
memory_space=smem,
)
views.append(
memref_dialect.view(scratch_ty, self.runtime_smem, _as_index(off), [])
)
# The below code emission relies on the assumption that the first scratch
# operand provided by Mosaic GPU always begins at the beginning of
# dynamic SMEM. Mosaic GPU is expected to uphold that invariant.
if self.thread_semantics == mgpu.ThreadSemantics.Lane:
view = memref_dialect.view(
scratch_ty, smem_base, _as_index(off), []
)
else:
view = mgpu.dialect.slice_smem(scratch_ty, mgpu_utils.c(off, i32))
views.append(view)

off += _align_to(
math.prod(s.shape) * jnp.dtype(s.dtype).itemsize, _SMEM_ALIGNMENT
)
assert off <= smem_scratch_bytes, "Ran out of scoped SMEM"
assert off <= self.smem_requested_bytes, "Ran out of scoped SMEM"
assert off % _SMEM_ALIGNMENT == 0

self.smem_used_bytes = off
Expand Down Expand Up @@ -596,7 +609,7 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value):
[_program_id(axis, squashed_dims) for axis in range(len(grid))],
approx_math,
mgpu.single_thread_predicate(per_block=False),
runtime_smem,
smem_requested_bytes=math.prod(ir.MemRefType(runtime_smem.type).shape),
smem_used_bytes=0,
runtime_barriers=grouped_barriers,
name_stack=source_info_util.NameStack(),
Expand Down
23 changes: 17 additions & 6 deletions jax/experimental/mosaic/gpu/dialect_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from jax._src.lib.mlir.dialects import func
from jax._src.lib.mlir.dialects import gpu
from jax._src.lib.mlir.dialects import llvm
from jax._src.lib.mlir.dialects import memref
from jax._src.lib.mlir.dialects import nvvm
from jax._src.lib.mlir.dialects import scf
from jax._src.lib.mlir.dialects import vector
Expand Down Expand Up @@ -598,15 +599,25 @@ def _mgpu_wait_op_lowering_rule(
return []


@_register_lowering(WaitOp)
def _for_op_lowering_rule(
_: LoweringContext, wait_op: scf.ForOp
# TODO(bchetioui): remove this once jaxlib minimum version >= 0.5.2.
SliceSMEMOp = getattr(mgpu, "SliceSMEMOp", None)


@_register_lowering(SliceSMEMOp)
def _mgpu_slice_smem_op_lowering_rule(
ctx: LoweringContext, op: SliceSMEMOp
) -> Sequence[ir.Value]:
del ctx
i8 = ir.IntegerType.get_signless(8)
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")

barrier = utils.BarrierRef.from_dialect_barrier_memref(wait_op.barrier)
barrier.wait_parity(wait_op.parity)
smem_base = gpu.dynamic_shared_memory(
ir.MemRefType.get((utils.DYNAMIC,), i8, memory_space=smem)
)

return []
offset = arith.index_cast(ir.IndexType.get(), op.offset)

return [memref.view(op.result.type, smem_base, offset, [])]


@_register_lowering(scf.ForOp)
Expand Down
8 changes: 8 additions & 0 deletions jaxlib/mosaic/dialect/gpu/mosaic_gpu.td
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,14 @@ def MosaicGPU_WGMMALayout :
let cppNamespace = "::mosaic_gpu";
}


def MosaicGPU_SliceSMEMOp : Op<MosaicGPU_Dialect, "slice_smem", []> {
let summary = "Constructs an SMEM MemRef with the requested type that begins at the specified SMEM offset address.";

let arguments = (ins I32:$offset);
let results = (outs MemRefOf<[AnyType]>);
}

def MosaicGPU_WGMMAOp : Op<MosaicGPU_Dialect, "wgmma", [InferTypeOpInterface]> {
let summary = "Multiply two matrices asyncronously using warpgroup level matrix multiply operations.";
let description = [{
Expand Down
22 changes: 22 additions & 0 deletions tests/mosaic/gpu_dialect_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,6 +802,28 @@ def test_lowering_for(self):
reg_vec_ty = ir.VectorType.get((2,), i32)
self.assertSequenceEqual(result_types, [i32, reg_vec_ty, reg_vec_ty])

def test_lowering_slice_smem_op(self):
shift = 1234
offset = None

def body():
nonlocal offset
i32 = ir.IntegerType.get_signless(32)
offset = arith.constant(i32, shift)
mgpu.dialect.slice_smem(i32, offset)

with ir.InsertionPoint(self.module.body):
func.FuncOp.from_py_func()(body)

mgpu.lower_mgpu_dialect(self.module, None)
# Avoid making a change detector, only validate that lowering runs as
# expected.
self.assertEmpty(
find_if(
self.module, lambda op: isinstance(op, mgpu.dialect.SliceSMEMOp)
)
)


if __name__ == "__main__":
parameterized.absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 6a14d39

Please sign in to comment.