diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index 582680fe7957..fb40474bc678 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -1434,6 +1434,16 @@ def check_consistency(sym, ctx_list, scale=1.0, grad_req='write', return gt +def list_gpus(): + """Return a list of GPUs + + Returns + ------- + list of int: + If there are n GPUs, then return a list [0,1,...,n-1]. Otherwise returns + []. + """ + return range(mx.util.get_gpu_count()) def download(url, fname=None, dirname=None, overwrite=False, retries=5): """Download an given URL