Skip to content

Commit

Permalink
Fix test train_unconditional (huggingface#2481)
Browse files Browse the repository at this point in the history
* Fix tensorboard tracking with `accelerate` @ `main`

* Fix `train_unconditional.py` with accelerate from main.
  • Loading branch information
pcuenca authored Feb 24, 2023
1 parent 54bc882 commit 5de4347
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions examples/unconditional_image_generation/train_unconditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
from diffusers.utils import check_min_version, is_tensorboard_available, is_wandb_available
from diffusers.utils import check_min_version, is_accelerate_version, is_tensorboard_available, is_wandb_available


# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
Expand Down Expand Up @@ -628,10 +628,13 @@ def transform_images(examples):
images_processed = (images * 255).round().astype("uint8")

if args.logger == "tensorboard":
accelerator.get_tracker("tensorboard").add_images(
"test_samples", images_processed.transpose(0, 3, 1, 2), epoch
)
if is_accelerate_version(">=", "0.17.0.dev0"):
tracker = accelerator.get_tracker("tensorboard", unwrap=True)
else:
tracker = accelerator.get_tracker()
tracker.add_images("test_samples", images_processed.transpose(0, 3, 1, 2), epoch)
elif args.logger == "wandb":
# Upcoming `log_images` helper coming in /~https://github.com/huggingface/accelerate/pull/962/files
accelerator.get_tracker("wandb").log(
{"test_samples": [wandb.Image(img) for img in images_processed], "epoch": epoch},
step=global_step,
Expand Down

0 comments on commit 5de4347

Please sign in to comment.