Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RealPLKSR LayerNorm #313

Merged
merged 3 commits into from
Jan 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 29 additions & 5 deletions libs/spandrel/spandrel/architectures/PLKSR/__arch/RealPLKSR.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,20 @@
from spandrel.util import store_hyperparameters


class LayerNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.bias = nn.Parameter(torch.zeros(dim))
self.eps = eps

def forward(self, x: torch.Tensor) -> torch.Tensor:
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
return self.weight[:, None, None] * x + self.bias[:, None, None]


class DCCM(nn.Sequential):
"Doubled Convolutional Channel Mixer"

Expand Down Expand Up @@ -56,11 +70,15 @@ def __init__(
dim: int,
kernel_size: int,
split_ratio: float,
norm_groups: int,
use_ea: bool = True,
norm_groups: int = 4,
use_layer_norm: bool = False,
):
super().__init__()

# Layer Norm
self.layer_norm = LayerNorm(dim) if use_layer_norm else nn.Identity()

# Local Texture
self.channel_mixer = DCCM(dim)

Expand All @@ -80,11 +98,16 @@ def __init__(
self.refine = nn.Conv2d(dim, dim, 1, 1, 0)
trunc_normal_(self.refine.weight, std=0.02)

# Group Normalization
self.norm = nn.GroupNorm(norm_groups, dim)
if not use_layer_norm:
self.norm = nn.GroupNorm(norm_groups, dim)
nn.init.constant_(self.norm.bias, 0)
nn.init.constant_(self.norm.weight, 1.0)
else:
self.norm = nn.Identity()

def forward(self, x: torch.Tensor) -> torch.Tensor:
x_skip = x
x = self.layer_norm(x)
x = self.channel_mixer(x)
x = self.lk(x)
x = self.attn(x)
Expand Down Expand Up @@ -114,6 +137,7 @@ def __init__(
norm_groups: int = 4,
dropout: float = 0,
dysample: bool = False,
layer_norm: bool = False,
):
super().__init__()

Expand All @@ -128,11 +152,11 @@ def __init__(
self.feats = nn.Sequential(
*[nn.Conv2d(in_ch, dim, 3, 1, 1)]
+ [
PLKBlock(dim, kernel_size, split_ratio, norm_groups, use_ea)
PLKBlock(dim, kernel_size, split_ratio, use_ea, norm_groups, layer_norm)
for _ in range(n_blocks)
]
+ [nn.Dropout2d(dropout)]
+ [nn.Conv2d(dim, 3 * upscaling_factor**2, 3, 1, 1)]
+ [nn.Conv2d(dim, out_ch * upscaling_factor**2, 3, 1, 1)]
)
trunc_normal_(self.feats[0].weight, std=0.02)
trunc_normal_(self.feats[-1].weight, std=0.02)
Expand Down
13 changes: 10 additions & 3 deletions libs/spandrel/spandrel/architectures/PLKSR/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

import math
from typing import Literal, Sequence, Union
from collections.abc import Sequence
from typing import Literal, Union

from typing_extensions import override

Expand Down Expand Up @@ -41,6 +42,7 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[_PLKSR]:
kernel_size = 17
split_ratio = 0.25
use_ea = True
supports_half = True

dim = state_dict["feats.0.weight"].shape[0]

Expand Down Expand Up @@ -118,10 +120,14 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[_PLKSR]:
n_blocks = total_feat_layers - 3
kernel_size = state_dict["feats.1.lk.conv.weight"].shape[2]
split_ratio = state_dict["feats.1.lk.conv.weight"].shape[0] / dim

use_layer_norm = "feats.1.layer_norm.bias" in state_dict
use_dysample = "to_img.init_pos" in state_dict
if use_dysample:
more_tags.append("DySample")
if use_layer_norm:
more_tags.append("LayerNorm")
else:
supports_half = False

model = RealPLKSR(
dim=dim,
Expand All @@ -132,6 +138,7 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[_PLKSR]:
use_ea=use_ea,
norm_groups=4, # un-detectable
dysample=use_dysample,
layer_norm=use_layer_norm,
)
else:
raise ValueError("Unknown model type")
Expand All @@ -142,7 +149,7 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[_PLKSR]:
architecture=self,
purpose="Restoration" if scale == 1 else "SR",
tags=[f"{dim}dim", f"{n_blocks}nb", f"{kernel_size}ks", *more_tags],
supports_half=False,
supports_half=supports_half,
supports_bfloat16=True,
scale=scale,
input_channels=3,
Expand Down
54 changes: 50 additions & 4 deletions tests/__snapshots__/test_PLKSR.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
scale=4,
size_requirements=SizeRequirements(minimum=0, multiple_of=1, square=False),
supports_bfloat16=True,
supports_half=False,
supports_half=True,
tags=list([
'64dim',
'12nb',
Expand All @@ -33,7 +33,7 @@
scale=2,
size_requirements=SizeRequirements(minimum=0, multiple_of=1, square=False),
supports_bfloat16=True,
supports_half=False,
supports_half=True,
tags=list([
'64dim',
'28nb',
Expand All @@ -55,7 +55,7 @@
scale=3,
size_requirements=SizeRequirements(minimum=0, multiple_of=1, square=False),
supports_bfloat16=True,
supports_half=False,
supports_half=True,
tags=list([
'64dim',
'28nb',
Expand All @@ -77,7 +77,7 @@
scale=4,
size_requirements=SizeRequirements(minimum=0, multiple_of=1, square=False),
supports_bfloat16=True,
supports_half=False,
supports_half=True,
tags=list([
'64dim',
'28nb',
Expand Down Expand Up @@ -154,3 +154,49 @@
tiling=<ModelTiling.SUPPORTED: 1>,
)
# ---
# name: test_RealPLKSR_LayerNorm_2x
ImageModelDescriptor(
architecture=PLKSRArch(
id='PLKSR',
name='PLKSR',
),
input_channels=3,
output_channels=3,
purpose='SR',
scale=2,
size_requirements=SizeRequirements(minimum=0, multiple_of=1, square=False),
supports_bfloat16=True,
supports_half=True,
tags=list([
'64dim',
'28nb',
'17ks',
'Real',
'LayerNorm',
]),
tiling=<ModelTiling.SUPPORTED: 1>,
)
# ---
# name: test_RealPLKSR_LayerNorm_4x
ImageModelDescriptor(
architecture=PLKSRArch(
id='PLKSR',
name='PLKSR',
),
input_channels=3,
output_channels=3,
purpose='SR',
scale=4,
size_requirements=SizeRequirements(minimum=0, multiple_of=1, square=False),
supports_bfloat16=True,
supports_half=True,
tags=list([
'64dim',
'28nb',
'17ks',
'Real',
'LayerNorm',
]),
tiling=<ModelTiling.SUPPORTED: 1>,
)
# ---
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
33 changes: 33 additions & 0 deletions tests/test_PLKSR.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def test_load():
lambda: RealPLKSR(split_ratio=0.75),
lambda: RealPLKSR(use_ea=False),
lambda: RealPLKSR(dysample=True),
lambda: RealPLKSR(layer_norm=True),
)


Expand All @@ -68,6 +69,12 @@ def test_size_requirements():
)
assert_size_requirements(file.load_model())

file = ModelFile.from_url(
"/~https://github.com/the-database/traiNNer-redux/releases/download/pretrained-models/4x_DF2K_Redux_RealPLKSRLayerNorm_50k.safetensors",
name="4x_DF2K_Redux_RealPLKSRLayerNorm_50k.safetensors",
)
assert_size_requirements(file.load_model())


def test_PLKSR_official_x4(snapshot):
file = ModelFile.from_url(
Expand Down Expand Up @@ -172,3 +179,29 @@ def test_RealPLKSR_DySample(snapshot):
model,
[TestImage.SR_16, TestImage.SR_32, TestImage.SR_64],
)


def test_RealPLKSR_LayerNorm_4x(snapshot):
file = ModelFile.from_url(
"/~https://github.com/the-database/traiNNer-redux/releases/download/pretrained-models/4x_DF2K_Redux_RealPLKSRLayerNorm_50k.safetensors",
name="4x_DF2K_Redux_RealPLKSRLayerNorm_50k.safetensors",
)
model = file.load_model()
assert model == snapshot(exclude=disallowed_props)
assert isinstance(model.model, RealPLKSR)
assert_image_inference(
file, model, [TestImage.SR_16, TestImage.SR_32, TestImage.SR_64]
)


def test_RealPLKSR_LayerNorm_2x(snapshot):
file = ModelFile.from_url(
"/~https://github.com/the-database/traiNNer-redux/releases/download/pretrained-models/2x_DF2K_Redux_RealPLKSRLayerNorm_450k.safetensors",
name="2x_DF2K_Redux_RealPLKSRLayerNorm_450k.safetensors",
)
model = file.load_model()
assert model == snapshot(exclude=disallowed_props)
assert isinstance(model.model, RealPLKSR)
assert_image_inference(
file, model, [TestImage.SR_16, TestImage.SR_32, TestImage.SR_64]
)
Loading