Skip to content

Commit

Permalink
Define the public API by what's documented (#287)
Browse files Browse the repository at this point in the history
* Define the public API by what's documented

* Fix inits

* Fixed arch doc strings

* Missing four
  • Loading branch information
RunDevelopment authored Jul 11, 2024
1 parent ce00b4a commit 7c1094f
Show file tree
Hide file tree
Showing 106 changed files with 228 additions and 113 deletions.
Empty file.
7 changes: 4 additions & 3 deletions libs/spandrel/spandrel/architectures/ATD/__arch/atd_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,9 +881,10 @@ def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):

@store_hyperparameters()
class ATD(nn.Module):
r"""ATD
A PyTorch impl of : `Transcending the Limit of Local Window: Advanced Super-Resolution Transformer
with Adaptive Token Dictionary`.
r"""
ATD
A PyTorch impl of : `Transcending the Limit of Local Window: Advanced Super-Resolution Transformer with Adaptive Token Dictionary`.
Args:
img_size (int | tuple(int)): Input image size. Default 64
Expand Down
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/ATD/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[ATD]:
output_channels=in_chans,
size_requirements=SizeRequirements(minimum=8),
)


__all__ = ["ATDArch", "ATD"]
4 changes: 3 additions & 1 deletion libs/spandrel/spandrel/architectures/CRAFT/__arch/CRAFT.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,10 @@ def forward(self, biases):


class Attention_regular(nn.Module):
"""Regular Rectangle-Window (regular-Rwin) self-attention with dynamic relative position bias.
"""
Regular Rectangle-Window (regular-Rwin) self-attention with dynamic relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
resolution (int): Input resolution.
Expand Down
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/CRAFT/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[CRAFT]:
output_channels=in_chans,
size_requirements=SizeRequirements(minimum=16, multiple_of=16),
)


__all__ = ["CRAFTArch", "CRAFT"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/Compact/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[Compact]:
input_channels=in_nc,
output_channels=out_nc,
)


__all__ = ["CompactArch", "Compact"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/DAT/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[DAT]:
output_channels=in_chans,
size_requirements=SizeRequirements(minimum=16),
)


__all__ = ["DATArch", "DAT"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/DCTLSA/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[DCTLSA]:
output_channels=out_nc,
size_requirements=SizeRequirements(minimum=16),
)


__all__ = ["DCTLSAArch", "DCTLSA"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/DITN/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[DITN]:
output_channels=3, # hard-coded in the architecture
size_requirements=SizeRequirements(multiple_of=patch_size),
)


__all__ = ["DITNArch", "DITN"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/DRCT/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[DRCT]:
output_channels=in_chans,
size_requirements=SizeRequirements(multiple_of=16),
)


__all__ = ["DRCTArch", "DRCT"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/DRUNet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,6 @@ def call(model: DRUNet, image: torch.Tensor) -> torch.Tensor:
size_requirements=SizeRequirements(multiple_of=8),
call_fn=call,
)


__all__ = ["DRUNetArch", "DRUNet"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/DnCNN/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,6 @@ def call(model: DnCNN, image: torch.Tensor) -> torch.Tensor:
size_requirements=SizeRequirements(),
call_fn=call,
)


__all__ = ["DnCNNArch", "DnCNN"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/ESRGAN/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,3 +233,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[ESRGAN]:
multiple_of=4 if shuffle_factor else 1,
),
)


__all__ = ["ESRGANArch", "ESRGAN"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/FBCNN/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[FBCNN]:
output_channels=out_nc,
call_fn=lambda model, image: model(image)[0],
)


__all__ = ["FBCNNArch", "FBCNN"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/FFTformer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[FFTformer]:
output_channels=out_channels,
size_requirements=SizeRequirements(multiple_of=32),
)


__all__ = ["FFTformerArch", "FFTformer"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/GFPGAN/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[GFPGAN]:
size_requirements=SizeRequirements(minimum=512),
call_fn=lambda model, image: model(image)[0],
)


__all__ = ["GFPGANArch", "GFPGAN"]
Empty file.
29 changes: 18 additions & 11 deletions libs/spandrel/spandrel/architectures/GRL/__arch/grl.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@


class TransformerStage(nn.Module):
"""Transformer stage.
"""
Transformer stage.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
Expand All @@ -58,11 +60,13 @@ class TransformerStage(nn.Module):
pretrained_stripe_size (list[int]): pretrained stripe size. This is actually not used. Default: [0, 0].
conv_type: The convolutional block before residual connection.
init_method: initialization method of the weight parameters used to train large scale models.
Choices: n, normal -- Swin V1 init method.
l, layernorm -- Swin V2 init method. Zero the weight and bias in the post layer normalization layer.
r, res_rescale -- EDSR rescale method. Rescale the residual blocks with a scaling factor 0.1
w, weight_rescale -- MSRResNet rescale method. Rescale the weight parameter in residual blocks with a scaling factor 0.1
t, trunc_normal_ -- nn.Linear, trunc_normal; nn.Conv2d, weight_rescale
Choices:
* n, normal -- Swin V1 init method.
* l, layernorm -- Swin V2 init method. Zero the weight and bias in the post layer normalization layer.
* r, res_rescale -- EDSR rescale method. Rescale the residual blocks with a scaling factor 0.1
* w, weight_rescale -- MSRResNet rescale method. Rescale the weight parameter in residual blocks with a scaling factor 0.1
* t, `trunc_normal_` -- nn.Linear, trunc_normal, nn.Conv2d, weight_rescale
fairscale_checkpoint (bool): Whether to use fairscale checkpoint.
offload_to_cpu (bool): used by fairscale_checkpoint
args:
Expand Down Expand Up @@ -185,6 +189,7 @@ def flops(self):
@store_hyperparameters()
class GRL(nn.Module):
r"""Image restoration transformer with global, non-local, and local connections
Args:
img_size (int | list[int]): Input image size. Default 64
in_channels (int): Number of input image channels. Default: 3
Expand Down Expand Up @@ -216,11 +221,13 @@ class GRL(nn.Module):
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
conv_type (str): The convolutional block before residual connection. Default: 1conv. Choices: 1conv, 3conv, 1conv1x1, linear
init_method: initialization method of the weight parameters used to train large scale models.
Choices: n, normal -- Swin V1 init method.
l, layernorm -- Swin V2 init method. Zero the weight and bias in the post layer normalization layer.
r, res_rescale -- EDSR rescale method. Rescale the residual blocks with a scaling factor 0.1
w, weight_rescale -- MSRResNet rescale method. Rescale the weight parameter in residual blocks with a scaling factor 0.1
t, trunc_normal_ -- nn.Linear, trunc_normal; nn.Conv2d, weight_rescale
Choices:
* n, normal -- Swin V1 init method.
* l, layernorm -- Swin V2 init method. Zero the weight and bias in the post layer normalization layer.
* r, res_rescale -- EDSR rescale method. Rescale the residual blocks with a scaling factor 0.1
* w, weight_rescale -- MSRResNet rescale method. Rescale the weight parameter in residual blocks with a scaling factor 0.1
* t, `trunc_normal_` -- nn.Linear, trunc_normal, nn.Conv2d, weight_rescale
fairscale_checkpoint (bool): Whether to use fairscale checkpoint.
offload_to_cpu (bool): used by fairscale_checkpoint
euclidean_dist (bool): use Euclidean distance or inner product as the similarity metric. An ablation study.
Expand Down
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/GRL/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,3 +359,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[GRL]:
input_channels=in_channels,
output_channels=out_channels,
)


__all__ = ["GRLArch", "GRL"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/HAT/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,3 +225,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[HAT]:
output_channels=in_chans,
size_requirements=SizeRequirements(minimum=16),
)


__all__ = ["HATArch", "HAT"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/HVICIDNet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[HVICIDNet]:
size_requirements=SizeRequirements(multiple_of=8),
tiling=ModelTiling.DISCOURAGED,
)


__all__ = ["HVICIDNetArch", "HVICIDNet"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/IPT/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,6 @@ def call(model: IPT, x: torch.Tensor):
size_requirements=SizeRequirements(minimum=patch_size),
call_fn=call,
)


__all__ = ["IPTArch", "IPT"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/KBNet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[_KBNet]:
return self._load_l(state_dict)
else:
return self._load_s(state_dict)


__all__ = ["KBNetArch", "KBNet_s", "KBNet_l"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/LaMa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,6 @@ def load(self, state_dict: StateDict) -> MaskedImageModelDescriptor[LaMa]:
output_channels=out_nc,
size_requirements=SizeRequirements(minimum=16, multiple_of=8),
)


__all__ = ["LaMaArch", "LaMa"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/MMRealSR/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,3 +193,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[MMRealSR]:
size_requirements=SizeRequirements(minimum=16),
call_fn=lambda model, image: model(image)[0],
)


__all__ = ["MMRealSRArch", "MMRealSR"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/MixDehazeNet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[MixDehazeNet]:
tiling=ModelTiling.DISCOURAGED,
call_fn=lambda model, image: model(image) * 0.5 + 0.5,
)


__all__ = ["MixDehazeNetArch", "MixDehazeNet"]
20 changes: 10 additions & 10 deletions libs/spandrel/spandrel/architectures/NAFNet/__arch/NAFNet_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@
# Copyright (c) 2022 megvii-model. All Rights Reserved.
# ------------------------------------------------------------------------

"""
Simple Baselines for Image Restoration
@article{chen2022simple,
title={Simple Baselines for Image Restoration},
author={Chen, Liangyu and Chu, Xiaojie and Zhang, Xiangyu and Sun, Jian},
journal={arXiv preprint arXiv:2204.04676},
year={2022}
}
"""
# """
# Simple Baselines for Image Restoration

# @article{chen2022simple,
# title={Simple Baselines for Image Restoration},
# author={Chen, Liangyu and Chu, Xiaojie and Zhang, Xiangyu and Sun, Jian},
# journal={arXiv preprint arXiv:2204.04676},
# year={2022}
# }
# """

from __future__ import annotations

Expand Down
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/NAFNet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[NAFNet]:
input_channels=img_channel,
output_channels=img_channel,
)


__all__ = ["NAFNetArch", "NAFNet"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/OmniSR/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[OmniSR]:
output_channels=num_out_ch,
size_requirements=SizeRequirements(minimum=16),
)


__all__ = ["OmniSRArch", "OmniSR"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/PLKSR/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[_PLKSR]:
input_channels=3,
output_channels=3,
)


__all__ = ["PLKSRArch", "PLKSR", "RealPLKSR"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/RGT/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[RGT]:
output_channels=in_chans,
size_requirements=SizeRequirements(minimum=16),
)


__all__ = ["RGTArch", "RGT"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/RealCUGAN/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[_RealCUGAN]:
output_channels=out_channels,
size_requirements=size_requirements,
)


__all__ = ["RealCUGANArch", "UpCunet2x", "UpCunet3x", "UpCunet4x", "UpCunet2x_fast"]
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@

class VectorQuantizer(nn.Module):
"""
see /~https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
____________________________________________
Discretization bottleneck part of the VQ-VAE.
Inputs:
- n_e : number of embeddings
- e_dim : dimension of embedding
- beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
_____________________________________________
see /~https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
Args:
n_e : number of embeddings
e_dim : dimension of embedding
beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
"""

def __init__(self, n_e, e_dim, beta):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,6 @@ def call(model: RestoreFormer, x: torch.Tensor) -> torch.Tensor:
size_requirements=SizeRequirements(multiple_of=32),
call_fn=call,
)


__all__ = ["RestoreFormerArch", "RestoreFormer"]
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[RetinexFormer]:
tiling=ModelTiling.DISCOURAGED,
call_fn=_call_fn,
)


__all__ = ["RetinexFormerArch", "RetinexFormer"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/SAFMN/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[SAFMN]:
output_channels=3, # hard-coded in the arch
size_requirements=SizeRequirements(multiple_of=8),
)


__all__ = ["SAFMNArch", "SAFMN"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/SAFMNBCIE/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[SAFMNBCIE]:
output_channels=3, # hard-coded in the arch
size_requirements=SizeRequirements(multiple_of=16),
)


__all__ = ["SAFMNBCIEArch", "SAFMNBCIE"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/SCUNet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[SCUNet]:
size_requirements=SizeRequirements(minimum=40),
tiling=ModelTiling.DISCOURAGED,
)


__all__ = ["SCUNetArch", "SCUNet"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/SPAN/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[SPAN]:
input_channels=num_in_ch,
output_channels=num_out_ch,
)


__all__ = ["SPANArch", "SPAN"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/SwiftSRGAN/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[SwiftSRGAN]:
input_channels=in_channels,
output_channels=in_channels,
)


__all__ = ["SwiftSRGANArch", "SwiftSRGAN"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/Swin2SR/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,3 +184,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[Swin2SR]:
output_channels=in_chans,
size_requirements=SizeRequirements(minimum=16),
)


__all__ = ["Swin2SRArch", "Swin2SR"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/SwinIR/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,3 +189,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[SwinIR]:
output_channels=out_nc,
size_requirements=SizeRequirements(minimum=16),
)


__all__ = ["SwinIRArch", "SwinIR"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/Uformer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[Uformer]:
output_channels=dd_in,
size_requirements=SizeRequirements(multiple_of=128, square=True),
)


__all__ = ["UformerArch", "Uformer"]
2 changes: 2 additions & 0 deletions libs/spandrel/spandrel/architectures/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""
The package containing the implementations of all supported architectures. Not necessary for most user code.
"""

__docformat__ = "google"
Loading

0 comments on commit 7c1094f

Please sign in to comment.