Skip to content

Commit

Permalink
update test and dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuan325 committed Oct 24, 2023
1 parent dfc3c30 commit 85c3003
Show file tree
Hide file tree
Showing 7 changed files with 1,603 additions and 1,605 deletions.
3,024 changes: 1,512 additions & 1,512 deletions data/airport_dataset.csv

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion extension_service/app/app_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def test_get_airport(app):
assert output[0]


def test_airports_search(app):
def test_airports_semantic_lookup(app):
with TestClient(app) as client:
response = client.get(
"/airports/semantic_lookup",
Expand Down
2 changes: 1 addition & 1 deletion extension_service/app/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ async def airports_semantic_lookup(query: str, top_k: int, request: Request):
embed_service: Embeddings = request.app.state.embed_service
query_embedding = embed_service.embed_query(query)

results = await ds.airports_semantic_lookup(query_embedding, 0.9, top_k)
results = await ds.airports_semantic_lookup(query_embedding, 0.7, top_k)
return results


Expand Down
22 changes: 9 additions & 13 deletions extension_service/datastore/datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from abc import ABC, abstractmethod
from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar
from typing import Any, Dict, Generic, Optional, TypeVar

import models

Expand Down Expand Up @@ -49,38 +49,34 @@ async def initialize_data(
self,
airports: list[models.Airport],
amenities: list[models.Amenity],
flights: List[models.Flight],
flights: list[models.Flight],
) -> None:
pass

@abstractmethod
async def export_data(
self,
) -> tuple[list[models.Airport], list[models.Amenity], List[models.Flight]]:
) -> tuple[list[models.Airport], list[models.Amenity], list[models.Flight]]:
pass

@abstractmethod
async def get_airport(self, id: int) -> Optional[models.Airport]:
raise NotImplementedError("Subclass should implement this!")

@abstractmethod
async def get_amenity(self, id: int) -> list[Dict[str, Any]]:
raise NotImplementedError("Subclass should implement this!")

@abstractmethod
async def amenities_search(
async def airports_semantic_lookup(
self, query_embedding: list[float], similarity_threshold: float, top_k: int
) -> list[Dict[str, Any]]:
) -> Optional[list[models.Airport]]:
raise NotImplementedError("Subclass should implement this!")

@abstractmethod
async def get_airport(self, id: int) -> list[models.Airport]:
async def get_amenity(self, id: int) -> list[Dict[str, Any]]:
raise NotImplementedError("Subclass should implement this!")

@abstractmethod
async def airports_semantic_lookup(
self, query_embedding: List[float], similarity_threshold: float, top_k: int
) -> list[models.Airport]:
async def amenities_search(
self, query_embedding: list[float], similarity_threshold: float, top_k: int
) -> list[Dict[str, Any]]:
raise NotImplementedError("Subclass should implement this!")

@abstractmethod
Expand Down
75 changes: 28 additions & 47 deletions extension_service/datastore/providers/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import asyncio
from ipaddress import IPv4Address, IPv6Address
from typing import Any, Dict, List, Literal, Optional
from typing import Any, Dict, Literal, Optional

import asyncpg
from pgvector.asyncpg import register_vector
Expand Down Expand Up @@ -67,7 +67,7 @@ async def initialize_data(
self,
airports: list[models.Airport],
amenities: list[models.Amenity],
flights: List[models.Flight],
flights: list[models.Flight],
) -> None:
async with self.__pool.acquire() as conn:
# If the table already exists, drop it to avoid conflicts
Expand Down Expand Up @@ -120,8 +120,8 @@ async def initialize_data(
name TEXT,
city TEXT,
country TEXT,
content TEXT,
embedding vector(768)
content TEXT NOT NULL,
embedding vector(768) NOT NULL
)
"""
)
Expand All @@ -147,8 +147,8 @@ async def initialize_data(
terminal TEXT,
category TEXT,
hour TEXT,
content TEXT,
embedding vector(768)
content TEXT NOT NULL,
embedding vector(768) NOT NULL
)
"""
)
Expand Down Expand Up @@ -203,85 +203,66 @@ async def get_airport(self, id: int) -> Optional[models.Airport]:
result = models.Airport.model_validate(dict(result))
return result

async def get_amenity(self, id: int) -> list[Dict[str, Any]]:
results = await self.__pool.fetch(
"""
SELECT name, description, location, terminal, category, hour
FROM amenities WHERE id=$1
""",
id,
)

results = [dict(r) for r in results]
return results

async def amenities_search(
async def airports_semantic_lookup(
self, query_embedding: list[float], similarity_threshold: float, top_k: int
) -> list[Dict[str, Any]]:
) -> Optional[list[models.Airport]]:
results = await self.__pool.fetch(
"""
SELECT name, description, location, terminal, category, hour
SELECT id, iata, name, city, country
FROM (
SELECT name, description, location, terminal, category, hour, 1 - (embedding <=> $1) AS similarity
FROM amenities
SELECT id, iata, name, city, country, 1 - (embedding <=> $1) AS similarity
FROM airports
WHERE 1 - (embedding <=> $1) > $2
ORDER BY similarity DESC
LIMIT $3
) AS sorted_amenities
) AS sorted_airports
""",
query_embedding,
similarity_threshold,
top_k,
timeout=10,
)

results = [dict(r) for r in results]
return results

async def get_airport(self, id: int) -> list[models.Airport]:
results = await self.__pool.fetch(
"""
SELECT id, iata, name, city, country FROM airports WHERE id=$1
""",
id,
)
if results is []:
return None

airports = [models.Airport.model_validate(dict(r)) for r in results]
return airports
results = [models.Airport.model_validate(dict(r)) for r in results]
return results

async def get_airport(self, id: int) -> List[Dict[str, Any]]:
async def get_amenity(self, id: int) -> list[Dict[str, Any]]:
results = await self.__pool.fetch(
"""
SELECT iata, name, city, country FROM airports WHERE id=$1
SELECT name, description, location, terminal, category, hour
FROM amenities WHERE id=$1
""",
id,
)

results = [dict(r) for r in results]
return results

async def airports_semantic_lookup(
self, query_embedding: List[float], similarity_threshold: float, top_k: int
) -> List[models.Airport]:
async def amenities_search(
self, query_embedding: list[float], similarity_threshold: float, top_k: int
) -> list[Dict[str, Any]]:
results = await self.__pool.fetch(
"""
SELECT iata, name, city, country
SELECT name, description, location, terminal, category, hour
FROM (
SELECT iata, name, city, country, 1 - (embedding <=> $1) AS similarity
FROM airports
SELECT name, description, location, terminal, category, hour, 1 - (embedding <=> $1) AS similarity
FROM amenities
WHERE 1 - (embedding <=> $1) > $2
ORDER BY similarity DESC
LIMIT $3
) AS sorted_airports
) AS sorted_amenities
""",
query_embedding,
similarity_threshold,
top_k,
timeout=10,
)

airports = [models.Airport.model_validate(dict(r)) for r in results]
return airports
results = [dict(r) for r in results]
return results

async def close(self):
await self.__pool.close()
77 changes: 49 additions & 28 deletions extension_service/datastore/providers/postgres_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,13 @@ class MockAsyncpgPool(asyncpg.Pool):
def __init__(self, mocks: Dict[str, MockRecord]):
self.mocks = mocks

async def fetch(self, query, *args):
return self.mocks.get(query.strip())
async def fetch(self, query, *args, timeout=None):
query = " ".join(q.strip() for q in query.splitlines()).strip()
return self.mocks.get(query)

async def fetchrow(self, query, *args, timeout=None):
query = " ".join(q.strip() for q in query.splitlines()).strip()
return self.mocks.get(query)


async def mock_postgres_provider(mocks: Dict[str, MockRecord]) -> postgres.Client:
Expand All @@ -51,31 +56,26 @@ async def mock_postgres_provider(mocks: Dict[str, MockRecord]) -> postgres.Clien

@pytest.mark.asyncio
async def test_get_airport():
mockRecord = [
MockRecord(
[
("id", 1),
("iata", "FOO"),
("name", "Foo Bar"),
("city", "baz"),
("country", "bundy"),
]
)
]
mocks = {
"SELECT id, iata, name, city, country FROM airports WHERE id=$1": mockRecord
}
mockRecord = MockRecord(
[
("id", 1),
("iata", "FOO"),
("name", "Foo Bar"),
("city", "baz"),
("country", "bundy"),
]
)
query = "SELECT id, iata, name, city, country FROM airports WHERE id=$1"
mocks = {query: mockRecord}
mockCl = await mock_postgres_provider(mocks)
res = await mockCl.get_airport(1)
expected_res = [
models.Airport(
id=1,
iata="FOO",
name="Foo Bar",
city="baz",
country="bundy",
)
]
expected_res = models.Airport(
id=1,
iata="FOO",
name="Foo Bar",
city="baz",
country="bundy",
)
assert res == expected_res


Expand All @@ -84,16 +84,37 @@ async def test_airport_search():
mockRecord = [
MockRecord(
[
("id", 1),
("iata", "FOO"),
("name", "Foo Bar"),
("city", "baz"),
("country", "bundy"),
]
)
]
mockCl = await create_postgres_provider(mockRecord)
res = await mockCl.airports_search(1, 0.7, 1)
query = """
SELECT id, iata, name, city, country
FROM (
SELECT id, iata, name, city, country, 1 - (embedding <=> $1) AS similarity
FROM airports
WHERE 1 - (embedding <=> $1) > $2
ORDER BY similarity DESC
LIMIT $3
) AS sorted_airports
"""
query = " ".join(q.strip() for q in query.splitlines()).strip()
mocks = {query: mockRecord}
mockCl = await mock_postgres_provider(mocks)
res = await mockCl.airports_semantic_lookup(1, 0.7, 1)
expected_res = [
{"iata": "FOO", "name": "Foo Bar", "city": "baz", "country": "bundy"}
models.Airport.model_validate(
{
"id": 1,
"iata": "FOO",
"name": "Foo Bar",
"city": "baz",
"country": "bundy",
}
)
]
assert res == expected_res
6 changes: 3 additions & 3 deletions extension_service/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import ast
import datetime
from decimal import Decimal
from typing import List
from typing import Optional

from numpy import float32
from pydantic import BaseModel, ConfigDict, FieldValidationInfo, field_validator
Expand All @@ -29,8 +29,8 @@ class Airport(BaseModel):
name: str
city: str
country: str
content: str
embedding: list[float32]
content: Optional[str] = None
embedding: Optional[list[float32]] = None

@field_validator("embedding", mode="before")
def validate(cls, v):
Expand Down

0 comments on commit 85c3003

Please sign in to comment.