Skip to content

Commit

Permalink
feat: implement signin signout (#258)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuan325 authored Mar 7, 2024
1 parent 8363eaf commit 88b6c83
Show file tree
Hide file tree
Showing 7 changed files with 171 additions and 39 deletions.
76 changes: 60 additions & 16 deletions llm_demo/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import os
from contextlib import asynccontextmanager
from typing import Optional
from typing import Any, Optional

import uvicorn
from fastapi import APIRouter, Body, FastAPI, HTTPException, Request
Expand Down Expand Up @@ -48,14 +48,29 @@ async def index(request: Request):
# User session setup
orchestrator = request.app.state.orchestrator
session = request.session

if "uuid" not in session or not orchestrator.user_session_exist(session["uuid"]):
await orchestrator.user_session_create(session)

# recheck if token and user info is still valid
user_id_token = orchestrator.get_user_id_token(session["uuid"])
if user_id_token:
if not get_user_info(user_id_token, request.app.state.client_id):
clear_user_info(session)
elif not user_id_token and "user_info" in session:
clear_user_info(session)

return templates.TemplateResponse(
"index.html",
{
"request": request,
"messages": request.session["history"],
"client_id": request.app.state.client_id,
"user_img": (
request.session["user_info"]["user_img"]
if "user_info" in request.session
else None
),
},
)

Expand All @@ -72,29 +87,49 @@ async def login_google(
client_id = request.app.state.client_id
if not client_id:
raise HTTPException(status_code=400, detail="Client id not found")
user_name = get_user_name(str(user_id_token), client_id)

session = request.session
user_info = get_user_info(str(user_id_token), client_id)
session["user_info"] = user_info

# create new request session
orchestrator = request.app.state.orchestrator
orchestrator.set_user_session_header(request.session["uuid"], str(user_id_token))
orchestrator.set_user_session_header(session["uuid"], str(user_id_token))
print("Logged in to Google.")

welcome_text = f"Welcome to Cymbal Air, {user_name}! How may I assist you?"
welcome_text = (
f"Welcome to Cymbal Air, {session['user_info']['name']}! How may I assist you?"
)
if len(request.session["history"]) == 1:
request.session["history"][0] = {
session["history"][0] = {
"type": "ai",
"data": {"content": welcome_text},
}
else:
request.session["history"].append(
{"type": "ai", "data": {"content": welcome_text}}
)
session["history"].append({"type": "ai", "data": {"content": welcome_text}})

# Redirect to source URL
source_url = request.headers["Referer"]
return RedirectResponse(url=source_url)


@routes.post("/logout/google")
async def logout_google(
request: Request,
):
"""Logout google account from user session and clear user session"""
if "uuid" not in request.session:
raise HTTPException(status_code=400, detail=f"No session to reset.")

uuid = request.session["uuid"]
orchestrator = request.app.state.orchestrator
if not orchestrator.user_session_exist(uuid):
raise HTTPException(status_code=500, detail=f"Current user session not found")

orchestrator.user_session_signout(uuid)
request.session.clear()


@routes.post("/chat", response_class=PlainTextResponse)
async def chat_handler(request: Request, prompt: str = Body(embed=True)):
"""Handler for LangChain chat requests"""
Expand All @@ -116,7 +151,7 @@ async def chat_handler(request: Request, prompt: str = Body(embed=True)):


@routes.post("/reset")
async def reset(request: Request):
def reset(request: Request):
"""Reset user session"""

if "uuid" not in request.session:
Expand All @@ -127,15 +162,24 @@ async def reset(request: Request):
if not orchestrator.user_session_exist(uuid):
raise HTTPException(status_code=500, detail=f"Current user session not found")

await orchestrator.user_session_reset(uuid)
request.session.clear()
orchestrator.user_session_reset(request.session, uuid)


def get_user_name(user_token_id: str, client_id: str) -> str:
id_info = id_token.verify_oauth2_token(
user_token_id, requests.Request(), audience=client_id
)
return id_info["name"]
def get_user_info(user_id_token: str, client_id: str) -> dict[str, str]:
try:
id_info = id_token.verify_oauth2_token(
user_id_token, requests.Request(), audience=client_id
)
return {
"user_img": id_info["picture"],
"name": id_info["name"],
}
except ValueError as err:
return {}


def clear_user_info(session: dict[str, Any]):
del session["user_info"]


def init_app(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,15 @@ class UserAgent:
client: ClientSession
agent: AgentExecutor

def __init__(self, client: ClientSession, agent: AgentExecutor):
def __init__(
self,
client: ClientSession,
agent: AgentExecutor,
memory: ConversationBufferMemory,
):
self.client = client
self.agent = agent
self.memory = memory

@classmethod
def initialize_agent(
Expand Down Expand Up @@ -74,7 +80,7 @@ def initialize_agent(
return_intermediate_steps=True,
)
agent.agent.llm_chain.prompt = prompt # type: ignore
return UserAgent(client, agent)
return UserAgent(client, agent, memory)

async def close(self):
await self.client.close()
Expand All @@ -86,6 +92,10 @@ async def invoke(self, prompt: str) -> Dict[str, Any]:
raise HTTPException(status_code=500, detail=f"Error invoking agent: {err}")
return response

def reset_memory(self, base_message: List[BaseMessage]):
self.memory.clear()
self.memory.chat_memory = ChatMessageHistory(messages=base_message)


class LangChainToolsOrchestrator(BaseOrchestrator):
_user_sessions: Dict[str, UserAgent]
Expand Down Expand Up @@ -123,10 +133,13 @@ async def user_session_invoke(self, uuid: str, prompt: str) -> str:
response = await user_session.invoke(prompt)
return response["output"]

async def user_session_reset(self, uuid: str):
def user_session_reset(self, session: dict[str, Any], uuid: str):
user_session = self.get_user_session(uuid)
await user_session.close()
del user_session
del session["history"]
base_history = self.get_base_history(session)
session["history"] = [base_history]
history = self.parse_messages(session["history"])
user_session.reset_memory(history)

def get_user_session(self, uuid: str) -> UserAgent:
return self._user_sessions[uuid]
Expand Down Expand Up @@ -175,6 +188,17 @@ def parse_messages(self, datas: List[Any]) -> List[BaseMessage]:
raise Exception("Message type not found.")
return messages

def get_base_history(self, session: dict[str, Any]):
if "user_info" in session:
base_history = {
"type": "ai",
"data": {
"content": f"Welcome to Cymbal Air, {session['user_info']['name']}! How may I assist you?"
},
}
return base_history
return BASE_HISTORY

def close_clients(self):
close_client_tasks = [
asyncio.create_task(a.close()) for a in self._user_sessions.values()
Expand Down
22 changes: 19 additions & 3 deletions llm_demo/orchestrator/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import asyncio
from abc import ABC, abstractmethod
from typing import Any
from typing import Any, Optional


class classproperty:
Expand Down Expand Up @@ -49,8 +49,8 @@ async def user_session_invoke(self, uuid: str, prompt: str) -> str:
raise NotImplementedError("Subclass should implement this!")

@abstractmethod
def user_session_reset(self, uuid: str):
"""Clear and reset user session."""
def user_session_reset(self, session: dict[str, Any], uuid: str):
"""Reset and clear history from user session."""
raise NotImplementedError("Subclass should implement this!")

@abstractmethod
Expand All @@ -61,6 +61,22 @@ def set_user_session_header(self, uuid: str, user_id_token: str):
user_session = self.get_user_session(uuid)
user_session.client.headers["User-Id-Token"] = f"Bearer {user_id_token}"

def get_user_id_token(self, uuid: str) -> Optional[str]:
user_session = self.get_user_session(uuid)
if user_session.client and "User-Id-Token" in user_session.client.headers:
token = user_session.client.headers["User-Id-Token"]
parts = str(token).split(" ")
if len(parts) != 2 or parts[0] != "Bearer":
raise Exception("Invalid ID token")
return parts[1]
return None

async def user_session_signout(self, uuid: str):
"""Sign out from user session. Clear and restart session."""
user_session = self.get_user_session(uuid)
await user_session.close()
del user_session


def createOrchestrator(orchestration_type: str) -> "BaseOrchestrator":
for cls in BaseOrchestrator.__subclasses__():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from google.protobuf.json_format import MessageToDict
from vertexai.preview.generative_models import ( # type: ignore
ChatSession,
Content,
GenerationResponse,
GenerativeModel,
Part,
Expand Down Expand Up @@ -116,6 +117,12 @@ async def request_function(self, function_call):
response = await response.json()
return response

def reset_memory(self, model: str):
"""reinitiate chat model to reset memory."""
del self.chat
chat_model = GenerativeModel(model, tools=[assistant_tool()])
self.chat = chat_model.start_chat()


class FunctionCallingOrchestrator(BaseOrchestrator):
_user_sessions: Dict[str, UserChatModel]
Expand Down Expand Up @@ -150,10 +157,12 @@ async def user_session_invoke(self, uuid: str, prompt: str) -> str:
response = await user_session.invoke(prompt)
return response["output"]

async def user_session_reset(self, uuid: str):
def user_session_reset(self, session: dict[str, Any], uuid: str):
user_session = self.get_user_session(uuid)
await user_session.close()
del user_session
del session["history"]
base_history = self.get_base_history(session)
session["history"] = [base_history]
user_session.reset_memory(self.MODEL)

def get_user_session(self, uuid: str) -> UserChatModel:
return self._user_sessions[uuid]
Expand All @@ -171,6 +180,17 @@ async def create_client_session(self) -> ClientSession:
raise_for_status=True,
)

def get_base_history(self, session: dict[str, Any]):
if "user_info" in session:
base_history = {
"type": "ai",
"data": {
"content": f"Welcome to Cymbal Air, {session['user_info']['name']}! How may I assist you?"
},
}
return base_history
return BASE_HISTORY

def close_clients(self):
close_client_tasks = [
asyncio.create_task(a.close()) for a in self._user_sessions.values()
Expand Down
19 changes: 17 additions & 2 deletions llm_demo/static/index.css
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,23 @@ body {
#g_id_onload,
.g_id_signin {
position: absolute;
top: 8px;
top: 6px;
right: 10px;
.chat-user-state {
top: 6px;
position: relative;
.chat-user-image {
border-radius: 50%;
height: 36px;
width: 36px;
}
.chat-signout-btn {
border-radius: 5px;
margin-left: 10px;
height: 32px;
font-family: "Roboto";
}
}
}

#menuButton {
Expand Down Expand Up @@ -191,4 +206,4 @@ div.chat-wrapper div.chat-content div .sender-icon img {
.send-button {
margin-top: 12px;
margin-right: 12px;
}
}
8 changes: 8 additions & 0 deletions llm_demo/static/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,14 @@ async function reset() {
})
}

async function signout() {
await fetch('logout/google', {
method: 'POST',
}).then(()=>{
window.location.reload()
})
}

// Helper function to print to chatroom
function log(name, msg) {
let message = `<div class="chat-bubble ${name}">
Expand Down
25 changes: 15 additions & 10 deletions llm_demo/templates/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
</head>

<body>

<div class="container">
<div class="chat-header">
<span class="material-symbols-outlined" id="menuButton">menu</span>
Expand All @@ -44,15 +43,21 @@
data-context="signin"
data-ux_mode="popup"
data-auto_prompt="false">
</div>
<div class="g_id_signin"
data-type="standard"
data-shape="rectangular"
data-theme="outline"
data-text="signin_with"
data-size="large"
data-logo_alignment="left"
data-onsuccess="onSignIn">
{% if user_img %}
<div class="chat-user-state">
<img class="chat-user-image" src="{{ user_img }}" alt="user image"></img>
<button type="button" class="btn btn-default chat-signout-btn" id="signoutButton" onclick="signout()">Sign out</button>
</div>
{% else %}
<div class="g_id_signin"
data-type="standard"
data-shape="rectangular"
data-theme="outline"
data-text="signin"
data-size="medium"
data-logo_alignment="left">
</div>
{% endif %}
</div>
<div class="chat-wrapper">
<div class="chat-content">
Expand Down

0 comments on commit 88b6c83

Please sign in to comment.