diff --git a/python/mxnet/gluon/contrib/estimator/estimator.py b/python/mxnet/gluon/contrib/estimator/estimator.py index d30595a6efb5..f43f17520654 100644 --- a/python/mxnet/gluon/contrib/estimator/estimator.py +++ b/python/mxnet/gluon/contrib/estimator/estimator.py @@ -88,24 +88,36 @@ def _check_metrics(self, metrics): return metrics def _check_context(self, context): - # handle context - if isinstance(context, Context): - context = [context] - elif isinstance(context, list) and all([isinstance(c, Context) for c in context]): - context = context - elif not context: - if num_gpus() > 0: + # infer available context + gpus = num_gpus() + available_gpus = [gpu(i) for i in range(gpus)] + + if context: + # check context values, only accept Context or a list of Context + if isinstance(context, Context): + context = [context] + elif isinstance(context, list) and all([isinstance(c, Context) for c in context]): + context = context + else: + raise ValueError("context must be a Context or a list of Context, " + "for example mx.cpu() or [mx.gpu(0), mx.gpu(1)], " + "refer to mxnet.Context:{}".format(context)) + for ctx in context: + assert ctx in available_gpus or str(ctx).startswith('cpu'), \ + "%s is not available, please make sure " \ + "your context is in one of: mx.cpu(), %s" % \ + (ctx, ", ".join([str(ctx) for ctx in available_gpus])) + else: + # provide default context + if gpus > 0: # only use 1 GPU by default - if num_gpus() > 1: + if gpus > 1: warnings.warn("You have multiple GPUs, gpu(0) will be used by default." "To utilize all your GPUs, specify context as a list of gpus, " "e.g. context=[mx.gpu(0), mx.gpu(1)] ") context = [gpu(0)] else: context = [cpu()] - else: - raise ValueError("context must be a Context or a list of Context, " - "refer to mxnet.Context:{}".format(context)) return context def _initialize(self, initializer): @@ -167,7 +179,8 @@ def prepare_loss_and_metrics(self): self.train_metrics = [Accuracy()] self.val_metrics = [] for loss in self.loss: - self.train_metrics.append(Loss(''.join([i for i in loss.name if not i.isdigit()]))) + # remove trailing numbers from loss name to avoid confusion + self.train_metrics.append(Loss(loss.name.rstrip('1234567890'))) for metric in self.train_metrics: val_metric = copy.deepcopy(metric) metric.name = "Train " + metric.name diff --git a/tests/python/unittest/test_gluon_estimator.py b/tests/python/unittest/test_gluon_estimator.py index 643214212e3a..b25baa255165 100644 --- a/tests/python/unittest/test_gluon_estimator.py +++ b/tests/python/unittest/test_gluon_estimator.py @@ -258,6 +258,12 @@ def test_context(): metrics=metrics, context='cpu') + with assert_raises(AssertionError): + est = Estimator(net=net, + loss=loss, + metrics=metrics, + context=[mx.gpu(0), mx.gpu(100)]) + def test_categorize_handlers(): class CustomHandler1(TrainBegin):