diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py index 1fd371034a09..65a284477436 100644 --- a/jax/_src/api_util.py +++ b/jax/_src/api_util.py @@ -590,19 +590,19 @@ def debug_info( *, static_argnums: Sequence[int] = (), static_argnames: Sequence[str] = (), - result_paths_thunk: Callable[[], tuple[str, ...]] | None = None, + result_paths: tuple[str, ...] | Callable[[], tuple[str, ...]] | None = None, # TODO(necula): check if we really need this, e.g., to speed up tracing? sourceinfo: str | None = None, signature: inspect.Signature | None = None, ) -> core.DebugInfo: """Constructd core.DebugInfo for a function given example args and kwargs. + See docstring for linear_util.DebugInfo. + `args` and `kwargs` are example positional and keyword arguments, users with `inspect.Signature` to get the names of argments. The arguments that are considered static for tracing purposes should be included, and designated using `static_argnums` and `static_argnames`. - - See docstring for linear_util.DebugInfo. """ if sourceinfo is None: sourceinfo = fun_sourceinfo(fun) @@ -610,7 +610,7 @@ def debug_info( signature = fun_signature(fun) arg_names = _non_static_arg_names(signature, args, kwargs, static_argnums, static_argnames) - return core.DebugInfo(traced_for, sourceinfo, arg_names, result_paths_thunk) + return core.DebugInfo(traced_for, sourceinfo, arg_names, result_paths) def fun_signature(fun: Callable) -> inspect.Signature | None: diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 37ad40d22a3d..b19ecba69d1b 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -761,7 +761,10 @@ def fallback_linearize_rule(_prim: core.Primitive, if not jvp: msg = f"Differentiation rule for '{_prim}' not implemented" raise NotImplementedError(msg) - debug_jvp = debug_info("linearize_prim_jvp", jvp, primals, params) + # TODO(necula): this is needed when JAX_USE_DIRECT_LINEARIZE=1; figure out + # how to set the result_paths. + debug_jvp = debug_info("linearize_prim_jvp", jvp, primals, params, + result_paths=("",)) return linearize_from_jvp(jvp, _prim.multiple_results, _nonzeros, False, False, debug_jvp, primals, params) @@ -855,7 +858,7 @@ def to_concrete_value(self): primitive_jvps : dict[core.Primitive, Callable] = {} primitive_transposes: dict[core.Primitive, Callable] = {} -primitive_linearizations : dict[core.Primitive, Callable] = {} +primitive_linearizations : dict[core.Primitive, Callable] = {} def deflinear(primitive, transpose_rule): primitive_jvps[primitive] = partial(linear_jvp, primitive)