Skip to content

Commit

Permalink
feat: add model with image embeddings as input
Browse files Browse the repository at this point in the history
  • Loading branch information
raphael0202 committed Mar 12, 2023
1 parent a39d61a commit d79bbc2
Show file tree
Hide file tree
Showing 7 changed files with 176 additions and 21 deletions.
26 changes: 21 additions & 5 deletions robotoff/cli/main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import pathlib
import sys
from pathlib import Path
Expand All @@ -7,7 +8,12 @@

from robotoff.elasticsearch.client import get_es_client
from robotoff.off import get_barcode_from_url
from robotoff.types import ObjectDetectionModel, PredictionType, WorkerQueue
from robotoff.types import (
NeuralCategoryClassifierModel,
ObjectDetectionModel,
PredictionType,
WorkerQueue,
)

app = typer.Typer()

Expand Down Expand Up @@ -129,27 +135,37 @@ def download_dataset(minify: bool = False) -> None:


@app.command()
def categorize(barcode: str, deepest_only: bool = False) -> None:
def categorize(
barcode: str,
deepest_only: bool = False,
model_name: NeuralCategoryClassifierModel = typer.Option(
NeuralCategoryClassifierModel.keras_2_0, help="name of the model to use"
),
threshold: Optional[float] = typer.Option(0.5, help="detection threshold to use"),
) -> None:
"""Predict product categories based on the neural category classifier.
deepest_only: controls whether the returned predictions should only contain the deepmost
categories for a predicted taxonomy chain.
For example, if we predict 'fresh vegetables' -> 'legumes' -> 'beans' for a product,
setting deepest_only=True will return 'beans'."""
from robotoff.off import get_product
from robotoff.prediction.category.neural.category_classifier import (
CategoryClassifier,
)
from robotoff.products import get_product
from robotoff.taxonomy import TaxonomyType, get_taxonomy
from robotoff.utils import get_logger

get_logger(level=logging.DEBUG)

product = get_product(barcode)
if product is None:
print(f"Product {barcode} not found")
return

predictions, _ = CategoryClassifier(
get_taxonomy(TaxonomyType.category.name)
).predict(product, deepest_only)
get_taxonomy(TaxonomyType.category.name, offline=True)
).predict(product, deepest_only, threshold=threshold, model_name=model_name)

if predictions:
for prediction in predictions:
Expand Down
17 changes: 16 additions & 1 deletion robotoff/prediction/category/neural/category_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@

from robotoff.prediction.types import Prediction
from robotoff.taxonomy import Taxonomy
from robotoff.triton import get_triton_inference_stub
from robotoff.types import JSONType, NeuralCategoryClassifierModel, PredictionType
from robotoff.utils import get_logger

from . import keras_category_classifier_2_0, keras_category_classifier_3_0

logger = get_logger(__name__)


class CategoryPrediction:
"""CategoryPrediction stores information about a category classification prediction."""
Expand Down Expand Up @@ -69,6 +73,8 @@ def predict(
:param neural_model_name: the name of the neural model to use to perform
prediction. `keras_2_0` is used by default.
"""
logger.debug("predicting category with model %s", model_name)

if threshold is None:
threshold = 0.5

Expand Down Expand Up @@ -97,8 +103,17 @@ def predict(
else:
# Otherwise we fetch OCR texts from Product Opener
ocr_texts = keras_category_classifier_3_0.fetch_ocr_texts(product)

triton_stub = get_triton_inference_stub()
image_embeddings = keras_category_classifier_3_0.generate_image_embeddings(
product, triton_stub
)
raw_predictions, debug = keras_category_classifier_3_0.predict(
product, ocr_texts, model_name, threshold=threshold
product,
ocr_texts,
model_name,
threshold=threshold,
image_embeddings=image_embeddings,
)

category_predictions = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,65 @@
import numpy as np
from tritonclient.grpc import service_pb2

from robotoff.off import generate_json_ocr_url
from robotoff.off import generate_image_url, generate_json_ocr_url
from robotoff.prediction.ocr.core import get_ocr_result
from robotoff.triton import (
deserialize_byte_tensor,
generate_clip_embedding_request,
get_triton_inference_stub,
serialize_byte_tensor,
)
from robotoff.types import JSONType, NeuralCategoryClassifierModel
from robotoff.utils import http_session
from robotoff.utils import get_image_from_url, get_logger, http_session

from .preprocessing import NUTRIMENT_NAMES, generate_inputs_from_product
from .preprocessing import (
IMAGE_EMBEDDING_DIM,
MAX_IMAGE_EMBEDDING,
NUTRIMENT_NAMES,
generate_inputs_dict,
)

logger = get_logger(__name__)


def generate_image_embeddings(product: JSONType, stub) -> Optional[np.ndarray]:
"""Generate image embeddings using CLIP model for the `MAX_IMAGE_EMBEDDING`
most recent images.
:param product: product data
:param stub: the triton inference stub to use
:return: None if no image was available or a numpy array of shape
(num_images, IMAGE_EMBEDDING_DIM)
"""
# Fetch the `MAX_IMAGE_EMBEDDING` most recent "raw" images
image_ids = sorted(
(int(image_id) for image_id in product.get("images", {}) if image_id.isdigit()),
reverse=True,
)[:MAX_IMAGE_EMBEDDING]
if image_ids:
barcode = product["code"]
image_urls = [
generate_image_url(barcode, f"{image_id}.400") for image_id in image_ids
]
images = [
get_image_from_url(image_url, error_raise=False, session=http_session)
for image_url in image_urls
]
non_null_images = [image for image in images if image is not None]
if len(images) != len(non_null_images):
logger.info(
"%d images could not be fetched (over %d)",
len(images) - len(non_null_images),
len(images),
)

request = generate_clip_embedding_request(non_null_images)
response = stub.ModelInfer(request)
return np.frombuffer(
response.raw_output_contents[0],
dtype=np.float32,
).reshape((len(images), -1))
return None


def fetch_ocr_texts(product: JSONType) -> list[str]:
Expand All @@ -39,6 +87,7 @@ def predict(
ocr_texts: list[str],
model_name: NeuralCategoryClassifierModel,
threshold: Optional[float] = None,
image_embeddings: Optional[np.ndarray] = None,
) -> tuple[list[tuple[str, float]], JSONType]:
"""Predict categories using v3 model.
Expand All @@ -53,7 +102,7 @@ def predict(
if threshold is None:
threshold = 0.5

inputs = generate_inputs_from_product(product, ocr_texts)
inputs = generate_inputs_dict(product, ocr_texts, image_embeddings)
debug: JSONType = {
"model_name": model_name.value,
"threshold": threshold,
Expand All @@ -78,24 +127,31 @@ def predict(

# Parameters on how to prepare data for each model type, see `build_triton_request`
model_input_flags: dict[NeuralCategoryClassifierModel, dict] = {
NeuralCategoryClassifierModel.keras_sota_3_0: {},
NeuralCategoryClassifierModel.keras_ingredient_ocr_3_0: {},
NeuralCategoryClassifierModel.keras_image_embeddings_3_0: {},
NeuralCategoryClassifierModel.keras_300_epochs_3_0: {"add_image_embeddings": False},
NeuralCategoryClassifierModel.keras_ingredient_ocr_3_0: {
"add_image_embeddings": False,
},
NeuralCategoryClassifierModel.keras_baseline_3_0: {
"add_ingredients_ocr_tags": False
"add_ingredients_ocr_tags": False,
"add_image_embeddings": False,
},
NeuralCategoryClassifierModel.keras_original_3_0: {
"add_ingredients_ocr_tags": False,
"add_nutriments": False,
"add_image_embeddings": False,
},
NeuralCategoryClassifierModel.keras_product_name_only_3_0: {
"add_ingredients_ocr_tags": False,
"add_nutriments": False,
"add_ingredient_tags": False,
"add_image_embeddings": False,
},
}

triton_model_names = {
NeuralCategoryClassifierModel.keras_sota_3_0: "category-classifier-keras-sota-3.0",
NeuralCategoryClassifierModel.keras_image_embeddings_3_0: "category-classifier-keras-image-embeddings-3.0",
NeuralCategoryClassifierModel.keras_300_epochs_3_0: "category-classifier-keras-300-epochs-3.0",
NeuralCategoryClassifierModel.keras_ingredient_ocr_3_0: "category-classifier-keras-ingredient-ocr-3.0",
NeuralCategoryClassifierModel.keras_baseline_3_0: "category-classifier-keras-baseline-3.0",
NeuralCategoryClassifierModel.keras_original_3_0: "category-classifier-keras-original-3.0",
Expand Down Expand Up @@ -128,11 +184,12 @@ def build_triton_request(
add_ingredient_tags: bool = True,
add_nutriments: bool = True,
add_ingredients_ocr_tags: bool = True,
add_image_embeddings: bool = True,
):
"""Build a Triton ModelInferRequest gRPC request.
:param inputs: the input dict, as generated by
`generate_inputs_from_product`
`generate_inputs_dict`
:param model_name: the name of the model to use, see global variable
`triton_model_names` for possible values
:param add_product_name: if True, add product name as input, defaults to
Expand All @@ -143,6 +200,8 @@ def build_triton_request(
True
:param add_ingredients_ocr_tags: if True, add ingredients extracted from
OCR as input, defaults to True
:param add_image_embeddings: if True, add image embeddings as input,
defaults to True
:return: the gRPC ModelInferRequest
"""
product_name = inputs["product_name"]
Expand Down Expand Up @@ -193,4 +252,23 @@ def build_triton_request(
[serialize_byte_tensor(np.array([ingredients_ocr_tags], dtype=object))]
)

if add_image_embeddings:
image_embeddings_input = service_pb2.ModelInferRequest().InferInputTensor()
image_embeddings_input.name = "image_embeddings"
image_embeddings_input.datatype = "FP32"
image_embeddings_input.shape.extend(
[1, MAX_IMAGE_EMBEDDING, IMAGE_EMBEDDING_DIM]
)
request.inputs.extend([image_embeddings_input])
value = inputs["image_embeddings"]
request.raw_input_contents.extend([value.tobytes()])

image_embeddings_mask_input = service_pb2.ModelInferRequest().InferInputTensor()
image_embeddings_mask_input.name = "image_embeddings_mask"
image_embeddings_mask_input.datatype = "FP32"
image_embeddings_mask_input.shape.extend([1, MAX_IMAGE_EMBEDDING])
request.inputs.extend([image_embeddings_mask_input])
value = inputs["image_embeddings_mask"]
request.raw_input_contents.extend([value.tobytes()])

return request
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from collections import defaultdict
from typing import Optional

import numpy as np
from flashtext import KeywordProcessor

from robotoff import settings
Expand All @@ -25,6 +26,8 @@
"energy_kcal",
"fruits_vegetables_nuts",
)
MAX_IMAGE_EMBEDDING = 10
IMAGE_EMBEDDING_DIM = 512


@functools.cache
Expand All @@ -44,12 +47,18 @@ def get_ingredient_processor():
)


def generate_inputs_from_product(product: JSONType, ocr_texts: list[str]) -> JSONType:
def generate_inputs_dict(
product: JSONType,
ocr_texts: list[str],
image_embeddings: Optional[np.ndarray] = None,
) -> JSONType:
"""Generate inputs for v3 category predictor model.
:param product: the product dict, the `product_name` and `ingredients`
fields are used, if provided
:param ocr_texts: a list of detected OCR texts, one per image
:param image_embeddings: embeddings generated by CLIP model of up to
the `MAX_IMAGE_EMBEDDING` most recent images.
:return: a dict containing inputs for v3 category predictor model
"""
ingredient_taxonomy = get_ingredient_taxonomy()
Expand All @@ -73,6 +82,41 @@ def generate_inputs_from_product(product: JSONType, ocr_texts: list[str]) -> JSO
nutriments.get(f"{nutriment_name.replace('_', '-')}_100g"),
nutriment_name=nutriment_name,
)

if image_embeddings is None:
# No image is available, so we provide zero-filled image embedding
# and embedding mask with a single non-zero element
# The GlobalAveragePooling1d that follows multi-head attention
# requires at least one non-masked step
image_embeddings = np.zeros(
(MAX_IMAGE_EMBEDDING, IMAGE_EMBEDDING_DIM), dtype=np.float32
)
num_images = 1
else:
if len(image_embeddings) < MAX_IMAGE_EMBEDDING:
num_images = len(image_embeddings)
# Fill padded positions with zero vectors
image_embeddings = np.concatenate(
[
image_embeddings,
np.zeros(
(
MAX_IMAGE_EMBEDDING - num_images,
IMAGE_EMBEDDING_DIM,
),
dtype=np.float32,
),
]
)
else:
num_images = MAX_IMAGE_EMBEDDING
image_embeddings = image_embeddings[:MAX_IMAGE_EMBEDDING]

inputs["image_embeddings"] = image_embeddings
# mask for multi-head attention and average pooling
inputs["image_embeddings_mask"] = np.array(
[1] * num_images + [0] * (MAX_IMAGE_EMBEDDING - num_images), dtype=np.float32
)
return inputs


Expand Down
2 changes: 1 addition & 1 deletion robotoff/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


@cachetools.cached(cachetools.Cache(maxsize=1))
def get_triton_inference_stub():
def get_triton_inference_stub() -> service_pb2_grpc.GRPCInferenceServiceStub:
channel = grpc.insecure_channel(settings.TRITON_URI)
return service_pb2_grpc.GRPCInferenceServiceStub(channel)

Expand Down
3 changes: 2 additions & 1 deletion robotoff/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ class ObjectDetectionModel(enum.Enum):
@enum.unique
class NeuralCategoryClassifierModel(enum.Enum):
keras_2_0 = "keras-2.0"
keras_sota_3_0 = "keras-sota-3-0"
keras_image_embeddings_3_0 = "keras-image-embeddings-3.0"
keras_300_epochs_3_0 = "keras-300-epochs-3-0"
keras_ingredient_ocr_3_0 = "keras-ingredient-ocr-3.0"
keras_baseline_3_0 = "keras-baseline-3.0"
keras_original_3_0 = "keras-original-3.0"
Expand Down
7 changes: 4 additions & 3 deletions tests/unit/cli/test_main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest

from robotoff.cli.main import categorize, init_elasticsearch
from robotoff.types import NeuralCategoryClassifierModel


def test_init_elasticsearch(mocker):
Expand Down Expand Up @@ -55,7 +56,7 @@ def test_categorize_no_product(mocker, capsys):
)
def test_categorize(mocker, capsys, confidence, want_nothing):
mocker.patch(
"robotoff.products.get_product",
"robotoff.off.get_product",
return_value={
"product_name": "Test Product",
"ingredients_tags": ["ingredient1"],
Expand All @@ -66,7 +67,7 @@ def test_categorize(mocker, capsys, confidence, want_nothing):
return_value=_construct_prediction_resp("en:chicken", confidence),
)

categorize("123")
categorize("123", threshold=0.5, model_name=NeuralCategoryClassifierModel.keras_2_0)
captured = capsys.readouterr()

assert captured.out.startswith("Nothing predicted") == want_nothing
assert ("Nothing predicted" in captured.out) == want_nothing

0 comments on commit d79bbc2

Please sign in to comment.