Skip to content

Commit

Permalink
Merge pull request #41 from paganpasta/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
paganpasta authored Aug 14, 2022
2 parents 83e545e + 5c60f73 commit 0ba09cf
Show file tree
Hide file tree
Showing 26 changed files with 884 additions and 24 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ Picking a model and doing a forward pass is as simple as ...

## Get Started!

Start with any one of these easy to follow [tutorials](https://eqxvision.readthedocs.io/en/latest/getting_started/Transfer_Learning.ipynb).
Start with any one of these easy to follow [tutorials](https://eqxvision.readthedocs.io/en/latest/getting_started/Transfer_Learning/).


## Tips
Expand Down
2 changes: 1 addition & 1 deletion docs/api/models/classification/convnext.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# ConvNeXts
# ConvNeXt

---

Expand Down
2 changes: 1 addition & 1 deletion docs/api/models/classification/densenet.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# DenseNets
# DenseNet

---

Expand Down
42 changes: 42 additions & 0 deletions docs/api/models/classification/efficientnet.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Efficientnet-V1

---

::: eqxvision.models.EfficientNet
selection:
members:
- __init__
- __call__

---

::: eqxvision.models.efficientnet_b0

---

::: eqxvision.models.efficientnet_b1

---

::: eqxvision.models.efficientnet_b2

---

::: eqxvision.models.efficientnet_b3

---

::: eqxvision.models.efficientnet_b4

---

::: eqxvision.models.efficientnet_b5

---

::: eqxvision.models.efficientnet_b6

---

::: eqxvision.models.efficientnet_b7

14 changes: 14 additions & 0 deletions docs/api/models/classification/efficientnet_v2.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Efficientnet-V2


::: eqxvision.models.efficientnet_v2_s

---

::: eqxvision.models.efficientnet_v2_m

---

::: eqxvision.models.efficientnet_v2_l

---
2 changes: 1 addition & 1 deletion docs/api/models/classification/resnets.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# ResNets
# ResNet

::: eqxvision.models.ResNet
selection:
Expand Down
2 changes: 1 addition & 1 deletion docs/api/models/classification/squeeze.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SqueezeNets
# SqueezeNet

---

Expand Down
2 changes: 1 addition & 1 deletion docs/api/models/classification/vgg.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# VGGs
# VGG

::: eqxvision.models.VGG
selection:
Expand Down
2 changes: 1 addition & 1 deletion docs/api/models/classification/vit.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Vision Transformers
# Vision Transformer

---

Expand Down
8 changes: 5 additions & 3 deletions docs/comparison.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,17 @@

- For `Vgg` and `Googlenet`, there's a big gap in performance of
pre-trained networks. The difference arises after the `adaptive-pooling`,
which implies the networks can still be used as feature extractors (see results [here](./getting_started/Transfer_Learning.ipynb).
which implies the networks can still be used as feature extractors
(see results [here](./getting_started/Transfer_Learning.ipynb)).

- As `Mobilenet-v3` uses `maxpool` with `ceil` and a number of adaptive-pooling` layers,
the pretrained models are provided with no guarantees.
- As `Mobilenet-v3` uses `maxpool` with `ceil` and a number of adaptive-pooling` layers,
which is not fully supported (yet). The pretrained models are provided with no guarantees.


| Method | Torchvision | Eqxvision |
|--------------------|-------------|------------|
| Alexnet | 56.518 | 56.522 |
| Convnext_tiny | 82.132 | 82.120 |
| Densenet121 | 74.432 | 74.434 |
| Googlenet | 69.774 | 61.046 |
| Mobilenet_v2 | 71.878 | 71.856 |
Expand Down
16 changes: 8 additions & 8 deletions docs/getting_started/Transfer_Learning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -781,10 +781,10 @@
"output_type": "stream",
"name": "stdout",
"text": [
"\u001b[K |████████████████████████████████| 145 kB 13.4 MB/s \n",
"\u001b[K |████████████████████████████████| 66 kB 5.1 MB/s \n",
"\u001b[K |████████████████████████████████| 76 kB 6.0 MB/s \n",
"\u001b[?25h"
"\u001B[K |████████████████████████████████| 145 kB 13.4 MB/s \n",
"\u001B[K |████████████████████████████████| 66 kB 5.1 MB/s \n",
"\u001B[K |████████████████████████████████| 76 kB 6.0 MB/s \n",
"\u001B[?25h"
]
}
],
Expand Down Expand Up @@ -956,7 +956,7 @@
{
"cell_type": "markdown",
"source": [
"# Model Prep.\n",
"### Model Prep.\n",
"\n",
"We need to perform two steps after initialising the model.\n",
"\n",
Expand Down Expand Up @@ -1035,7 +1035,7 @@
{
"cell_type": "markdown",
"source": [
"# Utility Methods\n",
"### Utility Methods\n",
"\n",
"The `filter_spec` decides the params w.r.t to which the gradient is computed.\n",
"Here, we will be computing gradient w.r.t to only the `classifier` module.\n",
Expand Down Expand Up @@ -1093,7 +1093,7 @@
{
"cell_type": "markdown",
"source": [
"# Optimizer & Scheduler\n",
"### Optimizer & Scheduler\n",
"\n",
"The important bit to remember is wrapping the model in `eqx.filter` before passing it on to the optimizer. This step will `fail` if you forget the filter."
],
Expand Down Expand Up @@ -1128,7 +1128,7 @@
{
"cell_type": "markdown",
"source": [
"# The Training"
"### The Training"
],
"metadata": {
"id": "d21ktOAAZzv3",
Expand Down
2 changes: 1 addition & 1 deletion eqxvision/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
r"""Root package info."""
__version__ = "0.1.7"
__version__ = "0.1.8"

from . import layers, models, utils
2 changes: 2 additions & 0 deletions eqxvision/layers/conv_norm_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def __init__(
- `key`: A `jax.random.PRNGKey` used to provide randomness for parameter
initialisation. (Keyword only argument.)
"""
if key is None:
key = jax.random.PRNGKey(0)

if padding is None:
padding = (kernel_size - 1) // 2 * dilation
Expand Down
15 changes: 11 additions & 4 deletions eqxvision/layers/drop_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,21 @@ class DropPath(eqx.Module):
inference: bool
mode: str

def __init__(self, p: float = 0.0, inference: bool = False, mode="atomic"):
def __init__(self, p: float = 0.0, inference: bool = False, mode="global"):
"""**Arguments:**
- `p`: The probability to drop a sample entirely during forward pass
- `inference`: Defaults to `False`. If `True`, then the input is returned unchanged
This may be toggled with `equinox.tree_inference`
- `mode`: Can be set to `atomic` or `per_channel`. When `atomic`, the whole input is dropped or kept.
If `per_channel`, then the decision on each channel is computed independently. Defaults to `atomic`
- `mode`: Can be set to `global` or `local`. If `global`, the whole input is dropped or retained.
If `local`, then the decision on each input unit is computed independently. Defaults to `global`
!!! note
For `mode = local`, an input `(channels, dim_0, dim_1, ...)` is reshaped and transposed to
`(channels, dims).transpose()`. For each `dim x channels` element
the decision is made independently.
"""
self.p = p
self.inference = inference
Expand All @@ -42,7 +49,7 @@ def __call__(self, x, *, key: "jax.random.PRNGKey") -> Array:
)

keep_prob = 1 - self.p
if self.mode == "atomic":
if self.mode == "global":
return x * jrandom.bernoulli(key, p=keep_prob)
else:
return x * jnp.expand_dims(
Expand Down
14 changes: 14 additions & 0 deletions eqxvision/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,20 @@
densenet169,
densenet201,
)
from .classification.efficientnet import (
EfficientNet,
efficientnet_b0,
efficientnet_b1,
efficientnet_b2,
efficientnet_b3,
efficientnet_b4,
efficientnet_b5,
efficientnet_b6,
efficientnet_b7,
efficientnet_v2_l,
efficientnet_v2_m,
efficientnet_v2_s,
)
from .classification.googlenet import GoogLeNet, googlenet
from .classification.mobilenetv2 import mobilenet_v2, MobileNetV2
from .classification.mobilenetv3 import (
Expand Down
1 change: 1 addition & 0 deletions eqxvision/models/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
alexnet,
convnext,
densenet,
efficientnet,
googlenet,
mobilenetv2,
mobilenetv3,
Expand Down
2 changes: 1 addition & 1 deletion eqxvision/models/classification/convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(
self.layer_scale = jnp.asarray(
jnp.ones(shape=(dim, 1, 1)) * layer_scale, dtype=jnp.float32
)
self.stochastic_depth = DropPath(p=stochastic_depth_prob, mode="per_channel")
self.stochastic_depth = DropPath(p=stochastic_depth_prob, mode="local")

def __call__(self, x: Array, *, key: "jax.random.PRNGKey") -> Array:
"""**Arguments:**
Expand Down
Loading

0 comments on commit 0ba09cf

Please sign in to comment.