From 23ece6aaf2bcf8e903ac3b558040efdf41113de5 Mon Sep 17 00:00:00 2001 From: Tohru <65994850+Tohrusky@users.noreply.github.com> Date: Sun, 24 Nov 2024 23:15:35 +0000 Subject: [PATCH] fix: SRCNN impl and tests (#26) --- ccrestoration/arch/srcnn_arch.py | 2 +- pyproject.toml | 3 ++- tests/test_util.py | 40 ++++++++++++++++++++++++++++++++ 3 files changed, 43 insertions(+), 2 deletions(-) create mode 100644 tests/test_util.py diff --git a/ccrestoration/arch/srcnn_arch.py b/ccrestoration/arch/srcnn_arch.py index dc45d83..0b94060 100644 --- a/ccrestoration/arch/srcnn_arch.py +++ b/ccrestoration/arch/srcnn_arch.py @@ -21,7 +21,7 @@ def __init__(self, num_channels: int = 1, scale: int = 2) -> None: def forward(self, x: torch.Tensor) -> torch.Tensor: x = F.interpolate(x, scale_factor=self.scale, mode="bilinear") - if self.num_channels == 1: + if self.num_channels == 1 and x.size(1) == 3: # RGB -> YUV x = rgb_to_yuv(x) y, u, v = x[:, 0:1, ...], x[:, 1:2, ...], x[:, 2:3, ...] diff --git a/pyproject.toml b/pyproject.toml index 8cab944..e84f048 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,8 @@ exclude_also = [ [tool.coverage.run] omit = [ "ccrestoration/arch/*", - "ccrestoration/vs/*" + "ccrestoration/vs/*", + "ccrestoration/util/device.py" ] [tool.mypy] diff --git a/tests/test_util.py b/tests/test_util.py new file mode 100644 index 0000000..22098dd --- /dev/null +++ b/tests/test_util.py @@ -0,0 +1,40 @@ +import cv2 +import pytest +import torch +from torchvision import transforms + +from ccrestoration.util.color import rgb_to_yuv, yuv_to_rgb +from ccrestoration.util.device import DEFAULT_DEVICE + +from .util import calculate_image_similarity, load_image + + +def test_device() -> None: + print(DEFAULT_DEVICE) + + +def test_color() -> None: + with pytest.raises(TypeError): + rgb_to_yuv(1) + with pytest.raises(TypeError): + yuv_to_rgb(1) + + with pytest.raises(ValueError): + rgb_to_yuv(torch.zeros(1, 1)) + with pytest.raises(ValueError): + yuv_to_rgb(torch.zeros(1, 1)) + + img = load_image() + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + + img = transforms.ToTensor()(img).unsqueeze(0).to("cpu") + + img = rgb_to_yuv(img) + img = yuv_to_rgb(img) + + img = img.squeeze(0).permute(1, 2, 0).cpu().numpy() + img = (img * 255).clip(0, 255).astype("uint8") + + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + + assert calculate_image_similarity(img, load_image())