Skip to content

Commit

Permalink
add amenities semantic search
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuan325 committed Oct 23, 2023
1 parent e873c71 commit 674ddd4
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 0 deletions.
15 changes: 15 additions & 0 deletions extension_service/app/app_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,18 @@ def test_get_amenity(app):
output = response.json()
assert len(output) == 1
assert output[0]


def test_amenities_search(app):
with TestClient(app) as client:
response = client.get(
"/amenities/search",
params={
"query": "A place to get food.",
"top_k": 5,
},
)
assert response.status_code == 200
output = response.json()
assert len(output) == 5
assert output[0]
11 changes: 11 additions & 0 deletions extension_service/app/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,14 @@ async def get_amenity(id: int, request: Request):
ds: datastore.Client = request.app.state.datastore
results = await ds.get_amenity(id)
return results


@routes.get("/amenities/search")
async def amenities_search(query: str, top_k: int, request: Request):
ds: datastore.Client = request.app.state.datastore

embed_service: Embeddings = request.app.state.embed_service
query_embedding = embed_service.embed_query(query)

results = await ds.amenities_search(query_embedding, 0.7, top_k)
return results
6 changes: 6 additions & 0 deletions extension_service/datastore/datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,12 @@ async def export_data(
async def get_amenity(self, id: int) -> List[Dict[str, Any]]:
raise NotImplementedError("Subclass should implement this!")

@abstractmethod
async def amenities_search(
self, query_embedding: List[float], similarity_threshold: float, top_k: int
) -> List[Dict[str, Any]]:
raise NotImplementedError("Subclass should implement this!")

@abstractmethod
async def close(self):
pass
Expand Down
23 changes: 23 additions & 0 deletions extension_service/datastore/providers/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,5 +153,28 @@ async def get_amenity(self, id: int) -> List[Dict[str, Any]]:
results = [dict(r) for r in results]
return results

async def amenities_search(
self, query_embedding: List[float], similarity_threshold: float, top_k: int
) -> List[Dict[str, Any]]:
results = await self.__pool.fetch(
"""
SELECT name, description, location, terminal, category, hour
FROM (
SELECT name, description, location, terminal, category, hour, 1 - (embedding <=> $1) AS similarity
FROM amenities
WHERE 1 - (embedding <=> $1) > $2
ORDER BY similarity DESC
LIMIT $3
) AS sorted_amenities
""",
query_embedding,
similarity_threshold,
top_k,
timeout=10,
)

results = [dict(r) for r in results]
return results

async def close(self):
await self.__pool.close()

0 comments on commit 674ddd4

Please sign in to comment.