Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update ad.backward_pass to support non-linear functions of constants #26811

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion jax/_src/interpreters/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,12 +330,31 @@ def write_primal(v, val):
# forces primal_in to contain UndefinedPrimals for tangent values!
map(write_primal, jaxpr.invars, primals_in)

# Start with a forward pass to evaluate any JaxprEqns that only operate on
# primals. This is required to support primitives with linearization rules
# that include computations on the residuals.
lin_eqns = []
for eqn in jaxpr.eqns:
if any(type(x) is not Literal and x not in primal_env for x in eqn.invars):
lin_eqns.append(eqn)
continue
subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params)
name_stack = source_info_util.current_name_stack() + eqn.source_info.name_stack
traceback = eqn.source_info.traceback
with source_info_util.user_context(
traceback, name_stack=name_stack), eqn.ctx.manager:
ans = eqn.primitive.bind(*subfuns, *map(read_primal, eqn.invars), **bind_params)
if eqn.primitive.multiple_results:
map(write_primal, eqn.outvars, ans)
else:
write_primal(eqn.outvars[0], ans)

ct_env: dict[Any, Any] = {}
ctx = (source_info_util.transform_name_stack('transpose') if transform_stack
else contextlib.nullcontext())
with ctx:
map(partial(write_cotangent, 'outvars'), jaxpr.outvars, cotangents_in)
for eqn in jaxpr.eqns[::-1]:
for eqn in lin_eqns[::-1]:
if eqn.primitive.ref_primitive:
if eqn.primitive is core.mutable_array_p:
val_var, = eqn.invars
Expand Down
14 changes: 14 additions & 0 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from jax._src import debugging
from jax._src import pjit as pjit_lib
from jax._src.ad_checkpoint import saved_residuals
from jax._src.interpreters import ad as ad_internal
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.compilation_cache import is_persistent_cache_enabled
Expand Down Expand Up @@ -4721,6 +4722,19 @@ def sin_of_sin(x):

check_invariant_to_use_direct_linearize(lambda: jax.grad(sin_of_sin)(1.0))

def test_deferred_primal_with_direct_linearize(self):
def my_sin_lin(nzs, x):
nz, = nzs
return (my_sin_p.bind(x), nz, x, lambda x, t: lax.mul(t, lax.cos(x)))

my_sin_p = core.Primitive("my_sin_p")
my_sin_p.def_impl(lax.sin)
my_sin_p.def_abstract_eval(lambda x: x)
ad_internal.primitive_linearizations[my_sin_p] = my_sin_lin

with config.use_direct_linearize(True):
jax.grad(my_sin_p.bind)(1.0) # doesn't crash


class RematTest(jtu.JaxTestCase):

Expand Down
Loading