Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cherry-pick] Refactor Model Parallel in eager dygraph mode (#41761) #41960

Merged
merged 3 commits into from
Apr 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 30 additions & 13 deletions paddle/fluid/operators/class_center_sample_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@ namespace cub = hipcub;
#include <iterator>
#include <random>
#include "paddle/fluid/operators/class_center_sample_op.h"
#include "paddle/phi/api/include/tensor.h"

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/distributed/collective/ProcessGroup.h"
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#endif
Expand Down Expand Up @@ -328,19 +330,34 @@ class ClassCenterSampleCUDAKernel : public framework::OpKernel<T> {

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
if (nranks > 1) {
const auto& comm =
platform::NCCLCommContext::Instance().Get(rid, ctx.GetPlace());
// use global calculate stream
const auto calcu_stream =
static_cast<platform::CUDADeviceContext*>(
platform::DeviceContextPool::Instance().Get(ctx.GetPlace()))
->stream();
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
num_classes_per_device_ptr, num_classes_per_device_ptr,
num_classes_per_device.numel(),
platform::ToNCCLDataType(
framework::TransToProtoVarType(num_classes_per_device.dtype())),
ncclSum, comm->comm(), calcu_stream));
auto map = distributed::ProcessGroupMapFromGid::getInstance();
if (map->has(rid)) {
// Use ProcessGroup
distributed::ProcessGroup* pg = map->get(rid);
std::vector<phi::DenseTensor> in_tensor;
std::vector<phi::DenseTensor> out_tensor;
in_tensor.push_back(num_classes_per_device);
out_tensor.push_back(num_classes_per_device);

distributed::AllreduceOptions opts;
opts.reduce_op = distributed::ReduceOp::SUM;
auto task = pg->AllReduce(in_tensor, out_tensor, opts);
task->Wait();
} else {
const auto& comm =
platform::NCCLCommContext::Instance().Get(rid, ctx.GetPlace());
// use global calculate stream
const auto calcu_stream =
static_cast<platform::CUDADeviceContext*>(
platform::DeviceContextPool::Instance().Get(ctx.GetPlace()))
->stream();
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
num_classes_per_device_ptr, num_classes_per_device_ptr,
num_classes_per_device.numel(),
platform::ToNCCLDataType(
framework::TransToProtoVarType(num_classes_per_device.dtype())),
ncclSum, comm->comm(), calcu_stream));
}
}
#endif

Expand Down
41 changes: 40 additions & 1 deletion paddle/fluid/operators/collective/c_allreduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@ limitations under the License. */

#include <string>

#include "paddle/fluid/distributed/collective/ProcessGroup.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
#include "paddle/phi/api/include/tensor.h"

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_ASCEND_CL) || defined(PADDLE_WITH_XPU_BKCL) || \
Expand Down Expand Up @@ -351,6 +353,7 @@ class CAllReduceOpCUDAKernel : public framework::OpKernel<T> {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto in = ctx.Input<framework::Tensor>("X");
auto out = ctx.Output<framework::Tensor>("Out");
int rid = ctx.Attr<int>("ring_id");

auto place = ctx.GetPlace();
ncclDataType_t dtype =
Expand All @@ -360,7 +363,43 @@ class CAllReduceOpCUDAKernel : public framework::OpKernel<T> {
out->Resize(in->dims());
void* recvbuff = out->mutable_data<T>(place);

int rid = ctx.Attr<int>("ring_id");
auto map = distributed::ProcessGroupMapFromGid::getInstance();
if (map->has(rid)) {
// Use ProcessGroup
distributed::ProcessGroup* pg = map->get(rid);
std::vector<phi::DenseTensor> in_tensor;
std::vector<phi::DenseTensor> out_tensor;
in_tensor.push_back(*in);
out_tensor.push_back(*out);

distributed::AllreduceOptions opts;
switch (red_type) {
case kRedSum:
opts.reduce_op = distributed::ReduceOp::SUM;
break;

case kRedMax:
opts.reduce_op = distributed::ReduceOp::MAX;
break;

case kRedMin:
opts.reduce_op = distributed::ReduceOp::MIN;
break;

case kRedProd:
opts.reduce_op = distributed::ReduceOp::PRODUCT;
break;

default:
PADDLE_THROW(platform::errors::InvalidArgument(
"Invalid reduce type: %d", red_type));
}

auto task = pg->AllReduce(in_tensor, out_tensor, opts);
task->Wait();
return;
}

auto comm = platform::NCCLCommContext::Instance().Get(rid, place);

gpuStream_t stream = nullptr;
Expand Down
47 changes: 31 additions & 16 deletions paddle/fluid/operators/collective/c_concat_op.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@ limitations under the License. */

#include "paddle/fluid/operators/collective/c_concat_op.h"
#include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/phi/api/include/tensor.h"

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/distributed/collective/ProcessGroup.h"
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#endif
Expand Down Expand Up @@ -55,26 +57,39 @@ class CConcatOpCUDAKernel : public framework::OpKernel<T> {
rank, nranks));

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
PADDLE_ENFORCE_EQ(
nranks, comm->nranks(),
platform::errors::InvalidArgument("nranks: %s should equal to %s",
nranks, comm->nranks()));

framework::Tensor temp_out;
framework::DDim temp_out_dims = x->dims();
temp_out_dims[0] *= nranks;
temp_out.mutable_data<T>(temp_out_dims, place);
int64_t send_numel = x->numel();
const T* send_buff = x->data<T>();
T* recv_buff = temp_out.data<T>();
gpuStream_t stream = nullptr;
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<platform::CUDADeviceContext*>(dev_ctx)->stream();

PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllGather(
send_buff, recv_buff, send_numel, static_cast<ncclDataType_t>(dtype),
comm->comm(), stream));

auto map = distributed::ProcessGroupMapFromGid::getInstance();
if (map->has(rid)) {
// Use ProcessGroup
distributed::ProcessGroup* pg = map->get(rid);
std::vector<phi::DenseTensor> in_tensor;
std::vector<phi::DenseTensor> out_tensor;
in_tensor.push_back(*x);
out_tensor.push_back(temp_out);
auto task = pg->AllGather(in_tensor, out_tensor);
task->Wait();
} else {
auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
PADDLE_ENFORCE_EQ(
nranks, comm->nranks(),
platform::errors::InvalidArgument("nranks: %s should equal to %s",
nranks, comm->nranks()));

int64_t send_numel = x->numel();
const T* send_buff = x->data<T>();
T* recv_buff = temp_out.data<T>();
gpuStream_t stream = nullptr;
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<platform::CUDADeviceContext*>(dev_ctx)->stream();

PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllGather(
send_buff, recv_buff, send_numel, static_cast<ncclDataType_t>(dtype),
comm->comm(), stream));
}

std::vector<framework::Tensor> inputs;
int axis = x->dims().size() - 1;
Expand Down
139 changes: 139 additions & 0 deletions paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/fluid/string/string_helper.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"

namespace paddle {
Expand Down Expand Up @@ -73,6 +74,21 @@ template <typename T>
class CSoftmaxWithCrossEntropyOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const int rid = ctx.Attr<int>("ring_id");
auto map = distributed::ProcessGroupMapFromGid::getInstance();
if (map->has(rid)) {
CSoftmaxWithCrossEntropyProcessGroupFunctor<phi::GPUContext, T> functor_;
functor_(ctx);
} else {
CSoftmaxWithCrossEntropyFunctor<phi::GPUContext, T> functor_;
functor_(ctx);
}
}
};

template <typename T>
struct CSoftmaxWithCrossEntropyFunctor<phi::GPUContext, T> {
void operator()(const framework::ExecutionContext& ctx) {
const Tensor* logits = ctx.Input<Tensor>("Logits");
const Tensor* labels = ctx.Input<Tensor>("Label");
Tensor* softmax = ctx.Output<Tensor>("Softmax");
Expand Down Expand Up @@ -201,6 +217,129 @@ class CSoftmaxWithCrossEntropyOpCUDAKernel : public framework::OpKernel<T> {
}
};

template <typename T>
struct CSoftmaxWithCrossEntropyProcessGroupFunctor<phi::GPUContext, T> {
void operator()(const framework::ExecutionContext& ctx) {
const Tensor* logits = ctx.Input<Tensor>("Logits");
const Tensor* labels = ctx.Input<Tensor>("Label");
Tensor* softmax = ctx.Output<Tensor>("Softmax");
Tensor* loss = ctx.Output<Tensor>("Loss");

const int rid = ctx.Attr<int>("ring_id");
const int nranks = ctx.Attr<int>("nranks");
const int rank = ctx.Attr<int>("rank");

const auto& place = ctx.GetPlace();
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();

auto map = distributed::ProcessGroupMapFromGid::getInstance();
distributed::ProcessGroup* pg = map->get(rid);
distributed::AllreduceOptions opts;
opts.reduce_op = distributed::ReduceOp::SUM;

// allocate memory on device.
softmax->mutable_data<T>(place);
loss->mutable_data<T>(place);

const auto& logits_dims = logits->dims();
const auto& labels_dims = labels->dims();

const int axis = logits_dims.size() - 1;
const int N = phi::funcs::SizeToAxis(axis, logits_dims);
const int D = phi::funcs::SizeFromAxis(axis, logits_dims);

Tensor logits_2d, softmax_2d, loss_2d;
logits_2d.ShareDataWith(*logits).Resize({N, D});
softmax_2d.ShareDataWith(*softmax).Resize({N, D});
loss_2d.ShareDataWith(*loss).Resize({N, 1});

auto eigen_logits = math::EigenMatrix<T>::From(logits_2d);
auto eigen_softmax = math::EigenMatrix<T>::From(softmax_2d);

// step 1, obtain logit_max
Tensor logits_max;
logits_max =
ctx.AllocateTmpTensor<T, platform::CUDADeviceContext>({N, 1}, dev_ctx);

auto eigen_logits_max = math::EigenMatrix<T>::From(logits_max);
Eigen::DSizes<int, 1> along_axis(1);
eigen_logits_max.device(*dev_ctx.eigen_device()) =
eigen_logits.maximum(along_axis);

std::vector<phi::DenseTensor> in_out;
in_out.push_back(logits_max);
pg->AllReduce(in_out, in_out, opts)->Synchronize();

// step 2, obtain logit - logit_max
Eigen::DSizes<int, 2> batch_by_one(N, 1);
Eigen::DSizes<int, 2> one_by_class(1, D);

eigen_softmax.device(*dev_ctx.eigen_device()) =
(eigen_logits -
eigen_logits_max.reshape(batch_by_one).broadcast(one_by_class))
.unaryExpr(math::ValueClip<T>());

// step 3, obtain predict target
Tensor predicted_logits;
predicted_logits =
ctx.AllocateTmpTensor<T, platform::CUDADeviceContext>({N, 1}, dev_ctx);
predicted_logits.mutable_data<T>(place);

auto t = framework::EigenVector<T>::Flatten(predicted_logits);
t.device(*dev_ctx.eigen_device()) = t.constant(static_cast<T>(0));

const int start_index = rank * D;
const int end_index = start_index + D;

int blocks = NumBlocks(N);
int threads = kNumCUDAThreads;
const auto& label_type = framework::TransToProtoVarType(labels->dtype());

if (label_type == framework::proto::VarType::INT32) {
MaskLabelByIndex<T, int32_t><<<blocks, threads, 0, dev_ctx.stream()>>>(
predicted_logits.data<T>(), softmax_2d.data<T>(),
labels->data<int32_t>(), start_index, end_index, N, D, nranks);
} else if (label_type == framework::proto::VarType::INT64) {
MaskLabelByIndex<T, int64_t><<<blocks, threads, 0, dev_ctx.stream()>>>(
predicted_logits.data<T>(), softmax_2d.data<T>(),
labels->data<int64_t>(), start_index, end_index, N, D, nranks);
}

in_out.clear();
in_out.push_back(predicted_logits);
pg->AllReduce(in_out, in_out, opts)->Synchronize();

// step 4, obtain exp(logit)
eigen_softmax.device(*dev_ctx.eigen_device()) = eigen_softmax.exp();

// step 5, obtain sum_exp_logits
Tensor sum_exp_logits;
sum_exp_logits =
ctx.AllocateTmpTensor<T, platform::CUDADeviceContext>({N, 1}, dev_ctx);
void* sum_exp_logits_buff = sum_exp_logits.mutable_data<T>(place);

auto eigen_sum_exp_logits = math::EigenMatrix<T>::From(sum_exp_logits);
eigen_sum_exp_logits.device(*dev_ctx.eigen_device()) =
eigen_softmax.sum(along_axis);

in_out.clear();
in_out.push_back(sum_exp_logits);
pg->AllReduce(in_out, in_out, opts)->Synchronize();

auto eigen_loss = math::EigenMatrix<T>::From(loss_2d);
auto eigen_predicted_logits = math::EigenMatrix<T>::From(predicted_logits);

eigen_loss.device(*dev_ctx.eigen_device()) =
(eigen_sum_exp_logits.log().unaryExpr(math::TolerableValue<T>()) -
eigen_predicted_logits)
.unaryExpr(math::TolerableValue<T>());

eigen_softmax.device(*dev_ctx.eigen_device()) =
(eigen_softmax *
eigen_sum_exp_logits.inverse().broadcast(one_by_class));
}
};

template <typename T>
class CSoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
public:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@ limitations under the License. */
#include <utility>
#include <vector>

#include "paddle/fluid/distributed/collective/ProcessGroup.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/cross_entropy.h"
#include "paddle/fluid/operators/math/softmax.h"
#include "paddle/phi/api/include/tensor.h"

namespace paddle {
namespace operators {
Expand All @@ -36,5 +38,15 @@ class CSoftmaxWithCrossEntropyOpCPUKernel : public framework::OpKernel<T> {
}
};

template <typename Context, typename T>
struct CSoftmaxWithCrossEntropyFunctor {
void operator()(const framework::ExecutionContext& ctx);
};

template <typename Context, typename T>
struct CSoftmaxWithCrossEntropyProcessGroupFunctor {
void operator()(const framework::ExecutionContext& ctx);
};

} // namespace operators
} // namespace paddle
Loading