From 8f57b8167b950d63041e9c53d0a1cf3e693597bd Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 28 Feb 2025 08:36:08 -0800 Subject: [PATCH] Add build targets for `jax-rocm-plugin` and `jax-rocm-pjrt` wheels. PiperOrigin-RevId: 732149495 --- jaxlib/jax.bzl | 3 +++ jaxlib/tools/BUILD.bazel | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+) diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 6958f1dbaedf..eb016da80546 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -479,6 +479,7 @@ def jax_wheel( build_wheel_only = True, editable = False, enable_cuda = False, + enable_rocm = False, platform_version = "", source_files = []): """Create jax artifact wheels. @@ -494,6 +495,7 @@ def jax_wheel( editable: whether to build an editable wheel platform_independent: whether to build a wheel without platform tag enable_cuda: whether to build a cuda wheel + enable_rocm: whether to build a rocm wheel platform_version: the cuda version to use for the wheel source_files: the source files to include in the wheel @@ -509,6 +511,7 @@ def jax_wheel( build_wheel_only = build_wheel_only, editable = editable, enable_cuda = enable_cuda, + enable_rocm = enable_rocm, platform_version = platform_version, # git_hash is empty by default. Use `--//jaxlib/tools:jaxlib_git_hash=$(git rev-parse HEAD)` # flag in bazel command to pass the git hash for nightly or release builds. diff --git a/jaxlib/tools/BUILD.bazel b/jaxlib/tools/BUILD.bazel index de9f636ed8d5..b95483b22712 100644 --- a/jaxlib/tools/BUILD.bazel +++ b/jaxlib/tools/BUILD.bazel @@ -238,6 +238,24 @@ jax_wheel( wheel_name = "jax_cuda12_plugin", ) +jax_wheel( + name = "jax_rocm_plugin_wheel", + enable_rocm = True, + no_abi = False, + platform_version = "60", + wheel_binary = ":build_gpu_kernels_wheel", + wheel_name = "jax_rocm60_plugin", +) + +jax_wheel( + name = "jax_rocm_plugin_wheel_editable", + editable = True, + enable_rocm = True, + platform_version = "60", + wheel_binary = ":build_gpu_kernels_wheel", + wheel_name = "jax_rocm60_plugin", +) + jax_wheel( name = "jax_cuda_pjrt_wheel", enable_cuda = True, @@ -258,6 +276,24 @@ jax_wheel( wheel_name = "jax_cuda12_pjrt", ) +jax_wheel( + name = "jax_rocm_pjrt_wheel", + enable_rocm = True, + no_abi = True, + platform_version = "60", + wheel_binary = ":build_gpu_plugin_wheel", + wheel_name = "jax_rocm60_pjrt", +) + +jax_wheel( + name = "jax_rocm_pjrt_wheel_editable", + editable = True, + enable_rocm = True, + platform_version = "60", + wheel_binary = ":build_gpu_plugin_wheel", + wheel_name = "jax_rocm60_pjrt", +) + AARCH64_MANYLINUX_TAG = "_".join(PLATFORM_TAGS_DICT[("Linux", "aarch64")]) PPC64LE_MANYLINUX_TAG = "_".join(PLATFORM_TAGS_DICT[("Linux", "ppc64le")])