From f251ed3af839b592f543b6c7c31e1178ed754ff0 Mon Sep 17 00:00:00 2001 From: Divyashree Sreepathihalli Date: Wed, 4 Dec 2024 21:06:00 -0800 Subject: [PATCH 1/5] Adding PaliGemma2 to KerasHub (#1998) * Add PaliGemma2 (#96) * Add PaliGemma2 arch * Enable mixed precision check for PaliGemma * Add conversion script * Revert ImageConverter and reduce mem usage in the conversion script * Remove `compute_output_spec` * Fix `compute_output_shape` issue for keras 3.1 * Add model cards and update conversion script * update presets --------- Co-authored-by: divyashreepathihalli * Update pali_gemma_presets.py - remove mix presets * Update pali_gemma_presets.py * Update convert_pali_gemma2_checkpoints.py --------- Co-authored-by: james77777778 <20734616+james77777778@users.noreply.github.com> --- .../models/pali_gemma/pali_gemma_backbone.py | 72 ++- .../pali_gemma/pali_gemma_backbone_test.py | 73 ++- .../pali_gemma/pali_gemma_decoder_block.py | 44 +- .../models/pali_gemma/pali_gemma_presets.py | 166 ++++++ .../src/models/pali_gemma/pali_gemma_vit.py | 23 +- .../convert_pali_gemma2_checkpoints.py | 518 ++++++++++++++++++ 6 files changed, 850 insertions(+), 46 deletions(-) create mode 100644 tools/checkpoint_conversion/convert_pali_gemma2_checkpoints.py diff --git a/keras_hub/src/models/pali_gemma/pali_gemma_backbone.py b/keras_hub/src/models/pali_gemma/pali_gemma_backbone.py index 75c1cb8ad0..6447ca2fc5 100644 --- a/keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +++ b/keras_hub/src/models/pali_gemma/pali_gemma_backbone.py @@ -48,22 +48,40 @@ class PaliGemmaBackbone(Backbone): a two-layer feedforward network for each transformer decoder block. head_dim: int. The size of each attention head in the mixed decoder. vit_patch_size: int. The size of each square patch in the input image. - vit_num_heads: int. The number of attention heads for the vision(image) + vit_num_heads: int. The number of attention heads for the vision (image) transformer encoder. vit_hidden_dim: int. The size of the transformer hidden state at the end of each vision transformer layer. vit_num_layers: int. The number of vision transformer layers. vit_intermediate_dim: int. The output dimension of the first Dense layer - in a two-layer feedforward network for vision transformer. - vit_pooling: string. The encoded vision embeddings are pooled using the - specified polling setting. The accepted values are `"map"`, `"gap"`, - `"0"` or `"none"`. Defaults to `"none"`. + in a two-layer feedforward network for vision transformer. Defaults + to `4304`. + vit_pooling: `None` or string. The encoded vision embeddings are pooled + using the specified polling setting. The accepted values are + `"map"`, `"gap"`, `"0"` or `None`. Defaults to `None`. vit_classifier_activation: activation function. The activation that is used for final output classification in the vision transformer. + Defaults to `None`. vit_name: string. The name used for vision transformer layers. + query_head_dim_normalize: boolean. If `True` normalize the query before + attention with `head_dim`. If `False`, normalize the query with + `hidden_dim / num_query_heads`. Defaults to `True`. + use_post_ffw_norm: boolean. Whether to normalize after the feedforward + block. Defaults to `False`. + use_post_attention_norm: boolean. Whether to normalize after the attention + block. Defaults to `False`. + attention_logit_soft_cap: `None` or int. Soft cap for the attention + logits. Defaults to `None`. + final_logit_soft_cap: `None` or int. Soft cap for the final logits. + Defaults to `None`. + use_sliding_window_attention: boolean. Whether to use sliding local + window attention. Defaults to `False`. + sliding_window_size: int. Size of the sliding local window. Defaults to + `4096`. layer_norm_epsilon: float. The epsilon value user for every layer norm - in all transformer blocks. + in all transformer blocks. Defaults to `1e-6`. dropout: float. Dropout probability for the Transformer decoder blocks. + Defaults to `0`. dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use for the models computations and weights. Note that some computations, such as softmax and layer normalization will always @@ -119,6 +137,13 @@ def __init__( vit_pooling=None, vit_classifier_activation=None, vit_name=None, + query_head_dim_normalize=True, + use_post_ffw_norm=False, + use_post_attention_norm=False, + attention_logit_soft_cap=None, + final_logit_soft_cap=None, + use_sliding_window_attention=False, + sliding_window_size=4096, layer_norm_epsilon=1e-6, dropout=0, dtype=None, @@ -136,6 +161,7 @@ def __init__( seed=None, ), dtype=dtype, + logit_soft_cap=final_logit_soft_cap, name="token_embedding", ) # TODO Remove this. Work around for previous serialization bug. @@ -155,12 +181,19 @@ def __init__( ) self.transformer_layers = [] for i in range(num_layers): + sliding_window = use_sliding_window_attention and (i % 2 == 0) layer = PaliGemmaDecoderBlock( hidden_dim=hidden_dim, intermediate_dim=intermediate_dim, - num_query_heads=num_query_heads, head_dim=head_dim, + num_query_heads=num_query_heads, num_key_value_heads=num_key_value_heads, + query_head_dim_normalize=query_head_dim_normalize, + use_post_ffw_norm=use_post_ffw_norm, + use_post_attention_norm=use_post_attention_norm, + logit_soft_cap=attention_logit_soft_cap, + use_sliding_window_attention=sliding_window, + sliding_window_size=sliding_window_size, dropout=dropout, dtype=dtype, name=f"decoder_block_{i}", @@ -173,7 +206,9 @@ def __init__( ) # === Functional Model === - image_input = self.vit_encoder.inputs[0] + image_input = keras.Input( + shape=(image_size, image_size, 3), name="images" + ) token_id_input = keras.Input( shape=(None,), dtype="int32", name="token_ids" ) @@ -219,7 +254,15 @@ def __init__( self.head_dim = head_dim self.layer_norm_epsilon = layer_norm_epsilon self.dropout = dropout - # VIT Params + # Gemma2 params + self.query_head_dim_normalize = query_head_dim_normalize + self.use_post_ffw_norm = use_post_ffw_norm + self.use_post_attention_norm = use_post_attention_norm + self.attention_logit_soft_cap = attention_logit_soft_cap + self.final_logit_soft_cap = final_logit_soft_cap + self.sliding_window_size = sliding_window_size + self.use_sliding_window_attention = use_sliding_window_attention + # ViT params self.vit_patch_size = vit_patch_size self.vit_num_heads = vit_num_heads self.vit_hidden_dim = vit_hidden_dim @@ -243,8 +286,6 @@ def get_config(self): "hidden_dim": self.hidden_dim, "intermediate_dim": self.intermediate_dim, "head_dim": self.head_dim, - "layer_norm_epsilon": self.layer_norm_epsilon, - "dropout": self.dropout, "vit_patch_size": self.vit_patch_size, "vit_num_heads": self.vit_num_heads, "vit_hidden_dim": self.vit_hidden_dim, @@ -253,6 +294,15 @@ def get_config(self): "vit_pooling": self.vit_pooling, "vit_classifier_activation": self.vit_classifier_activation, "vit_name": self.vit_name, + "query_head_dim_normalize": self.query_head_dim_normalize, + "use_post_ffw_norm": self.use_post_ffw_norm, + "use_post_attention_norm": self.use_post_attention_norm, + "final_logit_soft_cap": self.final_logit_soft_cap, + "attention_logit_soft_cap": self.attention_logit_soft_cap, + "sliding_window_size": self.sliding_window_size, + "use_sliding_window_attention": self.use_sliding_window_attention, + "layer_norm_epsilon": self.layer_norm_epsilon, + "dropout": self.dropout, } ) return config diff --git a/keras_hub/src/models/pali_gemma/pali_gemma_backbone_test.py b/keras_hub/src/models/pali_gemma/pali_gemma_backbone_test.py index 4d02e075cd..af6ddbe946 100644 --- a/keras_hub/src/models/pali_gemma/pali_gemma_backbone_test.py +++ b/keras_hub/src/models/pali_gemma/pali_gemma_backbone_test.py @@ -61,7 +61,6 @@ def test_backbone_basics(self): 8, ), variable_length_data=[self.input_data], - run_mixed_precision_check=False, # TODO: Set to `True` ) @pytest.mark.large @@ -98,3 +97,75 @@ def test_all_presets(self): preset=preset, input_data=self.input_data, ) + + +class PaliGemma2BackboneTest(TestCase): + def setUp(self): + self.batch_size = 2 + self.vocabulary_size = 256 + self.text_sequence_length = 64 + self.image_size = 16 + self.image_sequence_length = int((self.image_size / 4) ** 2) + self.init_kwargs = { + "vocabulary_size": self.vocabulary_size, + "image_size": self.image_size, + "num_layers": 2, + "num_query_heads": 2, + "num_key_value_heads": 1, + "hidden_dim": 8, + "intermediate_dim": 16, + "head_dim": 4, + "vit_patch_size": 4, + "vit_num_layers": 2, + "vit_num_heads": 2, + "vit_hidden_dim": 8, + "vit_intermediate_dim": 16, + # Gemma2 + "query_head_dim_normalize": True, + "use_post_ffw_norm": True, + "use_post_attention_norm": True, + "final_logit_soft_cap": 30, + "attention_logit_soft_cap": 50, + "use_sliding_window_attention": True, + "sliding_window_size": 4096, + } + + dummy_images = np.random.rand( + self.batch_size, self.image_size, self.image_size, 3 + ) + dummy_text_token_ids = np.random.rand( + self.batch_size, self.text_sequence_length + ) + self.input_data = { + "token_ids": dummy_text_token_ids, + "images": dummy_images, + "padding_mask": np.ones( + (self.batch_size, self.text_sequence_length), + dtype="int32", + ), + "response_mask": np.zeros( + (self.batch_size, self.text_sequence_length), + dtype="int32", + ), + } + + def test_backbone_basics(self): + self.run_backbone_test( + cls=PaliGemmaBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=( + self.batch_size, + self.text_sequence_length + self.image_sequence_length, + 8, + ), + variable_length_data=[self.input_data], + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=PaliGemmaBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) diff --git a/keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py b/keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py index dd445a13c0..b59b0498e7 100644 --- a/keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +++ b/keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py @@ -31,33 +31,25 @@ class PaliGemmaDecoderBlock(GemmaDecoderBlock): the attention layer. num_key_value_heads: int. The number of heads for the key and value projections in the attention layer. + query_head_dim_normalize: boolean. If `True` normalize the query before + attention with `head_dim`. If `False`, normalize the query with + `hidden_dim / num_query_heads`. Defaults to `True`. + use_post_ffw_norm: boolean. Whether to normalize after the feedforward + block. Defaults to `False`. + use_post_attention_norm: boolean. Whether to normalize after the + attention block. Defaults to `False`. + logit_soft_cap: `None` or int. Soft cap for the attention logits. + Defaults to `None`. + use_sliding_window_attention: boolean. Whether to use sliding local + window attention. Defaults to `False`. + sliding_window_size: int. Size of the sliding local window. Defaults to + `4096`. layer_norm_epsilon: float. The epsilon hyperparameter used for layer - normalization. + normalization. Defaults to `1e-6`. dropout: float. The dropout rate for the transformer attention layer. + Defaults to `0`. """ - def __init__( - self, - hidden_dim, - intermediate_dim, - head_dim, - num_query_heads, - num_key_value_heads, - layer_norm_epsilon=1e-6, - dropout=0, - **kwargs, - ): - super().__init__( - hidden_dim=hidden_dim, - intermediate_dim=intermediate_dim, - head_dim=head_dim, - num_query_heads=num_query_heads, - num_key_value_heads=num_key_value_heads, - layer_norm_epsilon=layer_norm_epsilon, - dropout=dropout, - **kwargs, - ) - def call( self, x, @@ -83,6 +75,9 @@ def call( attention_mask=attention_mask, ) + if self.use_post_attention_norm: + attention = self.post_attention_norm(attention) + if self.dropout: attention = self.attention_dropout(attention) @@ -94,6 +89,9 @@ def call( x = keras.activations.gelu(x1, approximate=True) * x2 x = self.ffw_linear(x) + if self.use_post_ffw_norm: + x = self.post_ffw_norm(x) + x = x + attention_x if cache is not None: diff --git a/keras_hub/src/models/pali_gemma/pali_gemma_presets.py b/keras_hub/src/models/pali_gemma/pali_gemma_presets.py index eb27596469..056d949d80 100644 --- a/keras_hub/src/models/pali_gemma/pali_gemma_presets.py +++ b/keras_hub/src/models/pali_gemma/pali_gemma_presets.py @@ -52,4 +52,170 @@ }, "kaggle_handle": "kaggle://keras/paligemma/keras/pali_gemma_3b_896/3", }, + # PaliGemma2 + "pali_gemma2_3b_ft_docci_448": { + "metadata": { + "description": ( + "3 billion parameter, image size 448, 27-layer for " + "SigLIP-So400m vision encoder and 26-layer Gemma2 2B lanuage " + "model. This model has been fine-tuned on the DOCCI dataset " + "for improved descriptions with fine-grained details." + ), + "params": 3032979696, + "official_name": "PaliGemma2", + "path": "pali_gemma2", + "model_card": "https://www.kaggle.com/models/google/paligemma-2", + }, + "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_3b_ft_docci_448/1", + }, + "pali_gemma2_10b_ft_docci_448": { + "metadata": { + "description": ( + "10 billion parameter, 27-layer for SigLIP-So400m vision " + "encoder and 42-layer Gemma2 9B lanuage model. This model has " + "been fine-tuned on the DOCCI dataset for improved " + "descriptions with fine-grained details." + ), + "params": 9663294192, + "official_name": "PaliGemma2", + "path": "pali_gemma2", + "model_card": "https://www.kaggle.com/models/google/paligemma-2", + }, + "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_10b_ft_docci_448/1", + }, + "pali_gemma2_3b_pt_224": { + "metadata": { + "description": ( + "3 billion parameter, image size 224, 27-layer for " + "SigLIP-So400m vision encoder and 26-layer Gemma2 2B lanuage " + "model. This model has been pre-trained on a mixture of " + "datasets." + ), + "params": 3032094960, + "official_name": "PaliGemma2", + "path": "pali_gemma2", + "model_card": "https://www.kaggle.com/models/google/paligemma-2", + }, + "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_3b_pt_224/1", + }, + "pali_gemma2_3b_pt_448": { + "metadata": { + "description": ( + "3 billion parameter, image size 448, 27-layer for " + "SigLIP-So400m vision encoder and 26-layer Gemma2 2B lanuage " + "model. This model has been pre-trained on a mixture of " + "datasets." + ), + "params": 3032979696, + "official_name": "PaliGemma2", + "path": "pali_gemma2", + "model_card": "https://www.kaggle.com/models/google/paligemma-2", + }, + "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_3b_pt_448/1", + }, + "pali_gemma2_3b_pt_896": { + "metadata": { + "description": ( + "3 billion parameter, image size 896, 27-layer for " + "SigLIP-So400m vision encoder and 26-layer Gemma2 2B lanuage " + "model. This model has been pre-trained on a mixture of " + "datasets." + ), + "params": 3036518640, + "official_name": "PaliGemma2", + "path": "pali_gemma2", + "model_card": "https://www.kaggle.com/models/google/paligemma-2", + }, + "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_3b_pt_896/1", + }, + "pali_gemma2_10b_pt_224": { + "metadata": { + "description": ( + "10 billion parameter, image size 224, 27-layer for " + "SigLIP-So400m vision encoder and 42-layer Gemma2 9B lanuage " + "model. This model has been pre-trained on a mixture of " + "datasets." + ), + "params": 9662409456, + "official_name": "PaliGemma2", + "path": "pali_gemma2", + "model_card": "https://www.kaggle.com/models/google/paligemma-2", + }, + "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_10b_pt_224/1", + }, + "pali_gemma2_10b_pt_448": { + "metadata": { + "description": ( + "10 billion parameter, image size 448, 27-layer for " + "SigLIP-So400m vision encoder and 42-layer Gemma2 9B lanuage " + "model. This model has been pre-trained on a mixture of " + "datasets." + ), + "params": 9663294192, + "official_name": "PaliGemma2", + "path": "pali_gemma2", + "model_card": "https://www.kaggle.com/models/google/paligemma-2", + }, + "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_10b_pt_448/1", + }, + "pali_gemma2_10b_pt_896": { + "metadata": { + "description": ( + "10 billion parameter, image size 896, 27-layer for " + "SigLIP-So400m vision encoder and 42-layer Gemma2 9B lanuage " + "model. This model has been pre-trained on a mixture of " + "datasets." + ), + "params": 9666833136, + "official_name": "PaliGemma2", + "path": "pali_gemma2", + "model_card": "https://www.kaggle.com/models/google/paligemma-2", + }, + "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_10b_pt_896/1", + }, + "pali_gemma2_28b_pt_224": { + "metadata": { + "description": ( + "28 billion parameter, image size 224, 27-layer for " + "SigLIP-So400m vision encoder and 46-layer Gemma2 27B lanuage " + "model. This model has been pre-trained on a mixture of " + "datasets." + ), + "params": 9662409456, + "official_name": "PaliGemma2", + "path": "pali_gemma2", + "model_card": "https://www.kaggle.com/models/google/paligemma-2", + }, + "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_28b_pt_224/1", + }, + "pali_gemma2_28b_pt_448": { + "metadata": { + "description": ( + "28 billion parameter, image size 448, 27-layer for " + "SigLIP-So400m vision encoder and 46-layer Gemma2 27B lanuage " + "model. This model has been pre-trained on a mixture of " + "datasets." + ), + "params": 9663294192, + "official_name": "PaliGemma2", + "path": "pali_gemma2", + "model_card": "https://www.kaggle.com/models/google/paligemma-2", + }, + "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_28b_pt_448/1", + }, + "pali_gemma2_28b_pt_896": { + "metadata": { + "description": ( + "28 billion parameter, image size 896, 27-layer for " + "SigLIP-So400m vision encoder and 46-layer Gemma2 27B lanuage " + "model. This model has been pre-trained on a mixture of " + "datasets." + ), + "params": 9666833136, + "official_name": "PaliGemma2", + "path": "pali_gemma2", + "model_card": "https://www.kaggle.com/models/google/paligemma-2", + }, + "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_28b_pt_896/1", + }, } diff --git a/keras_hub/src/models/pali_gemma/pali_gemma_vit.py b/keras_hub/src/models/pali_gemma/pali_gemma_vit.py index 190a5e8e13..bed2b3ea66 100644 --- a/keras_hub/src/models/pali_gemma/pali_gemma_vit.py +++ b/keras_hub/src/models/pali_gemma/pali_gemma_vit.py @@ -12,7 +12,7 @@ def __init__( dtype=None, **kwargs, ): - super().__init__(**kwargs) + super().__init__(dtype=dtype, **kwargs) self.hidden_dim = hidden_dim self.image_size = image_size self.patch_size = patch_size @@ -72,7 +72,7 @@ def __init__( dtype=None, **kwargs, ): - super().__init__(**kwargs) + super().__init__(dtype=dtype, **kwargs) self.hidden_dim = hidden_dim self.num_heads = num_heads @@ -282,7 +282,7 @@ def __init__( dtype=None, **kwargs, ): - super().__init__(**kwargs) + super().__init__(dtype=dtype, **kwargs) self.hidden_dim = hidden_dim self.num_layers = num_layers self.num_heads = num_heads @@ -311,25 +311,26 @@ def __init__( for i in range(self.num_layers) ] - def build(self, input_shape): - self.vision_embeddings.build(input_shape) + def build(self, inputs_shape): + self.vision_embeddings.build(inputs_shape) for block in self.resblocks: block.build([None, None, self.hidden_dim]) self.encoder_layer_norm.build([None, None, self.hidden_dim]) self.built = True - def call( - self, - x, - mask=None, - ): - x = self.vision_embeddings(x) + def call(self, inputs, mask=None): + x = self.vision_embeddings(inputs) for block in self.resblocks: x = block(x, mask=mask) x = self.encoder_layer_norm(x) return x def compute_output_shape(self, inputs_shape): + if inputs_shape is None: + # Fix the compatibility issue with Keras 3.1 where + # `compute_output_spec` fails to propagate `inputs_shape` + # correctly, causing it to be `None`. + inputs_shape = [None, None, None] return [inputs_shape[0], inputs_shape[1], self.hidden_dim] def get_config(self): diff --git a/tools/checkpoint_conversion/convert_pali_gemma2_checkpoints.py b/tools/checkpoint_conversion/convert_pali_gemma2_checkpoints.py new file mode 100644 index 0000000000..596fca6063 --- /dev/null +++ b/tools/checkpoint_conversion/convert_pali_gemma2_checkpoints.py @@ -0,0 +1,518 @@ +""" +Convert PaliGemma2 checkpoints to the Keras format. + +The checkpoints are from here: +https://www.kaggle.com/models/google/paligemma-2 + +The `vocabulary.spm` is from here: +https://www.kaggle.com/models/keras/paligemma/ + +Setup: + +```shell +pip install kaggle +export KAGGLE_USERNAME=... +export KAGGLE_KEY=... +``` + +Usage: + +```shell +python -m tools.checkpoint_conversion.convert_pali_gemma2_checkpoints --preset pali_gemma2_3b_pt_224 +python -m tools.checkpoint_conversion.convert_pali_gemma2_checkpoints --preset pali_gemma2_3b_pt_224 --weights_path ./path/to/weights.npz +python -m tools.checkpoint_conversion.convert_pali_gemma2_checkpoints --preset pali_gemma2_3b_pt_224 --proto_path ./path/to/vocabulary.spm +python -m tools.checkpoint_conversion.convert_pali_gemma2_checkpoints --preset pali_gemma2_3b_pt_224 --upload_uri kaggle://divyasss/hongyu_sharing/keras/pali_gemma2_3b_pt_224 +``` +""" + +import io +import os +import pathlib + +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" +os.environ["KERAS_BACKEND"] = "jax" +# No GPU for conversion, makes memory management easier. +os.environ["CUDA_VISIBLE_DEVICES"] = "-1" +os.environ["JAX_PLATFORMS"] = "cpu" + +import jax # noqa: E402 +import kagglehub # noqa: E402 +import keras # noqa: E402 +import ml_dtypes # noqa: E402 +import numpy as np # noqa: E402 +import PIL # noqa: E402 +import requests # noqa: E402 +from absl import app # noqa: E402 +from absl import flags # noqa: E402 + +import keras_hub # noqa: E402 + +FLAGS = flags.FLAGS + +PRESET_MAP = { + "pali_gemma2_3b_ft_docci_448": "google/paligemma-2/jax/paligemma2-3b-ft-docci-448", + "pali_gemma2_10b_ft_docci_448": "google/paligemma-2/jax/paligemma2-10b-ft-docci-448", + "pali_gemma2_3b_pt_224": "google/paligemma-2/jax/paligemma2-3b-pt-224", + "pali_gemma2_3b_pt_448": "google/paligemma-2/jax/paligemma2-3b-pt-448", + "pali_gemma2_3b_pt_896": "google/paligemma-2/jax/paligemma2-3b-pt-896", + "pali_gemma2_10b_pt_224": "google/paligemma-2/jax/paligemma2-10b-pt-224", + "pali_gemma2_10b_pt_448": "google/paligemma-2/jax/paligemma2-10b-pt-448", + "pali_gemma2_10b_pt_896": "google/paligemma-2/jax/paligemma2-10b-pt-896", + "pali_gemma2_28b_pt_224": "google/paligemma-2/jax/paligemma2-28b-pt-224", + "pali_gemma2_28b_pt_448": "google/paligemma-2/jax/paligemma2-28b-pt-448", + "pali_gemma2_28b_pt_896": "google/paligemma-2/jax/paligemma2-28b-pt-896", +} + + +flags.DEFINE_string( + "preset", + None, + f'Must be one of {",".join(PRESET_MAP.keys())}', + required=True, +) +flags.DEFINE_string( + "weights_path", + None, + "Optional path for the model weights to convert.", +) +flags.DEFINE_string( + "proto_path", + "vocabulary.spm", + "Optional path for the SentencePiece proto file of the tokenizer.", +) +flags.DEFINE_string( + "upload_uri", + None, + 'Could be "kaggle://keras/{variant}/keras/{preset}"', + required=False, +) + + +def format_weights(weights): + def recover_dtype(a): + """Numpy's stores bfloat16 type as "void" type, so we recover it.""" + if hasattr(a, "dtype") and a.dtype.type is np.void: + assert a.itemsize == 2, "Unknown dtype!" + return a.view(ml_dtypes.bfloat16) + else: + return a + + weights = dict(weights) + weights = jax.tree.map(recover_dtype, weights) + + formatted = {} + + # LLM part + prefix = "params/llm" + num_layers = int(weights[f"{prefix}/layers/mlp/linear"].shape[0]) + formatted["llm/embedding"] = weights[f"{prefix}/embedder/input_embedding"] + for i in range(num_layers): + layer_prefix = f"{prefix}/layers" + formatted_prefix = f"llm/decoder_block_{i}" + # RMSNorm + formatted[f"{formatted_prefix}/pre_norm/scale"] = weights[ + f"{layer_prefix}/pre_attention_norm/scale" + ][i] + formatted[f"{formatted_prefix}/post_norm/scale"] = weights[ + f"{layer_prefix}/post_attention_norm/scale" + ][i] + formatted[f"{formatted_prefix}/pre_ffw_norm/scale"] = weights[ + f"{layer_prefix}/pre_ffw_norm/scale" + ][i] + formatted[f"{formatted_prefix}/post_ffw_norm/scale"] = weights[ + f"{layer_prefix}/post_ffw_norm/scale" + ][i] + # MHA + formatted[f"{formatted_prefix}/mha/q/kernel"] = weights[ + f"{layer_prefix}/attn/q_einsum/w" + ][i] + formatted[f"{formatted_prefix}/mha/k/kernel"] = weights[ + f"{layer_prefix}/attn/kv_einsum/w" + ][i, 0] + formatted[f"{formatted_prefix}/mha/v/kernel"] = weights[ + f"{layer_prefix}/attn/kv_einsum/w" + ][i, 1] + formatted[f"{formatted_prefix}/mha/o/kernel"] = weights[ + f"{layer_prefix}/attn/attn_vec_einsum/w" + ][i] + # MLP + formatted[f"{formatted_prefix}/ffw_gating/kernel"] = weights[ + f"{layer_prefix}/mlp/gating_einsum" + ][i, 0] + formatted[f"{formatted_prefix}/ffw_gating_2/kernel"] = weights[ + f"{layer_prefix}/mlp/gating_einsum" + ][i, 1] + formatted[f"{formatted_prefix}/ffw_linear/kernel"] = weights[ + f"{layer_prefix}/mlp/linear" + ][i] + formatted["llm/final_normalization/scale"] = weights[ + f"{prefix}/final_norm/scale" + ] + + # ViT part + prefix = "params/img" + num_layers = int( + weights[f"{prefix}/Transformer/encoderblock/LayerNorm_1/scale"].shape[0] + ) + formatted["img/embedding/kernel"] = weights[f"{prefix}/embedding/kernel"] + formatted["img/embedding/bias"] = weights[f"{prefix}/embedding/bias"] + formatted["img/embedding/pos"] = weights[f"{prefix}/pos_embedding"] + formatted["img/ln/gamma"] = weights[ + f"{prefix}/Transformer/encoder_norm/scale" + ] + formatted["img/ln/beta"] = weights[ + f"{prefix}/Transformer/encoder_norm/bias" + ] + for i in range(num_layers): + encoder_prefix = f"{prefix}/Transformer/encoderblock" + formatted_prefix = f"img/encoder_block_{i}" + # MHA + formatted[f"{formatted_prefix}/mha/q/kernel"] = weights[ + f"{encoder_prefix}/MultiHeadDotProductAttention_0/query/kernel" + ][i] + formatted[f"{formatted_prefix}/mha/q/bias"] = weights[ + f"{encoder_prefix}/MultiHeadDotProductAttention_0/query/bias" + ][i] + formatted[f"{formatted_prefix}/mha/k/kernel"] = weights[ + f"{encoder_prefix}/MultiHeadDotProductAttention_0/key/kernel" + ][i] + formatted[f"{formatted_prefix}/mha/k/bias"] = weights[ + f"{encoder_prefix}/MultiHeadDotProductAttention_0/key/bias" + ][i] + formatted[f"{formatted_prefix}/mha/v/kernel"] = weights[ + f"{encoder_prefix}/MultiHeadDotProductAttention_0/value/kernel" + ][i] + formatted[f"{formatted_prefix}/mha/v/bias"] = weights[ + f"{encoder_prefix}/MultiHeadDotProductAttention_0/value/bias" + ][i] + formatted[f"{formatted_prefix}/mha/o/kernel"] = weights[ + f"{encoder_prefix}/MultiHeadDotProductAttention_0/out/kernel" + ][i] + formatted[f"{formatted_prefix}/mha/o/bias"] = weights[ + f"{encoder_prefix}/MultiHeadDotProductAttention_0/out/bias" + ][i] + # LN 0 + formatted[f"{formatted_prefix}/ln_0/gamma"] = weights[ + f"{encoder_prefix}/LayerNorm_0/scale" + ][i] + formatted[f"{formatted_prefix}/ln_0/beta"] = weights[ + f"{encoder_prefix}/LayerNorm_0/bias" + ][i] + # MLP + formatted[f"{formatted_prefix}/mlp_1/kernel"] = weights[ + f"{encoder_prefix}/MlpBlock_0/Dense_0/kernel" + ][i] + formatted[f"{formatted_prefix}/mlp_1/bias"] = weights[ + f"{encoder_prefix}/MlpBlock_0/Dense_0/bias" + ][i] + formatted[f"{formatted_prefix}/mlp_2/kernel"] = weights[ + f"{encoder_prefix}/MlpBlock_0/Dense_1/kernel" + ][i] + formatted[f"{formatted_prefix}/mlp_2/bias"] = weights[ + f"{encoder_prefix}/MlpBlock_0/Dense_1/bias" + ][i] + # LN 1 + formatted[f"{formatted_prefix}/ln_1/gamma"] = weights[ + f"{encoder_prefix}/LayerNorm_1/scale" + ][i] + formatted[f"{formatted_prefix}/ln_1/beta"] = weights[ + f"{encoder_prefix}/LayerNorm_1/bias" + ][i] + formatted["img/head/kernel"] = weights[f"{prefix}/head/kernel"] + formatted["img/head/bias"] = weights[f"{prefix}/head/bias"] + return formatted + + +def convert_tokenizer(proto_path): + return keras_hub.models.PaliGemmaTokenizer(proto=proto_path) + + +def convert_image_converter(image_size): + return keras_hub.layers.PaliGemmaImageConverter( + image_size=(image_size, image_size), + scale=1.0 / 127.5, + offset=-1, + ) + + +def convert_model(preset): + model_config = { + "vocabulary_size": 257152, + "vit_patch_size": 14, + "vit_num_heads": 16, + "vit_hidden_dim": 1152, + "vit_num_layers": 27, + # Gemma2 + "query_head_dim_normalize": True, + "use_post_ffw_norm": True, + "use_post_attention_norm": True, + "final_logit_soft_cap": 30, + "attention_logit_soft_cap": 50, + "use_sliding_window_attention": True, + "sliding_window_size": 4096, + } + preset = str(preset) + + # 2B, 10B, 28B + if "_3b_" in preset: + model_config.update( + { + "num_layers": 26, + "num_query_heads": 8, + "num_key_value_heads": 4, + "hidden_dim": 2304, + "intermediate_dim": 18432, + "head_dim": 256, + } + ) + elif "_10b_" in preset: + model_config.update( + { + "num_layers": 42, + "num_query_heads": 16, + "num_key_value_heads": 8, + "hidden_dim": 3584, + "intermediate_dim": 28672, + "head_dim": 256, + } + ) + elif "_28b_" in preset: + model_config.update( + { + "num_layers": 46, + "num_query_heads": 32, + "num_key_value_heads": 16, + "hidden_dim": 4608, + "intermediate_dim": 73728, + "head_dim": 128, + "query_head_dim_normalize": False, # Only for 28B + } + ) + + # Image size + image_size = int(preset.split("_")[-1]) + model_config.update({"image_size": image_size}) + + return keras_hub.models.PaliGemmaBackbone(**model_config) + + +def convert_weights(keras_model, weights): + from keras_hub.src.models.pali_gemma.pali_gemma_decoder_block import ( + PaliGemmaDecoderBlock, + ) + from keras_hub.src.models.pali_gemma.pali_gemma_vit import ( + PaliGemmaVitEncoder, + ) + from keras_hub.src.models.pali_gemma.pali_gemma_vit import ( + PaliGemmaVitEncoderBlock, + ) + + if not isinstance(keras_model, keras_hub.models.PaliGemmaBackbone): + raise ValueError( + "`keras_model` must be a `keras_hub.models.PaliGemmaBackbone`. " + f"Received: keras_model={keras_model} of type {type(keras_model)}" + ) + + # LLM part + keras_model.token_embedding.embeddings.assign(weights["llm/embedding"]) + for i, layer in enumerate(keras_model.transformer_layers): + if not isinstance(layer, PaliGemmaDecoderBlock): + raise ValueError + prefix = f"llm/decoder_block_{i}" + # RMSNorm + layer.pre_attention_norm.scale.assign( + weights[f"{prefix}/pre_norm/scale"] + ) + layer.post_attention_norm.scale.assign( + weights[f"{prefix}/post_norm/scale"] + ) + layer.pre_ffw_norm.scale.assign(weights[f"{prefix}/pre_ffw_norm/scale"]) + layer.post_ffw_norm.scale.assign( + weights[f"{prefix}/post_ffw_norm/scale"] + ) + # MHA + layer.attention.query_dense.kernel.assign( + weights[f"{prefix}/mha/q/kernel"] + ) + layer.attention.key_dense.kernel.assign( + weights[f"{prefix}/mha/k/kernel"] + ) + layer.attention.value_dense.kernel.assign( + weights[f"{prefix}/mha/v/kernel"] + ) + layer.attention.output_dense.kernel.assign( + weights[f"{prefix}/mha/o/kernel"] + ) + # MLP + layer.gating_ffw.kernel.assign(weights[f"{prefix}/ffw_gating/kernel"]) + layer.gating_ffw_2.kernel.assign( + weights[f"{prefix}/ffw_gating_2/kernel"] + ) + layer.ffw_linear.kernel.assign(weights[f"{prefix}/ffw_linear/kernel"]) + keras_model.layer_norm.scale.assign( + weights["llm/final_normalization/scale"] + ) + + # ViT part + vit_encoder = keras_model.vit_encoder.get_layer("image_encoder") + if not isinstance(vit_encoder, PaliGemmaVitEncoder): + raise ValueError + vit_encoder.encoder_layer_norm.gamma.assign(weights["img/ln/gamma"]) + vit_encoder.encoder_layer_norm.beta.assign(weights["img/ln/beta"]) + vit_encoder.vision_embeddings.patch_embedding.kernel.assign( + weights["img/embedding/kernel"] + ) + vit_encoder.vision_embeddings.patch_embedding.bias.assign( + weights["img/embedding/bias"] + ) + vit_encoder.vision_embeddings.position_embedding.embeddings.assign( + weights["img/embedding/pos"][0] + ) + for i, layer in enumerate(vit_encoder.resblocks): + if not isinstance(layer, PaliGemmaVitEncoderBlock): + raise ValueError + prefix = f"img/encoder_block_{i}" + input_dim = hidden_dim = layer.attn.hidden_dim + # MHA + layer.attn.query_proj.kernel.assign( + np.reshape( + weights[f"{prefix}/mha/q/kernel"], (input_dim, hidden_dim) + ) + ) + layer.attn.query_proj.bias.assign( + np.reshape(weights[f"{prefix}/mha/q/bias"], (-1,)) + ) + layer.attn.key_proj.kernel.assign( + np.reshape( + weights[f"{prefix}/mha/k/kernel"], (input_dim, hidden_dim) + ) + ) + layer.attn.key_proj.bias.assign( + np.reshape(weights[f"{prefix}/mha/k/bias"], (-1,)) + ) + layer.attn.value_proj.kernel.assign( + np.reshape( + weights[f"{prefix}/mha/v/kernel"], (input_dim, hidden_dim) + ) + ) + layer.attn.value_proj.bias.assign( + np.reshape(weights[f"{prefix}/mha/v/bias"], (-1,)) + ) + layer.attn.out_proj.kernel.assign( + np.reshape( + weights[f"{prefix}/mha/o/kernel"], (input_dim, hidden_dim) + ) + ) + layer.attn.out_proj.bias.assign(weights[f"{prefix}/mha/o/bias"]) + # LN 0 + layer.layer_norm_1.gamma.assign(weights[f"{prefix}/ln_0/gamma"]) + layer.layer_norm_1.beta.assign(weights[f"{prefix}/ln_0/beta"]) + # MLP + layer.mlp_dense_1.kernel.assign(weights[f"{prefix}/mlp_1/kernel"]) + layer.mlp_dense_1.bias.assign(weights[f"{prefix}/mlp_1/bias"]) + layer.mlp_dense_2.kernel.assign(weights[f"{prefix}/mlp_2/kernel"]) + layer.mlp_dense_2.bias.assign(weights[f"{prefix}/mlp_2/bias"]) + # LN 1 + layer.layer_norm_2.gamma.assign(weights[f"{prefix}/ln_1/gamma"]) + layer.layer_norm_2.beta.assign(weights[f"{prefix}/ln_1/beta"]) + vit_classifier = keras_model.vit_encoder.get_layer("image_classifier") + if not isinstance(vit_classifier, keras.layers.Dense): + raise ValueError + vit_classifier.kernel.assign(weights["img/head/kernel"]) + vit_classifier.bias.assign(weights["img/head/bias"]) + + return keras_model + + +def validate_output(keras_model, keras_tokenizer, keras_image_converter): + def read_image(url): + contents = io.BytesIO(requests.get(url).content) + image = PIL.Image.open(contents) + image = np.array(image).astype("float32") + # Remove alpha channel if neccessary. + if image.shape[2] == 4: + image = image[:, :, :3] + return image + + image = read_image( + "https://storage.googleapis.com/keras-cv/models/paligemma/cow_beach_1.png" + ) + prompt = "answer en where is the cow standing?\n" + max_length = 32 + preprocessor = keras_hub.models.PaliGemmaCausalLMPreprocessor( + tokenizer=keras_tokenizer, image_converter=keras_image_converter + ) + pali_gemma_lm = keras_hub.models.PaliGemmaCausalLM( + preprocessor=preprocessor, backbone=keras_model + ) + keras_output = pali_gemma_lm.generate( + inputs={"images": image, "prompts": prompt}, max_length=max_length + ) + keras_output = str(keras_output).replace(prompt, "") + print("🔶 Prompt:", prompt.replace("\n", "")) + print("🔶 KerasHub output:", keras_output) + + # TODO: Verify numerics with JAX model. + + +def main(_): + preset = str(FLAGS.preset) + print(f"🏃 Coverting {preset}") + + # Currently all weights are bfloat16 (and have much faster download times + # for it). We follow suit with Keras weights. + keras.config.set_floatx("bfloat16") + + if FLAGS.weights_path is not None: + weights_path = pathlib.Path(FLAGS.weights_path) + else: + presets = PRESET_MAP.keys() + if preset not in presets: + raise ValueError( + f"Invalid preset {preset}. Must be one of {list(presets)}" + ) + handle = PRESET_MAP[preset] + model_dir = kagglehub.model_download(handle) + print("✅ JAX model downloaded from kaggle") + + files = list(pathlib.Path(model_dir).glob("*.npz")) + if len(files) != 1: + raise ValueError( + f"Found too many files in {model_dir}. Expected only one file. " + f"Recevied: {files}" + ) + weights_path = files[0] + + weights = np.load(weights_path, allow_pickle=False) + weights = format_weights(weights) + image_size = int(preset.split("_")[-1]) + print("✅ JAX model weights loaded") + + keras_tokenizer = convert_tokenizer(FLAGS.proto_path) + keras_image_converter = convert_image_converter(image_size) + keras_model = convert_model(preset) + print("✅ Keras model loaded") + + convert_weights(keras_model, weights) + del weights + print("✅ Weights converted") + + validate_output(keras_model, keras_tokenizer, keras_image_converter) + print("✅ Output validated") + + keras_model.save_to_preset(preset) + keras_tokenizer.save_to_preset(preset) + keras_image_converter.save_to_preset(preset) + del keras_model + del keras_tokenizer + del keras_image_converter + print(f"🏁 Preset saved to ./{preset}") + + upload_uri = FLAGS.upload_uri + if upload_uri: + keras_hub.upload_preset(uri=upload_uri, preset=f"./{preset}") + print(f"🏁 Preset uploaded to {upload_uri}") + + +if __name__ == "__main__": + app.run(main) From 141b85976362c29e76796a922b5e2cd2e426c8ef Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Thu, 5 Dec 2024 05:27:11 +0000 Subject: [PATCH 2/5] Version bump to 0.18.0 --- keras_hub/src/version_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_hub/src/version_utils.py b/keras_hub/src/version_utils.py index 422d964bb7..3fc7c9de55 100644 --- a/keras_hub/src/version_utils.py +++ b/keras_hub/src/version_utils.py @@ -1,7 +1,7 @@ from keras_hub.src.api_export import keras_hub_export # Unique source of truth for the version number. -__version__ = "0.18.0.dev0" +__version__ = "0.18.0" @keras_hub_export("keras_hub.version") From da6db4eab598d74443bd5f1da78199786b761507 Mon Sep 17 00:00:00 2001 From: Divyashree Sreepathihalli Date: Wed, 4 Dec 2024 21:50:28 -0800 Subject: [PATCH 3/5] Update pali_gemma_presets.py (#2003) * Update pali_gemma_presets.py * code reformat --- .../models/pali_gemma/pali_gemma_presets.py | 36 +++++++++---------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/keras_hub/src/models/pali_gemma/pali_gemma_presets.py b/keras_hub/src/models/pali_gemma/pali_gemma_presets.py index 056d949d80..ffcf3ecafd 100644 --- a/keras_hub/src/models/pali_gemma/pali_gemma_presets.py +++ b/keras_hub/src/models/pali_gemma/pali_gemma_presets.py @@ -83,7 +83,7 @@ }, "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_10b_ft_docci_448/1", }, - "pali_gemma2_3b_pt_224": { + "pali_gemma2_pt_3b_224": { "metadata": { "description": ( "3 billion parameter, image size 224, 27-layer for " @@ -96,9 +96,9 @@ "path": "pali_gemma2", "model_card": "https://www.kaggle.com/models/google/paligemma-2", }, - "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_3b_pt_224/1", + "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_pt_3b_224/1", }, - "pali_gemma2_3b_pt_448": { + "pali_gemma2_pt_3b_448": { "metadata": { "description": ( "3 billion parameter, image size 448, 27-layer for " @@ -111,9 +111,9 @@ "path": "pali_gemma2", "model_card": "https://www.kaggle.com/models/google/paligemma-2", }, - "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_3b_pt_448/1", + "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_pt_3b_448/1", }, - "pali_gemma2_3b_pt_896": { + "pali_gemma2_pt_3b_896": { "metadata": { "description": ( "3 billion parameter, image size 896, 27-layer for " @@ -126,9 +126,9 @@ "path": "pali_gemma2", "model_card": "https://www.kaggle.com/models/google/paligemma-2", }, - "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_3b_pt_896/1", + "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_pt_3b_896/1", }, - "pali_gemma2_10b_pt_224": { + "pali_gemma2_pt_10b_224": { "metadata": { "description": ( "10 billion parameter, image size 224, 27-layer for " @@ -141,9 +141,9 @@ "path": "pali_gemma2", "model_card": "https://www.kaggle.com/models/google/paligemma-2", }, - "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_10b_pt_224/1", + "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_pt_10b_224/1", }, - "pali_gemma2_10b_pt_448": { + "pali_gemma2_pt_10b_448": { "metadata": { "description": ( "10 billion parameter, image size 448, 27-layer for " @@ -156,9 +156,9 @@ "path": "pali_gemma2", "model_card": "https://www.kaggle.com/models/google/paligemma-2", }, - "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_10b_pt_448/1", + "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_pt_10b_448/1", }, - "pali_gemma2_10b_pt_896": { + "pali_gemma2_pt_10b_896": { "metadata": { "description": ( "10 billion parameter, image size 896, 27-layer for " @@ -171,9 +171,9 @@ "path": "pali_gemma2", "model_card": "https://www.kaggle.com/models/google/paligemma-2", }, - "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_10b_pt_896/1", + "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_pt_10b_896/1", }, - "pali_gemma2_28b_pt_224": { + "pali_gemma2_pt_28b_224": { "metadata": { "description": ( "28 billion parameter, image size 224, 27-layer for " @@ -186,9 +186,9 @@ "path": "pali_gemma2", "model_card": "https://www.kaggle.com/models/google/paligemma-2", }, - "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_28b_pt_224/1", + "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_pt_28b_224/1", }, - "pali_gemma2_28b_pt_448": { + "pali_gemma2_pt_28b_448": { "metadata": { "description": ( "28 billion parameter, image size 448, 27-layer for " @@ -201,9 +201,9 @@ "path": "pali_gemma2", "model_card": "https://www.kaggle.com/models/google/paligemma-2", }, - "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_28b_pt_448/1", + "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_pt_28b_448/1", }, - "pali_gemma2_28b_pt_896": { + "pali_gemma2_pt_28b_896": { "metadata": { "description": ( "28 billion parameter, image size 896, 27-layer for " @@ -216,6 +216,6 @@ "path": "pali_gemma2", "model_card": "https://www.kaggle.com/models/google/paligemma-2", }, - "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_28b_pt_896/1", + "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_pt_28b_896/1", }, } From 80f1aa202f74c78018edcf0da0b126019db8d3b0 Mon Sep 17 00:00:00 2001 From: Divyashree Sreepathihalli Date: Wed, 4 Dec 2024 21:06:00 -0800 Subject: [PATCH 4/5] Adding PaliGemma2 to KerasHub (#1998) * Add PaliGemma2 (#96) * Add PaliGemma2 arch * Enable mixed precision check for PaliGemma * Add conversion script * Revert ImageConverter and reduce mem usage in the conversion script * Remove `compute_output_spec` * Fix `compute_output_shape` issue for keras 3.1 * Add model cards and update conversion script * update presets --------- Co-authored-by: divyashreepathihalli * Update pali_gemma_presets.py - remove mix presets * Update pali_gemma_presets.py * Update convert_pali_gemma2_checkpoints.py --------- Co-authored-by: james77777778 <20734616+james77777778@users.noreply.github.com> --- .../models/pali_gemma/pali_gemma_backbone.py | 72 ++- .../pali_gemma/pali_gemma_backbone_test.py | 73 ++- .../pali_gemma/pali_gemma_decoder_block.py | 44 +- .../models/pali_gemma/pali_gemma_presets.py | 166 ++++++ .../src/models/pali_gemma/pali_gemma_vit.py | 23 +- .../convert_pali_gemma2_checkpoints.py | 518 ++++++++++++++++++ 6 files changed, 850 insertions(+), 46 deletions(-) create mode 100644 tools/checkpoint_conversion/convert_pali_gemma2_checkpoints.py diff --git a/keras_hub/src/models/pali_gemma/pali_gemma_backbone.py b/keras_hub/src/models/pali_gemma/pali_gemma_backbone.py index 75c1cb8ad0..6447ca2fc5 100644 --- a/keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +++ b/keras_hub/src/models/pali_gemma/pali_gemma_backbone.py @@ -48,22 +48,40 @@ class PaliGemmaBackbone(Backbone): a two-layer feedforward network for each transformer decoder block. head_dim: int. The size of each attention head in the mixed decoder. vit_patch_size: int. The size of each square patch in the input image. - vit_num_heads: int. The number of attention heads for the vision(image) + vit_num_heads: int. The number of attention heads for the vision (image) transformer encoder. vit_hidden_dim: int. The size of the transformer hidden state at the end of each vision transformer layer. vit_num_layers: int. The number of vision transformer layers. vit_intermediate_dim: int. The output dimension of the first Dense layer - in a two-layer feedforward network for vision transformer. - vit_pooling: string. The encoded vision embeddings are pooled using the - specified polling setting. The accepted values are `"map"`, `"gap"`, - `"0"` or `"none"`. Defaults to `"none"`. + in a two-layer feedforward network for vision transformer. Defaults + to `4304`. + vit_pooling: `None` or string. The encoded vision embeddings are pooled + using the specified polling setting. The accepted values are + `"map"`, `"gap"`, `"0"` or `None`. Defaults to `None`. vit_classifier_activation: activation function. The activation that is used for final output classification in the vision transformer. + Defaults to `None`. vit_name: string. The name used for vision transformer layers. + query_head_dim_normalize: boolean. If `True` normalize the query before + attention with `head_dim`. If `False`, normalize the query with + `hidden_dim / num_query_heads`. Defaults to `True`. + use_post_ffw_norm: boolean. Whether to normalize after the feedforward + block. Defaults to `False`. + use_post_attention_norm: boolean. Whether to normalize after the attention + block. Defaults to `False`. + attention_logit_soft_cap: `None` or int. Soft cap for the attention + logits. Defaults to `None`. + final_logit_soft_cap: `None` or int. Soft cap for the final logits. + Defaults to `None`. + use_sliding_window_attention: boolean. Whether to use sliding local + window attention. Defaults to `False`. + sliding_window_size: int. Size of the sliding local window. Defaults to + `4096`. layer_norm_epsilon: float. The epsilon value user for every layer norm - in all transformer blocks. + in all transformer blocks. Defaults to `1e-6`. dropout: float. Dropout probability for the Transformer decoder blocks. + Defaults to `0`. dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use for the models computations and weights. Note that some computations, such as softmax and layer normalization will always @@ -119,6 +137,13 @@ def __init__( vit_pooling=None, vit_classifier_activation=None, vit_name=None, + query_head_dim_normalize=True, + use_post_ffw_norm=False, + use_post_attention_norm=False, + attention_logit_soft_cap=None, + final_logit_soft_cap=None, + use_sliding_window_attention=False, + sliding_window_size=4096, layer_norm_epsilon=1e-6, dropout=0, dtype=None, @@ -136,6 +161,7 @@ def __init__( seed=None, ), dtype=dtype, + logit_soft_cap=final_logit_soft_cap, name="token_embedding", ) # TODO Remove this. Work around for previous serialization bug. @@ -155,12 +181,19 @@ def __init__( ) self.transformer_layers = [] for i in range(num_layers): + sliding_window = use_sliding_window_attention and (i % 2 == 0) layer = PaliGemmaDecoderBlock( hidden_dim=hidden_dim, intermediate_dim=intermediate_dim, - num_query_heads=num_query_heads, head_dim=head_dim, + num_query_heads=num_query_heads, num_key_value_heads=num_key_value_heads, + query_head_dim_normalize=query_head_dim_normalize, + use_post_ffw_norm=use_post_ffw_norm, + use_post_attention_norm=use_post_attention_norm, + logit_soft_cap=attention_logit_soft_cap, + use_sliding_window_attention=sliding_window, + sliding_window_size=sliding_window_size, dropout=dropout, dtype=dtype, name=f"decoder_block_{i}", @@ -173,7 +206,9 @@ def __init__( ) # === Functional Model === - image_input = self.vit_encoder.inputs[0] + image_input = keras.Input( + shape=(image_size, image_size, 3), name="images" + ) token_id_input = keras.Input( shape=(None,), dtype="int32", name="token_ids" ) @@ -219,7 +254,15 @@ def __init__( self.head_dim = head_dim self.layer_norm_epsilon = layer_norm_epsilon self.dropout = dropout - # VIT Params + # Gemma2 params + self.query_head_dim_normalize = query_head_dim_normalize + self.use_post_ffw_norm = use_post_ffw_norm + self.use_post_attention_norm = use_post_attention_norm + self.attention_logit_soft_cap = attention_logit_soft_cap + self.final_logit_soft_cap = final_logit_soft_cap + self.sliding_window_size = sliding_window_size + self.use_sliding_window_attention = use_sliding_window_attention + # ViT params self.vit_patch_size = vit_patch_size self.vit_num_heads = vit_num_heads self.vit_hidden_dim = vit_hidden_dim @@ -243,8 +286,6 @@ def get_config(self): "hidden_dim": self.hidden_dim, "intermediate_dim": self.intermediate_dim, "head_dim": self.head_dim, - "layer_norm_epsilon": self.layer_norm_epsilon, - "dropout": self.dropout, "vit_patch_size": self.vit_patch_size, "vit_num_heads": self.vit_num_heads, "vit_hidden_dim": self.vit_hidden_dim, @@ -253,6 +294,15 @@ def get_config(self): "vit_pooling": self.vit_pooling, "vit_classifier_activation": self.vit_classifier_activation, "vit_name": self.vit_name, + "query_head_dim_normalize": self.query_head_dim_normalize, + "use_post_ffw_norm": self.use_post_ffw_norm, + "use_post_attention_norm": self.use_post_attention_norm, + "final_logit_soft_cap": self.final_logit_soft_cap, + "attention_logit_soft_cap": self.attention_logit_soft_cap, + "sliding_window_size": self.sliding_window_size, + "use_sliding_window_attention": self.use_sliding_window_attention, + "layer_norm_epsilon": self.layer_norm_epsilon, + "dropout": self.dropout, } ) return config diff --git a/keras_hub/src/models/pali_gemma/pali_gemma_backbone_test.py b/keras_hub/src/models/pali_gemma/pali_gemma_backbone_test.py index 4d02e075cd..af6ddbe946 100644 --- a/keras_hub/src/models/pali_gemma/pali_gemma_backbone_test.py +++ b/keras_hub/src/models/pali_gemma/pali_gemma_backbone_test.py @@ -61,7 +61,6 @@ def test_backbone_basics(self): 8, ), variable_length_data=[self.input_data], - run_mixed_precision_check=False, # TODO: Set to `True` ) @pytest.mark.large @@ -98,3 +97,75 @@ def test_all_presets(self): preset=preset, input_data=self.input_data, ) + + +class PaliGemma2BackboneTest(TestCase): + def setUp(self): + self.batch_size = 2 + self.vocabulary_size = 256 + self.text_sequence_length = 64 + self.image_size = 16 + self.image_sequence_length = int((self.image_size / 4) ** 2) + self.init_kwargs = { + "vocabulary_size": self.vocabulary_size, + "image_size": self.image_size, + "num_layers": 2, + "num_query_heads": 2, + "num_key_value_heads": 1, + "hidden_dim": 8, + "intermediate_dim": 16, + "head_dim": 4, + "vit_patch_size": 4, + "vit_num_layers": 2, + "vit_num_heads": 2, + "vit_hidden_dim": 8, + "vit_intermediate_dim": 16, + # Gemma2 + "query_head_dim_normalize": True, + "use_post_ffw_norm": True, + "use_post_attention_norm": True, + "final_logit_soft_cap": 30, + "attention_logit_soft_cap": 50, + "use_sliding_window_attention": True, + "sliding_window_size": 4096, + } + + dummy_images = np.random.rand( + self.batch_size, self.image_size, self.image_size, 3 + ) + dummy_text_token_ids = np.random.rand( + self.batch_size, self.text_sequence_length + ) + self.input_data = { + "token_ids": dummy_text_token_ids, + "images": dummy_images, + "padding_mask": np.ones( + (self.batch_size, self.text_sequence_length), + dtype="int32", + ), + "response_mask": np.zeros( + (self.batch_size, self.text_sequence_length), + dtype="int32", + ), + } + + def test_backbone_basics(self): + self.run_backbone_test( + cls=PaliGemmaBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=( + self.batch_size, + self.text_sequence_length + self.image_sequence_length, + 8, + ), + variable_length_data=[self.input_data], + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=PaliGemmaBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) diff --git a/keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py b/keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py index dd445a13c0..b59b0498e7 100644 --- a/keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +++ b/keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py @@ -31,33 +31,25 @@ class PaliGemmaDecoderBlock(GemmaDecoderBlock): the attention layer. num_key_value_heads: int. The number of heads for the key and value projections in the attention layer. + query_head_dim_normalize: boolean. If `True` normalize the query before + attention with `head_dim`. If `False`, normalize the query with + `hidden_dim / num_query_heads`. Defaults to `True`. + use_post_ffw_norm: boolean. Whether to normalize after the feedforward + block. Defaults to `False`. + use_post_attention_norm: boolean. Whether to normalize after the + attention block. Defaults to `False`. + logit_soft_cap: `None` or int. Soft cap for the attention logits. + Defaults to `None`. + use_sliding_window_attention: boolean. Whether to use sliding local + window attention. Defaults to `False`. + sliding_window_size: int. Size of the sliding local window. Defaults to + `4096`. layer_norm_epsilon: float. The epsilon hyperparameter used for layer - normalization. + normalization. Defaults to `1e-6`. dropout: float. The dropout rate for the transformer attention layer. + Defaults to `0`. """ - def __init__( - self, - hidden_dim, - intermediate_dim, - head_dim, - num_query_heads, - num_key_value_heads, - layer_norm_epsilon=1e-6, - dropout=0, - **kwargs, - ): - super().__init__( - hidden_dim=hidden_dim, - intermediate_dim=intermediate_dim, - head_dim=head_dim, - num_query_heads=num_query_heads, - num_key_value_heads=num_key_value_heads, - layer_norm_epsilon=layer_norm_epsilon, - dropout=dropout, - **kwargs, - ) - def call( self, x, @@ -83,6 +75,9 @@ def call( attention_mask=attention_mask, ) + if self.use_post_attention_norm: + attention = self.post_attention_norm(attention) + if self.dropout: attention = self.attention_dropout(attention) @@ -94,6 +89,9 @@ def call( x = keras.activations.gelu(x1, approximate=True) * x2 x = self.ffw_linear(x) + if self.use_post_ffw_norm: + x = self.post_ffw_norm(x) + x = x + attention_x if cache is not None: diff --git a/keras_hub/src/models/pali_gemma/pali_gemma_presets.py b/keras_hub/src/models/pali_gemma/pali_gemma_presets.py index eb27596469..056d949d80 100644 --- a/keras_hub/src/models/pali_gemma/pali_gemma_presets.py +++ b/keras_hub/src/models/pali_gemma/pali_gemma_presets.py @@ -52,4 +52,170 @@ }, "kaggle_handle": "kaggle://keras/paligemma/keras/pali_gemma_3b_896/3", }, + # PaliGemma2 + "pali_gemma2_3b_ft_docci_448": { + "metadata": { + "description": ( + "3 billion parameter, image size 448, 27-layer for " + "SigLIP-So400m vision encoder and 26-layer Gemma2 2B lanuage " + "model. This model has been fine-tuned on the DOCCI dataset " + "for improved descriptions with fine-grained details." + ), + "params": 3032979696, + "official_name": "PaliGemma2", + "path": "pali_gemma2", + "model_card": "https://www.kaggle.com/models/google/paligemma-2", + }, + "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_3b_ft_docci_448/1", + }, + "pali_gemma2_10b_ft_docci_448": { + "metadata": { + "description": ( + "10 billion parameter, 27-layer for SigLIP-So400m vision " + "encoder and 42-layer Gemma2 9B lanuage model. This model has " + "been fine-tuned on the DOCCI dataset for improved " + "descriptions with fine-grained details." + ), + "params": 9663294192, + "official_name": "PaliGemma2", + "path": "pali_gemma2", + "model_card": "https://www.kaggle.com/models/google/paligemma-2", + }, + "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_10b_ft_docci_448/1", + }, + "pali_gemma2_3b_pt_224": { + "metadata": { + "description": ( + "3 billion parameter, image size 224, 27-layer for " + "SigLIP-So400m vision encoder and 26-layer Gemma2 2B lanuage " + "model. This model has been pre-trained on a mixture of " + "datasets." + ), + "params": 3032094960, + "official_name": "PaliGemma2", + "path": "pali_gemma2", + "model_card": "https://www.kaggle.com/models/google/paligemma-2", + }, + "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_3b_pt_224/1", + }, + "pali_gemma2_3b_pt_448": { + "metadata": { + "description": ( + "3 billion parameter, image size 448, 27-layer for " + "SigLIP-So400m vision encoder and 26-layer Gemma2 2B lanuage " + "model. This model has been pre-trained on a mixture of " + "datasets." + ), + "params": 3032979696, + "official_name": "PaliGemma2", + "path": "pali_gemma2", + "model_card": "https://www.kaggle.com/models/google/paligemma-2", + }, + "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_3b_pt_448/1", + }, + "pali_gemma2_3b_pt_896": { + "metadata": { + "description": ( + "3 billion parameter, image size 896, 27-layer for " + "SigLIP-So400m vision encoder and 26-layer Gemma2 2B lanuage " + "model. This model has been pre-trained on a mixture of " + "datasets." + ), + "params": 3036518640, + "official_name": "PaliGemma2", + "path": "pali_gemma2", + "model_card": "https://www.kaggle.com/models/google/paligemma-2", + }, + "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_3b_pt_896/1", + }, + "pali_gemma2_10b_pt_224": { + "metadata": { + "description": ( + "10 billion parameter, image size 224, 27-layer for " + "SigLIP-So400m vision encoder and 42-layer Gemma2 9B lanuage " + "model. This model has been pre-trained on a mixture of " + "datasets." + ), + "params": 9662409456, + "official_name": "PaliGemma2", + "path": "pali_gemma2", + "model_card": "https://www.kaggle.com/models/google/paligemma-2", + }, + "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_10b_pt_224/1", + }, + "pali_gemma2_10b_pt_448": { + "metadata": { + "description": ( + "10 billion parameter, image size 448, 27-layer for " + "SigLIP-So400m vision encoder and 42-layer Gemma2 9B lanuage " + "model. This model has been pre-trained on a mixture of " + "datasets." + ), + "params": 9663294192, + "official_name": "PaliGemma2", + "path": "pali_gemma2", + "model_card": "https://www.kaggle.com/models/google/paligemma-2", + }, + "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_10b_pt_448/1", + }, + "pali_gemma2_10b_pt_896": { + "metadata": { + "description": ( + "10 billion parameter, image size 896, 27-layer for " + "SigLIP-So400m vision encoder and 42-layer Gemma2 9B lanuage " + "model. This model has been pre-trained on a mixture of " + "datasets." + ), + "params": 9666833136, + "official_name": "PaliGemma2", + "path": "pali_gemma2", + "model_card": "https://www.kaggle.com/models/google/paligemma-2", + }, + "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_10b_pt_896/1", + }, + "pali_gemma2_28b_pt_224": { + "metadata": { + "description": ( + "28 billion parameter, image size 224, 27-layer for " + "SigLIP-So400m vision encoder and 46-layer Gemma2 27B lanuage " + "model. This model has been pre-trained on a mixture of " + "datasets." + ), + "params": 9662409456, + "official_name": "PaliGemma2", + "path": "pali_gemma2", + "model_card": "https://www.kaggle.com/models/google/paligemma-2", + }, + "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_28b_pt_224/1", + }, + "pali_gemma2_28b_pt_448": { + "metadata": { + "description": ( + "28 billion parameter, image size 448, 27-layer for " + "SigLIP-So400m vision encoder and 46-layer Gemma2 27B lanuage " + "model. This model has been pre-trained on a mixture of " + "datasets." + ), + "params": 9663294192, + "official_name": "PaliGemma2", + "path": "pali_gemma2", + "model_card": "https://www.kaggle.com/models/google/paligemma-2", + }, + "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_28b_pt_448/1", + }, + "pali_gemma2_28b_pt_896": { + "metadata": { + "description": ( + "28 billion parameter, image size 896, 27-layer for " + "SigLIP-So400m vision encoder and 46-layer Gemma2 27B lanuage " + "model. This model has been pre-trained on a mixture of " + "datasets." + ), + "params": 9666833136, + "official_name": "PaliGemma2", + "path": "pali_gemma2", + "model_card": "https://www.kaggle.com/models/google/paligemma-2", + }, + "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_28b_pt_896/1", + }, } diff --git a/keras_hub/src/models/pali_gemma/pali_gemma_vit.py b/keras_hub/src/models/pali_gemma/pali_gemma_vit.py index 190a5e8e13..bed2b3ea66 100644 --- a/keras_hub/src/models/pali_gemma/pali_gemma_vit.py +++ b/keras_hub/src/models/pali_gemma/pali_gemma_vit.py @@ -12,7 +12,7 @@ def __init__( dtype=None, **kwargs, ): - super().__init__(**kwargs) + super().__init__(dtype=dtype, **kwargs) self.hidden_dim = hidden_dim self.image_size = image_size self.patch_size = patch_size @@ -72,7 +72,7 @@ def __init__( dtype=None, **kwargs, ): - super().__init__(**kwargs) + super().__init__(dtype=dtype, **kwargs) self.hidden_dim = hidden_dim self.num_heads = num_heads @@ -282,7 +282,7 @@ def __init__( dtype=None, **kwargs, ): - super().__init__(**kwargs) + super().__init__(dtype=dtype, **kwargs) self.hidden_dim = hidden_dim self.num_layers = num_layers self.num_heads = num_heads @@ -311,25 +311,26 @@ def __init__( for i in range(self.num_layers) ] - def build(self, input_shape): - self.vision_embeddings.build(input_shape) + def build(self, inputs_shape): + self.vision_embeddings.build(inputs_shape) for block in self.resblocks: block.build([None, None, self.hidden_dim]) self.encoder_layer_norm.build([None, None, self.hidden_dim]) self.built = True - def call( - self, - x, - mask=None, - ): - x = self.vision_embeddings(x) + def call(self, inputs, mask=None): + x = self.vision_embeddings(inputs) for block in self.resblocks: x = block(x, mask=mask) x = self.encoder_layer_norm(x) return x def compute_output_shape(self, inputs_shape): + if inputs_shape is None: + # Fix the compatibility issue with Keras 3.1 where + # `compute_output_spec` fails to propagate `inputs_shape` + # correctly, causing it to be `None`. + inputs_shape = [None, None, None] return [inputs_shape[0], inputs_shape[1], self.hidden_dim] def get_config(self): diff --git a/tools/checkpoint_conversion/convert_pali_gemma2_checkpoints.py b/tools/checkpoint_conversion/convert_pali_gemma2_checkpoints.py new file mode 100644 index 0000000000..596fca6063 --- /dev/null +++ b/tools/checkpoint_conversion/convert_pali_gemma2_checkpoints.py @@ -0,0 +1,518 @@ +""" +Convert PaliGemma2 checkpoints to the Keras format. + +The checkpoints are from here: +https://www.kaggle.com/models/google/paligemma-2 + +The `vocabulary.spm` is from here: +https://www.kaggle.com/models/keras/paligemma/ + +Setup: + +```shell +pip install kaggle +export KAGGLE_USERNAME=... +export KAGGLE_KEY=... +``` + +Usage: + +```shell +python -m tools.checkpoint_conversion.convert_pali_gemma2_checkpoints --preset pali_gemma2_3b_pt_224 +python -m tools.checkpoint_conversion.convert_pali_gemma2_checkpoints --preset pali_gemma2_3b_pt_224 --weights_path ./path/to/weights.npz +python -m tools.checkpoint_conversion.convert_pali_gemma2_checkpoints --preset pali_gemma2_3b_pt_224 --proto_path ./path/to/vocabulary.spm +python -m tools.checkpoint_conversion.convert_pali_gemma2_checkpoints --preset pali_gemma2_3b_pt_224 --upload_uri kaggle://divyasss/hongyu_sharing/keras/pali_gemma2_3b_pt_224 +``` +""" + +import io +import os +import pathlib + +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" +os.environ["KERAS_BACKEND"] = "jax" +# No GPU for conversion, makes memory management easier. +os.environ["CUDA_VISIBLE_DEVICES"] = "-1" +os.environ["JAX_PLATFORMS"] = "cpu" + +import jax # noqa: E402 +import kagglehub # noqa: E402 +import keras # noqa: E402 +import ml_dtypes # noqa: E402 +import numpy as np # noqa: E402 +import PIL # noqa: E402 +import requests # noqa: E402 +from absl import app # noqa: E402 +from absl import flags # noqa: E402 + +import keras_hub # noqa: E402 + +FLAGS = flags.FLAGS + +PRESET_MAP = { + "pali_gemma2_3b_ft_docci_448": "google/paligemma-2/jax/paligemma2-3b-ft-docci-448", + "pali_gemma2_10b_ft_docci_448": "google/paligemma-2/jax/paligemma2-10b-ft-docci-448", + "pali_gemma2_3b_pt_224": "google/paligemma-2/jax/paligemma2-3b-pt-224", + "pali_gemma2_3b_pt_448": "google/paligemma-2/jax/paligemma2-3b-pt-448", + "pali_gemma2_3b_pt_896": "google/paligemma-2/jax/paligemma2-3b-pt-896", + "pali_gemma2_10b_pt_224": "google/paligemma-2/jax/paligemma2-10b-pt-224", + "pali_gemma2_10b_pt_448": "google/paligemma-2/jax/paligemma2-10b-pt-448", + "pali_gemma2_10b_pt_896": "google/paligemma-2/jax/paligemma2-10b-pt-896", + "pali_gemma2_28b_pt_224": "google/paligemma-2/jax/paligemma2-28b-pt-224", + "pali_gemma2_28b_pt_448": "google/paligemma-2/jax/paligemma2-28b-pt-448", + "pali_gemma2_28b_pt_896": "google/paligemma-2/jax/paligemma2-28b-pt-896", +} + + +flags.DEFINE_string( + "preset", + None, + f'Must be one of {",".join(PRESET_MAP.keys())}', + required=True, +) +flags.DEFINE_string( + "weights_path", + None, + "Optional path for the model weights to convert.", +) +flags.DEFINE_string( + "proto_path", + "vocabulary.spm", + "Optional path for the SentencePiece proto file of the tokenizer.", +) +flags.DEFINE_string( + "upload_uri", + None, + 'Could be "kaggle://keras/{variant}/keras/{preset}"', + required=False, +) + + +def format_weights(weights): + def recover_dtype(a): + """Numpy's stores bfloat16 type as "void" type, so we recover it.""" + if hasattr(a, "dtype") and a.dtype.type is np.void: + assert a.itemsize == 2, "Unknown dtype!" + return a.view(ml_dtypes.bfloat16) + else: + return a + + weights = dict(weights) + weights = jax.tree.map(recover_dtype, weights) + + formatted = {} + + # LLM part + prefix = "params/llm" + num_layers = int(weights[f"{prefix}/layers/mlp/linear"].shape[0]) + formatted["llm/embedding"] = weights[f"{prefix}/embedder/input_embedding"] + for i in range(num_layers): + layer_prefix = f"{prefix}/layers" + formatted_prefix = f"llm/decoder_block_{i}" + # RMSNorm + formatted[f"{formatted_prefix}/pre_norm/scale"] = weights[ + f"{layer_prefix}/pre_attention_norm/scale" + ][i] + formatted[f"{formatted_prefix}/post_norm/scale"] = weights[ + f"{layer_prefix}/post_attention_norm/scale" + ][i] + formatted[f"{formatted_prefix}/pre_ffw_norm/scale"] = weights[ + f"{layer_prefix}/pre_ffw_norm/scale" + ][i] + formatted[f"{formatted_prefix}/post_ffw_norm/scale"] = weights[ + f"{layer_prefix}/post_ffw_norm/scale" + ][i] + # MHA + formatted[f"{formatted_prefix}/mha/q/kernel"] = weights[ + f"{layer_prefix}/attn/q_einsum/w" + ][i] + formatted[f"{formatted_prefix}/mha/k/kernel"] = weights[ + f"{layer_prefix}/attn/kv_einsum/w" + ][i, 0] + formatted[f"{formatted_prefix}/mha/v/kernel"] = weights[ + f"{layer_prefix}/attn/kv_einsum/w" + ][i, 1] + formatted[f"{formatted_prefix}/mha/o/kernel"] = weights[ + f"{layer_prefix}/attn/attn_vec_einsum/w" + ][i] + # MLP + formatted[f"{formatted_prefix}/ffw_gating/kernel"] = weights[ + f"{layer_prefix}/mlp/gating_einsum" + ][i, 0] + formatted[f"{formatted_prefix}/ffw_gating_2/kernel"] = weights[ + f"{layer_prefix}/mlp/gating_einsum" + ][i, 1] + formatted[f"{formatted_prefix}/ffw_linear/kernel"] = weights[ + f"{layer_prefix}/mlp/linear" + ][i] + formatted["llm/final_normalization/scale"] = weights[ + f"{prefix}/final_norm/scale" + ] + + # ViT part + prefix = "params/img" + num_layers = int( + weights[f"{prefix}/Transformer/encoderblock/LayerNorm_1/scale"].shape[0] + ) + formatted["img/embedding/kernel"] = weights[f"{prefix}/embedding/kernel"] + formatted["img/embedding/bias"] = weights[f"{prefix}/embedding/bias"] + formatted["img/embedding/pos"] = weights[f"{prefix}/pos_embedding"] + formatted["img/ln/gamma"] = weights[ + f"{prefix}/Transformer/encoder_norm/scale" + ] + formatted["img/ln/beta"] = weights[ + f"{prefix}/Transformer/encoder_norm/bias" + ] + for i in range(num_layers): + encoder_prefix = f"{prefix}/Transformer/encoderblock" + formatted_prefix = f"img/encoder_block_{i}" + # MHA + formatted[f"{formatted_prefix}/mha/q/kernel"] = weights[ + f"{encoder_prefix}/MultiHeadDotProductAttention_0/query/kernel" + ][i] + formatted[f"{formatted_prefix}/mha/q/bias"] = weights[ + f"{encoder_prefix}/MultiHeadDotProductAttention_0/query/bias" + ][i] + formatted[f"{formatted_prefix}/mha/k/kernel"] = weights[ + f"{encoder_prefix}/MultiHeadDotProductAttention_0/key/kernel" + ][i] + formatted[f"{formatted_prefix}/mha/k/bias"] = weights[ + f"{encoder_prefix}/MultiHeadDotProductAttention_0/key/bias" + ][i] + formatted[f"{formatted_prefix}/mha/v/kernel"] = weights[ + f"{encoder_prefix}/MultiHeadDotProductAttention_0/value/kernel" + ][i] + formatted[f"{formatted_prefix}/mha/v/bias"] = weights[ + f"{encoder_prefix}/MultiHeadDotProductAttention_0/value/bias" + ][i] + formatted[f"{formatted_prefix}/mha/o/kernel"] = weights[ + f"{encoder_prefix}/MultiHeadDotProductAttention_0/out/kernel" + ][i] + formatted[f"{formatted_prefix}/mha/o/bias"] = weights[ + f"{encoder_prefix}/MultiHeadDotProductAttention_0/out/bias" + ][i] + # LN 0 + formatted[f"{formatted_prefix}/ln_0/gamma"] = weights[ + f"{encoder_prefix}/LayerNorm_0/scale" + ][i] + formatted[f"{formatted_prefix}/ln_0/beta"] = weights[ + f"{encoder_prefix}/LayerNorm_0/bias" + ][i] + # MLP + formatted[f"{formatted_prefix}/mlp_1/kernel"] = weights[ + f"{encoder_prefix}/MlpBlock_0/Dense_0/kernel" + ][i] + formatted[f"{formatted_prefix}/mlp_1/bias"] = weights[ + f"{encoder_prefix}/MlpBlock_0/Dense_0/bias" + ][i] + formatted[f"{formatted_prefix}/mlp_2/kernel"] = weights[ + f"{encoder_prefix}/MlpBlock_0/Dense_1/kernel" + ][i] + formatted[f"{formatted_prefix}/mlp_2/bias"] = weights[ + f"{encoder_prefix}/MlpBlock_0/Dense_1/bias" + ][i] + # LN 1 + formatted[f"{formatted_prefix}/ln_1/gamma"] = weights[ + f"{encoder_prefix}/LayerNorm_1/scale" + ][i] + formatted[f"{formatted_prefix}/ln_1/beta"] = weights[ + f"{encoder_prefix}/LayerNorm_1/bias" + ][i] + formatted["img/head/kernel"] = weights[f"{prefix}/head/kernel"] + formatted["img/head/bias"] = weights[f"{prefix}/head/bias"] + return formatted + + +def convert_tokenizer(proto_path): + return keras_hub.models.PaliGemmaTokenizer(proto=proto_path) + + +def convert_image_converter(image_size): + return keras_hub.layers.PaliGemmaImageConverter( + image_size=(image_size, image_size), + scale=1.0 / 127.5, + offset=-1, + ) + + +def convert_model(preset): + model_config = { + "vocabulary_size": 257152, + "vit_patch_size": 14, + "vit_num_heads": 16, + "vit_hidden_dim": 1152, + "vit_num_layers": 27, + # Gemma2 + "query_head_dim_normalize": True, + "use_post_ffw_norm": True, + "use_post_attention_norm": True, + "final_logit_soft_cap": 30, + "attention_logit_soft_cap": 50, + "use_sliding_window_attention": True, + "sliding_window_size": 4096, + } + preset = str(preset) + + # 2B, 10B, 28B + if "_3b_" in preset: + model_config.update( + { + "num_layers": 26, + "num_query_heads": 8, + "num_key_value_heads": 4, + "hidden_dim": 2304, + "intermediate_dim": 18432, + "head_dim": 256, + } + ) + elif "_10b_" in preset: + model_config.update( + { + "num_layers": 42, + "num_query_heads": 16, + "num_key_value_heads": 8, + "hidden_dim": 3584, + "intermediate_dim": 28672, + "head_dim": 256, + } + ) + elif "_28b_" in preset: + model_config.update( + { + "num_layers": 46, + "num_query_heads": 32, + "num_key_value_heads": 16, + "hidden_dim": 4608, + "intermediate_dim": 73728, + "head_dim": 128, + "query_head_dim_normalize": False, # Only for 28B + } + ) + + # Image size + image_size = int(preset.split("_")[-1]) + model_config.update({"image_size": image_size}) + + return keras_hub.models.PaliGemmaBackbone(**model_config) + + +def convert_weights(keras_model, weights): + from keras_hub.src.models.pali_gemma.pali_gemma_decoder_block import ( + PaliGemmaDecoderBlock, + ) + from keras_hub.src.models.pali_gemma.pali_gemma_vit import ( + PaliGemmaVitEncoder, + ) + from keras_hub.src.models.pali_gemma.pali_gemma_vit import ( + PaliGemmaVitEncoderBlock, + ) + + if not isinstance(keras_model, keras_hub.models.PaliGemmaBackbone): + raise ValueError( + "`keras_model` must be a `keras_hub.models.PaliGemmaBackbone`. " + f"Received: keras_model={keras_model} of type {type(keras_model)}" + ) + + # LLM part + keras_model.token_embedding.embeddings.assign(weights["llm/embedding"]) + for i, layer in enumerate(keras_model.transformer_layers): + if not isinstance(layer, PaliGemmaDecoderBlock): + raise ValueError + prefix = f"llm/decoder_block_{i}" + # RMSNorm + layer.pre_attention_norm.scale.assign( + weights[f"{prefix}/pre_norm/scale"] + ) + layer.post_attention_norm.scale.assign( + weights[f"{prefix}/post_norm/scale"] + ) + layer.pre_ffw_norm.scale.assign(weights[f"{prefix}/pre_ffw_norm/scale"]) + layer.post_ffw_norm.scale.assign( + weights[f"{prefix}/post_ffw_norm/scale"] + ) + # MHA + layer.attention.query_dense.kernel.assign( + weights[f"{prefix}/mha/q/kernel"] + ) + layer.attention.key_dense.kernel.assign( + weights[f"{prefix}/mha/k/kernel"] + ) + layer.attention.value_dense.kernel.assign( + weights[f"{prefix}/mha/v/kernel"] + ) + layer.attention.output_dense.kernel.assign( + weights[f"{prefix}/mha/o/kernel"] + ) + # MLP + layer.gating_ffw.kernel.assign(weights[f"{prefix}/ffw_gating/kernel"]) + layer.gating_ffw_2.kernel.assign( + weights[f"{prefix}/ffw_gating_2/kernel"] + ) + layer.ffw_linear.kernel.assign(weights[f"{prefix}/ffw_linear/kernel"]) + keras_model.layer_norm.scale.assign( + weights["llm/final_normalization/scale"] + ) + + # ViT part + vit_encoder = keras_model.vit_encoder.get_layer("image_encoder") + if not isinstance(vit_encoder, PaliGemmaVitEncoder): + raise ValueError + vit_encoder.encoder_layer_norm.gamma.assign(weights["img/ln/gamma"]) + vit_encoder.encoder_layer_norm.beta.assign(weights["img/ln/beta"]) + vit_encoder.vision_embeddings.patch_embedding.kernel.assign( + weights["img/embedding/kernel"] + ) + vit_encoder.vision_embeddings.patch_embedding.bias.assign( + weights["img/embedding/bias"] + ) + vit_encoder.vision_embeddings.position_embedding.embeddings.assign( + weights["img/embedding/pos"][0] + ) + for i, layer in enumerate(vit_encoder.resblocks): + if not isinstance(layer, PaliGemmaVitEncoderBlock): + raise ValueError + prefix = f"img/encoder_block_{i}" + input_dim = hidden_dim = layer.attn.hidden_dim + # MHA + layer.attn.query_proj.kernel.assign( + np.reshape( + weights[f"{prefix}/mha/q/kernel"], (input_dim, hidden_dim) + ) + ) + layer.attn.query_proj.bias.assign( + np.reshape(weights[f"{prefix}/mha/q/bias"], (-1,)) + ) + layer.attn.key_proj.kernel.assign( + np.reshape( + weights[f"{prefix}/mha/k/kernel"], (input_dim, hidden_dim) + ) + ) + layer.attn.key_proj.bias.assign( + np.reshape(weights[f"{prefix}/mha/k/bias"], (-1,)) + ) + layer.attn.value_proj.kernel.assign( + np.reshape( + weights[f"{prefix}/mha/v/kernel"], (input_dim, hidden_dim) + ) + ) + layer.attn.value_proj.bias.assign( + np.reshape(weights[f"{prefix}/mha/v/bias"], (-1,)) + ) + layer.attn.out_proj.kernel.assign( + np.reshape( + weights[f"{prefix}/mha/o/kernel"], (input_dim, hidden_dim) + ) + ) + layer.attn.out_proj.bias.assign(weights[f"{prefix}/mha/o/bias"]) + # LN 0 + layer.layer_norm_1.gamma.assign(weights[f"{prefix}/ln_0/gamma"]) + layer.layer_norm_1.beta.assign(weights[f"{prefix}/ln_0/beta"]) + # MLP + layer.mlp_dense_1.kernel.assign(weights[f"{prefix}/mlp_1/kernel"]) + layer.mlp_dense_1.bias.assign(weights[f"{prefix}/mlp_1/bias"]) + layer.mlp_dense_2.kernel.assign(weights[f"{prefix}/mlp_2/kernel"]) + layer.mlp_dense_2.bias.assign(weights[f"{prefix}/mlp_2/bias"]) + # LN 1 + layer.layer_norm_2.gamma.assign(weights[f"{prefix}/ln_1/gamma"]) + layer.layer_norm_2.beta.assign(weights[f"{prefix}/ln_1/beta"]) + vit_classifier = keras_model.vit_encoder.get_layer("image_classifier") + if not isinstance(vit_classifier, keras.layers.Dense): + raise ValueError + vit_classifier.kernel.assign(weights["img/head/kernel"]) + vit_classifier.bias.assign(weights["img/head/bias"]) + + return keras_model + + +def validate_output(keras_model, keras_tokenizer, keras_image_converter): + def read_image(url): + contents = io.BytesIO(requests.get(url).content) + image = PIL.Image.open(contents) + image = np.array(image).astype("float32") + # Remove alpha channel if neccessary. + if image.shape[2] == 4: + image = image[:, :, :3] + return image + + image = read_image( + "https://storage.googleapis.com/keras-cv/models/paligemma/cow_beach_1.png" + ) + prompt = "answer en where is the cow standing?\n" + max_length = 32 + preprocessor = keras_hub.models.PaliGemmaCausalLMPreprocessor( + tokenizer=keras_tokenizer, image_converter=keras_image_converter + ) + pali_gemma_lm = keras_hub.models.PaliGemmaCausalLM( + preprocessor=preprocessor, backbone=keras_model + ) + keras_output = pali_gemma_lm.generate( + inputs={"images": image, "prompts": prompt}, max_length=max_length + ) + keras_output = str(keras_output).replace(prompt, "") + print("🔶 Prompt:", prompt.replace("\n", "")) + print("🔶 KerasHub output:", keras_output) + + # TODO: Verify numerics with JAX model. + + +def main(_): + preset = str(FLAGS.preset) + print(f"🏃 Coverting {preset}") + + # Currently all weights are bfloat16 (and have much faster download times + # for it). We follow suit with Keras weights. + keras.config.set_floatx("bfloat16") + + if FLAGS.weights_path is not None: + weights_path = pathlib.Path(FLAGS.weights_path) + else: + presets = PRESET_MAP.keys() + if preset not in presets: + raise ValueError( + f"Invalid preset {preset}. Must be one of {list(presets)}" + ) + handle = PRESET_MAP[preset] + model_dir = kagglehub.model_download(handle) + print("✅ JAX model downloaded from kaggle") + + files = list(pathlib.Path(model_dir).glob("*.npz")) + if len(files) != 1: + raise ValueError( + f"Found too many files in {model_dir}. Expected only one file. " + f"Recevied: {files}" + ) + weights_path = files[0] + + weights = np.load(weights_path, allow_pickle=False) + weights = format_weights(weights) + image_size = int(preset.split("_")[-1]) + print("✅ JAX model weights loaded") + + keras_tokenizer = convert_tokenizer(FLAGS.proto_path) + keras_image_converter = convert_image_converter(image_size) + keras_model = convert_model(preset) + print("✅ Keras model loaded") + + convert_weights(keras_model, weights) + del weights + print("✅ Weights converted") + + validate_output(keras_model, keras_tokenizer, keras_image_converter) + print("✅ Output validated") + + keras_model.save_to_preset(preset) + keras_tokenizer.save_to_preset(preset) + keras_image_converter.save_to_preset(preset) + del keras_model + del keras_tokenizer + del keras_image_converter + print(f"🏁 Preset saved to ./{preset}") + + upload_uri = FLAGS.upload_uri + if upload_uri: + keras_hub.upload_preset(uri=upload_uri, preset=f"./{preset}") + print(f"🏁 Preset uploaded to {upload_uri}") + + +if __name__ == "__main__": + app.run(main) From 95d6959d0bcb7e117e9453bc4fdb2f583a9d2da7 Mon Sep 17 00:00:00 2001 From: Divyashree Sreepathihalli Date: Wed, 4 Dec 2024 21:50:28 -0800 Subject: [PATCH 5/5] Update pali_gemma_presets.py (#2003) * Update pali_gemma_presets.py * code reformat --- .../models/pali_gemma/pali_gemma_presets.py | 36 +++++++++---------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/keras_hub/src/models/pali_gemma/pali_gemma_presets.py b/keras_hub/src/models/pali_gemma/pali_gemma_presets.py index 056d949d80..ffcf3ecafd 100644 --- a/keras_hub/src/models/pali_gemma/pali_gemma_presets.py +++ b/keras_hub/src/models/pali_gemma/pali_gemma_presets.py @@ -83,7 +83,7 @@ }, "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_10b_ft_docci_448/1", }, - "pali_gemma2_3b_pt_224": { + "pali_gemma2_pt_3b_224": { "metadata": { "description": ( "3 billion parameter, image size 224, 27-layer for " @@ -96,9 +96,9 @@ "path": "pali_gemma2", "model_card": "https://www.kaggle.com/models/google/paligemma-2", }, - "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_3b_pt_224/1", + "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_pt_3b_224/1", }, - "pali_gemma2_3b_pt_448": { + "pali_gemma2_pt_3b_448": { "metadata": { "description": ( "3 billion parameter, image size 448, 27-layer for " @@ -111,9 +111,9 @@ "path": "pali_gemma2", "model_card": "https://www.kaggle.com/models/google/paligemma-2", }, - "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_3b_pt_448/1", + "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_pt_3b_448/1", }, - "pali_gemma2_3b_pt_896": { + "pali_gemma2_pt_3b_896": { "metadata": { "description": ( "3 billion parameter, image size 896, 27-layer for " @@ -126,9 +126,9 @@ "path": "pali_gemma2", "model_card": "https://www.kaggle.com/models/google/paligemma-2", }, - "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_3b_pt_896/1", + "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_pt_3b_896/1", }, - "pali_gemma2_10b_pt_224": { + "pali_gemma2_pt_10b_224": { "metadata": { "description": ( "10 billion parameter, image size 224, 27-layer for " @@ -141,9 +141,9 @@ "path": "pali_gemma2", "model_card": "https://www.kaggle.com/models/google/paligemma-2", }, - "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_10b_pt_224/1", + "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_pt_10b_224/1", }, - "pali_gemma2_10b_pt_448": { + "pali_gemma2_pt_10b_448": { "metadata": { "description": ( "10 billion parameter, image size 448, 27-layer for " @@ -156,9 +156,9 @@ "path": "pali_gemma2", "model_card": "https://www.kaggle.com/models/google/paligemma-2", }, - "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_10b_pt_448/1", + "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_pt_10b_448/1", }, - "pali_gemma2_10b_pt_896": { + "pali_gemma2_pt_10b_896": { "metadata": { "description": ( "10 billion parameter, image size 896, 27-layer for " @@ -171,9 +171,9 @@ "path": "pali_gemma2", "model_card": "https://www.kaggle.com/models/google/paligemma-2", }, - "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_10b_pt_896/1", + "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_pt_10b_896/1", }, - "pali_gemma2_28b_pt_224": { + "pali_gemma2_pt_28b_224": { "metadata": { "description": ( "28 billion parameter, image size 224, 27-layer for " @@ -186,9 +186,9 @@ "path": "pali_gemma2", "model_card": "https://www.kaggle.com/models/google/paligemma-2", }, - "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_28b_pt_224/1", + "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_pt_28b_224/1", }, - "pali_gemma2_28b_pt_448": { + "pali_gemma2_pt_28b_448": { "metadata": { "description": ( "28 billion parameter, image size 448, 27-layer for " @@ -201,9 +201,9 @@ "path": "pali_gemma2", "model_card": "https://www.kaggle.com/models/google/paligemma-2", }, - "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_28b_pt_448/1", + "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_pt_28b_448/1", }, - "pali_gemma2_28b_pt_896": { + "pali_gemma2_pt_28b_896": { "metadata": { "description": ( "28 billion parameter, image size 896, 27-layer for " @@ -216,6 +216,6 @@ "path": "pali_gemma2", "model_card": "https://www.kaggle.com/models/google/paligemma-2", }, - "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_28b_pt_896/1", + "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_pt_28b_896/1", }, }