-
Notifications
You must be signed in to change notification settings - Fork 123
/
Copy pathinference.py
118 lines (94 loc) · 3.92 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from __future__ import print_function
import argparse
import os
import sys
import time
import tensorflow as tf
import numpy as np
from scipy import misc
from model import PSPNet101, PSPNet50
from tools import *
ADE20k_param = {'crop_size': [473, 473],
'num_classes': 150,
'model': PSPNet50}
cityscapes_param = {'crop_size': [720, 720],
'num_classes': 19,
'model': PSPNet101}
SAVE_DIR = './output/'
SNAPSHOT_DIR = './model/'
def get_arguments():
parser = argparse.ArgumentParser(description="Reproduced PSPNet")
parser.add_argument("--img-path", type=str, default='',
help="Path to the RGB image file.")
parser.add_argument("--checkpoints", type=str, default=SNAPSHOT_DIR,
help="Path to restore weights.")
parser.add_argument("--save-dir", type=str, default=SAVE_DIR,
help="Path to save output.")
parser.add_argument("--flipped-eval", action="store_true",
help="whether to evaluate with flipped img.")
parser.add_argument("--dataset", type=str, default='',
choices=['ade20k', 'cityscapes'],
required=True)
return parser.parse_args()
def save(saver, sess, logdir, step):
model_name = 'model.ckpt'
checkpoint_path = os.path.join(logdir, model_name)
if not os.path.exists(logdir):
os.makedirs(logdir)
saver.save(sess, checkpoint_path, global_step=step)
print('The checkpoint has been created.')
def load(saver, sess, ckpt_path):
saver.restore(sess, ckpt_path)
print("Restored model parameters from {}".format(ckpt_path))
def main():
args = get_arguments()
# load parameters
if args.dataset == 'ade20k':
param = ADE20k_param
elif args.dataset == 'cityscapes':
param = cityscapes_param
crop_size = param['crop_size']
num_classes = param['num_classes']
PSPNet = param['model']
# preprocess images
img, filename = load_img(args.img_path)
img_shape = tf.shape(img)
h, w = (tf.maximum(crop_size[0], img_shape[0]), tf.maximum(crop_size[1], img_shape[1]))
img = preprocess(img, h, w)
# Create network.
net = PSPNet({'data': img}, is_training=False, num_classes=num_classes)
with tf.variable_scope('', reuse=True):
flipped_img = tf.image.flip_left_right(tf.squeeze(img))
flipped_img = tf.expand_dims(flipped_img, dim=0)
net2 = PSPNet({'data': flipped_img}, is_training=False, num_classes=num_classes)
raw_output = net.layers['conv6']
# Do flipped eval or not
if args.flipped_eval:
flipped_output = tf.image.flip_left_right(tf.squeeze(net2.layers['conv6']))
flipped_output = tf.expand_dims(flipped_output, dim=0)
raw_output = tf.add_n([raw_output, flipped_output])
# Predictions.
raw_output_up = tf.image.resize_bilinear(raw_output, size=[h, w], align_corners=True)
raw_output_up = tf.image.crop_to_bounding_box(raw_output_up, 0, 0, img_shape[0], img_shape[1])
raw_output_up = tf.argmax(raw_output_up, axis=3)
pred = decode_labels(raw_output_up, img_shape, num_classes)
# Init tf Session
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
init = tf.global_variables_initializer()
sess.run(init)
restore_var = tf.global_variables()
ckpt = tf.train.get_checkpoint_state(args.checkpoints)
if ckpt and ckpt.model_checkpoint_path:
loader = tf.train.Saver(var_list=restore_var)
load_step = int(os.path.basename(ckpt.model_checkpoint_path).split('-')[1])
load(loader, sess, ckpt.model_checkpoint_path)
else:
print('No checkpoint file found.')
preds = sess.run(pred)
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)
misc.imsave(args.save_dir + filename, preds[0])
if __name__ == '__main__':
main()