Skip to content

Commit

Permalink
fix: create annotate function to centralize annotation
Browse files Browse the repository at this point in the history
  • Loading branch information
raphael0202 committed Dec 28, 2022
1 parent 20aab5a commit b55d72c
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 20 deletions.
5 changes: 2 additions & 3 deletions robotoff/app/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
SAVED_ANNOTATION_VOTE_RESULT,
UNKNOWN_INSIGHT_RESULT,
AnnotationResult,
InsightAnnotatorFactory,
annotate,
)
from robotoff.models import (
AnnotationVote,
Expand Down Expand Up @@ -385,8 +385,7 @@ def save_annotation(
if not verified:
return SAVED_ANNOTATION_VOTE_RESULT

annotator = InsightAnnotatorFactory.get(insight.type)
result = annotator.annotate(insight, annotation, update, data=data, auth=auth)
result = annotate(insight, annotation, update, data=data, auth=auth)
username = auth.get_username() if auth else "unknown annotator"
events.event_processor.send_async(
"question_answered", username, device_id, insight.barcode
Expand Down
19 changes: 19 additions & 0 deletions robotoff/insights/annotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,5 +504,24 @@ def get(cls, identifier: str) -> InsightAnnotator:
return cls.mapping[identifier]


def annotate(
insight: ProductInsight,
annotation: int,
update: bool = True,
data: Optional[dict] = None,
auth: Optional[OFFAuthentication] = None,
automatic: bool = False,
) -> AnnotationResult:
annotator = InsightAnnotatorFactory.get(insight.type)
return annotator.annotate(
insight=insight,
annotation=annotation,
update=update,
data=data,
auth=auth,
automatic=automatic,
)


class InvalidInsight(Exception):
pass
8 changes: 2 additions & 6 deletions robotoff/logos.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,7 @@

from robotoff import settings
from robotoff.elasticsearch import get_es_client
from robotoff.insights.annotate import (
UPDATED_ANNOTATION_RESULT,
InsightAnnotatorFactory,
)
from robotoff.insights.annotate import UPDATED_ANNOTATION_RESULT, annotate
from robotoff.insights.importer import import_insights
from robotoff.logo_label_type import LogoLabelType
from robotoff.models import (
Expand Down Expand Up @@ -397,11 +394,10 @@ def generate_insights_from_annotated_logos(
):
insight = ProductInsight.get_or_none(id=created_id)
if insight:
annotator = InsightAnnotatorFactory.get(insight.type)
logger.info(
"Annotating insight %s (product: %s)", insight.id, insight.barcode
)
annotation_result = annotator.annotate(insight, 1, auth=auth)
annotation_result = annotate(insight, 1, auth=auth)
annotated += int(annotation_result == UPDATED_ANNOTATION_RESULT)

return annotated
Expand Down
8 changes: 2 additions & 6 deletions robotoff/scheduler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,7 @@
from robotoff import settings, slack
from robotoff.elasticsearch import get_es_client
from robotoff.elasticsearch.export import ElasticsearchExporter
from robotoff.insights.annotate import (
UPDATED_ANNOTATION_RESULT,
InsightAnnotatorFactory,
)
from robotoff.insights.annotate import UPDATED_ANNOTATION_RESULT, annotate
from robotoff.insights.importer import import_insights
from robotoff.metrics import (
ensure_influx_database,
Expand Down Expand Up @@ -56,11 +53,10 @@ def process_insights():
.iterator()
):
try:
annotator = InsightAnnotatorFactory.get(insight.type)
logger.info(
"Annotating insight %s (product: %s)", insight.id, insight.barcode
)
annotation_result = annotator.annotate(insight, 1, update=True)
annotation_result = annotate(insight, 1, update=True)
processed += 1

if annotation_result == UPDATED_ANNOTATION_RESULT and insight.data.get(
Expand Down
13 changes: 8 additions & 5 deletions tests/integration/test_annotate_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import robotoff.insights.importer
import robotoff.taxonomy
from robotoff.app.api import api
from robotoff.insights.annotate import UPDATED_ANNOTATION_RESULT
from robotoff.models import LogoAnnotation, Prediction, ProductInsight
from robotoff.products import Product

Expand Down Expand Up @@ -164,6 +165,7 @@ def test_logo_annotation_brand(client, peewee_db, monkeypatch, mocker, fake_taxo
mocker.patch(
"robotoff.brands.get_brand_prefix", return_value={("Etorki", "0000000xxxxxx")}
)
mocker.patch("robotoff.logos.annotate", return_value=UPDATED_ANNOTATION_RESULT)
start = datetime.utcnow()
result = client.simulate_post(
"/api/v1/images/logos/annotate",
Expand Down Expand Up @@ -203,7 +205,7 @@ def test_logo_annotation_brand(client, peewee_db, monkeypatch, mocker, fake_taxo
assert prediction.value_tag == "Etorki"
assert prediction.predictor == "universal-logo-detector"
assert start <= prediction.timestamp <= end
assert prediction.automatic_processing
assert prediction.automatic_processing is False
# We check that this prediction in turn generates an insight

with peewee_db:
Expand All @@ -222,12 +224,12 @@ def test_logo_annotation_brand(client, peewee_db, monkeypatch, mocker, fake_taxo
assert insight.value_tag == "Etorki"
assert insight.predictor == "universal-logo-detector"
assert start <= prediction.timestamp <= end
assert insight.automatic_processing
assert insight.automatic_processing is False
assert insight.username == "a"
assert insight.completed_at is None # we did not run annotate yet


def test_logo_annotation_label(client, peewee_db, monkeypatch, fake_taxonomy):
def test_logo_annotation_label(client, peewee_db, monkeypatch, fake_taxonomy, mocker):
"""This test will check that, given an image with a logo above the confidence threshold,
that is then fed into the ANN logos and labels model, we annotate properly a product.
"""
Expand All @@ -237,6 +239,7 @@ def test_logo_annotation_label(client, peewee_db, monkeypatch, fake_taxonomy):
)
barcode = ann.image_prediction.image.barcode
_fake_store(monkeypatch, barcode)
mocker.patch("robotoff.logos.annotate", return_value=UPDATED_ANNOTATION_RESULT)
start = datetime.utcnow()
result = client.simulate_post(
"/api/v1/images/logos/annotate",
Expand Down Expand Up @@ -276,7 +279,7 @@ def test_logo_annotation_label(client, peewee_db, monkeypatch, fake_taxonomy):
assert prediction.value_tag == "en:eu-organic"
assert prediction.predictor == "universal-logo-detector"
assert start <= prediction.timestamp <= end
assert prediction.automatic_processing
assert prediction.automatic_processing is False
# We check that this prediction in turn generates an insight
with peewee_db:
insights = list(ProductInsight.select().filter(barcode=barcode).execute())
Expand All @@ -294,6 +297,6 @@ def test_logo_annotation_label(client, peewee_db, monkeypatch, fake_taxonomy):
assert insight.value_tag == "en:eu-organic"
assert insight.predictor == "universal-logo-detector"
assert start <= prediction.timestamp <= end
assert insight.automatic_processing
assert insight.automatic_processing is False
assert insight.username == "a"
assert insight.completed_at is None

0 comments on commit b55d72c

Please sign in to comment.