Skip to content

Commit

Permalink
fix: add new CLI command to rerun image import for all images (#1482)
Browse files Browse the repository at this point in the history
  • Loading branch information
raphael0202 authored Dec 4, 2024
1 parent fb84223 commit 98a1374
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 53 deletions.
36 changes: 36 additions & 0 deletions robotoff/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import typer

from robotoff.types import (
ImportImageFlag,
ObjectDetectionModel,
PredictionType,
ProductIdentifier,
Expand Down Expand Up @@ -617,6 +618,41 @@ def run_object_detection_model(
)


@app.command()
def rerun_import_all_images(
server_type: Optional[ServerType] = typer.Option(
None, help="Server type of the product"
),
limit: Optional[int] = typer.Option(
None, help="the maximum number of images to process, defaults to None (all)"
),
flags: list[ImportImageFlag] = typer.Option(
None, help="Flags to use for image import"
),
):
"""Rerun full image import on all images in DB.
This includes launching all ML models and insight extraction from the image and
associated OCR. To control which tasks are rerun, use the --flags option.
"""
from robotoff.workers.tasks.import_image import (
rerun_import_all_images as _rerun_import_all_images,
)

flags_ = flags or None
count = _rerun_import_all_images(
limit=limit, server_type=server_type, flags=flags_, return_count=True
)
message = (
f"rerunning full image import on {count} images, confirm?"
if flags_ is None
else f"running following tasks ({', '.join(flag.name for flag in flags_)}) on {count} images, confirm?"
)
if typer.confirm(message):
_rerun_import_all_images(limit=limit, server_type=server_type, flags=flags_)
typer.echo("The task was successfully scheduled.")


@app.command()
def run_nutrition_extraction(
image_url: str = typer.Argument(
Expand Down
7 changes: 6 additions & 1 deletion robotoff/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,16 @@ def refresh_images_in_db(product_id: ProductIdentifier, images: JSONType):
save_image(product_id, source_image, image_url, images, use_cache=True)


def add_image_fingerprint(image_model: ImageModel):
def add_image_fingerprint(image_model: ImageModel, overwrite: bool = False) -> None:
"""Update image in DB to add the image fingerprint.
:param image_model: the image model to update
:param overwrite: whether to overwrite the existing fingerprint
"""
if not overwrite and image_model.fingerprint is not None:
logger.debug("image %s already has a fingerprint, skipping", image_model.id)
return

image_url = image_model.get_image_url()
image = get_image_from_url(
image_url, error_raise=False, session=http_session, use_cache=True
Expand Down
9 changes: 9 additions & 0 deletions robotoff/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,3 +422,12 @@ def validate_nutrients(self) -> Self:
if k.endswith(self.nutrition_data_per) # type: ignore
}
return self


class ImportImageFlag(str, enum.Enum):
add_image_fingerprint = "add_image_fingerprint"
import_insights_from_image = "import_insights_from_image"
extract_ingredients = "extract_ingredients"
extract_nutrition = "extract_nutrition"
run_logo_object_detection = "run_logo_object_detection"
run_nutrition_table_object_detection = "run_nutrition_table_object_detection"
193 changes: 141 additions & 52 deletions robotoff/workers/tasks/import_image.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import dataclasses
import datetime
from pathlib import Path
from typing import Optional

import elasticsearch
import numpy as np
Expand Down Expand Up @@ -36,7 +35,12 @@
with_db,
)
from robotoff.notifier import NotifierFactory
from robotoff.off import generate_image_url, get_source_from_url, parse_ingredients
from robotoff.off import (
generate_image_url,
generate_json_ocr_url,
get_source_from_url,
parse_ingredients,
)
from robotoff.prediction import ingredient_list, nutrition_extraction
from robotoff.prediction.upc_image import UPCImageType, find_image_is_upc
from robotoff.products import get_product_store
Expand All @@ -47,6 +51,7 @@
get_triton_inference_stub,
)
from robotoff.types import (
ImportImageFlag,
JSONType,
ObjectDetectionModel,
Prediction,
Expand All @@ -61,7 +66,69 @@
logger = get_logger(__name__)


def run_import_image_job(product_id: ProductIdentifier, image_url: str, ocr_url: str):
@with_db
def rerun_import_all_images(
limit: int | None = None,
server_type: ServerType | None = None,
return_count: bool = False,
flags: list[ImportImageFlag] | None = None,
) -> None | int:
"""Rerun full image import on all images in DB.
This includes launching all ML models and insight extraction from the image and
associated OCR. To control which tasks are rerun, use the --flags option.
:param limit: the maximum number of images to process, defaults to None (all)
:param server_type: the server type (project) of the products, defaults to None
(all)
:param return_count: if True, return the number of images to process, without
processing them, defaults to False
:param flags: the list of flags to rerun, defaults to None (all)
:return: the number of images to process, or None if return_count is False
"""
where_clauses = [ImageModel.deleted == False] # noqa: E712

if server_type is not None:
where_clauses.append(ImageModel.server_type == server_type.name)
query = (
ImageModel.select(
ImageModel.barcode, ImageModel.image_id, ImageModel.server_type
)
.where(*where_clauses)
.order_by(ImageModel.uploaded_at.desc())
.tuples()
)
if limit:
query = query.limit(limit)

if return_count:
return query.count()

for barcode, image_id, server_type_str in query:
if not isinstance(barcode, str) and not barcode.isdigit():
raise ValueError("Invalid barcode: %s" % barcode)

product_id = ProductIdentifier(barcode, ServerType[server_type_str])
image_url = generate_image_url(product_id, image_id)
ocr_url = generate_json_ocr_url(product_id, image_id)
enqueue_job(
run_import_image_job,
get_high_queue(product_id),
job_kwargs={"result_ttl": 0},
product_id=product_id,
image_url=image_url,
ocr_url=ocr_url,
flags=flags,
)
return None


def run_import_image_job(
product_id: ProductIdentifier,
image_url: str,
ocr_url: str,
flags: list[ImportImageFlag] | None = None,
) -> None:
"""This job is triggered every time there is a new OCR image available for
processing by Robotoff, via an event published on the Redis stream.
Expand All @@ -73,8 +140,20 @@ def run_import_image_job(product_id: ProductIdentifier, image_url: str, ocr_url:
3. Triggers the 'object_detection' task
4. Stores the imported image metadata in the Robotoff DB.
5. Compute image fingerprint, for duplicate image detection.
What tasks are performed can be controlled using the `flags` parameter. By
default, all tasks are performed. A new rq job is enqueued for each task.
:param product_id: the product identifier
:param image_url: the URL of the image to import
:param ocr_url: the URL of the OCR JSON file
:param flags: the list of flags to run, defaults to None (all)
"""
logger.info("Running `import_image` for %s, image %s", product_id, image_url)

if flags is None:
flags = [flag for flag in ImportImageFlag]

source_image = get_source_from_url(image_url)
product = get_product_store(product_id.server_type)[product_id]
if product is None and settings.ENABLE_MONGODB_ACCESS:
Expand All @@ -85,7 +164,7 @@ def run_import_image_job(product_id: ProductIdentifier, image_url: str, ocr_url:
)
return

product_images: Optional[JSONType] = getattr(product, "images", None)
product_images: JSONType | None = getattr(product, "images", None)
with db:
image_model = save_image(
product_id, source_image, image_url, product_images, use_cache=True
Expand All @@ -106,66 +185,76 @@ def run_import_image_job(product_id: ProductIdentifier, image_url: str, ocr_url:
ImageModel.bulk_update([image_model], fields=["deleted"])
return

# Compute image fingerprint, this job is low priority
enqueue_job(
add_image_fingerprint_job,
low_queue,
job_kwargs={"result_ttl": 0},
image_model_id=image_model.id,
)

if product_id.server_type.is_food():
# Currently we don't support insight generation for projects other
# than OFF (OBF, OPF,...)
if ImportImageFlag.add_image_fingerprint in flags:
# Compute image fingerprint, this job is low priority
enqueue_job(
import_insights_from_image,
get_high_queue(product_id),
add_image_fingerprint_job,
low_queue,
job_kwargs={"result_ttl": 0},
product_id=product_id,
image_url=image_url,
ocr_url=ocr_url,
)
# Only extract ingredient lists for food products, as the model was not
# trained on non-food products
enqueue_job(
extract_ingredients_job,
get_high_queue(product_id),
# We add a higher timeout, as we request Product Opener to
# parse ingredient list, which may take a while depending on
# the number of ingredient list (~1s per ingredient list)
job_kwargs={"result_ttl": 0, "timeout": "2m"},
product_id=product_id,
ocr_url=ocr_url,
)
enqueue_job(
extract_nutrition_job,
get_high_queue(product_id),
job_kwargs={"result_ttl": 0, "timeout": "2m"},
product_id=product_id,
image_url=image_url,
ocr_url=ocr_url,
image_model_id=image_model.id,
)
# We make sure there are no concurrent insight processing by sending
# the job to the same queue. The queue is selected based on the product
# barcode. See `get_high_queue` documentation for more details.
enqueue_job(
run_logo_object_detection,
get_high_queue(product_id),
job_kwargs={"result_ttl": 0},
product_id=product_id,
image_url=image_url,
ocr_url=ocr_url,
)

if product_id.server_type.is_food():
if ImportImageFlag.import_insights_from_image in flags:
# Currently we don't support insight generation for projects other
# than OFF (OBF, OPF,...)
enqueue_job(
import_insights_from_image,
get_high_queue(product_id),
job_kwargs={"result_ttl": 0},
product_id=product_id,
image_url=image_url,
ocr_url=ocr_url,
)

if ImportImageFlag.extract_ingredients in flags:
# Only extract ingredient lists for food products, as the model was not
# trained on non-food products
enqueue_job(
extract_ingredients_job,
get_high_queue(product_id),
# We add a higher timeout, as we request Product Opener to
# parse ingredient list, which may take a while depending on
# the number of ingredient list (~1s per ingredient list)
job_kwargs={"result_ttl": 0, "timeout": "2m"},
product_id=product_id,
ocr_url=ocr_url,
)

if ImportImageFlag.extract_nutrition in flags:
enqueue_job(
extract_nutrition_job,
get_high_queue(product_id),
job_kwargs={"result_ttl": 0, "timeout": "2m"},
product_id=product_id,
image_url=image_url,
ocr_url=ocr_url,
)

if ImportImageFlag.run_logo_object_detection in flags:
# We make sure there are no concurrent insight processing by sending
# the job to the same queue. The queue is selected based on the product
# barcode. See `get_high_queue` documentation for more details.
enqueue_job(
run_nutrition_table_object_detection,
run_logo_object_detection,
get_high_queue(product_id),
job_kwargs={"result_ttl": 0},
product_id=product_id,
image_url=image_url,
ocr_url=ocr_url,
)

if product_id.server_type.is_food():
if ImportImageFlag.run_nutrition_table_object_detection in flags:
# Run object detection model that detects nutrition tables
enqueue_job(
run_nutrition_table_object_detection,
get_high_queue(product_id),
job_kwargs={"result_ttl": 0},
product_id=product_id,
image_url=image_url,
)

# Run UPC detection to detect if the image is dominated by a UPC (and thus
# should not be a product selected image)
# UPC detection is buggy since the upgrade to OpenCV 4.10
Expand Down

0 comments on commit 98a1374

Please sign in to comment.