Skip to content

Commit

Permalink
feat: improve CLI import commands
Browse files Browse the repository at this point in the history
  • Loading branch information
raphael0202 committed Dec 9, 2022
1 parent f24a600 commit f5491d1
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 33 deletions.
110 changes: 82 additions & 28 deletions robotoff/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def refresh_insights(
help="Refresh a specific product. If not provided, all products are updated",
),
batch_size: int = typer.Option(
50, help="Number of products to send in a worker tasks"
100, help="Number of products to send in a worker tasks"
),
):
"""Refresh insights based on available predictions.
Expand Down Expand Up @@ -268,69 +268,123 @@ def refresh_insights(


@app.command()
def import_images_in_db():
def import_images_in_db(
batch_size: int = typer.Option(
500, help="Number of items to send in a worker tasks"
),
):
"""Make sure that every image available in MongoDB is saved in `image`
table."""
import tqdm
from more_itertools import chunked

from robotoff import settings
from robotoff.models import ImageModel, db
from robotoff.off import generate_image_path
from robotoff.products import get_product_store
from robotoff.utils import get_logger
from robotoff.workers.queues import enqueue_job, low_queue
from robotoff.workers.tasks.import_image import save_image_job

logger = get_logger()

with db:
logger.info("Fetching existing images in DB...")
existing_images = set(
ImageModel.select(ImageModel.barcode, ImageModel.image_id).tuples()
)

store = get_product_store()
for product in tqdm.tqdm(store.iter_product(projection=["images", "code"])):
to_add = []
for product in tqdm.tqdm(
store.iter_product(projection=["images", "code"]), desc="product"
):
barcode = product.barcode
for image_id in (id_ for id_ in product.images.keys() if id_.isdigit()):
source_image = generate_image_path(product.barcode, image_id)
if (barcode, image_id) not in existing_images:
to_add.append((barcode, generate_image_path(barcode, image_id)))

batches = list(chunked(to_add, batch_size))
if typer.confirm(
f"{len(batches)} add image jobs are going to be launched, confirm?"
):
for batch in tqdm.tqdm(batches, desc="job"):
enqueue_job(
save_image_job,
low_queue,
job_kwargs={"result_ttl": 0},
barcode=product.barcode,
source_image=source_image,
batch=batch,
server_domain=settings.OFF_SERVER_DOMAIN,
)


class ObjectDetectionModel(enum.Enum):
nutriscore = "nutriscore"
universal_logo_detector = "universal-logo-detector"
nutrition_table = "nutrition-table"


@app.command()
def run_object_detection_models():
def run_object_detection_model(
model_name: ObjectDetectionModel = typer.Argument(
..., help="Name of the object detection model"
),
limit: Optional[int] = typer.Option(None, help="Maximum numbers of job to launch"),
):
"""Run universal-logo-detector and nutrition-table object detection models
on all images in DB."""
import tqdm
from peewee import JOIN

from robotoff import settings
from robotoff.models import ImageModel, db
from robotoff.models import ImageModel, ImagePrediction, db
from robotoff.off import generate_image_url
from robotoff.workers.queues import enqueue_job, low_queue
from robotoff.workers.tasks.import_image import (
run_logo_object_detection,
run_nutriscore_object_detection,
run_nutrition_table_object_detection,
)

if model_name == ObjectDetectionModel.universal_logo_detector:
func = run_logo_object_detection
elif model_name == ObjectDetectionModel.nutrition_table:
func = run_nutrition_table_object_detection
else:
func = run_nutriscore_object_detection

with db:
items = list(
ImageModel.select(ImageModel.barcode, ImageModel.image_id).tuples()
query = (
ImageModel.select(ImageModel.barcode, ImageModel.id)
.join(
ImagePrediction,
JOIN.LEFT_OUTER,
on=(
(ImagePrediction.image_id == ImageModel.id)
& (ImagePrediction.model_name == model_name)
),
)
.where(ImagePrediction.model_name.is_null())
.tuples()
)
if limit:
query = query.limit(limit)
missing_items = list(query)

for barcode, image_id in tqdm.tqdm(items, desc="barcode"):
image_url = generate_image_url(barcode, image_id)
enqueue_job(
run_logo_object_detection,
low_queue,
job_kwargs={"result_ttl": 0},
barcode=barcode,
image_url=image_url,
server_domain=settings.OFF_SERVER_DOMAIN,
)
enqueue_job(
run_nutrition_table_object_detection,
low_queue,
job_kwargs={"result_ttl": 0},
barcode=barcode,
image_url=image_url,
server_domain=settings.OFF_SERVER_DOMAIN,
)
if limit:
missing_items = missing_items[:limit]

if typer.confirm(f"{len(missing_items)} jobs are going to be launched, confirm?"):
for barcode, image_id in tqdm.tqdm(missing_items, desc="image"):
image_url = generate_image_url(barcode, image_id)
enqueue_job(
func,
low_queue,
job_kwargs={"result_ttl": 0},
barcode=barcode,
image_url=image_url,
server_domain=settings.OFF_SERVER_DOMAIN,
)


@app.command()
Expand Down
16 changes: 11 additions & 5 deletions robotoff/workers/tasks/import_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,17 @@ def import_insights_from_image(


@with_db
def save_image_job(barcode: str, source_image: str, server_domain: str):
product = get_product_store()[barcode]
if product is None:
return
save_image(barcode, source_image, product, server_domain)
def save_image_job(batch: list[tuple[str, str]], server_domain: str):
"""Save a batch of images in DB.
:param batch: a batch of (barcode, source_image) tuples
:param server_domain: the server domain to use
"""
for barcode, source_image in batch:
product = get_product_store()[barcode]
if product is None:
continue
save_image(barcode, source_image, product, server_domain)


def save_image(
Expand Down

0 comments on commit f5491d1

Please sign in to comment.