Skip to content

Commit

Permalink
Merge pull request #20 from paganpasta/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
paganpasta authored Aug 6, 2022
2 parents 002964f + caad298 commit b824b86
Show file tree
Hide file tree
Showing 20 changed files with 410 additions and 115 deletions.
32 changes: 32 additions & 0 deletions .github/workflows/python-publish.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
name: Upload Python Package

on:
push:
branches:
- main

permissions:
contents: read

jobs:
deploy:

runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v3
with:
python-version: '3.8'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install build
- name: Build package
run: python -m build
- name: Publish package
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
with:
user: __token__
password: ${{ secrets.PYPI_API_TOKEN }}
23 changes: 16 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@ pip install eqxvision
*requires:* `python>=3.7`

## Usage
???+ Example
Importing and doing a forward pass is as simple as
```python

Picking a model and doing a forward pass is as simple as ...

```python
import jax
import jax.random as jr
import equinox as eqx
Expand All @@ -31,14 +32,14 @@ pip install eqxvision

images = jr.uniform(jr.PRNGKey(0), shape=(1,3,224,224))
output = forward(net, images, jr.PRNGKey(0))
```
```

## What's New?
- `[Experimental]`Now supports loading PyTorch weights from `torchvision` for models **without** BatchNorm

!!! note
Due to slight differences in the implementation of underlying operations,
the output can differ for pretrained versions of the network.
differences in the output values can be expected from `torchvision` models.

## Tips
- Better to use `@equinox.jit_filter` instead of `@jax.jit`
Expand All @@ -51,12 +52,20 @@ pip install eqxvision
## Contributing
Pull requests are welcome. For major changes, please open an issue first to discuss what you would like to change.

Please make sure to update tests as appropriate.
### Development Process
If you plan to modify the code or documentation, please follow the steps below:

1. Fork the repository and create your branch from `dev`.
2. If you have modified the code (new feature or bug-fix), please add unit tests.
3. If you have changed APIs, update the documentation. Make sure the documentation builds. `mkdocs serve`
4. Ensure the test suite passes. `pytest`
5. Make sure your code passes the formatting checks. Automatically checked with a `pre-commit` hook.


## Acknowledgements
- [Equinox](/~https://github.com/patrick-kidger/equinox)
- [Patrick Kidger](/~https://github.com/patrick-kidger)
- [Torchvision](https://pytorch.org/vision/stable/index.html)

## License
[MIT](https://choosealicense.com/licenses/mit/)
[MIT](https://choosealicense.com/licenses/mit/)
8 changes: 0 additions & 8 deletions docs/api/layers/activation.md

This file was deleted.

33 changes: 33 additions & 0 deletions docs/api/models/classification/densenet.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# DenseNets

---

::: eqxvision.models.DenseNet
selection:
members:
- __init__
- __call__

---


::: eqxvision.models.densenet121


---


::: eqxvision.models.densenet161


---


::: eqxvision.models.densenet169


---


::: eqxvision.models.densenet201

2 changes: 1 addition & 1 deletion eqxvision/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
r"""Root package info."""
__version__ = "0.1.1"
__version__ = "0.1.2"

from . import layers, models, utils
1 change: 0 additions & 1 deletion eqxvision/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from .activations import Activation
from .drop_path import DropPath
from .mlps import MlpProjection
from .patch_embed import PatchEmbed
31 changes: 0 additions & 31 deletions eqxvision/layers/activations.py

This file was deleted.

7 changes: 7 additions & 0 deletions eqxvision/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
from .classification.alexnet import AlexNet, alexnet
from .classification.densenet import (
DenseNet,
densenet121,
densenet161,
densenet169,
densenet201,
)
from .classification.googlenet import GoogLeNet, googlenet
from .classification.resnet import (
ResNet,
Expand Down
2 changes: 1 addition & 1 deletion eqxvision/models/classification/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from . import alexnet, googlenet, resnet, squeezenet, vgg, vit
from . import alexnet, densenet, googlenet, resnet, squeezenet, vgg, vit
17 changes: 8 additions & 9 deletions eqxvision/models/classification/alexnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,11 @@
import jax.random as jrandom
from equinox.custom_types import Array

from ...layers import Activation
from ...utils import load_torch_weights, MODEL_URLS


class AlexNet(eqx.Module):
"""A simple port of torchvision.models.alexnet"""
"""A simple port of `torchvision.models.alexnet`"""

features: eqx.Module
avgpool: eqx.Module
Expand Down Expand Up @@ -43,17 +42,17 @@ def __init__(
self.features = nn.Sequential(
[
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2, key=keys[0]),
Activation(jnn.relu),
nn.Lambda(jnn.relu),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(64, 192, kernel_size=5, padding=2, key=keys[1]),
Activation(jnn.relu),
nn.Lambda(jnn.relu),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(192, 384, kernel_size=3, padding=1, key=keys[2]),
Activation(jnn.relu),
nn.Lambda(jnn.relu),
nn.Conv2d(384, 256, kernel_size=3, padding=1, key=keys[3]),
Activation(jnn.relu),
nn.Lambda(jnn.relu),
nn.Conv2d(256, 256, kernel_size=3, padding=1, key=keys[4]),
Activation(jnn.relu),
nn.Lambda(jnn.relu),
nn.MaxPool2d(kernel_size=3, stride=2),
]
)
Expand All @@ -62,10 +61,10 @@ def __init__(
[
nn.Dropout(p=dropout),
nn.Linear(256 * 6 * 6, 4096, key=keys[5]),
Activation(jnn.relu),
nn.Lambda(jnn.relu),
nn.Dropout(p=dropout),
nn.Linear(4096, 4096, key=keys[6]),
Activation(jnn.relu),
nn.Lambda(jnn.relu),
nn.Linear(4096, num_classes, key=keys[7]),
]
)
Expand Down
Loading

0 comments on commit b824b86

Please sign in to comment.