From b4b583ce4c61917ee48f1f1d6fdf38948627d267 Mon Sep 17 00:00:00 2001 From: Dan Date: Wed, 14 Dec 2022 16:30:49 -0800 Subject: [PATCH] cog predict.py + readme modification --- README.md | 12 ++++++- cog.yaml | 43 ++++++++++++++++++++++++ predict.py | 96 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 150 insertions(+), 1 deletion(-) create mode 100644 cog.yaml create mode 100644 predict.py diff --git a/README.md b/README.md index c90d3a6..3c416ef 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ ## LEDNet: Joint Low-light Enhancement and Deblurring in the Dark (ECCV 2022) -[Paper](https://arxiv.org/abs/2202.03373) | [Project Page](https://shangchenzhou.com/projects/LEDNet/) | [Video](https://youtu.be/450dkE-fOMY) +[Paper](https://arxiv.org/abs/2202.03373) | [Project Page](https://shangchenzhou.com/projects/LEDNet/) | [Video](https://youtu.be/450dkE-fOMY) | [Replicate Demo](https://replicate.com/sczhou/lednet) [Shangchen Zhou](https://shangchenzhou.com/), [Chongyi Li](https://li-chongyi.github.io/), [Chen Change Loy](https://www.mmlab-ntu.com/person/ccloy/) @@ -126,6 +126,16 @@ Run low-light generation: python scripts/generate_low_light_imgs.py --test_path 'IMG_ROOT' --result_path 'RESULT_ROOT' --model_path './weights/ce_zerodce.pth' ``` +### Inference with Cog +To run containerized local inference with LEDNet using [Cog](/~https://github.com/replicate/cog), run the following commands in the project root: + +``` +cog run python basicsr/setup.py develop +cog predict -i image=@'path/to/input_image.jpg' +``` + +You can view this demo running as an API [here on Replicate](https://replicate.com/sczhou/lednet). + ### License This project is licensed under S-Lab License 1.0. Redistribution and use for non-commercial purposes should follow this license. diff --git a/cog.yaml b/cog.yaml new file mode 100644 index 0000000..8fc5d4e --- /dev/null +++ b/cog.yaml @@ -0,0 +1,43 @@ +# Configuration for Cog ⚙️ +# Reference: /~https://github.com/replicate/cog/blob/main/docs/yaml.md + +build: + # set to true if your model requires a GPU + gpu: true + + # python version in the form '3.8' or '3.8.12' + python_version: "3.8" + + # opencv dependencies + system_packages: + - "ffmpeg" + - "libsm6" + - "libxext6" + + python_packages: + - "addict==2.4.0" + - "future==0.18.2" + - "lmdb==1.4.0" + - "lpips==0.1.4" + - "numpy==1.23.5" + - "opencv-python==4.6.0.66" + - "Pillow==9.3.0" + - "pyiqa==0.1.5" + - "PyYAML==6.0" + - "requests==2.28.1" + - "scikit-image==0.19.3" + - "scipy==1.9.3" + - "tb-nightly==2.12.0a20221213" + - "torch==1.13.0" + - "torchvision==0.14.0" + - "tqdm==4.64.1" + - "yapf==0.32.0" + + run: + # download models + - "mkdir -p weights" + - "curl -o weights/lednet.pth -L /~https://github.com/sczhou/LEDNet/releases/download/v0.1.0/lednet.pth" + - "curl -o weights/lednet_retrain_500000.pth -L /~https://github.com/sczhou/LEDNet/releases/download/v0.1.0/lednet_retrain_500000.pth" + +# predict.py defines how predictions are run on your model +predict: "predict.py:LEDNetPredictor" diff --git a/predict.py b/predict.py new file mode 100644 index 0000000..8dd9359 --- /dev/null +++ b/predict.py @@ -0,0 +1,96 @@ +# Modified by Shangchen Zhou from: /~https://github.com/TencentARC/GFPGAN/blob/master/inference_gfpgan.py +import os + +import cv2 +import torch +from cog import BasePredictor, Input, Path +from torchvision.transforms.functional import normalize + +from basicsr.utils import img2tensor, imwrite, tensor2img +from basicsr.utils.download_util import load_file_from_url +from basicsr.utils.registry import ARCH_REGISTRY +from inference_lednet import check_image_size + +pretrain_model_url = { + "lednet": "/~https://github.com/sczhou/LEDNet/releases/download/v0.1.0/lednet.pth", + "lednet_retrain": "/~https://github.com/sczhou/LEDNet/releases/download/v0.1.0/lednet_retrain_500000.pth", +} + +POTENTIAL_MODELS = list(pretrain_model_url.keys()) +DOWN_FACTOR = 8 # check_image_size +OUT_PATH = "./results" + + +class LEDNetPredictor(BasePredictor): + """ + Predictor wrapper around LEDNet + """ + + def setup(self): + """ + One-time setup method to load and prep model for efficient prediction. + """ + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.models = {} + + for model in POTENTIAL_MODELS: + net = ARCH_REGISTRY.get("LEDNet")( + channels=[32, 64, 128, 128], connection=False + ).to(self.device) + + ckpt_path = load_file_from_url( + url=pretrain_model_url[model], + model_dir="/weights", + progress=True, + file_name=None, + ) + checkpoint = torch.load(ckpt_path, map_location=self.device)["params"] + net.load_state_dict(checkpoint) + net.eval() + self.models[model] = net + + def predict( + self, + model: str = Input( + default="lednet", + description="pretrained model to use for inference", + choices=POTENTIAL_MODELS, + ), + image: Path = Input(description="Input image"), + ) -> Path: + """ + Runs inference with selected model on input image. + """ + net = self.models[model] + + img = cv2.imread(str(image), cv2.IMREAD_COLOR) + # prepare data + img_t = img2tensor(img / 255.0, bgr2rgb=True, float32=True) + + # without [-1,1] normalization in lednet model (paper version) + if not model == "lednet": + normalize(img_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) + img_t = img_t.unsqueeze(0).to(self.device) + + # lednet inference + with torch.no_grad(): + # check_image_size + H, W = img_t.shape[2:] + img_t = check_image_size(img_t, DOWN_FACTOR) + output_t = net(img_t) + output_t = output_t[:, :, :H, :W] + + if model == "lednet": + output = tensor2img(output_t, rgb2bgr=True, min_max=(0, 1)) + else: + output = tensor2img(output_t, rgb2bgr=True, min_max=(-1, 1)) + + del output_t + torch.cuda.empty_cache() + + output = output.astype("uint8") + # save restored img + save_restore_path = os.path.join(OUT_PATH, "out.jpg") + imwrite(output, save_restore_path) + + return Path(save_restore_path)