From 6b046920758d1ace4f3b4361917040da0110cc79 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Bournhonesque?= Date: Tue, 27 Dec 2022 14:34:52 +0100 Subject: [PATCH] fix: allow to run object detection models on URL list file --- robotoff/cli/main.py | 63 +++++++++++++++++++++++++++++++------------- 1 file changed, 44 insertions(+), 19 deletions(-) diff --git a/robotoff/cli/main.py b/robotoff/cli/main.py index 4e4ae8aea4..7636aabc8d 100644 --- a/robotoff/cli/main.py +++ b/robotoff/cli/main.py @@ -6,6 +6,7 @@ import typer from robotoff.elasticsearch.client import get_es_client +from robotoff.off import get_barcode_from_url from robotoff.types import ObjectDetectionModel, PredictionType, WorkerQueue app = typer.Typer() @@ -334,11 +335,21 @@ def run_object_detection_model( model_name: ObjectDetectionModel = typer.Argument( ..., help="Name of the object detection model" ), + input_path: Optional[Path] = typer.Option( + None, + exists=True, + file_okay=True, + dir_okay=False, + help="text file with image URLs to run object detection on. " + "If null, a query is performed in DB to fetch images without image predictions " + "for the specified model.", + ), limit: Optional[int] = typer.Option(None, help="Maximum numbers of job to launch"), ): """Launch object detection model jobs on all missing images (images without an ImagePrediction item for this model) in DB.""" from typing import Callable + from urllib.parse import urlparse import tqdm from peewee import JOIN @@ -346,6 +357,7 @@ def run_object_detection_model( from robotoff import settings from robotoff.models import ImageModel, ImagePrediction, db from robotoff.off import generate_image_url + from robotoff.utils import text_file_iter from robotoff.workers.queues import enqueue_job, low_queue from robotoff.workers.tasks.import_image import ( run_logo_object_detection, @@ -360,27 +372,40 @@ def run_object_detection_model( else: func = run_nutriscore_object_detection - with db: - query = ( - ImageModel.select(ImageModel.barcode, ImageModel.id) - .join( - ImagePrediction, - JOIN.LEFT_OUTER, - on=( - (ImagePrediction.image_id == ImageModel.id) - & (ImagePrediction.model_name == model_name.value) - ), + if input_path: + image_urls = list(text_file_iter(input_path)) + + for image_url in image_urls: + parsed_url = urlparse(image_url) + if not parsed_url.netloc or not parsed_url.scheme: + raise ValueError(f"invalid image URL: {image_url}") + + else: + with db: + query = ( + ImageModel.select(ImageModel.barcode, ImageModel.id) + .join( + ImagePrediction, + JOIN.LEFT_OUTER, + on=( + (ImagePrediction.image_id == ImageModel.id) + & (ImagePrediction.model_name == model_name.value) + ), + ) + .where(ImagePrediction.model_name.is_null()) + .tuples() ) - .where(ImagePrediction.model_name.is_null()) - .tuples() - ) - if limit: - query = query.limit(limit) - missing_items = list(query) + if limit: + query = query.limit(limit) + image_urls = [ + generate_image_url(barcode, image_id) + for barcode, image_id in query + if barcode.isdigit() + ] - if typer.confirm(f"{len(missing_items)} jobs are going to be launched, confirm?"): - for barcode, image_id in tqdm.tqdm(missing_items, desc="image"): - image_url = generate_image_url(barcode, image_id) + if typer.confirm(f"{len(image_urls)} jobs are going to be launched, confirm?"): + for image_url in tqdm.tqdm(image_urls, desc="image"): + barcode = get_barcode_from_url(image_url) enqueue_job( func, low_queue,