Skip to content

Commit

Permalink
fix: add server_type field to logo indexed in ES
Browse files Browse the repository at this point in the history
  • Loading branch information
raphael0202 committed Apr 14, 2023
1 parent 20a58d5 commit 506ab02
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 13 deletions.
3 changes: 3 additions & 0 deletions doc/references/api.yml
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,7 @@ paths:
description: Return ID and distance of each logo found, the number of neighbors returned and the ID of the query logo.
parameters:
- $ref: "#/components/parameters/ann_search_count"
- $ref: "#/components/parameters/server_type"
responses:
"200":
description: Response from ANN search
Expand All @@ -628,6 +629,7 @@ paths:
description: Return ID and distance of each logo found, the number of neighbors returned and the ID of the query logo.
parameters:
- $ref: "#/components/parameters/ann_search_count"
- $ref: "#/components/parameters/server_type"
responses:
"200":
description: Response from ANN search
Expand Down Expand Up @@ -923,6 +925,7 @@ components:
- 'obf'
- 'opff'
- 'opf'
- 'off_pro'
insight_types:
name: insight_types
in: query
Expand Down
5 changes: 4 additions & 1 deletion robotoff/app/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,6 +972,7 @@ def on_get(
- a specific logo otherwise
"""
count = req.get_param_as_int("count", min_value=1, max_value=500, default=100)
server_type = get_server_type_from_req(req)

if logo_id is None:
logo_embeddings = list(
Expand All @@ -993,7 +994,9 @@ def on_get(

raw_results = [
item
for item in knn_search(es_client, logo_embedding.embedding, count)
for item in knn_search(
es_client, logo_embedding.embedding, count, server_type=server_type
)
if item[0] != logo_id
][:count]
results = [{"logo_id": item[0], "distance": item[1]} for item in raw_results]
Expand Down
7 changes: 5 additions & 2 deletions robotoff/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,9 +485,12 @@ def init_elasticsearch(load_data: bool = True) -> None:

@app.command()
def add_logo_to_ann(
server_type: ServerType = typer.Option(
ServerType.off, help="Server type of the logos"
),
sleep_time: float = typer.Option(
0.0, help="Time to sleep between each query (in s)"
)
),
) -> None:
"""Index all missing logos in Elasticsearch ANN index."""
import logging
Expand Down Expand Up @@ -522,7 +525,7 @@ def add_logo_to_ann(
)
for logo_embedding_batch in chunked(logo_embedding_iter, 500):
try:
add_logos_to_ann(es_client, logo_embedding_batch)
add_logos_to_ann(es_client, logo_embedding_batch, server_type)
added += len(logo_embedding_batch)
except BulkIndexError as e:
logger.info("Request error during logo addition to ANN", exc_info=e)
Expand Down
3 changes: 2 additions & 1 deletion robotoff/elasticsearch/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@
"similarity": "dot_product",
"index_options": {"type": "hnsw", "m": 16, "ef_construction": 100},
},
}
},
"server_type": {"type": "keyword"},
},
},
}
Expand Down
41 changes: 35 additions & 6 deletions robotoff/logos.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,18 @@ def get_stored_logo_ids(es_client: elasticsearch.Elasticsearch) -> set[int]:


def add_logos_to_ann(
es_client: elasticsearch.Elasticsearch, logo_embeddings: list[LogoEmbedding]
es_client: elasticsearch.Elasticsearch,
logo_embeddings: list[LogoEmbedding],
server_type: ServerType,
) -> None:
"""Index logo embeddings in Elasticsearch ANN index."""
"""Index logo embeddings in Elasticsearch ANN index.
:param es_client: Elasticsearch client
:param logo_embeddings: a list of `LogoEmbedding`s model instances, the
fields `logo_id` and `embedding` should be available
:param server_type: the server type (project) associated with the logo
embeddings
"""
embeddings = [
np.frombuffer(logo_embedding.embedding, dtype=np.float32)
for logo_embedding in logo_embeddings
Expand All @@ -135,20 +144,26 @@ def add_logos_to_ann(
"_index": ElasticSearchIndex.logo.name,
"_id": logo_embedding.logo_id,
"embedding": embedding / np.linalg.norm(embedding),
"server_type": server_type.name,
}
for logo_embedding, embedding in zip(logo_embeddings, embeddings)
)
elasticsearch_bulk(es_client, actions)


def save_nearest_neighbors(
es_client: elasticsearch.Elasticsearch, logo_embeddings: list[LogoEmbedding]
es_client: elasticsearch.Elasticsearch,
logo_embeddings: list[LogoEmbedding],
server_type: ServerType,
) -> None:
"""Save nearest neighbors of a batch of logo embedding."""
updated = []
for logo_embedding in logo_embeddings:
results = knn_search(
es_client, logo_embedding.embedding, settings.K_NEAREST_NEIGHBORS
es_client,
logo_embedding.embedding,
settings.K_NEAREST_NEIGHBORS,
server_type,
)
results = [item for item in results if item[0] != logo_embedding.logo_id][
: settings.K_NEAREST_NEIGHBORS
Expand All @@ -171,8 +186,19 @@ def knn_search(
client: elasticsearch.Elasticsearch,
embedding_bytes: bytes,
k: int = settings.K_NEAREST_NEIGHBORS,
server_type: Optional[ServerType] = None,
) -> list[tuple[int, float]]:
"""Search for k approximate nearest neighbors of embedding_bytes in the elasticsearch logos index."""
"""Search for k approximate nearest neighbors of `embedding_bytes` in the
Elasticsearch logos index.
:param client: Elasticsearch client
:param embedding_bytes: 1d array of the logo embedding serialized using
`numpy.tobytes()`
:param k: number of nearest neighbors to return, defaults to
`settings.K_NEAREST_NEIGHBORS`
:param server_type: the server type (project) associated with the logos
to be returned. If not provided, logos from all projects are returned.
"""
embedding = np.frombuffer(embedding_bytes, dtype=np.float32)
knn_body = {
"field": "embedding",
Expand All @@ -181,6 +207,9 @@ def knn_search(
"num_candidates": k + 1,
}

if server_type is not None:
knn_body["filter"] = {"term": {"server_type": server_type.name}}

results = client.search(
index=ElasticSearchIndex.logo, knn=knn_body, source=False, size=k + 1
)
Expand Down Expand Up @@ -529,7 +558,7 @@ def refresh_nearest_neighbors(
.where(LogoEmbedding.logo_id.in_(logo_id_batch))
)
try:
save_nearest_neighbors(es_client, logo_embeddings)
save_nearest_neighbors(es_client, logo_embeddings, server_type)
except (
elasticsearch.ConnectionError,
elasticsearch.ConnectionTimeout,
Expand Down
4 changes: 2 additions & 2 deletions robotoff/workers/tasks/import_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,13 +385,13 @@ def process_created_logos(image_prediction_id: int, server_type: ServerType):

es_client = get_es_client()
try:
add_logos_to_ann(es_client, logo_embeddings)
add_logos_to_ann(es_client, logo_embeddings, server_type)
except BulkIndexError as e:
logger.info("Request error during logo addition to ANN", exc_info=e)
return

try:
save_nearest_neighbors(es_client, logo_embeddings)
save_nearest_neighbors(es_client, logo_embeddings, server_type)
except (elasticsearch.ConnectionError, elasticsearch.ConnectionTimeout) as e:
logger.info("Request error during ANN batch query", exc_info=e)
return
Expand Down
5 changes: 4 additions & 1 deletion tests/integration/test_import_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,10 @@ def test_process_created_logos(peewee_db, mocker):
logo_embeddings = [LogoEmbeddingFactory(logo=logo) for logo in logos]
process_created_logos(image_prediction.id, DEFAULT_SERVER_TYPE)
add_logos_to_ann_mock.assert_called()
embedding_args = add_logos_to_ann_mock.mock_calls[0].args[1]
mock_call = add_logos_to_ann_mock.mock_calls[0]
embedding_args = mock_call.args[1]
server_type = mock_call.args[2]
assert server_type == DEFAULT_SERVER_TYPE
assert sorted(embedding_args, key=lambda x: x.logo_id) == logo_embeddings
save_nearest_neighbors_mock.assert_called()
get_logo_confidence_thresholds_mock.assert_called()
Expand Down

0 comments on commit 506ab02

Please sign in to comment.