-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtesting.py
84 lines (67 loc) · 2.86 KB
/
testing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import numpy as np
import os, sys
from tqdm import tqdm
from options.options import parse
import argparse
parser = argparse.ArgumentParser(description="Script for testing")
parser.add_argument('-p', '--config', type=str, default='./options/test/LOLBlur.yml', help = 'Config file of testing')
args = parser.parse_args()
# read the options file and define the variables from it. If you want to change the hyperparameters of the net and the conditions of training go to
# the file and change them what you need
path_options = args.config
opt = parse(path_options)
os.environ["CUDA_VISIBLE_DEVICES"]= "0" # you need to fix this before importing torch
# PyTorch library
import torch
import torch.optim
import torch.multiprocessing as mp
import torch.distributed as dist
from data.dataset_reader.datapipeline import *
from archs import *
from losses import *
from data import *
from utils.utils import create_path_models
from utils.test_utils import *
from ptflops import get_model_complexity_info
#parameters for saving model
PATH_MODEL= create_path_models(opt['save'])
def load_model(model, path_weights):
map_location = 'cpu'
checkpoints = torch.load(path_weights, map_location=map_location, weights_only=False)
# print(checkpoints.keys())
# sys.exit()
weights = checkpoints['params']
weights = {'module.' + key: value for key, value in weights.items()}
macs, params = get_model_complexity_info(model, (3, 256, 256), print_per_layer_stat=False, verbose=False)
print('Network complexity: ' ,macs, params)
model.load_state_dict(weights)
print('Loaded weights correctly')
return model
def run_evaluation(rank, world_size):
setup(rank, world_size=world_size)
# LOAD THE DATALOADERS
test_loader, _ = create_test_data(rank, world_size=world_size, opt = opt['datasets'])
# DEFINE NETWORK
model, _, _ = create_model(opt['network'], rank=rank)
model = load_model(model, opt['save']['path'])
metrics_eval = {}
# Ensure all processes have reached this point
dist.barrier()
# eval phase
model.eval()
metrics_eval, _ = eval_model(model, test_loader, metrics_eval, rank=rank, world_size=world_size, eta = True)
# Ensure all processes have reached this point
dist.barrier()
# print some results
if rank==0:
if type(next(iter(metrics_eval.values()))) == dict:
for key, metric_eval in metrics_eval.items():
print(f" \t {key} --- PSNR: {metric_eval['valid_psnr']}, SSIM: {metric_eval['valid_ssim']}, LPIPS: {metric_eval['valid_lpips']}")
else:
print(f" \t {opt['datasets']['name']} --- PSNR: {metrics_eval['valid_psnr']}, SSIM: {metrics_eval['valid_ssim']}, LPIPS: {metrics_eval['valid_lpips']}")
cleanup()
def main():
world_size = 1
mp.spawn(run_evaluation, args =(world_size,), nprocs=world_size, join=True)
if __name__ == '__main__':
main()