From 0dd12dd97e31bba0fa160420317aca6c59b3a12d Mon Sep 17 00:00:00 2001 From: Usha Rengaraju <34335028+ushareng@users.noreply.github.com> Date: Tue, 1 Oct 2024 03:41:29 +0530 Subject: [PATCH 01/21] kaggle weights --- .../src/models/mobilenet/mobilenet_presets.py | 67 +++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 keras_hub/src/models/mobilenet/mobilenet_presets.py diff --git a/keras_hub/src/models/mobilenet/mobilenet_presets.py b/keras_hub/src/models/mobilenet/mobilenet_presets.py new file mode 100644 index 0000000000..5f29b28d0f --- /dev/null +++ b/keras_hub/src/models/mobilenet/mobilenet_presets.py @@ -0,0 +1,67 @@ +"""MobileNet model preset configurations.""" + +backbone_presets_no_weights = { + "mobilenet_v3_small": { + "metadata": { + "description": ( + "MobileNetV3 model with 14 layers where the batch " + "normalization and hard-swish activation are applied after the " + "convolution layers." + ), + "params": 933502, + "official_name": "MobileNetV3", + "path": "mobilenetv3", + }, + "kaggle_handle": "kaggle://keras/mobilenetv3/keras/mobilenet_v3_small/2", # noqa: E501 + }, + "mobilenet_v3_large": { + "metadata": { + "description": ( + "MobileNetV3 model with 28 layers where the batch " + "normalization and hard-swish activation are applied after the " + "convolution layers." + ), + "params": 2994518, + "official_name": "MobileNetV3", + "path": "mobilenetv3", + }, + "kaggle_handle": "kaggle://keras/mobilenetv3/keras/mobilenet_v3_large/2", # noqa: E501 + }, +} + +backbone_presets_with_weights = { + "mobile_net_v3_large": { + "metadata": { + "description": ( + "MobileNetV3 model with 28 layers where the batch " + "normalization and hard-swish activation are applied after the " + "convolution layers. " + "Pre-trained on the ImageNet 2012 classification task." + ), + "params": 2994518, + "official_name": "MobileNetV3", + "path": "mobilenetv3", + }, + "kaggle_handle": "kaggle://alexbutcher/mobilenet/keras/mobile_net_v3_large/1", + # "kaggle_handle": "kaggle://keras/mobilenetv3/keras/mobilenet_v3_large_imagenet/2", # noqa: E501 + }, + # "mobilenet_v3_small_imagenet": { + # "metadata": { + # "description": ( + # "MobileNetV3 model with 14 layers where the batch " + # "normalization and hard-swish activation are applied after the " + # "convolution layers. " + # "Pre-trained on the ImageNet 2012 classification task." + # ), + # "params": 933502, + # "official_name": "MobileNetV3", + # "path": "mobilenetv3", + # }, + # "kaggle_handle": "kaggle://keras/mobilenetv3/keras/mobilenet_v3_small_imagenet/2", # noqa: E501 + # }, +} + +backbone_presets = { + # **backbone_presets_no_weights, + **backbone_presets_with_weights, +} From d359d515ff545672af34f340376b05c32312534c Mon Sep 17 00:00:00 2001 From: Usha Rengaraju <34335028+ushareng@users.noreply.github.com> Date: Fri, 4 Oct 2024 20:18:17 +0530 Subject: [PATCH 02/21] updated Mobilenet backbone to match it with torch implementation --- .../models/mobilenet/mobilenet_backbone.py | 1020 +++++++++-------- .../mobilenet/mobilenet_backbone_test.py | 105 +- .../mobilenet_image_classifier_test.py | 128 ++- 3 files changed, 647 insertions(+), 606 deletions(-) diff --git a/keras_hub/src/models/mobilenet/mobilenet_backbone.py b/keras_hub/src/models/mobilenet/mobilenet_backbone.py index d3aff7e9b8..545c66d441 100644 --- a/keras_hub/src/models/mobilenet/mobilenet_backbone.py +++ b/keras_hub/src/models/mobilenet/mobilenet_backbone.py @@ -1,506 +1,514 @@ -import keras -from keras import ops - -from keras_hub.src.api_export import keras_hub_export -from keras_hub.src.models.backbone import Backbone - -BN_EPSILON = 1e-3 -BN_MOMENTUM = 0.999 - - -@keras_hub_export("keras_hub.models.MobileNetBackbone") -class MobileNetBackbone(Backbone): - """Instantiates the MobileNet architecture. - - MobileNet is a lightweight convolutional neural network (CNN) - optimized for mobile and edge devices, striking a balance between - accuracy and efficiency. By employing depthwise separable convolutions - and techniques like Squeeze-and-Excitation (SE) blocks, - MobileNet models are highly suitable for real-time applications on - resource-constrained devices. - - References: - - [MobileNets: Efficient Convolutional Neural Networks - for Mobile Vision Applications]( - https://arxiv.org/abs/1704.04861) - - [MobileNetV2: Inverted Residuals and Linear Bottlenecks]( - https://arxiv.org/abs/1801.04381) (CVPR 2018) - - [Searching for MobileNetV3](https://arxiv.org/pdf/1905.02244.pdf) - (ICCV 2019) - - Args: - stackwise_expansion: list of ints or floats, the expansion ratio for - each inverted residual block in the model. - stackwise_num_filters: list of ints, number of filters for each inverted - residual block in the model. - stackwise_kernel_size: list of ints, kernel size for each inverted - residual block in the model. - stackwise_num_strides: list of ints, stride length for each inverted - residual block in the model. - stackwise_se_ratio: se ratio for each inverted residual block in the - model. 0 if dont want to add Squeeze and Excite layer. - stackwise_activation: list of activation functions, for each inverted - residual block in the model. - image_shape: optional shape tuple, defaults to (224, 224, 3). - depth_multiplier: float, controls the width of the network. - - If `depth_multiplier` < 1.0, proportionally decreases the number - of filters in each layer. - - If `depth_multiplier` > 1.0, proportionally increases the number - of filters in each layer. - - If `depth_multiplier` = 1, default number of filters from the paper - are used at each layer. - input_num_filters: number of filters in first convolution layer - output_num_filters: specifies whether to add conv and batch_norm in the end, - if set to None, it will not add these layers in the end. - 'None' for MobileNetV1 - input_activation: activation function to be used in the input layer - 'hard_swish' for MobileNetV3, - 'relu6' for MobileNetV1 and MobileNetV2 - output_activation: activation function to be used in the output layer - 'hard_swish' for MobileNetV3, - 'relu6' for MobileNetV1 and MobileNetV2 - inverted_res_block: whether to use inverted residual blocks or not, - 'False' for MobileNetV1, - 'True' for MobileNetV2 and MobileNetV3 - - - Example: - ```python - input_data = tf.ones(shape=(8, 224, 224, 3)) - - # Randomly initialized backbone with a custom config - model = MobileNetBackbone( - stackwise_expansion=[1, 4, 6], - stackwise_num_filters=[4, 8, 16], - stackwise_kernel_size=[3, 3, 5], - stackwise_num_strides=[2, 2, 1], - stackwise_se_ratio=[0.25, None, 0.25], - stackwise_activation=["relu", "relu6", "hard_swish"], - output_num_filters=1280, - input_activation='hard_swish', - output_activation='hard_swish', - inverted_res_block=True, - - ) - output = model(input_data) - ``` - """ - - def __init__( - self, - stackwise_expansion, - stackwise_num_filters, - stackwise_kernel_size, - stackwise_num_strides, - stackwise_se_ratio, - stackwise_activation, - output_num_filters, - inverted_res_block, - image_shape=(224, 224, 3), - input_activation="hard_swish", - output_activation="hard_swish", - depth_multiplier=1.0, - input_num_filters=16, - **kwargs, - ): - # === Functional Model === - channel_axis = ( - -1 if keras.config.image_data_format() == "channels_last" else 1 - ) - - image_input = keras.layers.Input(shape=image_shape) - x = image_input # Intermediate result. - input_num_filters = adjust_channels(input_num_filters) - x = keras.layers.Conv2D( - input_num_filters, - kernel_size=3, - strides=(2, 2), - padding="same", - data_format=keras.config.image_data_format(), - use_bias=False, - name="input_conv", - )(x) - x = keras.layers.BatchNormalization( - axis=channel_axis, - epsilon=BN_EPSILON, - momentum=BN_MOMENTUM, - name="input_batch_norm", - )(x) - x = keras.layers.Activation(input_activation)(x) - - for stack_index in range(len(stackwise_num_filters)): - filters = adjust_channels( - (stackwise_num_filters[stack_index]) * depth_multiplier - ) - - if inverted_res_block: - x = apply_inverted_res_block( - x, - expansion=stackwise_expansion[stack_index], - filters=filters, - kernel_size=stackwise_kernel_size[stack_index], - stride=stackwise_num_strides[stack_index], - se_ratio=(stackwise_se_ratio[stack_index]), - activation=stackwise_activation[stack_index], - expansion_index=stack_index, - ) - else: - x = apply_depthwise_conv_block( - x, - filters=filters, - kernel_size=3, - stride=stackwise_num_strides[stack_index], - depth_multiplier=depth_multiplier, - block_id=stack_index, - ) - - if output_num_filters is not None: - last_conv_ch = adjust_channels(x.shape[channel_axis] * 6) - - x = keras.layers.Conv2D( - last_conv_ch, - kernel_size=1, - padding="same", - data_format=keras.config.image_data_format(), - use_bias=False, - name="output_conv", - )(x) - x = keras.layers.BatchNormalization( - axis=channel_axis, - epsilon=BN_EPSILON, - momentum=BN_MOMENTUM, - name="output_batch_norm", - )(x) - x = keras.layers.Activation(output_activation)(x) - - super().__init__(inputs=image_input, outputs=x, **kwargs) - - # === Config === - self.stackwise_expansion = stackwise_expansion - self.stackwise_num_filters = stackwise_num_filters - self.stackwise_kernel_size = stackwise_kernel_size - self.stackwise_num_strides = stackwise_num_strides - self.stackwise_se_ratio = stackwise_se_ratio - self.stackwise_activation = stackwise_activation - self.depth_multiplier = depth_multiplier - self.input_num_filters = input_num_filters - self.output_num_filters = output_num_filters - self.input_activation = keras.activations.get(input_activation) - self.output_activation = keras.activations.get(output_activation) - self.inverted_res_block = inverted_res_block - self.image_shape = image_shape - - def get_config(self): - config = super().get_config() - config.update( - { - "stackwise_expansion": self.stackwise_expansion, - "stackwise_num_filters": self.stackwise_num_filters, - "stackwise_kernel_size": self.stackwise_kernel_size, - "stackwise_num_strides": self.stackwise_num_strides, - "stackwise_se_ratio": self.stackwise_se_ratio, - "stackwise_activation": self.stackwise_activation, - "image_shape": self.image_shape, - "depth_multiplier": self.depth_multiplier, - "input_num_filters": self.input_num_filters, - "output_num_filters": self.output_num_filters, - "input_activation": keras.activations.serialize( - activation=self.input_activation - ), - "output_activation": keras.activations.serialize( - activation=self.output_activation - ), - "inverted_res_block": self.inverted_res_block, - } - ) - return config - - -def adjust_channels(x, divisor=8, min_value=None): - """Ensure that all layers have a channel number divisible by the `divisor`. - - Args: - x: integer, input value. - divisor: integer, the value by which a channel number should be - divisible, defaults to 8. - min_value: float, optional minimum value for the new tensor. If None, - defaults to value of divisor. - - Returns: - the updated input scalar. - """ - - if min_value is None: - min_value = divisor - - new_x = max(min_value, int(x + divisor / 2) // divisor * divisor) - - # make sure that round down does not go down by more than 10%. - if new_x < 0.9 * x: - new_x += divisor - return new_x - - -def apply_inverted_res_block( - x, - expansion, - filters, - kernel_size, - stride, - se_ratio, - activation, - expansion_index, -): - """An Inverted Residual Block. - - Args: - x: input tensor. - expansion: integer, the expansion ratio, multiplied with infilters to - get the minimum value passed to adjust_channels. - filters: integer, number of filters for convolution layer. - kernel_size: integer, the kernel size for DepthWise Convolutions. - stride: integer, the stride length for DepthWise Convolutions. - se_ratio: float, ratio for bottleneck filters. Number of bottleneck - filters = filters * se_ratio. - activation: the activation layer to use. - expansion_index: integer, a unique identification if you want to use - expanded convolutions. If greater than 0, an additional Conv+BN - layer is added after the expanded convolutional layer. - - Returns: - the updated input tensor. - """ - channel_axis = ( - -1 if keras.config.image_data_format() == "channels_last" else 1 - ) - activation = keras.activations.get(activation) - shortcut = x - prefix = "expanded_conv_" - infilters = x.shape[channel_axis] - - if expansion_index > 0: - prefix = f"expanded_conv_{expansion_index}_" - - x = keras.layers.Conv2D( - adjust_channels(infilters * expansion), - kernel_size=1, - padding="same", - data_format=keras.config.image_data_format(), - use_bias=False, - name=prefix + "expand", - )(x) - x = keras.layers.BatchNormalization( - axis=channel_axis, - epsilon=BN_EPSILON, - momentum=BN_MOMENTUM, - name=prefix + "expand_BatchNorm", - )(x) - x = keras.layers.Activation(activation=activation)(x) - - if stride == 2: - x = keras.layers.ZeroPadding2D( - padding=correct_pad_downsample(x, kernel_size), - name=prefix + "depthwise_pad", - )(x) - - x = keras.layers.DepthwiseConv2D( - kernel_size, - strides=stride, - padding="same" if stride == 1 else "valid", - data_format=keras.config.image_data_format(), - use_bias=False, - name=prefix + "depthwise", - )(x) - x = keras.layers.BatchNormalization( - axis=channel_axis, - epsilon=BN_EPSILON, - momentum=BN_MOMENTUM, - name=prefix + "depthwise_BatchNorm", - )(x) - x = keras.layers.Activation(activation=activation)(x) - - if se_ratio: - se_filters = adjust_channels(infilters * expansion) - x = SqueezeAndExcite2D( - input=x, - filters=se_filters, - bottleneck_filters=adjust_channels(se_filters * se_ratio), - squeeze_activation="relu", - excite_activation=keras.activations.hard_sigmoid, - ) - - x = keras.layers.Conv2D( - filters, - kernel_size=1, - padding="same", - data_format=keras.config.image_data_format(), - use_bias=False, - name=prefix + "project", - )(x) - x = keras.layers.BatchNormalization( - axis=channel_axis, - epsilon=BN_EPSILON, - momentum=BN_MOMENTUM, - name=prefix + "project_BatchNorm", - )(x) - - if stride == 1 and infilters == filters: - x = keras.layers.Add(name=prefix + "Add")([shortcut, x]) - - return x - - -def apply_depthwise_conv_block( - x, - filters, - kernel_size=3, - depth_multiplier=1, - stride=1, - block_id=1, -): - """Adds a depthwise convolution block. - - A depthwise convolution block consists of a depthwise conv, - batch normalization, relu6, pointwise convolution, - batch normalization and relu6 activation. - - Args: - x: Input tensor of shape `(rows, cols, channels) - filters: Integer, the dimensionality of the output space - (i.e. the number of output filters in the pointwise convolution). - depth_multiplier: controls the width of the network. - - If `depth_multiplier` < 1.0, proportionally decreases the number - of filters in each layer. - - If `depth_multiplier` > 1.0, proportionally increases the number - of filters in each layer. - - If `depth_multiplier` = 1, default number of filters from the - paper are used at each layer. - strides: An integer or tuple/list of 2 integers, specifying the strides - of the convolution along the width and height. - Can be a single integer to specify the same value for - all spatial dimensions. Specifying any stride value != 1 is - incompatible with specifying any `dilation_rate` value != 1. - block_id: Integer, a unique identification designating the block number. - - Input shape: - 4D tensor with shape: `(batch, rows, cols, channels)` in "channels_last" - 4D tensor with shape: `(batch, channels, rows, cols)` in "channels_first" - Returns: - Output tensor of block. - """ - channel_axis = ( - -1 if keras.config.image_data_format() == "channels_last" else 1 - ) - if stride == 2: - x = keras.layers.ZeroPadding2D( - padding=correct_pad_downsample(x, kernel_size), - name="conv_pad_%d" % block_id, - )(x) - - x = keras.layers.DepthwiseConv2D( - kernel_size, - strides=stride, - padding="same" if stride == 1 else "valid", - data_format=keras.config.image_data_format(), - depth_multiplier=depth_multiplier, - use_bias=False, - name="depthwise_%d" % block_id, - )(x) - x = keras.layers.BatchNormalization( - axis=channel_axis, - epsilon=BN_EPSILON, - momentum=BN_MOMENTUM, - name="depthwise_BatchNorm_%d" % block_id, - )(x) - x = keras.layers.ReLU(6.0)(x) - - x = keras.layers.Conv2D( - filters, - kernel_size=1, - padding="same", - data_format=keras.config.image_data_format(), - use_bias=False, - name="conv_%d" % block_id, - )(x) - x = keras.layers.BatchNormalization( - axis=channel_axis, - epsilon=BN_EPSILON, - momentum=BN_MOMENTUM, - name="BatchNorm_%d" % block_id, - )(x) - return keras.layers.ReLU(6.0)(x) - - -def SqueezeAndExcite2D( - input, - filters, - bottleneck_filters=None, - squeeze_activation="relu", - excite_activation="sigmoid", -): - """ - Description: - This layer applies a content-aware mechanism to adaptively assign - channel-wise weights. It uses global average pooling to compress - feature maps into single values, which are then processed by - two Conv1D layers: the first reduces the dimensionality, and - the second restores it. - Args: - filters: Number of input and output filters. The number of input and - output filters is same. - bottleneck_filters: (Optional) Number of bottleneck filters. Defaults - to `0.25 * filters` - squeeze_activation: (Optional) String, callable (or - keras.layers.Layer) or keras.activations.Activation instance - denoting activation to be applied after squeeze convolution. - Defaults to `relu`. - excite_activation: (Optional) String, callable (or - keras.layers.Layer) or keras.activations.Activation instance - denoting activation to be applied after excite convolution. - Defaults to `sigmoid`. - """ - if not bottleneck_filters: - bottleneck_filters = filters // 4 - - x = keras.layers.GlobalAveragePooling2D(keepdims=True)(input) - - x = keras.layers.Conv2D( - bottleneck_filters, - (1, 1), - data_format=keras.config.image_data_format(), - activation=squeeze_activation, - )(x) - x = keras.layers.Conv2D( - filters, - (1, 1), - data_format=keras.config.image_data_format(), - activation=excite_activation, - )(x) - - x = ops.multiply(x, input) - return x - - -def correct_pad_downsample(inputs, kernel_size): - """Returns a tuple for zero-padding for 2D convolution with downsampling. - - Args: - inputs: Input tensor. - kernel_size: An integer or tuple/list of 2 integers. - - Returns: - A tuple. - """ - img_dim = 1 - input_size = inputs.shape[img_dim : (img_dim + 2)] - if isinstance(kernel_size, int): - kernel_size = (kernel_size, kernel_size) - if input_size[0] is None: - adjust = (1, 1) - else: - adjust = (1 - input_size[0] % 2, 1 - input_size[1] % 2) - correct = (kernel_size[0] // 2, kernel_size[1] // 2) - return ( - (correct[0] - adjust[0], correct[0]), - (correct[1] - adjust[1], correct[1]), - ) +import keras +from keras import ops + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.backbone import Backbone + +BN_EPSILON = 1e-5 +BN_MOMENTUM = 0.9 + + +@keras_hub_export("keras_hub.models.MobileNetBackbone") +class MobileNetBackbone(Backbone): + """Instantiates the MobileNet architecture. + + MobileNet is a lightweight convolutional neural network (CNN) + optimized for mobile and edge devices, striking a balance between + accuracy and efficiency. By employing depthwise separable convolutions + and techniques like Squeeze-and-Excitation (SE) blocks, + MobileNet models are highly suitable for real-time applications on + resource-constrained devices. + + References: + - [MobileNets: Efficient Convolutional Neural Networks + for Mobile Vision Applications]( + https://arxiv.org/abs/1704.04861) + - [MobileNetV2: Inverted Residuals and Linear Bottlenecks]( + https://arxiv.org/abs/1801.04381) (CVPR 2018) + - [Searching for MobileNetV3](https://arxiv.org/pdf/1905.02244.pdf) + (ICCV 2019) + + Args: + stackwise_expansion: list of list of ints, the expanded filters for + each inverted residual block for each block in the model. + stackwise_num_blocks: list of ints, number of inversted residual blocks + per block + stackwise_num_filters: list of list of ints, number of filters for each inverted + residual block in the model. + stackwise_kernel_size: list of list of ints, kernel size for each inverted + residual block in the model. + stackwise_num_strides: list of list of ints, stride length for each inverted + residual block in the model. + stackwise_se_ratio: se ratio for each inverted residual block in the + model. 0 if dont want to add Squeeze and Excite layer. + stackwise_activation: list of list of activation functions, for each inverted + residual block in the model. + image_shape: optional shape tuple, defaults to (224, 224, 3). + input_num_filters: number of filters in first convolution layer + output_num_filters: specifies whether to add conv and batch_norm in the end, + if set to None, it will not add these layers in the end. + 'None' for MobileNetV1 + stackwise_padding: list of list of ints, padding value for each inverted + residual block in the model. + input_activation: activation function to be used in the input layer + 'hard_swish' for MobileNetV3, + 'relu6' for MobileNetV1 and MobileNetV2 + output_activation: activation function to be used in the output layer + 'hard_swish' for MobileNetV3, + 'relu6' for MobileNetV1 and MobileNetV2 + depthwise_filters: int, number of filters in depthwise separable + convolution layer + squeeze_and_excite: float, squeeze and excite ratio in the depthwise layer, + None, if dont want to do squeeze and excite + + + Example: + ```python + input_data = tf.ones(shape=(8, 224, 224, 3)) + + # Randomly initialized backbone with a custom config + model = MobileNetBackbone( + stackwise_expansion=[1, 4, 6], + stackwise_num_blocks=[2, 3, 2, 3], + stackwise_num_filters=[4, 8, 16], + stackwise_kernel_size=[[3, 3], [5, 5, 5], [5, 5], [5, 5, 5]], + stackwise_num_strides=[[2, 1], [2, 1, 1], [1, 1], [2, 1, 1]], + stackwise_se_ratio=[ + [None, None], + [0.25, 0.25, 0.25], + [0.3, 0.3], + [0.3, 0.25, 0.25], + ], + stackwise_activation=[ + ["relu", "relu"], + ["hard_swish", "hard_swish", "hard_swish"], + ["hard_swish", "hard_swish"], + ["hard_swish", "hard_swish", "hard_swish"], + ], + stackwise_padding=[[1, 1], [2, 2, 2], [2, 2], [2, 2, 2]], + output_num_filters=288, + input_activation="hard_swish", + output_activation="hard_swish", + inverted_res_block=True, + input_num_filters=16, + image_shape=(224, 224, 3), + depthwise_filters=8, + squeeze_and_excite=0.5, + + ) + output = model(input_data) + ``` + """ + + def __init__( + self, + stackwise_expansion, + stackwise_num_blocks, + stackwise_num_filters, + stackwise_kernel_size, + stackwise_num_strides, + stackwise_se_ratio, + stackwise_activation, + stackwise_padding, + output_num_filters, + depthwise_filters, + squeeze_and_excite=None, + image_shape=(224, 224, 3), + input_activation="hard_swish", + output_activation="hard_swish", + input_num_filters=16, + **kwargs, + ): + # === Functional Model === + channel_axis = ( + -1 if keras.config.image_data_format() == "channels_last" else 1 + ) + + image_input = keras.layers.Input(shape=image_shape) + x = image_input + input_num_filters = adjust_channels(input_num_filters) + x = keras.layers.Conv2D( + input_num_filters, + kernel_size=3, + strides=(2, 2), + padding=(1, 1), + data_format=keras.config.image_data_format(), + use_bias=False, + name="input_conv", + )(x) + x = keras.layers.BatchNormalization( + axis=channel_axis, + epsilon=BN_EPSILON, + momentum=BN_MOMENTUM, + name="input_batch_norm", + )(x) + x = keras.layers.Activation(input_activation)(x) + + x = apply_depthwise_conv_block( + x, depthwise_filters, squeeze_and_excite, name="block_0") + + for block in range(1, len(stackwise_num_blocks)): + for inverted_block in range(stackwise_num_blocks[block]): + x = apply_inverted_res_block( + x, + expansion=stackwise_expansion[block][inverted_block], + filters=adjust_channels( + stackwise_num_filters[block][inverted_block] + ), + kernel_size=stackwise_kernel_size[block][inverted_block], + stride=stackwise_num_strides[block][inverted_block], + padding=stackwise_padding[block][inverted_block], + se_ratio=stackwise_se_ratio[block][inverted_block], + activation=stackwise_activation[block][inverted_block], + name=f"block_{block}_{inverted_block}", + ) + + if output_num_filters is not None: + last_conv_ch = adjust_channels(output_num_filters) + + x = keras.layers.Conv2D( + last_conv_ch, + kernel_size=1, + padding=(1, 1), + data_format=keras.config.image_data_format(), + use_bias=False, + name="output_conv", + )(x) + x = keras.layers.BatchNormalization( + axis=channel_axis, + epsilon=BN_EPSILON, + momentum=BN_MOMENTUM, + name="output_batch_norm", + )(x) + x = keras.layers.Activation(output_activation)(x) + + super().__init__(inputs=image_input, outputs=x, **kwargs) + + # === Config === + self.stackwise_expansion = stackwise_expansion + self.stackwise_num_blocks = stackwise_num_blocks + self.stackwise_num_filters = stackwise_num_filters + self.stackwise_kernel_size = stackwise_kernel_size + self.stackwise_num_strides = stackwise_num_strides + self.stackwise_se_ratio = stackwise_se_ratio + self.stackwise_activation = stackwise_activation + self.stackwise_padding = stackwise_padding + self.input_num_filters = input_num_filters + self.output_num_filters = output_num_filters + self.depthwise_filters = depthwise_filters + self.squeeze_and_excite = squeeze_and_excite + self.input_activation = keras.activations.get(input_activation) + self.output_activation = keras.activations.get(output_activation) + self.image_shape = image_shape + + def get_config(self): + config = super().get_config() + config.update( + { + "stackwise_expansion": self.stackwise_expansion, + "stackwise_num_blocks": self.stackwise_num_blocks, + "stackwise_num_filters": self.stackwise_num_filters, + "stackwise_kernel_size": self.stackwise_kernel_size, + "stackwise_num_strides": self.stackwise_num_strides, + "stackwise_se_ratio": self.stackwise_se_ratio, + "stackwise_activation": self.stackwise_activation, + "stackwise_padding": self.stackwise_padding, + "image_shape": self.image_shape, + "input_num_filters": self.input_num_filters, + "output_num_filters": self.output_num_filters, + "depthwise_filters": self.depthwise_filters, + "squeeze_and_excite": self.squeeze_and_excite, + "input_activation": keras.activations.serialize( + activation=self.input_activation + ), + "output_activation": keras.activations.serialize( + activation=self.output_activation + ), + } + ) + return config + + +def adjust_channels(x, divisor=8, min_value=None): + """Ensure that all layers have a channel number divisible by the `divisor`. + + Args: + x: integer, input value. + divisor: integer, the value by which a channel number should be + divisible, defaults to 8. + min_value: float, optional minimum value for the new tensor. If None, + defaults to value of divisor. + + Returns: + the updated input scalar. + """ + + if min_value is None: + min_value = divisor + + new_x = max(min_value, int(x + divisor / 2) // divisor * divisor) + + # make sure that round down does not go down by more than 10%. + if new_x < 0.9 * x: + new_x += divisor + return new_x + + +def apply_inverted_res_block( + x, + expansion, + filters, + kernel_size, + stride, + padding, + se_ratio, + activation, + name=None, +): + """An Inverted Residual Block. + + Args: + x: input tensor. + expansion: integer, the expansion ratio, multiplied with infilters to + get the minimum value passed to adjust_channels. + filters: integer, number of filters for convolution layer. + kernel_size: integer, the kernel size for DepthWise Convolutions. + stride: integer, the stride length for DepthWise Convolutions. + padding: integer, padding for the convolution layer + se_ratio: float, ratio for bottleneck filters. Number of bottleneck + filters = filters * se_ratio. + activation: the activation layer to use. + name: string, block label. + + Returns: + the updated input tensor. + """ + channel_axis = ( + -1 if keras.config.image_data_format() == "channels_last" else 1 + ) + activation = keras.activations.get(activation) + shortcut = x + infilters = x.shape[channel_axis] + expanded_channels = adjust_channels(expansion) + + x = keras.layers.Conv2D( + expanded_channels, + kernel_size=1, + padding="same", + data_format=keras.config.image_data_format(), + use_bias=False, + name=f"{name}_conv1", + )(x) + + x = keras.layers.BatchNormalization( + axis=channel_axis, + epsilon=BN_EPSILON, + momentum=BN_MOMENTUM, + name=f"{name}_bn1", + )(x) + + x = keras.layers.Activation(activation=activation)(x) + + x = keras.layers.Conv2D( + expanded_channels, + kernel_size, + strides=stride, + padding=padding, + groups=expanded_channels, + data_format=keras.config.image_data_format(), + use_bias=False, + name=f"{name}_conv2", + )(x) + x = keras.layers.BatchNormalization( + axis=channel_axis, + epsilon=BN_EPSILON, + momentum=BN_MOMENTUM, + name=f"{name}_bn2", + )(x) + + x = keras.layers.Activation(activation=activation)(x) + + if se_ratio: + se_filters = expanded_channels + x = SqueezeAndExcite2D( + input=x, + filters=se_filters, + bottleneck_filters=adjust_channels(se_filters * se_ratio), + squeeze_activation="relu", + excite_activation=keras.activations.hard_sigmoid, + name=f"{name}_se", + ) + + x = keras.layers.Conv2D( + filters, + kernel_size=1, + padding="same", + data_format=keras.config.image_data_format(), + use_bias=False, + name=f"{name}_conv3", + )(x) + x = keras.layers.BatchNormalization( + axis=channel_axis, + epsilon=BN_EPSILON, + momentum=BN_MOMENTUM, + name=f"{name}_bn3", + )(x) + + if stride == 1 and infilters == filters: + x = keras.layers.Add(name=f"{name}_add")([shortcut, x]) + return x + + +def apply_depthwise_conv_block( + x, filters, kernel_size=3, stride=1, se=True, name=None +): + """Adds a depthwise convolution block. + + A depthwise convolution block consists of a depthwise conv, + batch normalization, relu6, pointwise convolution, + batch normalization and relu6 activation. + + Args: + x: Input tensor of shape `(rows, cols, channels) + filters: Integer, the dimensionality of the output space + (i.e. the number of output filters in the pointwise convolution). + strides: An integer or tuple/list of 2 integers, specifying the strides + of the convolution along the width and height. + Can be a single integer to specify the same value for + all spatial dimensions. Specifying any stride value != 1 is + incompatible with specifying any `dilation_rate` value != 1. + block_id: Integer, a unique identification designating the block number. + + Input shape: + 4D tensor with shape: `(batch, rows, cols, channels)` in "channels_last" + 4D tensor with shape: `(batch, channels, rows, cols)` in "channels_first" + Returns: + Output tensor of block. + """ + channel_axis = ( + -1 if keras.config.image_data_format() == "channels_last" else 1 + ) + infilters = x.shape[channel_axis] + name = f"{name}_0" + + x = keras.layers.Conv2D( + infilters, + kernel_size, + strides=stride, + padding=(1, 1), + data_format=keras.config.image_data_format(), + groups=infilters, + use_bias=False, + name=f"{name}_conv1", + )(x) + x = keras.layers.BatchNormalization( + axis=channel_axis, + epsilon=BN_EPSILON, + momentum=BN_MOMENTUM, + name=f"{name}_bn1", + )(x) + x = keras.layers.ReLU(6.0)(x) + + if se: + x = SqueezeAndExcite2D( + input=x, + filters=infilters, + bottleneck_filters=adjust_channels(infilters * se), + squeeze_activation="relu", + excite_activation=keras.activations.hard_sigmoid, + name=f"{name}_se", + ) + + x = keras.layers.Conv2D( + filters, + kernel_size=1, + padding="same", + data_format=keras.config.image_data_format(), + use_bias=False, + name=f"{name}_conv2", + )(x) + x = keras.layers.BatchNormalization( + axis=channel_axis, + epsilon=BN_EPSILON, + momentum=BN_MOMENTUM, + name=f"{name}_bn2", + )(x) + return x + + +def SqueezeAndExcite2D( + input, + filters, + bottleneck_filters=None, + squeeze_activation="relu", + excite_activation="sigmoid", + name=None, +): + """ + Description: + This layer applies a content-aware mechanism to adaptively assign + channel-wise weights. It uses global average pooling to compress + feature maps into single values, which are then processed by + two Conv1D layers: the first reduces the dimensionality, and + the second restores it. + Args: + filters: Number of input and output filters. The number of input and + output filters is same. + bottleneck_filters: (Optional) Number of bottleneck filters. Defaults + to `0.25 * filters` + squeeze_activation: (Optional) String, callable (or + keras.layers.Layer) or keras.activations.Activation instance + denoting activation to be applied after squeeze convolution. + Defaults to `relu`. + excite_activation: (Optional) String, callable (or + keras.layers.Layer) or keras.activations.Activation instance + denoting activation to be applied after excite convolution. + Defaults to `sigmoid`. + name: Name of the layer + """ + if not bottleneck_filters: + bottleneck_filters = filters // 4 + + x =input + x = keras.layers.Conv2D( + bottleneck_filters, + (1, 1), + data_format=keras.config.image_data_format(), + activation=squeeze_activation, + name=f"{name}_conv_reduce", + )(x) + x = keras.layers.Conv2D( + filters, + (1, 1), + data_format=keras.config.image_data_format(), + activation=excite_activation, + name=f"{name}_conv_expand", + )(x) + + x = ops.multiply(x, input) + return x + + +def correct_pad_downsample(inputs, kernel_size): + """Returns a tuple for zero-padding for 2D convolution with downsampling. + + Args: + inputs: Input tensor. + kernel_size: An integer or tuple/list of 2 integers. + + Returns: + A tuple. + """ + img_dim = 1 + input_size = inputs.shape[img_dim : (img_dim + 2)] + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size) + if input_size[0] is None: + adjust = (1, 1) + else: + adjust = (1 - input_size[0] % 2, 1 - input_size[1] % 2) + correct = (kernel_size[0] // 2, kernel_size[1] // 2) + return ( + (correct[0] - adjust[0], correct[0]), + (correct[1] - adjust[1], correct[1]), + ) diff --git a/keras_hub/src/models/mobilenet/mobilenet_backbone_test.py b/keras_hub/src/models/mobilenet/mobilenet_backbone_test.py index 24fdd0db4c..3c6024aef1 100644 --- a/keras_hub/src/models/mobilenet/mobilenet_backbone_test.py +++ b/keras_hub/src/models/mobilenet/mobilenet_backbone_test.py @@ -1,43 +1,62 @@ -import numpy as np -import pytest - -from keras_hub.src.models.mobilenet.mobilenet_backbone import MobileNetBackbone -from keras_hub.src.tests.test_case import TestCase - - -class MobileNetBackboneTest(TestCase): - def setUp(self): - self.init_kwargs = { - "stackwise_expansion": [1, 4, 6], - "stackwise_num_filters": [4, 8, 16], - "stackwise_kernel_size": [3, 3, 5], - "stackwise_num_strides": [2, 2, 1], - "stackwise_se_ratio": [0.25, None, 0.25], - "stackwise_activation": ["relu", "relu", "hard_swish"], - "output_num_filters": 1280, - "input_activation": "hard_swish", - "output_activation": "hard_swish", - "inverted_res_block": True, - "input_num_filters": 16, - "image_shape": (224, 224, 3), - "depth_multiplier": 1, - } - self.input_data = np.ones((2, 224, 224, 3), dtype="float32") - - def test_backbone_basics(self): - self.run_vision_backbone_test( - cls=MobileNetBackbone, - init_kwargs=self.init_kwargs, - input_data=self.input_data, - expected_output_shape=(2, 28, 28, 96), - run_mixed_precision_check=False, - run_data_format_check=False, - ) - - @pytest.mark.large - def test_saved_model(self): - self.run_model_saving_test( - cls=MobileNetBackbone, - init_kwargs=self.init_kwargs, - input_data=self.input_data, - ) +import numpy as np +import pytest + +from keras_hub.src.models.mobilenet.mobilenet_backbone import MobileNetBackbone +from keras_hub.src.tests.test_case import TestCase + + +class MobileNetBackboneTest(TestCase): + def setUp(self): + + self.init_kwargs = { + "stackwise_expansion": [ + [40, 56], + [64, 144, 144], + [ 72, 72], + [144, 288, 288] + ], + "stackwise_num_blocks": [2, 3, 2, 3], + "stackwise_num_filters": [4, 8, 16], + "stackwise_kernel_size": [[3, 3], [5, 5, 5], [5, 5], [5, 5, 5]], + "stackwise_num_strides": [[2, 1], [2, 1, 1], [1, 1], [2, 1, 1]], + "stackwise_se_ratio": [ + [None, None], + [0.25, 0.25, 0.25], + [0.3, 0.3], + [0.3, 0.25, 0.25], + ], + "stackwise_activation": [ + ["relu", "relu"], + ["hard_swish", "hard_swish", "hard_swish"], + ["hard_swish", "hard_swish"], + ["hard_swish", "hard_swish", "hard_swish"], + ], + "stackwise_padding": [[1, 1], [2, 2, 2], [2, 2], [2, 2, 2]], + "output_num_filters": 288, + "input_activation": "hard_swish", + "output_activation": "hard_swish", + "inverted_res_block": True, + "input_num_filters": 16, + "image_shape": (224, 224, 3), + "depthwise_filters": 8, + "squeeze_and_excite": 0.5, + } + self.input_data = np.ones((2, 224, 224, 3), dtype="float32") + + def test_backbone_basics(self): + self.run_vision_backbone_test( + cls=MobileNetBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=(2, 28, 28, 96), + run_mixed_precision_check=False, + run_data_format_check=False, + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=MobileNetBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) diff --git a/keras_hub/src/models/mobilenet/mobilenet_image_classifier_test.py b/keras_hub/src/models/mobilenet/mobilenet_image_classifier_test.py index 57ebd65039..cb251918a0 100644 --- a/keras_hub/src/models/mobilenet/mobilenet_image_classifier_test.py +++ b/keras_hub/src/models/mobilenet/mobilenet_image_classifier_test.py @@ -1,57 +1,71 @@ -import numpy as np -import pytest - -from keras_hub.src.models.mobilenet.mobilenet_backbone import MobileNetBackbone -from keras_hub.src.models.mobilenet.mobilenet_image_classifier import ( - MobileNetImageClassifier, -) -from keras_hub.src.tests.test_case import TestCase - - -class MobileNetImageClassifierTest(TestCase): - def setUp(self): - # Setup model. - self.images = np.ones((2, 224, 224, 3), dtype="float32") - self.labels = [0, 3] - self.backbone = MobileNetBackbone( - stackwise_expansion=[1, 4, 6], - stackwise_num_filters=[4, 8, 16], - stackwise_kernel_size=[3, 3, 5], - stackwise_num_strides=[2, 2, 1], - stackwise_se_ratio=[0.25, None, 0.25], - stackwise_activation=["relu", "relu", "hard_swish"], - output_num_filters=1280, - input_activation="hard_swish", - output_activation="hard_swish", - inverted_res_block=True, - input_num_filters=16, - image_shape=(224, 224, 3), - ) - self.init_kwargs = { - "backbone": self.backbone, - "num_classes": 2, - "activation": "softmax", - } - self.train_data = ( - self.images, - self.labels, - ) - - def test_classifier_basics(self): - pytest.skip( - reason="TODO: enable after preprocessor flow is figured out" - ) - self.run_task_test( - cls=MobileNetImageClassifier, - init_kwargs=self.init_kwargs, - train_data=self.train_data, - expected_output_shape=(2, 2), - ) - - @pytest.mark.large - def test_saved_model(self): - self.run_model_saving_test( - cls=MobileNetImageClassifier, - init_kwargs=self.init_kwargs, - input_data=self.images, - ) +import numpy as np +import pytest + +from keras_hub.src.models.mobilenet.mobilenet_backbone import MobileNetBackbone +from keras_hub.src.models.mobilenet.mobilenet_image_classifier import ( + MobileNetImageClassifier, +) +from keras_hub.src.tests.test_case import TestCase + + +class MobileNetImageClassifierTest(TestCase): + def setUp(self): + # Setup model. + self.images = np.ones((2, 224, 224, 3), dtype="float32") + self.labels = [0, 3] + self.backbone = MobileNetBackbone( + stackwise_expansion=[1, 4, 6], + stackwise_num_blocks=[2, 3, 2, 3], + stackwise_num_filters=[4, 8, 16], + stackwise_kernel_size=[[3, 3], [5, 5, 5], [5, 5], [5, 5, 5]], + stackwise_num_strides=[[2, 1], [2, 1, 1], [1, 1], [2, 1, 1]], + stackwise_se_ratio=[ + [None, None], + [0.25, 0.25, 0.25], + [0.3, 0.3], + [0.3, 0.25, 0.25], + ], + stackwise_activation=[ + ["relu", "relu"], + ["hard_swish", "hard_swish", "hard_swish"], + ["hard_swish", "hard_swish"], + ["hard_swish", "hard_swish", "hard_swish"], + ], + stackwise_padding=[[1, 1], [2, 2, 2], [2, 2], [2, 2, 2]], + output_num_filters=288, + input_activation="hard_swish", + output_activation="hard_swish", + inverted_res_block=True, + input_num_filters=16, + image_shape=(224, 224, 3), + depthwise_filters=8, + squeeze_and_excite=0.5, + ) + self.init_kwargs = { + "backbone": self.backbone, + "num_classes": 2, + "activation": "softmax", + } + self.train_data = ( + self.images, + self.labels, + ) + + def test_classifier_basics(self): + pytest.skip( + reason="TODO: enable after preprocessor flow is figured out" + ) + self.run_task_test( + cls=MobileNetImageClassifier, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape=(2, 2), + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=MobileNetImageClassifier, + init_kwargs=self.init_kwargs, + input_data=self.images, + ) From 966384da8d163dec922389f84ddab2dc3f555027 Mon Sep 17 00:00:00 2001 From: Usha Rengaraju <34335028+ushareng@users.noreply.github.com> Date: Fri, 4 Oct 2024 20:21:17 +0530 Subject: [PATCH 03/21] Deleted presets --- .../src/models/mobilenet/mobilenet_presets.py | 66 ------------------- 1 file changed, 66 deletions(-) diff --git a/keras_hub/src/models/mobilenet/mobilenet_presets.py b/keras_hub/src/models/mobilenet/mobilenet_presets.py index 5f29b28d0f..d3f5a12faa 100644 --- a/keras_hub/src/models/mobilenet/mobilenet_presets.py +++ b/keras_hub/src/models/mobilenet/mobilenet_presets.py @@ -1,67 +1 @@ -"""MobileNet model preset configurations.""" -backbone_presets_no_weights = { - "mobilenet_v3_small": { - "metadata": { - "description": ( - "MobileNetV3 model with 14 layers where the batch " - "normalization and hard-swish activation are applied after the " - "convolution layers." - ), - "params": 933502, - "official_name": "MobileNetV3", - "path": "mobilenetv3", - }, - "kaggle_handle": "kaggle://keras/mobilenetv3/keras/mobilenet_v3_small/2", # noqa: E501 - }, - "mobilenet_v3_large": { - "metadata": { - "description": ( - "MobileNetV3 model with 28 layers where the batch " - "normalization and hard-swish activation are applied after the " - "convolution layers." - ), - "params": 2994518, - "official_name": "MobileNetV3", - "path": "mobilenetv3", - }, - "kaggle_handle": "kaggle://keras/mobilenetv3/keras/mobilenet_v3_large/2", # noqa: E501 - }, -} - -backbone_presets_with_weights = { - "mobile_net_v3_large": { - "metadata": { - "description": ( - "MobileNetV3 model with 28 layers where the batch " - "normalization and hard-swish activation are applied after the " - "convolution layers. " - "Pre-trained on the ImageNet 2012 classification task." - ), - "params": 2994518, - "official_name": "MobileNetV3", - "path": "mobilenetv3", - }, - "kaggle_handle": "kaggle://alexbutcher/mobilenet/keras/mobile_net_v3_large/1", - # "kaggle_handle": "kaggle://keras/mobilenetv3/keras/mobilenet_v3_large_imagenet/2", # noqa: E501 - }, - # "mobilenet_v3_small_imagenet": { - # "metadata": { - # "description": ( - # "MobileNetV3 model with 14 layers where the batch " - # "normalization and hard-swish activation are applied after the " - # "convolution layers. " - # "Pre-trained on the ImageNet 2012 classification task." - # ), - # "params": 933502, - # "official_name": "MobileNetV3", - # "path": "mobilenetv3", - # }, - # "kaggle_handle": "kaggle://keras/mobilenetv3/keras/mobilenet_v3_small_imagenet/2", # noqa: E501 - # }, -} - -backbone_presets = { - # **backbone_presets_no_weights, - **backbone_presets_with_weights, -} From 38fa5d3a8bb23b8e3b95bc9f247caa7135516d19 Mon Sep 17 00:00:00 2001 From: ushareng Date: Fri, 4 Oct 2024 20:32:31 +0530 Subject: [PATCH 04/21] Mobilenet preset deleted --- keras_hub/src/models/mobilenet/mobilenet_presets.py | 1 - 1 file changed, 1 deletion(-) delete mode 100644 keras_hub/src/models/mobilenet/mobilenet_presets.py diff --git a/keras_hub/src/models/mobilenet/mobilenet_presets.py b/keras_hub/src/models/mobilenet/mobilenet_presets.py deleted file mode 100644 index d3f5a12faa..0000000000 --- a/keras_hub/src/models/mobilenet/mobilenet_presets.py +++ /dev/null @@ -1 +0,0 @@ - From 16d81ba6ff3223c3eec9213dcd85f66ac2b90324 Mon Sep 17 00:00:00 2001 From: ushareng Date: Fri, 4 Oct 2024 21:33:35 +0530 Subject: [PATCH 05/21] code reformat --- keras_hub/src/models/mobilenet/mobilenet_backbone.py | 9 +++++---- .../src/models/mobilenet/mobilenet_backbone_test.py | 8 ++++---- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/keras_hub/src/models/mobilenet/mobilenet_backbone.py b/keras_hub/src/models/mobilenet/mobilenet_backbone.py index 545c66d441..69b3ac85dd 100644 --- a/keras_hub/src/models/mobilenet/mobilenet_backbone.py +++ b/keras_hub/src/models/mobilenet/mobilenet_backbone.py @@ -56,7 +56,7 @@ class MobileNetBackbone(Backbone): output_activation: activation function to be used in the output layer 'hard_swish' for MobileNetV3, 'relu6' for MobileNetV1 and MobileNetV2 - depthwise_filters: int, number of filters in depthwise separable + depthwise_filters: int, number of filters in depthwise separable convolution layer squeeze_and_excite: float, squeeze and excite ratio in the depthwise layer, None, if dont want to do squeeze and excite @@ -125,7 +125,7 @@ def __init__( ) image_input = keras.layers.Input(shape=image_shape) - x = image_input + x = image_input input_num_filters = adjust_channels(input_num_filters) x = keras.layers.Conv2D( input_num_filters, @@ -145,7 +145,8 @@ def __init__( x = keras.layers.Activation(input_activation)(x) x = apply_depthwise_conv_block( - x, depthwise_filters, squeeze_and_excite, name="block_0") + x, depthwise_filters, squeeze_and_excite, name="block_0" + ) for block in range(1, len(stackwise_num_blocks)): for inverted_block in range(stackwise_num_blocks[block]): @@ -469,7 +470,7 @@ def SqueezeAndExcite2D( if not bottleneck_filters: bottleneck_filters = filters // 4 - x =input + x = input x = keras.layers.Conv2D( bottleneck_filters, (1, 1), diff --git a/keras_hub/src/models/mobilenet/mobilenet_backbone_test.py b/keras_hub/src/models/mobilenet/mobilenet_backbone_test.py index 3c6024aef1..c85251fc60 100644 --- a/keras_hub/src/models/mobilenet/mobilenet_backbone_test.py +++ b/keras_hub/src/models/mobilenet/mobilenet_backbone_test.py @@ -10,10 +10,10 @@ def setUp(self): self.init_kwargs = { "stackwise_expansion": [ - [40, 56], - [64, 144, 144], - [ 72, 72], - [144, 288, 288] + [40, 56], + [64, 144, 144], + [72, 72], + [144, 288, 288], ], "stackwise_num_blocks": [2, 3, 2, 3], "stackwise_num_filters": [4, 8, 16], From 72982b82cfc5e4a4a6058966aade9fa2d8ef0d30 Mon Sep 17 00:00:00 2001 From: ushareng Date: Fri, 4 Oct 2024 22:09:17 +0530 Subject: [PATCH 06/21] padding changed --- keras_hub/src/models/mobilenet/mobilenet_backbone.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/keras_hub/src/models/mobilenet/mobilenet_backbone.py b/keras_hub/src/models/mobilenet/mobilenet_backbone.py index 69b3ac85dd..25ac182e4d 100644 --- a/keras_hub/src/models/mobilenet/mobilenet_backbone.py +++ b/keras_hub/src/models/mobilenet/mobilenet_backbone.py @@ -85,7 +85,6 @@ class MobileNetBackbone(Backbone): ["hard_swish", "hard_swish"], ["hard_swish", "hard_swish", "hard_swish"], ], - stackwise_padding=[[1, 1], [2, 2, 2], [2, 2], [2, 2, 2]], output_num_filters=288, input_activation="hard_swish", output_activation="hard_swish", @@ -131,7 +130,7 @@ def __init__( input_num_filters, kernel_size=3, strides=(2, 2), - padding=(1, 1), + padding="same", data_format=keras.config.image_data_format(), use_bias=False, name="input_conv", @@ -170,7 +169,7 @@ def __init__( x = keras.layers.Conv2D( last_conv_ch, kernel_size=1, - padding=(1, 1), + padding="same", data_format=keras.config.image_data_format(), use_bias=False, name="output_conv", @@ -314,7 +313,7 @@ def apply_inverted_res_block( expanded_channels, kernel_size, strides=stride, - padding=padding, + padding=correct_pad_downsample(x, kernel_size), groups=expanded_channels, data_format=keras.config.image_data_format(), use_bias=False, @@ -396,7 +395,7 @@ def apply_depthwise_conv_block( infilters, kernel_size, strides=stride, - padding=(1, 1), + padding="same", data_format=keras.config.image_data_format(), groups=infilters, use_bias=False, From 78411c94cd130d1d293504e7783915082b4d4419 Mon Sep 17 00:00:00 2001 From: ushareng Date: Fri, 4 Oct 2024 22:36:51 +0530 Subject: [PATCH 07/21] downsample_padding --- keras_hub/src/models/mobilenet/mobilenet_backbone.py | 8 -------- keras_hub/src/models/mobilenet/mobilenet_backbone_test.py | 1 - .../models/mobilenet/mobilenet_image_classifier_test.py | 1 - 3 files changed, 10 deletions(-) diff --git a/keras_hub/src/models/mobilenet/mobilenet_backbone.py b/keras_hub/src/models/mobilenet/mobilenet_backbone.py index 25ac182e4d..0a8b69db74 100644 --- a/keras_hub/src/models/mobilenet/mobilenet_backbone.py +++ b/keras_hub/src/models/mobilenet/mobilenet_backbone.py @@ -48,8 +48,6 @@ class MobileNetBackbone(Backbone): output_num_filters: specifies whether to add conv and batch_norm in the end, if set to None, it will not add these layers in the end. 'None' for MobileNetV1 - stackwise_padding: list of list of ints, padding value for each inverted - residual block in the model. input_activation: activation function to be used in the input layer 'hard_swish' for MobileNetV3, 'relu6' for MobileNetV1 and MobileNetV2 @@ -108,7 +106,6 @@ def __init__( stackwise_num_strides, stackwise_se_ratio, stackwise_activation, - stackwise_padding, output_num_filters, depthwise_filters, squeeze_and_excite=None, @@ -157,7 +154,6 @@ def __init__( ), kernel_size=stackwise_kernel_size[block][inverted_block], stride=stackwise_num_strides[block][inverted_block], - padding=stackwise_padding[block][inverted_block], se_ratio=stackwise_se_ratio[block][inverted_block], activation=stackwise_activation[block][inverted_block], name=f"block_{block}_{inverted_block}", @@ -192,7 +188,6 @@ def __init__( self.stackwise_num_strides = stackwise_num_strides self.stackwise_se_ratio = stackwise_se_ratio self.stackwise_activation = stackwise_activation - self.stackwise_padding = stackwise_padding self.input_num_filters = input_num_filters self.output_num_filters = output_num_filters self.depthwise_filters = depthwise_filters @@ -212,7 +207,6 @@ def get_config(self): "stackwise_num_strides": self.stackwise_num_strides, "stackwise_se_ratio": self.stackwise_se_ratio, "stackwise_activation": self.stackwise_activation, - "stackwise_padding": self.stackwise_padding, "image_shape": self.image_shape, "input_num_filters": self.input_num_filters, "output_num_filters": self.output_num_filters, @@ -260,7 +254,6 @@ def apply_inverted_res_block( filters, kernel_size, stride, - padding, se_ratio, activation, name=None, @@ -274,7 +267,6 @@ def apply_inverted_res_block( filters: integer, number of filters for convolution layer. kernel_size: integer, the kernel size for DepthWise Convolutions. stride: integer, the stride length for DepthWise Convolutions. - padding: integer, padding for the convolution layer se_ratio: float, ratio for bottleneck filters. Number of bottleneck filters = filters * se_ratio. activation: the activation layer to use. diff --git a/keras_hub/src/models/mobilenet/mobilenet_backbone_test.py b/keras_hub/src/models/mobilenet/mobilenet_backbone_test.py index c85251fc60..6dfc020b6b 100644 --- a/keras_hub/src/models/mobilenet/mobilenet_backbone_test.py +++ b/keras_hub/src/models/mobilenet/mobilenet_backbone_test.py @@ -31,7 +31,6 @@ def setUp(self): ["hard_swish", "hard_swish"], ["hard_swish", "hard_swish", "hard_swish"], ], - "stackwise_padding": [[1, 1], [2, 2, 2], [2, 2], [2, 2, 2]], "output_num_filters": 288, "input_activation": "hard_swish", "output_activation": "hard_swish", diff --git a/keras_hub/src/models/mobilenet/mobilenet_image_classifier_test.py b/keras_hub/src/models/mobilenet/mobilenet_image_classifier_test.py index cb251918a0..8ac4607320 100644 --- a/keras_hub/src/models/mobilenet/mobilenet_image_classifier_test.py +++ b/keras_hub/src/models/mobilenet/mobilenet_image_classifier_test.py @@ -31,7 +31,6 @@ def setUp(self): ["hard_swish", "hard_swish"], ["hard_swish", "hard_swish", "hard_swish"], ], - stackwise_padding=[[1, 1], [2, 2, 2], [2, 2], [2, 2, 2]], output_num_filters=288, input_activation="hard_swish", output_activation="hard_swish", From e859fb4563fddc3d048e970d5f666afe7912fe80 Mon Sep 17 00:00:00 2001 From: ushareng Date: Fri, 4 Oct 2024 23:16:39 +0530 Subject: [PATCH 08/21] typo fixed --- keras_hub/src/models/mobilenet/mobilenet_backbone.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/keras_hub/src/models/mobilenet/mobilenet_backbone.py b/keras_hub/src/models/mobilenet/mobilenet_backbone.py index 0a8b69db74..9707a906a8 100644 --- a/keras_hub/src/models/mobilenet/mobilenet_backbone.py +++ b/keras_hub/src/models/mobilenet/mobilenet_backbone.py @@ -141,7 +141,7 @@ def __init__( x = keras.layers.Activation(input_activation)(x) x = apply_depthwise_conv_block( - x, depthwise_filters, squeeze_and_excite, name="block_0" + x, depthwise_filters, se=squeeze_and_excite, name="block_0" ) for block in range(1, len(stackwise_num_blocks)): @@ -352,7 +352,7 @@ def apply_inverted_res_block( def apply_depthwise_conv_block( - x, filters, kernel_size=3, stride=1, se=True, name=None + x, filters, kernel_size=3, stride=1, se=None, name=None ): """Adds a depthwise convolution block. From 35217ba88dc35db4b920bdbdc7b6e6c9ad488558 Mon Sep 17 00:00:00 2001 From: ushareng Date: Sat, 5 Oct 2024 18:56:32 +0530 Subject: [PATCH 09/21] timm script added --- .../models/mobilenet/mobilenet_backbone.py | 114 ++++++-- .../mobilenet/mobilenet_backbone_test.py | 28 +- .../mobilenet_image_classifier_test.py | 22 +- keras_hub/src/utils/timm/convert_mobilenet.py | 267 ++++++++++++++++++ .../src/utils/timm/convert_mobilenet_test.py | 26 ++ keras_hub/src/utils/timm/preset_loader.py | 3 + 6 files changed, 418 insertions(+), 42 deletions(-) create mode 100644 keras_hub/src/utils/timm/convert_mobilenet.py create mode 100644 keras_hub/src/utils/timm/convert_mobilenet_test.py diff --git a/keras_hub/src/models/mobilenet/mobilenet_backbone.py b/keras_hub/src/models/mobilenet/mobilenet_backbone.py index 9707a906a8..34ddbda6d0 100644 --- a/keras_hub/src/models/mobilenet/mobilenet_backbone.py +++ b/keras_hub/src/models/mobilenet/mobilenet_backbone.py @@ -33,20 +33,20 @@ class MobileNetBackbone(Backbone): each inverted residual block for each block in the model. stackwise_num_blocks: list of ints, number of inversted residual blocks per block - stackwise_num_filters: list of list of ints, number of filters for each inverted - residual block in the model. - stackwise_kernel_size: list of list of ints, kernel size for each inverted - residual block in the model. - stackwise_num_strides: list of list of ints, stride length for each inverted - residual block in the model. + stackwise_num_filters: list of list of ints, number of filters for + each inverted residual block in the model. + stackwise_kernel_size: list of list of ints, kernel size for each + inverted residual block in the model. + stackwise_num_strides: list of list of ints, stride length for each + inverted residual block in the model. stackwise_se_ratio: se ratio for each inverted residual block in the model. 0 if dont want to add Squeeze and Excite layer. - stackwise_activation: list of list of activation functions, for each inverted - residual block in the model. + stackwise_activation: list of list of activation functions, for each + inverted residual block in the model. image_shape: optional shape tuple, defaults to (224, 224, 3). input_num_filters: number of filters in first convolution layer - output_num_filters: specifies whether to add conv and batch_norm in the end, - if set to None, it will not add these layers in the end. + output_num_filters: specifies whether to add conv and batch_norm in the + end, if set to None, it will not add these layers in the end. 'None' for MobileNetV1 input_activation: activation function to be used in the input layer 'hard_swish' for MobileNetV3, @@ -56,8 +56,8 @@ class MobileNetBackbone(Backbone): 'relu6' for MobileNetV1 and MobileNetV2 depthwise_filters: int, number of filters in depthwise separable convolution layer - squeeze_and_excite: float, squeeze and excite ratio in the depthwise layer, - None, if dont want to do squeeze and excite + squeeze_and_excite: float, squeeze and excite ratio in the depthwise + layer, None, if dont want to do squeeze and excite Example: @@ -66,9 +66,19 @@ class MobileNetBackbone(Backbone): # Randomly initialized backbone with a custom config model = MobileNetBackbone( - stackwise_expansion=[1, 4, 6], + stackwise_expansion=[ + [40, 56], + [64, 144, 144], + [72, 72], + [144, 288, 288], + ], stackwise_num_blocks=[2, 3, 2, 3], - stackwise_num_filters=[4, 8, 16], + stackwise_num_filters=[ + [16, 16], + [24, 24, 24], + [24, 24], + [48, 48, 48], + ], stackwise_kernel_size=[[3, 3], [5, 5, 5], [5, 5], [5, 5, 5]], stackwise_num_strides=[[2, 1], [2, 1, 1], [1, 1], [2, 1, 1]], stackwise_se_ratio=[ @@ -86,7 +96,6 @@ class MobileNetBackbone(Backbone): output_num_filters=288, input_activation="hard_swish", output_activation="hard_swish", - inverted_res_block=True, input_num_filters=16, image_shape=(224, 224, 3), depthwise_filters=8, @@ -108,6 +117,7 @@ def __init__( stackwise_activation, output_num_filters, depthwise_filters, + last_layer_filter, squeeze_and_excite=None, image_shape=(224, 224, 3), input_activation="hard_swish", @@ -144,7 +154,7 @@ def __init__( x, depthwise_filters, se=squeeze_and_excite, name="block_0" ) - for block in range(1, len(stackwise_num_blocks)): + for block in range(len(stackwise_num_blocks)): for inverted_block in range(stackwise_num_blocks[block]): x = apply_inverted_res_block( x, @@ -156,27 +166,37 @@ def __init__( stride=stackwise_num_strides[block][inverted_block], se_ratio=stackwise_se_ratio[block][inverted_block], activation=stackwise_activation[block][inverted_block], - name=f"block_{block}_{inverted_block}", + name=f"block_{block+1}_{inverted_block}", ) - if output_num_filters is not None: - last_conv_ch = adjust_channels(output_num_filters) + x = ConvBnAct( + x, + filter=adjust_channels(last_layer_filter), + activation="hard_swish", + name=f"block_{len(stackwise_num_blocks)+1}_0", + ) - x = keras.layers.Conv2D( - last_conv_ch, - kernel_size=1, - padding="same", - data_format=keras.config.image_data_format(), - use_bias=False, - name="output_conv", - )(x) + last_conv_ch = adjust_channels(output_num_filters) + + x = keras.layers.Conv2D( + last_conv_ch, + kernel_size=1, + padding="same", + data_format=keras.config.image_data_format(), + use_bias=False, + name="output_conv", + )(x) + + # no output normalization in mobilenetv3 + if output_activation == "relu6": x = keras.layers.BatchNormalization( axis=channel_axis, epsilon=BN_EPSILON, momentum=BN_MOMENTUM, name="output_batch_norm", )(x) - x = keras.layers.Activation(output_activation)(x) + + x = keras.layers.Activation(output_activation)(x) super().__init__(inputs=image_input, outputs=x, **kwargs) @@ -191,6 +211,7 @@ def __init__( self.input_num_filters = input_num_filters self.output_num_filters = output_num_filters self.depthwise_filters = depthwise_filters + self.last_layer_filter = last_layer_filter self.squeeze_and_excite = squeeze_and_excite self.input_activation = keras.activations.get(input_activation) self.output_activation = keras.activations.get(output_activation) @@ -211,6 +232,7 @@ def get_config(self): "input_num_filters": self.input_num_filters, "output_num_filters": self.output_num_filters, "depthwise_filters": self.depthwise_filters, + "last_layer_filter": self.last_layer_filter, "squeeze_and_excite": self.squeeze_and_excite, "input_activation": keras.activations.serialize( activation=self.input_activation @@ -301,11 +323,16 @@ def apply_inverted_res_block( x = keras.layers.Activation(activation=activation)(x) + if stride == 2: + x = keras.layers.ZeroPadding2D( + padding=correct_pad_downsample(x, kernel_size), + )(x) + x = keras.layers.Conv2D( expanded_channels, kernel_size, strides=stride, - padding=correct_pad_downsample(x, kernel_size), + padding="same" if stride == 1 else "valid", groups=expanded_channels, data_format=keras.config.image_data_format(), use_bias=False, @@ -383,11 +410,16 @@ def apply_depthwise_conv_block( infilters = x.shape[channel_axis] name = f"{name}_0" + if stride == 2: + x = keras.layers.ZeroPadding2D( + padding=correct_pad_downsample(x, kernel_size), + )(x) + x = keras.layers.Conv2D( infilters, kernel_size, strides=stride, - padding="same", + padding="same" if stride == 1 else "valid", data_format=keras.config.image_data_format(), groups=infilters, use_bias=False, @@ -481,6 +513,28 @@ def SqueezeAndExcite2D( return x +def ConvBnAct(x, filter, activation, name=None): + channel_axis = ( + -1 if keras.config.image_data_format() == "channels_last" else 1 + ) + x = keras.layers.Conv2D( + filter, + kernel_size=1, + padding="same", + data_format=keras.config.image_data_format(), + use_bias=False, + name=f"{name}_conv", + )(x) + x = keras.layers.BatchNormalization( + axis=channel_axis, + epsilon=BN_EPSILON, + momentum=BN_MOMENTUM, + name=f"{name}_bn", + )(x) + x = keras.layers.Activation(activation)(x) + return x + + def correct_pad_downsample(inputs, kernel_size): """Returns a tuple for zero-padding for 2D convolution with downsampling. diff --git a/keras_hub/src/models/mobilenet/mobilenet_backbone_test.py b/keras_hub/src/models/mobilenet/mobilenet_backbone_test.py index 6dfc020b6b..8119d0aa1b 100644 --- a/keras_hub/src/models/mobilenet/mobilenet_backbone_test.py +++ b/keras_hub/src/models/mobilenet/mobilenet_backbone_test.py @@ -16,9 +16,24 @@ def setUp(self): [144, 288, 288], ], "stackwise_num_blocks": [2, 3, 2, 3], - "stackwise_num_filters": [4, 8, 16], - "stackwise_kernel_size": [[3, 3], [5, 5, 5], [5, 5], [5, 5, 5]], - "stackwise_num_strides": [[2, 1], [2, 1, 1], [1, 1], [2, 1, 1]], + "stackwise_num_filters": [ + [16, 16], + [24, 24, 24], + [24, 24], + [48, 48, 48], + ], + "stackwise_kernel_size": [ + [3, 3], + [5, 5, 5], + [5, 5], + [5, 5, 5], + ], + "stackwise_num_strides": [ + [2, 1], + [2, 1, 1], + [1, 1], + [2, 1, 1], + ], "stackwise_se_ratio": [ [None, None], [0.25, 0.25, 0.25], @@ -30,15 +45,16 @@ def setUp(self): ["hard_swish", "hard_swish", "hard_swish"], ["hard_swish", "hard_swish"], ["hard_swish", "hard_swish", "hard_swish"], + ["hard_swish"], ], - "output_num_filters": 288, + "output_num_filters": 1024, "input_activation": "hard_swish", "output_activation": "hard_swish", - "inverted_res_block": True, "input_num_filters": 16, "image_shape": (224, 224, 3), "depthwise_filters": 8, "squeeze_and_excite": 0.5, + "last_layer_filter": 288, } self.input_data = np.ones((2, 224, 224, 3), dtype="float32") @@ -47,7 +63,7 @@ def test_backbone_basics(self): cls=MobileNetBackbone, init_kwargs=self.init_kwargs, input_data=self.input_data, - expected_output_shape=(2, 28, 28, 96), + expected_output_shape=(2, 14, 14, 1024), run_mixed_precision_check=False, run_data_format_check=False, ) diff --git a/keras_hub/src/models/mobilenet/mobilenet_image_classifier_test.py b/keras_hub/src/models/mobilenet/mobilenet_image_classifier_test.py index 8ac4607320..36adb46613 100644 --- a/keras_hub/src/models/mobilenet/mobilenet_image_classifier_test.py +++ b/keras_hub/src/models/mobilenet/mobilenet_image_classifier_test.py @@ -14,11 +14,21 @@ def setUp(self): self.images = np.ones((2, 224, 224, 3), dtype="float32") self.labels = [0, 3] self.backbone = MobileNetBackbone( - stackwise_expansion=[1, 4, 6], + stackwise_expansion=[ + [40, 56], + [64, 144, 144], + [72, 72], + [144, 288, 288], + ], stackwise_num_blocks=[2, 3, 2, 3], - stackwise_num_filters=[4, 8, 16], - stackwise_kernel_size=[[3, 3], [5, 5, 5], [5, 5], [5, 5, 5]], - stackwise_num_strides=[[2, 1], [2, 1, 1], [1, 1], [2, 1, 1]], + stackwise_num_filters=[ + [16, 16], + [24, 24, 24], + [24, 24], + [48, 48, 48], + ], + stackwise_kernel_size=[[3, 3], [5, 5, 5], [5, 5], [5, 5, 5], [1]], + stackwise_num_strides=[[2, 1], [2, 1, 1], [1, 1], [2, 1, 1], [1]], stackwise_se_ratio=[ [None, None], [0.25, 0.25, 0.25], @@ -31,14 +41,14 @@ def setUp(self): ["hard_swish", "hard_swish"], ["hard_swish", "hard_swish", "hard_swish"], ], - output_num_filters=288, + output_num_filters=1024, input_activation="hard_swish", output_activation="hard_swish", - inverted_res_block=True, input_num_filters=16, image_shape=(224, 224, 3), depthwise_filters=8, squeeze_and_excite=0.5, + last_layer_filter=288, ) self.init_kwargs = { "backbone": self.backbone, diff --git a/keras_hub/src/utils/timm/convert_mobilenet.py b/keras_hub/src/utils/timm/convert_mobilenet.py new file mode 100644 index 0000000000..e362fca7be --- /dev/null +++ b/keras_hub/src/utils/timm/convert_mobilenet.py @@ -0,0 +1,267 @@ +import numpy as np + +from keras_hub.src.models.mobilenet.mobilenet_backbone import MobileNetBackbone + +backbone_cls = MobileNetBackbone + + +def convert_backbone_config(timm_config): + timm_architecture = timm_config["architecture"] + + if "mobilenetv3_" in timm_architecture: + input_activation = "hard_swish" + output_activation = "hard_swish" + + else: + input_activation = "relu6" + output_activation = "relu6" + + if timm_architecture == "mobilenetv3_small_050": + stackwise_num_blocks = [2, 3, 2, 3] + stackwise_expansion = [ + [40, 56], + [64, 144, 144], + [72, 72], + [144, 288, 288], + ] + stackwise_num_filters = [[16, 16], [24, 24, 24], [24, 24], [48, 48, 48]] + stackwise_kernel_size = [[3, 3], [5, 5, 5], [5, 5], [5, 5, 5]] + stackwise_num_strides = [[2, 1], [2, 1, 1], [1, 1], [2, 1, 1]] + stackwise_se_ratio = ( + [ + [None, None], + [0.25, 0.25, 0.25], + [0.3, 0.3], + [0.3, 0.25, 0.25], + ], + ) + stackwise_activation = ( + [ + ["relu6", "relu6"], + ["hard_swish", "hard_swish", "hard_swish"], + ["hard_swish", "hard_swish"], + ["hard_swish", "hard_swish", "hard_swish"], + ], + ) + output_num_filters = 1024 + input_num_filters = 16 + depthwise_filters = 8 + squeeze_and_excite = 0.5 + last_layer_filter = 288 + + elif timm_architecture == "mobilenetv2_050": + stackwise_num_blocks = ([2, 3, 4, 3, 3, 1],) + stackwise_expansion = ( + [ + [48, 96], + [96, 96, 96], + [96, 192, 192, 192], + [192, 288, 288], + [288, 480, 480], + [480], + ], + ) + stackwise_num_filters = ( + [ + [16, 16], + [16, 16, 16], + [32, 32, 32, 32], + [48, 48, 48], + [80, 80, 80], + [160], + ], + ) + stackwise_kernel_size = ( + [[3, 3], [3, 3, 3], [3, 3, 3, 3], [3, 3, 3], [3, 3, 3], [3]], + ) + stackwise_num_strides = ( + [[2, 1], [2, 1, 1], [2, 1, 1, 1], [1, 1, 1], [2, 1, 1], [1]], + ) + stackwise_se_ratio = ( + [ + [None, None], + [None, None, None], + [None, None, None, None], + [None, None, None], + [None, None, None], + [None], + ], + ) + stackwise_activation = ( + [ + ["relu6", "relu6"], + ["relu6", "relu6", "relu6"], + ["relu6", "relu6", "relu6", "relu6"], + ["relu6", "relu6", "relu6"], + ["relu6", "relu6", "relu6"], + ["relu6"], + ], + ) + output_num_filters = 1280 + input_num_filters = 16 + depthwise_filters = 8 + squeeze_and_excite = None + else: + raise ValueError( + f"Currently, the architecture {timm_architecture} is not supported." + ) + + return dict( + input_num_filters=input_num_filters, + input_activation=input_activation, + depthwise_filters=depthwise_filters, + squeeze_and_excite=squeeze_and_excite, + stackwise_num_blocks=stackwise_num_blocks, + stackwise_expansion=stackwise_expansion, + stackwise_num_filters=stackwise_num_filters, + stackwise_kernel_size=stackwise_kernel_size, + stackwise_num_strides=stackwise_num_strides, + stackwise_se_ratio=stackwise_se_ratio, + stackwise_activation=stackwise_activation, + output_num_filters=output_num_filters, + output_activation=output_activation, + last_layer_filter=last_layer_filter, + ) + + +def convert_weights(backbone, loader, timm_config): + def port_conv2d(keras_layer_name, hf_weight_prefix): + loader.port_weight( + backbone.get_layer(keras_layer_name).kernel, + hf_weight_key=f"{hf_weight_prefix}.weight", + hook_fn=lambda x, _: np.transpose(x, (2, 3, 1, 0)), + ) + + def port_batch_normalization(keras_layer_name, hf_weight_prefix): + loader.port_weight( + backbone.get_layer(keras_layer_name).gamma, + hf_weight_key=f"{hf_weight_prefix}.weight", + ) + loader.port_weight( + backbone.get_layer(keras_layer_name).beta, + hf_weight_key=f"{hf_weight_prefix}.bias", + ) + loader.port_weight( + backbone.get_layer(keras_layer_name).moving_mean, + hf_weight_key=f"{hf_weight_prefix}.running_mean", + ) + loader.port_weight( + backbone.get_layer(keras_layer_name).moving_variance, + hf_weight_key=f"{hf_weight_prefix}.running_var", + ) + + version = "v3" if backbone.output_activation == "hard_swish" else "v2" + + # Stem + port_conv2d("input_conv", "conv_stem") + port_batch_normalization("input_batch_norm", "bn1") + + # DepthWise Block (block 0) + hf_name = "blocks.0.0" + keras_name = "blocks_0" + port_conv2d(f"{keras_name}_conv1", f"{hf_name}.conv_dw") + port_batch_normalization(f"{keras_name}_bn1", f"{hf_name}.bn1") + + port_conv2d(f"{keras_name}_se_conv_reduce", f"{hf_name}.se.conv_reduce") + port_conv2d(f"{keras_name}_se_conv_expand", f"{hf_name}.se.conv_expand") + + port_conv2d(f"{keras_name}_conv2", f"{hf_name}.conv_pw") + port_batch_normalization(f"{keras_name}_bn2", f"{hf_name}.bn2") + + # Stages + num_stacks = len(backbone.stackwise_num_blocks) + for block_idx in range(num_stacks): + for inverted_block in range(backbone.stackwise_num_blocks[block_idx]): + # if version == "v1": + # keras_name = f"stack{stack_index}_block{block_idx}" + # hf_name = f"layer{stack_index+1}.{block_idx}" + # else: + # keras_name = f"stack{stack_index}_block{block_idx}" + # hf_name = f"stages.{stack_index}.blocks.{block_idx}" + keras_name = f"block_{block_idx+1}_{inverted_block}" + hf_name = f"blocks.{block_idx+1}.{inverted_block}" + + # ConvBnAct Block + if block_idx == num_stacks - 1 and version == "v3": + port_conv2d(f"{keras_name}_conv", f"{hf_name}.conv") + port_batch_normalization(f"{keras_name}_bn", f"{hf_name}.bn1") + + # Inverted Residual Block + else: + port_conv2d(f"{keras_name}_conv1", f"{hf_name}.conv_pw") + port_batch_normalization(f"{keras_name}_bn1", f"{hf_name}.bn1") + port_conv2d(f"{keras_name}_conv2", f"{hf_name}.conv_dw") + port_batch_normalization(f"{keras_name}_bn2", f"{hf_name}.bn2") + + if backbone.stackwise_se_ratio[block_idx][inverted_block]: + port_conv2d( + f"{keras_name}_se_conv_reduce", + f"{hf_name}.se.conv_reduce", + ) + port_conv2d( + f"{keras_name}_se_conv_expand", + f"{hf_name}.se.conv_expand", + ) + + port_conv2d(f"{keras_name}_c onv3", f"{hf_name}.conv_pwl") + port_batch_normalization(f"{keras_name}_bn3", f"{hf_name}.bn3") + + if version == "v3": + hf_name = f"blocks.{num_stacks+1}.0" + keras_name = "Dfs" + port_conv2d("output_conv", "conv_head") + if version == "v2": + port_batch_normalization("output_batch_norm", "bn2") + + # if version == "v1": + # if block_idx == 0 and ( + # block_type == "bottleneck_block" or stack_index > 0 + # ): + # port_conv2d( + # f"{keras_name}_0_conv", f"{hf_name}.downsample.0" + # ) + # port_batch_normalization( + # f"{keras_name}_0_bn", f"{hf_name}.downsample.1" + # ) + # port_conv2d(f"{keras_name}_1_conv", f"{hf_name}.conv1") + # port_batch_normalization(f"{keras_name}_1_bn", f"{hf_name}.bn1") + # port_conv2d(f"{keras_name}_2_conv", f"{hf_name}.conv2") + # port_batch_normalization(f"{keras_name}_2_bn", f"{hf_name}.bn2") + # if block_type == "bottleneck_block": + # port_conv2d(f"{keras_name}_3_conv", f"{hf_name}.conv3") + # port_batch_normalization( + # f"{keras_name}_3_bn", f"{hf_name}.bn3" + # ) + # else: + # if block_idx == 0 and ( + # block_type == "bottleneck_block" or stack_index > 0 + # ): + # port_conv2d( + # f"{keras_name}_0_conv", f"{hf_name}.downsample.conv" + # ) + # port_batch_normalization( + # f"{keras_name}_pre_activation_bn", f"{hf_name}.norm1" + # ) + # port_conv2d(f"{keras_name}_1_conv", f"{hf_name}.conv1") + # port_batch_normalization( + # f"{keras_name}_1_bn", f"{hf_name}.norm2" + # ) + # port_conv2d(f"{keras_name}_2_conv", f"{hf_name}.conv2") + # if block_type == "bottleneck_block": + # port_batch_normalization( + # f"{keras_name}_2_bn", f"{hf_name}.norm3" + # ) + # port_conv2d(f"{keras_name}_3_conv", f"{hf_name}.conv3") + + +def convert_head(task, loader, timm_config): + prefix = "classifier." + loader.port_weight( + task.output_dense.kernel, + hf_weight_key=prefix + "weight", + hook_fn=lambda x, _: np.transpose(np.squeeze(x)), + ) + loader.port_weight( + task.output_dense.bias, + hf_weight_key=prefix + "bias", + ) diff --git a/keras_hub/src/utils/timm/convert_mobilenet_test.py b/keras_hub/src/utils/timm/convert_mobilenet_test.py new file mode 100644 index 0000000000..8f876d261f --- /dev/null +++ b/keras_hub/src/utils/timm/convert_mobilenet_test.py @@ -0,0 +1,26 @@ +import pytest +from keras import ops + +from keras_hub.src.models.backbone import Backbone +from keras_hub.src.models.image_classifier import ImageClassifier +from keras_hub.src.tests.test_case import TestCase + + +class TimmMobileNetBackboneTest(TestCase): + @pytest.mark.large + def test_convert_mobilenet_backbone(self): + model = Backbone.from_preset( + "hf://timm/mobilenetv3_small_050.lamb_in1k" + ) + outputs = model.predict(ops.ones((1, 224, 224, 3))) + self.assertEqual(outputs.shape, (1, 14, 14, 1024)) + + @pytest.mark.large + def test_convert_mobilenet_classifier(self): + model = ImageClassifier.from_preset( + "hf://timm/mobilenetv3_small_050.lamb_in1k" + ) + outputs = model.predict(ops.ones((1, 512, 512, 3))) + self.assertEqual(outputs.shape, (1, 1000)) + + # TODO: compare numerics with timm model diff --git a/keras_hub/src/utils/timm/preset_loader.py b/keras_hub/src/utils/timm/preset_loader.py index e5b72333e0..65149b042f 100644 --- a/keras_hub/src/utils/timm/preset_loader.py +++ b/keras_hub/src/utils/timm/preset_loader.py @@ -4,6 +4,7 @@ from keras_hub.src.utils.preset_utils import PresetLoader from keras_hub.src.utils.preset_utils import jax_memory_cleanup from keras_hub.src.utils.timm import convert_densenet +from keras_hub.src.utils.timm import convert_mobilenet from keras_hub.src.utils.timm import convert_resnet from keras_hub.src.utils.transformers.safetensor_utils import SafetensorLoader @@ -16,6 +17,8 @@ def __init__(self, preset, config): self.converter = convert_resnet elif "densenet" in architecture: self.converter = convert_densenet + elif "mobilenet" in architecture: + self.converter = convert_mobilenet else: raise ValueError( "KerasHub has no converter for timm models " From 2e1e9c003591306bb5de5b67c10f766daec0055a Mon Sep 17 00:00:00 2001 From: ushareng Date: Tue, 8 Oct 2024 01:43:47 +0530 Subject: [PATCH 10/21] checkpoint conversion added --- keras_hub/api/layers/__init__.py | 3 + keras_hub/api/models/__init__.py | 3 + .../mobilenet/mobilenet_image_classifier.py | 5 +- ...mobilenet_image_classifier_preprocessor.py | 14 ++ .../mobilenet/mobilenet_image_converter.py | 8 + .../src/models/mobilenet/mobilenet_presets.py | 0 keras_hub/src/utils/timm/convert_mobilenet.py | 230 +++++++----------- .../src/utils/timm/convert_mobilenet_test.py | 2 +- .../convert_mobilenet_checkpoints.py | 112 +++++++++ 9 files changed, 235 insertions(+), 142 deletions(-) create mode 100644 keras_hub/src/models/mobilenet/mobilenet_image_classifier_preprocessor.py create mode 100644 keras_hub/src/models/mobilenet/mobilenet_image_converter.py create mode 100644 keras_hub/src/models/mobilenet/mobilenet_presets.py create mode 100644 tools/checkpoint_conversion/convert_mobilenet_checkpoints.py diff --git a/keras_hub/api/layers/__init__.py b/keras_hub/api/layers/__init__.py index 0fe7b300fa..95d0d40919 100644 --- a/keras_hub/api/layers/__init__.py +++ b/keras_hub/api/layers/__init__.py @@ -40,6 +40,9 @@ from keras_hub.src.models.densenet.densenet_image_converter import ( DenseNetImageConverter, ) +from keras_hub.src.models.mobilenet.mobilenet_image_converter import ( + MobileNetImageConverter, +) from keras_hub.src.models.pali_gemma.pali_gemma_image_converter import ( PaliGemmaImageConverter, ) diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index 1450ddceb3..7f4f87d9cc 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -210,6 +210,9 @@ from keras_hub.src.models.mobilenet.mobilenet_image_classifier import ( MobileNetImageClassifier, ) +from keras_hub.src.models.mobilenet.mobilenet_image_classifier_preprocessor import ( + MobileNetImageClassifierPreprocessor, +) from keras_hub.src.models.opt.opt_backbone import OPTBackbone from keras_hub.src.models.opt.opt_causal_lm import OPTCausalLM from keras_hub.src.models.opt.opt_causal_lm_preprocessor import ( diff --git a/keras_hub/src/models/mobilenet/mobilenet_image_classifier.py b/keras_hub/src/models/mobilenet/mobilenet_image_classifier.py index 96977bdf9f..bf07914781 100644 --- a/keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +++ b/keras_hub/src/models/mobilenet/mobilenet_image_classifier.py @@ -1,8 +1,11 @@ from keras_hub.src.api_export import keras_hub_export from keras_hub.src.models.image_classifier import ImageClassifier from keras_hub.src.models.mobilenet.mobilenet_backbone import MobileNetBackbone - +from keras_hub.src.models.mobilenet.mobilenet_image_classifier_preprocessor import ( + MobileNetImageClassifierPreprocessor, +) @keras_hub_export("keras_hub.models.MobileNetImageClassifier") class MobileNetImageClassifier(ImageClassifier): backbone_cls = MobileNetBackbone + preprocessor_cls = MobileNetImageClassifierPreprocessor diff --git a/keras_hub/src/models/mobilenet/mobilenet_image_classifier_preprocessor.py b/keras_hub/src/models/mobilenet/mobilenet_image_classifier_preprocessor.py new file mode 100644 index 0000000000..2ad3ef1ed7 --- /dev/null +++ b/keras_hub/src/models/mobilenet/mobilenet_image_classifier_preprocessor.py @@ -0,0 +1,14 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.image_classifier_preprocessor import ( + ImageClassifierPreprocessor, +) +from keras_hub.src.models.mobilenet.mobilenet_backbone import MobileNetBackbone +from keras_hub.src.models.mobilenet.mobilenet_image_converter import ( + MobileNetImageConverter, +) + + +@keras_hub_export("keras_hub.models.MobileNetImageClassifierPreprocessor") +class MobileNetImageClassifierPreprocessor(ImageClassifierPreprocessor): + backbone_cls = MobileNetBackbone + image_converter_cls = MobileNetImageConverter diff --git a/keras_hub/src/models/mobilenet/mobilenet_image_converter.py b/keras_hub/src/models/mobilenet/mobilenet_image_converter.py new file mode 100644 index 0000000000..da6fb0ab6a --- /dev/null +++ b/keras_hub/src/models/mobilenet/mobilenet_image_converter.py @@ -0,0 +1,8 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.preprocessing.image_converter import ImageConverter +from keras_hub.src.models.mobilenet.mobilenet_backbone import MobileNetBackbone + + +@keras_hub_export("keras_hub.layers.MobileNetImageConverter") +class MobileNetImageConverter(ImageConverter): + backbone_cls = MobileNetBackbone diff --git a/keras_hub/src/models/mobilenet/mobilenet_presets.py b/keras_hub/src/models/mobilenet/mobilenet_presets.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/keras_hub/src/utils/timm/convert_mobilenet.py b/keras_hub/src/utils/timm/convert_mobilenet.py index e362fca7be..7b4cf4c8e1 100644 --- a/keras_hub/src/utils/timm/convert_mobilenet.py +++ b/keras_hub/src/utils/timm/convert_mobilenet.py @@ -27,80 +27,76 @@ def convert_backbone_config(timm_config): stackwise_num_filters = [[16, 16], [24, 24, 24], [24, 24], [48, 48, 48]] stackwise_kernel_size = [[3, 3], [5, 5, 5], [5, 5], [5, 5, 5]] stackwise_num_strides = [[2, 1], [2, 1, 1], [1, 1], [2, 1, 1]] - stackwise_se_ratio = ( - [ - [None, None], - [0.25, 0.25, 0.25], - [0.3, 0.3], - [0.3, 0.25, 0.25], - ], - ) - stackwise_activation = ( - [ - ["relu6", "relu6"], - ["hard_swish", "hard_swish", "hard_swish"], - ["hard_swish", "hard_swish"], - ["hard_swish", "hard_swish", "hard_swish"], - ], - ) + stackwise_se_ratio = [ + [None, None], + [0.25, 0.25, 0.25], + [0.3, 0.3], + [0.3, 0.25, 0.25], + ] + stackwise_activation = [ + ["relu", "relu"], + ["hard_swish", "hard_swish", "hard_swish"], + ["hard_swish", "hard_swish"], + ["hard_swish", "hard_swish", "hard_swish"], + ] output_num_filters = 1024 input_num_filters = 16 depthwise_filters = 8 squeeze_and_excite = 0.5 last_layer_filter = 288 - elif timm_architecture == "mobilenetv2_050": - stackwise_num_blocks = ([2, 3, 4, 3, 3, 1],) - stackwise_expansion = ( - [ - [48, 96], - [96, 96, 96], - [96, 192, 192, 192], - [192, 288, 288], - [288, 480, 480], - [480], - ], - ) - stackwise_num_filters = ( - [ - [16, 16], - [16, 16, 16], - [32, 32, 32, 32], - [48, 48, 48], - [80, 80, 80], - [160], - ], - ) - stackwise_kernel_size = ( - [[3, 3], [3, 3, 3], [3, 3, 3, 3], [3, 3, 3], [3, 3, 3], [3]], - ) - stackwise_num_strides = ( - [[2, 1], [2, 1, 1], [2, 1, 1, 1], [1, 1, 1], [2, 1, 1], [1]], - ) - stackwise_se_ratio = ( - [ - [None, None], - [None, None, None], - [None, None, None, None], - [None, None, None], - [None, None, None], - [None], - ], - ) - stackwise_activation = ( - [ - ["relu6", "relu6"], - ["relu6", "relu6", "relu6"], - ["relu6", "relu6", "relu6", "relu6"], - ["relu6", "relu6", "relu6"], - ["relu6", "relu6", "relu6"], - ["relu6"], - ], - ) - output_num_filters = 1280 - input_num_filters = 16 - depthwise_filters = 8 - squeeze_and_excite = None + # elif timm_architecture == "mobilenetv2_050": + # stackwise_num_blocks = ([2, 3, 4, 3, 3, 1],) + # stackwise_expansion = ( + # [ + # [48, 96], + # [96, 96, 96], + # [96, 192, 192, 192], + # [192, 288, 288], + # [288, 480, 480], + # [480], + # ], + # ) + # stackwise_num_filters = ( + # [ + # [16, 16], + # [16, 16, 16], + # [32, 32, 32, 32], + # [48, 48, 48], + # [80, 80, 80], + # [160], + # ], + # ) + # stackwise_kernel_size = ( + # [[3, 3], [3, 3, 3], [3, 3, 3, 3], [3, 3, 3], [3, 3, 3], [3]], + # ) + # stackwise_num_strides = ( + # [[2, 1], [2, 1, 1], [2, 1, 1, 1], [1, 1, 1], [2, 1, 1], [1]], + # ) + # stackwise_se_ratio = ( + # [ + # [None, None], + # [None, None, None], + # [None, None, None, None], + # [None, None, None], + # [None, None, None], + # [None], + # ], + # ) + # stackwise_activation = ( + # [ + # ["relu6", "relu6"], + # ["relu6", "relu6", "relu6"], + # ["relu6", "relu6", "relu6", "relu6"], + # ["relu6", "relu6", "relu6"], + # ["relu6", "relu6", "relu6"], + # ["relu6"], + # ], + # ) + # output_num_filters = 1280 + # input_num_filters = 16 + # depthwise_filters = 8 + # squeeze_and_excite = None else: raise ValueError( f"Currently, the architecture {timm_architecture} is not supported." @@ -158,7 +154,7 @@ def port_batch_normalization(keras_layer_name, hf_weight_prefix): # DepthWise Block (block 0) hf_name = "blocks.0.0" - keras_name = "blocks_0" + keras_name = "block_0_0" port_conv2d(f"{keras_name}_conv1", f"{hf_name}.conv_dw") port_batch_normalization(f"{keras_name}_bn1", f"{hf_name}.bn1") @@ -172,86 +168,40 @@ def port_batch_normalization(keras_layer_name, hf_weight_prefix): num_stacks = len(backbone.stackwise_num_blocks) for block_idx in range(num_stacks): for inverted_block in range(backbone.stackwise_num_blocks[block_idx]): - # if version == "v1": - # keras_name = f"stack{stack_index}_block{block_idx}" - # hf_name = f"layer{stack_index+1}.{block_idx}" - # else: - # keras_name = f"stack{stack_index}_block{block_idx}" - # hf_name = f"stages.{stack_index}.blocks.{block_idx}" keras_name = f"block_{block_idx+1}_{inverted_block}" hf_name = f"blocks.{block_idx+1}.{inverted_block}" - # ConvBnAct Block - if block_idx == num_stacks - 1 and version == "v3": - port_conv2d(f"{keras_name}_conv", f"{hf_name}.conv") - port_batch_normalization(f"{keras_name}_bn", f"{hf_name}.bn1") - # Inverted Residual Block - else: - port_conv2d(f"{keras_name}_conv1", f"{hf_name}.conv_pw") - port_batch_normalization(f"{keras_name}_bn1", f"{hf_name}.bn1") - port_conv2d(f"{keras_name}_conv2", f"{hf_name}.conv_dw") - port_batch_normalization(f"{keras_name}_bn2", f"{hf_name}.bn2") - - if backbone.stackwise_se_ratio[block_idx][inverted_block]: - port_conv2d( - f"{keras_name}_se_conv_reduce", - f"{hf_name}.se.conv_reduce", - ) - port_conv2d( - f"{keras_name}_se_conv_expand", - f"{hf_name}.se.conv_expand", - ) - - port_conv2d(f"{keras_name}_c onv3", f"{hf_name}.conv_pwl") - port_batch_normalization(f"{keras_name}_bn3", f"{hf_name}.bn3") + port_conv2d(f"{keras_name}_conv1", f"{hf_name}.conv_pw") + port_batch_normalization(f"{keras_name}_bn1", f"{hf_name}.bn1") + port_conv2d(f"{keras_name}_conv2", f"{hf_name}.conv_dw") + port_batch_normalization(f"{keras_name}_bn2", f"{hf_name}.bn2") + + if backbone.stackwise_se_ratio[block_idx][inverted_block]: + port_conv2d( + f"{keras_name}_se_conv_reduce", + f"{hf_name}.se.conv_reduce", + ) + port_conv2d( + f"{keras_name}_se_conv_expand", + f"{hf_name}.se.conv_expand", + ) + + port_conv2d(f"{keras_name}_conv3", f"{hf_name}.conv_pwl") + port_batch_normalization(f"{keras_name}_bn3", f"{hf_name}.bn3") + + # ConvBnAct Block + port_conv2d(f"block_{num_stacks+1}_0_conv", f"blocks.{num_stacks+1}.0.conv") + port_batch_normalization( + f"block_{num_stacks+1}_0_bn", f"blocks.{num_stacks+1}.0.bn1" + ) if version == "v3": hf_name = f"blocks.{num_stacks+1}.0" keras_name = "Dfs" port_conv2d("output_conv", "conv_head") - if version == "v2": - port_batch_normalization("output_batch_norm", "bn2") - - # if version == "v1": - # if block_idx == 0 and ( - # block_type == "bottleneck_block" or stack_index > 0 - # ): - # port_conv2d( - # f"{keras_name}_0_conv", f"{hf_name}.downsample.0" - # ) - # port_batch_normalization( - # f"{keras_name}_0_bn", f"{hf_name}.downsample.1" - # ) - # port_conv2d(f"{keras_name}_1_conv", f"{hf_name}.conv1") - # port_batch_normalization(f"{keras_name}_1_bn", f"{hf_name}.bn1") - # port_conv2d(f"{keras_name}_2_conv", f"{hf_name}.conv2") - # port_batch_normalization(f"{keras_name}_2_bn", f"{hf_name}.bn2") - # if block_type == "bottleneck_block": - # port_conv2d(f"{keras_name}_3_conv", f"{hf_name}.conv3") - # port_batch_normalization( - # f"{keras_name}_3_bn", f"{hf_name}.bn3" - # ) - # else: - # if block_idx == 0 and ( - # block_type == "bottleneck_block" or stack_index > 0 - # ): - # port_conv2d( - # f"{keras_name}_0_conv", f"{hf_name}.downsample.conv" - # ) - # port_batch_normalization( - # f"{keras_name}_pre_activation_bn", f"{hf_name}.norm1" - # ) - # port_conv2d(f"{keras_name}_1_conv", f"{hf_name}.conv1") - # port_batch_normalization( - # f"{keras_name}_1_bn", f"{hf_name}.norm2" - # ) - # port_conv2d(f"{keras_name}_2_conv", f"{hf_name}.conv2") - # if block_type == "bottleneck_block": - # port_batch_normalization( - # f"{keras_name}_2_bn", f"{hf_name}.norm3" - # ) - # port_conv2d(f"{keras_name}_3_conv", f"{hf_name}.conv3") + # if version == "v2": + # port_batch_normalization("output_batch_norm", "bn2") def convert_head(task, loader, timm_config): diff --git a/keras_hub/src/utils/timm/convert_mobilenet_test.py b/keras_hub/src/utils/timm/convert_mobilenet_test.py index 8f876d261f..4d036ae033 100644 --- a/keras_hub/src/utils/timm/convert_mobilenet_test.py +++ b/keras_hub/src/utils/timm/convert_mobilenet_test.py @@ -20,7 +20,7 @@ def test_convert_mobilenet_classifier(self): model = ImageClassifier.from_preset( "hf://timm/mobilenetv3_small_050.lamb_in1k" ) - outputs = model.predict(ops.ones((1, 512, 512, 3))) + outputs = model.predict(ops.ones((1, 224, 224, 3))) self.assertEqual(outputs.shape, (1, 1000)) # TODO: compare numerics with timm model diff --git a/tools/checkpoint_conversion/convert_mobilenet_checkpoints.py b/tools/checkpoint_conversion/convert_mobilenet_checkpoints.py new file mode 100644 index 0000000000..270d18eef9 --- /dev/null +++ b/tools/checkpoint_conversion/convert_mobilenet_checkpoints.py @@ -0,0 +1,112 @@ +"""Convert mobilenet checkpoints. + +python tools/checkpoint_conversion/convert_mobilenet_checkpoints.py \ + --preset mobilenetv3_small_050 --upload_uri kaggle://alexbutcher/mobilenet/keras/mobilenetv3_small_050 +""" + +import os +import shutil + +import keras +import numpy as np +import PIL +import timm +import torch +from absl import app +from absl import flags + +import keras_hub + +PRESET_MAP = { + "mobilenetv3_small_050": "timm/mobilenetv3_small_050.lamb_in1k", +} +FLAGS = flags.FLAGS + + +flags.DEFINE_string( + "preset", + None, + "Must be a valid `MobileNet` preset from KerasHub", + required=True, +) +flags.DEFINE_string( + "upload_uri", + None, + 'Could be "kaggle://keras/{variant}/keras/{preset}_int8"', + required=False, +) + + +def validate_output(keras_model, timm_model): + file = keras.utils.get_file( + origin=( + "https://storage.googleapis.com/keras-cv/" + "models/paligemma/cow_beach_1.png" + ) + ) + image = PIL.Image.open(file) + batch = np.array([image]) + + # Preprocess with Timm. + data_config = timm.data.resolve_model_data_config(timm_model) + data_config["crop_pct"] = 1.0 # Stop timm from cropping. + transforms = timm.data.create_transform(**data_config, is_training=False) + timm_preprocessed = transforms(image) + timm_preprocessed = keras.ops.transpose(timm_preprocessed, axes=(1, 2, 0)) + timm_preprocessed = keras.ops.expand_dims(timm_preprocessed, 0) + + # Preprocess with Keras. + keras_preprocessed = keras_model.preprocessor(batch) + + # Call with Timm. Use the keras preprocessed image so we can keep modeling + # and preprocessing comparisons independent. + timm_batch = keras.ops.transpose(keras_preprocessed, axes=(0, 3, 1, 2)) + timm_batch = torch.from_numpy(np.array(timm_batch)) + timm_outputs = timm_model(timm_batch).detach().numpy() + timm_label = np.argmax(timm_outputs[0]) + + # Call with Keras. + keras_outputs = keras_model.predict(batch) + keras_label = np.argmax(keras_outputs[0]) + + print("🔶 Keras output:", keras_outputs[0, :10]) + print("🔶 TIMM output:", timm_outputs[0, :10]) + print("🔶 Keras label:", keras_label) + print("🔶 TIMM label:", timm_label) + modeling_diff = np.mean(np.abs(keras_outputs - timm_outputs)) + print("🔶 Modeling difference:", modeling_diff) + preprocessing_diff = np.mean(np.abs(keras_preprocessed - timm_preprocessed)) + print("🔶 Preprocessing difference:", preprocessing_diff) + + +def main(_): + preset = FLAGS.preset + if os.path.exists(preset): + shutil.rmtree(preset) + os.makedirs(preset) + + timm_name = PRESET_MAP[preset] + + print("✅ Loaded TIMM model.") + timm_model = timm.create_model(timm_name, pretrained=True) + timm_model = timm_model.eval() + + print("✅ Loaded KerasHub model.") + keras_model = keras_hub.models.ImageClassifier.from_preset( + "hf://" + timm_name, + ) + + keras_model.save_to_preset(f"./{preset}") + print(f"🏁 Preset saved to ./{preset}") + + validate_output(keras_model, timm_model) + + upload_uri = FLAGS.upload_uri + if upload_uri: + keras_hub.upload_preset(uri=upload_uri, preset=f"./{preset}") + print(f"🏁 Preset uploaded to {upload_uri}") + + +if __name__ == "__main__": + flags.mark_flag_as_required("preset") + app.run(main) From c40088db3528578789bb2dabf7a301ee94cee38a Mon Sep 17 00:00:00 2001 From: ushareng Date: Tue, 8 Oct 2024 01:58:20 +0530 Subject: [PATCH 11/21] preset added --- .../src/models/mobilenet/mobilenet_presets.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/keras_hub/src/models/mobilenet/mobilenet_presets.py b/keras_hub/src/models/mobilenet/mobilenet_presets.py index e69de29bb2..e18364676f 100644 --- a/keras_hub/src/models/mobilenet/mobilenet_presets.py +++ b/keras_hub/src/models/mobilenet/mobilenet_presets.py @@ -0,0 +1,15 @@ +"""MobileNet preset configurations.""" + +backbone_presets = { + "mobilenetv3_small_050": { + "metadata": { + "description": ( + "Small MObilenet V3 model pre-trained on the ImageNet 1k dataset " + "at a 224x224 resolution." + ), + "official_name": "MobileNet", + "path": "mobilenet3", + }, + "kaggle_handle": "kaggle://alexbutcher/mobilenet3/keras/mobilenetv3_small_050", + }, +} From 318e7d6ba841387b414b2d09812200378ae1b98c Mon Sep 17 00:00:00 2001 From: ushareng Date: Tue, 8 Oct 2024 02:50:26 +0530 Subject: [PATCH 12/21] preset testcase added MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit BytePairTokenizer must not split sequences of \n (#1910) * fix for loading of special tokens in Llama tokenizer * fix for Llama tokenizer which can have multiple end tokens * bug fix * adding some missing tokens to Llama3 tokenizer * fixed tests and Llama3Tokenizer init. * now loading correct eos_token config from Hugging Face checkpoint. Using hack for Keras checkpoint because it does not have this info * fix for BytePairTokenizer to make Lllama3-instruct work in chat: \n\n sequences are significant in the chat template and must be preserved by the tokenizer --------- Co-authored-by: Martin Görner fix for generation that never stops in Llama3-Instruct variants (#1904) * fix for loading of special tokens in Llama tokenizer * fix for Llama tokenizer which can have multiple end tokens * bug fix * adding some missing tokens to Llama3 tokenizer * fixed tests and Llama3Tokenizer init. * now loading correct eos_token config from Hugging Face checkpoint. Using hack for Keras checkpoint because it does not have this info --------- Co-authored-by: Martin Görner fix failing JAX GPU test (#1911) * fix tests * fix test Refactor `MMDiT`, add `ImageToImage` and `Inpaint` for SD3 (#1909) * Refactor `MMDiT` and add `ImageToImage` * Update model version * Fix minor bugs. * Add `Inpaint` for SD3. * Fix warnings of MMDiT. * Addcomment to Inpaint * Simplify `MMDiT` implementation and info of `summary()`. * Refactor `generate()` API of `TextToImage`, `ImageToImage` and `Inpaint`. Minor bug fix (#1915) Change to image_converter.image_size since it is a tuple and it's not a callable function. [Mix Transformer] Add Presets for MiTB0...MiTB5 (#1893) * add presets for mit * add standin paths * register presets in __init__.py * fix op in overlapping patching and embedding, start adding conversion utils * style * add padding to MiT patchingandembedding * update to support other presets * update conversin script * fix link for b5 * add cityscapes weights * update presets * update presets * update conversion script to make directories * use save_preset * change name of output dir * add preprocessor flow * api gen and add preprocessor to mits * conform to new image classifier style * format * resizing image converter -> ImageConverter * address comments refactoring remove default resizing for vision backbones (#1916) * remove defailt resizing * fix GPU test Update VGG model to be compatible with HF and add conversion scripts (#1914) Deeplab presets (#1918) * add preset configurations for deeplabv3 * fix uri * Add training details update presets to point to the main Keras Kaggle page (#1921) * update presets to point to the main keras page * update mit path Added test for the way BytePairTokenizer handles the \n\n sequence, which is important in Lama chat templates (#1912) * added test for the way BytePairTokenizer handles the \n\n sequence, which is important in Lama chat templates * un commented the test lines that were commented by mistake * fixed linter errors Task models fix (#1922) * added test for the way BytePairTokenizer handles the \n\n sequence, which is important in Lama chat templates * fix for wrongly configured task models LLama, PaliGemma, Mistral and Phi3 + test * comments * un commented the test lines that were commented by mistake * fixed linter errors adding option strip_prompt to generate() (#1913) * added test for the way BytePairTokenizer handles the \n\n sequence, which is important in Lama chat templates * un commented the test lines that were commented by mistake * fixed linter errors * added options strip_prompt to generate() * fix for tensorflow: the compiled version of generate(strip_prompt=True) now works + code refactoring to make it more understandable * added test for generate(strip_prompt=True) * minor edits Layout map for Llama (#1923) * added test for the way BytePairTokenizer handles the \n\n sequence, which is important in Lama chat templates * un commented the test lines that were commented by mistake * fixed linter errors * added default layout map for Llama * minor fixes in tests Update deeplab_v3_presets.py (#1924) Add paths to get SAM weights from (#1925) Two fixes for image resizing in preprocessing (#1927) 1. Properly display when are not resizing the input image in `model.summary()` 2. Allow setting the `image_size` directly on a preprocessing layer. 2. is just to allow a more consistent way to set the input shape across tasks. We now have: ```python text_classifier = keras_hub.models.TextClassifer.from_preset( "bert_base_en", ) text_classifier.preprocessor.sequence_length = 256 image_classifier = keras_hub.models.TextClassifer.from_preset( "bert_base_en", ) image_classifier.preprocessor.image_size = (256, 256) multi_modal_lm = keras_hub.models.CausalLM.from_preset( "some_preset", ) multi_modal_lm.preprocessor.sequence_length = 256 multi_modal_lm.preprocessor.image_size = (256, 256) ``` add back default image resizing (#1926) Update deeplab_v3_presets.py (#1928) * Update deeplab_v3_presets.py * Update deeplab_v3_presets.py Update PaliGemma to remove `include_rescaling` arg (#1917) * update PaliGemma * update conversion script * fix GPU tests fix path (#1929) * fix path * nit Fix paligemma checkpoint conversion script (#1931) * add back default image resizing * fix bug in image converter * fix paligemma checkpoint conversion file * fix preset name * remove debug code * revert unintended changes update preset path to point to latest version of models (#1932) Update sdv3 path (#1934) update sam docstring to show correct backbone in docstring (#1936) Convert input dict to tensors during train_on_batch (#1919) Register VGG presets. (#1935) * register vgg preset * nit * nit * nit Add ResNetVD presets (#1897) * Add ResNetVD presets * Updated Kaggle handles * Add weight conversion script for ResNet_vd * Add usage rebase conflict resolved conflict resolve Update sam_presets.py (#1940) Update vit_det_backbone.py (#1941) fix gpu test (#1939) * fix gpu test * cast input * update dtype * change to resnet preset * remove arg Added Support for Returning Attention Scores in TransformerEncoder call (#1879) * Added: Return attention scores argument to transformer encoder * Added: docstring for return_attention_scores and added a test to chek the working of the argument * Fixed: Test case by removing print stmts and using self.assertAllEqual * Fixed: Linting Mark preset tests as large (#1942) * fix tests * fix test * Update preset_utils_test.py version bump to 0.17.0.dev0 (#1944) Update stable_diffusion_3_presets.py (#1946) [Semantic Segmentation] - Add SegFormer Architecture, Weight Conversion Script and Presets (#1883) * initial commit - tf-based, kcv * porting to keras_hub structure - removing aliases, presets, etc. * enable instantiation of segformer backbone with custom MiT backbone * remove num_classes from backbone * fix input * add imports to __init__ * update preset * update docstrings * add basic tests * remove redundant imports * update docstrings * remove unused import * running api_gen.py * undo refactor of mit * update docstrings * add presets for mit * add standin paths * add presets for segformer backbone * register presets in __init__.py * addressing comments * addressing comments * addressing comments * update most tests * add remaining tests * remove copyright * fix test * override from_config * fix op in overlapping patching and embedding, start adding conversion utils * style * add padding to MiT patchingandembedding * update to support other presets * update conversin script * fix link for b5 * add cityscapes weights * update presets * update presets * update conversion script to make directories * use save_preset * change name of output dir * add preprocessor flow * api gen and add preprocessor to mits * conform to new image classifier style * format * resizing image converter -> ImageConverter * merge mit branch into segformer branch * add preprocessor and converter * address comments * clarify backbone usage * add conversion script * numerical equivalence changes * fix numerical inaccuracies * update conversion script * update conversion script * remove transpose * add preprocessor to segformer class * fix preset path * update test shape * update presets * update test shape * expand docstrings * add rescaling and normalization to preprocessor * remove backbone presets, remove copyrights, remove backbone cls from segmenter * remove copyright and unused import * apply same transformation to masks as input images * fix import * fix shape in tests Update readme (#1949) * Update README.md * Update README.md Update llama_backbone.py docstring (#1950) Update path (#1953) Update preset path for keras.io. There is no LLaMA2 in keras.io https://keras.io/api/keras_hub/models/llama2 This is the actual link: https://keras.io/api/keras_hub/models/llama2 For Vicuna it does not have it's own model direcotry, since it is also the part of Llama,, updated the path. Update SD3 init parameters (replacing `height`, `width` with `image_shape`) (#1951) * Replace SD3 `height` and `width` with `image_shape` * Update URI * Revert comment * Update SD3 handle * Replace `height` and `width` with `image_shape` * Update docstrings * Fix CI Update docstring (#1954) AudioConverter is registered as "keras_hub.layers.WhisperAudioConverter" and not as part of models. updated Mobilenet backbone to match it with torch implementation timm script added checkpoint conversion added Refactoring --- README.md | 30 +- keras_hub/api/layers/__init__.py | 7 + keras_hub/api/models/__init__.py | 27 +- .../layers/modeling/transformer_encoder.py | 35 +- .../modeling/transformer_encoder_test.py | 11 + .../layers/preprocessing/image_converter.py | 3 +- .../preprocessing/image_converter_test.py | 27 +- keras_hub/src/models/causal_lm.py | 42 +- .../deeplab_v3/deeplab_v3_backbone_test.py | 1 + .../models/deeplab_v3/deeplab_v3_presets.py | 18 +- .../src/models/densenet/densenet_presets.py | 6 +- keras_hub/src/models/gemma/gemma_backbone.py | 14 +- .../src/models/gemma/gemma_backbone_test.py | 8 +- keras_hub/src/models/image_to_image.py | 411 ++++++ keras_hub/src/models/inpaint.py | 513 ++++++++ keras_hub/src/models/llama/llama_backbone.py | 120 +- .../src/models/llama/llama_backbone_test.py | 85 ++ keras_hub/src/models/llama/llama_causal_lm.py | 4 +- keras_hub/src/models/llama/llama_presets.py | 10 +- .../llama3_causal_lm_preprocessor_test.py | 2 + .../models/llama3/llama3_causal_lm_test.py | 10 +- .../src/models/llama3/llama3_tokenizer.py | 27 +- .../models/llama3/llama3_tokenizer_test.py | 2 + .../src/models/mistral/mistral_causal_lm.py | 4 +- keras_hub/src/models/mit/__init__.py | 6 + .../mit_backbone.py} | 23 +- keras_hub/src/models/mit/mit_backbone_test.py | 45 + .../mit_image_classifier.py} | 6 +- .../mit/mit_image_classifier_preprocessor.py | 12 + .../mit_image_classifier_test.py} | 14 +- .../src/models/mit/mit_image_converter.py | 8 + .../mit_layers.py} | 30 +- keras_hub/src/models/mit/mit_presets.py | 151 +++ .../mix_transformer_backbone_test.py | 12 +- .../src/models/mix_transformer/__init__.py | 0 .../models/mobilenet/mobilenet_backbone.py | 1143 +++++++++-------- .../mobilenet/mobilenet_backbone_test.py | 7 +- .../mobilenet/mobilenet_image_classifier.py | 1 + .../mobilenet_image_classifier_test.py | 17 +- .../models/pali_gemma/pali_gemma_backbone.py | 6 - .../models/pali_gemma/pali_gemma_causal_lm.py | 4 +- .../models/pali_gemma/pali_gemma_presets.py | 10 +- .../src/models/pali_gemma/pali_gemma_vit.py | 13 - .../models/pali_gemma/pali_gemma_vit_test.py | 17 - keras_hub/src/models/phi3/phi3_causal_lm.py | 4 +- keras_hub/src/models/preprocessor.py | 24 +- .../src/models/resnet/resnet_backbone.py | 3 +- keras_hub/src/models/resnet/resnet_presets.py | 153 ++- .../src/models/sam/sam_image_segmenter.py | 2 +- keras_hub/src/models/sam/sam_presets.py | 6 +- keras_hub/src/models/segformer/__init__.py | 8 + .../models/segformer/segformer_backbone.py | 163 +++ .../segformer/segformer_backbone_tests.py | 76 ++ .../segformer/segformer_image_converter.py | 8 + .../segformer/segformer_image_segmenter.py | 171 +++ .../segformer_image_segmenter_preprocessor.py | 31 + .../segformer_image_segmenter_tests.py | 65 + .../src/models/segformer/segformer_presets.py | 136 ++ .../src/models/stable_diffusion_3/mmdit.py | 485 ++++--- .../stable_diffusion_3_backbone.py | 170 ++- .../stable_diffusion_3_backbone_test.py | 9 +- .../stable_diffusion_3_image_to_image.py | 171 +++ .../stable_diffusion_3_image_to_image_test.py | 180 +++ .../stable_diffusion_3_inpaint.py | 194 +++ .../stable_diffusion_3_inpaint_test.py | 197 +++ .../stable_diffusion_3_presets.py | 4 +- .../stable_diffusion_3_text_to_image.py | 23 +- .../stable_diffusion_3_text_to_image_test.py | 29 +- keras_hub/src/models/task.py | 7 +- keras_hub/src/models/text_to_image.py | 125 +- keras_hub/src/models/vae/vae_backbone.py | 14 +- keras_hub/src/models/vgg/__init__.py | 5 + keras_hub/src/models/vgg/vgg_backbone.py | 5 +- keras_hub/src/models/vgg/vgg_backbone_test.py | 2 +- .../src/models/vgg/vgg_image_classifier.py | 55 +- .../vgg/vgg_image_classifier_preprocessor.py | 12 + .../models/vgg/vgg_image_classifier_test.py | 18 +- .../src/models/vgg/vgg_image_converter.py | 8 + keras_hub/src/models/vgg/vgg_presets.py | 56 + .../src/models/vit_det/vit_det_backbone.py | 4 +- .../models/whisper/whisper_audio_converter.py | 2 +- keras_hub/src/tests/test_case.py | 13 +- .../src/tokenizers/byte_pair_tokenizer.py | 6 +- .../tokenizers/byte_pair_tokenizer_test.py | 20 +- keras_hub/src/tokenizers/tokenizer.py | 14 +- keras_hub/src/utils/pipeline_model.py | 6 +- keras_hub/src/utils/preset_utils.py | 20 +- keras_hub/src/utils/preset_utils_test.py | 2 + keras_hub/src/utils/timm/convert_mobilenet.py | 74 +- .../src/utils/timm/convert_mobilenet_test.py | 2 +- keras_hub/src/utils/timm/convert_vgg.py | 85 ++ keras_hub/src/utils/timm/preset_loader.py | 3 + .../src/utils/transformers/convert_llama3.py | 26 +- keras_hub/src/version_utils.py | 2 +- .../convert_mix_transformer.py | 196 +++ .../convert_pali_gemma_checkpoints.py | 69 +- .../convert_resnet_vd_checkpoints.py | 317 +++++ .../convert_sam_checkpoints.py | 4 + .../convert_segformer_checkpoints.py | 143 +++ .../convert_stable_diffusion_3_checkpoints.py | 34 +- .../convert_vgg_checkpoints.py | 116 ++ 101 files changed, 5527 insertions(+), 1192 deletions(-) create mode 100644 keras_hub/src/models/image_to_image.py create mode 100644 keras_hub/src/models/inpaint.py create mode 100644 keras_hub/src/models/mit/__init__.py rename keras_hub/src/models/{mix_transformer/mix_transformer_backbone.py => mit/mit_backbone.py} (87%) create mode 100644 keras_hub/src/models/mit/mit_backbone_test.py rename keras_hub/src/models/{mix_transformer/mix_transformer_classifier.py => mit/mit_image_classifier.py} (53%) create mode 100644 keras_hub/src/models/mit/mit_image_classifier_preprocessor.py rename keras_hub/src/models/{mix_transformer/mix_transformer_classifier_test.py => mit/mit_image_classifier_test.py} (78%) create mode 100644 keras_hub/src/models/mit/mit_image_converter.py rename keras_hub/src/models/{mix_transformer/mix_transformer_layers.py => mit/mit_layers.py} (92%) create mode 100644 keras_hub/src/models/mit/mit_presets.py rename keras_hub/src/models/{mix_transformer => mit}/mix_transformer_backbone_test.py (81%) delete mode 100644 keras_hub/src/models/mix_transformer/__init__.py create mode 100644 keras_hub/src/models/segformer/__init__.py create mode 100644 keras_hub/src/models/segformer/segformer_backbone.py create mode 100644 keras_hub/src/models/segformer/segformer_backbone_tests.py create mode 100644 keras_hub/src/models/segformer/segformer_image_converter.py create mode 100644 keras_hub/src/models/segformer/segformer_image_segmenter.py create mode 100644 keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py create mode 100644 keras_hub/src/models/segformer/segformer_image_segmenter_tests.py create mode 100644 keras_hub/src/models/segformer/segformer_presets.py create mode 100644 keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py create mode 100644 keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image_test.py create mode 100644 keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py create mode 100644 keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint_test.py create mode 100644 keras_hub/src/models/vgg/vgg_image_classifier_preprocessor.py create mode 100644 keras_hub/src/models/vgg/vgg_image_converter.py create mode 100644 keras_hub/src/models/vgg/vgg_presets.py create mode 100644 keras_hub/src/utils/timm/convert_vgg.py create mode 100644 tools/checkpoint_conversion/convert_mix_transformer.py create mode 100644 tools/checkpoint_conversion/convert_resnet_vd_checkpoints.py create mode 100644 tools/checkpoint_conversion/convert_segformer_checkpoints.py create mode 100644 tools/checkpoint_conversion/convert_vgg_checkpoints.py diff --git a/README.md b/README.md index cdbce8b532..5c9157e43c 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ [![contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat)](/~https://github.com/keras-team/keras-hub/issues) > [!IMPORTANT] -> 📢 KerasNLP is becoming KerasHub! 📢 Read +> 📢 KerasNLP is now KerasHub! 📢 Read > [the announcement](/~https://github.com/keras-team/keras-hub/issues/1831). > > We have renamed the repo to KerasHub in preparation for the release, but have not yet @@ -26,7 +26,7 @@ All models support JAX, TensorFlow, and PyTorch from a single model definition and can be fine-tuned on GPUs and TPUs out of the box. Models can be trained on individual accelerators with built-in PEFT techniques, or fine-tuned at scale with model and data parallel training. See our -[Getting Started guide](https://keras.io/guides/keras_nlp/getting_started) +[Getting Started guide](https://keras.io/guides/keras_hub/getting_started) to start learning our API. Browse our models on [Kaggle](https://www.kaggle.com/organizations/keras/models). We welcome contributions. @@ -35,9 +35,9 @@ We welcome contributions. ### For everyone -- [Home Page](https://keras.io/keras_nlp) -- [Developer Guides](https://keras.io/guides/keras_nlp) -- [API Reference](https://keras.io/api/keras_nlp) +- [Home Page](https://keras.io/keras_hub) +- [Developer Guides](https://keras.io/guides/keras_hub) +- [API Reference](https://keras.io/api/keras_hub) - [Pre-trained Models](https://www.kaggle.com/organizations/keras/models) ### For contributors @@ -56,7 +56,7 @@ Fine-tune a BERT classifier on IMDb movie reviews: import os os.environ["KERAS_BACKEND"] = "jax" # Or "tensorflow" or "torch"! -import keras_nlp +import keras_hub import tensorflow_datasets as tfds imdb_train, imdb_test = tfds.load( @@ -67,7 +67,7 @@ imdb_train, imdb_test = tfds.load( ) # Load a BERT model. -classifier = keras_nlp.models.Classifier.from_preset( +classifier = keras_hub.models.Classifier.from_preset( "bert_base_en", num_classes=2, activation="softmax", @@ -79,25 +79,17 @@ classifier.fit(imdb_train, validation_data=imdb_test) classifier.predict(["What an amazing movie!", "A total waste of my time."]) ``` -Try it out [in a colab](https://colab.research.google.com/gist/mattdangerw/e457e42d5ea827110c8d5cb4eb9d9a07/kerasnlp-quickstart.ipynb). +Try it out [in a colab](https://colab.research.google.com/drive/1gSWkh3yOLwmKAaNh2dQQ6kQIlnGte7P2?usp=sharing). For more in depth guides and examples, visit -[keras.io/keras_nlp](https://keras.io/keras_nlp/). +[keras.io/keras_hub](https://keras.io/keras_hub/). ## Installation -KerasHub is currently in pre-release. Note that pre-release versions may -introduce breaking changes to the API in future versions. For a stable and -supported experience, we recommend installing `keras-nlp` version 0.15.1: - -```bash -pip install keras-nlp==0.15.1 -``` - -To try out the latest pre-release version of KerasHub, you can use +To try out the latest version of KerasHub, you can use our nightly package: ```bash -pip install keras-hub-nightly +pip install keras-hub ``` KerasHub currently requires TensorFlow to be installed for use of the diff --git a/keras_hub/api/layers/__init__.py b/keras_hub/api/layers/__init__.py index 95d0d40919..53e0074414 100644 --- a/keras_hub/api/layers/__init__.py +++ b/keras_hub/api/layers/__init__.py @@ -40,6 +40,9 @@ from keras_hub.src.models.densenet.densenet_image_converter import ( DenseNetImageConverter, ) +from keras_hub.src.models.mix_transformer.mix_transformer_image_converter import ( + MiTImageConverter, +) from keras_hub.src.models.mobilenet.mobilenet_image_converter import ( MobileNetImageConverter, ) @@ -52,6 +55,10 @@ from keras_hub.src.models.sam.sam_image_converter import SAMImageConverter from keras_hub.src.models.sam.sam_mask_decoder import SAMMaskDecoder from keras_hub.src.models.sam.sam_prompt_encoder import SAMPromptEncoder +from keras_hub.src.models.segformer.segformer_image_converter import ( + SegFormerImageConverter, +) +from keras_hub.src.models.vgg.vgg_image_converter import VGGImageConverter from keras_hub.src.models.whisper.whisper_audio_converter import ( WhisperAudioConverter, ) diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index 7f4f87d9cc..88aa733c78 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -180,6 +180,8 @@ from keras_hub.src.models.image_segmenter_preprocessor import ( ImageSegmenterPreprocessor, ) +from keras_hub.src.models.image_to_image import ImageToImage +from keras_hub.src.models.inpaint import Inpaint from keras_hub.src.models.llama3.llama3_backbone import Llama3Backbone from keras_hub.src.models.llama3.llama3_causal_lm import Llama3CausalLM from keras_hub.src.models.llama3.llama3_causal_lm_preprocessor import ( @@ -200,11 +202,10 @@ MistralCausalLMPreprocessor, ) from keras_hub.src.models.mistral.mistral_tokenizer import MistralTokenizer -from keras_hub.src.models.mix_transformer.mix_transformer_backbone import ( - MiTBackbone, -) -from keras_hub.src.models.mix_transformer.mix_transformer_classifier import ( - MiTImageClassifier, +from keras_hub.src.models.mit.mit_backbone import MiTBackbone +from keras_hub.src.models.mit.mit_image_classifier import MiTImageClassifier +from keras_hub.src.models.mit.mit_image_classifier_preprocessor import ( + MiTImageClassifierPreprocessor, ) from keras_hub.src.models.mobilenet.mobilenet_backbone import MobileNetBackbone from keras_hub.src.models.mobilenet.mobilenet_image_classifier import ( @@ -268,11 +269,24 @@ from keras_hub.src.models.sam.sam_image_segmenter_preprocessor import ( SAMImageSegmenterPreprocessor, ) +from keras_hub.src.models.segformer.segformer_backbone import SegFormerBackbone +from keras_hub.src.models.segformer.segformer_image_segmenter import ( + SegFormerImageSegmenter, +) +from keras_hub.src.models.segformer.segformer_image_segmenter_preprocessor import ( + SegFormerImageSegmenterPreprocessor, +) from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM from keras_hub.src.models.seq_2_seq_lm_preprocessor import Seq2SeqLMPreprocessor from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import ( StableDiffusion3Backbone, ) +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_image_to_image import ( + StableDiffusion3ImageToImage, +) +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_inpaint import ( + StableDiffusion3Inpaint, +) from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image import ( StableDiffusion3TextToImage, ) @@ -291,6 +305,9 @@ from keras_hub.src.models.text_to_image import TextToImage from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone from keras_hub.src.models.vgg.vgg_image_classifier import VGGImageClassifier +from keras_hub.src.models.vgg.vgg_image_classifier_preprocessor import ( + VGGImageClassifierPreprocessor, +) from keras_hub.src.models.vit_det.vit_det_backbone import ViTDetBackbone from keras_hub.src.models.whisper.whisper_backbone import WhisperBackbone from keras_hub.src.models.whisper.whisper_tokenizer import WhisperTokenizer diff --git a/keras_hub/src/layers/modeling/transformer_encoder.py b/keras_hub/src/layers/modeling/transformer_encoder.py index 8d3fb0f950..5ed121e457 100644 --- a/keras_hub/src/layers/modeling/transformer_encoder.py +++ b/keras_hub/src/layers/modeling/transformer_encoder.py @@ -170,7 +170,12 @@ def build(self, inputs_shape): self.built = True def call( - self, inputs, padding_mask=None, attention_mask=None, training=None + self, + inputs, + padding_mask=None, + attention_mask=None, + training=None, + return_attention_scores=False, ): """Forward pass of the TransformerEncoder. @@ -185,6 +190,7 @@ def call( [batch_size, sequence_length, sequence_length]. training: a boolean indicating whether the layer should behave in training mode or in inference mode. + return_attention_scores: a boolean indicating whether the output should be `(attention_output, attention_scores)` if `True` or `attention_output` if `False`. Defaults to `False`. Returns: A Tensor of the same shape as the `inputs`. @@ -200,12 +206,24 @@ def call( residual = x if self.normalize_first: x = self._self_attention_layer_norm(x) - x = self._self_attention_layer( - query=x, - value=x, - attention_mask=self_attention_mask, - training=training, - ) + + if return_attention_scores: + x, attention_scores = self._self_attention_layer( + query=x, + value=x, + attention_mask=self_attention_mask, + return_attention_scores=return_attention_scores, + training=training, + ) + return x, attention_scores + else: + x = self._self_attention_layer( + query=x, + value=x, + attention_mask=self_attention_mask, + training=training, + ) + x = self._self_attention_dropout(x, training=training) x = x + residual if not self.normalize_first: @@ -222,6 +240,9 @@ def call( if not self.normalize_first: x = self._feedforward_layer_norm(x) + if return_attention_scores: + return x, attention_scores + return x def get_config(self): diff --git a/keras_hub/src/layers/modeling/transformer_encoder_test.py b/keras_hub/src/layers/modeling/transformer_encoder_test.py index c4763d3763..0f12a0920b 100644 --- a/keras_hub/src/layers/modeling/transformer_encoder_test.py +++ b/keras_hub/src/layers/modeling/transformer_encoder_test.py @@ -95,3 +95,14 @@ def test_mask_propagation(self): inputs._keras_mask = mask outputs = encoder(inputs) self.assertAllEqual(outputs._keras_mask, mask) + + def test_attention_scores(self): + encoder = TransformerEncoder(intermediate_dim=4, num_heads=2) + inputs = random.uniform(shape=[1, 4, 6]) + outputs, attention_scores = encoder( + inputs, return_attention_scores=True + ) + self.assertAllEqual(outputs.shape, inputs.shape) + + # attention scores shape (batch_size, num_of_attn_heads, seq_length, seq_length) + self.assertAllEqual(attention_scores.shape, [1, 2, 4, 4]) diff --git a/keras_hub/src/layers/preprocessing/image_converter.py b/keras_hub/src/layers/preprocessing/image_converter.py index e3b55bbde0..89142c469b 100644 --- a/keras_hub/src/layers/preprocessing/image_converter.py +++ b/keras_hub/src/layers/preprocessing/image_converter.py @@ -145,8 +145,9 @@ def image_size(self, value): @preprocessing_function def call(self, inputs): + x = inputs if self.image_size is not None: - x = self.resizing(inputs) + x = self.resizing(x) if self.scale is not None: x = x * self._expand_non_channel_dims(self.scale, x) if self.offset is not None: diff --git a/keras_hub/src/layers/preprocessing/image_converter_test.py b/keras_hub/src/layers/preprocessing/image_converter_test.py index 5e0fd940c2..d638ccf9ab 100644 --- a/keras_hub/src/layers/preprocessing/image_converter_test.py +++ b/keras_hub/src/layers/preprocessing/image_converter_test.py @@ -6,12 +6,10 @@ from keras import ops from keras_hub.src.layers.preprocessing.image_converter import ImageConverter -from keras_hub.src.models.pali_gemma.pali_gemma_backbone import ( - PaliGemmaBackbone, -) from keras_hub.src.models.pali_gemma.pali_gemma_image_converter import ( PaliGemmaImageConverter, ) +from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone from keras_hub.src.tests.test_case import TestCase @@ -86,24 +84,19 @@ def test_from_preset_errors(self): def test_save_to_preset(self): save_dir = self.get_temp_dir() converter = ImageConverter.from_preset( - "pali_gemma_3b_mix_224", + "resnet_50_imagenet", interpolation="nearest", ) converter.save_to_preset(save_dir) # Save a tiny backbone so the preset is valid. - backbone = PaliGemmaBackbone( - vocabulary_size=100, - image_size=224, - num_layers=1, - num_query_heads=1, - num_key_value_heads=1, - hidden_dim=8, - intermediate_dim=16, - head_dim=8, - vit_patch_size=14, - vit_num_heads=1, - vit_hidden_dim=8, - vit_num_layers=1, + backbone = ResNetBackbone( + input_conv_filters=[64], + input_conv_kernel_sizes=[7], + stackwise_num_filters=[64, 64, 64], + stackwise_num_blocks=[2, 2, 2], + stackwise_num_strides=[1, 2, 2], + block_type="basic_block", + use_pre_activation=True, ) backbone.save_to_preset(save_dir) diff --git a/keras_hub/src/models/causal_lm.py b/keras_hub/src/models/causal_lm.py index c86bd7be9f..2514022c4d 100644 --- a/keras_hub/src/models/causal_lm.py +++ b/keras_hub/src/models/causal_lm.py @@ -274,6 +274,7 @@ def generate( inputs, max_length=None, stop_token_ids="auto", + strip_prompt=False, ): """Generate text given prompt `inputs`. @@ -309,6 +310,9 @@ def generate( specify a list of token id's the model should stop on. Note that sequences of tokens will each be interpreted as a stop token, multi-token stop sequences are not supported. + strip_prompt: Optional. By default, generate() returns the full prompt + followed by its completion generated by the model. If this option + is set to True, only the newly generated text is returned. """ # Setup our three main passes. # 1. Optionally preprocessing strings to dense integer tensors. @@ -326,6 +330,10 @@ def generate( ) elif stop_token_ids == "auto": stop_token_ids = [self.preprocessor.tokenizer.end_token_id] + # Some models like Llama3 use two end tokens: <|eot_id|> in + # "instruct" versions and <|end_of_text|> in others. + if hasattr(self.preprocessor.tokenizer, "end_token2_id"): + stop_token_ids.append(self.preprocessor.tokenizer.end_token2_id) def preprocess(x): return self.preprocessor.generate_preprocess( @@ -335,6 +343,33 @@ def preprocess(x): def generate(x): return generate_function(x, stop_token_ids=stop_token_ids) + def strip_prompt_function(x, prompt): + # This function removes the prompt from the generated + # response, in a batch-friendly fashion. + y = {} + prompt_mask = prompt["padding_mask"] + seq_len = prompt_mask.shape[1] + + # We need to shift every output sequence by the size of the prompt. + shifts = -ops.sum(ops.cast(prompt_mask, "int"), axis=1) % seq_len + ix = ops.arange(seq_len, dtype="int") + ix = ops.expand_dims(ix, axis=0) - ops.expand_dims(shifts, axis=1) + + # This produces the desired shift (in fact a rollover). + def roll_sequence(seq): + return ops.take_along_axis(seq, ix, axis=1) + + # The shifting rolls the content over so the prompt is at the end of + # the sequence and the generated text is at the beginning. We mask + # it to retain the generated text only. + y["padding_mask"] = ops.logical_xor( + roll_sequence(prompt_mask), roll_sequence(x["padding_mask"]) + ) + # we assume the mask is enough and there is no need to zero-out the values + y["token_ids"] = roll_sequence(x["token_ids"]) + + return y + def postprocess(x): return self.preprocessor.generate_postprocess(x) @@ -343,7 +378,12 @@ def postprocess(x): if self.preprocessor is not None: inputs = [preprocess(x) for x in inputs] - outputs = [generate(x) for x in inputs] + + if strip_prompt: + outputs = [strip_prompt_function(generate(x), x) for x in inputs] + else: + outputs = [generate(x) for x in inputs] + if self.preprocessor is not None: outputs = [postprocess(x) for x in outputs] diff --git a/keras_hub/src/models/deeplab_v3/deeplab_v3_backbone_test.py b/keras_hub/src/models/deeplab_v3/deeplab_v3_backbone_test.py index a7b1809085..a0a3a8d1df 100644 --- a/keras_hub/src/models/deeplab_v3/deeplab_v3_backbone_test.py +++ b/keras_hub/src/models/deeplab_v3/deeplab_v3_backbone_test.py @@ -51,6 +51,7 @@ def test_saved_model(self): cls=DeepLabV3Backbone, init_kwargs=self.init_kwargs, input_data=self.input_data, + atol=0.00001, ) diff --git a/keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py b/keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py index 1b1dde181d..85cd186830 100644 --- a/keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +++ b/keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py @@ -1,4 +1,18 @@ """DeepLabV3 preset configurations.""" -# TODO /~https://github.com/keras-team/keras-hub/issues/1896, -backbone_presets = {} +backbone_presets = { + "deeplab_v3_plus_resnet50_pascalvoc": { + "metadata": { + "description": ( + "DeepLabV3+ model with ResNet50 as image encoder and trained on " + "augmented Pascal VOC dataset by Semantic Boundaries Dataset(SBD)" + "which is having categorical accuracy of 90.01 and 0.63 Mean IoU." + ), + "params": 39190656, + "official_name": "DeepLabV3", + "path": "deeplab_v3", + "model_card": "https://arxiv.org/abs/1802.02611", + }, + "kaggle_handle": "kaggle://keras/deeplabv3plus/keras/deeplab_v3_plus_resnet50_pascalvoc/3", + }, +} diff --git a/keras_hub/src/models/densenet/densenet_presets.py b/keras_hub/src/models/densenet/densenet_presets.py index 2c3ef77842..99702bf86f 100644 --- a/keras_hub/src/models/densenet/densenet_presets.py +++ b/keras_hub/src/models/densenet/densenet_presets.py @@ -12,7 +12,7 @@ "path": "densenet", "model_card": "https://arxiv.org/abs/1608.06993", }, - "kaggle_handle": "kaggle://kerashub/densenet/keras/densenet_121_imagenet", + "kaggle_handle": "kaggle://keras/densenet/keras/densenet_121_imagenet/2", }, "densenet_169_imagenet": { "metadata": { @@ -25,7 +25,7 @@ "path": "densenet", "model_card": "https://arxiv.org/abs/1608.06993", }, - "kaggle_handle": "kaggle://kerashub/densenet/keras/densenet_169_imagenet", + "kaggle_handle": "kaggle://keras/densenet/keras/densenet_169_imagenet/2", }, "densenet_201_imagenet": { "metadata": { @@ -38,6 +38,6 @@ "path": "densenet", "model_card": "https://arxiv.org/abs/1608.06993", }, - "kaggle_handle": "kaggle://kerashub/densenet/keras/densenet_201_imagenet", + "kaggle_handle": "kaggle://keras/densenet/keras/densenet_201_imagenet/2", }, } diff --git a/keras_hub/src/models/gemma/gemma_backbone.py b/keras_hub/src/models/gemma/gemma_backbone.py index c34547b83e..1d6482b96b 100644 --- a/keras_hub/src/models/gemma/gemma_backbone.py +++ b/keras_hub/src/models/gemma/gemma_backbone.py @@ -224,7 +224,7 @@ def get_layout_map( Example: ``` - # Feel free to change the mesh shape to balance data and model parallel + # Feel free to change the mesh shape to balance data and model parallelism mesh = keras.distribution.DeviceMesh( shape=(1, 8), axis_names=('batch', 'model'), devices=keras.distribution.list_devices()) @@ -232,11 +232,19 @@ def get_layout_map( mesh, model_parallel_dim_name="model") distribution = keras.distribution.ModelParallel( - mesh, layout_map, batch_dim_name='batch') + layout_map=layout_map, batch_dim_name='batch') with distribution.scope(): gemma_model = keras_hub.models.GemmaCausalLM.from_preset() ``` + To see how the layout map was applied, load the model then run (for one decoder block): + ``` + embedding_layer = gemma_model.backbone.get_layer("token_embedding") + decoder_block_1 = gemma_model.backbone.get_layer('decoder_block_1') + for variable in embedding_layer.weights + decoder_block_1.weights: + print(f'{variable.path:<58} {str(variable.shape):<16} {str(variable.value.sharding.spec)}') + ``` + Args: device_mesh: The `keras.distribution.DeviceMesh` instance for distribution. @@ -246,7 +254,7 @@ def get_layout_map( the data should be partition on. Return: `keras.distribution.LayoutMap` that contains the sharding spec - of all the model weights. + for all the model weights. """ # The weight path and shape of the Gemma backbone is like below (for 2G) # token_embedding/embeddings, (256128, 2048), 524550144 diff --git a/keras_hub/src/models/gemma/gemma_backbone_test.py b/keras_hub/src/models/gemma/gemma_backbone_test.py index bbd383e687..b5f8575332 100644 --- a/keras_hub/src/models/gemma/gemma_backbone_test.py +++ b/keras_hub/src/models/gemma/gemma_backbone_test.py @@ -74,11 +74,10 @@ def test_architecture_characteristics(self): def test_distribution(self): if keras.backend.backend() != "jax": - return + self.skipTest("`ModelParallel` testing requires the Jax backend.") devices = keras.distribution.list_devices("CPU") if len(devices) == 1: - # Need more than 1 device for distribution testing. - return + self.skipTest("`ModelParallel` testing requires multiple devices.") device_mesh = keras.distribution.DeviceMesh( shape=(1, len(devices)), axis_names=("batch", "model"), @@ -86,7 +85,7 @@ def test_distribution(self): ) layout_map = GemmaBackbone.get_layout_map(device_mesh) - distribution = keras.distribution.ModelParallel(device_mesh, layout_map) + distribution = keras.distribution.ModelParallel(layout_map=layout_map) with distribution.scope(): model = GemmaBackbone(**self.init_kwargs) @@ -129,7 +128,6 @@ def test_distribution_with_lora(self): self.skipTest("`ModelParallel` testing requires the Jax backend.") devices = keras.distribution.list_devices("CPU") if len(devices) == 1: - # Need more than 1 device for distribution testing. self.skipTest("`ModelParallel` testing requires multiple devices.") device_mesh = keras.distribution.DeviceMesh( shape=(1, len(devices)), diff --git a/keras_hub/src/models/image_to_image.py b/keras_hub/src/models/image_to_image.py new file mode 100644 index 0000000000..d3194a5815 --- /dev/null +++ b/keras_hub/src/models/image_to_image.py @@ -0,0 +1,411 @@ +import itertools +from functools import partial + +import keras +from keras import ops +from keras import random + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.task import Task +from keras_hub.src.utils.keras_utils import standardize_data_format + +try: + import tensorflow as tf +except ImportError: + tf = None + + +@keras_hub_export("keras_hub.models.ImageToImage") +class ImageToImage(Task): + """Base class for image-to-image tasks. + + `ImageToImage` tasks wrap a `keras_hub.models.Backbone` and + a `keras_hub.models.Preprocessor` to create a model that can be used for + generation and generative fine-tuning. + + `ImageToImage` tasks provide an additional, high-level `generate()` function + which can be used to generate image by token with a (image, string) in, + image out signature. + + All `ImageToImage` tasks include a `from_preset()` constructor which can be + used to load a pre-trained config and weights. + + Example: + + ```python + # Load a Stable Diffusion 3 backbone with pre-trained weights. + reference_image = np.ones((1024, 1024, 3), dtype="float32") + image_to_image = keras_hub.models.ImageToImage.from_preset( + "stable_diffusion_3_medium", + ) + image_to_image.generate( + reference_image, + "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", + ) + + # Load a Stable Diffusion 3 backbone at bfloat16 precision. + image_to_image = keras_hub.models.ImageToImage.from_preset( + "stable_diffusion_3_medium", + dtype="bfloat16", + ) + image_to_image.generate( + reference_image, + "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", + ) + ``` + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Default compilation. + self.compile() + + @property + def support_negative_prompts(self): + """Whether the model supports `negative_prompts` key in `generate()`.""" + return bool(True) + + @property + def image_shape(self): + return tuple(self.backbone.image_shape) + + @property + def latent_shape(self): + return tuple(self.backbone.latent_shape) + + def compile( + self, + optimizer="auto", + loss="auto", + *, + metrics="auto", + **kwargs, + ): + """Configures the `ImageToImage` task for training. + + The `ImageToImage` task extends the default compilation signature of + `keras.Model.compile` with defaults for `optimizer`, `loss`, and + `metrics`. To override these defaults, pass any value + to these arguments during compilation. + + Args: + optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer` + instance. Defaults to `"auto"`, which uses the default optimizer + for the given model and task. See `keras.Model.compile` and + `keras.optimizers` for more info on possible `optimizer` values. + loss: `"auto"`, a loss name, or a `keras.losses.Loss` instance. + Defaults to `"auto"`, where a + `keras.losses.MeanSquaredError` loss will be applied. See + `keras.Model.compile` and `keras.losses` for more info on + possible `loss` values. + metrics: `"auto"`, or a list of metrics to be evaluated by + the model during training and testing. Defaults to `"auto"`, + where a `keras.metrics.MeanSquaredError` will be applied to + track the loss of the model during training. See + `keras.Model.compile` and `keras.metrics` for more info on + possible `metrics` values. + **kwargs: See `keras.Model.compile` for a full list of arguments + supported by the compile method. + """ + # Ref: /~https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py#L410-L414 + if optimizer == "auto": + optimizer = keras.optimizers.AdamW( + 1e-4, weight_decay=1e-2, epsilon=1e-8, clipnorm=1.0 + ) + if loss == "auto": + loss = keras.losses.MeanSquaredError() + if metrics == "auto": + metrics = [keras.metrics.MeanSquaredError()] + super().compile( + optimizer=optimizer, + loss=loss, + metrics=metrics, + **kwargs, + ) + self.generate_function = None + + def generate_step(self, *args, **kwargs): + """Run generation on batches of input.""" + raise NotImplementedError + + def make_generate_function(self): + """Create or return the compiled generation function.""" + if self.generate_function is not None: + return self.generate_function + + self.generate_function = self.generate_step + if keras.config.backend() == "torch": + import torch + + def wrapped_function(*args, **kwargs): + with torch.no_grad(): + return self.generate_step(*args, **kwargs) + + self.generate_function = wrapped_function + elif keras.config.backend() == "tensorflow" and not self.run_eagerly: + self.generate_function = tf.function( + self.generate_step, jit_compile=self.jit_compile + ) + elif keras.config.backend() == "jax" and not self.run_eagerly: + import jax + + @partial(jax.jit) + def compiled_function(state, *args, **kwargs): + ( + trainable_variables, + non_trainable_variables, + ) = state + mapping = itertools.chain( + zip(self.trainable_variables, trainable_variables), + zip(self.non_trainable_variables, non_trainable_variables), + ) + + with keras.StatelessScope(state_mapping=mapping): + outputs = self.generate_step(*args, **kwargs) + return outputs + + def wrapped_function(*args, **kwargs): + # Create an explicit tuple of all variable state. + state = ( + # Use the explicit variable.value to preserve the + # sharding spec of distribution. + [v.value for v in self.trainable_variables], + [v.value for v in self.non_trainable_variables], + ) + outputs = compiled_function(state, *args, **kwargs) + return outputs + + self.generate_function = wrapped_function + return self.generate_function + + def _normalize_generate_inputs(self, inputs): + """Normalize user input to the generate function. + + This function converts all inputs to tensors, adds a batch dimension if + necessary, and returns a iterable "dataset like" object (either an + actual `tf.data.Dataset` or a list with a single batch element). + + The input format must be one of the following: + - A dict with "images", "prompts" and/or "negative_prompts" keys + - A tf.data.Dataset with "images", "prompts" and/or "negative_prompts" + keys + + The output will be a dict with "images", "prompts" and/or + "negative_prompts" keys. + """ + if tf and isinstance(inputs, tf.data.Dataset): + _inputs = { + "images": inputs.map(lambda x: x["images"]).as_numpy_iterator(), + "prompts": inputs.map( + lambda x: x["prompts"] + ).as_numpy_iterator(), + } + if self.support_negative_prompts: + _inputs["negative_prompts"] = inputs.map( + lambda x: x["negative_prompts"] + ).as_numpy_iterator() + return _inputs, False + + if ( + not isinstance(inputs, dict) + or "images" not in inputs + or "prompts" not in inputs + ): + raise ValueError( + '`inputs` must be a dict with "images" and "prompts" keys or a' + f"tf.data.Dataset. Received: inputs={inputs}" + ) + + def normalize(x): + if isinstance(x, str): + return [x], True + if tf and isinstance(x, tf.Tensor) and x.shape.rank == 0: + return x[tf.newaxis], True + return x, False + + def normalize_images(x): + data_format = getattr( + self.backbone, "data_format", standardize_data_format(None) + ) + input_is_scalar = False + x = ops.convert_to_tensor(x) + if len(ops.shape(x)) < 4: + x = ops.expand_dims(x, axis=0) + input_is_scalar = True + x = ops.image.resize( + x, + (self.backbone.image_shape[0], self.backbone.image_shape[1]), + interpolation="nearest", + data_format=data_format, + ) + return x, input_is_scalar + + def get_dummy_prompts(x): + dummy_prompts = [""] * len(x) + if tf and isinstance(x, tf.Tensor): + return tf.convert_to_tensor(dummy_prompts) + else: + return dummy_prompts + + for key in inputs: + if key == "images": + inputs[key], input_is_scalar = normalize_images(inputs[key]) + else: + inputs[key], input_is_scalar = normalize(inputs[key]) + + if self.support_negative_prompts and "negative_prompts" not in inputs: + inputs["negative_prompts"] = get_dummy_prompts(inputs["prompts"]) + + return [inputs], input_is_scalar + + def _normalize_generate_outputs(self, outputs, input_is_scalar): + """Normalize user output from the generate function. + + This function converts all output to numpy with a value range of + `[0, 255]`. If a batch dimension was added to the input, it is removed + from the output. + """ + + def normalize(x): + outputs = ops.concatenate(x, axis=0) + outputs = ops.clip(ops.divide(ops.add(outputs, 1.0), 2.0), 0.0, 1.0) + outputs = ops.cast(ops.round(ops.multiply(outputs, 255.0)), "uint8") + outputs = ops.squeeze(outputs, 0) if input_is_scalar else outputs + return ops.convert_to_numpy(outputs) + + if isinstance(outputs[0], dict): + normalized = {} + for key in outputs[0]: + normalized[key] = normalize([x[key] for x in outputs]) + return normalized + return normalize([x for x in outputs]) + + def generate( + self, + inputs, + num_steps, + guidance_scale, + strength, + seed=None, + ): + """Generate image based on the provided `inputs`. + + Typically, `inputs` is a dict with `"images"` and `"prompts"` keys. + `"images"` are reference images within a value range of + `[-1.0, 1.0]`, which will be resized to `self.backbone.height` and + `self.backbone.width`, then encoded into latent space by the VAE + encoder. `"prompts"` are strings that will be tokenized and encoded by + the text encoder. + + Some models support a `"negative_prompts"` key, which helps steer the + model away from generating certain styles and elements. To enable this, + add `"negative_prompts"` to the input dict. + + If `inputs` are a `tf.data.Dataset`, outputs will be generated + "batch-by-batch" and concatenated. Otherwise, all inputs will be + processed as batches. + + Args: + inputs: python data, tensor data, or a `tf.data.Dataset`. The format + must be one of the following: + - A dict with `"images"`, `"prompts"` and/or + `"negative_prompts"` keys. + - A `tf.data.Dataset` with `"images"`, `"prompts"` and/or + `"negative_prompts"` keys. + num_steps: int. The number of diffusion steps to take. + guidance_scale: float. The classifier free guidance scale defined in + [Classifier-Free Diffusion Guidance]( + https://arxiv.org/abs/2207.12598). A higher scale encourages + generating images more closely related to the prompts, typically + at the cost of lower image quality. + strength: float. Indicates the extent to which the reference + `images` are transformed. Must be between `0.0` and `1.0`. When + `strength=1.0`, `images` is essentially ignore and added noise + is maximum and the denoising process runs for the full number of + iterations specified in `num_steps`. + seed: optional int. Used as a random seed. + """ + num_steps = int(num_steps) + guidance_scale = float(guidance_scale) + strength = float(strength) + if strength < 0.0 or strength > 1.0: + raise ValueError( + "`strength` must be between `0.0` and `1.0`. " + f"Received strength={strength}." + ) + starting_step = int(num_steps * (1.0 - strength)) + starting_step = ops.convert_to_tensor(starting_step, "int32") + num_steps = ops.convert_to_tensor(num_steps, "int32") + guidance_scale = ops.convert_to_tensor(guidance_scale) + + # Check `inputs` format. + required_keys = ["images", "prompts"] + if tf and isinstance(inputs, tf.data.Dataset): + spec = inputs.element_spec + if not all(key in spec for key in required_keys): + raise ValueError( + "Expected a `tf.data.Dataset` with the following keys:" + f"{required_keys}. Received: inputs.element_spec={spec}" + ) + else: + if not isinstance(inputs, dict): + raise ValueError( + "Expected a `dict` or `tf.data.Dataset`. " + f"Received: inputs={inputs} of type {type(inputs)}." + ) + if not all(key in inputs for key in required_keys): + raise ValueError( + "Expected a `dict` with the following keys:" + f"{required_keys}. " + f"Received: inputs.keys={list(inputs.keys())}" + ) + + # Setup our three main passes. + # 1. Preprocessing strings to dense integer tensors. + # 2. Generate outputs via a compiled function on dense tensors. + # 3. Postprocess dense tensors to a value range of `[0, 255]`. + generate_function = self.make_generate_function() + + def preprocess(x): + if self.preprocessor is not None: + return self.preprocessor.generate_preprocess(x) + else: + return x + + def generate(images, x): + token_ids = x[0] if self.support_negative_prompts else x + + # Initialize noises. + if isinstance(token_ids, dict): + arbitrary_key = list(token_ids.keys())[0] + batch_size = ops.shape(token_ids[arbitrary_key])[0] + else: + batch_size = ops.shape(token_ids)[0] + noise_shape = (batch_size,) + self.latent_shape[1:] + noises = random.normal(noise_shape, dtype="float32", seed=seed) + + return generate_function( + images, noises, x, starting_step, num_steps, guidance_scale + ) + + # Normalize and preprocess inputs. + inputs, input_is_scalar = self._normalize_generate_inputs(inputs) + if self.support_negative_prompts: + images = [x["images"] for x in inputs] + token_ids = [preprocess(x["prompts"]) for x in inputs] + negative_token_ids = [ + preprocess(x["negative_prompts"]) for x in inputs + ] + # Tuple format: (images, (token_ids, negative_token_ids)). + inputs = [ + x for x in zip(images, zip(token_ids, negative_token_ids)) + ] + else: + images = [x["images"] for x in inputs] + token_ids = [preprocess(x["prompts"]) for x in inputs] + # Tuple format: (images, token_ids). + inputs = [x for x in zip(images, token_ids)] + + # Image-to-image. + outputs = [generate(*x) for x in inputs] + return self._normalize_generate_outputs(outputs, input_is_scalar) diff --git a/keras_hub/src/models/inpaint.py b/keras_hub/src/models/inpaint.py new file mode 100644 index 0000000000..40bcc7ad15 --- /dev/null +++ b/keras_hub/src/models/inpaint.py @@ -0,0 +1,513 @@ +import itertools +from functools import partial + +import keras +from keras import ops +from keras import random + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.task import Task +from keras_hub.src.utils.keras_utils import standardize_data_format + +try: + import tensorflow as tf +except ImportError: + tf = None + + +@keras_hub_export("keras_hub.models.Inpaint") +class Inpaint(Task): + """Base class for image-to-image tasks. + + `Inpaint` tasks wrap a `keras_hub.models.Backbone` and + a `keras_hub.models.Preprocessor` to create a model that can be used for + generation and generative fine-tuning. + + `Inpaint` tasks provide an additional, high-level `generate()` function + which can be used to generate image by token with a (image, mask, string) + in, image out signature. + + All `Inpaint` tasks include a `from_preset()` constructor which can be + used to load a pre-trained config and weights. + + Example: + + ```python + # Load a Stable Diffusion 3 backbone with pre-trained weights. + reference_image = np.ones((1024, 1024, 3), dtype="float32") + reference_mask = np.ones((1024, 1024), dtype="float32") + inpaint = keras_hub.models.Inpaint.from_preset( + "stable_diffusion_3_medium", + ) + inpaint.generate( + reference_image, + reference_mask, + "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", + ) + + # Load a Stable Diffusion 3 backbone at bfloat16 precision. + inpaint = keras_hub.models.Inpaint.from_preset( + "stable_diffusion_3_medium", + dtype="bfloat16", + ) + inpaint.generate( + reference_image, + reference_mask, + "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", + ) + ``` + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Default compilation. + self.compile() + + @property + def support_negative_prompts(self): + """Whether the model supports `negative_prompts` key in `generate()`.""" + return bool(True) + + @property + def image_shape(self): + return tuple(self.backbone.image_shape) + + @property + def latent_shape(self): + return tuple(self.backbone.latent_shape) + + def compile( + self, + optimizer="auto", + loss="auto", + *, + metrics="auto", + **kwargs, + ): + """Configures the `Inpaint` task for training. + + The `Inpaint` task extends the default compilation signature of + `keras.Model.compile` with defaults for `optimizer`, `loss`, and + `metrics`. To override these defaults, pass any value + to these arguments during compilation. + + Args: + optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer` + instance. Defaults to `"auto"`, which uses the default optimizer + for the given model and task. See `keras.Model.compile` and + `keras.optimizers` for more info on possible `optimizer` values. + loss: `"auto"`, a loss name, or a `keras.losses.Loss` instance. + Defaults to `"auto"`, where a + `keras.losses.MeanSquaredError` loss will be applied. See + `keras.Model.compile` and `keras.losses` for more info on + possible `loss` values. + metrics: `"auto"`, or a list of metrics to be evaluated by + the model during training and testing. Defaults to `"auto"`, + where a `keras.metrics.MeanSquaredError` will be applied to + track the loss of the model during training. See + `keras.Model.compile` and `keras.metrics` for more info on + possible `metrics` values. + **kwargs: See `keras.Model.compile` for a full list of arguments + supported by the compile method. + """ + # Ref: /~https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py#L410-L414 + if optimizer == "auto": + optimizer = keras.optimizers.AdamW( + 1e-4, weight_decay=1e-2, epsilon=1e-8, clipnorm=1.0 + ) + if loss == "auto": + loss = keras.losses.MeanSquaredError() + if metrics == "auto": + metrics = [keras.metrics.MeanSquaredError()] + super().compile( + optimizer=optimizer, + loss=loss, + metrics=metrics, + **kwargs, + ) + self.generate_function = None + + def generate_step(self, *args, **kwargs): + """Run generation on batches of input.""" + raise NotImplementedError + + def make_generate_function(self): + """Create or return the compiled generation function.""" + if self.generate_function is not None: + return self.generate_function + + self.generate_function = self.generate_step + if keras.config.backend() == "torch": + import torch + + def wrapped_function(*args, **kwargs): + with torch.no_grad(): + return self.generate_step(*args, **kwargs) + + self.generate_function = wrapped_function + elif keras.config.backend() == "tensorflow" and not self.run_eagerly: + self.generate_function = tf.function( + self.generate_step, jit_compile=self.jit_compile + ) + elif keras.config.backend() == "jax" and not self.run_eagerly: + import jax + + @partial(jax.jit) + def compiled_function(state, *args, **kwargs): + ( + trainable_variables, + non_trainable_variables, + ) = state + mapping = itertools.chain( + zip(self.trainable_variables, trainable_variables), + zip(self.non_trainable_variables, non_trainable_variables), + ) + + with keras.StatelessScope(state_mapping=mapping): + outputs = self.generate_step(*args, **kwargs) + return outputs + + def wrapped_function(*args, **kwargs): + # Create an explicit tuple of all variable state. + state = ( + # Use the explicit variable.value to preserve the + # sharding spec of distribution. + [v.value for v in self.trainable_variables], + [v.value for v in self.non_trainable_variables], + ) + outputs = compiled_function(state, *args, **kwargs) + return outputs + + self.generate_function = wrapped_function + return self.generate_function + + def _normalize_generate_images(self, inputs): + """Normalize user image to the generate function. + + This function converts all inputs to tensors, adds a batch dimension if + necessary, and returns a iterable "dataset like" object (either an + actual `tf.data.Dataset` or a list with a single batch element). + """ + if tf and isinstance(inputs, tf.data.Dataset): + return inputs.as_numpy_iterator(), False + + def normalize(x): + data_format = getattr( + self.backbone, "data_format", standardize_data_format(None) + ) + input_is_scalar = False + x = ops.convert_to_tensor(x) + if len(ops.shape(x)) < 4: + x = ops.expand_dims(x, axis=0) + input_is_scalar = True + x = ops.image.resize( + x, + (self.backbone.image_shape[0], self.backbone.image_shape[1]), + interpolation="nearest", + data_format=data_format, + ) + return x, input_is_scalar + + if isinstance(inputs, dict): + for key in inputs: + inputs[key], input_is_scalar = normalize(inputs[key]) + else: + inputs, input_is_scalar = normalize(inputs) + + return inputs, input_is_scalar + + def _normalize_generate_masks(self, inputs): + """Normalize user masks to the generate function. + + This function converts all inputs to tensors, adds a batch dimension if + necessary, and returns a iterable "dataset like" object (either an + actual `tf.data.Dataset` or a list with a single batch element). + """ + if tf and isinstance(inputs, tf.data.Dataset): + return inputs.as_numpy_iterator(), False + + def normalize(x): + data_format = getattr( + self.backbone, "data_format", standardize_data_format(None) + ) + input_is_scalar = False + x = ops.convert_to_tensor(x) + if len(ops.shape(x)) < 3: + x = ops.expand_dims(x, axis=0) + input_is_scalar = True + x = ops.expand_dims(x, axis=-1) + if keras.backend.standardize_dtype(x.dtype) == "bool": + x = ops.cast(x, "float32") + x = ops.image.resize( + x, + (self.backbone.image_shape[0], self.backbone.image_shape[1]), + interpolation="nearest", + data_format=data_format, + ) + x = ops.squeeze(x, axis=-1) + return x, input_is_scalar + + if isinstance(inputs, dict): + for key in inputs: + inputs[key], input_is_scalar = normalize(inputs[key]) + else: + inputs, input_is_scalar = normalize(inputs) + + return inputs, input_is_scalar + + def _normalize_generate_inputs(self, inputs): + """Normalize user input to the generate function. + + This function converts all inputs to tensors, adds a batch dimension if + necessary, and returns a iterable "dataset like" object (either an + actual `tf.data.Dataset` or a list with a single batch element). + + The input format must be one of the following: + - A dict with "images", "masks", "prompts" and/or "negative_prompts" + keys + - A tf.data.Dataset with "images", "masks", "prompts" and/or + "negative_prompts" keys + + The output will be a dict with "images", "masks", "prompts" and/or + "negative_prompts" keys. + """ + if tf and isinstance(inputs, tf.data.Dataset): + _inputs = { + "images": inputs.map(lambda x: x["images"]).as_numpy_iterator(), + "masks": inputs.map(lambda x: x["masks"]).as_numpy_iterator(), + "prompts": inputs.map( + lambda x: x["prompts"] + ).as_numpy_iterator(), + } + if self.support_negative_prompts: + _inputs["negative_prompts"] = inputs.map( + lambda x: x["negative_prompts"] + ).as_numpy_iterator() + return _inputs, False + + def normalize(x): + if isinstance(x, str): + return [x], True + if tf and isinstance(x, tf.Tensor) and x.shape.rank == 0: + return x[tf.newaxis], True + return x, False + + def normalize_images(x): + data_format = getattr( + self.backbone, "data_format", standardize_data_format(None) + ) + input_is_scalar = False + x = ops.convert_to_tensor(x) + if len(ops.shape(x)) < 4: + x = ops.expand_dims(x, axis=0) + input_is_scalar = True + x = ops.image.resize( + x, + (self.backbone.image_shape[0], self.backbone.image_shape[1]), + interpolation="nearest", + data_format=data_format, + ) + return x, input_is_scalar + + def normalize_masks(x): + data_format = getattr( + self.backbone, "data_format", standardize_data_format(None) + ) + input_is_scalar = False + x = ops.convert_to_tensor(x) + if len(ops.shape(x)) < 3: + x = ops.expand_dims(x, axis=0) + input_is_scalar = True + x = ops.expand_dims(x, axis=-1) + if keras.backend.standardize_dtype(x.dtype) == "bool": + x = ops.cast(x, "float32") + x = ops.image.resize( + x, + (self.backbone.image_shape[0], self.backbone.image_shape[1]), + interpolation="nearest", + data_format=data_format, + ) + x = ops.squeeze(x, axis=-1) + return x, input_is_scalar + + def get_dummy_prompts(x): + dummy_prompts = [""] * len(x) + if tf and isinstance(x, tf.Tensor): + return tf.convert_to_tensor(dummy_prompts) + else: + return dummy_prompts + + for key in inputs: + if key == "images": + inputs[key], input_is_scalar = normalize_images(inputs[key]) + elif key == "masks": + inputs[key], input_is_scalar = normalize_masks(inputs[key]) + else: + inputs[key], input_is_scalar = normalize(inputs[key]) + + if self.support_negative_prompts and "negative_prompts" not in inputs: + inputs["negative_prompts"] = get_dummy_prompts(inputs["prompts"]) + + return [inputs], input_is_scalar + + def _normalize_generate_outputs(self, outputs, input_is_scalar): + """Normalize user output from the generate function. + + This function converts all output to numpy with a value range of + `[0, 255]`. If a batch dimension was added to the input, it is removed + from the output. + """ + + def normalize(x): + outputs = ops.concatenate(x, axis=0) + outputs = ops.clip(ops.divide(ops.add(outputs, 1.0), 2.0), 0.0, 1.0) + outputs = ops.cast(ops.round(ops.multiply(outputs, 255.0)), "uint8") + outputs = ops.squeeze(outputs, 0) if input_is_scalar else outputs + return ops.convert_to_numpy(outputs) + + if isinstance(outputs[0], dict): + normalized = {} + for key in outputs[0]: + normalized[key] = normalize([x[key] for x in outputs]) + return normalized + return normalize([x for x in outputs]) + + def generate( + self, + inputs, + num_steps, + guidance_scale, + strength, + seed=None, + ): + """Generate image based on the provided `inputs`. + + Typically, `inputs` is a dict with `"images"` `"masks"` and `"prompts"` + keys. `"images"` are reference images within a value range of + `[-1.0, 1.0]`, which will be resized to height and width from + `self.backbone.image_shape`, then encoded into latent space by the VAE + encoder. `"masks"` are mask images with a boolean dtype, where white + pixels are repainted while black pixels are preserved. `"prompts"` are + strings that will be tokenized and encoded by the text encoder. + + Some models support a `"negative_prompts"` key, which helps steer the + model away from generating certain styles and elements. To enable this, + add `"negative_prompts"` to the input dict. + + If `inputs` are a `tf.data.Dataset`, outputs will be generated + "batch-by-batch" and concatenated. Otherwise, all inputs will be + processed as batches. + + Args: + inputs: python data, tensor data, or a `tf.data.Dataset`. The format + must be one of the following: + - A dict with `"images"`, `"masks"`, `"prompts"` and/or + `"negative_prompts"` keys. + - A `tf.data.Dataset` with `"images"`, `"masks"`, `"prompts"` + and/or `"negative_prompts"` keys. + num_steps: int. The number of diffusion steps to take. + guidance_scale: float. The classifier free guidance scale defined in + [Classifier-Free Diffusion Guidance]( + https://arxiv.org/abs/2207.12598). A higher scale encourages + generating images more closely related to the prompts, typically + at the cost of lower image quality. + strength: float. Indicates the extent to which the reference + `images` are transformed. Must be between `0.0` and `1.0`. When + `strength=1.0`, `images` is essentially ignore and added noise + is maximum and the denoising process runs for the full number of + iterations specified in `num_steps`. + seed: optional int. Used as a random seed. + """ + num_steps = int(num_steps) + guidance_scale = float(guidance_scale) + strength = float(strength) + if strength < 0.0 or strength > 1.0: + raise ValueError( + "`strength` must be between `0.0` and `1.0`. " + f"Received strength={strength}." + ) + starting_step = int(num_steps * (1.0 - strength)) + starting_step = ops.convert_to_tensor(starting_step, "int32") + num_steps = ops.convert_to_tensor(num_steps, "int32") + guidance_scale = ops.convert_to_tensor(guidance_scale) + + # Check `inputs` format. + required_keys = ["images", "masks", "prompts"] + if tf and isinstance(inputs, tf.data.Dataset): + spec = inputs.element_spec + if not all(key in spec for key in required_keys): + raise ValueError( + "Expected a `tf.data.Dataset` with the following keys:" + f"{required_keys}. Received: inputs.element_spec={spec}" + ) + else: + if not isinstance(inputs, dict): + raise ValueError( + "Expected a `dict` or `tf.data.Dataset`. " + f"Received: inputs={inputs} of type {type(inputs)}." + ) + if not all(key in inputs for key in required_keys): + raise ValueError( + "Expected a `dict` with the following keys:" + f"{required_keys}. " + f"Received: inputs.keys={list(inputs.keys())}" + ) + + # Setup our three main passes. + # 1. Preprocessing strings to dense integer tensors. + # 2. Generate outputs via a compiled function on dense tensors. + # 3. Postprocess dense tensors to a value range of `[0, 255]`. + generate_function = self.make_generate_function() + + def preprocess(x): + if self.preprocessor is not None: + return self.preprocessor.generate_preprocess(x) + else: + return x + + def generate(images, masks, x): + token_ids = x[0] if self.support_negative_prompts else x + + # Initialize noises. + if isinstance(token_ids, dict): + arbitrary_key = list(token_ids.keys())[0] + batch_size = ops.shape(token_ids[arbitrary_key])[0] + else: + batch_size = ops.shape(token_ids)[0] + noise_shape = (batch_size,) + self.latent_shape[1:] + noises = random.normal(noise_shape, dtype="float32", seed=seed) + + return generate_function( + images, + masks, + noises, + x, + starting_step, + num_steps, + guidance_scale, + ) + + # Normalize and preprocess inputs. + inputs, input_is_scalar = self._normalize_generate_inputs(inputs) + if self.support_negative_prompts: + images = [x["images"] for x in inputs] + masks = [x["masks"] for x in inputs] + token_ids = [preprocess(x["prompts"]) for x in inputs] + negative_token_ids = [ + preprocess(x["negative_prompts"]) for x in inputs + ] + # Tuple format: (images, masks, (token_ids, negative_token_ids)). + inputs = [ + x + for x in zip(images, masks, zip(token_ids, negative_token_ids)) + ] + else: + images = [x["images"] for x in inputs] + masks = [x["masks"] for x in inputs] + token_ids = [preprocess(x["prompts"]) for x in inputs] + # Tuple format: (images, masks, token_ids). + inputs = [x for x in zip(images, masks, token_ids)] + + # Inpaint. + outputs = [generate(*x) for x in inputs] + return self._normalize_generate_outputs(outputs, input_is_scalar) diff --git a/keras_hub/src/models/llama/llama_backbone.py b/keras_hub/src/models/llama/llama_backbone.py index a654bdf267..0e923c29cd 100644 --- a/keras_hub/src/models/llama/llama_backbone.py +++ b/keras_hub/src/models/llama/llama_backbone.py @@ -59,7 +59,7 @@ class LlamaBackbone(Backbone): } # Pretrained Llama decoder. - model = keras_hub.models.LlamaBackbone.from_preset("llama7b_base_en") + model = keras_hub.models.LlamaBackbone.from_preset("llama2_7b_en") model(input_data) # Randomly initialized Llama decoder with custom config. @@ -175,3 +175,121 @@ def get_config(self): } ) return config + + @staticmethod + def get_layout_map( + device_mesh, + model_parallel_dim_name="model", + data_parallel_dim_name="batch", + ): + """Get a `keras.distribution.LayoutMap` for model parallel distribution. + + The returned `LayoutMap` contains the sharding spec for the Llama + backbone weights, so that you can use it to distribute weights across + the accelerators. + + Example: + ``` + # Feel free to change the mesh shape to balance data and model parallelism + mesh = keras.distribution.DeviceMesh( + shape=(1, 8), + axis_names=('batch', 'model'), + devices=keras.distribution.list_devices(), + ) + layout_map = LlamaBackbone.get_layout_map( + mesh, + model_parallel_dim_name="model", + ) + + distribution = keras.distribution.ModelParallel( + layout_map=layout_map, + batch_dim_name='batch', + ) + + with distribution.scope(): + llama_model = keras_hub.models.LlamaCausalLM.from_preset() + ``` + + To see how the layout map was applied, load the model then run (for one decoder block): + ``` + embedding_layer = llama_model.backbone.get_layer("token_embedding") + decoder_block_1 = llama_model.backbone.get_layer('transformer_layer_0') + for variable in embedding_layer.weights + decoder_block_1.weights: + print(f'{variable.path:<58} {str(variable.shape):<16} {str(variable.value.sharding.spec)}') + ``` + + Args: + device_mesh: The `keras.distribution.DeviceMesh` instance for + distribution. + model_parallel_dim_name: The axis name of the device mesh, where + the weights should be partition on. + data_parallel_dim_name: The axis name of the device mesh, where + the data should be partition on. + Return: + `keras.distribution.LayoutMap` that contains the sharding spec + for all the model weights. + """ + # The weight path and shape of the Llama backbone is like below + # token_embedding/embeddings (128256, 2048) + # repeat block for decoder + # transformer_layer_0/self_attention/query/kernel (2048, 32, 64) + # transformer_layer_0/self_attention/key/kernel (2048, 8, 64) + # transformer_layer_0/self_attention/value/kernel (2048, 8, 64) + # transformer_layer_0/self_attention/attention_output/kernel (32, 64, 2048) + # transformer_layer_0/self_attention_layernorm/scale (2048,) + # transformer_layer_0/feedforward_intermediate_dense/kernel (2048, 8192) + # transformer_layer_0/feedforward_gate_dense/kernel (2048, 8192) + # transformer_layer_0/feedforward_output_dense/kernel (8192, 2048) + # transformer_layer_0/feedforward_layernorm/scale (2048,) + + if not isinstance(device_mesh, keras.distribution.DeviceMesh): + raise ValueError( + "Invalid device_mesh type. Expected `keras.distribution.Device`," + f" got {type(device_mesh)}" + ) + if model_parallel_dim_name not in device_mesh.axis_names: + raise ValueError( + f"{model_parallel_dim_name} is not found in the " + f"device_mesh.axis_names. {device_mesh.axis_name=}" + ) + if data_parallel_dim_name not in device_mesh.axis_names: + raise ValueError( + f"{data_parallel_dim_name} is not found in the " + f"device_mesh.axis_names. {device_mesh.axis_name=}" + ) + # Note that it is possible to further config the mesh to be 3D, eg + # (data, seq, model). We leave it as 2D for now for simplicity. + data_dim = data_parallel_dim_name + model_dim = model_parallel_dim_name + # The sharding config is based on the Gemma team training config. + # See https://arxiv.org/abs/2403.08295 + layout_map = keras.distribution.LayoutMap(device_mesh) + layout_map["token_embedding/embeddings"] = (model_dim, data_dim) + layout_map[ + "transformer_layer.*self_attention.*(query|key|value).kernel" + ] = ( + model_dim, + data_dim, + None, + ) + layout_map["transformer_layer.*attention_output.kernel"] = ( + model_dim, + None, + data_dim, + ) + layout_map[ + "transformer_layer.*feedforward_intermediate_dense.kernel" + ] = ( + data_dim, + model_dim, + ) + layout_map["transformer_layer.*feedforward_gate_dense.kernel"] = ( + data_dim, + model_dim, + ) + layout_map["transformer_layer.*feedforward_output_dense.kernel"] = ( + model_dim, + data_dim, + ) + + return layout_map diff --git a/keras_hub/src/models/llama/llama_backbone_test.py b/keras_hub/src/models/llama/llama_backbone_test.py index 3b8eca49fe..0007dd7a96 100644 --- a/keras_hub/src/models/llama/llama_backbone_test.py +++ b/keras_hub/src/models/llama/llama_backbone_test.py @@ -1,3 +1,4 @@ +import keras import pytest from keras import ops @@ -66,3 +67,87 @@ def test_all_presets(self): preset=preset, input_data=self.input_data, ) + + def test_distribution(self): + if keras.backend.backend() != "jax": + self.skipTest("`ModelParallel` testing requires the Jax backend.") + devices = keras.distribution.list_devices("CPU") + if len(devices) == 1: + self.skipTest("`ModelParallel` testing requires multiple devices.") + device_mesh = keras.distribution.DeviceMesh( + shape=(1, len(devices)), + axis_names=("batch", "model"), + devices=devices, + ) + + layout_map = LlamaBackbone.get_layout_map(device_mesh) + distribution = keras.distribution.ModelParallel(layout_map=layout_map) + with distribution.scope(): + model = LlamaBackbone(**self.init_kwargs) + + for w in model.weights: + if "token_embedding/embeddings" in w.path: + self.assertEqual( + tuple(w.value.sharding.spec), ("model", "batch") + ) + if "self_attention/query/kernel" in w.path: + self.assertEqual( + tuple(w.value.sharding.spec), ("model", "batch", None) + ) + if "self_attention/key/kernel" in w.path: + self.assertEqual( + tuple(w.value.sharding.spec), ("model", "batch", None) + ) + if "self_attention/value/kernel" in w.path: + self.assertEqual( + tuple(w.value.sharding.spec), ("model", "batch", None) + ) + if "self_attention/attention_output/kernel" in w.path: + self.assertEqual( + tuple(w.value.sharding.spec), ("model", None, "batch") + ) + if "feedforward_intermediate_dense/kernel" in w.path: + self.assertEqual( + tuple(w.value.sharding.spec), ("batch", "model") + ) + if "feedforward_gate_dense/kernel" in w.path: + self.assertEqual( + tuple(w.value.sharding.spec), ("batch", "model") + ) + if "feedforward_output_dense" in w.path: + self.assertEqual( + tuple(w.value.sharding.spec), ("model", "batch") + ) + + def test_distribution_with_lora(self): + if keras.backend.backend() != "jax": + self.skipTest("`ModelParallel` testing requires the Jax backend.") + devices = keras.distribution.list_devices("CPU") + if len(devices) == 1: + # Need more than 1 device for distribution testing. + self.skipTest("`ModelParallel` testing requires multiple devices.") + device_mesh = keras.distribution.DeviceMesh( + shape=(1, len(devices)), + axis_names=("batch", "model"), + devices=devices, + ) + + layout_map = LlamaBackbone.get_layout_map(device_mesh) + distribution = keras.distribution.ModelParallel(layout_map=layout_map) + with distribution.scope(): + model = LlamaBackbone(**self.init_kwargs) + model.enable_lora(rank=4) + + for w in model.weights: + if "self_attention/query/lora_kernel_a" in w.path: + self.assertEqual( + tuple(w.value.sharding.spec), (None, None, None) + ) + if "self_attention/query/lora_kernel_b" in w.path: + self.assertEqual(tuple(w.value.sharding.spec), (None, None)) + if "self_attention/value/lora_kernel_a" in w.path: + self.assertEqual( + tuple(w.value.sharding.spec), (None, None, None) + ) + if "self_attention/value/lora_kernel_b" in w.path: + self.assertEqual(tuple(w.value.sharding.spec), (None, None)) diff --git a/keras_hub/src/models/llama/llama_causal_lm.py b/keras_hub/src/models/llama/llama_causal_lm.py index 7e1e319f1d..7f0f901d52 100644 --- a/keras_hub/src/models/llama/llama_causal_lm.py +++ b/keras_hub/src/models/llama/llama_causal_lm.py @@ -42,7 +42,9 @@ def __init__(self, backbone, preprocessor=None, **kwargs): self.preprocessor = preprocessor # === Functional Model === - inputs = backbone.inputs + # This must be "backbone.input" i.e. the full input structure, + # rather than "backbone.inputs" which is the flattened list of inputs. + inputs = backbone.input hidden_states = backbone(inputs) outputs = backbone.token_embedding(hidden_states, reverse=True) super().__init__( diff --git a/keras_hub/src/models/llama/llama_presets.py b/keras_hub/src/models/llama/llama_presets.py index 6197cfe07f..f72a0ec95f 100644 --- a/keras_hub/src/models/llama/llama_presets.py +++ b/keras_hub/src/models/llama/llama_presets.py @@ -7,7 +7,7 @@ "description": "7 billion parameter, 32-layer, base LLaMA 2 model.", "params": 6738415616, "official_name": "LLaMA 2", - "path": "llama2", + "path": "llama", "model_card": "/~https://github.com/meta-llama/llama", }, "kaggle_handle": "kaggle://keras/llama2/keras/llama2_7b_en/1", @@ -20,7 +20,7 @@ ), "params": 6739839488, "official_name": "LLaMA 2", - "path": "llama2", + "path": "llama", "model_card": "/~https://github.com/meta-llama/llama", }, "kaggle_handle": "kaggle://keras/llama2/keras/llama2_7b_en_int8/1", @@ -33,7 +33,7 @@ ), "params": 6738415616, "official_name": "LLaMA 2", - "path": "llama2", + "path": "llama", "model_card": "/~https://github.com/meta-llama/llama", }, "kaggle_handle": "kaggle://keras/llama2/keras/llama2_instruct_7b_en/1", @@ -46,7 +46,7 @@ ), "params": 6739839488, "official_name": "LLaMA 2", - "path": "llama2", + "path": "llama", "model_card": "/~https://github.com/meta-llama/llama", }, "kaggle_handle": "kaggle://keras/llama2/keras/llama2_instruct_7b_en_int8/1", @@ -59,7 +59,7 @@ ), "params": 6738415616, "official_name": "Vicuna", - "path": "vicuna", + "path": "llama", "model_card": "/~https://github.com/lm-sys/FastChat", }, "kaggle_handle": "kaggle://keras/vicuna/keras/vicuna_1.5_7b_en/1", diff --git a/keras_hub/src/models/llama3/llama3_causal_lm_preprocessor_test.py b/keras_hub/src/models/llama3/llama3_causal_lm_preprocessor_test.py index b8b45d8fd6..f79be674fb 100644 --- a/keras_hub/src/models/llama3/llama3_causal_lm_preprocessor_test.py +++ b/keras_hub/src/models/llama3/llama3_causal_lm_preprocessor_test.py @@ -11,6 +11,8 @@ class Llama3CausalLMPreprocessorTest(TestCase): def setUp(self): self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"] self.vocab += ["<|begin_of_text|>", "<|end_of_text|>"] + self.vocab += ["<|start_header_id|>", "<|end_header_id|>"] + self.vocab += ["<|eot_id|>"] self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)]) self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"] self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"] diff --git a/keras_hub/src/models/llama3/llama3_causal_lm_test.py b/keras_hub/src/models/llama3/llama3_causal_lm_test.py index 995c1a00e1..a054b8ae14 100644 --- a/keras_hub/src/models/llama3/llama3_causal_lm_test.py +++ b/keras_hub/src/models/llama3/llama3_causal_lm_test.py @@ -16,6 +16,8 @@ class Llama3CausalLMTest(TestCase): def setUp(self): self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"] self.vocab += ["<|begin_of_text|>", "<|end_of_text|>"] + self.vocab += ["<|start_header_id|>", "<|end_header_id|>"] + self.vocab += ["<|eot_id|>"] self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)]) self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"] self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"] @@ -44,7 +46,7 @@ def test_causal_lm_basics(self): cls=Llama3CausalLM, init_kwargs=self.init_kwargs, train_data=self.train_data, - expected_output_shape=(2, 7, 8), + expected_output_shape=(2, 7, 11), ) def test_generate(self): @@ -67,6 +69,12 @@ def test_generate(self): prompt_ids["padding_mask"][:, :5], ) + def test_generate_strip_prompt(self): + causal_lm = Llama3CausalLM(**self.init_kwargs) + prompt = " airplane at airport" + output = causal_lm.generate(prompt, strip_prompt=True) + self.assertFalse(output.startswith(prompt)) + def test_early_stopping(self): causal_lm = Llama3CausalLM(**self.init_kwargs) call_with_cache = causal_lm.call_with_cache diff --git a/keras_hub/src/models/llama3/llama3_tokenizer.py b/keras_hub/src/models/llama3/llama3_tokenizer.py index 397b5e1923..ee3037e854 100644 --- a/keras_hub/src/models/llama3/llama3_tokenizer.py +++ b/keras_hub/src/models/llama3/llama3_tokenizer.py @@ -16,10 +16,33 @@ def __init__( self, vocabulary=None, merges=None, + bos_token="<|begin_of_text|>", + eos_token="<|end_of_text|>", + misc_special_tokens={"<|start_header_id|>", "<|end_header_id|>"}, **kwargs, ): - self._add_special_token("<|begin_of_text|>", "start_token") - self._add_special_token("<|end_of_text|>", "end_token") + # Note: all special tokens must also appear in "vocabulary" + + self._add_special_token(bos_token, "start_token") + misc_special_tokens -= {bos_token} + self._add_special_token(eos_token, "end_token") + misc_special_tokens -= {eos_token} + for i, token in enumerate(misc_special_tokens): + self._add_special_token(token, f"special_token_{i:03d}") + + # Hack: + # Llama models use the <|end_of_text|> or the <|eot_id|> as the stop + # token. This info can be read from config when loading a Hugging Face + # checkpoint but no such config exists for Keras checkpoints. + # Setting both probable end tokens when no config is availble will + # make text generation work in all cases as it will stop + # on both end tokens. However, the packer will always use + # "<|end_of_text|>" , which will be the wrong eos_token for "instruct" + # variants of Llama3. + # TODO: load this correctly from a Keras tokenizer config. + if eos_token == "<|end_of_text|>": + self._add_special_token("<|eot_id|>", "end_token2") + self.pad_token_id = 0 super().__init__( vocabulary=vocabulary, diff --git a/keras_hub/src/models/llama3/llama3_tokenizer_test.py b/keras_hub/src/models/llama3/llama3_tokenizer_test.py index 8440d8ebb2..aff591de04 100644 --- a/keras_hub/src/models/llama3/llama3_tokenizer_test.py +++ b/keras_hub/src/models/llama3/llama3_tokenizer_test.py @@ -8,6 +8,8 @@ class Llama3TokenizerTest(TestCase): def setUp(self): self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"] self.vocab += ["<|end_of_text|>", "<|begin_of_text|>"] + self.vocab += ["<|start_header_id|>", "<|end_header_id|>"] + self.vocab += ["<|eot_id|>"] self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)]) self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"] self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"] diff --git a/keras_hub/src/models/mistral/mistral_causal_lm.py b/keras_hub/src/models/mistral/mistral_causal_lm.py index 7f7ff03d14..06170aa089 100644 --- a/keras_hub/src/models/mistral/mistral_causal_lm.py +++ b/keras_hub/src/models/mistral/mistral_causal_lm.py @@ -42,7 +42,9 @@ def __init__(self, backbone, preprocessor=None, **kwargs): self.preprocessor = preprocessor # === Functional Model === - inputs = backbone.inputs + # This must be "backbone.input" i.e. the full input structure, + # rather than "backbone.inputs" which is the flattened list of inputs. + inputs = backbone.input hidden_states = backbone(inputs) outputs = backbone.token_embedding(hidden_states, reverse=True) super().__init__( diff --git a/keras_hub/src/models/mit/__init__.py b/keras_hub/src/models/mit/__init__.py new file mode 100644 index 0000000000..f581202b1c --- /dev/null +++ b/keras_hub/src/models/mit/__init__.py @@ -0,0 +1,6 @@ +from keras_hub.src.models.mit.mit_backbone import MiTBackbone +from keras_hub.src.models.mit.mit_image_classifier import MiTImageClassifier +from keras_hub.src.models.mit.mit_presets import backbone_presets +from keras_hub.src.utils.preset_utils import register_presets + +register_presets(backbone_presets, MiTBackbone) diff --git a/keras_hub/src/models/mix_transformer/mix_transformer_backbone.py b/keras_hub/src/models/mit/mit_backbone.py similarity index 87% rename from keras_hub/src/models/mix_transformer/mix_transformer_backbone.py rename to keras_hub/src/models/mit/mit_backbone.py index e8f881aee3..a6c57816c4 100644 --- a/keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +++ b/keras_hub/src/models/mit/mit_backbone.py @@ -1,15 +1,22 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import keras import numpy as np from keras import ops from keras_hub.src.api_export import keras_hub_export from keras_hub.src.models.feature_pyramid_backbone import FeaturePyramidBackbone -from keras_hub.src.models.mix_transformer.mix_transformer_layers import ( - HierarchicalTransformerEncoder, -) -from keras_hub.src.models.mix_transformer.mix_transformer_layers import ( - OverlappingPatchingAndEmbedding, -) +from keras_hub.src.models.mit.mit_layers import HierarchicalTransformerEncoder +from keras_hub.src.models.mit.mit_layers import OverlappingPatchingAndEmbedding @keras_hub_export("keras_hub.models.MiTBackbone") @@ -61,7 +68,7 @@ def __init__( ```python images = np.ones(shape=(1, 96, 96, 3)) labels = np.zeros(shape=(1, 96, 96, 1)) - backbone = keras_hub.models.MiTBackbone.from_preset("mit_b0_imagenet") + backbone = keras_hub.models.MiTBackbone.from_preset("mit_b0_ade20k_512") # Evaluate model model(images) @@ -104,7 +111,7 @@ def __init__( ] transformer_blocks.append(transformer_block) cur += depths[i] - layer_norms.append(keras.layers.LayerNormalization()) + layer_norms.append(keras.layers.LayerNormalization(epsilon=1e-5)) # === Functional Model === image_input = keras.layers.Input(shape=image_shape) diff --git a/keras_hub/src/models/mit/mit_backbone_test.py b/keras_hub/src/models/mit/mit_backbone_test.py new file mode 100644 index 0000000000..88c58e96a2 --- /dev/null +++ b/keras_hub/src/models/mit/mit_backbone_test.py @@ -0,0 +1,45 @@ +import numpy as np +import pytest + +from keras_hub.src.models.mit.mit_backbone import MiTBackbone +from keras_hub.src.tests.test_case import TestCase + + +class MiTBackboneTest(TestCase): + def setUp(self): + self.init_kwargs = { + "depths": [2, 2], + "image_shape": (32, 32, 3), + "hidden_dims": [4, 8], + "num_layers": 2, + "blockwise_num_heads": [1, 2], + "blockwise_sr_ratios": [8, 4], + "max_drop_path_rate": 0.1, + "patch_sizes": [7, 3], + "strides": [4, 2], + } + self.input_size = 32 + self.input_data = np.ones( + (2, self.input_size, self.input_size, 3), dtype="float32" + ) + + def test_backbone_basics(self): + self.run_vision_backbone_test( + cls=MiTBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=(2, 4, 4, 8), + expected_pyramid_output_keys=["P1", "P2"], + expected_pyramid_image_sizes=[(8, 8), (4, 4)], + run_quantization_check=False, + run_mixed_precision_check=False, + run_data_format_check=False, + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=MiTBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) diff --git a/keras_hub/src/models/mix_transformer/mix_transformer_classifier.py b/keras_hub/src/models/mit/mit_image_classifier.py similarity index 53% rename from keras_hub/src/models/mix_transformer/mix_transformer_classifier.py rename to keras_hub/src/models/mit/mit_image_classifier.py index 0daac9327f..370920ddf9 100644 --- a/keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +++ b/keras_hub/src/models/mit/mit_image_classifier.py @@ -1,10 +1,12 @@ from keras_hub.src.api_export import keras_hub_export from keras_hub.src.models.image_classifier import ImageClassifier -from keras_hub.src.models.mix_transformer.mix_transformer_backbone import ( - MiTBackbone, +from keras_hub.src.models.mit.mit_backbone import MiTBackbone +from keras_hub.src.models.mit.mit_image_classifier_preprocessor import ( + MiTImageClassifierPreprocessor, ) @keras_hub_export("keras_hub.models.MiTImageClassifier") class MiTImageClassifier(ImageClassifier): backbone_cls = MiTBackbone + preprocessor_cls = MiTImageClassifierPreprocessor diff --git a/keras_hub/src/models/mit/mit_image_classifier_preprocessor.py b/keras_hub/src/models/mit/mit_image_classifier_preprocessor.py new file mode 100644 index 0000000000..d3859c30d6 --- /dev/null +++ b/keras_hub/src/models/mit/mit_image_classifier_preprocessor.py @@ -0,0 +1,12 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.image_classifier_preprocessor import ( + ImageClassifierPreprocessor, +) +from keras_hub.src.models.mit.mit_backbone import MiTBackbone +from keras_hub.src.models.mit.mit_image_converter import MiTImageConverter + + +@keras_hub_export("keras_hub.models.MiTImageClassifierPreprocessor") +class MiTImageClassifierPreprocessor(ImageClassifierPreprocessor): + backbone_cls = MiTBackbone + image_converter_cls = MiTImageConverter diff --git a/keras_hub/src/models/mix_transformer/mix_transformer_classifier_test.py b/keras_hub/src/models/mit/mit_image_classifier_test.py similarity index 78% rename from keras_hub/src/models/mix_transformer/mix_transformer_classifier_test.py rename to keras_hub/src/models/mit/mit_image_classifier_test.py index fb7ff5ce2b..32055c47ed 100644 --- a/keras_hub/src/models/mix_transformer/mix_transformer_classifier_test.py +++ b/keras_hub/src/models/mit/mit_image_classifier_test.py @@ -1,23 +1,19 @@ import numpy as np import pytest -from keras_hub.src.models.mix_transformer.mix_transformer_backbone import ( - MiTBackbone, -) -from keras_hub.src.models.mix_transformer.mix_transformer_classifier import ( - MiTImageClassifier, -) +from keras_hub.src.models.mit.mit_backbone import MiTBackbone +from keras_hub.src.models.mit.mit_image_classifier import MiTImageClassifier from keras_hub.src.tests.test_case import TestCase class MiTImageClassifierTest(TestCase): def setUp(self): # Setup model. - self.images = np.ones((2, 16, 16, 3), dtype="float32") + self.images = np.ones((2, 32, 32, 3), dtype="float32") self.labels = [0, 3] self.backbone = MiTBackbone( depths=[2, 2, 2, 2], - image_shape=(16, 16, 3), + image_shape=(32, 32, 3), hidden_dims=[4, 8], num_layers=2, blockwise_num_heads=[1, 2], @@ -44,7 +40,7 @@ def test_classifier_basics(self): cls=MiTImageClassifier, init_kwargs=self.init_kwargs, train_data=self.train_data, - expected_output_shape=(2, 2), + expected_output_shape=(4, 4), ) @pytest.mark.large diff --git a/keras_hub/src/models/mit/mit_image_converter.py b/keras_hub/src/models/mit/mit_image_converter.py new file mode 100644 index 0000000000..269fcb88fd --- /dev/null +++ b/keras_hub/src/models/mit/mit_image_converter.py @@ -0,0 +1,8 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.preprocessing.image_converter import ImageConverter +from keras_hub.src.models.mit import MiTBackbone + + +@keras_hub_export("keras_hub.layers.MiTImageConverter") +class MiTImageConverter(ImageConverter): + backbone_cls = MiTBackbone diff --git a/keras_hub/src/models/mix_transformer/mix_transformer_layers.py b/keras_hub/src/models/mit/mit_layers.py similarity index 92% rename from keras_hub/src/models/mix_transformer/mix_transformer_layers.py rename to keras_hub/src/models/mit/mit_layers.py index 42402da7ea..b949fcb6e2 100644 --- a/keras_hub/src/models/mix_transformer/mix_transformer_layers.py +++ b/keras_hub/src/models/mit/mit_layers.py @@ -28,19 +28,23 @@ def __init__(self, project_dim=32, patch_size=7, stride=4, **kwargs): self.patch_size = patch_size self.stride = stride + padding_size = self.patch_size // 2 + + self.padding = keras.layers.ZeroPadding2D( + padding=(padding_size, padding_size) + ) self.proj = keras.layers.Conv2D( filters=project_dim, kernel_size=patch_size, strides=stride, - padding="same", + padding="valid", ) - self.norm = keras.layers.LayerNormalization() + self.norm = keras.layers.LayerNormalization(epsilon=1e-5) def call(self, x): + x = self.padding(x) x = self.proj(x) - # B, H, W, C - shape = x.shape - x = ops.reshape(x, (-1, shape[1] * shape[2], shape[3])) + x = ops.reshape(x, (-1, x.shape[1] * x.shape[2], x.shape[3])) x = self.norm(x) return x @@ -179,20 +183,21 @@ def __init__(self, project_dim, num_heads, sr_ratio): self.k = keras.layers.Dense(project_dim) self.v = keras.layers.Dense(project_dim) self.proj = keras.layers.Dense(project_dim) + self.dropout = keras.layers.Dropout(0.1) + self.proj_drop = keras.layers.Dropout(0.1) if sr_ratio > 1: self.sr = keras.layers.Conv2D( filters=project_dim, kernel_size=sr_ratio, strides=sr_ratio, - padding="same", ) - self.norm = keras.layers.LayerNormalization() + self.norm = keras.layers.LayerNormalization(epsilon=1e-5) def call(self, x): input_shape = ops.shape(x) H, W = int(math.sqrt(input_shape[1])), int(math.sqrt(input_shape[1])) - B, C = input_shape[0], input_shape[2] + B, N, C = input_shape[0], input_shape[1], input_shape[2] q = self.q(x) q = ops.reshape( @@ -208,12 +213,11 @@ def call(self, x): if self.sr_ratio > 1: x = ops.reshape( - ops.transpose(x, [0, 2, 1]), + x, (B, H, W, C), ) x = self.sr(x) - x = ops.reshape(x, [input_shape[0], input_shape[2], -1]) - x = ops.transpose(x, [0, 2, 1]) + x = ops.reshape(x, [B, -1, C]) x = self.norm(x) k = self.k(x) @@ -237,14 +241,16 @@ def call(self, x): attn = (q @ ops.transpose(k, [0, 1, 3, 2])) * self.scale attn = ops.nn.softmax(attn, axis=-1) + attn = self.dropout(attn) attn = attn @ v attn = ops.reshape( ops.transpose(attn, [0, 2, 1, 3]), - [input_shape[0], input_shape[1], input_shape[2]], + [B, N, C], ) x = self.proj(attn) + x = self.proj_drop(x) return x diff --git a/keras_hub/src/models/mit/mit_presets.py b/keras_hub/src/models/mit/mit_presets.py new file mode 100644 index 0000000000..9c2a5fe362 --- /dev/null +++ b/keras_hub/src/models/mit/mit_presets.py @@ -0,0 +1,151 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MiT model preset configurations.""" + +backbone_presets_with_weights = { + "mit_b0_ade20k_512": { + "metadata": { + "description": ( + "MiT (MixTransformer) model with 8 transformer blocks." + ), + "params": 3321962, + "official_name": "MiT", + "path": "mit", + }, + "kaggle_handle": "kaggle://keras/mit/keras/mit_b0_ade20k_512/1", + }, + "mit_b1_ade20k_512": { + "metadata": { + "description": ( + "MiT (MixTransformer) model with 8 transformer blocks." + ), + "params": 13156554, + "official_name": "MiT", + "path": "mit", + }, + "kaggle_handle": "kaggle://keras/mit/keras/mit_b1_ade20k_512/1", + }, + "mit_b2_ade20k_512": { + "metadata": { + "description": ( + "MiT (MixTransformer) model with 16 transformer blocks." + ), + "params": 24201418, + "official_name": "MiT", + "path": "mit", + }, + "kaggle_handle": "kaggle://keras/mit/keras/mit_b2_ade20k_512/1", + }, + "mit_b3_ade20k_512": { + "metadata": { + "description": ( + "MiT (MixTransformer) model with 28 transformer blocks." + ), + "params": 44077258, + "official_name": "MiT", + "path": "mit", + }, + "kaggle_handle": "kaggle://keras/mit/keras/mit_b3_ade20k_512/1", + }, + "mit_b4_ade20k_512": { + "metadata": { + "description": ( + "MiT (MixTransformer) model with 41 transformer blocks." + ), + "params": 60847818, + "official_name": "MiT", + "path": "mit", + }, + "kaggle_handle": "kaggle://keras/mit/keras/mit_b4_ade20k_512/1", + }, + "mit_b5_ade20k_640": { + "metadata": { + "description": ( + "MiT (MixTransformer) model with 52 transformer blocks." + ), + "params": 81448138, + "official_name": "MiT", + "path": "mit", + }, + "kaggle_handle": "kaggle://keras/mit/keras/mit_b5_ade20k_640/1", + }, + "mit_b0_cityscapes_1024": { + "metadata": { + "description": ( + "MiT (MixTransformer) model with 8 transformer blocks." + ), + "params": 3321962, + "official_name": "MiT", + "path": "mit", + }, + "kaggle_handle": "kaggle://keras/mit/keras/mit_b0_cityscapes_1024/1", + }, + "mit_b1_cityscapes_1024": { + "metadata": { + "description": ( + "MiT (MixTransformer) model with 8 transformer blocks." + ), + "params": 13156554, + "official_name": "MiT", + "path": "mit", + }, + "kaggle_handle": "kaggle://keras/mit/keras/mit_b1_cityscapes_1024/1", + }, + "mit_b2_cityscapes_1024": { + "metadata": { + "description": ( + "MiT (MixTransformer) model with 16 transformer blocks." + ), + "params": 24201418, + "official_name": "MiT", + "path": "mit", + }, + "kaggle_handle": "kaggle://keras/mit/keras/mit_b2_cityscapes_1024/1", + }, + "mit_b3_cityscapes_1024": { + "metadata": { + "description": ( + "MiT (MixTransformer) model with 28 transformer blocks." + ), + "params": 44077258, + "official_name": "MiT", + "path": "mit", + }, + "kaggle_handle": "kaggle://keras/mit/keras/mit_b3_cityscapes_1024/1", + }, + "mit_b4_cityscapes_1024": { + "metadata": { + "description": ( + "MiT (MixTransformer) model with 41 transformer blocks." + ), + "params": 60847818, + "official_name": "MiT", + "path": "mit", + }, + "kaggle_handle": "kaggle://keras/mit/keras/mit_b4_cityscapes_1024/1", + }, + "mit_b5_cityscapes_1024": { + "metadata": { + "description": ( + "MiT (MixTransformer) model with 52 transformer blocks." + ), + "params": 81448138, + "official_name": "MiT", + "path": "mit", + }, + "kaggle_handle": "kaggle://keras/mit/keras/mit_b5_cityscapes_1024/1", + }, +} + +backbone_presets = { + **backbone_presets_with_weights, +} diff --git a/keras_hub/src/models/mix_transformer/mix_transformer_backbone_test.py b/keras_hub/src/models/mit/mix_transformer_backbone_test.py similarity index 81% rename from keras_hub/src/models/mix_transformer/mix_transformer_backbone_test.py rename to keras_hub/src/models/mit/mix_transformer_backbone_test.py index b3840f5c07..88c58e96a2 100644 --- a/keras_hub/src/models/mix_transformer/mix_transformer_backbone_test.py +++ b/keras_hub/src/models/mit/mix_transformer_backbone_test.py @@ -1,9 +1,7 @@ import numpy as np import pytest -from keras_hub.src.models.mix_transformer.mix_transformer_backbone import ( - MiTBackbone, -) +from keras_hub.src.models.mit.mit_backbone import MiTBackbone from keras_hub.src.tests.test_case import TestCase @@ -11,7 +9,7 @@ class MiTBackboneTest(TestCase): def setUp(self): self.init_kwargs = { "depths": [2, 2], - "image_shape": (16, 16, 3), + "image_shape": (32, 32, 3), "hidden_dims": [4, 8], "num_layers": 2, "blockwise_num_heads": [1, 2], @@ -20,7 +18,7 @@ def setUp(self): "patch_sizes": [7, 3], "strides": [4, 2], } - self.input_size = 16 + self.input_size = 32 self.input_data = np.ones( (2, self.input_size, self.input_size, 3), dtype="float32" ) @@ -30,9 +28,9 @@ def test_backbone_basics(self): cls=MiTBackbone, init_kwargs=self.init_kwargs, input_data=self.input_data, - expected_output_shape=(2, 2, 2, 8), + expected_output_shape=(2, 4, 4, 8), expected_pyramid_output_keys=["P1", "P2"], - expected_pyramid_image_sizes=[(4, 4), (2, 2)], + expected_pyramid_image_sizes=[(8, 8), (4, 4)], run_quantization_check=False, run_mixed_precision_check=False, run_data_format_check=False, diff --git a/keras_hub/src/models/mix_transformer/__init__.py b/keras_hub/src/models/mix_transformer/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/keras_hub/src/models/mobilenet/mobilenet_backbone.py b/keras_hub/src/models/mobilenet/mobilenet_backbone.py index 34ddbda6d0..e40eac32b1 100644 --- a/keras_hub/src/models/mobilenet/mobilenet_backbone.py +++ b/keras_hub/src/models/mobilenet/mobilenet_backbone.py @@ -1,560 +1,583 @@ -import keras -from keras import ops - -from keras_hub.src.api_export import keras_hub_export -from keras_hub.src.models.backbone import Backbone - -BN_EPSILON = 1e-5 -BN_MOMENTUM = 0.9 - - -@keras_hub_export("keras_hub.models.MobileNetBackbone") -class MobileNetBackbone(Backbone): - """Instantiates the MobileNet architecture. - - MobileNet is a lightweight convolutional neural network (CNN) - optimized for mobile and edge devices, striking a balance between - accuracy and efficiency. By employing depthwise separable convolutions - and techniques like Squeeze-and-Excitation (SE) blocks, - MobileNet models are highly suitable for real-time applications on - resource-constrained devices. - - References: - - [MobileNets: Efficient Convolutional Neural Networks - for Mobile Vision Applications]( - https://arxiv.org/abs/1704.04861) - - [MobileNetV2: Inverted Residuals and Linear Bottlenecks]( - https://arxiv.org/abs/1801.04381) (CVPR 2018) - - [Searching for MobileNetV3](https://arxiv.org/pdf/1905.02244.pdf) - (ICCV 2019) - - Args: - stackwise_expansion: list of list of ints, the expanded filters for - each inverted residual block for each block in the model. - stackwise_num_blocks: list of ints, number of inversted residual blocks - per block - stackwise_num_filters: list of list of ints, number of filters for - each inverted residual block in the model. - stackwise_kernel_size: list of list of ints, kernel size for each - inverted residual block in the model. - stackwise_num_strides: list of list of ints, stride length for each - inverted residual block in the model. - stackwise_se_ratio: se ratio for each inverted residual block in the - model. 0 if dont want to add Squeeze and Excite layer. - stackwise_activation: list of list of activation functions, for each - inverted residual block in the model. - image_shape: optional shape tuple, defaults to (224, 224, 3). - input_num_filters: number of filters in first convolution layer - output_num_filters: specifies whether to add conv and batch_norm in the - end, if set to None, it will not add these layers in the end. - 'None' for MobileNetV1 - input_activation: activation function to be used in the input layer - 'hard_swish' for MobileNetV3, - 'relu6' for MobileNetV1 and MobileNetV2 - output_activation: activation function to be used in the output layer - 'hard_swish' for MobileNetV3, - 'relu6' for MobileNetV1 and MobileNetV2 - depthwise_filters: int, number of filters in depthwise separable - convolution layer - squeeze_and_excite: float, squeeze and excite ratio in the depthwise - layer, None, if dont want to do squeeze and excite - - - Example: - ```python - input_data = tf.ones(shape=(8, 224, 224, 3)) - - # Randomly initialized backbone with a custom config - model = MobileNetBackbone( - stackwise_expansion=[ - [40, 56], - [64, 144, 144], - [72, 72], - [144, 288, 288], - ], - stackwise_num_blocks=[2, 3, 2, 3], - stackwise_num_filters=[ - [16, 16], - [24, 24, 24], - [24, 24], - [48, 48, 48], - ], - stackwise_kernel_size=[[3, 3], [5, 5, 5], [5, 5], [5, 5, 5]], - stackwise_num_strides=[[2, 1], [2, 1, 1], [1, 1], [2, 1, 1]], - stackwise_se_ratio=[ - [None, None], - [0.25, 0.25, 0.25], - [0.3, 0.3], - [0.3, 0.25, 0.25], - ], - stackwise_activation=[ - ["relu", "relu"], - ["hard_swish", "hard_swish", "hard_swish"], - ["hard_swish", "hard_swish"], - ["hard_swish", "hard_swish", "hard_swish"], - ], - output_num_filters=288, - input_activation="hard_swish", - output_activation="hard_swish", - input_num_filters=16, - image_shape=(224, 224, 3), - depthwise_filters=8, - squeeze_and_excite=0.5, - - ) - output = model(input_data) - ``` - """ - - def __init__( - self, - stackwise_expansion, - stackwise_num_blocks, - stackwise_num_filters, - stackwise_kernel_size, - stackwise_num_strides, - stackwise_se_ratio, - stackwise_activation, - output_num_filters, - depthwise_filters, - last_layer_filter, - squeeze_and_excite=None, - image_shape=(224, 224, 3), - input_activation="hard_swish", - output_activation="hard_swish", - input_num_filters=16, - **kwargs, - ): - # === Functional Model === - channel_axis = ( - -1 if keras.config.image_data_format() == "channels_last" else 1 - ) - - image_input = keras.layers.Input(shape=image_shape) - x = image_input - input_num_filters = adjust_channels(input_num_filters) - x = keras.layers.Conv2D( - input_num_filters, - kernel_size=3, - strides=(2, 2), - padding="same", - data_format=keras.config.image_data_format(), - use_bias=False, - name="input_conv", - )(x) - x = keras.layers.BatchNormalization( - axis=channel_axis, - epsilon=BN_EPSILON, - momentum=BN_MOMENTUM, - name="input_batch_norm", - )(x) - x = keras.layers.Activation(input_activation)(x) - - x = apply_depthwise_conv_block( - x, depthwise_filters, se=squeeze_and_excite, name="block_0" - ) - - for block in range(len(stackwise_num_blocks)): - for inverted_block in range(stackwise_num_blocks[block]): - x = apply_inverted_res_block( - x, - expansion=stackwise_expansion[block][inverted_block], - filters=adjust_channels( - stackwise_num_filters[block][inverted_block] - ), - kernel_size=stackwise_kernel_size[block][inverted_block], - stride=stackwise_num_strides[block][inverted_block], - se_ratio=stackwise_se_ratio[block][inverted_block], - activation=stackwise_activation[block][inverted_block], - name=f"block_{block+1}_{inverted_block}", - ) - - x = ConvBnAct( - x, - filter=adjust_channels(last_layer_filter), - activation="hard_swish", - name=f"block_{len(stackwise_num_blocks)+1}_0", - ) - - last_conv_ch = adjust_channels(output_num_filters) - - x = keras.layers.Conv2D( - last_conv_ch, - kernel_size=1, - padding="same", - data_format=keras.config.image_data_format(), - use_bias=False, - name="output_conv", - )(x) - - # no output normalization in mobilenetv3 - if output_activation == "relu6": - x = keras.layers.BatchNormalization( - axis=channel_axis, - epsilon=BN_EPSILON, - momentum=BN_MOMENTUM, - name="output_batch_norm", - )(x) - - x = keras.layers.Activation(output_activation)(x) - - super().__init__(inputs=image_input, outputs=x, **kwargs) - - # === Config === - self.stackwise_expansion = stackwise_expansion - self.stackwise_num_blocks = stackwise_num_blocks - self.stackwise_num_filters = stackwise_num_filters - self.stackwise_kernel_size = stackwise_kernel_size - self.stackwise_num_strides = stackwise_num_strides - self.stackwise_se_ratio = stackwise_se_ratio - self.stackwise_activation = stackwise_activation - self.input_num_filters = input_num_filters - self.output_num_filters = output_num_filters - self.depthwise_filters = depthwise_filters - self.last_layer_filter = last_layer_filter - self.squeeze_and_excite = squeeze_and_excite - self.input_activation = keras.activations.get(input_activation) - self.output_activation = keras.activations.get(output_activation) - self.image_shape = image_shape - - def get_config(self): - config = super().get_config() - config.update( - { - "stackwise_expansion": self.stackwise_expansion, - "stackwise_num_blocks": self.stackwise_num_blocks, - "stackwise_num_filters": self.stackwise_num_filters, - "stackwise_kernel_size": self.stackwise_kernel_size, - "stackwise_num_strides": self.stackwise_num_strides, - "stackwise_se_ratio": self.stackwise_se_ratio, - "stackwise_activation": self.stackwise_activation, - "image_shape": self.image_shape, - "input_num_filters": self.input_num_filters, - "output_num_filters": self.output_num_filters, - "depthwise_filters": self.depthwise_filters, - "last_layer_filter": self.last_layer_filter, - "squeeze_and_excite": self.squeeze_and_excite, - "input_activation": keras.activations.serialize( - activation=self.input_activation - ), - "output_activation": keras.activations.serialize( - activation=self.output_activation - ), - } - ) - return config - - -def adjust_channels(x, divisor=8, min_value=None): - """Ensure that all layers have a channel number divisible by the `divisor`. - - Args: - x: integer, input value. - divisor: integer, the value by which a channel number should be - divisible, defaults to 8. - min_value: float, optional minimum value for the new tensor. If None, - defaults to value of divisor. - - Returns: - the updated input scalar. - """ - - if min_value is None: - min_value = divisor - - new_x = max(min_value, int(x + divisor / 2) // divisor * divisor) - - # make sure that round down does not go down by more than 10%. - if new_x < 0.9 * x: - new_x += divisor - return new_x - - -def apply_inverted_res_block( - x, - expansion, - filters, - kernel_size, - stride, - se_ratio, - activation, - name=None, -): - """An Inverted Residual Block. - - Args: - x: input tensor. - expansion: integer, the expansion ratio, multiplied with infilters to - get the minimum value passed to adjust_channels. - filters: integer, number of filters for convolution layer. - kernel_size: integer, the kernel size for DepthWise Convolutions. - stride: integer, the stride length for DepthWise Convolutions. - se_ratio: float, ratio for bottleneck filters. Number of bottleneck - filters = filters * se_ratio. - activation: the activation layer to use. - name: string, block label. - - Returns: - the updated input tensor. - """ - channel_axis = ( - -1 if keras.config.image_data_format() == "channels_last" else 1 - ) - activation = keras.activations.get(activation) - shortcut = x - infilters = x.shape[channel_axis] - expanded_channels = adjust_channels(expansion) - - x = keras.layers.Conv2D( - expanded_channels, - kernel_size=1, - padding="same", - data_format=keras.config.image_data_format(), - use_bias=False, - name=f"{name}_conv1", - )(x) - - x = keras.layers.BatchNormalization( - axis=channel_axis, - epsilon=BN_EPSILON, - momentum=BN_MOMENTUM, - name=f"{name}_bn1", - )(x) - - x = keras.layers.Activation(activation=activation)(x) - - if stride == 2: - x = keras.layers.ZeroPadding2D( - padding=correct_pad_downsample(x, kernel_size), - )(x) - - x = keras.layers.Conv2D( - expanded_channels, - kernel_size, - strides=stride, - padding="same" if stride == 1 else "valid", - groups=expanded_channels, - data_format=keras.config.image_data_format(), - use_bias=False, - name=f"{name}_conv2", - )(x) - x = keras.layers.BatchNormalization( - axis=channel_axis, - epsilon=BN_EPSILON, - momentum=BN_MOMENTUM, - name=f"{name}_bn2", - )(x) - - x = keras.layers.Activation(activation=activation)(x) - - if se_ratio: - se_filters = expanded_channels - x = SqueezeAndExcite2D( - input=x, - filters=se_filters, - bottleneck_filters=adjust_channels(se_filters * se_ratio), - squeeze_activation="relu", - excite_activation=keras.activations.hard_sigmoid, - name=f"{name}_se", - ) - - x = keras.layers.Conv2D( - filters, - kernel_size=1, - padding="same", - data_format=keras.config.image_data_format(), - use_bias=False, - name=f"{name}_conv3", - )(x) - x = keras.layers.BatchNormalization( - axis=channel_axis, - epsilon=BN_EPSILON, - momentum=BN_MOMENTUM, - name=f"{name}_bn3", - )(x) - - if stride == 1 and infilters == filters: - x = keras.layers.Add(name=f"{name}_add")([shortcut, x]) - return x - - -def apply_depthwise_conv_block( - x, filters, kernel_size=3, stride=1, se=None, name=None -): - """Adds a depthwise convolution block. - - A depthwise convolution block consists of a depthwise conv, - batch normalization, relu6, pointwise convolution, - batch normalization and relu6 activation. - - Args: - x: Input tensor of shape `(rows, cols, channels) - filters: Integer, the dimensionality of the output space - (i.e. the number of output filters in the pointwise convolution). - strides: An integer or tuple/list of 2 integers, specifying the strides - of the convolution along the width and height. - Can be a single integer to specify the same value for - all spatial dimensions. Specifying any stride value != 1 is - incompatible with specifying any `dilation_rate` value != 1. - block_id: Integer, a unique identification designating the block number. - - Input shape: - 4D tensor with shape: `(batch, rows, cols, channels)` in "channels_last" - 4D tensor with shape: `(batch, channels, rows, cols)` in "channels_first" - Returns: - Output tensor of block. - """ - channel_axis = ( - -1 if keras.config.image_data_format() == "channels_last" else 1 - ) - infilters = x.shape[channel_axis] - name = f"{name}_0" - - if stride == 2: - x = keras.layers.ZeroPadding2D( - padding=correct_pad_downsample(x, kernel_size), - )(x) - - x = keras.layers.Conv2D( - infilters, - kernel_size, - strides=stride, - padding="same" if stride == 1 else "valid", - data_format=keras.config.image_data_format(), - groups=infilters, - use_bias=False, - name=f"{name}_conv1", - )(x) - x = keras.layers.BatchNormalization( - axis=channel_axis, - epsilon=BN_EPSILON, - momentum=BN_MOMENTUM, - name=f"{name}_bn1", - )(x) - x = keras.layers.ReLU(6.0)(x) - - if se: - x = SqueezeAndExcite2D( - input=x, - filters=infilters, - bottleneck_filters=adjust_channels(infilters * se), - squeeze_activation="relu", - excite_activation=keras.activations.hard_sigmoid, - name=f"{name}_se", - ) - - x = keras.layers.Conv2D( - filters, - kernel_size=1, - padding="same", - data_format=keras.config.image_data_format(), - use_bias=False, - name=f"{name}_conv2", - )(x) - x = keras.layers.BatchNormalization( - axis=channel_axis, - epsilon=BN_EPSILON, - momentum=BN_MOMENTUM, - name=f"{name}_bn2", - )(x) - return x - - -def SqueezeAndExcite2D( - input, - filters, - bottleneck_filters=None, - squeeze_activation="relu", - excite_activation="sigmoid", - name=None, -): - """ - Description: - This layer applies a content-aware mechanism to adaptively assign - channel-wise weights. It uses global average pooling to compress - feature maps into single values, which are then processed by - two Conv1D layers: the first reduces the dimensionality, and - the second restores it. - Args: - filters: Number of input and output filters. The number of input and - output filters is same. - bottleneck_filters: (Optional) Number of bottleneck filters. Defaults - to `0.25 * filters` - squeeze_activation: (Optional) String, callable (or - keras.layers.Layer) or keras.activations.Activation instance - denoting activation to be applied after squeeze convolution. - Defaults to `relu`. - excite_activation: (Optional) String, callable (or - keras.layers.Layer) or keras.activations.Activation instance - denoting activation to be applied after excite convolution. - Defaults to `sigmoid`. - name: Name of the layer - """ - if not bottleneck_filters: - bottleneck_filters = filters // 4 - - x = input - x = keras.layers.Conv2D( - bottleneck_filters, - (1, 1), - data_format=keras.config.image_data_format(), - activation=squeeze_activation, - name=f"{name}_conv_reduce", - )(x) - x = keras.layers.Conv2D( - filters, - (1, 1), - data_format=keras.config.image_data_format(), - activation=excite_activation, - name=f"{name}_conv_expand", - )(x) - - x = ops.multiply(x, input) - return x - - -def ConvBnAct(x, filter, activation, name=None): - channel_axis = ( - -1 if keras.config.image_data_format() == "channels_last" else 1 - ) - x = keras.layers.Conv2D( - filter, - kernel_size=1, - padding="same", - data_format=keras.config.image_data_format(), - use_bias=False, - name=f"{name}_conv", - )(x) - x = keras.layers.BatchNormalization( - axis=channel_axis, - epsilon=BN_EPSILON, - momentum=BN_MOMENTUM, - name=f"{name}_bn", - )(x) - x = keras.layers.Activation(activation)(x) - return x - - -def correct_pad_downsample(inputs, kernel_size): - """Returns a tuple for zero-padding for 2D convolution with downsampling. - - Args: - inputs: Input tensor. - kernel_size: An integer or tuple/list of 2 integers. - - Returns: - A tuple. - """ - img_dim = 1 - input_size = inputs.shape[img_dim : (img_dim + 2)] - if isinstance(kernel_size, int): - kernel_size = (kernel_size, kernel_size) - if input_size[0] is None: - adjust = (1, 1) - else: - adjust = (1 - input_size[0] % 2, 1 - input_size[1] % 2) - correct = (kernel_size[0] // 2, kernel_size[1] // 2) - return ( - (correct[0] - adjust[0], correct[0]), - (correct[1] - adjust[1], correct[1]), - ) +import keras +from keras import ops + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.backbone import Backbone + +BN_EPSILON = 1e-5 +BN_MOMENTUM = 0.9 + + +@keras_hub_export("keras_hub.models.MobileNetBackbone") +class MobileNetBackbone(Backbone): + """Instantiates the MobileNet architecture. + + MobileNet is a lightweight convolutional neural network (CNN) + optimized for mobile and edge devices, striking a balance between + accuracy and efficiency. By employing depthwise separable convolutions + and techniques like Squeeze-and-Excitation (SE) blocks, + MobileNet models are highly suitable for real-time applications on + resource-constrained devices. + + References: + - [MobileNets: Efficient Convolutional Neural Networks + for Mobile Vision Applications]( + https://arxiv.org/abs/1704.04861) + - [MobileNetV2: Inverted Residuals and Linear Bottlenecks]( + https://arxiv.org/abs/1801.04381) (CVPR 2018) + - [Searching for MobileNetV3](https://arxiv.org/pdf/1905.02244.pdf) + (ICCV 2019) + + Args: + stackwise_expansion: list of list of ints, the expanded filters for + each inverted residual block for each block in the model. + stackwise_num_blocks: list of ints, number of inversted residual blocks + per block + stackwise_num_filters: list of list of ints, number of filters for + each inverted residual block in the model. + stackwise_kernel_size: list of list of ints, kernel size for each + inverted residual block in the model. + stackwise_num_strides: list of list of ints, stride length for each + inverted residual block in the model. + stackwise_se_ratio: se ratio for each inverted residual block in the + model. 0 if dont want to add Squeeze and Excite layer. + stackwise_activation: list of list of activation functions, for each + inverted residual block in the model. + image_shape: optional shape tuple, defaults to (224, 224, 3). + input_num_filters: number of filters in first convolution layer + output_num_filters: specifies whether to add conv and batch_norm in the + end, if set to None, it will not add these layers in the end. + 'None' for MobileNetV1 + input_activation: activation function to be used in the input layer + 'hard_swish' for MobileNetV3, + 'relu6' for MobileNetV1 and MobileNetV2 + output_activation: activation function to be used in the output layer + 'hard_swish' for MobileNetV3, + 'relu6' for MobileNetV1 and MobileNetV2 + depthwise_filters: int, number of filters in depthwise separable + convolution layer + squeeze_and_excite: float, squeeze and excite ratio in the depthwise + layer, None, if dont want to do squeeze and excite + + + Example: + ```python + input_data = tf.ones(shape=(8, 224, 224, 3)) + + # Randomly initialized backbone with a custom config + model = MobileNetBackbone( + stackwise_expansion=[ + [40, 56], + [64, 144, 144], + [72, 72], + [144, 288, 288], + ], + stackwise_num_blocks=[2, 3, 2, 3], + stackwise_num_filters=[ + [16, 16], + [24, 24, 24], + [24, 24], + [48, 48, 48], + ], + stackwise_kernel_size=[[3, 3], [5, 5, 5], [5, 5], [5, 5, 5]], + stackwise_num_strides=[[2, 1], [2, 1, 1], [1, 1], [2, 1, 1]], + stackwise_se_ratio=[ + [None, None], + [0.25, 0.25, 0.25], + [0.3, 0.3], + [0.3, 0.25, 0.25], + ], + stackwise_activation=[ + ["relu", "relu"], + ["hard_swish", "hard_swish", "hard_swish"], + ["hard_swish", "hard_swish"], + ["hard_swish", "hard_swish", "hard_swish"], + ], + output_num_filters=288, + input_activation="hard_swish", + output_activation="hard_swish", + input_num_filters=16, + image_shape=(224, 224, 3), + depthwise_filters=8, + squeeze_and_excite=0.5, + + ) + output = model(input_data) + ``` + """ + + def __init__( + self, + stackwise_expansion, + stackwise_num_blocks, + stackwise_num_filters, + stackwise_kernel_size, + stackwise_num_strides, + stackwise_se_ratio, + stackwise_activation, + stackwise_padding, + output_num_filters, + depthwise_filters, + last_layer_filter, + squeeze_and_excite=None, + image_shape=(None, None, 3), + input_activation="hard_swish", + output_activation="hard_swish", + input_num_filters=16, + **kwargs, + ): + # === Functional Model === + channel_axis = ( + -1 if keras.config.image_data_format() == "channels_last" else 1 + ) + + image_input = keras.layers.Input(shape=image_shape) + x = image_input + input_num_filters = adjust_channels(input_num_filters) + + pad_width = ( + (0, 0), # No padding for batch + (1, 1), # 1 pixel padding for height + (1, 1), # 1 pixel padding for width + (0, 0), + ) # No padding for channels + x = ops.pad(x, pad_width=pad_width) + x = keras.layers.Conv2D( + input_num_filters, + kernel_size=3, + strides=(2, 2), + data_format=keras.config.image_data_format(), + use_bias=False, + name="input_conv", + )(x) + x = keras.layers.BatchNormalization( + axis=channel_axis, + epsilon=BN_EPSILON, + momentum=BN_MOMENTUM, + name="input_batch_norm", + )(x) + x = keras.layers.Activation(input_activation)(x) + + x = apply_depthwise_conv_block( + x, depthwise_filters, se=squeeze_and_excite, name="block_0" + ) + + for block in range(len(stackwise_num_blocks)): + for inverted_block in range(stackwise_num_blocks[block]): + x = apply_inverted_res_block( + x, + expansion=stackwise_expansion[block][inverted_block], + filters=adjust_channels( + stackwise_num_filters[block][inverted_block] + ), + kernel_size=stackwise_kernel_size[block][inverted_block], + stride=stackwise_num_strides[block][inverted_block], + se_ratio=stackwise_se_ratio[block][inverted_block], + activation=stackwise_activation[block][inverted_block], + padding=stackwise_padding[block][inverted_block], + name=f"block_{block+1}_{inverted_block}", + ) + + x = ConvBnAct( + x, + filter=adjust_channels(last_layer_filter), + activation="hard_swish", + name=f"block_{len(stackwise_num_blocks)+1}_0", + ) + + last_conv_ch = adjust_channels(output_num_filters) + + x = keras.layers.Conv2D( + last_conv_ch, + kernel_size=1, + data_format=keras.config.image_data_format(), + use_bias=False, + name="output_conv", + )(x) + + # no output normalization in mobilenetv3 + if output_activation == "relu6": + x = keras.layers.BatchNormalization( + axis=channel_axis, + epsilon=BN_EPSILON, + momentum=BN_MOMENTUM, + name="output_batch_norm", + )(x) + + x = keras.layers.Activation(output_activation)(x) + + super().__init__(inputs=image_input, outputs=x, **kwargs) + + # === Config === + self.stackwise_expansion = stackwise_expansion + self.stackwise_num_blocks = stackwise_num_blocks + self.stackwise_num_filters = stackwise_num_filters + self.stackwise_kernel_size = stackwise_kernel_size + self.stackwise_num_strides = stackwise_num_strides + self.stackwise_se_ratio = stackwise_se_ratio + self.stackwise_activation = stackwise_activation + self.stackwise_padding = stackwise_padding + self.input_num_filters = input_num_filters + self.output_num_filters = output_num_filters + self.depthwise_filters = depthwise_filters + self.last_layer_filter = last_layer_filter + self.squeeze_and_excite = squeeze_and_excite + self.input_activation = keras.activations.get(input_activation) + self.output_activation = keras.activations.get(output_activation) + self.image_shape = image_shape + + def get_config(self): + config = super().get_config() + config.update( + { + "stackwise_expansion": self.stackwise_expansion, + "stackwise_num_blocks": self.stackwise_num_blocks, + "stackwise_num_filters": self.stackwise_num_filters, + "stackwise_kernel_size": self.stackwise_kernel_size, + "stackwise_num_strides": self.stackwise_num_strides, + "stackwise_se_ratio": self.stackwise_se_ratio, + "stackwise_activation": self.stackwise_activation, + "stackwise_padding": self.stackwise_padding, + "image_shape": self.image_shape, + "input_num_filters": self.input_num_filters, + "output_num_filters": self.output_num_filters, + "depthwise_filters": self.depthwise_filters, + "last_layer_filter": self.last_layer_filter, + "squeeze_and_excite": self.squeeze_and_excite, + "input_activation": keras.activations.serialize( + activation=self.input_activation + ), + "output_activation": keras.activations.serialize( + activation=self.output_activation + ), + } + ) + return config + + +def adjust_channels(x, divisor=8, min_value=None): + """Ensure that all layers have a channel number divisible by the `divisor`. + + Args: + x: integer, input value. + divisor: integer, the value by which a channel number should be + divisible, defaults to 8. + min_value: float, optional minimum value for the new tensor. If None, + defaults to value of divisor. + + Returns: + the updated input scalar. + """ + + if min_value is None: + min_value = divisor + + new_x = max(min_value, int(x + divisor / 2) // divisor * divisor) + + # make sure that round down does not go down by more than 10%. + if new_x < 0.9 * x: + new_x += divisor + return new_x + + +def apply_inverted_res_block( + x, + expansion, + filters, + kernel_size, + stride, + se_ratio, + activation, + padding, + name=None, +): + """An Inverted Residual Block. + + Args: + x: input tensor. + expansion: integer, the expansion ratio, multiplied with infilters to + get the minimum value passed to adjust_channels. + filters: integer, number of filters for convolution layer. + kernel_size: integer, the kernel size for DepthWise Convolutions. + stride: integer, the stride length for DepthWise Convolutions. + se_ratio: float, ratio for bottleneck filters. Number of bottleneck + filters = filters * se_ratio. + activation: the activation layer to use. + padding: padding in the conv2d layer + name: string, block label. + + Returns: + the updated input tensor. + """ + channel_axis = ( + -1 if keras.config.image_data_format() == "channels_last" else 1 + ) + activation = keras.activations.get(activation) + shortcut = x + infilters = x.shape[channel_axis] + expanded_channels = adjust_channels(expansion) + + x = keras.layers.Conv2D( + expanded_channels, + kernel_size=1, + data_format=keras.config.image_data_format(), + use_bias=False, + name=f"{name}_conv1", + )(x) + + x = keras.layers.BatchNormalization( + axis=channel_axis, + epsilon=BN_EPSILON, + momentum=BN_MOMENTUM, + name=f"{name}_bn1", + )(x) + + x = keras.layers.Activation(activation=activation)(x) + + # if stride == 2: + # x = keras.layers.ZeroPadding2D( + # padding=correct_pad_downsample(x, kernel_size), + # )(x) + + # pad_width=[[padding, padding], [padding, padding]] + pad_width = ( + (0, 0), # No padding for batch + (padding, padding), # 1 pixel padding for height + (padding, padding), # 1 pixel padding for width + (0, 0), + ) # No padding for channels + x = ops.pad(x, pad_width=pad_width) + + x = keras.layers.Conv2D( + expanded_channels, + kernel_size, + strides=stride, + padding="valid", + groups=expanded_channels, + data_format=keras.config.image_data_format(), + use_bias=False, + name=f"{name}_conv2", + )(x) + x = keras.layers.BatchNormalization( + axis=channel_axis, + epsilon=BN_EPSILON, + momentum=BN_MOMENTUM, + name=f"{name}_bn2", + )(x) + + x = keras.layers.Activation(activation=activation)(x) + + if se_ratio: + se_filters = expanded_channels + x = SqueezeAndExcite2D( + input=x, + filters=se_filters, + bottleneck_filters=adjust_channels(se_filters * se_ratio), + squeeze_activation="relu", + excite_activation=keras.activations.hard_sigmoid, + name=f"{name}_se", + ) + + x = keras.layers.Conv2D( + filters, + kernel_size=1, + data_format=keras.config.image_data_format(), + use_bias=False, + name=f"{name}_conv3", + )(x) + x = keras.layers.BatchNormalization( + axis=channel_axis, + epsilon=BN_EPSILON, + momentum=BN_MOMENTUM, + name=f"{name}_bn3", + )(x) + + if stride == 1 and infilters == filters: + x = keras.layers.Add(name=f"{name}_add")([shortcut, x]) + return x + + +def apply_depthwise_conv_block( + x, filters, kernel_size=3, stride=2, se=None, name=None +): + """Adds a depthwise convolution block. + + A depthwise convolution block consists of a depthwise conv, + batch normalization, relu6, pointwise convolution, + batch normalization and relu6 activation. + + Args: + x: Input tensor of shape `(rows, cols, channels) + filters: Integer, the dimensionality of the output space + (i.e. the number of output filters in the pointwise convolution). + strides: An integer or tuple/list of 2 integers, specifying the strides + of the convolution along the width and height. + Can be a single integer to specify the same value for + all spatial dimensions. Specifying any stride value != 1 is + incompatible with specifying any `dilation_rate` value != 1. + block_id: Integer, a unique identification designating the block number. + + Input shape: + 4D tensor with shape: `(batch, rows, cols, channels)` in "channels_last" + 4D tensor with shape: `(batch, channels, rows, cols)` in "channels_first" + Returns: + Output tensor of block. + """ + channel_axis = ( + -1 if keras.config.image_data_format() == "channels_last" else 1 + ) + infilters = x.shape[channel_axis] + name = f"{name}_0" + + # if stride == 2: + # x = keras.layers.ZeroPadding2D( + # padding=correct_pad_downsample(x, kernel_size), + # )(x) + pad_width = ( + (0, 0), # No padding for batch + (1, 1), # 1 pixel padding for height + (1, 1), # 1 pixel padding for width + (0, 0), + ) # No padding for channels + x = ops.pad(x, pad_width=pad_width) + x = keras.layers.Conv2D( + infilters, + kernel_size, + strides=stride, + padding="valid", + data_format=keras.config.image_data_format(), + groups=infilters, + use_bias=False, + name=f"{name}_conv1", + )(x) + x = keras.layers.BatchNormalization( + axis=channel_axis, + epsilon=BN_EPSILON, + momentum=BN_MOMENTUM, + name=f"{name}_bn1", + )(x) + x = keras.layers.ReLU(6.0)(x) + + if se: + x = SqueezeAndExcite2D( + input=x, + filters=infilters, + bottleneck_filters=adjust_channels(infilters * se), + squeeze_activation="relu", + excite_activation=keras.activations.hard_sigmoid, + name=f"{name}_se", + ) + + x = keras.layers.Conv2D( + filters, + kernel_size=1, + data_format=keras.config.image_data_format(), + use_bias=False, + name=f"{name}_conv2", + )(x) + x = keras.layers.BatchNormalization( + axis=channel_axis, + epsilon=BN_EPSILON, + momentum=BN_MOMENTUM, + name=f"{name}_bn2", + )(x) + return x + + +def SqueezeAndExcite2D( + input, + filters, + bottleneck_filters=None, + squeeze_activation="relu", + excite_activation="sigmoid", + name=None, +): + """ + Description: + This layer applies a content-aware mechanism to adaptively assign + channel-wise weights. It uses global average pooling to compress + feature maps into single values, which are then processed by + two Conv1D layers: the first reduces the dimensionality, and + the second restores it. + Args: + filters: Number of input and output filters. The number of input and + output filters is same. + bottleneck_filters: (Optional) Number of bottleneck filters. Defaults + to `0.25 * filters` + squeeze_activation: (Optional) String, callable (or + keras.layers.Layer) or keras.activations.Activation instance + denoting activation to be applied after squeeze convolution. + Defaults to `relu`. + excite_activation: (Optional) String, callable (or + keras.layers.Layer) or keras.activations.Activation instance + denoting activation to be applied after excite convolution. + Defaults to `sigmoid`. + name: Name of the layer + """ + if not bottleneck_filters: + bottleneck_filters = filters // 4 + + x = input + x = keras.layers.Conv2D( + bottleneck_filters, + (1, 1), + data_format=keras.config.image_data_format(), + activation=squeeze_activation, + name=f"{name}_conv_reduce", + )(x) + x = keras.layers.Conv2D( + filters, + (1, 1), + data_format=keras.config.image_data_format(), + activation=excite_activation, + name=f"{name}_conv_expand", + )(x) + + x = ops.multiply(x, input) + return x + + +def ConvBnAct(x, filter, activation, name=None): + channel_axis = ( + -1 if keras.config.image_data_format() == "channels_last" else 1 + ) + x = keras.layers.Conv2D( + filter, + kernel_size=1, + data_format=keras.config.image_data_format(), + use_bias=False, + name=f"{name}_conv", + )(x) + x = keras.layers.BatchNormalization( + axis=channel_axis, + epsilon=BN_EPSILON, + momentum=BN_MOMENTUM, + name=f"{name}_bn", + )(x) + x = keras.layers.Activation(activation)(x) + return x + + +def correct_pad_downsample(inputs, kernel_size): + """Returns a tuple for zero-padding for 2D convolution with downsampling. + + Args: + inputs: Input tensor. + kernel_size: An integer or tuple/list of 2 integers. + + Returns: + A tuple. + """ + img_dim = 1 + input_size = inputs.shape[img_dim : (img_dim + 2)] + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size) + if input_size[0] is None: + adjust = (1, 1) + else: + adjust = (1 - input_size[0] % 2, 1 - input_size[1] % 2) + correct = (kernel_size[0] // 2, kernel_size[1] // 2) + return ( + (correct[0] - adjust[0], correct[0]), + (correct[1] - adjust[1], correct[1]), + ) diff --git a/keras_hub/src/models/mobilenet/mobilenet_backbone_test.py b/keras_hub/src/models/mobilenet/mobilenet_backbone_test.py index 8119d0aa1b..3d909c9221 100644 --- a/keras_hub/src/models/mobilenet/mobilenet_backbone_test.py +++ b/keras_hub/src/models/mobilenet/mobilenet_backbone_test.py @@ -37,8 +37,8 @@ def setUp(self): "stackwise_se_ratio": [ [None, None], [0.25, 0.25, 0.25], - [0.3, 0.3], - [0.3, 0.25, 0.25], + [0.25, 0.25], + [0.25, 0.25, 0.25], ], "stackwise_activation": [ ["relu", "relu"], @@ -47,6 +47,7 @@ def setUp(self): ["hard_swish", "hard_swish", "hard_swish"], ["hard_swish"], ], + "stackwise_padding": [[1, 1], [2, 2, 2], [2, 2], [2, 2, 2]], "output_num_filters": 1024, "input_activation": "hard_swish", "output_activation": "hard_swish", @@ -63,7 +64,7 @@ def test_backbone_basics(self): cls=MobileNetBackbone, init_kwargs=self.init_kwargs, input_data=self.input_data, - expected_output_shape=(2, 14, 14, 1024), + expected_output_shape=(2, 7, 7, 1024), run_mixed_precision_check=False, run_data_format_check=False, ) diff --git a/keras_hub/src/models/mobilenet/mobilenet_image_classifier.py b/keras_hub/src/models/mobilenet/mobilenet_image_classifier.py index bf07914781..e9cc0fc153 100644 --- a/keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +++ b/keras_hub/src/models/mobilenet/mobilenet_image_classifier.py @@ -5,6 +5,7 @@ MobileNetImageClassifierPreprocessor, ) + @keras_hub_export("keras_hub.models.MobileNetImageClassifier") class MobileNetImageClassifier(ImageClassifier): backbone_cls = MobileNetBackbone diff --git a/keras_hub/src/models/mobilenet/mobilenet_image_classifier_test.py b/keras_hub/src/models/mobilenet/mobilenet_image_classifier_test.py index 36adb46613..7997b444fd 100644 --- a/keras_hub/src/models/mobilenet/mobilenet_image_classifier_test.py +++ b/keras_hub/src/models/mobilenet/mobilenet_image_classifier_test.py @@ -32,8 +32,8 @@ def setUp(self): stackwise_se_ratio=[ [None, None], [0.25, 0.25, 0.25], - [0.3, 0.3], - [0.3, 0.25, 0.25], + [0.25, 0.25], + [0.25, 0.25, 0.25], ], stackwise_activation=[ ["relu", "relu"], @@ -41,6 +41,7 @@ def setUp(self): ["hard_swish", "hard_swish"], ["hard_swish", "hard_swish", "hard_swish"], ], + stackwise_padding=[[1, 1], [2, 2, 2], [2, 2], [2, 2, 2], [1]], output_num_filters=1024, input_activation="hard_swish", output_activation="hard_swish", @@ -71,6 +72,18 @@ def test_classifier_basics(self): expected_output_shape=(2, 2), ) + @pytest.mark.large + def test_smallest_preset(self): + # Test that our forward pass is stable! + image_batch = self.load_test_image()[None, ...] / 255.0 + self.run_preset_test( + cls=MobileNetImageClassifier, + preset="mobilenetv3_small_050", + input_data=image_batch, + expected_output_shape=(1, 1000), + expected_labels=[85], + ) + @pytest.mark.large def test_saved_model(self): self.run_model_saving_test( diff --git a/keras_hub/src/models/pali_gemma/pali_gemma_backbone.py b/keras_hub/src/models/pali_gemma/pali_gemma_backbone.py index e8eb1dd232..75c1cb8ad0 100644 --- a/keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +++ b/keras_hub/src/models/pali_gemma/pali_gemma_backbone.py @@ -61,8 +61,6 @@ class PaliGemmaBackbone(Backbone): vit_classifier_activation: activation function. The activation that is used for final output classification in the vision transformer. vit_name: string. The name used for vision transformer layers. - include_rescaling: bool. If true, the image input will be rescaled from - the range `[0, 255]`, to the range `[0, 1]`. layer_norm_epsilon: float. The epsilon value user for every layer norm in all transformer blocks. dropout: float. Dropout probability for the Transformer decoder blocks. @@ -121,7 +119,6 @@ def __init__( vit_pooling=None, vit_classifier_activation=None, vit_name=None, - include_rescaling=True, layer_norm_epsilon=1e-6, dropout=0, dtype=None, @@ -145,7 +142,6 @@ def __init__( vit_intermediate_dim = vit_intermediate_dim or 4304 self.vit_encoder = PaliGemmaVit( image_size=image_size, - include_rescaling=include_rescaling, patch_size=vit_patch_size, num_heads=vit_num_heads, hidden_dim=vit_hidden_dim, @@ -215,7 +211,6 @@ def __init__( # === Config === self.vocabulary_size = vocabulary_size self.image_size = image_size - self.include_rescaling = include_rescaling self.num_layers = num_layers self.num_query_heads = num_query_heads self.num_key_value_heads = num_key_value_heads @@ -242,7 +237,6 @@ def get_config(self): { "vocabulary_size": self.vocabulary_size, "image_size": self.image_size, - "include_rescaling": self.include_rescaling, "num_layers": self.num_layers, "num_query_heads": self.num_query_heads, "num_key_value_heads": self.num_key_value_heads, diff --git a/keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py b/keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py index 5419daee5b..a0f912add1 100644 --- a/keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +++ b/keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py @@ -110,7 +110,9 @@ def __init__( self.backbone = backbone # === Functional Model === - inputs = backbone.inputs + # This must be "backbone.input" i.e. the full input structure, + # rather than "backbone.inputs" which is the flattened list of inputs. + inputs = backbone.input hidden_state = backbone(inputs=inputs) outputs = backbone.token_embedding(hidden_state, reverse=True) outputs = outputs[:, backbone.image_sequence_length :, :] diff --git a/keras_hub/src/models/pali_gemma/pali_gemma_presets.py b/keras_hub/src/models/pali_gemma/pali_gemma_presets.py index 3f642833a4..af5443fd1a 100644 --- a/keras_hub/src/models/pali_gemma/pali_gemma_presets.py +++ b/keras_hub/src/models/pali_gemma/pali_gemma_presets.py @@ -12,7 +12,7 @@ "path": "pali_gemma", "model_card": "https://www.kaggle.com/models/google/paligemma", }, - "kaggle_handle": "kaggle://keras/paligemma/keras/pali_gemma_3b_mix_224/2", + "kaggle_handle": "kaggle://keras/paligemma/keras/pali_gemma_3b_mix_224/3", }, "pali_gemma_3b_mix_448": { "metadata": { @@ -24,7 +24,7 @@ "path": "pali_gemma", "model_card": "https://www.kaggle.com/models/google/paligemma", }, - "kaggle_handle": "kaggle://keras/paligemma/keras/pali_gemma_3b_mix_448/2", + "kaggle_handle": "kaggle://keras/paligemma/keras/pali_gemma_3b_mix_448/3", }, "pali_gemma_3b_224": { "metadata": { @@ -36,7 +36,7 @@ "path": "pali_gemma", "model_card": "https://www.kaggle.com/models/google/paligemma", }, - "kaggle_handle": "kaggle://keras/paligemma/keras/pali_gemma_3b_224/2", + "kaggle_handle": "kaggle://keras/paligemma/keras/pali_gemma_3b_224/3", }, "pali_gemma_3b_448": { "metadata": { @@ -48,7 +48,7 @@ "path": "pali_gemma", "model_card": "https://www.kaggle.com/models/google/paligemma", }, - "kaggle_handle": "kaggle://keras/paligemma/keras/pali_gemma_3b_448/2", + "kaggle_handle": "kaggle://keras/paligemma/keras/pali_gemma_3b_448/3", }, "pali_gemma_3b_896": { "metadata": { @@ -60,6 +60,6 @@ "path": "pali_gemma", "model_card": "https://www.kaggle.com/models/google/paligemma", }, - "kaggle_handle": "kaggle://keras/paligemma/keras/pali_gemma_3b_896/2", + "kaggle_handle": "kaggle://keras/paligemma/keras/pali_gemma_3b_896/3", }, } diff --git a/keras_hub/src/models/pali_gemma/pali_gemma_vit.py b/keras_hub/src/models/pali_gemma/pali_gemma_vit.py index 20194a6039..190a5e8e13 100644 --- a/keras_hub/src/models/pali_gemma/pali_gemma_vit.py +++ b/keras_hub/src/models/pali_gemma/pali_gemma_vit.py @@ -410,8 +410,6 @@ class PaliGemmaVit(keras.Model): Args: image_size: int. The height/width of the image. Both height and width is expected to be the same. - include_rescaling: bool. If true, the image input will be rescaled from - the range `[0, 255]`, to the range `[0, 1]`. patch_size: int. The size of each square patch in the input image. num_heads: int. The number of attention heads for the vision(image) transformer encoder. @@ -452,7 +450,6 @@ def __init__( num_layers, intermediate_dim, num_classes, - include_rescaling=True, pooling=None, classifier_activation=None, dtype=None, @@ -463,14 +460,6 @@ def __init__( shape=(image_size, image_size, 3), name="images" ) x = image_input # Intermediate result. - # TODO we have moved this rescaling to preprocessing layers for most - # models. We should consider removing it here, though it would break - # compatibility. - if include_rescaling: - rescaling = keras.layers.Rescaling( - scale=1.0 / 127.5, offset=-1.0, name="rescaling" - ) - x = rescaling(image_input) x = PaliGemmaVitEncoder( hidden_dim=hidden_dim, num_layers=num_layers, @@ -520,7 +509,6 @@ def __init__( self.pooling = pooling self.num_classes = num_classes self.image_size = image_size - self.include_rescaling = include_rescaling self.patch_size = patch_size self.classifier_activation = keras.activations.get( classifier_activation @@ -549,7 +537,6 @@ def get_config(self): self.classifier_activation ), "image_size": self.image_size, - "include_rescaling": self.include_rescaling, "patch_size": self.patch_size, } ) diff --git a/keras_hub/src/models/pali_gemma/pali_gemma_vit_test.py b/keras_hub/src/models/pali_gemma/pali_gemma_vit_test.py index 9611590da0..76d11e356b 100644 --- a/keras_hub/src/models/pali_gemma/pali_gemma_vit_test.py +++ b/keras_hub/src/models/pali_gemma/pali_gemma_vit_test.py @@ -30,23 +30,6 @@ def test_vit_encoder(self): output.shape, (batch_size, intermediate_dim, hidden_dim) ) - def test_vit_rescaling(self): - vit_encoder = PaliGemmaVit( - image_size=16, - patch_size=4, - hidden_dim=8, - num_layers=2, - num_heads=2, - intermediate_dim=16, - num_classes=32, - ) - self.assertIsNotNone(vit_encoder.get_layer("rescaling")) - with self.assertRaises(ValueError): - config = vit_encoder.get_config() - config["include_rescaling"] = False - vit_encoder = PaliGemmaVit.from_config(config) - vit_encoder.get_layer("rescaling") - def test_vision_embeddings(self): embeddings_layer = PaliGemmaVitEmbeddings( image_size=16, diff --git a/keras_hub/src/models/phi3/phi3_causal_lm.py b/keras_hub/src/models/phi3/phi3_causal_lm.py index fed4c2ea27..a60c336afb 100644 --- a/keras_hub/src/models/phi3/phi3_causal_lm.py +++ b/keras_hub/src/models/phi3/phi3_causal_lm.py @@ -41,7 +41,9 @@ def __init__(self, backbone, preprocessor=None, **kwargs): self.preprocessor = preprocessor # === Functional Model === - inputs = backbone.inputs + # This must be "backbone.input" i.e. the full input structure, + # rather than "backbone.inputs" which is the flattened list of inputs. + inputs = backbone.input hidden_states = backbone(inputs) outputs = backbone.token_embedding(hidden_states, reverse=True) super().__init__( diff --git a/keras_hub/src/models/preprocessor.py b/keras_hub/src/models/preprocessor.py index f0569a36f8..f338b45339 100644 --- a/keras_hub/src/models/preprocessor.py +++ b/keras_hub/src/models/preprocessor.py @@ -32,7 +32,7 @@ class Preprocessor(PreprocessingLayer): image_converter_cls = None def __init__(self, *args, **kwargs): - self.config_name = kwargs.pop("config_name", PREPROCESSOR_CONFIG_FILE) + self.config_file = kwargs.pop("config_file", PREPROCESSOR_CONFIG_FILE) super().__init__(*args, **kwargs) self._tokenizer = None self._image_converter = None @@ -71,6 +71,22 @@ def image_converter(self): def image_converter(self, value): self._image_converter = value + @property + def image_size(self): + """Shortcut to get/set the image size of the image converter.""" + if self.image_converter is None: + return None + return self.image_converter.image_size + + @image_size.setter + def image_size(self, value): + if self.image_converter is None: + raise ValueError( + "Cannot set `image_size` on preprocessor if `image_converter` " + " is `None`." + ) + self.image_converter.image_size = value + def get_config(self): config = super().get_config() if self.tokenizer: @@ -85,7 +101,7 @@ def get_config(self): ) config.update( { - "config_name": self.config_name, + "config_file": self.config_file, } ) return config @@ -117,7 +133,7 @@ def presets(cls): def from_preset( cls, preset, - config_name=PREPROCESSOR_CONFIG_FILE, + config_file=PREPROCESSOR_CONFIG_FILE, **kwargs, ): """Instantiate a `keras_hub.models.Preprocessor` from a model preset. @@ -167,7 +183,7 @@ def from_preset( # Detect the correct subclass if we need to. if cls.backbone_cls != backbone_cls: cls = find_subclass(preset, cls, backbone_cls) - return loader.load_preprocessor(cls, config_name, **kwargs) + return loader.load_preprocessor(cls, config_file, **kwargs) @classmethod def _add_missing_kwargs(cls, loader, kwargs): diff --git a/keras_hub/src/models/resnet/resnet_backbone.py b/keras_hub/src/models/resnet/resnet_backbone.py index bc8def804a..407ce44f5b 100644 --- a/keras_hub/src/models/resnet/resnet_backbone.py +++ b/keras_hub/src/models/resnet/resnet_backbone.py @@ -68,7 +68,7 @@ class ResNetBackbone(FeaturePyramidBackbone): input_data = np.random.uniform(0, 1, size=(2, 224, 224, 3)) # Pretrained ResNet backbone. - model = keras_hub.models.ResNetBackbone.from_preset("resnet50") + model = keras_hub.models.ResNetBackbone.from_preset("resnet_50_imagenet") model(input_data) # Randomly initialized ResNetV2 backbone with a custom config. @@ -80,7 +80,6 @@ class ResNetBackbone(FeaturePyramidBackbone): stackwise_num_strides=[1, 2, 2], block_type="basic_block", use_pre_activation=True, - pooling="avg", ) model(input_data) ``` diff --git a/keras_hub/src/models/resnet/resnet_presets.py b/keras_hub/src/models/resnet/resnet_presets.py index 58bed3d90a..c3f7c17de6 100644 --- a/keras_hub/src/models/resnet/resnet_presets.py +++ b/keras_hub/src/models/resnet/resnet_presets.py @@ -12,7 +12,7 @@ "path": "resnet", "model_card": "https://arxiv.org/abs/2110.00476", }, - "kaggle_handle": "kaggle://kerashub/resnetv1/keras/resnet_18_imagenet/3", + "kaggle_handle": "kaggle://keras/resnetv1/keras/resnet_18_imagenet/2", }, "resnet_50_imagenet": { "metadata": { @@ -25,7 +25,7 @@ "path": "resnet", "model_card": "https://arxiv.org/abs/2110.00476", }, - "kaggle_handle": "kaggle://kerashub/resnetv1/keras/resnet_50_imagenet/3", + "kaggle_handle": "kaggle://keras/resnetv1/keras/resnet_50_imagenet/2", }, "resnet_101_imagenet": { "metadata": { @@ -38,7 +38,7 @@ "path": "resnet", "model_card": "https://arxiv.org/abs/2110.00476", }, - "kaggle_handle": "kaggle://kerashub/resnetv1/keras/resnet_101_imagenet/3", + "kaggle_handle": "kaggle://keras/resnetv1/keras/resnet_101_imagenet/2", }, "resnet_152_imagenet": { "metadata": { @@ -51,7 +51,7 @@ "path": "resnet", "model_card": "https://arxiv.org/abs/2110.00476", }, - "kaggle_handle": "kaggle://kerashub/resnetv1/keras/resnet_152_imagenet/3", + "kaggle_handle": "kaggle://keras/resnetv1/keras/resnet_152_imagenet/2", }, "resnet_v2_50_imagenet": { "metadata": { @@ -64,7 +64,7 @@ "path": "resnet", "model_card": "https://arxiv.org/abs/2110.00476", }, - "kaggle_handle": "kaggle://kerashub/resnetv2/keras/resnet_v2_50_imagenet/3", + "kaggle_handle": "kaggle://keras/resnetv2/keras/resnet_v2_50_imagenet/2", }, "resnet_v2_101_imagenet": { "metadata": { @@ -77,6 +77,147 @@ "path": "resnet", "model_card": "https://arxiv.org/abs/2110.00476", }, - "kaggle_handle": "kaggle://kerashub/resnetv2/keras/resnet_v2_101_imagenet/3", + "kaggle_handle": "kaggle://keras/resnetv2/keras/resnet_v2_101_imagenet/2", + }, + "resnet_vd_18_imagenet": { + "metadata": { + "description": ( + "18-layer ResNetVD (ResNet with bag of tricks) model " + "pre-trained on the ImageNet 1k dataset at a 224x224 " + "resolution." + ), + "params": 11722824, + "official_name": "ResNet", + "path": "resnet", + "model_card": "https://arxiv.org/abs/1812.01187", + }, + "kaggle_handle": "kaggle://kerashub/resnetvd/keras/resnet_vd_18_imagenet", + }, + "resnet_vd_34_imagenet": { + "metadata": { + "description": ( + "34-layer ResNetVD (ResNet with bag of tricks) model " + "pre-trained on the ImageNet 1k dataset at a 224x224 " + "resolution." + ), + "params": 21838408, + "official_name": "ResNet", + "path": "resnet", + "model_card": "https://arxiv.org/abs/1812.01187", + }, + "kaggle_handle": "kaggle://kerashub/resnetvd/keras/resnet_vd_34_imagenet", + }, + "resnet_vd_50_imagenet": { + "metadata": { + "description": ( + "50-layer ResNetVD (ResNet with bag of tricks) model " + "pre-trained on the ImageNet 1k dataset at a 224x224 " + "resolution." + ), + "params": 25629512, + "official_name": "ResNet", + "path": "resnet", + "model_card": "https://arxiv.org/abs/1812.01187", + }, + "kaggle_handle": "kaggle://kerashub/resnetvd/keras/resnet_vd_50_imagenet", + }, + "resnet_vd_50_ssld_imagenet": { + "metadata": { + "description": ( + "50-layer ResNetVD (ResNet with bag of tricks) model " + "pre-trained on the ImageNet 1k dataset at a 224x224 " + "resolution with knowledge distillation." + ), + "params": 25629512, + "official_name": "ResNet", + "path": "resnet", + "model_card": "https://arxiv.org/abs/1812.01187", + }, + "kaggle_handle": "kaggle://kerashub/resnetvd/keras/resnet_vd_50_ssld_imagenet", + }, + "resnet_vd_50_ssld_v2_imagenet": { + "metadata": { + "description": ( + "50-layer ResNetVD (ResNet with bag of tricks) model " + "pre-trained on the ImageNet 1k dataset at a 224x224 " + "resolution with knowledge distillation and AutoAugment." + ), + "params": 25629512, + "official_name": "ResNet", + "path": "resnet", + "model_card": "https://arxiv.org/abs/1812.01187", + }, + "kaggle_handle": "kaggle://kerashub/resnetvd/keras/resnet_vd_50_ssld_v2_imagenet", + }, + "resnet_vd_50_ssld_v2_fix_imagenet": { + "metadata": { + "description": ( + "50-layer ResNetVD (ResNet with bag of tricks) model " + "pre-trained on the ImageNet 1k dataset at a 224x224 " + "resolution with knowledge distillation, AutoAugment and " + "additional fine-tuning of the classification head." + ), + "params": 25629512, + "official_name": "ResNet", + "path": "resnet", + "model_card": "https://arxiv.org/abs/1812.01187", + }, + "kaggle_handle": "kaggle://kerashub/resnetvd/keras/resnet_vd_50_ssld_v2_fix_imagenet", + }, + "resnet_vd_101_imagenet": { + "metadata": { + "description": ( + "101-layer ResNetVD (ResNet with bag of tricks) model " + "pre-trained on the ImageNet 1k dataset at a 224x224 " + "resolution." + ), + "params": 44673864, + "official_name": "ResNet", + "path": "resnet", + "model_card": "https://arxiv.org/abs/1812.01187", + }, + "kaggle_handle": "kaggle://kerashub/resnetvd/keras/resnet_vd_101_imagenet", + }, + "resnet_vd_101_ssld_imagenet": { + "metadata": { + "description": ( + "101-layer ResNetVD (ResNet with bag of tricks) model " + "pre-trained on the ImageNet 1k dataset at a 224x224 " + "resolution with knowledge distillation." + ), + "params": 44673864, + "official_name": "ResNet", + "path": "resnet", + "model_card": "https://arxiv.org/abs/1812.01187", + }, + "kaggle_handle": "kaggle://kerashub/resnetvd/keras/resnet_vd_101_ssld_imagenet", + }, + "resnet_vd_152_imagenet": { + "metadata": { + "description": ( + "152-layer ResNetVD (ResNet with bag of tricks) model " + "pre-trained on the ImageNet 1k dataset at a 224x224 " + "resolution." + ), + "params": 60363592, + "official_name": "ResNet", + "path": "resnet", + "model_card": "https://arxiv.org/abs/1812.01187", + }, + "kaggle_handle": "kaggle://kerashub/resnetvd/keras/resnet_vd_152_imagenet", + }, + "resnet_vd_200_imagenet": { + "metadata": { + "description": ( + "200-layer ResNetVD (ResNet with bag of tricks) model " + "pre-trained on the ImageNet 1k dataset at a 224x224 " + "resolution." + ), + "params": 74933064, + "official_name": "ResNet", + "path": "resnet", + "model_card": "https://arxiv.org/abs/1812.01187", + }, + "kaggle_handle": "kaggle://kerashub/resnetvd/keras/resnet_vd_200_imagenet", }, } diff --git a/keras_hub/src/models/sam/sam_image_segmenter.py b/keras_hub/src/models/sam/sam_image_segmenter.py index ed4b63ecd0..19b0035cb7 100644 --- a/keras_hub/src/models/sam/sam_image_segmenter.py +++ b/keras_hub/src/models/sam/sam_image_segmenter.py @@ -31,7 +31,7 @@ class SAMImageSegmenter(ImageSegmenter): Args: - backbone: A `keras_hub.models.VGGBackbone` instance. + backbone: A `keras_hub.models.SAMBackbone` instance. Example: Load pretrained model using `from_preset`. diff --git a/keras_hub/src/models/sam/sam_presets.py b/keras_hub/src/models/sam/sam_presets.py index 7b7986662c..60e33616e7 100644 --- a/keras_hub/src/models/sam/sam_presets.py +++ b/keras_hub/src/models/sam/sam_presets.py @@ -9,7 +9,7 @@ "path": "sam", "model_card": "https://arxiv.org/abs/2304.02643", }, - "kaggle_handle": "kaggle://kerashub/sam/keras/sam_base_sa1b/2", + "kaggle_handle": "kaggle://keras/sam/keras/sam_base_sa1b/4", }, "sam_large_sa1b": { "metadata": { @@ -19,7 +19,7 @@ "path": "sam", "model_card": "https://arxiv.org/abs/2304.02643", }, - "kaggle_handle": "kaggle://kerashub/sam/keras/sam_large_sa1b/2", + "kaggle_handle": "kaggle://keras/sam/keras/sam_large_sa1b/4", }, "sam_huge_sa1b": { "metadata": { @@ -29,6 +29,6 @@ "path": "sam", "model_card": "https://arxiv.org/abs/2304.02643", }, - "kaggle_handle": "kaggle://kerashub/sam/keras/sam_huge_sa1b/2", + "kaggle_handle": "kaggle://keras/sam/keras/sam_huge_sa1b/4", }, } diff --git a/keras_hub/src/models/segformer/__init__.py b/keras_hub/src/models/segformer/__init__.py new file mode 100644 index 0000000000..3a95690dba --- /dev/null +++ b/keras_hub/src/models/segformer/__init__.py @@ -0,0 +1,8 @@ +from keras_hub.src.models.segformer.segformer_backbone import SegFormerBackbone +from keras_hub.src.models.segformer.segformer_image_segmenter import ( + SegFormerImageSegmenter, +) +from keras_hub.src.models.segformer.segformer_presets import presets +from keras_hub.src.utils.preset_utils import register_presets + +register_presets(presets, SegFormerImageSegmenter) diff --git a/keras_hub/src/models/segformer/segformer_backbone.py b/keras_hub/src/models/segformer/segformer_backbone.py new file mode 100644 index 0000000000..f5563b4c02 --- /dev/null +++ b/keras_hub/src/models/segformer/segformer_backbone.py @@ -0,0 +1,163 @@ +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.backbone import Backbone + + +@keras_hub_export("keras_hub.models.SegFormerBackbone") +class SegFormerBackbone(Backbone): + """A Keras model implementing the SegFormer architecture for semantic segmentation. + + This class implements the majority of the SegFormer architecture described in + [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers] + (https://arxiv.org/abs/2105.15203) and [based on the TensorFlow implementation from DeepVision] + (/~https://github.com/DavidLandup0/deepvision/tree/main/deepvision/models/segmentation/segformer). + + SegFormers are meant to be used with the MixTransformer (MiT) encoder family, and + and use a very lightweight all-MLP decoder head. + + The MiT encoder uses a hierarchical transformer which outputs features at multiple scales, + similar to that of the hierarchical outputs typically associated with CNNs. + + Args: + image_encoder: `keras.Model`. The backbone network for the model that is + used as a feature extractor for the SegFormer encoder. + Should be used with the MiT backbone model + (`keras_hub.models.MiTBackbone`) which was created + specifically for SegFormers. + num_classes: int, the number of classes for the detection model, + including the background class. + projection_filters: int, number of filters in the + convolution layer projecting the concatenated features into + a segmentation map. Defaults to 256`. + + Example: + + Using the class with a custom `backbone`: + + ```python + import keras_hub + + backbone = keras_hub.models.MiTBackbone( + depths=[2, 2, 2, 2], + image_shape=(224, 224, 3), + hidden_dims=[32, 64, 160, 256], + num_layers=4, + blockwise_num_heads=[1, 2, 5, 8], + blockwise_sr_ratios=[8, 4, 2, 1], + max_drop_path_rate=0.1, + patch_sizes=[7, 3, 3, 3], + strides=[4, 2, 2, 2], + ) + + segformer_backbone = keras_hub.models.SegFormerBackbone(image_encoder=backbone, projection_filters=256) + ``` + + Using the class with a preset `backbone`: + + ```python + import keras_hub + + backbone = keras_hub.models.MiTBackbone.from_preset("mit_b0_ade20k_512") + segformer_backbone = keras_hub.models.SegFormerBackbone(image_encoder=backbone, projection_filters=256) + ``` + + """ + + def __init__( + self, + image_encoder, + projection_filters, + **kwargs, + ): + if not isinstance(image_encoder, keras.layers.Layer) or not isinstance( + image_encoder, keras.Model + ): + raise ValueError( + "Argument `image_encoder` must be a `keras.layers.Layer` instance " + f" or `keras.Model`. Received instead " + f"image_encoder={image_encoder} (of type {type(image_encoder)})." + ) + + # === Layers === + inputs = keras.layers.Input(shape=image_encoder.input.shape[1:]) + + self.feature_extractor = keras.Model( + image_encoder.inputs, image_encoder.pyramid_outputs + ) + + features = self.feature_extractor(inputs) + # Get height and width of level one output + _, height, width, _ = features["P1"].shape + + self.mlp_blocks = [] + + for feature_dim, feature in zip(image_encoder.hidden_dims, features): + self.mlp_blocks.append( + keras.layers.Dense( + projection_filters, name=f"linear_{feature_dim}" + ) + ) + + self.resizing = keras.layers.Resizing( + height, width, interpolation="bilinear" + ) + self.concat = keras.layers.Concatenate(axis=-1) + self.linear_fuse = keras.Sequential( + [ + keras.layers.Conv2D( + filters=projection_filters, kernel_size=1, use_bias=False + ), + keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.9), + keras.layers.Activation("relu"), + ] + ) + + # === Functional Model === + # Project all multi-level outputs onto + # the same dimensionality and feature map shape + multi_layer_outs = [] + for index, (feature_dim, feature) in enumerate( + zip(image_encoder.hidden_dims, features) + ): + out = self.mlp_blocks[index](features[feature]) + out = self.resizing(out) + multi_layer_outs.append(out) + + # Concat now-equal feature maps + concatenated_outs = self.concat(multi_layer_outs[::-1]) + + # Fuse concatenated features into a segmentation map + seg = self.linear_fuse(concatenated_outs) + + super().__init__( + inputs=inputs, + outputs=seg, + **kwargs, + ) + + # === Config === + self.projection_filters = projection_filters + self.image_encoder = image_encoder + + def get_config(self): + config = super().get_config() + config.update( + { + "projection_filters": self.projection_filters, + "image_encoder": keras.saving.serialize_keras_object( + self.image_encoder + ), + } + ) + return config + + @classmethod + def from_config(cls, config): + if "image_encoder" in config and isinstance( + config["image_encoder"], dict + ): + config["image_encoder"] = keras.layers.deserialize( + config["image_encoder"] + ) + return super().from_config(config) diff --git a/keras_hub/src/models/segformer/segformer_backbone_tests.py b/keras_hub/src/models/segformer/segformer_backbone_tests.py new file mode 100644 index 0000000000..22133763e7 --- /dev/null +++ b/keras_hub/src/models/segformer/segformer_backbone_tests.py @@ -0,0 +1,76 @@ +import numpy as np +import pytest +from keras import ops + +from keras_hub.api.models import MiTBackbone +from keras_hub.api.models import SegFormerBackbone +from keras_hub.src.tests.test_case import TestCase + + +class SegFormerTest(TestCase): + def setUp(self): + image_encoder = MiTBackbone( + depths=[2, 2], + image_shape=(224, 224, 3), + hidden_dims=[32, 64], + num_layers=2, + blockwise_num_heads=[1, 2], + blockwise_sr_ratios=[8, 4], + max_drop_path_rate=0.1, + patch_sizes=[7, 3], + strides=[4, 2], + ) + projection_filters = 256 + self.input_size = 224 + self.input_data = ops.ones((2, self.input_size, self.input_size, 3)) + + self.init_kwargs = { + "projection_filters": projection_filters, + "image_encoder": image_encoder, + } + + def test_segformer_backbone_construction(self): + + SegFormerBackbone( + image_encoder=self.init_kwargs["image_encoder"], + projection_filters=self.init_kwargs["projection_filters"], + ) + + @pytest.mark.large + def test_segformer_call(self): + segformer_backbone = SegFormerBackbone( + image_encoder=self.init_kwargs["image_encoder"], + projection_filters=self.init_kwargs["projection_filters"], + ) + + images = np.random.uniform(size=(2, 224, 224, 3)) + segformer_output = segformer_backbone(images) + segformer_predict = segformer_backbone.predict(images) + + assert segformer_output.shape == (2, 56, 56, 256) + assert segformer_predict.shape == (2, 56, 56, 256) + + def test_backbone_basics(self): + + self.run_vision_backbone_test( + cls=SegFormerBackbone, + init_kwargs={**self.init_kwargs}, + input_data=self.input_data, + expected_output_shape=(2, 56, 56, 256), + ) + + def test_task(self): + self.run_task_test( + cls=SegFormerBackbone, + init_kwargs={**self.init_kwargs}, + train_data=self.input_data, + expected_output_shape=(2, 56, 56, 256), + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=SegFormerBackbone, + init_kwargs={**self.init_kwargs}, + input_data=self.input_data, + ) diff --git a/keras_hub/src/models/segformer/segformer_image_converter.py b/keras_hub/src/models/segformer/segformer_image_converter.py new file mode 100644 index 0000000000..44febd6833 --- /dev/null +++ b/keras_hub/src/models/segformer/segformer_image_converter.py @@ -0,0 +1,8 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.preprocessing.image_converter import ImageConverter +from keras_hub.src.models.segformer.segformer_backbone import SegFormerBackbone + + +@keras_hub_export("keras_hub.layers.SegFormerImageConverter") +class SegFormerImageConverter(ImageConverter): + backbone_cls = SegFormerBackbone diff --git a/keras_hub/src/models/segformer/segformer_image_segmenter.py b/keras_hub/src/models/segformer/segformer_image_segmenter.py new file mode 100644 index 0000000000..1b00c7a754 --- /dev/null +++ b/keras_hub/src/models/segformer/segformer_image_segmenter.py @@ -0,0 +1,171 @@ +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.image_segmenter import ImageSegmenter +from keras_hub.src.models.segformer.segformer_backbone import SegFormerBackbone +from keras_hub.src.models.segformer.segformer_image_segmenter_preprocessor import ( + SegFormerImageSegmenterPreprocessor, +) + + +@keras_hub_export("keras_hub.models.SegFormerImageSegmenter") +class SegFormerImageSegmenter(ImageSegmenter): + """A Keras model implementing the SegFormer architecture for semantic segmentation. + + This class implements the segmentation head of the SegFormer architecture described in + [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers] + (https://arxiv.org/abs/2105.15203) and [based on the TensorFlow implementation from DeepVision] + (/~https://github.com/DavidLandup0/deepvision/tree/main/deepvision/models/segmentation/segformer). + + SegFormers are meant to be used with the MixTransformer (MiT) encoder family, and + and use a very lightweight all-MLP decoder head. + + The MiT encoder uses a hierarchical transformer which outputs features at multiple scales, + similar to that of the hierarchical outputs typically associated with CNNs. + + Args: + image_encoder: `keras.Model`. The backbone network for the model that is + used as a feature extractor for the SegFormer encoder. + It is *intended* to be used only with the MiT backbone model + (`keras_hub.models.MiTBackbone`) which was created + specifically for SegFormers. + Alternatively, can be a `keras_hub.models.Backbone` a model subclassing + `keras_hub.models.FeaturePyramidBackbone`, or a `keras.Model` + that has a `pyramid_outputs` property which is + a dictionary with keys "P2", "P3", "P4", and "P5" and layer names as values. + num_classes: int, the number of classes for the detection model, + including the background class. + projection_filters: int, number of filters in the + convolution layer projecting the concatenated features into + a segmentation map. Defaults to 256`. + + + Example: + + Using presets: + + ```python + import keras_hub + import numpy as np + + segmenter = keras_hub.models.SegFormerImageSegmenter.from_preset("segformer_b0_ade20k_512") + + images = np.random.rand(1, 512, 512, 3) + segformer(images) + ``` + + Using the SegFormer backbone: + + ```python + encoder = keras_hub.models.MiTBackbone.from_preset("mit_b0_ade20k_512") + backbone = keras_hub.models.SegFormerBackbone(image_encoder=encoder, projection_filters=256) + ``` + + Using the SegFormer backbone with a custom encoder: + + ```python + import keras + import keras_hub + import numpy as np + + images = np.ones(shape=(1, 96, 96, 3)) + labels = np.zeros(shape=(1, 96, 96, 1)) + + encoder = keras_hub.models.MiTBackbone( + depths=[2, 2, 2, 2], + image_shape=(96, 96, 3), + hidden_dims=[32, 64, 160, 256], + num_layers=4, + blockwise_num_heads=[1, 2, 5, 8], + blockwise_sr_ratios=[8, 4, 2, 1], + max_drop_path_rate=0.1, + patch_sizes=[7, 3, 3, 3], + strides=[4, 2, 2, 2], + ) + + backbone = keras_hub.models.SegFormerBackbone(image_encoder=encoder, projection_filters=256) + segformer = keras_hub.models.SegFormerImageSegmenter(backbone=backbone, num_classes=4) + + segformer(images) + ``` + + Using the segmentor class with a preset backbone: + + ```python + import keras_hub + + image_encoder = keras_hub.models.MiTBackbone.from_preset("mit_b0_ade20k_512") + backbone = keras_hub.models.SegFormerBackbone(image_encoder=encoder, projection_filters=256) + segformer = keras_hub.models.SegFormerImageSegmenter(backbone=backbone, num_classes=4) + ``` + """ + + backbone_cls = SegFormerBackbone + preprocessor_cls = SegFormerImageSegmenterPreprocessor + + def __init__( + self, + backbone, + num_classes, + preprocessor=None, + **kwargs, + ): + if not isinstance(backbone, keras.layers.Layer) or not isinstance( + backbone, keras.Model + ): + raise ValueError( + "Argument `backbone` must be a `keras.layers.Layer` instance " + f" or `keras.Model`. Received instead " + f"backbone={backbone} (of type {type(backbone)})." + ) + + # === Layers === + inputs = backbone.input + + self.backbone = backbone + self.preprocessor = preprocessor + self.dropout = keras.layers.Dropout(0.1) + self.output_segmentation_head = keras.layers.Conv2D( + filters=num_classes, kernel_size=1, strides=1 + ) + self.resizing = keras.layers.Resizing( + height=inputs.shape[1], + width=inputs.shape[2], + interpolation="bilinear", + ) + + # === Functional Model === + x = self.backbone(inputs) + x = self.dropout(x) + x = self.output_segmentation_head(x) + output = self.resizing(x) + + super().__init__( + inputs=inputs, + outputs=output, + **kwargs, + ) + + # === Config === + self.num_classes = num_classes + self.backbone = backbone + + def get_config(self): + config = super().get_config() + config.update( + { + "num_classes": self.num_classes, + "backbone": keras.saving.serialize_keras_object(self.backbone), + } + ) + return config + + @classmethod + def from_config(cls, config): + if "image_encoder" in config and isinstance( + config["image_encoder"], dict + ): + config["image_encoder"] = keras.layers.deserialize( + config["image_encoder"] + ) + return super().from_config(config) diff --git a/keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py b/keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py new file mode 100644 index 0000000000..fd8c5fba35 --- /dev/null +++ b/keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py @@ -0,0 +1,31 @@ +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.image_segmenter_preprocessor import ( + ImageSegmenterPreprocessor, +) +from keras_hub.src.models.segformer.segformer_backbone import SegFormerBackbone +from keras_hub.src.models.segformer.segformer_image_converter import ( + SegFormerImageConverter, +) +from keras_hub.src.utils.tensor_utils import preprocessing_function + +IMAGENET_DEFAULT_MEAN = [0.485, 0.456, 0.406] +IMAGENET_DEFAULT_STD = [0.229, 0.224, 0.225] + + +@keras_hub_export("keras_hub.models.SegFormerImageSegmenterPreprocessor") +class SegFormerImageSegmenterPreprocessor(ImageSegmenterPreprocessor): + backbone_cls = SegFormerBackbone + image_converter_cls = SegFormerImageConverter + + @preprocessing_function + def call(self, x, y=None, sample_weight=None): + if self.image_converter: + x = self.image_converter(x) + y = self.image_converter(y) + + x = x / 255 + x = (x - IMAGENET_DEFAULT_MEAN) / IMAGENET_DEFAULT_STD + + return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) diff --git a/keras_hub/src/models/segformer/segformer_image_segmenter_tests.py b/keras_hub/src/models/segformer/segformer_image_segmenter_tests.py new file mode 100644 index 0000000000..4ad2e8bc6f --- /dev/null +++ b/keras_hub/src/models/segformer/segformer_image_segmenter_tests.py @@ -0,0 +1,65 @@ +import numpy as np +import pytest +from keras import ops + +from keras_hub.api.models import MiTBackbone +from keras_hub.api.models import SegFormerBackbone +from keras_hub.api.models import SegFormerImageSegmenter +from keras_hub.src.tests.test_case import TestCase + + +class SegFormerTest(TestCase): + def setUp(self): + image_encoder = MiTBackbone( + depths=[2, 2], + image_shape=(224, 224, 3), + hidden_dims=[32, 64], + num_layers=2, + blockwise_num_heads=[1, 2], + blockwise_sr_ratios=[8, 4], + max_drop_path_rate=0.1, + patch_sizes=[7, 3], + strides=[4, 2], + ) + projection_filters = 256 + self.backbone = SegFormerBackbone( + image_encoder=image_encoder, projection_filters=projection_filters + ) + + self.input_size = 224 + self.input_data = ops.ones((2, self.input_size, self.input_size, 3)) + + self.init_kwargs = {"backbone": self.backbone, "num_classes": 4} + + def test_segformer_segmenter_construction(self): + SegFormerImageSegmenter(backbone=self.backbone, num_classes=4) + + @pytest.mark.large + def test_segformer_call(self): + + segformer = SegFormerImageSegmenter( + backbone=self.backbone, num_classes=4 + ) + + images = np.random.uniform(size=(2, 224, 224, 4)) + segformer_output = segformer(images) + segformer_predict = segformer.predict(images) + + assert segformer_output.shape == images.shape + assert segformer_predict.shape == images.shape + + def test_task(self): + self.run_task_test( + cls=SegFormerImageSegmenter, + init_kwargs={**self.init_kwargs}, + train_data=self.input_data, + expected_output_shape=(2, 224, 224), + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=SegFormerImageSegmenter, + init_kwargs={**self.init_kwargs}, + input_data=self.input_data, + ) diff --git a/keras_hub/src/models/segformer/segformer_presets.py b/keras_hub/src/models/segformer/segformer_presets.py new file mode 100644 index 0000000000..2c0fff0a50 --- /dev/null +++ b/keras_hub/src/models/segformer/segformer_presets.py @@ -0,0 +1,136 @@ +"""SegFormer model preset configurations.""" + +presets = { + "segformer_b0_ade20k_512": { + "metadata": { + "description": ( + "SegFormer model with MiTB0 backbone fine-tuned on ADE20k in 512x512 resolution." + ), + "params": 3719027, + "official_name": "SegFormerB0", + "path": "segformer_b0", + }, + "kaggle_handle": "kaggle://kerashub/segformer/keras/segformer_b0_ade20k_512", + }, + "segformer_b1_ade20k_512": { + "metadata": { + "description": ( + "SegFormer model with MiTB1 backbone fine-tuned on ADE20k in 512x512 resolution." + ), + "params": 13682643, + "official_name": "SegFormerB1", + "path": "segformer_b1", + }, + "kaggle_handle": "kaggle://kerashub/segformer/keras/segformer_b1_ade20k_512", + }, + "segformer_b2_ade20k_512": { + "metadata": { + "description": ( + "SegFormer model with MiTB2 backbone fine-tuned on ADE20k in 512x512 resolution." + ), + "params": 24727507, + "official_name": "SegFormerB2", + "path": "segformer_b2", + }, + "kaggle_handle": "kaggle://kerashub/segformer/keras/segformer_b2_ade20k_512", + }, + "segformer_b3_ade20k_512": { + "metadata": { + "description": ( + "SegFormer model with MiTB3 backbone fine-tuned on ADE20k in 512x512 resolution." + ), + "params": 44603347, + "official_name": "SegFormerB3", + "path": "segformer_b3", + }, + "kaggle_handle": "kaggle://kerashub/segformer/keras/segformer_b3_ade20k_512", + }, + "segformer_b4_ade20k_512": { + "metadata": { + "description": ( + "SegFormer model with MiTB4 backbone fine-tuned on ADE20k in 512x512 resolution." + ), + "params": 61373907, + "official_name": "SegFormerB4", + "path": "segformer_b4", + }, + "kaggle_handle": "kaggle://kerashub/segformer/keras/segformer_b4_ade20k_512", + }, + "segformer_b5_ade20k_640": { + "metadata": { + "description": ( + "SegFormer model with MiTB5 backbone fine-tuned on ADE20k in 640x640 resolution." + ), + "params": 81974227, + "official_name": "SegFormerB5", + "path": "segformer_b5", + }, + "kaggle_handle": "kaggle://kerashub/segformer/keras/segformer_b5_ade20k_640", + }, + "segformer_b0_cityscapes_1024": { + "metadata": { + "description": ( + "SegFormer model with MiTB0 backbone fine-tuned on Cityscapes in 1024x1024 resolution." + ), + "params": 3719027, + "official_name": "SegFormerB0", + "path": "segformer_b0", + }, + "kaggle_handle": "kaggle://kerashub/segformer/keras/segformer_b0_cityscapes_1024", + }, + "segformer_b1_cityscapes_1024": { + "metadata": { + "description": ( + "SegFormer model with MiTB1 backbone fine-tuned on Cityscapes in 1024x1024 resolution." + ), + "params": 13682643, + "official_name": "SegFormerB1", + "path": "segformer_b1", + }, + "kaggle_handle": "kaggle://kerashub/segformer/keras/segformer_b1_ade20k_512", + }, + "segformer_b2_cityscapes_1024": { + "metadata": { + "description": ( + "SegFormer model with MiTB2 backbone fine-tuned on Cityscapes in 1024x1024 resolution." + ), + "params": 24727507, + "official_name": "SegFormerB2", + "path": "segformer_b2", + }, + "kaggle_handle": "kaggle://kerashub/segformer/keras/segformer_b2_cityscapes_1024", + }, + "segformer_b3_cityscapes_1024": { + "metadata": { + "description": ( + "SegFormer model with MiTB3 backbone fine-tuned on Cityscapes in 1024x1024 resolution." + ), + "params": 44603347, + "official_name": "SegFormerB3", + "path": "segformer_b3", + }, + "kaggle_handle": "kaggle://kerashub/segformer/keras/segformer_b3_cityscapes_1024", + }, + "segformer_b4_cityscapes_1024": { + "metadata": { + "description": ( + "SegFormer model with MiTB4 backbone fine-tuned on Cityscapes in 1024x1024 resolution." + ), + "params": 61373907, + "official_name": "SegFormerB4", + "path": "segformer_b4", + }, + "kaggle_handle": "kaggle://kerashub/segformer/keras/segformer_b4_cityscapes_1024", + }, + "segformer_b5_cityscapes_1024": { + "metadata": { + "description": ( + "SegFormer model with MiTB5 backbone fine-tuned on Cityscapes in 1024x1024 resolution." + ), + "params": 81974227, + "official_name": "SegFormerB5", + "path": "segformer_b5", + }, + "kaggle_handle": "kaggle://kerashub/segformer/keras/segformer_b5_cityscapes_1024", + }, +} diff --git a/keras_hub/src/models/stable_diffusion_3/mmdit.py b/keras_hub/src/models/stable_diffusion_3/mmdit.py index 0fe78e571b..546d56f13a 100644 --- a/keras_hub/src/models/stable_diffusion_3/mmdit.py +++ b/keras_hub/src/models/stable_diffusion_3/mmdit.py @@ -2,7 +2,6 @@ import keras from keras import layers -from keras import models from keras import ops from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding @@ -11,7 +10,167 @@ from keras_hub.src.utils.keras_utils import standardize_data_format +class AdaptiveLayerNormalization(layers.Layer): + """Adaptive layer normalization. + + Args: + embedding_dim: int. The size of each embedding vector. + residual_modulation: bool. Whether to output the modulation parameters + of the residual connection within the block of the diffusion + transformers. Defaults to `False`. + **kwargs: other keyword arguments passed to `keras.layers.Layer`, + including `name`, `dtype` etc. + + References: + - [FiLM: Visual Reasoning with a General Conditioning Layer]( + https://arxiv.org/abs/1709.07871). + - [Scalable Diffusion Models with Transformers]( + https://arxiv.org/abs/2212.09748). + """ + + def __init__(self, hidden_dim, residual_modulation=False, **kwargs): + super().__init__(**kwargs) + self.hidden_dim = int(hidden_dim) + self.residual_modulation = bool(residual_modulation) + num_modulations = 6 if self.residual_modulation else 2 + + self.silu = layers.Activation("silu", dtype=self.dtype_policy) + self.dense = layers.Dense( + num_modulations * hidden_dim, dtype=self.dtype_policy, name="dense" + ) + self.norm = layers.LayerNormalization( + epsilon=1e-6, + center=False, + scale=False, + dtype="float32", + name="norm", + ) + + def build(self, inputs_shape, embeddings_shape): + self.silu.build(embeddings_shape) + self.dense.build(embeddings_shape) + self.norm.build(inputs_shape) + + def call(self, inputs, embeddings, training=None): + x = inputs + emb = self.dense(self.silu(embeddings), training=training) + if self.residual_modulation: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + ops.split(emb, 6, axis=1) + ) + else: + shift_msa, scale_msa = ops.split(emb, 2, axis=1) + scale_msa = ops.expand_dims(scale_msa, axis=1) + shift_msa = ops.expand_dims(shift_msa, axis=1) + x = ops.add( + ops.multiply( + self.norm(x, training=training), + ops.add(1.0, scale_msa), + ), + shift_msa, + ) + if self.residual_modulation: + return x, gate_msa, shift_mlp, scale_mlp, gate_mlp + else: + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_dim": self.hidden_dim, + "residual_modulation": self.residual_modulation, + } + ) + return config + + def compute_output_shape(self, inputs_shape, embeddings_shape): + if self.residual_modulation: + return ( + inputs_shape, + embeddings_shape, + embeddings_shape, + embeddings_shape, + embeddings_shape, + ) + else: + return inputs_shape + + +class MLP(layers.Layer): + """A MLP block with architecture. + + Args: + hidden_dim: int. The number of units in the hidden layers. + output_dim: int. The number of units in the output layer. + activation: str of callable. Activation to use in the hidden layers. + Default to `None`. + """ + + def __init__(self, hidden_dim, output_dim, activation=None, **kwargs): + super().__init__(**kwargs) + self.hidden_dim = int(hidden_dim) + self.output_dim = int(output_dim) + self.activation = keras.activations.get(activation) + + self.dense1 = layers.Dense( + hidden_dim, + activation=self.activation, + dtype=self.dtype_policy, + name="dense1", + ) + self.dense2 = layers.Dense( + output_dim, + activation=None, + dtype=self.dtype_policy, + name="dense2", + ) + + def build(self, inputs_shape): + self.dense1.build(inputs_shape) + inputs_shape = self.dense1.compute_output_shape(inputs_shape) + self.dense2.build(inputs_shape) + + def call(self, inputs, training=None): + x = self.dense1(inputs, training=training) + return self.dense2(x, training=training) + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_dim": self.hidden_dim, + "output_dim": self.output_dim, + "activation": keras.activations.serialize(self.activation), + } + ) + return config + + def compute_output_shape(self, inputs_shape): + outputs_shape = list(inputs_shape) + outputs_shape[-1] = self.output_dim + return outputs_shape + + class PatchEmbedding(layers.Layer): + """A layer that converts images into patches. + + Args: + patch_size: int. The size of one side of each patch. + hidden_dim: int. The number of units in the hidden layers. + data_format: `None` or str. If specified, either `"channels_last"` or + `"channels_first"`. The ordering of the dimensions in the + inputs. `"channels_last"` corresponds to inputs with shape + `(batch_size, height, width, channels)` + while `"channels_first"` corresponds to inputs with shape + `(batch_size, channels, height, width)`. It defaults to the + `image_data_format` value found in your Keras config file at + `~/.keras/keras.json`. If you never set it, then it will be + `"channels_last"`. + **kwargs: other keyword arguments passed to `keras.layers.Layer`, + including `name`, `dtype` etc. + """ + def __init__(self, patch_size, hidden_dim, data_format=None, **kwargs): super().__init__(**kwargs) self.patch_size = int(patch_size) @@ -48,6 +207,15 @@ def get_config(self): class AdjustablePositionEmbedding(PositionEmbedding): + """A position embedding layer with adjustable height and width. + + The embedding will be cropped to match the input dimensions. + + Args: + height: int. The maximum height of the embedding. + width: int. The maximum width of the embedding. + """ + def __init__( self, height, @@ -84,11 +252,36 @@ def call(self, inputs, height=None, width=None): position_embedding = ops.expand_dims(position_embedding, axis=0) return position_embedding + def get_config(self): + config = super().get_config() + del config["sequence_length"] + config.update( + { + "height": self.height, + "width": self.width, + } + ) + return config + def compute_output_shape(self, input_shape): return input_shape class TimestepEmbedding(layers.Layer): + """A layer which learns embedding for input timesteps. + + Args: + embedding_dim: int. The size of the embedding. + frequency_dim: int. The size of the frequency. + max_period: int. Controls the maximum frequency of the embeddings. + **kwargs: other keyword arguments passed to `keras.layers.Layer`, + including `name`, `dtype` etc. + + Reference: + - [Denoising Diffusion Probabilistic Models]( + https://arxiv.org/abs/2006.11239). + """ + def __init__( self, embedding_dim, frequency_dim=256, max_period=10000, **kwargs ): @@ -96,17 +289,23 @@ def __init__( self.embedding_dim = int(embedding_dim) self.frequency_dim = int(frequency_dim) self.max_period = float(max_period) - self.half_frequency_dim = self.frequency_dim // 2 - - self.mlp = models.Sequential( - [ - layers.Dense( - embedding_dim, activation="silu", dtype=self.dtype_policy - ), - layers.Dense( - embedding_dim, activation=None, dtype=self.dtype_policy + # Precomputed `freq`. + half_frequency_dim = frequency_dim // 2 + self.freq = ops.exp( + ops.divide( + ops.multiply( + -math.log(max_period), + ops.arange(0, half_frequency_dim, dtype="float32"), ), - ], + half_frequency_dim, + ) + ) + + self.mlp = MLP( + embedding_dim, + embedding_dim, + "silu", + dtype=self.dtype_policy, name="mlp", ) @@ -118,16 +317,7 @@ def build(self, inputs_shape): def _create_timestep_embedding(self, inputs): compute_dtype = keras.backend.result_type(self.compute_dtype, "float32") x = ops.cast(inputs, compute_dtype) - freqs = ops.exp( - ops.divide( - ops.multiply( - -math.log(self.max_period), - ops.arange(0, self.half_frequency_dim, dtype="float32"), - ), - self.half_frequency_dim, - ) - ) - freqs = ops.cast(freqs, compute_dtype) + freqs = ops.cast(self.freq, compute_dtype) x = ops.multiply(x, ops.expand_dims(freqs, axis=0)) embedding = ops.concatenate([ops.cos(x), ops.sin(x)], axis=-1) if self.frequency_dim % 2 != 0: @@ -143,6 +333,7 @@ def get_config(self): config.update( { "embedding_dim": self.embedding_dim, + "frequency_dim": self.frequency_dim, "max_period": self.max_period, } ) @@ -155,6 +346,18 @@ def compute_output_shape(self, inputs_shape): class DismantledBlock(layers.Layer): + """A dismantled block used to compute pre- and post-attention. + + Args: + num_heads: int. Number of attention heads. + hidden_dim: int. The number of units in the hidden layers. + mlp_ratio: float. The expansion ratio of `MLP`. + use_projection: bool. Whether to use an attention projection layer at + the end of the block. + **kwargs: other keyword arguments passed to `keras.layers.Layer`, + including `name`, `dtype` etc. + """ + def __init__( self, num_heads, @@ -173,25 +376,18 @@ def __init__( self.head_dim = head_dim mlp_hidden_dim = int(hidden_dim * mlp_ratio) self.mlp_hidden_dim = mlp_hidden_dim - num_modulations = 6 if use_projection else 2 - self.num_modulations = num_modulations - - self.adaptive_norm_modulation = models.Sequential( - [ - layers.Activation("silu", dtype=self.dtype_policy), - layers.Dense( - num_modulations * hidden_dim, dtype=self.dtype_policy - ), - ], - name="adaptive_norm_modulation", - ) - self.norm1 = layers.LayerNormalization( - epsilon=1e-6, - center=False, - scale=False, - dtype="float32", - name="norm1", - ) + + if use_projection: + self.ada_layer_norm = AdaptiveLayerNormalization( + hidden_dim, + residual_modulation=True, + dtype=self.dtype_policy, + name="ada_layer_norm", + ) + else: + self.ada_layer_norm = AdaptiveLayerNormalization( + hidden_dim, dtype=self.dtype_policy, name="ada_layer_norm" + ) self.attention_qkv = layers.Dense( hidden_dim * 3, dtype=self.dtype_policy, name="attention_qkv" ) @@ -206,73 +402,45 @@ def __init__( dtype="float32", name="norm2", ) - self.mlp = models.Sequential( - [ - layers.Dense( - mlp_hidden_dim, - activation=gelu_approximate, - dtype=self.dtype_policy, - ), - layers.Dense( - hidden_dim, - dtype=self.dtype_policy, - ), - ], + self.mlp = MLP( + mlp_hidden_dim, + hidden_dim, + gelu_approximate, + dtype=self.dtype_policy, name="mlp", ) def build(self, inputs_shape, timestep_embedding): - self.adaptive_norm_modulation.build(timestep_embedding) + self.ada_layer_norm.build(inputs_shape, timestep_embedding) self.attention_qkv.build(inputs_shape) - self.norm1.build(inputs_shape) if self.use_projection: self.attention_proj.build(inputs_shape) self.norm2.build(inputs_shape) self.mlp.build(inputs_shape) def _modulate(self, inputs, shift, scale): - shift = ops.expand_dims(shift, axis=1) - scale = ops.expand_dims(scale, axis=1) + inputs = ops.cast(inputs, self.compute_dtype) + shift = ops.cast(shift, self.compute_dtype) + scale = ops.cast(scale, self.compute_dtype) return ops.add(ops.multiply(inputs, ops.add(scale, 1.0)), shift) def _compute_pre_attention(self, inputs, timestep_embedding, training=None): batch_size = ops.shape(inputs)[0] if self.use_projection: - modulation = self.adaptive_norm_modulation( - timestep_embedding, training=training - ) - modulation = ops.reshape( - modulation, (batch_size, 6, self.hidden_dim) - ) - ( - shift_msa, - scale_msa, - gate_msa, - shift_mlp, - scale_mlp, - gate_mlp, - ) = ops.unstack(modulation, 6, axis=1) - qkv = self.attention_qkv( - self._modulate(self.norm1(inputs), shift_msa, scale_msa), - training=training, + x, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.ada_layer_norm( + inputs, timestep_embedding, training=training ) + qkv = self.attention_qkv(x, training=training) qkv = ops.reshape( qkv, (batch_size, -1, 3, self.num_heads, self.head_dim) ) q, k, v = ops.unstack(qkv, 3, axis=2) return (q, k, v), (inputs, gate_msa, shift_mlp, scale_mlp, gate_mlp) else: - modulation = self.adaptive_norm_modulation( - timestep_embedding, training=training - ) - modulation = ops.reshape( - modulation, (batch_size, 2, self.hidden_dim) - ) - shift_msa, scale_msa = ops.unstack(modulation, 2, axis=1) - qkv = self.attention_qkv( - self._modulate(self.norm1(inputs), shift_msa, scale_msa), - training=training, + x = self.ada_layer_norm( + inputs, timestep_embedding, training=training ) + qkv = self.attention_qkv(x, training=training) qkv = ops.reshape( qkv, (batch_size, -1, 3, self.num_heads, self.head_dim) ) @@ -283,12 +451,16 @@ def _compute_post_attention( self, inputs, inputs_intermediates, training=None ): x, gate_msa, shift_mlp, scale_mlp, gate_mlp = inputs_intermediates + gate_msa = ops.expand_dims(gate_msa, axis=1) + shift_mlp = ops.expand_dims(shift_mlp, axis=1) + scale_mlp = ops.expand_dims(scale_mlp, axis=1) + gate_mlp = ops.expand_dims(gate_mlp, axis=1) attn = self.attention_proj(inputs, training=training) - x = ops.add(x, ops.multiply(ops.expand_dims(gate_msa, axis=1), attn)) + x = ops.add(x, ops.multiply(gate_msa, attn)) x = ops.add( x, ops.multiply( - ops.expand_dims(gate_mlp, axis=1), + gate_mlp, self.mlp( self._modulate(self.norm2(x), shift_mlp, scale_mlp), training=training, @@ -328,6 +500,27 @@ def get_config(self): class MMDiTBlock(layers.Layer): + """A MMDiT block consisting of two `DismantledBlock` layers. + + One `DismantledBlock` processes the input latents, and the other processes + the context embedding. This block integrates two modalities within the + attention operation, allowing each representation to operate in its own + space while considering the other. + + Args: + num_heads: int. Number of attention heads. + hidden_dim: int. The number of units in the hidden layers. + mlp_ratio: float. The expansion ratio of `MLP`. + use_context_projection: bool. Whether to use an attention projection + layer at the end of the context block. + **kwargs: other keyword arguments passed to `keras.layers.Layer`, + including `name`, `dtype` etc. + + Reference: + - [Scaling Rectified Flow Transformers for High-Resolution Image Synthesis]( + https://arxiv.org/abs/2403.03206) + """ + def __init__( self, num_heads, @@ -345,8 +538,6 @@ def __init__( head_dim = hidden_dim // num_heads self.head_dim = head_dim self._inverse_sqrt_key_dim = 1.0 / math.sqrt(head_dim) - self._dot_product_equation = "aecd,abcd->acbe" - self._combine_equation = "acbe,aecd->abcd" self.x_block = DismantledBlock( num_heads=num_heads, @@ -371,20 +562,18 @@ def build(self, inputs_shape, context_shape, timestep_embedding_shape): self.context_block.build(context_shape, timestep_embedding_shape) def _compute_attention(self, query, key, value): - query = ops.multiply( - query, ops.cast(self._inverse_sqrt_key_dim, query.dtype) - ) - attention_scores = ops.einsum(self._dot_product_equation, key, query) - attention_scores = self.softmax(attention_scores) - attention_scores = ops.cast(attention_scores, self.compute_dtype) - attention_output = ops.einsum( - self._combine_equation, attention_scores, value - ) - batch_size = ops.shape(attention_output)[0] - attention_output = ops.reshape( - attention_output, (batch_size, -1, self.num_heads * self.head_dim) - ) - return attention_output + # Ref: jax.nn.dot_product_attention + # /~https://github.com/jax-ml/jax/blob/db89c245ac66911c98f265a05956fdfa4bc79d83/jax/_src/nn/functions.py#L846 + batch_size = ops.shape(query)[0] + logits = ops.einsum("BTNH,BSNH->BNTS", query, key) + logits = ops.multiply(logits, self._inverse_sqrt_key_dim) + probs = self.softmax(logits) + probs = ops.cast(probs, self.compute_dtype) + encoded = ops.einsum("BNTS,BSNH->BTNH", probs, value) + encoded = ops.reshape( + encoded, (batch_size, -1, self.num_heads * self.head_dim) + ) + return encoded def call(self, inputs, context, timestep_embedding, training=None): # Compute pre-attention. @@ -453,74 +642,16 @@ def compute_output_shape( return inputs_shape -class OutputLayer(layers.Layer): - def __init__(self, hidden_dim, output_dim, **kwargs): - super().__init__(**kwargs) - self.hidden_dim = hidden_dim - self.output_dim = output_dim - num_modulation = 2 - - self.adaptive_norm_modulation = models.Sequential( - [ - layers.Activation("silu", dtype=self.dtype_policy), - layers.Dense( - num_modulation * hidden_dim, dtype=self.dtype_policy - ), - ], - name="adaptive_norm_modulation", - ) - self.norm = layers.LayerNormalization( - epsilon=1e-6, - center=False, - scale=False, - dtype="float32", - name="norm", - ) - self.output_dense = layers.Dense( - output_dim, - use_bias=True, - dtype=self.dtype_policy, - name="output_dense", - ) - - def build(self, inputs_shape, timestep_embedding_shape): - self.adaptive_norm_modulation.build(timestep_embedding_shape) - self.norm.build(inputs_shape) - self.output_dense.build(inputs_shape) - - def _modulate(self, inputs, shift, scale): - shift = ops.expand_dims(shift, axis=1) - scale = ops.expand_dims(scale, axis=1) - return ops.add(ops.multiply(inputs, ops.add(scale, 1.0)), shift) - - def call(self, inputs, timestep_embedding, training=None): - x = inputs - modulation = self.adaptive_norm_modulation( - timestep_embedding, training=training - ) - modulation = ops.reshape(modulation, (-1, 2, self.hidden_dim)) - shift, scale = ops.unstack(modulation, 2, axis=1) - x = self._modulate(self.norm(x), shift, scale) - x = self.output_dense(x, training=training) - return x - - def get_config(self): - config = super().get_config() - config.update( - { - "hidden_dim": self.hidden_dim, - "output_dim": self.output_dim, - } - ) - return config - - def compute_output_shape(self, inputs_shape): - outputs_shape = list(inputs_shape) - outputs_shape[-1] = self.output_dim - return outputs_shape +class Unpatch(layers.Layer): + """A layer that reconstructs the image from hidden patches. + Args: + patch_size: int. The size of each square patch in the input image. + output_dim: int. The number of units in the output layer. + **kwargs: other keyword arguments passed to `keras.layers.Layer`, + including `name`, `dtype` etc. + """ -class Unpatch(layers.Layer): def __init__(self, patch_size, output_dim, **kwargs): super().__init__(**kwargs) self.patch_size = int(patch_size) @@ -556,7 +687,7 @@ def compute_output_shape(self, inputs_shape): class MMDiT(Backbone): - """Multimodal Diffusion Transformer (MMDiT) model for Stable Diffusion 3. + """A Multimodal Diffusion Transformer (MMDiT) model. MMDiT is introduced in [ Scaling Rectified Flow Transformers for High-Resolution Image Synthesis]( @@ -636,12 +767,8 @@ def __init__( dtype=dtype, name="context_embedding", ) - self.vector_embedding = models.Sequential( - [ - layers.Dense(hidden_dim, activation="silu", dtype=dtype), - layers.Dense(hidden_dim, activation=None, dtype=dtype), - ], - name="vector_embedding", + self.vector_embedding = MLP( + hidden_dim, hidden_dim, "silu", dtype=dtype, name="vector_embedding" ) self.vector_embedding_add = layers.Add( dtype=dtype, name="vector_embedding_add" @@ -660,8 +787,11 @@ def __init__( ) for i in range(num_layers) ] - self.output_layer = OutputLayer( - hidden_dim, output_dim_in_final, dtype=dtype, name="output_layer" + self.output_ada_layer_norm = AdaptiveLayerNormalization( + hidden_dim, dtype=dtype, name="output_ada_layer_norm" + ) + self.output_dense = layers.Dense( + output_dim_in_final, dtype=dtype, name="output_dense" ) self.unpatch = Unpatch( patch_size, output_dim, dtype=dtype, name="unpatch" @@ -696,7 +826,8 @@ def __init__( x = block(x, context, timestep_embedding) # Output layer. - x = self.output_layer(x, timestep_embedding) + x = self.output_ada_layer_norm(x, timestep_embedding) + x = self.output_dense(x) outputs = self.unpatch(x, height=image_height, width=image_width) super().__init__( diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py index c5930a3460..4dd3e4403d 100644 --- a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py @@ -51,11 +51,52 @@ def compute_output_shape(self, inputs_shape): return (inputs_shape[0], self.hidden_dim) -class ClassifierFreeGuidanceConcatenate(layers.Layer): - def __init__(self, axis=0, **kwargs): - super().__init__(**kwargs) - self.axis = axis +class CLIPConcatenate(layers.Layer): + def call( + self, + clip_l_projection, + clip_g_projection, + clip_l_intermediate_output, + clip_g_intermediate_output, + padding, + ): + pooled_embeddings = ops.concatenate( + [clip_l_projection, clip_g_projection], axis=-1 + ) + embeddings = ops.concatenate( + [clip_l_intermediate_output, clip_g_intermediate_output], axis=-1 + ) + embeddings = ops.pad(embeddings, [[0, 0], [0, 0], [0, padding]]) + return pooled_embeddings, embeddings + + +class ImageRescaling(layers.Rescaling): + """Rescales inputs from image space to latent space. + + The rescaling is performed using the formula: `(inputs - offset) * scale`. + """ + + def call(self, inputs): + dtype = self.compute_dtype + scale = self.backend.cast(self.scale, dtype) + offset = self.backend.cast(self.offset, dtype) + return (self.backend.cast(inputs, dtype) - offset) * scale + + +class LatentRescaling(layers.Rescaling): + """Rescales inputs from latent space to image space. + The rescaling is performed using the formula: `inputs / scale + offset`. + """ + + def call(self, inputs): + dtype = self.compute_dtype + scale = self.backend.cast(self.scale, dtype) + offset = self.backend.cast(self.offset, dtype) + return (self.backend.cast(inputs, dtype) / scale) + offset + + +class ClassifierFreeGuidanceConcatenate(layers.Layer): def call( self, latents, @@ -66,20 +107,16 @@ def call( timestep, ): timestep = ops.broadcast_to(timestep, ops.shape(latents)[:1]) - latents = ops.concatenate([latents, latents], axis=self.axis) + latents = ops.concatenate([latents, latents], axis=0) contexts = ops.concatenate( - [positive_contexts, negative_contexts], axis=self.axis + [positive_contexts, negative_contexts], axis=0 ) pooled_projections = ops.concatenate( - [positive_pooled_projections, negative_pooled_projections], - axis=self.axis, + [positive_pooled_projections, negative_pooled_projections], axis=0 ) - timesteps = ops.concatenate([timestep, timestep], axis=self.axis) + timesteps = ops.concatenate([timestep, timestep], axis=0) return latents, contexts, pooled_projections, timesteps - def get_config(self): - return super().get_config() - class ClassifierFreeGuidance(layers.Layer): """Perform classifier free guidance. @@ -100,9 +137,6 @@ class ClassifierFreeGuidance(layers.Layer): - [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). """ - def __init__(self, **kwargs): - super().__init__(**kwargs) - def call(self, inputs, guidance_scale): positive_noise, negative_noise = ops.split(inputs, 2, axis=0) return ops.add( @@ -112,9 +146,6 @@ def call(self, inputs, guidance_scale): ), ) - def get_config(self): - return super().get_config() - def compute_output_shape(self, inputs_shape): outputs_shape = list(inputs_shape) if outputs_shape[0] is not None: @@ -142,16 +173,10 @@ class EulerStep(layers.Layer): https://arxiv.org/abs/2206.00364). """ - def __init__(self, **kwargs): - super().__init__(**kwargs) - def call(self, latents, noise_residual, sigma, sigma_next): sigma_diff = ops.subtract(sigma_next, sigma) return ops.add(latents, ops.multiply(sigma_diff, noise_residual)) - def get_config(self): - return super().get_config() - def compute_output_shape(self, latents_shape): return latents_shape @@ -190,8 +215,8 @@ class StableDiffusion3Backbone(Backbone): model. Defaults to `1000`. shift: float. The shift value for the timestep schedule. Defaults to `3.0`. - height: optional int. The output height of the image. - width: optional int. The output width of the image. + image_shape: tuple. The input shape without the batch size. Defaults to + `(1024, 1024, 3)`. data_format: `None` or str. If specified, either `"channels_last"` or `"channels_first"`. The ordering of the dimensions in the inputs. `"channels_last"` corresponds to inputs with shape @@ -245,23 +270,21 @@ def __init__( output_channels=3, num_train_timesteps=1000, shift=3.0, - height=None, - width=None, + image_shape=(1024, 1024, 3), data_format=None, dtype=None, **kwargs, ): - height = int(height or 1024) - width = int(width or 1024) - if height % 8 != 0 or width % 8 != 0: - raise ValueError( - "`height` and `width` must be divisible by 8. " - f"Received: height={height}, width={width}" - ) data_format = standardize_data_format(data_format) if data_format != "channels_last": raise NotImplementedError - image_shape = (height, width, int(vae.input_channels)) + height = image_shape[0] + width = image_shape[1] + if height % 8 != 0 or width % 8 != 0: + raise ValueError( + "height and width in `image_shape` must be divisible by 8. " + f"Received: image_shape={image_shape}" + ) latent_shape = (height // 8, width // 8, int(latent_channels)) context_shape = (None, 4096 if t5 is None else t5.hidden_dim) pooled_projection_shape = (clip_l.hidden_dim + clip_g.hidden_dim,) @@ -272,12 +295,13 @@ def __init__( self.clip_l_projection = CLIPProjection( clip_l.hidden_dim, dtype=dtype, name="clip_l_projection" ) - self.clip_l_projection.build([None, clip_l.hidden_dim], None) self.clip_g = clip_g self.clip_g_projection = CLIPProjection( clip_g.hidden_dim, dtype=dtype, name="clip_g_projection" ) - self.clip_g_projection.build([None, clip_g.hidden_dim], None) + self.clip_concatenate = CLIPConcatenate( + dtype=dtype, name="clip_concatenate" + ) self.t5 = t5 self.diffuser = MMDiT( mmdit_patch_size, @@ -293,6 +317,12 @@ def __init__( name="diffuser", ) self.vae = vae + self.cfg_concat = ClassifierFreeGuidanceConcatenate( + dtype=dtype, name="classifier_free_guidance_concat" + ) + self.cfg = ClassifierFreeGuidance( + dtype=dtype, name="classifier_free_guidance" + ) # Set `dtype="float32"` to ensure the high precision for the noise # residual. self.scheduler = FlowMatchEulerDiscreteScheduler( @@ -301,17 +331,17 @@ def __init__( dtype="float32", name="scheduler", ) - self.cfg_concat = ClassifierFreeGuidanceConcatenate( - dtype="float32", name="classifier_free_guidance_concat" - ) - self.cfg = ClassifierFreeGuidance( - dtype="float32", name="classifier_free_guidance" - ) self.euler_step = EulerStep(dtype="float32", name="euler_step") - self.latent_rescaling = layers.Rescaling( - scale=1.0 / self.vae.scale, + self.image_rescaling = ImageRescaling( + scale=self.vae.scale, offset=self.vae.shift, - dtype="float32", + dtype=dtype, + name="image_rescaling", + ) + self.latent_rescaling = LatentRescaling( + scale=self.vae.scale, + offset=self.vae.shift, + dtype=dtype, name="latent_rescaling", ) @@ -420,8 +450,7 @@ def __init__( self.output_channels = output_channels self.num_train_timesteps = num_train_timesteps self.shift = shift - self.height = height - self.width = width + self.image_shape = image_shape @property def latent_shape(self): @@ -440,8 +469,12 @@ def encode_text_step(self, token_ids, negative_token_ids): t5_hidden_dim = self.t5_hidden_dim def encode(token_ids): - clip_l_outputs = self.clip_l(token_ids["clip_l"], training=False) - clip_g_outputs = self.clip_g(token_ids["clip_g"], training=False) + clip_l_outputs = self.clip_l( + {"token_ids": token_ids["clip_l"]}, training=False + ) + clip_g_outputs = self.clip_g( + {"token_ids": token_ids["clip_g"]}, training=False + ) clip_l_projection = self.clip_l_projection( clip_l_outputs["sequence_output"], token_ids["clip_l"], @@ -452,23 +485,21 @@ def encode(token_ids): token_ids["clip_g"], training=False, ) - pooled_embeddings = ops.concatenate( - [clip_l_projection, clip_g_projection], - axis=-1, - ) - embeddings = ops.concatenate( - [ - clip_l_outputs["intermediate_output"], - clip_g_outputs["intermediate_output"], - ], - axis=-1, - ) - embeddings = ops.pad( - embeddings, - [[0, 0], [0, 0], [0, t5_hidden_dim - clip_hidden_dim]], + pooled_embeddings, embeddings = self.clip_concatenate( + clip_l_projection, + clip_g_projection, + clip_l_outputs["intermediate_output"], + clip_g_outputs["intermediate_output"], + padding=t5_hidden_dim - clip_hidden_dim, ) if self.t5 is not None: - t5_outputs = self.t5(token_ids["t5"], training=False) + t5_outputs = self.t5( + { + "token_ids": token_ids["t5"], + "padding_mask": ops.ones_like(token_ids["t5"]), + }, + training=False, + ) embeddings = ops.concatenate([embeddings, t5_outputs], axis=-2) else: padded_size = self.clip_l.max_sequence_length @@ -490,9 +521,7 @@ def encode(token_ids): def encode_image_step(self, images): latents = self.vae.encode(images) - return ops.multiply( - ops.subtract(latents, self.vae.shift), self.vae.scale - ) + return self.image_rescaling(latents) def add_noise_step(self, latents, noises, step, num_steps): return self.scheduler.add_noise(latents, noises, step, num_steps) @@ -553,8 +582,7 @@ def get_config(self): "output_channels": self.output_channels, "num_train_timesteps": self.num_train_timesteps, "shift": self.shift, - "height": self.height, - "width": self.width, + "image_shape": self.image_shape, } ) return config diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone_test.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone_test.py index 37723b0b5a..77415a6eec 100644 --- a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone_test.py +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone_test.py @@ -11,7 +11,8 @@ class StableDiffusion3BackboneTest(TestCase): def setUp(self): - height, width = 64, 64 + image_shape = (64, 64, 3) + height, width = image_shape[0], image_shape[1] vae = VAEBackbone( [32, 32, 32, 32], [1, 1, 1, 1], @@ -36,8 +37,7 @@ def setUp(self): "vae": vae, "clip_l": clip_l, "clip_g": clip_g, - "height": height, - "width": width, + "image_shape": image_shape, } self.input_data = { "images": ops.ones((2, height, width, 3)), @@ -82,7 +82,6 @@ def test_all_presets(self): preset=preset, input_data=self.input_data, init_kwargs={ - "height": self.init_kwargs["height"], - "width": self.init_kwargs["width"], + "image_shape": self.init_kwargs["image_shape"], }, ) diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py new file mode 100644 index 0000000000..285ba834b4 --- /dev/null +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py @@ -0,0 +1,171 @@ +from keras import ops + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.image_to_image import ImageToImage +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import ( + StableDiffusion3Backbone, +) +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image_preprocessor import ( + StableDiffusion3TextToImagePreprocessor, +) + + +@keras_hub_export("keras_hub.models.StableDiffusion3ImageToImage") +class StableDiffusion3ImageToImage(ImageToImage): + """An end-to-end Stable Diffusion 3 model for image-to-image generation. + + This model has a `generate()` method, which generates images based + on a combination of a reference image and a text prompt. + + Args: + backbone: A `keras_hub.models.StableDiffusion3Backbone` instance. + preprocessor: A + `keras_hub.models.StableDiffusion3TextToImagePreprocessor` instance. + + Examples: + + Use `generate()` to do image generation. + ```python + image_to_image = keras_hub.models.StableDiffusion3ImageToImage.from_preset( + "stable_diffusion_3_medium", image_shape=(512, 512, 3) + ) + image_to_image.generate( + { + "images": np.ones((512, 512, 3), dtype="float32"), + "prompts": "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", + } + ) + + # Generate with batched prompts. + image_to_image.generate( + { + "images": np.ones((2, 512, 512, 3), dtype="float32"), + "prompts": ["cute wallpaper art of a cat", "cute wallpaper art of a dog"], + } + ) + + # Generate with different `num_steps`, `guidance_scale` and `strength`. + image_to_image.generate( + { + "images": np.ones((512, 512, 3), dtype="float32"), + "prompts": "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", + } + num_steps=50, + guidance_scale=5.0, + strength=0.6, + ) + + # Generate with `negative_prompts`. + text_to_image.generate( + { + "images": np.ones((512, 512, 3), dtype="float32"), + "prompts": "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", + "negative_prompts": "green color", + } + ) + ``` + """ + + backbone_cls = StableDiffusion3Backbone + preprocessor_cls = StableDiffusion3TextToImagePreprocessor + + def __init__( + self, + backbone, + preprocessor, + **kwargs, + ): + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor + + # === Functional Model === + inputs = backbone.input + outputs = backbone.output + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + def fit(self, *args, **kwargs): + raise NotImplementedError( + "Currently, `fit` is not supported for " + "`StableDiffusion3ImageToImage`." + ) + + def generate_step( + self, + images, + noises, + token_ids, + starting_step, + num_steps, + guidance_scale, + ): + """A compilable generation function for batched of inputs. + + This function represents the inner, XLA-compilable, generation function + for batched inputs. + + Args: + images: A (batch_size, image_height, image_width, 3) tensor + containing the reference images. + noises: A (batch_size, latent_height, latent_width, channels) tensor + containing the noises to be added to the latents. Typically, + this tensor is sampled from the Gaussian distribution. + token_ids: A pair of (batch_size, num_tokens) tensor containing the + tokens based on the input prompts and negative prompts. + starting_step: int. The number of the starting diffusion step. + num_steps: int. The number of diffusion steps to take. + guidance_scale: float. The classifier free guidance scale defined in + [Classifier-Free Diffusion Guidance]( + https://arxiv.org/abs/2207.12598). Higher scale encourages to + generate images that are closely linked to prompts, usually at + the expense of lower image quality. + """ + token_ids, negative_token_ids = token_ids + + # Encode images. + latents = self.backbone.encode_image_step(images) + + # Add noises to latents. + latents = self.backbone.add_noise_step( + latents, noises, starting_step, num_steps + ) + + # Encode inputs. + embeddings = self.backbone.encode_text_step( + token_ids, negative_token_ids + ) + + # Denoise. + def body_fun(step, latents): + return self.backbone.denoise_step( + latents, + embeddings, + step, + num_steps, + guidance_scale, + ) + + latents = ops.fori_loop(starting_step, num_steps, body_fun, latents) + + # Decode. + return self.backbone.decode_step(latents) + + def generate( + self, + inputs, + num_steps=50, + guidance_scale=7.0, + strength=0.8, + seed=None, + ): + return super().generate( + inputs, + num_steps=num_steps, + guidance_scale=guidance_scale, + strength=strength, + seed=seed, + ) diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image_test.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image_test.py new file mode 100644 index 0000000000..8fa5b167ab --- /dev/null +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image_test.py @@ -0,0 +1,180 @@ +import keras +import pytest +from keras import ops + +from keras_hub.src.models.clip.clip_preprocessor import CLIPPreprocessor +from keras_hub.src.models.clip.clip_text_encoder import CLIPTextEncoder +from keras_hub.src.models.clip.clip_tokenizer import CLIPTokenizer +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import ( + StableDiffusion3Backbone, +) +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_image_to_image import ( + StableDiffusion3ImageToImage, +) +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image_preprocessor import ( + StableDiffusion3TextToImagePreprocessor, +) +from keras_hub.src.models.vae.vae_backbone import VAEBackbone +from keras_hub.src.tests.test_case import TestCase + + +class StableDiffusion3ImageToImageTest(TestCase): + def setUp(self): + # Instantiate the preprocessor. + vocab = ["air", "plane", "port"] + vocab += ["<|endoftext|>", "<|startoftext|>"] + vocab = dict([(token, i) for i, token in enumerate(vocab)]) + merges = ["a i", "p l", "n e", "p o", "r t", "ai r", "pl a"] + merges += ["po rt", "pla ne"] + clip_l_tokenizer = CLIPTokenizer(vocab, merges, pad_with_end_token=True) + clip_g_tokenizer = CLIPTokenizer(vocab, merges) + clip_l_preprocessor = CLIPPreprocessor(clip_l_tokenizer) + clip_g_preprocessor = CLIPPreprocessor(clip_g_tokenizer) + self.preprocessor = StableDiffusion3TextToImagePreprocessor( + clip_l_preprocessor, clip_g_preprocessor + ) + + self.backbone = StableDiffusion3Backbone( + mmdit_patch_size=2, + mmdit_hidden_dim=16 * 2, + mmdit_num_layers=2, + mmdit_num_heads=2, + mmdit_position_size=192, + vae=VAEBackbone( + [32, 32, 32, 32], + [1, 1, 1, 1], + [32, 32, 32, 32], + [1, 1, 1, 1], + # Use `mode` generate a deterministic output. + sampler_method="mode", + name="vae", + ), + clip_l=CLIPTextEncoder( + 20, 64, 64, 2, 2, 128, "quick_gelu", -2, name="clip_l" + ), + clip_g=CLIPTextEncoder( + 20, 128, 128, 2, 2, 256, "gelu", -2, name="clip_g" + ), + image_shape=(64, 64, 3), + ) + self.init_kwargs = { + "preprocessor": self.preprocessor, + "backbone": self.backbone, + } + self.input_data = { + "images": ops.ones((2, 64, 64, 3)), + "latents": ops.ones((2, 8, 8, 16)), + "clip_l_token_ids": ops.ones((2, 5), dtype="int32"), + "clip_l_negative_token_ids": ops.ones((2, 5), dtype="int32"), + "clip_g_token_ids": ops.ones((2, 5), dtype="int32"), + "clip_g_negative_token_ids": ops.ones((2, 5), dtype="int32"), + "num_steps": ops.ones((2,), dtype="int32"), + "guidance_scale": ops.ones((2,)), + } + + def test_image_to_image_basics(self): + pytest.skip( + reason="TODO: enable after preprocessor flow is figured out" + ) + self.run_task_test( + cls=StableDiffusion3ImageToImage, + init_kwargs=self.init_kwargs, + train_data=None, + expected_output_shape={ + "images": (2, 64, 64, 3), + "latents": (2, 8, 8, 16), + }, + ) + + def test_generate(self): + image_to_image = StableDiffusion3ImageToImage(**self.init_kwargs) + seed = 42 + image = self.input_data["images"][0] + # String input. + prompt = ["airplane"] + negative_prompt = [""] + output = image_to_image.generate( + { + "images": image, + "prompts": prompt, + "negative_prompts": negative_prompt, + }, + seed=seed, + ) + # Int tensor input. + prompt_ids = self.preprocessor.generate_preprocess(prompt) + negative_prompt_ids = self.preprocessor.generate_preprocess( + negative_prompt + ) + image_to_image.preprocessor = None + output2 = image_to_image.generate( + { + "images": image, + "prompts": prompt_ids, + "negative_prompts": negative_prompt_ids, + }, + seed=seed, + ) + self.assertAllClose(output, output2) + + def test_generate_with_lower_precision(self): + original_floatx = keras.config.floatx() + try: + for dtype in ["float16", "bfloat16"]: + keras.config.set_floatx(dtype) + image_to_image = StableDiffusion3ImageToImage( + **self.init_kwargs + ) + seed = 42 + image = self.input_data["images"][0] + # String input. + prompt = ["airplane"] + negative_prompt = [""] + output = image_to_image.generate( + { + "images": image, + "prompts": prompt, + "negative_prompts": negative_prompt, + }, + seed=seed, + ) + # Int tensor input. + prompt_ids = self.preprocessor.generate_preprocess(prompt) + negative_prompt_ids = self.preprocessor.generate_preprocess( + negative_prompt + ) + image_to_image.preprocessor = None + output2 = image_to_image.generate( + { + "images": image, + "prompts": prompt_ids, + "negative_prompts": negative_prompt_ids, + }, + seed=seed, + ) + self.assertAllClose(output, output2) + finally: + # Restore floatx to the original value to prevent impact on other + # tests even if there is an exception. + keras.config.set_floatx(original_floatx) + + def test_generate_compilation(self): + image_to_image = StableDiffusion3ImageToImage(**self.init_kwargs) + image = self.input_data["images"][0] + # Assert we do not recompile with successive calls. + image_to_image.generate({"images": image, "prompts": "airplane"}) + first_fn = image_to_image.generate_function + image_to_image.generate({"images": image, "prompts": "airplane"}) + second_fn = image_to_image.generate_function + self.assertEqual(first_fn, second_fn) + # Assert we do recompile after compile is called. + image_to_image.compile() + self.assertIsNone(image_to_image.generate_function) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=StableDiffusion3ImageToImage, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py new file mode 100644 index 0000000000..8d5ed7c6af --- /dev/null +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py @@ -0,0 +1,194 @@ +from keras import ops + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.inpaint import Inpaint +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import ( + StableDiffusion3Backbone, +) +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image_preprocessor import ( + StableDiffusion3TextToImagePreprocessor, +) + + +@keras_hub_export("keras_hub.models.StableDiffusion3Inpaint") +class StableDiffusion3Inpaint(Inpaint): + """An end-to-end Stable Diffusion 3 model for inpaint generation. + + This model has a `generate()` method, which generates images based + on a combination of a reference image, mask and a text prompt. + + Args: + backbone: A `keras_hub.models.StableDiffusion3Backbone` instance. + preprocessor: A + `keras_hub.models.StableDiffusion3TextToImagePreprocessor` instance. + + Examples: + + Use `generate()` to do image generation. + ```python + reference_image = np.ones((1024, 1024, 3), dtype="float32") + reference_mask = np.ones((1024, 1024), dtype="float32") + inpaint = keras_hub.models.StableDiffusion3Inpaint.from_preset( + "stable_diffusion_3_medium", image_shape=(512, 512, 3) + ) + inpaint.generate( + reference_image, + reference_mask, + "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", + ) + + # Generate with batched prompts. + reference_images = np.ones((2, 512, 512, 3), dtype="float32") + reference_mask = np.ones((2, 1024, 1024), dtype="float32") + inpaint.generate( + reference_images, + reference_mask, + ["cute wallpaper art of a cat", "cute wallpaper art of a dog"] + ) + + # Generate with different `num_steps`, `guidance_scale` and `strength`. + inpaint.generate( + reference_image, + reference_mask, + "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", + num_steps=50, + guidance_scale=5.0, + strength=0.6, + ) + ``` + """ + + backbone_cls = StableDiffusion3Backbone + preprocessor_cls = StableDiffusion3TextToImagePreprocessor + + def __init__( + self, + backbone, + preprocessor, + **kwargs, + ): + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor + + # === Functional Model === + inputs = backbone.input + outputs = backbone.output + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + def fit(self, *args, **kwargs): + raise NotImplementedError( + "Currently, `fit` is not supported for " + "`StableDiffusion3Inpaint`." + ) + + def generate_step( + self, + images, + masks, + noises, + token_ids, + starting_step, + num_steps, + guidance_scale, + ): + """A compilable generation function for batched of inputs. + + This function represents the inner, XLA-compilable, generation function + for batched inputs. + + Args: + images: A (batch_size, image_height, image_width, 3) tensor + containing the reference images. + masks: A (batch_size, image_height, image_width) tensor + containing the reference masks. + noises: A (batch_size, latent_height, latent_width, channels) tensor + containing the noises to be added to the latents. Typically, + this tensor is sampled from the Gaussian distribution. + token_ids: A pair of (batch_size, num_tokens) tensor containing the + tokens based on the input prompts and negative prompts. + starting_step: int. The number of the starting diffusion step. + num_steps: int. The number of diffusion steps to take. + guidance_scale: float. The classifier free guidance scale defined in + [Classifier-Free Diffusion Guidance]( + https://arxiv.org/abs/2207.12598). Higher scale encourages to + generate images that are closely linked to prompts, usually at + the expense of lower image quality. + """ + token_ids, negative_token_ids = token_ids + + # Get masked images. + masks = ops.cast(ops.expand_dims(masks, axis=-1) > 0.5, images.dtype) + masks_latent_size = ops.image.resize( + masks, + (self.backbone.latent_shape[1], self.backbone.latent_shape[2]), + interpolation="nearest", + ) + + # Encode images. + image_latents = self.backbone.encode_image_step(images) + + # Add noises to latents. + latents = self.backbone.add_noise_step( + image_latents, noises, starting_step, num_steps + ) + + # Encode inputs. + embeddings = self.backbone.encode_text_step( + token_ids, negative_token_ids + ) + + # Denoise. + def body_fun(step, latents): + latents = self.backbone.denoise_step( + latents, + embeddings, + step, + num_steps, + guidance_scale, + ) + + # Compute the previous latents x_t -> x_t-1. + def true_fn(): + next_step = ops.add(step, 1) + return self.backbone.add_noise_step( + image_latents, noises, next_step, num_steps + ) + + init_latents = ops.cond( + step < ops.subtract(num_steps, 1), + true_fn, + lambda: ops.cast(image_latents, noises.dtype), + ) + latents = ops.add( + ops.multiply( + ops.subtract(1.0, masks_latent_size), init_latents + ), + ops.multiply(masks_latent_size, latents), + ) + return latents + + latents = ops.fori_loop(starting_step, num_steps, body_fun, latents) + + # Decode. + return self.backbone.decode_step(latents) + + def generate( + self, + inputs, + num_steps=50, + guidance_scale=7.0, + strength=0.6, + seed=None, + ): + return super().generate( + inputs, + num_steps=num_steps, + guidance_scale=guidance_scale, + strength=strength, + seed=seed, + ) diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint_test.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint_test.py new file mode 100644 index 0000000000..5e8ddd32c6 --- /dev/null +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint_test.py @@ -0,0 +1,197 @@ +import keras +import pytest +from keras import ops + +from keras_hub.src.models.clip.clip_preprocessor import CLIPPreprocessor +from keras_hub.src.models.clip.clip_text_encoder import CLIPTextEncoder +from keras_hub.src.models.clip.clip_tokenizer import CLIPTokenizer +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import ( + StableDiffusion3Backbone, +) +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_inpaint import ( + StableDiffusion3Inpaint, +) +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image_preprocessor import ( + StableDiffusion3TextToImagePreprocessor, +) +from keras_hub.src.models.vae.vae_backbone import VAEBackbone +from keras_hub.src.tests.test_case import TestCase + + +class StableDiffusion3InpaintTest(TestCase): + def setUp(self): + # Instantiate the preprocessor. + vocab = ["air", "plane", "port"] + vocab += ["<|endoftext|>", "<|startoftext|>"] + vocab = dict([(token, i) for i, token in enumerate(vocab)]) + merges = ["a i", "p l", "n e", "p o", "r t", "ai r", "pl a"] + merges += ["po rt", "pla ne"] + clip_l_tokenizer = CLIPTokenizer(vocab, merges, pad_with_end_token=True) + clip_g_tokenizer = CLIPTokenizer(vocab, merges) + clip_l_preprocessor = CLIPPreprocessor(clip_l_tokenizer) + clip_g_preprocessor = CLIPPreprocessor(clip_g_tokenizer) + self.preprocessor = StableDiffusion3TextToImagePreprocessor( + clip_l_preprocessor, clip_g_preprocessor + ) + + self.backbone = StableDiffusion3Backbone( + mmdit_patch_size=2, + mmdit_hidden_dim=16 * 2, + mmdit_num_layers=2, + mmdit_num_heads=2, + mmdit_position_size=192, + vae=VAEBackbone( + [32, 32, 32, 32], + [1, 1, 1, 1], + [32, 32, 32, 32], + [1, 1, 1, 1], + # Use `mode` generate a deterministic output. + sampler_method="mode", + name="vae", + ), + clip_l=CLIPTextEncoder( + 20, 64, 64, 2, 2, 128, "quick_gelu", -2, name="clip_l" + ), + clip_g=CLIPTextEncoder( + 20, 128, 128, 2, 2, 256, "gelu", -2, name="clip_g" + ), + image_shape=(64, 64, 3), + ) + self.init_kwargs = { + "preprocessor": self.preprocessor, + "backbone": self.backbone, + } + self.input_data = { + "images": ops.ones((2, 64, 64, 3)), + "latents": ops.ones((2, 8, 8, 16)), + "clip_l_token_ids": ops.ones((2, 5), dtype="int32"), + "clip_l_negative_token_ids": ops.ones((2, 5), dtype="int32"), + "clip_g_token_ids": ops.ones((2, 5), dtype="int32"), + "clip_g_negative_token_ids": ops.ones((2, 5), dtype="int32"), + "num_steps": ops.ones((2,), dtype="int32"), + "guidance_scale": ops.ones((2,)), + } + + def test_inpaint_basics(self): + pytest.skip( + reason="TODO: enable after preprocessor flow is figured out" + ) + self.run_task_test( + cls=StableDiffusion3Inpaint, + init_kwargs=self.init_kwargs, + train_data=None, + expected_output_shape={ + "images": (2, 64, 64, 3), + "latents": (2, 8, 8, 16), + }, + ) + + def test_generate(self): + inpaint = StableDiffusion3Inpaint(**self.init_kwargs) + seed = 42 + image = self.input_data["images"][0] + mask = self.input_data["images"][0][..., 0] # (B, H, W) + # String input. + prompt = ["airplane"] + negative_prompt = [""] + output = inpaint.generate( + { + "images": image, + "masks": mask, + "prompts": prompt, + "negative_prompts": negative_prompt, + }, + seed=seed, + ) + # Int tensor input. + prompt_ids = self.preprocessor.generate_preprocess(prompt) + negative_prompt_ids = self.preprocessor.generate_preprocess( + negative_prompt + ) + inpaint.preprocessor = None + output2 = inpaint.generate( + { + "images": image, + "masks": mask, + "prompts": prompt_ids, + "negative_prompts": negative_prompt_ids, + }, + seed=seed, + ) + self.assertAllClose(output, output2) + + def test_generate_with_lower_precision(self): + original_floatx = keras.config.floatx() + try: + for dtype in ["float16", "bfloat16"]: + keras.config.set_floatx(dtype) + inpaint = StableDiffusion3Inpaint(**self.init_kwargs) + seed = 42 + image = self.input_data["images"][0] + mask = self.input_data["images"][0][..., 0] # (B, H, W) + # String input. + prompt = ["airplane"] + negative_prompt = [""] + output = inpaint.generate( + { + "images": image, + "masks": mask, + "prompts": prompt, + "negative_prompts": negative_prompt, + }, + seed=seed, + ) + # Int tensor input. + prompt_ids = self.preprocessor.generate_preprocess(prompt) + negative_prompt_ids = self.preprocessor.generate_preprocess( + negative_prompt + ) + inpaint.preprocessor = None + output2 = inpaint.generate( + { + "images": image, + "masks": mask, + "prompts": prompt_ids, + "negative_prompts": negative_prompt_ids, + }, + seed=seed, + ) + self.assertAllClose(output, output2) + finally: + # Restore floatx to the original value to prevent impact on other + # tests even if there is an exception. + keras.config.set_floatx(original_floatx) + + def test_generate_compilation(self): + inpaint = StableDiffusion3Inpaint(**self.init_kwargs) + image = self.input_data["images"][0] + mask = self.input_data["images"][0][..., 0] # (B, H, W) + # Assert we do not recompile with successive calls. + inpaint.generate( + { + "images": image, + "masks": mask, + "prompts": "airplane", + } + ) + first_fn = inpaint.generate_function + inpaint.generate( + { + "images": image, + "masks": mask, + "prompts": "airplane", + } + ) + second_fn = inpaint.generate_function + self.assertEqual(first_fn, second_fn) + # Assert we do recompile after compile is called. + inpaint.compile() + self.assertIsNone(inpaint.generate_function) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=StableDiffusion3Inpaint, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py index 2067fdb8dc..a7756fc645 100644 --- a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py @@ -10,9 +10,9 @@ ), "params": 2987080931, "official_name": "StableDiffusion3", - "path": "stablediffusion3", + "path": "stable_diffusion_3", "model_card": "https://arxiv.org/abs/2110.00476", }, - "kaggle_handle": "kaggle://kerashub/stablediffusion3/keras/stable_diffusion_3_medium/3", + "kaggle_handle": "kaggle://keras/stablediffusion3/keras/stable_diffusion_3_medium/3", } } diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py index 63f0ba6c28..739c6f4650 100644 --- a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py @@ -27,7 +27,7 @@ class StableDiffusion3TextToImage(TextToImage): Use `generate()` to do image generation. ```python text_to_image = keras_hub.models.StableDiffusion3TextToImage.from_preset( - "stable_diffusion_3_medium", height=512, width=512 + "stable_diffusion_3_medium", image_shape=(512, 512, 3) ) text_to_image.generate( "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" @@ -44,6 +44,14 @@ class StableDiffusion3TextToImage(TextToImage): num_steps=50, guidance_scale=5.0, ) + + # Generate with `negative_prompts`. + text_to_image.generate( + { + "prompts": "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", + "negative_prompts": "green color", + } + ) ``` """ @@ -79,7 +87,6 @@ def generate_step( self, latents, token_ids, - negative_token_ids, num_steps, guidance_scale, ): @@ -92,10 +99,8 @@ def generate_step( latents: A (batch_size, height, width, channels) tensor containing the latents to start generation from. Typically, this tensor is sampled from the Gaussian distribution. - token_ids: A (batch_size, num_tokens) tensor containing the - tokens based on the input prompts. - negative_token_ids: A (batch_size, num_tokens) tensor - containing the negative tokens based on the input prompts. + token_ids: A pair of (batch_size, num_tokens) tensor containing the + tokens based on the input prompts and negative prompts. num_steps: int. The number of diffusion steps to take. guidance_scale: float. The classifier free guidance scale defined in [Classifier-Free Diffusion Guidance]( @@ -103,7 +108,9 @@ def generate_step( generate images that are closely linked to prompts, usually at the expense of lower image quality. """ - # Encode inputs. + token_ids, negative_token_ids = token_ids + + # Encode prompts. embeddings = self.backbone.encode_text_step( token_ids, negative_token_ids ) @@ -126,14 +133,12 @@ def body_fun(step, latents): def generate( self, inputs, - negative_inputs=None, num_steps=28, guidance_scale=7.0, seed=None, ): return super().generate( inputs, - negative_inputs=negative_inputs, num_steps=num_steps, guidance_scale=guidance_scale, seed=seed, diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_test.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_test.py index 837c95fa37..69d30de834 100644 --- a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_test.py +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_test.py @@ -55,8 +55,7 @@ def setUp(self): clip_g=CLIPTextEncoder( 20, 128, 128, 2, 2, 256, "gelu", -2, name="clip_g" ), - height=64, - width=64, + image_shape=(64, 64, 3), ) self.init_kwargs = { "preprocessor": self.preprocessor, @@ -93,7 +92,13 @@ def test_generate(self): # String input. prompt = ["airplane"] negative_prompt = [""] - output = text_to_image.generate(prompt, negative_prompt, seed=seed) + output = text_to_image.generate( + { + "prompts": prompt, + "negative_prompts": negative_prompt, + }, + seed=seed, + ) # Int tensor input. prompt_ids = self.preprocessor.generate_preprocess(prompt) negative_prompt_ids = self.preprocessor.generate_preprocess( @@ -101,7 +106,11 @@ def test_generate(self): ) text_to_image.preprocessor = None output2 = text_to_image.generate( - prompt_ids, negative_prompt_ids, seed=seed + { + "prompts": prompt_ids, + "negative_prompts": negative_prompt_ids, + }, + seed=seed, ) self.assertAllClose(output, output2) @@ -116,7 +125,11 @@ def test_generate_with_lower_precision(self): prompt = ["airplane"] negative_prompt = [""] output = text_to_image.generate( - prompt, negative_prompt, seed=seed + { + "prompts": prompt, + "negative_prompts": negative_prompt, + }, + seed=seed, ) # Int tensor input. prompt_ids = self.preprocessor.generate_preprocess(prompt) @@ -125,7 +138,11 @@ def test_generate_with_lower_precision(self): ) text_to_image.preprocessor = None output2 = text_to_image.generate( - prompt_ids, negative_prompt_ids, seed=seed + { + "prompts": prompt_ids, + "negative_prompts": negative_prompt_ids, + }, + seed=seed, ) self.assertAllClose(output, output2) finally: diff --git a/keras_hub/src/models/task.py b/keras_hub/src/models/task.py index b107284444..af12f1cb1c 100644 --- a/keras_hub/src/models/task.py +++ b/keras_hub/src/models/task.py @@ -280,7 +280,7 @@ def summary( def highlight_number(x): if x is None: - f"[color(45)]{x}[/]" + return f"[color(45)]{x}[/]" return f"[color(34)]{x:,}[/]" # Format number with commas. def highlight_symbol(x): @@ -339,7 +339,10 @@ def add_layer(layer, info): add_layer(layer, info) elif isinstance(layer, ImageConverter): info = "Image size: " - info += highlight_shape(layer.image_size()) + image_size = layer.image_size + if image_size is None: + image_size = (None, None) + info += highlight_shape(image_size) add_layer(layer, info) elif isinstance(layer, AudioConverter): info = "Audio shape: " diff --git a/keras_hub/src/models/text_to_image.py b/keras_hub/src/models/text_to_image.py index 291a4b023e..54b8dcdae2 100644 --- a/keras_hub/src/models/text_to_image.py +++ b/keras_hub/src/models/text_to_image.py @@ -56,6 +56,11 @@ def __init__(self, *args, **kwargs): # Default compilation. self.compile() + @property + def support_negative_prompts(self): + """Whether the model supports `negative_prompts` key in `generate()`.""" + return bool(True) + @property def latent_shape(self): return tuple(self.backbone.latent_shape) @@ -171,9 +176,26 @@ def _normalize_generate_inputs(self, inputs): This function converts all inputs to tensors, adds a batch dimension if necessary, and returns a iterable "dataset like" object (either an actual `tf.data.Dataset` or a list with a single batch element). + + The input format must be one of the following: + - A single string + - A list of strings + - A dict with "prompts" and/or "negative_prompts" keys + - A tf.data.Dataset with "prompts" and/or "negative_prompts" keys + + The output will be a dict with "prompts" and/or "negative_prompts" keys. """ if tf and isinstance(inputs, tf.data.Dataset): - return inputs.as_numpy_iterator(), False + _inputs = { + "prompts": inputs.map( + lambda x: x["prompts"] + ).as_numpy_iterator() + } + if self.support_negative_prompts: + _inputs["negative_prompts"] = inputs.map( + lambda x: x["negative_prompts"] + ).as_numpy_iterator() + return _inputs, False def normalize(x): if isinstance(x, str): @@ -182,13 +204,24 @@ def normalize(x): return x[tf.newaxis], True return x, False + def get_dummy_prompts(x): + dummy_prompts = [""] * len(x) + if tf and isinstance(x, tf.Tensor): + return tf.convert_to_tensor(dummy_prompts) + else: + return dummy_prompts + if isinstance(inputs, dict): for key in inputs: inputs[key], input_is_scalar = normalize(inputs[key]) else: inputs, input_is_scalar = normalize(inputs) + inputs = {"prompts": inputs} - return inputs, input_is_scalar + if self.support_negative_prompts and "negative_prompts" not in inputs: + inputs["negative_prompts"] = get_dummy_prompts(inputs["prompts"]) + + return [inputs], input_is_scalar def _normalize_generate_outputs(self, outputs, input_is_scalar): """Normalize user output from the generate function. @@ -199,12 +232,11 @@ def _normalize_generate_outputs(self, outputs, input_is_scalar): """ def normalize(x): - outputs = ops.clip(ops.divide(ops.add(x, 1.0), 2.0), 0.0, 1.0) + outputs = ops.concatenate(x, axis=0) + outputs = ops.clip(ops.divide(ops.add(outputs, 1.0), 2.0), 0.0, 1.0) outputs = ops.cast(ops.round(ops.multiply(outputs, 255.0)), "uint8") - outputs = ops.convert_to_numpy(outputs) - if input_is_scalar: - outputs = outputs[0] - return outputs + outputs = ops.squeeze(outputs, 0) if input_is_scalar else outputs + return ops.convert_to_numpy(outputs) if isinstance(outputs[0], dict): normalized = {} @@ -216,23 +248,40 @@ def normalize(x): def generate( self, inputs, - negative_inputs, num_steps, guidance_scale, seed=None, ): - """Generate image based on the provided `inputs` and `negative_inputs`. + """Generate image based on the provided `inputs`. + + Typically, `inputs` contains a text description (known as a prompt) used + to guide the image generation. + + Some models support a `negative_prompts` key, which helps steer the + model away from generating certain styles and elements. To enable this, + pass `prompts` and `negative_prompts` as a dict: + + ```python + text_to_image.generate( + { + "prompts": "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", + "negative_prompts": "green color", + } + ) + ``` If `inputs` are a `tf.data.Dataset`, outputs will be generated "batch-by-batch" and concatenated. Otherwise, all inputs will be processed as batches. Args: - inputs: python data, tensor data, or a `tf.data.Dataset`. - negative_inputs: python data, tensor data, or a `tf.data.Dataset`. - Unlike `inputs`, these are used as negative inputs to guide the - generation. If not provided, it defaults to `""` for each input - in `inputs`. + inputs: python data, tensor data, or a `tf.data.Dataset`. The format + must be one of the following: + - A single string + - A list of strings + - A dict with "prompts" and/or "negative_prompts" keys + - A `tf.data.Dataset` with "prompts" and/or "negative_prompts" + keys num_steps: int. The number of diffusion steps to take. guidance_scale: float. The classifier free guidance scale defined in [Classifier-Free Diffusion Guidance]( @@ -251,32 +300,36 @@ def generate( generate_function = self.make_generate_function() def preprocess(x): - return self.preprocessor.generate_preprocess(x) + if self.preprocessor is not None: + return self.preprocessor.generate_preprocess(x) + else: + return x + + def generate(x): + token_ids = x[0] if self.support_negative_prompts else x + + # Initialize latents. + if isinstance(token_ids, dict): + arbitrary_key = list(token_ids.keys())[0] + batch_size = ops.shape(token_ids[arbitrary_key])[0] + else: + batch_size = ops.shape(token_ids)[0] + latent_shape = (batch_size,) + self.latent_shape[1:] + latents = random.normal(latent_shape, dtype="float32", seed=seed) + + return generate_function(latents, x, num_steps, guidance_scale) # Normalize and preprocess inputs. inputs, input_is_scalar = self._normalize_generate_inputs(inputs) - if negative_inputs is None: - negative_inputs = [""] * len(inputs) - negative_inputs, _ = self._normalize_generate_inputs(negative_inputs) - - if self.preprocessor is not None: - inputs = preprocess(inputs) - negative_inputs = preprocess(negative_inputs) - if isinstance(inputs, dict): - batch_size = len(inputs[list(inputs.keys())[0]]) + if self.support_negative_prompts: + token_ids = [preprocess(x["prompts"]) for x in inputs] + negative_token_ids = [ + preprocess(x["negative_prompts"]) for x in inputs + ] + inputs = [x for x in zip(token_ids, negative_token_ids)] else: - batch_size = len(inputs) - - # Initialize random latents. - latent_shape = (batch_size,) + self.latent_shape[1:] - latents = random.normal(latent_shape, dtype="float32", seed=seed) + inputs = [preprocess(x["prompts"]) for x in inputs] # Text-to-image. - outputs = generate_function( - latents, - inputs, - negative_inputs, - num_steps, - guidance_scale, - ) + outputs = [generate(x) for x in inputs] return self._normalize_generate_outputs(outputs, input_is_scalar) diff --git a/keras_hub/src/models/vae/vae_backbone.py b/keras_hub/src/models/vae/vae_backbone.py index c84986314d..606107d17f 100644 --- a/keras_hub/src/models/vae/vae_backbone.py +++ b/keras_hub/src/models/vae/vae_backbone.py @@ -10,7 +10,7 @@ class VAEBackbone(Backbone): - """VAE backbone used in latent diffusion models. + """Variational Autoencoder(VAE) backbone used in latent diffusion models. When encoding, this model generates mean and log variance of the input images. When decoding, it reconstructs images from the latent space. @@ -51,6 +51,18 @@ class VAEBackbone(Backbone): `"channels_last"`. dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype to use for the model's computations and weights. + + Example: + ```Python + backbone = VAEBackbone( + encoder_num_filters=[32, 32, 32, 32], + encoder_num_blocks=[1, 1, 1, 1], + decoder_num_filters=[32, 32, 32, 32], + decoder_num_blocks=[1, 1, 1, 1], + ) + input_data = ops.ones((2, self.height, self.width, 3)) + output = backbone(input_data) + ``` """ def __init__( diff --git a/keras_hub/src/models/vgg/__init__.py b/keras_hub/src/models/vgg/__init__.py index e69de29bb2..4850d0eab4 100644 --- a/keras_hub/src/models/vgg/__init__.py +++ b/keras_hub/src/models/vgg/__init__.py @@ -0,0 +1,5 @@ +from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone +from keras_hub.src.models.vgg.vgg_presets import backbone_presets +from keras_hub.src.utils.preset_utils import register_presets + +register_presets(backbone_presets, VGGBackbone) diff --git a/keras_hub/src/models/vgg/vgg_backbone.py b/keras_hub/src/models/vgg/vgg_backbone.py index cf2638146e..ef91c8689d 100644 --- a/keras_hub/src/models/vgg/vgg_backbone.py +++ b/keras_hub/src/models/vgg/vgg_backbone.py @@ -20,7 +20,7 @@ class VGGBackbone(Backbone): stackwise_num_filters: list of ints, filter size for convolutional blocks per VGG block. For both VGG16 and VGG19 this is [ 64, 128, 256, 512, 512]. - image_shape: tuple, optional shape tuple, defaults to (224, 224, 3). + image_shape: tuple, optional shape tuple, defaults to (None, None, 3). Examples: ```python @@ -47,12 +47,11 @@ def __init__( image_shape=(None, None, 3), **kwargs, ): - # === Functional Model === img_input = keras.layers.Input(shape=image_shape) x = img_input - for stack_index in range(len(stackwise_num_repeats) - 1): + for stack_index in range(len(stackwise_num_repeats)): x = apply_vgg_block( x=x, num_layers=stackwise_num_repeats[stack_index], diff --git a/keras_hub/src/models/vgg/vgg_backbone_test.py b/keras_hub/src/models/vgg/vgg_backbone_test.py index 87e9ed6ef5..19dd7844da 100644 --- a/keras_hub/src/models/vgg/vgg_backbone_test.py +++ b/keras_hub/src/models/vgg/vgg_backbone_test.py @@ -19,7 +19,7 @@ def test_backbone_basics(self): cls=VGGBackbone, init_kwargs=self.init_kwargs, input_data=self.input_data, - expected_output_shape=(2, 4, 4, 64), + expected_output_shape=(2, 2, 2, 64), run_mixed_precision_check=False, ) diff --git a/keras_hub/src/models/vgg/vgg_image_classifier.py b/keras_hub/src/models/vgg/vgg_image_classifier.py index 4d02f1ca5f..a72b256288 100644 --- a/keras_hub/src/models/vgg/vgg_image_classifier.py +++ b/keras_hub/src/models/vgg/vgg_image_classifier.py @@ -4,6 +4,9 @@ from keras_hub.src.models.image_classifier import ImageClassifier from keras_hub.src.models.task import Task from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone +from keras_hub.src.models.vgg.vgg_image_classifier_preprocessor import ( + VGGImageClassifierPreprocessor, +) @keras_hub_export("keras_hub.models.VGGImageClassifier") @@ -96,13 +99,14 @@ class VGGImageClassifier(ImageClassifier): """ backbone_cls = VGGBackbone + preprocessor_cls = VGGImageClassifierPreprocessor def __init__( self, backbone, num_classes, preprocessor=None, - pooling="flatten", + pooling="avg", pooling_hidden_dim=4096, activation=None, dropout=0.0, @@ -141,24 +145,46 @@ def __init__( "Unknown `pooling` type. Polling should be either `'avg'` or " f"`'max'`. Received: pooling={pooling}." ) - self.output_dropout = keras.layers.Dropout( - dropout, - dtype=head_dtype, - name="output_dropout", - ) - self.output_dense = keras.layers.Dense( - num_classes, - activation=activation, - dtype=head_dtype, - name="predictions", + + self.head = keras.Sequential( + [ + keras.layers.Conv2D( + filters=4096, + kernel_size=7, + name="fc1", + activation=activation, + use_bias=True, + padding="same", + ), + keras.layers.Dropout( + rate=dropout, + dtype=head_dtype, + name="output_dropout", + ), + keras.layers.Conv2D( + filters=4096, + kernel_size=1, + name="fc2", + activation=activation, + use_bias=True, + padding="same", + ), + self.pooler, + keras.layers.Dense( + num_classes, + activation=activation, + dtype=head_dtype, + name="predictions", + ), + ], + name="head", ) # === Functional Model === inputs = self.backbone.input x = self.backbone(inputs) - x = self.pooler(x) - x = self.output_dropout(x) - outputs = self.output_dense(x) + outputs = self.head(x) + # Skip the parent class functional model. Task.__init__( self, @@ -173,6 +199,7 @@ def __init__( self.pooling = pooling self.pooling_hidden_dim = pooling_hidden_dim self.dropout = dropout + self.preprocessor = preprocessor def get_config(self): # Backbone serialized in `super` diff --git a/keras_hub/src/models/vgg/vgg_image_classifier_preprocessor.py b/keras_hub/src/models/vgg/vgg_image_classifier_preprocessor.py new file mode 100644 index 0000000000..f32f965095 --- /dev/null +++ b/keras_hub/src/models/vgg/vgg_image_classifier_preprocessor.py @@ -0,0 +1,12 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.image_classifier_preprocessor import ( + ImageClassifierPreprocessor, +) +from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone +from keras_hub.src.models.vgg.vgg_image_converter import VGGImageConverter + + +@keras_hub_export("keras_hub.models.VGGImageClassifierPreprocessor") +class VGGImageClassifierPreprocessor(ImageClassifierPreprocessor): + backbone_cls = VGGBackbone + image_converter_cls = VGGImageConverter diff --git a/keras_hub/src/models/vgg/vgg_image_classifier_test.py b/keras_hub/src/models/vgg/vgg_image_classifier_test.py index 6d95ddaac5..34bb7e3db8 100644 --- a/keras_hub/src/models/vgg/vgg_image_classifier_test.py +++ b/keras_hub/src/models/vgg/vgg_image_classifier_test.py @@ -3,24 +3,33 @@ from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone from keras_hub.src.models.vgg.vgg_image_classifier import VGGImageClassifier +from keras_hub.src.models.vgg.vgg_image_classifier_preprocessor import ( + VGGImageClassifierPreprocessor, +) +from keras_hub.src.models.vgg.vgg_image_converter import VGGImageConverter from keras_hub.src.tests.test_case import TestCase class VGGImageClassifierTest(TestCase): def setUp(self): # Setup model. - self.images = np.ones((2, 4, 4, 3), dtype="float32") - self.labels = [0, 3] + self.images = np.ones((2, 8, 8, 3), dtype="float32") + self.labels = [0, 1] self.backbone = VGGBackbone( stackwise_num_repeats=[2, 4, 4], stackwise_num_filters=[2, 16, 16], - image_shape=(4, 4, 3), + image_shape=(8, 8, 3), + ) + image_converter = VGGImageConverter(image_size=(8, 8)) + self.preprocessor = VGGImageClassifierPreprocessor( + image_converter=image_converter, ) self.init_kwargs = { "backbone": self.backbone, "num_classes": 2, "activation": "softmax", "pooling": "flatten", + "preprocessor": self.preprocessor, } self.train_data = ( self.images, @@ -28,9 +37,6 @@ def setUp(self): ) def test_classifier_basics(self): - pytest.skip( - reason="TODO: enable after preprocessor flow is figured out" - ) self.run_task_test( cls=VGGImageClassifier, init_kwargs=self.init_kwargs, diff --git a/keras_hub/src/models/vgg/vgg_image_converter.py b/keras_hub/src/models/vgg/vgg_image_converter.py new file mode 100644 index 0000000000..69ccacbd1d --- /dev/null +++ b/keras_hub/src/models/vgg/vgg_image_converter.py @@ -0,0 +1,8 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.preprocessing.image_converter import ImageConverter +from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone + + +@keras_hub_export("keras_hub.layers.VGGImageConverter") +class VGGImageConverter(ImageConverter): + backbone_cls = VGGBackbone diff --git a/keras_hub/src/models/vgg/vgg_presets.py b/keras_hub/src/models/vgg/vgg_presets.py new file mode 100644 index 0000000000..e0379a8da0 --- /dev/null +++ b/keras_hub/src/models/vgg/vgg_presets.py @@ -0,0 +1,56 @@ +"""vgg preset configurations.""" + +backbone_presets = { + "vgg_11_imagenet": { + "metadata": { + "description": ( + "11-layer vgg model pre-trained on the ImageNet 1k dataset " + "at a 224x224 resolution." + ), + "params": 9220480, + "official_name": "vgg", + "path": "vgg", + "model_card": "https://arxiv.org/abs/1409.1556", + }, + "kaggle_handle": "kaggle://keras/vgg/keras/vgg_11_imagenet/1", + }, + "vgg_13_imagenet": { + "metadata": { + "description": ( + "13-layer vgg model pre-trained on the ImageNet 1k dataset " + "at a 224x224 resolution." + ), + "params": 9404992, + "official_name": "vgg", + "path": "vgg", + "model_card": "https://arxiv.org/abs/1409.1556", + }, + "kaggle_handle": "kaggle://keras/vgg/keras/vgg_13_imagenet/1", + }, + "vgg_16_imagenet": { + "metadata": { + "description": ( + "16-layer vgg model pre-trained on the ImageNet 1k dataset " + "at a 224x224 resolution." + ), + "params": 14714688, + "official_name": "vgg", + "path": "vgg", + "model_card": "https://arxiv.org/abs/1409.1556", + }, + "kaggle_handle": "kaggle://keras/vgg/keras/vgg_16_imagenet/1", + }, + "vgg_19_imagenet": { + "metadata": { + "description": ( + "19-layer vgg model pre-trained on the ImageNet 1k dataset " + "at a 224x224 resolution." + ), + "params": 20024384, + "official_name": "vgg", + "path": "vgg", + "model_card": "https://arxiv.org/abs/1409.1556", + }, + "kaggle_handle": "kaggle://keras/vgg/keras/vgg_19_imagenet/1", + }, +} diff --git a/keras_hub/src/models/vit_det/vit_det_backbone.py b/keras_hub/src/models/vit_det/vit_det_backbone.py index 94f7887c44..7d5883409e 100644 --- a/keras_hub/src/models/vit_det/vit_det_backbone.py +++ b/keras_hub/src/models/vit_det/vit_det_backbone.py @@ -31,7 +31,7 @@ class ViTDetBackbone(Backbone): global_attention_layer_indices (list): Indexes for blocks using global attention. image_shape (tuple[int], optional): The size of the input image in - `(H, W, C)` format. Defaults to `(1024, 1024, 3)`. + `(H, W, C)` format. Defaults to `(None, None, 3)`. patch_size (int, optional): the patch size to be supplied to the Patching layer to turn input images into a flattened sequence of patches. Defaults to `16`. @@ -79,7 +79,7 @@ def __init__( intermediate_dim, num_heads, global_attention_layer_indices, - image_shape=(1024, 1024, 3), + image_shape=(None, None, 3), patch_size=16, num_output_channels=256, use_bias=True, diff --git a/keras_hub/src/models/whisper/whisper_audio_converter.py b/keras_hub/src/models/whisper/whisper_audio_converter.py index 633042f547..9890109bac 100644 --- a/keras_hub/src/models/whisper/whisper_audio_converter.py +++ b/keras_hub/src/models/whisper/whisper_audio_converter.py @@ -39,7 +39,7 @@ class WhisperAudioConverter(AudioConverter): audio_tensor = tf.ones((8000,), dtype="float32") # Compute the log-mel spectrogram. - audio_converter = keras_hub.models.WhisperAudioConverter.from_preset( + audio_converter = keras_hub.layers.WhisperAudioConverter.from_preset( "whisper_base_en", ) audio_converter(audio_tensor) diff --git a/keras_hub/src/tests/test_case.py b/keras_hub/src/tests/test_case.py index 6d06c7266c..03c01cb24b 100644 --- a/keras_hub/src/tests/test_case.py +++ b/keras_hub/src/tests/test_case.py @@ -388,6 +388,8 @@ def run_model_saving_test( cls, init_kwargs, input_data, + atol=0.000001, + rtol=0.000001, ): """Save and load a model from disk and assert output is unchanged.""" model = cls(**init_kwargs) @@ -401,7 +403,7 @@ def run_model_saving_test( # Check that output matches. restored_output = restored_model(input_data) - self.assertAllClose(model_output, restored_output) + self.assertAllClose(model_output, restored_output, atol=atol, rtol=rtol) def run_backbone_test( self, @@ -567,6 +569,15 @@ def run_task_test( ds = tf.data.Dataset.from_tensor_slices(train_data).batch(batch_size) x, y, sw = keras.utils.unpack_x_y_sample_weight(train_data) + # Test: the tree struct output by the + # preprocessor must match what model expects. + preprocessed_data = preprocessor(*train_data)[0] + tree.assert_same_structure( + preprocessed_data, + task._inputs_struct, + check_types=False, + ) + # Test predict. output = task.predict(x) if expected_output_shape is not None: diff --git a/keras_hub/src/tokenizers/byte_pair_tokenizer.py b/keras_hub/src/tokenizers/byte_pair_tokenizer.py index 41cef2b652..a7447c562e 100644 --- a/keras_hub/src/tokenizers/byte_pair_tokenizer.py +++ b/keras_hub/src/tokenizers/byte_pair_tokenizer.py @@ -43,7 +43,11 @@ SPLIT_PATTERN_1 = SPLIT_PATTERN_1.replace( "{special_spaces}", SPECIAL_WHITESPACES ) -SPLIT_PATTERN_2 = rf"""[\s६{SPECIAL_WHITESPACES}]$""" + +# The pattern " \t\r\f\v" is the same as \s "all spaces" but without the \n. +# Multiple \n\n\n in sequence must not be split for Llama3. +# SPLIT_PATTERN_2 = rf"""[\s६{SPECIAL_WHITESPACES}]$""" +SPLIT_PATTERN_2 = rf"""[ \t\r\f\v६{SPECIAL_WHITESPACES}]$""" def create_alts_for_unsplittable_tokens(unsplittable_tokens): diff --git a/keras_hub/src/tokenizers/byte_pair_tokenizer_test.py b/keras_hub/src/tokenizers/byte_pair_tokenizer_test.py index 5995df2fed..1aef54e214 100644 --- a/keras_hub/src/tokenizers/byte_pair_tokenizer_test.py +++ b/keras_hub/src/tokenizers/byte_pair_tokenizer_test.py @@ -1,5 +1,4 @@ import keras -import pytest import tensorflow as tf from keras_hub.src.tests.test_case import TestCase @@ -15,7 +14,6 @@ ) -@pytest.mark.large class BytePairTokenizerTest(TestCase): def setUp(self): super().setUp() @@ -111,6 +109,24 @@ def test_whitespace_split(self): encoded = self.tokenizer(input_data) self.assertAllEqual(encoded, [1437, 1437, 50140, 50118, 29]) + # This is important for Llama3 which uses the \n\n sequence in chat + # templates: \n\n must be tokenized as a single token + input_data = "Hello\n\nHello" + encoded = self.tokenizer(input_data) + self.assertAllEqual(encoded, [31414, 50140, 31414]) + + input_data = "Hello\n\n\n\nHello" + encoded = self.tokenizer(input_data) + self.assertAllEqual(encoded, [31414, 50140, 50140, 31414]) + + input_data = "Hello\n\n" + encoded = self.tokenizer(input_data) + self.assertAllEqual(encoded, [31414, 50140]) + + input_data = "Hello\n\n\n\n" + encoded = self.tokenizer(input_data) + self.assertAllEqual(encoded, [31414, 50140, 50140]) + def test_special_whitespace(self): input_data = "\xa0 \xa0 \x3000 s" encoded = self.tokenizer(input_data) diff --git a/keras_hub/src/tokenizers/tokenizer.py b/keras_hub/src/tokenizers/tokenizer.py index b97efae444..5e8986a89e 100644 --- a/keras_hub/src/tokenizers/tokenizer.py +++ b/keras_hub/src/tokenizers/tokenizer.py @@ -66,7 +66,7 @@ def detokenize(self, inputs): backbone_cls = None def __init__(self, *args, **kwargs): - self.config_name = kwargs.pop("config_name", TOKENIZER_CONFIG_FILE) + self.config_file = kwargs.pop("config_file", TOKENIZER_CONFIG_FILE) super().__init__(*args, **kwargs) self.file_assets = None @@ -178,7 +178,7 @@ def get_config(self): config = super().get_config() config.update( { - "config_name": self.config_name, + "config_file": self.config_file, } ) return config @@ -199,11 +199,11 @@ def call(self, inputs, *args, training=None, **kwargs): def load_preset_assets(self, preset): asset_path = None for asset in self.file_assets: - subdir = self.config_name.split(".")[0] + subdir = self.config_file.split(".")[0] preset_path = os.path.join(ASSET_DIR, subdir, asset) asset_path = get_file(preset, preset_path) - tokenizer_config_name = os.path.dirname(asset_path) - self.load_assets(tokenizer_config_name) + tokenizer_config_file = os.path.dirname(asset_path) + self.load_assets(tokenizer_config_file) @classproperty def presets(cls): @@ -214,7 +214,7 @@ def presets(cls): def from_preset( cls, preset, - config_name=TOKENIZER_CONFIG_FILE, + config_file=TOKENIZER_CONFIG_FILE, **kwargs, ): """Instantiate a `keras_hub.models.Tokenizer` from a model preset. @@ -260,4 +260,4 @@ class like `keras_hub.models.Tokenizer.from_preset()`, or from backbone_cls = loader.check_backbone_class() if cls.backbone_cls != backbone_cls: cls = find_subclass(preset, cls, backbone_cls) - return loader.load_tokenizer(cls, config_name, **kwargs) + return loader.load_tokenizer(cls, config_file, **kwargs) diff --git a/keras_hub/src/utils/pipeline_model.py b/keras_hub/src/utils/pipeline_model.py index 68bc4d8877..f874b057fe 100644 --- a/keras_hub/src/utils/pipeline_model.py +++ b/keras_hub/src/utils/pipeline_model.py @@ -232,7 +232,7 @@ def train_on_batch( ): data = self.preprocess_samples(x, y, sample_weight) x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(data) - x = ops.convert_to_tensor(x) + x = tree.map_structure(ops.convert_to_tensor, x) if y is not None: y = ops.convert_to_tensor(y) if sample_weight is not None: @@ -253,7 +253,7 @@ def test_on_batch( ): data = self.preprocess_samples(x, y, sample_weight) x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(data) - x = ops.convert_to_tensor(x) + x = tree.map_structure(ops.convert_to_tensor, x) if y is not None: y = ops.convert_to_tensor(y) if sample_weight is not None: @@ -272,7 +272,7 @@ def predict_on_batch( ): data = self.preprocess_samples(x) x, _, _ = keras.utils.unpack_x_y_sample_weight(data) - x = ops.convert_to_tensor(x) + x = tree.map_structure(ops.convert_to_tensor, x) return super().predict_on_batch( x=x, **kwargs, diff --git a/keras_hub/src/utils/preset_utils.py b/keras_hub/src/utils/preset_utils.py index 65af19df7f..52aad373a0 100644 --- a/keras_hub/src/utils/preset_utils.py +++ b/keras_hub/src/utils/preset_utils.py @@ -563,10 +563,8 @@ def get_backbone_kwargs(self, **kwargs): backbone_kwargs["dtype"] = kwargs.pop("dtype", None) # Forward `height` and `width` to backbone when using `TextToImage`. - if "height" in kwargs: - backbone_kwargs["height"] = kwargs.pop("height", None) - if "width" in kwargs: - backbone_kwargs["width"] = kwargs.pop("width", None) + if "image_shape" in kwargs: + backbone_kwargs["image_shape"] = kwargs.pop("image_shape", None) return backbone_kwargs, kwargs @@ -578,7 +576,7 @@ def load_backbone(self, cls, load_weights, **kwargs): """Load the backbone model from the preset.""" raise NotImplementedError - def load_tokenizer(self, cls, config_name=TOKENIZER_CONFIG_FILE, **kwargs): + def load_tokenizer(self, cls, config_file=TOKENIZER_CONFIG_FILE, **kwargs): """Load a tokenizer layer from the preset.""" raise NotImplementedError @@ -609,7 +607,7 @@ def load_task(self, cls, load_weights, load_task_weights, **kwargs): return cls(**kwargs) def load_preprocessor( - self, cls, config_name=PREPROCESSOR_CONFIG_FILE, **kwargs + self, cls, config_file=PREPROCESSOR_CONFIG_FILE, **kwargs ): """Load a prepocessor layer from the preset. @@ -632,8 +630,8 @@ def load_backbone(self, cls, load_weights, **kwargs): backbone.load_weights(get_file(self.preset, MODEL_WEIGHTS_FILE)) return backbone - def load_tokenizer(self, cls, config_name=TOKENIZER_CONFIG_FILE, **kwargs): - tokenizer_config = load_json(self.preset, config_name) + def load_tokenizer(self, cls, config_file=TOKENIZER_CONFIG_FILE, **kwargs): + tokenizer_config = load_json(self.preset, config_file) tokenizer = load_serialized_object(tokenizer_config, **kwargs) if hasattr(tokenizer, "load_preset_assets"): tokenizer.load_preset_assets(self.preset) @@ -678,13 +676,13 @@ def load_task(self, cls, load_weights, load_task_weights, **kwargs): return task def load_preprocessor( - self, cls, config_name=PREPROCESSOR_CONFIG_FILE, **kwargs + self, cls, config_file=PREPROCESSOR_CONFIG_FILE, **kwargs ): # If there is no `preprocessing.json` or it's for the wrong class, # delegate to the super class loader. - if not check_file_exists(self.preset, config_name): + if not check_file_exists(self.preset, config_file): return super().load_preprocessor(cls, **kwargs) - preprocessor_json = load_json(self.preset, config_name) + preprocessor_json = load_json(self.preset, config_file) if not issubclass(check_config_class(preprocessor_json), cls): return super().load_preprocessor(cls, **kwargs) # We found a `preprocessing.json` with a complete config for our class. diff --git a/keras_hub/src/utils/preset_utils_test.py b/keras_hub/src/utils/preset_utils_test.py index 00baf28235..9d36428698 100644 --- a/keras_hub/src/utils/preset_utils_test.py +++ b/keras_hub/src/utils/preset_utils_test.py @@ -18,6 +18,7 @@ class PresetUtilsTest(TestCase): + @pytest.mark.large def test_preset_errors(self): with self.assertRaisesRegex(ValueError, "must be a string"): AlbertTextClassifier.from_preset(AlbertTextClassifier) @@ -34,6 +35,7 @@ def test_preset_errors(self): with self.assertRaisesRegex(ValueError, "class keras_hub>BortBackbone"): BertBackbone.from_preset(preset_dir) + @pytest.mark.large def test_upload_empty_preset(self): temp_dir = self.get_temp_dir() empty_preset = os.path.join(temp_dir, "empty") diff --git a/keras_hub/src/utils/timm/convert_mobilenet.py b/keras_hub/src/utils/timm/convert_mobilenet.py index 7b4cf4c8e1..e2de6d8e34 100644 --- a/keras_hub/src/utils/timm/convert_mobilenet.py +++ b/keras_hub/src/utils/timm/convert_mobilenet.py @@ -30,8 +30,8 @@ def convert_backbone_config(timm_config): stackwise_se_ratio = [ [None, None], [0.25, 0.25, 0.25], - [0.3, 0.3], - [0.3, 0.25, 0.25], + [0.25, 0.25], + [0.25, 0.25, 0.25], ] stackwise_activation = [ ["relu", "relu"], @@ -39,64 +39,12 @@ def convert_backbone_config(timm_config): ["hard_swish", "hard_swish"], ["hard_swish", "hard_swish", "hard_swish"], ] + stackwise_padding = [[1, 1], [2, 2, 2], [2, 2], [2, 2, 2]] output_num_filters = 1024 input_num_filters = 16 depthwise_filters = 8 squeeze_and_excite = 0.5 last_layer_filter = 288 - - # elif timm_architecture == "mobilenetv2_050": - # stackwise_num_blocks = ([2, 3, 4, 3, 3, 1],) - # stackwise_expansion = ( - # [ - # [48, 96], - # [96, 96, 96], - # [96, 192, 192, 192], - # [192, 288, 288], - # [288, 480, 480], - # [480], - # ], - # ) - # stackwise_num_filters = ( - # [ - # [16, 16], - # [16, 16, 16], - # [32, 32, 32, 32], - # [48, 48, 48], - # [80, 80, 80], - # [160], - # ], - # ) - # stackwise_kernel_size = ( - # [[3, 3], [3, 3, 3], [3, 3, 3, 3], [3, 3, 3], [3, 3, 3], [3]], - # ) - # stackwise_num_strides = ( - # [[2, 1], [2, 1, 1], [2, 1, 1, 1], [1, 1, 1], [2, 1, 1], [1]], - # ) - # stackwise_se_ratio = ( - # [ - # [None, None], - # [None, None, None], - # [None, None, None, None], - # [None, None, None], - # [None, None, None], - # [None], - # ], - # ) - # stackwise_activation = ( - # [ - # ["relu6", "relu6"], - # ["relu6", "relu6", "relu6"], - # ["relu6", "relu6", "relu6", "relu6"], - # ["relu6", "relu6", "relu6"], - # ["relu6", "relu6", "relu6"], - # ["relu6"], - # ], - # ) - # output_num_filters = 1280 - # input_num_filters = 16 - # depthwise_filters = 8 - # squeeze_and_excite = None else: raise ValueError( f"Currently, the architecture {timm_architecture} is not supported." @@ -114,6 +62,7 @@ def convert_backbone_config(timm_config): stackwise_num_strides=stackwise_num_strides, stackwise_se_ratio=stackwise_se_ratio, stackwise_activation=stackwise_activation, + stackwise_padding=stackwise_padding, output_num_filters=output_num_filters, output_activation=output_activation, last_layer_filter=last_layer_filter, @@ -122,6 +71,7 @@ def convert_backbone_config(timm_config): def convert_weights(backbone, loader, timm_config): def port_conv2d(keras_layer_name, hf_weight_prefix): + print(f"porting weights {hf_weight_prefix} -> {keras_layer_name}") loader.port_weight( backbone.get_layer(keras_layer_name).kernel, hf_weight_key=f"{hf_weight_prefix}.weight", @@ -129,6 +79,7 @@ def port_conv2d(keras_layer_name, hf_weight_prefix): ) def port_batch_normalization(keras_layer_name, hf_weight_prefix): + print(f"porting weights {hf_weight_prefix} -> {keras_layer_name}") loader.port_weight( backbone.get_layer(keras_layer_name).gamma, hf_weight_key=f"{hf_weight_prefix}.weight", @@ -145,9 +96,11 @@ def port_batch_normalization(keras_layer_name, hf_weight_prefix): backbone.get_layer(keras_layer_name).moving_variance, hf_weight_key=f"{hf_weight_prefix}.running_var", ) - - version = "v3" if backbone.output_activation == "hard_swish" else "v2" - + loader.port_weight( + backbone.get_layer(keras_layer_name).moving_variance, + hf_weight_key=f"{hf_weight_prefix}.running_var", + ) + # Stem port_conv2d("input_conv", "conv_stem") port_batch_normalization("input_batch_norm", "bn1") @@ -155,6 +108,7 @@ def port_batch_normalization(keras_layer_name, hf_weight_prefix): # DepthWise Block (block 0) hf_name = "blocks.0.0" keras_name = "block_0_0" + port_conv2d(f"{keras_name}_conv1", f"{hf_name}.conv_dw") port_batch_normalization(f"{keras_name}_bn1", f"{hf_name}.bn1") @@ -196,14 +150,10 @@ def port_batch_normalization(keras_layer_name, hf_weight_prefix): f"block_{num_stacks+1}_0_bn", f"blocks.{num_stacks+1}.0.bn1" ) - if version == "v3": - hf_name = f"blocks.{num_stacks+1}.0" - keras_name = "Dfs" port_conv2d("output_conv", "conv_head") # if version == "v2": # port_batch_normalization("output_batch_norm", "bn2") - def convert_head(task, loader, timm_config): prefix = "classifier." loader.port_weight( diff --git a/keras_hub/src/utils/timm/convert_mobilenet_test.py b/keras_hub/src/utils/timm/convert_mobilenet_test.py index 4d036ae033..59c504b306 100644 --- a/keras_hub/src/utils/timm/convert_mobilenet_test.py +++ b/keras_hub/src/utils/timm/convert_mobilenet_test.py @@ -13,7 +13,7 @@ def test_convert_mobilenet_backbone(self): "hf://timm/mobilenetv3_small_050.lamb_in1k" ) outputs = model.predict(ops.ones((1, 224, 224, 3))) - self.assertEqual(outputs.shape, (1, 14, 14, 1024)) + self.assertEqual(outputs.shape, (1, 7, 7, 1024)) @pytest.mark.large def test_convert_mobilenet_classifier(self): diff --git a/keras_hub/src/utils/timm/convert_vgg.py b/keras_hub/src/utils/timm/convert_vgg.py new file mode 100644 index 0000000000..445d0ee436 --- /dev/null +++ b/keras_hub/src/utils/timm/convert_vgg.py @@ -0,0 +1,85 @@ +from typing import Any + +import numpy as np + +from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone +from keras_hub.src.models.vgg.vgg_image_classifier import VGGImageClassifier + +backbone_cls = VGGBackbone + + +REPEATS_BY_SIZE = { + "vgg11": [1, 1, 2, 2, 2], + "vgg13": [2, 2, 2, 2, 2], + "vgg16": [2, 2, 3, 3, 3], + "vgg19": [2, 2, 4, 4, 4], +} + + +def convert_backbone_config(timm_config): + architecture = timm_config["architecture"] + stackwise_num_repeats = REPEATS_BY_SIZE[architecture] + return dict( + stackwise_num_repeats=stackwise_num_repeats, + stackwise_num_filters=[64, 128, 256, 512, 512], + ) + + +def convert_conv2d( + model, + loader, + keras_layer_name: str, + hf_layer_name: str, +): + loader.port_weight( + model.get_layer(keras_layer_name).kernel, + hf_weight_key=f"{hf_layer_name}.weight", + hook_fn=lambda x, _: np.transpose(x, (2, 3, 1, 0)), + ) + loader.port_weight( + model.get_layer(keras_layer_name).bias, + hf_weight_key=f"{hf_layer_name}.bias", + ) + + +def convert_weights( + backbone: VGGBackbone, + loader, + timm_config: dict[Any], +): + architecture = timm_config["architecture"] + stackwise_num_repeats = REPEATS_BY_SIZE[architecture] + + hf_index_to_keras_layer_name = {} + layer_index = 0 + for block_index, repeats_in_block in enumerate(stackwise_num_repeats): + for repeat_index in range(repeats_in_block): + hf_index = layer_index + layer_index += 2 # Conv + activation layers. + layer_name = f"block{block_index + 1}_conv{repeat_index + 1}" + hf_index_to_keras_layer_name[hf_index] = layer_name + layer_index += 1 # Pooling layer after blocks. + + for hf_index, keras_layer_name in hf_index_to_keras_layer_name.items(): + convert_conv2d( + backbone, loader, keras_layer_name, f"features.{hf_index}" + ) + + +def convert_head( + task: VGGImageClassifier, + loader, + timm_config: dict[Any], +): + convert_conv2d(task.head, loader, "fc1", "pre_logits.fc1") + convert_conv2d(task.head, loader, "fc2", "pre_logits.fc2") + + loader.port_weight( + task.head.get_layer("predictions").kernel, + hf_weight_key="head.fc.weight", + hook_fn=lambda x, _: np.transpose(np.squeeze(x)), + ) + loader.port_weight( + task.head.get_layer("predictions").bias, + hf_weight_key="head.fc.bias", + ) diff --git a/keras_hub/src/utils/timm/preset_loader.py b/keras_hub/src/utils/timm/preset_loader.py index 65149b042f..392f432bb1 100644 --- a/keras_hub/src/utils/timm/preset_loader.py +++ b/keras_hub/src/utils/timm/preset_loader.py @@ -6,6 +6,7 @@ from keras_hub.src.utils.timm import convert_densenet from keras_hub.src.utils.timm import convert_mobilenet from keras_hub.src.utils.timm import convert_resnet +from keras_hub.src.utils.timm import convert_vgg from keras_hub.src.utils.transformers.safetensor_utils import SafetensorLoader @@ -19,6 +20,8 @@ def __init__(self, preset, config): self.converter = convert_densenet elif "mobilenet" in architecture: self.converter = convert_mobilenet + elif "vgg" in architecture: + self.converter = convert_vgg else: raise ValueError( "KerasHub has no converter for timm models " diff --git a/keras_hub/src/utils/transformers/convert_llama3.py b/keras_hub/src/utils/transformers/convert_llama3.py index 08e982e862..75c7eb801c 100644 --- a/keras_hub/src/utils/transformers/convert_llama3.py +++ b/keras_hub/src/utils/transformers/convert_llama3.py @@ -107,10 +107,26 @@ def convert_tokenizer(cls, preset, **kwargs): vocab = tokenizer_config["model"]["vocab"] merges = tokenizer_config["model"]["merges"] - bot = tokenizer_config["added_tokens"][0] # begin of text - eot = tokenizer_config["added_tokens"][1] # end of text - - vocab[bot["content"]] = bot["id"] - vocab[eot["content"]] = eot["id"] + # Load all special tokens with the exception of "reserved" ones. + special_tokens = set() + for token in tokenizer_config["added_tokens"]: + if not token["content"].startswith("<|reserved_special_token_"): + vocab[token["content"]] = token["id"] + special_tokens.add(token["content"]) + + # Load text start and stop tokens from the config. + # Llama3 uses the <|end_of_text|> end token for regular models + # but uses <|eot_id|> for instruction-tuned variants. + tokenizer_config2 = load_json(preset, "tokenizer_config.json") + bos_token = tokenizer_config2["bos_token"] + eos_token = tokenizer_config2["eos_token"] + + kwargs.update( + { + "bos_token": bos_token, + "eos_token": eos_token, + "misc_special_tokens": special_tokens, + } + ) return cls(vocabulary=vocab, merges=merges, **kwargs) diff --git a/keras_hub/src/version_utils.py b/keras_hub/src/version_utils.py index 1b36b8e41f..0a67b13192 100644 --- a/keras_hub/src/version_utils.py +++ b/keras_hub/src/version_utils.py @@ -1,7 +1,7 @@ from keras_hub.src.api_export import keras_hub_export # Unique source of truth for the version number. -__version__ = "0.16.1" +__version__ = "0.17.0.dev0" @keras_hub_export("keras_hub.version") diff --git a/tools/checkpoint_conversion/convert_mix_transformer.py b/tools/checkpoint_conversion/convert_mix_transformer.py new file mode 100644 index 0000000000..6419cc405e --- /dev/null +++ b/tools/checkpoint_conversion/convert_mix_transformer.py @@ -0,0 +1,196 @@ +# Usage example +# python tools/checkpoint_conversion/convert_mix_transformer.py --preset "B0_ade_512" + +from absl import app +from absl import flags +from transformers import SegformerForSemanticSegmentation + +import keras_hub + +FLAGS = flags.FLAGS + + +DOWNLOAD_URLS = { + "B0_ade_512": "nvidia/segformer-b0-finetuned-ade-512-512", + "B1_ade_512": "nvidia/segformer-b1-finetuned-ade-512-512", + "B2_ade_512": "nvidia/segformer-b2-finetuned-ade-512-512", + "B3_ade_512": "nvidia/segformer-b3-finetuned-ade-512-512", + "B4_ade_512": "nvidia/segformer-b4-finetuned-ade-512-512", + "B5_ade_640": "nvidia/segformer-b5-finetuned-ade-640-640", + "B0_cityscapes_1024": "nvidia/segformer-b0-finetuned-cityscapes-1024-1024", + "B1_cityscapes_1024": "nvidia/segformer-b1-finetuned-cityscapes-1024-1024", + "B2_cityscapes_1024": "nvidia/segformer-b2-finetuned-cityscapes-1024-1024", + "B3_cityscapes_1024": "nvidia/segformer-b3-finetuned-cityscapes-1024-1024", + "B4_cityscapes_1024": "nvidia/segformer-b4-finetuned-cityscapes-1024-1024", + "B5_cityscapes_1024": "nvidia/segformer-b5-finetuned-cityscapes-1024-1024", +} + + +MODEL_CONFIGS = { + "B0": {"hidden_dims": [32, 64, 160, 256], "depths": [2, 2, 2, 2]}, + "B1": {"hidden_dims": [64, 128, 320, 512], "depths": [2, 2, 2, 2]}, + "B2": {"hidden_dims": [64, 128, 320, 512], "depths": [3, 4, 6, 3]}, + "B3": {"hidden_dims": [64, 128, 320, 512], "depths": [3, 4, 18, 3]}, + "B4": {"hidden_dims": [64, 128, 320, 512], "depths": [3, 8, 27, 3]}, + "B5": {"hidden_dims": [64, 128, 320, 512], "depths": [3, 6, 40, 3]}, +} + +flags.DEFINE_string( + "preset", None, f'Must be one of {",".join(DOWNLOAD_URLS.keys())}' +) + + +def get_indices_from_depths(depths): + proj_indices = [] + norm_indices = [] + hierarchical_encoder_indices = [] + + current_layer_idx = 1 + + for layer_idx, depth in enumerate(depths): + # Add projection index (before the hierarchical encoders) + proj_indices.append(current_layer_idx) + + # Hierarchical encoder block indices + for block_idx in range(depth): + hierarchical_encoder_indices.append( + (current_layer_idx + 1, layer_idx, block_idx) + ) + current_layer_idx += 1 + + # Add normalization index (after the hierarchical encoders) + norm_indices.append(current_layer_idx + 1) + + # Skip to the next layer after output_level + current_layer_idx += 3 + + return proj_indices, norm_indices, hierarchical_encoder_indices + + +def set_conv_weights(conv_layer, state_dict): + conv_weights = state_dict["weight"].numpy().transpose(2, 3, 1, 0) + conv_bias = state_dict["bias"].numpy() + conv_layer.set_weights([conv_weights, conv_bias]) + + +def set_dwconv_weights(conv_layer, state_dict): + conv_weights = state_dict["dwconv.weight"].numpy().transpose(2, 3, 0, 1) + conv_bias = state_dict["dwconv.bias"].numpy() + conv_layer.set_weights([conv_weights, conv_bias]) + + +def set_layer_norm_weights(layer_norm, state_dict): + gamma = state_dict["weight"].numpy() + beta = state_dict["bias"].numpy() + layer_norm.set_weights([gamma, beta]) + + +def set_dense_weights(dense_layer, state_dict): + weight = state_dict["weight"].numpy().T + bias = state_dict["bias"].numpy() + dense_layer.set_weights([weight, bias]) + + +def set_hierarchical_encoder_weights(keras_layer, pytorch_layer, key): + + set_layer_norm_weights( + keras_layer.norm1, pytorch_layer.layer_norm_1.state_dict() + ) + + set_dense_weights( + keras_layer.attn.q, pytorch_layer.attention.self.query.state_dict() + ) + set_dense_weights( + keras_layer.attn.k, pytorch_layer.attention.self.key.state_dict() + ) + set_dense_weights( + keras_layer.attn.v, pytorch_layer.attention.self.value.state_dict() + ) + set_dense_weights( + keras_layer.attn.proj, pytorch_layer.attention.output.dense.state_dict() + ) + + if keras_layer.attn.sr_ratio > 1: + set_conv_weights( + keras_layer.attn.sr, pytorch_layer.attention.self.sr.state_dict() + ) + set_layer_norm_weights( + keras_layer.attn.norm, + pytorch_layer.attention.self.layer_norm.state_dict(), + ) + + set_layer_norm_weights( + keras_layer.norm2, pytorch_layer.layer_norm_2.state_dict() + ) + + set_dense_weights( + keras_layer.mlp.fc1, pytorch_layer.mlp.dense1.state_dict() + ) + set_dwconv_weights( + keras_layer.mlp.dwconv, pytorch_layer.mlp.dwconv.state_dict() + ) + set_dense_weights( + keras_layer.mlp.fc2, pytorch_layer.mlp.dense2.state_dict() + ) + + +def main(_): + print("\n-> Loading HuggingFace model") + model = SegformerForSemanticSegmentation.from_pretrained( + DOWNLOAD_URLS[FLAGS.preset] + ) + original_mit = original_mit = model.segformer.encoder + + model_type = FLAGS.preset.split("_")[0] + print("\n-> Instantiating KerasHub Model") + keras_mit = keras_hub.models.MiTBackbone( + depths=MODEL_CONFIGS[model_type]["depths"], + image_shape=(224, 224, 3), + hidden_dims=MODEL_CONFIGS[model_type]["hidden_dims"], + num_layers=4, + blockwise_num_heads=[1, 2, 5, 8], + blockwise_sr_ratios=[8, 4, 2, 1], + max_drop_path_rate=0.1, + patch_sizes=[7, 3, 3, 3], + strides=[4, 2, 2, 2], + ) + + # Indices for the different patch embeddings and layer norms + proj_indices, layer_norm_indices, hierarchical_encoder_indices = ( + get_indices_from_depths(MODEL_CONFIGS[model_type]["depths"]) + ) + + print("\n-> Converting weights...") + # Loop through the indices to set convolutional and normalization weights + for i, idx in enumerate(proj_indices): + set_conv_weights( + keras_mit.layers[idx].proj, + original_mit.patch_embeddings[i].proj.state_dict(), + ) + set_layer_norm_weights( + keras_mit.layers[idx].norm, + original_mit.patch_embeddings[i].layer_norm.state_dict(), + ) + + # Set layer normalization weights + for i, idx in enumerate(layer_norm_indices): + set_layer_norm_weights( + keras_mit.layers[idx], original_mit.layer_norm[i].state_dict() + ) + + # Set hierarchical encoder weights + for layer_idx, block_idx, key in hierarchical_encoder_indices: + set_hierarchical_encoder_weights( + keras_mit.layers[layer_idx], + original_mit.block[block_idx][int(key)], + key=key, + ) + + directory = f"MiT_{FLAGS.preset}" + print(f"\n-> Saving converted KerasHub model in {directory}") + keras_mit.save_to_preset(directory) + + +if __name__ == "__main__": + flags.mark_flag_as_required("preset") + app.run(main) diff --git a/tools/checkpoint_conversion/convert_pali_gemma_checkpoints.py b/tools/checkpoint_conversion/convert_pali_gemma_checkpoints.py index d21bb8d82d..befb6093cf 100644 --- a/tools/checkpoint_conversion/convert_pali_gemma_checkpoints.py +++ b/tools/checkpoint_conversion/convert_pali_gemma_checkpoints.py @@ -1,17 +1,47 @@ +""" +python tools/checkpoint_conversion/convert_pali_gemma_checkpoints.py \ + --weights_path=paligemma-3b-mix-224.npz \ + --image_size=224 --checkpoint_name=pali_gemma_3b_mix_224 +python tools/checkpoint_conversion/convert_pali_gemma_checkpoints.py \ + --weights_path=paligemma-3b-mix-448.npz \ + --image_size=448 --checkpoint_name=pali_gemma_3b_mix_448 +python tools/checkpoint_conversion/convert_pali_gemma_checkpoints.py \ + --weights_path=paligemma-3b-pt-224.npz \ + --image_size=224 --checkpoint_name=pali_gemma_3b_224 +python tools/checkpoint_conversion/convert_pali_gemma_checkpoints.py \ + --weights_path=paligemma-3b-pt-448.npz \ + --image_size=448 --checkpoint_name=pali_gemma_3b_448 +python tools/checkpoint_conversion/convert_pali_gemma_checkpoints.py \ + --weights_path=paligemma-3b-pt-896.npz \ + --image_size=896 --checkpoint_name=pali_gemma_3b_896 +""" + import argparse import os import numpy as np +from keras_hub.src.models.pali_gemma.pali_gemma_backbone import ( + PaliGemmaBackbone, +) +from keras_hub.src.models.pali_gemma.pali_gemma_causal_lm import ( + PaliGemmaCausalLM, +) +from keras_hub.src.models.pali_gemma.pali_gemma_causal_lm_preprocessor import ( + PaliGemmaCausalLMPreprocessor, +) +from keras_hub.src.models.pali_gemma.pali_gemma_image_converter import ( + PaliGemmaImageConverter, +) +from keras_hub.src.models.pali_gemma.pali_gemma_tokenizer import ( + PaliGemmaTokenizer, +) + os.environ["KERAS_BACKEND"] = "jax" import keras # noqa: E402 from keras import ops # noqa: E402 -from keras_hub.src.models.pali_gemma.pali_gemma_backbone import ( # noqa: E402 - PaliGemmaBackbone, -) - # No GPU for conversion, makes memory management easier. os.environ["CUDA_VISIBLE_DEVICES"] = "-1" @@ -299,14 +329,39 @@ def main(args): pali_gemma_backbone_config = { "vit_num_layers": 27, "vit_hidden_dim": 1152, + "vocabulary_size": 257152, "image_size": args.image_size, + "num_layers": 18, + "num_query_heads": 8, + "num_key_value_heads": 1, + "hidden_dim": 2048, + "intermediate_dim": 32768, + "head_dim": 256, + "vit_patch_size": 14, + "vit_num_heads": 16, } - keras_model = PaliGemmaBackbone(**pali_gemma_backbone_config) + pg_image_converter = PaliGemmaImageConverter( + image_size=(args.image_size, args.image_size), + scale=1.0 / 127.5, + offset=-1, + ) + tokenizer = PaliGemmaTokenizer( + proto="vocabulary.spm", + ) + pg_presprocessor = PaliGemmaCausalLMPreprocessor( + tokenizer=tokenizer, image_converter=pg_image_converter + ) + pg_backbone = PaliGemmaBackbone(**pali_gemma_backbone_config) + keras_model = PaliGemmaCausalLM( + preprocessor=pg_presprocessor, backbone=pg_backbone + ) # This could be from kaggle or provide local dir path weights = np.load(args.weights_path) jax_weights = get_weights_as_numpy(weights, **pali_gemma_backbone_config) - keras_model = convert_pali_gemma_weights( - keras_model, jax_weights["params"], **pali_gemma_backbone_config + keras_model.backbone = convert_pali_gemma_weights( + keras_model.backbone, + jax_weights["params"], + **pali_gemma_backbone_config, ) # Specify preset name keras_model.save_to_preset(args.checkpoint_name) diff --git a/tools/checkpoint_conversion/convert_resnet_vd_checkpoints.py b/tools/checkpoint_conversion/convert_resnet_vd_checkpoints.py new file mode 100644 index 0000000000..a8a424f494 --- /dev/null +++ b/tools/checkpoint_conversion/convert_resnet_vd_checkpoints.py @@ -0,0 +1,317 @@ +#!/usr/bin/env python3 + +"""Converts ResNet_vd models from PaddleClas. + +Usage: python3 convert_resnet_vd_checkpoints.py + +ResNet_vd model weights from PaddleClas listed in `configurations` below will +be downloaded, saved as Keras model files and the resulting models will be +verified for numerical agreement with PaddleClas. + +Requirements: +pip3 install -q git+/~https://github.com/keras-team/keras-hub.git +pip3 install -q paddleclas paddlepaddle +""" + +import os +import re +import tarfile +import urllib.request + +import keras +import numpy as np +import paddle +import paddleclas +from paddleclas.deploy.python import preprocess as pc_preproc +from PIL import Image + +import keras_hub + +"""Architecture Specifications""" + +configurations = { + "ResNet18_vd": { + "stackwise_num_blocks": [2, 2, 2, 2], + "block_type": "basic_block_vd", + }, + "ResNet34_vd": { + "stackwise_num_blocks": [3, 4, 6, 3], + "block_type": "basic_block_vd", + }, + "ResNet50_vd": { + "stackwise_num_blocks": [3, 4, 6, 3], + "block_type": "bottleneck_block_vd", + }, + "ResNet50_vd_ssld": { + "stackwise_num_blocks": [3, 4, 6, 3], + "block_type": "bottleneck_block_vd", + }, + "ResNet50_vd_ssld_v2": { + "stackwise_num_blocks": [3, 4, 6, 3], + "block_type": "bottleneck_block_vd", + }, + "Fix_ResNet50_vd_ssld_v2": { + "stackwise_num_blocks": [3, 4, 6, 3], + "block_type": "bottleneck_block_vd", + }, + "ResNet101_vd": { + "stackwise_num_blocks": [3, 4, 23, 3], + "block_type": "bottleneck_block_vd", + }, + "ResNet101_vd_ssld": { + "stackwise_num_blocks": [3, 4, 23, 3], + "block_type": "bottleneck_block_vd", + }, + "ResNet152_vd": { + "stackwise_num_blocks": [3, 8, 36, 3], + "block_type": "bottleneck_block_vd", + }, + "ResNet200_vd": { + "stackwise_num_blocks": [3, 12, 48, 3], + "block_type": "bottleneck_block_vd", + }, +} + + +"""Download Files""" + +# Create the directory if it doesn't exist +os.makedirs("pretrained_models", exist_ok=True) +base_url = "https://paddle-imagenet-models-name.bj.bcebos.com/" + +for arch in configurations.keys(): + tar_file = f"{arch}_pretrained.tar" + download_url = f"{base_url}{tar_file}" + file_path = os.path.join("pretrained_models", tar_file) + + # Download the tar file + print(f"Downloading {tar_file}...") + urllib.request.urlretrieve(download_url, file_path) + + # Extract the tar file + print(f"Extracting {tar_file}...") + with tarfile.open(file_path, "r") as tar: + tar.extractall(path="pretrained_models", filter="data") + + +"""Model Conversion""" + + +def convert_paddle_to_keras(paddle_weights: dict, keras_model: keras.Model): + """Ports a paddle weights dictionary to a Keras model.""" + + def map_residual_layer_name(name: str): + """Translate a Keras ResNet_vd layer name to a PaddleClas ResNet + layer name prefix for a residual block.""" + branch_mapping = { + # this suffix addresses the specific conv layer within a block + 0: "1", + 1: "2a", + 2: "2b", + 3: "2c", + } + match = re.match( + r"^stack(?P\d)_block(?P\d+)_(?P\d)_(?Pbn|conv)", + name, + ) + assert match is not None + + # ResNet models have two different formats of layer name encodings + # in PaddleClas. first try a mapping in the form + # stack2_block3_1_conv -> res4b2_branch2a + paddle_address = ( + f'{int(match["stack"])+2}b{int(match["block"])}' + f'_branch{branch_mapping[int(match["conv"])]}' + ) + if match["type"] == "bn": + paddle_name = f"bn{paddle_address}" + elif match["type"] == "conv": + paddle_name = f"res{paddle_address}" + if any(name.startswith(paddle_name) for name in paddle_weights): + return paddle_name + + # if that was not successful, try a mapping like + # stack2_block3_1_conv -> res4c_branch2a + paddle_address = ( + f'{int(match["stack"])+2}{"abcdefghijkl"[int(match["block"])]}' + f'_branch{branch_mapping[int(match["conv"])]}' + ) + if match["type"] == "bn": + paddle_name = f"bn{paddle_address}" + elif match["type"] == "conv": + paddle_name = f"res{paddle_address}" + return paddle_name + + def map_layer_name(name: str): + """Translate a Keras ResNet_vd layer name to a PaddleClas ResNet layer + name prefix.""" + mapping = { + # stem layers + "conv1_conv": "conv1_1", + "conv1_bn": "bnv1_1", + "conv2_conv": "conv1_2", + "conv2_bn": "bnv1_2", + "conv3_conv": "conv1_3", + "conv3_bn": "bnv1_3", + } + return mapping.get(name) or map_residual_layer_name(name) + + def set_batchnorm_layer( + paddle_name_prefix: str, target_layer: keras.layers.Layer + ): + """Assign Keras BatchNorm layer weigths from Paddle weights.""" + target_layer.set_weights( + [ + paddle_weights.pop(f"{paddle_name_prefix}_scale"), + paddle_weights.pop(f"{paddle_name_prefix}_offset"), + paddle_weights.pop(f"{paddle_name_prefix}_mean"), + paddle_weights.pop(f"{paddle_name_prefix}_variance"), + ] + ) + + def set_conv_layer( + paddle_name_prefix: str, target_layer: keras.layers.Layer + ): + """Assign Keras Conv2D layer weights from Paddle weights.""" + if target_layer.use_bias: + target_layer.set_weights( + [ + np.transpose( + paddle_weights.pop(f"{paddle_name_prefix}_weights"), + (2, 3, 1, 0), + ), + paddle_weights.pop(f"{paddle_name_prefix}_bias"), + ] + ) + else: + target_layer.set_weights( + [ + np.transpose( + paddle_weights.pop(f"{paddle_name_prefix}_weights"), + (2, 3, 1, 0), + ) + ] + ) + + def set_dense_layer( + paddle_name_prefix: str, target_layer: keras.layers.Layer + ): + """Assign Keras Dense layer weights from Paddle weights.""" + if target_layer.use_bias: + target_layer.set_weights( + [ + paddle_weights.pop(f"{paddle_name_prefix}.w_0"), + paddle_weights.pop(f"{paddle_name_prefix}.b_0"), + ] + ) + else: + target_layer.set_weights( + [paddle_weights.pop(f"{paddle_name_prefix}.w_0")] + ) + + for layer in keras_model.backbone.layers: + # iterate over all layers that have parameters in the keras model, + # to ensure we process all weights in the Keras model + if layer.variables: + if isinstance(layer, keras.layers.Conv2D): + set_conv_layer(map_layer_name(layer.name), layer) + elif isinstance(layer, keras.layers.BatchNormalization): + set_batchnorm_layer(map_layer_name(layer.name), layer) + else: + raise TypeError("Unexpected layer type encountered in model") + set_dense_layer("fc_0", keras_model.get_layer("predictions")) + + # ensure we have consumed all weights, i.e. there are no leftover + # weights in the paddle model + assert len(paddle_weights) == 0 + + +"""Instantiate model architectures as indicated above and load PaddleClas +weights into the Keras model""" + +for architecture_name, architecture_config in configurations.items(): + print(f"Converting {architecture_name}") + backbone_model = keras_hub.models.ResNetBackbone( + input_conv_filters=[32, 32, 64], + input_conv_kernel_sizes=[3, 3, 3], + stackwise_num_filters=[64, 128, 256, 512], + stackwise_num_strides=[1, 2, 2, 2], + **architecture_config, + ) + image_converter = keras_hub.layers.ResNetImageConverter( + height=224, + width=224, + mean=[0.485, 0.456, 0.406], + variance=[0.229**2, 0.224**2, 0.225**2], + scale=1 / 255.0, + ) + resnet_preprocessor = keras_hub.models.ResNetImageClassifierPreprocessor( + image_converter + ) + classifier_model = keras_hub.models.ResNetImageClassifier( + backbone=backbone_model, + preprocessor=resnet_preprocessor, + num_classes=1000, + ) + paddle_model = paddle.load( + f"pretrained_models/{architecture_name}_pretrained" + ) + convert_paddle_to_keras(paddle_model, classifier_model) + classifier_model.save(f"{architecture_name}.keras") + classifier_model.save_to_preset(f"{architecture_name}") + print(f"Parameter count: {classifier_model.count_params()}") + +"""Check for Numerical Agreement + +Compare results when using PaddleClas with results when using our Keras models. +In general, PaddleClas appears to mainly target command-line utilisation +rather than offering an API. While PaddleClas model architectures can directly +be instantiated, this interface strangely only provides some of the pretrained +models (and doesn't appear to be documented anywhere). + +To ensure behaviour and performances when using PaddleClas as command-line tool +match our observed results, we here use `PaddleClas` directly. +""" + +urllib.request.urlretrieve( + "https://storage.googleapis.com/tensorflow/keras-applications/tests/elephant.jpg", + "elephant.jpg", +) + +print(f'{"Model": <25}Error') +for architecture_name in configurations: + # PaddleClas prediction + predictor = paddleclas.PaddleClas(model_name=architecture_name).predictor + # PaddleClas selects the top 5 predictions during + # postprocessing. turn this off. + predictor.postprocess = None + # for comparable results, manually perform resizing and cropping + preprocess_ops = [ + op + for op in predictor.preprocess_ops + if isinstance( + op, + ( + pc_preproc.NormalizeImage, + pc_preproc.ResizeImage, + pc_preproc.CropImage, + ), + ) + ] + predictor.preprocess_ops = [ + op for op in predictor.preprocess_ops if op not in preprocess_ops + ] + image = np.asarray(Image.open("elephant.jpg"), dtype=np.float32) + for op in preprocess_ops: + image = op(image) + paddle_prediction = predictor.predict(image) + + # Keras prediction + # in contrast to PaddleClas, Keras' predictions are not softmax'ed + keras_model = keras.saving.load_model(f"{architecture_name}.keras") + keras_prediction = keras_model(image[None]).numpy() + keras_prediction = keras.ops.softmax(keras_prediction) + + # compare + max_error = np.max(np.abs(paddle_prediction - keras_prediction)) + print(f"{architecture_name: <25}{max_error}") diff --git a/tools/checkpoint_conversion/convert_sam_checkpoints.py b/tools/checkpoint_conversion/convert_sam_checkpoints.py index 08f4f4a504..69cd1482cb 100644 --- a/tools/checkpoint_conversion/convert_sam_checkpoints.py +++ b/tools/checkpoint_conversion/convert_sam_checkpoints.py @@ -1,3 +1,7 @@ +# Get the huge PyTorch model weights from the following location +# curl -sSL https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth -o sam_vit_h_4b8939.pth +# curl -sSL https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth -o sam_vit_l_0b3195.pth +# curl -sSL https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth -o sam_vit_b_01ec64.pth import argparse import os diff --git a/tools/checkpoint_conversion/convert_segformer_checkpoints.py b/tools/checkpoint_conversion/convert_segformer_checkpoints.py new file mode 100644 index 0000000000..230cf5227d --- /dev/null +++ b/tools/checkpoint_conversion/convert_segformer_checkpoints.py @@ -0,0 +1,143 @@ +# Usage example +# python tools/checkpoint_conversion/convert_mix_transformer.py --preset "B0_ade_512" + +import numpy as np +from absl import app +from absl import flags +from transformers import SegformerForSemanticSegmentation + +import keras_hub +from keras_hub.src.models.segformer.segformer_image_segmenter_preprocessor import ( + SegFormerImageSegmenterPreprocessor, +) + +FLAGS = flags.FLAGS + +PROJECTION_FILTERS = { + "b0_ade20k_512": 256, + "b1_ade20k_512": 256, + "b2_ade20k_512": 768, + "b3_ade20k_512": 768, + "b4_ade20k_512": 768, + "b5_ade20k_640": 768, + "b0_cityscapes_1024": 256, + "b1_cityscapes_1024": 256, + "b2_cityscapes_1024": 768, + "b3_cityscapes_1024": 768, + "b4_cityscapes_1024": 768, + "b5_cityscapes_1024": 768, +} + + +DOWNLOAD_URLS = { + "b0_ade20k_512": "nvidia/segformer-b0-finetuned-ade-512-512", + "b1_ade20k_512": "nvidia/segformer-b1-finetuned-ade-512-512", + "b2_ade20k_512": "nvidia/segformer-b2-finetuned-ade-512-512", + "b3_ade20k_512": "nvidia/segformer-b3-finetuned-ade-512-512", + "b4_ade20k_512": "nvidia/segformer-b4-finetuned-ade-512-512", + "b5_ade20k_640": "nvidia/segformer-b5-finetuned-ade-640-640", + "b0_cityscapes_1024": "nvidia/segformer-b0-finetuned-cityscapes-1024-1024", + "b1_cityscapes_1024": "nvidia/segformer-b1-finetuned-cityscapes-1024-1024", + "b2_cityscapes_1024": "nvidia/segformer-b2-finetuned-cityscapes-1024-1024", + "b3_cityscapes_1024": "nvidia/segformer-b3-finetuned-cityscapes-1024-1024", + "b4_cityscapes_1024": "nvidia/segformer-b4-finetuned-cityscapes-1024-1024", + "b5_cityscapes_1024": "nvidia/segformer-b5-finetuned-cityscapes-1024-1024", +} + +flags.DEFINE_string( + "preset", None, f'Must be one of {",".join(DOWNLOAD_URLS.keys())}' +) + + +def set_conv_weights(conv_layer, state_dict): + conv_weights = state_dict["weight"].numpy().transpose(2, 3, 1, 0) + bias = None + if "bias" in state_dict.keys(): + bias = state_dict["bias"].numpy() + conv_layer.set_weights([conv_weights, bias]) + else: + conv_layer.set_weights([conv_weights]) + + +def set_dense_weights(dense_layer, state_dict): + weight = state_dict["weight"].numpy().T + bias = state_dict["bias"].numpy() + dense_layer.set_weights([weight, bias]) + + +def set_batchnorm_weights(bn_layer, state_dict): + gamma = state_dict["weight"].numpy() + beta = state_dict["bias"].numpy() + running_mean = state_dict["running_mean"].numpy() + running_var = state_dict["running_var"].numpy() + + bn_layer.set_weights([gamma, beta, running_mean, running_var]) + + +def main(_): + print("\n-> Loading HuggingFace model") + original_segformer = SegformerForSemanticSegmentation.from_pretrained( + DOWNLOAD_URLS[FLAGS.preset] + ) + + print("\n-> Instantiating KerasHub Model") + + resolution = int(FLAGS.preset.split("_")[-1]) + + encoder = keras_hub.models.MiTBackbone.from_preset( + "mit_" + FLAGS.preset, image_shape=(resolution, resolution, 3) + ) + segformer_backbone = keras_hub.models.SegFormerBackbone( + image_encoder=encoder, + projection_filters=PROJECTION_FILTERS[FLAGS.preset], + ) + num_classes = 150 if "ade20k" in FLAGS.preset else 19 + + preprocessor = SegFormerImageSegmenterPreprocessor() + segformer_segmenter = keras_hub.models.SegFormerImageSegmenter( + backbone=segformer_backbone, + num_classes=num_classes, + preprocessor=preprocessor, + ) + segformer_backbone(np.random.rand(1, resolution, resolution, 3)) + + set_dense_weights( + segformer_backbone.layers[5], + original_segformer.decode_head.linear_c[0].proj.state_dict(), + ) + set_dense_weights( + segformer_backbone.layers[4], + original_segformer.decode_head.linear_c[1].proj.state_dict(), + ) + set_dense_weights( + segformer_backbone.layers[3], + original_segformer.decode_head.linear_c[2].proj.state_dict(), + ) + set_dense_weights( + segformer_backbone.layers[2], + original_segformer.decode_head.linear_c[3].proj.state_dict(), + ) + set_conv_weights( + segformer_backbone.layers[-1].layers[0], + original_segformer.decode_head.linear_fuse.state_dict(), + ) + set_batchnorm_weights( + segformer_backbone.layers[-1].layers[1], + original_segformer.decode_head.batch_norm.state_dict(), + ) + + set_conv_weights( + segformer_segmenter.layers[-2], + original_segformer.decode_head.classifier.state_dict(), + ) + + print("\n-> Converting weights...") + + directory = f"SegFormer_{FLAGS.preset}" + print(f"\n-> Saving converted KerasHub model in {directory}") + segformer_segmenter.save_to_preset(directory) + + +if __name__ == "__main__": + flags.mark_flag_as_required("preset") + app.run(main) diff --git a/tools/checkpoint_conversion/convert_stable_diffusion_3_checkpoints.py b/tools/checkpoint_conversion/convert_stable_diffusion_3_checkpoints.py index 15b9691532..38e19cf107 100644 --- a/tools/checkpoint_conversion/convert_stable_diffusion_3_checkpoints.py +++ b/tools/checkpoint_conversion/convert_stable_diffusion_3_checkpoints.py @@ -113,8 +113,7 @@ def convert_model(preset, height, width): vae, clip_l, clip_g, - height=height, - width=width, + image_shape=(height, width, 3), name="stable_diffusion_3_backbone", ) return backbone @@ -130,23 +129,23 @@ def convert_preprocessor(): vocabulary, merges, pad_with_end_token=True, - config_name="clip_l_tokenizer.json", + config_file="clip_l_tokenizer.json", name="clip_l_tokenizer", ) clip_g_tokenizer = CLIPTokenizer( vocabulary, merges, - config_name="clip_g_tokenizer.json", + config_file="clip_g_tokenizer.json", name="clip_g_tokenizer", ) clip_l_preprocessor = CLIPPreprocessor( clip_l_tokenizer, - config_name="clip_l_preprocessor.json", + config_file="clip_l_preprocessor.json", name="clip_l_preprocessor", ) clip_g_preprocessor = CLIPPreprocessor( clip_g_tokenizer, - config_name="clip_g_preprocessor.json", + config_file="clip_g_preprocessor.json", name="clip_g_preprocessor", ) preprocessor = StableDiffusion3TextToImagePreprocessor( @@ -310,19 +309,19 @@ def port_diffuser(preset, filename, model): ) port_dense(loader, model.context_embedding, "context_embedder") port_dense( - loader, model.vector_embedding.layers[0], "y_embedder.mlp.0" + loader, model.vector_embedding.dense1, "y_embedder.mlp.0" ) port_dense( - loader, model.vector_embedding.layers[1], "y_embedder.mlp.2" + loader, model.vector_embedding.dense2, "y_embedder.mlp.2" ) port_dense( loader, - model.timestep_embedding.mlp.layers[0], + model.timestep_embedding.mlp.dense1, "t_embedder.mlp.0", ) port_dense( loader, - model.timestep_embedding.mlp.layers[1], + model.timestep_embedding.mlp.dense2, "t_embedder.mlp.2", ) @@ -338,7 +337,7 @@ def port_diffuser(preset, filename, model): prefix = f"joint_blocks.{i}.{block_name}" port_dense( loader, - block.adaptive_norm_modulation.layers[1], + block.ada_layer_norm.dense, f"{prefix}.adaLN_modulation.1", ) port_dense( @@ -351,18 +350,16 @@ def port_diffuser(preset, filename, model): port_dense( loader, block.attention_proj, f"{prefix}.attn.proj" ) - port_dense(loader, block.mlp.layers[0], f"{prefix}.mlp.fc1") - port_dense(loader, block.mlp.layers[1], f"{prefix}.mlp.fc2") + port_dense(loader, block.mlp.dense1, f"{prefix}.mlp.fc1") + port_dense(loader, block.mlp.dense2, f"{prefix}.mlp.fc2") # Output layer port_dense( loader, - model.output_layer.adaptive_norm_modulation.layers[1], + model.output_ada_layer_norm.dense, "final_layer.adaLN_modulation.1", ) - port_dense( - loader, model.output_layer.output_dense, "final_layer.linear" - ) + port_dense(loader, model.output_dense, "final_layer.linear") return model def port_vae(preset, filename, model): @@ -534,8 +531,7 @@ def main(_): keras_preprocessor.save_to_preset(preset) # Set the image size to 1024, the same as in huggingface/diffusers. - keras_model.height = 1024 - keras_model.width = 1024 + keras_model.image_shape = (1024, 1024, 3) keras_model.save_to_preset(preset) print(f"🏁 Preset saved to ./{preset}.") diff --git a/tools/checkpoint_conversion/convert_vgg_checkpoints.py b/tools/checkpoint_conversion/convert_vgg_checkpoints.py new file mode 100644 index 0000000000..fea9aaf01f --- /dev/null +++ b/tools/checkpoint_conversion/convert_vgg_checkpoints.py @@ -0,0 +1,116 @@ +"""Loads an external VGG model and saves it in Keras format. + +Optionally uploads the model to Keras if the `--upload_uri` flag is passed. + +python tools/checkpoint_conversion/convert_vgg_checkpoints.py \ + --preset vgg11 --upload_uri kaggle://kerashub/vgg/keras/vgg11 +""" + +import os +import shutil + +import keras +import numpy as np +import PIL +import timm +import torch +from absl import app +from absl import flags + +import keras_hub + +PRESET_MAP = { + "vgg11": "timm/vgg11.tv_in1k", + "vgg13": "timm/vgg13.tv_in1k", + "vgg16": "timm/vgg16.tv_in1k", + "vgg19": "timm/vgg19.tv_in1k", + # TODO(jeffcarp): Add BN variants. +} + + +PRESET = flags.DEFINE_string( + "preset", + None, + "Must be a valid `VGG` preset from KerasHub", + required=True, +) +UPLOAD_URI = flags.DEFINE_string( + "upload_uri", + None, + 'Could be "kaggle://keras/{variant}/keras/{preset}_int8"', +) + + +def validate_output(keras_model, timm_model): + file = keras.utils.get_file( + origin=( + "https://storage.googleapis.com/keras-cv/" + "models/paligemma/cow_beach_1.png" + ) + ) + image = PIL.Image.open(file) + batch = np.array([image]) + + # Preprocess with Timm. + data_config = timm.data.resolve_model_data_config(timm_model) + data_config["crop_pct"] = 1.0 # Stop timm from cropping. + transforms = timm.data.create_transform(**data_config, is_training=False) + timm_preprocessed = transforms(image) + timm_preprocessed = keras.ops.transpose(timm_preprocessed, axes=(1, 2, 0)) + timm_preprocessed = keras.ops.expand_dims(timm_preprocessed, 0) + + # Preprocess with Keras. + keras_preprocessed = keras_model.preprocessor(batch) + + # Call with Timm. Use the keras preprocessed image so we can keep modeling + # and preprocessing comparisons independent. + timm_batch = keras.ops.transpose(keras_preprocessed, axes=(0, 3, 1, 2)) + timm_batch = torch.from_numpy(np.array(timm_batch)) + timm_outputs = timm_model(timm_batch).detach().numpy() + timm_label = np.argmax(timm_outputs[0]) + + # Call with Keras. + keras_outputs = keras_model.predict(batch) + keras_label = np.argmax(keras_outputs[0]) + + print("🔶 Keras output:", keras_outputs[0, :10]) + print("🔶 TIMM output:", timm_outputs[0, :10]) + print("🔶 Keras label:", keras_label) + print("🔶 TIMM label:", timm_label) + modeling_diff = np.mean(np.abs(keras_outputs - timm_outputs)) + print("🔶 Modeling difference:", modeling_diff) + preprocessing_diff = np.mean(np.abs(keras_preprocessed - timm_preprocessed)) + print("🔶 Preprocessing difference:", preprocessing_diff) + + +def main(_): + preset = PRESET.value + if os.path.exists(preset): + shutil.rmtree(preset) + os.makedirs(preset) + + timm_name = PRESET_MAP[preset] + + timm_model = timm.create_model(timm_name, pretrained=True) + timm_model = timm_model.eval() + print("✅ Loaded TIMM model.") + print(timm_model) + + keras_model = keras_hub.models.ImageClassifier.from_preset( + "hf://" + timm_name, + ) + print("✅ Loaded KerasHub model.") + + keras_model.save_to_preset(f"./{preset}") + print(f"🏁 Preset saved to ./{preset}") + + validate_output(keras_model, timm_model) + + upload_uri = UPLOAD_URI.value + if upload_uri: + keras_hub.upload_preset(uri=upload_uri, preset=f"./{preset}") + print(f"🏁 Preset uploaded to {upload_uri}") + + +if __name__ == "__main__": + app.run(main) From 2d5568d7aa31439a971836de6fd84717e9c90e4a Mon Sep 17 00:00:00 2001 From: ushareng Date: Fri, 25 Oct 2024 00:32:31 +0530 Subject: [PATCH 13/21] rebase done --- keras_hub/api/layers/__init__.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/keras_hub/api/layers/__init__.py b/keras_hub/api/layers/__init__.py index 53e0074414..2060fb0da3 100644 --- a/keras_hub/api/layers/__init__.py +++ b/keras_hub/api/layers/__init__.py @@ -40,9 +40,7 @@ from keras_hub.src.models.densenet.densenet_image_converter import ( DenseNetImageConverter, ) -from keras_hub.src.models.mix_transformer.mix_transformer_image_converter import ( - MiTImageConverter, -) +from keras_hub.src.models.mit.mit_image_converter import MiTImageConverter from keras_hub.src.models.mobilenet.mobilenet_image_converter import ( MobileNetImageConverter, ) From 9f4c7a3b69c00ec57e60a2640743f93f1da0d8ed Mon Sep 17 00:00:00 2001 From: ushareng Date: Fri, 25 Oct 2024 01:07:29 +0530 Subject: [PATCH 14/21] code formatting --- keras_hub/src/utils/timm/convert_mobilenet.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/keras_hub/src/utils/timm/convert_mobilenet.py b/keras_hub/src/utils/timm/convert_mobilenet.py index e2de6d8e34..307f4a4acc 100644 --- a/keras_hub/src/utils/timm/convert_mobilenet.py +++ b/keras_hub/src/utils/timm/convert_mobilenet.py @@ -31,7 +31,7 @@ def convert_backbone_config(timm_config): [None, None], [0.25, 0.25, 0.25], [0.25, 0.25], - [0.25, 0.25, 0.25], + [0.25, 0.25, 0.25], ] stackwise_activation = [ ["relu", "relu"], @@ -100,7 +100,7 @@ def port_batch_normalization(keras_layer_name, hf_weight_prefix): backbone.get_layer(keras_layer_name).moving_variance, hf_weight_key=f"{hf_weight_prefix}.running_var", ) - + # Stem port_conv2d("input_conv", "conv_stem") port_batch_normalization("input_batch_norm", "bn1") @@ -154,6 +154,7 @@ def port_batch_normalization(keras_layer_name, hf_weight_prefix): # if version == "v2": # port_batch_normalization("output_batch_norm", "bn2") + def convert_head(task, loader, timm_config): prefix = "classifier." loader.port_weight( From eb90095cabf763f7418b9bf071301e3d0991dae3 Mon Sep 17 00:00:00 2001 From: ushareng Date: Sat, 26 Oct 2024 12:45:39 +0530 Subject: [PATCH 15/21] preset path updated --- keras_hub/src/models/mobilenet/mobilenet_presets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_hub/src/models/mobilenet/mobilenet_presets.py b/keras_hub/src/models/mobilenet/mobilenet_presets.py index e18364676f..172e7fdbf6 100644 --- a/keras_hub/src/models/mobilenet/mobilenet_presets.py +++ b/keras_hub/src/models/mobilenet/mobilenet_presets.py @@ -10,6 +10,6 @@ "official_name": "MobileNet", "path": "mobilenet3", }, - "kaggle_handle": "kaggle://alexbutcher/mobilenet3/keras/mobilenetv3_small_050", + "kaggle_handle": "kaggle://keras/mobilenet/keras/mobilenetv3_small_050", }, } From 7cbfebc15208fe1b5f1755d99f27c14706756562 Mon Sep 17 00:00:00 2001 From: Piseth Ky Date: Wed, 15 Jan 2025 17:05:09 -0800 Subject: [PATCH 16/21] WIP mobilenet fixes, subblock refactoring --- .../models/mobilenet/depthwise_conv_block.py | 304 ++++++++++++++++++ .../models/mobilenet/mobilenet_backbone.py | 169 +--------- .../models/mobilenet/squeeze_and_excite_2d.py | 101 ++++++ keras_hub/src/utils/timm/convert_mobilenet.py | 16 +- .../convert_mobilenet_checkpoints.py | 1 + 5 files changed, 427 insertions(+), 164 deletions(-) create mode 100644 keras_hub/src/models/mobilenet/depthwise_conv_block.py create mode 100644 keras_hub/src/models/mobilenet/squeeze_and_excite_2d.py diff --git a/keras_hub/src/models/mobilenet/depthwise_conv_block.py b/keras_hub/src/models/mobilenet/depthwise_conv_block.py new file mode 100644 index 0000000000..9345c090c3 --- /dev/null +++ b/keras_hub/src/models/mobilenet/depthwise_conv_block.py @@ -0,0 +1,304 @@ +import keras + +BN_AXIS = 3 + + +class DepthwiseConvBlock(keras.layers.Layer): + """ + A depthwise convolution block consists of a depthwise conv, + batch normalization, relu6, pointwise convolution, + batch normalization and relu6 activation. + + Args: + x: Input tensor of shape `(rows, cols, channels) + filters: Integer, the dimensionality of the output space + (i.e. the number of output filters in the pointwise convolution). + strides: An integer or tuple/list of 2 integers, specifying the strides + of the convolution along the width and height. + Can be a single integer to specify the same value for + all spatial dimensions. Specifying any stride value != 1 is + incompatible with specifying any `dilation_rate` value != 1. + block_id: Integer, a unique identification designating the block number. + + Input shape: + 4D tensor with shape: `(batch, rows, cols, channels)` in "channels_last" + 4D tensor with shape: `(batch, channels, rows, cols)` in "channels_first" + Returns: + Output tensor of block. + """ + + def __init__( + self, + input_filters, + output_filters, + expand_ratio=1, + kernel_size=3, + strides=1, + data_format="channels_last", + se_ratio=0.0, + batch_norm_momentum=0.9, + batch_norm_epsilon=1e-3, + activation="swish", + projection_activation=None, + dropout=0.2, + nores=False, + projection_kernel_size=1, + **kwargs, + ): + super().__init__(**kwargs) + self.input_filters = input_filters + self.output_filters = output_filters + self.expand_ratio = expand_ratio + self.kernel_size = kernel_size + self.strides = strides + self.data_format = data_format + self.se_ratio = se_ratio + self.batch_norm_momentum = batch_norm_momentum + self.batch_norm_epsilon = batch_norm_epsilon + self.activation = activation + self.projection_activation = projection_activation + self.dropout = dropout + self.nores = nores + self.projection_kernel_size = projection_kernel_size + self.filters = self.input_filters * self.expand_ratio + self.filters_se = max(1, int(input_filters * se_ratio)) + + padding_pixels = kernel_size // 2 + self.conv1_pad = keras.layers.ZeroPadding2D( + padding=(padding_pixels, padding_pixels), + name=self.name + "expand_conv_pad", + ) + self.conv1 = keras.layers.Conv2D( + filters=self.filters, + kernel_size=kernel_size, + strides=strides, + kernel_initializer=self._conv_kernel_initializer(), + padding="valid", + data_format=data_format, + use_bias=False, + name=self.name + "expand_conv", + ) + self.bn1 = keras.layers.BatchNormalization( + axis=BN_AXIS, + momentum=self.batch_norm_momentum, + epsilon=self.batch_norm_epsilon, + name=self.name + "expand_bn", + ) + self.act = keras.layers.Activation( + self.activation, name=self.name + "expand_activation" + ) + + self.se_conv1 = keras.layers.Conv2D( + self.filters_se, + 1, + padding="same", + data_format=data_format, + activation=self.activation, + kernel_initializer=self._conv_kernel_initializer(), + name=self.name + "se_reduce", + ) + + self.se_conv2 = keras.layers.Conv2D( + self.filters, + 1, + padding="same", + data_format=data_format, + activation="sigmoid", + kernel_initializer=self._conv_kernel_initializer(), + name=self.name + "se_expand", + ) + + padding_pixels = projection_kernel_size // 2 + self.output_conv_pad = keras.layers.ZeroPadding2D( + padding=(padding_pixels, padding_pixels), + name=self.name + "project_conv_pad", + ) + self.output_conv = keras.layers.Conv2D( + filters=self.output_filters, + kernel_size=projection_kernel_size, + strides=1, + kernel_initializer=self._conv_kernel_initializer(), + padding="valid", + data_format=data_format, + use_bias=False, + name=self.name + "project_conv", + ) + + self.bn2 = keras.layers.BatchNormalization( + axis=BN_AXIS, + momentum=self.batch_norm_momentum, + epsilon=self.batch_norm_epsilon, + name=self.name + "project_bn", + ) + + if self.projection_activation: + self.projection_act = keras.layers.Activation( + self.projection_activation, name=self.name + "projection_act" + ) + + if self.dropout: + self.dropout_layer = keras.layers.Dropout( + self.dropout, + noise_shape=(None, 1, 1, 1), + name=self.name + "drop", + ) + + def _conv_kernel_initializer( + self, + scale=2.0, + mode="fan_out", + distribution="truncated_normal", + seed=None, + ): + return keras.initializers.VarianceScaling( + scale=scale, mode=mode, distribution=distribution, seed=seed + ) + + def build(self, input_shape): + if self.name is None: + self.name = keras.backend.get_uid("block0") + + def call(self, inputs): + # Expansion phase + x = self.conv1_pad(inputs) + x = self.conv1(x) + x = self.bn1(x) + x = self.act(x) + + # Squeeze and excite + if 0 < self.se_ratio <= 1: + se = keras.layers.GlobalAveragePooling2D( + name=self.name + "se_squeeze", + data_format=self.data_format, + )(x) + if BN_AXIS == 1: + se_shape = (self.filters, 1, 1) + else: + se_shape = (1, 1, self.filters) + + se = keras.layers.Reshape(se_shape, name=self.name + "se_reshape")( + se + ) + + se = self.se_conv1(se) + se = self.se_conv2(se) + + x = keras.layers.multiply([x, se], name=self.name + "se_excite") + + # Output phase: + x = self.output_conv_pad(x) + x = self.output_conv(x) + x = self.bn2(x) + if self.expand_ratio == 1 and self.projection_activation: + x = self.projection_act(x) + + # Residual: + if ( + self.strides == 1 + and self.input_filters == self.output_filters + and not self.nores + ): + if self.dropout: + x = self.dropout_layer(x) + x = keras.layers.Add(name=self.name + "add")([x, inputs]) + return x + + def get_config(self): + config = { + "input_filters": self.input_filters, + "output_filters": self.output_filters, + "expand_ratio": self.expand_ratio, + "kernel_size": self.kernel_size, + "strides": self.strides, + "data_format": self.data_format, + "se_ratio": self.se_ratio, + "batch_norm_momentum": self.batch_norm_momentum, + "batch_norm_epsilon": self.batch_norm_epsilon, + "activation": self.activation, + "projection_activation": self.projection_activation, + "dropout": self.dropout, + "nores": self.nores, + "projection_kernel_size": self.projection_kernel_size, + } + + base_config = super().get_config() + return dict(list(base_config.items()) + list(config.items())) + +def apply_depthwise_conv_block( + x, filters, kernel_size=3, stride=2, se=None, name=None +): + """Adds a depthwise convolution block. + + A depthwise convolution block consists of a depthwise conv, + batch normalization, relu6, pointwise convolution, + batch normalization and relu6 activation. + + Args: + x: Input tensor of shape `(rows, cols, channels) + filters: Integer, the dimensionality of the output space + (i.e. the number of output filters in the pointwise convolution). + strides: An integer or tuple/list of 2 integers, specifying the strides + of the convolution along the width and height. + Can be a single integer to specify the same value for + all spatial dimensions. Specifying any stride value != 1 is + incompatible with specifying any `dilation_rate` value != 1. + block_id: Integer, a unique identification designating the block number. + + Input shape: + 4D tensor with shape: `(batch, rows, cols, channels)` in "channels_last" + 4D tensor with shape: `(batch, channels, rows, cols)` in "channels_first" + Returns: + Output tensor of block. + """ + channel_axis = ( + -1 if keras.config.image_data_format() == "channels_last" else 1 + ) + infilters = x.shape[channel_axis] + name = f"{name}_0" + + x = keras.layers.ZeroPadding2D( + padding=(1, 1), + name=f"{name}_pad", + )(x) + x = keras.layers.Conv2D( + infilters, + kernel_size, + strides=stride, + padding="valid", + data_format=keras.config.image_data_format(), + groups=infilters, + use_bias=False, + name=f"{name}_conv1", + )(x) + x = keras.layers.BatchNormalization( + axis=channel_axis, + epsilon=BN_EPSILON, + momentum=BN_MOMENTUM, + name=f"{name}_bn1", + )(x) + x = keras.layers.ReLU()(x) + + if se: + x = SqueezeAndExcite2D( + input=x, + filters=infilters, + bottleneck_filters=adjust_channels(infilters * se), + squeeze_activation="relu", + excite_activation=keras.activations.hard_sigmoid, + name=f"{name}_se", + ) + + x = keras.layers.Conv2D( + filters, + kernel_size=1, + data_format=keras.config.image_data_format(), + use_bias=False, + name=f"{name}_conv2", + )(x) + x = keras.layers.BatchNormalization( + axis=channel_axis, + epsilon=BN_EPSILON, + momentum=BN_MOMENTUM, + name=f"{name}_bn2", + )(x) + return x \ No newline at end of file diff --git a/keras_hub/src/models/mobilenet/mobilenet_backbone.py b/keras_hub/src/models/mobilenet/mobilenet_backbone.py index e40eac32b1..d7f7ae894b 100644 --- a/keras_hub/src/models/mobilenet/mobilenet_backbone.py +++ b/keras_hub/src/models/mobilenet/mobilenet_backbone.py @@ -1,5 +1,4 @@ import keras -from keras import ops from keras_hub.src.api_export import keras_hub_export from keras_hub.src.models.backbone import Backbone @@ -135,13 +134,10 @@ def __init__( x = image_input input_num_filters = adjust_channels(input_num_filters) - pad_width = ( - (0, 0), # No padding for batch - (1, 1), # 1 pixel padding for height - (1, 1), # 1 pixel padding for width - (0, 0), - ) # No padding for channels - x = ops.pad(x, pad_width=pad_width) + x = keras.layers.ZeroPadding2D( + padding=(1,1), + name="input_pad", + )(x) x = keras.layers.Conv2D( input_num_filters, kernel_size=3, @@ -334,19 +330,10 @@ def apply_inverted_res_block( x = keras.layers.Activation(activation=activation)(x) - # if stride == 2: - # x = keras.layers.ZeroPadding2D( - # padding=correct_pad_downsample(x, kernel_size), - # )(x) - - # pad_width=[[padding, padding], [padding, padding]] - pad_width = ( - (0, 0), # No padding for batch - (padding, padding), # 1 pixel padding for height - (padding, padding), # 1 pixel padding for width - (0, 0), - ) # No padding for channels - x = ops.pad(x, pad_width=pad_width) + x = keras.layers.ZeroPadding2D( + padding=(padding, padding), + name=f"{name}_pad", + )(x) x = keras.layers.Conv2D( expanded_channels, @@ -397,146 +384,6 @@ def apply_inverted_res_block( return x -def apply_depthwise_conv_block( - x, filters, kernel_size=3, stride=2, se=None, name=None -): - """Adds a depthwise convolution block. - - A depthwise convolution block consists of a depthwise conv, - batch normalization, relu6, pointwise convolution, - batch normalization and relu6 activation. - - Args: - x: Input tensor of shape `(rows, cols, channels) - filters: Integer, the dimensionality of the output space - (i.e. the number of output filters in the pointwise convolution). - strides: An integer or tuple/list of 2 integers, specifying the strides - of the convolution along the width and height. - Can be a single integer to specify the same value for - all spatial dimensions. Specifying any stride value != 1 is - incompatible with specifying any `dilation_rate` value != 1. - block_id: Integer, a unique identification designating the block number. - - Input shape: - 4D tensor with shape: `(batch, rows, cols, channels)` in "channels_last" - 4D tensor with shape: `(batch, channels, rows, cols)` in "channels_first" - Returns: - Output tensor of block. - """ - channel_axis = ( - -1 if keras.config.image_data_format() == "channels_last" else 1 - ) - infilters = x.shape[channel_axis] - name = f"{name}_0" - - # if stride == 2: - # x = keras.layers.ZeroPadding2D( - # padding=correct_pad_downsample(x, kernel_size), - # )(x) - pad_width = ( - (0, 0), # No padding for batch - (1, 1), # 1 pixel padding for height - (1, 1), # 1 pixel padding for width - (0, 0), - ) # No padding for channels - x = ops.pad(x, pad_width=pad_width) - x = keras.layers.Conv2D( - infilters, - kernel_size, - strides=stride, - padding="valid", - data_format=keras.config.image_data_format(), - groups=infilters, - use_bias=False, - name=f"{name}_conv1", - )(x) - x = keras.layers.BatchNormalization( - axis=channel_axis, - epsilon=BN_EPSILON, - momentum=BN_MOMENTUM, - name=f"{name}_bn1", - )(x) - x = keras.layers.ReLU(6.0)(x) - - if se: - x = SqueezeAndExcite2D( - input=x, - filters=infilters, - bottleneck_filters=adjust_channels(infilters * se), - squeeze_activation="relu", - excite_activation=keras.activations.hard_sigmoid, - name=f"{name}_se", - ) - - x = keras.layers.Conv2D( - filters, - kernel_size=1, - data_format=keras.config.image_data_format(), - use_bias=False, - name=f"{name}_conv2", - )(x) - x = keras.layers.BatchNormalization( - axis=channel_axis, - epsilon=BN_EPSILON, - momentum=BN_MOMENTUM, - name=f"{name}_bn2", - )(x) - return x - - -def SqueezeAndExcite2D( - input, - filters, - bottleneck_filters=None, - squeeze_activation="relu", - excite_activation="sigmoid", - name=None, -): - """ - Description: - This layer applies a content-aware mechanism to adaptively assign - channel-wise weights. It uses global average pooling to compress - feature maps into single values, which are then processed by - two Conv1D layers: the first reduces the dimensionality, and - the second restores it. - Args: - filters: Number of input and output filters. The number of input and - output filters is same. - bottleneck_filters: (Optional) Number of bottleneck filters. Defaults - to `0.25 * filters` - squeeze_activation: (Optional) String, callable (or - keras.layers.Layer) or keras.activations.Activation instance - denoting activation to be applied after squeeze convolution. - Defaults to `relu`. - excite_activation: (Optional) String, callable (or - keras.layers.Layer) or keras.activations.Activation instance - denoting activation to be applied after excite convolution. - Defaults to `sigmoid`. - name: Name of the layer - """ - if not bottleneck_filters: - bottleneck_filters = filters // 4 - - x = input - x = keras.layers.Conv2D( - bottleneck_filters, - (1, 1), - data_format=keras.config.image_data_format(), - activation=squeeze_activation, - name=f"{name}_conv_reduce", - )(x) - x = keras.layers.Conv2D( - filters, - (1, 1), - data_format=keras.config.image_data_format(), - activation=excite_activation, - name=f"{name}_conv_expand", - )(x) - - x = ops.multiply(x, input) - return x - - def ConvBnAct(x, filter, activation, name=None): channel_axis = ( -1 if keras.config.image_data_format() == "channels_last" else 1 diff --git a/keras_hub/src/models/mobilenet/squeeze_and_excite_2d.py b/keras_hub/src/models/mobilenet/squeeze_and_excite_2d.py new file mode 100644 index 0000000000..b1301211e8 --- /dev/null +++ b/keras_hub/src/models/mobilenet/squeeze_and_excite_2d.py @@ -0,0 +1,101 @@ +import keras + +BN_AXIS = 3 + + +class SqueezeAndExcite2D(keras.layers.Layer): + """ + Description: + This layer applies a content-aware mechanism to adaptively assign + channel-wise weights. It uses global average pooling to compress + feature maps into single values, which are then processed by + two Conv1D layers: the first reduces the dimensionality, and + the second restores it. + Args: + filters: Number of input and output filters. The number of input and + output filters is same. + bottleneck_filters: (Optional) Number of bottleneck filters. Defaults + to `0.25 * filters` + squeeze_activation: (Optional) String, callable (or + keras.layers.Layer) or keras.activations.Activation instance + denoting activation to be applied after squeeze convolution. + Defaults to `relu`. + excite_activation: (Optional) String, callable (or + keras.layers.Layer) or keras.activations.Activation instance + denoting activation to be applied after excite convolution. + Defaults to `sigmoid`. + name: Name of the layer + """ + + def __init__( + self, + input, + filters, + bottleneck_filters=None, + squeeze_activation="relu", + excite_activation="sigmoid", + name=None, + **kwargs, + ): + super().__init__(**kwargs) + self.input = input + self.filters = filters + self.bottleneck_filters = bottleneck_filters + self.squeeze_activation = squeeze_activation + self.excite_activation = excite_activation + self.name = name + + image_data_format = keras.config.image_data_format() + if image_data_format == "channels_last": + self.spatial_dims = (1, 2) + else: + self.spatial_dims = (2, 3) + + self.conv_reduce = keras.layers.Conv2D( + bottleneck_filters, + (1, 1), + data_format=image_data_format, + activation=squeeze_activation, + name=f"{name}_conv_reduce", + ) + self.act1 = keras.layers.Activation( + self.squeeze_activation, name=self.name + "squeeze_activation" + ) + + self.conv_expand = keras.layers.Conv2D( + filters, + (1, 1), + data_format=image_data_format, + activation=excite_activation, + name=f"{name}_conv_expand", + ) + self.gate = keras.layers.Activation( + self.excite_activation, name=self.name + "excite_activation" + ) + + def build(self, input_shape): + if self.name is None: + self.name = keras.backend.get_uid("block0") + + def call(self, inputs): + x_se = keras.ops.mean(inputs, axis=self.spatial_dims, keepdims=True) + x_se = self.conv_reduce(x_se) + x_se = self.act1(x_se) + x_se = self.conv_expand(x_se) + return inputs * self.gate(x_se) + + def get_config(self): + config = super().get_config() + config.update( + { + "input": self.input, + "filters": self.filters, + "bottleneck_filters": self.bottleneck_filters, + "squeeze_activation": self.squeeze_activation, + "excite_activation": self.excite_activation, + "name": self.name, + "spatial_dims": self.spatial_dims, + } + ) + + return config diff --git a/keras_hub/src/utils/timm/convert_mobilenet.py b/keras_hub/src/utils/timm/convert_mobilenet.py index 307f4a4acc..34914c4fe6 100644 --- a/keras_hub/src/utils/timm/convert_mobilenet.py +++ b/keras_hub/src/utils/timm/convert_mobilenet.py @@ -70,13 +70,21 @@ def convert_backbone_config(timm_config): def convert_weights(backbone, loader, timm_config): - def port_conv2d(keras_layer_name, hf_weight_prefix): + def port_conv2d(keras_layer_name, hf_weight_prefix, port_bias=False): print(f"porting weights {hf_weight_prefix} -> {keras_layer_name}") loader.port_weight( backbone.get_layer(keras_layer_name).kernel, hf_weight_key=f"{hf_weight_prefix}.weight", hook_fn=lambda x, _: np.transpose(x, (2, 3, 1, 0)), ) + + if port_bias: + print(f"porting bias {hf_weight_prefix} -> {keras_layer_name}") + loader.port_weight( + backbone.get_layer(keras_layer_name).bias, + hf_weight_key=f"{hf_weight_prefix}.bias", + ) + def port_batch_normalization(keras_layer_name, hf_weight_prefix): print(f"porting weights {hf_weight_prefix} -> {keras_layer_name}") @@ -112,8 +120,8 @@ def port_batch_normalization(keras_layer_name, hf_weight_prefix): port_conv2d(f"{keras_name}_conv1", f"{hf_name}.conv_dw") port_batch_normalization(f"{keras_name}_bn1", f"{hf_name}.bn1") - port_conv2d(f"{keras_name}_se_conv_reduce", f"{hf_name}.se.conv_reduce") - port_conv2d(f"{keras_name}_se_conv_expand", f"{hf_name}.se.conv_expand") + port_conv2d(f"{keras_name}_se_conv_reduce", f"{hf_name}.se.conv_reduce", True) + port_conv2d(f"{keras_name}_se_conv_expand", f"{hf_name}.se.conv_expand", True) port_conv2d(f"{keras_name}_conv2", f"{hf_name}.conv_pw") port_batch_normalization(f"{keras_name}_bn2", f"{hf_name}.bn2") @@ -135,10 +143,12 @@ def port_batch_normalization(keras_layer_name, hf_weight_prefix): port_conv2d( f"{keras_name}_se_conv_reduce", f"{hf_name}.se.conv_reduce", + True, ) port_conv2d( f"{keras_name}_se_conv_expand", f"{hf_name}.se.conv_expand", + True, ) port_conv2d(f"{keras_name}_conv3", f"{hf_name}.conv_pwl") diff --git a/tools/checkpoint_conversion/convert_mobilenet_checkpoints.py b/tools/checkpoint_conversion/convert_mobilenet_checkpoints.py index 270d18eef9..3a47f7d779 100644 --- a/tools/checkpoint_conversion/convert_mobilenet_checkpoints.py +++ b/tools/checkpoint_conversion/convert_mobilenet_checkpoints.py @@ -56,6 +56,7 @@ def validate_output(keras_model, timm_model): timm_preprocessed = keras.ops.expand_dims(timm_preprocessed, 0) # Preprocess with Keras. + batch = keras.ops.cast(batch, "float32") keras_preprocessed = keras_model.preprocessor(batch) # Call with Timm. Use the keras preprocessed image so we can keep modeling From cfe4a4f36e7b8a581d2eec87cd3cc7ad033afc4f Mon Sep 17 00:00:00 2001 From: Piseth Ky Date: Thu, 16 Jan 2025 16:16:53 -0800 Subject: [PATCH 17/21] WIP refactored, classifier/task changes --- keras_hub/src/models/image_classifier.py | 24 ++ .../src/models/mobilenet/conv_bn_act_block.py | 58 ++++ .../models/mobilenet/depthwise_conv_block.py | 302 ++++-------------- .../mobilenet/inverted_residual_block.py | 161 ++++++++++ .../models/mobilenet/mobilenet_backbone.py | 221 +------------ .../models/mobilenet/squeeze_and_excite_2d.py | 5 - keras_hub/src/models/mobilenet/util.py | 23 ++ keras_hub/src/utils/timm/convert_mobilenet.py | 72 +++-- .../convert_mobilenet_checkpoints.py | 3 +- 9 files changed, 383 insertions(+), 486 deletions(-) create mode 100644 keras_hub/src/models/mobilenet/conv_bn_act_block.py create mode 100644 keras_hub/src/models/mobilenet/inverted_residual_block.py create mode 100644 keras_hub/src/models/mobilenet/util.py diff --git a/keras_hub/src/models/image_classifier.py b/keras_hub/src/models/image_classifier.py index e75e390899..878a087ed4 100644 --- a/keras_hub/src/models/image_classifier.py +++ b/keras_hub/src/models/image_classifier.py @@ -97,6 +97,8 @@ def __init__( activation=None, dropout=0.0, head_dtype=None, + include_conv=False, + flatten=False, **kwargs, ): head_dtype = head_dtype or backbone.dtype_policy @@ -127,6 +129,20 @@ def __init__( dtype=head_dtype, name="output_dropout", ) + + if include_conv: + self.output_conv = keras.layers.Conv2D( + filters=1024, + kernel_size=(1, 1), + strides=(1, 1), + use_bias=True, + padding='valid', + activation="hardswish", + ) + + if flatten: + self.flatten = keras.layers.Flatten() + self.output_dense = keras.layers.Dense( num_classes, activation=activation, @@ -139,6 +155,10 @@ def __init__( x = self.backbone(inputs) x = self.pooler(x) x = self.output_dropout(x) + if include_conv: + x = self.output_conv(x) + if flatten: + x = self.flatten(x) outputs = self.output_dense(x) super().__init__( inputs=inputs, @@ -151,6 +171,8 @@ def __init__( self.activation = activation self.pooling = pooling self.dropout = dropout + self.include_conv = include_conv + self.flatten = flatten def get_config(self): # Backbone serialized in `super` @@ -161,6 +183,8 @@ def get_config(self): "pooling": self.pooling, "activation": self.activation, "dropout": self.dropout, + "include_conv": self.include_conv, + "flatten": self.flatten, } ) return config diff --git a/keras_hub/src/models/mobilenet/conv_bn_act_block.py b/keras_hub/src/models/mobilenet/conv_bn_act_block.py new file mode 100644 index 0000000000..030e92f7ec --- /dev/null +++ b/keras_hub/src/models/mobilenet/conv_bn_act_block.py @@ -0,0 +1,58 @@ +import keras + + +BN_EPSILON = 1e-5 +BN_MOMENTUM = 0.9 +BN_AXIS = 3 + + +class ConvBnActBlock(keras.layers.Layer): + def __init__( + self, + filter, + activation, + name=None, + **kwargs, + ): + super().__init__(**kwargs) + self.filter = filter + self.activation = activation + self.name = name + + channel_axis = ( + -1 if keras.config.image_data_format() == "channels_last" else 1 + ) + self.conv = keras.layers.Conv2D( + filter, + kernel_size=1, + data_format=keras.config.image_data_format(), + use_bias=False, + name=f"{name}_conv", + ) + self.bn = keras.layers.BatchNormalization( + axis=channel_axis, + epsilon=BN_EPSILON, + momentum=BN_MOMENTUM, + name=f"{name}_bn", + ) + self.act = keras.layers.Activation(activation) + + def build(self, input_shape): + if self.name is None: + self.name = keras.backend.get_uid("block0") + + def call(self, inputs): + x = self.conv(inputs) + x = self.bn(x) + x = self.act(x) + return x + + def get_config(self): + config = { + "filter": self.filter, + "activation": self.activation, + "name": self.name, + } + + base_config = super().get_config() + return dict(list(base_config.items()) + list(config.items())) diff --git a/keras_hub/src/models/mobilenet/depthwise_conv_block.py b/keras_hub/src/models/mobilenet/depthwise_conv_block.py index 9345c090c3..feccdd1ffa 100644 --- a/keras_hub/src/models/mobilenet/depthwise_conv_block.py +++ b/keras_hub/src/models/mobilenet/depthwise_conv_block.py @@ -1,5 +1,11 @@ import keras +from keras_hub.src.models.mobilenet.squeeze_and_excite_2d import SqueezeAndExcite2D +from keras_hub.src.models.mobilenet.util import adjust_channels + + +BN_EPSILON = 1e-5 +BN_MOMENTUM = 0.9 BN_AXIS = 3 @@ -29,129 +35,70 @@ class DepthwiseConvBlock(keras.layers.Layer): def __init__( self, - input_filters, - output_filters, - expand_ratio=1, - kernel_size=3, - strides=1, - data_format="channels_last", - se_ratio=0.0, - batch_norm_momentum=0.9, - batch_norm_epsilon=1e-3, - activation="swish", - projection_activation=None, - dropout=0.2, - nores=False, - projection_kernel_size=1, + infilters, + filters, + kernel_size=3, + stride=2, + se=None, + name=None, **kwargs, ): super().__init__(**kwargs) - self.input_filters = input_filters - self.output_filters = output_filters - self.expand_ratio = expand_ratio + self.infilters = infilters + self.filters = filters self.kernel_size = kernel_size - self.strides = strides - self.data_format = data_format - self.se_ratio = se_ratio - self.batch_norm_momentum = batch_norm_momentum - self.batch_norm_epsilon = batch_norm_epsilon - self.activation = activation - self.projection_activation = projection_activation - self.dropout = dropout - self.nores = nores - self.projection_kernel_size = projection_kernel_size - self.filters = self.input_filters * self.expand_ratio - self.filters_se = max(1, int(input_filters * se_ratio)) + self.stride = stride + self.se = se + self.name = name - padding_pixels = kernel_size // 2 - self.conv1_pad = keras.layers.ZeroPadding2D( - padding=(padding_pixels, padding_pixels), - name=self.name + "expand_conv_pad", + channel_axis = ( + -1 if keras.config.image_data_format() == "channels_last" else 1 + ) + self.name = name = f"{name}_0" + + self.pad = keras.layers.ZeroPadding2D( + padding=(1, 1), + name=f"{name}_pad", ) self.conv1 = keras.layers.Conv2D( - filters=self.filters, - kernel_size=kernel_size, - strides=strides, - kernel_initializer=self._conv_kernel_initializer(), + infilters, + kernel_size, + strides=stride, padding="valid", - data_format=data_format, + data_format=keras.config.image_data_format(), + groups=infilters, use_bias=False, - name=self.name + "expand_conv", + name=f"{name}_conv1", ) self.bn1 = keras.layers.BatchNormalization( - axis=BN_AXIS, - momentum=self.batch_norm_momentum, - epsilon=self.batch_norm_epsilon, - name=self.name + "expand_bn", - ) - self.act = keras.layers.Activation( - self.activation, name=self.name + "expand_activation" - ) - - self.se_conv1 = keras.layers.Conv2D( - self.filters_se, - 1, - padding="same", - data_format=data_format, - activation=self.activation, - kernel_initializer=self._conv_kernel_initializer(), - name=self.name + "se_reduce", - ) - - self.se_conv2 = keras.layers.Conv2D( - self.filters, - 1, - padding="same", - data_format=data_format, - activation="sigmoid", - kernel_initializer=self._conv_kernel_initializer(), - name=self.name + "se_expand", + axis=channel_axis, + epsilon=BN_EPSILON, + momentum=BN_MOMENTUM, + name=f"{name}_bn1", ) + self.act1 = keras.layers.ReLU() + + if se: + self.se_layer = SqueezeAndExcite2D( + filters=infilters, + bottleneck_filters=adjust_channels(infilters * se), + squeeze_activation="relu", + excite_activation=keras.activations.hard_sigmoid, + name=f"{name}_se", + ) - padding_pixels = projection_kernel_size // 2 - self.output_conv_pad = keras.layers.ZeroPadding2D( - padding=(padding_pixels, padding_pixels), - name=self.name + "project_conv_pad", - ) - self.output_conv = keras.layers.Conv2D( - filters=self.output_filters, - kernel_size=projection_kernel_size, - strides=1, - kernel_initializer=self._conv_kernel_initializer(), - padding="valid", - data_format=data_format, + self.conv2 = keras.layers.Conv2D( + filters, + kernel_size=1, + data_format=keras.config.image_data_format(), use_bias=False, - name=self.name + "project_conv", + name=f"{name}_conv2", ) - self.bn2 = keras.layers.BatchNormalization( - axis=BN_AXIS, - momentum=self.batch_norm_momentum, - epsilon=self.batch_norm_epsilon, - name=self.name + "project_bn", - ) - - if self.projection_activation: - self.projection_act = keras.layers.Activation( - self.projection_activation, name=self.name + "projection_act" - ) - - if self.dropout: - self.dropout_layer = keras.layers.Dropout( - self.dropout, - noise_shape=(None, 1, 1, 1), - name=self.name + "drop", - ) - - def _conv_kernel_initializer( - self, - scale=2.0, - mode="fan_out", - distribution="truncated_normal", - seed=None, - ): - return keras.initializers.VarianceScaling( - scale=scale, mode=mode, distribution=distribution, seed=seed + axis=channel_axis, + epsilon=BN_EPSILON, + momentum=BN_MOMENTUM, + name=f"{name}_bn2", ) def build(self, input_shape): @@ -159,146 +106,27 @@ def build(self, input_shape): self.name = keras.backend.get_uid("block0") def call(self, inputs): - # Expansion phase - x = self.conv1_pad(inputs) + x = self.pad(inputs) x = self.conv1(x) x = self.bn1(x) - x = self.act(x) - - # Squeeze and excite - if 0 < self.se_ratio <= 1: - se = keras.layers.GlobalAveragePooling2D( - name=self.name + "se_squeeze", - data_format=self.data_format, - )(x) - if BN_AXIS == 1: - se_shape = (self.filters, 1, 1) - else: - se_shape = (1, 1, self.filters) - - se = keras.layers.Reshape(se_shape, name=self.name + "se_reshape")( - se - ) - - se = self.se_conv1(se) - se = self.se_conv2(se) + x = self.act1(x) - x = keras.layers.multiply([x, se], name=self.name + "se_excite") + if self.se_layer: + x = self.se_layer(x) - # Output phase: - x = self.output_conv_pad(x) - x = self.output_conv(x) + x = self.conv2(x) x = self.bn2(x) - if self.expand_ratio == 1 and self.projection_activation: - x = self.projection_act(x) - - # Residual: - if ( - self.strides == 1 - and self.input_filters == self.output_filters - and not self.nores - ): - if self.dropout: - x = self.dropout_layer(x) - x = keras.layers.Add(name=self.name + "add")([x, inputs]) return x def get_config(self): config = { - "input_filters": self.input_filters, - "output_filters": self.output_filters, - "expand_ratio": self.expand_ratio, - "kernel_size": self.kernel_size, - "strides": self.strides, - "data_format": self.data_format, - "se_ratio": self.se_ratio, - "batch_norm_momentum": self.batch_norm_momentum, - "batch_norm_epsilon": self.batch_norm_epsilon, - "activation": self.activation, - "projection_activation": self.projection_activation, - "dropout": self.dropout, - "nores": self.nores, - "projection_kernel_size": self.projection_kernel_size, + "infilters": self.infilters, + "filters": self.filters, + "kernel_size": self.kernel_size, + "stride": self.stride, + "se": self.se, + "name": self.name, } base_config = super().get_config() return dict(list(base_config.items()) + list(config.items())) - -def apply_depthwise_conv_block( - x, filters, kernel_size=3, stride=2, se=None, name=None -): - """Adds a depthwise convolution block. - - A depthwise convolution block consists of a depthwise conv, - batch normalization, relu6, pointwise convolution, - batch normalization and relu6 activation. - - Args: - x: Input tensor of shape `(rows, cols, channels) - filters: Integer, the dimensionality of the output space - (i.e. the number of output filters in the pointwise convolution). - strides: An integer or tuple/list of 2 integers, specifying the strides - of the convolution along the width and height. - Can be a single integer to specify the same value for - all spatial dimensions. Specifying any stride value != 1 is - incompatible with specifying any `dilation_rate` value != 1. - block_id: Integer, a unique identification designating the block number. - - Input shape: - 4D tensor with shape: `(batch, rows, cols, channels)` in "channels_last" - 4D tensor with shape: `(batch, channels, rows, cols)` in "channels_first" - Returns: - Output tensor of block. - """ - channel_axis = ( - -1 if keras.config.image_data_format() == "channels_last" else 1 - ) - infilters = x.shape[channel_axis] - name = f"{name}_0" - - x = keras.layers.ZeroPadding2D( - padding=(1, 1), - name=f"{name}_pad", - )(x) - x = keras.layers.Conv2D( - infilters, - kernel_size, - strides=stride, - padding="valid", - data_format=keras.config.image_data_format(), - groups=infilters, - use_bias=False, - name=f"{name}_conv1", - )(x) - x = keras.layers.BatchNormalization( - axis=channel_axis, - epsilon=BN_EPSILON, - momentum=BN_MOMENTUM, - name=f"{name}_bn1", - )(x) - x = keras.layers.ReLU()(x) - - if se: - x = SqueezeAndExcite2D( - input=x, - filters=infilters, - bottleneck_filters=adjust_channels(infilters * se), - squeeze_activation="relu", - excite_activation=keras.activations.hard_sigmoid, - name=f"{name}_se", - ) - - x = keras.layers.Conv2D( - filters, - kernel_size=1, - data_format=keras.config.image_data_format(), - use_bias=False, - name=f"{name}_conv2", - )(x) - x = keras.layers.BatchNormalization( - axis=channel_axis, - epsilon=BN_EPSILON, - momentum=BN_MOMENTUM, - name=f"{name}_bn2", - )(x) - return x \ No newline at end of file diff --git a/keras_hub/src/models/mobilenet/inverted_residual_block.py b/keras_hub/src/models/mobilenet/inverted_residual_block.py new file mode 100644 index 0000000000..c95733744b --- /dev/null +++ b/keras_hub/src/models/mobilenet/inverted_residual_block.py @@ -0,0 +1,161 @@ +import keras + +from keras_hub.src.models.mobilenet.squeeze_and_excite_2d import SqueezeAndExcite2D +from keras_hub.src.models.mobilenet.util import adjust_channels + + +BN_EPSILON = 1e-5 +BN_MOMENTUM = 0.9 +BN_AXIS = 3 + + +class InvertedResidualBlock(keras.layers.Layer): + """An Inverted Residual Block. + + Args: + expansion: integer, the expansion ratio, multiplied with infilters to + get the minimum value passed to adjust_channels. + filters: integer, number of filters for convolution layer. + kernel_size: integer, the kernel size for DepthWise Convolutions. + stride: integer, the stride length for DepthWise Convolutions. + se_ratio: float, ratio for bottleneck filters. Number of bottleneck + filters = filters * se_ratio. + activation: the activation layer to use. + padding: padding in the conv2d layer + name: string, block label. + + Returns: + the updated input tensor. + """ + + def __init__( + self, + expansion, + infilters, + filters, + kernel_size, + stride, + se_ratio, + activation, + padding, + name=None, + **kwargs, + ): + super().__init__(**kwargs) + self.expansion = expansion + self.infilters = infilters + self.filters = filters + self.kernel_size = kernel_size + self.stride = stride + self.se_ratio = se_ratio + self.activation = activation + self.padding = padding + self.name = name + + channel_axis = ( + -1 if keras.config.image_data_format() == "channels_last" else 1 + ) + expanded_channels = adjust_channels(expansion) + + self.conv1 = keras.layers.Conv2D( + expanded_channels, + kernel_size=1, + data_format=keras.config.image_data_format(), + use_bias=False, + name=f"{name}_conv1", + ) + + self.bn1 = keras.layers.BatchNormalization( + axis=channel_axis, + epsilon=BN_EPSILON, + momentum=BN_MOMENTUM, + name=f"{name}_bn1", + ) + + self.act1 = keras.layers.Activation(activation=activation) + + self.pad = keras.layers.ZeroPadding2D( + padding=(padding, padding), + name=f"{name}_pad", + ) + + self.conv2 = keras.layers.Conv2D( + expanded_channels, + kernel_size, + strides=stride, + padding="valid", + groups=expanded_channels, + data_format=keras.config.image_data_format(), + use_bias=False, + name=f"{name}_conv2", + ) + self.bn2 = keras.layers.BatchNormalization( + axis=channel_axis, + epsilon=BN_EPSILON, + momentum=BN_MOMENTUM, + name=f"{name}_bn2", + ) + + self.act2 = keras.layers.Activation(activation=activation) + + self.se = None + if self.se_ratio: + se_filters = expanded_channels + self.se = SqueezeAndExcite2D( + filters=se_filters, + bottleneck_filters=adjust_channels(se_filters * se_ratio), + squeeze_activation="relu", + excite_activation=keras.activations.hard_sigmoid, + name=f"{name}_se", + ) + + self.conv3 = keras.layers.Conv2D( + filters, + kernel_size=1, + data_format=keras.config.image_data_format(), + use_bias=False, + name=f"{name}_conv3", + ) + self.bn3 = keras.layers.BatchNormalization( + axis=channel_axis, + epsilon=BN_EPSILON, + momentum=BN_MOMENTUM, + name=f"{name}_bn3", + ) + + def build(self, input_shape): + if self.name is None: + self.name = keras.backend.get_uid("block0") + + def call(self, inputs): + x = inputs + x = self.conv1(x) + x = self.bn1(x) + x = self.act1(x) + x = self.pad(x) + x = self.conv2(x) + x = self.bn2(x) + x = self.act2(x) + if self.se: + x = self.se(x) + x = self.conv3(x) + x = self.bn3(x) + if self.stride == 1 and self.infilters == self.filters: + x = inputs + x + return x + + def get_config(self): + config = { + "expansion": self.expansion, + "infilters": self.infilters, + "filters": self.filters, + "kernel_size": self.kernel_size, + "stride": self.stride, + "se_ratio": self.se_ratio, + "activation": self.activation, + "padding": self.padding, + "name": self.name, + } + + base_config = super().get_config() + return dict(list(base_config.items()) + list(config.items())) diff --git a/keras_hub/src/models/mobilenet/mobilenet_backbone.py b/keras_hub/src/models/mobilenet/mobilenet_backbone.py index d7f7ae894b..8ccd8d04d0 100644 --- a/keras_hub/src/models/mobilenet/mobilenet_backbone.py +++ b/keras_hub/src/models/mobilenet/mobilenet_backbone.py @@ -2,6 +2,11 @@ from keras_hub.src.api_export import keras_hub_export from keras_hub.src.models.backbone import Backbone +from keras_hub.src.models.mobilenet.depthwise_conv_block import DepthwiseConvBlock +from keras_hub.src.models.mobilenet.inverted_residual_block import InvertedResidualBlock +from keras_hub.src.models.mobilenet.conv_bn_act_block import ConvBnActBlock +from keras_hub.src.models.mobilenet.util import adjust_channels + BN_EPSILON = 1e-5 BN_MOMENTUM = 0.9 @@ -154,15 +159,16 @@ def __init__( )(x) x = keras.layers.Activation(input_activation)(x) - x = apply_depthwise_conv_block( - x, depthwise_filters, se=squeeze_and_excite, name="block_0" - ) + x = DepthwiseConvBlock( + input_num_filters, depthwise_filters, se=squeeze_and_excite, name="block_0" + )(x) for block in range(len(stackwise_num_blocks)): for inverted_block in range(stackwise_num_blocks[block]): - x = apply_inverted_res_block( - x, + infilters = x.shape[channel_axis] + x = InvertedResidualBlock( expansion=stackwise_expansion[block][inverted_block], + infilters=infilters, filters=adjust_channels( stackwise_num_filters[block][inverted_block] ), @@ -172,36 +178,14 @@ def __init__( activation=stackwise_activation[block][inverted_block], padding=stackwise_padding[block][inverted_block], name=f"block_{block+1}_{inverted_block}", - ) + )(x) - x = ConvBnAct( - x, + x = ConvBnActBlock( filter=adjust_channels(last_layer_filter), activation="hard_swish", name=f"block_{len(stackwise_num_blocks)+1}_0", - ) - - last_conv_ch = adjust_channels(output_num_filters) - - x = keras.layers.Conv2D( - last_conv_ch, - kernel_size=1, - data_format=keras.config.image_data_format(), - use_bias=False, - name="output_conv", )(x) - # no output normalization in mobilenetv3 - if output_activation == "relu6": - x = keras.layers.BatchNormalization( - axis=channel_axis, - epsilon=BN_EPSILON, - momentum=BN_MOMENTUM, - name="output_batch_norm", - )(x) - - x = keras.layers.Activation(output_activation)(x) - super().__init__(inputs=image_input, outputs=x, **kwargs) # === Config === @@ -249,182 +233,3 @@ def get_config(self): } ) return config - - -def adjust_channels(x, divisor=8, min_value=None): - """Ensure that all layers have a channel number divisible by the `divisor`. - - Args: - x: integer, input value. - divisor: integer, the value by which a channel number should be - divisible, defaults to 8. - min_value: float, optional minimum value for the new tensor. If None, - defaults to value of divisor. - - Returns: - the updated input scalar. - """ - - if min_value is None: - min_value = divisor - - new_x = max(min_value, int(x + divisor / 2) // divisor * divisor) - - # make sure that round down does not go down by more than 10%. - if new_x < 0.9 * x: - new_x += divisor - return new_x - - -def apply_inverted_res_block( - x, - expansion, - filters, - kernel_size, - stride, - se_ratio, - activation, - padding, - name=None, -): - """An Inverted Residual Block. - - Args: - x: input tensor. - expansion: integer, the expansion ratio, multiplied with infilters to - get the minimum value passed to adjust_channels. - filters: integer, number of filters for convolution layer. - kernel_size: integer, the kernel size for DepthWise Convolutions. - stride: integer, the stride length for DepthWise Convolutions. - se_ratio: float, ratio for bottleneck filters. Number of bottleneck - filters = filters * se_ratio. - activation: the activation layer to use. - padding: padding in the conv2d layer - name: string, block label. - - Returns: - the updated input tensor. - """ - channel_axis = ( - -1 if keras.config.image_data_format() == "channels_last" else 1 - ) - activation = keras.activations.get(activation) - shortcut = x - infilters = x.shape[channel_axis] - expanded_channels = adjust_channels(expansion) - - x = keras.layers.Conv2D( - expanded_channels, - kernel_size=1, - data_format=keras.config.image_data_format(), - use_bias=False, - name=f"{name}_conv1", - )(x) - - x = keras.layers.BatchNormalization( - axis=channel_axis, - epsilon=BN_EPSILON, - momentum=BN_MOMENTUM, - name=f"{name}_bn1", - )(x) - - x = keras.layers.Activation(activation=activation)(x) - - x = keras.layers.ZeroPadding2D( - padding=(padding, padding), - name=f"{name}_pad", - )(x) - - x = keras.layers.Conv2D( - expanded_channels, - kernel_size, - strides=stride, - padding="valid", - groups=expanded_channels, - data_format=keras.config.image_data_format(), - use_bias=False, - name=f"{name}_conv2", - )(x) - x = keras.layers.BatchNormalization( - axis=channel_axis, - epsilon=BN_EPSILON, - momentum=BN_MOMENTUM, - name=f"{name}_bn2", - )(x) - - x = keras.layers.Activation(activation=activation)(x) - - if se_ratio: - se_filters = expanded_channels - x = SqueezeAndExcite2D( - input=x, - filters=se_filters, - bottleneck_filters=adjust_channels(se_filters * se_ratio), - squeeze_activation="relu", - excite_activation=keras.activations.hard_sigmoid, - name=f"{name}_se", - ) - - x = keras.layers.Conv2D( - filters, - kernel_size=1, - data_format=keras.config.image_data_format(), - use_bias=False, - name=f"{name}_conv3", - )(x) - x = keras.layers.BatchNormalization( - axis=channel_axis, - epsilon=BN_EPSILON, - momentum=BN_MOMENTUM, - name=f"{name}_bn3", - )(x) - - if stride == 1 and infilters == filters: - x = keras.layers.Add(name=f"{name}_add")([shortcut, x]) - return x - - -def ConvBnAct(x, filter, activation, name=None): - channel_axis = ( - -1 if keras.config.image_data_format() == "channels_last" else 1 - ) - x = keras.layers.Conv2D( - filter, - kernel_size=1, - data_format=keras.config.image_data_format(), - use_bias=False, - name=f"{name}_conv", - )(x) - x = keras.layers.BatchNormalization( - axis=channel_axis, - epsilon=BN_EPSILON, - momentum=BN_MOMENTUM, - name=f"{name}_bn", - )(x) - x = keras.layers.Activation(activation)(x) - return x - - -def correct_pad_downsample(inputs, kernel_size): - """Returns a tuple for zero-padding for 2D convolution with downsampling. - - Args: - inputs: Input tensor. - kernel_size: An integer or tuple/list of 2 integers. - - Returns: - A tuple. - """ - img_dim = 1 - input_size = inputs.shape[img_dim : (img_dim + 2)] - if isinstance(kernel_size, int): - kernel_size = (kernel_size, kernel_size) - if input_size[0] is None: - adjust = (1, 1) - else: - adjust = (1 - input_size[0] % 2, 1 - input_size[1] % 2) - correct = (kernel_size[0] // 2, kernel_size[1] // 2) - return ( - (correct[0] - adjust[0], correct[0]), - (correct[1] - adjust[1], correct[1]), - ) diff --git a/keras_hub/src/models/mobilenet/squeeze_and_excite_2d.py b/keras_hub/src/models/mobilenet/squeeze_and_excite_2d.py index b1301211e8..30a07ace35 100644 --- a/keras_hub/src/models/mobilenet/squeeze_and_excite_2d.py +++ b/keras_hub/src/models/mobilenet/squeeze_and_excite_2d.py @@ -29,7 +29,6 @@ class SqueezeAndExcite2D(keras.layers.Layer): def __init__( self, - input, filters, bottleneck_filters=None, squeeze_activation="relu", @@ -38,7 +37,6 @@ def __init__( **kwargs, ): super().__init__(**kwargs) - self.input = input self.filters = filters self.bottleneck_filters = bottleneck_filters self.squeeze_activation = squeeze_activation @@ -55,7 +53,6 @@ def __init__( bottleneck_filters, (1, 1), data_format=image_data_format, - activation=squeeze_activation, name=f"{name}_conv_reduce", ) self.act1 = keras.layers.Activation( @@ -66,7 +63,6 @@ def __init__( filters, (1, 1), data_format=image_data_format, - activation=excite_activation, name=f"{name}_conv_expand", ) self.gate = keras.layers.Activation( @@ -88,7 +84,6 @@ def get_config(self): config = super().get_config() config.update( { - "input": self.input, "filters": self.filters, "bottleneck_filters": self.bottleneck_filters, "squeeze_activation": self.squeeze_activation, diff --git a/keras_hub/src/models/mobilenet/util.py b/keras_hub/src/models/mobilenet/util.py new file mode 100644 index 0000000000..b17efc4b87 --- /dev/null +++ b/keras_hub/src/models/mobilenet/util.py @@ -0,0 +1,23 @@ +def adjust_channels(x, divisor=8, min_value=None): + """Ensure that all layers have a channel number divisible by the `divisor`. + + Args: + x: integer, input value. + divisor: integer, the value by which a channel number should be + divisible, defaults to 8. + min_value: float, optional minimum value for the new tensor. If None, + defaults to value of divisor. + + Returns: + the updated input scalar. + """ + + if min_value is None: + min_value = divisor + + new_x = max(min_value, int(x + divisor / 2) // divisor * divisor) + + # make sure that round down does not go down by more than 10%. + if new_x < 0.9 * x: + new_x += divisor + return new_x \ No newline at end of file diff --git a/keras_hub/src/utils/timm/convert_mobilenet.py b/keras_hub/src/utils/timm/convert_mobilenet.py index 34914c4fe6..09545c6c42 100644 --- a/keras_hub/src/utils/timm/convert_mobilenet.py +++ b/keras_hub/src/utils/timm/convert_mobilenet.py @@ -70,61 +70,64 @@ def convert_backbone_config(timm_config): def convert_weights(backbone, loader, timm_config): - def port_conv2d(keras_layer_name, hf_weight_prefix, port_bias=False): - print(f"porting weights {hf_weight_prefix} -> {keras_layer_name}") + def port_conv2d(keras_layer, hf_weight_prefix, port_bias=False): + print(f"porting weights {hf_weight_prefix} -> {keras_layer}") loader.port_weight( - backbone.get_layer(keras_layer_name).kernel, + keras_layer.kernel, hf_weight_key=f"{hf_weight_prefix}.weight", hook_fn=lambda x, _: np.transpose(x, (2, 3, 1, 0)), ) if port_bias: - print(f"porting bias {hf_weight_prefix} -> {keras_layer_name}") + print(f"porting bias {hf_weight_prefix} -> {keras_layer}") loader.port_weight( - backbone.get_layer(keras_layer_name).bias, + keras_layer.bias, hf_weight_key=f"{hf_weight_prefix}.bias", ) - def port_batch_normalization(keras_layer_name, hf_weight_prefix): - print(f"porting weights {hf_weight_prefix} -> {keras_layer_name}") + def port_batch_normalization(keras_layer, hf_weight_prefix): + print(f"porting weights {hf_weight_prefix} -> {keras_layer}") loader.port_weight( - backbone.get_layer(keras_layer_name).gamma, + keras_layer.gamma, hf_weight_key=f"{hf_weight_prefix}.weight", ) loader.port_weight( - backbone.get_layer(keras_layer_name).beta, + keras_layer.beta, hf_weight_key=f"{hf_weight_prefix}.bias", ) loader.port_weight( - backbone.get_layer(keras_layer_name).moving_mean, + keras_layer.moving_mean, hf_weight_key=f"{hf_weight_prefix}.running_mean", ) loader.port_weight( - backbone.get_layer(keras_layer_name).moving_variance, + keras_layer.moving_variance, hf_weight_key=f"{hf_weight_prefix}.running_var", ) loader.port_weight( - backbone.get_layer(keras_layer_name).moving_variance, + keras_layer.moving_variance, hf_weight_key=f"{hf_weight_prefix}.running_var", ) # Stem - port_conv2d("input_conv", "conv_stem") - port_batch_normalization("input_batch_norm", "bn1") + port_conv2d(backbone.get_layer("input_conv"), "conv_stem") + port_batch_normalization(backbone.get_layer("input_batch_norm"), "bn1") # DepthWise Block (block 0) hf_name = "blocks.0.0" keras_name = "block_0_0" - port_conv2d(f"{keras_name}_conv1", f"{hf_name}.conv_dw") - port_batch_normalization(f"{keras_name}_bn1", f"{hf_name}.bn1") + stem_block = backbone.get_layer(keras_name) - port_conv2d(f"{keras_name}_se_conv_reduce", f"{hf_name}.se.conv_reduce", True) - port_conv2d(f"{keras_name}_se_conv_expand", f"{hf_name}.se.conv_expand", True) + port_conv2d(stem_block.conv1, f"{hf_name}.conv_dw") + port_batch_normalization(stem_block.bn1, f"{hf_name}.bn1") - port_conv2d(f"{keras_name}_conv2", f"{hf_name}.conv_pw") - port_batch_normalization(f"{keras_name}_bn2", f"{hf_name}.bn2") + stem_se_block = stem_block.se_layer + port_conv2d(stem_se_block.conv_reduce, f"{hf_name}.se.conv_reduce", True) + port_conv2d(stem_se_block.conv_expand, f"{hf_name}.se.conv_expand", True) + + port_conv2d(stem_block.conv2, f"{hf_name}.conv_pw") + port_batch_normalization(stem_block.bn2, f"{hf_name}.bn2") # Stages num_stacks = len(backbone.stackwise_num_blocks) @@ -134,37 +137,36 @@ def port_batch_normalization(keras_layer_name, hf_weight_prefix): hf_name = f"blocks.{block_idx+1}.{inverted_block}" # Inverted Residual Block - port_conv2d(f"{keras_name}_conv1", f"{hf_name}.conv_pw") - port_batch_normalization(f"{keras_name}_bn1", f"{hf_name}.bn1") - port_conv2d(f"{keras_name}_conv2", f"{hf_name}.conv_dw") - port_batch_normalization(f"{keras_name}_bn2", f"{hf_name}.bn2") + ir_block = backbone.get_layer(keras_name) + port_conv2d(ir_block.conv1, f"{hf_name}.conv_pw") + port_batch_normalization(ir_block.bn1, f"{hf_name}.bn1") + port_conv2d(ir_block.conv2, f"{hf_name}.conv_dw") + port_batch_normalization(ir_block.bn2, f"{hf_name}.bn2") if backbone.stackwise_se_ratio[block_idx][inverted_block]: + ir_se_block = ir_block.se port_conv2d( - f"{keras_name}_se_conv_reduce", + ir_se_block.conv_reduce, f"{hf_name}.se.conv_reduce", True, ) port_conv2d( - f"{keras_name}_se_conv_expand", + ir_se_block.conv_expand, f"{hf_name}.se.conv_expand", True, ) - port_conv2d(f"{keras_name}_conv3", f"{hf_name}.conv_pwl") - port_batch_normalization(f"{keras_name}_bn3", f"{hf_name}.bn3") + port_conv2d(ir_block.conv3, f"{hf_name}.conv_pwl") + port_batch_normalization(ir_block.bn3, f"{hf_name}.bn3") # ConvBnAct Block - port_conv2d(f"block_{num_stacks+1}_0_conv", f"blocks.{num_stacks+1}.0.conv") + cba_block_name = f"block_{num_stacks+1}_0" + cba_block = backbone.get_layer(cba_block_name) + port_conv2d(cba_block.conv, f"blocks.{num_stacks+1}.0.conv") port_batch_normalization( - f"block_{num_stacks+1}_0_bn", f"blocks.{num_stacks+1}.0.bn1" + cba_block.bn, f"blocks.{num_stacks+1}.0.bn1" ) - port_conv2d("output_conv", "conv_head") - # if version == "v2": - # port_batch_normalization("output_batch_norm", "bn2") - - def convert_head(task, loader, timm_config): prefix = "classifier." loader.port_weight( diff --git a/tools/checkpoint_conversion/convert_mobilenet_checkpoints.py b/tools/checkpoint_conversion/convert_mobilenet_checkpoints.py index 3a47f7d779..4992960f23 100644 --- a/tools/checkpoint_conversion/convert_mobilenet_checkpoints.py +++ b/tools/checkpoint_conversion/convert_mobilenet_checkpoints.py @@ -79,7 +79,6 @@ def validate_output(keras_model, timm_model): preprocessing_diff = np.mean(np.abs(keras_preprocessed - timm_preprocessed)) print("🔶 Preprocessing difference:", preprocessing_diff) - def main(_): preset = FLAGS.preset if os.path.exists(preset): @@ -95,6 +94,8 @@ def main(_): print("✅ Loaded KerasHub model.") keras_model = keras_hub.models.ImageClassifier.from_preset( "hf://" + timm_name, + include_conv=True, + flatten=True, ) keras_model.save_to_preset(f"./{preset}") From a5a0bb366d83d55871b8ab1788525e5004864b06 Mon Sep 17 00:00:00 2001 From: Piseth Ky Date: Fri, 17 Jan 2025 14:38:10 -0800 Subject: [PATCH 18/21] matched mobilenetv3 inference, working now --- keras_hub/src/models/image_classifier.py | 24 ------- .../src/models/mobilenet/conv_bn_act_block.py | 5 +- .../models/mobilenet/depthwise_conv_block.py | 21 +++--- .../mobilenet/inverted_residual_block.py | 5 +- .../models/mobilenet/mobilenet_backbone.py | 16 +++-- .../mobilenet/mobilenet_image_classifier.py | 66 +++++++++++++++++++ keras_hub/src/models/mobilenet/util.py | 2 +- keras_hub/src/utils/timm/convert_mobilenet.py | 32 +++++++-- .../convert_mobilenet_checkpoints.py | 3 +- 9 files changed, 122 insertions(+), 52 deletions(-) diff --git a/keras_hub/src/models/image_classifier.py b/keras_hub/src/models/image_classifier.py index 878a087ed4..e75e390899 100644 --- a/keras_hub/src/models/image_classifier.py +++ b/keras_hub/src/models/image_classifier.py @@ -97,8 +97,6 @@ def __init__( activation=None, dropout=0.0, head_dtype=None, - include_conv=False, - flatten=False, **kwargs, ): head_dtype = head_dtype or backbone.dtype_policy @@ -129,20 +127,6 @@ def __init__( dtype=head_dtype, name="output_dropout", ) - - if include_conv: - self.output_conv = keras.layers.Conv2D( - filters=1024, - kernel_size=(1, 1), - strides=(1, 1), - use_bias=True, - padding='valid', - activation="hardswish", - ) - - if flatten: - self.flatten = keras.layers.Flatten() - self.output_dense = keras.layers.Dense( num_classes, activation=activation, @@ -155,10 +139,6 @@ def __init__( x = self.backbone(inputs) x = self.pooler(x) x = self.output_dropout(x) - if include_conv: - x = self.output_conv(x) - if flatten: - x = self.flatten(x) outputs = self.output_dense(x) super().__init__( inputs=inputs, @@ -171,8 +151,6 @@ def __init__( self.activation = activation self.pooling = pooling self.dropout = dropout - self.include_conv = include_conv - self.flatten = flatten def get_config(self): # Backbone serialized in `super` @@ -183,8 +161,6 @@ def get_config(self): "pooling": self.pooling, "activation": self.activation, "dropout": self.dropout, - "include_conv": self.include_conv, - "flatten": self.flatten, } ) return config diff --git a/keras_hub/src/models/mobilenet/conv_bn_act_block.py b/keras_hub/src/models/mobilenet/conv_bn_act_block.py index 030e92f7ec..d5e5b95be7 100644 --- a/keras_hub/src/models/mobilenet/conv_bn_act_block.py +++ b/keras_hub/src/models/mobilenet/conv_bn_act_block.py @@ -1,6 +1,5 @@ import keras - BN_EPSILON = 1e-5 BN_MOMENTUM = 0.9 BN_AXIS = 3 @@ -9,8 +8,8 @@ class ConvBnActBlock(keras.layers.Layer): def __init__( self, - filter, - activation, + filter, + activation, name=None, **kwargs, ): diff --git a/keras_hub/src/models/mobilenet/depthwise_conv_block.py b/keras_hub/src/models/mobilenet/depthwise_conv_block.py index feccdd1ffa..7658fadd32 100644 --- a/keras_hub/src/models/mobilenet/depthwise_conv_block.py +++ b/keras_hub/src/models/mobilenet/depthwise_conv_block.py @@ -1,9 +1,10 @@ import keras -from keras_hub.src.models.mobilenet.squeeze_and_excite_2d import SqueezeAndExcite2D +from keras_hub.src.models.mobilenet.squeeze_and_excite_2d import ( + SqueezeAndExcite2D, +) from keras_hub.src.models.mobilenet.util import adjust_channels - BN_EPSILON = 1e-5 BN_MOMENTUM = 0.9 BN_AXIS = 3 @@ -36,10 +37,10 @@ class DepthwiseConvBlock(keras.layers.Layer): def __init__( self, infilters, - filters, - kernel_size=3, - stride=2, - se=None, + filters, + kernel_size=3, + stride=2, + se=None, name=None, **kwargs, ): @@ -121,10 +122,10 @@ def call(self, inputs): def get_config(self): config = { "infilters": self.infilters, - "filters": self.filters, - "kernel_size": self.kernel_size, - "stride": self.stride, - "se": self.se, + "filters": self.filters, + "kernel_size": self.kernel_size, + "stride": self.stride, + "se": self.se, "name": self.name, } diff --git a/keras_hub/src/models/mobilenet/inverted_residual_block.py b/keras_hub/src/models/mobilenet/inverted_residual_block.py index c95733744b..0145d68b48 100644 --- a/keras_hub/src/models/mobilenet/inverted_residual_block.py +++ b/keras_hub/src/models/mobilenet/inverted_residual_block.py @@ -1,9 +1,10 @@ import keras -from keras_hub.src.models.mobilenet.squeeze_and_excite_2d import SqueezeAndExcite2D +from keras_hub.src.models.mobilenet.squeeze_and_excite_2d import ( + SqueezeAndExcite2D, +) from keras_hub.src.models.mobilenet.util import adjust_channels - BN_EPSILON = 1e-5 BN_MOMENTUM = 0.9 BN_AXIS = 3 diff --git a/keras_hub/src/models/mobilenet/mobilenet_backbone.py b/keras_hub/src/models/mobilenet/mobilenet_backbone.py index 8ccd8d04d0..36a9d7ec12 100644 --- a/keras_hub/src/models/mobilenet/mobilenet_backbone.py +++ b/keras_hub/src/models/mobilenet/mobilenet_backbone.py @@ -2,12 +2,15 @@ from keras_hub.src.api_export import keras_hub_export from keras_hub.src.models.backbone import Backbone -from keras_hub.src.models.mobilenet.depthwise_conv_block import DepthwiseConvBlock -from keras_hub.src.models.mobilenet.inverted_residual_block import InvertedResidualBlock from keras_hub.src.models.mobilenet.conv_bn_act_block import ConvBnActBlock +from keras_hub.src.models.mobilenet.depthwise_conv_block import ( + DepthwiseConvBlock, +) +from keras_hub.src.models.mobilenet.inverted_residual_block import ( + InvertedResidualBlock, +) from keras_hub.src.models.mobilenet.util import adjust_channels - BN_EPSILON = 1e-5 BN_MOMENTUM = 0.9 @@ -140,7 +143,7 @@ def __init__( input_num_filters = adjust_channels(input_num_filters) x = keras.layers.ZeroPadding2D( - padding=(1,1), + padding=(1, 1), name="input_pad", )(x) x = keras.layers.Conv2D( @@ -160,7 +163,10 @@ def __init__( x = keras.layers.Activation(input_activation)(x) x = DepthwiseConvBlock( - input_num_filters, depthwise_filters, se=squeeze_and_excite, name="block_0" + input_num_filters, + depthwise_filters, + se=squeeze_and_excite, + name="block_0", )(x) for block in range(len(stackwise_num_blocks)): diff --git a/keras_hub/src/models/mobilenet/mobilenet_image_classifier.py b/keras_hub/src/models/mobilenet/mobilenet_image_classifier.py index e9cc0fc153..750bbc1245 100644 --- a/keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +++ b/keras_hub/src/models/mobilenet/mobilenet_image_classifier.py @@ -1,3 +1,5 @@ +import keras + from keras_hub.src.api_export import keras_hub_export from keras_hub.src.models.image_classifier import ImageClassifier from keras_hub.src.models.mobilenet.mobilenet_backbone import MobileNetBackbone @@ -10,3 +12,67 @@ class MobileNetImageClassifier(ImageClassifier): backbone_cls = MobileNetBackbone preprocessor_cls = MobileNetImageClassifierPreprocessor + + def __init__( + self, + backbone, + num_classes, + preprocessor=None, + head_dtype=None, + **kwargs, + ): + super().__init__( + backbone, + num_classes, + preprocessor=preprocessor, + head_dtype=head_dtype, + **kwargs, + ) + + head_dtype = head_dtype or backbone.dtype_policy + data_format = getattr(backbone, "data_format", None) + + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor + self.pooler = keras.layers.GlobalAveragePooling2D( + data_format, keepdims=True, dtype=head_dtype, name="pooler" + ) + + self.output_conv = keras.layers.Conv2D( + filters=1024, + kernel_size=(1, 1), + strides=(1, 1), + use_bias=True, + padding="valid", + activation="hard_silu", + ) + + self.flatten = keras.layers.Flatten() + + self.output_dense = keras.layers.Dense( + num_classes, + dtype=head_dtype, + name="predictions", + ) + + # === Config === + self.num_classes = num_classes + + def get_config(self): + # Backbone serialized in `super` + config = super().get_config() + config.update( + { + "num_classes": self.num_classes, + } + ) + return config + + def call(self, inputs): + x = self.backbone(inputs) + x = self.pooler(x) + x = self.output_conv(x) + x = self.flatten(x) + x = self.output_dense(x) + return x diff --git a/keras_hub/src/models/mobilenet/util.py b/keras_hub/src/models/mobilenet/util.py index b17efc4b87..59896c209a 100644 --- a/keras_hub/src/models/mobilenet/util.py +++ b/keras_hub/src/models/mobilenet/util.py @@ -20,4 +20,4 @@ def adjust_channels(x, divisor=8, min_value=None): # make sure that round down does not go down by more than 10%. if new_x < 0.9 * x: new_x += divisor - return new_x \ No newline at end of file + return new_x diff --git a/keras_hub/src/utils/timm/convert_mobilenet.py b/keras_hub/src/utils/timm/convert_mobilenet.py index 09545c6c42..950e71c022 100644 --- a/keras_hub/src/utils/timm/convert_mobilenet.py +++ b/keras_hub/src/utils/timm/convert_mobilenet.py @@ -77,7 +77,7 @@ def port_conv2d(keras_layer, hf_weight_prefix, port_bias=False): hf_weight_key=f"{hf_weight_prefix}.weight", hook_fn=lambda x, _: np.transpose(x, (2, 3, 1, 0)), ) - + if port_bias: print(f"porting bias {hf_weight_prefix} -> {keras_layer}") loader.port_weight( @@ -85,7 +85,6 @@ def port_conv2d(keras_layer, hf_weight_prefix, port_bias=False): hf_weight_key=f"{hf_weight_prefix}.bias", ) - def port_batch_normalization(keras_layer, hf_weight_prefix): print(f"porting weights {hf_weight_prefix} -> {keras_layer}") loader.port_weight( @@ -163,11 +162,34 @@ def port_batch_normalization(keras_layer, hf_weight_prefix): cba_block_name = f"block_{num_stacks+1}_0" cba_block = backbone.get_layer(cba_block_name) port_conv2d(cba_block.conv, f"blocks.{num_stacks+1}.0.conv") - port_batch_normalization( - cba_block.bn, f"blocks.{num_stacks+1}.0.bn1" - ) + port_batch_normalization(cba_block.bn, f"blocks.{num_stacks+1}.0.bn1") + def convert_head(task, loader, timm_config): + def port_conv2d(keras_layer, hf_weight_prefix, port_bias=False): + print(f"porting weights {hf_weight_prefix} -> {keras_layer}") + loader.port_weight( + keras_layer.kernel, + hf_weight_key=f"{hf_weight_prefix}.weight", + hook_fn=lambda x, _: np.transpose(x, (2, 3, 1, 0)), + ) + + if port_bias: + print(f"porting bias {hf_weight_prefix} -> {keras_layer}") + loader.port_weight( + keras_layer.bias, + hf_weight_key=f"{hf_weight_prefix}.bias", + ) + + data_format = getattr(task.backbone, "data_format", None) + if not data_format or data_format == "channels_last": + conv_head_input_shape = (None, 1, 1, task.backbone.output_shape[-1]) + else: + conv_head_input_shape = (None, task.backbone.output_shape[1], 1, 1) + task.output_conv.build(input_shape=conv_head_input_shape) + task.output_dense.build(input_shape=(None, 1024)) + + port_conv2d(task.output_conv, "conv_head", True) prefix = "classifier." loader.port_weight( task.output_dense.kernel, diff --git a/tools/checkpoint_conversion/convert_mobilenet_checkpoints.py b/tools/checkpoint_conversion/convert_mobilenet_checkpoints.py index 4992960f23..3a47f7d779 100644 --- a/tools/checkpoint_conversion/convert_mobilenet_checkpoints.py +++ b/tools/checkpoint_conversion/convert_mobilenet_checkpoints.py @@ -79,6 +79,7 @@ def validate_output(keras_model, timm_model): preprocessing_diff = np.mean(np.abs(keras_preprocessed - timm_preprocessed)) print("🔶 Preprocessing difference:", preprocessing_diff) + def main(_): preset = FLAGS.preset if os.path.exists(preset): @@ -94,8 +95,6 @@ def main(_): print("✅ Loaded KerasHub model.") keras_model = keras_hub.models.ImageClassifier.from_preset( "hf://" + timm_name, - include_conv=True, - flatten=True, ) keras_model.save_to_preset(f"./{preset}") From cc72ba427241b9ae7f493cf1c0866aab88ea2c88 Mon Sep 17 00:00:00 2001 From: Piseth Ky Date: Fri, 17 Jan 2025 15:28:46 -0800 Subject: [PATCH 19/21] format pass --- keras_hub/src/models/mobilenet/mobilenet_backbone.py | 4 ++-- .../src/models/mobilenet/mobilenet_backbone_test.py | 1 - keras_hub/src/utils/timm/convert_mobilenet.py | 10 +++++----- keras_hub/src/utils/timm/preset_loader.py | 2 +- 4 files changed, 8 insertions(+), 9 deletions(-) diff --git a/keras_hub/src/models/mobilenet/mobilenet_backbone.py b/keras_hub/src/models/mobilenet/mobilenet_backbone.py index 243cc74537..d36d911b3e 100644 --- a/keras_hub/src/models/mobilenet/mobilenet_backbone.py +++ b/keras_hub/src/models/mobilenet/mobilenet_backbone.py @@ -190,13 +190,13 @@ def __init__( se_ratio=stackwise_se_ratio[block][inverted_block], activation=stackwise_activation[block][inverted_block], padding=stackwise_padding[block][inverted_block], - name=f"block_{block+1}_{inverted_block}", + name=f"block_{block + 1}_{inverted_block}", )(x) x = ConvBnActBlock( filter=adjust_channels(last_layer_filter), activation="hard_swish", - name=f"block_{len(stackwise_num_blocks)+1}_0", + name=f"block_{len(stackwise_num_blocks) + 1}_0", )(x) super().__init__(inputs=image_input, outputs=x, **kwargs) diff --git a/keras_hub/src/models/mobilenet/mobilenet_backbone_test.py b/keras_hub/src/models/mobilenet/mobilenet_backbone_test.py index 3d909c9221..c0033d52b0 100644 --- a/keras_hub/src/models/mobilenet/mobilenet_backbone_test.py +++ b/keras_hub/src/models/mobilenet/mobilenet_backbone_test.py @@ -7,7 +7,6 @@ class MobileNetBackboneTest(TestCase): def setUp(self): - self.init_kwargs = { "stackwise_expansion": [ [40, 56], diff --git a/keras_hub/src/utils/timm/convert_mobilenet.py b/keras_hub/src/utils/timm/convert_mobilenet.py index 950e71c022..1962da69a5 100644 --- a/keras_hub/src/utils/timm/convert_mobilenet.py +++ b/keras_hub/src/utils/timm/convert_mobilenet.py @@ -132,8 +132,8 @@ def port_batch_normalization(keras_layer, hf_weight_prefix): num_stacks = len(backbone.stackwise_num_blocks) for block_idx in range(num_stacks): for inverted_block in range(backbone.stackwise_num_blocks[block_idx]): - keras_name = f"block_{block_idx+1}_{inverted_block}" - hf_name = f"blocks.{block_idx+1}.{inverted_block}" + keras_name = f"block_{block_idx + 1}_{inverted_block}" + hf_name = f"blocks.{block_idx + 1}.{inverted_block}" # Inverted Residual Block ir_block = backbone.get_layer(keras_name) @@ -159,10 +159,10 @@ def port_batch_normalization(keras_layer, hf_weight_prefix): port_batch_normalization(ir_block.bn3, f"{hf_name}.bn3") # ConvBnAct Block - cba_block_name = f"block_{num_stacks+1}_0" + cba_block_name = f"block_{num_stacks + 1}_0" cba_block = backbone.get_layer(cba_block_name) - port_conv2d(cba_block.conv, f"blocks.{num_stacks+1}.0.conv") - port_batch_normalization(cba_block.bn, f"blocks.{num_stacks+1}.0.bn1") + port_conv2d(cba_block.conv, f"blocks.{num_stacks + 1}.0.conv") + port_batch_normalization(cba_block.bn, f"blocks.{num_stacks + 1}.0.bn1") def convert_head(task, loader, timm_config): diff --git a/keras_hub/src/utils/timm/preset_loader.py b/keras_hub/src/utils/timm/preset_loader.py index 15a8a6f135..0a6955ebda 100644 --- a/keras_hub/src/utils/timm/preset_loader.py +++ b/keras_hub/src/utils/timm/preset_loader.py @@ -4,8 +4,8 @@ from keras_hub.src.utils.preset_utils import PresetLoader from keras_hub.src.utils.preset_utils import jax_memory_cleanup from keras_hub.src.utils.timm import convert_densenet -from keras_hub.src.utils.timm import convert_mobilenet from keras_hub.src.utils.timm import convert_efficientnet +from keras_hub.src.utils.timm import convert_mobilenet from keras_hub.src.utils.timm import convert_resnet from keras_hub.src.utils.timm import convert_vgg from keras_hub.src.utils.transformers.safetensor_utils import SafetensorLoader From c4214d5f82d600dea7a0156eec271f8c0424ef85 Mon Sep 17 00:00:00 2001 From: Piseth Ky Date: Fri, 17 Jan 2025 15:46:52 -0800 Subject: [PATCH 20/21] actual format pass --- keras_hub/src/models/mobilenet/depthwise_conv_block.py | 3 ++- .../src/models/mobilenet/mobilenet_image_classifier.py | 6 +++--- keras_hub/src/models/mobilenet/mobilenet_presets.py | 4 ++-- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/keras_hub/src/models/mobilenet/depthwise_conv_block.py b/keras_hub/src/models/mobilenet/depthwise_conv_block.py index 7658fadd32..1bc8191a09 100644 --- a/keras_hub/src/models/mobilenet/depthwise_conv_block.py +++ b/keras_hub/src/models/mobilenet/depthwise_conv_block.py @@ -29,7 +29,8 @@ class DepthwiseConvBlock(keras.layers.Layer): Input shape: 4D tensor with shape: `(batch, rows, cols, channels)` in "channels_last" - 4D tensor with shape: `(batch, channels, rows, cols)` in "channels_first" + 4D tensor with shape: `(batch, channels, rows, cols)` in + "channels_first" Returns: Output tensor of block. """ diff --git a/keras_hub/src/models/mobilenet/mobilenet_image_classifier.py b/keras_hub/src/models/mobilenet/mobilenet_image_classifier.py index 750bbc1245..050bb2de30 100644 --- a/keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +++ b/keras_hub/src/models/mobilenet/mobilenet_image_classifier.py @@ -1,11 +1,11 @@ import keras +from mobilenet.mobilenet_image_classifier_preprocessor import ( + MobileNetImageClassifierPreprocessor, +) from keras_hub.src.api_export import keras_hub_export from keras_hub.src.models.image_classifier import ImageClassifier from keras_hub.src.models.mobilenet.mobilenet_backbone import MobileNetBackbone -from keras_hub.src.models.mobilenet.mobilenet_image_classifier_preprocessor import ( - MobileNetImageClassifierPreprocessor, -) @keras_hub_export("keras_hub.models.MobileNetImageClassifier") diff --git a/keras_hub/src/models/mobilenet/mobilenet_presets.py b/keras_hub/src/models/mobilenet/mobilenet_presets.py index 172e7fdbf6..3f7c81ec61 100644 --- a/keras_hub/src/models/mobilenet/mobilenet_presets.py +++ b/keras_hub/src/models/mobilenet/mobilenet_presets.py @@ -4,8 +4,8 @@ "mobilenetv3_small_050": { "metadata": { "description": ( - "Small MObilenet V3 model pre-trained on the ImageNet 1k dataset " - "at a 224x224 resolution." + "Small MObilenet V3 model pre-trained on the ImageNet 1k " + "dataset at a 224x224 resolution." ), "official_name": "MobileNet", "path": "mobilenet3", From 3c8cbd58803ea582258070ff88a161cdb8ed28f5 Mon Sep 17 00:00:00 2001 From: Piseth Ky Date: Fri, 17 Jan 2025 23:00:23 -0800 Subject: [PATCH 21/21] fix import --- .../src/models/mobilenet/mobilenet_image_classifier.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/keras_hub/src/models/mobilenet/mobilenet_image_classifier.py b/keras_hub/src/models/mobilenet/mobilenet_image_classifier.py index 050bb2de30..750bbc1245 100644 --- a/keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +++ b/keras_hub/src/models/mobilenet/mobilenet_image_classifier.py @@ -1,11 +1,11 @@ import keras -from mobilenet.mobilenet_image_classifier_preprocessor import ( - MobileNetImageClassifierPreprocessor, -) from keras_hub.src.api_export import keras_hub_export from keras_hub.src.models.image_classifier import ImageClassifier from keras_hub.src.models.mobilenet.mobilenet_backbone import MobileNetBackbone +from keras_hub.src.models.mobilenet.mobilenet_image_classifier_preprocessor import ( + MobileNetImageClassifierPreprocessor, +) @keras_hub_export("keras_hub.models.MobileNetImageClassifier")