-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-978] Second order gradient support for some unary operators #14613
[MXNET-978] Second order gradient support for some unary operators #14613
Conversation
…xnet into develop/higher_order_grad
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! LGTM
src/imperative/imperative.cc
Outdated
@@ -347,8 +347,9 @@ std::vector<NDArray*> Imperative::Backward( | |||
x_reqs.push_back(info.grad_req); | |||
info.fresh_out_grad = true; | |||
} | |||
CHECK_GT(xs.size(), 0) | |||
<< "There are no inputs in computation graph that require gradients."; | |||
if (xs.empty()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why change from CHECK to warning?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Current backward operation requires an operator must have at least one inputs, because the gradient of a constants is always zero. However, the second order of some operators such as relu is actually gradient of a constant (ones or zeros). Therefore we need to support gradient for constant operators.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should dive deeper into this one. Does it produce the warning (or early the failure) for some of the test cases?
In the original code I think the intention is to get if there's any input nodes which have gradient attached, I understand your explanation but what I don't see is where would we store the gradient for such constants, is because grad_req of the constant is kNullOp? the constant is just another node right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The rootcause is when we do second order gradient of the negative(x) operator. The backward graph of this does not require any input and therefore will trigger this condition. I think if I remove the test for negative(x) then we do not need to modify this.
Thanks for your contributions @apeforest. Can you also look into the CI failures ? @mxnet-label-bot Add [Operator, pr-awaiting-review, Backend]. |
@sxjscience Please help to review |
Looks good. However, the test has not passed and we need to check what caused the error. Also, we should later add a testing utility function that checks the higher gradient by numerical differentiation. |
@apeforest Thanks for working on this. Can you look into build failures? |
@with_seed() | ||
def test_elemwise_mul(): | ||
x = nd.array([1, 2, 3]) | ||
y = nd.zeros(3) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need this y?
def cos(x): | ||
return nd.cos(x) | ||
|
||
x = nd.array([1, 2, 3]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we randomize the test arrays with random_arrays and rand_shape_2d
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think for second order not using random inputs helps reason about the gradient result...
def negative(x): | ||
return nd.negative(x) | ||
|
||
x = nd.array([1, 2, 3]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above and for rest of the tests
@apeforest could you have a look at the CI failures |
auto grad_grad_x_mid = MakeNode("cos", n->attrs.name + "_mid_grad_grad", | ||
{n->inputs[1]}, nullptr, &n); | ||
auto grad_grad_x = MakeNode("negative", n->attrs.name + "_backward_grad_grad", | ||
{nnvm::NodeEntry(grad_grad_x_mid)}, nullptr, &n); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated
.set_attr<nnvm::FGradient>("FGradient", | ||
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) { | ||
std::vector<nnvm::NodeEntry> ret; | ||
// f(x) -> f = relu |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we be having these comments? couldn't it be included as part of 'R"code'
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a hidden operator so user do not see this.
MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_sin, unary_bwd<mshadow_op::sin_grad>) | ||
.set_attr<nnvm::FGradient>("FGradient", | ||
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) { | ||
// f(x) = sin(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a hidden operator so user do not see this.
|
||
arrays = random_arrays((2, 2), (2, 3), (4, 5, 2), (3, 1, 4, 5)) | ||
for array in arrays: | ||
check_second_order_unary(array, sin, grad_grad_op) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think check_second_order_unary
function should be moved to python/mxnet/test_utils.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's only used in this test. If we add a different test file then it makes sense as you suggested.
def grad_grad_op(x): | ||
return nd.zeros_like(x) | ||
|
||
arrays = random_arrays((2, 2), (2, 3), (4, 5, 2), (3, 1, 4, 5)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't we test for 1-d arrays?
not sure if it is needed here, but there is this to randomize the shape of an array - /~https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/test_utils.py#L418
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the suggestion. updated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
everything else lgtm
@kshitij12345 gentle ping again :) please approve it there is no other concern. thanks |
@apeforest As mentioned in the #14992 for from mxnet import nd, autograd
import numpy
import math
grad_grad_op = lambda x: -nd.sin(x) # -nd.cos(x)
x = nd.random.normal(0,1,(3,3))
x.attach_grad()
with autograd.record():
y = nd.sin(x) # nd.cos(x)
y_grad = autograd.grad(y, x, head_grads= nd.ones_like(y) * 0.5, create_graph=True, retain_graph=True)[0]
y_grad.backward(nd.ones_like(y_grad) * 0.6)
numpy.testing.assert_allclose(x.grad.asnumpy() , ( grad_grad_op(x) * 0.5 * 0.6).asnumpy(), rtol=1e-7, atol=1e-7) As the |
@kshitij12345 Thanks for catching this bug. I have updated my sin and cos implementation with comments to help better understanding the mathematics behind second order gradient calculation. Please help to review again. |
@kshitij12345 could you please take a look at the PR again. Thanks |
// f''(x) = 0 | ||
auto gx = nnvm::NodeEntry{n}; // f'(x) | ||
ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad", | ||
{ograds[0], gx}, nullptr, &n)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to what you have done below for sin
and cos
.
gx
is actually f'(x) * {head_grads/output_gradient}
.
It should actually only be f'(x)
.
Explanation : gx = f'(x) * head_grads
.
Therefore, gx w.r.t. f'(x) = head_grads
,
Similarly gx w.r.t. head_grads = f'(x)
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated
Could you update the This covers check for gradient of the first input argument as well. Have tested a similar Pytorch Script which works ( code in PR #15120 ). However do note that for PR #15120 , Assertion fails with Please check to see if it works for you. |
@kshitij12345 I updated the test locally and it also fails with sin(x) and cos(x) as well with the same reason of log(x). I have not pushed it to the remote because it then requires your PR #15120 to get merged first. Otherwise the test will always fail. |
Ping @aidan-plenert-macdonald for review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! I think the code is right, but I would like the tests to be updated slightly to catch ops that don't support Nth order gradients.
for dim in range(1, 5): | ||
shape = rand_shape_nd(dim) | ||
array = random_arrays(shape) | ||
check_second_order_unary(array, cos, grad_grad_op) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can these check_second_order_unary
checks be changes to Nth order?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR is only to verify second order gradient. Can we add test for Nth order gradient in a separate PR?
// f''(x) = -sin(x) | ||
auto x_grad = MakeNode("cos", n->attrs.name + "_x_grad", | ||
{n->inputs[1]}, nullptr, &n); | ||
auto x_grad_grad = MakeNode("negative", n->attrs.name + "_x_grad_grad", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Although it appears that this permit's Nth order derivatives, the _grad_grad
thing concerns me as though this is simple registering a 2nd order grad and calling that good. Assuming this does support Nth order, it worries me that someone else trying to copy this would "cheat" by only having 2 gradients.
Ideally the testing would catch this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as long as the nodes are themselves differentiable, then it would support additional differentiation. For some complex functions having just 2nd gradient should be good. In this case I would say it's N times differentiable. the _x_grad_grad is just the name of the node as far as I understand it. Why does it concern you?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I renamed it. Also this indicates the gradient is on top of previous gradient. User can still apply this recursively for nth order gradient.
To summarize from discussion in #15120, I think in MXNet we calculate input gradients based on the variables specified in the autograd.grad() API. In this case y_grad, even if a grad attribute is attached to it, no gradient values will be assigned during the second pass of |
Thank you very much for the clarification.
BTW, I was wondering if above should be documented? Mostly probably users won't end up in this situation but still. Your thoughts? |
Yes, I think we should document this in the autograd API. What do you think? |
@kshitij12345 Do you think we can merge this PR now? If so, could you please approve it? thanks! |
LGTM |
…pache#14613) * try to add support some ops * add unit test for second order grad * implement grad for relu and add unit test * fix lint * register FGradient attribute for backward relu * resolve conflict * remove unused imports * change gradient using set_attr * remove higher order grad test for negative(x) * fix lint * reverse indent * remove unused backward operator * refactor backward for sin(x) and cos(x) * change value init to list init * change to list initialization * generate random shape in test * fix a bug in second order backward * fix lint * fix lint * address reviewer comment and renaming
Description
This PR is to support higher order gradient in some basic unary operators. Thanks to @sxjscience, this PR is a continued work of #12821.
This PR adds additional support for relu operator and fixed the issue of negative operator support in #12821. In addition, I added unit tests for to test higher order gradient.
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes