diff --git a/src/operator/spatial_transformer.cu b/src/operator/spatial_transformer.cu index 33dbe3e7c069..fd330bd4ca87 100644 --- a/src/operator/spatial_transformer.cu +++ b/src/operator/spatial_transformer.cu @@ -35,12 +35,23 @@ template __device__ bool between(DType value, int lowerBound, int upperBound) { return (value >= lowerBound && value <= upperBound); } + template -__global__ void BilinearSamplingForwardKernel(const int i_c, const int i_h, - const int i_w, const DType* data, - const DType* grid, const int o_n, - const int o_c, const int o_h, - const int o_w, DType* out) { +__global__ void +/* + * In order to not generate the code that uses too many + * registers (resulting in too many resources requested + * error) we need to tell the compiler that we will be + * launching this kernel with cuda::kMaxThreadsPerBlock + * threads per block. Setting __launch_bounds__ ensures + * that such configuration can always be launched. + */ +__launch_bounds__(cuda::kMaxThreadsPerBlock, 1) +BilinearSamplingForwardKernel(const int i_c, const int i_h, + const int i_w, const DType* data, + const DType* grid, const int o_n, + const int o_c, const int o_h, + const int o_w, DType* out) { for (int index = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x; index < o_n * o_c * o_h * o_w; index += blockDim.x * gridDim.x * gridDim.y) { @@ -77,13 +88,23 @@ __global__ void BilinearSamplingForwardKernel(const int i_c, const int i_h, } } +/* + * In order to not generate the code that uses too many + * registers (resulting in too many resources requested + * error) we need to tell the compiler that we will be + * launching this kernel with cuda::kMaxThreadsPerBlock + * threads per block. Setting __launch_bounds__ ensures + * that such configuration can always be launched. + */ template -__global__ void BilinearSamplingBackwardKernel(const int i_c, const int i_h, - const int i_w, const DType* grad, - const DType* data, const int o_n, - const int o_c, const int o_h, - const int o_w, DType* g_input, - DType* grid_src) { +__global__ void +__launch_bounds__(cuda::kMaxThreadsPerBlock, 1) +BilinearSamplingBackwardKernel(const int i_c, const int i_h, + const int i_w, const DType* grad, + const DType* data, const int o_n, + const int o_c, const int o_h, + const int o_w, DType* g_input, + DType* grid_src) { for (int index = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x; index < o_n * o_h * o_w; index += blockDim.x * gridDim.x * gridDim.y) {