Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Create individual user client session #137

Merged
merged 9 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 93 additions & 4 deletions langchain_tools_demo/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,28 +13,116 @@
# limitations under the License.

import os
from datetime import date, timedelta
from typing import Dict, Optional

import aiohttp
import dateutil.parser as dparser
import google.auth.transport.requests # type: ignore
import google.oauth2.id_token # type: ignore
from langchain.agents import AgentType, initialize_agent
from langchain.agents.agent import AgentExecutor
from langchain.globals import set_verbose # type: ignore
from langchain.llms.vertexai import VertexAI
from langchain.memory import ConversationBufferMemory
from langchain.prompts.chat import ChatPromptTemplate

from tools import convert_date, tools
from tools import initialize_tools

set_verbose(bool(os.getenv("DEBUG", default=False)))
BASE_URL = os.getenv("BASE_URL", default="http://127.0.0.1:8080")

# aiohttp context
connector = None


# Class for setting up a dedicated llm agent for each individual user
class UserAgent:
client: aiohttp.ClientSession
duwenxin99 marked this conversation as resolved.
Show resolved Hide resolved
agent: AgentExecutor

def __init__(self, client, agent) -> None:
self.client = client
self.agent = agent


user_agents: Dict[str, UserAgent] = {}


def get_id_token(url: str) -> str:
"""Helper method to generate ID tokens for authenticated requests"""
# Use Application Default Credentials on Cloud Run
if os.getenv("K_SERVICE"):
auth_req = google.auth.transport.requests.Request()
return google.oauth2.id_token.fetch_id_token(auth_req, url)
else:
# Use gcloud credentials locally
import subprocess

return (
subprocess.run(
["gcloud", "auth", "print-identity-token"],
stdout=subprocess.PIPE,
check=True,
)
.stdout.strip()
.decode()
)


def convert_date(date_string: str) -> str:
"""Convert date into appropriate date string"""
if date_string == "tomorrow":
converted = date.today() + timedelta(1)
elif date_string == "yesterday":
converted = date.today() - timedelta(1)
elif date_string != "null" and date_string != "today" and date_string is not None:
converted = dparser.parse(date_string, fuzzy=True).date()
else:
converted = date.today()

return converted.strftime("%Y-%m-%d")


def get_header() -> Optional[dict]:
if "http://" in BASE_URL:
return None
else:
# Append ID Token to make authenticated requests to Cloud Run services
headers = {"Authorization": f"Bearer {get_id_token(BASE_URL)}"}
return headers


async def get_connector():
global connector
if connector is None:
connector = aiohttp.TCPConnector(limit=100)
return connector


async def handle_error_response(response):
if response.status != 200:
return f"Error sending {response.method} request to {str(response.url)}): {await response.text()}"


async def create_client_session() -> aiohttp.ClientSession:
return aiohttp.ClientSession(
connector=await get_connector(),
headers=get_header(),
raise_for_status=handle_error_response,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep an eye out if the nicer error message worked better than this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could modify the error handler function if we think more details are needed in future usage

)


# Agent
def init_agent() -> AgentExecutor:
async def init_agent() -> UserAgent:
"""Load an agent executor with tools and LLM"""
print("Initializing agent..")
llm = VertexAI(max_output_tokens=512)
memory = ConversationBufferMemory(
memory_key="chat_history", input_key="input", output_key="output"
)

client = await create_client_session()
tools = await initialize_tools(client)
agent = initialize_agent(
tools,
llm,
Expand All @@ -59,7 +147,8 @@ def init_agent() -> AgentExecutor:
[("system", template), ("human", human_message_template)]
)
agent.agent.llm_chain.prompt = prompt # type: ignore
return agent

return UserAgent(client, agent)


PREFIX = """SFO Airport Assistant helps travelers find their way at the airport.
Expand Down
42 changes: 24 additions & 18 deletions langchain_tools_demo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import os
import uuid
from contextlib import asynccontextmanager

import uvicorn
from fastapi import Body, FastAPI, HTTPException, Request
Expand All @@ -24,27 +26,31 @@
from markdown import markdown
from starlette.middleware.sessions import SessionMiddleware

from agent import init_agent
from tools import session
from agent import init_agent, user_agents

app = FastAPI()

@asynccontextmanager
async def lifespan(app: FastAPI):
duwenxin99 marked this conversation as resolved.
Show resolved Hide resolved
# FastAPI app startup event
print("Loading application...")
yield
# FastAPI app shutdown event
close_client_tasks = [
asyncio.create_task(c.client.close()) for c in user_agents.values()
]

asyncio.gather(*close_client_tasks)


# FastAPI setup
app = FastAPI(lifespan=lifespan)
app.mount("/static", StaticFiles(directory="static"), name="static")
# TODO: set secret_key for production
app.add_middleware(SessionMiddleware, secret_key="SECRET_KEY")
templates = Jinja2Templates(directory="templates")

agents: dict[str, AgentExecutor] = {}
BASE_HISTORY = [{"role": "assistant", "content": "How can I help you?"}]


async def on_shutdown():
if session is not None:
await session.close()


app.add_event_handler("shutdown", on_shutdown)


@app.get("/", response_class=HTMLResponse)
def index(request: Request):
"""Render the default template."""
Expand All @@ -71,14 +77,14 @@ async def chat_handler(request: Request, prompt: str = Body(embed=True)):
# Add user message to chat history
request.session["messages"] += [{"role": "user", "content": prompt}]
# Agent setup
if request.session["uuid"] in agents:
agent = agents[request.session["uuid"]]
if request.session["uuid"] in user_agents:
user_agent = user_agents[request.session["uuid"]]
else:
agent = init_agent()
agents[request.session["uuid"]] = agent
user_agent = await init_agent()
user_agents[request.session["uuid"]] = user_agent
try:
# Send prompt to LLM
response = await agent.ainvoke({"input": prompt})
response = await user_agent.agent.ainvoke({"input": prompt})
request.session["messages"] += [
{"role": "assistant", "content": response["output"]}
]
Expand Down
Loading