Skip to content

Commit

Permalink
Add PaliGemma2 (#96)
Browse files Browse the repository at this point in the history
* Add PaliGemma2 arch

* Enable mixed precision check for PaliGemma

* Add conversion script

* Revert ImageConverter and reduce mem usage in the conversion script

* Remove `compute_output_spec`

* Fix `compute_output_shape` issue for keras 3.1

* Add model cards and update conversion script

* update presets

---------

Co-authored-by: divyashreepathihalli <divyashreepathihalli@gmail.com>
  • Loading branch information
james77777778 and divyashreepathihalli authored Dec 4, 2024
1 parent 8b44ade commit 7eb2044
Show file tree
Hide file tree
Showing 6 changed files with 917 additions and 46 deletions.
72 changes: 61 additions & 11 deletions keras_hub/src/models/pali_gemma/pali_gemma_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,22 +48,40 @@ class PaliGemmaBackbone(Backbone):
a two-layer feedforward network for each transformer decoder block.
head_dim: int. The size of each attention head in the mixed decoder.
vit_patch_size: int. The size of each square patch in the input image.
vit_num_heads: int. The number of attention heads for the vision(image)
vit_num_heads: int. The number of attention heads for the vision (image)
transformer encoder.
vit_hidden_dim: int. The size of the transformer hidden state at the end
of each vision transformer layer.
vit_num_layers: int. The number of vision transformer layers.
vit_intermediate_dim: int. The output dimension of the first Dense layer
in a two-layer feedforward network for vision transformer.
vit_pooling: string. The encoded vision embeddings are pooled using the
specified polling setting. The accepted values are `"map"`, `"gap"`,
`"0"` or `"none"`. Defaults to `"none"`.
in a two-layer feedforward network for vision transformer. Defaults
to `4304`.
vit_pooling: `None` or string. The encoded vision embeddings are pooled
using the specified polling setting. The accepted values are
`"map"`, `"gap"`, `"0"` or `None`. Defaults to `None`.
vit_classifier_activation: activation function. The activation that
is used for final output classification in the vision transformer.
Defaults to `None`.
vit_name: string. The name used for vision transformer layers.
query_head_dim_normalize: boolean. If `True` normalize the query before
attention with `head_dim`. If `False`, normalize the query with
`hidden_dim / num_query_heads`. Defaults to `True`.
use_post_ffw_norm: boolean. Whether to normalize after the feedforward
block. Defaults to `False`.
use_post_attention_norm: boolean. Whether to normalize after the attention
block. Defaults to `False`.
attention_logit_soft_cap: `None` or int. Soft cap for the attention
logits. Defaults to `None`.
final_logit_soft_cap: `None` or int. Soft cap for the final logits.
Defaults to `None`.
use_sliding_window_attention: boolean. Whether to use sliding local
window attention. Defaults to `False`.
sliding_window_size: int. Size of the sliding local window. Defaults to
`4096`.
layer_norm_epsilon: float. The epsilon value user for every layer norm
in all transformer blocks.
in all transformer blocks. Defaults to `1e-6`.
dropout: float. Dropout probability for the Transformer decoder blocks.
Defaults to `0`.
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
for the models computations and weights. Note that some
computations, such as softmax and layer normalization will always
Expand Down Expand Up @@ -119,6 +137,13 @@ def __init__(
vit_pooling=None,
vit_classifier_activation=None,
vit_name=None,
query_head_dim_normalize=True,
use_post_ffw_norm=False,
use_post_attention_norm=False,
attention_logit_soft_cap=None,
final_logit_soft_cap=None,
use_sliding_window_attention=False,
sliding_window_size=4096,
layer_norm_epsilon=1e-6,
dropout=0,
dtype=None,
Expand All @@ -136,6 +161,7 @@ def __init__(
seed=None,
),
dtype=dtype,
logit_soft_cap=final_logit_soft_cap,
name="token_embedding",
)
# TODO Remove this. Work around for previous serialization bug.
Expand All @@ -155,12 +181,19 @@ def __init__(
)
self.transformer_layers = []
for i in range(num_layers):
sliding_window = use_sliding_window_attention and (i % 2 == 0)
layer = PaliGemmaDecoderBlock(
hidden_dim=hidden_dim,
intermediate_dim=intermediate_dim,
num_query_heads=num_query_heads,
head_dim=head_dim,
num_query_heads=num_query_heads,
num_key_value_heads=num_key_value_heads,
query_head_dim_normalize=query_head_dim_normalize,
use_post_ffw_norm=use_post_ffw_norm,
use_post_attention_norm=use_post_attention_norm,
logit_soft_cap=attention_logit_soft_cap,
use_sliding_window_attention=sliding_window,
sliding_window_size=sliding_window_size,
dropout=dropout,
dtype=dtype,
name=f"decoder_block_{i}",
Expand All @@ -173,7 +206,9 @@ def __init__(
)

# === Functional Model ===
image_input = self.vit_encoder.inputs[0]
image_input = keras.Input(
shape=(image_size, image_size, 3), name="images"
)
token_id_input = keras.Input(
shape=(None,), dtype="int32", name="token_ids"
)
Expand Down Expand Up @@ -219,7 +254,15 @@ def __init__(
self.head_dim = head_dim
self.layer_norm_epsilon = layer_norm_epsilon
self.dropout = dropout
# VIT Params
# Gemma2 params
self.query_head_dim_normalize = query_head_dim_normalize
self.use_post_ffw_norm = use_post_ffw_norm
self.use_post_attention_norm = use_post_attention_norm
self.attention_logit_soft_cap = attention_logit_soft_cap
self.final_logit_soft_cap = final_logit_soft_cap
self.sliding_window_size = sliding_window_size
self.use_sliding_window_attention = use_sliding_window_attention
# ViT params
self.vit_patch_size = vit_patch_size
self.vit_num_heads = vit_num_heads
self.vit_hidden_dim = vit_hidden_dim
Expand All @@ -243,8 +286,6 @@ def get_config(self):
"hidden_dim": self.hidden_dim,
"intermediate_dim": self.intermediate_dim,
"head_dim": self.head_dim,
"layer_norm_epsilon": self.layer_norm_epsilon,
"dropout": self.dropout,
"vit_patch_size": self.vit_patch_size,
"vit_num_heads": self.vit_num_heads,
"vit_hidden_dim": self.vit_hidden_dim,
Expand All @@ -253,6 +294,15 @@ def get_config(self):
"vit_pooling": self.vit_pooling,
"vit_classifier_activation": self.vit_classifier_activation,
"vit_name": self.vit_name,
"query_head_dim_normalize": self.query_head_dim_normalize,
"use_post_ffw_norm": self.use_post_ffw_norm,
"use_post_attention_norm": self.use_post_attention_norm,
"final_logit_soft_cap": self.final_logit_soft_cap,
"attention_logit_soft_cap": self.attention_logit_soft_cap,
"sliding_window_size": self.sliding_window_size,
"use_sliding_window_attention": self.use_sliding_window_attention,
"layer_norm_epsilon": self.layer_norm_epsilon,
"dropout": self.dropout,
}
)
return config
73 changes: 72 additions & 1 deletion keras_hub/src/models/pali_gemma/pali_gemma_backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def test_backbone_basics(self):
8,
),
variable_length_data=[self.input_data],
run_mixed_precision_check=False, # TODO: Set to `True`
)

@pytest.mark.large
Expand Down Expand Up @@ -98,3 +97,75 @@ def test_all_presets(self):
preset=preset,
input_data=self.input_data,
)


class PaliGemma2BackboneTest(TestCase):
def setUp(self):
self.batch_size = 2
self.vocabulary_size = 256
self.text_sequence_length = 64
self.image_size = 16
self.image_sequence_length = int((self.image_size / 4) ** 2)
self.init_kwargs = {
"vocabulary_size": self.vocabulary_size,
"image_size": self.image_size,
"num_layers": 2,
"num_query_heads": 2,
"num_key_value_heads": 1,
"hidden_dim": 8,
"intermediate_dim": 16,
"head_dim": 4,
"vit_patch_size": 4,
"vit_num_layers": 2,
"vit_num_heads": 2,
"vit_hidden_dim": 8,
"vit_intermediate_dim": 16,
# Gemma2
"query_head_dim_normalize": True,
"use_post_ffw_norm": True,
"use_post_attention_norm": True,
"final_logit_soft_cap": 30,
"attention_logit_soft_cap": 50,
"use_sliding_window_attention": True,
"sliding_window_size": 4096,
}

dummy_images = np.random.rand(
self.batch_size, self.image_size, self.image_size, 3
)
dummy_text_token_ids = np.random.rand(
self.batch_size, self.text_sequence_length
)
self.input_data = {
"token_ids": dummy_text_token_ids,
"images": dummy_images,
"padding_mask": np.ones(
(self.batch_size, self.text_sequence_length),
dtype="int32",
),
"response_mask": np.zeros(
(self.batch_size, self.text_sequence_length),
dtype="int32",
),
}

def test_backbone_basics(self):
self.run_backbone_test(
cls=PaliGemmaBackbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output_shape=(
self.batch_size,
self.text_sequence_length + self.image_sequence_length,
8,
),
variable_length_data=[self.input_data],
)

@pytest.mark.large
def test_saved_model(self):
self.run_model_saving_test(
cls=PaliGemmaBackbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
)
44 changes: 21 additions & 23 deletions keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,33 +31,25 @@ class PaliGemmaDecoderBlock(GemmaDecoderBlock):
the attention layer.
num_key_value_heads: int. The number of heads for the key and value
projections in the attention layer.
query_head_dim_normalize: boolean. If `True` normalize the query before
attention with `head_dim`. If `False`, normalize the query with
`hidden_dim / num_query_heads`. Defaults to `True`.
use_post_ffw_norm: boolean. Whether to normalize after the feedforward
block. Defaults to `False`.
use_post_attention_norm: boolean. Whether to normalize after the
attention block. Defaults to `False`.
logit_soft_cap: `None` or int. Soft cap for the attention logits.
Defaults to `None`.
use_sliding_window_attention: boolean. Whether to use sliding local
window attention. Defaults to `False`.
sliding_window_size: int. Size of the sliding local window. Defaults to
`4096`.
layer_norm_epsilon: float. The epsilon hyperparameter used for layer
normalization.
normalization. Defaults to `1e-6`.
dropout: float. The dropout rate for the transformer attention layer.
Defaults to `0`.
"""

def __init__(
self,
hidden_dim,
intermediate_dim,
head_dim,
num_query_heads,
num_key_value_heads,
layer_norm_epsilon=1e-6,
dropout=0,
**kwargs,
):
super().__init__(
hidden_dim=hidden_dim,
intermediate_dim=intermediate_dim,
head_dim=head_dim,
num_query_heads=num_query_heads,
num_key_value_heads=num_key_value_heads,
layer_norm_epsilon=layer_norm_epsilon,
dropout=dropout,
**kwargs,
)

def call(
self,
x,
Expand All @@ -83,6 +75,9 @@ def call(
attention_mask=attention_mask,
)

if self.use_post_attention_norm:
attention = self.post_attention_norm(attention)

if self.dropout:
attention = self.attention_dropout(attention)

Expand All @@ -94,6 +89,9 @@ def call(
x = keras.activations.gelu(x1, approximate=True) * x2
x = self.ffw_linear(x)

if self.use_post_ffw_norm:
x = self.post_ffw_norm(x)

x = x + attention_x

if cache is not None:
Expand Down
Loading

0 comments on commit 7eb2044

Please sign in to comment.