Skip to content

Commit

Permalink
Add docstrings for kimm.blocks.* (#50)
Browse files Browse the repository at this point in the history
* Add docstrings for `kimm.blocks.*`

* Fix argument
  • Loading branch information
james77777778 authored Jun 2, 2024
1 parent ac667c4 commit 927370b
Show file tree
Hide file tree
Showing 10 changed files with 74 additions and 24 deletions.
21 changes: 13 additions & 8 deletions kimm/_src/blocks/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from keras import backend
from keras import layers
from keras.src.utils.argument_validation import standardize_tuple

from kimm._src.kimm_export import kimm_export

Expand All @@ -10,29 +11,33 @@
def apply_conv2d_block(
inputs,
filters: typing.Optional[int] = None,
kernel_size: typing.Optional[
typing.Union[int, typing.Sequence[int]]
] = None,
kernel_size: typing.Union[int, typing.Sequence[int]] = 1,
strides: int = 1,
groups: int = 1,
activation: typing.Optional[str] = None,
use_depthwise: bool = False,
add_skip: bool = False,
has_skip: bool = False,
bn_momentum: float = 0.9,
bn_epsilon: float = 1e-5,
padding: typing.Optional[typing.Literal["same", "valid"]] = None,
name="conv2d_block",
):
"""(ZeroPadding) + Conv2D/DepthwiseConv2D + BN + (Activation)."""
if kernel_size is None:
raise ValueError(
f"kernel_size must be passed. Received: kernel_size={kernel_size}"
)
if isinstance(kernel_size, int):
kernel_size = [kernel_size, kernel_size]
kernel_size = standardize_tuple(kernel_size, 2, "kernel_size")

channels_axis = -1 if backend.image_data_format() == "channels_last" else -3
input_channels = inputs.shape[channels_axis]
has_skip = add_skip and strides == 1 and input_channels == filters
input_filters = inputs.shape[channels_axis]
if has_skip and (strides != 1 or input_filters != filters):
raise ValueError(
"If `has_skip=True`, strides must be 1 and `filters` must be the "
"same as input_filters. "
f"Received: strides={strides}, filters={filters}, "
f"input_filters={input_filters}"
)
x = inputs

if padding is None:
Expand Down
17 changes: 12 additions & 5 deletions kimm/_src/blocks/depthwise_separation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
@kimm_export(parent_path=["kimm.blocks"])
def apply_depthwise_separation_block(
inputs,
output_channels: int,
filters: int,
depthwise_kernel_size: int = 3,
pointwise_kernel_size: int = 1,
strides: int = 1,
Expand All @@ -21,14 +21,21 @@ def apply_depthwise_separation_block(
se_gate_activation: typing.Optional[str] = "sigmoid",
se_make_divisible_number: typing.Optional[int] = None,
pw_activation: typing.Optional[str] = None,
skip: bool = True,
has_skip: bool = True,
bn_epsilon: float = 1e-5,
padding: typing.Optional[typing.Literal["same", "valid"]] = None,
name: str = "depthwise_separation_block",
):
"""Conv2D block + (SqueezeAndExcitation) + Conv2D."""
channels_axis = -1 if backend.image_data_format() == "channels_last" else -3
input_channels = inputs.shape[channels_axis]
has_skip = skip and (strides == 1 and input_channels == output_channels)
input_filters = inputs.shape[channels_axis]
if has_skip and (strides != 1 or input_filters != filters):
raise ValueError(
"If `has_skip=True`, strides must be 1 and `filters` must be the "
"same as input_filters. "
f"Received: strides={strides}, filters={filters}, "
f"input_filters={input_filters}"
)

x = inputs
x = apply_conv2d_block(
Expand All @@ -52,7 +59,7 @@ def apply_depthwise_separation_block(
)
x = apply_conv2d_block(
x,
output_channels,
filters,
pointwise_kernel_size,
1,
activation=pw_activation,
Expand Down
7 changes: 4 additions & 3 deletions kimm/_src/blocks/inverted_residual.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
@kimm_export(parent_path=["kimm.blocks"])
def apply_inverted_residual_block(
inputs,
output_channels: int,
filters: int,
depthwise_kernel_size: int = 3,
expansion_kernel_size: int = 1,
pointwise_kernel_size: int = 1,
Expand All @@ -28,10 +28,11 @@ def apply_inverted_residual_block(
padding: typing.Optional[typing.Literal["same", "valid"]] = None,
name: str = "inverted_residual_block",
):
"""Conv2D block + DepthwiseConv2D block + (SE) + Conv2D."""
channels_axis = -1 if backend.image_data_format() == "channels_last" else -3
input_channels = inputs.shape[channels_axis]
hidden_channels = make_divisible(input_channels * expansion_ratio)
has_skip = strides == 1 and input_channels == output_channels
has_skip = strides == 1 and input_channels == filters

x = inputs
# Point-wise expansion
Expand Down Expand Up @@ -70,7 +71,7 @@ def apply_inverted_residual_block(
# Point-wise linear projection
x = apply_conv2d_block(
x,
output_channels,
filters,
pointwise_kernel_size,
1,
activation=None,
Expand Down
1 change: 1 addition & 0 deletions kimm/_src/blocks/squeeze_and_excitation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def apply_se_block(
se_input_channels: typing.Optional[int] = None,
name: str = "se_block",
):
"""Squeeze and Excitation."""
channels_axis = -1 if backend.image_data_format() == "channels_last" else -3
input_channels = inputs.shape[channels_axis]
if se_input_channels is None:
Expand Down
2 changes: 2 additions & 0 deletions kimm/_src/blocks/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def apply_mlp_block(
data_format: typing.Optional[str] = None,
name: str = "mlp_block",
):
"""Dense/Conv2D + Activation + Dense/Conv2D."""
if data_format is None:
data_format = backend.image_data_format()
dim_axis = -1 if data_format == "channels_last" else 1
Expand Down Expand Up @@ -56,6 +57,7 @@ def apply_transformer_block(
activation: str = "gelu",
name: str = "transformer_block",
):
"""LN + Attention + LN + MLP block."""
# data_format must be "channels_last"
x = inputs
residual_1 = x
Expand Down
2 changes: 1 addition & 1 deletion kimm/_src/layers/reparameterizable_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(
self,
filters,
kernel_size,
strides=(1, 1),
strides=1,
padding=None,
has_skip: bool = True,
has_scale: bool = True,
Expand Down
16 changes: 14 additions & 2 deletions kimm/_src/models/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,16 +245,28 @@ def __init__(
"activation": activation,
}
if block_type == "ds":
has_skip = x.shape[channels_axis] == c and s == 1
x = apply_depthwise_separation_block(
x, c, k, 1, s, se, se_activation=activation, **_kwargs
x,
c,
k,
1,
s,
se,
se_activation=activation,
has_skip=has_skip,
**_kwargs,
)
elif block_type == "ir":
se_c = x.shape[channels_axis]
x = apply_inverted_residual_block(
x, c, k, 1, 1, s, e, se, se_channels=se_c, **_kwargs
)
elif block_type == "cn":
x = apply_conv2d_block(x, c, k, s, add_skip=True, **_kwargs)
has_skip = x.shape[channels_axis] == c and s == 1
x = apply_conv2d_block(
x, c, k, s, has_skip=has_skip, **_kwargs
)
elif block_type == "er":
x = apply_edge_residual_block(x, c, k, 1, s, e, **_kwargs)
current_stride *= s
Expand Down
6 changes: 3 additions & 3 deletions kimm/_src/models/hgnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def apply_high_perf_gpu_block(
hidden_channels,
output_channels,
kernel_size,
add_skip=False,
has_skip=False,
use_light_block=False,
use_learnable_affine=False,
aggregation="ese",
Expand Down Expand Up @@ -329,7 +329,7 @@ def apply_high_perf_gpu_block(
name=f"{name}_aggregation_0",
)
x = apply_ese_module(x, output_channels, name=f"{name}_aggregation_1")
if add_skip:
if has_skip:
x = layers.Add()([x, inputs])
return x

Expand Down Expand Up @@ -375,7 +375,7 @@ def apply_high_perf_gpu_stage(
hidden_channels,
output_channels,
kernel_size,
add_skip=False if i == 0 else True,
has_skip=False if i == 0 else True,
use_light_block=use_light_block,
use_learnable_affine=use_learnable_affine,
aggregation=aggregation,
Expand Down
15 changes: 14 additions & 1 deletion kimm/_src/models/mobilenet_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import typing

import keras
from keras import backend

from kimm._src.blocks.conv2d import apply_conv2d_block
from kimm._src.blocks.depthwise_separation import (
Expand Down Expand Up @@ -55,6 +56,10 @@ def __init__(
)

self.set_properties(kwargs)
channels_axis = (
-1 if backend.image_data_format() == "channels_last" else -3
)

inputs = self.determine_input_tensor(
input_tensor,
self._input_shape,
Expand Down Expand Up @@ -93,8 +98,16 @@ def __init__(
s = s if current_layer_idx == 0 else 1
name = f"blocks_{current_block_idx}_{current_layer_idx}"
if block_type == "ds":
has_skip = x.shape[channels_axis] == c and s == 1
x = apply_depthwise_separation_block(
x, c, k, 1, s, activation="relu6", name=name
x,
c,
k,
1,
s,
activation="relu6",
has_skip=has_skip,
name=name,
)
elif block_type == "ir":
x = apply_inverted_residual_block(
Expand Down
11 changes: 10 additions & 1 deletion kimm/_src/models/mobilenet_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import warnings

import keras
from keras import backend
from keras import layers

from kimm._src.blocks.conv2d import apply_conv2d_block
Expand Down Expand Up @@ -124,6 +125,10 @@ def __init__(
padding = kwargs.pop("padding", None)

self.set_properties(kwargs)
channels_axis = (
-1 if backend.image_data_format() == "channels_last" else -3
)

inputs = self.determine_input_tensor(
input_tensor,
self._input_shape,
Expand Down Expand Up @@ -181,6 +186,10 @@ def __init__(
),
}
if block_type in ("ds", "dsa"):
if block_type == "dsa":
has_skip = False
else:
has_skip = x.shape[channels_axis] == c and s == 1
x = apply_depthwise_separation_block(
x,
c,
Expand All @@ -193,7 +202,7 @@ def __init__(
se_gate_activation="hard_sigmoid",
se_make_divisible_number=8,
pw_activation=act if block_type == "dsa" else None,
skip=False if block_type == "dsa" else True,
has_skip=has_skip,
**_kwargs,
)
elif block_type == "ir":
Expand Down

0 comments on commit 927370b

Please sign in to comment.