Skip to content

Commit

Permalink
fix: load lazily all resources in Robotoff
Browse files Browse the repository at this point in the history
We use to load many resources when importing robotoff module, this
has an negative impact on package loading and on memory consumption
Load all resources lazily
  • Loading branch information
raphael0202 committed Apr 6, 2023
1 parent 6866739 commit 4dfa93f
Show file tree
Hide file tree
Showing 9 changed files with 65 additions and 52 deletions.
7 changes: 3 additions & 4 deletions robotoff/brands.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import functools
import operator
from typing import Optional

import cachetools

from robotoff import settings
from robotoff.products import ProductDataset
from robotoff.taxonomy import TaxonomyType, get_taxonomy
Expand All @@ -18,7 +17,7 @@
logger = get_logger(__name__)


@cachetools.cached(cachetools.LRUCache(maxsize=1))
@functools.cache
def get_brand_prefix() -> set[tuple[str, str]]:
"""Get a set of brand prefix tuples found in Open Food Facts databases.
Expand All @@ -29,7 +28,7 @@ def get_brand_prefix() -> set[tuple[str, str]]:
return set(tuple(x) for x in load_json(settings.BRAND_PREFIX_PATH, compressed=True)) # type: ignore


@cachetools.cached(cachetools.LRUCache(maxsize=1))
@functools.cache
def get_brand_blacklist() -> set[str]:
logger.info("Loading brand blacklist...")
return set(text_file_iter(settings.OCR_TAXONOMY_BRANDS_BLACKLIST_PATH))
Expand Down
6 changes: 3 additions & 3 deletions robotoff/prediction/category/matcher.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import datetime
import functools
import itertools
import operator
import re
from typing import Iterable, Optional

import cachetools
from flashtext import KeywordProcessor

from robotoff import settings
Expand Down Expand Up @@ -230,7 +230,7 @@ def get_match_maps(taxonomy_type: str) -> MatchMapType:
)


@cachetools.cached(cache=cachetools.TTLCache(maxsize=1, ttl=3600))
@functools.cache
def get_processors() -> dict[str, KeywordProcessor]:
"""Return a dict mapping lang to flashtext KeywordProcessor used to
perform category matching.
Expand Down Expand Up @@ -273,7 +273,7 @@ def generate_intersect_categories_ingredients() -> dict[str, set[str]]:
return matches


@cachetools.cached(cache=cachetools.Cache(maxsize=1))
@functools.cache
def get_intersect_categories_ingredients():
"""Return intersection between category and ingredient maps saved on-disk
for supported language.
Expand Down
29 changes: 18 additions & 11 deletions robotoff/prediction/ocr/brand.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def generate_brand_keyword_processor(
return generate_keyword_processor(brands, keep_func=keep_func)


@functools.cache
def get_logo_annotation_brands() -> dict[str, str]:
brands: dict[str, str] = {}

Expand All @@ -45,13 +46,18 @@ def get_logo_annotation_brands() -> dict[str, str]:
return brands


LOGO_ANNOTATION_BRANDS: dict[str, str] = get_logo_annotation_brands()
TAXONOMY_BRAND_PROCESSOR = generate_brand_keyword_processor(
text_file_iter(settings.OCR_TAXONOMY_BRANDS_PATH)
)
BRAND_PROCESSOR = generate_brand_keyword_processor(
text_file_iter(settings.OCR_BRANDS_PATH),
)
@functools.cache
def get_taxonomy_brand_processor():
return generate_brand_keyword_processor(
text_file_iter(settings.OCR_TAXONOMY_BRANDS_PATH)
)


@functools.cache
def get_brand_processor():
return generate_brand_keyword_processor(
text_file_iter(settings.OCR_BRANDS_PATH),
)


def extract_brands(
Expand Down Expand Up @@ -89,9 +95,10 @@ def extract_brands(

def extract_brands_google_cloud_vision(ocr_result: OCRResult) -> list[Prediction]:
predictions = []
logo_annotation_brands = get_logo_annotation_brands()
for logo_annotation in ocr_result.logo_annotations:
if logo_annotation.description in LOGO_ANNOTATION_BRANDS:
brand = LOGO_ANNOTATION_BRANDS[logo_annotation.description]
if logo_annotation.description in logo_annotation_brands:
brand = logo_annotation_brands[logo_annotation.description]

predictions.append(
Prediction(
Expand All @@ -112,10 +119,10 @@ def find_brands(content: Union[OCRResult, str]) -> list[Prediction]:
predictions: list[Prediction] = []

predictions += extract_brands(
BRAND_PROCESSOR, content, "curated-list", automatic_processing=True
get_brand_processor(), content, "curated-list", automatic_processing=True
)
predictions += extract_brands(
TAXONOMY_BRAND_PROCESSOR, content, "taxonomy", automatic_processing=False
get_taxonomy_brand_processor(), content, "taxonomy", automatic_processing=False
)

if isinstance(content, OCRResult):
Expand Down
8 changes: 4 additions & 4 deletions robotoff/prediction/ocr/image_flag.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
from typing import Optional, Union

from flashtext import KeywordProcessor
Expand Down Expand Up @@ -47,6 +48,7 @@
}


@functools.cache
def generate_image_flag_keyword_processor() -> KeywordProcessor:
processor = KeywordProcessor()

Expand All @@ -60,9 +62,6 @@ def generate_image_flag_keyword_processor() -> KeywordProcessor:
return processor


PROCESSOR = generate_image_flag_keyword_processor()


def extract_image_flag_flashtext(
processor: KeywordProcessor, text: str
) -> Optional[Prediction]:
Expand All @@ -82,7 +81,8 @@ def flag_image(content: Union[OCRResult, str]) -> list[Prediction]:
predictions: list[Prediction] = []

text = get_text(content)
prediction = extract_image_flag_flashtext(PROCESSOR, text)
processor = generate_image_flag_keyword_processor()
prediction = extract_image_flag_flashtext(processor, text)

if prediction is not None:
predictions.append(prediction)
Expand Down
17 changes: 7 additions & 10 deletions robotoff/prediction/ocr/label.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import re
from typing import Iterable, Optional, Union

Expand All @@ -6,7 +7,6 @@
from robotoff import settings
from robotoff.types import Prediction, PredictionType
from robotoff.utils import get_logger, text_file_iter
from robotoff.utils.cache import CachedStore

from .dataclass import OCRField, OCRRegex, OCRResult, get_match_bounding_box, get_text
from .utils import generate_keyword_processor
Expand Down Expand Up @@ -182,6 +182,7 @@ def process_es_bio_label_code(match) -> str:
}


@functools.cache
def get_logo_annotation_labels() -> dict[str, str]:
labels: dict[str, str] = {}

Expand All @@ -197,6 +198,7 @@ def get_logo_annotation_labels() -> dict[str, str]:
return labels


@functools.cache
def generate_label_keyword_processor(labels: Optional[Iterable[str]] = None):
if labels is None:
labels = text_file_iter(settings.OCR_LABEL_FLASHTEXT_DATA_PATH)
Expand Down Expand Up @@ -234,14 +236,9 @@ def extract_label_flashtext(
return predictions


LOGO_ANNOTATION_LABELS: dict[str, str] = get_logo_annotation_labels()
LABEL_KEYWORD_PROCESSOR_STORE = CachedStore(
fetch_func=generate_label_keyword_processor, expiration_interval=None
)


def find_labels(content: Union[OCRResult, str]) -> list[Prediction]:
predictions = []
logo_annotation_labels = get_logo_annotation_labels()

for label_tag, regex_list in LABELS_REGEX.items():
for ocr_regex in regex_list:
Expand Down Expand Up @@ -277,13 +274,13 @@ def find_labels(content: Union[OCRResult, str]) -> list[Prediction]:
)
)

processor = LABEL_KEYWORD_PROCESSOR_STORE.get()
processor = generate_label_keyword_processor()
predictions += extract_label_flashtext(processor, content)

if isinstance(content, OCRResult):
for logo_annotation in content.logo_annotations:
if logo_annotation.description in LOGO_ANNOTATION_LABELS:
label_tag = LOGO_ANNOTATION_LABELS[logo_annotation.description]
if logo_annotation.description in logo_annotation_labels:
label_tag = logo_annotation_labels[logo_annotation.description]

predictions.append(
Prediction(
Expand Down
6 changes: 3 additions & 3 deletions robotoff/prediction/ocr/packaging.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import functools
from typing import Optional, Union

import cachetools
from lark import Discard, Lark, Transformer

from robotoff import settings
Expand Down Expand Up @@ -60,7 +60,7 @@ def generate_packaging_lark_file(lang: str):
)


@cachetools.cached(cachetools.LRUCache(maxsize=5))
@functools.cache
def load_grammar(lang: str, start: str = "value", **kwargs) -> Lark:
return Lark.open(
str(settings.GRAMMARS_DIR / f"packaging_{lang}.lark"),
Expand Down Expand Up @@ -153,7 +153,7 @@ def _match_tag(self, type_: str, value: str) -> Optional[str]:
return value_tags[0]


@cachetools.cached(cachetools.LRUCache(maxsize=10))
@functools.cache
def load_taxonomy_map(lang: str) -> dict[str, dict[str, list[str]]]:
return {
TaxonomyType.packaging_shape.name: load_json( # type: ignore
Expand Down
36 changes: 23 additions & 13 deletions robotoff/prediction/ocr/store.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import re
from typing import Union

Expand All @@ -22,6 +23,7 @@ def store_sort_key(item):
return -len(store), store


@functools.cache
def get_sorted_stores() -> list[tuple[str, str]]:
sorted_stores: dict[str, str] = {}

Expand All @@ -37,31 +39,39 @@ def get_sorted_stores() -> list[tuple[str, str]]:
return sorted(sorted_stores.items(), key=store_sort_key)


SORTED_STORES = get_sorted_stores()
STORE_REGEX_STR = "|".join(
r"((?<!\w){}(?!\w))".format(pattern) for _, pattern in SORTED_STORES
)
NOTIFY_STORES: set[str] = set(text_file_iter(settings.OCR_STORES_NOTIFY_DATA_PATH))
STORE_REGEX = OCRRegex(
re.compile(STORE_REGEX_STR, re.I), field=OCRField.full_text_contiguous
)
@functools.cache
def get_store_ocr_regex() -> OCRRegex:
sorted_stores = get_sorted_stores()
store_regex_str = "|".join(
r"((?<!\w){}(?!\w))".format(pattern) for _, pattern in sorted_stores
)
return OCRRegex(
re.compile(store_regex_str, re.I), field=OCRField.full_text_contiguous
)


@functools.cache
def get_notify_stores() -> set[str]:
return set(text_file_iter(settings.OCR_STORES_NOTIFY_DATA_PATH))


def find_stores(content: Union[OCRResult, str]) -> list[Prediction]:
results = []

text = get_text(content, STORE_REGEX)
store_ocr_regex = get_store_ocr_regex()
sorted_stores = get_sorted_stores()
notify_stores = get_notify_stores()
text = get_text(content, store_ocr_regex)

if not text:
return []

for match in STORE_REGEX.regex.finditer(text):
for match in store_ocr_regex.regex.finditer(text):
groups = match.groups()

for idx, match_str in enumerate(groups):
if match_str is not None:
store, _ = SORTED_STORES[idx]
data = {"text": match_str, "notify": store in NOTIFY_STORES}
store, _ = sorted_stores[idx]
data = {"text": match_str, "notify": store in notify_stores}
if (
bounding_box := get_match_bounding_box(
content, match.start(), match.end()
Expand Down
4 changes: 2 additions & 2 deletions robotoff/products.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import abc
import datetime
import enum
import functools
import gzip
import json
import os
Expand All @@ -9,7 +10,6 @@
import tempfile
from typing import Iterable, Iterator, Optional, Union

import cachetools
import requests
from pymongo import MongoClient

Expand Down Expand Up @@ -501,7 +501,7 @@ def iter_product(self, projection: Optional[list[str]] = None):
yield from (Product(p) for p in self.collection.find(projection=projection))


@cachetools.cached(cachetools.LRUCache(maxsize=1))
@functools.cache
def get_min_product_store() -> ProductStore:
logger.info("Loading product store in memory...")
ps = MemoryProductStore.load_min()
Expand Down
4 changes: 2 additions & 2 deletions robotoff/triton.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import functools
import struct

import cachetools
import grpc
import numpy as np
from more_itertools import chunked
Expand All @@ -17,7 +17,7 @@
CLIP_MAX_BATCH_SIZE = 32


@cachetools.cached(cachetools.Cache(maxsize=1))
@functools.cache
def get_triton_inference_stub() -> service_pb2_grpc.GRPCInferenceServiceStub:
channel = grpc.insecure_channel(settings.TRITON_URI)
return service_pb2_grpc.GRPCInferenceServiceStub(channel)
Expand Down

0 comments on commit 4dfa93f

Please sign in to comment.