Skip to content

Commit

Permalink
[Mosaic GPU][NFC] Move some functions to a new file called `inference…
Browse files Browse the repository at this point in the history
…_utils.py`.

The intent is to move utils that are useful for both layout inference and
transform inference to a shared location.

PiperOrigin-RevId: 731340075
  • Loading branch information
bchetioui authored and Google-ML-Automation committed Feb 28, 2025
1 parent 5a77070 commit 37d6a15
Show file tree
Hide file tree
Showing 4 changed files with 188 additions and 130 deletions.
36 changes: 21 additions & 15 deletions jax/experimental/mosaic/gpu/dialect_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import numpy as np

from . import fragmented_array as fa
from . import inference_utils
from . import launch_context
from . import layouts
from . import utils
Expand Down Expand Up @@ -61,7 +62,9 @@ def lower_op(self, op: ir.OpView):
lowering_rule = _lowerings[name]

# TODO(bchetioui): make sure all layouts are set here.
if layouts.should_have_layout(op) and not layouts.has_any_layout_set(op):
if inference_utils.should_have_layout(
op
) and not inference_utils.has_any_layout_set(op):
raise ValueError(f"{op} is missing a layout and can not be lowered.")

new_results = lowering_rule(self, op)
Expand Down Expand Up @@ -277,7 +280,7 @@ def _vector_store_op_lowering_rule(
f"for {vector_store_op}"
)

[to_store_layout] = layouts.in_layouts(vector_store_op)
[to_store_layout] = inference_utils.in_layouts(vector_store_op)
fragmented_array = _fragmented_array_from_ir(
vector_store_op.valueToStore, to_store_layout
)
Expand Down Expand Up @@ -437,8 +440,8 @@ def _binary_op_lowering_rule(
[fa.FragmentedArray, fa.FragmentedArray], fa.FragmentedArray
],
) -> Sequence[ir.Value]:
in_layouts = layouts.in_layouts(op)
[layout] = layouts.out_layouts(op)
in_layouts = inference_utils.in_layouts(op)
[layout] = inference_utils.out_layouts(op)
if any(in_layout != layout for in_layout in in_layouts):
raise ValueError("Layout mismatch")
lhs = _fragmented_array_from_ir(op.lhs, layout, is_signed)
Expand Down Expand Up @@ -492,8 +495,8 @@ def _binary_op_lowering_rule(
def _cmpi_op_lowering_rule(
_: LoweringContext, op: arith.CmpIOp
) -> Sequence[ir.Value]:
in_layouts = layouts.in_layouts(op)
[layout] = layouts.out_layouts(op)
in_layouts = inference_utils.in_layouts(op)
[layout] = inference_utils.out_layouts(op)
if any(in_layout != layout for in_layout in in_layouts):
raise ValueError("Layout mismatch")
impl, is_signed = CMPI_IMPLS[op.predicate.value]
Expand All @@ -516,8 +519,8 @@ def _cmpi_op_lowering_rule(
def _cmpf_op_lowering_rule(
_: LoweringContext, op: arith.CmpFOp
) -> Sequence[ir.Value]:
in_layouts = layouts.in_layouts(op)
[layout] = layouts.out_layouts(op)
in_layouts = inference_utils.in_layouts(op)
[layout] = inference_utils.out_layouts(op)
if any(in_layout != layout for in_layout in in_layouts):
raise ValueError("Layout mismatch")
impl = CMPF_IMPLS[op.predicate.value]
Expand All @@ -530,7 +533,10 @@ def _cmpf_op_lowering_rule(
def _mgpu_wgmma_op_lowering_rule(
_: LoweringContext, wgmma_op: mgpu.WGMMAOp
) -> Sequence[ir.Value]:
fa_layouts = (*layouts.in_layouts(wgmma_op), *layouts.out_layouts(wgmma_op))
fa_layouts = (
*inference_utils.in_layouts(wgmma_op),
*inference_utils.out_layouts(wgmma_op),
)
if not all(map(layouts.is_wgmma_fragmented_layout, fa_layouts)):
raise ValueError("Layout mismatch")
wgmma_layout = fa_layouts[0]
Expand Down Expand Up @@ -607,12 +613,12 @@ def _for_op_lowering_rule(
def _for_op_lowering_rule(
ctx: LoweringContext, for_op: scf.ForOp
) -> MlirLoweringRuleResult:
if not layouts.should_have_layout(for_op):
if not inference_utils.should_have_layout(for_op):
return _traverse_op_lowering_rule(ctx, for_op)
in_layouts = layouts.in_layouts(for_op)
out_layouts = layouts.out_layouts(for_op)
in_layouts = inference_utils.in_layouts(for_op)
out_layouts = inference_utils.out_layouts(for_op)
yield_op = for_op.body.operations[len(for_op.body.operations) - 1]
yield_layouts = layouts.in_layouts(yield_op)
yield_layouts = inference_utils.in_layouts(yield_op)
if in_layouts != out_layouts or in_layouts != yield_layouts:
raise ValueError("Layout mismatch")
fa_layouts = in_layouts
Expand Down Expand Up @@ -692,7 +698,7 @@ def recreate_carry(lowered_carry):
def _traverse_op_lowering_rule(
ctx: LoweringContext, op: ir.OpView
) -> MlirLoweringRuleResult:
if layouts.should_have_layout(op):
if inference_utils.should_have_layout(op):
raise ValueError(
f"Rule cannot handle an op with vector operands or results: {op}"
)
Expand Down Expand Up @@ -734,7 +740,7 @@ def _should_lower(op: ir.OpView) -> bool:
"""Returns 'true' if the operation should be lowered."""
return (
op.OPERATION_NAME.startswith("mosaic_gpu.")
or layouts.should_have_layout(op)
or inference_utils.should_have_layout(op)
or any(bool(b) for r in op.regions for b in r) # Does it have subblocks?
)

Expand Down
146 changes: 146 additions & 0 deletions jax/experimental/mosaic/gpu/inference_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Layout & transform inference convenience utils."""

from collections.abc import Callable, Sequence
import enum
import itertools
from typing import cast

from jax._src.lib.mlir import ir

MlirOperation = ir.Operation | ir.OpView

def in_layouts(op: MlirOperation) -> Sequence[ir.Attribute]:
"""Returns the in_layouts attribute of the given operation.
Raises:
ValueError: If the operation does not have an in_layouts attribute.
"""
if "in_layouts" not in op.attributes:
raise ValueError(f"{op} does not have an in_layouts attribute.")
return op.attributes["in_layouts"] # type: ignore


def out_layouts(op: MlirOperation) -> Sequence[ir.Attribute]:
"""Returns the out_layouts attribute of the given operation.
Raises:
ValueError: If the operation does not have an out_layouts attribute.
"""
if "out_layouts" not in op.attributes:
raise ValueError(f"{op} does not have an out_layouts attribute.")
return op.attributes["out_layouts"] # type: ignore


def should_have_layout(op: MlirOperation) -> bool:
"""Returns 'true' if the operation should be assigned a layout."""

is_array = lambda v: ir.VectorType.isinstance(v.type)
return any(map(is_array, itertools.chain(op.operands, op.results))) # type: ignore


def has_in_layouts_set(op: MlirOperation) -> bool:
return "in_layouts" in op.attributes


def has_out_layouts_set(op: MlirOperation) -> bool:
return "out_layouts" in op.attributes


def has_any_layout_set(op: MlirOperation) -> bool:
return has_in_layouts_set(op) or has_out_layouts_set(op)


def in_layout_for_operand(
op: MlirOperation,
operand: ir.Value,
) -> ir.Attribute | None:
"""Returns the layout of the operand in the given operation if it is set.
Raises:
ValueError: If `operand` is not an operand of `op`, or if `operand` is not a
Vector.
"""
if not ir.VectorType.isinstance(operand.type):
raise ValueError(f"{operand} is not a vector.")

operand_number = [
o for o in op.operands if ir.VectorType.isinstance(o.type)
].index(operand)

if not has_in_layouts_set(op):
return None

return in_layouts(op)[operand_number]


def value_layout(value: ir.Value) -> ir.Attribute | None:
"""Returns the layout for a given value as defined by its owner.
Raises:
ValueError: If `result` is not a Vector.
"""
if not ir.VectorType.isinstance(value.type):
raise ValueError(f"{value} is not a vector.")

owner = value.owner
if isinstance(owner, ir.Operation):
if not has_out_layouts_set(owner):
return None
value_result_number = [
r for r in owner.results if ir.VectorType.isinstance(r.type)
].index(value)
return out_layouts(owner)[value_result_number]

# Block case, useful when attempting to derive layouts for ops
# depending on function parameters, or loop block arguments.
if isinstance(owner, ir.Block):
owner_op = owner.owner
block = cast(ir.Block, owner)
if not has_in_layouts_set(owner_op):
return None
value_arg_number = [
r for r in block.arguments if ir.VectorType.isinstance(r.type)
].index(value)
return in_layouts(owner_op)[value_arg_number]

raise NotImplementedError(
f"{owner} is not a function block nor an operation."
)


class TraversalOrder(enum.Enum):
"""Traversal orders with respect to the data flow for IR."""

FORWARD = 1
BACKWARDS = 2


def traverse_op(
op: ir.OpView,
callback: Callable[[ir.OpView], None],
traversal_order: TraversalOrder = TraversalOrder.FORWARD,
):
"""Traverses the operation and applies the callback in the given order."""
for region in op.operation.regions:
for block in region:
if traversal_order == TraversalOrder.FORWARD:
ops_to_traverse = list(block)
else:
ops_to_traverse = reversed(list(block)) # type: ignore
for block_op in ops_to_traverse:
traverse_op(block_op, callback, traversal_order)
callback(op)
Loading

0 comments on commit 37d6a15

Please sign in to comment.