Skip to content

Commit

Permalink
feat: add api endpoints for amenities (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuan325 authored Oct 23, 2023
1 parent 74db3ff commit bd1c411
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 1 deletion.
29 changes: 29 additions & 0 deletions extension_service/app/app_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,32 @@ def test_hello_world(app):
response = client.get("/")
assert response.status_code == 200
assert response.json() == {"message": "Hello World"}


def test_get_amenity(app):
with TestClient(app) as client:
response = client.get(
"/amenities",
params={
"id": 1,
},
)
assert response.status_code == 200
output = response.json()
assert len(output) == 1
assert output[0]


def test_amenities_search(app):
with TestClient(app) as client:
response = client.get(
"/amenities/search",
params={
"query": "A place to get food.",
"top_k": 5,
},
)
assert response.status_code == 200
output = response.json()
assert len(output) == 5
assert output[0]
23 changes: 22 additions & 1 deletion extension_service/app/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,32 @@
# limitations under the License.


from fastapi import APIRouter
from fastapi import APIRouter, Request
from langchain.embeddings.base import Embeddings

import datastore

routes = APIRouter()


@routes.get("/")
async def root():
return {"message": "Hello World"}


@routes.get("/amenities")
async def get_amenity(id: int, request: Request):
ds: datastore.Client = request.app.state.datastore
results = await ds.get_amenity(id)
return results


@routes.get("/amenities/search")
async def amenities_search(query: str, top_k: int, request: Request):
ds: datastore.Client = request.app.state.datastore

embed_service: Embeddings = request.app.state.embed_service
query_embedding = embed_service.embed_query(query)

results = await ds.amenities_search(query_embedding, 0.7, top_k)
return results
10 changes: 10 additions & 0 deletions extension_service/datastore/datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,16 @@ async def export_data(
) -> tuple[list[models.Airport], list[models.Amenity]]:
pass

@abstractmethod
async def get_amenity(self, id: int) -> list[Dict[str, Any]]:
raise NotImplementedError("Subclass should implement this!")

@abstractmethod
async def amenities_search(
self, query_embedding: list[float], similarity_threshold: float, top_k: int
) -> list[Dict[str, Any]]:
raise NotImplementedError("Subclass should implement this!")

@abstractmethod
async def close(self):
pass
Expand Down
35 changes: 35 additions & 0 deletions extension_service/datastore/providers/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,5 +141,40 @@ async def export_data(

return airports, amenities

async def get_amenity(self, id: int) -> list[Dict[str, Any]]:
results = await self.__pool.fetch(
"""
SELECT name, description, location, terminal, category, hour
FROM amenities WHERE id=$1
""",
id,
)

results = [dict(r) for r in results]
return results

async def amenities_search(
self, query_embedding: list[float], similarity_threshold: float, top_k: int
) -> list[Dict[str, Any]]:
results = await self.__pool.fetch(
"""
SELECT name, description, location, terminal, category, hour
FROM (
SELECT name, description, location, terminal, category, hour, 1 - (embedding <=> $1) AS similarity
FROM amenities
WHERE 1 - (embedding <=> $1) > $2
ORDER BY similarity DESC
LIMIT $3
) AS sorted_amenities
""",
query_embedding,
similarity_threshold,
top_k,
timeout=10,
)

results = [dict(r) for r in results]
return results

async def close(self):
await self.__pool.close()

0 comments on commit bd1c411

Please sign in to comment.