Skip to content

Commit

Permalink
feat: add a threshold parameter to /predict/category endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
raphael0202 committed Oct 4, 2022
1 parent 94a1057 commit 0f68e93
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 13 deletions.
8 changes: 5 additions & 3 deletions robotoff/app/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,16 +402,18 @@ def on_get(self, req: falcon.Request, resp: falcon.Response):

class CategoryPredictorResource:
def on_get(self, req: falcon.Request, resp: falcon.Response):
barcode = req.get_param("barcode", required=True)
deepest_only = req.get_param_as_bool("deepest_only", default=False)
"""Predict categories using neural categorizer for a specific product."""
barcode: str = req.get_param("barcode", required=True)
deepest_only: bool = req.get_param_as_bool("deepest_only", default=False)
threshold: Optional[float] = req.get_param_as_float("threshold", default=None)

categories = []

product = get_product(barcode)
if product:
predictions = CategoryClassifier(
get_taxonomy(TaxonomyType.category.name)
).predict(product, deepest_only)
).predict(product, deepest_only, threshold)
categories = [p.to_dict() for p in predictions]

resp.media = {"categories": categories}
Expand Down
33 changes: 23 additions & 10 deletions robotoff/prediction/category/neural/category_classifier.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List
from typing import Dict, List, Optional

from robotoff import settings
from robotoff.prediction.types import Prediction, PredictionType
Expand Down Expand Up @@ -44,15 +44,28 @@ class CategoryClassifier:
def __init__(self, category_taxonomy: Taxonomy):
self.taxonomy = category_taxonomy

def predict(self, product: Dict, deepest_only: bool = False) -> List[Prediction]:
"""Returns an unordered list of category predictions for the given product.
:param deepest_only: controls whether the returned list should only contain the deepmost categories
for a predicted taxonomy chain.
For example, if we predict 'fresh vegetables' -> 'legumes' -> 'beans' for a product,
def predict(
self,
product: Dict,
deepest_only: bool = False,
threshold: Optional[float] = None,
) -> List[Prediction]:
"""Returns an unordered list of category predictions for the given
product.
:param product: the product to predict the categories from, should
have at least `product_name` and `ingredients_tags` fields
:param deepest_only: controls whether the returned list should only
contain the deepmost categories for a predicted taxonomy chain.
For example, if we predict 'fresh vegetables' -> 'legumes' ->
'beans' for a product,
setting deepest_only=True will return ['beans'].
:param threshold: the score above which we consider the category to be
detected (default: 0.5)
"""
if threshold is None:
threshold = 0.5

# model was train with product having a name
if not product.get("product_name"):
Expand Down Expand Up @@ -95,9 +108,9 @@ def predict(self, product: Dict, deepest_only: bool = False) -> List[Prediction]

category_predictions = []

# We only consider predictions with a confidence score of 0.5 and above.
# We only consider predictions with a confidence score of `threshold` and above.
for idx, confidence in enumerate(prediction["output_mapper_layer"]):
if confidence >= 0.5:
if confidence >= threshold:
category_predictions.append(
CategoryPrediction(
category=prediction["output_mapper_layer_1"][idx],
Expand Down

0 comments on commit 0f68e93

Please sign in to comment.