From 0e02f7b125f93fb6ed72c850d4127f4c05e5c3dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Bournhonesque?= Date: Mon, 19 Dec 2022 15:28:13 +0100 Subject: [PATCH] fix: add integration tests to process_created_logos --- robotoff/workers/tasks/import_image.py | 2 +- tests/integration/models_utils.py | 2 +- tests/integration/test_import_image.py | 49 ++++++++++++++++++++++++-- 3 files changed, 49 insertions(+), 4 deletions(-) diff --git a/robotoff/workers/tasks/import_image.py b/robotoff/workers/tasks/import_image.py index c288c8a54e..739baf937c 100644 --- a/robotoff/workers/tasks/import_image.py +++ b/robotoff/workers/tasks/import_image.py @@ -386,7 +386,7 @@ def save_logo_embeddings(logos: list[LogoAnnotation], image: Image.Image): @with_db def process_created_logos(image_prediction_id: int, server_domain: str): - logo_embeddings = ( + logo_embeddings = list( LogoEmbedding.select() .join(LogoAnnotation) .join(ImagePrediction) diff --git a/tests/integration/models_utils.py b/tests/integration/models_utils.py index b0b2e8d77c..f6195ec957 100644 --- a/tests/integration/models_utils.py +++ b/tests/integration/models_utils.py @@ -138,7 +138,7 @@ class Meta: model = LogoEmbedding logo = factory.SubFactory(LogoAnnotation) - embedding = factory.LazyFunction(lambda: np.random.rand(512)) + embedding = factory.LazyFunction(lambda: np.random.rand(512).tobytes()) def clean_db(): diff --git a/tests/integration/test_import_image.py b/tests/integration/test_import_image.py index e73af7bf5b..08d832ec56 100644 --- a/tests/integration/test_import_image.py +++ b/tests/integration/test_import_image.py @@ -2,10 +2,19 @@ import pytest from PIL import Image +from robotoff import settings from robotoff.models import LogoEmbedding -from robotoff.workers.tasks.import_image import save_logo_embeddings +from robotoff.workers.tasks.import_image import ( + process_created_logos, + save_logo_embeddings, +) -from .models_utils import ImagePredictionFactory, LogoAnnotationFactory, clean_db +from .models_utils import ( + ImagePredictionFactory, + LogoAnnotationFactory, + LogoEmbeddingFactory, + clean_db, +) @pytest.fixture(autouse=True) @@ -51,3 +60,39 @@ def test_save_logo_embeddings(peewee_db, mocker): logo_id_to_logo_embedding[logo.id].embedding, dtype=np.float32 ).reshape((1, 512)) assert (embedding == expected_embeddings[i]).all() + + +def test_process_created_logos(peewee_db, mocker): + add_logos_to_ann_mock = mocker.patch( + "robotoff.workers.tasks.import_image.add_logos_to_ann", + return_value=None, + ) + save_nearest_neighbors_mock = mocker.patch( + "robotoff.workers.tasks.import_image.save_nearest_neighbors", + return_value=None, + ) + get_logo_confidence_thresholds_mock = mocker.patch( + "robotoff.workers.tasks.import_image.get_logo_confidence_thresholds", + return_value=dict, + ) + import_logo_insights_mock = mocker.patch( + "robotoff.workers.tasks.import_image.import_logo_insights", + return_value=1, + ) + + with peewee_db: + image_prediction = ImagePredictionFactory() + logos = [ + LogoAnnotationFactory(image_prediction=image_prediction, index=i) + for i in range(5) + ] + logo_embeddings = [LogoEmbeddingFactory(logo=logo) for logo in logos] + process_created_logos( + image_prediction.id, server_domain=settings.OFF_SERVER_DOMAIN + ) + add_logos_to_ann_mock.assert_called() + embedding_args = add_logos_to_ann_mock.mock_calls[0].args[0] + assert sorted(embedding_args, key=lambda x: x.logo_id) == logo_embeddings + save_nearest_neighbors_mock.assert_called() + get_logo_confidence_thresholds_mock.assert_called() + import_logo_insights_mock.assert_called()