Skip to content

Commit

Permalink
[xla:gpu] Make address computation fusion compatible with command buf…
Browse files Browse the repository at this point in the history
…fers

PiperOrigin-RevId: 606892947
  • Loading branch information
tyb0807 authored and tensorflower-gardener committed Feb 14, 2024
1 parent 4314cf6 commit 0445622
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 45 deletions.
2 changes: 2 additions & 0 deletions third_party/xla/xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3247,6 +3247,8 @@ cc_library(
"//xla:util",
"//xla/hlo/ir:hlo",
"//xla/service:hlo_pass",
"//xla/service/gpu:hlo_fusion_analysis",
"//xla/service/gpu:hlo_traversal",
"//xla/stream_executor:device_description",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
Expand Down
67 changes: 43 additions & 24 deletions third_party/xla/xla/service/gpu/command_buffer_scheduling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License.
#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <optional>
#include <utility>
#include <variant>
#include <vector>
Expand All @@ -39,6 +40,8 @@ limitations under the License.
#include "xla/hlo/ir/hlo_schedule.h"
#include "xla/service/gpu/backend_configs.pb.h"
#include "xla/service/gpu/cublas_cudnn.h"
#include "xla/service/gpu/hlo_fusion_analysis.h"
#include "xla/service/gpu/hlo_traversal.h"
#include "xla/service/gpu/variant_visitor.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
Expand All @@ -56,7 +59,8 @@ using CommandBufferConfig = CommandBufferScheduling::CommandBufferConfig;

// Returns true if HLO computation can be executed as a command buffer.
static bool IsCommand(const HloComputation* computation,
const CommandBufferConfig& config);
const CommandBufferConfig& config,
const se::DeviceDescription& device_description);

//===----------------------------------------------------------------------===//
// No-op HLO operations.
Expand Down Expand Up @@ -91,27 +95,30 @@ static bool IsNoOp(const HloInstruction* hlo) {
// This is a template to define pattern matching functions for HLO instructions
// that do not have a corresponding class for them.
template <HloOpcode op>
static bool IsCommand(const HloInstruction*, const CommandBufferConfig&);
static bool IsCommand(const HloInstruction*, const CommandBufferConfig&,
const se::DeviceDescription&);

// While loops can be executed inside command buffers only if condition and body
// regions can be executed as command buffers.
template <>
bool IsCommand<HloOpcode::kWhile>(const HloInstruction* hlo,
const CommandBufferConfig& config) {
bool IsCommand<HloOpcode::kWhile>(
const HloInstruction* hlo, const CommandBufferConfig& config,
const se::DeviceDescription& device_description) {
return config.contains(DebugOptions::CONDITIONALS) &&
IsCommand(hlo->while_body(), config) &&
IsCommand(hlo->while_condition(), config);
IsCommand(hlo->while_body(), config, device_description) &&
IsCommand(hlo->while_condition(), config, device_description);
}

// Conditional can be executed inside command buffers only if all regions of its
// branches can be executed as command buffers.
template <>
bool IsCommand<HloOpcode::kConditional>(const HloInstruction* hlo,
const CommandBufferConfig& config) {
bool IsCommand<HloOpcode::kConditional>(
const HloInstruction* hlo, const CommandBufferConfig& config,
const se::DeviceDescription& device_description) {
return config.contains(DebugOptions::CONDITIONALS) &&
absl::c_all_of(hlo->branch_computations(),
[&](const HloComputation* comp) {
return IsCommand(comp, config);
return IsCommand(comp, config, device_description);
});
}

Expand All @@ -127,16 +134,25 @@ static bool IsCommand(const HloCustomCallInstruction* hlo,
}

static bool IsCommand(const HloInstruction* hlo,
const CommandBufferConfig& config) {
const CommandBufferConfig& config,
const se::DeviceDescription& device_description) {
if (auto* fusion = DynCast<HloFusionInstruction>(hlo)) {
// TODO(vuson): Make address computation fusion compatible with command
// buffer
auto gpu_config = fusion->backend_config<GpuBackendConfig>();
const FusionBackendConfig& backend_config =
gpu_config->fusion_backend_config();
const auto& custom_config = backend_config.custom_fusion_config();
return custom_config.name() != "address_computation" &&
config.contains(DebugOptions::FUSION);
if (custom_config.name() == "address_computation") {
auto fusion_analysis =
HloFusionAnalysis::Create(fusion, &device_description);
const HloFusionAdaptor& adaptor = fusion_analysis.fusion();
auto custom_call_adaptor = HloFindIf(
adaptor.GetRoots(), adaptor,
[](auto node) { return node.opcode() == HloOpcode::kCustomCall; });
const auto* custom_call = static_cast<const HloCustomCallInstruction*>(
&custom_call_adaptor->instruction());
return IsCommand(custom_call, config);
}
return config.contains(DebugOptions::FUSION);
}

if (auto* sort = DynCast<HloSortInstruction>(hlo))
Expand All @@ -151,10 +167,10 @@ static bool IsCommand(const HloInstruction* hlo,
return IsCommand(custom_call, config);

if (hlo->opcode() == HloOpcode::kWhile)
return IsCommand<HloOpcode::kWhile>(hlo, config);
return IsCommand<HloOpcode::kWhile>(hlo, config, device_description);

if (hlo->opcode() == HloOpcode::kConditional)
return IsCommand<HloOpcode::kConditional>(hlo, config);
return IsCommand<HloOpcode::kConditional>(hlo, config, device_description);

return false;
}
Expand Down Expand Up @@ -221,11 +237,13 @@ static HloInstruction* FindAsyncDoneCommand(const HloInstruction* start) {

// Returns true if HLO computation can be executed as a command buffer.
static bool IsCommand(const HloComputation* computation,
const CommandBufferConfig& config) {
const CommandBufferConfig& config,
const se::DeviceDescription& device_description) {
return absl::c_all_of(
computation->instructions(), [&](const HloInstruction* inst) {
return IsNoOp(inst) || IsConstant(inst) || IsParameter(inst) ||
IsCommand(inst, config) || IsAsyncStartCommand(inst, config) ||
IsCommand(inst, config, device_description) ||
IsAsyncStartCommand(inst, config) ||
IsAsyncDoneCommand(inst, config);
});
}
Expand All @@ -252,7 +270,7 @@ static void RemoveTrailingNoOps(HloInstructionSequence& seq) {
std::vector<HloInstructionSequence>
CommandBufferScheduling::CollectCommandBufferSequences(
const HloInstructionSequence schedule, const CommandBufferConfig& config,
int32_t min_num_commands) {
const se::DeviceDescription& device_description, int32_t min_num_commands) {
std::vector<HloInstructionSequence> sequences;

HloInstructionSequence current_seq;
Expand Down Expand Up @@ -282,7 +300,7 @@ CommandBufferScheduling::CollectCommandBufferSequences(
}

// Synchronous commands always can be added to instruction sequence.
if (IsCommand(inst, config)) {
if (IsCommand(inst, config, device_description)) {
num_commands_in_current_seq++;
current_seq.push_back(inst);
continue;
Expand Down Expand Up @@ -574,9 +592,9 @@ absl::StatusOr<HloComputation*> CommandBufferScheduling::RewriteCommandBuffer(
//===----------------------------------------------------------------------===//

CommandBufferScheduling::CommandBufferScheduling(
const se::GpuComputeCapability& gpu_compute_comp,
const se::DeviceDescription& device_description,
int32_t gpu_toolkit_version, int32_t gpu_driver_version)
: gpu_compute_comp_(gpu_compute_comp),
: device_description_(device_description),
gpu_toolkit_version_(gpu_toolkit_version),
gpu_driver_version_(gpu_driver_version) {}

Expand Down Expand Up @@ -630,7 +648,8 @@ absl::StatusOr<bool> CommandBufferScheduling::Run(
return true; // check for ROCM support
};

if (std::visit(VariantVisitor{check_cuda, check_rocm}, gpu_compute_comp_)) {
if (std::visit(VariantVisitor{check_cuda, check_rocm},
device_description_.gpu_compute_capability())) {
erase(kRequireTracing); // cuStreamBeginCaptureToGraph
erase(kRequireConditionals); // on-device control flow
}
Expand All @@ -652,7 +671,7 @@ absl::StatusOr<bool> CommandBufferScheduling::Run(

std::vector<HloInstructionSequence> sequences =
CollectCommandBufferSequences(
module->schedule().sequence(comp), config,
module->schedule().sequence(comp), config, device_description_,
debug_options.xla_gpu_graph_min_graph_size());

for (const HloInstructionSequence& seq : sequences) {
Expand Down
8 changes: 4 additions & 4 deletions third_party/xla/xla/service/gpu/command_buffer_scheduling.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,11 @@ limitations under the License.
#include "absl/strings/string_view.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/ir/hlo_schedule.h"
#include "xla/service/hlo_pass_interface.h"
#include "xla/status.h"
#include "xla/statusor.h"
#include "xla/stream_executor/device_description.h"

namespace xla::gpu {

Expand Down Expand Up @@ -75,7 +74,7 @@ class CommandBufferScheduling : public HloModulePass {
using CommandBufferConfig =
absl::flat_hash_set<DebugOptions::CommandBufferCmdType>;

CommandBufferScheduling(const se::GpuComputeCapability& gpu_compute_comp,
CommandBufferScheduling(const se::DeviceDescription& device_description,
int32_t gpu_toolkit_version,
int32_t gpu_driver_version);

Expand All @@ -90,6 +89,7 @@ class CommandBufferScheduling : public HloModulePass {

static std::vector<HloInstructionSequence> CollectCommandBufferSequences(
HloInstructionSequence schedule, const CommandBufferConfig& config,
const se::DeviceDescription& device_description,
int32_t min_num_commands = 1);

// Moves kParameter and kConstant instructions in a computation to
Expand Down Expand Up @@ -127,7 +127,7 @@ class CommandBufferScheduling : public HloModulePass {
CommandBuffer command_buffer);

private:
se::GpuComputeCapability gpu_compute_comp_;
se::DeviceDescription device_description_;
// For NVIDIA gpus XLA can be compiled with a CUDA version that is larger than
// the version supported by the driver, e.g. we can compile for CUDA 12.3 but
// have 12.1 driver installed. When deciding what command buffer features we
Expand Down
31 changes: 15 additions & 16 deletions third_party/xla/xla/service/gpu/command_buffer_scheduling_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ limitations under the License.
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/hlo/ir/hlo_schedule.h"
#include "xla/service/hlo_parser.h"
#include "xla/stream_executor/device_description.h"
#include "xla/tests/filecheck.h"
#include "xla/tests/hlo_test_base.h"
#include "xla/tests/verified_hlo_module.h"
Expand All @@ -39,11 +40,8 @@ class CommandBufferSchedulingTest : public HloTestBase {
// Use CUDA 12.3 version for testing as it has all the features we rely on.
static constexpr int32_t kCudaVersion = 12030;

const auto& gpu_comp() {
return backend()
.default_stream_executor()
->GetDeviceDescription()
.gpu_compute_capability();
const se::DeviceDescription& device_desc() {
return backend().default_stream_executor()->GetDeviceDescription();
}

DebugOptions GetDebugOptionsForTest() override {
Expand Down Expand Up @@ -101,7 +99,7 @@ TEST_F(CommandBufferSchedulingTest, SingleCommandBuffer) {
// CHECK: })";

RunAndFilecheckHloRewrite(
hlo, CommandBufferScheduling(gpu_comp(), kCudaVersion, kCudaVersion),
hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
expected, [](HloModule* module) {
EXPECT_TRUE(module->has_schedule());
TF_CHECK_OK(module->schedule().Verify());
Expand Down Expand Up @@ -179,7 +177,7 @@ TEST_F(CommandBufferSchedulingTest, MultipleCommandBuffers) {
// CHECK: })";

RunAndFilecheckHloRewrite(
hlo, CommandBufferScheduling(gpu_comp(), kCudaVersion, kCudaVersion),
hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
expected, [](HloModule* module) {
EXPECT_TRUE(module->has_schedule());
TF_CHECK_OK(module->schedule().Verify());
Expand Down Expand Up @@ -218,7 +216,7 @@ TEST_F(CommandBufferSchedulingTest, AllReduceStartFollowedByDone) {
CHECK: })";

RunAndFilecheckHloRewrite(
hlo, CommandBufferScheduling(gpu_comp(), kCudaVersion, kCudaVersion),
hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
expected, [](HloModule* module) {
EXPECT_TRUE(module->has_schedule());
TF_CHECK_OK(module->schedule().Verify());
Expand Down Expand Up @@ -253,7 +251,7 @@ TEST_F(CommandBufferSchedulingTest, AllGatherStartFollowedByDone) {
CHECK: })";

RunAndFilecheckHloRewrite(
hlo, CommandBufferScheduling(gpu_comp(), kCudaVersion, kCudaVersion),
hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
expected, [](HloModule* module) {
EXPECT_TRUE(module->has_schedule());
TF_CHECK_OK(module->schedule().Verify());
Expand Down Expand Up @@ -294,7 +292,7 @@ TEST_F(CommandBufferSchedulingTest, ReduceScatterStartFollowedByDone) {
CHECK: })";

RunAndFilecheckHloRewrite(
hlo, CommandBufferScheduling(gpu_comp(), kCudaVersion, kCudaVersion),
hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
expected, [](HloModule* module) {
EXPECT_TRUE(module->has_schedule());
TF_CHECK_OK(module->schedule().Verify());
Expand Down Expand Up @@ -355,7 +353,8 @@ TEST_F(CommandBufferSchedulingTest, CollectCommandBufferSequence) {
config.insert(DebugOptions::FUSION);

std::vector<HloInstructionSequence> command_buffer_sequences =
CommandBufferScheduling::CollectCommandBufferSequences(seq, config);
CommandBufferScheduling::CollectCommandBufferSequences(seq, config,
device_desc());
EXPECT_EQ(command_buffer_sequences.size(), 2);

std::vector<HloInstruction*> seq_0 =
Expand Down Expand Up @@ -541,7 +540,7 @@ TEST_F(CommandBufferSchedulingTest, ForwardControlDependencies) {
CHECK: })";

RunAndFilecheckHloRewrite(
hlo, CommandBufferScheduling(gpu_comp(), kCudaVersion, kCudaVersion),
hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
expected, [](HloModule* module) {
EXPECT_TRUE(module->has_schedule());
TF_CHECK_OK(module->schedule().Verify());
Expand Down Expand Up @@ -581,7 +580,7 @@ TEST_F(CommandBufferSchedulingTest, ForwardControlDependenciesToParams) {
CHECK: })";

RunAndFilecheckHloRewrite(
hlo, CommandBufferScheduling(gpu_comp(), kCudaVersion, kCudaVersion),
hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
expected, [](HloModule* module) {
EXPECT_TRUE(module->has_schedule());
TF_CHECK_OK(module->schedule().Verify());
Expand Down Expand Up @@ -660,7 +659,7 @@ TEST_F(CommandBufferSchedulingTest, WhileNotCommand) {
CHECK: })";

RunAndFilecheckHloRewrite(
hlo, CommandBufferScheduling(gpu_comp(), kCudaVersion, kCudaVersion),
hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
expected, [](HloModule* module) {
EXPECT_TRUE(module->has_schedule());
TF_CHECK_OK(module->schedule().Verify());
Expand Down Expand Up @@ -722,7 +721,7 @@ TEST_F(CommandBufferSchedulingTest, While) {
CHECK: })";

RunAndFilecheckHloRewrite(
hlo, CommandBufferScheduling(gpu_comp(), kCudaVersion, kCudaVersion),
hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
expected, [](HloModule* module) {
EXPECT_TRUE(module->has_schedule());
TF_CHECK_OK(module->schedule().Verify());
Expand Down Expand Up @@ -798,7 +797,7 @@ TEST_F(CommandBufferSchedulingTest, Conditional) {
CHECK: })";

RunAndFilecheckHloRewrite(
hlo, CommandBufferScheduling(gpu_comp(), kCudaVersion, kCudaVersion),
hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion),
expected, [](HloModule* module) {
EXPECT_TRUE(module->has_schedule());
TF_CHECK_OK(module->schedule().Verify());
Expand Down
2 changes: 1 addition & 1 deletion third_party/xla/xla/service/gpu/gpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2123,7 +2123,7 @@ absl::Status GpuCompiler::RunPostSchedulingPipelines(
constexpr int toolkit_version = TF_ROCM_VERSION;
#endif
pipeline.AddPass<CommandBufferScheduling>(
gpu_device_info.gpu_compute_capability(), toolkit_version,
gpu_device_info, toolkit_version,
driver_version.value_or(toolkit_version));
TF_RETURN_IF_ERROR(pipeline.Run(module).status());
}
Expand Down

0 comments on commit 0445622

Please sign in to comment.