Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: vertex ai function calling llm #188

Merged
merged 18 commits into from
Mar 6, 2024
12 changes: 6 additions & 6 deletions llm_demo/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,15 @@ async def lifespan(app: FastAPI):
print("Loading application...")
yield
# FastAPI app shutdown event
app.state.orchestration_type.close_clients()
app.state.orchestrator.close_clients()


@routes.get("/")
@routes.post("/")
async def index(request: Request):
"""Render the default template."""
# User session setup
orchestrator = request.app.state.orchestration_type
orchestrator = request.app.state.orchestrator
session = request.session
if "uuid" not in session or not orchestrator.user_session_exist(session["uuid"]):
await orchestrator.user_session_create(session)
Expand Down Expand Up @@ -75,7 +75,7 @@ async def login_google(
user_name = get_user_name(str(user_id_token), client_id)

# create new request session
orchestrator = request.app.state.orchestration_type
orchestrator = request.app.state.orchestrator
orchestrator.set_user_session_header(request.session["uuid"], str(user_id_token))
print("Logged in to Google.")

Expand Down Expand Up @@ -108,7 +108,7 @@ async def chat_handler(request: Request, prompt: str = Body(embed=True)):

# Add user message to chat history
request.session["history"].append({"type": "human", "data": {"content": prompt}})
orchestrator = request.app.state.orchestration_type
orchestrator = request.app.state.orchestrator
output = await orchestrator.user_session_invoke(request.session["uuid"], prompt)
# Return assistant response
request.session["history"].append({"type": "ai", "data": {"content": output}})
Expand All @@ -123,7 +123,7 @@ async def reset(request: Request):
raise HTTPException(status_code=400, detail=f"No session to reset.")

uuid = request.session["uuid"]
orchestrator = request.app.state.orchestration_type
orchestrator = request.app.state.orchestrator
if not orchestrator.user_session_exist(uuid):
raise HTTPException(status_code=500, detail=f"Current user session not found")

Expand All @@ -148,7 +148,7 @@ def init_app(
raise HTTPException(status_code=500, detail="Orchestrator not found")
app = FastAPI(lifespan=lifespan)
app.state.client_id = client_id
app.state.orchestration_type = createOrchestrator(orchestration_type)
app.state.orchestrator = createOrchestrator(orchestration_type)
app.include_router(routes)
app.mount("/static", StaticFiles(directory="static"), name="static")
app.add_middleware(SessionMiddleware, secret_key=secret_key)
Expand Down
9 changes: 7 additions & 2 deletions llm_demo/orchestrator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from . import langchain_tools
from . import langchain_tools, vertexai_function_calling
from .orchestrator import BaseOrchestrator, createOrchestrator

__ALL__ = ["BaseOrchestrator", "createOrchestrator", "langchain_tools"]
__ALL__ = [
"BaseOrchestrator",
"createOrchestrator",
"langchain_tools",
"vertexai_function_calling",
]
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
from .tools import initialize_tools

set_verbose(bool(os.getenv("DEBUG", default=False)))
MODEL = "gemini-pro"
BASE_HISTORY = {
"type": "ai",
"data": {"content": "Welcome to Cymbal Air! How may I assist you?"},
Expand All @@ -55,8 +54,9 @@ def initialize_agent(
tools: List[StructuredTool],
history: List[BaseMessage],
prompt: ChatPromptTemplate,
model: str,
) -> "UserAgent":
llm = VertexAI(max_output_tokens=512, model_name=MODEL)
llm = VertexAI(max_output_tokens=512, model_name=model)
memory = ConversationBufferMemory(
chat_memory=ChatMessageHistory(messages=history),
memory_key="chat_history",
Expand Down Expand Up @@ -88,10 +88,13 @@ async def invoke(self, prompt: str) -> Dict[str, Any]:


class LangChainToolsOrchestrator(BaseOrchestrator):
_user_sessions: Dict[str, UserAgent] = {}
_user_sessions: Dict[str, UserAgent]
# aiohttp context
connector = None

def __init__(self):
self._user_sessions = {}

@classproperty
def kind(cls):
return "langchain-tools"
Expand All @@ -111,7 +114,7 @@ async def user_session_create(self, session: dict[str, Any]):
client = await self.create_client_session()
tools = await initialize_tools(client)
prompt = self.create_prompt_template(tools)
agent = UserAgent.initialize_agent(client, tools, history, prompt)
agent = UserAgent.initialize_agent(client, tools, history, prompt, self.MODEL)
self._user_sessions[id] = agent

async def user_session_invoke(self, uuid: str, prompt: str) -> str:
Expand Down
2 changes: 2 additions & 0 deletions llm_demo/orchestrator/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ def __get__(self, instance, owner):


class BaseOrchestrator(ABC):
MODEL = "gemini-pro"

@classproperty
@abstractmethod
def kind(cls):
Expand Down
17 changes: 17 additions & 0 deletions llm_demo/orchestrator/vertexai_function_calling/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .function_calling_orchestrator import FunctionCallingOrchestrator

__ALL__ = ["FunctionCallingOrchestrator"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import os
import uuid
from datetime import date
from typing import Any, Dict

from aiohttp import ClientSession, TCPConnector
from fastapi import HTTPException
from google.auth.transport.requests import Request # type: ignore
from google.protobuf.json_format import MessageToDict
from vertexai.preview.generative_models import ( # type: ignore
ChatSession,
GenerationResponse,
GenerativeModel,
Part,
)

from ..orchestrator import BaseOrchestrator, classproperty
from .functions import BASE_URL, assistant_tool, function_request, get_headers

DEBUG = os.getenv("DEBUG", default=False)
BASE_HISTORY = {
"type": "ai",
"data": {"content": "Welcome to Cymbal Air! How may I assist you?"},
}


class UserChatModel:
client: ClientSession
chat: ChatSession

def __init__(self, client: ClientSession, chat: ChatSession):
self.client = client
self.chat = chat

@classmethod
def initialize_chat_model(
cls, client: ClientSession, model: str
) -> "UserChatModel":
chat_model = GenerativeModel(model, tools=[assistant_tool()])
function_calling_session = chat_model.start_chat()
return UserChatModel(client, function_calling_session)

async def close(self):
await self.client.close()

async def invoke(self, input_prompt: str) -> Dict[str, Any]:
prompt = self.get_prompt()
model_response = self.request_chat_model(prompt + input_prompt)
self.debug_log(f"Prompt:\n{prompt}.\nQuestion: {input_prompt}.")
self.debug_log(f"Function call response:\n{model_response}")
part_response = model_response.candidates[0].content.parts[0]

# implement multi turn chat with while loop
while "function_call" in part_response._raw_part:
function_call = MessageToDict(part_response.function_call._pb)
function_response = await self.request_function(function_call)
self.debug_log(f"Function response:\n{function_response}")
part = Part.from_function_response(
name=function_call["name"],
response={
"content": function_response,
},
)
model_response = self.request_chat_model(part)
part_response = model_response.candidates[0].content.parts[0]

if "text" in part_response._raw_part:
content = part_response.text
self.debug_log(f"Output content: {content}")
return {"output": content}
else:
raise HTTPException(
status_code=500, detail="Error: Chat model response unknown"
)

def get_prompt(self) -> str:
today_date = date.today().strftime("%Y-%m-%d")
prompt = f"{PREFIX}. Today is {today_date}."
return prompt

def debug_log(self, output: str) -> None:
if DEBUG:
print(output)

def request_chat_model(self, prompt: str):
try:
model_response = self.chat.send_message(prompt)
except Exception as err:
raise HTTPException(status_code=500, detail=f"Error invoking agent: {err}")
return model_response

async def request_function(self, function_call):
url = function_request(function_call["name"])
params = function_call["args"]
self.debug_log(f"Function url is {url}.\nParams is {params}.")
response = await self.client.get(
url=f"{BASE_URL}/{url}",
params=params,
headers=get_headers(self.client),
)
response = await response.json()
return response


class FunctionCallingOrchestrator(BaseOrchestrator):
_user_sessions: Dict[str, UserChatModel]
# aiohttp context
connector = None

def __init__(self):
self._user_sessions = {}

@classproperty
def kind(cls):
return "vertexai-function-calling"

def user_session_exist(self, uuid: str) -> bool:
return uuid in self._user_sessions

async def user_session_create(self, session: dict[str, Any]):
"""Create and load an agent executor with tools and LLM."""
print("Initializing agent..")
if "uuid" not in session:
session["uuid"] = str(uuid.uuid4())
id = session["uuid"]
if "history" not in session:
session["history"] = [BASE_HISTORY]
client = await self.create_client_session()
chat = UserChatModel.initialize_chat_model(client, self.MODEL)
self._user_sessions[id] = chat

async def user_session_invoke(self, uuid: str, prompt: str) -> str:
user_session = self.get_user_session(uuid)
# Send prompt to LLM
response = await user_session.invoke(prompt)
return response["output"]

async def user_session_reset(self, uuid: str):
user_session = self.get_user_session(uuid)
await user_session.close()
del user_session

def get_user_session(self, uuid: str) -> UserChatModel:
return self._user_sessions[uuid]

async def get_connector(self) -> TCPConnector:
if self.connector is None:
self.connector = TCPConnector(limit=100)
return self.connector

async def create_client_session(self) -> ClientSession:
return ClientSession(
connector=await self.get_connector(),
connector_owner=False,
headers={},
raise_for_status=True,
)

def close_clients(self):
close_client_tasks = [
asyncio.create_task(a.close()) for a in self._user_sessions.values()
]
asyncio.gather(*close_client_tasks)


PREFIX = """The Cymbal Air Customer Service Assistant helps customers of Cymbal Air with their travel needs.

Cymbal Air is a passenger airline offering convenient flights to many cities around the world from its
hub in San Francisco. Cymbal Air takes pride in using the latest technology to offer the best customer
service!

Cymbal Air Customer Service Assistant (or just "Assistant" for short) is designed to assist
with a wide range of tasks, from answering simple questions to complex multi-query questions that
require passing results from one query to another. Using the latest AI models, Assistant is able to
generate human-like text based on the input it receives, allowing it to engage in natural-sounding
conversations and provide responses that are coherent and relevant to the topic at hand.

Assistant is a powerful tool that can help answer a wide range of questions pertaining to travel on Cymbal Air
as well as ammenities of San Francisco Airport.
"""
Loading