diff --git a/langchain_tools_demo/agent.py b/langchain_tools_demo/agent.py index 36cba3e3..a17d46ce 100644 --- a/langchain_tools_demo/agent.py +++ b/langchain_tools_demo/agent.py @@ -13,7 +13,13 @@ # limitations under the License. import os +from datetime import date, timedelta +from typing import Dict, Optional +import aiohttp +import dateutil.parser as dparser +import google.auth.transport.requests # type: ignore +import google.oauth2.id_token # type: ignore from langchain.agents import AgentType, initialize_agent from langchain.agents.agent import AgentExecutor from langchain.globals import set_verbose # type: ignore @@ -21,20 +27,102 @@ from langchain.memory import ConversationBufferMemory from langchain.prompts.chat import ChatPromptTemplate -from tools import convert_date, tools +from tools import initialize_tools set_verbose(bool(os.getenv("DEBUG", default=False))) +BASE_URL = os.getenv("BASE_URL", default="http://127.0.0.1:8080") + +# aiohttp context +connector = None + + +# Class for setting up a dedicated llm agent for each individual user +class UserAgent: + client: aiohttp.ClientSession + agent: AgentExecutor + + def __init__(self, client, agent) -> None: + self.client = client + self.agent = agent + + +user_agents: Dict[str, UserAgent] = {} + + +def get_id_token(url: str) -> str: + """Helper method to generate ID tokens for authenticated requests""" + # Use Application Default Credentials on Cloud Run + if os.getenv("K_SERVICE"): + auth_req = google.auth.transport.requests.Request() + return google.oauth2.id_token.fetch_id_token(auth_req, url) + else: + # Use gcloud credentials locally + import subprocess + + return ( + subprocess.run( + ["gcloud", "auth", "print-identity-token"], + stdout=subprocess.PIPE, + check=True, + ) + .stdout.strip() + .decode() + ) + + +def convert_date(date_string: str) -> str: + """Convert date into appropriate date string""" + if date_string == "tomorrow": + converted = date.today() + timedelta(1) + elif date_string == "yesterday": + converted = date.today() - timedelta(1) + elif date_string != "null" and date_string != "today" and date_string is not None: + converted = dparser.parse(date_string, fuzzy=True).date() + else: + converted = date.today() + + return converted.strftime("%Y-%m-%d") + + +def get_header() -> Optional[dict]: + if "http://" in BASE_URL: + return None + else: + # Append ID Token to make authenticated requests to Cloud Run services + headers = {"Authorization": f"Bearer {get_id_token(BASE_URL)}"} + return headers + + +async def get_connector(): + global connector + if connector is None: + connector = aiohttp.TCPConnector(limit=100) + return connector + + +async def handle_error_response(response): + if response.status != 200: + return f"Error sending {response.method} request to {str(response.url)}): {await response.text()}" + + +async def create_client_session() -> aiohttp.ClientSession: + return aiohttp.ClientSession( + connector=await get_connector(), + headers=get_header(), + raise_for_status=handle_error_response, + ) # Agent -def init_agent() -> AgentExecutor: +async def init_agent() -> UserAgent: """Load an agent executor with tools and LLM""" print("Initializing agent..") llm = VertexAI(max_output_tokens=512) memory = ConversationBufferMemory( memory_key="chat_history", input_key="input", output_key="output" ) - + client = await create_client_session() + tools = await initialize_tools(client) agent = initialize_agent( tools, llm, @@ -59,7 +147,8 @@ def init_agent() -> AgentExecutor: [("system", template), ("human", human_message_template)] ) agent.agent.llm_chain.prompt = prompt # type: ignore - return agent + + return UserAgent(client, agent) PREFIX = """SFO Airport Assistant helps travelers find their way at the airport. diff --git a/langchain_tools_demo/main.py b/langchain_tools_demo/main.py index ed56801d..a78b5e31 100644 --- a/langchain_tools_demo/main.py +++ b/langchain_tools_demo/main.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import os import uuid +from contextlib import asynccontextmanager import uvicorn from fastapi import Body, FastAPI, HTTPException, Request @@ -24,27 +26,31 @@ from markdown import markdown from starlette.middleware.sessions import SessionMiddleware -from agent import init_agent -from tools import session +from agent import init_agent, user_agents -app = FastAPI() + +@asynccontextmanager +async def lifespan(app: FastAPI): + # FastAPI app startup event + print("Loading application...") + yield + # FastAPI app shutdown event + close_client_tasks = [ + asyncio.create_task(c.client.close()) for c in user_agents.values() + ] + + asyncio.gather(*close_client_tasks) + + +# FastAPI setup +app = FastAPI(lifespan=lifespan) app.mount("/static", StaticFiles(directory="static"), name="static") # TODO: set secret_key for production app.add_middleware(SessionMiddleware, secret_key="SECRET_KEY") templates = Jinja2Templates(directory="templates") - -agents: dict[str, AgentExecutor] = {} BASE_HISTORY = [{"role": "assistant", "content": "How can I help you?"}] -async def on_shutdown(): - if session is not None: - await session.close() - - -app.add_event_handler("shutdown", on_shutdown) - - @app.get("/", response_class=HTMLResponse) def index(request: Request): """Render the default template.""" @@ -71,14 +77,14 @@ async def chat_handler(request: Request, prompt: str = Body(embed=True)): # Add user message to chat history request.session["messages"] += [{"role": "user", "content": prompt}] # Agent setup - if request.session["uuid"] in agents: - agent = agents[request.session["uuid"]] + if request.session["uuid"] in user_agents: + user_agent = user_agents[request.session["uuid"]] else: - agent = init_agent() - agents[request.session["uuid"]] = agent + user_agent = await init_agent() + user_agents[request.session["uuid"]] = user_agent try: # Send prompt to LLM - response = await agent.ainvoke({"input": prompt}) + response = await user_agent.agent.ainvoke({"input": prompt}) request.session["messages"] += [ {"role": "assistant", "content": response["output"]} ] diff --git a/langchain_tools_demo/tools.py b/langchain_tools_demo/tools.py index 4e35884b..af9d772d 100644 --- a/langchain_tools_demo/tools.py +++ b/langchain_tools_demo/tools.py @@ -13,81 +13,17 @@ # limitations under the License. import os -from datetime import date, timedelta from typing import Optional import aiohttp -import dateutil.parser as dparser -import google.auth.transport.requests # type: ignore -import google.oauth2.id_token # type: ignore from langchain.tools import StructuredTool, tool from pydantic.v1 import BaseModel, Field BASE_URL = os.getenv("BASE_URL", default="http://127.0.0.1:8080") -session = None -# create a new client session -async def get_session(): - if session is None: - client_session = aiohttp.ClientSession() - return client_session - - -# Helper functions -async def get_request(url: str, params: dict) -> aiohttp.ClientResponse: - """Helper method to make backend requests""" - session = await get_session() - params = {key: value for key, value in params.items() if value is not None} - if "http://" in url: - response = await session.get( - url, - params=params, - ) - return response - else: - # Append ID Token to make authenticated requests to Cloud Run services - response = await session.get( - url, - params=params, - headers={"Authorization": f"Bearer {get_id_token(url)}"}, - ) - return response - - -def get_id_token(url: str) -> str: - """Helper method to generate ID tokens for authenticated requests""" - # Use Application Default Credentials on Cloud Run - if os.getenv("K_SERVICE"): - auth_req = google.auth.transport.requests.Request() - return google.oauth2.id_token.fetch_id_token(auth_req, url) - else: - # Use gcloud credentials locally - import subprocess - - return ( - subprocess.run( - ["gcloud", "auth", "print-identity-token"], - stdout=subprocess.PIPE, - check=True, - ) - .stdout.strip() - .decode() - ) - - -def convert_date(date_string: str) -> str: - """Convert date into appropriate date string""" - if date_string == "tomorrow": - converted = date.today() + timedelta(1) - elif date_string == "yesterday": - converted = date.today() - timedelta(1) - elif date_string != "null" and date_string != "today" and date_string is not None: - converted = dparser.parse(date_string, fuzzy=True).date() - else: - converted = date.today() - - return converted.strftime("%Y-%m-%d") +def filter_none_values(params: dict) -> dict: + return {key: value for key, value in params.items() if value is not None} # Tools @@ -95,15 +31,16 @@ class AirportIdInput(BaseModel): id: int = Field(description="Unique identifier") -async def get_airport(id: int): - response = await get_request( - f"{BASE_URL}/airports", - {"id": id}, - ) - if response.status != 200: - return f"Error trying to find airport: {response}" +async def generate_get_airport(client: aiohttp.ClientSession): + async def get_airport(id: int): + response = await client.get( + url=f"{BASE_URL}/airports", + params={"id": id}, + ) - return await response.json() + return await response.json() + + return get_airport class AirportSearchInput(BaseModel): @@ -112,44 +49,46 @@ class AirportSearchInput(BaseModel): name: Optional[str] = Field(description="Airport name") -async def search_airports(country: str, city: str, name: str): - response = await get_request( - f"{BASE_URL}/airports/search", - { - "country": country, - "city": city, - "name": name, - }, - ) - if response.status != 200: - return f"Error searching airports: {response}" - - num = 2 - response_json = await response.json() - if len(response_json) < 1: - return "There are no airports matching that query. Let the user know there are no results." - elif len(response_json) > num: - return ( - f"There are {len(response_json)} airports matching that query. Here are the first {num} results:\n" - + " ".join([f"{response_json[i]}" for i in range(num)]) +async def generate_search_airports(client: aiohttp.ClientSession): + async def search_airports(country: str, city: str, name: str): + response = await client.get( + url=f"{BASE_URL}/airports/search", + params={ + "country": country, + "city": city, + "name": name, + }, ) - else: - return "\n".join([f"{r}" for r in response_json]) + + num = 2 + response_json = await response.json() + if len(response_json) < 1: + return "There are no airports matching that query. Let the user know there are no results." + elif len(response_json) > num: + return ( + f"There are {len(response_json)} airports matching that query. Here are the first {num} results:\n" + + " ".join([f"{response_json[i]}" for i in range(num)]) + ) + else: + return "\n".join([f"{r}" for r in response_json]) + + return search_airports class FlightIdInput(BaseModel): id: int = Field(description="Unique identifier") -async def get_flight(id: int): - response = await get_request( - f"{BASE_URL}/flights", - {"flight_id": id}, - ) - if response.status != 200: - return f"Error trying to find flight: {response}" +async def generate_get_flight(client: aiohttp.ClientSession): + async def get_flight(id: int): + response = await client.get( + url=f"{BASE_URL}/flights", + params={"flight_id": id}, + ) + + return await response.json() - return await response.json() + return get_flight class FlightNumberInput(BaseModel): @@ -157,15 +96,16 @@ class FlightNumberInput(BaseModel): flight_number: str = Field(description="1 to 4 digit number") -async def search_flights_by_number(airline: str, flight_number: str): - response = await get_request( - f"{BASE_URL}/flights/search", - {"airline": airline, "flight_number": flight_number}, - ) - if response.status != 200: - return f"Error trying to find flight: {response}" +async def generate_search_flights_by_number(client: aiohttp.ClientSession): + async def search_flights_by_number(airline: str, flight_number: str): + response = await client.get( + url=f"{BASE_URL}/flights/search", + params={"airline": airline, "flight_number": flight_number}, + ) - return await response.json() + return await response.json() + + return search_flights_by_number class ListFlights(BaseModel): @@ -176,29 +116,35 @@ class ListFlights(BaseModel): date: Optional[str] = Field(description="Date of flight departure") -async def list_flights(departure_airport: str, arrival_airport: str, date: str): - response = await get_request( - f"{BASE_URL}/flights/search", - { +async def generate_list_flights(client: aiohttp.ClientSession): + async def list_flights( + departure_airport: Optional[str], + arrival_airport: Optional[str], + date: Optional[str], + ): + params = { "departure_airport": departure_airport, "arrival_airport": arrival_airport, "date": date, - }, - ) - if response.status != 200: - return f"Error searching flights: {response}" - - num = 2 - response_json = await response.json() - if len(response_json) < 1: - return "There are no flights matching that query. Let the user know there are no results." - elif len(response_json) > num: - return ( - f"There are {len(response_json)} flights matching that query. Here are the first {num} results:\n" - + " ".join([f"{response_json[i]}" for i in range(num)]) + } + response = await client.get( + url=f"{BASE_URL}/flights/search", + params=filter_none_values(params), ) - else: - return "\n".join([f"{r}" for r in response_json]) + + num = 2 + response_json = await response.json() + if len(response_json) < 1: + return "There are no flights matching that query. Let the user know there are no results." + elif len(response_json) > num: + return ( + f"There are {len(response_json)} flights matching that query. Here are the first {num} results:\n" + + " ".join([f"{response_json[i]}" for i in range(num)]) + ) + else: + return "\n".join([f"{r}" for r in response_json]) + + return list_flights # Amenities @@ -206,35 +152,44 @@ class AmenityIdInput(BaseModel): id: int = Field(description="Unique identifier") -async def get_amenity(id: int): - response = await get_request( - f"{BASE_URL}/amenities", - {"id": id}, - ) - if response.status != 200: - return f"Error trying to find amenity: {response}" - return await response.json() +async def generate_get_amenity(client: aiohttp.ClientSession): + async def get_amenity(id: int): + response = await client.get(url=f"{BASE_URL}/amenities", params={"id": id}) + return await response.json() + + return get_amenity class QueryInput(BaseModel): query: str = Field(description="Search query") -async def search_amenities(query: str): - response = await get_request( - f"{BASE_URL}/amenities/search", {"top_k": "5", "query": query} - ) - if response.status != 200: - return f"Error searching amenities: {response}" +async def generate_search_amenities(client: aiohttp.ClientSession): + async def search_amenities(query: str): + """ + Use this tool to search amenities by name or to recommended airport amenities at SFO. + If user provides flight info, use 'Get Flight' and 'Get Flights by Number' + first to get gate info and location. + Only recommend amenities that are returned by this query. + Find amenities close to the user by matching the terminal and then comparing + the gate numbers. Gate number iterate by letter and number, example A1 A2 A3 + B1 B2 B3 C1 C2 C3. Gate A3 is close to A2 and B1. + """ + response = await client.get( + url=f"{BASE_URL}/amenities/search", params={"top_k": "5", "query": query} + ) - return await response.json() + response = await response.json() + return response + + return search_amenities # Tools for agent -def initialize_tools(): +async def initialize_tools(client: aiohttp.ClientSession): return [ StructuredTool.from_function( - coroutine=get_airport, + coroutine=await generate_get_airport(client), name="Get Airport", description="""Use this tool to get info for a specific airport. Do NOT guess an airport id. @@ -243,7 +198,7 @@ def initialize_tools(): args_schema=AirportIdInput, ), StructuredTool.from_function( - coroutine=search_airports, + coroutine=generate_search_airports, name="Search Airport", description=""" Use this tool to list all airports matching search criteria. @@ -272,7 +227,7 @@ def initialize_tools(): args_schema=AirportSearchInput, ), StructuredTool.from_function( - coroutine=get_flight, + coroutine=await generate_get_flight(client), name="Get Flight", description=""" Use this tool to get info for a specific flight. @@ -286,7 +241,7 @@ def initialize_tools(): args_schema=FlightIdInput, ), StructuredTool.from_function( - coroutine=search_flights_by_number, + coroutine=await generate_search_flights_by_number(client), name="Search Flights By Flight Number", description=""" Use this tool to get info for a specific flight. Do NOT use this tool with a flight id. @@ -299,7 +254,7 @@ def initialize_tools(): args_schema=FlightNumberInput, ), StructuredTool.from_function( - coroutine=list_flights, + coroutine=await generate_list_flights(client), name="List Flights", description=""" Use this tool to list all flights matching search criteria. @@ -328,7 +283,7 @@ def initialize_tools(): args_schema=ListFlights, ), StructuredTool.from_function( - coroutine=get_amenity, + coroutine=await generate_get_amenity(client), name="Get Amenity", description=""" Use this tool to get info for a specific airport amenity. @@ -339,7 +294,7 @@ def initialize_tools(): args_schema=AmenityIdInput, ), StructuredTool.from_function( - coroutine=search_amenities, + coroutine=await generate_search_amenities(client), name="Search Amenities", description=""" Use this tool to search amenities by name or to recommended airport amenities at SFO. @@ -353,7 +308,3 @@ def initialize_tools(): args_schema=QueryInput, ), ] - - -# Tools for agent -tools = initialize_tools()