Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pten] Replace platform::Place to pten::Place. #38899

Merged
merged 28 commits into from
Jan 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
e3640af
add pten::Place data structure.
jiweibo Jan 10, 2022
212ea96
update ci problem
jiweibo Jan 10, 2022
08a0263
fix ci problem
jiweibo Jan 10, 2022
30d84e7
update
jiweibo Jan 10, 2022
7c66def
Merge branch 'develop' of /~https://github.com/PaddlePaddle/Paddle into…
jiweibo Jan 11, 2022
21ec56d
using platform::Place=pten::Place
jiweibo Jan 11, 2022
a26f0f7
remove BOOST_GET_CONST for CPUPlace and GPUPlace
jiweibo Jan 11, 2022
2211b28
Merge branch 'develop' of /~https://github.com/PaddlePaddle/Paddle into…
jiweibo Jan 11, 2022
b13b038
compile pass 25%.
jiweibo Jan 12, 2022
f46f4e8
compile pass 45%
jiweibo Jan 12, 2022
62d1114
compile pass 60%
jiweibo Jan 12, 2022
afccb3c
remove boost_get for xpu npu mlu and ipu
jiweibo Jan 12, 2022
45b5f1d
compile pass on cpu and gpu.
jiweibo Jan 12, 2022
139f3ff
fix compile problem
jiweibo Jan 12, 2022
de07a8c
Merge branch 'develop' of /~https://github.com/PaddlePaddle/Paddle into…
jiweibo Jan 12, 2022
57cde3c
fix compile error.
jiweibo Jan 12, 2022
a4a7263
update
jiweibo Jan 12, 2022
6582a7e
fix ci problem
jiweibo Jan 13, 2022
e32da4f
Merge branch 'develop' of /~https://github.com/PaddlePaddle/Paddle into…
jiweibo Jan 13, 2022
991a751
update
jiweibo Jan 13, 2022
fedb225
ci approve
jiweibo Jan 13, 2022
909b81a
Merge branch 'develop' of /~https://github.com/PaddlePaddle/Paddle into…
jiweibo Jan 13, 2022
7930697
fix ci problem
jiweibo Jan 14, 2022
e0e593a
fix ci eager test problem
jiweibo Jan 14, 2022
bf94564
Merge branch 'develop' of /~https://github.com/PaddlePaddle/Paddle into…
jiweibo Jan 14, 2022
ba6ddd0
remove BOOST_GET_CONST
jiweibo Jan 14, 2022
87111ea
fix npu compile
jiweibo Jan 14, 2022
4a5e80f
Merge branch 'develop' of /~https://github.com/PaddlePaddle/Paddle into…
jiweibo Jan 16, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions paddle/fluid/distributed/fleet_executor/carrier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,8 @@ static std::shared_ptr<framework::GarbageCollector> GetGC(
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::is_gpu_place(place)) {
if (framework::IsFastEagerDeletionModeEnabled()) {
gc.reset(new framework::UnsafeFastGPUGarbageCollector(
BOOST_GET_CONST(platform::CUDAPlace, place), max_memory_size));
gc.reset(new framework::UnsafeFastGPUGarbageCollector(place,
max_memory_size));
}
}
#endif
Expand Down
59 changes: 28 additions & 31 deletions paddle/fluid/distributed/service/brpc_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,12 @@ void SerializeLodTensor(framework::Variable* var,
iobuf->append(reinterpret_cast<const char*>(tensor->data()), data_len);
} else {
#ifdef PADDLE_WITH_CUDA
char* temp_ptr =
new char[tensor->numel() * framework::SizeOfType(tensor->type())];
char* temp_ptr = new char[tensor->numel() *
framework::SizeOfType(tensor->type())]; // NOLINT
auto stream =
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream();
memory::Copy(
platform::CPUPlace(), temp_ptr,
BOOST_GET_CONST(platform::CUDAPlace, tensor->place()), tensor->data(),
platform::CPUPlace(), temp_ptr, tensor->place(), tensor->data(),
tensor->numel() * framework::SizeOfType(tensor->type()), stream);
auto data_len = tensor->numel() * framework::SizeOfType(tensor->type());
iobuf->append(reinterpret_cast<const char*>(&data_len), 8);
Expand Down Expand Up @@ -148,13 +147,12 @@ void SerializeSelectedRows(framework::Variable* var,
iobuf->append(reinterpret_cast<const char*>(tensor->data()), data_len);
} else {
#ifdef PADDLE_WITH_CUDA
char* temp_ptr =
new char[tensor->numel() * framework::SizeOfType(tensor->type())];
char* temp_ptr = new char[tensor->numel() *
framework::SizeOfType(tensor->type())]; // NOLINT
auto stream =
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream();
memory::Copy(
platform::CPUPlace(), temp_ptr,
BOOST_GET_CONST(platform::CUDAPlace, tensor->place()), tensor->data(),
platform::CPUPlace(), temp_ptr, tensor->place(), tensor->data(),
tensor->numel() * framework::SizeOfType(tensor->type()), stream);
auto data_len = tensor->numel() * framework::SizeOfType(tensor->type());
iobuf->append(reinterpret_cast<const char*>(&data_len), 8);
Expand Down Expand Up @@ -204,7 +202,7 @@ void DeserializeFromMultiVarMsgAndIOBuf(const MultiVarMsg& multi_msg,
}

void DeserializeLodTensor(framework::Variable* var, const VarMsg& msg,
butil::IOBufBytesIterator& io_buffer_itr,
butil::IOBufBytesIterator& io_buffer_itr, // NOLINT
const platform::DeviceContext& ctx) {
const auto place = ctx.GetPlace();
framework::LoDTensor* tensor = var->GetMutable<framework::LoDTensor>();
Expand All @@ -229,30 +227,30 @@ void DeserializeLodTensor(framework::Variable* var, const VarMsg& msg,

// IO Buffer
if (platform::is_cpu_place(place)) {
unsigned long data_len;
io_buffer_itr.copy_and_forward((void*)(&data_len), 8);
unsigned long data_len; // NOLINT
io_buffer_itr.copy_and_forward((void*)(&data_len), 8); // NOLINT
io_buffer_itr.copy_and_forward(tensor_data, data_len);
} else if (platform::is_gpu_place(place)) {
#ifdef PADDLE_WITH_CUDA
unsigned long data_len;
char* temp_ptr =
new char[tensor->numel() * framework::SizeOfType(tensor->type())];
io_buffer_itr.copy_and_forward((void*)(&data_len), 8);
io_buffer_itr.copy_and_forward((void*)temp_ptr, data_len);
unsigned long data_len; // NOLINT
char* temp_ptr = new char[tensor->numel() *
framework::SizeOfType(tensor->type())]; // NOLINT
io_buffer_itr.copy_and_forward((void*)(&data_len), 8); // NOLINT
io_buffer_itr.copy_and_forward((void*)temp_ptr, data_len); // NOLINT
auto stream =
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream();
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, place), tensor_data,
platform::CPUPlace(), (void*)temp_ptr,
tensor->numel() * framework::SizeOfType(tensor->type()),
stream);
memory::Copy(
place, tensor_data, platform::CPUPlace(), (void*)temp_ptr, // NOLINT
tensor->numel() * framework::SizeOfType(tensor->type()), stream);
delete[] temp_ptr;
#endif
}
}

void DeserializeSelectedRows(framework::Variable* var, const VarMsg& msg,
butil::IOBufBytesIterator& io_buffer_itr,
const platform::DeviceContext& ctx) {
void DeserializeSelectedRows(
framework::Variable* var, const VarMsg& msg,
butil::IOBufBytesIterator& io_buffer_itr, // NOLINT
const platform::DeviceContext& ctx) {
const auto place = ctx.GetPlace();
auto* slr = var->GetMutable<framework::SelectedRows>();
framework::Tensor* tensor = slr->mutable_value();
Expand All @@ -269,20 +267,19 @@ void DeserializeSelectedRows(framework::Variable* var, const VarMsg& msg,
tensor->mutable_data(place, VarMessageToVarType(msg.data_type()));
// IO Buffer
if (platform::is_cpu_place(place)) {
unsigned long data_len;
io_buffer_itr.copy_and_forward((void*)(&data_len), 8);
unsigned long data_len; // NOLINT
io_buffer_itr.copy_and_forward((void*)(&data_len), 8); // NOLINT
io_buffer_itr.copy_and_forward(tensor_data, data_len);
} else if (platform::is_gpu_place(place)) {
#ifdef PADDLE_WITH_CUDA
char* temp_ptr =
new char[tensor->numel() * framework::SizeOfType(tensor->type())];
unsigned long data_len;
io_buffer_itr.copy_and_forward((void*)(&data_len), 8);
char* temp_ptr = new char[tensor->numel() *
framework::SizeOfType(tensor->type())]; // NOLINT
unsigned long data_len; // NOLINT
io_buffer_itr.copy_and_forward((void*)(&data_len), 8); // NOLINT
io_buffer_itr.copy_and_forward(temp_ptr, data_len);
auto stream =
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream();
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, place), tensor_data,
platform::CPUPlace(), temp_ptr,
memory::Copy(place, tensor_data, platform::CPUPlace(), temp_ptr,
tensor->numel() * framework::SizeOfType(tensor->type()),
stream);
delete[] temp_ptr;
Expand Down
3 changes: 1 addition & 2 deletions paddle/fluid/distributed/service/heter_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,7 @@ int GetMicroId(const platform::DeviceContext& ctx,
auto stream =
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream();
memory::Copy(
platform::CPUPlace(), temp_ptr,
BOOST_GET_CONST(platform::CUDAPlace, tensor->place()), tensor->data(),
platform::CPUPlace(), temp_ptr, tensor->place(), tensor->data(),
tensor->numel() * framework::SizeOfType(tensor->type()), stream);
float* temp_ptr_float = reinterpret_cast<float*>(temp_ptr);
micro_id = static_cast<int>(temp_ptr_float[0]);
Expand Down
32 changes: 16 additions & 16 deletions paddle/fluid/eager/accumulation/gradient_accumulation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class TensorAddFunctor : public boost::static_visitor<> {
TensorAddFunctor(int64_t numel, const T* x, T* y)
: numel_(numel), x_(x), y_(y) {}

void operator()(const paddle::platform::CPUPlace& place) {
void operator()(const paddle::platform::CPUPlace& place) const {
paddle::platform::CPUDeviceContext* ctx =
dynamic_cast<paddle::platform::CPUDeviceContext*>(
paddle::platform::DeviceContextPool::Instance().Get(place));
Expand All @@ -56,7 +56,7 @@ class TensorAddFunctor : public boost::static_visitor<> {
// TODO(jiabin): Support xpu here from gradient_accumulator.cc

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
void operator()(const paddle::platform::CUDAPlace& place) {
void operator()(const paddle::platform::CUDAPlace& place) const {
paddle::platform::CUDADeviceContext* ctx =
dynamic_cast<paddle::platform::CUDADeviceContext*>(
paddle::platform::DeviceContextPool::Instance().Get(place));
Expand All @@ -66,7 +66,7 @@ class TensorAddFunctor : public boost::static_visitor<> {
blas.AXPY(numel_, 1., x_, y_);
}
#else
void operator()(const paddle::platform::CUDAPlace& place) {
void operator()(const paddle::platform::CUDAPlace& place) const {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"Gradient accumulation on place (%s) "
"is not supported in imperative mode",
Expand All @@ -76,22 +76,22 @@ class TensorAddFunctor : public boost::static_visitor<> {

// TODO(jiabin): Support Npu here from gradient_accumulator.cc
// there is NO blas in CUDAPinnedPlace
void operator()(const paddle::platform::CUDAPinnedPlace& place) {
void operator()(const paddle::platform::CUDAPinnedPlace& place) const {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"Gradient accumulation on place (%s) "
"is not supported in imperative mode",
place));
}

#ifdef PADDLE_WITH_ASCEND_CL
void operator()(const paddle::platform::NPUPlace& place) {
void operator()(const paddle::platform::NPUPlace& place) const {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"Gradient accumulation on place (%s) "
"is not supported in imperative mode",
place));
}
#else
void operator()(const paddle::platform::NPUPlace& place) {
void operator()(const paddle::platform::NPUPlace& place) const {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"Gradient accumulation on place (%s) "
"is not supported in imperative mode",
Expand All @@ -100,14 +100,14 @@ class TensorAddFunctor : public boost::static_visitor<> {
#endif

#ifdef PADDLE_WITH_XPU
void operator()(const paddle::platform::XPUPlace& place) {
void operator()(const paddle::platform::XPUPlace& place) const {
paddle::platform::XPUDeviceContext* ctx =
dynamic_cast<paddle::platform::XPUDeviceContext*>(
paddle::platform::DeviceContextPool::Instance().Get(place));
xpu::add<T>(ctx->x_context(), x_, y_, y_, static_cast<int>(numel_));
}
#else
void operator()(const paddle::platform::XPUPlace& place) {
void operator()(const paddle::platform::XPUPlace& place) const {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"Gradient accumulation on place (%s) "
"is not supported in imperative mode",
Expand All @@ -116,14 +116,14 @@ class TensorAddFunctor : public boost::static_visitor<> {
#endif

#ifdef PADDLE_WITH_MLU
void operator()(const paddle::platform::MLUPlace& place) {
void operator()(const paddle::platform::MLUPlace& place) const {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"Gradient accumulation on place (%s) "
"is not supported in imperative mode",
place));
}
#else
void operator()(const paddle::platform::MLUPlace& place) {
void operator()(const paddle::platform::MLUPlace& place) const {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"Gradient accumulation on place (%s) "
"is not supported in imperative mode",
Expand All @@ -132,22 +132,22 @@ class TensorAddFunctor : public boost::static_visitor<> {
#endif

#ifdef PADDLE_WITH_IPU
void operator()(const paddle::platform::IPUPlace& place) {
void operator()(const paddle::platform::IPUPlace& place) const {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"Gradient accumulation on place (%s) "
"is not supported in imperative mode",
place));
}
#else
void operator()(const paddle::platform::IPUPlace& place) {
void operator()(const paddle::platform::IPUPlace& place) const {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"Gradient accumulation on place (%s) "
"is not supported in imperative mode",
place));
}
#endif

void operator()(const paddle::platform::NPUPinnedPlace& place) {
void operator()(const paddle::platform::NPUPinnedPlace& place) const {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"Gradient accumulation on place (%s) "
"is not supported in imperative mode",
Expand All @@ -157,7 +157,7 @@ class TensorAddFunctor : public boost::static_visitor<> {
private:
int64_t numel_;
const T* x_;
T* y_;
mutable T* y_;
};

template <typename DeviceContext, typename T>
Expand Down Expand Up @@ -218,7 +218,7 @@ void TensorAdd(const egr::EagerTensor& src, egr::EagerTensor* dst) {
if (data_type == paddle::framework::DataTypeTrait<cpp_type>::DataType()) { \
TensorAddFunctor<cpp_type> func(numel, src_tensor->data<cpp_type>(), \
dst_tensor->mutable_data<cpp_type>()); \
boost::apply_visitor(func, place); \
paddle::platform::VisitPlace(place, func); \
return; \
}

Expand Down Expand Up @@ -294,7 +294,7 @@ void VariableAdd(const egr::EagerTensor& src, egr::EagerTensor* dst) {
TensorAddFunctor<cpp_type> func( \
numel, src_tensor.data<cpp_type>(), \
dst_tensor->mutable_data<cpp_type>(place)); \
boost::apply_visitor(func, place); \
paddle::platform::VisitPlace(place, func); \
return; \
}

Expand Down
9 changes: 3 additions & 6 deletions paddle/fluid/eager/legacy/op_runner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -150,24 +150,21 @@ void RunOp(const std::string& type, const NameTensorMap& ins,
VLOG(6) << "Get Device id";
if (paddle::platform::is_gpu_place(place)) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
paddle::platform::SetDeviceId(
BOOST_GET_CONST(paddle::platform::CUDAPlace, place).device);
paddle::platform::SetDeviceId(place.device);
#else
PADDLE_THROW(paddle::platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with GPU if use CUDAPlace."));
#endif
} else if (paddle::platform::is_xpu_place(place)) {
#ifdef PADDLE_WITH_XPU
paddle::platform::SetXPUDeviceId(
BOOST_GET_CONST(paddle::platform::XPUPlace, place).device);
paddle::platform::SetXPUDeviceId(place.device);
#else
PADDLE_THROW(paddle::platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with XPU if use XPUPlace."));
#endif
} else if (paddle::platform::is_npu_place(place)) {
#ifdef PADDLE_WITH_ASCEND_CL
paddle::platform::SetNPUDeviceId(
BOOST_GET_CONST(paddle::platform::NPUPlace, place).device);
paddle::platform::SetNPUDeviceId(place.device);
#else
PADDLE_THROW(paddle::platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with NPU if use NPUPlace."));
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/eager/legacy/prepared_operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ PreparedOp PrepareImpl(const NameTensorMap& ins, const NameTensorMap& outs,
auto& kernels = kernels_iter->second;
auto kernel_iter = kernels.find(expected_kernel_key);
#ifdef PADDLE_WITH_XPU
if (is_xpu_place(expected_kernel_key.place_) &&
if (paddle::platform::is_xpu_place(expected_kernel_key.place_) &&
(kernel_iter == kernels.end() ||
!paddle::platform::is_xpu_support_op(op.Type(), expected_kernel_key) ||
paddle::platform::is_in_xpu_black_list(op.Type()))) {
Expand All @@ -129,7 +129,7 @@ PreparedOp PrepareImpl(const NameTensorMap& ins, const NameTensorMap& outs,
#endif
#ifdef PADDLE_WITH_ASCEND_CL
if (kernel_iter == kernels.end() &&
is_npu_place(expected_kernel_key.place_)) {
paddle::platform::is_npu_place(expected_kernel_key.place_)) {
VLOG(3) << "missing NPU kernel: " << op.Type()
<< ", expected_kernel_key:" << expected_kernel_key
<< ", fallbacking to CPU one!";
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/data_device_transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ void TransDataDevice(const Tensor &in, const platform::Place &dst_place,
<< " dst_place: " << dst_place;

PADDLE_ENFORCE_NE(
in.place().which(), dst_place.which(),
in.place().GetType(), dst_place.GetType(),
platform::errors::Unavailable("Currently, model parallelism is only "
"supported between CPU and CUDA."));

Expand Down
7 changes: 4 additions & 3 deletions paddle/fluid/framework/details/all_reduce_op_handle.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/reduce_and_gather.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler.h"

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
Expand Down Expand Up @@ -181,7 +182,7 @@ void AllReduceOpHandle::AllReduceFunc(
const framework::proto::VarType::Type &dtype, int64_t numel,
const std::vector<platform::Place> &places,
const std::vector<std::string> &out_var_names) {
if (is_gpu_place(places[0])) {
if (platform::is_gpu_place(places[0])) {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
PADDLE_ENFORCE_NOT_NULL(nccl_ctxs_,
platform::errors::InvalidArgument(
Expand All @@ -200,7 +201,7 @@ void AllReduceOpHandle::AllReduceFunc(
PADDLE_THROW(
platform::errors::PreconditionNotMet("Not compiled with GPU."));
#endif
} else if (is_xpu_place(places[0])) {
} else if (platform::is_xpu_place(places[0])) {
#if defined(PADDLE_WITH_XPU_BKCL)
PADDLE_ENFORCE_NOT_NULL(bkcl_ctxs_,
platform::errors::InvalidArgument(
Expand Down Expand Up @@ -286,7 +287,7 @@ void AllReduceOpHandle::NCCLAllReduceFunc(
void AllReduceOpHandle::SyncNCCLAllReduce() {
if (FLAGS_sync_nccl_allreduce) {
for (auto &p : places_) {
int dev_id = BOOST_GET_CONST(platform::CUDAPlace, p).device;
int dev_id = p.device;
auto *nccl_ctxs =
nccl_ctxs_->GetRunEnvNCCLCtx(run_order_, use_hierarchical_allreduce_);
auto &nccl_ctx = nccl_ctxs->at(dev_id);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ BindThreadedSSAGraphExecutor::BindThreadedSSAGraphExecutor(
}
int index = 0;
for (uint32_t i = 0; i < places.size(); i++) {
int id = BOOST_GET_CONST(platform::XPUPlace, places_[i]).device;
int id = places_[i].device;
if (place_to_index_.find(id) == place_to_index_.end()) {
place_to_index_[id] = index;
index++;
Expand Down Expand Up @@ -145,8 +145,7 @@ FetchResultType BindThreadedSSAGraphExecutor::RunMainStream(
RunMultiDeviceOpAsync(cur_op, op_deps.get(), ready_ops);
continue;
} else {
cur_place =
BOOST_GET_CONST(platform::XPUPlace, dev_ctxes_.begin()->first);
cur_place = dev_ctxes_.begin()->first;
int cur_index = place_to_index_[cur_place.device];
RunOpAsyncMainStream(cur_op, op_deps.get(), ready_ops, cur_index);
}
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/details/bkcl_op_handle.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class BKCLOpHandleBase : public OpHandleBase {
platform::errors::InvalidArgument(
"The argument run_order_ must be >= 0, but got %d.", run_order_));
auto flat_bkcl_ctxs = bkcl_ctxs_->GetFlatCtx(run_order_);
int dev_id = BOOST_GET_CONST(platform::XPUPlace, place).device;
int dev_id = place.device;
auto& bkcl_ctx = flat_bkcl_ctxs->at(dev_id);
auto comm = bkcl_ctx.comm_;

Expand Down
Loading