Skip to content

Commit

Permalink
Merge pull request #46 from paganpasta/remove_torch
Browse files Browse the repository at this point in the history
refactored tests and better torch handling
  • Loading branch information
paganpasta authored Aug 19, 2022
2 parents efc5a9a + 29a3564 commit 6b29961
Show file tree
Hide file tree
Showing 15 changed files with 91 additions and 60 deletions.
6 changes: 6 additions & 0 deletions eqxvision/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import os
import sys
import warnings
from pathlib import Path
from typing import NewType, Optional
Expand Down Expand Up @@ -106,6 +107,11 @@ def load_torch_weights(
**Returns:**
The model with weights loaded from the `PyTorch` checkpoint.
"""
if "torch" not in sys.modules:
raise RuntimeError(
" Torch package not found! Pretrained is only supported with the torch package."
)

if filepath is None and url is None:
raise ValueError("Both filepath and url cannot be empty!")
elif filepath and url:
Expand Down
44 changes: 19 additions & 25 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,34 +17,28 @@ def _getkey():
return _getkey


@pytest.fixture(scope="session")
def demo_image():
img = Image.open("./tests/static/img.png")
img = img.convert("RGB")
@pytest.fixture()
def img_transform():
def _transform(img_size):
return transforms.Compose(
[
transforms.Resize(img_size),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
]
)

transform = transforms.Compose(
[
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
]
)
return jnp.asarray(transform(img).unsqueeze(0))
return _transform


@pytest.fixture(scope="session")
def demo_image_256():
img = Image.open("./tests/static/img.png")
img = img.convert("RGB")

transform = transforms.Compose(
[
transforms.Resize(256),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
]
)
return jnp.asarray(transform(img).unsqueeze(0))
@pytest.fixture()
def demo_image(img_transform):
def _demo_image(img_size):
img = Image.open("./tests/static/img.png")
img = img.convert("RGB")
return jnp.asarray(img_transform(img_size)(img).unsqueeze(0))

return _demo_image


@pytest.fixture(scope="session")
Expand Down
20 changes: 8 additions & 12 deletions tests/test_models/test_alexnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
import jax.numpy as jnp

import eqxvision.models as models
from eqxvision import utils


class TestAlexNet:
answer = (1, 1000)

def test_output_shape(self, getkey, demo_image):
img = demo_image(224)
c_counter = 0

@eqx.filter_jit
Expand All @@ -20,28 +20,24 @@ def forward(model, x, key):
return jax.vmap(model)(x, key=keys)

model = models.alexnet(num_classes=1000)
output = forward(model, demo_image, getkey())
output = forward(model, img, getkey())
assert output.shape == self.answer
forward(model, demo_image, getkey())
forward(model, img, getkey())
assert c_counter == 1

def test_pretrained(self, getkey, demo_image, net_preds):
img = demo_image(224)

@eqx.filter_jit
def forward(net, imgs, keys):
outputs = jax.vmap(net, axis_name="batch")(imgs, key=keys)
return outputs

model = models.alexnet(pretrained=False)
new_model = utils.load_torch_weights(
model=model, url=utils.MODEL_URLS["alexnet"]
)

new_model = eqx.tree_inference(new_model, True)
assert model != new_model
model = models.alexnet(pretrained=True)

pt_outputs = net_preds["alexnet"]
new_model = eqx.tree_inference(new_model, True)
new_model = eqx.tree_inference(model, True)
keys = jax.random.split(getkey(), 1)
eqx_outputs = forward(new_model.features, demo_image, keys)
eqx_outputs = forward(new_model.features, img, keys)

assert jnp.isclose(pt_outputs, eqx_outputs, atol=1e-4).all()
7 changes: 5 additions & 2 deletions tests/test_models/test_convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,22 @@ class TestConvNext:

@pytest.mark.parametrize("model_func", model_list)
def test_convnext(self, model_func, demo_image, getkey):
img = demo_image(224)

@eqx.filter_jit
def forward(net, x, key):
keys = jax.random.split(key, x.shape[0])
ans = jax.vmap(net, axis_name="batch")(x, key=keys)
return ans

model = model_func[1](num_classes=1000)
output = forward(model, demo_image, getkey())
output = forward(model, img, getkey())
assert output.shape == self.answer

@pytest.mark.parametrize("model_func", model_list)
def test_pretrained(self, getkey, model_func, demo_image, net_preds):
keys = jax.random.split(getkey(), 1)
img = demo_image(224)

@eqx.filter_jit
def forward(net, imgs, keys):
Expand All @@ -35,7 +38,7 @@ def forward(net, imgs, keys):

model = model_func[1](pretrained=True)
model = eqx.tree_inference(model, True)
eqx_outputs = forward(model, demo_image, keys)
eqx_outputs = forward(model, img, keys)
pt_outputs = net_preds[model_func[0]]

assert jnp.argmax(eqx_outputs) == jnp.argmax(pt_outputs)
7 changes: 5 additions & 2 deletions tests/test_models/test_densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,21 @@ class TestDenseNet:

@pytest.mark.parametrize("model_func", model_list)
def test_densenets(self, model_func, demo_image, getkey):
img = demo_image(224)

@eqx.filter_jit
def forward(net, x, key):
keys = jax.random.split(key, x.shape[0])
ans = jax.vmap(net, axis_name="batch")(x, key=keys)
return ans

model = model_func[1](num_classes=1000)
output = forward(model, demo_image, getkey())
output = forward(model, img, getkey())
assert output.shape == self.answer

@pytest.mark.parametrize("model_func", model_list)
def test_pretrained(self, getkey, model_func, demo_image, net_preds):
img = demo_image(224)
keys = jax.random.split(getkey(), 1)

@eqx.filter_jit
Expand All @@ -36,6 +39,6 @@ def forward(net, imgs, keys):
model = model_func[1](pretrained=True)
model = eqx.tree_inference(model, True)
pt_outputs = net_preds[model_func[0]]
eqx_outputs = forward(model, demo_image, keys)
eqx_outputs = forward(model, img, keys)

assert jnp.isclose(eqx_outputs, pt_outputs, atol=1e-4).all()
7 changes: 5 additions & 2 deletions tests/test_models/test_efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,21 @@ class TestEfficientNet:

@pytest.mark.parametrize("model_func", model_list)
def test_efficientnet(self, model_func, demo_image, getkey):
img = demo_image(224)

@eqx.filter_jit
def forward(net, x, key):
keys = jax.random.split(key, x.shape[0])
ans = jax.vmap(net, axis_name="batch")(x, key=keys)
return ans

model = model_func[1](num_classes=1000)
output = forward(model, demo_image, getkey())
output = forward(model, img, getkey())
assert output.shape == self.answer

@pytest.mark.parametrize("model_func", model_list)
def test_pretrained(self, getkey, model_func, demo_image, net_preds):
img = demo_image(224)
keys = jax.random.split(getkey(), 1)

@eqx.filter_jit
Expand All @@ -38,7 +41,7 @@ def forward(net, imgs, keys):

model = model_func[1](pretrained=True)
model = eqx.tree_inference(model, True)
eqx_outputs = forward(model, demo_image, keys)
eqx_outputs = forward(model, img, keys)
pt_outputs = net_preds[model_func[0]]

assert jnp.argmax(eqx_outputs) == jnp.argmax(pt_outputs)
4 changes: 3 additions & 1 deletion tests/test_models/test_googlenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@ class TestGoogLeNet:
answer = (1, 1000)

def test_output_shape(self, demo_image, getkey):
img = demo_image(224)

@eqx.filter_jit
def forward(net, x, key):
keys = jax.random.split(key, x.shape[0])
ans = jax.vmap(net, axis_name="batch")(x, key=keys)
return ans

model = models.googlenet(num_classes=1000, aux_logits=False)
output = forward(model, demo_image, getkey())
output = forward(model, img, getkey())
assert output.shape == self.answer
7 changes: 5 additions & 2 deletions tests/test_models/test_mobilenetv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,21 @@ class TestMobileNetv2:

@pytest.mark.parametrize("model_func", model_list)
def test_mobilenet(self, model_func, demo_image, getkey):
img = demo_image(224)

@eqx.filter_jit
def forward(net, x, key):
keys = jax.random.split(key, x.shape[0])
ans = jax.vmap(net, axis_name="batch")(x, key=keys)
return ans

model = model_func[1](num_classes=1000)
output = forward(model, demo_image, getkey())
output = forward(model, img, getkey())
assert output.shape == self.answer

@pytest.mark.parametrize("model_func", model_list)
def test_pretrained(self, getkey, model_func, demo_image, net_preds):
img = demo_image(224)
keys = jax.random.split(getkey(), 1)

@eqx.filter_jit
Expand All @@ -36,6 +39,6 @@ def forward(net, imgs, keys):
model = model_func[1](pretrained=True)
model = eqx.tree_inference(model, True)
pt_outputs = net_preds[model_func[0]]
eqx_outputs = forward(model, demo_image, keys)
eqx_outputs = forward(model, img, keys)

assert jnp.argmax(eqx_outputs, axis=1) == jnp.argmax(pt_outputs, axis=1)
4 changes: 3 additions & 1 deletion tests/test_models/test_mobilenetv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@ class TestMobileNetv3:

@pytest.mark.parametrize("model_func", model_list)
def test_mobilenet(self, model_func, demo_image, getkey):
img = demo_image(224)

@eqx.filter_jit
def forward(net, x, key):
keys = jax.random.split(key, x.shape[0])
ans = jax.vmap(net, axis_name="batch")(x, key=keys)
return ans

model = model_func[1](num_classes=1000)
output = forward(model, demo_image, getkey())
output = forward(model, img, getkey())
assert output.shape == self.answer
7 changes: 5 additions & 2 deletions tests/test_models/test_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,21 @@ class TestResNet:

@pytest.mark.parametrize("model_func", model_list)
def test_resnets(self, model_func, demo_image, getkey):
img = demo_image(224)

@eqx.filter_jit
def forward(net, x, key):
keys = jax.random.split(key, x.shape[0])
ans = jax.vmap(net, axis_name="batch")(x, key=keys)
return ans

model = model_func[1](num_classes=1000)
output = forward(model, demo_image, getkey())
output = forward(model, img, getkey())
assert output.shape == self.answer

@pytest.mark.parametrize("model_func", model_list)
def test_pretrained(self, getkey, model_func, demo_image, net_preds):
img = demo_image(224)
keys = jax.random.split(getkey(), 1)

@eqx.filter_jit
Expand All @@ -36,6 +39,6 @@ def forward(net, imgs, keys):
model = model_func[1](pretrained=True)
model = eqx.tree_inference(model, True)
pt_outputs = net_preds[model_func[0]]
eqx_outputs = forward(model, demo_image, keys)
eqx_outputs = forward(model, img, keys)

assert jnp.isclose(eqx_outputs, pt_outputs, atol=1e-4).all()
7 changes: 5 additions & 2 deletions tests/test_models/test_shufflenetv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,21 @@ class TestShuffleNetV2:

@pytest.mark.parametrize("model_func", model_list)
def test_shufflenet(self, model_func, demo_image, getkey):
img = demo_image(224)

@eqx.filter_jit
def forward(net, x, key):
keys = jax.random.split(key, x.shape[0])
ans = jax.vmap(net, axis_name="batch")(x, key=keys)
return ans

model = model_func[1](num_classes=1000)
output = forward(model, demo_image, getkey())
output = forward(model, img, getkey())
assert output.shape == self.answer

@pytest.mark.parametrize("model_func", model_list)
def test_pretrained(self, getkey, model_func, demo_image, net_preds):
img = demo_image(224)
keys = jax.random.split(getkey(), 1)

@eqx.filter_jit
Expand All @@ -36,6 +39,6 @@ def forward(net, imgs, keys):
model = model_func[1](pretrained=True)
model = eqx.tree_inference(model, True)
pt_outputs = net_preds[model_func[0]]
eqx_outputs = forward(model, demo_image, keys)
eqx_outputs = forward(model, img, keys)

assert jnp.isclose(eqx_outputs, pt_outputs, atol=1e-4).all()
8 changes: 6 additions & 2 deletions tests/test_models/test_squeezenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,21 @@ class TestSqueezeNet:

@pytest.mark.parametrize("model_func", model_list)
def test_sneNet(self, model_func, demo_image, getkey):
img = demo_image(224)

@eqx.filter_jit
def forward(net, x, key):
keys = jax.random.split(key, x.shape[0])
ans = jax.vmap(net, axis_name="batch")(x, key=keys)
return ans

model = model_func(num_classes=1000)
output = forward(model, demo_image, getkey())
output = forward(model, img, getkey())
assert output.shape == self.answer

def test_pretrained(self, getkey, demo_image, net_preds):
img = demo_image(224)

@eqx.filter_jit
def forward(net, imgs, keys):
outputs = jax.vmap(net, axis_name="batch")(imgs, key=keys)
Expand All @@ -37,6 +41,6 @@ def forward(net, imgs, keys):
pt_outputs = net_preds["squeezenet1_0"]
new_model = eqx.tree_inference(new_model, True)
keys = jax.random.split(getkey(), 1)
eqx_outputs = forward(new_model, demo_image, keys)
eqx_outputs = forward(new_model, img, keys)

assert jnp.argmax(pt_outputs) == jnp.argmax(eqx_outputs)
Loading

0 comments on commit 6b29961

Please sign in to comment.