Skip to content

Commit

Permalink
fix: SRCNN impl and tests (#26)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tohrusky authored Nov 24, 2024
1 parent c7e8321 commit 23ece6a
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 2 deletions.
2 changes: 1 addition & 1 deletion ccrestoration/arch/srcnn_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(self, num_channels: int = 1, scale: int = 2) -> None:
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.interpolate(x, scale_factor=self.scale, mode="bilinear")

if self.num_channels == 1:
if self.num_channels == 1 and x.size(1) == 3:
# RGB -> YUV
x = rgb_to_yuv(x)
y, u, v = x[:, 0:1, ...], x[:, 1:2, ...], x[:, 2:3, ...]
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ exclude_also = [
[tool.coverage.run]
omit = [
"ccrestoration/arch/*",
"ccrestoration/vs/*"
"ccrestoration/vs/*",
"ccrestoration/util/device.py"
]

[tool.mypy]
Expand Down
40 changes: 40 additions & 0 deletions tests/test_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import cv2
import pytest
import torch
from torchvision import transforms

from ccrestoration.util.color import rgb_to_yuv, yuv_to_rgb
from ccrestoration.util.device import DEFAULT_DEVICE

from .util import calculate_image_similarity, load_image


def test_device() -> None:
print(DEFAULT_DEVICE)


def test_color() -> None:
with pytest.raises(TypeError):
rgb_to_yuv(1)
with pytest.raises(TypeError):
yuv_to_rgb(1)

with pytest.raises(ValueError):
rgb_to_yuv(torch.zeros(1, 1))
with pytest.raises(ValueError):
yuv_to_rgb(torch.zeros(1, 1))

img = load_image()
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

img = transforms.ToTensor()(img).unsqueeze(0).to("cpu")

img = rgb_to_yuv(img)
img = yuv_to_rgb(img)

img = img.squeeze(0).permute(1, 2, 0).cpu().numpy()
img = (img * 255).clip(0, 255).astype("uint8")

img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)

assert calculate_image_similarity(img, load_image())

0 comments on commit 23ece6a

Please sign in to comment.