Skip to content

Commit

Permalink
[Mosaic GPU] Add support for warpgroup lowering of loops with vector …
Browse files Browse the repository at this point in the history
…carries

PiperOrigin-RevId: 731260912
  • Loading branch information
apaszke authored and Google-ML-Automation committed Feb 26, 2025
1 parent 1de2f83 commit 99a12ef
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 4 deletions.
84 changes: 83 additions & 1 deletion jax/experimental/mosaic/gpu/dialect_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,10 +603,91 @@ def _for_op_lowering_rule(
return []


@_register_lowering(scf.ForOp)
def _for_op_lowering_rule(
ctx: LoweringContext, for_op: scf.ForOp
) -> MlirLoweringRuleResult:
if not layouts.should_have_layout(for_op):
return _traverse_op_lowering_rule(ctx, for_op)
in_layouts = layouts.in_layouts(for_op)
out_layouts = layouts.out_layouts(for_op)
yield_op = for_op.body.operations[len(for_op.body.operations) - 1]
yield_layouts = layouts.in_layouts(yield_op)
if in_layouts != out_layouts or in_layouts != yield_layouts:
raise ValueError("Layout mismatch")
fa_layouts = in_layouts

fa_layouts_it = iter(fa_layouts)
arg_template = [
(_fragmented_array_from_ir(arg, next(fa_layouts_it)), arg.type)
if ir.VectorType.isinstance(arg.type)
else (arg, arg.type)
for arg in for_op.initArgs
]
def lower_carry(carry):
fa_layouts_it = iter(fa_layouts)
carry_with_fas = [
_fragmented_array_from_ir(arg, next(fa_layouts_it))
if ir.VectorType.isinstance(arg.type)
else arg
for arg in carry
]
lowered_carry = []
for c in carry_with_fas:
if isinstance(c, fa.FragmentedArray):
lowered_carry.extend(c.registers.flat)
else:
lowered_carry.append(c)
return lowered_carry

def recreate_carry(lowered_carry):
recreated_carry = []
arg_it = iter(lowered_carry)
for arg_value, arg_type in arg_template:
if isinstance(arg_value, fa.FragmentedArray):
carry_registers = np.asarray(
[next(arg_it) for _ in arg_value.registers.flat], dtype=object
)
carry_registers = carry_registers.reshape(arg_value.registers.shape)
carry = fa.FragmentedArray(
_registers=carry_registers,
_layout=arg_value.layout,
_is_signed=arg_value.is_signed,
)
recreated_carry.append(_fragmented_array_to_ir(carry, arg_type))
else:
recreated_carry.append(next(arg_it))
return recreated_carry

new_for_op = scf.ForOp(
for_op.lowerBound,
for_op.upperBound,
for_op.step,
lower_carry(for_op.initArgs),
)
with ir.InsertionPoint(new_for_op.body):
recreated_carry = recreate_carry(new_for_op.body.arguments[1:])
ops_to_lower = []
for op in for_op.body:
if op == yield_op:
continue
mgpu.private_operation_remove_from_parent(op)
mgpu.private_block_append_owned_operation(new_for_op.body, op)
ops_to_lower.append(op)
new_args = (new_for_op.induction_variable, *recreated_carry)
for old_carry, new_carry in zip(for_op.body.arguments, new_args, strict=True):
old_carry.replace_all_uses_with(new_carry)
for op in ops_to_lower:
ctx.lower_op(op)
new_yield_operands = lower_carry(yield_op.operands)
yield_op.erase()
scf.yield_(new_yield_operands)
return recreate_carry(new_for_op.results)


@_register_lowering(func.FuncOp)
@_register_lowering(gpu.LaunchOp)
@_register_lowering(scf.IfOp) # TODO(apaszke,bchetioui): Add a proper rule.
@_register_lowering(scf.ForOp) # TODO(apaszke,bchetioui): Add a proper rule.
@_register_lowering(scf.IndexSwitchOp) # TODO(apaszke,bchetioui): Add a proper rule.
def _traverse_op_lowering_rule(
ctx: LoweringContext, op: ir.OpView
Expand Down Expand Up @@ -661,6 +742,7 @@ def _should_lower(op: ir.OpView) -> bool:
def lower_mgpu_dialect(
module: ir.Module, launch_context: launch_context.LaunchContext | None
):
# TODO(apaszke,bchetioui): Make sure the layouts match.
# TODO(bchetioui): rethink this API. It doesn't make sense to pass in a full
# module and to traverse all `gpu.LaunchOp`s if we have a `LaunchContext` that
# references a single `gpu.LaunchOp`.
Expand Down
2 changes: 2 additions & 0 deletions jax/experimental/mosaic/gpu/layout_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,8 @@ def _infer_yield_op_layout(op: scf.YieldOp) -> OptionalLayouts:
if not ir.VectorType.isinstance(result.type):
continue
if (layout := _value_layout(result)) is not None:
if layouts_lib.is_splat_fragmented_layout(layout):
return None
layouts.append(layout)
else:
# Not all layouts could be inferred for vector ops. Return for now.
Expand Down
2 changes: 2 additions & 0 deletions jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ NB_MODULE(_mosaic_gpu_ext, m) {
}
},
nb::arg("context"), nb::arg("load") = true);
m.def("private_operation_remove_from_parent", mlirOperationRemoveFromParent);
m.def("private_block_append_owned_operation", mlirBlockAppendOwnedOperation);

mlir::python::nanobind_adaptors::mlir_attribute_subclass(
m, "TileTransformAttr", mlirMosaicGpuIsATileTransformAttr)
Expand Down
41 changes: 41 additions & 0 deletions tests/mosaic/gpu_dialect_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,47 @@ def check_type(ty: ir.Type):
check_type(store1.valueToStore.type)
check_type(store2.valueToStore.type)

def test_lowering_for(self):
shape = (4, 128)
i32 = ir.IntegerType.get_signless(32)
vec_ty = ir.VectorType.get(shape, i32)
splat_layout_attr = layouts.to_layout_attr(mgpu.WGSplatFragLayout(shape))
strided_layout_attr = layouts.to_layout_attr(
mgpu.WGStridedFragLayout.from_shaped_type(vec_ty)
)
with ir.InsertionPoint(self.module.body):
i1 = arith.constant(ir.IndexType.get(), 1)
c1 = arith.constant(i32, 1)
splat = vector.SplatOp(
ir.VectorType.get(shape, i32), arith.constant(i32, 1234),
)
splat.attributes["out_layouts"] = ir.ArrayAttr.get([
splat_layout_attr
])
ptr = llvm.mlir_undef(ir.Type.parse("!llvm.ptr"))
ref = mgpu_utils.ptr_as_memref(ptr, ir.MemRefType.get(shape, i32))
i0 = arith.constant(ir.IndexType.get(), 0)
other_vec = vector.LoadOp(vec_ty, ref, [i0, i0])
other_vec.attributes["out_layouts"] = ir.ArrayAttr.get([strided_layout_attr])
for_op = scf.ForOp(i1, i1, i1, [c1, splat.result])
for_op.attributes["in_layouts"] = ir.ArrayAttr.get([strided_layout_attr])
for_op.attributes["out_layouts"] = ir.ArrayAttr.get([strided_layout_attr])
with ir.InsertionPoint(for_op.body):
i, int_carry, vec_carry = for_op.body.arguments
new_int_carry = arith.addi(int_carry, arith.index_castui(i32, i))
new_vec_carry = arith.AddIOp(vec_carry, other_vec)
new_vec_carry.attributes["in_layouts"] = ir.ArrayAttr.get([strided_layout_attr] * 2)
new_vec_carry.attributes["out_layouts"] = ir.ArrayAttr.get([strided_layout_attr])
yield_op = scf.YieldOp([new_int_carry, new_vec_carry])
yield_op.attributes["in_layouts"] = ir.ArrayAttr.get([strided_layout_attr])

mgpu.lower_mgpu_dialect(self.module, None)
self.module.operation.verify()
[for_op] = find_if(self.module, lambda op: isinstance(op, scf.ForOp))
result_types = [r.type for r in for_op.results]
reg_vec_ty = ir.VectorType.get((2,), i32)
self.assertSequenceEqual(result_types, [i32, reg_vec_ty, reg_vec_ty])


if __name__ == "__main__":
parameterized.absltest.main(testLoader=jtu.JaxTestLoader())
5 changes: 2 additions & 3 deletions tests/pallas/mosaic_gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -902,9 +902,8 @@ def kernel(x_ref, o_ref):
force_while=[False, True], thread_semantics=[*plgpu.ThreadSemantics]
)
def test_fori_loop_array(self, force_while, thread_semantics):
if thread_semantics == plgpu.ThreadSemantics.Warpgroup:
# TODO(apaszke,bchetioui,slebedev): Support while + array carries.
self.skipTest("WG semantics unsupported")
if force_while and thread_semantics == plgpu.ThreadSemantics.Warpgroup:
self.skipTest("WG semantics does not support force_while.")

@functools.partial(
pl.pallas_call,
Expand Down

0 comments on commit 99a12ef

Please sign in to comment.