Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
duwenxin99 committed Dec 15, 2023
1 parent 2f2455e commit 0508a88
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 11 deletions.
10 changes: 6 additions & 4 deletions langchain_tools_demo/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@
from langchain.prompts.chat import ChatPromptTemplate

from tools import initialize_tools
from typing import Dict, Optional

set_verbose(bool(os.getenv("DEBUG", default=False)))
BASE_URL = os.getenv("BASE_URL", default="http://127.0.0.1:8080")

# aiohttp context
connector = None
client_agents = {}


class ClientAgent:
Expand All @@ -45,6 +45,9 @@ def __init__(self, client, agent) -> None:
self.agent = agent


client_agents: Dict[str, ClientAgent] = {}


def get_id_token(url: str) -> str:
"""Helper method to generate ID tokens for authenticated requests"""
# Use Application Default Credentials on Cloud Run
Expand Down Expand Up @@ -80,10 +83,9 @@ def convert_date(date_string: str) -> str:
return converted.strftime("%Y-%m-%d")


def get_header() -> dict:
headers = {}
def get_header() -> Optional[dict]:
if "http://" in BASE_URL:
return headers
return None
else:
# Append ID Token to make authenticated requests to Cloud Run services
headers = {"Authorization": f"Bearer {get_id_token(BASE_URL)}"}
Expand Down
10 changes: 4 additions & 6 deletions langchain_tools_demo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from markdown import markdown
from starlette.middleware.sessions import SessionMiddleware

from agent import ClientAgent, client_agents, init_agent
from agent import client_agents, init_agent


@asynccontextmanager
Expand All @@ -35,9 +35,9 @@ async def lifespan(app: FastAPI):
print("Loading application...")
yield
# FastAPI app shutdown event
close_client_tasks = []
for c in client_agents.items:
tasks += asyncio.ensure_task(c.session.close())
close_client_tasks = [
asyncio.create_task(c.session.close()) for c in client_agents.values()
]

asyncio.gather(close_client_tasks)

Expand All @@ -48,8 +48,6 @@ async def lifespan(app: FastAPI):
# TODO: set secret_key for production
app.add_middleware(SessionMiddleware, secret_key="SECRET_KEY")
templates = Jinja2Templates(directory="templates")

client_agents: dict[str, ClientAgent] = {}
BASE_HISTORY = [{"role": "assistant", "content": "How can I help you?"}]


Expand Down
2 changes: 1 addition & 1 deletion langchain_tools_demo/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ async def list_flights(
"date": date,
}
response = await client.get(
url=f"{BASE_URL}/flights/search?depar",
url=f"{BASE_URL}/flights/search",
params=filter_none_values(params),
)

Expand Down

0 comments on commit 0508a88

Please sign in to comment.