Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Mosaic GPU][NFC] Move some functions to a new file called inference_utils.py. #26776

Merged
merged 1 commit into from
Feb 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 21 additions & 15 deletions jax/experimental/mosaic/gpu/dialect_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import numpy as np

from . import fragmented_array as fa
from . import inference_utils
from . import launch_context
from . import layouts
from . import utils
Expand Down Expand Up @@ -61,7 +62,9 @@ def lower_op(self, op: ir.OpView):
lowering_rule = _lowerings[name]

# TODO(bchetioui): make sure all layouts are set here.
if layouts.should_have_layout(op) and not layouts.has_any_layout_set(op):
if inference_utils.should_have_layout(
op
) and not inference_utils.has_any_layout_set(op):
raise ValueError(f"{op} is missing a layout and can not be lowered.")

new_results = lowering_rule(self, op)
Expand Down Expand Up @@ -277,7 +280,7 @@ def _vector_store_op_lowering_rule(
f"for {vector_store_op}"
)

[to_store_layout] = layouts.in_layouts(vector_store_op)
[to_store_layout] = inference_utils.in_layouts(vector_store_op)
fragmented_array = _fragmented_array_from_ir(
vector_store_op.valueToStore, to_store_layout
)
Expand Down Expand Up @@ -437,8 +440,8 @@ def _binary_op_lowering_rule(
[fa.FragmentedArray, fa.FragmentedArray], fa.FragmentedArray
],
) -> Sequence[ir.Value]:
in_layouts = layouts.in_layouts(op)
[layout] = layouts.out_layouts(op)
in_layouts = inference_utils.in_layouts(op)
[layout] = inference_utils.out_layouts(op)
if any(in_layout != layout for in_layout in in_layouts):
raise ValueError("Layout mismatch")
lhs = _fragmented_array_from_ir(op.lhs, layout, is_signed)
Expand Down Expand Up @@ -492,8 +495,8 @@ def _binary_op_lowering_rule(
def _cmpi_op_lowering_rule(
_: LoweringContext, op: arith.CmpIOp
) -> Sequence[ir.Value]:
in_layouts = layouts.in_layouts(op)
[layout] = layouts.out_layouts(op)
in_layouts = inference_utils.in_layouts(op)
[layout] = inference_utils.out_layouts(op)
if any(in_layout != layout for in_layout in in_layouts):
raise ValueError("Layout mismatch")
impl, is_signed = CMPI_IMPLS[op.predicate.value]
Expand All @@ -516,8 +519,8 @@ def _cmpi_op_lowering_rule(
def _cmpf_op_lowering_rule(
_: LoweringContext, op: arith.CmpFOp
) -> Sequence[ir.Value]:
in_layouts = layouts.in_layouts(op)
[layout] = layouts.out_layouts(op)
in_layouts = inference_utils.in_layouts(op)
[layout] = inference_utils.out_layouts(op)
if any(in_layout != layout for in_layout in in_layouts):
raise ValueError("Layout mismatch")
impl = CMPF_IMPLS[op.predicate.value]
Expand All @@ -530,7 +533,10 @@ def _cmpf_op_lowering_rule(
def _mgpu_wgmma_op_lowering_rule(
_: LoweringContext, wgmma_op: mgpu.WGMMAOp
) -> Sequence[ir.Value]:
fa_layouts = (*layouts.in_layouts(wgmma_op), *layouts.out_layouts(wgmma_op))
fa_layouts = (
*inference_utils.in_layouts(wgmma_op),
*inference_utils.out_layouts(wgmma_op),
)
if not all(map(layouts.is_wgmma_fragmented_layout, fa_layouts)):
raise ValueError("Layout mismatch")
wgmma_layout = fa_layouts[0]
Expand Down Expand Up @@ -607,12 +613,12 @@ def _for_op_lowering_rule(
def _for_op_lowering_rule(
ctx: LoweringContext, for_op: scf.ForOp
) -> MlirLoweringRuleResult:
if not layouts.should_have_layout(for_op):
if not inference_utils.should_have_layout(for_op):
return _traverse_op_lowering_rule(ctx, for_op)
in_layouts = layouts.in_layouts(for_op)
out_layouts = layouts.out_layouts(for_op)
in_layouts = inference_utils.in_layouts(for_op)
out_layouts = inference_utils.out_layouts(for_op)
yield_op = for_op.body.operations[len(for_op.body.operations) - 1]
yield_layouts = layouts.in_layouts(yield_op)
yield_layouts = inference_utils.in_layouts(yield_op)
if in_layouts != out_layouts or in_layouts != yield_layouts:
raise ValueError("Layout mismatch")
fa_layouts = in_layouts
Expand Down Expand Up @@ -692,7 +698,7 @@ def recreate_carry(lowered_carry):
def _traverse_op_lowering_rule(
ctx: LoweringContext, op: ir.OpView
) -> MlirLoweringRuleResult:
if layouts.should_have_layout(op):
if inference_utils.should_have_layout(op):
raise ValueError(
f"Rule cannot handle an op with vector operands or results: {op}"
)
Expand Down Expand Up @@ -734,7 +740,7 @@ def _should_lower(op: ir.OpView) -> bool:
"""Returns 'true' if the operation should be lowered."""
return (
op.OPERATION_NAME.startswith("mosaic_gpu.")
or layouts.should_have_layout(op)
or inference_utils.should_have_layout(op)
or any(bool(b) for r in op.regions for b in r) # Does it have subblocks?
)

Expand Down
146 changes: 146 additions & 0 deletions jax/experimental/mosaic/gpu/inference_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Layout & transform inference convenience utils."""

from collections.abc import Callable, Sequence
import enum
import itertools
from typing import cast

from jax._src.lib.mlir import ir

MlirOperation = ir.Operation | ir.OpView

def in_layouts(op: MlirOperation) -> Sequence[ir.Attribute]:
"""Returns the in_layouts attribute of the given operation.

Raises:
ValueError: If the operation does not have an in_layouts attribute.
"""
if "in_layouts" not in op.attributes:
raise ValueError(f"{op} does not have an in_layouts attribute.")
return op.attributes["in_layouts"] # type: ignore


def out_layouts(op: MlirOperation) -> Sequence[ir.Attribute]:
"""Returns the out_layouts attribute of the given operation.

Raises:
ValueError: If the operation does not have an out_layouts attribute.
"""
if "out_layouts" not in op.attributes:
raise ValueError(f"{op} does not have an out_layouts attribute.")
return op.attributes["out_layouts"] # type: ignore


def should_have_layout(op: MlirOperation) -> bool:
"""Returns 'true' if the operation should be assigned a layout."""

is_array = lambda v: ir.VectorType.isinstance(v.type)
return any(map(is_array, itertools.chain(op.operands, op.results))) # type: ignore


def has_in_layouts_set(op: MlirOperation) -> bool:
return "in_layouts" in op.attributes


def has_out_layouts_set(op: MlirOperation) -> bool:
return "out_layouts" in op.attributes


def has_any_layout_set(op: MlirOperation) -> bool:
return has_in_layouts_set(op) or has_out_layouts_set(op)


def in_layout_for_operand(
op: MlirOperation,
operand: ir.Value,
) -> ir.Attribute | None:
"""Returns the layout of the operand in the given operation if it is set.

Raises:
ValueError: If `operand` is not an operand of `op`, or if `operand` is not a
Vector.
"""
if not ir.VectorType.isinstance(operand.type):
raise ValueError(f"{operand} is not a vector.")

operand_number = [
o for o in op.operands if ir.VectorType.isinstance(o.type)
].index(operand)

if not has_in_layouts_set(op):
return None

return in_layouts(op)[operand_number]


def value_layout(value: ir.Value) -> ir.Attribute | None:
"""Returns the layout for a given value as defined by its owner.

Raises:
ValueError: If `result` is not a Vector.
"""
if not ir.VectorType.isinstance(value.type):
raise ValueError(f"{value} is not a vector.")

owner = value.owner
if isinstance(owner, ir.Operation):
if not has_out_layouts_set(owner):
return None
value_result_number = [
r for r in owner.results if ir.VectorType.isinstance(r.type)
].index(value)
return out_layouts(owner)[value_result_number]

# Block case, useful when attempting to derive layouts for ops
# depending on function parameters, or loop block arguments.
if isinstance(owner, ir.Block):
owner_op = owner.owner
block = cast(ir.Block, owner)
if not has_in_layouts_set(owner_op):
return None
value_arg_number = [
r for r in block.arguments if ir.VectorType.isinstance(r.type)
].index(value)
return in_layouts(owner_op)[value_arg_number]

raise NotImplementedError(
f"{owner} is not a function block nor an operation."
)


class TraversalOrder(enum.Enum):
"""Traversal orders with respect to the data flow for IR."""

FORWARD = 1
BACKWARDS = 2


def traverse_op(
op: ir.OpView,
callback: Callable[[ir.OpView], None],
traversal_order: TraversalOrder = TraversalOrder.FORWARD,
):
"""Traverses the operation and applies the callback in the given order."""
for region in op.operation.regions:
for block in region:
if traversal_order == TraversalOrder.FORWARD:
ops_to_traverse = list(block)
else:
ops_to_traverse = reversed(list(block)) # type: ignore
for block_op in ops_to_traverse:
traverse_op(block_op, callback, traversal_order)
callback(op)
Loading
Loading