Skip to content

Commit

Permalink
fix: add integration tests to process_created_logos
Browse files Browse the repository at this point in the history
  • Loading branch information
raphael0202 committed Dec 22, 2022
1 parent a3aa2c4 commit 0e02f7b
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 4 deletions.
2 changes: 1 addition & 1 deletion robotoff/workers/tasks/import_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def save_logo_embeddings(logos: list[LogoAnnotation], image: Image.Image):

@with_db
def process_created_logos(image_prediction_id: int, server_domain: str):
logo_embeddings = (
logo_embeddings = list(
LogoEmbedding.select()
.join(LogoAnnotation)
.join(ImagePrediction)
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/models_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ class Meta:
model = LogoEmbedding

logo = factory.SubFactory(LogoAnnotation)
embedding = factory.LazyFunction(lambda: np.random.rand(512))
embedding = factory.LazyFunction(lambda: np.random.rand(512).tobytes())


def clean_db():
Expand Down
49 changes: 47 additions & 2 deletions tests/integration/test_import_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,19 @@
import pytest
from PIL import Image

from robotoff import settings
from robotoff.models import LogoEmbedding
from robotoff.workers.tasks.import_image import save_logo_embeddings
from robotoff.workers.tasks.import_image import (
process_created_logos,
save_logo_embeddings,
)

from .models_utils import ImagePredictionFactory, LogoAnnotationFactory, clean_db
from .models_utils import (
ImagePredictionFactory,
LogoAnnotationFactory,
LogoEmbeddingFactory,
clean_db,
)


@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -51,3 +60,39 @@ def test_save_logo_embeddings(peewee_db, mocker):
logo_id_to_logo_embedding[logo.id].embedding, dtype=np.float32
).reshape((1, 512))
assert (embedding == expected_embeddings[i]).all()


def test_process_created_logos(peewee_db, mocker):
add_logos_to_ann_mock = mocker.patch(
"robotoff.workers.tasks.import_image.add_logos_to_ann",
return_value=None,
)
save_nearest_neighbors_mock = mocker.patch(
"robotoff.workers.tasks.import_image.save_nearest_neighbors",
return_value=None,
)
get_logo_confidence_thresholds_mock = mocker.patch(
"robotoff.workers.tasks.import_image.get_logo_confidence_thresholds",
return_value=dict,
)
import_logo_insights_mock = mocker.patch(
"robotoff.workers.tasks.import_image.import_logo_insights",
return_value=1,
)

with peewee_db:
image_prediction = ImagePredictionFactory()
logos = [
LogoAnnotationFactory(image_prediction=image_prediction, index=i)
for i in range(5)
]
logo_embeddings = [LogoEmbeddingFactory(logo=logo) for logo in logos]
process_created_logos(
image_prediction.id, server_domain=settings.OFF_SERVER_DOMAIN
)
add_logos_to_ann_mock.assert_called()
embedding_args = add_logos_to_ann_mock.mock_calls[0].args[0]
assert sorted(embedding_args, key=lambda x: x.logo_id) == logo_embeddings
save_nearest_neighbors_mock.assert_called()
get_logo_confidence_thresholds_mock.assert_called()
import_logo_insights_mock.assert_called()

0 comments on commit 0e02f7b

Please sign in to comment.