Skip to content

Commit

Permalink
fix: allow to run object detection models on URL list file
Browse files Browse the repository at this point in the history
  • Loading branch information
raphael0202 committed Dec 27, 2022
1 parent 5a392db commit 6b04692
Showing 1 changed file with 44 additions and 19 deletions.
63 changes: 44 additions & 19 deletions robotoff/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import typer

from robotoff.elasticsearch.client import get_es_client
from robotoff.off import get_barcode_from_url
from robotoff.types import ObjectDetectionModel, PredictionType, WorkerQueue

app = typer.Typer()
Expand Down Expand Up @@ -334,18 +335,29 @@ def run_object_detection_model(
model_name: ObjectDetectionModel = typer.Argument(
..., help="Name of the object detection model"
),
input_path: Optional[Path] = typer.Option(
None,
exists=True,
file_okay=True,
dir_okay=False,
help="text file with image URLs to run object detection on. "
"If null, a query is performed in DB to fetch images without image predictions "
"for the specified model.",
),
limit: Optional[int] = typer.Option(None, help="Maximum numbers of job to launch"),
):
"""Launch object detection model jobs on all missing images (images
without an ImagePrediction item for this model) in DB."""
from typing import Callable
from urllib.parse import urlparse

import tqdm
from peewee import JOIN

from robotoff import settings
from robotoff.models import ImageModel, ImagePrediction, db
from robotoff.off import generate_image_url
from robotoff.utils import text_file_iter
from robotoff.workers.queues import enqueue_job, low_queue
from robotoff.workers.tasks.import_image import (
run_logo_object_detection,
Expand All @@ -360,27 +372,40 @@ def run_object_detection_model(
else:
func = run_nutriscore_object_detection

with db:
query = (
ImageModel.select(ImageModel.barcode, ImageModel.id)
.join(
ImagePrediction,
JOIN.LEFT_OUTER,
on=(
(ImagePrediction.image_id == ImageModel.id)
& (ImagePrediction.model_name == model_name.value)
),
if input_path:
image_urls = list(text_file_iter(input_path))

for image_url in image_urls:
parsed_url = urlparse(image_url)
if not parsed_url.netloc or not parsed_url.scheme:
raise ValueError(f"invalid image URL: {image_url}")

else:
with db:
query = (
ImageModel.select(ImageModel.barcode, ImageModel.id)
.join(
ImagePrediction,
JOIN.LEFT_OUTER,
on=(
(ImagePrediction.image_id == ImageModel.id)
& (ImagePrediction.model_name == model_name.value)
),
)
.where(ImagePrediction.model_name.is_null())
.tuples()
)
.where(ImagePrediction.model_name.is_null())
.tuples()
)
if limit:
query = query.limit(limit)
missing_items = list(query)
if limit:
query = query.limit(limit)
image_urls = [
generate_image_url(barcode, image_id)
for barcode, image_id in query
if barcode.isdigit()
]

if typer.confirm(f"{len(missing_items)} jobs are going to be launched, confirm?"):
for barcode, image_id in tqdm.tqdm(missing_items, desc="image"):
image_url = generate_image_url(barcode, image_id)
if typer.confirm(f"{len(image_urls)} jobs are going to be launched, confirm?"):
for image_url in tqdm.tqdm(image_urls, desc="image"):
barcode = get_barcode_from_url(image_url)
enqueue_job(
func,
low_queue,
Expand Down

0 comments on commit 6b04692

Please sign in to comment.