Skip to content

Commit

Permalink
Merge pull request #13 from daanelson/cogify
Browse files Browse the repository at this point in the history
Add Replicate demo
  • Loading branch information
sczhou authored Jan 25, 2023
2 parents b505e65 + b4b583c commit 9fa3c0a
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 1 deletion.
12 changes: 11 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

## LEDNet: Joint Low-light Enhancement and Deblurring in the Dark (ECCV 2022)

[Paper](https://arxiv.org/abs/2202.03373) | [Project Page](https://shangchenzhou.com/projects/LEDNet/) | [Video](https://youtu.be/450dkE-fOMY)
[Paper](https://arxiv.org/abs/2202.03373) | [Project Page](https://shangchenzhou.com/projects/LEDNet/) | [Video](https://youtu.be/450dkE-fOMY) | [Replicate Demo](https://replicate.com/sczhou/lednet)

[Shangchen Zhou](https://shangchenzhou.com/), [Chongyi Li](https://li-chongyi.github.io/), [Chen Change Loy](https://www.mmlab-ntu.com/person/ccloy/)

Expand Down Expand Up @@ -126,6 +126,16 @@ Run low-light generation:
python scripts/generate_low_light_imgs.py --test_path 'IMG_ROOT' --result_path 'RESULT_ROOT' --model_path './weights/ce_zerodce.pth'
```

### Inference with Cog
To run containerized local inference with LEDNet using [Cog](/~https://github.com/replicate/cog), run the following commands in the project root:

```
cog run python basicsr/setup.py develop
cog predict -i image=@'path/to/input_image.jpg'
```

You can view this demo running as an API [here on Replicate](https://replicate.com/sczhou/lednet).

### License

This project is licensed under <a rel="license" href="/~https://github.com/sczhou/LEDNet/blob/master/LICENSE">S-Lab License 1.0</a>. Redistribution and use for non-commercial purposes should follow this license.
Expand Down
43 changes: 43 additions & 0 deletions cog.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Configuration for Cog ⚙️
# Reference: /~https://github.com/replicate/cog/blob/main/docs/yaml.md

build:
# set to true if your model requires a GPU
gpu: true

# python version in the form '3.8' or '3.8.12'
python_version: "3.8"

# opencv dependencies
system_packages:
- "ffmpeg"
- "libsm6"
- "libxext6"

python_packages:
- "addict==2.4.0"
- "future==0.18.2"
- "lmdb==1.4.0"
- "lpips==0.1.4"
- "numpy==1.23.5"
- "opencv-python==4.6.0.66"
- "Pillow==9.3.0"
- "pyiqa==0.1.5"
- "PyYAML==6.0"
- "requests==2.28.1"
- "scikit-image==0.19.3"
- "scipy==1.9.3"
- "tb-nightly==2.12.0a20221213"
- "torch==1.13.0"
- "torchvision==0.14.0"
- "tqdm==4.64.1"
- "yapf==0.32.0"

run:
# download models
- "mkdir -p weights"
- "curl -o weights/lednet.pth -L /~https://github.com/sczhou/LEDNet/releases/download/v0.1.0/lednet.pth"
- "curl -o weights/lednet_retrain_500000.pth -L /~https://github.com/sczhou/LEDNet/releases/download/v0.1.0/lednet_retrain_500000.pth"

# predict.py defines how predictions are run on your model
predict: "predict.py:LEDNetPredictor"
96 changes: 96 additions & 0 deletions predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Modified by Shangchen Zhou from: /~https://github.com/TencentARC/GFPGAN/blob/master/inference_gfpgan.py
import os

import cv2
import torch
from cog import BasePredictor, Input, Path
from torchvision.transforms.functional import normalize

from basicsr.utils import img2tensor, imwrite, tensor2img
from basicsr.utils.download_util import load_file_from_url
from basicsr.utils.registry import ARCH_REGISTRY
from inference_lednet import check_image_size

pretrain_model_url = {
"lednet": "/~https://github.com/sczhou/LEDNet/releases/download/v0.1.0/lednet.pth",
"lednet_retrain": "/~https://github.com/sczhou/LEDNet/releases/download/v0.1.0/lednet_retrain_500000.pth",
}

POTENTIAL_MODELS = list(pretrain_model_url.keys())
DOWN_FACTOR = 8 # check_image_size
OUT_PATH = "./results"


class LEDNetPredictor(BasePredictor):
"""
Predictor wrapper around LEDNet
"""

def setup(self):
"""
One-time setup method to load and prep model for efficient prediction.
"""
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.models = {}

for model in POTENTIAL_MODELS:
net = ARCH_REGISTRY.get("LEDNet")(
channels=[32, 64, 128, 128], connection=False
).to(self.device)

ckpt_path = load_file_from_url(
url=pretrain_model_url[model],
model_dir="/weights",
progress=True,
file_name=None,
)
checkpoint = torch.load(ckpt_path, map_location=self.device)["params"]
net.load_state_dict(checkpoint)
net.eval()
self.models[model] = net

def predict(
self,
model: str = Input(
default="lednet",
description="pretrained model to use for inference",
choices=POTENTIAL_MODELS,
),
image: Path = Input(description="Input image"),
) -> Path:
"""
Runs inference with selected model on input image.
"""
net = self.models[model]

img = cv2.imread(str(image), cv2.IMREAD_COLOR)
# prepare data
img_t = img2tensor(img / 255.0, bgr2rgb=True, float32=True)

# without [-1,1] normalization in lednet model (paper version)
if not model == "lednet":
normalize(img_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
img_t = img_t.unsqueeze(0).to(self.device)

# lednet inference
with torch.no_grad():
# check_image_size
H, W = img_t.shape[2:]
img_t = check_image_size(img_t, DOWN_FACTOR)
output_t = net(img_t)
output_t = output_t[:, :, :H, :W]

if model == "lednet":
output = tensor2img(output_t, rgb2bgr=True, min_max=(0, 1))
else:
output = tensor2img(output_t, rgb2bgr=True, min_max=(-1, 1))

del output_t
torch.cuda.empty_cache()

output = output.astype("uint8")
# save restored img
save_restore_path = os.path.join(OUT_PATH, "out.jpg")
imwrite(output, save_restore_path)

return Path(save_restore_path)

0 comments on commit 9fa3c0a

Please sign in to comment.