-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrt_calibrator.py
151 lines (120 loc) · 4.47 KB
/
trt_calibrator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
from __future__ import print_function
import os
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit # fix init error of cuda
# __all__ = [
# "TRTPercentileCalibrator",
# "TRTEntropyCalibrator",
# "TRTMinMaxCalibrator",
# ]
class TRTEntropyCalibrator(trt.IInt8EntropyCalibrator2):
def __init__(self, input_layers, stream, cache_file):
super(TRTEntropyCalibrator, self).__init__()
self.input_layers = input_layers
# 数据读取的类, 等同于图片处理的回调
self.stream = stream
# 分配GPU
self.d_input = cuda.mem_alloc(self.stream.calibration_data.nbytes)
# cache路径
self.cache_file = cache_file
# 重置校准集
self.stream.reset()
def get_batch_size(self):
return self.stream.batch_size
def get_batch(self, names):
try:
batch = self.stream.next_batch()
if not batch.size:
return None
cuda.memcpy_htod(self.d_input, batch)
return [int(self.d_input)]
except StopIteration:
return None
def read_calibration_cache(self):
# If there is a cache, use it instead of calibrating again. Otherwise, implicitly return None.
if os.path.exists(self.cache_file):
with open(self.cache_file, "rb") as f:
return f.read()
else:
return None
def write_calibration_cache(self, cache):
# cache = ctypes.c_char_p(int(ptr))
with open(self.cache_file, "wb") as f:
f.write(cache)
class TRTMinMaxCalibrator(trt.IInt8MinMaxCalibrator):
def __init__(self, input_layers, stream, cache_file):
super(TRTMinMaxCalibrator, self).__init__()
self.input_layers = input_layers
# 数据读取的类, 等同于图片处理的回调
self.stream = stream
# 分配GP
self.d_input = cuda.mem_alloc(self.stream.calibration_data.nbytes)
# cache路径
self.cache_file = cache_file
# 重置校准集
self.stream.reset()
def get_batch_size(self):
return self.stream.batch_size
def get_batch(self, names):
try:
batch = self.stream.next_batch()
if not batch.size:
return None
cuda.memcpy_htod(self.d_input, batch)
return [int(self.d_input)]
except StopIteration:
return None
def read_calibration_cache(self):
# If there is a cache, use it instead of calibrating again. Otherwise, implicitly return None.
if os.path.exists(self.cache_file):
with open(self.cache_file, "rb") as f:
return f.read()
else:
return None
def write_calibration_cache(self, cache):
# cache = ctypes.c_char_p(int(ptr))
with open(self.cache_file, "wb") as f:
f.write(cache)
class TRTPercentileCalibrator(trt.IInt8LegacyCalibrator):
def __init__(
self, input_layers, stream, cache_file, quantile=0.9995, regression_cutoff=1.0
):
super(TRTPercentileCalibrator, self).__init__()
self.input_layers = input_layers
self.stream = stream
self.d_input = cuda.mem_alloc(self.stream.calibration_data.nbytes)
self.cache_file = cache_file
self.stream.reset()
self.quantile = quantile
self.regression_cutoff = regression_cutoff
def get_batch_size(self):
return self.stream.batch_size
def get_batch(self, names):
try:
batch = self.stream.next_batch()
if not batch.size:
return None
cuda.memcpy_htod(self.d_input, batch)
return [int(self.d_input)]
except StopIteration:
return None
def read_calibration_cache(self):
# If there is a cache, use it instead of calibrating again. Otherwise, implicitly return None.
if os.path.exists(self.cache_file):
with open(self.cache_file, "rb") as f:
return f.read()
else:
return None
def write_calibration_cache(self, cache):
# cache = ctypes.c_char_p(int(ptr))
with open(self.cache_file, "wb") as f:
f.write(cache)
def get_quantile(self):
return self.quantile
def get_regression_cutoff(self):
return self.regression_cutoff
def read_histogram_cache(self, length):
return None
def write_histogram_cache(self, ptr, length):
return None