diff --git a/src/operator/contrib/deformable_convolution-inl.h b/src/operator/contrib/deformable_convolution-inl.h index 000d703066d7..eb23d99bbb1a 100644 --- a/src/operator/contrib/deformable_convolution-inl.h +++ b/src/operator/contrib/deformable_convolution-inl.h @@ -61,9 +61,9 @@ struct DeformableConvolutionParam : public dmlc::Parameter layout; @@ -109,10 +109,10 @@ class DeformableConvolutionOp : public Operator { } virtual void Forward(const OpContext &ctx, - const std::vector &in_data, - const std::vector &req, - const std::vector &out_data, - const std::vector &aux_args) { + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data, + const std::vector &aux_args) { using namespace mshadow; using namespace mshadow::expr; CHECK_EQ(req[conv::kOut], kWriteTo); @@ -147,10 +147,11 @@ class DeformableConvolutionOp : public Operator { Shape4(num_, group_, M, N), s); for (index_t n = 0; n < num_; ++n) { // transform image to col_buffer in order to use gemm - deformable_im2col(s, in_data[conv::kData].dptr() + n*input_dim_, - in_data[conv::kOffset].dptr() + n*input_offset_dim_, in_data[conv::kData].shape_, - col_buffer.shape_, param_.kernel, param_.pad, param_.stride, param_.dilate, - param_.num_deformable_group, col_buffer.dptr()); + deformable_im2col(s, in_data[conv::kData].dptr() + n * input_dim_, + in_data[conv::kOffset].dptr() + n * input_offset_dim_, + in_data[conv::kData].shape_, col_buffer.shape_, + param_.kernel, param_.pad, param_.stride, param_.dilate, + param_.num_deformable_group, col_buffer.dptr()); Tensor output_3d = output_4d[n]; for (index_t g = 0; g < group_; ++g) { // Legacy approach shown here for comparison: @@ -168,12 +169,12 @@ class DeformableConvolutionOp : public Operator { } virtual void Backward(const OpContext &ctx, - const std::vector& out_grad, - const std::vector& in_data, - const std::vector& out_data, - const std::vector& req, - const std::vector& in_grad, - const std::vector& aux_args) { + const std::vector& out_grad, + const std::vector& in_data, + const std::vector& out_data, + const std::vector& req, + const std::vector& in_grad, + const std::vector& aux_args) { using namespace mshadow; using namespace mshadow::expr; CHECK_EQ(out_grad.size(), 1U); @@ -226,26 +227,27 @@ class DeformableConvolutionOp : public Operator { // gradient w.r.t. input coordinate data deformable_col2im_coord(s, col_buffer.dptr(), - in_data[conv::kData].dptr() + n*input_dim_, - in_data[conv::kOffset].dptr() + n*input_offset_dim_, - in_grad[conv::kData].shape_, col_buffer.shape_, - param_.kernel, param_.pad, param_.stride, param_.dilate, param_.num_deformable_group, - in_grad[conv::kOffset].dptr() + n*input_offset_dim_, - req[conv::kOffset]); + in_data[conv::kData].dptr() + n * input_dim_, + in_data[conv::kOffset].dptr() + n * input_offset_dim_, + in_grad[conv::kData].shape_, col_buffer.shape_, + param_.kernel, param_.pad, param_.stride, + param_.dilate, param_.num_deformable_group, + in_grad[conv::kOffset].dptr() + n * input_offset_dim_); // gradient w.r.t. input data deformable_col2im(s, col_buffer.dptr(), - in_data[conv::kOffset].dptr() + n*input_offset_dim_, - in_grad[conv::kData].shape_, col_buffer.shape_, - param_.kernel, param_.pad, param_.stride, param_.dilate, param_.num_deformable_group, - in_grad[conv::kData].dptr() + n*input_dim_, - req[conv::kData]); + in_data[conv::kOffset].dptr() + n * input_offset_dim_, + in_grad[conv::kData].shape_, col_buffer.shape_, + param_.kernel, param_.pad, param_.stride, + param_.dilate, param_.num_deformable_group, + in_grad[conv::kData].dptr() + n * input_dim_); // gradient w.r.t. weight, dWeight should accumulate across the batch and group - deformable_im2col(s, in_data[conv::kData].dptr() + n*input_dim_, - in_data[conv::kOffset].dptr() + n*input_offset_dim_, in_data[conv::kData].shape_, - col_buffer.shape_, param_.kernel, param_.pad, param_.stride, param_.dilate, - param_.num_deformable_group, col_buffer.dptr()); + deformable_im2col(s, in_data[conv::kData].dptr() + n * input_dim_, + in_data[conv::kOffset].dptr() + n * input_offset_dim_, + in_data[conv::kData].shape_, col_buffer.shape_, param_.kernel, + param_.pad, param_.stride, param_.dilate, + param_.num_deformable_group, col_buffer.dptr()); for (index_t g = 0; g < group_; ++g) { auto request = (n == 0) ? req[conv::kWeight] : kAddTo; @@ -327,9 +329,9 @@ class DeformableConvolutionOp : public Operator { template Operator* CreateOp(DeformableConvolutionParam param, int dtype, - mxnet::ShapeVector *in_shape, - mxnet::ShapeVector *out_shape, - Context ctx); + mxnet::ShapeVector *in_shape, + mxnet::ShapeVector *out_shape, + Context ctx); #if DMLC_USE_CXX11 class DeformableConvolutionProp : public OperatorProperty { @@ -360,8 +362,8 @@ class DeformableConvolutionProp : public OperatorProperty { } bool InferShape(mxnet::ShapeVector *in_shape, - mxnet::ShapeVector *out_shape, - mxnet::ShapeVector *aux_shape) const override { + mxnet::ShapeVector *out_shape, + mxnet::ShapeVector *aux_shape) const override { using namespace mshadow; if (!param_.no_bias) { CHECK_EQ(in_shape->size(), 4U) << "Input:[data, offset, weight, bias]"; @@ -411,8 +413,6 @@ class DeformableConvolutionProp : public OperatorProperty { oshape[3] = (dshape[3] + 2 * param_.pad[1] - (param_.dilate[1] * (ksize_x - 1) + 1)) / param_.stride[1] + 1; SHAPE_ASSIGN_CHECK(*out_shape, 0, ConvertLayout(oshape, kNCHW, param_.layout.value())); - CHECK_EQ(oshape[1] % param_.num_deformable_group, 0U) \ - << "output num_filter must divide deformable group size"; CHECK_EQ(oshape[2], offsetshape[2]) \ << "output height must equal to offset map height"; CHECK_EQ(oshape[3], offsetshape[3]) \ @@ -450,8 +450,8 @@ class DeformableConvolutionProp : public OperatorProperty { } bool InferType(std::vector *in_type, - std::vector *out_type, - std::vector *aux_type) const override { + std::vector *out_type, + std::vector *aux_type) const override { CHECK_GE(in_type->size(), 1U); int dtype = (*in_type)[0]; CHECK_NE(dtype, -1) << "First input must have specified type"; @@ -477,10 +477,9 @@ class DeformableConvolutionProp : public OperatorProperty { return "_contrib_DeformableConvolution"; } - std::vector DeclareBackwardDependency( - const std::vector &out_grad, - const std::vector &in_data, - const std::vector &out_data) const override { + std::vector DeclareBackwardDependency(const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data) const override { return{ out_grad[conv::kOut], in_data[conv::kData], in_data[conv::kOffset], in_data[conv::kWeight] }; } @@ -501,7 +500,7 @@ class DeformableConvolutionProp : public OperatorProperty { } Operator* CreateOperatorEx(Context ctx, mxnet::ShapeVector *in_shape, - std::vector *in_type) const override; + std::vector *in_type) const override; private: DeformableConvolutionParam param_; diff --git a/src/operator/contrib/deformable_convolution.cc b/src/operator/contrib/deformable_convolution.cc index 8bb1ae23f40d..60138376f287 100644 --- a/src/operator/contrib/deformable_convolution.cc +++ b/src/operator/contrib/deformable_convolution.cc @@ -62,7 +62,7 @@ The deformable convolution operation is described in https://arxiv.org/abs/1703. For 2-D deformable convolution, the shapes are - **data**: *(batch_size, channel, height, width)* -- **offset**: *(batch_size, num_deformable_group * kernel[0] * kernel[1], height, width)* +- **offset**: *(batch_size, num_deformable_group * kernel[0] * kernel[1] * 2, height, width)* - **weight**: *(num_filter, channel, kernel[0], kernel[1])* - **bias**: *(num_filter,)* - **out**: *(batch_size, num_filter, out_height, out_width)*. @@ -89,9 +89,9 @@ the *g* results. If ``num_deformable_group`` is larger than 1, denoted by *dg*, then split the input ``offset`` evenly into *dg* parts along the channel axis, and also evenly -split ``out`` evenly into *dg* parts along the channel axis. Next compute the -deformable convolution, apply the *i*-th part of the offset part on the *i*-th -out. +split ``data`` into *dg* parts along the channel axis. Next compute the +deformable convolution, apply the *i*-th part of the offset on the *i*-th part +of the data. Both ``weight`` and ``bias`` are learnable parameters. diff --git a/src/operator/contrib/nn/deformable_im2col.cuh b/src/operator/contrib/nn/deformable_im2col.cuh index 5f206d23d8d7..9494fb379faf 100644 --- a/src/operator/contrib/nn/deformable_im2col.cuh +++ b/src/operator/contrib/nn/deformable_im2col.cuh @@ -75,26 +75,26 @@ namespace mxnet { namespace op { template -__device__ DType deformable_im2col_bilinear(const DType* bottom_data, const int data_width, - const int height, const int width, DType h, DType w) { - - int h_low = floor(h); - int w_low = floor(w); - int h_high; - int w_high; +__device__ DType deformable_im2col_bilinear(const DType* bottom_data, + const index_t data_width, + const index_t height, + const index_t width, + DType h, DType w) { + index_t h_low = floor(h); + index_t w_low = floor(w); + index_t h_high; + index_t w_high; if (h_low >= height - 1) { h_high = h_low = height - 1; - h = (DType)h_low; - } - else { + h = static_cast(h_low); + } else { h_high = h_low + 1; } if (w_low >= width - 1) { w_high = w_low = width - 1; - w = (DType)w_low; - } - else { + w = static_cast(w_low); + } else { w_high = w_low + 1; } @@ -115,30 +115,30 @@ __device__ DType deformable_im2col_bilinear(const DType* bottom_data, const int template __device__ DType get_gradient_weight(DType argmax_h, DType argmax_w, - const int h, const int w, const int height, const int width) { - + const index_t h, const index_t w, + const index_t height, const index_t width) { if (argmax_h < 0 || argmax_h > height || argmax_w < 0 || argmax_w > width) { //empty return 0; } - argmax_h = max(argmax_h, (DType)0.0f); - argmax_w = max(argmax_w, (DType)0.0f); + argmax_h = max(argmax_h, static_cast(0.0f)); + argmax_w = max(argmax_w, static_cast(0.0f)); - int argmax_h_low = (int)argmax_h; - int argmax_w_low = (int)argmax_w; - int argmax_h_high; - int argmax_w_high; + index_t argmax_h_low = static_cast(argmax_h); + index_t argmax_w_low = static_cast(argmax_w); + index_t argmax_h_high; + index_t argmax_w_high; if (argmax_h_low >= height - 1) { argmax_h_high = argmax_h_low = height - 1; - argmax_h = (DType)argmax_h_low; + argmax_h = static_cast(argmax_h_low); } else { argmax_h_high = argmax_h_low + 1; } if (argmax_w_low >= width - 1) { argmax_w_high = argmax_w_low = width - 1; - argmax_w = (DType)argmax_w_low; + argmax_w = static_cast(argmax_w_low); } else { argmax_w_high = argmax_w_low + 1; } @@ -162,9 +162,10 @@ __device__ DType get_gradient_weight(DType argmax_h, DType argmax_w, template __device__ DType get_coordinate_weight(DType argmax_h, DType argmax_w, - const int height, const int width, const DType* im_data, - const int data_width, const int bp_dir) { - + const index_t height, const index_t width, + const DType* im_data, + const index_t data_width, + const index_t bp_dir) { if (argmax_h < 0 || argmax_h > height || argmax_w < 0 || argmax_w > width) { //empty @@ -174,34 +175,38 @@ __device__ DType get_coordinate_weight(DType argmax_h, DType argmax_w, if (argmax_h < 0) argmax_h = 0; if (argmax_w < 0) argmax_w = 0; - int argmax_h_low = (int)argmax_h; - int argmax_w_low = (int)argmax_w; - int argmax_h_high; - int argmax_w_high; + index_t argmax_h_low = static_cast(argmax_h); + index_t argmax_w_low = static_cast(argmax_w); + index_t argmax_h_high; + index_t argmax_w_high; if (argmax_h_low >= height - 1) { argmax_h_high = argmax_h_low = height - 1; - argmax_h = (DType)argmax_h_low; + argmax_h = static_cast(argmax_h_low); } else { argmax_h_high = argmax_h_low + 1; } if (argmax_w_low >= width - 1) { argmax_w_high = argmax_w_low = width - 1; - argmax_w = (DType)argmax_w_low; + argmax_w = static_cast(argmax_w_low); } else { argmax_w_high = argmax_w_low + 1; } - DType weight = 0; + DType weight = 0; + DType im_ll = im_data[argmax_h_low * data_width + argmax_w_low]; + DType im_lh = im_data[argmax_h_low * data_width + argmax_w_high]; + DType im_hl = im_data[argmax_h_high * data_width + argmax_w_low]; + DType im_hh = im_data[argmax_h_high * data_width + argmax_w_high]; if (bp_dir == 0) { - weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low]; - weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high]; - weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low]; - weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + weight += -1 * (argmax_w_low + 1 - argmax_w) * im_ll; + weight += -1 * (argmax_w - argmax_w_low) * im_lh; + weight += (argmax_w_low + 1 - argmax_w) * im_hl; + weight += (argmax_w - argmax_w_low) * im_hh; } else if (bp_dir == 1) { - weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low]; - weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high]; - weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low]; - weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + weight += -1 * (argmax_h_low + 1 - argmax_h) * im_ll; + weight += (argmax_h_low + 1 - argmax_h) * im_lh; + weight += -1 * (argmax_h - argmax_h_low) * im_hl; + weight += (argmax_h - argmax_h_low) * im_hh; } return weight; @@ -213,35 +218,38 @@ __device__ DType get_coordinate_weight(DType argmax_h, DType argmax_w, * DO NOT call this directly. Use wrapper function im2col() instead; */ template -__global__ void deformable_im2col_gpu_kernel(const int n, const DType* data_im, const DType* data_offset, - const int height, const int width, const int kernel_h, const int kernel_w, - const int pad_h, const int pad_w, - const int stride_h, const int stride_w, - const int dilation_h, const int dilation_w, - const int channel_per_deformable_group, - const int height_col, const int width_col, - DType* data_col) { +__global__ void deformable_im2col_gpu_kernel(const index_t n, const DType* data_im, + const DType* data_offset, + const index_t height, const index_t width, + const index_t kernel_h, const index_t kernel_w, + const index_t pad_h, const index_t pad_w, + const index_t stride_h, const index_t stride_w, + const index_t dilation_h, const index_t dilation_w, + const index_t channel_per_group, + const index_t height_col, const index_t width_col, + DType* data_col) { CUDA_KERNEL_LOOP(index, n) { // index index of output matrix - const int w_col = index % width_col; - const int h_col = (index / width_col) % height_col; - const int c_im = (index / width_col) / height_col; - const int c_col = c_im * kernel_h * kernel_w; + const index_t w_col = index % width_col; + const index_t h_col = (index / width_col) % height_col; + const index_t c_im = (index / width_col) / height_col; + const index_t c_col = c_im * kernel_h * kernel_w; - // compute deformable group index - const int deformable_group_index = c_im / channel_per_deformable_group; + const index_t group_index = c_im / channel_per_group; + const index_t group_offset_step = 2 * kernel_h * kernel_w * height_col * width_col; - const int h_in = h_col * stride_h - pad_h; - const int w_in = w_col * stride_w - pad_w; + const index_t h_in = h_col * stride_h - pad_h; + const index_t w_in = w_col * stride_w - pad_w; DType* data_col_ptr = data_col + (c_col * height_col + h_col) * width_col + w_col; const DType* data_im_ptr = data_im + (c_im * height + h_in) * width + w_in; - const DType* data_offset_ptr = data_offset + deformable_group_index * 2 * kernel_h * kernel_w * height_col * width_col; + const DType* data_offset_ptr = data_offset + group_index * group_offset_step; - for (int i = 0; i < kernel_h; ++i) { - for (int j = 0; j < kernel_w; ++j) { - const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; - const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col; + for (index_t i = 0; i < kernel_h; ++i) { + for (index_t j = 0; j < kernel_w; ++j) { + const index_t data_offset_h_ptr = ((2 * (i * kernel_w + j)) * + height_col + h_col) * width_col + w_col; + const index_t data_offset_w_ptr = data_offset_h_ptr + height_col * width_col; const DType offset_h = data_offset_ptr[data_offset_h_ptr]; const DType offset_w = data_offset_ptr[data_offset_w_ptr]; DType val = static_cast(0); @@ -250,8 +258,8 @@ __global__ void deformable_im2col_gpu_kernel(const int n, const DType* data_im, if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) { const DType map_h = i * dilation_h + offset_h; const DType map_w = j * dilation_w + offset_w; - const int cur_height = height - h_in; - const int cur_width = width - w_in; + const index_t cur_height = height - h_in; + const index_t cur_width = width - w_in; val = deformable_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w); } *data_col_ptr = val; @@ -262,10 +270,6 @@ __global__ void deformable_im2col_gpu_kernel(const int n, const DType* data_im, } - - - - /*!\brief * cpu function of deformable_im2col algorithm * \param s device stream @@ -282,24 +286,33 @@ __global__ void deformable_im2col_gpu_kernel(const int n, const DType* data_im, */ template inline void deformable_im2col(mshadow::Stream* s, - const DType* data_im, const DType* data_offset, - const mxnet::TShape& im_shape, const mxnet::TShape& col_shape, const mxnet::TShape& kernel_shape, - const mxnet::TShape& pad, const mxnet::TShape& stride, const mxnet::TShape& dilation, - const uint32_t deformable_group, DType* data_col) { + const DType* data_im, + const DType* data_offset, + const mxnet::TShape& im_shape, + const mxnet::TShape& col_shape, + const mxnet::TShape& kernel_shape, + const mxnet::TShape& pad, + const mxnet::TShape& stride, + const mxnet::TShape& dilation, + const index_t deformable_group, + DType* data_col) { // num_axes should be smaller than block size - index_t num_spatial_axes = kernel_shape.ndim(); + const int num_spatial_axes = kernel_shape.ndim(); CHECK_LT(num_spatial_axes, mshadow::cuda::kBaseThreadNum); - index_t channel_per_deformable_group = im_shape[1] / deformable_group; + index_t channel_per_group = im_shape[1] / deformable_group; index_t num_kernels = im_shape[1] * col_shape.ProdShape(1, col_shape.ndim()); using namespace mxnet_op; switch (num_spatial_axes) { case 2: deformable_im2col_gpu_kernel // NOLINT_NEXT_LINE(whitespace/operators) <<::GetStream(s)>>>( - num_kernels, data_im, data_offset, im_shape[2], im_shape[3], kernel_shape[0], kernel_shape[1], - pad[0], pad[1], stride[0], stride[1], dilation[0], dilation[1], channel_per_deformable_group, - col_shape[1], col_shape[2], data_col); + 0, mshadow::Stream::GetStream(s)>>>(num_kernels, data_im, data_offset, + im_shape[2], im_shape[3], + kernel_shape[0], kernel_shape[1], + pad[0], pad[1], stride[0], stride[1], + dilation[0], dilation[1], + channel_per_group, + col_shape[1], col_shape[2], data_col); MSHADOW_CUDA_POST_KERNEL_CHECK(deformable_im2col_gpu_kernel); break; default: @@ -314,39 +327,42 @@ inline void deformable_im2col(mshadow::Stream* s, * \brief DO NOT call this directly. Use wrapper function deformable_col2im() instead; */ template -__global__ void deformable_col2im_gpu_kernel(const int n, const DType* data_col, const DType* data_offset, - const int channels, const int height, const int width, - const int kernel_h, const int kernel_w, - const int pad_h, const int pad_w, - const int stride_h, const int stride_w, - const int dilation_h, const int dilation_w, - const int channel_per_deformable_group, - const int height_col, const int width_col, - DType* grad_im, OpReqType req) { +__global__ void deformable_col2im_gpu_kernel(const index_t n, const DType* data_col, + const DType* data_offset, const index_t channels, + const index_t height, const index_t width, + const index_t kernel_h, const index_t kernel_w, + const index_t pad_h, const index_t pad_w, + const index_t stride_h, const index_t stride_w, + const index_t dilation_h, const index_t dilation_w, + const index_t channel_per_group, + const index_t height_col, const index_t width_col, + DType* grad_im) { CUDA_KERNEL_LOOP(index, n) { - const int j = (index / width_col / height_col) % kernel_w; - const int i = (index / width_col / height_col / kernel_w) % kernel_h; - const int c = index / width_col / height_col / kernel_w / kernel_h; + const index_t j = (index / width_col / height_col) % kernel_w; + const index_t i = (index / width_col / height_col / kernel_w) % kernel_h; + const index_t c = index / width_col / height_col / kernel_w / kernel_h; // compute the start and end of the output - const int deformable_group_index = c / channel_per_deformable_group; + const index_t group_index = c / channel_per_group; + const index_t group_offset_step = 2 * kernel_h * kernel_w * height_col * width_col; - int w_out = index % width_col; - int h_out = (index / width_col) % height_col; - int w_in = w_out * stride_w - pad_w; - int h_in = h_out * stride_h - pad_h; + index_t w_col = index % width_col; + index_t h_col = (index / width_col) % height_col; + index_t w_in = w_col * stride_w - pad_w; + index_t h_in = h_col * stride_h - pad_h; - const DType* data_offset_ptr = data_offset + deformable_group_index * 2 * kernel_h * kernel_w * height_col * width_col; - const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; - const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; + const DType* data_offset_ptr = data_offset + group_index * group_offset_step; + const index_t data_offset_h_ptr = ((2 * (i * kernel_w + j)) * + height_col + h_col) * width_col + w_col; + const index_t data_offset_w_ptr = data_offset_h_ptr + height_col * width_col; const DType offset_h = data_offset_ptr[data_offset_h_ptr]; const DType offset_w = data_offset_ptr[data_offset_w_ptr]; const DType cur_inv_h_data = h_in + i * dilation_h + offset_h; const DType cur_inv_w_data = w_in + j * dilation_w + offset_w; const DType cur_top_grad = data_col[index]; - const int cur_h = (int)cur_inv_h_data; - const int cur_w = (int)cur_inv_w_data; + const index_t cur_h = static_cast(cur_inv_h_data); + const index_t cur_w = static_cast(cur_inv_w_data); for (int dy = -2; dy <= 2; dy++) { for (int dx = -2; dx <= 2; dx++) { if (cur_h + dy >= 0 && cur_h + dy < height && @@ -354,8 +370,9 @@ __global__ void deformable_col2im_gpu_kernel(const int n, const DType* data_col, abs(cur_inv_h_data - (cur_h + dy)) < 1 && abs(cur_inv_w_data - (cur_w + dx)) < 1 ) { - int cur_bottom_grad_pos = (c * height + cur_h + dy) * width + cur_w + dx; - DType weight = get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width); + index_t cur_bottom_grad_pos = (c * height + cur_h + dy) * width + cur_w + dx; + DType weight = get_gradient_weight(cur_inv_h_data, cur_inv_w_data, + cur_h + dy, cur_w + dx, height, width); atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); } } @@ -380,14 +397,19 @@ __global__ void deformable_col2im_gpu_kernel(const int n, const DType* data_col, */ template inline void deformable_col2im(mshadow::Stream* s, - const DType* data_col, const DType* data_offset, - const mxnet::TShape& im_shape, const mxnet::TShape& col_shape, const mxnet::TShape& kernel_shape, - const mxnet::TShape& pad, const mxnet::TShape& stride, - const mxnet::TShape& dilation, const uint32_t deformable_group, - DType* grad_im, OpReqType req) { - index_t num_spatial_axes = kernel_shape.ndim(); + const DType* data_col, + const DType* data_offset, + const mxnet::TShape& im_shape, + const mxnet::TShape& col_shape, + const mxnet::TShape& kernel_shape, + const mxnet::TShape& pad, + const mxnet::TShape& stride, + const mxnet::TShape& dilation, + const index_t deformable_group, + DType* grad_im) { + const int num_spatial_axes = kernel_shape.ndim(); index_t im_size = im_shape.ProdShape(1, im_shape.ndim()); - index_t channel_per_deformable_group = im_shape[1] / deformable_group; + index_t channel_per_group = im_shape[1] / deformable_group; index_t num_kernels = col_shape.ProdShape(0, col_shape.ndim()); // num_axes should be smaller than block size CHECK_LT(num_spatial_axes, mshadow::cuda::kBaseThreadNum); @@ -397,11 +419,15 @@ inline void deformable_col2im(mshadow::Stream* s, // To avoid involving atomic operations, we will launch one kernel per // bottom dimension, and then in the kernel add up the top dimensions. // NOLINT_NEXT_LINE(whitespace/operators) - deformable_col2im_gpu_kernel<<::GetStream(s)>>>( - num_kernels, data_col, data_offset, im_shape[1], im_shape[2], im_shape[3], - kernel_shape[0], kernel_shape[1], pad[0], pad[1], stride[0], stride[1], - dilation[0], dilation[1], channel_per_deformable_group, col_shape[1], col_shape[2], grad_im, req); + deformable_col2im_gpu_kernel + <<::GetStream(s)>>>(num_kernels, data_col, data_offset, + im_shape[1], im_shape[2], im_shape[3], + kernel_shape[0], kernel_shape[1], + pad[0], pad[1], stride[0], stride[1], + dilation[0], dilation[1], + channel_per_group, + col_shape[1], col_shape[2], grad_im); MSHADOW_CUDA_POST_KERNEL_CHECK(deformable_col2im_gpu_kernel); break; default: @@ -416,44 +442,50 @@ inline void deformable_col2im(mshadow::Stream* s, * \brief DO NOT call this directly. Use wrapper function deformable_col2im_coord() instead; */ template -__global__ void deformable_col2im_coord_gpu_kernel(const int n, const DType* data_col, - const DType* data_im, const DType* data_offset, - const int channels, const int height, const int width, - const int kernel_h, const int kernel_w, - const int pad_h, const int pad_w, - const int stride_h, const int stride_w, - const int dilation_h, const int dilation_w, - const int channel_per_deformable_group, - const int height_col, const int width_col, - DType* grad_offset, OpReqType req) { +__global__ void deformable_col2im_coord_gpu_kernel(const index_t n, const DType* data_col, + const DType* data_im, + const DType* data_offset, + const index_t channels, + const index_t height, const index_t width, + const index_t kernel_h, const index_t kernel_w, + const index_t pad_h, const index_t pad_w, + const index_t stride_h, const index_t stride_w, + const index_t dilation_h, const index_t dilation_w, + const index_t channel_per_group, + const index_t height_col, const index_t width_col, + DType* grad_offset) { CUDA_KERNEL_LOOP(index, n) { DType val = 0; - int w = index % width_col; - int h = (index / width_col) % height_col; - int c = index / width_col / height_col; + index_t w = index % width_col; + index_t h = (index / width_col) % height_col; + index_t c = index / width_col / height_col; // compute the start and end of the output - const int deformable_group_index = c / (2 * kernel_h * kernel_w); - const int col_step = kernel_h * kernel_w; - int cnt = 0; - const DType* data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * width_col * height_col; - const DType* data_im_ptr = data_im + deformable_group_index * channel_per_deformable_group / kernel_h / kernel_w * height * width; - const DType* data_offset_ptr = data_offset + deformable_group_index * 2 * kernel_h * kernel_w * height_col * width_col; - - const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; - - for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step) { - const int col_pos = ((col_c * height_col) + h) * width_col + w; - const int bp_dir = offset_c % 2; - - int j = (col_pos / width_col / height_col) % kernel_w; - int i = (col_pos / width_col / height_col / kernel_w) % kernel_h; - int w_out = col_pos % width_col; - int h_out = (col_pos / width_col) % height_col; - int w_in = w_out * stride_w - pad_w; - int h_in = h_out * stride_h - pad_h; - const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); - const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out); + const index_t group_index = c / (2 * kernel_h * kernel_w); + const index_t group_col_step = channel_per_group * width_col * height_col; + const index_t group_im_step = channel_per_group / kernel_h / kernel_w * height * width; + const index_t group_offset_step = 2 * kernel_h * kernel_w * height_col * width_col; + const index_t col_step = kernel_h * kernel_w; + const DType* data_col_ptr = data_col + group_index * group_col_step; + const DType* data_im_ptr = data_im + group_index * group_im_step; + const DType* data_offset_ptr = data_offset + group_index * group_offset_step; + + index_t cnt = 0; + const index_t offset_c = c - group_index * 2 * kernel_h * kernel_w; + + for (index_t col_c = (offset_c / 2); col_c < channel_per_group; col_c += col_step) { + const index_t col_pos = ((col_c * height_col) + h) * width_col + w; + const index_t bp_dir = offset_c % 2; + + index_t j = (col_pos / width_col / height_col) % kernel_w; + index_t i = (col_pos / width_col / height_col / kernel_w) % kernel_h; + index_t w_col = col_pos % width_col; + index_t h_col = (col_pos / width_col) % height_col; + index_t w_in = w_col * stride_w - pad_w; + index_t h_in = h_col * stride_h - pad_h; + const index_t data_offset_h_ptr = ((2 * (i * kernel_w + j)) * + height_col + h_col) * width_col + w_col; + const index_t data_offset_w_ptr = data_offset_h_ptr + height_col * width_col; const DType offset_h = data_offset_ptr[data_offset_h_ptr]; const DType offset_w = data_offset_ptr[data_offset_w_ptr]; DType inv_h = h_in + i * dilation_h + offset_h; @@ -461,9 +493,9 @@ __global__ void deformable_col2im_coord_gpu_kernel(const int n, const DType* dat if (inv_h < 0 || inv_w < 0 || inv_h >= height || inv_w >= width) { inv_h = inv_w = -1; } - const DType weight = get_coordinate_weight( - inv_h, inv_w, - height, width, data_im_ptr + cnt * height * width, width, bp_dir); + const DType weight = get_coordinate_weight(inv_h, inv_w, height, width, + data_im_ptr + cnt * height * width, + width, bp_dir); val += weight * data_col_ptr[col_pos]; cnt += 1; } @@ -472,6 +504,7 @@ __global__ void deformable_col2im_coord_gpu_kernel(const int n, const DType* dat } } + /*!\brief * gpu function of deformable_col2im_coord algorithm * \param s device stream @@ -489,13 +522,21 @@ __global__ void deformable_col2im_coord_gpu_kernel(const int n, const DType* dat */ template inline void deformable_col2im_coord(mshadow::Stream* s, - const DType* data_col, const DType* data_im, const DType* data_offset, const mxnet::TShape& im_shape, - const mxnet::TShape& col_shape, const mxnet::TShape& kernel_shape, - const mxnet::TShape& pad, const mxnet::TShape& stride, - const mxnet::TShape& dilation, const uint32_t deformable_group, DType* grad_offset, OpReqType req) { - index_t num_spatial_axes = kernel_shape.ndim(); - index_t num_kernels = col_shape[1] * col_shape[2] * 2 * kernel_shape[0] * kernel_shape[1] * deformable_group; - index_t channel_per_deformable_group = col_shape[0] / deformable_group; + const DType* data_col, + const DType* data_im, + const DType* data_offset, + const mxnet::TShape& im_shape, + const mxnet::TShape& col_shape, + const mxnet::TShape& kernel_shape, + const mxnet::TShape& pad, + const mxnet::TShape& stride, + const mxnet::TShape& dilation, + const index_t deformable_group, + DType* grad_offset) { + const int num_spatial_axes = kernel_shape.ndim(); + index_t num_kernels = col_shape[1] * col_shape[2] * 2 * + kernel_shape[0] * kernel_shape[1] * deformable_group; + index_t channel_per_group = col_shape[0] / deformable_group; // num_axes should be smaller than block size CHECK_LT(num_spatial_axes, mshadow::cuda::kBaseThreadNum); using namespace mxnet_op; @@ -504,12 +545,15 @@ inline void deformable_col2im_coord(mshadow::Stream* s, // To avoid involving atomic operations, we will launch one kernel per // bottom dimension, and then in the kernel add up the top dimensions. // NOLINT_NEXT_LINE(whitespace/operators) - - deformable_col2im_coord_gpu_kernel << ::GetStream(s) >> >( - num_kernels, data_col, data_im, data_offset, im_shape[1], im_shape[2], im_shape[3], - kernel_shape[0], kernel_shape[1], pad[0], pad[1], stride[0], stride[1], - dilation[0], dilation[1], channel_per_deformable_group, col_shape[1], col_shape[2], grad_offset, req); + deformable_col2im_coord_gpu_kernel + <<::GetStream(s)>>>(num_kernels, data_col, data_im, data_offset, + im_shape[1], im_shape[2], im_shape[3], + kernel_shape[0], kernel_shape[1], + pad[0], pad[1], stride[0], stride[1], + dilation[0], dilation[1], + channel_per_group, + col_shape[1], col_shape[2], grad_offset); MSHADOW_CUDA_POST_KERNEL_CHECK(deformable_col2im_coord_gpu_kernel); break; default: diff --git a/src/operator/contrib/nn/deformable_im2col.h b/src/operator/contrib/nn/deformable_im2col.h index 1f96fe5b2366..3f42668b86be 100644 --- a/src/operator/contrib/nn/deformable_im2col.h +++ b/src/operator/contrib/nn/deformable_im2col.h @@ -65,11 +65,197 @@ #include #include #include +#include #include "../../mxnet_op.h" namespace mxnet { namespace op { +template +inline DType im2col_bilinear_cpu(const DType* data, + const index_t height, + const index_t width, + DType h, DType w) { + index_t h_low = floor(h); + index_t w_low = floor(w); + index_t h_high; + index_t w_high; + + if (h_low >= height - 1) { + h_high = height - 1; + h = static_cast(h_low); + } else { + h_high = h_low + 1; + } + + if (w_low >= width - 1) { + w_high = width - 1; + w = static_cast(w_low); + } else { + w_high = w_low + 1; + } + + DType lh = h - h_low; + DType lw = w - w_low; + DType hh = 1 - lh, hw = 1 - lw; + + DType v1 = data[h_low * width + w_low]; + DType v2 = data[h_low * width + w_high]; + DType v3 = data[h_high * width + w_low]; + DType v4 = data[h_high * width + w_high]; + DType w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + return w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4; +} + + +template +inline DType get_gradient_weight_cpu(DType argmax_h, DType argmax_w, + const index_t h, const index_t w, + const index_t height, const index_t width) { + if (argmax_h < 0 || argmax_h > height || argmax_w < 0 || argmax_w > width) { + // empty + return 0; + } + + argmax_h = std::max(argmax_h, static_cast(0.0f)); + argmax_w = std::max(argmax_w, static_cast(0.0f)); + + index_t argmax_h_low = static_cast(argmax_h); + index_t argmax_w_low = static_cast(argmax_w); + index_t argmax_h_high; + index_t argmax_w_high; + if (argmax_h_low >= height - 1) { + argmax_h_high = argmax_h_low = height - 1; + argmax_h = static_cast(argmax_h_low); + } else { + argmax_h_high = argmax_h_low + 1; + } + if (argmax_w_low >= width - 1) { + argmax_w_high = argmax_w_low = width - 1; + argmax_w = static_cast(argmax_w_low); + } else { + argmax_w_high = argmax_w_low + 1; + } + DType weight = 0; + if (h == argmax_h_low) { + if (w == argmax_w_low) { + weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); + } else if (w == argmax_w_high) { + weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); + } + } else if (h == argmax_h_high) { + if (w == argmax_w_low) { + weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); + } else if (w == argmax_w_high) { + weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); + } + } + return weight; +} + + +template +inline DType get_coordinate_weight_cpu(DType argmax_h, DType argmax_w, + const index_t height, const index_t width, + const DType* im_data, + const index_t data_width, const index_t bp_dir) { + if (argmax_h < 0 || argmax_h > height || argmax_w < 0 || argmax_w > width) { + // empty + return 0; + } + + if (argmax_h < 0) argmax_h = 0; + if (argmax_w < 0) argmax_w = 0; + + index_t argmax_h_low = static_cast(argmax_h); + index_t argmax_w_low = static_cast(argmax_w); + index_t argmax_h_high; + index_t argmax_w_high; + if (argmax_h_low >= height - 1) { + argmax_h_high = argmax_h_low = height - 1; + argmax_h = static_cast(argmax_h_low); + } else { + argmax_h_high = argmax_h_low + 1; + } + if (argmax_w_low >= width - 1) { + argmax_w_high = argmax_w_low = width - 1; + argmax_w = static_cast(argmax_w_low); + } else { + argmax_w_high = argmax_w_low + 1; + } + + DType weight = 0; + DType im_ll = im_data[argmax_h_low * data_width + argmax_w_low]; + DType im_lh = im_data[argmax_h_low * data_width + argmax_w_high]; + DType im_hl = im_data[argmax_h_high * data_width + argmax_w_low]; + DType im_hh = im_data[argmax_h_high * data_width + argmax_w_high]; + if (bp_dir == 0) { + weight += -1 * (argmax_w_low + 1 - argmax_w) * im_ll; + weight += -1 * (argmax_w - argmax_w_low) * im_lh; + weight += (argmax_w_low + 1 - argmax_w) * im_hl; + weight += (argmax_w - argmax_w_low) * im_hh; + } else if (bp_dir == 1) { + weight += -1 * (argmax_h_low + 1 - argmax_h) * im_ll; + weight += (argmax_h_low + 1 - argmax_h) * im_lh; + weight += -1 * (argmax_h - argmax_h_low) * im_hl; + weight += (argmax_h - argmax_h_low) * im_hh; + } + + return weight; +} + + +/*! + * \brief deformable_im2col 2D cpu version. + * DO NOT call this function directly. + * Use the wrapper function im2col() instead. + */ +template +inline void deformable_im2col_cpu(const DType* data_im, + const DType* data_offset, + const index_t channels, + const index_t height, const index_t width, + const index_t kernel_h, const index_t kernel_w, + const index_t pad_h, const index_t pad_w, + const index_t stride_h, const index_t stride_w, + const index_t dilation_h, const index_t dilation_w, + const index_t deformable_group, + const index_t height_col, const index_t width_col, + DType* data_col) { + const index_t channel_size = height * width; + const index_t offset_size = 2 * kernel_h * kernel_w * height_col * width_col; + const index_t channel_per_group = channels / deformable_group; + for (index_t channel = 0; channel < channels; channel++, data_im += channel_size) { + if (channel % channel_per_group == 0 && channel != 0) { + data_offset += offset_size; + } + for (index_t i = 0; i < kernel_h; i++) { + for (index_t j = 0; j < kernel_w; j++) { + index_t input_row = -pad_h + i * dilation_h; + for (index_t h_col = 0; h_col < height_col; h_col++) { + index_t input_col = -pad_w + j * dilation_w; + for (index_t w_col = 0; w_col < width_col; w_col++) { + index_t offset_h_ptr = ((2 * (i * kernel_w + j)) * + height_col + h_col) * width_col + w_col; + index_t offset_w_ptr = offset_h_ptr + height_col * width_col; + DType im_row = input_row + data_offset[offset_h_ptr]; + DType im_col = input_col + data_offset[offset_w_ptr]; + if (im_row >= 0 && im_col >= 0 && im_row < height && im_col < width) { + *(data_col++) = im2col_bilinear_cpu(data_im, height, width, im_row, im_col); + } else { + *(data_col++) = 0; + } + input_col += stride_w; + } + input_row += stride_h; + } + } + } + } +} + + /*!\brief * cpu function of deformable_im2col algorithm * \param s device stream @@ -86,18 +272,92 @@ namespace op { */ template inline void deformable_im2col(mshadow::Stream* s, - const DType* data_im, const DType* data_offset, - const mxnet::TShape& im_shape, const mxnet::TShape& col_shape, const mxnet::TShape& kernel_shape, - const mxnet::TShape& pad, const mxnet::TShape& stride, const mxnet::TShape& dilation, - const uint32_t deformable_group, DType* data_col) { + const DType* data_im, const DType* data_offset, + const mxnet::TShape& im_shape, + const mxnet::TShape& col_shape, + const mxnet::TShape& kernel_shape, + const mxnet::TShape& pad, + const mxnet::TShape& stride, + const mxnet::TShape& dilation, + const index_t deformable_group, + DType* data_col) { if (2 == kernel_shape.ndim()) { - LOG(FATAL) << "only implemented in GPU"; + deformable_im2col_cpu(data_im, data_offset, + im_shape[1], im_shape[2], im_shape[3], + kernel_shape[0], kernel_shape[1], + pad[0], pad[1], + stride[0], stride[1], + dilation[0], dilation[1], + deformable_group, + col_shape[1], col_shape[2], data_col); } else { LOG(FATAL) << "not implemented"; } } +/*! + * \brief deformable_col2im cpu version. + * DO NOT call this directly. + * Use wrapper function deformable_col2im() instead; + */ +template +inline void deformable_col2im_cpu(const DType* data_col, + const DType* data_offset, const index_t channels, + const index_t height, const index_t width, + const index_t kernel_h, const index_t kernel_w, + const index_t pad_h, const index_t pad_w, + const index_t stride_h, const index_t stride_w, + const index_t dilation_h, const index_t dilation_w, + const index_t deformable_group, + const index_t height_col, const index_t width_col, + DType* grad_im) { + index_t channel_per_group = channels / deformable_group; + index_t count = channels * kernel_h * kernel_w * height_col * width_col; + for (index_t index = 0; index < count; ++index) { + const index_t j = (index / width_col / height_col) % kernel_w; + const index_t i = (index / width_col / height_col / kernel_w) % kernel_h; + const index_t c = index / width_col / height_col / kernel_w / kernel_h; + // compute the start and end of the output + + const index_t group_index = c / channel_per_group; + const index_t group_offset_step = 2 * kernel_h * kernel_w * height_col * width_col; + + index_t w_col = index % width_col; + index_t h_col = (index / width_col) % height_col; + index_t w_in = w_col * stride_w - pad_w; + index_t h_in = h_col * stride_h - pad_h; + + const DType* data_offset_ptr = data_offset + group_index * group_offset_step; + const index_t data_offset_h_ptr = ((2 * (i * kernel_w + j)) * + height_col + h_col) * width_col + w_col; + const index_t data_offset_w_ptr = data_offset_h_ptr + height_col * width_col; + const DType offset_h = data_offset_ptr[data_offset_h_ptr]; + const DType offset_w = data_offset_ptr[data_offset_w_ptr]; + const DType cur_inv_h_data = h_in + i * dilation_h + offset_h; + const DType cur_inv_w_data = w_in + j * dilation_w + offset_w; + + const DType cur_top_grad = data_col[index]; + const index_t cur_h = static_cast(cur_inv_h_data); + const index_t cur_w = static_cast(cur_inv_w_data); + for (int dy = -2; dy <= 2; dy++) { + for (int dx = -2; dx <= 2; dx++) { + if (cur_h + dy >= 0 && cur_h + dy < height && + cur_w + dx >= 0 && cur_w + dx < width && + std::abs(cur_inv_h_data - (cur_h + dy)) < 1 && + std::abs(cur_inv_w_data - (cur_w + dx)) < 1 + ) { + index_t cur_bottom_grad_pos = (c * height + cur_h + dy) * width + cur_w + dx; + DType weight = get_gradient_weight_cpu(cur_inv_h_data, cur_inv_w_data, + cur_h + dy, cur_w + dx, height, width); + grad_im[cur_bottom_grad_pos] += weight * cur_top_grad; + } + } + } + } +} + + /*!\brief * cpu function of deformable_col2im algorithm * \param s device stream @@ -114,12 +374,98 @@ inline void deformable_im2col(mshadow::Stream* s, */ template inline void deformable_col2im(mshadow::Stream* s, - const DType* data_col, const DType* data_offset, - const mxnet::TShape& im_shape, const mxnet::TShape& col_shape, const mxnet::TShape& kernel_shape, - const mxnet::TShape& pad, const mxnet::TShape& stride, - const mxnet::TShape& dilation, const uint32_t deformable_group, - DType* grad_im, OpReqType req) { - LOG(FATAL) << "only implemented in GPU"; + const DType* data_col, + const DType* data_offset, + const mxnet::TShape& im_shape, + const mxnet::TShape& col_shape, + const mxnet::TShape& kernel_shape, + const mxnet::TShape& pad, + const mxnet::TShape& stride, + const mxnet::TShape& dilation, + const index_t deformable_group, + DType* grad_im) { + if (2 == kernel_shape.ndim()) { + deformable_col2im_cpu(data_col, data_offset, + im_shape[1], im_shape[2], im_shape[3], + kernel_shape[0], kernel_shape[1], + pad[0], pad[1], stride[0], stride[1], + dilation[0], dilation[1], + deformable_group, + col_shape[1], col_shape[2], grad_im); + } else { + LOG(FATAL) << "not implemented"; + } +} + + +/*! + * \brief deformable_col2im_coord cpu version. + * DO NOT call this directly. + * Use wrapper function deformable_col2im_coord() instead; + */ +template +inline void deformable_col2im_coord_cpu(const DType* data_col, + const DType* data_im, + const DType* data_offset, + const index_t channels, + const index_t height, const index_t width, + const index_t kernel_h, const index_t kernel_w, + const index_t pad_h, const index_t pad_w, + const index_t stride_h, const index_t stride_w, + const index_t dilation_h, const index_t dilation_w, + const index_t deformable_group, + const index_t height_col, const index_t width_col, + DType* grad_offset) { + index_t channel_per_group = channels * kernel_h * kernel_w / deformable_group; + index_t count = height_col * width_col * 2 * kernel_h * kernel_w * deformable_group; + for (index_t index = 0; index < count; ++index) { + DType val = 0; + index_t w = index % width_col; + index_t h = (index / width_col) % height_col; + index_t c = index / width_col / height_col; + // compute the start and end of the output + + const index_t group_index = c / (2 * kernel_h * kernel_w); + const index_t group_col_step = channel_per_group * width_col * height_col; + const index_t group_im_step = channel_per_group / kernel_h / kernel_w * height * width; + const index_t group_offset_step = 2 * kernel_h * kernel_w * height_col * width_col; + const index_t col_step = kernel_h * kernel_w; + const DType* data_col_ptr = data_col + group_index * group_col_step; + const DType* data_im_ptr = data_im + group_index * group_im_step; + const DType* data_offset_ptr = data_offset + group_index * group_offset_step; + + index_t cnt = 0; + const index_t offset_c = c - group_index * 2 * kernel_h * kernel_w; + + for (index_t col_c = (offset_c / 2); col_c < channel_per_group; col_c += col_step) { + const index_t col_pos = ((col_c * height_col) + h) * width_col + w; + const index_t bp_dir = offset_c % 2; + + index_t j = (col_pos / width_col / height_col) % kernel_w; + index_t i = (col_pos / width_col / height_col / kernel_w) % kernel_h; + index_t w_col = col_pos % width_col; + index_t h_col = (col_pos / width_col) % height_col; + index_t w_in = w_col * stride_w - pad_w; + index_t h_in = h_col * stride_h - pad_h; + const index_t data_offset_h_ptr = ((2 * (i * kernel_w + j)) * + height_col + h_col) * width_col + w_col; + const index_t data_offset_w_ptr = data_offset_h_ptr + height_col * width_col; + const DType offset_h = data_offset_ptr[data_offset_h_ptr]; + const DType offset_w = data_offset_ptr[data_offset_w_ptr]; + DType inv_h = h_in + i * dilation_h + offset_h; + DType inv_w = w_in + j * dilation_w + offset_w; + if (inv_h < 0 || inv_w < 0 || inv_h >= height || inv_w >= width) { + inv_h = inv_w = -1; + } + const DType weight = get_coordinate_weight_cpu(inv_h, inv_w, height, width, + data_im_ptr + cnt * height * width, + width, bp_dir); + val += weight * data_col_ptr[col_pos]; + cnt += 1; + } + + grad_offset[index] = val; + } } @@ -138,16 +484,30 @@ inline void deformable_col2im(mshadow::Stream* s, * \param deformable_group #offset group that deformable convolution use * \param grad_offset pointer of the offset (C, H, W,...) in the offset batch */ - template inline void deformable_col2im_coord(mshadow::Stream* s, - const DType* data_col, const DType* data_im, - const DType* data_offset, const mxnet::TShape& im_shape, - const mxnet::TShape& col_shape, const mxnet::TShape& kernel_shape, - const mxnet::TShape& pad, const mxnet::TShape& stride, - const mxnet::TShape& dilation, const uint32_t deformable_group, - DType* grad_offset, OpReqType req) { - LOG(FATAL) << "only implemented in GPU"; + const DType* data_col, + const DType* data_im, + const DType* data_offset, + const mxnet::TShape& im_shape, + const mxnet::TShape& col_shape, + const mxnet::TShape& kernel_shape, + const mxnet::TShape& pad, + const mxnet::TShape& stride, + const mxnet::TShape& dilation, + const index_t deformable_group, + DType* grad_offset) { + if (2 == kernel_shape.ndim()) { + deformable_col2im_coord_cpu(data_col, data_im, data_offset, + im_shape[1], im_shape[2], im_shape[3], + kernel_shape[0], kernel_shape[1], + pad[0], pad[1], stride[0], stride[1], + dilation[0], dilation[1], + deformable_group, + col_shape[1], col_shape[2], grad_offset); + } else { + LOG(FATAL) << "not implemented"; + } } } // namespace op diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index 710686da9e7c..9c88dc15488c 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -1656,10 +1656,14 @@ def test_deformable_convolution_with_type(): 'deformable_conv_data': (2, 2, 10, 10), 'deformable_conv_offset': (2, 18, 8, 8), 'type_dict': {'deformable_conv_data': np.float32, 'deformable_conv_offset': np.float32}}, - # {'ctx': mx.gpu(0), - # 'deformable_conv_data': (2, 2, 10, 10), - # 'deformable_conv_offset': (2, 18, 8, 8), - # 'type_dict': {'deformable_conv_data': np.float16, 'deformable_conv_offset': np.float16}}, + {'ctx': mx.cpu(0), + 'deformable_conv_data': (2, 2, 10, 10), + 'deformable_conv_offset': (2, 18, 8, 8), + 'type_dict': {'deformable_conv_data': np.float64, 'deformable_conv_offset': np.float64}}, + {'ctx': mx.cpu(0), + 'deformable_conv_data': (2, 2, 10, 10), + 'deformable_conv_offset': (2, 18, 8, 8), + 'type_dict': {'deformable_conv_data': np.float32, 'deformable_conv_offset': np.float32}}, ] check_consistency(sym, ctx_list, scale=0.1, tol=tol) @@ -1676,9 +1680,9 @@ def test_deformable_convolution_options(): tol = {np.dtype(np.float32): 1e-1, np.dtype(np.float64): 1e-3} # 2D convolution + # since atomicAdd does not support fp16 (which deformable conv uses in backward), we do not test fp16 here # Pad > 0 - # since atomicAdd does not support fp16 (which deformable conv uses in backward), we do not test fp16 here ctx_list = [{'ctx': mx.gpu(0), 'deformable_conv_data': (2, 2, 7, 7), 'deformable_conv_offset': (2, 18, 7, 7), @@ -1687,12 +1691,19 @@ def test_deformable_convolution_options(): 'deformable_conv_data': (2, 2, 7, 7), 'deformable_conv_offset': (2, 18, 7, 7), 'type_dict': {'deformable_conv_data': np.float32, 'deformable_conv_offset': np.float32}}, + {'ctx': mx.cpu(0), + 'deformable_conv_data': (2, 2, 7, 7), + 'deformable_conv_offset': (2, 18, 7, 7), + 'type_dict': {'deformable_conv_data': np.float64, 'deformable_conv_offset': np.float64}}, + {'ctx': mx.cpu(0), + 'deformable_conv_data': (2, 2, 7, 7), + 'deformable_conv_offset': (2, 18, 7, 7), + 'type_dict': {'deformable_conv_data': np.float32, 'deformable_conv_offset': np.float32}}, ] sym = mx.sym.contrib.DeformableConvolution(num_filter=3, kernel=(3,3), pad=(1,1), name='deformable_conv') check_consistency(sym, ctx_list, scale=0.1, tol=tol) # Stride > 1 - # since atomicAdd does not support fp16 (which deformable conv uses in backward), we do not test fp16 here ctx_list = [{'ctx': mx.gpu(0), 'deformable_conv_data': (2, 2, 7, 7), 'deformable_conv_offset': (2, 18, 3, 3), @@ -1701,12 +1712,19 @@ def test_deformable_convolution_options(): 'deformable_conv_data': (2, 2, 7, 7), 'deformable_conv_offset': (2, 18, 3, 3), 'type_dict': {'deformable_conv_data': np.float32, 'deformable_conv_offset': np.float32}}, + {'ctx': mx.cpu(0), + 'deformable_conv_data': (2, 2, 7, 7), + 'deformable_conv_offset': (2, 18, 3, 3), + 'type_dict': {'deformable_conv_data': np.float64, 'deformable_conv_offset': np.float64}}, + {'ctx': mx.cpu(0), + 'deformable_conv_data': (2, 2, 7, 7), + 'deformable_conv_offset': (2, 18, 3, 3), + 'type_dict': {'deformable_conv_data': np.float32, 'deformable_conv_offset': np.float32}}, ] sym = mx.sym.contrib.DeformableConvolution(num_filter=3, kernel=(3,3), stride=(2,2), name='deformable_conv') check_consistency(sym, ctx_list, scale=0.1, tol=tol) # Dilate > 1 - # since atomicAdd does not support fp16 (which deformable conv uses in backward), we do not test fp16 here ctx_list = [{'ctx': mx.gpu(0), 'deformable_conv_data': (2, 2, 7, 7), 'deformable_conv_offset': (2, 18, 3, 3), @@ -1715,12 +1733,19 @@ def test_deformable_convolution_options(): 'deformable_conv_data': (2, 2, 7, 7), 'deformable_conv_offset': (2, 18, 3, 3), 'type_dict': {'deformable_conv_data': np.float32, 'deformable_conv_offset': np.float32}}, + {'ctx': mx.cpu(0), + 'deformable_conv_data': (2, 2, 7, 7), + 'deformable_conv_offset': (2, 18, 3, 3), + 'type_dict': {'deformable_conv_data': np.float64, 'deformable_conv_offset': np.float64}}, + {'ctx': mx.cpu(0), + 'deformable_conv_data': (2, 2, 7, 7), + 'deformable_conv_offset': (2, 18, 3, 3), + 'type_dict': {'deformable_conv_data': np.float32, 'deformable_conv_offset': np.float32}}, ] sym = mx.sym.contrib.DeformableConvolution(num_filter=3, kernel=(3,3), dilate=(2,2), name='deformable_conv') check_consistency(sym, ctx_list, scale=0.1, tol=tol) # Deformable group > 1 - # since atomicAdd does not support fp16 (which deformable conv uses in backward), we do not test fp16 here ctx_list = [{'ctx': mx.gpu(0), 'deformable_conv_data': (2, 2, 7, 7), 'deformable_conv_offset': (2, 36, 5, 5), @@ -1729,13 +1754,18 @@ def test_deformable_convolution_options(): 'deformable_conv_data': (2, 2, 7, 7), 'deformable_conv_offset': (2, 36, 5, 5), 'type_dict': {'deformable_conv_data': np.float32, 'deformable_conv_offset': np.float32}}, - # {'ctx': mx.gpu(0), - # 'deformable_conv_data': (2, 2, 7, 7), - # 'deformable_conv_offset': (2, 36, 5, 5), - # 'type_dict': {'deformable_conv_data': np.float16, 'deformable_offset': np.float16}}, + {'ctx': mx.cpu(0), + 'deformable_conv_data': (2, 2, 7, 7), + 'deformable_conv_offset': (2, 36, 5, 5), + 'type_dict': {'deformable_conv_data': np.float64, 'deformable_conv_offset': np.float64}}, + {'ctx': mx.cpu(0), + 'deformable_conv_data': (2, 2, 7, 7), + 'deformable_conv_offset': (2, 36, 5, 5), + 'type_dict': {'deformable_conv_data': np.float32, 'deformable_conv_offset': np.float32}}, ] - sym = mx.sym.contrib.DeformableConvolution(num_filter=4, kernel=(3,3), num_deformable_group=2, - name='deformable_conv') + sym = mx.sym.contrib.DeformableConvolution(num_filter=4, kernel=(3,3), num_deformable_group=2, name='deformable_conv') + check_consistency(sym, ctx_list, scale=0.1, tol=tol) + @with_seed() @assert_raises_cudnn_not_satisfied(min_version='5.1.10')