Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add exception handling support for waitall #14397

Merged
merged 25 commits into from
Apr 8, 2019
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
a85b3f0
Relax constexpr restriction
anirudh2290 Feb 5, 2019
5bd6428
Change imagenet_gen_qsym_mkldnn
anirudh2290 Feb 12, 2019
d684e4c
Merge branch 'master' of /~https://github.com/apache/incubator-mxnet
anirudh2290 Feb 12, 2019
5debcc2
Merge branch 'master' of /~https://github.com/apache/incubator-mxnet
anirudh2290 Feb 26, 2019
082a0aa
Merge branch 'master' of /~https://github.com/apache/incubator-mxnet
anirudh2290 Mar 1, 2019
f194aa2
Merge branch 'master' of /~https://github.com/apache/incubator-mxnet
anirudh2290 Mar 4, 2019
26079e6
Add exception handling support for waitall
anirudh2290 Mar 12, 2019
f0a76e3
Fix exception handling documentation
anirudh2290 Mar 12, 2019
c169a86
Fix quantization file
anirudh2290 Mar 12, 2019
cd98fa9
Revert constexpr change
anirudh2290 Mar 12, 2019
4f694f6
Add comments
anirudh2290 Mar 12, 2019
d79560b
Fix test
anirudh2290 Mar 12, 2019
3a581e8
Skip exception for op check names
anirudh2290 Mar 12, 2019
0b5444b
Print exceptions thrown for CPP Package NDArray module
anirudh2290 Mar 12, 2019
38b8dca
Reducing batch_size to make cpp-package example pass
anirudh2290 Mar 14, 2019
1c0d936
Fix bug: #14426
anirudh2290 Mar 14, 2019
034e9c7
use ExceptionRef in threaded_engine code
anirudh2290 Mar 14, 2019
d34d95e
add note for performance impact of waitall
anirudh2290 Mar 14, 2019
48a6638
Add check for GPU contxt
anirudh2290 Mar 20, 2019
f631d57
Use range for with const reference
anirudh2290 Mar 21, 2019
54e301f
Improve comments and error message for exception handling test
anirudh2290 Mar 21, 2019
9e6972a
Change exception_ptr name in waitall
anirudh2290 Mar 21, 2019
ee1fabd
Merge branch 'master' of /~https://github.com/apache/incubator-mxnet in…
anirudh2290 Mar 22, 2019
ff8151c
Fix bug
anirudh2290 Mar 25, 2019
2a13065
Merge branch 'master' of /~https://github.com/apache/incubator-mxnet in…
anirudh2290 Apr 2, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions cpp-package/include/mxnet-cpp/ndarray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,12 +233,12 @@ inline NDArray NDArray::Reshape(const Shape &new_shape) const {
return NDArray(handle);
}
inline void NDArray::WaitToRead() const {
CHECK_EQ(MXNDArrayWaitToRead(blob_ptr_->handle_), 0);
CHECK_EQ(MXNDArrayWaitToRead(blob_ptr_->handle_), 0) << MXGetLastError();
}
inline void NDArray::WaitToWrite() {
CHECK_EQ(MXNDArrayWaitToWrite(blob_ptr_->handle_), 0);
CHECK_EQ(MXNDArrayWaitToWrite(blob_ptr_->handle_), 0) << MXGetLastError();
}
inline void NDArray::WaitAll() { CHECK_EQ(MXNDArrayWaitAll(), 0); }
inline void NDArray::WaitAll() { CHECK_EQ(MXNDArrayWaitAll(), 0) << MXGetLastError(); }
inline void NDArray::SampleGaussian(mx_float mu, mx_float sigma, NDArray *out) {
Operator("_random_normal")(mu, sigma).Invoke(*out);
}
Expand Down
3 changes: 0 additions & 3 deletions docs/architecture/exception_handling.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,3 @@ except mx.base.MXNetError as ex:
d.asnumpy()
```

### Limitation
anirudh2290 marked this conversation as resolved.
Show resolved Hide resolved

Rethrowing exceptions as part of `mx.nd.waitall` is not supported. So if your code executes a few operators and then calls `waitall` instead of `wait_to_read`/`asnumpy`, the exception will disappear. Please avoid waitalls in your code unless you are confident about your code not throwing exception in any scenario.
7 changes: 0 additions & 7 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,13 +157,6 @@ def waitall():
"""Wait for all async operations to finish in MXNet.

This function is used for benchmarking only.

.. warning::

If your code has exceptions, `waitall` can cause silent failures.
For this reason you should avoid `waitall` in your code.
Use it only if you are confident that your code is error free.
Then make sure you call `wait_to_read` on all outputs after `waitall`.
"""
check_call(_LIB.MXNDArrayWaitAll())

Expand Down
27 changes: 27 additions & 0 deletions src/engine/threaded_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,25 @@ void ThreadedEngine::WaitForAll() {
finished_cv_.wait(lock, [this]() {
return pending_.load() == 0 || kill_.load();
});
std::exception_ptr tmp;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall this code be wrapped in a function?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently it is used only once so it is fine to not use a function.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe then a better variable name than tmp? ex_to_rethrow?

if (!global_exception_refs_.empty()) {
// iterate through all exception refs
for (auto itr = global_exception_refs_.begin();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we use range for with const reference? is much less noisy.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review!

itr != global_exception_refs_.end(); ++itr) {
const std::shared_ptr<std::exception_ptr>& ptr = *itr;
// the first exception will be saved to be rethrown later
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the order of exceptions stored in the "global_exception_refs_" ? If we are throwing the first one then is it the innermost in the stack that causes all other exceptions or the outermost ? If its outermost then it might not give correct idea about what was the root cause

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@access2rohit the order of the exceptions will be maintained exception thrown first will be rethrown first.

if (*ptr != nullptr && !tmp) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can be evaluated in bool context, so less noise.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed

tmp = *ptr;
}
// clear exceptions, WaitToRead following WaitForAll shouldn't throw
*ptr = nullptr;
}
// A waitall following a waitall shouldn't throw any exceptions
global_exception_refs_.clear();
if (tmp != nullptr) {
std::rethrow_exception(tmp);
}
}
}

inline void ThreadedEngine::OnComplete(ThreadedOpr* threaded_opr) {
Expand All @@ -428,6 +447,14 @@ inline void ThreadedEngine::OnComplete(ThreadedOpr* threaded_opr) {
for (auto&& i : threaded_opr->mutable_vars) {
if (threaded_opr->opr_exception && *threaded_opr->opr_exception) {
i->var_exception = threaded_opr->opr_exception;
// add current operator exceptions to global exceptions if not already
// added
auto it = std::find(global_exception_refs_.begin(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

L452-L457 are used in three places. Can we make it a function?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review ! I have moved it into a function.

global_exception_refs_.end(),
threaded_opr->opr_exception);
if (it == global_exception_refs_.end()) {
global_exception_refs_.push_back(threaded_opr->opr_exception);
}
}
const bool debug_info = (engine_info_ && debug_wait_var_ == i);
if (debug_info) {
Expand Down
29 changes: 27 additions & 2 deletions src/engine/threaded_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,11 @@ class ThreadedVar final
static std::atomic<std::size_t> counter;
~ThreadedVar() { LOG(INFO) << __func__ << " " << --counter; }
#endif // ENGINE_DEBUG
/*! \brief exception_ptr associated with the ThreadedVar */
/*!
* \brief exception_ptr associated with the ThreadedOpr
* cannot modify state of exception object since dereferencing
* exception_ptr is undefined behavior. Using shared_ptr to hold
anirudh2290 marked this conversation as resolved.
Show resolved Hide resolved
* exception_ptr and overcome this limitation */
std::shared_ptr<std::exception_ptr> var_exception;

private:
Expand Down Expand Up @@ -254,7 +258,11 @@ struct ThreadedOpr final : public Opr,
}
// define possible debug information
DEFINE_ENGINE_DEBUG_INFO(ThreadedOpr);
/*! \brief exception_ptr associated with the ThreadedOpr */
/*!
* \brief exception_ptr associated with the ThreadedOpr
* cannot modify state of exception object since dereferencing
* exception_ptr is undefined behavior. Using shared_ptr to hold
* exception_ptr and overcome this limitation */
std::shared_ptr<std::exception_ptr> opr_exception;
}; // struct ThreadedOpr

Expand Down Expand Up @@ -432,6 +440,9 @@ class ThreadedEngine : public Engine {
};
/*! thread local store for bulk */
typedef dmlc::ThreadLocalStore<BulkStatus> BulkStatusStore;
/*! shared_ptr to exception_ptr, used for exception handling */
typedef std::shared_ptr<std::exception_ptr> ExceptionRef;
anirudh2290 marked this conversation as resolved.
Show resolved Hide resolved

/*!
* \brief check if thee is duplication in const_vars and mutable_vars.
* \param const_vars the variables to read from.
Expand Down Expand Up @@ -460,13 +471,25 @@ class ThreadedEngine : public Engine {
for (auto&& i : threaded_opr->const_vars) {
if (i->var_exception && *i->var_exception) {
threaded_opr->opr_exception = i->var_exception;
auto it = std::find(global_exception_refs_.begin(),
global_exception_refs_.end(),
threaded_opr->opr_exception);
if (it == global_exception_refs_.end()) {
global_exception_refs_.push_back(threaded_opr->opr_exception);
}
break;
}
}
if (!(threaded_opr->opr_exception && *threaded_opr->opr_exception)) {
for (auto&& i : threaded_opr->mutable_vars) {
if (i->var_exception && *i->var_exception) {
threaded_opr->opr_exception = i->var_exception;
auto it = std::find(global_exception_refs_.begin(),
global_exception_refs_.end(),
threaded_opr->opr_exception);
if (it == global_exception_refs_.end()) {
global_exception_refs_.push_back(threaded_opr->opr_exception);
}
break;
}
}
Expand Down Expand Up @@ -542,6 +565,8 @@ class ThreadedEngine : public Engine {
*/
std::mutex finished_m_;
std::condition_variable finished_cv_;
/*! \brief global exception refs, which are rethrown when WaitForAll is called */
std::vector<ExceptionRef> global_exception_refs_;

/*!
* \brief Holding a shared_ptr to the object pool to prevent it from being destructed too early
Expand Down
110 changes: 78 additions & 32 deletions tests/python/unittest/test_exc_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ def imperative(exec_numpy=True):
c.asnumpy()

imperative(exec_numpy=False)
assert_raises(MXNetError, imperative, True)
assert_raises(MXNetError, imperative, exec_numpy=True)

@with_seed()
def test_exc_symbolic():
def symbolic(exec_backward=True):
def symbolic(exec_backward=True, waitall=True):
x = mx.sym.Variable('x')
y = mx.sym.Variable('y')
z = mx.sym.Variable('z')
Expand All @@ -58,16 +58,25 @@ def symbolic(exec_backward=True):
outputs = exec1.forward()
if exec_backward:
exec1.backward()
exec1.grad_arrays[0].asnumpy()
if waitall:
mx.nd.waitall()
else:
exec1.grad_arrays[0].asnumpy()
else:
outputs[0].asnumpy()
if waitall:
mx.nd.waitall()
else:
outputs[0].asnumpy()

assert_raises(MXNetError, symbolic, False)
assert_raises(MXNetError, symbolic, True)
assert_raises(MXNetError, symbolic, exec_backward=False)
assert_raises(MXNetError, symbolic, exec_backward=True)

assert_raises(MXNetError, symbolic, exec_backward=False, waitall=True)
assert_raises(MXNetError, symbolic, exec_backward=True, waitall=True)

@with_seed()
def test_exc_gluon():
def gluon(exec_wait=True):
def gluon(exec_wait=True, waitall=False):
model = nn.Sequential()
model.add(nn.Dense(128, activation='tanh', in_units=10, flatten=False))
model.add(nn.Dropout(1))
Expand All @@ -77,46 +86,83 @@ def gluon(exec_wait=True):
y = model(x)
model.collect_params().initialize(ctx=[default_context()])
z = model(mx.nd.random.normal(10, -10, (32, 2, 10), ctx=default_context()))
if exec_wait:
if waitall:
mx.nd.waitall()
elif exec_wait:
z.wait_to_read()

gluon(exec_wait=False)
assert_raises(MXNetError, gluon, True)
assert_raises(MXNetError, gluon, exec_wait=True)

assert_raises(MXNetError, gluon, waitall=True)

@with_seed()
def test_exc_multiple_waits():
caught = False
try:
a = mx.nd.random.normal(0, -1, (2, 2)).copyto(default_context())
a.wait_to_read()
except MXNetError:
caught = True
assert caught, "No exception thrown"
try:
b = mx.nd.random.normal(0, -1, (2, 2)).copyto(default_context())
b.wait_to_read()
except MXNetError:
caught = True
assert caught, "No exception thrown"
def multiple_waits(waitall=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does it make sense to use "@raises"? maybe it would be easier to read.

https://nose.readthedocs.io/en/latest/testing_tools.html

At least a small comment explaining the test approach for future readers and that we expect exception to be thrown, is that the intent?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added comments. Intention is to test multiple wait_to_reads and waitalls for vars in same scope.

caught = False
try:
a = mx.nd.random.normal(0, -1, (2, 2)).copyto(default_context())
if waitall:
mx.nd.waitall()
else:
a.wait_to_read()
except MXNetError:
caught = True
assert caught, "No exception thrown"
try:
b = mx.nd.random.normal(0, -1, (2, 2)).copyto(default_context())
if waitall:
mx.nd.waitall()
else:
b.wait_to_read()
except MXNetError:
caught = True
assert caught, "No exception thrown"

multiple_waits(waitall=False)
multiple_waits(waitall=True)

@with_seed()
def test_exc_post_fail():
def post_fail(waitall=False):
caught = False
try:
a, b = mx.nd.random_normal(0, -1, (2, 2)).copyto(default_context())
if waitall:
mx.nd.waitall()
else:
a.asnumpy()
except MXNetError:
caught = True
assert caught, "No exception thrown"
b.asnumpy()
post_fail(waitall=False)
post_fail(waitall=True)

@with_seed()
def test_exc_mutable_var_fail():
def mutable_var_check(waitall=False):
a, b = mx.nd.random_normal(0, -1, (2, 2)).copyto(default_context())
a = mx.nd.dot(a, a)
if waitall:
mx.nd.waitall()
else:
a.asnumpy()
assert_raises(MXNetError, mutable_var_check, waitall=False)
assert_raises(MXNetError, mutable_var_check, waitall=True)

@with_seed()
def test_multiple_waitalls():
caught = False
try:
a, b = mx.nd.random_normal(0, -1, (2, 2)).copyto(default_context())
a.asnumpy()
a = mx.nd.random.normal(0, -1, (2, 2)).copyto(default_context())
mx.nd.waitall()
except MXNetError:
caught = True
assert caught, "No exception thrown"
b.asnumpy()
mx.nd.waitall()


@with_seed()
def test_exc_mutable_var_fail():
def mutable_var_check():
a, b = mx.nd.random_normal(0, -1, (2, 2)).copyto(default_context())
a = mx.nd.dot(a, a)
a.asnumpy()
assert_raises(MXNetError, mutable_var_check)

if __name__ == '__main__':
import nose
Expand Down
14 changes: 12 additions & 2 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6961,7 +6961,12 @@ def get_output_names_callback(name, arr):

op_exe = op_sym.simple_bind(ctx=mx.current_context(), grad_req='null')
op_exe.set_monitor_callback(get_output_names_callback, monitor_all=False)
op_exe.forward()
try:
op_exe.forward()
mx.nd.waitall()
except mx.base.MXNetError:
# skip errors since test is to check output names
pass
for output_name, expected_name in zip(output_names, expected_names):
assert output_name == expected_name

Expand Down Expand Up @@ -7007,7 +7012,12 @@ def get_output_names_callback(name, arr):

op_exe = op_sym.simple_bind(ctx=mx.current_context(), grad_req='null')
op_exe.set_monitor_callback(get_output_names_callback, monitor_all=True)
op_exe.forward()
try:
op_exe.forward()
mx.nd.waitall()
except mx.base.MXNetError:
# skip errors since test is to check all names
pass
for output_name, expected_name in zip(output_names, expected_names):
assert output_name == expected_name

Expand Down