Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix spatial transformer op
Browse files Browse the repository at this point in the history
  • Loading branch information
anirudh2290 committed Jul 18, 2018
1 parent 5b4d528 commit a224fee
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 14 deletions.
15 changes: 9 additions & 6 deletions src/operator/spatial_transformer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -110,20 +110,21 @@ __global__ void BilinearSamplingBackwardKernel(const int i_c, const int i_h,
DType bottom_right_v = 0;
// calc input grad
if (between(top_left_x, 0, i_w-1) && between(top_left_y, 0, i_h-1)) {
*(g_input + data_index) += *(grad + grad_index) * top_left_y_w * top_left_x_w;
atomicAdd((g_input + data_index), *(grad + grad_index) * top_left_y_w * top_left_x_w);
top_left_v = *(data + data_index);
}
if (between(top_left_x+1, 0, i_w-1) && between(top_left_y, 0, i_h-1)) {
*(g_input + data_index + 1) += *(grad + grad_index) * top_left_y_w * (1.0 - top_left_x_w);
atomicAdd((g_input + data_index + 1),
*(grad + grad_index) * top_left_y_w * (1.0 - top_left_x_w));
top_right_v = *(data + data_index + 1);
}
if (between(top_left_x, 0, i_w-1) && between(top_left_y+1, 0, i_h-1)) {
*(g_input + data_index+ i_w) += *(grad + grad_index) * (1.0 - top_left_y_w) * top_left_x_w;
bottom_left_v = *(data + data_index + i_w);
atomicAdd((g_input + data_index + i_w),
*(grad + grad_index) * (1.0 - top_left_y_w) * top_left_x_w);
}
if (between(top_left_x+1, 0, i_w-1) && between(top_left_y+1, 0, i_h-1)) {
*(g_input + data_index+ i_w + 1) += *(grad + grad_index) * (1.0 - top_left_y_w) *
(1.0 - top_left_x_w);
atomicAdd((g_input + data_index + i_w + 1),
*(grad + grad_index) * (1.0 - top_left_y_w) * (1.0 - top_left_x_w));
bottom_right_v = *(data + data_index + i_w + 1);
}
// calc weight grad of top_left_w, then multiple -1 is the grad of grid_src
Expand Down Expand Up @@ -157,6 +158,7 @@ inline void BilinearSamplingForward(const Tensor<gpu, 4, DType> &output,
cudaStream_t stream = Stream<gpu>::GetStream(output.stream_);
BilinearSamplingForwardKernel<DType> << <num_blocks, threads_per_block, 0, stream >> >(
i_c, i_h, i_w, data, grid, o_n, o_c, o_h, o_w, out);
MSHADOW_CUDA_POST_KERNEL_CHECK(BilinearSamplingForwardKernel);
}

template<typename DType>
Expand All @@ -180,6 +182,7 @@ inline void BilinearSamplingBackward(const Tensor<gpu, 4, DType> &input_grad,
cudaStream_t stream = Stream<gpu>::GetStream(input_grad.stream_);
BilinearSamplingBackwardKernel<DType> << <num_blocks, threads_per_block, 0, stream >> >(
i_c, i_h, i_w, grad, data, o_n, o_c, o_h, o_w, g_input, grid_src);
MSHADOW_CUDA_POST_KERNEL_CHECK(BilinearSamplingBackwardKernel);
}

} // namespace mshadow
Expand Down
7 changes: 2 additions & 5 deletions tests/python/unittest/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,18 +95,15 @@ def random_seed(seed=None):
random.seed(next_seed)


def assert_raises_cudnn_disabled(assertion_error=False):
def assert_raises_cudnn_disabled():
def test_helper(orig_test):
@make_decorator(orig_test)
def test_new(*args, **kwargs):
cudnn_disabled = (os.getenv('CUDNN_OFF_TEST_ONLY') == "true")
if not cudnn_disabled or mx.context.current_context().device_type == 'cpu':
orig_test(*args, **kwargs)
else:
if assertion_error:
errors = (MXNetError, RuntimeError, AssertionError)
else:
errors = (MXNetError, RuntimeError)
errors = (MXNetError, RuntimeError)
assert_raises(errors, orig_test, *args, **kwargs)
return test_new
return test_helper
Expand Down
3 changes: 0 additions & 3 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2407,9 +2407,6 @@ def test_flip():


@with_seed()
# The test is disabled with USE_CUDA=ON and USE_CUDNN=OFF because of failures with the SpatialTransformer op.
# Tracked at /~https://github.com/apache/incubator-mxnet/issues/11568
@assert_raises_cudnn_disabled(assertion_error=True)
def test_stn():
np.set_printoptions(threshold=np.nan)
num_filter = 2 # conv of loc net
Expand Down

0 comments on commit a224fee

Please sign in to comment.