This repository has been archived by the owner on Jun 5, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 137
/
Copy pathtensorrt_optimize_tf1.15.py
56 lines (44 loc) · 1.57 KB
/
tensorrt_optimize_tf1.15.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
# coding=utf-8
"""Given the tensorflow frozen graph, use TensorRT to optimize,
get a new frozen graph."""
from __future__ import print_function
import argparse
import time
import tensorflow as tf
from tensorflow.python.compiler.tensorrt import trt_convert as trt
parser = argparse.ArgumentParser()
parser.add_argument("pbfile")
parser.add_argument("newpbfile")
parser.add_argument("--precision_mode", default="FP32",
help="FP32, FP16, or INT8")
# parameter
# https://docs.nvidia.com/deeplearning/frameworks/tf-trt-user-guide/index.html
if __name__ == "__main__":
args = parser.parse_args()
# not sure what these do, so leave them default
#max_batch_size = 1
#minimum_segment_size = 2 # smaller the faster? 5 -60?
#max_workspace_size_bytes = 1 << 32
#maximum_cached_engines = 1
output_names = [
"final_boxes",
"final_labels",
"final_probs",
"fpn_box_feat",
]
tf_config = tf.ConfigProto()
tf_config.gpu_options.allow_growth = True
with tf.Graph().as_default() as tf_graph:
with tf.Session(config=tf_config) as tf_sess:
with tf.gfile.GFile(args.pbfile, "rb") as f:
frozen_graph = tf.GraphDef()
frozen_graph.ParseFromString(f.read())
converter = trt.TrtGraphConverter(
input_graph_def=frozen_graph,
nodes_blacklist=output_names,
is_dynamic_op=False,
precision_mode=args.precision_mode) #output nodes
trt_graph = converter.convert()
#converter.save(args.newpbfile)
with open(args.newpbfile, "wb") as f:
f.write(trt_graph.SerializeToString())