diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 7169395205e0..d5a5a3076240 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -5198,6 +5198,50 @@ def create_operator(self, ctx, shapes, dtypes): x = mx.nd.Custom(length=10, depth=10, op_type="no_input_op") assert_almost_equal(x.asnumpy(), np.ones(shape=(10, 10), dtype=np.float32)) + # test custom operator fork + # see /~https://github.com/apache/incubator-mxnet/issues/14396 + class AdditionOP(mx.operator.CustomOp): + def __init__(self): + super(AdditionOP, self).__init__() + def forward(self, is_train, req, in_data, out_data, aux): + out_data[0][:] = in_data[0] + in_data[1] + def backward(self, req, out_grad, in_data, out_data, in_grad, aux): + in_grad[0][:] = out_grad[0] + in_grad[1][:] = out_grad[0] + + @mx.operator.register("AdditionOP") + class AdditionOPProp(mx.operator.CustomOpProp): + def __init__(self): + super(AdditionOPProp, self).__init__() + def list_arguments(self): + return ['a', 'b'] + def list_outputs(self): + return ['output'] + def infer_shape(self, in_shape): + return in_shape, [in_shape[0]] + def create_operator(self, ctx, shapes, dtypes): + return AdditionOP() + + def custom_add(): + a = mx.nd.array([1, 2, 3]) + b = mx.nd.array([4, 5, 6]) + a.attach_grad() + b.attach_grad() + + with mx.autograd.record(): + c = mx.nd.Custom(a, b, op_type='AdditionOP') + + dc = mx.nd.array([7, 8, 9]) + c.backward(dc) + + custom_add() + from multiprocessing import Process + p = Process(target=custom_add) + p.daemon = True + p.start() + p.join(5) + assert not p.is_alive(), "deadlock may exist in custom operator" + @with_seed() def test_psroipooling(): for num_rois in [1, 2]: