Fixed printing order of results in jax.debug.print
documentation
#26839
+2
−2
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Here: https://docs.jax.dev/en/latest/debugging/print_breakpoint.html#printing-under-jax-pmap
It mentions that when
jax.pmap
-ed,jax.debug.prints
outputs might be reordered!But, in the comments, they show printing in the same order 2 times. I initially got confused reading that.
I think the comments should show the printing order in both ways like it is shown here: https://docs.jax.dev/en/latest/debugging/print_breakpoint.html#ordering-of-printed-results