From acd0d3f6d40e4ce2218b3dabf23937c691cc3e5b Mon Sep 17 00:00:00 2001 From: Yuxi Hu Date: Fri, 18 Jan 2019 15:15:39 -0800 Subject: [PATCH] Handle horovod errors (#24) * add status manger to handle error in Horovod * add unit tests for testing errors * leverage MXNet callback to populate errors * invoke callback in a cleaner way * pass in dmlc::Error instead of char* * update imagenet example * use a function to invoke callback * fix wording --- examples/mxnet_imagenet_resnet50.py | 35 +++--- horovod/mxnet/__init__.py | 5 +- horovod/mxnet/adapter.cc | 15 --- horovod/mxnet/adapter.h | 8 +- horovod/mxnet/mpi_ops.cc | 92 +++++++++------- horovod/mxnet/mpi_ops.h | 1 + horovod/mxnet/mpi_ops.py | 6 +- test/test_mxnet.py | 161 +++++++++++++++++++++++++++- 8 files changed, 245 insertions(+), 78 deletions(-) diff --git a/examples/mxnet_imagenet_resnet50.py b/examples/mxnet_imagenet_resnet50.py index bba14ea032..b4988237c8 100644 --- a/examples/mxnet_imagenet_resnet50.py +++ b/examples/mxnet_imagenet_resnet50.py @@ -20,7 +20,6 @@ import math import os - from gluoncv.model_zoo import get_model import horovod.mxnet as hvd import mxnet as mx @@ -67,8 +66,8 @@ (default is : 40,60)') parser.add_argument('--warmup-lr', type=float, default=0.0, help='starting warmup learning rate (default: 0.0)') -parser.add_argument('--warmup-epochs', type=int, default=5, - help='number of warmup epochs (default: 5)') +parser.add_argument('--warmup-epochs', type=int, default=10, + help='number of warmup epochs (default: 10)') parser.add_argument('--last-gamma', action='store_true', default=False, help='whether to init gamma of the last BN layer in \ each bottleneck to 0 (default: False)') @@ -76,16 +75,14 @@ help='type of model to use. see vision_model for options.') parser.add_argument('--use-pretrained', action='store_true', default=False, help='load pretrained model weights (default: False)') -parser.add_argument('--optimizer', type=str, default='nag', - help='optimizer to use for training (default: nag)') parser.add_argument('--eval-epoch', action='store_true', default=False, help='evaluate validation accuracy after each epoch (default: False)') parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training (default: False)') parser.add_argument('--log-interval', type=int, default=0, help='number of batches to wait before logging (default: 0)') -parser.add_argument('--save-frequency', type=int, default=0, - help='frequency of model saving. (default: 0)') +parser.add_argument('--save-frequency', type=int, default=10, + help='frequency of model saving. (default: 10)') args = parser.parse_args() @@ -266,7 +263,7 @@ def reset(self): val_data = None -def main(): +def train(): # Get model from GluonCV model zoo # https://gluon-cv.mxnet.io/model_zoo/index.html net = get_model(args.model, **kwargs) @@ -303,7 +300,7 @@ def main(): 'lr_scheduler': lr_sched} if args.dtype == 'float16': optimizer_params['multi_precision'] = True - opt = mx.optimizer.create(args.optimizer, sym=out, **optimizer_params) + opt = mx.optimizer.create('sgd', sym=out, **optimizer_params) # Horovod: wrap optimizer with DistributedOptimizer opt = hvd.DistributedOptimizer(opt) @@ -329,7 +326,8 @@ def main(): eval_data = val_data batch_callback = None if args.log_interval > 0: - batch_callback = mx.callback.Speedometer(batch_size, max(1, args.log_interval)) + batch_callback = mx.callback.Speedometer(batch_size, + max(1, args.log_interval)) epoch_callback = None if args.save_frequency > 0: epoch_callback = mx.callback.do_checkpoint( @@ -346,14 +344,15 @@ def main(): optimizer=opt, optimizer_params=optimizer_params) - # Evaluate performance - acc_top1 = mx.metric.Accuracy() - acc_top5 = mx.metric.TopKAccuracy(5) - res = mod.score(val_data, [acc_top1, acc_top5]) - for name, val in res: - logging.info('Epoch[%d] Rank[%d] Validation-%s=%f', - args.num_epochs - 1, hvd.rank(), name, val) + # Evaluate performance if not using synthetic data + if args.use_rec: + acc_top1 = mx.metric.Accuracy() + acc_top5 = mx.metric.TopKAccuracy(5) + res = mod.score(val_data, [acc_top1, acc_top5]) + for name, val in res: + logging.info('Epoch[%d] Rank[%d] Validation-%s=%f', + args.num_epochs - 1, hvd.rank(), name, val) if __name__ == '__main__': - main() + train() diff --git a/horovod/mxnet/__init__.py b/horovod/mxnet/__init__.py index 534d10938d..946368a96c 100644 --- a/horovod/mxnet/__init__.py +++ b/horovod/mxnet/__init__.py @@ -19,8 +19,11 @@ from horovod.common import check_extension -from horovod.mxnet.mpi_ops import allreduce, allreduce_ +check_extension('horovod.mxnet', 'HOROVOD_WITH_MXNET', + __file__, 'mpi_lib') + from horovod.mxnet.mpi_ops import allgather +from horovod.mxnet.mpi_ops import allreduce, allreduce_ from horovod.mxnet.mpi_ops import broadcast, broadcast_ from horovod.mxnet.mpi_ops import init, shutdown from horovod.mxnet.mpi_ops import size, local_size, rank, local_rank diff --git a/horovod/mxnet/adapter.cc b/horovod/mxnet/adapter.cc index 3ad95d7dc7..2853d70bee 100644 --- a/horovod/mxnet/adapter.cc +++ b/horovod/mxnet/adapter.cc @@ -17,8 +17,6 @@ #include "cuda.h" #endif -#include - #include "adapter.h" #include "cuda_util.h" #include "tensor_util.h" @@ -124,19 +122,6 @@ template Framework MXOpContext::framework() const { return Framework::MXNET; } -void ThrowIfError(Status status) { - switch (status.type()) { - case StatusType::OK: - return; - case StatusType::PRECONDITION_ERROR: - throw std::logic_error(status.reason()); - case StatusType::ABORTED: - throw std::runtime_error(status.reason()); - default: // Includes UNKNOWN_ERROR - throw std::runtime_error(status.reason()); - } -} - template class MXTensor; template class MXTemporaryBuffer; template class MXOpContext; diff --git a/horovod/mxnet/adapter.h b/horovod/mxnet/adapter.h index d1a0967de8..6847708740 100644 --- a/horovod/mxnet/adapter.h +++ b/horovod/mxnet/adapter.h @@ -16,6 +16,8 @@ #ifndef HOROVOD_MXNET_ADAPTER_H #define HOROVOD_MXNET_ADAPTER_H +#include + #include "../common/common.h" namespace horovod { @@ -68,7 +70,11 @@ template class MXOpContext : public OpContext { T* output_; }; -void ThrowIfError(Status status); +inline void ThrowIfError(const Status& status) { + if (!status.ok()) { + throw dmlc::Error(status.reason()); + } +} } // namespace mxnet } // namespace horovod diff --git a/horovod/mxnet/mpi_ops.cc b/horovod/mxnet/mpi_ops.cc index 67c08c3b20..7cdac855a7 100644 --- a/horovod/mxnet/mpi_ops.cc +++ b/horovod/mxnet/mpi_ops.cc @@ -13,9 +13,6 @@ // limitations under the License. // ============================================================================= -#include -#include -#include #include #include "../common/operations.h" @@ -42,8 +39,17 @@ std::string GetOpName(std::string prefix, char* name) { } } // namespace +inline void InvokeCompleteCallback(Callback on_complete, const Status& status) { + if (status.ok()) { + on_complete(); + } else { + auto error = dmlc::Error(status.reason()); + on_complete(&error); + } +} + void DoAllreduce(NDArray* tensor, NDArray* output, const std::string& name, - Callback cb) { + Callback on_complete) { ThrowIfError(common::CheckInitialized()); auto device = TensorUtil::GetDevice(tensor); @@ -54,39 +60,40 @@ void DoAllreduce(NDArray* tensor, NDArray* output, const std::string& name, auto enqueue_result = EnqueueTensorAllreduce(hvd_context, hvd_tensor, hvd_output, nullptr, name, device, - [cb](const Status& status) { - cb(); + [on_complete](const Status& status) { + InvokeCompleteCallback(on_complete, status); }); ThrowIfError(enqueue_result); } #if HAVE_CUDA void DoAllreduceCudaOnCPU(NDArray* tensor, NDArray* output, std::string& name, - Callback cb) { + Callback on_complete) { ThrowIfError(common::CheckInitialized()); + // Make async copy of input tensor to CPU tensor and record completion event. auto hvd_cpu_buffer = std::make_shared>( CPU_DEVICE_ID, tensor->dtype()); TensorUtil::AsyncCopyCudaToCPU(tensor, hvd_cpu_buffer->tensor()); - auto ready_event = std::make_shared>(tensor); - auto hvd_context = std::make_shared>( CPU_DEVICE_ID, hvd_cpu_buffer->tensor()); + auto ready_event = std::make_shared>(tensor); auto enqueue_result = EnqueueTensorAllreduce( hvd_context, hvd_cpu_buffer, hvd_cpu_buffer, ready_event, name, CPU_DEVICE_ID, - [hvd_cpu_buffer, output, cb](const Status& status) { + [hvd_cpu_buffer, output, on_complete](const Status& status) { TensorUtil::CopyCPUToCuda(hvd_cpu_buffer->tensor(), output); - cb(); + InvokeCompleteCallback(on_complete, status); }); ThrowIfError(enqueue_result); } #endif void DoAllgather(NDArray* tensor, NDArray* output, std::string& name, - Callback cb) { + Callback on_complete) { ThrowIfError(common::CheckInitialized()); + auto device = TensorUtil::GetDevice(tensor); auto hvd_tensor = std::make_shared>(tensor); auto hvd_context = std::make_shared>(device, output); @@ -94,15 +101,15 @@ void DoAllgather(NDArray* tensor, NDArray* output, std::string& name, auto enqueue_result = EnqueueTensorAllgather(hvd_context, hvd_tensor, nullptr, name, device, - [cb](const Status& status) { - cb(); + [on_complete](const Status& status) { + InvokeCompleteCallback(on_complete, status); }); ThrowIfError(enqueue_result); } #if HAVE_CUDA void DoAllgatherCudaOnCPU(NDArray* tensor, NDArray* output, std::string& name, - Callback cb) { + Callback on_complete) { ThrowIfError(common::CheckInitialized()); // Make async copy of input tensor to CPU tensor and record completion event. @@ -119,17 +126,18 @@ void DoAllgatherCudaOnCPU(NDArray* tensor, NDArray* output, std::string& name, auto enqueue_result = EnqueueTensorAllgather( hvd_context, hvd_cpu_tensor, ready_event, name, CPU_DEVICE_ID, - [hvd_cpu_output, output, cb](const Status& status) { + [hvd_cpu_output, output, on_complete](const Status& status) { TensorUtil::CopyCPUToCuda(hvd_cpu_output->tensor(), output); - cb(); + InvokeCompleteCallback(on_complete, status); }); ThrowIfError(enqueue_result); } #endif void DoBroadcast(NDArray* tensor, NDArray* output, int root_rank, - std::string& name, Callback cb) { + std::string& name, Callback on_complete) { ThrowIfError(common::CheckInitialized()); + auto device = TensorUtil::GetDevice(tensor); auto hvd_tensor = std::make_shared>(tensor); auto hvd_context = std::make_shared>(device, output); @@ -145,8 +153,8 @@ void DoBroadcast(NDArray* tensor, NDArray* output, int root_rank, auto enqueue_result = EnqueueTensorBroadcast( hvd_context, hvd_tensor, hvd_output, root_rank, nullptr, name, device, - [cb](const Status& status) { - cb(); + [on_complete](const Status& status) { + InvokeCompleteCallback(on_complete, status); }); ThrowIfError(enqueue_result); } @@ -154,7 +162,7 @@ void DoBroadcast(NDArray* tensor, NDArray* output, int root_rank, #if HAVE_CUDA void DoBroadcastCudaOnCPU( std::shared_ptr>& hvd_cpu_buffer, int root_rank, - std::string& name, Callback cb) { + std::string& name, Callback on_complete) { // Make async copy of input tensor to CPU tensor and record completion event. auto hvd_context = std::make_shared>( CPU_DEVICE_ID, hvd_cpu_buffer->tensor()); @@ -164,8 +172,8 @@ void DoBroadcastCudaOnCPU( auto enqueue_result = EnqueueTensorBroadcast( hvd_context, hvd_cpu_buffer, hvd_cpu_buffer, root_rank, ready_event, name, CPU_DEVICE_ID, - [cb](const Status& status) { - cb(); + [on_complete](const Status& status) { + InvokeCompleteCallback(on_complete, status); }); ThrowIfError(enqueue_result); } @@ -173,18 +181,20 @@ void DoBroadcastCudaOnCPU( extern "C" int horovod_mxnet_allreduce_async(NDArray* input, NDArray* output, char* name, bool average) { + MX_API_BEGIN(); std::string op_name = GetOpName("allreduce", name); auto allreduce_async_fn = [input, output, op_name](RunContext rctx, - Callback cb) mutable { - DoAllreduce(input, output, op_name, cb); + Callback on_complete) mutable { + DoAllreduce(input, output, op_name, on_complete); }; + #if HAVE_CUDA auto allreduce_async_cpu_fn = [input, output, op_name](RunContext rctx, - Callback cb) mutable { - DoAllreduceCudaOnCPU(input, output, op_name, cb); + Callback on_complete) mutable { + DoAllreduceCudaOnCPU(input, output, op_name, on_complete); }; #endif @@ -217,23 +227,26 @@ extern "C" int horovod_mxnet_allreduce_async(NDArray* input, NDArray* output, if (average) { *output /= horovod_size(); } - return 0; + + MX_API_END(); } extern "C" int horovod_mxnet_allgather_async(NDArray* input, NDArray* output, char* name) { + MX_API_BEGIN(); std::string op_name = GetOpName("allgather", name); auto allgather_async_fn = [input, output, op_name](RunContext rctx, - Callback cb) mutable { - DoAllgather(input, output, op_name, cb); + Callback on_complete) mutable { + DoAllgather(input, output, op_name, on_complete); }; + #if HAVE_CUDA auto allgather_async_cpu_fn = [input, output, op_name](RunContext rctx, - Callback cb) mutable { - DoAllgatherCudaOnCPU(input, output, op_name, cb); + Callback on_complete) mutable { + DoAllgatherCudaOnCPU(input, output, op_name, on_complete); }; #endif @@ -261,17 +274,19 @@ extern "C" int horovod_mxnet_allgather_async(NDArray* input, NDArray* output, "HorovodAllgather"); } #endif - return 0; + + MX_API_END(); } extern "C" int horovod_mxnet_broadcast_async(NDArray* input, NDArray* output, int root_rank, char* name) { + MX_API_BEGIN(); std::string op_name = GetOpName("broadcast", name); auto broadcast_async_fn = [input, output, op_name, root_rank](RunContext rctx, - Callback cb) mutable { - DoBroadcast(input, output, root_rank, op_name, cb); + Callback on_complete) mutable { + DoBroadcast(input, output, root_rank, op_name, on_complete); }; #if HAVE_CUDA && HOROVOD_GPU_BROADCAST != 'M' @@ -283,8 +298,8 @@ extern "C" int horovod_mxnet_broadcast_async(NDArray* input, NDArray* output, TensorUtil::AsyncCopyCudaToCPU(input, hvd_cpu_buffer->tensor()); auto broadcast_async_cpu_fn = [hvd_cpu_buffer, op_name, root_rank] - (RunContext rctx, Callback cb) mutable { - DoBroadcastCudaOnCPU(hvd_cpu_buffer, root_rank, op_name, cb); + (RunContext rctx, Callback on_complete) mutable { + DoBroadcastCudaOnCPU(hvd_cpu_buffer, root_rank, op_name, on_complete); }; Engine::Get()->PushAsync(broadcast_async_cpu_fn, input->ctx(), {}, @@ -297,7 +312,8 @@ extern "C" int horovod_mxnet_broadcast_async(NDArray* input, NDArray* output, {output->var()}, FnProperty::kNormal, 0, "HorovodBroadcast"); #endif - return 0; + + MX_API_END(); } } // namespace mxnet diff --git a/horovod/mxnet/mpi_ops.h b/horovod/mxnet/mpi_ops.h index 518fb7bb4d..b0acb8b50e 100644 --- a/horovod/mxnet/mpi_ops.h +++ b/horovod/mxnet/mpi_ops.h @@ -18,6 +18,7 @@ #include #include +#include #include #include diff --git a/horovod/mxnet/mpi_ops.py b/horovod/mxnet/mpi_ops.py index 239082b3f9..4da31eb814 100644 --- a/horovod/mxnet/mpi_ops.py +++ b/horovod/mxnet/mpi_ops.py @@ -18,12 +18,11 @@ from __future__ import print_function # Load all the necessary MXNet C types. -import mxnet as mx import ctypes import os -from mxnet.base import c_str_array, c_handle_array, c_array, c_array_buf, c_str -from mxnet.base import check_call, string_types, mx_uint, py_str, string_types +import mxnet as mx +from mxnet.base import c_str, check_call, string_types from horovod.common import get_ext_suffix from horovod.common import HorovodBasics as _HorovodBasics @@ -77,6 +76,7 @@ def allreduce(tensor, average=True, name=None): else: check_call(MPI_MXNET_LIB_CTYPES.horovod_mxnet_allreduce_async(c_in, c_out, name, ctypes.c_bool(average))) + return output diff --git a/test/test_mxnet.py b/test/test_mxnet.py index f826745d4d..3a4f4711ff 100644 --- a/test/test_mxnet.py +++ b/test/test_mxnet.py @@ -17,11 +17,12 @@ from __future__ import division from __future__ import print_function +import horovod.mxnet as hvd import itertools import mxnet as mx -import unittest import numpy as np -import horovod.mxnet as hvd +import unittest +from mxnet.base import MXNetError from mxnet.test_utils import same @@ -161,6 +162,94 @@ def test_horovod_allreduce_inplace(self): incorrect results for self' mx.ndarray.waitall() + def test_horovod_allreduce_error(self): + """Test that the allreduce raises an error if different ranks try to + send tensors of different rank or dimension.""" + hvd.init() + rank = hvd.rank() + size = hvd.size() + + # This test does not apply if there is only one worker. + if size == 1: + return + + # Same rank, different dimension + ctx = self._current_context() + + shape = (17 + rank, 3) + tensor = mx.nd.ones(shape=shape, ctx=ctx) + try: + output = hvd.allreduce(tensor) + output.wait_to_read() + assert False, 'hvd.allreduce did not throw error' + except (MXNetError, RuntimeError): + pass + + # Same number of elements, different rank + if rank == 0: + shape = (17, 23 * 57) + else: + shape = (17, 23, 57) + tensor = mx.nd.ones(shape=shape, ctx=ctx) + try: + output = hvd.allreduce(tensor) + output.wait_to_read() + assert False, 'hvd.allreduce did not throw error' + except (MXNetError, RuntimeError): + pass + + def test_horovod_allreduce_type_error(self): + """Test that the allreduce raises an error if different ranks try to + send tensors of different type.""" + hvd.init() + rank = hvd.rank() + size = hvd.size() + + # This test does not apply if there is only one worker. + if size == 1: + return + + ctx = self._current_context() + shape = (17, 3) + tensor = mx.nd.ones(shape=shape, ctx=ctx) + if rank % 2 == 0: + tensor = tensor.astype('int32') + else: + tensor = tensor.astype('float32') + + try: + output = hvd.allreduce(tensor) + output.wait_to_read() + assert False, 'hvd.allreduce did not throw error' + except (MXNetError, RuntimeError): + pass + + @unittest.skipUnless(has_gpu, "no gpu detected") + def test_horovod_allreduce_cpu_gpu_error(self): + """Test that the allreduce raises an error if different ranks try to + perform reduction on CPU and GPU.""" + hvd.init() + rank = hvd.rank() + size = hvd.size() + + # This test does not apply if there is only one worker. + if size == 1: + return + + shape = (17, 17, 17) + if rank % 2 == 0: + ctx = mx.gpu(hvd.rank()) + else: + ctx = mx.cpu(hvd.rank()) + tensor = mx.nd.ones(shape=shape, ctx=ctx) + + try: + output = hvd.allreduce(tensor) + output.wait_to_read() + assert False, 'hvd.allreduce did not throw cpu-gpu error' + except (MXNetError, RuntimeError): + pass + def test_horovod_broadcast(self): """Test that the broadcast correctly broadcasts 1D, 2D, 3D tensors.""" hvd.init() @@ -299,6 +388,74 @@ def test_horovod_broadcast_grad(self): 'hvd.broadcast produces incorrect broadcasted tensor' mx.ndarray.waitall() + def test_horovod_broadcast_error(self): + """Test that the broadcast returns an error if any dimension besides + the first is different among the tensors being broadcasted.""" + hvd.init() + rank = hvd.rank() + size = hvd.size() + + # This test does not apply if there is only one worker. + if size == 1: + return + + ctx = self._current_context() + shape = (17, rank+1) + tensor = mx.nd.ones(shape=shape, ctx=ctx) + + try: + output = hvd.broadcast(tensor, 0) + output.wait_to_read() + assert False, 'hvd.broadcast did not throw error' + except (MXNetError, RuntimeError): + pass + + def test_horovod_broadcast_type_error(self): + """Test that the broadcast returns an error if the types being broadcasted + differ among the processes""" + hvd.init() + rank = hvd.rank() + size = hvd.size() + + # This test does not apply if there is only one worker. + if size == 1: + return + + ctx = self._current_context() + shape = (17, 3) + tensor = mx.nd.ones(shape=shape, ctx=ctx) + if rank % 2 == 0: + tensor = tensor.astype('int32') + else: + tensor = tensor.astype('float32') + + try: + output = hvd.broadcast(tensor, 0) + output.wait_to_read() + assert False, 'hvd.broadcast did not throw error' + except (MXNetError, RuntimeError): + pass + + def test_horovod_broadcast_rank_error(self): + """Test that the broadcast returns an error if different ranks + specify different root rank.""" + hvd.init() + rank = hvd.rank() + size = hvd.size() + + # This test does not apply if there is only one worker. + if size == 1: + return + + ctx = self._current_context() + shape = (17, 17, 17) + tensor = mx.nd.ones(shape=shape, ctx=ctx) + try: + output = hvd.broadcast(tensor, root_rank=rank) + output.wait_to_read() + assert False, 'hvd.broadcast did not throw rank error' + except (MXNetError, RuntimeError): + pass if __name__ == '__main__': unittest.main()