Skip to content

Commit

Permalink
Add functionality to allow promoting RC wheels during release
Browse files Browse the repository at this point in the history
List of changes:
1. Allow us to build a RC wheel when building release artifacts. This is done by modifying the build CLI to use the new JAX build rule and passing in the build options that control the wheel tag. A new build argument `use_new_wheel_build_rule` is introduced to the build CLI to avoid breaking anyone that uses the CLI and the old build rule. Note that this option will go way in the future when the build CLI migrates fully to the new build rule.
2. Change the upload script to upload both rc and release tagged wheels (changes internal)

PiperOrigin-RevId: 731121899
  • Loading branch information
nitins17 authored and Google-ML-Automation committed Feb 27, 2025
1 parent 6f57410 commit a65607c
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 63 deletions.
92 changes: 65 additions & 27 deletions build/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,15 @@
"jax-rocm-pjrt": "//jaxlib/tools:build_gpu_plugin_wheel",
}

# Dictionary with the new wheel build rule. ROCm is not supported yet. Note that
# when JAX migrates to the new wheel build rule fully, the build CLI will
# switch to the new wheel build rule by default.
WHEEL_BUILD_TARGET_DICT_NEW = {
"jax": "//:jax_wheel",
"jaxlib": "//jaxlib/tools:jaxlib_wheel",
"jax-cuda-plugin": "//jaxlib/tools:jax_cuda_plugin_wheel",
"jax-cuda-pjrt": "//jaxlib/tools:jax_cuda_pjrt_wheel",
}

def add_global_arguments(parser: argparse.ArgumentParser):
"""Adds all the global arguments that applies to all the CLI subcommands."""
Expand Down Expand Up @@ -147,6 +156,16 @@ def add_artifact_subcommand_arguments(parser: argparse.ArgumentParser):
""",
)

parser.add_argument(
"--use_new_wheel_build_rule",
action="store_true",
help=
"""
Whether to use the new wheel build rule. Temporary flag and will be
removed once JAX migrates to the new wheel build rule fully.
""",
)

parser.add_argument(
"--editable",
action="store_true",
Expand Down Expand Up @@ -386,7 +405,10 @@ async def main():
for option in args.bazel_startup_options:
bazel_command_base.append(option)

bazel_command_base.append("run")
if not args.use_new_wheel_build_rule or args.command == "requirements_update":
bazel_command_base.append("run")
else:
bazel_command_base.append("build")

if args.python_version:
# Do not add --repo_env=HERMETIC_PYTHON_VERSION with default args.python_version
Expand Down Expand Up @@ -592,13 +614,22 @@ async def main():
wheel_build_command_base.append("--config=cuda_libraries_from_stubs")

with open(".jax_configure.bazelrc", "w") as f:
jax_configure_options = utils.get_jax_configure_bazel_options(wheel_build_command_base.get_command_as_list())
jax_configure_options = utils.get_jax_configure_bazel_options(wheel_build_command_base.get_command_as_list(), args.use_new_wheel_build_rule)
if not jax_configure_options:
logging.error("Error retrieving the Bazel options to be written to .jax_configure.bazelrc, exiting.")
sys.exit(1)
f.write(jax_configure_options)
logging.info("Bazel options written to .jax_configure.bazelrc")

if args.use_new_wheel_build_rule:
if "rocm" in args.wheels:
logging.error("ROCm is not supported with the new wheel build rule.")
sys.exit(1)
logging.info("Using new wheel build rule")
wheel_build_targets = WHEEL_BUILD_TARGET_DICT_NEW
else:
wheel_build_targets = WHEEL_BUILD_TARGET_DICT

if args.configure_only:
logging.info("--configure_only is set so not running any Bazel commands.")
else:
Expand All @@ -611,12 +642,18 @@ async def main():
if ("plugin" in wheel or "pjrt" in wheel) and "jax" not in wheel:
wheel = "jax-" + wheel

if wheel not in WHEEL_BUILD_TARGET_DICT.keys():
logging.error(
"Incorrect wheel name provided, valid choices are jaxlib,"
" jax-cuda-plugin or cuda-plugin, jax-cuda-pjrt or cuda-pjrt,"
" jax-rocm-plugin or rocm-plugin, jax-rocm-pjrt or rocm-pjrt"
)
if wheel not in wheel_build_targets.keys():
if args.use_new_wheel_build_rule:
logging.error(
"Incorrect wheel name provided, valid choices are jax, jaxlib,"
" jax-cuda-plugin or cuda-plugin, jax-cuda-pjrt or cuda-pjrt"
)
else:
logging.error(
"Incorrect wheel name provided, valid choices are jaxlib,"
" jax-cuda-plugin or cuda-plugin, jax-cuda-pjrt or cuda-pjrt,"
" jax-rocm-plugin or rocm-plugin, jax-rocm-pjrt or rocm-pjrt"
)
sys.exit(1)

wheel_build_command = copy.deepcopy(wheel_build_command_base)
Expand All @@ -629,32 +666,33 @@ async def main():
)

# Append the build target to the Bazel command.
build_target = WHEEL_BUILD_TARGET_DICT[wheel]
build_target = wheel_build_targets[wheel]
wheel_build_command.append(build_target)

wheel_build_command.append("--")
if not args.use_new_wheel_build_rule:
wheel_build_command.append("--")

if args.editable:
logger.info("Building an editable build")
output_path = os.path.join(output_path, wheel)
wheel_build_command.append("--editable")
if args.editable:
logger.info("Building an editable build")
output_path = os.path.join(output_path, wheel)
wheel_build_command.append("--editable")

wheel_build_command.append(f'--output_path="{output_path}"')
wheel_build_command.append(f"--cpu={target_cpu}")
wheel_build_command.append(f'--output_path="{output_path}"')
wheel_build_command.append(f"--cpu={target_cpu}")

if "cuda" in wheel:
wheel_build_command.append("--enable-cuda=True")
if args.cuda_version:
cuda_major_version = args.cuda_version.split(".")[0]
else:
cuda_major_version = args.cuda_major_version
wheel_build_command.append(f"--platform_version={cuda_major_version}")
if "cuda" in wheel:
wheel_build_command.append("--enable-cuda=True")
if args.cuda_version:
cuda_major_version = args.cuda_version.split(".")[0]
else:
cuda_major_version = args.cuda_major_version
wheel_build_command.append(f"--platform_version={cuda_major_version}")

if "rocm" in wheel:
wheel_build_command.append("--enable-rocm=True")
wheel_build_command.append(f"--platform_version={args.rocm_version}")
if "rocm" in wheel:
wheel_build_command.append("--enable-rocm=True")
wheel_build_command.append(f"--platform_version={args.rocm_version}")

wheel_build_command.append(f"--jaxlib_git_hash={git_hash}")
wheel_build_command.append(f"--jaxlib_git_hash={git_hash}")

result = await executor.run(wheel_build_command.get_command_as_string(), args.dry_run, args.detailed_timestamped_log)
# Exit with error if any wheel build fails.
Expand Down
10 changes: 7 additions & 3 deletions build/tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,11 +213,15 @@ def get_gcc_major_version(gcc_path: str):
return major_version


def get_jax_configure_bazel_options(bazel_command: list[str]):
def get_jax_configure_bazel_options(bazel_command: list[str], use_new_wheel_build_rule: bool):
"""Returns the bazel options to be written to .jax_configure.bazelrc."""
# Get the index of the "run" parameter. Build options will come after "run" so
# we find the index of "run" and filter everything after it.
start = bazel_command.index("run")
# we find the index of "run" and filter everything after it. If we are using
# the new wheel build rule, we will find the index of "build" instead.
if use_new_wheel_build_rule:
start = bazel_command.index("build")
else:
start = bazel_command.index("run")
jax_configure_bazel_options = ""
try:
for i in range(start + 1, len(bazel_command)):
Expand Down
90 changes: 57 additions & 33 deletions ci/build_artifacts.sh
Original file line number Diff line number Diff line change
Expand Up @@ -45,52 +45,76 @@ if [[ $os =~ "msys_nt" && $arch == "x86_64" ]]; then
arch="amd64"
fi

if [[ "${allowed_artifacts[@]}" =~ "${artifact}" ]]; then
if [[ "$JAXCI_ARTIFACT_TYPE" == "release" ]]; then
artifact_tag_flags="--bazel_options=--repo_env=ML_WHEEL_TYPE=release"
elif [[ "$JAXCI_ARTIFACT_TYPE" == "nightly" ]]; then
current_date=$(date +%Y%m%d)
artifact_tag_flags="--bazel_options=--repo_env=ML_WHEEL_BUILD_DATE=${current_date} --bazel_options=--repo_env=ML_WHEEL_TYPE=nightly"
elif [[ "$JAXCI_ARTIFACT_TYPE" == "default" ]]; then
artifact_tag_flags="--bazel_options=--repo_env=ML_WHEEL_TYPE=custom --bazel_options=--repo_env=ML_WHEEL_BUILD_DATE=$(git show -s --format=%as HEAD) --bazel_options=--repo_env=ML_WHEEL_GIT_HASH=$(git rev-parse HEAD) --bazel_options=--//jaxlib/tools:jaxlib_git_hash=$(git rev-parse HEAD)"
else
echo "Error: Invalid artifact type: $JAXCI_ARTIFACT_TYPE. Allowed values are: release, nightly, default"
exit 1
fi

# Build the jax artifact
if [[ "$artifact" == "jax" ]]; then
python -m build --outdir $JAXCI_OUTPUT_DIR
else
if [[ "${allowed_artifacts[@]}" =~ "${artifact}" ]]; then
# Figure out the bazelrc config to use. We will use one of the "rbe_"/"ci_"
# flags in the .bazelrc depending upon the platform we are building for.
bazelrc_config="${os}_${arch}"

# Figure out the bazelrc config to use. We will use one of the "rbe_"/"ci_"
# flags in the .bazelrc depending upon the platform we are building for.
bazelrc_config="${os}_${arch}"
# On platforms with no RBE support, we can use the Bazel remote cache. Set
# it to be empty by default to avoid unbound variable errors.
bazel_remote_cache=""

# On platforms with no RBE support, we can use the Bazel remote cache. Set
# it to be empty by default to avoid unbound variable errors.
bazel_remote_cache=""
if [[ "$JAXCI_BUILD_ARTIFACT_WITH_RBE" == 1 ]]; then
bazelrc_config="rbe_${bazelrc_config}"
else
bazelrc_config="ci_${bazelrc_config}"

if [[ "$JAXCI_BUILD_ARTIFACT_WITH_RBE" == 1 ]]; then
bazelrc_config="rbe_${bazelrc_config}"
# Set remote cache flags. Pushes to the cache bucket is limited to JAX's
# CI system.
if [[ "$JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE" == 1 ]]; then
bazel_remote_cache="--bazel_options=--config=public_cache_push"
else
bazelrc_config="ci_${bazelrc_config}"

# Set remote cache flags. Pushes to the cache bucket is limited to JAX's
# CI system.
if [[ "$JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE" == 1 ]]; then
bazel_remote_cache="--bazel_options=--config=public_cache_push"
else
bazel_remote_cache="--bazel_options=--config=public_cache"
fi
bazel_remote_cache="--bazel_options=--config=public_cache"
fi
fi

# Use the "_cuda" configs when building the CUDA artifacts.
if [[ ("$artifact" == "jax-cuda-plugin") || ("$artifact" == "jax-cuda-pjrt") ]]; then
bazelrc_config="${bazelrc_config}_cuda"
fi
# Use the "_cuda" configs when building the CUDA artifacts.
if [[ ("$artifact" == "jax-cuda-plugin") || ("$artifact" == "jax-cuda-pjrt") ]]; then
bazelrc_config="${bazelrc_config}_cuda"
fi

# Build the artifact.
# Build the artifact.
python build/build.py build --wheels="$artifact" \
--bazel_options=--config="$bazelrc_config" $bazel_remote_cache \
--python_version=$JAXCI_HERMETIC_PYTHON_VERSION \
--verbose --detailed_timestamped_log --use_new_wheel_build_rule \
$artifact_tag_flags

# If building release artifacts, we also build a release candidate ("rc")
# tagged wheel.
if [[ "$JAXCI_ARTIFACT_TYPE" == "release" ]]; then
python build/build.py build --wheels="$artifact" \
--bazel_options=--config="$bazelrc_config" $bazel_remote_cache \
--python_version=$JAXCI_HERMETIC_PYTHON_VERSION \
--verbose --detailed_timestamped_log
--verbose --detailed_timestamped_log --use_new_wheel_build_rule \
$artifact_tag_flags --bazel_options=--repo_env=ML_WHEEL_VERSION_SUFFIX="$JAXCI_WHEEL_RC_VERSION"
fi

# If building `jaxlib` or `jax-cuda-plugin` or `jax-cuda-pjrt` for Linux, we
# run `auditwheel show` to verify manylinux compliance.
if [[ "$os" == "linux" ]]; then
./ci/utilities/run_auditwheel.sh
fi
# Move the built artifacts from the Bazel cache directory to the output
# directory.
if [[ "$artifact" == "jax" ]]; then
mv bazel-bin/dist/*.whl "$JAXCI_OUTPUT_DIR"
mv bazel-bin/dist/*.tar.gz "$JAXCI_OUTPUT_DIR"
else
mv bazel-bin/jaxlib/tools//dist/*.whl "$JAXCI_OUTPUT_DIR"
fi

# If building `jaxlib` or `jax-cuda-plugin` or `jax-cuda-pjrt` for Linux, we
# run `auditwheel show` to verify manylinux compliance.
if [[ "$os" == "linux" ]] && [[ "$artifact" != "jax" ]]; then
./ci/utilities/run_auditwheel.sh
fi

else
Expand Down
9 changes: 9 additions & 0 deletions ci/envs/default.env
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,15 @@ export JAXCI_BUILD_ARTIFACT_WITH_RBE=${JAXCI_BUILD_ARTIFACT_WITH_RBE:-0}
# flag is enabled only for CI builds.
export JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE=${JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE:-0}

# Type of artifacts to build. Valid values are "default", "release", "nightly".
# This affecets the wheel naming/tag.
export JAXCI_ARTIFACT_TYPE=${JAXCI_ARTIFACT_TYPE:-"default"}

# When building release artifacts, we build a release candidate wheel ("rc"
# tagged wheel) in addition to the release wheel. This environment variable
# sets the version of the release candidate ("RC") artifact to build.
export JAXCI_WHEEL_RC_VERSION=${JAXCI_WHEEL_RC_VERSION:-}

# #############################################################################
# Test script specific environment variables.
# #############################################################################
Expand Down
3 changes: 3 additions & 0 deletions ci/utilities/setup_build_environment.sh
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,6 @@ function retry {

# Retry "bazel --version" 3 times to avoid flakiness when downloading bazel.
retry "bazel --version"

# Create the output directory if it doesn't exist.
mkdir -p "$JAXCI_OUTPUT_DIR"

0 comments on commit a65607c

Please sign in to comment.