-
Notifications
You must be signed in to change notification settings - Fork 27
/
Copy pathtest.py
70 lines (53 loc) · 2.7 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import argparse
import torch
import torch.backends.cudnn as cudnn
import numpy as np
import PIL.Image as pil_image
from models import RDN
from utils import convert_rgb_to_y, denormalize, calc_psnr
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--weights-file', type=str, required=True)
parser.add_argument('--image-file', type=str, required=True)
parser.add_argument('--num-features', type=int, default=64)
parser.add_argument('--growth-rate', type=int, default=64)
parser.add_argument('--num-blocks', type=int, default=16)
parser.add_argument('--num-layers', type=int, default=8)
parser.add_argument('--scale', type=int, default=4)
args = parser.parse_args()
cudnn.benchmark = True
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = RDN(scale_factor=args.scale,
num_channels=3,
num_features=args.num_features,
growth_rate=args.growth_rate,
num_blocks=args.num_blocks,
num_layers=args.num_layers).to(device)
state_dict = model.state_dict()
for n, p in torch.load(args.weights_file, map_location=lambda storage, loc: storage).items():
if n in state_dict.keys():
state_dict[n].copy_(p)
else:
raise KeyError(n)
model.eval()
image = pil_image.open(args.image_file).convert('RGB')
image_width = (image.width // args.scale) * args.scale
image_height = (image.height // args.scale) * args.scale
hr = image.resize((image_width, image_height), resample=pil_image.BICUBIC)
lr = hr.resize((hr.width // args.scale, hr.height // args.scale), resample=pil_image.BICUBIC)
bicubic = lr.resize((lr.width * args.scale, lr.height * args.scale), resample=pil_image.BICUBIC)
bicubic.save(args.image_file.replace('.', '_bicubic_x{}.'.format(args.scale)))
lr = np.expand_dims(np.array(lr).astype(np.float32).transpose([2, 0, 1]), 0) / 255.0
hr = np.expand_dims(np.array(hr).astype(np.float32).transpose([2, 0, 1]), 0) / 255.0
lr = torch.from_numpy(lr).to(device)
hr = torch.from_numpy(hr).to(device)
with torch.no_grad():
preds = model(lr).squeeze(0)
preds_y = convert_rgb_to_y(denormalize(preds), dim_order='chw')
hr_y = convert_rgb_to_y(denormalize(hr.squeeze(0)), dim_order='chw')
preds_y = preds_y[args.scale:-args.scale, args.scale:-args.scale]
hr_y = hr_y[args.scale:-args.scale, args.scale:-args.scale]
psnr = calc_psnr(hr_y, preds_y)
print('PSNR: {:.2f}'.format(psnr))
output = pil_image.fromarray(denormalize(preds).permute(1, 2, 0).byte().cpu().numpy())
output.save(args.image_file.replace('.', '_rdn_x{}.'.format(args.scale)))