diff --git a/extension_service/datastore/datastore.py b/extension_service/datastore/datastore.py index 259c5767..e2b52933 100644 --- a/extension_service/datastore/datastore.py +++ b/extension_service/datastore/datastore.py @@ -13,7 +13,7 @@ # limitations under the License. from abc import ABC, abstractmethod -from typing import Any, Dict, Generic, List, Tuple, TypeVar +from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar import models @@ -59,6 +59,10 @@ async def export_data( ) -> tuple[list[models.Airport], list[models.Amenity], List[models.Flight]]: pass + @abstractmethod + async def get_airport(self, id: int) -> Optional[models.Airport]: + raise NotImplementedError("Subclass should implement this!") + @abstractmethod async def get_amenity(self, id: int) -> list[Dict[str, Any]]: raise NotImplementedError("Subclass should implement this!") @@ -69,10 +73,6 @@ async def amenities_search( ) -> list[Dict[str, Any]]: raise NotImplementedError("Subclass should implement this!") - @abstractmethod - async def get_airport(self, id: int) -> list[models.Airport]: - raise NotImplementedError("Subclass should implement this!") - @abstractmethod async def close(self): pass diff --git a/extension_service/datastore/providers/postgres.py b/extension_service/datastore/providers/postgres.py index aa51b474..f0ee19e9 100644 --- a/extension_service/datastore/providers/postgres.py +++ b/extension_service/datastore/providers/postgres.py @@ -14,7 +14,7 @@ import asyncio from ipaddress import IPv4Address, IPv6Address -from typing import Any, Dict, List, Literal +from typing import Any, Dict, List, Literal, Optional import asyncpg from pgvector.asyncpg import register_vector @@ -184,6 +184,20 @@ async def export_data( flights = [models.Flight.model_validate(dict(f)) for f in await flights_task] return airports, amenities, flights + async def get_airport(self, id: int) -> Optional[models.Airport]: + result = await self.__pool.fetchrow( + """ + SELECT id, iata, name, city, country FROM airports WHERE id=$1 + """, + id, + ) + + if result is None: + return None + + result = models.Airport.model_validate(dict(result)) + return result + async def get_amenity(self, id: int) -> list[Dict[str, Any]]: results = await self.__pool.fetch( """ @@ -219,16 +233,5 @@ async def amenities_search( results = [dict(r) for r in results] return results - async def get_airport(self, id: int) -> list[models.Airport]: - results = await self.__pool.fetch( - """ - SELECT id, iata, name, city, country FROM airports WHERE id=$1 - """, - id, - ) - - airports = [models.Airport.model_validate(dict(r)) for r in results] - return airports - async def close(self): await self.__pool.close() diff --git a/extension_service/datastore/providers/postgres_test.py b/extension_service/datastore/providers/postgres_test.py index 85067f91..c57ab630 100644 --- a/extension_service/datastore/providers/postgres_test.py +++ b/extension_service/datastore/providers/postgres_test.py @@ -68,14 +68,12 @@ async def test_get_airport(): mockCl = await mock_postgres_provider(mocks) res = await mockCl.get_airport(1) expected_res = [ - models.Airport.model_validate( - { - "id": 1, - "iata": "FOO", - "name": "Foo Bar", - "city": "baz", - "country": "bundy", - } + models.Airport( + id=1, + iata="FOO", + name="Foo Bar", + city="baz", + country="bundy", ) ] assert res == expected_res