Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Issues with spatial transformer op when cudnn disabled #11568

Closed
anirudh2290 opened this issue Jul 5, 2018 · 2 comments · Fixed by #11806
Closed

Issues with spatial transformer op when cudnn disabled #11568

anirudh2290 opened this issue Jul 5, 2018 · 2 comments · Fixed by #11806

Comments

@anirudh2290
Copy link
Member

Description

as part of PR: #11470, it was found that spatial transformer op without cudnn enabled doesn't pass tests.
To reproduce try one of the two scripts below:

Script 1:

import numpy as np
import mxnet as mx
from mxnet.test_utils import assert_almost_equal, default_context

np.set_printoptions(threshold=np.nan)
num_filter = 2  # conv of loc net
kernel = (3, 3)  # conv of loc net
num_hidden = 6  # fc of loc net
for n in [1, 2, 3, 4]:
    for c in [1, 2, 3, 4]:
        for h in [5, 9, 13, 17]:  # for convenience test, this third and forth input dim should be 4x + 1
            for w in [5, 9, 13, 17]:
                data_shape = (n, c, h, w)
                target_shape = (int((data_shape[2]+1)/2), int((data_shape[3]+1)/2))
                data = mx.sym.Variable(name="data")
                loc = mx.sym.Convolution(data=data, kernel=kernel, pad=(1, 1), num_filter=num_filter, name="loc_conv")
                loc = mx.sym.Flatten(data=loc)
                loc = mx.sym.FullyConnected(data=loc, num_hidden=num_hidden, name="loc_fc")
                stn = mx.sym.SpatialTransformer(data=data, loc=loc, target_shape=target_shape,
                                                transform_type="affine", sampler_type="bilinear")
                arg_names = stn.list_arguments()
                arg_shapes, out_shapes, _ = stn.infer_shape(data=data_shape)
                # check shape
                assert out_shapes[0] == (data_shape[0], data_shape[1], target_shape[0], target_shape[1])
                #dev = default_context()
                dev = mx.gpu(0)
                args = {}
                args['data'] = mx.random.normal(0, 1, data_shape, ctx=mx.cpu()).copyto(dev)
                args['loc_conv_weight'] = mx.nd.zeros((num_filter, data_shape[1], kernel[0], kernel[1]), ctx=dev)
                args['loc_conv_bias'] = mx.nd.zeros((num_filter,), ctx=dev)
                args['loc_fc_weight'] = mx.nd.zeros((6, num_filter*data_shape[2]*data_shape[3]), ctx=dev)
                args['loc_fc_bias'] = mx.nd.array([0.5, 0, 0, 0, 0.5, 0], ctx=dev)
                grad_grad = [mx.nd.zeros(shape, ctx=dev) for shape in arg_shapes]
                exe = stn.bind(dev, args=args, args_grad=grad_grad)
                exe.forward(is_train=True)
                out = exe.outputs[0].asnumpy()
                # check forward
                assert_almost_equal(out, args['data'].asnumpy()[:, :, h//4:h-h//4, w//4:w-w//4], rtol=1e-2, atol=1e-4)
                out_grad = mx.nd.ones(out.shape, ctx=dev)
                exe.backward([out_grad])
                # check backward
                assert_almost_equal(out_grad.asnumpy(), grad_grad[0].asnumpy()[:, :, h//4:h-h//4, w//4:w-w//4], rtol=1e-2, atol=1e-4)

Result:

AssertionError:
Items are not equal:
Error 9999.758789 exceeds tolerance rtol=0.010000, atol=0.000100.  Location of maximum error:(0, 0, 0, 0), a=1.000000, b=0.000000
 a: array([[[[1., 1., 1., ..., 1., 1., 1.],
         [1., 1., 1., ..., 1., 1., 1.],
         [1., 1., 1., ..., 1., 1., 1.]]]], dtype=float32)
 b: array([[[[0.00000024, 0.99999976, 1.        , ..., 1.        ,
          1.        , 1.        ],
         [0.00000024, 0.99999976, 1.        , ..., 1.        ,...

Script 2:

import mxnet as mx
import numpy as np
from mxnet.test_utils import check_consistency

data = mx.sym.Variable('data')
loc = mx.sym.Flatten(data)
loc = mx.sym.FullyConnected(data=loc, num_hidden=10)
loc = mx.sym.Activation(data=loc, act_type='relu')
loc = mx.sym.FullyConnected(data=loc, num_hidden=6)
sym = mx.sym.SpatialTransformer(data=data, loc=loc, target_shape=(10, 10),
                                transform_type="affine", sampler_type="bilinear")
ctx_list = [{'ctx': mx.gpu(0), 'data': (1, 5, 10, 10), 'type_dict': {'data': np.float64}},
            {'ctx': mx.cpu(0), 'data': (1, 5, 10, 10), 'type_dict': {'data': np.float64}}]
check_consistency(sym, ctx_list)
check_consistency(sym, ctx_list, grad_req="add")

Result:

Traceback (most recent call last):
  File "test_spatial_transformer.py", line 14, in <module>
    check_consistency(sym, ctx_list)
  File "/home/ubuntu/sparse_support/mxnet/python/mxnet/test_utils.py", line 1356, in check_consistency
    gtarr = gt[name].astype(dtypes[i]).asnumpy()
  File "/home/ubuntu/sparse_support/mxnet/python/mxnet/ndarray/ndarray.py", line 1910, in asnumpy
    ctypes.c_size_t(data.size)))
  File "/home/ubuntu/sparse_support/mxnet/python/mxnet/base.py", line 210, in check_call
    raise MXNetError(py_str(_LIB.MXGetLastError()))
mxnet.base.MXNetError: [21:50:56] /home/ubuntu/sparse_support/mxnet/3rdparty/mshadow/mshadow/././././cuda/tensor_gpu-inl.cuh:167: Check failed: err == cudaSuccess (7 vs. 0) Name: MapRedKeepLowestKernel ErrStr:too many resources requested for launch

Stack trace returned 10 entries:
[bt] (0) /home/ubuntu/sparse_support/mxnet/python/mxnet/../../build/libmxnet.so(dmlc::StackTrace[abi:cxx11]()+0x54) [0x7feab9a7b97d]
[bt] (1) /home/ubuntu/sparse_support/mxnet/python/mxnet/../../build/libmxnet.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x2a) [0x7feab9a7bc64]
[bt] (2) /home/ubuntu/sparse_support/mxnet/python/mxnet/../../build/libmxnet.so(void mshadow::cuda::MapReduceKeepLowest<mshadow::sv::saveto, mshadow::red::sum, mshadow::Tensor<mshadow::gpu, 1, double>, mshadow::Tensor<mshadow::gpu, 2, double>, double>(mshadow::expr::Plan<mshadow::Tensor<mshadow::gpu, 1, double>, double>, mshadow::expr::Plan<mshadow::Tensor<mshadow::gpu, 2, double>, double> const&, double, mshadow::Shape<2>, CUstream_st*)+0x2ca) [0x7feaba0b9007]
[bt] (3) /home/ubuntu/sparse_support/mxnet/python/mxnet/../../build/libmxnet.so(void mshadow::MapReduceKeepLowest<mshadow::sv::saveto, mshadow::red::sum, mshadow::Tensor<mshadow::gpu, 1, double>, double, mshadow::Tensor<mshadow::gpu, 2, double>, 0>(mshadow::TRValue<mshadow::Tensor<mshadow::gpu, 1, double>, mshadow::gpu, 1, double>*, mshadow::expr::Exp<mshadow::Tensor<mshadow::gpu, 2, double>, double, 0> const&, double)+0x39b) [0x7feaba0b8249]
[bt] (4) /home/ubuntu/sparse_support/mxnet/python/mxnet/../../build/libmxnet.so(mshadow::expr::ExpComplexEngine<mshadow::sv::saveto, mshadow::Tensor<mshadow::gpu, 1, double>, mshadow::expr::ReduceTo1DExp<mshadow::Tensor<mshadow::gpu, 2, double>, double, mshadow::red::sum, 1>, double>::Eval(mshadow::Tensor<mshadow::gpu, 1, double>*, mshadow::expr::ReduceTo1DExp<mshadow::Tensor<mshadow::gpu, 2, double>, double, mshadow::red::sum, 1> const&)+0x37) [0x7feaba0b729b]
[bt] (5) /home/ubuntu/sparse_support/mxnet/python/mxnet/../../build/libmxnet.so(void mshadow::expr::ExpEngine<mshadow::sv::saveto, mshadow::Tensor<mshadow::gpu, 1, double>, double>::Eval<mshadow::expr::ReduceTo1DExp<mshadow::Tensor<mshadow::gpu, 2, double>, double, mshadow::red::sum, 1> >(mshadow::Tensor<mshadow::gpu, 1, double>*, mshadow::expr::Exp<mshadow::expr::ReduceTo1DExp<mshadow::Tensor<mshadow::gpu, 2, double>, double, mshadow::red::sum, 1>, double, 7> const&)+0x37) [0x7feaba0b5a1c]
[bt] (6) /home/ubuntu/sparse_support/mxnet/python/mxnet/../../build/libmxnet.so(mshadow::Tensor<mshadow::gpu, 1, double>& mshadow::expr::RValueExp<mshadow::Tensor<mshadow::gpu, 1, double>, double>::__assign<mshadow::expr::ReduceTo1DExp<mshadow::Tensor<mshadow::gpu, 2, double>, double, mshadow::red::sum, 1>, 7>(mshadow::expr::Exp<mshadow::expr::ReduceTo1DExp<mshadow::Tensor<mshadow::gpu, 2, double>, double, mshadow::red::sum, 1>, double, 7> const&)+0x37) [0x7feaba0b4d49]
[bt] (7) /home/ubuntu/sparse_support/mxnet/python/mxnet/../../build/libmxnet.so(mshadow::Tensor<mshadow::gpu, 1, double>& mshadow::Tensor<mshadow::gpu, 1, double>::operator=<mshadow::expr::ReduceTo1DExp<mshadow::Tensor<mshadow::gpu, 2, double>, double, mshadow::red::sum, 1>, 7>(mshadow::expr::Exp<mshadow::expr::ReduceTo1DExp<mshadow::Tensor<mshadow::gpu, 2, double>, double, mshadow::red::sum, 1>, double, 7> const&)+0x23) [0x7feaba0b465b]
[bt] (8) /home/ubuntu/sparse_support/mxnet/python/mxnet/../../build/libmxnet.so(void mxnet::op::FCBackward<mshadow::gpu, double>(mxnet::OpContext const&, mxnet::op::FullyConnectedParam const&, std::vector<mxnet::TBlob, std::allocator<mxnet::TBlob> > const&, std::vector<mxnet::TBlob, std::allocator<mxnet::TBlob> > const&, std::vector<mxnet::OpReqType, std::allocator<mxnet::OpReqType> > const&, std::vector<mxnet::TBlob, std::allocator<mxnet::TBlob> > const&)+0xafd) [0x7feaba0b2f99]
[bt] (9) /home/ubuntu/sparse_support/mxnet/python/mxnet/../../build/libmxnet.so(void mxnet::op::FullyConnectedGradCompute<mshadow::gpu>(nnvm::NodeAttrs const&, mxnet::OpContext const&, std::vector<mxnet::TBlob, std::allocator<mxnet::TBlob> > const&, std::vector<mxnet::OpReqType, std::allocator<mxnet::OpReqType> > const&, std::vector<mxnet::TBlob, std::allocator<mxnet::TBlob> > const&)+0x4b0) [0x7feaba0ad474]

Environment info (Required)

What to do:
1. Download the diagnosis script from https://raw.githubusercontent.com/apache/incubator-mxnet/master/tools/diagnose.py
2. Run the script using `python diagnose.py` and paste its output here.

Package used (Python/R/Scala/Julia):
(I'm using ...)

For Scala user, please provide:

  1. Java version: (java -version)
  2. Maven version: (mvn -version)
  3. Scala runtime if applicable: (scala -version)

For R user, please provide R sessionInfo():

Build info (Required if built from source)

Compiler (gcc/clang/mingw/visual studio):

MXNet commit hash:
(Paste the output of git rev-parse HEAD here.)

Build config:
(Paste the content of config.mk, or the build command.)

Error Message:

(Paste the complete error message, including stack trace.)

Minimum reproducible example

(If you are using your own code, please provide a short script that reproduces the error. Otherwise, please provide link to the existing example.)

Steps to reproduce

(Paste the commands you ran that produced the error.)

What have you tried to solve it?

@andrewfayres
Copy link
Contributor

@sandeep-krishnamurthy requesting this be labeled as cuda and operator

@anirudh2290
Copy link
Member Author

Added disabled test since the tests are disabled for USE_CUDA=ON USE_CUDNN=OFF as part of #11470

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants