From 56285aec6b7e6d41efd99544467acfd7033b6576 Mon Sep 17 00:00:00 2001 From: Sai-Suraj-27 Date: Fri, 28 Feb 2025 06:17:14 +0000 Subject: [PATCH] Fixed printing order of results in jax.debug.print documentation. --- docs/debugging/print_breakpoint.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/debugging/print_breakpoint.md b/docs/debugging/print_breakpoint.md index 73ac0262851d..85580120c0a9 100644 --- a/docs/debugging/print_breakpoint.md +++ b/docs/debugging/print_breakpoint.md @@ -91,8 +91,8 @@ def f(x): jax.debug.print("x: {}", x) return x jax.pmap(f)(xs) -# Prints: x: 1.0 -# x: 0.0 +# Prints: x: 0.0 +# x: 1.0 # OR # Prints: x: 1.0 # x: 0.0