Skip to content

Commit

Permalink
matched mobilenetv3 inference, working now
Browse files Browse the repository at this point in the history
  • Loading branch information
pkgoogle committed Jan 17, 2025
1 parent cfe4a4f commit a5a0bb3
Show file tree
Hide file tree
Showing 9 changed files with 122 additions and 52 deletions.
24 changes: 0 additions & 24 deletions keras_hub/src/models/image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,6 @@ def __init__(
activation=None,
dropout=0.0,
head_dtype=None,
include_conv=False,
flatten=False,
**kwargs,
):
head_dtype = head_dtype or backbone.dtype_policy
Expand Down Expand Up @@ -129,20 +127,6 @@ def __init__(
dtype=head_dtype,
name="output_dropout",
)

if include_conv:
self.output_conv = keras.layers.Conv2D(
filters=1024,
kernel_size=(1, 1),
strides=(1, 1),
use_bias=True,
padding='valid',
activation="hardswish",
)

if flatten:
self.flatten = keras.layers.Flatten()

self.output_dense = keras.layers.Dense(
num_classes,
activation=activation,
Expand All @@ -155,10 +139,6 @@ def __init__(
x = self.backbone(inputs)
x = self.pooler(x)
x = self.output_dropout(x)
if include_conv:
x = self.output_conv(x)
if flatten:
x = self.flatten(x)
outputs = self.output_dense(x)
super().__init__(
inputs=inputs,
Expand All @@ -171,8 +151,6 @@ def __init__(
self.activation = activation
self.pooling = pooling
self.dropout = dropout
self.include_conv = include_conv
self.flatten = flatten

def get_config(self):
# Backbone serialized in `super`
Expand All @@ -183,8 +161,6 @@ def get_config(self):
"pooling": self.pooling,
"activation": self.activation,
"dropout": self.dropout,
"include_conv": self.include_conv,
"flatten": self.flatten,
}
)
return config
Expand Down
5 changes: 2 additions & 3 deletions keras_hub/src/models/mobilenet/conv_bn_act_block.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import keras


BN_EPSILON = 1e-5
BN_MOMENTUM = 0.9
BN_AXIS = 3
Expand All @@ -9,8 +8,8 @@
class ConvBnActBlock(keras.layers.Layer):
def __init__(
self,
filter,
activation,
filter,
activation,
name=None,
**kwargs,
):
Expand Down
21 changes: 11 additions & 10 deletions keras_hub/src/models/mobilenet/depthwise_conv_block.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import keras

from keras_hub.src.models.mobilenet.squeeze_and_excite_2d import SqueezeAndExcite2D
from keras_hub.src.models.mobilenet.squeeze_and_excite_2d import (
SqueezeAndExcite2D,
)
from keras_hub.src.models.mobilenet.util import adjust_channels


BN_EPSILON = 1e-5
BN_MOMENTUM = 0.9
BN_AXIS = 3
Expand Down Expand Up @@ -36,10 +37,10 @@ class DepthwiseConvBlock(keras.layers.Layer):
def __init__(
self,
infilters,
filters,
kernel_size=3,
stride=2,
se=None,
filters,
kernel_size=3,
stride=2,
se=None,
name=None,
**kwargs,
):
Expand Down Expand Up @@ -121,10 +122,10 @@ def call(self, inputs):
def get_config(self):
config = {
"infilters": self.infilters,
"filters": self.filters,
"kernel_size": self.kernel_size,
"stride": self.stride,
"se": self.se,
"filters": self.filters,
"kernel_size": self.kernel_size,
"stride": self.stride,
"se": self.se,
"name": self.name,
}

Expand Down
5 changes: 3 additions & 2 deletions keras_hub/src/models/mobilenet/inverted_residual_block.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import keras

from keras_hub.src.models.mobilenet.squeeze_and_excite_2d import SqueezeAndExcite2D
from keras_hub.src.models.mobilenet.squeeze_and_excite_2d import (
SqueezeAndExcite2D,
)
from keras_hub.src.models.mobilenet.util import adjust_channels


BN_EPSILON = 1e-5
BN_MOMENTUM = 0.9
BN_AXIS = 3
Expand Down
16 changes: 11 additions & 5 deletions keras_hub/src/models/mobilenet/mobilenet_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.backbone import Backbone
from keras_hub.src.models.mobilenet.depthwise_conv_block import DepthwiseConvBlock
from keras_hub.src.models.mobilenet.inverted_residual_block import InvertedResidualBlock
from keras_hub.src.models.mobilenet.conv_bn_act_block import ConvBnActBlock
from keras_hub.src.models.mobilenet.depthwise_conv_block import (
DepthwiseConvBlock,
)
from keras_hub.src.models.mobilenet.inverted_residual_block import (
InvertedResidualBlock,
)
from keras_hub.src.models.mobilenet.util import adjust_channels


BN_EPSILON = 1e-5
BN_MOMENTUM = 0.9

Expand Down Expand Up @@ -140,7 +143,7 @@ def __init__(
input_num_filters = adjust_channels(input_num_filters)

x = keras.layers.ZeroPadding2D(
padding=(1,1),
padding=(1, 1),
name="input_pad",
)(x)
x = keras.layers.Conv2D(
Expand All @@ -160,7 +163,10 @@ def __init__(
x = keras.layers.Activation(input_activation)(x)

x = DepthwiseConvBlock(
input_num_filters, depthwise_filters, se=squeeze_and_excite, name="block_0"
input_num_filters,
depthwise_filters,
se=squeeze_and_excite,
name="block_0",
)(x)

for block in range(len(stackwise_num_blocks)):
Expand Down
66 changes: 66 additions & 0 deletions keras_hub/src/models/mobilenet/mobilenet_image_classifier.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import keras

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.image_classifier import ImageClassifier
from keras_hub.src.models.mobilenet.mobilenet_backbone import MobileNetBackbone
Expand All @@ -10,3 +12,67 @@
class MobileNetImageClassifier(ImageClassifier):
backbone_cls = MobileNetBackbone
preprocessor_cls = MobileNetImageClassifierPreprocessor

def __init__(
self,
backbone,
num_classes,
preprocessor=None,
head_dtype=None,
**kwargs,
):
super().__init__(
backbone,
num_classes,
preprocessor=preprocessor,
head_dtype=head_dtype,
**kwargs,
)

head_dtype = head_dtype or backbone.dtype_policy
data_format = getattr(backbone, "data_format", None)

# === Layers ===
self.backbone = backbone
self.preprocessor = preprocessor
self.pooler = keras.layers.GlobalAveragePooling2D(
data_format, keepdims=True, dtype=head_dtype, name="pooler"
)

self.output_conv = keras.layers.Conv2D(
filters=1024,
kernel_size=(1, 1),
strides=(1, 1),
use_bias=True,
padding="valid",
activation="hard_silu",
)

self.flatten = keras.layers.Flatten()

self.output_dense = keras.layers.Dense(
num_classes,
dtype=head_dtype,
name="predictions",
)

# === Config ===
self.num_classes = num_classes

def get_config(self):
# Backbone serialized in `super`
config = super().get_config()
config.update(
{
"num_classes": self.num_classes,
}
)
return config

def call(self, inputs):
x = self.backbone(inputs)
x = self.pooler(x)
x = self.output_conv(x)
x = self.flatten(x)
x = self.output_dense(x)
return x
2 changes: 1 addition & 1 deletion keras_hub/src/models/mobilenet/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ def adjust_channels(x, divisor=8, min_value=None):
# make sure that round down does not go down by more than 10%.
if new_x < 0.9 * x:
new_x += divisor
return new_x
return new_x
32 changes: 27 additions & 5 deletions keras_hub/src/utils/timm/convert_mobilenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,14 @@ def port_conv2d(keras_layer, hf_weight_prefix, port_bias=False):
hf_weight_key=f"{hf_weight_prefix}.weight",
hook_fn=lambda x, _: np.transpose(x, (2, 3, 1, 0)),
)

if port_bias:
print(f"porting bias {hf_weight_prefix} -> {keras_layer}")
loader.port_weight(
keras_layer.bias,
hf_weight_key=f"{hf_weight_prefix}.bias",
)


def port_batch_normalization(keras_layer, hf_weight_prefix):
print(f"porting weights {hf_weight_prefix} -> {keras_layer}")
loader.port_weight(
Expand Down Expand Up @@ -163,11 +162,34 @@ def port_batch_normalization(keras_layer, hf_weight_prefix):
cba_block_name = f"block_{num_stacks+1}_0"
cba_block = backbone.get_layer(cba_block_name)
port_conv2d(cba_block.conv, f"blocks.{num_stacks+1}.0.conv")
port_batch_normalization(
cba_block.bn, f"blocks.{num_stacks+1}.0.bn1"
)
port_batch_normalization(cba_block.bn, f"blocks.{num_stacks+1}.0.bn1")


def convert_head(task, loader, timm_config):
def port_conv2d(keras_layer, hf_weight_prefix, port_bias=False):
print(f"porting weights {hf_weight_prefix} -> {keras_layer}")
loader.port_weight(
keras_layer.kernel,
hf_weight_key=f"{hf_weight_prefix}.weight",
hook_fn=lambda x, _: np.transpose(x, (2, 3, 1, 0)),
)

if port_bias:
print(f"porting bias {hf_weight_prefix} -> {keras_layer}")
loader.port_weight(
keras_layer.bias,
hf_weight_key=f"{hf_weight_prefix}.bias",
)

data_format = getattr(task.backbone, "data_format", None)
if not data_format or data_format == "channels_last":
conv_head_input_shape = (None, 1, 1, task.backbone.output_shape[-1])
else:
conv_head_input_shape = (None, task.backbone.output_shape[1], 1, 1)
task.output_conv.build(input_shape=conv_head_input_shape)
task.output_dense.build(input_shape=(None, 1024))

port_conv2d(task.output_conv, "conv_head", True)
prefix = "classifier."
loader.port_weight(
task.output_dense.kernel,
Expand Down
3 changes: 1 addition & 2 deletions tools/checkpoint_conversion/convert_mobilenet_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def validate_output(keras_model, timm_model):
preprocessing_diff = np.mean(np.abs(keras_preprocessed - timm_preprocessed))
print("🔶 Preprocessing difference:", preprocessing_diff)


def main(_):
preset = FLAGS.preset
if os.path.exists(preset):
Expand All @@ -94,8 +95,6 @@ def main(_):
print("✅ Loaded KerasHub model.")
keras_model = keras_hub.models.ImageClassifier.from_preset(
"hf://" + timm_name,
include_conv=True,
flatten=True,
)

keras_model.save_to_preset(f"./{preset}")
Expand Down

0 comments on commit a5a0bb3

Please sign in to comment.