Skip to content

Commit

Permalink
Fix assertion fails on DebugInfo with JAX_USE_DIRECT_LINEARIZE=1
Browse files Browse the repository at this point in the history
When using JAX_USE_DIRECT_LINEARIZE=1 there was an assertion that
result_paths is not None. This is because in this case we create
a DebugInfo that is not used with `lu.wrap_init` (which normally
takes care of the result_paths). To allow others to make
progress, I just fill in a (wrong) result_paths, and a TODO
to come back and fix this later.
  • Loading branch information
gnecula committed Feb 27, 2025
1 parent 99a12ef commit 766abba
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
8 changes: 4 additions & 4 deletions jax/_src/api_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,27 +590,27 @@ def debug_info(
*,
static_argnums: Sequence[int] = (),
static_argnames: Sequence[str] = (),
result_paths_thunk: Callable[[], tuple[str, ...]] | None = None,
result_paths: tuple[str, ...] | Callable[[], tuple[str, ...]] | None = None,
# TODO(necula): check if we really need this, e.g., to speed up tracing?
sourceinfo: str | None = None,
signature: inspect.Signature | None = None,
) -> core.DebugInfo:
"""Constructd core.DebugInfo for a function given example args and kwargs.
See docstring for linear_util.DebugInfo.
`args` and `kwargs` are example positional and keyword arguments, users with
`inspect.Signature` to get the names of argments. The arguments that are
considered static for tracing purposes should be included, and designated
using `static_argnums` and `static_argnames`.
See docstring for linear_util.DebugInfo.
"""
if sourceinfo is None:
sourceinfo = fun_sourceinfo(fun)
if signature is None:
signature = fun_signature(fun)
arg_names = _non_static_arg_names(signature, args, kwargs, static_argnums,
static_argnames)
return core.DebugInfo(traced_for, sourceinfo, arg_names, result_paths_thunk)
return core.DebugInfo(traced_for, sourceinfo, arg_names, result_paths)


def fun_signature(fun: Callable) -> inspect.Signature | None:
Expand Down
7 changes: 5 additions & 2 deletions jax/_src/interpreters/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,7 +761,10 @@ def fallback_linearize_rule(_prim: core.Primitive,
if not jvp:
msg = f"Differentiation rule for '{_prim}' not implemented"
raise NotImplementedError(msg)
debug_jvp = debug_info("linearize_prim_jvp", jvp, primals, params)
# TODO(necula): this is needed when JAX_USE_DIRECT_LINEARIZE=1; figure out
# how to set the result_paths.
debug_jvp = debug_info("linearize_prim_jvp", jvp, primals, params,
result_paths=("",))
return linearize_from_jvp(jvp, _prim.multiple_results, _nonzeros, False, False,
debug_jvp, primals, params)

Expand Down Expand Up @@ -855,7 +858,7 @@ def to_concrete_value(self):

primitive_jvps : dict[core.Primitive, Callable] = {}
primitive_transposes: dict[core.Primitive, Callable] = {}
primitive_linearizations : dict[core.Primitive, Callable] = {}
primitive_linearizations : dict[core.Primitive, Callable] = {}

def deflinear(primitive, transpose_rule):
primitive_jvps[primitive] = partial(linear_jvp, primitive)
Expand Down

0 comments on commit 766abba

Please sign in to comment.