You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
as part of PR: #11470, it was found that spatial transformer op without cudnn enabled doesn't pass tests.
To reproduce try one of the two scripts below:
Script 1:
import numpy as np
import mxnet as mx
from mxnet.test_utils import assert_almost_equal, default_context
np.set_printoptions(threshold=np.nan)
num_filter = 2 # conv of loc net
kernel = (3, 3) # conv of loc net
num_hidden = 6 # fc of loc net
for n in [1, 2, 3, 4]:
for c in [1, 2, 3, 4]:
for h in [5, 9, 13, 17]: # for convenience test, this third and forth input dim should be 4x + 1
for w in [5, 9, 13, 17]:
data_shape = (n, c, h, w)
target_shape = (int((data_shape[2]+1)/2), int((data_shape[3]+1)/2))
data = mx.sym.Variable(name="data")
loc = mx.sym.Convolution(data=data, kernel=kernel, pad=(1, 1), num_filter=num_filter, name="loc_conv")
loc = mx.sym.Flatten(data=loc)
loc = mx.sym.FullyConnected(data=loc, num_hidden=num_hidden, name="loc_fc")
stn = mx.sym.SpatialTransformer(data=data, loc=loc, target_shape=target_shape,
transform_type="affine", sampler_type="bilinear")
arg_names = stn.list_arguments()
arg_shapes, out_shapes, _ = stn.infer_shape(data=data_shape)
# check shape
assert out_shapes[0] == (data_shape[0], data_shape[1], target_shape[0], target_shape[1])
#dev = default_context()
dev = mx.gpu(0)
args = {}
args['data'] = mx.random.normal(0, 1, data_shape, ctx=mx.cpu()).copyto(dev)
args['loc_conv_weight'] = mx.nd.zeros((num_filter, data_shape[1], kernel[0], kernel[1]), ctx=dev)
args['loc_conv_bias'] = mx.nd.zeros((num_filter,), ctx=dev)
args['loc_fc_weight'] = mx.nd.zeros((6, num_filter*data_shape[2]*data_shape[3]), ctx=dev)
args['loc_fc_bias'] = mx.nd.array([0.5, 0, 0, 0, 0.5, 0], ctx=dev)
grad_grad = [mx.nd.zeros(shape, ctx=dev) for shape in arg_shapes]
exe = stn.bind(dev, args=args, args_grad=grad_grad)
exe.forward(is_train=True)
out = exe.outputs[0].asnumpy()
# check forward
assert_almost_equal(out, args['data'].asnumpy()[:, :, h//4:h-h//4, w//4:w-w//4], rtol=1e-2, atol=1e-4)
out_grad = mx.nd.ones(out.shape, ctx=dev)
exe.backward([out_grad])
# check backward
assert_almost_equal(out_grad.asnumpy(), grad_grad[0].asnumpy()[:, :, h//4:h-h//4, w//4:w-w//4], rtol=1e-2, atol=1e-4)
What to do:
1. Download the diagnosis script from https://raw.githubusercontent.com/apache/incubator-mxnet/master/tools/diagnose.py
2. Run the script using `python diagnose.py` and paste its output here.
Package used (Python/R/Scala/Julia):
(I'm using ...)
For Scala user, please provide:
Java version: (java -version)
Maven version: (mvn -version)
Scala runtime if applicable: (scala -version)
For R user, please provide R sessionInfo():
Build info (Required if built from source)
Compiler (gcc/clang/mingw/visual studio):
MXNet commit hash:
(Paste the output of git rev-parse HEAD here.)
Build config:
(Paste the content of config.mk, or the build command.)
Error Message:
(Paste the complete error message, including stack trace.)
Minimum reproducible example
(If you are using your own code, please provide a short script that reproduces the error. Otherwise, please provide link to the existing example.)
Steps to reproduce
(Paste the commands you ran that produced the error.)
What have you tried to solve it?
The text was updated successfully, but these errors were encountered:
Description
as part of PR: #11470, it was found that spatial transformer op without cudnn enabled doesn't pass tests.
To reproduce try one of the two scripts below:
Script 1:
Result:
Script 2:
Result:
Environment info (Required)
Package used (Python/R/Scala/Julia):
(I'm using ...)
For Scala user, please provide:
java -version
)mvn -version
)scala -version
)For R user, please provide R
sessionInfo()
:Build info (Required if built from source)
Compiler (gcc/clang/mingw/visual studio):
MXNet commit hash:
(Paste the output of
git rev-parse HEAD
here.)Build config:
(Paste the content of config.mk, or the build command.)
Error Message:
(Paste the complete error message, including stack trace.)
Minimum reproducible example
(If you are using your own code, please provide a short script that reproduces the error. Otherwise, please provide link to the existing example.)
Steps to reproduce
(Paste the commands you ran that produced the error.)
What have you tried to solve it?
The text was updated successfully, but these errors were encountered: