diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index 1c7265fd698d..290df2af6b06 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -603,10 +603,91 @@ def _for_op_lowering_rule( return [] +@_register_lowering(scf.ForOp) +def _for_op_lowering_rule( + ctx: LoweringContext, for_op: scf.ForOp +) -> MlirLoweringRuleResult: + if not layouts.should_have_layout(for_op): + return _traverse_op_lowering_rule(ctx, for_op) + in_layouts = layouts.in_layouts(for_op) + out_layouts = layouts.out_layouts(for_op) + yield_op = for_op.body.operations[len(for_op.body.operations) - 1] + yield_layouts = layouts.in_layouts(yield_op) + if in_layouts != out_layouts or in_layouts != yield_layouts: + raise ValueError("Layout mismatch") + fa_layouts = in_layouts + + fa_layouts_it = iter(fa_layouts) + arg_template = [ + (_fragmented_array_from_ir(arg, next(fa_layouts_it)), arg.type) + if ir.VectorType.isinstance(arg.type) + else (arg, arg.type) + for arg in for_op.initArgs + ] + def lower_carry(carry): + fa_layouts_it = iter(fa_layouts) + carry_with_fas = [ + _fragmented_array_from_ir(arg, next(fa_layouts_it)) + if ir.VectorType.isinstance(arg.type) + else arg + for arg in carry + ] + lowered_carry = [] + for c in carry_with_fas: + if isinstance(c, fa.FragmentedArray): + lowered_carry.extend(c.registers.flat) + else: + lowered_carry.append(c) + return lowered_carry + + def recreate_carry(lowered_carry): + recreated_carry = [] + arg_it = iter(lowered_carry) + for arg_value, arg_type in arg_template: + if isinstance(arg_value, fa.FragmentedArray): + carry_registers = np.asarray( + [next(arg_it) for _ in arg_value.registers.flat], dtype=object + ) + carry_registers = carry_registers.reshape(arg_value.registers.shape) + carry = fa.FragmentedArray( + _registers=carry_registers, + _layout=arg_value.layout, + _is_signed=arg_value.is_signed, + ) + recreated_carry.append(_fragmented_array_to_ir(carry, arg_type)) + else: + recreated_carry.append(next(arg_it)) + return recreated_carry + + new_for_op = scf.ForOp( + for_op.lowerBound, + for_op.upperBound, + for_op.step, + lower_carry(for_op.initArgs), + ) + with ir.InsertionPoint(new_for_op.body): + recreated_carry = recreate_carry(new_for_op.body.arguments[1:]) + ops_to_lower = [] + for op in for_op.body: + if op == yield_op: + continue + mgpu.private_operation_remove_from_parent(op) + mgpu.private_block_append_owned_operation(new_for_op.body, op) + ops_to_lower.append(op) + new_args = (new_for_op.induction_variable, *recreated_carry) + for old_carry, new_carry in zip(for_op.body.arguments, new_args, strict=True): + old_carry.replace_all_uses_with(new_carry) + for op in ops_to_lower: + ctx.lower_op(op) + new_yield_operands = lower_carry(yield_op.operands) + yield_op.erase() + scf.yield_(new_yield_operands) + return recreate_carry(new_for_op.results) + + @_register_lowering(func.FuncOp) @_register_lowering(gpu.LaunchOp) @_register_lowering(scf.IfOp) # TODO(apaszke,bchetioui): Add a proper rule. -@_register_lowering(scf.ForOp) # TODO(apaszke,bchetioui): Add a proper rule. @_register_lowering(scf.IndexSwitchOp) # TODO(apaszke,bchetioui): Add a proper rule. def _traverse_op_lowering_rule( ctx: LoweringContext, op: ir.OpView @@ -661,6 +742,7 @@ def _should_lower(op: ir.OpView) -> bool: def lower_mgpu_dialect( module: ir.Module, launch_context: launch_context.LaunchContext | None ): + # TODO(apaszke,bchetioui): Make sure the layouts match. # TODO(bchetioui): rethink this API. It doesn't make sense to pass in a full # module and to traverse all `gpu.LaunchOp`s if we have a `LaunchContext` that # references a single `gpu.LaunchOp`. diff --git a/jax/experimental/mosaic/gpu/layout_inference.py b/jax/experimental/mosaic/gpu/layout_inference.py index 520e83456816..41c7c02a97ec 100644 --- a/jax/experimental/mosaic/gpu/layout_inference.py +++ b/jax/experimental/mosaic/gpu/layout_inference.py @@ -317,6 +317,8 @@ def _infer_yield_op_layout(op: scf.YieldOp) -> OptionalLayouts: if not ir.VectorType.isinstance(result.type): continue if (layout := _value_layout(result)) is not None: + if layouts_lib.is_splat_fragmented_layout(layout): + return None layouts.append(layout) else: # Not all layouts could be inferred for vector ops. Return for now. diff --git a/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc b/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc index 4a09b19658d1..c73084abc99d 100644 --- a/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc +++ b/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc @@ -35,6 +35,8 @@ NB_MODULE(_mosaic_gpu_ext, m) { } }, nb::arg("context"), nb::arg("load") = true); + m.def("private_operation_remove_from_parent", mlirOperationRemoveFromParent); + m.def("private_block_append_owned_operation", mlirBlockAppendOwnedOperation); mlir::python::nanobind_adaptors::mlir_attribute_subclass( m, "TileTransformAttr", mlirMosaicGpuIsATileTransformAttr) diff --git a/tests/mosaic/gpu_dialect_test.py b/tests/mosaic/gpu_dialect_test.py index 1707f9238c07..1bee3123dbad 100644 --- a/tests/mosaic/gpu_dialect_test.py +++ b/tests/mosaic/gpu_dialect_test.py @@ -761,6 +761,47 @@ def check_type(ty: ir.Type): check_type(store1.valueToStore.type) check_type(store2.valueToStore.type) + def test_lowering_for(self): + shape = (4, 128) + i32 = ir.IntegerType.get_signless(32) + vec_ty = ir.VectorType.get(shape, i32) + splat_layout_attr = layouts.to_layout_attr(mgpu.WGSplatFragLayout(shape)) + strided_layout_attr = layouts.to_layout_attr( + mgpu.WGStridedFragLayout.from_shaped_type(vec_ty) + ) + with ir.InsertionPoint(self.module.body): + i1 = arith.constant(ir.IndexType.get(), 1) + c1 = arith.constant(i32, 1) + splat = vector.SplatOp( + ir.VectorType.get(shape, i32), arith.constant(i32, 1234), + ) + splat.attributes["out_layouts"] = ir.ArrayAttr.get([ + splat_layout_attr + ]) + ptr = llvm.mlir_undef(ir.Type.parse("!llvm.ptr")) + ref = mgpu_utils.ptr_as_memref(ptr, ir.MemRefType.get(shape, i32)) + i0 = arith.constant(ir.IndexType.get(), 0) + other_vec = vector.LoadOp(vec_ty, ref, [i0, i0]) + other_vec.attributes["out_layouts"] = ir.ArrayAttr.get([strided_layout_attr]) + for_op = scf.ForOp(i1, i1, i1, [c1, splat.result]) + for_op.attributes["in_layouts"] = ir.ArrayAttr.get([strided_layout_attr]) + for_op.attributes["out_layouts"] = ir.ArrayAttr.get([strided_layout_attr]) + with ir.InsertionPoint(for_op.body): + i, int_carry, vec_carry = for_op.body.arguments + new_int_carry = arith.addi(int_carry, arith.index_castui(i32, i)) + new_vec_carry = arith.AddIOp(vec_carry, other_vec) + new_vec_carry.attributes["in_layouts"] = ir.ArrayAttr.get([strided_layout_attr] * 2) + new_vec_carry.attributes["out_layouts"] = ir.ArrayAttr.get([strided_layout_attr]) + yield_op = scf.YieldOp([new_int_carry, new_vec_carry]) + yield_op.attributes["in_layouts"] = ir.ArrayAttr.get([strided_layout_attr]) + + mgpu.lower_mgpu_dialect(self.module, None) + self.module.operation.verify() + [for_op] = find_if(self.module, lambda op: isinstance(op, scf.ForOp)) + result_types = [r.type for r in for_op.results] + reg_vec_ty = ir.VectorType.get((2,), i32) + self.assertSequenceEqual(result_types, [i32, reg_vec_ty, reg_vec_ty]) + if __name__ == "__main__": parameterized.absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 51d769d445b1..7dae608a41cb 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -902,9 +902,8 @@ def kernel(x_ref, o_ref): force_while=[False, True], thread_semantics=[*plgpu.ThreadSemantics] ) def test_fori_loop_array(self, force_while, thread_semantics): - if thread_semantics == plgpu.ThreadSemantics.Warpgroup: - # TODO(apaszke,bchetioui,slebedev): Support while + array carries. - self.skipTest("WG semantics unsupported") + if force_while and thread_semantics == plgpu.ThreadSemantics.Warpgroup: + self.skipTest("WG semantics does not support force_while.") @functools.partial( pl.pallas_call,