From 02f7e2ed661dcced1fecf92a428f05c05d4f13d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Bournhonesque?= Date: Wed, 28 Dec 2022 16:12:17 +0100 Subject: [PATCH] fix: give credit to annotator when annotating logos --- robotoff/app/api.py | 127 ++++++++---------- robotoff/app/core.py | 36 +++++ robotoff/cli/main.py | 15 +-- robotoff/insights/importer.py | 67 ++++++--- robotoff/logos.py | 45 +++++-- robotoff/scheduler/__init__.py | 4 +- robotoff/types.py | 54 ++++++++ robotoff/workers/tasks/__init__.py | 4 +- robotoff/workers/tasks/import_image.py | 7 +- robotoff/workers/tasks/product_updated.py | 32 ++--- scripts/refresh_insights.py | 34 ----- .../insights/test_category_import.py | 27 ++-- tests/integration/test_import_image.py | 3 +- tests/unit/insights/test_importer.py | 24 ++-- .../workers/tasks/test_product_updated.py | 8 +- 15 files changed, 291 insertions(+), 196 deletions(-) delete mode 100644 scripts/refresh_insights.py diff --git a/robotoff/app/api.py b/robotoff/app/api.py index bce3511f18..30635a74be 100644 --- a/robotoff/app/api.py +++ b/robotoff/app/api.py @@ -29,6 +29,7 @@ get_logo_annotation, get_predictions, save_annotation, + update_logo_annotations, ) from robotoff.app.middleware import DBConnectionMiddleware from robotoff.elasticsearch import get_es_client @@ -45,6 +46,7 @@ LogoEmbedding, ProductInsight, batch_insert, + db, ) from robotoff.off import ( OFFAuthentication, @@ -749,6 +751,16 @@ def on_get(self, req: falcon.Request, resp: falcon.Response): resp.media = {"logos": items, "count": query_count} +def check_logo_annotation(type_: str, value: Optional[str] = None): + if value is not None: + if type_ == "label" and not is_prefixed_value(value): + raise falcon.HTTPBadRequest( + description=f"language-prefixed value are required for label type (here: {value})" + ) + elif type_ in ("brand", "category", "label", "store"): + raise falcon.HTTPBadRequest(description=f"value required for type {type_})") + + class ImageLogoDetailResource: def on_get(self, req: falcon.Request, resp: falcon.Response, logo_id: int): logo = LogoAnnotation.get_or_none(id=logo_id) @@ -770,38 +782,25 @@ def on_put(self, req: falcon.Request, resp: falcon.Response, logo_id: int): description="authentication is required to annotate logos" ) - logo = LogoAnnotation.get_or_none(id=logo_id) - - if logo is None: - resp.status = falcon.HTTP_404 - return - - type_ = req.media["type"] - value = req.media["value"] or None - updated = False - - if type_ != logo.annotation_type: - logo.annotation_type = type_ - updated = True - - if value != logo.annotation_value: - logo.annotation_value = value - - if value is not None: - value_tag = get_tag(value) - logo.annotation_value_tag = value_tag - logo.taxonomy_value = match_taxonomized_value(value_tag, type_) - else: - logo.annotation_value_tag = None - logo.taxonomy_value = None + with db.atomic(): + logo = LogoAnnotation.get_or_none(id=logo_id) + if logo is None: + resp.status = falcon.HTTP_404 + return - updated = True + type_ = req.media["type"] + value = req.media["value"] or None + check_logo_annotation(type_, value) - if updated: - logo.username = auth.get_username() - logo.completed_at = datetime.datetime.utcnow() - logo.save() - generate_insights_from_annotated_logos([logo], settings.OFF_SERVER_DOMAIN) + if type_ != logo.annotation_type or value != logo.annotation_value: + annotated_logos = update_logo_annotations( + [(type_, value, logo)], + username=auth.get_username() or "unknown", + completed_at=datetime.datetime.utcnow(), + ) + generate_insights_from_annotated_logos( + annotated_logos, settings.OFF_SERVER_DOMAIN, auth + ) resp.status = falcon.HTTP_204 @@ -816,47 +815,37 @@ def on_post(self, req: falcon.Request, resp: falcon.Response): ) server_domain = req.media.get("server_domain", settings.OFF_SERVER_DOMAIN) annotations = req.media["annotations"] - username = auth.get_username() completed_at = datetime.datetime.utcnow() - annotated_logos = [] - - for annotation in annotations: - logo_id = annotation["logo_id"] - type_ = annotation["type"] - value = annotation["value"] or None - try: - logo = LogoAnnotation.get_by_id(logo_id) - except LogoAnnotation.DoesNotExist: - raise falcon.HTTPNotFound(description=f"logo {logo_id} not found") - - if logo.annotation_type is not None: - # Logo is already annotated, skip - continue - - if value is not None: - if type_ == "label" and not is_prefixed_value(value): - raise falcon.HTTPBadRequest( - description=f"language-prefixed value are required for label type (here: {value})" - ) - logo.annotation_value = value - value_tag = get_tag(value) - logo.annotation_value_tag = value_tag - logo.taxonomy_value = match_taxonomized_value(value_tag, type_) - elif type_ in ("brand", "category", "label", "store"): - raise falcon.HTTPBadRequest( - description=f"value required for type {type_} (logo {logo_id})" + annotation_logos = [] + + with db.atomic(): + for annotation in annotations: + logo_id = annotation["logo_id"] + type_ = annotation["type"] + value = annotation["value"] or None + check_logo_annotation(type_, value) + + try: + logo = LogoAnnotation.get_by_id(logo_id) + except LogoAnnotation.DoesNotExist: + raise falcon.HTTPNotFound(description=f"logo {logo_id} not found") + + if logo.annotation_type is None: + # Don't annotate already annotated logos + annotation_logos.append((type_, value, logo)) + + if annotation_logos: + annotated_logos = update_logo_annotations( + annotation_logos, + username=auth.get_username() or "unknown", + completed_at=completed_at, ) - - logo.annotation_type = type_ - logo.username = username - logo.completed_at = completed_at - annotated_logos.append(logo) - - for logo in annotated_logos: - logo.save() - - created = generate_insights_from_annotated_logos(annotated_logos, server_domain) - resp.media = {"created insights": created} + annotated = generate_insights_from_annotated_logos( + annotated_logos, server_domain, auth + ) + else: + annotated = 0 + resp.media = {"created insights": annotated} class ImageLogoUpdateResource: diff --git a/robotoff/app/core.py b/robotoff/app/core.py index 1f75ad2226..34f5401644 100644 --- a/robotoff/app/core.py +++ b/robotoff/app/core.py @@ -1,3 +1,4 @@ +import datetime import functools from enum import Enum from typing import Iterable, NamedTuple, Optional, Union @@ -24,7 +25,9 @@ db, ) from robotoff.off import OFFAuthentication +from robotoff.taxonomy import match_taxonomized_value from robotoff.utils import get_logger +from robotoff.utils.text import get_tag logger = get_logger(__name__) @@ -434,3 +437,36 @@ def get_logo_annotation( return query.count() else: return query.iterator() + + +def update_logo_annotations( + annotation_logos: list[tuple[str, Optional[str], LogoAnnotation]], + username: str, + completed_at: datetime.datetime, +) -> list[LogoAnnotation]: + updated_fields = set() + updated_logos = [] + for type_, value, logo in annotation_logos: + logo.annotation_type = type_ + updated_fields.add("annotation_type") + + if value is not None: + value_tag = get_tag(value) + logo.annotation_value = value + logo.annotation_value_tag = value_tag + logo.taxonomy_value = match_taxonomized_value(value_tag, type_) + logo.username = username + logo.completed_at = completed_at + updated_fields |= { + "annotation_value", + "annotation_value_tag", + "taxonomy_value", + "username", + "completed_at", + } + updated_logos.append(logo) + + if updated_logos: + LogoAnnotation.bulk_update(updated_logos, fields=list(updated_fields)) + + return updated_logos diff --git a/robotoff/cli/main.py b/robotoff/cli/main.py index ddd0456f77..51f3cc1246 100644 --- a/robotoff/cli/main.py +++ b/robotoff/cli/main.py @@ -72,9 +72,10 @@ def regenerate_ocr_insights( ) with db: - imported = importer.import_insights(predictions, settings.OFF_SERVER_DOMAIN) - - logger.info("Import finished, %s insights imported", imported) + import_result = importer.import_insights( + predictions, settings.OFF_SERVER_DOMAIN + ) + logger.info(import_result) @app.command() @@ -200,20 +201,16 @@ def import_insights( else: raise ValueError("--generate-from or --input-path must be provided") - imported = 0 with db.connection_context(): for prediction_batch in tqdm.tqdm( chunked(predictions, batch_size), desc="prediction batch" ): # Create a new transaction for every batch with db.atomic(): - batch_imported = importer.import_insights( + import_results = importer.import_insights( prediction_batch, settings.OFF_SERVER_DOMAIN ) - logger.info(f"{batch_imported} insights imported in batch") - imported += batch_imported - - logger.info(f"{imported} insights imported") + logger.info(import_results) @app.command() diff --git a/robotoff/insights/importer.py b/robotoff/insights/importer.py index b2e27bb5e2..5042465bfc 100644 --- a/robotoff/insights/importer.py +++ b/robotoff/insights/importer.py @@ -29,7 +29,13 @@ get_taxonomy, match_taxonomized_value, ) -from robotoff.types import InsightType, PredictionType +from robotoff.types import ( + InsightImportResult, + InsightType, + PredictionImportResult, + PredictionType, + ProductInsightImportResult, +) from robotoff.utils import get_logger, text_file_iter from robotoff.utils.cache import CachedStore @@ -246,7 +252,7 @@ def import_insights( predictions: list[Prediction], server_domain: str, product_store: DBProductStore, - ) -> int: + ) -> ProductInsightImportResult: """Import insights, this is the main method. :return: the number of insights that were imported. @@ -272,19 +278,21 @@ def import_insights( to_create, to_update, to_delete = cls.generate_insights( barcode, predictions, server_domain, product_store ) - if to_delete: - to_delete_ids = [insight.id for insight in to_delete] - logger.info("Deleting %s insights", len(to_delete_ids)) + to_delete_ids = [insight.id for insight in to_delete] + if to_delete_ids: ProductInsight.delete().where( ProductInsight.id.in_(to_delete_ids) ).execute() + if to_create: inserts += batch_insert( ProductInsight, (model_to_dict(insight) for insight in to_create), 50, ) + created_ids = [insight.id for insight in to_create] + updated_ids = [] for insight, reference_insight in to_update: update = {} for field_name in ( @@ -298,11 +306,18 @@ def import_insights( update[field_name] = getattr(insight, field_name) if update: + updated_ids.append(reference_insight.id) ProductInsight.update(**update).where( ProductInsight.id == reference_insight.id ).execute() - return inserts + return ProductInsightImportResult( + insight_created_ids=created_ids, + insight_deleted_ids=to_delete_ids, + insight_updated_ids=updated_ids, + barcode=barcode, + type=cls.get_type(), + ) @classmethod def generate_insights( @@ -999,7 +1014,7 @@ def import_insights( predictions: Iterable[Prediction], server_domain: str, product_store: Optional[DBProductStore] = None, -) -> int: +) -> InsightImportResult: """Import predictions and generate (and import) insights from these predictions. @@ -1008,19 +1023,23 @@ def import_insights( if product_store is None: product_store = get_product_store() - updated_prediction_types_by_barcode = import_predictions( + updated_prediction_types_by_barcode, prediction_import_results = import_predictions( predictions, product_store, server_domain ) - return import_insights_for_products( + product_insight_import_results = import_insights_for_products( updated_prediction_types_by_barcode, server_domain, product_store ) + return InsightImportResult( + product_insight_import_results=product_insight_import_results, + prediction_import_results=prediction_import_results, + ) def import_insights_for_products( prediction_types_by_barcode: dict[str, set[PredictionType]], server_domain: str, product_store: DBProductStore, -) -> int: +) -> list[ProductInsightImportResult]: """Re-compute insights for products with new predictions. :param prediction_types_by_barcode: a dict that associates each barcode @@ -1030,7 +1049,7 @@ def import_insights_for_products( :return: Number of imported insights """ - imported = 0 + import_results = [] for importer in IMPORTERS: required_prediction_types = importer.get_required_prediction_types() selected_barcodes: list[str] = [] @@ -1052,26 +1071,27 @@ def import_insights_for_products( ): try: with Lock(name=f"robotoff:import:{barcode}", expire=60, timeout=10): - imported += importer.import_insights( + result = importer.import_insights( barcode, list(product_predictions), server_domain, product_store, ) + import_results.append(result) except LockedResourceException: logger.info( "Couldn't acquire insight import lock, skipping insight import for product %s", barcode, ) continue - return imported + return import_results def import_predictions( predictions: Iterable[Prediction], product_store: DBProductStore, server_domain: str, -) -> dict[str, set[PredictionType]]: +) -> tuple[dict[str, set[PredictionType]], list[PredictionImportResult]]: """Check validity and import provided Prediction. :param predictions: the Predictions to import @@ -1086,28 +1106,31 @@ def import_predictions( if is_valid_product_prediction(p, product_store[p.barcode]) # type: ignore ] - predictions_imported = 0 + predictions_import_results = [] updated_prediction_types_by_barcode: dict[str, set[PredictionType]] = {} for barcode, product_predictions_iter in itertools.groupby( sorted(predictions, key=operator.attrgetter("barcode")), operator.attrgetter("barcode"), ): product_predictions_group = list(product_predictions_iter) - predictions_imported += import_product_predictions( + predictions_imported = import_product_predictions( barcode, product_predictions_group, server_domain ) + predictions_import_results.append( + PredictionImportResult(created=predictions_imported, barcode=barcode) + ) updated_prediction_types_by_barcode[barcode] = set( prediction.type for prediction in product_predictions_group ) logger.info("%s predictions imported", predictions_imported) - return updated_prediction_types_by_barcode + return updated_prediction_types_by_barcode, predictions_import_results def refresh_insights( barcode: str, server_domain: str, product_store: Optional[DBProductStore] = None, -) -> int: +) -> list[InsightImportResult]: """Refresh all insights for specific product. All predictions are fetched, and insights are created/deleted by each @@ -1129,18 +1152,18 @@ def refresh_insights( predictions = [Prediction(**p) for p in get_product_predictions([barcode])] prediction_types = set(p.type for p in predictions) - imported = 0 + import_results = [] for importer in IMPORTERS: required_prediction_types = importer.get_required_prediction_types() if prediction_types >= required_prediction_types: - imported += importer.import_insights( + import_result = importer.import_insights( barcode, [p for p in predictions if p.type in required_prediction_types], server_domain, product_store, ) - - return imported + import_results.append(import_result) + return import_results def get_product_predictions( diff --git a/robotoff/logos.py b/robotoff/logos.py index e983f02ab6..80e76db18a 100644 --- a/robotoff/logos.py +++ b/robotoff/logos.py @@ -1,4 +1,5 @@ import datetime +import itertools import operator from typing import Optional @@ -11,6 +12,10 @@ from robotoff import settings from robotoff.elasticsearch import get_es_client +from robotoff.insights.annotate import ( + UPDATED_ANNOTATION_RESULT, + InsightAnnotatorFactory, +) from robotoff.insights.importer import import_insights from robotoff.logo_label_type import LogoLabelType from robotoff.models import ( @@ -21,10 +26,11 @@ LogoEmbedding, ) from robotoff.models import Prediction as PredictionModel -from robotoff.models import db +from robotoff.models import ProductInsight, db +from robotoff.off import OFFAuthentication from robotoff.prediction.types import Prediction from robotoff.slack import NotifierFactory -from robotoff.types import ElasticSearchIndex, PredictionType +from robotoff.types import ElasticSearchIndex, InsightImportResult, PredictionType from robotoff.utils import get_logger from robotoff.utils.types import JSONType @@ -302,7 +308,7 @@ def import_logo_insights( thresholds: dict[LogoLabelType, float], default_threshold: float = 0.1, notify: bool = True, -): +) -> InsightImportResult: selected_logos = [] logo_probs = [] for logo in logos: @@ -325,7 +331,7 @@ def import_logo_insights( logo_probs.append(probs) if not logos: - return + return InsightImportResult() # Delete all predictions for these logos from universal logo detectors # that are not from a human annotator @@ -344,24 +350,25 @@ def import_logo_insights( ) ).execute() predictions = predict_logo_predictions(selected_logos, logo_probs) - imported = import_insights(predictions, server_domain) + import_result = import_insights(predictions, server_domain) if notify: for logo, probs in zip(selected_logos, logo_probs): NotifierFactory.get_notifier().send_logo_notification(logo, probs) - return imported + return import_result def generate_insights_from_annotated_logos( - logos: list[LogoAnnotation], server_domain: str + logos: list[LogoAnnotation], server_domain: str, auth: OFFAuthentication ) -> int: + """Generate and apply insights from annotated logos.""" predictions = [] for logo in logos: prediction = generate_prediction( logo_type=logo.annotation_type, logo_value=logo.taxonomy_value, - automatic_processing=True, # because this is a user annotation, which we trust. + automatic_processing=False, # we're going to apply it immediately data={ "confidence": 1.0, "logo_id": logo.id, @@ -379,11 +386,25 @@ def generate_insights_from_annotated_logos( prediction.source_image = image.source_image predictions.append(prediction) - imported = import_insights(predictions, server_domain) + import_result = import_insights(predictions, server_domain) + if import_result.created_predictions_count(): + logger.info(import_result) + + annotated = 0 + for created_id in itertools.chain.from_iterable( + insight_import_result.insight_created_ids + for insight_import_result in import_result.product_insight_import_results + ): + insight = ProductInsight.get_or_none(id=created_id) + if insight: + annotator = InsightAnnotatorFactory.get(insight.type) + logger.info( + "Annotating insight %s (product: %s)", insight.id, insight.barcode + ) + annotation_result = annotator.annotate(insight, 1, auth=auth) + annotated += int(annotation_result == UPDATED_ANNOTATION_RESULT) - if imported: - logger.info("%s logo insights imported after annotation", imported) - return imported + return annotated def predict_logo_predictions( diff --git a/robotoff/scheduler/__init__.py b/robotoff/scheduler/__init__.py index 15b1cca048..e394739e43 100644 --- a/robotoff/scheduler/__init__.py +++ b/robotoff/scheduler/__init__.py @@ -220,10 +220,10 @@ def generate_insights(): product_predictions_iter = predict_from_dataset(dataset, datetime_threshold) with db: - imported = import_insights( + import_result = import_insights( product_predictions_iter, server_domain=settings.OFF_SERVER_DOMAIN ) - logger.info("%s category insights imported", imported) + logger.info(import_result) def transform_insight_iter(insights_iter: Iterable[dict]): diff --git a/robotoff/types.py b/robotoff/types.py index d820538481..55705ac721 100644 --- a/robotoff/types.py +++ b/robotoff/types.py @@ -1,4 +1,6 @@ +import dataclasses import enum +import uuid class WorkerQueue(enum.Enum): @@ -136,3 +138,55 @@ class InsightType(str, enum.Enum): class ElasticSearchIndex(str, enum.Enum): product = "product" logo = "logo" + + +@dataclasses.dataclass +class ProductInsightImportResult: + insight_created_ids: list[uuid.UUID] + insight_updated_ids: list[uuid.UUID] + insight_deleted_ids: list[uuid.UUID] + barcode: str + type: InsightType + + +@dataclasses.dataclass +class PredictionImportResult: + created: int + barcode: str + + +@dataclasses.dataclass +class InsightImportResult: + product_insight_import_results: list[ + ProductInsightImportResult + ] = dataclasses.field(default_factory=list) + prediction_import_results: list[PredictionImportResult] = dataclasses.field( + default_factory=list + ) + + def created_predictions_count(self) -> int: + return sum(x.created for x in self.prediction_import_results) + + def created_insights_count(self) -> int: + return sum( + len(x.insight_created_ids) for x in self.product_insight_import_results + ) + + def deleted_insights_count(self) -> int: + return sum( + len(x.insight_deleted_ids) for x in self.product_insight_import_results + ) + + def updated_insights_count(self) -> int: + return sum( + len(x.insight_updated_ids) for x in self.product_insight_import_results + ) + + def __repr__(self) -> str: + return ( + f"" + ) diff --git a/robotoff/workers/tasks/__init__.py b/robotoff/workers/tasks/__init__.py index f0ba988154..c437cd64a0 100644 --- a/robotoff/workers/tasks/__init__.py +++ b/robotoff/workers/tasks/__init__.py @@ -55,4 +55,6 @@ def refresh_insights_job(barcodes: list[str], server_domain: str): f"Refreshing insights for {len(barcodes)} products, server_domain: {server_domain}" ) for barcode in barcodes: - refresh_insights(barcode, server_domain) + import_results = refresh_insights(barcode, server_domain) + for import_result in import_results: + logger.info(import_result) diff --git a/robotoff/workers/tasks/import_image.py b/robotoff/workers/tasks/import_image.py index b219abe2e2..9740022e16 100644 --- a/robotoff/workers/tasks/import_image.py +++ b/robotoff/workers/tasks/import_image.py @@ -145,8 +145,8 @@ def import_insights_from_image( ) with db: - imported = import_insights(predictions, server_domain) - logger.info("Import finished, %s insights imported", imported) + import_result = import_insights(predictions, server_domain) + logger.info(import_result) def save_image_job(batch: list[tuple[str, str]], server_domain: str): @@ -313,7 +313,8 @@ def run_nutriscore_object_detection(barcode: str, image_url: str, server_domain: "bounding_box": result["bounding_box"], }, ) - import_insights([prediction], server_domain) + import_result = import_insights([prediction], server_domain) + logger.info(import_result) def run_logo_object_detection( diff --git a/robotoff/workers/tasks/product_updated.py b/robotoff/workers/tasks/product_updated.py index d64883c034..0c28b58435 100644 --- a/robotoff/workers/tasks/product_updated.py +++ b/robotoff/workers/tasks/product_updated.py @@ -44,15 +44,16 @@ def update_insights_job(barcode: str, server_domain: str): updated_product_predict_insights(barcode, product_dict, server_domain) logger.info("Refreshing insights...") - imported = refresh_insights(barcode, server_domain) - logger.info("%s insights created after refresh", imported) + import_results = refresh_insights(barcode, server_domain) + for import_result in import_results: + logger.info(import_result) except LockedResourceException: logger.info( f"Couldn't acquire product_update lock, skipping product_update for product {barcode}" ) -def add_category_insight(barcode: str, product: JSONType, server_domain: str) -> bool: +def add_category_insight(barcode: str, product: JSONType, server_domain: str): """Predict categories for product and import predicted category insight. :param barcode: product barcode @@ -61,7 +62,7 @@ def add_category_insight(barcode: str, product: JSONType, server_domain: str) -> :return: True if at least one category insight was imported """ if get_server_type(server_domain) != ServerType.off: - return False + return logger.info("Predicting product categories...") # predict category using matching algorithm on product name @@ -79,32 +80,25 @@ def add_category_insight(barcode: str, product: JSONType, server_domain: str) -> ) if len(product_predictions) < 1: - return False + return for prediction in product_predictions: prediction.barcode = barcode - imported = import_insights(product_predictions, server_domain) - logger.info("%s category insight imported for product %s", imported, barcode) - - return bool(imported) + import_result = import_insights(product_predictions, server_domain) + logger.info(import_result) def updated_product_predict_insights( barcode: str, product: JSONType, server_domain: str -) -> bool: - updated = add_category_insight(barcode, product, server_domain) +) -> None: + add_category_insight(barcode, product, server_domain) product_name = product.get("product_name") if not product_name: - return updated + return logger.info("Generating predictions from product name...") predictions_all = get_predictions_from_product_name(barcode, product_name) - imported = import_insights(predictions_all, server_domain) - logger.info("%s insights imported for product %s", imported, barcode) - - if imported: - updated = True - - return updated + import_result = import_insights(predictions_all, server_domain) + logger.info(import_result) diff --git a/scripts/refresh_insights.py b/scripts/refresh_insights.py deleted file mode 100644 index 139e288efd..0000000000 --- a/scripts/refresh_insights.py +++ /dev/null @@ -1,34 +0,0 @@ -import tqdm -from more_itertools import chunked - -from robotoff import settings -from robotoff.insights.importer import refresh_insights -from robotoff.models import ProductInsight, db -from robotoff.utils import get_logger - -logger = get_logger() -logger.info("Refreshing insights of all products") - -imported = 0 - -with db: - barcodes = [ - barcode - for (barcode, _) in ProductInsight.select( - ProductInsight.barcode, ProductInsight.timestamp - ) - .where(ProductInsight.annotation.is_null()) - .order_by(ProductInsight.timestamp.asc(), ProductInsight.barcode.asc()) - .tuples() - .iterator() - ] - -barcodes = sorted(set(barcodes), key=lambda x: barcodes.index(x)) -logger.info(f"{len(barcodes)} products to refresh") -for barcode_batch in tqdm.tqdm(chunked(barcodes, 100)): - with db: - for barcode in barcode_batch: - logger.info(f"Refreshing insights for product {barcode}") - imported += refresh_insights(barcode, settings.OFF_SERVER_DOMAIN) - -logger.info(f"Refreshed insights: {imported}") diff --git a/tests/integration/insights/test_category_import.py b/tests/integration/insights/test_category_import.py index e382599b7d..78f60bc657 100644 --- a/tests/integration/insights/test_category_import.py +++ b/tests/integration/insights/test_category_import.py @@ -76,12 +76,11 @@ def fake_product_store(self): def _run_import(self, predictions, product_store=None): if product_store is None: product_store = self.fake_product_store() - imported = import_insights( + return import_insights( predictions, server_domain=settings.OFF_SERVER_DOMAIN, product_store=product_store, ) - return imported @pytest.mark.parametrize( "predictions", @@ -107,8 +106,10 @@ def test_import_one_same_value_tag(self, predictions): original_insights = ProductInsight.select() assert len(original_insights) == 1 original_timestamp = original_insights[0].timestamp - imported = self._run_import(predictions) - assert imported == 0 + import_result = self._run_import(predictions) + assert import_result.created_insights_count() == 0 + assert import_result.updated_insights_count() == 1 + assert import_result.deleted_insights_count() == 0 # no insight created insights = list(ProductInsight.select()) assert len(insights) == 1 @@ -133,8 +134,10 @@ def test_import_one_same_value_tag(self, predictions): def test_import_one_different_value_tag(self, predictions): """Test when a more precise category is available as prediction: the prediction should be used as insight instead of the less precise one.""" - imported = self._run_import(predictions) - assert imported == 1 + import_result = self._run_import(predictions) + assert import_result.created_insights_count() == 1 + assert import_result.updated_insights_count() == 0 + assert import_result.deleted_insights_count() == 1 # no insight created assert ProductInsight.select().count() == 1 inserted = ProductInsight.get(ProductInsight.id != insight_id1) @@ -143,10 +146,12 @@ def test_import_one_different_value_tag(self, predictions): assert not inserted.automatic_processing def test_import_auto(self): - imported = self._run_import( + import_result = self._run_import( [neural_prediction("en:smoked-salmons", confidence=0.91, auto=True)] ) - assert imported == 1 + assert import_result.created_insights_count() == 1 + assert import_result.updated_insights_count() == 0 + assert import_result.deleted_insights_count() == 1 # no insight created assert ProductInsight.select().count() == 1 inserted = ProductInsight.get(ProductInsight.id != insight_id1) @@ -165,7 +170,9 @@ def test_import_auto(self): ) def test_import_product_not_in_store(self, predictions): # we should not create insight for non existing products ! - imported = self._run_import(predictions, product_store={barcode1: None}) - assert imported == 0 + import_result = self._run_import(predictions, product_store={barcode1: None}) + assert import_result.created_insights_count() == 0 + assert import_result.updated_insights_count() == 0 + assert import_result.deleted_insights_count() == 0 # no insight created assert ProductInsight.select().count() == 1 diff --git a/tests/integration/test_import_image.py b/tests/integration/test_import_image.py index eb09a9f040..f83d9fa430 100644 --- a/tests/integration/test_import_image.py +++ b/tests/integration/test_import_image.py @@ -4,6 +4,7 @@ from robotoff import settings from robotoff.models import LogoEmbedding +from robotoff.types import InsightImportResult from robotoff.workers.tasks.import_image import ( process_created_logos, save_logo_embeddings, @@ -77,7 +78,7 @@ def test_process_created_logos(peewee_db, mocker): ) import_logo_insights_mock = mocker.patch( "robotoff.workers.tasks.import_image.import_logo_insights", - return_value=1, + return_value=InsightImportResult(), ) with peewee_db: diff --git a/tests/unit/insights/test_importer.py b/tests/unit/insights/test_importer.py index 85e2e63379..d1bb6f0937 100644 --- a/tests/unit/insights/test_importer.py +++ b/tests/unit/insights/test_importer.py @@ -24,7 +24,7 @@ from robotoff.prediction.types import Prediction from robotoff.products import Product from robotoff.taxonomy import get_taxonomy -from robotoff.types import InsightType, PredictionType +from robotoff.types import InsightType, PredictionType, ProductInsightImportResult DEFAULT_BARCODE = "3760094310634" DEFAULT_SERVER_DOMAIN = "api.openfoodfacts.org" @@ -667,13 +667,15 @@ def generate_insights( batch_insert_mock = mocker.patch( "robotoff.insights.importer.batch_insert", return_value=1 ) - imported = FakeImporter.import_insights( + import_result = FakeImporter.import_insights( DEFAULT_BARCODE, [Prediction(type=PredictionType.label)], DEFAULT_SERVER_DOMAIN, product_store=FakeProductStore(), ) - assert imported == 1 + assert len(import_result.insight_created_ids) == 1 + assert len(import_result.insight_updated_ids) == 0 + assert len(import_result.insight_deleted_ids) == 1 batch_insert_mock.assert_called_once() product_insight_delete_mock.assert_called_once() @@ -1165,15 +1167,17 @@ def test_import_insights_single_product(self, mocker): ) import_insights_mock = mocker.patch( "robotoff.insights.importer.InsightImporter.import_insights", - return_value=1, + return_value=ProductInsightImportResult( + [], [], [], DEFAULT_BARCODE, InsightType.category + ), ) product_store = FakeProductStore() - imported = import_insights_for_products( + import_result = import_insights_for_products( {DEFAULT_BARCODE: {PredictionType.category}}, DEFAULT_SERVER_DOMAIN, product_store=product_store, ) - assert imported == 1 + assert len(import_result) == 1 get_product_predictions_mock.assert_called_once() import_insights_mock.assert_called_once_with( DEFAULT_BARCODE, [prediction], DEFAULT_SERVER_DOMAIN, product_store @@ -1193,14 +1197,16 @@ def test_import_insights_type_mismatch(self, mocker): ) import_insights_mock = mocker.patch( "robotoff.insights.importer.InsightImporter.import_insights", - return_value=0, + return_value=ProductInsightImportResult( + [], [], [], DEFAULT_BARCODE, InsightType.image_orientation + ), ) product_store = FakeProductStore() - imported = import_insights_for_products( + import_results = import_insights_for_products( {DEFAULT_BARCODE: {PredictionType.image_orientation}}, DEFAULT_SERVER_DOMAIN, product_store=product_store, ) - assert imported == 0 + assert len(import_results) == 0 assert not get_product_predictions_mock.called assert not import_insights_mock.called diff --git a/tests/unit/workers/tasks/test_product_updated.py b/tests/unit/workers/tasks/test_product_updated.py index 581b036e96..31295a6027 100644 --- a/tests/unit/workers/tasks/test_product_updated.py +++ b/tests/unit/workers/tasks/test_product_updated.py @@ -1,6 +1,6 @@ from robotoff import settings from robotoff.prediction.types import Prediction -from robotoff.types import PredictionType +from robotoff.types import InsightImportResult, PredictionType from robotoff.workers.tasks.product_updated import add_category_insight # TODO: refactor function under test to make it easier to test @@ -46,10 +46,10 @@ def test_add_category_insight_with_ml_insights(mocker): ) import_insights_mock = mocker.patch( "robotoff.workers.tasks.product_updated.import_insights", - return_value=1, + return_value=InsightImportResult(), ) server_domain = settings.BaseURLProvider().get() - imported = add_category_insight("123", {"code": "123"}, server_domain) + add_category_insight("123", {"code": "123"}, server_domain) import_insights_mock.assert_called_once_with( [ @@ -64,5 +64,3 @@ def test_add_category_insight_with_ml_insights(mocker): ], server_domain, ) - - assert imported