Skip to content

Commit

Permalink
Add build targets for jax-rocm-plugin and jax-rocm-pjrt wheels.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 731832189
  • Loading branch information
Google-ML-Automation committed Feb 28, 2025
1 parent bb96226 commit 83d05a9
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 0 deletions.
3 changes: 3 additions & 0 deletions jaxlib/jax.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,7 @@ def jax_wheel(
build_wheel_only = True,
editable = False,
enable_cuda = False,
enable_rocm = False,
platform_version = "",
source_files = []):
"""Create jax artifact wheels.
Expand All @@ -494,6 +495,7 @@ def jax_wheel(
editable: whether to build an editable wheel
platform_independent: whether to build a wheel without platform tag
enable_cuda: whether to build a cuda wheel
enable_rocm: whether to build a rocm wheel
platform_version: the cuda version to use for the wheel
source_files: the source files to include in the wheel
Expand All @@ -509,6 +511,7 @@ def jax_wheel(
build_wheel_only = build_wheel_only,
editable = editable,
enable_cuda = enable_cuda,
enable_rocm = enable_rocm,
platform_version = platform_version,
# git_hash is empty by default. Use `--//jaxlib/tools:jaxlib_git_hash=$(git rev-parse HEAD)`
# flag in bazel command to pass the git hash for nightly or release builds.
Expand Down
36 changes: 36 additions & 0 deletions jaxlib/tools/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,24 @@ jax_wheel(
wheel_name = "jax_cuda12_plugin",
)

jax_wheel(
name = "jax_rocm_plugin_wheel",
enable_rocm = True,
no_abi = False,
platform_version = "60",
wheel_binary = ":build_gpu_kernels_wheel",
wheel_name = "jax_rocm60_plugin",
)

jax_wheel(
name = "jax_rocm_plugin_wheel_editable",
editable = True,
enable_rocm = True,
platform_version = "60",
wheel_binary = ":build_gpu_kernels_wheel",
wheel_name = "jax_rocm60_plugin",
)

jax_wheel(
name = "jax_cuda_pjrt_wheel",
enable_cuda = True,
Expand All @@ -258,6 +276,24 @@ jax_wheel(
wheel_name = "jax_cuda12_pjrt",
)

jax_wheel(
name = "jax_rocm_pjrt_wheel",
enable_rocm = True,
no_abi = True,
platform_version = "60",
wheel_binary = ":build_gpu_plugin_wheel",
wheel_name = "jax_rocm60_pjrt",
)

jax_wheel(
name = "jax_rocm_pjrt_wheel_editable",
editable = True,
enable_rocm = True,
platform_version = "60",
wheel_binary = ":build_gpu_plugin_wheel",
wheel_name = "jax_rocm60_pjrt",
)

AARCH64_MANYLINUX_TAG = "_".join(PLATFORM_TAGS_DICT[("Linux", "aarch64")])

PPC64LE_MANYLINUX_TAG = "_".join(PLATFORM_TAGS_DICT[("Linux", "ppc64le")])
Expand Down

0 comments on commit 83d05a9

Please sign in to comment.