diff --git a/robotoff/app/api.py b/robotoff/app/api.py index 995b81c2bb..28a53c93db 100644 --- a/robotoff/app/api.py +++ b/robotoff/app/api.py @@ -16,6 +16,7 @@ from falcon.media.validators import jsonschema from falcon_cors import CORS from falcon_multipart.middleware import MultipartMiddleware +from openfoodfacts import OCRResult from openfoodfacts.ocr import OCRParsingException, OCRResultGenerationException from openfoodfacts.types import COUNTRY_CODE_TO_NAME, Country from PIL import Image @@ -627,16 +628,8 @@ def on_get(self, req: falcon.Request, resp: falcon.Response): ) ocr_url = req.get_param("ocr_url", required=True) - aggregation_strategy = req.get_param("aggregation_strategy", default="FIRST") - model_version = req.get_param("model_version", default="1") try: - output = ingredient_list.predict_from_ocr( - ocr_url, - aggregation_strategy=ingredient_list.AggregationStrategy[ - aggregation_strategy - ], - model_version=model_version, - ) + ocr_result = OCRResult.from_url(ocr_url, http_session, error_raise=True) except OCRResultGenerationException as e: error_message, _ = e.args resp.media = { @@ -644,7 +637,27 @@ def on_get(self, req: falcon.Request, resp: falcon.Response): "description": error_message, } return - resp.media = dataclasses.asdict(output) + + aggregation_strategy = req.get_param("aggregation_strategy", default="FIRST") + model_version = req.get_param("model_version", default="1") + output = ingredient_list.predict_from_ocr( + ocr_result, + aggregation_strategy=ingredient_list.AggregationStrategy[ + aggregation_strategy + ], + model_version=model_version, + ) + + output_dict = dataclasses.asdict(output) + + if aggregation_strategy != "NONE": + # Add bounding boxes to entities + for entity in output_dict["entities"]: + entity["bounding_boxes"] = ocr_result.get_match_bounding_box( + entity["start"], entity["end"] + ) + + resp.media = output_dict class UpdateDatasetResource: