Skip to content

Commit

Permalink
add amenities embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuan325 committed Oct 12, 2023
1 parent 31520f2 commit 3e2decd
Show file tree
Hide file tree
Showing 5 changed files with 161 additions and 16 deletions.
142 changes: 142 additions & 0 deletions data/amenity_embeddings_dataset.csv

Large diffs are not rendered by default.

27 changes: 15 additions & 12 deletions extension_service/datastore/providers/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,19 +126,19 @@ async def initialize_data(
)

await conn.execute("CREATE EXTENSION IF NOT EXISTS vector")
await conn.execute("DROP TABLE IF EXISTS product_embeddings")
await conn.execute("DROP TABLE IF EXISTS amenity_embeddings")
await conn.execute(
"""
CREATE TABLE product_embeddings(
product_id VARCHAR(1024) NOT NULL REFERENCES products(product_id),
CREATE TABLE amenity_embeddings(
amenity_id VARCHAR(1024) NOT NULL REFERENCES amenities(amenity_id),
content TEXT,
embedding vector(768))
"""
)
# Insert all the data
await conn.executemany(
"""INSERT INTO product_embeddings VALUES ($1, $2, $3)""",
[(e.product_id, e.content, e.embedding) for e in embeddings],
"""INSERT INTO amenity_embeddings VALUES ($1, $2, $3)""",
[(e.amenity_id, e.content, e.embedding) for e in embeddings],
)

async def export_data(
Expand All @@ -149,7 +149,7 @@ async def export_data(
self.__pool.fetch("""SELECT * FROM amenities""")
)
emb_task = asyncio.create_task(
self.__pool.fetch("""SELECT * FROM product_embeddings""")
self.__pool.fetch("""SELECT * FROM amenity_embeddings""")
)

toys = [models.Toy.model_validate(dict(t)) for t in await toy_task]
Expand All @@ -164,18 +164,21 @@ async def semantic_similarity_search(
results = await self.__pool.fetch(
"""
WITH vector_matches AS (
SELECT product_id, 1 - (embedding <=> $1) AS similarity
FROM product_embeddings
SELECT amenity_id, 1 - (embedding <=> $1) AS similarity
FROM amenity_embeddings
WHERE 1 - (embedding <=> $1) > $2
ORDER BY similarity DESC
LIMIT $3
)
SELECT
product_name,
list_price,
amenity_name,
description
FROM products
WHERE product_id IN (SELECT product_id FROM vector_matches)
location,
terminal,
amenity_type,
hour
FROM amenities
WHERE amenity_id IN (SELECT amenity_id FROM vector_matches)
""",
query_embedding,
similarity_threshold,
Expand Down
2 changes: 1 addition & 1 deletion extension_service/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class Amenity(BaseModel):
class Embedding(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)

product_id: str
amenity_id: str
content: str
embedding: List[float32]

Expand Down
4 changes: 2 additions & 2 deletions extension_service/run_database_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ async def main():
for a in amenities:
writer.writerow(a.model_dump())

with open("../data/product_embeddings_dataset.csv.new", "w") as f:
col_names = ["product_id", "content", "embedding"]
with open("../data/amenity_embeddings_dataset.csv.new", "w") as f:
col_names = ["amenity_id", "content", "embedding"]
writer = csv.DictWriter(f, col_names, delimiter=",")
writer.writeheader()
for e in embeddings:
Expand Down
2 changes: 1 addition & 1 deletion extension_service/run_database_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ async def main() -> None:
amenities = [models.Amenity.model_validate(line) for line in reader]

embeddings: List[models.Embedding] = []
with open("../data/product_embeddings_dataset.csv", "r") as f:
with open("../data/amenity_embeddings_dataset.csv", "r") as f:
reader = csv.DictReader(f, delimiter=",")
embeddings = [models.Embedding.model_validate(line) for line in reader]

Expand Down

0 comments on commit 3e2decd

Please sign in to comment.