From 6caf4475fd63698638fd1d6303b6dcf274c711ff Mon Sep 17 00:00:00 2001 From: = Date: Fri, 28 Feb 2025 06:17:14 +0000 Subject: [PATCH] Fixed printing order of results in jax.debug.print documentation. --- docs/debugging/print_breakpoint.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/debugging/print_breakpoint.md b/docs/debugging/print_breakpoint.md index 73ac0262851d..85580120c0a9 100644 --- a/docs/debugging/print_breakpoint.md +++ b/docs/debugging/print_breakpoint.md @@ -91,8 +91,8 @@ def f(x): jax.debug.print("x: {}", x) return x jax.pmap(f)(xs) -# Prints: x: 1.0 -# x: 0.0 +# Prints: x: 0.0 +# x: 1.0 # OR # Prints: x: 1.0 # x: 0.0