Skip to content

Commit

Permalink
fix: fix refresh-insight scheduled job
Browse files Browse the repository at this point in the history
- loading min dataset in memory took +20GB of RAM, which made the
scheduler crash. By allowing a projection (selected fields) to be passed
as parameter, we reduce memory usage to ~2GB.
- Use ServerSide cursor to avoid excessive memory consumption
- Also delete predictions associated with deleted products
  • Loading branch information
raphael0202 committed Apr 16, 2023
1 parent d94f3d1 commit d363cf8
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 44 deletions.
48 changes: 27 additions & 21 deletions robotoff/products.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import gzip
import json
import os
import pathlib
import shutil
import tempfile
from pathlib import Path
from typing import Iterable, Iterator, Optional, Union

import requests
Expand Down Expand Up @@ -36,7 +36,7 @@ def get_image_id(image_path: str) -> Optional[str]:
:return: the image ID ("2" in the previous example) or None if the image
is not "raw" (not digit-numbered)
"""
image_id = pathlib.Path(image_path).stem
image_id = Path(image_path).stem

if image_id.isdigit():
return image_id
Expand All @@ -45,7 +45,7 @@ def get_image_id(image_path: str) -> Optional[str]:


def is_valid_image(images: JSONType, image_path: str) -> bool:
image_id = pathlib.Path(image_path).stem
image_id = Path(image_path).stem
if not image_id.isdigit():
return False

Expand Down Expand Up @@ -77,7 +77,7 @@ def is_special_image(
if not is_valid_image(images, image_path):
return False

image_id = pathlib.Path(image_path).stem
image_id = Path(image_path).stem

for image_key, image_data in images.items():
if (
Expand All @@ -93,7 +93,7 @@ def is_special_image(
return False


def minify_product_dataset(dataset_path: pathlib.Path, output_path: pathlib.Path):
def minify_product_dataset(dataset_path: Path, output_path: Path):
if dataset_path.suffix == ".gz":
jsonl_iter_func = gzip_jsonl_iter
else:
Expand Down Expand Up @@ -128,7 +128,7 @@ def save_product_dataset_etag(etag: str):

def fetch_dataset(minify: bool = True) -> bool:
with tempfile.TemporaryDirectory() as tmp_dir:
output_dir = pathlib.Path(tmp_dir)
output_dir = Path(tmp_dir)
output_path = output_dir / "products.jsonl.gz"
etag = download_dataset(output_path)

Expand Down Expand Up @@ -180,7 +180,7 @@ def download_dataset(output_path: os.PathLike) -> str:
return current_etag


def is_valid_dataset(dataset_path: pathlib.Path) -> bool:
def is_valid_dataset(dataset_path: Path) -> bool:
"""Check the dataset integrity: readable end to end and with a minimum number of products.
This is used to spot corrupted archive files."""
dataset = ProductDataset(dataset_path)
Expand Down Expand Up @@ -347,9 +347,14 @@ def take(self, count: int):
def iter(self) -> Iterable[JSONType]:
return iter(self)

def iter_product(self) -> Iterable["Product"]:
def iter_product(
self, projection: Optional[list[str]] = None
) -> Iterable["Product"]:
for item in self:
yield Product(item)
projected_item = (
{k: item[k] for k in projection if k in item} if projection else item
)
yield Product(projected_item)

def collect(self) -> list[JSONType]:
return list(self)
Expand All @@ -364,12 +369,10 @@ def __init__(self, jsonl_path):
self.jsonl_path = jsonl_path

def stream(self) -> ProductStream:
json_path_str = str(self.jsonl_path)

if json_path_str.endswith(".gz"):
iterator = gzip_jsonl_iter(json_path_str)
if str(self.jsonl_path).endswith(".gz"):
iterator = gzip_jsonl_iter(self.jsonl_path)
else:
iterator = jsonl_iter(json_path_str)
iterator = jsonl_iter(self.jsonl_path)

return ProductStream(iterator)

Expand Down Expand Up @@ -452,22 +455,26 @@ def __len__(self):
return len(self.store)

@classmethod
def load_from_path(cls, path: str):
def load_from_path(cls, path: Path, projection: Optional[list[str]] = None):
logger.info("Loading product store")

if projection is not None and "code" not in projection:
raise ValueError("at least `code` must be in projection")

ds = ProductDataset(path)
stream = ds.stream()

store: dict[str, Product] = {}

for product in stream.iter_product():
for product in stream.iter_product(projection):
if product.barcode:
store[product.barcode] = product

return cls(store)

@classmethod
def load_min(cls):
return cls.load_from_path(settings.JSONL_MIN_DATASET_PATH)
def load_min(cls, projection: Optional[list[str]] = None):
return cls.load_from_path(settings.JSONL_MIN_DATASET_PATH, projection)

@classmethod
def load_full(cls):
Expand Down Expand Up @@ -513,10 +520,9 @@ def iter_product(self, projection: Optional[list[str]] = None):
yield from (Product(p) for p in self.collection.find(projection=projection))


@functools.cache
def get_min_product_store() -> ProductStore:
def get_min_product_store(projection: Optional[list[str]] = None) -> ProductStore:
logger.info("Loading product store in memory...")
ps = MemoryProductStore.load_min()
ps = MemoryProductStore.load_min(projection)
logger.info("product store loaded (%s items)", len(ps))
return ps

Expand Down
65 changes: 42 additions & 23 deletions robotoff/scheduler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import datetime
import functools
import os
import uuid
from typing import Iterable
Expand All @@ -9,6 +8,7 @@
from apscheduler.executors.pool import ThreadPoolExecutor
from apscheduler.jobstores.memory import MemoryJobStore
from apscheduler.schedulers.blocking import BlockingScheduler
from playhouse.postgres_ext import ServerSide
from sentry_sdk import capture_exception

from robotoff import settings, slack
Expand All @@ -21,7 +21,7 @@
save_facet_metrics,
save_insight_metrics,
)
from robotoff.models import ProductInsight, db, with_db
from robotoff.models import Prediction, ProductInsight, db, with_db
from robotoff.prediction.category.matcher import predict_from_dataset
from robotoff.products import (
Product,
Expand All @@ -30,7 +30,7 @@
get_min_product_store,
has_dataset_changed,
)
from robotoff.types import ServerType
from robotoff.types import ProductIdentifier, ServerType
from robotoff.utils import get_logger

from .latent import generate_quality_facets
Expand Down Expand Up @@ -79,10 +79,10 @@ def process_insights():


@with_db
def refresh_insights(with_deletion: bool = False):
deleted = 0
updated = 0
product_store = get_min_product_store()
def refresh_insights(with_deletion: bool = True):
product_store = get_min_product_store(
["code", "brands_tags", "countries_tags", "unique_scans_n"]
)
# Only OFF is currently supported
server_type = ServerType.off

Expand All @@ -100,22 +100,22 @@ def refresh_insights(with_deletion: bool = False):
return

insight: ProductInsight
for insight in (
ProductInsight.select()
.where(
deleted = 0
updated = 0
for insight in ServerSide(
ProductInsight.select().where(
ProductInsight.annotation.is_null(),
ProductInsight.timestamp <= datetime_threshold,
ProductInsight.server_type == server_type.name,
)
.iterator()
):
product_id = insight.get_product_id()
product: Product = product_store[insight.barcode]

if product is None:
if with_deletion:
# Product has been deleted from OFF
logger.info("%s deleted", product_id)
logger.info("%s deleted, deleting insight %s", product_id, insight)
deleted += 1
insight.delete_instance()
else:
Expand All @@ -124,20 +124,38 @@ def refresh_insights(with_deletion: bool = False):
if insight_updated:
updated += 1

logger.info("{} insights deleted".format(deleted))
logger.info("{} insights updated".format(updated))
prediction: Prediction
deleted = 0
for prediction in ServerSide(
Prediction.select().where(
Prediction.timestamp <= datetime_threshold,
Prediction.server_type == server_type.name,
)
):
product = product_store[ProductIdentifier(prediction.barcode, server_type)]

if product is None:
if with_deletion:
# Product has been deleted from OFF
logger.info(
"%s deleted, deleting prediction %s", product_id, prediction
)
deleted += 1
prediction.delete_instance()

logger.info("%s prediction deleted", deleted)


def update_insight_attributes(product: Product, insight: ProductInsight) -> bool:
to_update = False
updated_fields = []
if insight.brands != product.brands_tags:
logger.info(
"Updating brand %s -> %s (%s)",
insight.brands,
product.brands_tags,
insight.get_product_id(),
)
to_update = True
updated_fields.append("brands")
insight.brands = product.brands_tags

if insight.countries != product.countries_tags:
Expand All @@ -147,7 +165,7 @@ def update_insight_attributes(product: Product, insight: ProductInsight) -> bool
product.countries_tags,
insight.get_product_id(),
)
to_update = True
updated_fields.append("countries")
insight.countries = product.countries_tags

if insight.unique_scans_n != product.unique_scans_n:
Expand All @@ -157,17 +175,18 @@ def update_insight_attributes(product: Product, insight: ProductInsight) -> bool
product.unique_scans_n,
insight.get_product_id(),
)
to_update = True
updated_fields.append("unique_scans_n")
insight.unique_scans_n = product.unique_scans_n

if to_update:
insight.save()
if updated_fields:
# Only update selected field with bulk_update and a list of fields to update
ProductInsight.bulk_update([insight], fields=updated_fields)

return to_update
return bool(updated_fields)


@with_db
def mark_insights():
def mark_insights() -> int:
marked = 0
insight: ProductInsight
for insight in (
Expand Down Expand Up @@ -285,7 +304,7 @@ def run():
# are no longer applicable.
# - Updating insight attributes.
scheduler.add_job(
functools.partial(refresh_insights, with_deletion=True),
refresh_insights,
"cron",
day="*",
hour="4",
Expand Down

0 comments on commit d363cf8

Please sign in to comment.