Skip to content

Commit

Permalink
optimize backward
Browse files Browse the repository at this point in the history
  • Loading branch information
haohongxiang committed Nov 9, 2021
1 parent 2a143f8 commit d9b93c5
Showing 1 changed file with 27 additions and 14 deletions.
41 changes: 27 additions & 14 deletions paddle/fluid/operators/index_select_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,18 @@ __global__ void index_select_grad_cuda_kernel(const T* output_grad,

int64_t pre_idx = idx / (stride * size);
int64_t dim_idx = idx % (stride * size) / stride;
int64_t begin_idx = idx + (delta * pre_idx - dim_idx) * stride;
IndexT src_dim_idx = index[dim_idx];
int64_t input_idx = idx + (delta * pre_idx + src_dim_idx - dim_idx) * stride;
paddle::platform::CudaAtomicAdd(&input_grad[input_idx], output_grad[idx]);
}

input_grad[idx] = 0.0;
for (int64_t i = 0; i < nums; i++) {
if (index[i] == dim_idx) {
input_grad[idx] += output_grad[begin_idx + i * stride];
}
template <typename T>
__global__ void index_select_grad_init(T* input_grad, int64_t N) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= N) {
return;
}
input_grad[idx] = 0.0;
}

template <typename DeviceContext, typename T>
Expand Down Expand Up @@ -143,8 +147,8 @@ class IndexSelectGradCUDAKernel : public framework::OpKernel<T> {
dim = dim >= 0 ? dim : dim + input_dim.size();
auto stride_dim = framework::stride(input_dim);
int64_t stride = stride_dim[dim];
int64_t size = input_dim[dim];
int64_t delta = output_dim[dim] - size;
int64_t size = output_dim[dim];
int64_t delta = input_dim[dim] - size;

const auto& index_type = index->type();
bool index_type_match = index_type == framework::proto::VarType::INT64 ||
Expand All @@ -161,17 +165,22 @@ class IndexSelectGradCUDAKernel : public framework::OpKernel<T> {

int64_t numel = in_grad->numel();
int64_t index_nums = index->numel();
int64_t out_nums = output_grad->numel();

auto stream =
context.template device_context<platform::CUDADeviceContext>().stream();

index_select_grad_init<
T><<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(in_grad_data, numel);

if (index_type == framework::proto::VarType::INT64) {
const int64_t* index_data = index->data<int64_t>();
index_select_grad_cuda_kernel<T, int64_t><<<
(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS,
(out_nums + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(output_grad_data, in_grad_data,
index_data, index_nums, numel,
stride, size, delta);
index_data, index_nums,
out_nums, stride, size, delta);
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(stream));
#else
Expand All @@ -180,10 +189,10 @@ class IndexSelectGradCUDAKernel : public framework::OpKernel<T> {
} else {
const int* index_data = index->data<int>();
index_select_grad_cuda_kernel<T, int><<<
(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS,
(out_nums + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(output_grad_data, in_grad_data,
index_data, index_nums, numel,
stride, size, delta);
index_data, index_nums,
out_nums, stride, size, delta);
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(stream));
#else
Expand All @@ -201,12 +210,16 @@ REGISTER_OP_CUDA_KERNEL(
index_select,
ops::IndexSelectCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::IndexSelectCUDAKernel<paddle::platform::CUDADeviceContext, double>,
ops::IndexSelectCUDAKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::IndexSelectCUDAKernel<paddle::platform::CUDADeviceContext, int>,
ops::IndexSelectCUDAKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
index_select_grad,
ops::IndexSelectGradCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::IndexSelectGradCUDAKernel<paddle::platform::CUDADeviceContext, double>,
ops::IndexSelectGradCUDAKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::IndexSelectGradCUDAKernel<paddle::platform::CUDADeviceContext, int>,
ops::IndexSelectGradCUDAKernel<paddle::platform::CUDADeviceContext,
int64_t>);

1 comment on commit d9b93c5

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.