Skip to content

Commit

Permalink
[Dygraph] Refactor Model Parallel in eager mode (#41761) (#41960)
Browse files Browse the repository at this point in the history
* refactor mp in eager mode

* update

* update

* add uts
  • Loading branch information
haohongxiang authored Apr 20, 2022
1 parent ef78c9c commit 5ce7f48
Show file tree
Hide file tree
Showing 12 changed files with 305 additions and 45 deletions.
43 changes: 30 additions & 13 deletions paddle/fluid/operators/class_center_sample_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@ namespace cub = hipcub;
#include <iterator>
#include <random>
#include "paddle/fluid/operators/class_center_sample_op.h"
#include "paddle/phi/api/include/tensor.h"

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/distributed/collective/ProcessGroup.h"
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#endif
Expand Down Expand Up @@ -328,19 +330,34 @@ class ClassCenterSampleCUDAKernel : public framework::OpKernel<T> {

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
if (nranks > 1) {
const auto& comm =
platform::NCCLCommContext::Instance().Get(rid, ctx.GetPlace());
// use global calculate stream
const auto calcu_stream =
static_cast<platform::CUDADeviceContext*>(
platform::DeviceContextPool::Instance().Get(ctx.GetPlace()))
->stream();
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
num_classes_per_device_ptr, num_classes_per_device_ptr,
num_classes_per_device.numel(),
platform::ToNCCLDataType(
framework::TransToProtoVarType(num_classes_per_device.dtype())),
ncclSum, comm->comm(), calcu_stream));
auto map = distributed::ProcessGroupMapFromGid::getInstance();
if (map->has(rid)) {
// Use ProcessGroup
distributed::ProcessGroup* pg = map->get(rid);
std::vector<phi::DenseTensor> in_tensor;
std::vector<phi::DenseTensor> out_tensor;
in_tensor.push_back(num_classes_per_device);
out_tensor.push_back(num_classes_per_device);

distributed::AllreduceOptions opts;
opts.reduce_op = distributed::ReduceOp::SUM;
auto task = pg->AllReduce(in_tensor, out_tensor, opts);
task->Wait();
} else {
const auto& comm =
platform::NCCLCommContext::Instance().Get(rid, ctx.GetPlace());
// use global calculate stream
const auto calcu_stream =
static_cast<platform::CUDADeviceContext*>(
platform::DeviceContextPool::Instance().Get(ctx.GetPlace()))
->stream();
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
num_classes_per_device_ptr, num_classes_per_device_ptr,
num_classes_per_device.numel(),
platform::ToNCCLDataType(
framework::TransToProtoVarType(num_classes_per_device.dtype())),
ncclSum, comm->comm(), calcu_stream));
}
}
#endif

Expand Down
41 changes: 40 additions & 1 deletion paddle/fluid/operators/collective/c_allreduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@ limitations under the License. */

#include <string>

#include "paddle/fluid/distributed/collective/ProcessGroup.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
#include "paddle/phi/api/include/tensor.h"

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_ASCEND_CL) || defined(PADDLE_WITH_XPU_BKCL) || \
Expand Down Expand Up @@ -351,6 +353,7 @@ class CAllReduceOpCUDAKernel : public framework::OpKernel<T> {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto in = ctx.Input<framework::Tensor>("X");
auto out = ctx.Output<framework::Tensor>("Out");
int rid = ctx.Attr<int>("ring_id");

auto place = ctx.GetPlace();
ncclDataType_t dtype =
Expand All @@ -360,7 +363,43 @@ class CAllReduceOpCUDAKernel : public framework::OpKernel<T> {
out->Resize(in->dims());
void* recvbuff = out->mutable_data<T>(place);

int rid = ctx.Attr<int>("ring_id");
auto map = distributed::ProcessGroupMapFromGid::getInstance();
if (map->has(rid)) {
// Use ProcessGroup
distributed::ProcessGroup* pg = map->get(rid);
std::vector<phi::DenseTensor> in_tensor;
std::vector<phi::DenseTensor> out_tensor;
in_tensor.push_back(*in);
out_tensor.push_back(*out);

distributed::AllreduceOptions opts;
switch (red_type) {
case kRedSum:
opts.reduce_op = distributed::ReduceOp::SUM;
break;

case kRedMax:
opts.reduce_op = distributed::ReduceOp::MAX;
break;

case kRedMin:
opts.reduce_op = distributed::ReduceOp::MIN;
break;

case kRedProd:
opts.reduce_op = distributed::ReduceOp::PRODUCT;
break;

default:
PADDLE_THROW(platform::errors::InvalidArgument(
"Invalid reduce type: %d", red_type));
}

auto task = pg->AllReduce(in_tensor, out_tensor, opts);
task->Wait();
return;
}

auto comm = platform::NCCLCommContext::Instance().Get(rid, place);

gpuStream_t stream = nullptr;
Expand Down
47 changes: 31 additions & 16 deletions paddle/fluid/operators/collective/c_concat_op.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@ limitations under the License. */

#include "paddle/fluid/operators/collective/c_concat_op.h"
#include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/phi/api/include/tensor.h"

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/distributed/collective/ProcessGroup.h"
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#endif
Expand Down Expand Up @@ -55,26 +57,39 @@ class CConcatOpCUDAKernel : public framework::OpKernel<T> {
rank, nranks));

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
PADDLE_ENFORCE_EQ(
nranks, comm->nranks(),
platform::errors::InvalidArgument("nranks: %s should equal to %s",
nranks, comm->nranks()));

framework::Tensor temp_out;
framework::DDim temp_out_dims = x->dims();
temp_out_dims[0] *= nranks;
temp_out.mutable_data<T>(temp_out_dims, place);
int64_t send_numel = x->numel();
const T* send_buff = x->data<T>();
T* recv_buff = temp_out.data<T>();
gpuStream_t stream = nullptr;
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<platform::CUDADeviceContext*>(dev_ctx)->stream();

PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllGather(
send_buff, recv_buff, send_numel, static_cast<ncclDataType_t>(dtype),
comm->comm(), stream));

auto map = distributed::ProcessGroupMapFromGid::getInstance();
if (map->has(rid)) {
// Use ProcessGroup
distributed::ProcessGroup* pg = map->get(rid);
std::vector<phi::DenseTensor> in_tensor;
std::vector<phi::DenseTensor> out_tensor;
in_tensor.push_back(*x);
out_tensor.push_back(temp_out);
auto task = pg->AllGather(in_tensor, out_tensor);
task->Wait();
} else {
auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
PADDLE_ENFORCE_EQ(
nranks, comm->nranks(),
platform::errors::InvalidArgument("nranks: %s should equal to %s",
nranks, comm->nranks()));

int64_t send_numel = x->numel();
const T* send_buff = x->data<T>();
T* recv_buff = temp_out.data<T>();
gpuStream_t stream = nullptr;
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<platform::CUDADeviceContext*>(dev_ctx)->stream();

PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllGather(
send_buff, recv_buff, send_numel, static_cast<ncclDataType_t>(dtype),
comm->comm(), stream));
}

std::vector<framework::Tensor> inputs;
int axis = x->dims().size() - 1;
Expand Down
139 changes: 139 additions & 0 deletions paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/fluid/string/string_helper.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"

namespace paddle {
Expand Down Expand Up @@ -73,6 +74,21 @@ template <typename T>
class CSoftmaxWithCrossEntropyOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const int rid = ctx.Attr<int>("ring_id");
auto map = distributed::ProcessGroupMapFromGid::getInstance();
if (map->has(rid)) {
CSoftmaxWithCrossEntropyProcessGroupFunctor<phi::GPUContext, T> functor_;
functor_(ctx);
} else {
CSoftmaxWithCrossEntropyFunctor<phi::GPUContext, T> functor_;
functor_(ctx);
}
}
};

template <typename T>
struct CSoftmaxWithCrossEntropyFunctor<phi::GPUContext, T> {
void operator()(const framework::ExecutionContext& ctx) {
const Tensor* logits = ctx.Input<Tensor>("Logits");
const Tensor* labels = ctx.Input<Tensor>("Label");
Tensor* softmax = ctx.Output<Tensor>("Softmax");
Expand Down Expand Up @@ -201,6 +217,129 @@ class CSoftmaxWithCrossEntropyOpCUDAKernel : public framework::OpKernel<T> {
}
};

template <typename T>
struct CSoftmaxWithCrossEntropyProcessGroupFunctor<phi::GPUContext, T> {
void operator()(const framework::ExecutionContext& ctx) {
const Tensor* logits = ctx.Input<Tensor>("Logits");
const Tensor* labels = ctx.Input<Tensor>("Label");
Tensor* softmax = ctx.Output<Tensor>("Softmax");
Tensor* loss = ctx.Output<Tensor>("Loss");

const int rid = ctx.Attr<int>("ring_id");
const int nranks = ctx.Attr<int>("nranks");
const int rank = ctx.Attr<int>("rank");

const auto& place = ctx.GetPlace();
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();

auto map = distributed::ProcessGroupMapFromGid::getInstance();
distributed::ProcessGroup* pg = map->get(rid);
distributed::AllreduceOptions opts;
opts.reduce_op = distributed::ReduceOp::SUM;

// allocate memory on device.
softmax->mutable_data<T>(place);
loss->mutable_data<T>(place);

const auto& logits_dims = logits->dims();
const auto& labels_dims = labels->dims();

const int axis = logits_dims.size() - 1;
const int N = phi::funcs::SizeToAxis(axis, logits_dims);
const int D = phi::funcs::SizeFromAxis(axis, logits_dims);

Tensor logits_2d, softmax_2d, loss_2d;
logits_2d.ShareDataWith(*logits).Resize({N, D});
softmax_2d.ShareDataWith(*softmax).Resize({N, D});
loss_2d.ShareDataWith(*loss).Resize({N, 1});

auto eigen_logits = math::EigenMatrix<T>::From(logits_2d);
auto eigen_softmax = math::EigenMatrix<T>::From(softmax_2d);

// step 1, obtain logit_max
Tensor logits_max;
logits_max =
ctx.AllocateTmpTensor<T, platform::CUDADeviceContext>({N, 1}, dev_ctx);

auto eigen_logits_max = math::EigenMatrix<T>::From(logits_max);
Eigen::DSizes<int, 1> along_axis(1);
eigen_logits_max.device(*dev_ctx.eigen_device()) =
eigen_logits.maximum(along_axis);

std::vector<phi::DenseTensor> in_out;
in_out.push_back(logits_max);
pg->AllReduce(in_out, in_out, opts)->Synchronize();

// step 2, obtain logit - logit_max
Eigen::DSizes<int, 2> batch_by_one(N, 1);
Eigen::DSizes<int, 2> one_by_class(1, D);

eigen_softmax.device(*dev_ctx.eigen_device()) =
(eigen_logits -
eigen_logits_max.reshape(batch_by_one).broadcast(one_by_class))
.unaryExpr(math::ValueClip<T>());

// step 3, obtain predict target
Tensor predicted_logits;
predicted_logits =
ctx.AllocateTmpTensor<T, platform::CUDADeviceContext>({N, 1}, dev_ctx);
predicted_logits.mutable_data<T>(place);

auto t = framework::EigenVector<T>::Flatten(predicted_logits);
t.device(*dev_ctx.eigen_device()) = t.constant(static_cast<T>(0));

const int start_index = rank * D;
const int end_index = start_index + D;

int blocks = NumBlocks(N);
int threads = kNumCUDAThreads;
const auto& label_type = framework::TransToProtoVarType(labels->dtype());

if (label_type == framework::proto::VarType::INT32) {
MaskLabelByIndex<T, int32_t><<<blocks, threads, 0, dev_ctx.stream()>>>(
predicted_logits.data<T>(), softmax_2d.data<T>(),
labels->data<int32_t>(), start_index, end_index, N, D, nranks);
} else if (label_type == framework::proto::VarType::INT64) {
MaskLabelByIndex<T, int64_t><<<blocks, threads, 0, dev_ctx.stream()>>>(
predicted_logits.data<T>(), softmax_2d.data<T>(),
labels->data<int64_t>(), start_index, end_index, N, D, nranks);
}

in_out.clear();
in_out.push_back(predicted_logits);
pg->AllReduce(in_out, in_out, opts)->Synchronize();

// step 4, obtain exp(logit)
eigen_softmax.device(*dev_ctx.eigen_device()) = eigen_softmax.exp();

// step 5, obtain sum_exp_logits
Tensor sum_exp_logits;
sum_exp_logits =
ctx.AllocateTmpTensor<T, platform::CUDADeviceContext>({N, 1}, dev_ctx);
void* sum_exp_logits_buff = sum_exp_logits.mutable_data<T>(place);

auto eigen_sum_exp_logits = math::EigenMatrix<T>::From(sum_exp_logits);
eigen_sum_exp_logits.device(*dev_ctx.eigen_device()) =
eigen_softmax.sum(along_axis);

in_out.clear();
in_out.push_back(sum_exp_logits);
pg->AllReduce(in_out, in_out, opts)->Synchronize();

auto eigen_loss = math::EigenMatrix<T>::From(loss_2d);
auto eigen_predicted_logits = math::EigenMatrix<T>::From(predicted_logits);

eigen_loss.device(*dev_ctx.eigen_device()) =
(eigen_sum_exp_logits.log().unaryExpr(math::TolerableValue<T>()) -
eigen_predicted_logits)
.unaryExpr(math::TolerableValue<T>());

eigen_softmax.device(*dev_ctx.eigen_device()) =
(eigen_softmax *
eigen_sum_exp_logits.inverse().broadcast(one_by_class));
}
};

template <typename T>
class CSoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
public:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@ limitations under the License. */
#include <utility>
#include <vector>

#include "paddle/fluid/distributed/collective/ProcessGroup.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/cross_entropy.h"
#include "paddle/fluid/operators/math/softmax.h"
#include "paddle/phi/api/include/tensor.h"

namespace paddle {
namespace operators {
Expand All @@ -36,5 +38,15 @@ class CSoftmaxWithCrossEntropyOpCPUKernel : public framework::OpKernel<T> {
}
};

template <typename Context, typename T>
struct CSoftmaxWithCrossEntropyFunctor {
void operator()(const framework::ExecutionContext& ctx);
};

template <typename Context, typename T>
struct CSoftmaxWithCrossEntropyProcessGroupFunctor {
void operator()(const framework::ExecutionContext& ctx);
};

} // namespace operators
} // namespace paddle
Loading

0 comments on commit 5ce7f48

Please sign in to comment.