Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CLIPTokenizer does not work as expected #2018

Open
fdtomasi opened this issue Dec 11, 2024 · 4 comments
Open

CLIPTokenizer does not work as expected #2018

fdtomasi opened this issue Dec 11, 2024 · 4 comments
Assignees
Labels
type:Bug Something isn't working

Comments

@fdtomasi
Copy link

To Reproduce

from keras_hub import models
tokenizer = models.Tokenizer.from_preset(
    "clip_vit_h_14_laion2b_s32b_b79k", 
    sequence_length=77,
    pad_with_end_token=True,
)
tokenizer = models.CLIPPreprocessor(tokenizer, sequence_length=77)
tokenizer(["a cat sitting on the table"])

which returns

{'token_ids': <tf.Tensor: shape=(1, 77), dtype=int32, numpy=
 array([[49406,   320,  2368,  4919,   525,   518,  2175,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0, 49407]], dtype=int32)>,
 'padding_mask': <tf.Tensor: shape=(1, 77), dtype=bool, numpy=
 array([[ True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True]])>}

This is surprising because of a few reasons. First, even if pad_with_end_token=True, the pad is using 0 (which correspond to ! in this vocabulary). Also, the end token is added at the end of the padding instead of the end of the original sequence.
Further, padding_mask is all True, while I would expect to be False in correspondence of padding tokens.

Additional context
Using keras_hub==0.18.1, keras==3.7.0.

@james77777778
Copy link
Collaborator

You can work around the issue by not specifying sequence_length in Tokenizer.
I have proposed a fix for this #2031

import keras_hub

preset = "clip_vit_h_14_laion2b_s32b_b79k"
text = ["a cat sitting on the table"]

tokenizer = keras_hub.models.Tokenizer.from_preset(
    preset, pad_with_end_token=True
)
preprocessor = keras_hub.models.CLIPPreprocessor(tokenizer, sequence_length=77)
print(preprocessor(text))
{'token_ids': Array([[49406,   320,  2368,  4919,   525,   518,  2175, 49407, 49407,
        49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
        49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
        49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
        49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
        49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
        49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
        49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
        49407, 49407, 49407, 49407, 49407]], dtype=int32), 'padding_mask': Array([[ True,  True,  True,  True,  True,  True,  True,  True, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False]], dtype=bool)}

@mattdangerw
Copy link
Member

Thanks for the bug! I think @james77777778's suggestion is the correct one, don't set the sequence length of both the tokenizer and preprocessor.

@mattdangerw
Copy link
Member

In general, we want our tokenizers to just handle the string to ragged int mapping. Tokenizers should not pad. And then be composed with other layers (e.g. StartEndPacker) for special token packing and padding. The goal is to keep or tokenizer more narrow, and not turn into a layer that does everything. Flexibility through composition rather than sprawling init args.

However ClipTokenizer seems to buck this trend. Is there a reason we need to have the pad_with_end_token argument on the tokenizer at all? Also what is CLIPPreprocessor for? In general we have a tokenizer (unspecialized for any task), and a preprocessor for a specific task. We might want to do some cleanup of the CLIP API.

@divyashreepathihalli and @james77777778 what do you think?

@james77777778
Copy link
Collaborator

@mattdangerw I wasn't aware of the tagging until today...

However ClipTokenizer seems to buck this trend. Is there a reason we need to have the pad_with_end_token argument on the tokenizer at all?

That option is required by some downstream tasks, such as SD3.

Also what is CLIPPreprocessor for?

It is currently specific to SD3.

These impls might be tailored toward SD3 since they were initially developed for use in SD3. I can propose a PR to refactor them.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
type:Bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants