diff --git a/keras_hub/api/layers/__init__.py b/keras_hub/api/layers/__init__.py index 0fe7b300fa..8471f4a2a2 100644 --- a/keras_hub/api/layers/__init__.py +++ b/keras_hub/api/layers/__init__.py @@ -49,6 +49,7 @@ from keras_hub.src.models.sam.sam_image_converter import SAMImageConverter from keras_hub.src.models.sam.sam_mask_decoder import SAMMaskDecoder from keras_hub.src.models.sam.sam_prompt_encoder import SAMPromptEncoder +from keras_hub.src.models.vgg.vgg_image_classifier import VGGImageConverter from keras_hub.src.models.whisper.whisper_audio_converter import ( WhisperAudioConverter, ) diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index 1450ddceb3..fd0c2e461f 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -288,6 +288,9 @@ from keras_hub.src.models.text_to_image import TextToImage from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone from keras_hub.src.models.vgg.vgg_image_classifier import VGGImageClassifier +from keras_hub.src.models.vgg.vgg_image_classifier import ( + VGGImageClassifierPreprocessor, +) from keras_hub.src.models.vit_det.vit_det_backbone import ViTDetBackbone from keras_hub.src.models.whisper.whisper_backbone import WhisperBackbone from keras_hub.src.models.whisper.whisper_tokenizer import WhisperTokenizer diff --git a/keras_hub/src/models/vgg/__init__.py b/keras_hub/src/models/vgg/__init__.py index e69de29bb2..076cf83025 100644 --- a/keras_hub/src/models/vgg/__init__.py +++ b/keras_hub/src/models/vgg/__init__.py @@ -0,0 +1 @@ +from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone diff --git a/keras_hub/src/models/vgg/vgg_backbone.py b/keras_hub/src/models/vgg/vgg_backbone.py index cf2638146e..504624a6c4 100644 --- a/keras_hub/src/models/vgg/vgg_backbone.py +++ b/keras_hub/src/models/vgg/vgg_backbone.py @@ -47,12 +47,11 @@ def __init__( image_shape=(None, None, 3), **kwargs, ): - # === Functional Model === img_input = keras.layers.Input(shape=image_shape) x = img_input - for stack_index in range(len(stackwise_num_repeats) - 1): + for stack_index in range(len(stackwise_num_repeats)): x = apply_vgg_block( x=x, num_layers=stackwise_num_repeats[stack_index], diff --git a/keras_hub/src/models/vgg/vgg_backbone_test.py b/keras_hub/src/models/vgg/vgg_backbone_test.py index 87e9ed6ef5..19dd7844da 100644 --- a/keras_hub/src/models/vgg/vgg_backbone_test.py +++ b/keras_hub/src/models/vgg/vgg_backbone_test.py @@ -19,7 +19,7 @@ def test_backbone_basics(self): cls=VGGBackbone, init_kwargs=self.init_kwargs, input_data=self.input_data, - expected_output_shape=(2, 4, 4, 64), + expected_output_shape=(2, 2, 2, 64), run_mixed_precision_check=False, ) diff --git a/keras_hub/src/models/vgg/vgg_image_classifier.py b/keras_hub/src/models/vgg/vgg_image_classifier.py index 4d02f1ca5f..f570d05b94 100644 --- a/keras_hub/src/models/vgg/vgg_image_classifier.py +++ b/keras_hub/src/models/vgg/vgg_image_classifier.py @@ -1,11 +1,26 @@ import keras from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.preprocessing.image_converter import ImageConverter from keras_hub.src.models.image_classifier import ImageClassifier +from keras_hub.src.models.image_classifier_preprocessor import ( + ImageClassifierPreprocessor, +) from keras_hub.src.models.task import Task from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone +@keras_hub_export("keras_hub.layers.VGGImageConverter") +class VGGImageConverter(ImageConverter): + backbone_cls = VGGBackbone + + +@keras_hub_export("keras_hub.models.VGGImageClassifierPreprocessor") +class VGGImageClassifierPreprocessor(ImageClassifierPreprocessor): + backbone_cls = VGGBackbone + image_converter_cls = VGGImageConverter + + @keras_hub_export("keras_hub.models.VGGImageClassifier") class VGGImageClassifier(ImageClassifier): """VGG image classification task. @@ -96,13 +111,14 @@ class VGGImageClassifier(ImageClassifier): """ backbone_cls = VGGBackbone + preprocessor_cls = VGGImageClassifierPreprocessor def __init__( self, backbone, num_classes, preprocessor=None, - pooling="flatten", + pooling="avg", pooling_hidden_dim=4096, activation=None, dropout=0.0, @@ -141,24 +157,46 @@ def __init__( "Unknown `pooling` type. Polling should be either `'avg'` or " f"`'max'`. Received: pooling={pooling}." ) - self.output_dropout = keras.layers.Dropout( - dropout, - dtype=head_dtype, - name="output_dropout", - ) - self.output_dense = keras.layers.Dense( - num_classes, - activation=activation, - dtype=head_dtype, - name="predictions", + + self.head = keras.Sequential( + [ + keras.layers.Conv2D( + filters=4096, + kernel_size=7, + name="fc1", + activation=activation, + use_bias=True, + padding="same", + ), + keras.layers.Dropout( + rate=dropout, + dtype=head_dtype, + name="output_dropout", + ), + keras.layers.Conv2D( + filters=4096, + kernel_size=1, + name="fc2", + activation=activation, + use_bias=True, + padding="same", + ), + self.pooler, + keras.layers.Dense( + num_classes, + activation=activation, + dtype=head_dtype, + name="predictions", + ), + ], + name="head", ) # === Functional Model === inputs = self.backbone.input x = self.backbone(inputs) - x = self.pooler(x) - x = self.output_dropout(x) - outputs = self.output_dense(x) + outputs = self.head(x) + # Skip the parent class functional model. Task.__init__( self, diff --git a/keras_hub/src/utils/timm/convert_vgg.py b/keras_hub/src/utils/timm/convert_vgg.py new file mode 100644 index 0000000000..445d0ee436 --- /dev/null +++ b/keras_hub/src/utils/timm/convert_vgg.py @@ -0,0 +1,85 @@ +from typing import Any + +import numpy as np + +from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone +from keras_hub.src.models.vgg.vgg_image_classifier import VGGImageClassifier + +backbone_cls = VGGBackbone + + +REPEATS_BY_SIZE = { + "vgg11": [1, 1, 2, 2, 2], + "vgg13": [2, 2, 2, 2, 2], + "vgg16": [2, 2, 3, 3, 3], + "vgg19": [2, 2, 4, 4, 4], +} + + +def convert_backbone_config(timm_config): + architecture = timm_config["architecture"] + stackwise_num_repeats = REPEATS_BY_SIZE[architecture] + return dict( + stackwise_num_repeats=stackwise_num_repeats, + stackwise_num_filters=[64, 128, 256, 512, 512], + ) + + +def convert_conv2d( + model, + loader, + keras_layer_name: str, + hf_layer_name: str, +): + loader.port_weight( + model.get_layer(keras_layer_name).kernel, + hf_weight_key=f"{hf_layer_name}.weight", + hook_fn=lambda x, _: np.transpose(x, (2, 3, 1, 0)), + ) + loader.port_weight( + model.get_layer(keras_layer_name).bias, + hf_weight_key=f"{hf_layer_name}.bias", + ) + + +def convert_weights( + backbone: VGGBackbone, + loader, + timm_config: dict[Any], +): + architecture = timm_config["architecture"] + stackwise_num_repeats = REPEATS_BY_SIZE[architecture] + + hf_index_to_keras_layer_name = {} + layer_index = 0 + for block_index, repeats_in_block in enumerate(stackwise_num_repeats): + for repeat_index in range(repeats_in_block): + hf_index = layer_index + layer_index += 2 # Conv + activation layers. + layer_name = f"block{block_index + 1}_conv{repeat_index + 1}" + hf_index_to_keras_layer_name[hf_index] = layer_name + layer_index += 1 # Pooling layer after blocks. + + for hf_index, keras_layer_name in hf_index_to_keras_layer_name.items(): + convert_conv2d( + backbone, loader, keras_layer_name, f"features.{hf_index}" + ) + + +def convert_head( + task: VGGImageClassifier, + loader, + timm_config: dict[Any], +): + convert_conv2d(task.head, loader, "fc1", "pre_logits.fc1") + convert_conv2d(task.head, loader, "fc2", "pre_logits.fc2") + + loader.port_weight( + task.head.get_layer("predictions").kernel, + hf_weight_key="head.fc.weight", + hook_fn=lambda x, _: np.transpose(np.squeeze(x)), + ) + loader.port_weight( + task.head.get_layer("predictions").bias, + hf_weight_key="head.fc.bias", + ) diff --git a/keras_hub/src/utils/timm/preset_loader.py b/keras_hub/src/utils/timm/preset_loader.py index e5b72333e0..1524db8530 100644 --- a/keras_hub/src/utils/timm/preset_loader.py +++ b/keras_hub/src/utils/timm/preset_loader.py @@ -5,6 +5,7 @@ from keras_hub.src.utils.preset_utils import jax_memory_cleanup from keras_hub.src.utils.timm import convert_densenet from keras_hub.src.utils.timm import convert_resnet +from keras_hub.src.utils.timm import convert_vgg from keras_hub.src.utils.transformers.safetensor_utils import SafetensorLoader @@ -16,6 +17,8 @@ def __init__(self, preset, config): self.converter = convert_resnet elif "densenet" in architecture: self.converter = convert_densenet + elif "vgg" in architecture: + self.converter = convert_vgg else: raise ValueError( "KerasHub has no converter for timm models " diff --git a/tools/checkpoint_conversion/convert_vgg_checkpoints.py b/tools/checkpoint_conversion/convert_vgg_checkpoints.py new file mode 100644 index 0000000000..fea9aaf01f --- /dev/null +++ b/tools/checkpoint_conversion/convert_vgg_checkpoints.py @@ -0,0 +1,116 @@ +"""Loads an external VGG model and saves it in Keras format. + +Optionally uploads the model to Keras if the `--upload_uri` flag is passed. + +python tools/checkpoint_conversion/convert_vgg_checkpoints.py \ + --preset vgg11 --upload_uri kaggle://kerashub/vgg/keras/vgg11 +""" + +import os +import shutil + +import keras +import numpy as np +import PIL +import timm +import torch +from absl import app +from absl import flags + +import keras_hub + +PRESET_MAP = { + "vgg11": "timm/vgg11.tv_in1k", + "vgg13": "timm/vgg13.tv_in1k", + "vgg16": "timm/vgg16.tv_in1k", + "vgg19": "timm/vgg19.tv_in1k", + # TODO(jeffcarp): Add BN variants. +} + + +PRESET = flags.DEFINE_string( + "preset", + None, + "Must be a valid `VGG` preset from KerasHub", + required=True, +) +UPLOAD_URI = flags.DEFINE_string( + "upload_uri", + None, + 'Could be "kaggle://keras/{variant}/keras/{preset}_int8"', +) + + +def validate_output(keras_model, timm_model): + file = keras.utils.get_file( + origin=( + "https://storage.googleapis.com/keras-cv/" + "models/paligemma/cow_beach_1.png" + ) + ) + image = PIL.Image.open(file) + batch = np.array([image]) + + # Preprocess with Timm. + data_config = timm.data.resolve_model_data_config(timm_model) + data_config["crop_pct"] = 1.0 # Stop timm from cropping. + transforms = timm.data.create_transform(**data_config, is_training=False) + timm_preprocessed = transforms(image) + timm_preprocessed = keras.ops.transpose(timm_preprocessed, axes=(1, 2, 0)) + timm_preprocessed = keras.ops.expand_dims(timm_preprocessed, 0) + + # Preprocess with Keras. + keras_preprocessed = keras_model.preprocessor(batch) + + # Call with Timm. Use the keras preprocessed image so we can keep modeling + # and preprocessing comparisons independent. + timm_batch = keras.ops.transpose(keras_preprocessed, axes=(0, 3, 1, 2)) + timm_batch = torch.from_numpy(np.array(timm_batch)) + timm_outputs = timm_model(timm_batch).detach().numpy() + timm_label = np.argmax(timm_outputs[0]) + + # Call with Keras. + keras_outputs = keras_model.predict(batch) + keras_label = np.argmax(keras_outputs[0]) + + print("🔶 Keras output:", keras_outputs[0, :10]) + print("🔶 TIMM output:", timm_outputs[0, :10]) + print("🔶 Keras label:", keras_label) + print("🔶 TIMM label:", timm_label) + modeling_diff = np.mean(np.abs(keras_outputs - timm_outputs)) + print("🔶 Modeling difference:", modeling_diff) + preprocessing_diff = np.mean(np.abs(keras_preprocessed - timm_preprocessed)) + print("🔶 Preprocessing difference:", preprocessing_diff) + + +def main(_): + preset = PRESET.value + if os.path.exists(preset): + shutil.rmtree(preset) + os.makedirs(preset) + + timm_name = PRESET_MAP[preset] + + timm_model = timm.create_model(timm_name, pretrained=True) + timm_model = timm_model.eval() + print("✅ Loaded TIMM model.") + print(timm_model) + + keras_model = keras_hub.models.ImageClassifier.from_preset( + "hf://" + timm_name, + ) + print("✅ Loaded KerasHub model.") + + keras_model.save_to_preset(f"./{preset}") + print(f"🏁 Preset saved to ./{preset}") + + validate_output(keras_model, timm_model) + + upload_uri = UPLOAD_URI.value + if upload_uri: + keras_hub.upload_preset(uri=upload_uri, preset=f"./{preset}") + print(f"🏁 Preset uploaded to {upload_uri}") + + +if __name__ == "__main__": + app.run(main)