From 88b6c83b5c83d1ff0957ed5b5daeef9e1f04e905 Mon Sep 17 00:00:00 2001 From: Yuan <45984206+Yuan325@users.noreply.github.com> Date: Thu, 7 Mar 2024 15:51:44 -0800 Subject: [PATCH] feat: implement signin signout (#258) --- llm_demo/app.py | 76 +++++++++++++++---- .../langchain_tools_orchestrator.py | 34 +++++++-- llm_demo/orchestrator/orchestrator.py | 22 +++++- .../function_calling_orchestrator.py | 26 ++++++- llm_demo/static/index.css | 19 ++++- llm_demo/static/index.js | 8 ++ llm_demo/templates/index.html | 25 +++--- 7 files changed, 171 insertions(+), 39 deletions(-) diff --git a/llm_demo/app.py b/llm_demo/app.py index 6382ac8b..6fbec3f4 100644 --- a/llm_demo/app.py +++ b/llm_demo/app.py @@ -14,7 +14,7 @@ import os from contextlib import asynccontextmanager -from typing import Optional +from typing import Any, Optional import uvicorn from fastapi import APIRouter, Body, FastAPI, HTTPException, Request @@ -48,14 +48,29 @@ async def index(request: Request): # User session setup orchestrator = request.app.state.orchestrator session = request.session + if "uuid" not in session or not orchestrator.user_session_exist(session["uuid"]): await orchestrator.user_session_create(session) + + # recheck if token and user info is still valid + user_id_token = orchestrator.get_user_id_token(session["uuid"]) + if user_id_token: + if not get_user_info(user_id_token, request.app.state.client_id): + clear_user_info(session) + elif not user_id_token and "user_info" in session: + clear_user_info(session) + return templates.TemplateResponse( "index.html", { "request": request, "messages": request.session["history"], "client_id": request.app.state.client_id, + "user_img": ( + request.session["user_info"]["user_img"] + if "user_info" in request.session + else None + ), }, ) @@ -72,29 +87,49 @@ async def login_google( client_id = request.app.state.client_id if not client_id: raise HTTPException(status_code=400, detail="Client id not found") - user_name = get_user_name(str(user_id_token), client_id) + + session = request.session + user_info = get_user_info(str(user_id_token), client_id) + session["user_info"] = user_info # create new request session orchestrator = request.app.state.orchestrator - orchestrator.set_user_session_header(request.session["uuid"], str(user_id_token)) + orchestrator.set_user_session_header(session["uuid"], str(user_id_token)) print("Logged in to Google.") - welcome_text = f"Welcome to Cymbal Air, {user_name}! How may I assist you?" + welcome_text = ( + f"Welcome to Cymbal Air, {session['user_info']['name']}! How may I assist you?" + ) if len(request.session["history"]) == 1: - request.session["history"][0] = { + session["history"][0] = { "type": "ai", "data": {"content": welcome_text}, } else: - request.session["history"].append( - {"type": "ai", "data": {"content": welcome_text}} - ) + session["history"].append({"type": "ai", "data": {"content": welcome_text}}) # Redirect to source URL source_url = request.headers["Referer"] return RedirectResponse(url=source_url) +@routes.post("/logout/google") +async def logout_google( + request: Request, +): + """Logout google account from user session and clear user session""" + if "uuid" not in request.session: + raise HTTPException(status_code=400, detail=f"No session to reset.") + + uuid = request.session["uuid"] + orchestrator = request.app.state.orchestrator + if not orchestrator.user_session_exist(uuid): + raise HTTPException(status_code=500, detail=f"Current user session not found") + + orchestrator.user_session_signout(uuid) + request.session.clear() + + @routes.post("/chat", response_class=PlainTextResponse) async def chat_handler(request: Request, prompt: str = Body(embed=True)): """Handler for LangChain chat requests""" @@ -116,7 +151,7 @@ async def chat_handler(request: Request, prompt: str = Body(embed=True)): @routes.post("/reset") -async def reset(request: Request): +def reset(request: Request): """Reset user session""" if "uuid" not in request.session: @@ -127,15 +162,24 @@ async def reset(request: Request): if not orchestrator.user_session_exist(uuid): raise HTTPException(status_code=500, detail=f"Current user session not found") - await orchestrator.user_session_reset(uuid) - request.session.clear() + orchestrator.user_session_reset(request.session, uuid) -def get_user_name(user_token_id: str, client_id: str) -> str: - id_info = id_token.verify_oauth2_token( - user_token_id, requests.Request(), audience=client_id - ) - return id_info["name"] +def get_user_info(user_id_token: str, client_id: str) -> dict[str, str]: + try: + id_info = id_token.verify_oauth2_token( + user_id_token, requests.Request(), audience=client_id + ) + return { + "user_img": id_info["picture"], + "name": id_info["name"], + } + except ValueError as err: + return {} + + +def clear_user_info(session: dict[str, Any]): + del session["user_info"] def init_app( diff --git a/llm_demo/orchestrator/langchain_tools/langchain_tools_orchestrator.py b/llm_demo/orchestrator/langchain_tools/langchain_tools_orchestrator.py index 9d73820e..f4315049 100644 --- a/llm_demo/orchestrator/langchain_tools/langchain_tools_orchestrator.py +++ b/llm_demo/orchestrator/langchain_tools/langchain_tools_orchestrator.py @@ -43,9 +43,15 @@ class UserAgent: client: ClientSession agent: AgentExecutor - def __init__(self, client: ClientSession, agent: AgentExecutor): + def __init__( + self, + client: ClientSession, + agent: AgentExecutor, + memory: ConversationBufferMemory, + ): self.client = client self.agent = agent + self.memory = memory @classmethod def initialize_agent( @@ -74,7 +80,7 @@ def initialize_agent( return_intermediate_steps=True, ) agent.agent.llm_chain.prompt = prompt # type: ignore - return UserAgent(client, agent) + return UserAgent(client, agent, memory) async def close(self): await self.client.close() @@ -86,6 +92,10 @@ async def invoke(self, prompt: str) -> Dict[str, Any]: raise HTTPException(status_code=500, detail=f"Error invoking agent: {err}") return response + def reset_memory(self, base_message: List[BaseMessage]): + self.memory.clear() + self.memory.chat_memory = ChatMessageHistory(messages=base_message) + class LangChainToolsOrchestrator(BaseOrchestrator): _user_sessions: Dict[str, UserAgent] @@ -123,10 +133,13 @@ async def user_session_invoke(self, uuid: str, prompt: str) -> str: response = await user_session.invoke(prompt) return response["output"] - async def user_session_reset(self, uuid: str): + def user_session_reset(self, session: dict[str, Any], uuid: str): user_session = self.get_user_session(uuid) - await user_session.close() - del user_session + del session["history"] + base_history = self.get_base_history(session) + session["history"] = [base_history] + history = self.parse_messages(session["history"]) + user_session.reset_memory(history) def get_user_session(self, uuid: str) -> UserAgent: return self._user_sessions[uuid] @@ -175,6 +188,17 @@ def parse_messages(self, datas: List[Any]) -> List[BaseMessage]: raise Exception("Message type not found.") return messages + def get_base_history(self, session: dict[str, Any]): + if "user_info" in session: + base_history = { + "type": "ai", + "data": { + "content": f"Welcome to Cymbal Air, {session['user_info']['name']}! How may I assist you?" + }, + } + return base_history + return BASE_HISTORY + def close_clients(self): close_client_tasks = [ asyncio.create_task(a.close()) for a in self._user_sessions.values() diff --git a/llm_demo/orchestrator/orchestrator.py b/llm_demo/orchestrator/orchestrator.py index c7e7d0d4..c4fc7f0a 100644 --- a/llm_demo/orchestrator/orchestrator.py +++ b/llm_demo/orchestrator/orchestrator.py @@ -14,7 +14,7 @@ import asyncio from abc import ABC, abstractmethod -from typing import Any +from typing import Any, Optional class classproperty: @@ -49,8 +49,8 @@ async def user_session_invoke(self, uuid: str, prompt: str) -> str: raise NotImplementedError("Subclass should implement this!") @abstractmethod - def user_session_reset(self, uuid: str): - """Clear and reset user session.""" + def user_session_reset(self, session: dict[str, Any], uuid: str): + """Reset and clear history from user session.""" raise NotImplementedError("Subclass should implement this!") @abstractmethod @@ -61,6 +61,22 @@ def set_user_session_header(self, uuid: str, user_id_token: str): user_session = self.get_user_session(uuid) user_session.client.headers["User-Id-Token"] = f"Bearer {user_id_token}" + def get_user_id_token(self, uuid: str) -> Optional[str]: + user_session = self.get_user_session(uuid) + if user_session.client and "User-Id-Token" in user_session.client.headers: + token = user_session.client.headers["User-Id-Token"] + parts = str(token).split(" ") + if len(parts) != 2 or parts[0] != "Bearer": + raise Exception("Invalid ID token") + return parts[1] + return None + + async def user_session_signout(self, uuid: str): + """Sign out from user session. Clear and restart session.""" + user_session = self.get_user_session(uuid) + await user_session.close() + del user_session + def createOrchestrator(orchestration_type: str) -> "BaseOrchestrator": for cls in BaseOrchestrator.__subclasses__(): diff --git a/llm_demo/orchestrator/vertexai_function_calling/function_calling_orchestrator.py b/llm_demo/orchestrator/vertexai_function_calling/function_calling_orchestrator.py index edd82835..551980dc 100644 --- a/llm_demo/orchestrator/vertexai_function_calling/function_calling_orchestrator.py +++ b/llm_demo/orchestrator/vertexai_function_calling/function_calling_orchestrator.py @@ -24,6 +24,7 @@ from google.protobuf.json_format import MessageToDict from vertexai.preview.generative_models import ( # type: ignore ChatSession, + Content, GenerationResponse, GenerativeModel, Part, @@ -116,6 +117,12 @@ async def request_function(self, function_call): response = await response.json() return response + def reset_memory(self, model: str): + """reinitiate chat model to reset memory.""" + del self.chat + chat_model = GenerativeModel(model, tools=[assistant_tool()]) + self.chat = chat_model.start_chat() + class FunctionCallingOrchestrator(BaseOrchestrator): _user_sessions: Dict[str, UserChatModel] @@ -150,10 +157,12 @@ async def user_session_invoke(self, uuid: str, prompt: str) -> str: response = await user_session.invoke(prompt) return response["output"] - async def user_session_reset(self, uuid: str): + def user_session_reset(self, session: dict[str, Any], uuid: str): user_session = self.get_user_session(uuid) - await user_session.close() - del user_session + del session["history"] + base_history = self.get_base_history(session) + session["history"] = [base_history] + user_session.reset_memory(self.MODEL) def get_user_session(self, uuid: str) -> UserChatModel: return self._user_sessions[uuid] @@ -171,6 +180,17 @@ async def create_client_session(self) -> ClientSession: raise_for_status=True, ) + def get_base_history(self, session: dict[str, Any]): + if "user_info" in session: + base_history = { + "type": "ai", + "data": { + "content": f"Welcome to Cymbal Air, {session['user_info']['name']}! How may I assist you?" + }, + } + return base_history + return BASE_HISTORY + def close_clients(self): close_client_tasks = [ asyncio.create_task(a.close()) for a in self._user_sessions.values() diff --git a/llm_demo/static/index.css b/llm_demo/static/index.css index c74b2a68..dac9308f 100644 --- a/llm_demo/static/index.css +++ b/llm_demo/static/index.css @@ -51,8 +51,23 @@ body { #g_id_onload, .g_id_signin { position: absolute; - top: 8px; + top: 6px; right: 10px; + .chat-user-state { + top: 6px; + position: relative; + .chat-user-image { + border-radius: 50%; + height: 36px; + width: 36px; + } + .chat-signout-btn { + border-radius: 5px; + margin-left: 10px; + height: 32px; + font-family: "Roboto"; + } + } } #menuButton { @@ -191,4 +206,4 @@ div.chat-wrapper div.chat-content div .sender-icon img { .send-button { margin-top: 12px; margin-right: 12px; -} \ No newline at end of file +} diff --git a/llm_demo/static/index.js b/llm_demo/static/index.js index 19b06de9..7a219f59 100644 --- a/llm_demo/static/index.js +++ b/llm_demo/static/index.js @@ -75,6 +75,14 @@ async function reset() { }) } +async function signout() { + await fetch('logout/google', { + method: 'POST', + }).then(()=>{ + window.location.reload() + }) +} + // Helper function to print to chatroom function log(name, msg) { let message = `
diff --git a/llm_demo/templates/index.html b/llm_demo/templates/index.html index 41ac0a4e..c1bb3890 100644 --- a/llm_demo/templates/index.html +++ b/llm_demo/templates/index.html @@ -33,7 +33,6 @@ -
menu @@ -44,15 +43,21 @@ data-context="signin" data-ux_mode="popup" data-auto_prompt="false"> -
-