Skip to content

Commit

Permalink
fix: improve resources caching (#1492)
Browse files Browse the repository at this point in the history
  • Loading branch information
raphael0202 authored Dec 9, 2024
1 parent 946c73b commit bb72dce
Show file tree
Hide file tree
Showing 20 changed files with 803 additions and 648 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@ repos:
- id: flake8

- repo: /~https://github.com/timothycrosley/isort
rev: "5.9.3"
rev: "5.13.2"
hooks:
- id: isort
1,311 changes: 688 additions & 623 deletions poetry.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ more-itertools = "~8.9.0"
matplotlib = "~3.9.1"
typer = "~0.7.0"
py-healthcheck = "^1.10.1"
cachetools = "^5.2.0"
cachetools = "~5.5.0"
tritonclient = {extras = ["grpc"], version = "2.38.0"}
rq = "~1.11.1"
python-redis-lock = "~4.0.0"
Expand Down Expand Up @@ -98,7 +98,7 @@ pytest = "~7.2.0"
pytest-mock = "~3.10.0"
pre-commit = "~2.20.0"
toml-sort = "~0.20.1"
isort = "~5.9.3"
isort = "~5.13.2"
flake8-bugbear = "~22.10.27"
flake8-github-actions = "~0.1.1"
pytest-cov = "~4.0.0"
Expand Down
4 changes: 3 additions & 1 deletion robotoff/app/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
update_logo_annotations,
validate_params,
)
from robotoff.app.middleware import DBConnectionMiddleware
from robotoff.app.middleware import CacheClearMiddleware, DBConnectionMiddleware
from robotoff.batch import import_batch_predictions
from robotoff.elasticsearch import get_es_client
from robotoff.insights.extraction import (
Expand Down Expand Up @@ -1849,6 +1849,8 @@ def custom_handle_uncaught_exception(
middleware=[
falcon.CORSMiddleware(allow_origins="*", allow_credentials="*"),
DBConnectionMiddleware(),
# Clear cache after the request, to keep RAM usage low
CacheClearMiddleware(),
],
)

Expand Down
6 changes: 6 additions & 0 deletions robotoff/app/middleware.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from robotoff.models import db
from robotoff.utils.cache import function_cache_register


class DBConnectionMiddleware:
Expand All @@ -8,3 +9,8 @@ def process_resource(self, req, resp, resource, params):
def process_response(self, req, resp, resource, req_succeeded):
if not db.is_closed():
db.close()


class CacheClearMiddleware:
def process_response(self, req, resp, resource, req_succeeded):
function_cache_register.clear_all()
4 changes: 4 additions & 0 deletions robotoff/brands.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
load_json,
text_file_iter,
)
from robotoff.utils.cache import function_cache_register

logger = get_logger(__name__)

Expand Down Expand Up @@ -166,6 +167,9 @@ def load_resources():
get_brand_blacklist()


function_cache_register.register(get_brand_prefix)
function_cache_register.register(get_brand_blacklist)

if __name__ == "__main__":
blacklisted_brands = get_brand_blacklist()
dump_taxonomy_brands(
Expand Down
4 changes: 4 additions & 0 deletions robotoff/insights/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
ServerType,
)
from robotoff.utils import get_logger, text_file_iter
from robotoff.utils.cache import function_cache_register

logger = get_logger(__name__)

Expand Down Expand Up @@ -2166,3 +2167,6 @@ def get_product_predictions(
where_clauses.append(PredictionModel.type.in_(prediction_types))

yield from PredictionModel.select().where(*where_clauses).dicts().iterator()


function_cache_register.register(get_authorized_labels)
13 changes: 10 additions & 3 deletions robotoff/logos.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import datetime
import functools
import itertools
import operator
from typing import Optional

import cachetools
import elasticsearch
import numpy as np
from cachetools.func import ttl_cache
from elasticsearch.helpers import bulk as elasticsearch_bulk
from elasticsearch.helpers import scan as elasticsearch_scan
from more_itertools import chunked
Expand Down Expand Up @@ -34,6 +35,7 @@
ServerType,
)
from robotoff.utils import get_logger
from robotoff.utils.cache import function_cache_register
from robotoff.utils.text import get_tag

logger = get_logger(__name__)
Expand Down Expand Up @@ -106,7 +108,7 @@ def filter_logos(
return filtered


@cachetools.cached(cachetools.LRUCache(maxsize=1))
@functools.cache
def get_logo_confidence_thresholds() -> dict[LogoLabelType, float]:
logger.debug("Loading logo confidence thresholds from DB...")
thresholds = {}
Expand Down Expand Up @@ -245,7 +247,8 @@ def knn_search(
return []


@cachetools.cached(cachetools.TTLCache(maxsize=1, ttl=3600)) # 1h
# ttl: 1h
@ttl_cache(maxsize=1, ttl=3600)
def get_logo_annotations() -> dict[int, LogoLabelType]:
logger.debug("Loading logo annotations from DB...")
annotations: dict[int, LogoLabelType] = {}
Expand Down Expand Up @@ -634,3 +637,7 @@ def refresh_nearest_neighbors(
)

logger.info("refresh of logo nearest neighbors finished")


function_cache_register.register(get_logo_confidence_thresholds)
function_cache_register.register(get_logo_annotations)
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
)
from robotoff.types import JSONType, NeuralCategoryClassifierModel, ProductIdentifier
from robotoff.utils import get_image_from_url, get_logger, http_session
from robotoff.utils.cache import function_cache_register

from .preprocessing import (
IMAGE_EMBEDDING_DIM,
MAX_IMAGE_EMBEDDING,
NUTRIMENT_NAMES,
clear_ingredient_processing_cache,
generate_inputs_dict,
)

Expand Down Expand Up @@ -325,7 +325,8 @@ def predict(
break

if clear_cache:
clear_ingredient_processing_cache()
function_cache_register.clear("get_ingredient_taxonomy")
function_cache_register.clear("get_ingredient_processor")

return category_predictions, debug

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from robotoff import settings
from robotoff.taxonomy import Taxonomy
from robotoff.types import JSONType
from robotoff.utils.cache import function_cache_register
from robotoff.utils.text import KeywordProcessor

from .text_utils import fold, get_tag
Expand Down Expand Up @@ -46,19 +47,6 @@ def get_ingredient_processor():
)


def clear_ingredient_processing_cache():
"""Clear all ingredient processing cache:
- Ingredient processor
- Model ingredient taxonomy
As these resources are memory-hungry, it should be cleared from memory if
not used anymore.
"""
get_ingredient_taxonomy.cache_clear()
get_ingredient_processor.cache_clear()


def generate_inputs_dict(
product: JSONType,
ocr_texts: list[str],
Expand Down Expand Up @@ -339,3 +327,7 @@ def extract_ingredient_from_text(
"""
text = fold(text.lower())
return processor.extract_keywords(text, span_info=True)


function_cache_register.register(get_ingredient_taxonomy)
function_cache_register.register(get_ingredient_processor)
4 changes: 4 additions & 0 deletions robotoff/prediction/ingredient_list/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from robotoff.prediction.langid import LanguagePrediction, predict_lang_batch
from robotoff.triton import GRPCInferenceServiceStub, get_triton_inference_stub
from robotoff.utils import http_session
from robotoff.utils.cache import function_cache_register

from .transformers_pipeline import AggregationStrategy, TokenClassificationPipeline

Expand Down Expand Up @@ -300,3 +301,6 @@ def build_triton_request(
request.raw_input_contents.extend([attention_mask.tobytes()])

return request


function_cache_register.register(get_tokenizer)
4 changes: 4 additions & 0 deletions robotoff/prediction/ingredient_list/postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from lark import Discard, Lark, Transformer

from robotoff import settings
from robotoff.utils.cache import function_cache_register

ASTERISK_SYMBOL = r"((\* ?=?|\(¹\)|\") ?)"
FROM_ORGANIC_FARMING_FR = r"issus? de l'agriculture (biologique|bio|durable)"
Expand Down Expand Up @@ -140,3 +141,6 @@ def detect_trace_mention(text: str, end_idx: int) -> int:

end_idx += end_idx_offset
return end_idx


function_cache_register.register(load_trace_grammar)
5 changes: 5 additions & 0 deletions robotoff/prediction/nutrition_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
get_triton_inference_stub,
)
from robotoff.types import JSONType
from robotoff.utils.cache import function_cache_register
from robotoff.utils.logger import get_logger

logger = get_logger(__name__)
Expand Down Expand Up @@ -679,3 +680,7 @@ def build_triton_request(
add_triton_infer_input_tensor(request, "pixel_values", pixel_values, "FP32")

return request


function_cache_register.register(get_processor)
function_cache_register.register(get_id2label)
6 changes: 6 additions & 0 deletions robotoff/prediction/ocr/brand.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from robotoff.brands import get_brand_blacklist, keep_brand_from_taxonomy
from robotoff.types import Prediction, PredictionType
from robotoff.utils import get_logger, text_file_iter
from robotoff.utils.cache import function_cache_register
from robotoff.utils.text import KeywordProcessor, get_tag

from .utils import generate_keyword_processor
Expand Down Expand Up @@ -135,3 +136,8 @@ def find_brands(content: Union[OCRResult, str]) -> list[Prediction]:
predictions += extract_brands_google_cloud_vision(content)

return predictions


function_cache_register.register(get_logo_annotation_brands)
function_cache_register.register(get_taxonomy_brand_processor)
function_cache_register.register(get_brand_processor)
4 changes: 4 additions & 0 deletions robotoff/prediction/ocr/image_flag.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from robotoff import settings
from robotoff.types import Prediction, PredictionType
from robotoff.utils import text_file_iter
from robotoff.utils.cache import function_cache_register
from robotoff.utils.text import KeywordProcessor

# Increase version ID when introducing breaking change: changes for which we
Expand Down Expand Up @@ -133,3 +134,6 @@ def flag_image(content: Union[OCRResult, str]) -> list[Prediction]:
break

return predictions


function_cache_register.register(generate_image_flag_keyword_processor)
5 changes: 5 additions & 0 deletions robotoff/prediction/ocr/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from robotoff import settings
from robotoff.types import Prediction, PredictionType
from robotoff.utils import get_logger, text_file_iter
from robotoff.utils.cache import function_cache_register
from robotoff.utils.text import KeywordProcessor

from .utils import generate_keyword_processor
Expand Down Expand Up @@ -280,3 +281,7 @@ def find_labels(content: Union[OCRResult, str]) -> list[Prediction]:
)

return predictions


function_cache_register.register(get_logo_annotation_labels)
function_cache_register.register(generate_label_keyword_processor)
5 changes: 5 additions & 0 deletions robotoff/prediction/ocr/packaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from robotoff.taxonomy import TaxonomyType
from robotoff.types import PackagingElementProperty, Prediction, PredictionType
from robotoff.utils import get_logger, load_json
from robotoff.utils.cache import function_cache_register
from robotoff.utils.text import strip_consecutive_spaces

logger = get_logger(__name__)
Expand Down Expand Up @@ -237,3 +238,7 @@ def find_packaging(content: Union[OCRResult, str]) -> list[Prediction]:
return predictions

return []


function_cache_register.register(load_grammar)
function_cache_register.register(load_taxonomy_map)
6 changes: 6 additions & 0 deletions robotoff/prediction/ocr/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from robotoff import settings
from robotoff.types import Prediction, PredictionType
from robotoff.utils import text_file_iter
from robotoff.utils.cache import function_cache_register

# Increase version ID when introducing breaking change: changes for which we
# want old predictions to be removed in DB and replaced by newer ones
Expand Down Expand Up @@ -102,3 +103,8 @@ def find_stores(content: Union[OCRResult, str]) -> list[Prediction]:
break

return results


function_cache_register.register(get_sorted_stores)
function_cache_register.register(get_store_ocr_regex)
function_cache_register.register(get_notify_stores)
13 changes: 10 additions & 3 deletions robotoff/taxonomy.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import collections
from typing import Optional

import cachetools
from cachetools.func import ttl_cache
from openfoodfacts.taxonomy import Taxonomy
from openfoodfacts.taxonomy import get_taxonomy as _get_taxonomy
from openfoodfacts.types import TaxonomyType

from robotoff import settings
from robotoff.utils import get_logger
from robotoff.utils.cache import function_cache_register
from robotoff.utils.text import get_tag

logger = get_logger(__name__)
Expand Down Expand Up @@ -39,7 +40,8 @@ def generate_category_hierarchy(
return categories_hierarchy_list


@cachetools.cached(cache=cachetools.TTLCache(maxsize=100, ttl=12 * 60 * 60)) # 12h
# ttl: 12h
@ttl_cache(maxsize=100, ttl=12 * 60 * 60)
def get_taxonomy(taxonomy_type: TaxonomyType | str, offline: bool = False) -> Taxonomy:
"""Return the taxonomy of type `taxonomy_type`.
Expand Down Expand Up @@ -73,7 +75,8 @@ def is_prefixed_value(value: str) -> bool:
return len(value) > 3 and value[2] == ":"


@cachetools.cached(cachetools.TTLCache(maxsize=2, ttl=43200)) # 12h TTL
# ttl: 12h
@ttl_cache(maxsize=2, ttl=12 * 60 * 60)
def get_taxonomy_mapping(taxonomy_type: str) -> dict[str, str]:
"""Return for label type a mapping of prefixed taxonomy values in all
languages (such as `fr:bio-europeen` or `es:"ecologico-ue`) to their
Expand Down Expand Up @@ -121,3 +124,7 @@ def load_resources():

for taxonomy_type in (TaxonomyType.brand, TaxonomyType.label):
get_taxonomy_mapping(taxonomy_type.name)


function_cache_register.register(get_taxonomy)
function_cache_register.register(get_taxonomy_mapping)
Loading

0 comments on commit bb72dce

Please sign in to comment.