Skip to content

Commit

Permalink
Merge pull request #16 from paganpasta/models/squeezenet
Browse files Browse the repository at this point in the history
added squeezenet
  • Loading branch information
paganpasta authored Aug 4, 2022
2 parents fcb9174 + e92a89b commit 002964f
Show file tree
Hide file tree
Showing 14 changed files with 305 additions and 59 deletions.
52 changes: 26 additions & 26 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,40 +13,40 @@ pip install eqxvision
*requires:* `python>=3.7`

## Usage

```python title="forward.py"
import jax
import jax.random as jr
import equinox as eqx
from eqxvision.models import alexnet

@eqx.filter_jit
def forward(net, images, key):
keys = jax.random.split(key, images.shape[0])
output = jax.vmap(net, axis_name=('batch'))(images, key=keys)
...
???+ Example
Importing and doing a forward pass is as simple as
```python
import jax
import jax.random as jr
import equinox as eqx
from eqxvision.models import alexnet

net = alexnet(num_classes=1000)

images = jr.uniform(jr.PRNGKey(0), shape=(1,3,224,224))
output = forward(net, images, jr.PRNGKey(0))
```

```python title="set_inference.py"
import equinox as eqx
from eqxvision.models import alexnet

net = alexnet(num_classes=1000)
net = eqx.tree_inference(net, True)
```
@eqx.filter_jit
def forward(net, images, key):
keys = jax.random.split(key, images.shape[0])
output = jax.vmap(net, axis_name=('batch'))(images, key=keys)
...

net = alexnet(num_classes=1000)

images = jr.uniform(jr.PRNGKey(0), shape=(1,3,224,224))
output = forward(net, images, jr.PRNGKey(0))
```

## What's New?
- `[Experimental]`Now supports loading PyTorch weights from `torchvision` for models **without** BatchNorm

!!! note
Due to slight differences in the implementation of underlying operations,
the output can differ for pretrained versions of the network.

## Tips
- Better to use `@equinox.jit_filter` instead of `@jax.jit`
- Advisable to use `jax.vmap` with `axis_name='batch'` for all models
- Advisable to use `jax.{v,p}map` with `axis_name='batch'` for all models
- Don't forget to switch to `inference` mode for evaluations
- Wrap with `eqx.filter(net, eqx.is_array)` for `Optax` initialisation.



## Contributing
Pull requests are welcome. For major changes, please open an issue first to discuss what you would like to change.
Expand Down
20 changes: 20 additions & 0 deletions docs/api/models/classification/squeeze.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# SqueezeNets

---

::: eqxvision.models.SqueezeNet
selection:
members:
- __init__
- __call__

---


::: eqxvision.models.squeezenet1_0


---


::: eqxvision.models.squeezenet1_1
50 changes: 24 additions & 26 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,40 +13,38 @@ pip install eqxvision
*requires:* `python>=3.7`

## Usage

```python title="forward.py"
import jax
import jax.random as jr
import equinox as eqx
from eqxvision.models import alexnet

@eqx.filter_jit
def forward(net, images, key):
keys = jax.random.split(key, images.shape[0])
output = jax.vmap(net, axis_name=('batch'))(images, key=keys)
...
???+ Example
Importing and doing a forward pass is as simple as
```python
import jax
import jax.random as jr
import equinox as eqx
from eqxvision.models import alexnet

net = alexnet(num_classes=1000)

images = jr.uniform(jr.PRNGKey(0), shape=(1,3,224,224))
output = forward(net, images, jr.PRNGKey(0))
```

```python title="set_inference.py"
import equinox as eqx
from eqxvision.models import alexnet

net = alexnet(num_classes=1000)
net = eqx.tree_inference(net, True)
```
@eqx.filter_jit
def forward(net, images, key):
keys = jax.random.split(key, images.shape[0])
output = jax.vmap(net, axis_name=('batch'))(images, key=keys)
...

net = alexnet(num_classes=1000)

images = jr.uniform(jr.PRNGKey(0), shape=(1,3,224,224))
output = forward(net, images, jr.PRNGKey(0))
```

## What's New?
- `[Experimental]`Now supports loading PyTorch weights from `torchvision` for models **without** BatchNorm

!!! note
Due to slight differences in the implementation of underlying operations,
the output can differ for pretrained versions of the network.

## Tips
- Better to use `@equinox.jit_filter` instead of `@jax.jit`
- Advisable to use `jax.vmap` with `axis_name='batch'` for all models
- Advisable to use `jax.{v,p}map` with `axis_name='batch'` for all models
- Don't forget to switch to `inference` mode for evaluations
- Wrap with `eqx.filter(net, eqx.is_array)` for `Optax` initialisation.



Expand Down
2 changes: 1 addition & 1 deletion eqxvision/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
r"""Root package info."""
__version__ = "0.1.0"
__version__ = "0.1.1"

from . import layers, models, utils
1 change: 1 addition & 0 deletions eqxvision/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
wide_resnet50_2,
wide_resnet101_2,
)
from .classification.squeezenet import SqueezeNet, squeezenet1_0, squeezenet1_1
from .classification.vgg import (
VGG,
vgg11,
Expand Down
2 changes: 1 addition & 1 deletion eqxvision/models/classification/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from . import alexnet, googlenet, resnet, vgg, vit
from . import alexnet, googlenet, resnet, squeezenet, vgg, vit
174 changes: 174 additions & 0 deletions eqxvision/models/classification/squeezenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
from typing import Any, Optional

import equinox as eqx
import equinox.nn as nn
import jax
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jrandom
from equinox.custom_types import Array

from ...utils import load_torch_weights, MODEL_URLS


class Fire(eqx.Module):
inplanes: int
squeeze: nn.Conv2d
squeeze_activation: nn.Lambda
expand1x1: nn.Conv2d
expand1x1_activation: nn.Lambda
expand3x3: nn.Conv2d
expand3x3_activation: nn.Lambda

def __init__(
self,
inplanes: int,
squeeze_planes: int,
expand1x1_planes: int,
expand3x3_planes: int,
key=None,
) -> None:
super().__init__()
keys = jrandom.split(key, 3)
self.inplanes = inplanes
self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1, key=keys[0])
self.squeeze_activation = nn.Lambda(jnn.relu)
self.expand1x1 = nn.Conv2d(
squeeze_planes, expand1x1_planes, kernel_size=1, key=keys[1]
)
self.expand1x1_activation = nn.Lambda(jnn.relu)
self.expand3x3 = nn.Conv2d(
squeeze_planes, expand3x3_planes, kernel_size=3, padding=1, key=keys[2]
)
self.expand3x3_activation = nn.Lambda(jnn.relu)

def __call__(self, x: Array, *, key: "jax.random.PRNGKey") -> Array:
x = self.squeeze_activation(self.squeeze(x))
return jnp.concatenate(
[
self.expand1x1_activation(self.expand1x1(x)),
self.expand3x3_activation(self.expand3x3(x)),
],
axis=0,
)


class SqueezeNet(eqx.Module):
"""A simple port of `torchvision.models.squeezenet`"""

features: nn.Sequential
classifier: nn.Sequential

def __init__(
self,
version: str = "1_0",
num_classes: int = 1000,
dropout: float = 0.5,
*,
key: Optional["jax.random.PRNGKey"] = None
) -> None:
"""**Arguments:**
- `version`: Specifies the version of the network. Defaults to `1_0`.
- `num_classes`: Number of classes in the classification task.
Also controls the final output shape `(num_classes,)`. Defaults to `1000`.
- `dropout`: The probability parameter for `equinox.nn.Dropout`.
- `key`: A `jax.random.PRNGKey` used to provide randomness for parameter
initialisation. (Keyword only argument.)
"""
super().__init__()
if key is None:
key = jrandom.PRNGKey(0)
keys = jrandom.split(key, 10)
if version == "1_0":
self.features = nn.Sequential(
[
nn.Conv2d(3, 96, kernel_size=7, stride=2, key=keys[0]),
nn.Lambda(jnn.relu),
nn.MaxPool2d(kernel_size=3, stride=2),
Fire(96, 16, 64, 64, key=keys[1]),
Fire(128, 16, 64, 64, key=keys[2]),
Fire(128, 32, 128, 128, key=keys[3]),
nn.MaxPool2d(kernel_size=3, stride=2),
Fire(256, 32, 128, 128, key=keys[4]),
Fire(256, 48, 192, 192, key=keys[5]),
Fire(384, 48, 192, 192, key=keys[6]),
Fire(384, 64, 256, 256, key=keys[7]),
nn.MaxPool2d(kernel_size=3, stride=2),
Fire(512, 64, 256, 256, key=keys[8]),
]
)
elif version == "1_1":
self.features = nn.Sequential(
[
nn.Conv2d(3, 64, kernel_size=3, stride=2, key=keys[0]),
nn.Lambda(jnn.relu),
nn.MaxPool2d(kernel_size=3, stride=2),
Fire(64, 16, 64, 64, key=keys[1]),
Fire(128, 16, 64, 64, key=keys[2]),
nn.MaxPool2d(kernel_size=3, stride=2),
Fire(128, 32, 128, 128, key=keys[3]),
Fire(256, 32, 128, 128, key=keys[4]),
nn.MaxPool2d(kernel_size=3, stride=2),
Fire(256, 48, 192, 192, key=keys[5]),
Fire(384, 48, 192, 192, key=keys[6]),
Fire(384, 64, 256, 256, key=keys[7]),
Fire(512, 64, 256, 256, key=keys[8]),
]
)

# Final convolution is initialized differently from the rest
final_conv = nn.Conv2d(512, num_classes, kernel_size=1, key=keys[9])
self.classifier = nn.Sequential(
[
nn.Dropout(p=dropout),
final_conv,
nn.Lambda(jnn.relu),
nn.AdaptiveAvgPool2d((1, 1)),
]
)

def __call__(self, x: Array, *, key: "jax.random.PRNGKey") -> Array:
"""**Arguments:**
- `x`: The input. Should be a JAX array with `3` channels.
- `key`: Required parameter. Utilised by few layers such as `Dropout` or `DropPath`.
"""
x = self.features(x)
x = self.classifier(x, key=key)
return jnp.ravel(x)


def _squeezenet(version: str, pretrained: bool, **kwargs: Any) -> SqueezeNet:
model = SqueezeNet(version, **kwargs)
if pretrained:
arch = "squeezenet" + version
model = load_torch_weights(model, url=MODEL_URLS[arch])
return model


def squeezenet1_0(pretrained: bool = False, **kwargs: Any) -> SqueezeNet:
r"""SqueezeNet model architecture from the `"SqueezeNet: AlexNet-level
accuracy with 50x fewer parameters and <0.5MB model size"
<https://arxiv.org/abs/1602.07360>`_ paper.
The required minimum input size of the model is 21x21.
**Arguments:**
- `pretrained`: If `True`, the weights are loaded from `PyTorch` saved checkpoint.
"""
return _squeezenet("1_0", pretrained, **kwargs)


def squeezenet1_1(pretrained: bool = False, **kwargs: Any) -> SqueezeNet:
r"""SqueezeNet 1.1 model from the `official SqueezeNet repo
</~https://github.com/DeepScale/SqueezeNet/tree/master/SqueezeNet_v1.1>`_.
SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters
than SqueezeNet 1.0, without sacrificing accuracy.
The required minimum input size of the model is 17x17.
**Arguments:**
- `pretrained`: If `True`, the weights are loaded from `PyTorch` saved checkpoint.
"""
return _squeezenet("1_1", pretrained, **kwargs)
6 changes: 5 additions & 1 deletion eqxvision/models/classification/vgg.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,8 @@ def vgg11(pretrained=False, **kwargs: Any) -> VGG:
r"""VGG 11-layer model (configuration "A") from
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
The required minimum input size of the model is 32x32.
**Arguments:**
**Arguments:**
- `pretrained`: If `True`, the weights are loaded from `PyTorch` saved checkpoint.
"""
Expand All @@ -182,6 +183,7 @@ def vgg13(pretrained=False, **kwargs: Any) -> VGG:
r"""VGG 13-layer model (configuration "B")
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
The required minimum input size of the model is 32x32.
**Arguments:**
- `pretrained`: If `True`, the weights are loaded from `PyTorch` saved checkpoint.
Expand All @@ -204,6 +206,7 @@ def vgg16(pretrained=False, **kwargs: Any) -> VGG:
r"""VGG 16-layer model (configuration "D")
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
The required minimum input size of the model is 32x32.
**Arguments:**
- `pretrained`: If `True`, the weights are loaded from `PyTorch` saved checkpoint.
Expand All @@ -226,6 +229,7 @@ def vgg19(pretrained=False, **kwargs: Any) -> VGG:
r"""VGG 19-layer model (configuration "E")
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
The required minimum input size of the model is 32x32.
**Arguments:**
- `pretrained`: If `True`, the weights are loaded from `PyTorch` saved checkpoint.
Expand Down
2 changes: 2 additions & 0 deletions eqxvision/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
"vgg13_bn": "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth",
"vgg16_bn": "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth",
"vgg19_bn": "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth",
"squeezenet1_0": "https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth",
"squeezenet1_1": "https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth",
}


Expand Down
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ nav:
- 'api/models/classification/alexnet.md'
- 'api/models/classification/googlenet.md'
- 'api/models/classification/resnets.md'
- 'api/models/classification/squeeze.md'
- 'api/models/classification/vit.md'
- 'api/models/classification/vgg.md'
- Vision Layers:
Expand Down
Loading

0 comments on commit 002964f

Please sign in to comment.