From a1e63d99ef0847f38805fbef02846d3e9d399175 Mon Sep 17 00:00:00 2001 From: paganpasta Date: Thu, 4 Aug 2022 22:10:50 +0100 Subject: [PATCH 1/7] added release action --- .github/workflows/run_release.yml | 32 +++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 .github/workflows/run_release.yml diff --git a/.github/workflows/run_release.yml b/.github/workflows/run_release.yml new file mode 100644 index 0000000..a231064 --- /dev/null +++ b/.github/workflows/run_release.yml @@ -0,0 +1,32 @@ +name: Upload Python Package + +on: + push: + branches: + - main + +permissions: + contents: read + +jobs: + deploy: + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + - name: Set up Python + uses: actions/setup-python@v3 + with: + python-version: '3.8' + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install build + - name: Build package + run: python -m build + - name: Publish package + uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 + with: + user: __token__ + password: ${{ secrets.PYPI_API_TOKEN }} From 554e4e499aa196336f1ad6db9c60dbb648229477 Mon Sep 17 00:00:00 2001 From: Aditya Singh Date: Thu, 4 Aug 2022 22:14:25 +0100 Subject: [PATCH 2/7] Create python-publish.yml --- .github/workflows/python-publish.yml | 32 ++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 .github/workflows/python-publish.yml diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml new file mode 100644 index 0000000..a231064 --- /dev/null +++ b/.github/workflows/python-publish.yml @@ -0,0 +1,32 @@ +name: Upload Python Package + +on: + push: + branches: + - main + +permissions: + contents: read + +jobs: + deploy: + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + - name: Set up Python + uses: actions/setup-python@v3 + with: + python-version: '3.8' + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install build + - name: Build package + run: python -m build + - name: Publish package + uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 + with: + user: __token__ + password: ${{ secrets.PYPI_API_TOKEN }} From 3dcbdf673713e367e8d3aac08dd091ed2d622df9 Mon Sep 17 00:00:00 2001 From: Aditya Singh Date: Thu, 4 Aug 2022 22:29:00 +0100 Subject: [PATCH 3/7] Update README.md --- README.md | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 5547d2c..a44fc77 100644 --- a/README.md +++ b/README.md @@ -13,9 +13,10 @@ pip install eqxvision *requires:* `python>=3.7` ## Usage -???+ Example - Importing and doing a forward pass is as simple as - ```python + +Picking a model and doing a forward pass is as simple as ... + +```python import jax import jax.random as jr import equinox as eqx @@ -31,14 +32,14 @@ pip install eqxvision images = jr.uniform(jr.PRNGKey(0), shape=(1,3,224,224)) output = forward(net, images, jr.PRNGKey(0)) - ``` +``` ## What's New? - `[Experimental]`Now supports loading PyTorch weights from `torchvision` for models **without** BatchNorm !!! note Due to slight differences in the implementation of underlying operations, - the output can differ for pretrained versions of the network. + differences in the output values can be expected from `torchvision` models. ## Tips - Better to use `@equinox.jit_filter` instead of `@jax.jit` @@ -59,4 +60,4 @@ Please make sure to update tests as appropriate. - [Torchvision](https://pytorch.org/vision/stable/index.html) ## License -[MIT](https://choosealicense.com/licenses/mit/) \ No newline at end of file +[MIT](https://choosealicense.com/licenses/mit/) From 7539a9b5a24ac4d950d7d4bd182231e952fc1c46 Mon Sep 17 00:00:00 2001 From: Aditya Singh Date: Thu, 4 Aug 2022 22:40:52 +0100 Subject: [PATCH 4/7] Extended contributions --- README.md | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index a44fc77..1700387 100644 --- a/README.md +++ b/README.md @@ -52,7 +52,15 @@ Picking a model and doing a forward pass is as simple as ... ## Contributing Pull requests are welcome. For major changes, please open an issue first to discuss what you would like to change. -Please make sure to update tests as appropriate. +### Development Process +If you plan to modify the code or documentation, please follow the steps below: + +1. Fork the repository and create your branch from `dev`. +2. If you have modified the code (new feature or bug-fix), please add unit tests. +3. If you have changed APIs, update the documentation. Make sure the documentation builds. `mkdocs serve` +4. Ensure the test suite passes. `pytest` +5. Make sure your code passes the formatting checks. Automatically checked with a `pre-commit` hook. + ## Acknowledgements - [Equinox](/~https://github.com/patrick-kidger/equinox) From 703f493ea3a996591c7311f7ad7df773251a3404 Mon Sep 17 00:00:00 2001 From: paganpasta Date: Sat, 6 Aug 2022 01:52:58 +0100 Subject: [PATCH 5/7] densenet and trimming tests --- docs/api/layers/activation.md | 8 - docs/api/models/classification/densenet.md | 33 +++ eqxvision/layers/__init__.py | 1 - eqxvision/layers/activations.py | 31 --- eqxvision/models/__init__.py | 7 + eqxvision/models/classification/__init__.py | 2 +- eqxvision/models/classification/alexnet.py | 17 +- eqxvision/models/classification/densenet.py | 272 +++++++++++++++++++ eqxvision/models/classification/googlenet.py | 2 +- eqxvision/models/classification/vgg.py | 7 +- eqxvision/utils.py | 12 +- mkdocs.yml | 2 +- tests/test_layers.py | 26 -- tests/test_models/test_densenets.py | 25 ++ tests/test_models/test_resnet.py | 12 +- tests/test_models/test_squeezenet.py | 5 +- tests/test_models/test_vgg.py | 6 - 17 files changed, 361 insertions(+), 107 deletions(-) delete mode 100644 docs/api/layers/activation.md create mode 100644 docs/api/models/classification/densenet.md delete mode 100644 eqxvision/layers/activations.py create mode 100644 eqxvision/models/classification/densenet.py create mode 100644 tests/test_models/test_densenets.py diff --git a/docs/api/layers/activation.md b/docs/api/layers/activation.md deleted file mode 100644 index 00fabad..0000000 --- a/docs/api/layers/activation.md +++ /dev/null @@ -1,8 +0,0 @@ -# Activation - -::: eqxvision.layers.Activation - selection: - members: - - __init__ - - __call__ - diff --git a/docs/api/models/classification/densenet.md b/docs/api/models/classification/densenet.md new file mode 100644 index 0000000..977f34c --- /dev/null +++ b/docs/api/models/classification/densenet.md @@ -0,0 +1,33 @@ +# DenseNets + +--- + +::: eqxvision.models.DenseNet + selection: + members: + - __init__ + - __call__ + +--- + + +::: eqxvision.models.densenet121 + + +--- + + +::: eqxvision.models.densenet161 + + +--- + + +::: eqxvision.models.densenet169 + + +--- + + +::: eqxvision.models.densenet201 + diff --git a/eqxvision/layers/__init__.py b/eqxvision/layers/__init__.py index 8e61f90..8f46c27 100644 --- a/eqxvision/layers/__init__.py +++ b/eqxvision/layers/__init__.py @@ -1,4 +1,3 @@ -from .activations import Activation from .drop_path import DropPath from .mlps import MlpProjection from .patch_embed import PatchEmbed diff --git a/eqxvision/layers/activations.py b/eqxvision/layers/activations.py deleted file mode 100644 index f876243..0000000 --- a/eqxvision/layers/activations.py +++ /dev/null @@ -1,31 +0,0 @@ -from typing import Any, Callable, Optional - -import equinox as eqx -import jax - - -class Activation(eqx.Module): - """Wrapper around Callables to make them eqx.Modules. Useful for `nn.Sequential`.""" - - activation: Callable - - def __init__( - self, - activation: Callable, - ): - """**Arguments:** - - - `activation`: The `callable` to be wrapped in `equinox.Module`. - """ - self.activation = activation - - def __call__(self, x: Any, *, key: Optional["jax.random.PRNGKey"] = None) -> Any: - """**Arguments:** - - - `x`: A JAX `ndarray`. - - `key`: Ignored. - **Returns:** - - The output of the activation function. - """ - return self.activation(x) diff --git a/eqxvision/models/__init__.py b/eqxvision/models/__init__.py index df392a0..59c1e4a 100644 --- a/eqxvision/models/__init__.py +++ b/eqxvision/models/__init__.py @@ -1,4 +1,11 @@ from .classification.alexnet import AlexNet, alexnet +from .classification.densenet import ( + DenseNet, + densenet121, + densenet161, + densenet169, + densenet201, +) from .classification.googlenet import GoogLeNet, googlenet from .classification.resnet import ( ResNet, diff --git a/eqxvision/models/classification/__init__.py b/eqxvision/models/classification/__init__.py index 856af1f..d389180 100644 --- a/eqxvision/models/classification/__init__.py +++ b/eqxvision/models/classification/__init__.py @@ -1 +1 @@ -from . import alexnet, googlenet, resnet, squeezenet, vgg, vit +from . import alexnet, densenet, googlenet, resnet, squeezenet, vgg, vit diff --git a/eqxvision/models/classification/alexnet.py b/eqxvision/models/classification/alexnet.py index 184380a..d9e8015 100644 --- a/eqxvision/models/classification/alexnet.py +++ b/eqxvision/models/classification/alexnet.py @@ -8,12 +8,11 @@ import jax.random as jrandom from equinox.custom_types import Array -from ...layers import Activation from ...utils import load_torch_weights, MODEL_URLS class AlexNet(eqx.Module): - """A simple port of torchvision.models.alexnet""" + """A simple port of `torchvision.models.alexnet`""" features: eqx.Module avgpool: eqx.Module @@ -43,17 +42,17 @@ def __init__( self.features = nn.Sequential( [ nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2, key=keys[0]), - Activation(jnn.relu), + nn.Lambda(jnn.relu), nn.MaxPool2d(kernel_size=3, stride=2), nn.Conv2d(64, 192, kernel_size=5, padding=2, key=keys[1]), - Activation(jnn.relu), + nn.Lambda(jnn.relu), nn.MaxPool2d(kernel_size=3, stride=2), nn.Conv2d(192, 384, kernel_size=3, padding=1, key=keys[2]), - Activation(jnn.relu), + nn.Lambda(jnn.relu), nn.Conv2d(384, 256, kernel_size=3, padding=1, key=keys[3]), - Activation(jnn.relu), + nn.Lambda(jnn.relu), nn.Conv2d(256, 256, kernel_size=3, padding=1, key=keys[4]), - Activation(jnn.relu), + nn.Lambda(jnn.relu), nn.MaxPool2d(kernel_size=3, stride=2), ] ) @@ -62,10 +61,10 @@ def __init__( [ nn.Dropout(p=dropout), nn.Linear(256 * 6 * 6, 4096, key=keys[5]), - Activation(jnn.relu), + nn.Lambda(jnn.relu), nn.Dropout(p=dropout), nn.Linear(4096, 4096, key=keys[6]), - Activation(jnn.relu), + nn.Lambda(jnn.relu), nn.Linear(4096, num_classes, key=keys[7]), ] ) diff --git a/eqxvision/models/classification/densenet.py b/eqxvision/models/classification/densenet.py new file mode 100644 index 0000000..728ed8c --- /dev/null +++ b/eqxvision/models/classification/densenet.py @@ -0,0 +1,272 @@ +from typing import Any, Optional, Sequence, Tuple, Union + +import equinox as eqx +import equinox.experimental as eqxex +import equinox.nn as nn +import jax +import jax.nn as jnn +import jax.numpy as jnp +import jax.random as jrandom +from equinox.custom_types import Array + + +class _DenseLayer(eqx.Module): + norm1: eqxex.BatchNorm + relu: nn.Lambda + conv1: nn.Conv2d + norm2: eqxex.BatchNorm + conv2: nn.Conv2d + dropout: nn.Dropout + + def __init__( + self, + num_input_features: int, + growth_rate: int, + bn_size: int, + drop_rate: float, + key: "jax.random.PRNGKey", + ) -> None: + super().__init__() + keys = jrandom.split(key, 2) + self.norm1 = eqxex.BatchNorm(num_input_features, axis_name="batch") + self.relu = nn.Lambda(jnn.relu) + self.conv1 = nn.Conv2d( + num_input_features, + bn_size * growth_rate, + kernel_size=1, + stride=1, + use_bias=False, + key=keys[0], + ) + self.norm2 = eqxex.BatchNorm(bn_size * growth_rate, axis_name="batch") + self.conv2 = nn.Conv2d( + bn_size * growth_rate, + growth_rate, + kernel_size=3, + stride=1, + padding=1, + use_bias=False, + key=keys[1], + ) + self.dropout = nn.Dropout(p=float(drop_rate)) + + def __call__( + self, x: Union[Array, Sequence[Array]], *, key: "jax.random.PRNGKey" + ) -> Array: + if isinstance(x, Array): + prev_features = [x] + else: + prev_features = x + + concated_features = jnp.concatenate(prev_features, axis=0) + bottleneck_output = self.conv1(self.relu(self.norm1(concated_features))) + new_features = self.conv2(self.relu(self.norm2(bottleneck_output))) + new_features = self.dropout(new_features, key=key) + return new_features + + +class _DenseBlock(eqx.Module): + layers: Sequence[eqx.Module] + num_layers: int + + def __init__( + self, + num_layers: int, + num_input_features: int, + bn_size: int, + growth_rate: int, + drop_rate: float, + key: "jax.random.PRNGKey" = None, + ) -> None: + super().__init__() + self.layers = [] + self.num_layers = num_layers + keys = jrandom.split(key, num_layers) + for i in range(num_layers): + layer = _DenseLayer( + num_input_features + i * growth_rate, + growth_rate=growth_rate, + bn_size=bn_size, + drop_rate=drop_rate, + key=keys[i], + ) + self.layers.append(layer) + + def __call__(self, x: Array, *, key: "jax.random.PRNGKey") -> Array: + features = [x] + keys = jrandom.split(key, self.num_layers) + for i in range(self.num_layers): + new_features = self.layers[i](features, key=keys[i]) + features.append(new_features) + return jnp.concatenate(features, 0) + + +class _Transition(eqx.Module): + layers: nn.Sequential + + def __init__( + self, + num_input_features: int, + num_output_features: int, + key: "jax.random.PRNGKey" = None, + ) -> None: + super().__init__() + self.layers = nn.Sequential( + [ + eqxex.BatchNorm(num_input_features, axis_name="batch"), + nn.Lambda(jnn.relu), + nn.Conv2d( + num_input_features, + num_output_features, + kernel_size=1, + stride=1, + use_bias=False, + key=key, + ), + nn.AvgPool2d(kernel_size=2, stride=2), + ] + ) + + def __call__(self, x: Array, *, key: "jax.random.PRNGKey") -> Array: + return self.layers(x, key=key) + + +class DenseNet(eqx.Module): + """A simple port of `torchvision.models.densenet`.""" + + features: nn.Sequential + classifier: nn.Linear + + def __init__( + self, + growth_rate: int = 32, + block_config: Tuple[int, int, int, int] = (6, 12, 24, 16), + num_init_features: int = 64, + bn_size: int = 4, + drop_rate: float = 0, + num_classes: int = 1000, + *, + key: Optional["jax.random.PRNGKey"] = None, + ) -> None: + """ + **Arguments:** + + - `growth_rate`: Number of filters to add in each layer (`k` in paper) + - `block_config`: Number of layers in each pooling block + - `num_init_features` - The number of filters to learn in the first convolution layer + - `bn_size`: Multiplicative factor for number of bottle neck layers + (i.e. bn_size * k features in the bottleneck layer) + - `drop_rate`: Dropout rate after each dense layer + - `num_classes`: Number of classes in the classification task. + Also controls the final output shape `(num_classes,)`. Defaults to `1000`. + """ + super().__init__() + if key is None: + key = jrandom.PRNGKey(0) + # First convolution + keys = jrandom.split(key, 2 * len(block_config) + 2) + self.features = nn.Sequential( + [ + nn.Conv2d( + 3, + num_init_features, + kernel_size=7, + stride=2, + padding=3, + use_bias=False, + key=keys[0], + ), + eqxex.BatchNorm(num_init_features, axis_name="batch"), + nn.Lambda(jnn.relu), + nn.MaxPool2d(kernel_size=3, stride=2, padding=1), + ] + ) + + # Each denseblock + num_features = num_init_features + for i, num_layers in enumerate(block_config): + keys = jrandom.split(keys[i * 2 + 1], 3) + block = _DenseBlock( + num_layers=num_layers, + num_input_features=num_features, + bn_size=bn_size, + growth_rate=growth_rate, + drop_rate=drop_rate, + key=keys[0], + ) + self.features.layers.append(block) + num_features = num_features + num_layers * growth_rate + if i != len(block_config) - 1: + trans = _Transition( + num_input_features=num_features, + num_output_features=num_features // 2, + key=keys[i * 2 + 2], + ) + self.features.layers.append(trans) + num_features = num_features // 2 + + # Final batch norm, relu and pooling + self.features.layers.extend( + [ + eqxex.BatchNorm(num_features, axis_name="batch"), + nn.Lambda(jnn.relu), + nn.AdaptiveAvgPool2d((1, 1)), + ] + ) + # Linear layer + self.classifier = nn.Linear(num_features, num_classes, key=keys[-1]) + + def __call__(self, x: Array, *, key: "jax.random.PRNGKey") -> Array: + """**Arguments:** + + - `x`: The input. Should be a JAX array with `3` channels. + - `key`: Required parameter. Utilised by few layers such as `Dropout` or `DropPath`. + """ + out = self.features(x, key=key) + out = jnp.ravel(out) + out = self.classifier(out) + return out + + +def _densenet( + growth_rate: int, + block_config: Tuple[int, int, int, int], + num_init_features: int, + **kwargs: Any, +) -> DenseNet: + model = DenseNet(growth_rate, block_config, num_init_features, **kwargs) + return model + + +def densenet121(**kwargs: Any) -> DenseNet: + r"""Densenet-121 model from + `"Densely Connected Convolutional Networks" `_. + The required minimum input size of the model is 29x29. + + """ + return _densenet(32, (6, 12, 24, 16), 64, **kwargs) + + +def densenet161(**kwargs: Any) -> DenseNet: + r"""Densenet-161 model from + `"Densely Connected Convolutional Networks" `_. + The required minimum input size of the model is 29x29. + + """ + return _densenet(48, (6, 12, 36, 24), 96, **kwargs) + + +def densenet169(**kwargs: Any) -> DenseNet: + r"""Densenet-169 model from + `"Densely Connected Convolutional Networks" `_. + The required minimum input size of the model is 29x29. + """ + return _densenet(32, (6, 12, 32, 32), 64, **kwargs) + + +def densenet201(**kwargs: Any) -> DenseNet: + r"""Densenet-201 model from + `"Densely Connected Convolutional Networks" `_. + The required minimum input size of the model is 29x29. + """ + return _densenet(32, (6, 12, 48, 32), 64, **kwargs) diff --git a/eqxvision/models/classification/googlenet.py b/eqxvision/models/classification/googlenet.py index f6bcf41..2713b41 100644 --- a/eqxvision/models/classification/googlenet.py +++ b/eqxvision/models/classification/googlenet.py @@ -10,7 +10,7 @@ class GoogLeNet(eqx.Module): - """A simple port of torchvision.models.GoogLeNet""" + """A simple port of `torchvision.models.GoogLeNet`""" aux_logits: bool conv1: eqx.Module diff --git a/eqxvision/models/classification/vgg.py b/eqxvision/models/classification/vgg.py index 969a6f0..dd3f16d 100644 --- a/eqxvision/models/classification/vgg.py +++ b/eqxvision/models/classification/vgg.py @@ -9,7 +9,6 @@ import jax.random as jrandom from equinox.custom_types import Array -from ...layers import Activation from ...utils import load_torch_weights, MODEL_URLS @@ -100,7 +99,7 @@ def __init__( nn.Linear(512 * 7 * 7, 4096, key=keys[1]), nn.Dropout(p=dropout), nn.Linear(4096, 4096, key=keys[2]), - Activation(jnn.relu), + nn.Lambda(jnn.relu), nn.Dropout(p=dropout), nn.Linear(4096, num_classes, key=keys[3]), ] @@ -142,10 +141,10 @@ def _make_layers( layers += [ conv2d, eqex.BatchNorm(v, axis_name="batch"), - Activation(jnn.relu), + nn.Lambda(jnn.relu), ] else: - layers += [conv2d, Activation(jnn.relu)] + layers += [conv2d, nn.Lambda(jnn.relu)] in_channels = v count += 1 return nn.Sequential(layers) diff --git a/eqxvision/utils.py b/eqxvision/utils.py index 1388c90..7e858a2 100644 --- a/eqxvision/utils.py +++ b/eqxvision/utils.py @@ -18,6 +18,10 @@ _Url = NewType("_Url", str) MODEL_URLS = { "alexnet": "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth", + "densenet121": "https://download.pytorch.org/models/densenet121-a639ec97.pth", + "densenet169": "https://download.pytorch.org/models/densenet169-b2777c0a.pth", + "densenet201": "https://download.pytorch.org/models/densenet201-c1103571.pth", + "densenet161": "https://download.pytorch.org/models/densenet161-8d451a50.pth", "googlenet": "https://download.pytorch.org/models/googlenet-1378be20.pth", "resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth", "resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth", @@ -26,8 +30,8 @@ "resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth", "resnext50_32x4d": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth", "resnext101_32x8d": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth", - "wide_resnet50_2": "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth", - "wide_resnet101_2": "https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth", + "squeezenet1_0": "https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth", + "squeezenet1_1": "https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth", "vgg11": "https://download.pytorch.org/models/vgg11-8a719046.pth", "vgg13": "https://download.pytorch.org/models/vgg13-19584684.pth", "vgg16": "https://download.pytorch.org/models/vgg16-397923af.pth", @@ -36,8 +40,8 @@ "vgg13_bn": "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth", "vgg16_bn": "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth", "vgg19_bn": "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth", - "squeezenet1_0": "https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth", - "squeezenet1_1": "https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth", + "wide_resnet50_2": "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth", + "wide_resnet101_2": "https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth", } diff --git a/mkdocs.yml b/mkdocs.yml index a287616..2ec41f5 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -87,13 +87,13 @@ nav: - Full API: - Classification models: - 'api/models/classification/alexnet.md' + - 'api/models/classification/densenet.md' - 'api/models/classification/googlenet.md' - 'api/models/classification/resnets.md' - 'api/models/classification/squeeze.md' - 'api/models/classification/vit.md' - 'api/models/classification/vgg.md' - Vision Layers: - - 'api/layers/activation.md' - 'api/layers/drop_path.md' - 'api/layers/mlp.md' - 'api/layers/patch.md' diff --git a/tests/test_layers.py b/tests/test_layers.py index 782c317..e0bcc06 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -1,37 +1,11 @@ import equinox as eqx import jax import jax.nn as jnn -import jax.numpy as jnp import jax.random as jrandom import eqxvision.layers as layers -def test_activation(getkey): - c_counter = 0 - - @eqx.filter_jit - def forward(net, xs, keys): - nonlocal c_counter - c_counter += 1 - return jax.vmap(net)(xs, key=keys) - - chain = eqx.nn.Sequential( - [ - eqx.nn.Identity(), - layers.Activation(jnn.relu), - ] - ) - x = jnp.array([[-1, -2, -3], [1, 2, 3]]) - keys = jrandom.split(getkey(), 2) - output = forward(chain, x, keys) - assert output.shape == (2, 3) - assert (output >= 0).all() - - forward(chain, x, keys) - assert c_counter == 1 - - def test_patch_embed(getkey): @eqx.filter_jit def forward(net, xs): diff --git a/tests/test_models/test_densenets.py b/tests/test_models/test_densenets.py new file mode 100644 index 0000000..156d0de --- /dev/null +++ b/tests/test_models/test_densenets.py @@ -0,0 +1,25 @@ +import equinox as eqx +import jax +import pytest + +import eqxvision.models as models + + +model_list = [models.densenet121] + + +class TestDenseNet: + random_image = jax.random.uniform(key=jax.random.PRNGKey(0), shape=(1, 3, 224, 224)) + answer = (1, 1000) + + @pytest.mark.parametrize("model_func", model_list) + def test_densenets(self, model_func, getkey): + @eqx.filter_jit + def forward(net, x, key): + keys = jax.random.split(key, x.shape[0]) + ans = jax.vmap(net, axis_name="batch")(x, key=keys) + return ans + + model = model_func(num_classes=1000) + output = forward(model, self.random_image, getkey()) + assert output.shape == self.answer diff --git a/tests/test_models/test_resnet.py b/tests/test_models/test_resnet.py index 6a7620a..c81beb4 100644 --- a/tests/test_models/test_resnet.py +++ b/tests/test_models/test_resnet.py @@ -5,17 +5,7 @@ import eqxvision.models as models -model_list = [ - models.resnet18, - models.resnet34, - models.resnet50, - models.resnet101, - models.resnet152, - models.resnext50_32x4d, - models.resnext101_32x8d, - models.wide_resnet50_2, - models.wide_resnet101_2, -] +model_list = [models.resnet18] class TestResNet: diff --git a/tests/test_models/test_squeezenet.py b/tests/test_models/test_squeezenet.py index 62c6cd9..04331c3 100644 --- a/tests/test_models/test_squeezenet.py +++ b/tests/test_models/test_squeezenet.py @@ -6,10 +6,7 @@ import eqxvision.models as models -model_list = [ - models.squeezenet1_0, - models.squeezenet1_1, -] +model_list = [models.squeezenet1_0] class TestSqueezeNet: diff --git a/tests/test_models/test_vgg.py b/tests/test_models/test_vgg.py index cc4029c..4c91895 100644 --- a/tests/test_models/test_vgg.py +++ b/tests/test_models/test_vgg.py @@ -10,12 +10,6 @@ model_list = [ models.vgg11, models.vgg11_bn, - models.vgg13, - models.vgg13_bn, - models.vgg16, - models.vgg16_bn, - models.vgg19, - models.vgg19_bn, ] From c4e91ad7f51a216543a4ea73e2e2ed2812be5938 Mon Sep 17 00:00:00 2001 From: paganpasta Date: Sat, 6 Aug 2022 01:53:49 +0100 Subject: [PATCH 6/7] bumping version --- eqxvision/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/eqxvision/__init__.py b/eqxvision/__init__.py index 6f5a824..023fd13 100644 --- a/eqxvision/__init__.py +++ b/eqxvision/__init__.py @@ -1,4 +1,4 @@ r"""Root package info.""" -__version__ = "0.1.1" +__version__ = "0.1.2" from . import layers, models, utils From 2d9989eb915f71547ef169f9df3a2d5dffcfa800 Mon Sep 17 00:00:00 2001 From: paganpasta Date: Sat, 6 Aug 2022 01:55:22 +0100 Subject: [PATCH 7/7] removed run release workflow --- .github/workflows/run_release.yml | 32 ------------------------------- 1 file changed, 32 deletions(-) delete mode 100644 .github/workflows/run_release.yml diff --git a/.github/workflows/run_release.yml b/.github/workflows/run_release.yml deleted file mode 100644 index a231064..0000000 --- a/.github/workflows/run_release.yml +++ /dev/null @@ -1,32 +0,0 @@ -name: Upload Python Package - -on: - push: - branches: - - main - -permissions: - contents: read - -jobs: - deploy: - - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v3 - - name: Set up Python - uses: actions/setup-python@v3 - with: - python-version: '3.8' - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install build - - name: Build package - run: python -m build - - name: Publish package - uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 - with: - user: __token__ - password: ${{ secrets.PYPI_API_TOKEN }}