Skip to content

Commit

Permalink
[cherry-pick] Support FP16 for index_select op (#38751)
Browse files Browse the repository at this point in the history
* optimize backward (#37055)

* update

* update

* update

* modify code style
  • Loading branch information
haohongxiang authored Jan 6, 2022
1 parent 327f635 commit 0d081cb
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions paddle/fluid/operators/index_select_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -201,12 +201,16 @@ REGISTER_OP_CUDA_KERNEL(
index_select,
ops::IndexSelectCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::IndexSelectCUDAKernel<paddle::platform::CUDADeviceContext, double>,
ops::IndexSelectCUDAKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::IndexSelectCUDAKernel<paddle::platform::CUDADeviceContext, int>,
ops::IndexSelectCUDAKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
index_select_grad,
ops::IndexSelectGradCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::IndexSelectGradCUDAKernel<paddle::platform::CUDADeviceContext, double>,
ops::IndexSelectGradCUDAKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::IndexSelectGradCUDAKernel<paddle::platform::CUDADeviceContext, int>,
ops::IndexSelectGradCUDAKernel<paddle::platform::CUDADeviceContext,
int64_t>);

0 comments on commit 0d081cb

Please sign in to comment.