Skip to content

Commit

Permalink
optimize backward (#37055)
Browse files Browse the repository at this point in the history
  • Loading branch information
haohongxiang authored Nov 9, 2021
1 parent 7181670 commit aac00f6
Showing 1 changed file with 27 additions and 14 deletions.
41 changes: 27 additions & 14 deletions paddle/fluid/operators/index_select_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,18 @@ __global__ void index_select_grad_cuda_kernel(const T* output_grad,

int64_t pre_idx = idx / (stride * size);
int64_t dim_idx = idx % (stride * size) / stride;
int64_t begin_idx = idx + (delta * pre_idx - dim_idx) * stride;
IndexT src_dim_idx = index[dim_idx];
int64_t input_idx = idx + (delta * pre_idx + src_dim_idx - dim_idx) * stride;
paddle::platform::CudaAtomicAdd(&input_grad[input_idx], output_grad[idx]);
}

input_grad[idx] = 0.0;
for (int64_t i = 0; i < nums; i++) {
if (index[i] == dim_idx) {
input_grad[idx] += output_grad[begin_idx + i * stride];
}
template <typename T>
__global__ void index_select_grad_init(T* input_grad, int64_t N) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= N) {
return;
}
input_grad[idx] = 0.0;
}

template <typename DeviceContext, typename T>
Expand Down Expand Up @@ -143,8 +147,8 @@ class IndexSelectGradCUDAKernel : public framework::OpKernel<T> {
dim = dim >= 0 ? dim : dim + input_dim.size();
auto stride_dim = framework::stride(input_dim);
int64_t stride = stride_dim[dim];
int64_t size = input_dim[dim];
int64_t delta = output_dim[dim] - size;
int64_t size = output_dim[dim];
int64_t delta = input_dim[dim] - size;

const auto& index_type = index->type();
bool index_type_match = index_type == framework::proto::VarType::INT64 ||
Expand All @@ -161,17 +165,22 @@ class IndexSelectGradCUDAKernel : public framework::OpKernel<T> {

int64_t numel = in_grad->numel();
int64_t index_nums = index->numel();
int64_t out_nums = output_grad->numel();

auto stream =
context.template device_context<platform::CUDADeviceContext>().stream();

index_select_grad_init<
T><<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(in_grad_data, numel);

if (index_type == framework::proto::VarType::INT64) {
const int64_t* index_data = index->data<int64_t>();
index_select_grad_cuda_kernel<T, int64_t><<<
(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS,
(out_nums + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(output_grad_data, in_grad_data,
index_data, index_nums, numel,
stride, size, delta);
index_data, index_nums,
out_nums, stride, size, delta);
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(stream));
#else
Expand All @@ -180,10 +189,10 @@ class IndexSelectGradCUDAKernel : public framework::OpKernel<T> {
} else {
const int* index_data = index->data<int>();
index_select_grad_cuda_kernel<T, int><<<
(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS,
(out_nums + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(output_grad_data, in_grad_data,
index_data, index_nums, numel,
stride, size, delta);
index_data, index_nums,
out_nums, stride, size, delta);
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(stream));
#else
Expand All @@ -201,12 +210,16 @@ REGISTER_OP_CUDA_KERNEL(
index_select,
ops::IndexSelectCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::IndexSelectCUDAKernel<paddle::platform::CUDADeviceContext, double>,
ops::IndexSelectCUDAKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::IndexSelectCUDAKernel<paddle::platform::CUDADeviceContext, int>,
ops::IndexSelectCUDAKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
index_select_grad,
ops::IndexSelectGradCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::IndexSelectGradCUDAKernel<paddle::platform::CUDADeviceContext, double>,
ops::IndexSelectGradCUDAKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::IndexSelectGradCUDAKernel<paddle::platform::CUDADeviceContext, int>,
ops::IndexSelectGradCUDAKernel<paddle::platform::CUDADeviceContext,
int64_t>);

0 comments on commit aac00f6

Please sign in to comment.