Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update ad.backward_pass to support non-linear functions of constants #26811

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

dfm
Copy link
Collaborator

@dfm dfm commented Feb 27, 2025

Right now, ad.backward_pass assumes that all primitives that it encounters are linear, and our current AD approach of JVP -> partial eval -> transposition means that this is always satisfied. When using direct linearization, it is conceivable that we would want to support non-linear eqns in the backwards pass, as long as their inputs are all literals or primals (i.e. functions of the residuals in the linear rule).

This PR adds an initial forward sweep in ad.backward_pass (based on core.eval_jaxpr) to evaluate any of these deferred operations before transposing the remaining eqns. I think this is a reasonable approach, but I'm not sure I've thought through all the potential failure modes. WDYT?

@dfm dfm requested review from mattjj and dougalm February 27, 2025 16:18
@dfm dfm self-assigned this Feb 27, 2025
@dfm
Copy link
Collaborator Author

dfm commented Feb 27, 2025

Looks like this has some failures that seem related to the "type errors" mentioned in the FIXME comments. I knew it couldn't be as straightforward as this :D

I'm not sure I totally understand the source of these issues, but I'll think about it some more with this as a starting point!

@mattjj
Copy link
Collaborator

mattjj commented Feb 27, 2025

I suspect this approach doesn't work, because I think we tried it in #1749 (note "To make remat_call_p transposable, we need to upgrade our ad.backward_pass code to do some nonlinear evaluation (to rematerialize the residuals needed)" in the PR message, and the double loop over eqns here and here). See also the predecessor #1719 which I think was lumped into #1749.

The trouble was that a HOP like jit could include a mix of linear and nonlinear stuff. That is, we can't assume that a HOP is only purely linear or nonlinear, and we may have to run just the nonlinear parts of it first and run the linear parts later. (I haven't read this code carefully but I suspect we can construct examples that this PR won't be able to transpose, e.g. by making an inner jit that includes linear and nonlinear parts, then trying to transpose wrt the linear inputs using jax.linear_transpose. I haven't actually tried that yet though.)

To run nested nonlinear parts, saving any intermediates we need in HOP bodies (ie recursing in to handle mixed linear/nonlinear HOPs, rather than just doing it at the top level and assuming each HOP application is purely linear or purely nonlinear), is essentially partial evaluation. Indeed that's the approach we eventually took for remat transpose: we evaluate the nonlinear parts using partial eval and then run the linear-only backward_pass on the result. (By the way, we use the same logic for transposing run_state and five_loop, since we realized it's exactly what we need for handling mutable arrays.)

We could build that same logic into ad.backward_pass, ie in it use partial evaluation to evaluate the (nested) nonlinear stuff, then apply the current logic to transpose the purely linear part. That's worth considering, but it won't let us delete partial eval!

@dfm
Copy link
Collaborator Author

dfm commented Feb 27, 2025

Thanks for this clear (as always!!) explanation @mattjj, and for the references. Indeed, this all makes sense, and I see the complications.

For the specific case of direct linearization, I guess we could update the lin rule signature to explicitly return the unzipped functions: the usual linear part, plus a new one to compute the residuals from the primals. But, that doesn't seem very useful in most cases.

That's worth considering, but it won't let us delete partial eval!

Yeah - I guess there's also the question of how important is it that we support this behavior in linearization rules. I can see in principle why we might want to do this, but did you or @dougalm have any specific applications in mind? It seems like it would be simple enough to support this in the case of jax.custom_lin API...

@dougalm
Copy link
Collaborator

dougalm commented Feb 27, 2025

Sorry, I wasn't clear about my original proposal for nonlinear operations within linear functions. In particular, linear functions may have nonlinear ops but they can't return nonlinear values. So there should be no need to worry about unzipping functions!

Here's the type system I had in mind. Each variable and literal in a jaxpr is tagged with a linearity, which is just a boolean attribute that can be either "linear" or "nonlinear". Also, each jaxpr has a linearity (a jaxpr can be linear or nonlinear. There is no mixed option). And each eqn in a jaxpr has a linearity too. Again, eqns can be linear or nonlinear. Never a mixture.

The rule for linear jaxprs is that they only return linear values. Their arguments can be any mixture of linear an nonlinear. Similarly, linear eqns can only return linear values. Nonlinear eqns can only have nonlinear arguments. Linear eqns can have a mixture of linear and nonlinear arguments depending on the op. We could imagine a "check_linearity" function that checks whether the argument linearity is acceptable for a given op. For example, here are the rules for add and mul.

IsLinear = bool
def check_linearity_add(x_lin:IsLinear, y_lin:IsLinear):
  assert x_lin and y_lin

def check_linearity_mul(x_lin:IsLinear, y_lin:IsLinear):
  assert (x_lin and not y_lin) or (not x_lin and y_lin)

The way to transpose a linear jaxpr is to pull all the nonlinear eqns to the beginning and then reverse the order of the linear eqns and transpose their operations. Easy.

The challenge is that we don't currently have these linearity annotations on jaxprs so we need to infer them on the fly. We can mostly do this by looking at data dependencies. Any variable that depends on a linear variable is linear. The rest are nonlinear. This is what "undefined primal" is all about in transposition. "undefined primal" means "I've inferred that this is a linear value". But there are some genuinely ambiguous cases. np.zeros can be considered linear even if it doesn't depend on a linear value. But it can be considered nonlinear too. There's no way to infer which way it's intended. Probably we're fine if we just assume operations are nonlinear by default if they don't depend on linear values. But it would be much simpler if we just had the annotations.

Anyway, bottom line: no need to worry about unzipping mixed linear/nonlinear functions or HOPs.

@dfm
Copy link
Collaborator Author

dfm commented Feb 27, 2025

Thanks @dougalm! I think that what you've suggested here is close to what I have implemented in this PR: Any eqn that any linear inputs is treated as linear, while the others are treated as non-linear, and evaluated first. But, my implementation introduces some failures that seem to indicate some problems with the logic the way that I have implemented it. In particular, there seem to be issues with the control flow tests, so Matt's comment about HOPs seems to be pointing in the right direction wrt what I'm missing!

@dougalm
Copy link
Collaborator

dougalm commented Feb 27, 2025

Yes exactly, your implementation looks like the best approximation we can currently muster until we add linearity annotations. I'm curious to understand where it fell down. Can you paste the jaxpr that failed to transpose under your new logic?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants