diff --git a/src/operator/contrib/deformable_convolution-inl.h b/src/operator/contrib/deformable_convolution-inl.h index 0ea0a2af02d7..eb23d99bbb1a 100644 --- a/src/operator/contrib/deformable_convolution-inl.h +++ b/src/operator/contrib/deformable_convolution-inl.h @@ -61,9 +61,9 @@ struct DeformableConvolutionParam : public dmlc::Parameter layout; @@ -232,7 +232,7 @@ class DeformableConvolutionOp : public Operator { in_grad[conv::kData].shape_, col_buffer.shape_, param_.kernel, param_.pad, param_.stride, param_.dilate, param_.num_deformable_group, - in_grad[conv::kOffset].dptr() + n*input_offset_dim_); + in_grad[conv::kOffset].dptr() + n * input_offset_dim_); // gradient w.r.t. input data deformable_col2im(s, col_buffer.dptr(), diff --git a/src/operator/contrib/nn/deformable_im2col.cuh b/src/operator/contrib/nn/deformable_im2col.cuh index a12f04e8743f..9494fb379faf 100644 --- a/src/operator/contrib/nn/deformable_im2col.cuh +++ b/src/operator/contrib/nn/deformable_im2col.cuh @@ -76,27 +76,25 @@ namespace op { template __device__ DType deformable_im2col_bilinear(const DType* bottom_data, - const int data_width, - const int height, - const int width, + const index_t data_width, + const index_t height, + const index_t width, DType h, DType w) { - int h_low = floor(h); - int w_low = floor(w); - int h_high; - int w_high; + index_t h_low = floor(h); + index_t w_low = floor(w); + index_t h_high; + index_t w_high; if (h_low >= height - 1) { h_high = h_low = height - 1; h = static_cast(h_low); - } - else { + } else { h_high = h_low + 1; } if (w_low >= width - 1) { w_high = w_low = width - 1; w = static_cast(w_low); - } - else { + } else { w_high = w_low + 1; } @@ -117,8 +115,8 @@ __device__ DType deformable_im2col_bilinear(const DType* bottom_data, template __device__ DType get_gradient_weight(DType argmax_h, DType argmax_w, - const int h, const int w, - const int height, const int width) { + const index_t h, const index_t w, + const index_t height, const index_t width) { if (argmax_h < 0 || argmax_h > height || argmax_w < 0 || argmax_w > width) { //empty return 0; @@ -127,10 +125,10 @@ __device__ DType get_gradient_weight(DType argmax_h, DType argmax_w, argmax_h = max(argmax_h, static_cast(0.0f)); argmax_w = max(argmax_w, static_cast(0.0f)); - int argmax_h_low = static_cast(argmax_h); - int argmax_w_low = static_cast(argmax_w); - int argmax_h_high; - int argmax_w_high; + index_t argmax_h_low = static_cast(argmax_h); + index_t argmax_w_low = static_cast(argmax_w); + index_t argmax_h_high; + index_t argmax_w_high; if (argmax_h_low >= height - 1) { argmax_h_high = argmax_h_low = height - 1; argmax_h = static_cast(argmax_h_low); @@ -164,10 +162,10 @@ __device__ DType get_gradient_weight(DType argmax_h, DType argmax_w, template __device__ DType get_coordinate_weight(DType argmax_h, DType argmax_w, - const int height, const int width, + const index_t height, const index_t width, const DType* im_data, - const int data_width, - const int bp_dir) { + const index_t data_width, + const index_t bp_dir) { if (argmax_h < 0 || argmax_h > height || argmax_w < 0 || argmax_w > width) { //empty @@ -177,10 +175,10 @@ __device__ DType get_coordinate_weight(DType argmax_h, DType argmax_w, if (argmax_h < 0) argmax_h = 0; if (argmax_w < 0) argmax_w = 0; - int argmax_h_low = static_cast(argmax_h); - int argmax_w_low = static_cast(argmax_w); - int argmax_h_high; - int argmax_w_high; + index_t argmax_h_low = static_cast(argmax_h); + index_t argmax_w_low = static_cast(argmax_w); + index_t argmax_h_high; + index_t argmax_w_high; if (argmax_h_low >= height - 1) { argmax_h_high = argmax_h_low = height - 1; argmax_h = static_cast(argmax_h_low); @@ -220,38 +218,38 @@ __device__ DType get_coordinate_weight(DType argmax_h, DType argmax_w, * DO NOT call this directly. Use wrapper function im2col() instead; */ template -__global__ void deformable_im2col_gpu_kernel(const int n, const DType* data_im, +__global__ void deformable_im2col_gpu_kernel(const index_t n, const DType* data_im, const DType* data_offset, - const int height, const int width, - const int kernel_h, const int kernel_w, - const int pad_h, const int pad_w, - const int stride_h, const int stride_w, - const int dilation_h, const int dilation_w, - const int channel_per_group, - const int height_col, const int width_col, + const index_t height, const index_t width, + const index_t kernel_h, const index_t kernel_w, + const index_t pad_h, const index_t pad_w, + const index_t stride_h, const index_t stride_w, + const index_t dilation_h, const index_t dilation_w, + const index_t channel_per_group, + const index_t height_col, const index_t width_col, DType* data_col) { CUDA_KERNEL_LOOP(index, n) { // index index of output matrix - const int w_col = index % width_col; - const int h_col = (index / width_col) % height_col; - const int c_im = (index / width_col) / height_col; - const int c_col = c_im * kernel_h * kernel_w; + const index_t w_col = index % width_col; + const index_t h_col = (index / width_col) % height_col; + const index_t c_im = (index / width_col) / height_col; + const index_t c_col = c_im * kernel_h * kernel_w; - const int group_index = c_im / channel_per_group; - const int group_offset_step = 2 * kernel_h * kernel_w * height_col * width_col; + const index_t group_index = c_im / channel_per_group; + const index_t group_offset_step = 2 * kernel_h * kernel_w * height_col * width_col; - const int h_in = h_col * stride_h - pad_h; - const int w_in = w_col * stride_w - pad_w; + const index_t h_in = h_col * stride_h - pad_h; + const index_t w_in = w_col * stride_w - pad_w; DType* data_col_ptr = data_col + (c_col * height_col + h_col) * width_col + w_col; const DType* data_im_ptr = data_im + (c_im * height + h_in) * width + w_in; const DType* data_offset_ptr = data_offset + group_index * group_offset_step; - for (int i = 0; i < kernel_h; ++i) { - for (int j = 0; j < kernel_w; ++j) { - const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * + for (index_t i = 0; i < kernel_h; ++i) { + for (index_t j = 0; j < kernel_w; ++j) { + const index_t data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; - const int data_offset_w_ptr = data_offset_h_ptr + height_col * width_col; + const index_t data_offset_w_ptr = data_offset_h_ptr + height_col * width_col; const DType offset_h = data_offset_ptr[data_offset_h_ptr]; const DType offset_w = data_offset_ptr[data_offset_w_ptr]; DType val = static_cast(0); @@ -260,8 +258,8 @@ __global__ void deformable_im2col_gpu_kernel(const int n, const DType* data_im, if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) { const DType map_h = i * dilation_h + offset_h; const DType map_w = j * dilation_w + offset_w; - const int cur_height = height - h_in; - const int cur_width = width - w_in; + const index_t cur_height = height - h_in; + const index_t cur_width = width - w_in; val = deformable_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w); } *data_col_ptr = val; @@ -296,13 +294,13 @@ inline void deformable_im2col(mshadow::Stream* s, const mxnet::TShape& pad, const mxnet::TShape& stride, const mxnet::TShape& dilation, - const uint32_t deformable_group, + const index_t deformable_group, DType* data_col) { // num_axes should be smaller than block size - int num_spatial_axes = kernel_shape.ndim(); + const int num_spatial_axes = kernel_shape.ndim(); CHECK_LT(num_spatial_axes, mshadow::cuda::kBaseThreadNum); - int channel_per_group = im_shape[1] / deformable_group; - int num_kernels = im_shape[1] * col_shape.ProdShape(1, col_shape.ndim()); + index_t channel_per_group = im_shape[1] / deformable_group; + index_t num_kernels = im_shape[1] * col_shape.ProdShape(1, col_shape.ndim()); using namespace mxnet_op; switch (num_spatial_axes) { case 2: @@ -329,42 +327,42 @@ inline void deformable_im2col(mshadow::Stream* s, * \brief DO NOT call this directly. Use wrapper function deformable_col2im() instead; */ template -__global__ void deformable_col2im_gpu_kernel(const int n, const DType* data_col, - const DType* data_offset, const int channels, - const int height, const int width, - const int kernel_h, const int kernel_w, - const int pad_h, const int pad_w, - const int stride_h, const int stride_w, - const int dilation_h, const int dilation_w, - const int channel_per_group, - const int height_col, const int width_col, +__global__ void deformable_col2im_gpu_kernel(const index_t n, const DType* data_col, + const DType* data_offset, const index_t channels, + const index_t height, const index_t width, + const index_t kernel_h, const index_t kernel_w, + const index_t pad_h, const index_t pad_w, + const index_t stride_h, const index_t stride_w, + const index_t dilation_h, const index_t dilation_w, + const index_t channel_per_group, + const index_t height_col, const index_t width_col, DType* grad_im) { CUDA_KERNEL_LOOP(index, n) { - const int j = (index / width_col / height_col) % kernel_w; - const int i = (index / width_col / height_col / kernel_w) % kernel_h; - const int c = index / width_col / height_col / kernel_w / kernel_h; + const index_t j = (index / width_col / height_col) % kernel_w; + const index_t i = (index / width_col / height_col / kernel_w) % kernel_h; + const index_t c = index / width_col / height_col / kernel_w / kernel_h; // compute the start and end of the output - const int group_index = c / channel_per_group; - const int group_offset_step = 2 * kernel_h * kernel_w * height_col * width_col; + const index_t group_index = c / channel_per_group; + const index_t group_offset_step = 2 * kernel_h * kernel_w * height_col * width_col; - int w_col = index % width_col; - int h_col = (index / width_col) % height_col; - int w_in = w_col * stride_w - pad_w; - int h_in = h_col * stride_h - pad_h; + index_t w_col = index % width_col; + index_t h_col = (index / width_col) % height_col; + index_t w_in = w_col * stride_w - pad_w; + index_t h_in = h_col * stride_h - pad_h; const DType* data_offset_ptr = data_offset + group_index * group_offset_step; - const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * + const index_t data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; - const int data_offset_w_ptr = data_offset_h_ptr + height_col * width_col; + const index_t data_offset_w_ptr = data_offset_h_ptr + height_col * width_col; const DType offset_h = data_offset_ptr[data_offset_h_ptr]; const DType offset_w = data_offset_ptr[data_offset_w_ptr]; const DType cur_inv_h_data = h_in + i * dilation_h + offset_h; const DType cur_inv_w_data = w_in + j * dilation_w + offset_w; const DType cur_top_grad = data_col[index]; - const int cur_h = (int)cur_inv_h_data; - const int cur_w = (int)cur_inv_w_data; + const index_t cur_h = static_cast(cur_inv_h_data); + const index_t cur_w = static_cast(cur_inv_w_data); for (int dy = -2; dy <= 2; dy++) { for (int dx = -2; dx <= 2; dx++) { if (cur_h + dy >= 0 && cur_h + dy < height && @@ -372,7 +370,7 @@ __global__ void deformable_col2im_gpu_kernel(const int n, const DType* data_col, abs(cur_inv_h_data - (cur_h + dy)) < 1 && abs(cur_inv_w_data - (cur_w + dx)) < 1 ) { - int cur_bottom_grad_pos = (c * height + cur_h + dy) * width + cur_w + dx; + index_t cur_bottom_grad_pos = (c * height + cur_h + dy) * width + cur_w + dx; DType weight = get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width); atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); @@ -407,12 +405,12 @@ inline void deformable_col2im(mshadow::Stream* s, const mxnet::TShape& pad, const mxnet::TShape& stride, const mxnet::TShape& dilation, - const uint32_t deformable_group, + const index_t deformable_group, DType* grad_im) { - int num_spatial_axes = kernel_shape.ndim(); - int im_size = im_shape.ProdShape(1, im_shape.ndim()); - int channel_per_group = im_shape[1] / deformable_group; - int num_kernels = col_shape.ProdShape(0, col_shape.ndim()); + const int num_spatial_axes = kernel_shape.ndim(); + index_t im_size = im_shape.ProdShape(1, im_shape.ndim()); + index_t channel_per_group = im_shape[1] / deformable_group; + index_t num_kernels = col_shape.ProdShape(0, col_shape.ndim()); // num_axes should be smaller than block size CHECK_LT(num_spatial_axes, mshadow::cuda::kBaseThreadNum); using namespace mxnet_op; @@ -444,50 +442,50 @@ inline void deformable_col2im(mshadow::Stream* s, * \brief DO NOT call this directly. Use wrapper function deformable_col2im_coord() instead; */ template -__global__ void deformable_col2im_coord_gpu_kernel(const int n, const DType* data_col, +__global__ void deformable_col2im_coord_gpu_kernel(const index_t n, const DType* data_col, const DType* data_im, const DType* data_offset, - const int channels, - const int height, const int width, - const int kernel_h, const int kernel_w, - const int pad_h, const int pad_w, - const int stride_h, const int stride_w, - const int dilation_h, const int dilation_w, - const int channel_per_group, - const int height_col, const int width_col, + const index_t channels, + const index_t height, const index_t width, + const index_t kernel_h, const index_t kernel_w, + const index_t pad_h, const index_t pad_w, + const index_t stride_h, const index_t stride_w, + const index_t dilation_h, const index_t dilation_w, + const index_t channel_per_group, + const index_t height_col, const index_t width_col, DType* grad_offset) { CUDA_KERNEL_LOOP(index, n) { DType val = 0; - int w = index % width_col; - int h = (index / width_col) % height_col; - int c = index / width_col / height_col; + index_t w = index % width_col; + index_t h = (index / width_col) % height_col; + index_t c = index / width_col / height_col; // compute the start and end of the output - const int group_index = c / (2 * kernel_h * kernel_w); - const int group_col_step = channel_per_group * width_col * height_col; - const int group_im_step = channel_per_group / kernel_h / kernel_w * height * width; - const int group_offset_step = 2 * kernel_h * kernel_w * height_col * width_col; - const int col_step = kernel_h * kernel_w; + const index_t group_index = c / (2 * kernel_h * kernel_w); + const index_t group_col_step = channel_per_group * width_col * height_col; + const index_t group_im_step = channel_per_group / kernel_h / kernel_w * height * width; + const index_t group_offset_step = 2 * kernel_h * kernel_w * height_col * width_col; + const index_t col_step = kernel_h * kernel_w; const DType* data_col_ptr = data_col + group_index * group_col_step; const DType* data_im_ptr = data_im + group_index * group_im_step; const DType* data_offset_ptr = data_offset + group_index * group_offset_step; - int cnt = 0; - const int offset_c = c - group_index * 2 * kernel_h * kernel_w; + index_t cnt = 0; + const index_t offset_c = c - group_index * 2 * kernel_h * kernel_w; - for (int col_c = (offset_c / 2); col_c < channel_per_group; col_c += col_step) { - const int col_pos = ((col_c * height_col) + h) * width_col + w; - const int bp_dir = offset_c % 2; + for (index_t col_c = (offset_c / 2); col_c < channel_per_group; col_c += col_step) { + const index_t col_pos = ((col_c * height_col) + h) * width_col + w; + const index_t bp_dir = offset_c % 2; - int j = (col_pos / width_col / height_col) % kernel_w; - int i = (col_pos / width_col / height_col / kernel_w) % kernel_h; - int w_col = col_pos % width_col; - int h_col = (col_pos / width_col) % height_col; - int w_in = w_col * stride_w - pad_w; - int h_in = h_col * stride_h - pad_h; - const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * + index_t j = (col_pos / width_col / height_col) % kernel_w; + index_t i = (col_pos / width_col / height_col / kernel_w) % kernel_h; + index_t w_col = col_pos % width_col; + index_t h_col = (col_pos / width_col) % height_col; + index_t w_in = w_col * stride_w - pad_w; + index_t h_in = h_col * stride_h - pad_h; + const index_t data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; - const int data_offset_w_ptr = data_offset_h_ptr + height_col * width_col; + const index_t data_offset_w_ptr = data_offset_h_ptr + height_col * width_col; const DType offset_h = data_offset_ptr[data_offset_h_ptr]; const DType offset_w = data_offset_ptr[data_offset_w_ptr]; DType inv_h = h_in + i * dilation_h + offset_h; @@ -533,12 +531,12 @@ inline void deformable_col2im_coord(mshadow::Stream* s, const mxnet::TShape& pad, const mxnet::TShape& stride, const mxnet::TShape& dilation, - const uint32_t deformable_group, + const index_t deformable_group, DType* grad_offset) { - int num_spatial_axes = kernel_shape.ndim(); - int num_kernels = col_shape[1] * col_shape[2] * 2 * + const int num_spatial_axes = kernel_shape.ndim(); + index_t num_kernels = col_shape[1] * col_shape[2] * 2 * kernel_shape[0] * kernel_shape[1] * deformable_group; - int channel_per_group = col_shape[0] / deformable_group; + index_t channel_per_group = col_shape[0] / deformable_group; // num_axes should be smaller than block size CHECK_LT(num_spatial_axes, mshadow::cuda::kBaseThreadNum); using namespace mxnet_op; diff --git a/src/operator/contrib/nn/deformable_im2col.h b/src/operator/contrib/nn/deformable_im2col.h index 3915bc95e247..3f42668b86be 100644 --- a/src/operator/contrib/nn/deformable_im2col.h +++ b/src/operator/contrib/nn/deformable_im2col.h @@ -73,13 +73,13 @@ namespace op { template inline DType im2col_bilinear_cpu(const DType* data, - const int height, - const int width, + const index_t height, + const index_t width, DType h, DType w) { - int h_low = floor(h); - int w_low = floor(w); - int h_high; - int w_high; + index_t h_low = floor(h); + index_t w_low = floor(w); + index_t h_high; + index_t w_high; if (h_low >= height - 1) { h_high = height - 1; @@ -111,8 +111,8 @@ inline DType im2col_bilinear_cpu(const DType* data, template inline DType get_gradient_weight_cpu(DType argmax_h, DType argmax_w, - const int h, const int w, - const int height, const int width) { + const index_t h, const index_t w, + const index_t height, const index_t width) { if (argmax_h < 0 || argmax_h > height || argmax_w < 0 || argmax_w > width) { // empty return 0; @@ -121,10 +121,10 @@ inline DType get_gradient_weight_cpu(DType argmax_h, DType argmax_w, argmax_h = std::max(argmax_h, static_cast(0.0f)); argmax_w = std::max(argmax_w, static_cast(0.0f)); - int argmax_h_low = static_cast(argmax_h); - int argmax_w_low = static_cast(argmax_w); - int argmax_h_high; - int argmax_w_high; + index_t argmax_h_low = static_cast(argmax_h); + index_t argmax_w_low = static_cast(argmax_w); + index_t argmax_h_high; + index_t argmax_w_high; if (argmax_h_low >= height - 1) { argmax_h_high = argmax_h_low = height - 1; argmax_h = static_cast(argmax_h_low); @@ -157,9 +157,9 @@ inline DType get_gradient_weight_cpu(DType argmax_h, DType argmax_w, template inline DType get_coordinate_weight_cpu(DType argmax_h, DType argmax_w, - const int height, const int width, + const index_t height, const index_t width, const DType* im_data, - const int data_width, const int bp_dir) { + const index_t data_width, const index_t bp_dir) { if (argmax_h < 0 || argmax_h > height || argmax_w < 0 || argmax_w > width) { // empty return 0; @@ -168,10 +168,10 @@ inline DType get_coordinate_weight_cpu(DType argmax_h, DType argmax_w, if (argmax_h < 0) argmax_h = 0; if (argmax_w < 0) argmax_w = 0; - int argmax_h_low = static_cast(argmax_h); - int argmax_w_low = static_cast(argmax_w); - int argmax_h_high; - int argmax_w_high; + index_t argmax_h_low = static_cast(argmax_h); + index_t argmax_w_low = static_cast(argmax_w); + index_t argmax_h_high; + index_t argmax_w_high; if (argmax_h_low >= height - 1) { argmax_h_high = argmax_h_low = height - 1; argmax_h = static_cast(argmax_h_low); @@ -214,31 +214,31 @@ inline DType get_coordinate_weight_cpu(DType argmax_h, DType argmax_w, template inline void deformable_im2col_cpu(const DType* data_im, const DType* data_offset, - const int channels, - const int height, const int width, - const int kernel_h, const int kernel_w, - const int pad_h, const int pad_w, - const int stride_h, const int stride_w, - const int dilation_h, const int dilation_w, - const int deformable_group, - const int height_col, const int width_col, + const index_t channels, + const index_t height, const index_t width, + const index_t kernel_h, const index_t kernel_w, + const index_t pad_h, const index_t pad_w, + const index_t stride_h, const index_t stride_w, + const index_t dilation_h, const index_t dilation_w, + const index_t deformable_group, + const index_t height_col, const index_t width_col, DType* data_col) { - const int channel_size = height * width; - const int offset_size = 2 * kernel_h * kernel_w * height_col * width_col; - const int channel_per_group = channels / deformable_group; - for (int channel = 0; channel < channels; channel++, data_im += channel_size) { + const index_t channel_size = height * width; + const index_t offset_size = 2 * kernel_h * kernel_w * height_col * width_col; + const index_t channel_per_group = channels / deformable_group; + for (index_t channel = 0; channel < channels; channel++, data_im += channel_size) { if (channel % channel_per_group == 0 && channel != 0) { data_offset += offset_size; } - for (int i = 0; i < kernel_h; i++) { - for (int j = 0; j < kernel_w; j++) { - int input_row = -pad_h + i * dilation_h; - for (int h_col = 0; h_col < height_col; h_col++) { - int input_col = -pad_w + j * dilation_w; - for (int w_col = 0; w_col < width_col; w_col++) { - int offset_h_ptr = ((2 * (i * kernel_w + j)) * + for (index_t i = 0; i < kernel_h; i++) { + for (index_t j = 0; j < kernel_w; j++) { + index_t input_row = -pad_h + i * dilation_h; + for (index_t h_col = 0; h_col < height_col; h_col++) { + index_t input_col = -pad_w + j * dilation_w; + for (index_t w_col = 0; w_col < width_col; w_col++) { + index_t offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; - int offset_w_ptr = offset_h_ptr + height_col * width_col; + index_t offset_w_ptr = offset_h_ptr + height_col * width_col; DType im_row = input_row + data_offset[offset_h_ptr]; DType im_col = input_col + data_offset[offset_w_ptr]; if (im_row >= 0 && im_col >= 0 && im_row < height && im_col < width) { @@ -279,7 +279,7 @@ inline void deformable_im2col(mshadow::Stream* s, const mxnet::TShape& pad, const mxnet::TShape& stride, const mxnet::TShape& dilation, - const uint32_t deformable_group, + const index_t deformable_group, DType* data_col) { if (2 == kernel_shape.ndim()) { deformable_im2col_cpu(data_im, data_offset, @@ -303,43 +303,43 @@ inline void deformable_im2col(mshadow::Stream* s, */ template inline void deformable_col2im_cpu(const DType* data_col, - const DType* data_offset, const int channels, - const int height, const int width, - const int kernel_h, const int kernel_w, - const int pad_h, const int pad_w, - const int stride_h, const int stride_w, - const int dilation_h, const int dilation_w, - const int deformable_group, - const int height_col, const int width_col, + const DType* data_offset, const index_t channels, + const index_t height, const index_t width, + const index_t kernel_h, const index_t kernel_w, + const index_t pad_h, const index_t pad_w, + const index_t stride_h, const index_t stride_w, + const index_t dilation_h, const index_t dilation_w, + const index_t deformable_group, + const index_t height_col, const index_t width_col, DType* grad_im) { - int channel_per_group = channels / deformable_group; - int count = channels * kernel_h * kernel_w * height_col * width_col; - for (int index = 0; index < count; ++index) { - const int j = (index / width_col / height_col) % kernel_w; - const int i = (index / width_col / height_col / kernel_w) % kernel_h; - const int c = index / width_col / height_col / kernel_w / kernel_h; + index_t channel_per_group = channels / deformable_group; + index_t count = channels * kernel_h * kernel_w * height_col * width_col; + for (index_t index = 0; index < count; ++index) { + const index_t j = (index / width_col / height_col) % kernel_w; + const index_t i = (index / width_col / height_col / kernel_w) % kernel_h; + const index_t c = index / width_col / height_col / kernel_w / kernel_h; // compute the start and end of the output - const int group_index = c / channel_per_group; - const int group_offset_step = 2 * kernel_h * kernel_w * height_col * width_col; + const index_t group_index = c / channel_per_group; + const index_t group_offset_step = 2 * kernel_h * kernel_w * height_col * width_col; - int w_col = index % width_col; - int h_col = (index / width_col) % height_col; - int w_in = w_col * stride_w - pad_w; - int h_in = h_col * stride_h - pad_h; + index_t w_col = index % width_col; + index_t h_col = (index / width_col) % height_col; + index_t w_in = w_col * stride_w - pad_w; + index_t h_in = h_col * stride_h - pad_h; const DType* data_offset_ptr = data_offset + group_index * group_offset_step; - const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * + const index_t data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; - const int data_offset_w_ptr = data_offset_h_ptr + height_col * width_col; + const index_t data_offset_w_ptr = data_offset_h_ptr + height_col * width_col; const DType offset_h = data_offset_ptr[data_offset_h_ptr]; const DType offset_w = data_offset_ptr[data_offset_w_ptr]; const DType cur_inv_h_data = h_in + i * dilation_h + offset_h; const DType cur_inv_w_data = w_in + j * dilation_w + offset_w; const DType cur_top_grad = data_col[index]; - const int cur_h = static_cast(cur_inv_h_data); - const int cur_w = static_cast(cur_inv_w_data); + const index_t cur_h = static_cast(cur_inv_h_data); + const index_t cur_w = static_cast(cur_inv_w_data); for (int dy = -2; dy <= 2; dy++) { for (int dx = -2; dx <= 2; dx++) { if (cur_h + dy >= 0 && cur_h + dy < height && @@ -347,7 +347,7 @@ inline void deformable_col2im_cpu(const DType* data_col, std::abs(cur_inv_h_data - (cur_h + dy)) < 1 && std::abs(cur_inv_w_data - (cur_w + dx)) < 1 ) { - int cur_bottom_grad_pos = (c * height + cur_h + dy) * width + cur_w + dx; + index_t cur_bottom_grad_pos = (c * height + cur_h + dy) * width + cur_w + dx; DType weight = get_gradient_weight_cpu(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width); grad_im[cur_bottom_grad_pos] += weight * cur_top_grad; @@ -382,7 +382,7 @@ inline void deformable_col2im(mshadow::Stream* s, const mxnet::TShape& pad, const mxnet::TShape& stride, const mxnet::TShape& dilation, - const uint32_t deformable_group, + const index_t deformable_group, DType* grad_im) { if (2 == kernel_shape.ndim()) { deformable_col2im_cpu(data_col, data_offset, @@ -407,49 +407,49 @@ template inline void deformable_col2im_coord_cpu(const DType* data_col, const DType* data_im, const DType* data_offset, - const int channels, - const int height, const int width, - const int kernel_h, const int kernel_w, - const int pad_h, const int pad_w, - const int stride_h, const int stride_w, - const int dilation_h, const int dilation_w, - const int deformable_group, - const int height_col, const int width_col, + const index_t channels, + const index_t height, const index_t width, + const index_t kernel_h, const index_t kernel_w, + const index_t pad_h, const index_t pad_w, + const index_t stride_h, const index_t stride_w, + const index_t dilation_h, const index_t dilation_w, + const index_t deformable_group, + const index_t height_col, const index_t width_col, DType* grad_offset) { - int channel_per_group = channels * kernel_h * kernel_w / deformable_group; - int count = height_col * width_col * 2 * kernel_h * kernel_w * deformable_group; - for (int index = 0; index < count; ++index) { + index_t channel_per_group = channels * kernel_h * kernel_w / deformable_group; + index_t count = height_col * width_col * 2 * kernel_h * kernel_w * deformable_group; + for (index_t index = 0; index < count; ++index) { DType val = 0; - int w = index % width_col; - int h = (index / width_col) % height_col; - int c = index / width_col / height_col; + index_t w = index % width_col; + index_t h = (index / width_col) % height_col; + index_t c = index / width_col / height_col; // compute the start and end of the output - const int group_index = c / (2 * kernel_h * kernel_w); - const int group_col_step = channel_per_group * width_col * height_col; - const int group_im_step = channel_per_group / kernel_h / kernel_w * height * width; - const int group_offset_step = 2 * kernel_h * kernel_w * height_col * width_col; - const int col_step = kernel_h * kernel_w; + const index_t group_index = c / (2 * kernel_h * kernel_w); + const index_t group_col_step = channel_per_group * width_col * height_col; + const index_t group_im_step = channel_per_group / kernel_h / kernel_w * height * width; + const index_t group_offset_step = 2 * kernel_h * kernel_w * height_col * width_col; + const index_t col_step = kernel_h * kernel_w; const DType* data_col_ptr = data_col + group_index * group_col_step; const DType* data_im_ptr = data_im + group_index * group_im_step; const DType* data_offset_ptr = data_offset + group_index * group_offset_step; - int cnt = 0; - const int offset_c = c - group_index * 2 * kernel_h * kernel_w; + index_t cnt = 0; + const index_t offset_c = c - group_index * 2 * kernel_h * kernel_w; - for (int col_c = (offset_c / 2); col_c < channel_per_group; col_c += col_step) { - const int col_pos = ((col_c * height_col) + h) * width_col + w; - const int bp_dir = offset_c % 2; + for (index_t col_c = (offset_c / 2); col_c < channel_per_group; col_c += col_step) { + const index_t col_pos = ((col_c * height_col) + h) * width_col + w; + const index_t bp_dir = offset_c % 2; - int j = (col_pos / width_col / height_col) % kernel_w; - int i = (col_pos / width_col / height_col / kernel_w) % kernel_h; - int w_col = col_pos % width_col; - int h_col = (col_pos / width_col) % height_col; - int w_in = w_col * stride_w - pad_w; - int h_in = h_col * stride_h - pad_h; - const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * + index_t j = (col_pos / width_col / height_col) % kernel_w; + index_t i = (col_pos / width_col / height_col / kernel_w) % kernel_h; + index_t w_col = col_pos % width_col; + index_t h_col = (col_pos / width_col) % height_col; + index_t w_in = w_col * stride_w - pad_w; + index_t h_in = h_col * stride_h - pad_h; + const index_t data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; - const int data_offset_w_ptr = data_offset_h_ptr + height_col * width_col; + const index_t data_offset_w_ptr = data_offset_h_ptr + height_col * width_col; const DType offset_h = data_offset_ptr[data_offset_h_ptr]; const DType offset_w = data_offset_ptr[data_offset_w_ptr]; DType inv_h = h_in + i * dilation_h + offset_h; @@ -495,7 +495,7 @@ inline void deformable_col2im_coord(mshadow::Stream* s, const mxnet::TShape& pad, const mxnet::TShape& stride, const mxnet::TShape& dilation, - const uint32_t deformable_group, + const index_t deformable_group, DType* grad_offset) { if (2 == kernel_shape.ndim()) { deformable_col2im_coord_cpu(data_col, data_im, data_offset,