-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtorchserve_custom_handler.py
238 lines (187 loc) · 8.47 KB
/
torchserve_custom_handler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
import abc
import base64
import importlib.util
import inspect
import io
import logging
import os
import time
import numpy as np
import torch
from PIL import Image
from pre_post_processing import build_inference_transform, post_process_handle
logger = logging.getLogger(__name__)
def list_classes_from_module(module, parent_class=None):
"""
Parse user defined module to get all model service classes in it.
:param module:
:param parent_class:
:return: List of model service class definitions
"""
# Parsing the module to get all defined classes
classes = [cls[1] for cls in inspect.getmembers(module, lambda member: inspect.isclass(member) and
member.__module__ == module.__name__)]
# filter classes that is subclass of parent_class
if parent_class is not None:
return [c for c in classes if issubclass(c, parent_class)]
return classes
class CustomHandler(abc.ABC):
"""
Base default handler to load torchscript or eager mode [state_dict] models
Also, provides handle method per torch serve custom model specification
"""
def __init__(self):
self.model = None
self.mapping = None
self.device = None
self.initialized = False
self.context = None
self.manifest = None
self.map_location = None
self.explain = False
self.target = 0
def initialize(self, context):
"""Initialize function loads the model.pt file and initialized the model object.
First try to load torchscript else load eager mode state_dict based model.
Args:
context (context): It is a JSON Object containing information
pertaining to the model artifacts parameters.
Raises:
RuntimeError: Raises the Runtime error when the model.py is missing
"""
properties = context.system_properties
self.map_location = "cuda" if torch.cuda.is_available() and properties.get("gpu_id") is not None else "cpu"
self.device = torch.device(
self.map_location + ":" + str(properties.get("gpu_id"))
if torch.cuda.is_available() and properties.get("gpu_id") is not None
else self.map_location
)
self.manifest = context.manifest
model_dir = properties.get("model_dir")
model_pt_path = None
if "serializedFile" in self.manifest["model"]:
serialized_file = self.manifest["model"]["serializedFile"]
model_pt_path = os.path.join(model_dir, serialized_file)
# model def file
model_file = self.manifest["model"].get("modelFile", "")
if model_file:
logger.debug("Loading eager model")
self.model = self._load_pickled_model(model_dir, model_file, model_pt_path)
self.model.to(self.device)
else:
logger.debug("Loading torchscript model")
if not os.path.isfile(model_pt_path):
raise RuntimeError("Missing the model.pt file")
self.model = self._load_torchscript_model(model_pt_path)
self.model.eval()
logger.debug('Model file %s loaded successfully', model_pt_path)
self.initialized = True
self.image_processing = build_inference_transform(self.manifest['base_model'], size=224)
def _load_torchscript_model(self, model_pt_path):
"""Loads the PyTorch model and returns the NN model object.
Args:
model_pt_path (str): denotes the path of the model file.
Returns:
(NN Model Object) : Loads the model object.
"""
return torch.jit.load(model_pt_path, map_location=self.device)
def _load_pickled_model(self, model_dir, model_file, model_pt_path):
"""
Loads the pickle file from the given model path.
Args:
model_dir (str): Points to the location of the model artefacts.
model_file (.py): the file which contains the model class.
model_pt_path (str): points to the location of the model pickle file.
Raises:
RuntimeError: It raises this error when the model.py file is missing.
ValueError: Raises value error when there is more than one class in the label,
since the mapping supports only one label per class.
Returns:
serialized model file: Returns the pickled pytorch model file
"""
model_def_path = os.path.join(model_dir, model_file)
if not os.path.isfile(model_def_path):
raise RuntimeError("Missing the model.py file")
module = importlib.import_module(model_file.split(".")[0])
model_class_definitions = list_classes_from_module(module)
if len(model_class_definitions) != 1:
raise ValueError(
"Expected only one class as model definition. {}".format(
model_class_definitions
)
)
model_class = model_class_definitions[0]
model = model_class()
if model_pt_path:
state_dict = torch.load(model_pt_path, map_location=self.device)
model.load_state_dict(state_dict)
return model
def preprocess(self, data):
"""The preprocess function of MNIST program converts the input data to a float tensor
Args:
data (List): Input data from the request is in the form of a Tensor
Returns:
list : The preprocess function returns the input image as a list of float tensors.
"""
images = []
for row in data:
# Compat layer: normally the envelope should just return the data
# directly, but older versions of Torchserve didn't have envelope.
image = row.get("data") or row.get("body")
if isinstance(image, str):
# if the image is a string of bytesarray.
image = base64.b64decode(image)
# If the image is sent as bytesarray
if isinstance(image, (bytearray, bytes)):
image = Image.open(io.BytesIO(image))
image = self.image_processing(image=np.array(image))['image']
else:
# if the image is a list
image = torch.FloatTensor(image)
images.append(image)
return torch.stack(images).to(self.device)
def inference(self, data, *args, **kwargs):
"""
The Inference Function is used to make a prediction call on the given input request.
The user needs to override the inference function to customize it.
Args:
data (Torch Tensor): A Torch Tensor is passed to make the Inference Request.
The shape should match the model input shape.
Returns:
Torch Tensor : The Predicted Torch Tensor is returned in this function.
"""
marshalled_data = data.to(self.device)
with torch.no_grad():
results = self.model(marshalled_data, *args, **kwargs)
return results
def postprocess(self, data):
"""
The post process function makes use of the output from the inference and converts into a
Torchserve supported response output.
Args:
data (Torch Tensor): The torch tensor received from the prediction output of the model.
Returns:
List: The post process function returns a list of the predicted output.
"""
return post_process_handle(data)
def handle(self, data, context):
"""Entry point for default handler. It takes the data from the input request and returns
the predicted outcome for the input.
Args:
data (list): The input data that needs to be made a prediction request on.
context (Context): It is a JSON Object containing information pertaining to
the model artefacts parameters.
Returns:
list : Returns a list of dictionary with the predicted response.
"""
# It can be used for pre or post processing if needed as additional request
# information is available in context
start_time = time.time()
self.context = context
metrics = self.context.metrics
data_preprocess = self.preprocess(data)
output = self.inference(data_preprocess)
output = self.postprocess(output)
stop_time = time.time()
metrics.add_time('HandlerTime', round((stop_time - start_time) * 1000, 2), None, 'ms')
return output