From f322d6935b893e80e6877cbd4f9f9a1fd4f5206a Mon Sep 17 00:00:00 2001 From: ptredak Date: Thu, 8 Nov 2018 15:00:38 -0800 Subject: [PATCH] Fix launch bounds in spatial transformer --- src/operator/spatial_transformer.cu | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/src/operator/spatial_transformer.cu b/src/operator/spatial_transformer.cu index 33dbe3e7c069..1a16de3bab7c 100644 --- a/src/operator/spatial_transformer.cu +++ b/src/operator/spatial_transformer.cu @@ -36,11 +36,13 @@ __device__ bool between(DType value, int lowerBound, int upperBound) { return (value >= lowerBound && value <= upperBound); } template -__global__ void BilinearSamplingForwardKernel(const int i_c, const int i_h, - const int i_w, const DType* data, - const DType* grid, const int o_n, - const int o_c, const int o_h, - const int o_w, DType* out) { +__global__ void +__launch_bounds__(cuda::kMaxThreadsPerBlock, 1) +BilinearSamplingForwardKernel(const int i_c, const int i_h, + const int i_w, const DType* data, + const DType* grid, const int o_n, + const int o_c, const int o_h, + const int o_w, DType* out) { for (int index = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x; index < o_n * o_c * o_h * o_w; index += blockDim.x * gridDim.x * gridDim.y) { @@ -78,12 +80,14 @@ __global__ void BilinearSamplingForwardKernel(const int i_c, const int i_h, } template -__global__ void BilinearSamplingBackwardKernel(const int i_c, const int i_h, - const int i_w, const DType* grad, - const DType* data, const int o_n, - const int o_c, const int o_h, - const int o_w, DType* g_input, - DType* grid_src) { +__global__ void +__launch_bounds__(cuda::kMaxThreadsPerBlock, 1) +BilinearSamplingBackwardKernel(const int i_c, const int i_h, + const int i_w, const DType* grad, + const DType* data, const int o_n, + const int o_c, const int o_h, + const int o_w, DType* g_input, + DType* grid_src) { for (int index = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x; index < o_n * o_h * o_w; index += blockDim.x * gridDim.x * gridDim.y) {