From 4b1811cce5933f5c9a6de2e930fa15d31e62bfcc Mon Sep 17 00:00:00 2001 From: Wang Jiajun Date: Thu, 21 Mar 2019 00:23:01 +0800 Subject: [PATCH] fix custom operation in fork (#14451) * fix custom operation in fork * add test * fix custom stop * swap order * add docs * update doc --- docs/tutorials/gluon/customop.md | 20 +++++++++++++ src/c_api/c_api.cc | 1 + src/initialize.cc | 5 ++++ src/operator/custom/custom-inl.h | 27 ++++++++++++------ src/operator/custom/custom.cc | 5 ---- tests/python/unittest/test_operator.py | 39 ++++++++++++++++++++++++++ 6 files changed, 83 insertions(+), 14 deletions(-) diff --git a/docs/tutorials/gluon/customop.md b/docs/tutorials/gluon/customop.md index eae0344c8702..29ab21843114 100644 --- a/docs/tutorials/gluon/customop.md +++ b/docs/tutorials/gluon/customop.md @@ -30,6 +30,7 @@ Custom operator in python is easy to develop and good for prototyping, but may h import numpy as np import mxnet as mx from mxnet import gluon, autograd +import os ``` ## Parameter-less operators @@ -214,5 +215,24 @@ y = dense(x) print(y) ``` +## Using custom operators with fork +In Linux systems, the default method in multiprocessing to create process is by using fork. If there are unfinished async custom operations when forking, the program will be blocked because of python GIL. Always use sync calls like `wait_to_read` or `waitall` before calling fork. + +``` +x = mx.nd.array([0, 1, 2, 3]) +y = mx.nd.Custom(x, op_type='sigmoid') +# unfinished async sigmoid operation will cause blocking +os.fork() +``` + +Correctly handling this will make mxnet depend upon libpython, so the workaround now is to ensure that all custom operations are executed before forking process. + +``` +x = mx.nd.array([0, 1, 2, 3]) +y = mx.nd.Custom(x, op_type='sigmoid') +# force execution by reading y +print(y.asnumpy()) +os.fork() +``` diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 5a7329acaeab..70ba84b5f94b 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -111,6 +111,7 @@ int MXRandomSeedContext(int seed, int dev_type, int dev_id) { int MXNotifyShutdown() { API_BEGIN(); + mxnet::op::custom::CustomOperator::Get()->Stop(); Engine::Get()->NotifyShutdown(); API_END(); } diff --git a/src/initialize.cc b/src/initialize.cc index 8d0e3c304216..00a736abd8ba 100644 --- a/src/initialize.cc +++ b/src/initialize.cc @@ -26,6 +26,7 @@ #include #include #include "./engine/openmp.h" +#include "./operator/custom/custom-inl.h" #if MXNET_USE_OPENCV #include #endif // MXNET_USE_OPENCV @@ -53,12 +54,15 @@ class LibraryInitializer { // disable openmp for multithreaded workers #ifndef _WIN32 + using op::custom::CustomOperator; pthread_atfork( []() { + CustomOperator::Get()->Stop(); Engine::Get()->Stop(); }, []() { Engine::Get()->Start(); + CustomOperator::Get()->Start(); }, []() { // Conservative thread management for multiprocess workers @@ -71,6 +75,7 @@ class LibraryInitializer { #endif // MXNET_USE_OPENCV engine::OpenMP::Get()->set_enabled(false); Engine::Get()->Start(); + CustomOperator::Get()->Start(); }); #endif } diff --git a/src/operator/custom/custom-inl.h b/src/operator/custom/custom-inl.h index f88e830bc573..c5eaea13661e 100644 --- a/src/operator/custom/custom-inl.h +++ b/src/operator/custom/custom-inl.h @@ -136,7 +136,21 @@ class CustomOperator { cv_.notify_all(); } - ~CustomOperator() { + static CustomOperator* Get() { + static CustomOperator inst; + return &inst; + } + + void Start() { + num_free_threads = 0; + destructing_ = false; + naive_engine_ = true; + if (std::string("NaiveEngine") != dmlc::GetEnv("MXNET_ENGINE_TYPE", std::string())) { + naive_engine_ = false; + } + } + + void Stop() { if (naive_engine_) return; { std::unique_lock lock(mutex_); @@ -145,17 +159,12 @@ class CustomOperator { } for (auto &worker : workers_) worker.join(); + workers_.clear(); } - static CustomOperator* Get(); - private: - CustomOperator() : num_free_threads(0) { - destructing_ = false; - naive_engine_ = true; - if (std::string("NaiveEngine") != dmlc::GetEnv("MXNET_ENGINE_TYPE", std::string())) { - naive_engine_ = false; - } + CustomOperator() { + this->Start(); } void ThreadTarget() { std::unique_lock lock(mutex_); diff --git a/src/operator/custom/custom.cc b/src/operator/custom/custom.cc index 39cca4d7c436..46249c9bbcc6 100644 --- a/src/operator/custom/custom.cc +++ b/src/operator/custom/custom.cc @@ -34,11 +34,6 @@ namespace mxnet { namespace op { namespace custom { -CustomOperator* CustomOperator::Get() { - static CustomOperator inst; - return &inst; -} - struct CustomParam { std::string op_type; size_t num_args, num_outs, num_auxs; diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index f4d2ef32cc2e..c9498ecb0bd2 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -5198,6 +5198,45 @@ def create_operator(self, ctx, shapes, dtypes): x = mx.nd.Custom(length=10, depth=10, op_type="no_input_op") assert_almost_equal(x.asnumpy(), np.ones(shape=(10, 10), dtype=np.float32)) + # test custom operator fork + # see /~https://github.com/apache/incubator-mxnet/issues/14396 + if not sys.platform.startswith('win'): # no fork in windows + class AdditionOP(mx.operator.CustomOp): + def __init__(self): + super(AdditionOP, self).__init__() + def forward(self, is_train, req, in_data, out_data, aux): + out_data[0][:] = in_data[0] + in_data[1] + def backward(self, req, out_grad, in_data, out_data, in_grad, aux): + in_grad[0][:] = out_grad[0] + in_grad[1][:] = out_grad[0] + + @mx.operator.register("AdditionOP") + class AdditionOPProp(mx.operator.CustomOpProp): + def __init__(self): + super(AdditionOPProp, self).__init__() + def list_arguments(self): + return ['a', 'b'] + def list_outputs(self): + return ['output'] + def infer_shape(self, in_shape): + return in_shape, [in_shape[0]] + def create_operator(self, ctx, shapes, dtypes): + return AdditionOP() + + def custom_add(): + a = mx.nd.array([1, 2, 3]) + b = mx.nd.array([4, 5, 6]) + c = mx.nd.Custom(a, b, op_type='AdditionOP') + assert_almost_equal((a + b).asnumpy(), c.asnumpy()) + + custom_add() + from multiprocessing import Process + p = Process(target=custom_add) + p.daemon = True + p.start() + p.join(5) + assert not p.is_alive(), "deadlock may exist in custom operator" + @with_seed() def test_psroipooling(): for num_rois in [1, 2]: