diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index 4dbf82edd3f5..f3299163f323 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -564,8 +564,18 @@ def _conv_with_num_streams(seed): @with_seed() def test_convolution_multiple_streams(): + engines = ['NaiveEngine', 'ThreadedEngine', 'ThreadedEnginePerDevice'] + + if os.getenv('MXNET_ENGINE_TYPE') is not None: + engines = [os.getenv('MXNET_ENGINE_TYPE'),] + print("Only running against '%s'" % engines[0], file=sys.stderr, end='') + # Remove this else clause when the ThreadedEngine can handle this test + else: + engines.remove('ThreadedEngine') + print("SKIP: 'ThreadedEngine', only running against %s" % engines, file=sys.stderr, end='') + for num_streams in [1, 2]: - for engine in ['NaiveEngine', 'ThreadedEngine', 'ThreadedEnginePerDevice']: + for engine in engines: _test_in_separate_process(_conv_with_num_streams, {'MXNET_GPU_WORKER_NSTREAMS' : num_streams, 'MXNET_ENGINE_TYPE' : engine})