Skip to content

Commit

Permalink
fix: don't keep numpy ndarray in debug.inputs dict
Browse files Browse the repository at this point in the history
  • Loading branch information
raphael0202 committed Mar 12, 2023
1 parent c925558 commit 64e7124
Showing 1 changed file with 35 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def generate_image_embeddings(product: JSONType, stub) -> Optional[np.ndarray]:
return np.frombuffer(
response.raw_output_contents[0],
dtype=np.float32,
).reshape((len(images), -1))
).reshape((len(non_null_images), -1))
return None


Expand Down Expand Up @@ -103,11 +103,7 @@ def predict(
threshold = 0.5

inputs = generate_inputs_dict(product, ocr_texts, image_embeddings)
debug: JSONType = {
"model_name": model_name.value,
"threshold": threshold,
"inputs": inputs,
}
debug = generate_debug_dict(model_name, threshold, inputs)
scores, labels = _predict(inputs, model_name)
indices = np.argsort(-scores)

Expand All @@ -125,6 +121,39 @@ def predict(
return category_predictions, debug


def generate_debug_dict(
model_name: NeuralCategoryClassifierModel, threshold: float, inputs: JSONType
) -> JSONType:
"""Generate dict containing debug information.
:param model_name: name of the model used during prediction
:param threshold: detection threshold used
:param inputs: inputs dict used for inference
:return: the debug dict
"""
debug = {
"model_name": model_name.value,
"threshold": threshold,
"inputs": {
k: v
for k, v in inputs.items()
# Don't keep numpy ndarray in debug.inputs dict
if k not in ("image_embeddings", "image_embeddings_mask")
},
}

if inputs["image_embeddings_mask"].sum() == 1:
# `image_embeddings_mask` always has at least one non-zero element,
# check whether there is an image or not by checking if
# `image_embeddings` is zero-filled
num_images = 0 if np.all(inputs["image_embeddings"][0] == 0) else 1
else:
num_images = int(inputs["image_embeddings_mask"].sum())

debug["inputs"]["num_images"] = num_images # type: ignore
return debug


# Parameters on how to prepare data for each model type, see `build_triton_request`
model_input_flags: dict[NeuralCategoryClassifierModel, dict] = {
NeuralCategoryClassifierModel.keras_image_embeddings_3_0: {},
Expand Down

0 comments on commit 64e7124

Please sign in to comment.