Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[BUGFIX] fix log_sigmoid bugs #20372

Merged
merged 3 commits into from
Jun 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/common/cuda/rtc/backward_functions-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ backward_sigmoid(const DTypeGrad grad, const DType val) {
template <typename DType, typename DTypeGrad>
__device__ inline mixed_type<DTypeGrad, DType>
backward_log_sigmoid(const DTypeGrad grad, const DType val) {
return grad * 1 / (1 + op::exp(val));
return grad * (1 - op::exp(val));
}

template <typename DType, typename DTypeGrad>
Expand Down
2 changes: 1 addition & 1 deletion src/operator/mshadow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ MXNET_UNARY_MATH_OP(sigmoid_grad, math::id(a) * (1.0f - math::id(a)));

MXNET_UNARY_MATH_OP(log_sigmoid, math::log(1.0f / (1.0f + math::exp(-a))));

MXNET_UNARY_MATH_OP(log_sigmoid_grad, 1.0f / (1.0f + math::exp(a)));
MXNET_UNARY_MATH_OP(log_sigmoid_grad, 1.0f - math::exp(a));

struct mish : public mxnet_op::tunable {
template<typename DType>
Expand Down
13 changes: 8 additions & 5 deletions src/operator/nn/activation.cu
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,16 @@ void ActivationCompute<gpu>(const nnvm::NodeAttrs& attrs,
const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
const int act_type = param.act_type;

// SoftReLU, kSoftSign and Mish are not supported by CUDNN yet
// SoftReLU, SoftSign, Log_Sigmoid and Mish are not supported by CUDNN yet
if (act_type == activation::kSoftReLU) {
ActivationForward<gpu, mshadow_op::softrelu, mshadow_op::softrelu_grad>(ctx,
inputs[0], req[0], outputs[0]);
} else if (act_type == activation::kSoftSign) {
ActivationForward<gpu, mshadow_op::softsign, mshadow_op::softsign_grad>(ctx,
inputs[0], req[0], outputs[0]);
} else if (act_type == activation::kLogSigmoid) {
ActivationForward<gpu, mshadow_op::log_sigmoid, mshadow_op::log_sigmoid_grad>(ctx,
inputs[0], req[0], outputs[0]);
} else if (act_type == activation::kMish) {
ActivationForward<gpu, mshadow_op::mish, mshadow_op::mish_grad>(ctx,
inputs[0], req[0], outputs[0]);
Expand All @@ -87,10 +90,13 @@ void ActivationGradCompute<gpu>(const nnvm::NodeAttrs& attrs,

bool do_memory_opt = dmlc::GetEnv("MXNET_MEMORY_OPT", 0);

// SoftReLU, SoftSign and Mish not supported by CUDNN yet
// SoftReLU, SoftSign, Log_Sigmoid and Mish not supported by CUDNN yet
if (act_type == activation::kSoftReLU) {
ActivationBackward<gpu, mshadow_op::softrelu, mshadow_op::softrelu_grad>(
ctx, inputs.at(0), inputs.at(1), req[0], outputs[0]);
} else if (act_type == activation::kLogSigmoid) {
ActivationBackward<gpu, mshadow_op::log_sigmoid, mshadow_op::log_sigmoid_grad>(
ctx, inputs.at(0), inputs.at(1), req[0], outputs[0]);
} else if (act_type == activation::kMish) {
ActivationBackward<gpu, mshadow_op::mish, mshadow_op::mish_grad>(
ctx, inputs.at(0), inputs.at(2), req[0], outputs[0]);
Expand Down Expand Up @@ -121,9 +127,6 @@ void ActivationGradCompute<gpu>(const nnvm::NodeAttrs& attrs,
} else if (act_type == activation::kSigmoid) {
ActivationBackward<gpu, mshadow_op::sigmoid, mshadow_op::sigmoid_grad>(
ctx, inputs.at(0), inputs.at(1), req[0], outputs[0]);
} else if (act_type == activation::kLogSigmoid) {
ActivationBackward<gpu, mshadow_op::log_sigmoid, mshadow_op::log_sigmoid_grad>(
ctx, inputs.at(0), inputs.at(1), req[0], outputs[0]);
} else {
LOG(FATAL) << "unknown activation type";
}
Expand Down
28 changes: 26 additions & 2 deletions src/operator/tensor/elemwise_unary_op_basic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,10 +161,34 @@ The storage type of ``log_sigmoid`` output is always dense

)code" ADD_FILELINE)
.set_attr<FCompute>("FCompute<cpu>", UnaryOp::Compute<cpu, mshadow_op::log_sigmoid>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_log_sigmoid"});
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous version looks correct with the ElemwiseGradUseIn which makes the input to the gradient function the input of the elementwise function. Could you elaborate on in which cases this would fail and why you need to change it to ElemwiseGradUseOut and the definition?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure how scalar array would trigger the problem yet

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, @szha .
The reason of "scalar array would trigger the problem" is:
/~https://github.com/apache/incubator-mxnet/blob/835e25031f847b80277b6d11db0519723d26a80a/src/operator/nn/activation.cc#L126-L140

There are 2 solutions to make scalar array input to work.

  1. The input of log_sigmoid_grad should be y. So we can modify the following code which takes x as its input. This is also what I am doing in this pr.
    /~https://github.com/apache/incubator-mxnet/blob/835e25031f847b80277b6d11db0519723d26a80a/src/operator/mshadow_op.h#L416
    MXNET_UNARY_MATH_OP(log_sigmoid_grad, 1.0f - math::exp(a));
  2. Since log_sigmoid_grad takes x as input, we can also change the following code. Now it will make x as input to log_sigmoid_grad.
    /~https://github.com/apache/incubator-mxnet/blob/835e25031f847b80277b6d11db0519723d26a80a/src/operator/nn/activation-inl.h#L207-L210
        case activation::kLogSigmoid:
          ActivationBackward<xpu, mshadow_op::log_sigmoid, mshadow_op::log_sigmoid_grad>(
              ctx, inputs[0], inputs[2], req[0], outputs[0]);
          break;

I think solution_1 is better. For y = log_sigmoid(x), it calculates dx based on (dy, y) instead of (dy, x) which enables inplace operation during y = log_simoid(x) (i.e. y and x shares the same memory).

Another problem arose when I adopted the solution_1. The gradient of sym.log_sigmoid() will be wrong. The reason of this problem is that the input of _backward_log_sigmoid is x. When I adopt the solution_1, the input of _backward_log_sigmoid should be y. The source code of sym.log_sigmoid() is the following.
/~https://github.com/apache/incubator-mxnet/blob/835e25031f847b80277b6d11db0519723d26a80a/src/operator/tensor/elemwise_unary_op_basic.cc#L152-L167
So, I change it to ElemwiseGradUseOut in reference of the source code of sym.sigmoid().

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the detailed analysis. The proposed change looks good and I have no further concern.

.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{"_backward_log_sigmoid"});

MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(_backward_log_sigmoid,
unary_bwd<mshadow_op::log_sigmoid_grad>);
unary_bwd<mshadow_op::log_sigmoid_grad>)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::ObjectPtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
// n->inputs[0] : y_grad
// n->inputs[1] : f(x) = log_sigmoid(x)
// ograds[0] : head_grads
// f''(x) = f'(x) * (f'(x) - 1)
// NodeEntry{n} : y_grad * f'(x)
auto ones = MakeNode("ones_like", n->attrs.name + "_grad_ones", {n->inputs[1]}, nullptr, &n);
auto grad_minus_one = MakeNode("elemwise_sub", n->attrs.name + "_grad_sub",
{n->inputs[0], nnvm::NodeEntry{ones}}, nullptr, &n);
auto grad_grad_mid = MakeNode("elemwise_mul", n->attrs.name + "_grad_mul",
{n->inputs[0], nnvm::NodeEntry{grad_minus_one}}, nullptr, &n);
auto dydx = MakeNode("elemwise_div", n->attrs.name + "_grad_div",
{nnvm::NodeEntry{n}, n->inputs[0]}, nullptr, &n);

// when building gradient graph, the backward node of n->inputs[1] will be
// added to the graph again, therefore f`(x) will be multiplied
std::vector<nnvm::NodeEntry> ret;
ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "backward_grad_grad",
{ograds[0], nnvm::NodeEntry{dydx}}, nullptr, &n));
ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "backward_grad_grad_in",
{ograds[0], nnvm::NodeEntry{grad_grad_mid}}, nullptr, &n));
return ret;
});

// mish
MXNET_OPERATOR_REGISTER_UNARY(mish)
Expand Down
74 changes: 63 additions & 11 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -3449,6 +3449,42 @@ def hybrid_forward(self, F, a):
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)


@use_np
def test_npx_activation_log_sigmoid():
def np_log_sigmoid(x):
return _np.log(_np.divide(1.0, (1.0 + _np.exp(-x))))
def np_log_sigmoid_grad(x):
return _np.divide(1.0, _np.add(1.0, _np.exp(x)))

class TestLogSigmoid(HybridBlock):
def __init__(self):
super(TestLogSigmoid, self).__init__()

def forward(self, a):
return npx.activation(a, act_type='log_sigmoid')

shapes = [(), (2, 3, 4)]
for hybridize in [True, False]:
for shape in shapes:
test_log_sigmoid = TestLogSigmoid()
if hybridize:
test_log_sigmoid.hybridize()
x = rand_ndarray(shape).as_np_ndarray()
x.attach_grad()
np_out = np_log_sigmoid(x.asnumpy())
with mx.autograd.record():
mx_out = test_log_sigmoid(x)
assert mx_out.shape == np_out.shape
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
mx_out.backward()
np_backward = np_log_sigmoid_grad(x.asnumpy())
assert_almost_equal(x.grad.asnumpy(), np_backward, rtol=1e-3, atol=1e-5)

mx_out = npx.activation(x, act_type='log_sigmoid')
np_out = np_log_sigmoid(x.asnumpy())
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)


@use_np
def test_npx_activation_mish():
def np_mish(a):
Expand All @@ -3459,17 +3495,33 @@ def np_mish_grad(a):
sigmoid = _np.divide(1.0, (1.0 + _np.exp(-a)))
return tanh + a * sigmoid * (1.0 - tanh * tanh)

shape = (3, 4)
A = mx.np.random.uniform(low=-1.0, high=1.0, size=shape)
A.attach_grad()
np_out = np_mish(A.asnumpy())
with mx.autograd.record():
B = mx.npx.activation(A, act_type='mish')
assert B.shape == np_out.shape
assert_almost_equal(B.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
B.backward()
np_backward = np_mish_grad(A.asnumpy())
assert_almost_equal(A.grad.asnumpy(), np_backward, rtol=1e-3, atol=1e-5)
class TestMish(HybridBlock):
def __init__(self):
super(TestMish, self).__init__()

def forward(self, a):
return npx.activation(a, act_type='mish')

shapes = [(), (2, 3, 4)]
for hybridize in [True, False]:
for shape in shapes:
test_mish = TestMish()
if hybridize:
test_mish.hybridize()
x = rand_ndarray(shape).as_np_ndarray()
x.attach_grad()
np_out = np_mish(x.asnumpy())
with mx.autograd.record():
mx_out = test_mish(x)
assert mx_out.shape == np_out.shape
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
mx_out.backward()
np_backward = np_mish_grad(x.asnumpy())
assert_almost_equal(x.grad.asnumpy(), np_backward, rtol=1e-3, atol=1e-5)

mx_out = npx.activation(x, act_type='mish')
np_out = np_mish(x.asnumpy())
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)


@use_np
Expand Down