Skip to content

Commit

Permalink
fix: improve /predict/nutrition route
Browse files Browse the repository at this point in the history
  • Loading branch information
raphael0202 committed May 19, 2023
1 parent f06a45a commit 95d953b
Show file tree
Hide file tree
Showing 2 changed files with 171 additions and 64 deletions.
80 changes: 56 additions & 24 deletions doc/references/api.yml
Original file line number Diff line number Diff line change
Expand Up @@ -918,34 +918,42 @@ paths:
get:
tags:
- Predict
summary: Extract nutritional information from an OCR JSON
summary: Extract nutritional information from an image
description: |
We currently only use the OCR text as input, and detect nutrient-value pairs if they are consecutive
to each other in the OCR text (ex: "protein: 10.5g, fat: 2.1g").
parameters:
- name: ocr_url
- $ref: "#/components/parameters/barcode"
- $ref: "#/components/parameters/server_type"
- name: image_ids
in: query
required: true
description: The URL of the OCR JSON to use for extraction
required: false
description: |
a comma-separated list of IDs of images to extract nutritional information from.
If not provided, the 10 most recent images will be used.
schema:
type: string
example: https://static.openfoodfacts.org/images/products/216/124/000/3089/1.json
format: uri
example: "1,2,5"
responses:
"200":
description: the extracted nutritional information or an error message
description: the extracted nutritional information
content:
application/json:
schema:
oneOf:
- type: object
title: Successful response
description: the extracted nutritional information
type: object
description: |
the extracted nutritional information from provided (or 10 most recent by default) images.
The predictions are ordered, predictions from most recent images come first.
properties:
predictions:
type: object
properties:
nutrients:
type: object
description: |
a dict mapping nutrient name (`energy`, `fat`,...) to a dict containing the following fields:
a dict mapping nutrient name (`energy`, `fat`,...) to a list of dict containing detected nutritional information.
The list contains as many elements as the number of detected values for this nutrient. Each element of the list
has the following fields:
- `raw`: string of the full detected pattern (ex: `Valeur énergétique: 245 kj`)
- `nutrient`: nutrient mention (`energy`, `saturated_fat`,...)
- `value`: nutrient value, should be an number (example: `245`)
Expand All @@ -958,9 +966,35 @@ paths:
type: string
description: predictor used to generate this prediction
example: "regex"
source_image:
type: string
description: the path of the image the nutrient prediction was generated from
required:
- "nutrients"
- $ref: "#/components/schemas/ErrorResponse"
- "predictor"
- "predictor_version"
- "source_image"
image_ids:
type: array
description: |
list of the IDs of images that were used as input analyzed
items:
type: number
errors:
type: array
description: a list of errors that occured during processing
items:
type: object
properties:
error:
type: string
description: the identifier of the error
error_description:
type: string
description: a full description of the error that occured
required:
- "predictions"
- "image_ids"
"400":
description: "An HTTP 400 is returned if the provided parameters are invalid"
/predict/ocr_prediction:
Expand Down Expand Up @@ -1013,17 +1047,6 @@ paths:

components:
schemas:
ErrorResponse:
type: object
title: Error response
description: this response is returned in case of error
properties:
error:
type: string
description: the identifier of the error
error_description:
type: string
description: a full description of the error that occured
LogoANNSearchResponse:
type: object
properties:
Expand Down Expand Up @@ -1207,6 +1230,7 @@ components:
server_type:
name: server_type
in: query
required: false
description: The server type (=project) to use, such as 'off' (Open Food Facts), 'obf' (Open Beauty Facts),...
schema:
type: string
Expand Down Expand Up @@ -1298,6 +1322,14 @@ components:
schema:
type: integer
example: 5410041040807
barcode:
name: barcode
in: query
required: true
description: The barcode of the product
schema:
type: integer
example: 5410041040807
tags:
- name: Questions
- name: Insights
Expand Down
155 changes: 115 additions & 40 deletions robotoff/app/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,13 @@
from robotoff.off import (
OFFAuthentication,
generate_image_path,
generate_json_ocr_url,
get_barcode_from_url,
get_product,
)
from robotoff.prediction.category import predict_category
from robotoff.prediction.object_detection import ObjectDetectionModelRegistry
from robotoff.prediction.ocr.dataclass import OCRParsingException
from robotoff.products import get_product_dataset_etag
from robotoff.products import get_image_id, get_product, get_product_dataset_etag
from robotoff.spellcheck import SPELLCHECKERS, Spellchecker
from robotoff.taxonomy import is_prefixed_value, match_taxonomized_value
from robotoff.types import (
Expand Down Expand Up @@ -387,53 +387,128 @@ def spellcheck(self, req: falcon.Request, resp: falcon.Response):

class NutritionPredictorResource:
def on_get(self, req: falcon.Request, resp: falcon.Response):
ocr_url = req.get_param("ocr_url", required=True)
barcode = req.get_param("barcode", required=True)
# we transform image IDs to int to be sure to have "raw" image IDs as
# input
image_ids = req.get_param_as_list("image_ids", required=False, transform=int)

if not ocr_url.endswith(".json"):
raise falcon.HTTPBadRequest("a JSON file is expected")
if image_ids is not None:
# convert image IDs back to string
image_ids = set(str(x) for x in image_ids)

barcode = get_barcode_from_url(ocr_url)
server_type = get_server_type_from_req(req)

if barcode is None:
raise falcon.HTTPBadRequest(f"invalid OCR URL: {ocr_url}")
if server_type not in (ServerType.off, ServerType.off_pro):
raise falcon.HTTPBadRequest(f"invalid server type: {server_type}")

try:
predictions = extract_ocr_predictions(
ProductIdentifier(
barcode,
# Nutritional values only makes sense for off
ServerType.off,
product_id = ProductIdentifier(barcode, server_type)
product = get_product(product_id, ["images"])

if product is None:
raise falcon.HTTPBadRequest(f"product not found: {barcode}")

errors = []
existing_image_ids: set[str] = set(x for x in product["images"] if x.isdigit())
# We only keep image IDs that actually exist for this product
if image_ids:
target_image_ids = existing_image_ids & image_ids
if missing_image_ids := image_ids - existing_image_ids:
for missing_image_id in missing_image_ids:
errors.append(
{
"error": "unknown_image",
"error_description": f"the image {missing_image_id} for product "
f"{product_id} does not exist",
}
)
else:
target_image_ids = existing_image_ids

# We only keep the 10 most recent images (a higher image ID means that
# the image is more recent)
target_image_ids = sorted(target_image_ids, key=lambda x: int(x), reverse=True)[
:10
]
predictions: list[JSONType] = []

# Fetch existing predictions in DB to avoid recomputing them
existing_predictions = list(
Prediction.select(
Prediction.data["nutrients"].as_json().alias("nutrients"),
Prediction.source_image,
Prediction.predictor,
Prediction.predictor_version,
)
.where(
Prediction.barcode == barcode,
Prediction.server_type == server_type.name,
Prediction.type == PredictionType.nutrient.name,
Prediction.source_image.in_(
[
generate_image_path(barcode, image_id)
for image_id in target_image_ids
]
),
ocr_url,
[PredictionType.nutrient],
)
.dicts()
)
predictions += existing_predictions

except requests.exceptions.RequestException:
resp.media = {
"error": "download_error",
"error_description": "an error occurred during OCR JSON download",
}
return
# Remove predictions of image IDs that were fetched from DB
remaining_image_ids = set(target_image_ids) - set(
get_image_id(p["source_image"]) for p in existing_predictions
)
for image_id in remaining_image_ids:
# Perform detection on remaining images
ocr_url = generate_json_ocr_url(product_id, image_id)
try:
predictions += [
{
# Only keep relevant fields for nutrition information
"nutrients": p.data["nutrients"],
"source_image": p.source_image,
"predictor": p.predictor,
"predictor_version": p.predictor_version,
}
for p in extract_ocr_predictions(
product_id,
ocr_url,
[PredictionType.nutrient],
)
]

except requests.exceptions.RequestException:
errors.append(
{
"error": "download_error",
"error_description": f"an error occurred during OCR JSON download: {ocr_url}",
}
)

except OCRParsingException as e:
logger.error(e)
resp.media = {
"error": "invalid_ocr",
"error_description": "an error occurred during OCR parsing",
}
return
except OCRParsingException as e:
logger.error(e)
errors.append(
{
"error": "invalid_ocr",
"error_description": f"an error occurred during OCR parsing: {ocr_url}",
}
)

# `predictions` is either empty or contains a single item
# (see `find_nutrient_values` in robotoff.prediction.ocr.nutrient)
if not predictions:
resp.media = {"nutrients": {}}
else:
prediction = predictions[0]
resp.media = {
"nutrients": prediction.data["nutrients"],
"predictor": prediction.predictor,
"predictor_version": prediction.predictor_version,
}
# Sort predictions to have most recent images (higher image ID) first
predictions = sorted(
predictions,
key=lambda x: int(get_image_id(x["source_image"])), # type: ignore
reverse=True,
)
response = {
"predictions": predictions,
"image_ids": target_image_ids,
}

if errors:
response["errors"] = errors

resp.media = response


def transform_to_prediction_type(value: str) -> PredictionType:
Expand Down

0 comments on commit 95d953b

Please sign in to comment.