Skip to content

Commit

Permalink
Merge pull request #26 from paganpasta/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
paganpasta authored Aug 7, 2022
2 parents 0f07b3d + 36913d4 commit 062c7e9
Show file tree
Hide file tree
Showing 10 changed files with 417 additions and 3 deletions.
32 changes: 32 additions & 0 deletions docs/api/models/classification/shufflenetv2.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# ShuffleNet-V2

---

::: eqxvision.models.ShuffleNetV2
selection:
members:
- __init__
- __call__

---


::: eqxvision.models.shufflenet_v2_x0_5


---


::: eqxvision.models.shufflenet_v2_x1_0


---


::: eqxvision.models.shufflenet_v2_x1_5


---


::: eqxvision.models.shufflenet_v2_x2_0
2 changes: 1 addition & 1 deletion eqxvision/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
r"""Root package info."""
__version__ = "0.1.3"
__version__ = "0.1.4"

from . import layers, models, utils
7 changes: 7 additions & 0 deletions eqxvision/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@
wide_resnet50_2,
wide_resnet101_2,
)
from .classification.shufflenetv2 import (
shufflenet_v2_x0_5,
shufflenet_v2_x1_0,
shufflenet_v2_x1_5,
shufflenet_v2_x2_0,
ShuffleNetV2,
)
from .classification.squeezenet import SqueezeNet, squeezenet1_0, squeezenet1_1
from .classification.vgg import (
VGG,
Expand Down
1 change: 1 addition & 0 deletions eqxvision/models/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
mobilenetv2,
mobilenetv3,
resnet,
shufflenetv2,
squeezenet,
vgg,
vit,
Expand Down
4 changes: 2 additions & 2 deletions eqxvision/models/classification/googlenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class GoogLeNet(eqx.Module):
def __init__(
self,
num_classes: int = 1000,
aux_logits: bool = True,
aux_logits: bool = False,
blocks: Optional[List[Callable[..., eqx.Module]]] = None,
dropout: float = 0.2,
dropout_aux: float = 0.7,
Expand All @@ -50,7 +50,7 @@ def __init__(
- `num_classes`: Number of classes in the classification task.
Also controls the final output shape `(num_classes,)`. Defaults to `1000`
- `aux_logits`: If `True`, two auxiliary branches are added to the network. Defaults to `True`
- `aux_logits`: If `True`, two auxiliary branches are added to the network. Defaults to `False`
- `blocks`: Blocks for constructing the network
- `dropout`: Dropout applied on the `main` branch. Defaults to `0.2`
- `dropout_aux`: Dropout applied on the `aux` branches. Defaults to `0.7`
Expand Down
291 changes: 291 additions & 0 deletions eqxvision/models/classification/shufflenetv2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,291 @@
from typing import Any, Callable, List, Optional

import equinox as eqx
import equinox.experimental as eqxex
import equinox.nn as nn
import jax
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jrandom
from equinox.custom_types import Array


def channel_shuffle(x: Array, groups: int) -> Array:
num_channels, height, width = x.shape
channels_per_group = num_channels // groups
x = jnp.reshape(x, (groups, channels_per_group, height, width))
x = jnp.transpose(x, axes=(1, 0, 2, 3))
x = jnp.reshape(x, (-1, height, width))
return x


class InvertedResidual(eqx.Module):
stride: int
branch1: nn.Sequential
branch2: nn.Sequential

def __init__(
self,
inp: int,
oup: int,
stride: int,
*,
key: "jax.random.PRNGKey" = None,
) -> None:
super().__init__()

keys = jrandom.split(key, 5)

if not (1 <= stride <= 3):
raise ValueError("illegal stride value")

branch_features = oup // 2
assert (stride != 1) or (inp == branch_features << 1)

self.stride = stride
if stride > 1:
self.branch1 = nn.Sequential(
[
self.depthwise_conv(
inp,
inp,
kernel_size=3,
stride=self.stride,
padding=1,
key=keys[0],
),
eqxex.BatchNorm(inp, axis_name="batch"),
nn.Conv2d(
inp,
branch_features,
kernel_size=1,
stride=1,
padding=0,
use_bias=False,
key=keys[1],
),
eqxex.BatchNorm(branch_features, axis_name="batch"),
nn.Lambda(jnn.relu),
]
)
else:
self.branch1 = nn.Sequential([nn.Identity])

self.branch2 = nn.Sequential(
[
nn.Conv2d(
inp if (self.stride > 1) else branch_features,
branch_features,
kernel_size=1,
stride=1,
padding=0,
use_bias=False,
key=keys[2],
),
eqxex.BatchNorm(branch_features, axis_name="batch"),
nn.Lambda(jnn.relu),
self.depthwise_conv(
branch_features,
branch_features,
kernel_size=3,
stride=self.stride,
padding=1,
key=keys[3],
),
eqxex.BatchNorm(branch_features, axis_name="batch"),
nn.Conv2d(
branch_features,
branch_features,
kernel_size=1,
stride=1,
padding=0,
use_bias=False,
key=keys[4],
),
eqxex.BatchNorm(branch_features, axis_name="batch"),
nn.Lambda(jnn.relu),
]
)

@staticmethod
def depthwise_conv(
i: int,
o: int,
kernel_size: int,
stride: int = 1,
padding: int = 0,
bias: bool = False,
key=None,
) -> nn.Conv2d:
return nn.Conv2d(
i, o, kernel_size, stride, padding, use_bias=bias, groups=i, key=key
)

def __call__(self, x, *, key: "jax.random.PRNGKey") -> Array:
if self.stride == 1:
x1, x2 = jnp.split(x, 2, axis=0)
out = jnp.concatenate((x1, self.branch2(x2)), axis=0)
else:
out = jnp.concatenate((self.branch1(x), self.branch2(x)), axis=0)

out = channel_shuffle(out, 2)
return out


class ShuffleNetV2(eqx.Module):
"""A simple port of `torchvision.models.shufflenetv2`"""

conv1: nn.Sequential
maxpool: nn.MaxPool2d
stage2: nn.Sequential
stage3: nn.Sequential
stage4: nn.Sequential
conv5: nn.Sequential
pool: nn.AdaptiveAvgPool2d
fc: nn.Linear

def __init__(
self,
stages_repeats: List[int],
stages_out_channels: List[int],
num_classes: int = 1000,
inverted_residual: Callable[..., eqx.Module] = InvertedResidual,
*,
key: Optional["jax.random.PRNGKey"] = None,
) -> None:
"""**Arguments:**
- stages_repeats: Number of times a block is repeated for each stage
- stages_out_channels: Output at each stage
- num_classes: Number of classes in the classification task.
Also controls the final output shape `(num_classes,)`. Defaults to `1000`
- inverted_residual: Network structure
- key: A `jax.random.PRNGKey` used to provide randomness for parameter
initialisation. (Keyword only argument.)
"""
super().__init__()
if key is None:
key = jrandom.PRNGKey(0)
keys = jrandom.split(key, 2)

if len(stages_repeats) != 3:
raise ValueError("expected stages_repeats as list of 3 positive ints")
if len(stages_out_channels) != 5:
raise ValueError("expected stages_out_channels as list of 5 positive ints")

input_channels = 3
output_channels = stages_out_channels[0]
self.conv1 = nn.Sequential(
[
nn.Conv2d(
input_channels,
output_channels,
3,
2,
1,
use_bias=False,
key=keys[0],
),
eqxex.BatchNorm(output_channels, axis_name="batch"),
nn.Lambda(jnn.relu),
]
)
input_channels = output_channels

self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

# Static annotations for mypy
self.stage2: nn.Sequential
self.stage3: nn.Sequential
self.stage4: nn.Sequential
stage_names = [f"stage{i}" for i in [2, 3, 4]]
for name, repeats, output_channels in zip(
stage_names, stages_repeats, stages_out_channels[1:]
):
keys = jrandom.split(keys[1], 2)
seq = [inverted_residual(input_channels, output_channels, 2, key=keys[0])]
for i in range(repeats - 1):
keys = jrandom.split(keys[1], 2)
seq.append(
inverted_residual(output_channels, output_channels, 1, key=keys[0])
)
setattr(self, name, nn.Sequential(seq))
input_channels = output_channels

keys = jrandom.split(keys[1], 2)
output_channels = stages_out_channels[-1]
self.conv5 = nn.Sequential(
[
nn.Conv2d(
input_channels,
output_channels,
1,
1,
0,
use_bias=False,
key=keys[0],
),
eqxex.BatchNorm(output_channels, axis_name="batch"),
nn.Lambda(jnn.relu),
]
)
self.pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(output_channels, num_classes, key=keys[1])

def __call__(self, x, *, key: Optional["jax.random.PRNGKey"] = None) -> Array:
"""**Arguments:**
- `x`: The input `JAX` array
- `key`: Required parameter. Utilised by few layers such as `Dropout` or `DropPath`
"""
keys = jrandom.split(key, 5)
x = self.conv1(x, key=keys[0])
x = self.maxpool(x)
x = self.stage2(x, key=keys[1])
x = self.stage3(x, key=keys[2])
x = self.stage4(x, key=keys[3])
x = self.conv5(x, key=keys[4])
x = jnp.ravel(self.pool(x))
x = self.fc(x)
return x


def _shufflenetv2(*args: Any, **kwargs: Any) -> ShuffleNetV2:
model = ShuffleNetV2(*args, **kwargs)
return model


def shufflenet_v2_x0_5(**kwargs: Any) -> ShuffleNetV2:
"""
Constructs a ShuffleNetV2 with 0.5x output channels, as described in
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
<https://arxiv.org/abs/1807.11164>`_.
"""
return _shufflenetv2([4, 8, 4], [24, 48, 96, 192, 1024], **kwargs)


def shufflenet_v2_x1_0(**kwargs: Any) -> ShuffleNetV2:
"""
Constructs a ShuffleNetV2 with 1.0x output channels, as described in
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
<https://arxiv.org/abs/1807.11164>`_.
"""
return _shufflenetv2([4, 8, 4], [24, 116, 232, 464, 1024], **kwargs)


def shufflenet_v2_x1_5(**kwargs: Any) -> ShuffleNetV2:
"""
Constructs a ShuffleNetV2 with 1.5x output channels, as described in
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
<https://arxiv.org/abs/1807.11164>`_.
"""
return _shufflenetv2([4, 8, 4], [24, 176, 352, 704, 1024], **kwargs)


def shufflenet_v2_x2_0(**kwargs: Any) -> ShuffleNetV2:
"""
Constructs a ShuffleNetV2 with 2.0x output channels, as described in
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
<https://arxiv.org/abs/1807.11164>`_.
"""
return _shufflenetv2([4, 8, 4], [24, 244, 488, 976, 2048], **kwargs)
4 changes: 4 additions & 0 deletions eqxvision/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@
"resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth",
"resnext50_32x4d": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
"resnext101_32x8d": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
"shufflenetv2_x0.5": "https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth",
"shufflenetv2_x1.0": "https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth",
"shufflenetv2_x1.5": None,
"shufflenetv2_x2.0": None,
"squeezenet1_0": "https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth",
"squeezenet1_1": "https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth",
"vgg11": "https://download.pytorch.org/models/vgg11-8a719046.pth",
Expand Down
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ nav:
- 'api/models/classification/mobilenetv2.md'
- 'api/models/classification/mobilenetv3.md'
- 'api/models/classification/resnets.md'
- 'api/models/classification/shufflenetv2.md'
- 'api/models/classification/squeeze.md'
- 'api/models/classification/vit.md'
- 'api/models/classification/vgg.md'
Expand Down
Loading

0 comments on commit 062c7e9

Please sign in to comment.