Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Properly handling custom op exception by modify engine #14693

Merged
merged 6 commits into from
Apr 16, 2019
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions docs/faq/env_var.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,6 @@ $env:MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0
* MXNET_MP_OPENCV_NUM_THREADS
- Values: Int ```(default=0)```
- The number of OpenCV execution threads given to multiprocess workers. OpenCV multithreading is disabled if `MXNET_MP_OPENCV_NUM_THREADS` < 1 (default). Enlarge this number may boost the performance of individual workers when executing underlying OpenCV functions but please consider reducing the overall `num_workers` to avoid thread contention (not available on Windows).
* MXNET_CUSTOM_OP_NUM_THREADS
- Values: Int ```(default=16)```
- The maximum number of threads given to custom operators.

## Memory Options

Expand Down
6 changes: 5 additions & 1 deletion include/mxnet/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,9 @@ enum class FnProperty {
/*! \brief Delete variable call */
kDeleteVar,
/*! \brief Prioritized sync operation on GPU */
kGPUPrioritized
kGPUPrioritized,
/*! \brief Operation not to be skipped even with associated exception */
kNoSkip
}; // enum class FnProperty

/*!
Expand Down Expand Up @@ -230,6 +232,8 @@ class MXNET_API Engine {
* \brief Wait until all the activity of engine finishes.
*/
virtual void WaitForAll() = 0;
/*!\brief Throw if threre are associated exception with var */
virtual void Throw(VarHandle var) = 0;
/*!\brief virtual destructor */
virtual ~Engine() noexcept(false) {}
/*!
Expand Down
3 changes: 3 additions & 0 deletions src/engine/naive_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,9 @@ class NaiveEngine final : public Engine {
void WaitForAll() override {
}

void Throw(VarHandle var) override {
}

void NotifyShutdown() override {
shutdown_phase_.store(true);
}
Expand Down
5 changes: 5 additions & 0 deletions src/engine/threaded_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,11 @@ inline void ThreadedEngine::ThrowException(ThreadedVar* threaded_var) {
return;
}

void ThreadedEngine::Throw(VarHandle var) {
ThreadedVar *threaded_var = ThreadedVar::CastFromBase(var);
ThrowException(threaded_var);
}
szha marked this conversation as resolved.
Show resolved Hide resolved

void ThreadedEngine::OnCompleteStatic(Engine *engine, void *opr_block_,
const dmlc::Error* error) {
OprBlock *opr_block = static_cast<OprBlock*>(opr_block_);
Expand Down
5 changes: 3 additions & 2 deletions src/engine/threaded_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,7 @@ class ThreadedEngine : public Engine {
void DeleteVariable(SyncFn delete_fn, Context exec_ctx, VarHandle var) override;
void WaitForVar(VarHandle var) override;
void WaitForAll() override;
void Throw(VarHandle var) override;
void NotifyShutdown() override {
shutdown_phase_.store(true);
}
Expand Down Expand Up @@ -374,8 +375,8 @@ class ThreadedEngine : public Engine {
LOG(INFO) << "ExecuteOprFn ";
}
try {
if (!(threaded_opr->opr_exception && *threaded_opr->opr_exception) ||
threaded_opr->wait) {
if ((!(threaded_opr->opr_exception && *threaded_opr->opr_exception) ||
threaded_opr->prop == FnProperty::kNoSkip) || threaded_opr->wait) {
threaded_opr->fn(run_ctx, callback);
} else {
callback();
Expand Down
45 changes: 35 additions & 10 deletions src/operator/custom/custom-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,12 @@ class CustomOperator {
bool prev_recording = Imperative::Get()->set_is_recording(recording);
bool prev_training = Imperative::Get()->set_is_training(training);

func();
try {
func();
} catch (dmlc::Error& e) {
exception_ =
std::make_shared<std::exception_ptr>(std::current_exception());
}

Imperative::Get()->set_is_training(prev_training);
Imperative::Get()->set_is_recording(prev_recording);
Expand All @@ -116,6 +121,16 @@ class CustomOperator {

Engine::Get()->PushSync(
[=](RunContext rctx) {
try {
Throw();
for (const auto& i : arrs) {
Engine::Get()->Throw(i.var());
}
szha marked this conversation as resolved.
Show resolved Hide resolved
} catch(dmlc::Error& err) {
ctx.async_on_complete(&err);
return;
}

for (size_t i = 0, out_idx = 0; i < arrs.size(); i++) {
if (arrs[i].storage_type() == kDefaultStorage ||
arrs[i].storage_type() == kUndefinedStorage)
Expand All @@ -125,14 +140,15 @@ class CustomOperator {
out_idx++;
}
}

ctx.async_on_complete();
},
ctx.run_ctx.ctx, vars, vars2, FnProperty::kNormal, 0,
ctx.run_ctx.ctx, vars, vars2, FnProperty::kNoSkip, 0,
"CustomOperator");
});
// increase num_threads if there is not enough threads to execute custom operator
if (q_.size() > num_free_threads)
CreateThreads(q_.size() - num_free_threads);
if (q_.size() > num_free_threads_)
CreateThreads(q_.size() - num_free_threads_);
cv_.notify_all();
}

Expand All @@ -142,9 +158,10 @@ class CustomOperator {
}

void Start() {
num_free_threads = 0;
num_free_threads_ = 0;
destructing_ = false;
naive_engine_ = true;
exception_ = nullptr;
if (std::string("NaiveEngine") != dmlc::GetEnv("MXNET_ENGINE_TYPE", std::string())) {
naive_engine_ = false;
}
Expand All @@ -162,6 +179,14 @@ class CustomOperator {
workers_.clear();
}

inline void Throw() {
if (exception_ && *exception_) {
std::exception_ptr tmp = *exception_;
exception_ = nullptr;
std::rethrow_exception(tmp);
}
}

private:
CustomOperator() {
this->Start();
Expand All @@ -171,21 +196,20 @@ class CustomOperator {
while (!q_.empty() || !destructing_) {
cv_.wait(lock, [&] {return !q_.empty() || destructing_;});
while (!q_.empty()) {
--num_free_threads;
--num_free_threads_;
auto fn = q_.front();
q_.pop();
lock.unlock();
fn();
++num_free_threads;
++num_free_threads_;
lock.lock();
}
}
}
void SetNumThreads(int num_threads) {
num_threads = std::min(dmlc::GetEnv("MXNET_CUSTOM_OP_NUM_THREADS", 16), num_threads);
for (int i = workers_.size(); i < num_threads; ++i) {
workers_.emplace_back(std::thread([this]{this->ThreadTarget();}));
++num_free_threads;
++num_free_threads_;
}
}
void CreateThreads(int num_new_threads) {
Expand All @@ -196,8 +220,9 @@ class CustomOperator {
// async worker
std::condition_variable cv_;
std::vector<std::thread> workers_;
std::atomic<uint32_t> num_free_threads;
std::atomic<uint32_t> num_free_threads_;
std::queue<std::function<void(void)> > q_;
std::shared_ptr<std::exception_ptr> exception_;
bool naive_engine_;
bool destructing_;
};
Expand Down
56 changes: 34 additions & 22 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from mxnet.test_utils import *
from mxnet.base import py_str, MXNetError, _as_list
from common import setup_module, with_seed, teardown, assert_raises_cudnn_not_satisfied, assertRaises
from nose.tools import assert_raises
import unittest
import os

Expand Down Expand Up @@ -5355,29 +5356,29 @@ def create_operator(self, ctx, shapes, dtypes):

# test custom operator fork
# see /~https://github.com/apache/incubator-mxnet/issues/14396
if not sys.platform.startswith('win'): # no fork in windows
class AdditionOP(mx.operator.CustomOp):
def __init__(self):
super(AdditionOP, self).__init__()
def forward(self, is_train, req, in_data, out_data, aux):
out_data[0][:] = in_data[0] + in_data[1]
def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
in_grad[0][:] = out_grad[0]
in_grad[1][:] = out_grad[0]

@mx.operator.register("AdditionOP")
class AdditionOPProp(mx.operator.CustomOpProp):
def __init__(self):
super(AdditionOPProp, self).__init__()
def list_arguments(self):
return ['a', 'b']
def list_outputs(self):
return ['output']
def infer_shape(self, in_shape):
return in_shape, [in_shape[0]]
def create_operator(self, ctx, shapes, dtypes):
return AdditionOP()
class AdditionOP(mx.operator.CustomOp):
def __init__(self):
super(AdditionOP, self).__init__()
def forward(self, is_train, req, in_data, out_data, aux):
out_data[0][:] = in_data[0] + in_data[1]
szha marked this conversation as resolved.
Show resolved Hide resolved
def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
in_grad[0][:] = out_grad[0]
in_grad[1][:] = out_grad[0]

@mx.operator.register("AdditionOP")
class AdditionOPProp(mx.operator.CustomOpProp):
def __init__(self):
super(AdditionOPProp, self).__init__()
def list_arguments(self):
return ['a', 'b']
def list_outputs(self):
return ['output']
def infer_shape(self, in_shape):
return in_shape, [in_shape[0]]
def create_operator(self, ctx, shapes, dtypes):
return AdditionOP()

if not sys.platform.startswith('win'): # no fork in windows
def custom_add():
a = mx.nd.array([1, 2, 3])
b = mx.nd.array([4, 5, 6])
Expand All @@ -5392,6 +5393,17 @@ def custom_add():
p.join(5)
assert not p.is_alive(), "deadlock may exist in custom operator"

# test except handling
def custom_add_exc():
a = mx.nd.array([1, 2, 3])
b = mx.nd.array([4, 5])
# trigger exception by providing unmatched operand shapes
c = mx.nd.Custom(a, b, op_type='AdditionOP')
c.wait_to_read()

assert_raises(MXNetError, custom_add_exc)


@with_seed()
def test_psroipooling():
for num_rois in [1, 2]:
Expand Down