-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update ad.backward_pass to support non-linear functions of constants #26811
base: main
Are you sure you want to change the base?
Conversation
Looks like this has some failures that seem related to the "type errors" mentioned in the I'm not sure I totally understand the source of these issues, but I'll think about it some more with this as a starting point! |
I suspect this approach doesn't work, because I think we tried it in #1749 (note "To make remat_call_p transposable, we need to upgrade our ad.backward_pass code to do some nonlinear evaluation (to rematerialize the residuals needed)" in the PR message, and the double loop over eqns here and here). See also the predecessor #1719 which I think was lumped into #1749. The trouble was that a HOP like jit could include a mix of linear and nonlinear stuff. That is, we can't assume that a HOP is only purely linear or nonlinear, and we may have to run just the nonlinear parts of it first and run the linear parts later. (I haven't read this code carefully but I suspect we can construct examples that this PR won't be able to transpose, e.g. by making an inner jit that includes linear and nonlinear parts, then trying to transpose wrt the linear inputs using To run nested nonlinear parts, saving any intermediates we need in HOP bodies (ie recursing in to handle mixed linear/nonlinear HOPs, rather than just doing it at the top level and assuming each HOP application is purely linear or purely nonlinear), is essentially partial evaluation. Indeed that's the approach we eventually took for remat transpose: we evaluate the nonlinear parts using partial eval and then run the linear-only We could build that same logic into |
Thanks for this clear (as always!!) explanation @mattjj, and for the references. Indeed, this all makes sense, and I see the complications. For the specific case of direct linearization, I guess we could update the lin rule signature to explicitly return the unzipped functions: the usual linear part, plus a new one to compute the residuals from the primals. But, that doesn't seem very useful in most cases.
Yeah - I guess there's also the question of how important is it that we support this behavior in linearization rules. I can see in principle why we might want to do this, but did you or @dougalm have any specific applications in mind? It seems like it would be simple enough to support this in the case of |
Sorry, I wasn't clear about my original proposal for nonlinear operations within linear functions. In particular, linear functions may have nonlinear ops but they can't return nonlinear values. So there should be no need to worry about unzipping functions! Here's the type system I had in mind. Each variable and literal in a jaxpr is tagged with a linearity, which is just a boolean attribute that can be either "linear" or "nonlinear". Also, each jaxpr has a linearity (a jaxpr can be linear or nonlinear. There is no mixed option). And each eqn in a jaxpr has a linearity too. Again, eqns can be linear or nonlinear. Never a mixture. The rule for linear jaxprs is that they only return linear values. Their arguments can be any mixture of linear an nonlinear. Similarly, linear eqns can only return linear values. Nonlinear eqns can only have nonlinear arguments. Linear eqns can have a mixture of linear and nonlinear arguments depending on the op. We could imagine a "check_linearity" function that checks whether the argument linearity is acceptable for a given op. For example, here are the rules for IsLinear = bool
def check_linearity_add(x_lin:IsLinear, y_lin:IsLinear):
assert x_lin and y_lin
def check_linearity_mul(x_lin:IsLinear, y_lin:IsLinear):
assert (x_lin and not y_lin) or (not x_lin and y_lin) The way to transpose a linear jaxpr is to pull all the nonlinear eqns to the beginning and then reverse the order of the linear eqns and transpose their operations. Easy. The challenge is that we don't currently have these linearity annotations on jaxprs so we need to infer them on the fly. We can mostly do this by looking at data dependencies. Any variable that depends on a linear variable is linear. The rest are nonlinear. This is what "undefined primal" is all about in transposition. "undefined primal" means "I've inferred that this is a linear value". But there are some genuinely ambiguous cases. Anyway, bottom line: no need to worry about unzipping mixed linear/nonlinear functions or HOPs. |
Thanks @dougalm! I think that what you've suggested here is close to what I have implemented in this PR: Any eqn that any linear inputs is treated as linear, while the others are treated as non-linear, and evaluated first. But, my implementation introduces some failures that seem to indicate some problems with the logic the way that I have implemented it. In particular, there seem to be issues with the control flow tests, so Matt's comment about HOPs seems to be pointing in the right direction wrt what I'm missing! |
Yes exactly, your implementation looks like the best approximation we can currently muster until we add linearity annotations. I'm curious to understand where it fell down. Can you paste the jaxpr that failed to transpose under your new logic? |
Right now,
ad.backward_pass
assumes that all primitives that it encounters are linear, and our current AD approach of JVP -> partial eval -> transposition means that this is always satisfied. When using direct linearization, it is conceivable that we would want to support non-linear eqns in the backwards pass, as long as their inputs are all literals or primals (i.e. functions of the residuals in the linear rule).This PR adds an initial forward sweep in
ad.backward_pass
(based oncore.eval_jaxpr
) to evaluate any of these deferred operations before transposing the remaining eqns. I think this is a reasonable approach, but I'm not sure I've thought through all the potential failure modes. WDYT?