Skip to content

Commit

Permalink
Optimize performance of softmax_bwd when axis!=-1 (#38609)
Browse files Browse the repository at this point in the history
* Optimize performance of softmax_bwd when axis!=-1

* fix

* fix

* fix

* fix
  • Loading branch information
ZzSean authored Feb 11, 2022
1 parent a117497 commit 2ea15fc
Showing 1 changed file with 62 additions and 0 deletions.
62 changes: 62 additions & 0 deletions paddle/fluid/operators/softmax_cudnn_op.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,43 @@ __global__ void NormalSoftmaxForward(T* output, const T* input, int high_dim,
}
}

template <typename T, typename AccT,
template <typename, typename> class Functor>
__global__ void NormalSoftmaxBackward(T* input_grad, const T* output_grad,
const T* output, int high_dim,
int mid_dim, int low_dim) {
using kMode = kps::details::ReduceMode;
const int high_stride = mid_dim * low_dim;
const int mid_stride = low_dim;
for (int high_id = blockIdx.y; high_id < high_dim; high_id += gridDim.y) {
for (int low_id = blockIdx.x * blockDim.x + threadIdx.x; low_id < low_dim;
low_id += blockDim.x * gridDim.x) {
const int grad_offset = high_id * high_stride + low_id;

// 1. reduce sum
AccT sum = 0;
for (int mid_id = threadIdx.y; mid_id < mid_dim; mid_id += blockDim.y) {
int data_offset = grad_offset + mid_id * mid_stride;
sum += static_cast<AccT>(output_grad[data_offset]) *
static_cast<AccT>(output[data_offset]);
}
if (blockDim.y > 1) {
kps::Reduce<AccT, 1, 1, 1, kps::AddFunctor<AccT>, kMode::kGlobalMode>(
&sum, &sum, kps::AddFunctor<AccT>(), false);
}

// 2. (log)softmax backward
Functor<AccT, T> functor(sum);
for (int mid_id = threadIdx.y; mid_id < mid_dim; mid_id += blockDim.y) {
int data_offset = grad_offset + mid_id * mid_stride;
input_grad[data_offset] =
functor(static_cast<AccT>(output_grad[data_offset]),
static_cast<AccT>(output[data_offset]));
}
}
}
}

template <typename T, bool LogMode = false>
void LaunchNormalSoftmaxForward(const platform::CUDADeviceContext& dev_ctx,
T* output_data, const T* input_data,
Expand All @@ -603,6 +640,28 @@ void LaunchNormalSoftmaxForward(const platform::CUDADeviceContext& dev_ctx,
}
}

template <typename T, bool LogMode = false>
void LaunchNormalSoftmaxBackward(const platform::CUDADeviceContext& dev_ctx,
T* input_grad_data, const T* output_grad_data,
const T* output_data, int high_dim,
int mid_dim, int low_dim) {
using AccT = typename details::MPTypeTrait<T>::Type;
dim3 grid, block;
GetLaunchConfig(high_dim, mid_dim, low_dim, &grid, &block);
if (LogMode) {
NormalSoftmaxBackward<
T, AccT,
LogSoftmaxBackwardFunctor><<<grid, block, 0, dev_ctx.stream()>>>(
input_grad_data, output_grad_data, output_data, high_dim, mid_dim,
low_dim);
} else {
NormalSoftmaxBackward<
T, AccT, SoftmaxBackwardFunctor><<<grid, block, 0, dev_ctx.stream()>>>(
input_grad_data, output_grad_data, output_data, high_dim, mid_dim,
low_dim);
}
}

template <typename T, bool LogMode = false>
void SoftmaxForwardCUDAKernelDriver(const platform::CUDADeviceContext& dev_ctx,
const Tensor& x, const int input_axis,
Expand Down Expand Up @@ -741,6 +800,9 @@ void SoftmaxBackwardCUDAKernelDriver(const platform::CUDADeviceContext& dev_ctx,
blocks, threads, dev_ctx, dx_data, dout.data<T>(), out.data<T>(), N,
dim, dim, kDimLog2);
}
} else if (D > 1) {
LaunchNormalSoftmaxBackward<T, LogMode>(dev_ctx, dx_data, dout.data<T>(),
out.data<T>(), N, dim, D);
} else {
ScopedTensorDescriptor desc;
std::vector<int> tensor_dims = {N, dim, D, 1};
Expand Down

0 comments on commit 2ea15fc

Please sign in to comment.