Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to use E5 text encoder for SDXL #108

Closed
wants to merge 26 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
6f36711
add hacky e5 option to sdxl
A-Jacobson Nov 22, 2023
38f3231
fix masking broadcasting
A-Jacobson Nov 22, 2023
0ddbc7f
use value, not batch for masking
A-Jacobson Nov 22, 2023
4856b49
use e5 instead of overai clip or projection layer
A-Jacobson Nov 22, 2023
39307a6
reverse tokenizer dicts for merging
A-Jacobson Nov 22, 2023
f65412d
change attetion conditioning dim in unet for e5
A-Jacobson Nov 22, 2023
94c97af
fix conditioning dims for e5 in unet
A-Jacobson Nov 22, 2023
bc506ef
add tests for sdxl foward and sdxl_e5 forward
A-Jacobson Nov 22, 2023
30ad22d
fix sdxl_forward naming
A-Jacobson Nov 22, 2023
f3f9a17
print conditioning dim
A-Jacobson Nov 22, 2023
23d7c21
new config key
A-Jacobson Nov 22, 2023
972971e
swap back proj layer dim
A-Jacobson Nov 22, 2023
eb175a7
only change xatten dim
A-Jacobson Nov 22, 2023
de8b7c1
try to fix masking
A-Jacobson Nov 27, 2023
227114c
Merge branch 'main' into e5
A-Jacobson Jan 3, 2024
42a9923
add docstrings
A-Jacobson Jan 3, 2024
1ada4b8
pre-commit hooks
A-Jacobson Jan 3, 2024
b5badcd
update sdxl dummy batch
A-Jacobson Jan 3, 2024
80dc844
pre-commit
A-Jacobson Jan 3, 2024
4b3237d
remove boolean flag, replace with sdxl-e5 model_name and tokenizer_path
A-Jacobson Jan 30, 2024
fdb99a4
update sdxl tets
A-Jacobson Jan 30, 2024
af9a06d
pull everything but e5 using standing sdxl model_name
A-Jacobson Jan 30, 2024
d6a7191
fix sdxl boolean in image caption dataset
A-Jacobson Jan 30, 2024
d80dc75
styling
A-Jacobson Jan 30, 2024
31e0570
Merge branch 'main' into e5
A-Jacobson Feb 21, 2024
37693a4
Update diffusion/models/models.py
A-Jacobson Feb 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions diffusion/datasets/image_caption.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ class StreamingImageCaptionDataset(StreamingDataset):
``StreamingImageCaptionDataset`` uses either ``streams`` or ``remote``/``local``. Default:``None``.
remote (str, optional): Remote directory (S3 or local filesystem) where dataset is stored. Default: ``None``.
local (str, optional): Local filesystem directory where dataset is cached during operation. Default: ``None``.
tokenizer_name_or_path (str): The name or path of the tokenizer to use. Default: ``'stabilityai/stable-diffusion-2-base'``.
tokenizer_name_or_path (str): The name or path of the tokenizer to use.
``'stabilityai/stable-diffusion-2-base'``, ``'stabilityai/stable-diffusion-xl-base-1.0'`` or ``'sdxl-e5'``.
Default: ``'stabilityai/stable-diffusion-2-base'``.
caption_drop_prob (float): The probability of dropping a caption. Default: ``0.0``.
microcond_drop_prob (float): The probability of dropping microconditioning. Only relevant for SDXL. Default: ``0.0``.
caption_selection (str): If there are multiple captions, specifies how to select a single caption.
Expand All @@ -42,7 +44,6 @@ class StreamingImageCaptionDataset(StreamingDataset):
transform (Callable, optional): The transforms to apply to the image. Default: ``None``.
image_key (str): Key associated with the image in the streaming dataset. Default: ``'image'``.
caption_key (str): Key associated with the caption in the streaming dataset. Default: ``'caption'``.
sdxl (bool): Whether or not we're training SDXL. Default: `False`.
zero_dropped_captions (bool): If True, zero out text embeddings for dropped captions. Default: ``False``.

**streaming_kwargs: Additional arguments to pass in the construction of the StreamingDataloader
Expand All @@ -61,7 +62,6 @@ def __init__(
transform: Optional[Callable] = None,
image_key: str = 'image',
caption_key: str = 'caption',
sdxl: bool = False,
zero_dropped_captions: bool = False,
**streaming_kwargs,
) -> None:
Expand All @@ -82,14 +82,14 @@ def __init__(

self.crop = crop
self.transform = transform
self.sdxl = sdxl
self.caption_drop_prob = caption_drop_prob
self.microcond_drop_prob = microcond_drop_prob
self.caption_selection = caption_selection
self.image_key = image_key
self.caption_key = caption_key
self.zero_dropped_captions = zero_dropped_captions

self.sdxl = _is_sdxl(tokenizer_name_or_path)
if self.sdxl:
self.tokenizer = SDXLTokenizer(tokenizer_name_or_path)
else:
Expand Down Expand Up @@ -197,7 +197,9 @@ def build_streaming_image_caption_dataloader(
remote (str, Sequence[str]): One or more remote directories (S3 or local filesystem) where dataset is stored.
local (str, Sequence[str]): One or more local filesystem directories where dataset is cached during operation.
batch_size (int): The batch size to use for both the ``StreamingDataset`` and ``DataLoader``.
tokenizer_name_or_path (str): The name or path of the tokenizer to use. Default: ``'stabilityai/stable-diffusion-2-base'``.
tokenizer_name_or_path (str): The name or path of the tokenizer to use.
``'stabilityai/stable-diffusion-2-base'``, ``'stabilityai/stable-diffusion-xl-base-1.0'`` or ``'sdxl-e5'``.
Default: ``'stabilityai/stable-diffusion-2-base'``.
caption_drop_prob (float): The probability of dropping a caption. Default: ``0.0``.
microcond_drop_prob (float): The probability of dropping microconditioning. Only relevant for SDXL. Default: ``0.0``.
resize_size (int): The size to resize the image to. Default: ``256``.
Expand Down Expand Up @@ -241,9 +243,11 @@ def build_streaming_image_caption_dataloader(
streams.append(Stream(remote=r, local=l))

# Infer SDXL from tokenizer path
sdxl = (tokenizer_name_or_path == 'stabilityai/stable-diffusion-xl-base-1.0')
sdxl = _is_sdxl(tokenizer_name_or_path)
if sdxl:
log.info('Detected SDXL tokenizer, using SDXL crop transform and tokenizers.')
if tokenizer_name_or_path == 'sdxl-e5':
log.info('Using E5 text encoder')

# Set the crop to apply
if crop_type == 'square':
Expand Down Expand Up @@ -271,7 +275,6 @@ def build_streaming_image_caption_dataloader(
image_key=image_key,
caption_key=caption_key,
batch_size=batch_size,
sdxl=sdxl,
zero_dropped_captions=zero_dropped_captions,
**streaming_kwargs,
)
Expand All @@ -284,3 +287,8 @@ def build_streaming_image_caption_dataloader(
)

return dataloader


def _is_sdxl(tokenizer_name_or_path):
"""Infer SDXL from tokenizer path."""
return (tokenizer_name_or_path == 'stabilityai/stable-diffusion-xl-base-1.0' or tokenizer_name_or_path == 'sdxl-e5')
56 changes: 46 additions & 10 deletions diffusion/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
from torchmetrics import MeanSquaredError
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.multimodal.clip_score import CLIPScore
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, PretrainedConfig
from transformers import (AutoModel, AutoTokenizer, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer,
PretrainedConfig)

from diffusion.models.autoencoder import (AutoEncoder, AutoEncoderLoss, ComposerAutoEncoder,
ComposerDiffusersAutoEncoder, load_autoencoder)
Expand Down Expand Up @@ -222,7 +223,8 @@ def stable_diffusion_xl(

Args:
model_name (str): Name of the model to load. Determines the text encoders, tokenizers,
and noise scheduler. Defaults to 'stabilityai/stable-diffusion-xl-base-1.0'.
and noise scheduler. 'sdxl-e5' or 'stabilityai/stable-diffusion-xl-base-1.0'.
Defaults to 'stabilityai/stable-diffusion-xl-base-1.0'.
unet_model_name (str): Name of the UNet model to load. Defaults to
'stabilityai/stable-diffusion-xl-base-1.0'.
vae_model_name (str): Name of the VAE model to load. Defaults to
Expand Down Expand Up @@ -311,6 +313,8 @@ def stable_diffusion_xl(
# Adapt the unet config to account for differing number of latent channels if necessary
unet_config['in_channels'] = vae.config['latent_channels']
unet_config['out_channels'] = vae.config['latent_channels']
if model_name == 'sdxl-e5': # e5 + clip embedding dims + micro conditioning
unet_config['cross_attention_dim'] = 2304
# Init the unet from the config
unet = UNet2DConditionModel(**unet_config)

Expand All @@ -325,8 +329,18 @@ def stable_diffusion_xl(
# Last conv block out projection
unet.conv_out = zero_module(unet.conv_out)

# Make the noise schedulers
noise_scheduler = DDPMScheduler.from_pretrained(model_name, subfolder='scheduler')
torch_dtype = torch.float16 if encode_latents_in_fp16 else None
try:
vae = AutoencoderKL.from_pretrained(vae_model_name, subfolder='vae', torch_dtype=torch_dtype)
except: # for handling SDXL vae fp16 fixed checkpoint
vae = AutoencoderKL.from_pretrained(vae_model_name, torch_dtype=torch_dtype)

tokenizer = SDXLTokenizer(model_name)
text_encoder = SDXLTextEncoder(model_name, encode_latents_in_fp16)

scheduler_model_name = 'stabilityai/stable-diffusion-xl-base-1.0' if model_name == 'sdxl-e5' else model_name
noise_scheduler = DDPMScheduler.from_pretrained(scheduler_model_name, subfolder='scheduler')

inference_noise_scheduler = EulerDiscreteScheduler(num_train_timesteps=1000,
beta_start=0.00085,
beta_end=0.012,
Expand Down Expand Up @@ -635,14 +649,23 @@ class SDXLTextEncoder(torch.nn.Module):
Creates two text encoders (a CLIPTextModel and CLIPTextModelWithProjection) that behave like one.

Args:
model_name (str): Name of the model's text encoders to load. Defaults to 'stabilityai/stable-diffusion-xl-base-1.0'.
model_name (str): Name of the model's text encoders to load.
'sdxl-e5' or 'stabilityai/stable-diffusion-xl-base-1.0'.
Default: 'stabilityai/stable-diffusion-xl-base-1.0'.
encode_latents_in_fp16 (bool): Whether to encode latents in fp16. Defaults to True.
"""

def __init__(self, model_name='stabilityai/stable-diffusion-xl-base-1.0', encode_latents_in_fp16=True):
super().__init__()
_validate_model_name(model_name)
torch_dtype = torch.float16 if encode_latents_in_fp16 else None
self.text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder='text_encoder', torch_dtype=torch_dtype)
if model_name == 'sdxl-e5':
self.text_encoder = AutoModel.from_pretrained('intfloat/e5-large-v2', torch_dtype=torch_dtype)
model_name = 'stabilityai/stable-diffusion-xl-base-1.0' # set model name to sdxl to pull other encoder
else:
self.text_encoder = CLIPTextModel.from_pretrained(model_name,
subfolder='text_encoder',
torch_dtype=torch_dtype)
self.text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(model_name,
subfolder='text_encoder_2',
torch_dtype=torch_dtype)
Expand All @@ -669,18 +692,25 @@ class SDXLTokenizer:
Tokenizes prompt with two tokenizers and returns the joined output.

Args:
model_name (str): Name of the model's text encoders to load. Defaults to 'stabilityai/stable-diffusion-xl-base-1.0'.
model_name (str): Name of the model's tokenizers to load.
'sdxl-e5' or 'stabilityai/stable-diffusion-xl-base-1.0'.
Default: 'stabilityai/stable-diffusion-xl-base-1.0'.
"""

def __init__(self, model_name='stabilityai/stable-diffusion-xl-base-1.0'):
self.tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder='tokenizer')
_validate_model_name(model_name)
if model_name == 'sdxl-e5':
self.tokenizer = AutoTokenizer.from_pretrained('intfloat/e5-large-v2')
model_name = 'stabilityai/stable-diffusion-xl-base-1.0'
else:
self.tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder='tokenizer')
self.tokenizer_2 = CLIPTokenizer.from_pretrained(model_name, subfolder='tokenizer_2')

def __call__(self, prompt, padding, truncation, return_tensors, max_length=None):
tokenized_output = self.tokenizer(
prompt,
padding=padding,
max_length=self.tokenizer.model_max_length if max_length is None else max_length,
max_length=self.tokenizer_2.model_max_length if max_length is None else max_length,
truncation=truncation,
return_tensors=return_tensors)
tokenized_output_2 = self.tokenizer_2(
Expand All @@ -691,6 +721,12 @@ def __call__(self, prompt, padding, truncation, return_tensors, max_length=None)
return_tensors=return_tensors)

# Add second tokenizer output to first tokenizer
for key in tokenized_output.keys():
for key in tokenized_output_2.keys():
tokenized_output[key] = [tokenized_output[key], tokenized_output_2[key]]
return tokenized_output


def _validate_model_name(model_name):
valid_model_names = {'sdxl-e5', 'stabilityai/stable-diffusion-xl-base-1.0'}
if model_name not in valid_model_names:
raise ValueError(f'model_name must be one of {valid_model_names}.')
4 changes: 2 additions & 2 deletions diffusion/models/stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,6 @@ def eval_forward(self, batch, outputs=None):
# Skip this if outputs have already been computed, e.g. during training
if outputs is not None:
return outputs
# Get unet outputs
unet_out, targets, timesteps = self.forward(batch)
# Sample images from the prompts in the batch
prompts = batch[self.text_key]
height, width = batch[self.image_key].shape[-2], batch[self.image_key].shape[-1]
Expand All @@ -306,6 +304,8 @@ def eval_forward(self, batch, outputs=None):
# Set to resolution we are trying to generate
batch['cond_target_size'] = torch.tensor([[width, height]]).repeat(bsz, 1).to(device)

unet_out, targets, timesteps = self.forward(batch)

generated_images = {}
for guidance_scale in self.val_guidance_scales:
gen_images = self.generate(tokenized_prompts=prompts,
Expand Down
33 changes: 32 additions & 1 deletion tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest
import torch

from diffusion.models.models import stable_diffusion_2
from diffusion.models.models import stable_diffusion_2, stable_diffusion_xl


def test_model_forward():
Expand Down Expand Up @@ -44,3 +44,34 @@ def test_model_generate(guidance_scale, negative_prompt):
progress_bar=False,
)
assert output.shape == (1, 3, 8, 8)


@pytest.mark.parametrize('model_name', ['stabilityai/stable-diffusion-xl-base-1.0', 'sdxl-e5'])
def test_model_forward_sdxl(model_name):
model = stable_diffusion_xl(model_name=model_name,
pretrained=False,
fsdp=False,
encode_latents_in_fp16=False,
clip_qkv=None)
batch_size = 1
H = 32
W = 32
image = torch.randn(batch_size, 3, H, W)
latent = torch.randn(batch_size, 4, H // 8, W // 8)
caption = torch.randint(low=0, high=128, size=(
batch_size,
77,
), dtype=torch.long)
caption = torch.stack([caption, caption], dim=1)
micro_conditioning = torch.randint(1, H, (batch_size, 2))

batch = {
'image': image,
'captions': caption,
'cond_original_size': micro_conditioning,
'cond_crops_coords_top_left': micro_conditioning,
'cond_target_size': micro_conditioning
}
output, target, _ = model(batch) # model.forward generates the unet output noise or v_pred target.
assert output.shape == latent.shape
assert target.shape == latent.shape
Loading