diff --git a/python/mxnet/gluon/contrib/nn/basic_layers.py b/python/mxnet/gluon/contrib/nn/basic_layers.py index ea0ec188d6d6..706e5e4dfb12 100644 --- a/python/mxnet/gluon/contrib/nn/basic_layers.py +++ b/python/mxnet/gluon/contrib/nn/basic_layers.py @@ -24,7 +24,7 @@ 'PixelShuffle3D'] import warnings -from .... import nd, test_utils +from .... import nd, context from ...block import HybridBlock, Block from ...nn import Sequential, HybridSequential, BatchNorm @@ -233,7 +233,7 @@ def _get_num_devices(self): warnings.warn("Caution using SyncBatchNorm: " "if not using all the GPUs, please mannually set num_devices", UserWarning) - num_devices = mx.context.num_gpus() + num_devices = context.num_gpus() num_devices = num_devices if num_devices > 0 else 1 return num_devices