Skip to content

Commit

Permalink
[CustomDevice] add c_identity op (#52982) (#53013)
Browse files Browse the repository at this point in the history
* [CustomDevice] add c_identity op

* fix use calc stream
  • Loading branch information
ronny1996 authored Apr 20, 2023
1 parent 585f9d6 commit d131e67
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 19 deletions.
119 changes: 101 additions & 18 deletions paddle/fluid/distributed/collective/process_group_custom.cc
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Collective(
std::vector<phi::DenseTensor>& inputs,
std::vector<phi::DenseTensor>& outputs,
Fn fn,
CommType op_type) {
CommType op_type,
bool sync_op,
bool use_calc_stream) {
const auto places = GetPlaceList(inputs);
const auto key = GetKeyFromPlaces(places);

Expand All @@ -199,20 +201,28 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Collective(
}

auto& ccl_comms = places_to_customcomm_[key];
SyncDefaultStream(places, places_to_events_[key], places_to_ctx_[key]);
if (!use_calc_stream) {
SyncDefaultStream(places, places_to_events_[key], places_to_ctx_[key]);
}
auto task = CreateTask(places, rank_, op_type, inputs);
task->SetOutputs(outputs);

for (size_t i = 0; i < inputs.size(); ++i) {
phi::DeviceGuard guard(places[i]);
const auto& ccl_stream = places_to_ctx_[key][i]->stream();
const auto& ccl_stream =
use_calc_stream ? reinterpret_cast<phi::CustomContext*>(
phi::DeviceContextPool::Instance().Get(places[i]))
->stream()
: places_to_ctx_[key][i]->stream();
phi::stream::Stream stream(places[i], ccl_stream);
fn(inputs[i], outputs[i], ccl_comms[i]->GetCustomCCLComm(), stream);
}

for (size_t i = 0; i < inputs.size(); ++i) {
phi::DeviceGuard guard(places[i]);
task->control_events_[i].Record(*places_to_ctx_[key][i]);
if (!use_calc_stream) {
for (size_t i = 0; i < inputs.size(); ++i) {
phi::DeviceGuard guard(places[i]);
task->control_events_[i].Record(*places_to_ctx_[key][i]);
}
}
return task;
}
Expand Down Expand Up @@ -280,7 +290,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather(
comm,
stream);
},
CommType::ALLGATHER);
CommType::ALLGATHER,
sync_op,
use_calc_stream);
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather(
Expand Down Expand Up @@ -322,7 +334,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather(
comm,
stream);
},
CommType::ALLGATHER);
CommType::ALLGATHER,
false,
false);
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce(
Expand All @@ -333,7 +347,36 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce(
bool use_calc_stream) {
std::vector<phi::DenseTensor> in_wrapper{in_tensor};
std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
return AllReduce(in_wrapper, out_wrapper, opts);
PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(in_wrapper, device_type_),
true,
platform::errors::InvalidArgument(
"All inputs should be in CustomPlace(%s).", device_type_));
PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(out_wrapper, device_type_),
true,
platform::errors::InvalidArgument(
"All outputs should be in CustomPlace(%s).", device_type_));
return Collective(
in_wrapper,
out_wrapper,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
phi::ccl::CCLComm comm,
const phi::stream::Stream& stream) {
return phi::DeviceManager::CCLAllReduce(
device_type_,
input.data(),
output.data(),
input.numel(),
phi::ccl::ToCCLDataType(input.dtype()),
ToCustomCCLRedType(opts.reduce_op),
comm,
stream);
},
CommType::ALLREDUCE,
sync_op,
use_calc_stream);
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce(
Expand All @@ -342,9 +385,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce(
const AllreduceOptions& opts,
bool sync_op // for compatibility, no use now
) {
std::vector<phi::DenseTensor> in_wrapper{in_tensor};
std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
return AllReduce(in_wrapper, out_wrapper, opts);
return AllReduce(out_tensor, in_tensor, opts, sync_op, false);
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce(
Expand Down Expand Up @@ -378,7 +419,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce(
comm,
stream);
},
CommType::ALLREDUCE);
CommType::ALLREDUCE,
false,
false);
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast(
Expand All @@ -389,17 +432,55 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast(
bool use_calc_stream) {
std::vector<phi::DenseTensor> in_wrapper{in_tensor};
std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
return Broadcast(in_wrapper, out_wrapper, opts);
PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(in_wrapper, device_type_),
true,
platform::errors::InvalidArgument(
"All inputs should be in CustomPlace(%s).", device_type_));
PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(out_wrapper, device_type_),
true,
platform::errors::InvalidArgument(
"All outputs should be in CustomPlace(%s).", device_type_));
return Collective(
in_wrapper,
out_wrapper,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
phi::ccl::CCLComm comm,
const phi::stream::Stream& stream) {
int root = opts.source_rank * in_wrapper.size() + opts.source_root;
if (rank_ == root) {
return phi::DeviceManager::CCLBroadcast(
device_type_,
input.data(),
input.numel(),
phi::ccl::ToCCLDataType(input.dtype()),
root,
comm,
stream);
} else {
return phi::DeviceManager::CCLBroadcast(
device_type_,
output.data(),
output.numel(),
phi::ccl::ToCCLDataType(output.dtype()),
root,
comm,
stream);
}
},
CommType::BROADCAST,
sync_op,
use_calc_stream);
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const BroadcastOptions& opts,
bool sync_op) {
std::vector<phi::DenseTensor> in_wrapper{in_tensor};
std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
return Broadcast(in_wrapper, out_wrapper, opts);
return Broadcast(out_tensor, in_tensor, opts, sync_op, false);
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Barrier(
Expand Down Expand Up @@ -489,7 +570,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast(
stream);
}
},
CommType::BROADCAST);
CommType::BROADCAST,
false,
false);
}

std::shared_ptr<ProcessGroupCustom>
Expand Down
4 changes: 3 additions & 1 deletion paddle/fluid/distributed/collective/process_group_custom.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,9 @@ class ProcessGroupCustom : public ProcessGroupWithStream {
std::vector<phi::DenseTensor>& inputs, // NOLINT
std::vector<phi::DenseTensor>& outputs, // NOLINT
Fn fn,
CommType op_type);
CommType op_type,
bool sync_op,
bool use_calc_stream);

void CreateCustomManagerCache(const std::string& places_key,
const std::vector<Place>& places);
Expand Down
16 changes: 16 additions & 0 deletions paddle/fluid/operators/custom_device_common_op_registry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/operators/custom_device_common_op_registry.h"
#include "paddle/fluid/distributed/collective/process_group.h"
#include "paddle/fluid/operators/collective/c_concat_op.h"
#include "paddle/fluid/operators/collective/c_identity_op.h"
#include "paddle/fluid/operators/load_combine_op.h"
#include "paddle/fluid/operators/run_program_op.h"
#include "paddle/fluid/operators/save_combine_op.h"
Expand Down Expand Up @@ -589,6 +590,21 @@ void RegisterCustomDeviceCommonKernel(const std::string& dev_type) {
paddle::platform::CustomDeviceContext,
paddle::platform::float16>) {}

REGISTER_OP_CUSTOM_DEVICE_KERNEL(
c_identity,
device_type,
paddle::operators::
CIdentityOpKernel<float, paddle::platform::CustomDeviceContext>,
paddle::operators::
CIdentityOpKernel<double, paddle::platform::CustomDeviceContext>,
paddle::operators::
CIdentityOpKernel<int, paddle::platform::CustomDeviceContext>,
paddle::operators::
CIdentityOpKernel<int64_t, paddle::platform::CustomDeviceContext>,
paddle::operators::CIdentityOpKernel<
paddle::platform::float16,
paddle::platform::CustomDeviceContext>) {}

#endif
}

Expand Down

0 comments on commit d131e67

Please sign in to comment.