diff --git a/robotoff/logos.py b/robotoff/logos.py index be2a64389d..36bad670f5 100644 --- a/robotoff/logos.py +++ b/robotoff/logos.py @@ -114,27 +114,28 @@ def get_stored_logo_ids() -> set[int]: return set(r.json()["stored"]) -def add_logos_to_ann(logos: list[LogoEmbedding]) -> None: +def add_logos_to_ann(logo_embeddings: list[LogoEmbedding]) -> None: es_client = get_es_client() - for logo in logos: - embedding = np.frombuffer(logo.embedding, dtype=np.float32) + for logo_embedding in logo_embeddings: + embedding = np.frombuffer(logo_embedding.embedding, dtype=np.float32) es_client.index( index=ElasticSearchIndex.logo, - id=logo.logo_id, + id=logo_embedding.logo_id, document={ "embedding": embedding / np.linalg.norm(embedding), }, ) -def save_nearest_neighbors(embeddings: list[LogoEmbedding]) -> None: +def save_nearest_neighbors(logo_embeddings: list[LogoEmbedding]) -> None: es_client = get_es_client() - for embedding in embeddings: + for logo_embedding in logo_embeddings: + embedding = np.frombuffer(logo_embedding.embedding, dtype=np.float32) knn_body = { "field": "embedding", - "query_vector": np.frombuffer(embedding.embedding, dtype=np.float32), + "query_vector": embedding / np.linalg.norm(embedding), "k": settings.K_NEAREST_NEIGHBORS + 1, "num_candidates": settings.NUM_CANDIDATES + 1, } @@ -145,11 +146,11 @@ def save_nearest_neighbors(embeddings: list[LogoEmbedding]) -> None: if hits := results["hits"]["hits"]: logo_ids, distances = zip(*[(hit["_id"], hit["_score"]) for hit in hits]) - embedding.logo.nearest_neighbors = { + logo_embedding.logo.nearest_neighbors = { "distances": distances, "logo_ids": logo_ids, } - embedding.logo.save() + logo_embedding.logo.save() @cachetools.cached(cachetools.LRUCache(maxsize=1))