-
Notifications
You must be signed in to change notification settings - Fork 5.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[XPU][Fleet] Support multi-card infer for xpu (#50490)
* support xpu multi-card infer * add ut * clean code * clean code * fix * fix * fix * fix
- Loading branch information
1 parent
3b6ebc9
commit 517d807
Showing
21 changed files
with
485 additions
and
45 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
73 changes: 73 additions & 0 deletions
73
paddle/fluid/inference/tests/api/analyzer_dist_model_xpu_tester.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "gtest/gtest.h" | ||
#include "paddle/fluid/framework/block_desc.h" | ||
#include "paddle/fluid/framework/op_desc.h" | ||
#include "paddle/fluid/framework/program_desc.h" | ||
#include "paddle/fluid/framework/scope.h" | ||
#include "paddle/fluid/inference/tests/api/tester_helper.h" | ||
#include "paddle/fluid/inference/utils/singleton.h" | ||
|
||
namespace paddle { | ||
namespace inference { | ||
|
||
TEST(test_dist_model_xpu, dist_model_xpu) { | ||
std::cout << "Analysis Predictor DistModel XPU test." << std::endl; | ||
AnalysisConfig config; | ||
config.SetModel(FLAGS_infer_model + "/__model__", | ||
FLAGS_infer_model + "/__params__"); | ||
config.SwitchUseFeedFetchOps(false); | ||
config.EnableXpu(); | ||
config.SetXpuDeviceId(0); | ||
DistConfig dist_config; | ||
dist_config.SetRanks(1, 0); | ||
dist_config.EnableDistModel(true); | ||
dist_config.SetEndpoints({""}, ""); | ||
config.SetDistConfig(dist_config); | ||
|
||
auto predictor = paddle_infer::CreatePredictor(config); | ||
int batch_size = 1; | ||
int channels = 1; | ||
int height = 48; | ||
int width = 512; | ||
int nums = batch_size * channels * height * width; | ||
std::cout << "Created predictor." << std::endl; | ||
|
||
float* input = new float[nums]; | ||
for (int i = 0; i < nums; ++i) input[i] = 0; | ||
auto input_names = predictor->GetInputNames(); | ||
|
||
auto input_t = predictor->GetInputHandle(input_names[0]); | ||
input_t->Reshape({batch_size, channels, height, width}); | ||
input_t->CopyFromCpu(input); | ||
std::cout << "Input data." << std::endl; | ||
|
||
predictor->Run(); | ||
std::cout << "Zero Copy Run." << std::endl; | ||
|
||
std::vector<float> out_data; | ||
auto output_names = predictor->GetOutputNames(); | ||
auto output_t = predictor->GetOutputHandle(output_names[0]); | ||
std::vector<int> output_shape = output_t->shape(); | ||
int out_num = std::accumulate( | ||
output_shape.begin(), output_shape.end(), 1, std::multiplies<int>()); | ||
out_data.resize(out_num); | ||
output_t->CopyToCpu(out_data.data()); | ||
std::cout << "Output data." << std::endl; | ||
delete[] input; | ||
} | ||
|
||
} // namespace inference | ||
} // namespace paddle |
118 changes: 118 additions & 0 deletions
118
paddle/fluid/operators/collective/c_broadcast_op_xpu.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/fluid/operators/collective/c_broadcast_op.h" | ||
|
||
#ifdef PADDLE_WITH_XPU_BKCL | ||
#include "paddle/fluid/platform/collective_helper.h" | ||
#include "paddle/fluid/platform/device/xpu/bkcl_helper.h" | ||
#endif | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
template <typename T> | ||
class CBroadcastOpXPUKernel : public framework::OpKernel<T> { | ||
public: | ||
void Compute(const framework::ExecutionContext& ctx) const override { | ||
#if defined(PADDLE_WITH_XPU_BKCL) | ||
auto x = ctx.Input<phi::DenseTensor>("X"); | ||
auto out = ctx.Output<phi::DenseTensor>("Out"); | ||
size_t numel = x->numel(); | ||
|
||
BKCLDataType dtype = | ||
platform::ToBKCLDataType(framework::TransToProtoVarType(x->dtype())); | ||
int ring_id = ctx.Attr<int>("ring_id"); | ||
auto place = ctx.GetPlace(); | ||
auto comm = | ||
paddle::platform::BKCLCommContext::Instance().Get(ring_id, place); | ||
|
||
XPUStream stream = nullptr; | ||
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); | ||
if (ctx.Attr<bool>("use_calc_stream")) { | ||
stream = static_cast<platform::XPUDeviceContext*>(dev_ctx) | ||
->x_context() | ||
->xpu_stream; | ||
} else { | ||
stream = comm->stream(); | ||
} | ||
|
||
int root = ctx.Attr<int>("root"); | ||
VLOG(3) << "begin bkcl broadcast, parameter is: " | ||
<< "root " << root << ", comm: " << comm->comm() | ||
<< ", stream: " << stream; | ||
void* send_recv_buffer = nullptr; | ||
if (root == comm->rank()) { | ||
// API: BKCLResult_t bkcl_broadcast(const BKCLContext_t ctx, | ||
// const void* sendbuf, | ||
// void* recvbuf, | ||
// size_t count, BKCLDataType datatype, | ||
// int root, | ||
// XPUStream stream); | ||
send_recv_buffer = reinterpret_cast<void*>(const_cast<T*>(x->data<T>())); | ||
auto ret = bkcl_broadcast(comm->comm(), | ||
send_recv_buffer, | ||
send_recv_buffer, | ||
numel, | ||
dtype, | ||
root, | ||
stream); | ||
PADDLE_ENFORCE_EQ(ret, | ||
BKCL_SUCCESS, | ||
platform::errors::PreconditionNotMet( | ||
"XPU BKCL c_broadcast execute failed")); | ||
if (out != x) { | ||
framework::TensorCopy( | ||
*static_cast<const phi::DenseTensor*>(x), | ||
place, | ||
*platform::DeviceContextPool::Instance().Get(place), | ||
static_cast<phi::DenseTensor*>(out)); | ||
} | ||
} else { | ||
auto& dev_ctx = ctx.template device_context<platform::XPUDeviceContext>(); | ||
dev_ctx.template Alloc<T>(out); | ||
send_recv_buffer = out->data<T>(); | ||
auto ret = bkcl_broadcast(comm->comm(), | ||
send_recv_buffer, | ||
send_recv_buffer, | ||
numel, | ||
dtype, | ||
root, | ||
stream); | ||
PADDLE_ENFORCE_EQ(ret, | ||
BKCL_SUCCESS, | ||
platform::errors::PreconditionNotMet( | ||
"XPU BKCL c_broadcast execute failed")); | ||
} | ||
|
||
VLOG(3) << "rank " << comm->rank() << " invoke Bcast. received " | ||
<< phi::product(out->dims()); | ||
out->Resize(x->dims()); | ||
out->set_lod(x->lod()); | ||
#else | ||
PADDLE_THROW(platform::errors::PreconditionNotMet( | ||
"PaddlePaddle should be compiled with XPU and BKCL.")); | ||
#endif | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
namespace plat = paddle::platform; | ||
|
||
REGISTER_OP_XPU_KERNEL(c_broadcast, | ||
ops::CBroadcastOpXPUKernel<float>, | ||
ops::CBroadcastOpXPUKernel<plat::float16>); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.