Skip to content

Commit

Permalink
refactor: remove CachedStore class
Browse files Browse the repository at this point in the history
This was a legacy class, we now use built-in decorator or cachetools
classes
  • Loading branch information
raphael0202 committed Oct 27, 2023
1 parent 5e86341 commit 5d8c007
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 62 deletions.
10 changes: 3 additions & 7 deletions robotoff/prediction/ocr/location.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import gzip
import json
import re
from functools import cache
from pathlib import Path
from typing import BinaryIO, Iterable, Optional, Union

Expand All @@ -10,7 +11,6 @@
from robotoff import settings
from robotoff.types import Prediction, PredictionType
from robotoff.utils import get_logger
from robotoff.utils.cache import CachedStore
from robotoff.utils.text import KeywordProcessor, strip_accents_v1

# Increase version ID when introducing breaking change: changes for which we
Expand All @@ -31,6 +31,7 @@ class City:
coordinates: Optional[tuple[float, float]]


@cache()
def load_cities_fr(source: Union[Path, BinaryIO, None] = None) -> set[City]:
"""Load French cities dataset.
Expand Down Expand Up @@ -254,11 +255,6 @@ def find_nearby_postal_code(
return match.group(1), sub_start + match.start(1), sub_start + match.end(1)


ADDRESS_EXTRACTOR_STORE = CachedStore(
lambda: AddressExtractor(load_cities_fr()), expiration_interval=None
)


def find_locations(content: Union[OCRResult, str]) -> list[Prediction]:
"""Find location predictions in the text content.
Expand All @@ -270,5 +266,5 @@ def find_locations(content: Union[OCRResult, str]) -> list[Prediction]:
Returns:
list of Prediction: See :meth:`.AddressExtractor.extract_addresses`.
"""
location_extractor: AddressExtractor = ADDRESS_EXTRACTOR_STORE.get()
location_extractor = AddressExtractor(load_cities_fr())
return location_extractor.extract_addresses(content)
23 changes: 6 additions & 17 deletions robotoff/prediction/ocr/packager_code.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import re
from functools import cache
from typing import Optional, Union

from openfoodfacts.ocr import (
Expand All @@ -12,7 +13,6 @@
from robotoff import settings
from robotoff.types import Prediction, PredictionType
from robotoff.utils import text_file_iter
from robotoff.utils.cache import CachedStore
from robotoff.utils.text import KeywordProcessor

from .utils import generate_keyword_processor
Expand Down Expand Up @@ -52,16 +52,14 @@ def process_USDA_match_to_flashtext(match) -> Optional[str]:
unchecked_code = match.group().upper()
unchecked_code = re.sub(r"\s*\.*", "", unchecked_code)

processor = USDA_CODE_KEYWORD_PROCESSOR_STORE.get()
processor = generate_USDA_code_keyword_processor()
USDA_code = extract_USDA_code(processor, unchecked_code)
return USDA_code


@cache()
def generate_USDA_code_keyword_processor() -> KeywordProcessor:
"""Builds the KeyWordProcessor for USDA codes
This will be called only once thanks to CachedStore
"""
"""Builds the KeyWordProcessor for USDA codes."""

codes = text_file_iter(settings.OCR_USDA_CODE_FLASHTEXT_DATA_PATH)
return generate_keyword_processor(codes)
Expand All @@ -79,11 +77,6 @@ def extract_USDA_code(processor: KeywordProcessor, text: str) -> Optional[str]:
return USDA_code


USDA_CODE_KEYWORD_PROCESSOR_STORE = CachedStore(
fetch_func=generate_USDA_code_keyword_processor, expiration_interval=None
)


PACKAGER_CODE = {
"fr_emb": [
OCRRegex(
Expand Down Expand Up @@ -191,6 +184,7 @@ def find_packager_codes_regex(content: Union[OCRResult, str]) -> list[Prediction
return results


@cache()
def generate_fishing_code_keyword_processor() -> KeywordProcessor:
codes = text_file_iter(settings.OCR_FISHING_FLASHTEXT_DATA_PATH)
return generate_keyword_processor(("{}||{}".format(c.upper(), c) for c in codes))
Expand Down Expand Up @@ -226,13 +220,8 @@ def extract_fishing_code(
return predictions


FISHING_KEYWORD_PROCESSOR_STORE = CachedStore(
fetch_func=generate_fishing_code_keyword_processor, expiration_interval=None
)


def find_packager_codes(content: Union[OCRResult, str]) -> list[Prediction]:
predictions = find_packager_codes_regex(content)
processor = FISHING_KEYWORD_PROCESSOR_STORE.get()
processor = generate_fishing_code_keyword_processor()
predictions += extract_fishing_code(processor, content)
return predictions
9 changes: 3 additions & 6 deletions robotoff/prediction/ocr/trace.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import re
from functools import cache
from typing import Optional, Union

from openfoodfacts.ocr import (
Expand All @@ -12,7 +13,6 @@
from robotoff import settings
from robotoff.types import Prediction, PredictionType
from robotoff.utils import text_file_iter
from robotoff.utils.cache import CachedStore

from .utils import generate_keyword_processor

Expand All @@ -21,6 +21,7 @@
PREDICTOR_VERSION = "1"


@cache()
def generate_trace_keyword_processor(labels: Optional[list[str]] = None):
if labels is None:
labels = list(text_file_iter(settings.OCR_TRACE_ALLERGEN_DATA_PATH))
Expand All @@ -36,10 +37,6 @@ def generate_trace_keyword_processor(labels: Optional[list[str]] = None):
field=OCRField.full_text_contiguous,
)

TRACE_KEYWORD_PROCESSOR_STORE = CachedStore(
fetch_func=generate_trace_keyword_processor, expiration_interval=None
)


def find_traces(content: Union[OCRResult, str]) -> list[Prediction]:
predictions = []
Expand All @@ -49,7 +46,7 @@ def find_traces(content: Union[OCRResult, str]) -> list[Prediction]:
if not text:
return []

processor = TRACE_KEYWORD_PROCESSOR_STORE.get()
processor = generate_trace_keyword_processor()

for match in TRACES_REGEX.regex.finditer(text):
prompt = match.group()
Expand Down
32 changes: 0 additions & 32 deletions robotoff/utils/cache.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,3 @@
import datetime
from typing import Callable, Optional

from robotoff.utils import get_logger

logger = get_logger(__name__)


class CachedStore:
def __init__(self, fetch_func: Callable, expiration_interval: Optional[int] = 30):
self.store = None
self.expires_after: Optional[datetime.datetime] = None
self.fetch_func: Callable = fetch_func
self.expiration_timedelta: Optional[datetime.timedelta]

if expiration_interval is not None:
self.expiration_timedelta = datetime.timedelta(minutes=expiration_interval)
else:
self.expiration_timedelta = None

def get(self, **kwargs):
if self.store is None or (
self.expiration_timedelta is not None
and datetime.datetime.utcnow() >= self.expires_after
):
if self.store is not None:
logger.info("CachedStore expired, reloading...")

if self.expiration_timedelta is not None:
self.expires_after = (
datetime.datetime.utcnow() + self.expiration_timedelta
)
self.store = self.fetch_func(**kwargs)

return self.store

0 comments on commit 5d8c007

Please sign in to comment.