diff --git a/langchain_tools_demo/agent.py b/langchain_tools_demo/agent.py index dad530487..9d15241c0 100644 --- a/langchain_tools_demo/agent.py +++ b/langchain_tools_demo/agent.py @@ -27,13 +27,13 @@ from langchain.prompts.chat import ChatPromptTemplate from tools import initialize_tools +from typing import Dict, Optional set_verbose(bool(os.getenv("DEBUG", default=False))) BASE_URL = os.getenv("BASE_URL", default="http://127.0.0.1:8080") # aiohttp context connector = None -client_agents = {} class ClientAgent: @@ -45,6 +45,9 @@ def __init__(self, client, agent) -> None: self.agent = agent +client_agents: Dict[str, ClientAgent] = {} + + def get_id_token(url: str) -> str: """Helper method to generate ID tokens for authenticated requests""" # Use Application Default Credentials on Cloud Run @@ -80,10 +83,9 @@ def convert_date(date_string: str) -> str: return converted.strftime("%Y-%m-%d") -def get_header() -> dict: - headers = {} +def get_header() -> Optional[dict]: if "http://" in BASE_URL: - return headers + return None else: # Append ID Token to make authenticated requests to Cloud Run services headers = {"Authorization": f"Bearer {get_id_token(BASE_URL)}"} diff --git a/langchain_tools_demo/main.py b/langchain_tools_demo/main.py index dd87b3738..31899a513 100644 --- a/langchain_tools_demo/main.py +++ b/langchain_tools_demo/main.py @@ -26,7 +26,7 @@ from markdown import markdown from starlette.middleware.sessions import SessionMiddleware -from agent import ClientAgent, client_agents, init_agent +from agent import client_agents, init_agent @asynccontextmanager @@ -35,9 +35,9 @@ async def lifespan(app: FastAPI): print("Loading application...") yield # FastAPI app shutdown event - close_client_tasks = [] - for c in client_agents.items: - tasks += asyncio.ensure_task(c.session.close()) + close_client_tasks = [ + asyncio.create_task(c.session.close()) for c in client_agents.values() + ] asyncio.gather(close_client_tasks) @@ -48,8 +48,6 @@ async def lifespan(app: FastAPI): # TODO: set secret_key for production app.add_middleware(SessionMiddleware, secret_key="SECRET_KEY") templates = Jinja2Templates(directory="templates") - -client_agents: dict[str, ClientAgent] = {} BASE_HISTORY = [{"role": "assistant", "content": "How can I help you?"}] diff --git a/langchain_tools_demo/tools.py b/langchain_tools_demo/tools.py index 63c2ab347..cbb516497 100644 --- a/langchain_tools_demo/tools.py +++ b/langchain_tools_demo/tools.py @@ -96,7 +96,7 @@ async def list_flights( "date": date, } response = await client.get( - url=f"{BASE_URL}/flights/search?depar", + url=f"{BASE_URL}/flights/search", params=filter_none_values(params), )