diff --git a/diffusion/datasets/image_caption.py b/diffusion/datasets/image_caption.py index 107bec60..37157f8a 100644 --- a/diffusion/datasets/image_caption.py +++ b/diffusion/datasets/image_caption.py @@ -32,7 +32,9 @@ class StreamingImageCaptionDataset(StreamingDataset): ``StreamingImageCaptionDataset`` uses either ``streams`` or ``remote``/``local``. Default:``None``. remote (str, optional): Remote directory (S3 or local filesystem) where dataset is stored. Default: ``None``. local (str, optional): Local filesystem directory where dataset is cached during operation. Default: ``None``. - tokenizer_name_or_path (str): The name or path of the tokenizer to use. Default: ``'stabilityai/stable-diffusion-2-base'``. + tokenizer_name_or_path (str): The name or path of the tokenizer to use. + ``'stabilityai/stable-diffusion-2-base'``, ``'stabilityai/stable-diffusion-xl-base-1.0'`` or ``'sdxl-e5'``. + Default: ``'stabilityai/stable-diffusion-2-base'``. caption_drop_prob (float): The probability of dropping a caption. Default: ``0.0``. microcond_drop_prob (float): The probability of dropping microconditioning. Only relevant for SDXL. Default: ``0.0``. caption_selection (str): If there are multiple captions, specifies how to select a single caption. @@ -42,7 +44,6 @@ class StreamingImageCaptionDataset(StreamingDataset): transform (Callable, optional): The transforms to apply to the image. Default: ``None``. image_key (str): Key associated with the image in the streaming dataset. Default: ``'image'``. caption_key (str): Key associated with the caption in the streaming dataset. Default: ``'caption'``. - sdxl (bool): Whether or not we're training SDXL. Default: `False`. zero_dropped_captions (bool): If True, zero out text embeddings for dropped captions. Default: ``False``. **streaming_kwargs: Additional arguments to pass in the construction of the StreamingDataloader @@ -61,7 +62,6 @@ def __init__( transform: Optional[Callable] = None, image_key: str = 'image', caption_key: str = 'caption', - sdxl: bool = False, zero_dropped_captions: bool = False, **streaming_kwargs, ) -> None: @@ -82,7 +82,6 @@ def __init__( self.crop = crop self.transform = transform - self.sdxl = sdxl self.caption_drop_prob = caption_drop_prob self.microcond_drop_prob = microcond_drop_prob self.caption_selection = caption_selection @@ -90,6 +89,7 @@ def __init__( self.caption_key = caption_key self.zero_dropped_captions = zero_dropped_captions + self.sdxl = _is_sdxl(tokenizer_name_or_path) if self.sdxl: self.tokenizer = SDXLTokenizer(tokenizer_name_or_path) else: @@ -197,7 +197,9 @@ def build_streaming_image_caption_dataloader( remote (str, Sequence[str]): One or more remote directories (S3 or local filesystem) where dataset is stored. local (str, Sequence[str]): One or more local filesystem directories where dataset is cached during operation. batch_size (int): The batch size to use for both the ``StreamingDataset`` and ``DataLoader``. - tokenizer_name_or_path (str): The name or path of the tokenizer to use. Default: ``'stabilityai/stable-diffusion-2-base'``. + tokenizer_name_or_path (str): The name or path of the tokenizer to use. + ``'stabilityai/stable-diffusion-2-base'``, ``'stabilityai/stable-diffusion-xl-base-1.0'`` or ``'sdxl-e5'``. + Default: ``'stabilityai/stable-diffusion-2-base'``. caption_drop_prob (float): The probability of dropping a caption. Default: ``0.0``. microcond_drop_prob (float): The probability of dropping microconditioning. Only relevant for SDXL. Default: ``0.0``. resize_size (int): The size to resize the image to. Default: ``256``. @@ -241,9 +243,11 @@ def build_streaming_image_caption_dataloader( streams.append(Stream(remote=r, local=l)) # Infer SDXL from tokenizer path - sdxl = (tokenizer_name_or_path == 'stabilityai/stable-diffusion-xl-base-1.0') + sdxl = _is_sdxl(tokenizer_name_or_path) if sdxl: log.info('Detected SDXL tokenizer, using SDXL crop transform and tokenizers.') + if tokenizer_name_or_path == 'sdxl-e5': + log.info('Using E5 text encoder') # Set the crop to apply if crop_type == 'square': @@ -271,7 +275,6 @@ def build_streaming_image_caption_dataloader( image_key=image_key, caption_key=caption_key, batch_size=batch_size, - sdxl=sdxl, zero_dropped_captions=zero_dropped_captions, **streaming_kwargs, ) @@ -284,3 +287,8 @@ def build_streaming_image_caption_dataloader( ) return dataloader + + +def _is_sdxl(tokenizer_name_or_path): + """Infer SDXL from tokenizer path.""" + return (tokenizer_name_or_path == 'stabilityai/stable-diffusion-xl-base-1.0' or tokenizer_name_or_path == 'sdxl-e5') diff --git a/diffusion/models/models.py b/diffusion/models/models.py index 0bc8d557..4321b8ca 100644 --- a/diffusion/models/models.py +++ b/diffusion/models/models.py @@ -13,7 +13,8 @@ from torchmetrics import MeanSquaredError from torchmetrics.image.fid import FrechetInceptionDistance from torchmetrics.multimodal.clip_score import CLIPScore -from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, PretrainedConfig +from transformers import (AutoModel, AutoTokenizer, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, + PretrainedConfig) from diffusion.models.autoencoder import (AutoEncoder, AutoEncoderLoss, ComposerAutoEncoder, ComposerDiffusersAutoEncoder, load_autoencoder) @@ -222,7 +223,8 @@ def stable_diffusion_xl( Args: model_name (str): Name of the model to load. Determines the text encoders, tokenizers, - and noise scheduler. Defaults to 'stabilityai/stable-diffusion-xl-base-1.0'. + and noise scheduler. 'sdxl-e5' or 'stabilityai/stable-diffusion-xl-base-1.0'. + Defaults to 'stabilityai/stable-diffusion-xl-base-1.0'. unet_model_name (str): Name of the UNet model to load. Defaults to 'stabilityai/stable-diffusion-xl-base-1.0'. vae_model_name (str): Name of the VAE model to load. Defaults to @@ -311,6 +313,8 @@ def stable_diffusion_xl( # Adapt the unet config to account for differing number of latent channels if necessary unet_config['in_channels'] = vae.config['latent_channels'] unet_config['out_channels'] = vae.config['latent_channels'] + if model_name == 'sdxl-e5': # e5 + clip embedding dims + micro conditioning + unet_config['cross_attention_dim'] = 2304 # Init the unet from the config unet = UNet2DConditionModel(**unet_config) @@ -325,8 +329,18 @@ def stable_diffusion_xl( # Last conv block out projection unet.conv_out = zero_module(unet.conv_out) - # Make the noise schedulers - noise_scheduler = DDPMScheduler.from_pretrained(model_name, subfolder='scheduler') + torch_dtype = torch.float16 if encode_latents_in_fp16 else None + try: + vae = AutoencoderKL.from_pretrained(vae_model_name, subfolder='vae', torch_dtype=torch_dtype) + except: # for handling SDXL vae fp16 fixed checkpoint + vae = AutoencoderKL.from_pretrained(vae_model_name, torch_dtype=torch_dtype) + + tokenizer = SDXLTokenizer(model_name) + text_encoder = SDXLTextEncoder(model_name, encode_latents_in_fp16) + + scheduler_model_name = 'stabilityai/stable-diffusion-xl-base-1.0' if model_name == 'sdxl-e5' else model_name + noise_scheduler = DDPMScheduler.from_pretrained(scheduler_model_name, subfolder='scheduler') + inference_noise_scheduler = EulerDiscreteScheduler(num_train_timesteps=1000, beta_start=0.00085, beta_end=0.012, @@ -635,14 +649,23 @@ class SDXLTextEncoder(torch.nn.Module): Creates two text encoders (a CLIPTextModel and CLIPTextModelWithProjection) that behave like one. Args: - model_name (str): Name of the model's text encoders to load. Defaults to 'stabilityai/stable-diffusion-xl-base-1.0'. + model_name (str): Name of the model's text encoders to load. + 'sdxl-e5' or 'stabilityai/stable-diffusion-xl-base-1.0'. + Default: 'stabilityai/stable-diffusion-xl-base-1.0'. encode_latents_in_fp16 (bool): Whether to encode latents in fp16. Defaults to True. """ def __init__(self, model_name='stabilityai/stable-diffusion-xl-base-1.0', encode_latents_in_fp16=True): super().__init__() + _validate_model_name(model_name) torch_dtype = torch.float16 if encode_latents_in_fp16 else None - self.text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder='text_encoder', torch_dtype=torch_dtype) + if model_name == 'sdxl-e5': + self.text_encoder = AutoModel.from_pretrained('intfloat/e5-large-v2', torch_dtype=torch_dtype) + model_name = 'stabilityai/stable-diffusion-xl-base-1.0' # set model name to sdxl to pull other encoder + else: + self.text_encoder = CLIPTextModel.from_pretrained(model_name, + subfolder='text_encoder', + torch_dtype=torch_dtype) self.text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(model_name, subfolder='text_encoder_2', torch_dtype=torch_dtype) @@ -669,18 +692,25 @@ class SDXLTokenizer: Tokenizes prompt with two tokenizers and returns the joined output. Args: - model_name (str): Name of the model's text encoders to load. Defaults to 'stabilityai/stable-diffusion-xl-base-1.0'. + model_name (str): Name of the model's tokenizers to load. + 'sdxl-e5' or 'stabilityai/stable-diffusion-xl-base-1.0'. + Default: 'stabilityai/stable-diffusion-xl-base-1.0'. """ def __init__(self, model_name='stabilityai/stable-diffusion-xl-base-1.0'): - self.tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder='tokenizer') + _validate_model_name(model_name) + if model_name == 'sdxl-e5': + self.tokenizer = AutoTokenizer.from_pretrained('intfloat/e5-large-v2') + model_name = 'stabilityai/stable-diffusion-xl-base-1.0' + else: + self.tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder='tokenizer') self.tokenizer_2 = CLIPTokenizer.from_pretrained(model_name, subfolder='tokenizer_2') def __call__(self, prompt, padding, truncation, return_tensors, max_length=None): tokenized_output = self.tokenizer( prompt, padding=padding, - max_length=self.tokenizer.model_max_length if max_length is None else max_length, + max_length=self.tokenizer_2.model_max_length if max_length is None else max_length, truncation=truncation, return_tensors=return_tensors) tokenized_output_2 = self.tokenizer_2( @@ -691,6 +721,12 @@ def __call__(self, prompt, padding, truncation, return_tensors, max_length=None) return_tensors=return_tensors) # Add second tokenizer output to first tokenizer - for key in tokenized_output.keys(): + for key in tokenized_output_2.keys(): tokenized_output[key] = [tokenized_output[key], tokenized_output_2[key]] return tokenized_output + + +def _validate_model_name(model_name): + valid_model_names = {'sdxl-e5', 'stabilityai/stable-diffusion-xl-base-1.0'} + if model_name not in valid_model_names: + raise ValueError(f'model_name must be one of {valid_model_names}.') diff --git a/diffusion/models/stable_diffusion.py b/diffusion/models/stable_diffusion.py index de92e3e8..b2fe4fda 100644 --- a/diffusion/models/stable_diffusion.py +++ b/diffusion/models/stable_diffusion.py @@ -283,8 +283,6 @@ def eval_forward(self, batch, outputs=None): # Skip this if outputs have already been computed, e.g. during training if outputs is not None: return outputs - # Get unet outputs - unet_out, targets, timesteps = self.forward(batch) # Sample images from the prompts in the batch prompts = batch[self.text_key] height, width = batch[self.image_key].shape[-2], batch[self.image_key].shape[-1] @@ -306,6 +304,8 @@ def eval_forward(self, batch, outputs=None): # Set to resolution we are trying to generate batch['cond_target_size'] = torch.tensor([[width, height]]).repeat(bsz, 1).to(device) + unet_out, targets, timesteps = self.forward(batch) + generated_images = {} for guidance_scale in self.val_guidance_scales: gen_images = self.generate(tokenized_prompts=prompts, diff --git a/tests/test_model.py b/tests/test_model.py index d2091ac4..1deae55a 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -7,7 +7,7 @@ import pytest import torch -from diffusion.models.models import stable_diffusion_2 +from diffusion.models.models import stable_diffusion_2, stable_diffusion_xl def test_model_forward(): @@ -44,3 +44,34 @@ def test_model_generate(guidance_scale, negative_prompt): progress_bar=False, ) assert output.shape == (1, 3, 8, 8) + + +@pytest.mark.parametrize('model_name', ['stabilityai/stable-diffusion-xl-base-1.0', 'sdxl-e5']) +def test_model_forward_sdxl(model_name): + model = stable_diffusion_xl(model_name=model_name, + pretrained=False, + fsdp=False, + encode_latents_in_fp16=False, + clip_qkv=None) + batch_size = 1 + H = 32 + W = 32 + image = torch.randn(batch_size, 3, H, W) + latent = torch.randn(batch_size, 4, H // 8, W // 8) + caption = torch.randint(low=0, high=128, size=( + batch_size, + 77, + ), dtype=torch.long) + caption = torch.stack([caption, caption], dim=1) + micro_conditioning = torch.randint(1, H, (batch_size, 2)) + + batch = { + 'image': image, + 'captions': caption, + 'cond_original_size': micro_conditioning, + 'cond_crops_coords_top_left': micro_conditioning, + 'cond_target_size': micro_conditioning + } + output, target, _ = model(batch) # model.forward generates the unet output noise or v_pred target. + assert output.shape == latent.shape + assert target.shape == latent.shape