Skip to content

Commit

Permalink
Update VGG model to be compatible with HF and add conversion scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffcarp committed Oct 10, 2024
1 parent f25c8ff commit 335f7c5
Show file tree
Hide file tree
Showing 9 changed files with 263 additions and 17 deletions.
1 change: 1 addition & 0 deletions keras_hub/api/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from keras_hub.src.models.sam.sam_image_converter import SAMImageConverter
from keras_hub.src.models.sam.sam_mask_decoder import SAMMaskDecoder
from keras_hub.src.models.sam.sam_prompt_encoder import SAMPromptEncoder
from keras_hub.src.models.vgg.vgg_image_classifier import VGGImageConverter
from keras_hub.src.models.whisper.whisper_audio_converter import (
WhisperAudioConverter,
)
3 changes: 3 additions & 0 deletions keras_hub/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,9 @@
from keras_hub.src.models.text_to_image import TextToImage
from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone
from keras_hub.src.models.vgg.vgg_image_classifier import VGGImageClassifier
from keras_hub.src.models.vgg.vgg_image_classifier import (
VGGImageClassifierPreprocessor,
)
from keras_hub.src.models.vit_det.vit_det_backbone import ViTDetBackbone
from keras_hub.src.models.whisper.whisper_backbone import WhisperBackbone
from keras_hub.src.models.whisper.whisper_tokenizer import WhisperTokenizer
Expand Down
1 change: 1 addition & 0 deletions keras_hub/src/models/vgg/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone
3 changes: 1 addition & 2 deletions keras_hub/src/models/vgg/vgg_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,11 @@ def __init__(
image_shape=(None, None, 3),
**kwargs,
):

# === Functional Model ===
img_input = keras.layers.Input(shape=image_shape)
x = img_input

for stack_index in range(len(stackwise_num_repeats) - 1):
for stack_index in range(len(stackwise_num_repeats)):
x = apply_vgg_block(
x=x,
num_layers=stackwise_num_repeats[stack_index],
Expand Down
2 changes: 1 addition & 1 deletion keras_hub/src/models/vgg/vgg_backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_backbone_basics(self):
cls=VGGBackbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output_shape=(2, 4, 4, 64),
expected_output_shape=(2, 2, 2, 64),
run_mixed_precision_check=False,
)

Expand Down
66 changes: 52 additions & 14 deletions keras_hub/src/models/vgg/vgg_image_classifier.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,26 @@
import keras

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
from keras_hub.src.models.image_classifier import ImageClassifier
from keras_hub.src.models.image_classifier_preprocessor import (
ImageClassifierPreprocessor,
)
from keras_hub.src.models.task import Task
from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone


@keras_hub_export("keras_hub.layers.VGGImageConverter")
class VGGImageConverter(ImageConverter):
backbone_cls = VGGBackbone


@keras_hub_export("keras_hub.models.VGGImageClassifierPreprocessor")
class VGGImageClassifierPreprocessor(ImageClassifierPreprocessor):
backbone_cls = VGGBackbone
image_converter_cls = VGGImageConverter


@keras_hub_export("keras_hub.models.VGGImageClassifier")
class VGGImageClassifier(ImageClassifier):
"""VGG image classification task.
Expand Down Expand Up @@ -96,13 +111,14 @@ class VGGImageClassifier(ImageClassifier):
"""

backbone_cls = VGGBackbone
preprocessor_cls = VGGImageClassifierPreprocessor

def __init__(
self,
backbone,
num_classes,
preprocessor=None,
pooling="flatten",
pooling="avg",
pooling_hidden_dim=4096,
activation=None,
dropout=0.0,
Expand Down Expand Up @@ -141,24 +157,46 @@ def __init__(
"Unknown `pooling` type. Polling should be either `'avg'` or "
f"`'max'`. Received: pooling={pooling}."
)
self.output_dropout = keras.layers.Dropout(
dropout,
dtype=head_dtype,
name="output_dropout",
)
self.output_dense = keras.layers.Dense(
num_classes,
activation=activation,
dtype=head_dtype,
name="predictions",

self.head = keras.Sequential(
[
keras.layers.Conv2D(
filters=4096,
kernel_size=7,
name="fc1",
activation=activation,
use_bias=True,
padding="same",
),
keras.layers.Dropout(
rate=dropout,
dtype=head_dtype,
name="output_dropout",
),
keras.layers.Conv2D(
filters=4096,
kernel_size=1,
name="fc2",
activation=activation,
use_bias=True,
padding="same",
),
self.pooler,
keras.layers.Dense(
num_classes,
activation=activation,
dtype=head_dtype,
name="predictions",
),
],
name="head",
)

# === Functional Model ===
inputs = self.backbone.input
x = self.backbone(inputs)
x = self.pooler(x)
x = self.output_dropout(x)
outputs = self.output_dense(x)
outputs = self.head(x)

# Skip the parent class functional model.
Task.__init__(
self,
Expand Down
85 changes: 85 additions & 0 deletions keras_hub/src/utils/timm/convert_vgg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from typing import Any

import numpy as np

from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone
from keras_hub.src.models.vgg.vgg_image_classifier import VGGImageClassifier

backbone_cls = VGGBackbone


REPEATS_BY_SIZE = {
"vgg11": [1, 1, 2, 2, 2],
"vgg13": [2, 2, 2, 2, 2],
"vgg16": [2, 2, 3, 3, 3],
"vgg19": [2, 2, 4, 4, 4],
}


def convert_backbone_config(timm_config):
architecture = timm_config["architecture"]
stackwise_num_repeats = REPEATS_BY_SIZE[architecture]
return dict(
stackwise_num_repeats=stackwise_num_repeats,
stackwise_num_filters=[64, 128, 256, 512, 512],
)


def convert_conv2d(
model,
loader,
keras_layer_name: str,
hf_layer_name: str,
):
loader.port_weight(
model.get_layer(keras_layer_name).kernel,
hf_weight_key=f"{hf_layer_name}.weight",
hook_fn=lambda x, _: np.transpose(x, (2, 3, 1, 0)),
)
loader.port_weight(
model.get_layer(keras_layer_name).bias,
hf_weight_key=f"{hf_layer_name}.bias",
)


def convert_weights(
backbone: VGGBackbone,
loader,
timm_config: dict[Any],
):
architecture = timm_config["architecture"]
stackwise_num_repeats = REPEATS_BY_SIZE[architecture]

hf_index_to_keras_layer_name = {}
layer_index = 0
for block_index, repeats_in_block in enumerate(stackwise_num_repeats):
for repeat_index in range(repeats_in_block):
hf_index = layer_index
layer_index += 2 # Conv + activation layers.
layer_name = f"block{block_index + 1}_conv{repeat_index + 1}"
hf_index_to_keras_layer_name[hf_index] = layer_name
layer_index += 1 # Pooling layer after blocks.

for hf_index, keras_layer_name in hf_index_to_keras_layer_name.items():
convert_conv2d(
backbone, loader, keras_layer_name, f"features.{hf_index}"
)


def convert_head(
task: VGGImageClassifier,
loader,
timm_config: dict[Any],
):
convert_conv2d(task.head, loader, "fc1", "pre_logits.fc1")
convert_conv2d(task.head, loader, "fc2", "pre_logits.fc2")

loader.port_weight(
task.head.get_layer("predictions").kernel,
hf_weight_key="head.fc.weight",
hook_fn=lambda x, _: np.transpose(np.squeeze(x)),
)
loader.port_weight(
task.head.get_layer("predictions").bias,
hf_weight_key="head.fc.bias",
)
3 changes: 3 additions & 0 deletions keras_hub/src/utils/timm/preset_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from keras_hub.src.utils.preset_utils import jax_memory_cleanup
from keras_hub.src.utils.timm import convert_densenet
from keras_hub.src.utils.timm import convert_resnet
from keras_hub.src.utils.timm import convert_vgg
from keras_hub.src.utils.transformers.safetensor_utils import SafetensorLoader


Expand All @@ -16,6 +17,8 @@ def __init__(self, preset, config):
self.converter = convert_resnet
elif "densenet" in architecture:
self.converter = convert_densenet
elif "vgg" in architecture:
self.converter = convert_vgg
else:
raise ValueError(
"KerasHub has no converter for timm models "
Expand Down
116 changes: 116 additions & 0 deletions tools/checkpoint_conversion/convert_vgg_checkpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
"""Loads an external VGG model and saves it in Keras format.
Optionally uploads the model to Keras if the `--upload_uri` flag is passed.
python tools/checkpoint_conversion/convert_vgg_checkpoints.py \
--preset vgg11 --upload_uri kaggle://kerashub/vgg/keras/vgg11
"""

import os
import shutil

import keras
import numpy as np
import PIL
import timm
import torch
from absl import app
from absl import flags

import keras_hub

PRESET_MAP = {
"vgg11": "timm/vgg11.tv_in1k",
"vgg13": "timm/vgg13.tv_in1k",
"vgg16": "timm/vgg16.tv_in1k",
"vgg19": "timm/vgg19.tv_in1k",
# TODO(jeffcarp): Add BN variants.
}


PRESET = flags.DEFINE_string(
"preset",
None,
"Must be a valid `VGG` preset from KerasHub",
required=True,
)
UPLOAD_URI = flags.DEFINE_string(
"upload_uri",
None,
'Could be "kaggle://keras/{variant}/keras/{preset}_int8"',
)


def validate_output(keras_model, timm_model):
file = keras.utils.get_file(
origin=(
"https://storage.googleapis.com/keras-cv/"
"models/paligemma/cow_beach_1.png"
)
)
image = PIL.Image.open(file)
batch = np.array([image])

# Preprocess with Timm.
data_config = timm.data.resolve_model_data_config(timm_model)
data_config["crop_pct"] = 1.0 # Stop timm from cropping.
transforms = timm.data.create_transform(**data_config, is_training=False)
timm_preprocessed = transforms(image)
timm_preprocessed = keras.ops.transpose(timm_preprocessed, axes=(1, 2, 0))
timm_preprocessed = keras.ops.expand_dims(timm_preprocessed, 0)

# Preprocess with Keras.
keras_preprocessed = keras_model.preprocessor(batch)

# Call with Timm. Use the keras preprocessed image so we can keep modeling
# and preprocessing comparisons independent.
timm_batch = keras.ops.transpose(keras_preprocessed, axes=(0, 3, 1, 2))
timm_batch = torch.from_numpy(np.array(timm_batch))
timm_outputs = timm_model(timm_batch).detach().numpy()
timm_label = np.argmax(timm_outputs[0])

# Call with Keras.
keras_outputs = keras_model.predict(batch)
keras_label = np.argmax(keras_outputs[0])

print("๐Ÿ”ถ Keras output:", keras_outputs[0, :10])
print("๐Ÿ”ถ TIMM output:", timm_outputs[0, :10])
print("๐Ÿ”ถ Keras label:", keras_label)
print("๐Ÿ”ถ TIMM label:", timm_label)
modeling_diff = np.mean(np.abs(keras_outputs - timm_outputs))
print("๐Ÿ”ถ Modeling difference:", modeling_diff)
preprocessing_diff = np.mean(np.abs(keras_preprocessed - timm_preprocessed))
print("๐Ÿ”ถ Preprocessing difference:", preprocessing_diff)


def main(_):
preset = PRESET.value
if os.path.exists(preset):
shutil.rmtree(preset)
os.makedirs(preset)

timm_name = PRESET_MAP[preset]

timm_model = timm.create_model(timm_name, pretrained=True)
timm_model = timm_model.eval()
print("โœ… Loaded TIMM model.")
print(timm_model)

keras_model = keras_hub.models.ImageClassifier.from_preset(
"hf://" + timm_name,
)
print("โœ… Loaded KerasHub model.")

keras_model.save_to_preset(f"./{preset}")
print(f"๐Ÿ Preset saved to ./{preset}")

validate_output(keras_model, timm_model)

upload_uri = UPLOAD_URI.value
if upload_uri:
keras_hub.upload_preset(uri=upload_uri, preset=f"./{preset}")
print(f"๐Ÿ Preset uploaded to {upload_uri}")


if __name__ == "__main__":
app.run(main)

0 comments on commit 335f7c5

Please sign in to comment.