Skip to content

Commit

Permalink
resolve comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuan325 committed Oct 23, 2023
1 parent 2953f8e commit b2ac1bc
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 25 deletions.
10 changes: 5 additions & 5 deletions extension_service/datastore/datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from abc import ABC, abstractmethod
from typing import Any, Dict, Generic, List, Tuple, TypeVar
from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar

import models

Expand Down Expand Up @@ -59,6 +59,10 @@ async def export_data(
) -> tuple[list[models.Airport], list[models.Amenity], List[models.Flight]]:
pass

@abstractmethod
async def get_airport(self, id: int) -> Optional[models.Airport]:
raise NotImplementedError("Subclass should implement this!")

@abstractmethod
async def get_amenity(self, id: int) -> list[Dict[str, Any]]:
raise NotImplementedError("Subclass should implement this!")
Expand All @@ -69,10 +73,6 @@ async def amenities_search(
) -> list[Dict[str, Any]]:
raise NotImplementedError("Subclass should implement this!")

@abstractmethod
async def get_airport(self, id: int) -> list[models.Airport]:
raise NotImplementedError("Subclass should implement this!")

@abstractmethod
async def close(self):
pass
Expand Down
27 changes: 15 additions & 12 deletions extension_service/datastore/providers/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import asyncio
from ipaddress import IPv4Address, IPv6Address
from typing import Any, Dict, List, Literal
from typing import Any, Dict, List, Literal, Optional

import asyncpg
from pgvector.asyncpg import register_vector
Expand Down Expand Up @@ -184,6 +184,20 @@ async def export_data(
flights = [models.Flight.model_validate(dict(f)) for f in await flights_task]
return airports, amenities, flights

async def get_airport(self, id: int) -> Optional[models.Airport]:
result = await self.__pool.fetchrow(
"""
SELECT id, iata, name, city, country FROM airports WHERE id=$1
""",
id,
)

if result is None:
return None

result = models.Airport.model_validate(dict(result))
return result

async def get_amenity(self, id: int) -> list[Dict[str, Any]]:
results = await self.__pool.fetch(
"""
Expand Down Expand Up @@ -219,16 +233,5 @@ async def amenities_search(
results = [dict(r) for r in results]
return results

async def get_airport(self, id: int) -> list[models.Airport]:
results = await self.__pool.fetch(
"""
SELECT id, iata, name, city, country FROM airports WHERE id=$1
""",
id,
)

airports = [models.Airport.model_validate(dict(r)) for r in results]
return airports

async def close(self):
await self.__pool.close()
14 changes: 6 additions & 8 deletions extension_service/datastore/providers/postgres_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,12 @@ async def test_get_airport():
mockCl = await mock_postgres_provider(mocks)
res = await mockCl.get_airport(1)
expected_res = [
models.Airport.model_validate(
{
"id": 1,
"iata": "FOO",
"name": "Foo Bar",
"city": "baz",
"country": "bundy",
}
models.Airport(
id=1,
iata="FOO",
name="Foo Bar",
city="baz",
country="bundy",
)
]
assert res == expected_res

0 comments on commit b2ac1bc

Please sign in to comment.