diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index 290df2af6b06..4602ba99e111 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -34,6 +34,7 @@ import numpy as np from . import fragmented_array as fa +from . import inference_utils from . import launch_context from . import layouts from . import utils @@ -61,7 +62,9 @@ def lower_op(self, op: ir.OpView): lowering_rule = _lowerings[name] # TODO(bchetioui): make sure all layouts are set here. - if layouts.should_have_layout(op) and not layouts.has_any_layout_set(op): + if inference_utils.should_have_layout( + op + ) and not inference_utils.has_any_layout_set(op): raise ValueError(f"{op} is missing a layout and can not be lowered.") new_results = lowering_rule(self, op) @@ -277,7 +280,7 @@ def _vector_store_op_lowering_rule( f"for {vector_store_op}" ) - [to_store_layout] = layouts.in_layouts(vector_store_op) + [to_store_layout] = inference_utils.in_layouts(vector_store_op) fragmented_array = _fragmented_array_from_ir( vector_store_op.valueToStore, to_store_layout ) @@ -437,8 +440,8 @@ def _binary_op_lowering_rule( [fa.FragmentedArray, fa.FragmentedArray], fa.FragmentedArray ], ) -> Sequence[ir.Value]: - in_layouts = layouts.in_layouts(op) - [layout] = layouts.out_layouts(op) + in_layouts = inference_utils.in_layouts(op) + [layout] = inference_utils.out_layouts(op) if any(in_layout != layout for in_layout in in_layouts): raise ValueError("Layout mismatch") lhs = _fragmented_array_from_ir(op.lhs, layout, is_signed) @@ -492,8 +495,8 @@ def _binary_op_lowering_rule( def _cmpi_op_lowering_rule( _: LoweringContext, op: arith.CmpIOp ) -> Sequence[ir.Value]: - in_layouts = layouts.in_layouts(op) - [layout] = layouts.out_layouts(op) + in_layouts = inference_utils.in_layouts(op) + [layout] = inference_utils.out_layouts(op) if any(in_layout != layout for in_layout in in_layouts): raise ValueError("Layout mismatch") impl, is_signed = CMPI_IMPLS[op.predicate.value] @@ -516,8 +519,8 @@ def _cmpi_op_lowering_rule( def _cmpf_op_lowering_rule( _: LoweringContext, op: arith.CmpFOp ) -> Sequence[ir.Value]: - in_layouts = layouts.in_layouts(op) - [layout] = layouts.out_layouts(op) + in_layouts = inference_utils.in_layouts(op) + [layout] = inference_utils.out_layouts(op) if any(in_layout != layout for in_layout in in_layouts): raise ValueError("Layout mismatch") impl = CMPF_IMPLS[op.predicate.value] @@ -530,7 +533,10 @@ def _cmpf_op_lowering_rule( def _mgpu_wgmma_op_lowering_rule( _: LoweringContext, wgmma_op: mgpu.WGMMAOp ) -> Sequence[ir.Value]: - fa_layouts = (*layouts.in_layouts(wgmma_op), *layouts.out_layouts(wgmma_op)) + fa_layouts = ( + *inference_utils.in_layouts(wgmma_op), + *inference_utils.out_layouts(wgmma_op), + ) if not all(map(layouts.is_wgmma_fragmented_layout, fa_layouts)): raise ValueError("Layout mismatch") wgmma_layout = fa_layouts[0] @@ -607,12 +613,12 @@ def _for_op_lowering_rule( def _for_op_lowering_rule( ctx: LoweringContext, for_op: scf.ForOp ) -> MlirLoweringRuleResult: - if not layouts.should_have_layout(for_op): + if not inference_utils.should_have_layout(for_op): return _traverse_op_lowering_rule(ctx, for_op) - in_layouts = layouts.in_layouts(for_op) - out_layouts = layouts.out_layouts(for_op) + in_layouts = inference_utils.in_layouts(for_op) + out_layouts = inference_utils.out_layouts(for_op) yield_op = for_op.body.operations[len(for_op.body.operations) - 1] - yield_layouts = layouts.in_layouts(yield_op) + yield_layouts = inference_utils.in_layouts(yield_op) if in_layouts != out_layouts or in_layouts != yield_layouts: raise ValueError("Layout mismatch") fa_layouts = in_layouts @@ -692,7 +698,7 @@ def recreate_carry(lowered_carry): def _traverse_op_lowering_rule( ctx: LoweringContext, op: ir.OpView ) -> MlirLoweringRuleResult: - if layouts.should_have_layout(op): + if inference_utils.should_have_layout(op): raise ValueError( f"Rule cannot handle an op with vector operands or results: {op}" ) @@ -734,7 +740,7 @@ def _should_lower(op: ir.OpView) -> bool: """Returns 'true' if the operation should be lowered.""" return ( op.OPERATION_NAME.startswith("mosaic_gpu.") - or layouts.should_have_layout(op) + or inference_utils.should_have_layout(op) or any(bool(b) for r in op.regions for b in r) # Does it have subblocks? ) diff --git a/jax/experimental/mosaic/gpu/inference_utils.py b/jax/experimental/mosaic/gpu/inference_utils.py new file mode 100644 index 000000000000..cad3bfc11df2 --- /dev/null +++ b/jax/experimental/mosaic/gpu/inference_utils.py @@ -0,0 +1,146 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Layout & transform inference convenience utils.""" + +from collections.abc import Callable, Sequence +import enum +import itertools +from typing import cast + +from jax._src.lib.mlir import ir + +MlirOperation = ir.Operation | ir.OpView + +def in_layouts(op: MlirOperation) -> Sequence[ir.Attribute]: + """Returns the in_layouts attribute of the given operation. + + Raises: + ValueError: If the operation does not have an in_layouts attribute. + """ + if "in_layouts" not in op.attributes: + raise ValueError(f"{op} does not have an in_layouts attribute.") + return op.attributes["in_layouts"] # type: ignore + + +def out_layouts(op: MlirOperation) -> Sequence[ir.Attribute]: + """Returns the out_layouts attribute of the given operation. + + Raises: + ValueError: If the operation does not have an out_layouts attribute. + """ + if "out_layouts" not in op.attributes: + raise ValueError(f"{op} does not have an out_layouts attribute.") + return op.attributes["out_layouts"] # type: ignore + + +def should_have_layout(op: MlirOperation) -> bool: + """Returns 'true' if the operation should be assigned a layout.""" + + is_array = lambda v: ir.VectorType.isinstance(v.type) + return any(map(is_array, itertools.chain(op.operands, op.results))) # type: ignore + + +def has_in_layouts_set(op: MlirOperation) -> bool: + return "in_layouts" in op.attributes + + +def has_out_layouts_set(op: MlirOperation) -> bool: + return "out_layouts" in op.attributes + + +def has_any_layout_set(op: MlirOperation) -> bool: + return has_in_layouts_set(op) or has_out_layouts_set(op) + + +def in_layout_for_operand( + op: MlirOperation, + operand: ir.Value, +) -> ir.Attribute | None: + """Returns the layout of the operand in the given operation if it is set. + + Raises: + ValueError: If `operand` is not an operand of `op`, or if `operand` is not a + Vector. + """ + if not ir.VectorType.isinstance(operand.type): + raise ValueError(f"{operand} is not a vector.") + + operand_number = [ + o for o in op.operands if ir.VectorType.isinstance(o.type) + ].index(operand) + + if not has_in_layouts_set(op): + return None + + return in_layouts(op)[operand_number] + + +def value_layout(value: ir.Value) -> ir.Attribute | None: + """Returns the layout for a given value as defined by its owner. + + Raises: + ValueError: If `result` is not a Vector. + """ + if not ir.VectorType.isinstance(value.type): + raise ValueError(f"{value} is not a vector.") + + owner = value.owner + if isinstance(owner, ir.Operation): + if not has_out_layouts_set(owner): + return None + value_result_number = [ + r for r in owner.results if ir.VectorType.isinstance(r.type) + ].index(value) + return out_layouts(owner)[value_result_number] + + # Block case, useful when attempting to derive layouts for ops + # depending on function parameters, or loop block arguments. + if isinstance(owner, ir.Block): + owner_op = owner.owner + block = cast(ir.Block, owner) + if not has_in_layouts_set(owner_op): + return None + value_arg_number = [ + r for r in block.arguments if ir.VectorType.isinstance(r.type) + ].index(value) + return in_layouts(owner_op)[value_arg_number] + + raise NotImplementedError( + f"{owner} is not a function block nor an operation." + ) + + +class TraversalOrder(enum.Enum): + """Traversal orders with respect to the data flow for IR.""" + + FORWARD = 1 + BACKWARDS = 2 + + +def traverse_op( + op: ir.OpView, + callback: Callable[[ir.OpView], None], + traversal_order: TraversalOrder = TraversalOrder.FORWARD, +): + """Traverses the operation and applies the callback in the given order.""" + for region in op.operation.regions: + for block in region: + if traversal_order == TraversalOrder.FORWARD: + ops_to_traverse = list(block) + else: + ops_to_traverse = reversed(list(block)) # type: ignore + for block_op in ops_to_traverse: + traverse_op(block_op, callback, traversal_order) + callback(op) diff --git a/jax/experimental/mosaic/gpu/layout_inference.py b/jax/experimental/mosaic/gpu/layout_inference.py index 41c7c02a97ec..d5afeb69ac8e 100644 --- a/jax/experimental/mosaic/gpu/layout_inference.py +++ b/jax/experimental/mosaic/gpu/layout_inference.py @@ -26,6 +26,7 @@ from jax._src.lib.mlir.dialects import vector from . import fragmented_array as fa +from . import inference_utils from . import layouts as layouts_lib # mypy: ignore-errors @@ -120,63 +121,6 @@ def _choose_representative_layout( return layouts_lib.to_layout_attr(splat_layout) -def _in_layout_for_operand( - op: ir.OpView, - operand: ir.Value, -) -> ir.Attribute | None: - """Returns the layout of the operand in the given operation if it is set. - - Raises: - ValueError: If `operand` is not an operand of `op`, or if `operand` is not a - Vector. - """ - if not ir.VectorType.isinstance(operand.type): - raise ValueError(f"{operand} is not a vector.") - - operand_number = [ - o for o in op.operands if ir.VectorType.isinstance(o.type) - ].index(operand) - - if not layouts_lib.has_in_layouts_set(op): - return None - - return layouts_lib.in_layouts(op)[operand_number] - - -def _value_layout(value: ir.Value) -> ir.Attribute | None: - """Returns the layout for a given value as defined by its owner. - - Raises: - ValueError: If `result` is not a Vector. - """ - if not ir.VectorType.isinstance(value.type): - raise ValueError(f"{value} is not a vector.") - - owner = value.owner - if isinstance(owner, ir.Operation): - if not layouts_lib.has_out_layouts_set(owner): - return None - value_result_number = [ - r for r in owner.results if ir.VectorType.isinstance(r.type) - ].index(value) - return layouts_lib.out_layouts(owner)[value_result_number] - - # Block case, useful when attempting to derive layouts for ops - # depending on function parameters, or loop block arguments. - if isinstance(owner, ir.Block): - owner_op = owner.owner - block = cast(ir.Block, owner) - if not layouts_lib.has_in_layouts_set(owner_op): - return None - value_arg_number = [ - r for r in block.arguments if ir.VectorType.isinstance(r.type) - ].index(value) - return layouts_lib.in_layouts(owner_op)[value_arg_number] - - raise NotImplementedError( - f"{owner} is not a function block nor an operation.") - - def _infer_pointwise_op_layouts(op: ir.OpView) -> OptionalLayouts: def is_array(v: ir.Value) -> bool: @@ -185,14 +129,14 @@ def is_array(v: ir.Value) -> bool: num_vector_operands = len([o for o in op.operands if is_array(o)]) num_vector_results = len([r for r in op.results if is_array(r)]) - if layouts_lib.has_in_layouts_set(op): - op_in_layouts = layouts_lib.in_layouts(op) + if inference_utils.has_in_layouts_set(op): + op_in_layouts = inference_utils.in_layouts(op) if op_in_layouts: layout = op_in_layouts[0] return (num_vector_operands * [layout], num_vector_results * [layout]) - if layouts_lib.has_out_layouts_set(op): - op_out_layouts = layouts_lib.out_layouts(op) + if inference_utils.has_out_layouts_set(op): + op_out_layouts = inference_utils.out_layouts(op) if op_out_layouts: layout = op_out_layouts[0] return (num_vector_operands * [layout], num_vector_results * [layout]) @@ -209,7 +153,7 @@ def is_array(v: ir.Value) -> bool: for operand in op.operands: if not ir.VectorType.isinstance(operand.type): continue - if (layout := _value_layout(operand)) is not None: + if (layout := inference_utils.value_layout(operand)) is not None: layouts.add(layout) else: all_inputs_have_layout = False @@ -224,7 +168,7 @@ def is_array(v: ir.Value) -> bool: for op_operand_use in cast(ir.OpResult, op_result).uses: consumer = op_operand_use.owner op_user = consumer.operands[op_operand_use.operand_number] - layout = _in_layout_for_operand(consumer, op_user) + layout = inference_utils.in_layout_for_operand(consumer, op_user) if layout is not None: layouts.add(layout) @@ -295,7 +239,7 @@ def _infer_constant_op_layout(constant_op: arith.ConstantOp) -> OptionalLayouts: for use in cast(ir.OpResult, constant_op.result).uses: consumer = use.owner operand = consumer.operands[use.operand_number] - layout = _in_layout_for_operand(consumer, operand) + layout = inference_utils.in_layout_for_operand(consumer, operand) if layout is not None: break @@ -316,7 +260,7 @@ def _infer_yield_op_layout(op: scf.YieldOp) -> OptionalLayouts: for result in op.results_: if not ir.VectorType.isinstance(result.type): continue - if (layout := _value_layout(result)) is not None: + if (layout := inference_utils.value_layout(result)) is not None: if layouts_lib.is_splat_fragmented_layout(layout): return None layouts.append(layout) @@ -332,8 +276,8 @@ def _infer_for_op_layout(op: scf.ForOp) -> OptionalLayouts: yield_op = op.body.operations[len(op.body.operations) - 1] assert isinstance(yield_op, scf.YieldOp) - if layouts_lib.has_in_layouts_set(yield_op): - yield_layouts = list(layouts_lib.in_layouts(yield_op)) + if inference_utils.has_in_layouts_set(yield_op): + yield_layouts = list(inference_utils.in_layouts(yield_op)) if any( layouts_lib.is_splat_fragmented_layout(layout) for layout in yield_layouts @@ -394,7 +338,7 @@ def traverse_op( def infer_layout(module: ir.Module): def inference_step(op: ir.Operation): - if not layouts_lib.should_have_layout(op): + if not inference_utils.should_have_layout(op): return elif inference_rule := _layout_inference_rules.get(op.OPERATION_NAME, None): # pytype: disable=attribute-error pass @@ -419,11 +363,15 @@ def inference_step(op: ir.Operation): # # Backwards pass for op in module.body: - traverse_op(op, inference_step, TraversalOrder.BACKWARDS) + inference_utils.traverse_op( + op, inference_step, inference_utils.TraversalOrder.BACKWARDS + ) # Forward pass for op in module.body: - traverse_op(op, inference_step, TraversalOrder.FORWARD) + inference_utils.traverse_op( + op, inference_step, inference_utils.TraversalOrder.FORWARD + ) # At this point, layouts have been propagated as far as they could be # propagated. However, it is possible for some operations to remain @@ -437,8 +385,9 @@ def to_default_layout(ty: ir.Type) -> ir.Attribute | None: return layouts_lib.to_strided_fragmented_layout_attr(layout) def set_default_layout(op: ir.OpView): - if (layouts_lib.should_have_layout(op) and - not layouts_lib.has_any_layout_set(op)): + if inference_utils.should_have_layout( + op + ) and not inference_utils.has_any_layout_set(op): in_layouts = [] for operand in op.operands: if (layout := to_default_layout(operand.type)) is not None: diff --git a/jax/experimental/mosaic/gpu/layouts.py b/jax/experimental/mosaic/gpu/layouts.py index 9d77f401931c..334ebeddd005 100644 --- a/jax/experimental/mosaic/gpu/layouts.py +++ b/jax/experimental/mosaic/gpu/layouts.py @@ -14,8 +14,6 @@ """Layout utilities.""" -from collections.abc import Sequence -import itertools import re from jax._src.lib.mlir import ir @@ -159,44 +157,3 @@ def from_layout_attr( raise NotImplementedError( f"Unsupported layout for conversion from MLIR attribute: {attr}" ) - - -def in_layouts(op: ir.OpView) -> Sequence[ir.Attribute]: - """Returns the in_layouts attribute of the given operation. - - Raises: - ValueError: If the operation does not have an in_layouts attribute. - """ - if "in_layouts" not in op.attributes: - raise ValueError(f"{op} does not have an in_layouts attribute.") - return op.attributes["in_layouts"] # type: ignore - - -def out_layouts(op: ir.OpView) -> Sequence[ir.Attribute]: - """Returns the out_layouts attribute of the given operation. - - Raises: - ValueError: If the operation does not have an out_layouts attribute. - """ - if "out_layouts" not in op.attributes: - raise ValueError(f"{op} does not have an out_layouts attribute.") - return op.attributes["out_layouts"] # type: ignore - - -def should_have_layout(op: ir.OpView) -> bool: - """Returns 'true' if the operation should be assigned a layout.""" - - is_array = lambda v: ir.VectorType.isinstance(v.type) - return any(map(is_array, itertools.chain(op.operands, op.results))) # type: ignore - - -def has_in_layouts_set(op: ir.OpView) -> bool: - return "in_layouts" in op.attributes - - -def has_out_layouts_set(op: ir.OpView) -> bool: - return "out_layouts" in op.attributes - - -def has_any_layout_set(op: ir.OpView) -> bool: - return has_in_layouts_set(op) or has_out_layouts_set(op)