From 82a01f873f5bfc7617912d7543b88c87777bcfce Mon Sep 17 00:00:00 2001 From: Yuan Teoh Date: Mon, 12 Feb 2024 11:56:58 -0800 Subject: [PATCH 01/18] update Dockerfile --- llm_demo/int.tests.cloudbuild.yaml | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/llm_demo/int.tests.cloudbuild.yaml b/llm_demo/int.tests.cloudbuild.yaml index 7af70c2a..d6d162e3 100644 --- a/llm_demo/int.tests.cloudbuild.yaml +++ b/llm_demo/int.tests.cloudbuild.yaml @@ -12,6 +12,16 @@ # limitations under the License. steps: + - id: "Update config" + name: python:3.11 + dir: langchain_tools_demo + entrypoint: /bin/bash + args: + - "-c" + - | + # Create config + cp example-config.yml config.yml + - id: "Deploy to Cloud Run" name: "gcr.io/cloud-builders/gcloud:latest" dir: llm_demo From 08d8d31dca9f0b0b17f6fec31afaefc2cada7db5 Mon Sep 17 00:00:00 2001 From: Yuan Teoh Date: Tue, 13 Feb 2024 14:20:51 -0800 Subject: [PATCH 02/18] update config param --- llm_demo/app.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/llm_demo/app.py b/llm_demo/app.py index b0298b28..b5ed5a80 100644 --- a/llm_demo/app.py +++ b/llm_demo/app.py @@ -24,6 +24,8 @@ from google.auth.transport import requests # type:ignore from google.oauth2 import id_token # type:ignore from markdown import markdown +from piny import StrictMatcher, YamlLoader # type: ignore +from pydantic import BaseModel from starlette.middleware.sessions import SessionMiddleware from orchestrator import BaseOrchestrator, createOrchestrator @@ -32,6 +34,20 @@ templates = Jinja2Templates(directory="templates") +class AppConfig(BaseModel): + host: IPv4Address | IPv6Address = IPv4Address("0.0.0.0") + port: int = 8081 + clientId: Optional[str] = None + # TODO: Add this at the next PR when Orchestration interface is created + # orchestration: orchestration.Config + + +def parse_config(path: str) -> AppConfig: + with open(path, "r") as file: + config = YamlLoader(path=path, matcher=StrictMatcher).load() + return AppConfig(**config) + + @asynccontextmanager async def lifespan(app: FastAPI): # FastAPI app startup event From c564134c5340dde8bf28d8e382614fab9a2d506c Mon Sep 17 00:00:00 2001 From: Yuan Teoh Date: Mon, 12 Feb 2024 22:11:52 -0800 Subject: [PATCH 03/18] add orchestration interface --- llm_demo/app.py | 67 ++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 64 insertions(+), 3 deletions(-) diff --git a/llm_demo/app.py b/llm_demo/app.py index b5ed5a80..f41c1c90 100644 --- a/llm_demo/app.py +++ b/llm_demo/app.py @@ -29,7 +29,6 @@ from starlette.middleware.sessions import SessionMiddleware from orchestrator import BaseOrchestrator, createOrchestrator - routes = APIRouter() templates = Jinja2Templates(directory="templates") @@ -38,8 +37,7 @@ class AppConfig(BaseModel): host: IPv4Address | IPv6Address = IPv4Address("0.0.0.0") port: int = 8081 clientId: Optional[str] = None - # TODO: Add this at the next PR when Orchestration interface is created - # orchestration: orchestration.Config + orchestration: Optional[str] def parse_config(path: str) -> AppConfig: @@ -54,18 +52,34 @@ async def lifespan(app: FastAPI): print("Loading application...") yield # FastAPI app shutdown event +<<<<<<< HEAD:llm_demo/app.py app.state.orchestration_type.close_clients() +======= + close_client_tasks = [asyncio.create_task(a.close()) for a in ais.values()] + + asyncio.gather(*close_client_tasks) +>>>>>>> 30d0d92 (add orchestration interface):langchain_tools_demo/main.py @routes.get("/") @routes.post("/") async def index(request: Request): """Render the default template.""" +<<<<<<< HEAD:llm_demo/app.py # User session setup orchestrator = request.app.state.orchestration_type session = request.session if "uuid" not in session or not orchestrator.user_session_exist(session["uuid"]): await orchestrator.user_session_create(session) +======= + # Agent setup + agent = await get_agent( + request.session, + user_id_token=None, + orchestration=request.app.state.orchestration, + ) + templates = Jinja2Templates(directory="templates") +>>>>>>> 30d0d92 (add orchestration interface):langchain_tools_demo/main.py return templates.TemplateResponse( "index.html", { @@ -91,8 +105,12 @@ async def login_google( user_name = get_user_name(str(user_id_token), client_id) # create new request session +<<<<<<< HEAD:llm_demo/app.py orchestrator = request.app.state.orchestration_type orchestrator.set_user_session_header(request.session["uuid"], str(user_id_token)) +======= + _ = await get_agent(request.session, str(user_id_token), orchestration=None) +>>>>>>> 30d0d92 (add orchestration interface):langchain_tools_demo/main.py print("Logged in to Google.") welcome_text = f"Welcome to Cymbal Air, {user_name}! How may I assist you?" @@ -124,11 +142,45 @@ async def chat_handler(request: Request, prompt: str = Body(embed=True)): # Add user message to chat history request.session["history"].append({"type": "human", "data": {"content": prompt}}) +<<<<<<< HEAD:llm_demo/app.py orchestrator = request.app.state.orchestration_type output = await orchestrator.user_session_invoke(request.session["uuid"], prompt) # Return assistant response request.session["history"].append({"type": "ai", "data": {"content": output}}) return markdown(output) +======= + ai = await get_agent(request.session, user_id_token=None, orchestration=None) + try: + print(prompt) + # Send prompt to LLM + response = await ai.invoke(prompt) + # Return assistant response + request.session["history"].append( + {"type": "ai", "data": {"content": response["output"]}} + ) + return markdown(response["output"]) + except Exception as err: + raise HTTPException(status_code=500, detail=f"Error invoking agent: {err}") + + +async def get_agent( + session: dict[str, Any], user_id_token: Optional[str], orchestration: Optional[str] +): + global ais + if "uuid" not in session: + session["uuid"] = str(uuid.uuid4()) + id = session["uuid"] + if "history" not in session: + session["history"] = BASE_HISTORY + if id not in ais: + if not orchestration: + raise HTTPException(status_code=500, detail="orchestration not provided.") + ais[id] = await create(orchestration, session["history"]) + ai = ais[id] + if user_id_token is not None: + ai.client.headers["User-Id-Token"] = f"Bearer {user_id_token}" + return ai +>>>>>>> 30d0d92 (add orchestration interface):langchain_tools_demo/main.py @routes.post("/reset") @@ -139,11 +191,20 @@ async def reset(request: Request): raise HTTPException(status_code=400, detail=f"No session to reset.") uuid = request.session["uuid"] +<<<<<<< HEAD:llm_demo/app.py orchestrator = request.app.state.orchestration_type if not orchestrator.user_session_exist(uuid): raise HTTPException(status_code=500, detail=f"Current user session not found") await orchestrator.user_session_reset(uuid) +======= + global ais + if uuid not in ais.keys(): + raise HTTPException(status_code=500, detail=f"Current agent not found") + + await ais[uuid].client.close() + del ais[uuid] +>>>>>>> 30d0d92 (add orchestration interface):langchain_tools_demo/main.py request.session.clear() From 87bfd6390567162c45a5e5e0abbf01f063429fd4 Mon Sep 17 00:00:00 2001 From: Yuan Teoh Date: Tue, 13 Feb 2024 21:14:49 -0800 Subject: [PATCH 04/18] update interface and resolve comments --- llm_demo/app.py | 83 +------------------ llm_demo/int.tests.cloudbuild.yaml | 10 --- llm_demo/orchestrator/__init__.py | 6 ++ .../orchestrator/langchain_tools/__init__.py | 4 + .../langchain_tools_orchestrator.py | 33 ++++++++ 5 files changed, 46 insertions(+), 90 deletions(-) diff --git a/llm_demo/app.py b/llm_demo/app.py index f41c1c90..0d125465 100644 --- a/llm_demo/app.py +++ b/llm_demo/app.py @@ -24,62 +24,32 @@ from google.auth.transport import requests # type:ignore from google.oauth2 import id_token # type:ignore from markdown import markdown -from piny import StrictMatcher, YamlLoader # type: ignore -from pydantic import BaseModel from starlette.middleware.sessions import SessionMiddleware from orchestrator import BaseOrchestrator, createOrchestrator + routes = APIRouter() templates = Jinja2Templates(directory="templates") -class AppConfig(BaseModel): - host: IPv4Address | IPv6Address = IPv4Address("0.0.0.0") - port: int = 8081 - clientId: Optional[str] = None - orchestration: Optional[str] - - -def parse_config(path: str) -> AppConfig: - with open(path, "r") as file: - config = YamlLoader(path=path, matcher=StrictMatcher).load() - return AppConfig(**config) - - @asynccontextmanager async def lifespan(app: FastAPI): # FastAPI app startup event print("Loading application...") yield # FastAPI app shutdown event -<<<<<<< HEAD:llm_demo/app.py - app.state.orchestration_type.close_clients() -======= - close_client_tasks = [asyncio.create_task(a.close()) for a in ais.values()] - - asyncio.gather(*close_client_tasks) ->>>>>>> 30d0d92 (add orchestration interface):langchain_tools_demo/main.py + app.state.orchestrator.close_clients() @routes.get("/") @routes.post("/") async def index(request: Request): """Render the default template.""" -<<<<<<< HEAD:llm_demo/app.py # User session setup orchestrator = request.app.state.orchestration_type session = request.session if "uuid" not in session or not orchestrator.user_session_exist(session["uuid"]): await orchestrator.user_session_create(session) -======= - # Agent setup - agent = await get_agent( - request.session, - user_id_token=None, - orchestration=request.app.state.orchestration, - ) - templates = Jinja2Templates(directory="templates") ->>>>>>> 30d0d92 (add orchestration interface):langchain_tools_demo/main.py return templates.TemplateResponse( "index.html", { @@ -105,12 +75,8 @@ async def login_google( user_name = get_user_name(str(user_id_token), client_id) # create new request session -<<<<<<< HEAD:llm_demo/app.py orchestrator = request.app.state.orchestration_type orchestrator.set_user_session_header(request.session["uuid"], str(user_id_token)) -======= - _ = await get_agent(request.session, str(user_id_token), orchestration=None) ->>>>>>> 30d0d92 (add orchestration interface):langchain_tools_demo/main.py print("Logged in to Google.") welcome_text = f"Welcome to Cymbal Air, {user_name}! How may I assist you?" @@ -142,45 +108,11 @@ async def chat_handler(request: Request, prompt: str = Body(embed=True)): # Add user message to chat history request.session["history"].append({"type": "human", "data": {"content": prompt}}) -<<<<<<< HEAD:llm_demo/app.py orchestrator = request.app.state.orchestration_type output = await orchestrator.user_session_invoke(request.session["uuid"], prompt) # Return assistant response request.session["history"].append({"type": "ai", "data": {"content": output}}) return markdown(output) -======= - ai = await get_agent(request.session, user_id_token=None, orchestration=None) - try: - print(prompt) - # Send prompt to LLM - response = await ai.invoke(prompt) - # Return assistant response - request.session["history"].append( - {"type": "ai", "data": {"content": response["output"]}} - ) - return markdown(response["output"]) - except Exception as err: - raise HTTPException(status_code=500, detail=f"Error invoking agent: {err}") - - -async def get_agent( - session: dict[str, Any], user_id_token: Optional[str], orchestration: Optional[str] -): - global ais - if "uuid" not in session: - session["uuid"] = str(uuid.uuid4()) - id = session["uuid"] - if "history" not in session: - session["history"] = BASE_HISTORY - if id not in ais: - if not orchestration: - raise HTTPException(status_code=500, detail="orchestration not provided.") - ais[id] = await create(orchestration, session["history"]) - ai = ais[id] - if user_id_token is not None: - ai.client.headers["User-Id-Token"] = f"Bearer {user_id_token}" - return ai ->>>>>>> 30d0d92 (add orchestration interface):langchain_tools_demo/main.py @routes.post("/reset") @@ -191,20 +123,11 @@ async def reset(request: Request): raise HTTPException(status_code=400, detail=f"No session to reset.") uuid = request.session["uuid"] -<<<<<<< HEAD:llm_demo/app.py orchestrator = request.app.state.orchestration_type if not orchestrator.user_session_exist(uuid): raise HTTPException(status_code=500, detail=f"Current user session not found") await orchestrator.user_session_reset(uuid) -======= - global ais - if uuid not in ais.keys(): - raise HTTPException(status_code=500, detail=f"Current agent not found") - - await ais[uuid].client.close() - del ais[uuid] ->>>>>>> 30d0d92 (add orchestration interface):langchain_tools_demo/main.py request.session.clear() @@ -225,7 +148,7 @@ def init_app( raise HTTPException(status_code=500, detail="Orchestrator not found") app = FastAPI(lifespan=lifespan) app.state.client_id = client_id - app.state.orchestration_type = createOrchestrator(orchestration_type) + app.state.orchestrator = BaseOrchestrator.create(orchestrator) app.include_router(routes) app.mount("/static", StaticFiles(directory="static"), name="static") app.add_middleware(SessionMiddleware, secret_key=secret_key) diff --git a/llm_demo/int.tests.cloudbuild.yaml b/llm_demo/int.tests.cloudbuild.yaml index d6d162e3..7af70c2a 100644 --- a/llm_demo/int.tests.cloudbuild.yaml +++ b/llm_demo/int.tests.cloudbuild.yaml @@ -12,16 +12,6 @@ # limitations under the License. steps: - - id: "Update config" - name: python:3.11 - dir: langchain_tools_demo - entrypoint: /bin/bash - args: - - "-c" - - | - # Create config - cp example-config.yml config.yml - - id: "Deploy to Cloud Run" name: "gcr.io/cloud-builders/gcloud:latest" dir: llm_demo diff --git a/llm_demo/orchestrator/__init__.py b/llm_demo/orchestrator/__init__.py index 35b97a74..a19d7705 100644 --- a/llm_demo/orchestrator/__init__.py +++ b/llm_demo/orchestrator/__init__.py @@ -13,6 +13,12 @@ # limitations under the License. from . import langchain_tools +<<<<<<<< HEAD:llm_demo/orchestrator/__init__.py from .orchestrator import BaseOrchestrator, createOrchestrator __ALL__ = ["BaseOrchestrator", "createOrchestrator", "langchain_tools"] +======== +from .orchestrator import BaseOrchestrator + +__ALL__ = [BaseOrchestrator, langchain_tools] +>>>>>>>> 6daec6e (update interface and resolve comments):langchain_tools_demo/orchestrator/__init__.py diff --git a/llm_demo/orchestrator/langchain_tools/__init__.py b/llm_demo/orchestrator/langchain_tools/__init__.py index 0633529c..cfefd3fc 100644 --- a/llm_demo/orchestrator/langchain_tools/__init__.py +++ b/llm_demo/orchestrator/langchain_tools/__init__.py @@ -14,4 +14,8 @@ from .langchain_tools_orchestrator import LangChainToolsOrchestrator +<<<<<<<< HEAD:llm_demo/orchestrator/langchain_tools/__init__.py __ALL__ = ["LangChainToolsOrchestrator"] +======== +__ALL__ = [LangChainToolsOrchestrator] +>>>>>>>> 6daec6e (update interface and resolve comments):langchain_tools_demo/orchestrator/langchain_tools/__init__.py diff --git a/llm_demo/orchestrator/langchain_tools/langchain_tools_orchestrator.py b/llm_demo/orchestrator/langchain_tools/langchain_tools_orchestrator.py index be147600..7dca74af 100644 --- a/llm_demo/orchestrator/langchain_tools/langchain_tools_orchestrator.py +++ b/llm_demo/orchestrator/langchain_tools/langchain_tools_orchestrator.py @@ -18,7 +18,11 @@ from datetime import date from typing import Any, Dict, List, Optional +<<<<<<<< HEAD:llm_demo/orchestrator/langchain_tools/langchain_tools_orchestrator.py from aiohttp import ClientSession, TCPConnector +======== +from aiohttp import ClientSession +>>>>>>>> 6daec6e (update interface and resolve comments):langchain_tools_demo/orchestrator/langchain_tools/langchain_tools_orchestrator.py from fastapi import HTTPException from langchain.agents import AgentType, initialize_agent from langchain.agents.agent import AgentExecutor @@ -33,11 +37,14 @@ from .tools import initialize_tools set_verbose(bool(os.getenv("DEBUG", default=False))) +<<<<<<<< HEAD:llm_demo/orchestrator/langchain_tools/langchain_tools_orchestrator.py MODEL = "gemini-pro" BASE_HISTORY = { "type": "ai", "data": {"content": "Welcome to Cymbal Air! How may I assist you?"}, } +======== +>>>>>>>> 6daec6e (update interface and resolve comments):langchain_tools_demo/orchestrator/langchain_tools/langchain_tools_orchestrator.py class UserAgent: @@ -56,7 +63,11 @@ def initialize_agent( history: List[BaseMessage], prompt: ChatPromptTemplate, ) -> "UserAgent": +<<<<<<<< HEAD:llm_demo/orchestrator/langchain_tools/langchain_tools_orchestrator.py llm = VertexAI(max_output_tokens=512, model_name=MODEL) +======== + llm = VertexAI(max_output_tokens=512, model_name="gemini-pro") +>>>>>>>> 6daec6e (update interface and resolve comments):langchain_tools_demo/orchestrator/langchain_tools/langchain_tools_orchestrator.py memory = ConversationBufferMemory( chat_memory=ChatMessageHistory(messages=history), memory_key="chat_history", @@ -80,15 +91,23 @@ async def close(self): await self.client.close() async def invoke(self, prompt: str) -> Dict[str, Any]: +<<<<<<<< HEAD:llm_demo/orchestrator/langchain_tools/langchain_tools_orchestrator.py try: response = await self.agent.ainvoke({"input": prompt}) except Exception as err: raise HTTPException(status_code=500, detail=f"Error invoking agent: {err}") +======== + response = await self.agent.ainvoke({"input": prompt}) +>>>>>>>> 6daec6e (update interface and resolve comments):langchain_tools_demo/orchestrator/langchain_tools/langchain_tools_orchestrator.py return response class LangChainToolsOrchestrator(BaseOrchestrator): +<<<<<<<< HEAD:llm_demo/orchestrator/langchain_tools/langchain_tools_orchestrator.py _user_sessions: Dict[str, UserAgent] = {} +======== + ais: Dict[str, UserAgent] = {} +>>>>>>>> 6daec6e (update interface and resolve comments):langchain_tools_demo/orchestrator/langchain_tools/langchain_tools_orchestrator.py # aiohttp context connector = None @@ -96,6 +115,7 @@ class LangChainToolsOrchestrator(BaseOrchestrator): def kind(cls): return "langchain-tools" +<<<<<<<< HEAD:llm_demo/orchestrator/langchain_tools/langchain_tools_orchestrator.py def user_session_exist(self, uuid: str) -> bool: return uuid in self._user_sessions @@ -108,10 +128,17 @@ async def user_session_create(self, session: dict[str, Any]): if "history" not in session: session["history"] = [BASE_HISTORY] history = self.parse_messages(session["history"]) +======== + async def create_ai(self, base_history: List[Any]) -> UserAgent: + """Create and load an agent executor with tools and LLM.""" + print("Initializing agent..") + history = self.parse_messages(base_history) +>>>>>>>> 6daec6e (update interface and resolve comments):langchain_tools_demo/orchestrator/langchain_tools/langchain_tools_orchestrator.py client = await self.create_client_session() tools = await initialize_tools(client) prompt = self.create_prompt_template(tools) agent = UserAgent.initialize_agent(client, tools, history, prompt) +<<<<<<<< HEAD:llm_demo/orchestrator/langchain_tools/langchain_tools_orchestrator.py self._user_sessions[id] = agent async def user_session_invoke(self, uuid: str, prompt: str) -> str: @@ -140,6 +167,9 @@ async def create_client_session(self) -> ClientSession: headers={}, raise_for_status=True, ) +======== + return agent +>>>>>>>> 6daec6e (update interface and resolve comments):langchain_tools_demo/orchestrator/langchain_tools/langchain_tools_orchestrator.py def create_prompt_template(self, tools: List[StructuredTool]) -> ChatPromptTemplate: # Create new prompt template @@ -172,12 +202,15 @@ def parse_messages(self, datas: List[Any]) -> List[BaseMessage]: raise Exception("Message type not found.") return messages +<<<<<<<< HEAD:llm_demo/orchestrator/langchain_tools/langchain_tools_orchestrator.py def close_clients(self): close_client_tasks = [ asyncio.create_task(a.close()) for a in self._user_sessions.values() ] asyncio.gather(*close_client_tasks) +======== +>>>>>>>> 6daec6e (update interface and resolve comments):langchain_tools_demo/orchestrator/langchain_tools/langchain_tools_orchestrator.py PREFIX = """The Cymbal Air Customer Service Assistant helps customers of Cymbal Air with their travel needs. From 15de72ed75a6204b5588c35743ff5307b6006197 Mon Sep 17 00:00:00 2001 From: Yuan Teoh Date: Tue, 20 Feb 2024 20:40:41 -0800 Subject: [PATCH 05/18] update interface --- llm_demo/app.py | 30 ++++++++++++++++- llm_demo/orchestrator/__init__.py | 6 ---- .../langchain_tools_orchestrator.py | 33 ------------------- 3 files changed, 29 insertions(+), 40 deletions(-) diff --git a/llm_demo/app.py b/llm_demo/app.py index 0d125465..d65a8ab1 100644 --- a/llm_demo/app.py +++ b/llm_demo/app.py @@ -45,11 +45,18 @@ async def lifespan(app: FastAPI): @routes.post("/") async def index(request: Request): """Render the default template.""" +<<<<<<< HEAD:llm_demo/app.py # User session setup orchestrator = request.app.state.orchestration_type session = request.session if "uuid" not in session or not orchestrator.user_session_exist(session["uuid"]): await orchestrator.user_session_create(session) +======= + # Agent setup + orchestrator = request.app.state.orchestrator + await orchestrator.user_session_create(request.session) + templates = Jinja2Templates(directory="templates") +>>>>>>> 493e7c5 (update interface):langchain_tools_demo/main.py return templates.TemplateResponse( "index.html", { @@ -75,7 +82,11 @@ async def login_google( user_name = get_user_name(str(user_id_token), client_id) # create new request session +<<<<<<< HEAD:llm_demo/app.py orchestrator = request.app.state.orchestration_type +======= + orchestrator = request.app.state.orchestrator +>>>>>>> 493e7c5 (update interface):langchain_tools_demo/main.py orchestrator.set_user_session_header(request.session["uuid"], str(user_id_token)) print("Logged in to Google.") @@ -108,7 +119,11 @@ async def chat_handler(request: Request, prompt: str = Body(embed=True)): # Add user message to chat history request.session["history"].append({"type": "human", "data": {"content": prompt}}) +<<<<<<< HEAD:llm_demo/app.py orchestrator = request.app.state.orchestration_type +======= + orchestrator = request.app.state.orchestrator +>>>>>>> 493e7c5 (update interface):langchain_tools_demo/main.py output = await orchestrator.user_session_invoke(request.session["uuid"], prompt) # Return assistant response request.session["history"].append({"type": "ai", "data": {"content": output}}) @@ -123,9 +138,15 @@ async def reset(request: Request): raise HTTPException(status_code=400, detail=f"No session to reset.") uuid = request.session["uuid"] +<<<<<<< HEAD:llm_demo/app.py orchestrator = request.app.state.orchestration_type if not orchestrator.user_session_exist(uuid): raise HTTPException(status_code=500, detail=f"Current user session not found") +======= + orchestrator = request.app.state.orchestrator + if not orchestrator.user_session_exist(uuid): + raise HTTPException(status_code=500, detail=f"Current agent not found") +>>>>>>> 493e7c5 (update interface):langchain_tools_demo/main.py await orchestrator.user_session_reset(uuid) request.session.clear() @@ -139,16 +160,23 @@ def get_user_name(user_token_id: str, client_id: str) -> str: def init_app( +<<<<<<< HEAD:llm_demo/app.py orchestration_type: Optional[str], client_id: Optional[str], secret_key: Optional[str], ) -> FastAPI: # FastAPI setup if orchestration_type is None: +======= + orchestrator: Optional[str], client_id: Optional[str], secret_key: Optional[str] +) -> FastAPI: + # FastAPI setup + if orchestrator is None: +>>>>>>> 493e7c5 (update interface):langchain_tools_demo/main.py raise HTTPException(status_code=500, detail="Orchestrator not found") app = FastAPI(lifespan=lifespan) app.state.client_id = client_id - app.state.orchestrator = BaseOrchestrator.create(orchestrator) + app.state.orchestrator = createOrchestrator(orchestrator) app.include_router(routes) app.mount("/static", StaticFiles(directory="static"), name="static") app.add_middleware(SessionMiddleware, secret_key=secret_key) diff --git a/llm_demo/orchestrator/__init__.py b/llm_demo/orchestrator/__init__.py index a19d7705..35b97a74 100644 --- a/llm_demo/orchestrator/__init__.py +++ b/llm_demo/orchestrator/__init__.py @@ -13,12 +13,6 @@ # limitations under the License. from . import langchain_tools -<<<<<<<< HEAD:llm_demo/orchestrator/__init__.py from .orchestrator import BaseOrchestrator, createOrchestrator __ALL__ = ["BaseOrchestrator", "createOrchestrator", "langchain_tools"] -======== -from .orchestrator import BaseOrchestrator - -__ALL__ = [BaseOrchestrator, langchain_tools] ->>>>>>>> 6daec6e (update interface and resolve comments):langchain_tools_demo/orchestrator/__init__.py diff --git a/llm_demo/orchestrator/langchain_tools/langchain_tools_orchestrator.py b/llm_demo/orchestrator/langchain_tools/langchain_tools_orchestrator.py index 7dca74af..be147600 100644 --- a/llm_demo/orchestrator/langchain_tools/langchain_tools_orchestrator.py +++ b/llm_demo/orchestrator/langchain_tools/langchain_tools_orchestrator.py @@ -18,11 +18,7 @@ from datetime import date from typing import Any, Dict, List, Optional -<<<<<<<< HEAD:llm_demo/orchestrator/langchain_tools/langchain_tools_orchestrator.py from aiohttp import ClientSession, TCPConnector -======== -from aiohttp import ClientSession ->>>>>>>> 6daec6e (update interface and resolve comments):langchain_tools_demo/orchestrator/langchain_tools/langchain_tools_orchestrator.py from fastapi import HTTPException from langchain.agents import AgentType, initialize_agent from langchain.agents.agent import AgentExecutor @@ -37,14 +33,11 @@ from .tools import initialize_tools set_verbose(bool(os.getenv("DEBUG", default=False))) -<<<<<<<< HEAD:llm_demo/orchestrator/langchain_tools/langchain_tools_orchestrator.py MODEL = "gemini-pro" BASE_HISTORY = { "type": "ai", "data": {"content": "Welcome to Cymbal Air! How may I assist you?"}, } -======== ->>>>>>>> 6daec6e (update interface and resolve comments):langchain_tools_demo/orchestrator/langchain_tools/langchain_tools_orchestrator.py class UserAgent: @@ -63,11 +56,7 @@ def initialize_agent( history: List[BaseMessage], prompt: ChatPromptTemplate, ) -> "UserAgent": -<<<<<<<< HEAD:llm_demo/orchestrator/langchain_tools/langchain_tools_orchestrator.py llm = VertexAI(max_output_tokens=512, model_name=MODEL) -======== - llm = VertexAI(max_output_tokens=512, model_name="gemini-pro") ->>>>>>>> 6daec6e (update interface and resolve comments):langchain_tools_demo/orchestrator/langchain_tools/langchain_tools_orchestrator.py memory = ConversationBufferMemory( chat_memory=ChatMessageHistory(messages=history), memory_key="chat_history", @@ -91,23 +80,15 @@ async def close(self): await self.client.close() async def invoke(self, prompt: str) -> Dict[str, Any]: -<<<<<<<< HEAD:llm_demo/orchestrator/langchain_tools/langchain_tools_orchestrator.py try: response = await self.agent.ainvoke({"input": prompt}) except Exception as err: raise HTTPException(status_code=500, detail=f"Error invoking agent: {err}") -======== - response = await self.agent.ainvoke({"input": prompt}) ->>>>>>>> 6daec6e (update interface and resolve comments):langchain_tools_demo/orchestrator/langchain_tools/langchain_tools_orchestrator.py return response class LangChainToolsOrchestrator(BaseOrchestrator): -<<<<<<<< HEAD:llm_demo/orchestrator/langchain_tools/langchain_tools_orchestrator.py _user_sessions: Dict[str, UserAgent] = {} -======== - ais: Dict[str, UserAgent] = {} ->>>>>>>> 6daec6e (update interface and resolve comments):langchain_tools_demo/orchestrator/langchain_tools/langchain_tools_orchestrator.py # aiohttp context connector = None @@ -115,7 +96,6 @@ class LangChainToolsOrchestrator(BaseOrchestrator): def kind(cls): return "langchain-tools" -<<<<<<<< HEAD:llm_demo/orchestrator/langchain_tools/langchain_tools_orchestrator.py def user_session_exist(self, uuid: str) -> bool: return uuid in self._user_sessions @@ -128,17 +108,10 @@ async def user_session_create(self, session: dict[str, Any]): if "history" not in session: session["history"] = [BASE_HISTORY] history = self.parse_messages(session["history"]) -======== - async def create_ai(self, base_history: List[Any]) -> UserAgent: - """Create and load an agent executor with tools and LLM.""" - print("Initializing agent..") - history = self.parse_messages(base_history) ->>>>>>>> 6daec6e (update interface and resolve comments):langchain_tools_demo/orchestrator/langchain_tools/langchain_tools_orchestrator.py client = await self.create_client_session() tools = await initialize_tools(client) prompt = self.create_prompt_template(tools) agent = UserAgent.initialize_agent(client, tools, history, prompt) -<<<<<<<< HEAD:llm_demo/orchestrator/langchain_tools/langchain_tools_orchestrator.py self._user_sessions[id] = agent async def user_session_invoke(self, uuid: str, prompt: str) -> str: @@ -167,9 +140,6 @@ async def create_client_session(self) -> ClientSession: headers={}, raise_for_status=True, ) -======== - return agent ->>>>>>>> 6daec6e (update interface and resolve comments):langchain_tools_demo/orchestrator/langchain_tools/langchain_tools_orchestrator.py def create_prompt_template(self, tools: List[StructuredTool]) -> ChatPromptTemplate: # Create new prompt template @@ -202,15 +172,12 @@ def parse_messages(self, datas: List[Any]) -> List[BaseMessage]: raise Exception("Message type not found.") return messages -<<<<<<<< HEAD:llm_demo/orchestrator/langchain_tools/langchain_tools_orchestrator.py def close_clients(self): close_client_tasks = [ asyncio.create_task(a.close()) for a in self._user_sessions.values() ] asyncio.gather(*close_client_tasks) -======== ->>>>>>>> 6daec6e (update interface and resolve comments):langchain_tools_demo/orchestrator/langchain_tools/langchain_tools_orchestrator.py PREFIX = """The Cymbal Air Customer Service Assistant helps customers of Cymbal Air with their travel needs. From 98010234bd2aee0de5b17478a84c0f98551bd264 Mon Sep 17 00:00:00 2001 From: Yuan Teoh Date: Sun, 11 Feb 2024 23:43:31 -0800 Subject: [PATCH 06/18] update app in langchain demo --- llm_demo/app.py | 6 ++++++ llm_demo/example-config.yml | 4 ++++ llm_demo/run_app.py | 1 - 3 files changed, 10 insertions(+), 1 deletion(-) create mode 100644 llm_demo/example-config.yml diff --git a/llm_demo/app.py b/llm_demo/app.py index d65a8ab1..b311ff24 100644 --- a/llm_demo/app.py +++ b/llm_demo/app.py @@ -31,6 +31,12 @@ routes = APIRouter() templates = Jinja2Templates(directory="templates") +BASE_HISTORY: list[BaseMessage] = [ + AIMessage(content="I am an SFO Airport Assistant, ready to assist you.") +] +CLIENT_ID = os.getenv("CLIENT_ID") +routes = APIRouter() + @asynccontextmanager async def lifespan(app: FastAPI): diff --git a/llm_demo/example-config.yml b/llm_demo/example-config.yml new file mode 100644 index 00000000..24465a79 --- /dev/null +++ b/llm_demo/example-config.yml @@ -0,0 +1,4 @@ +host: 0.0.0.0 +port: 8081 +orchestration: + kind: ${LLM_FRAMEWORK} diff --git a/llm_demo/run_app.py b/llm_demo/run_app.py index 7fe1e5c9..7be2d849 100644 --- a/llm_demo/run_app.py +++ b/llm_demo/run_app.py @@ -15,7 +15,6 @@ import asyncio import os - import uvicorn from app import init_app From 1d19349387c156ec2e5edaf4e5e81076c811d3dc Mon Sep 17 00:00:00 2001 From: Yuan Teoh Date: Mon, 12 Feb 2024 11:43:37 -0800 Subject: [PATCH 07/18] update requirement --- llm_demo/requirements.txt | 4 ++++ llm_demo/run_app.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/llm_demo/requirements.txt b/llm_demo/requirements.txt index 2dcb24d1..05095b87 100644 --- a/llm_demo/requirements.txt +++ b/llm_demo/requirements.txt @@ -9,3 +9,7 @@ markdown==3.5.2 types-Markdown==3.5.0.20240129 uvicorn[standard]==0.27.0.post1 python-multipart==0.0.7 +<<<<<<< HEAD +======= +piny==1.1.0 +>>>>>>> 94c687f (update requirement) diff --git a/llm_demo/run_app.py b/llm_demo/run_app.py index 7be2d849..f66d3777 100644 --- a/llm_demo/run_app.py +++ b/llm_demo/run_app.py @@ -16,6 +16,10 @@ import asyncio import os import uvicorn +<<<<<<< HEAD +======= +from piny import StrictMatcher, YamlLoader # type: ignore +>>>>>>> 94c687f (update requirement) from app import init_app From f8cc5ead769df0ac5564ade884756ef4d117d398 Mon Sep 17 00:00:00 2001 From: Yuan Teoh Date: Mon, 12 Feb 2024 11:56:58 -0800 Subject: [PATCH 08/18] update Dockerfile --- llm_demo/int.tests.cloudbuild.yaml | 10 ++++++++++ llm_demo/run_app.py | 4 ---- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/llm_demo/int.tests.cloudbuild.yaml b/llm_demo/int.tests.cloudbuild.yaml index 7af70c2a..d6d162e3 100644 --- a/llm_demo/int.tests.cloudbuild.yaml +++ b/llm_demo/int.tests.cloudbuild.yaml @@ -12,6 +12,16 @@ # limitations under the License. steps: + - id: "Update config" + name: python:3.11 + dir: langchain_tools_demo + entrypoint: /bin/bash + args: + - "-c" + - | + # Create config + cp example-config.yml config.yml + - id: "Deploy to Cloud Run" name: "gcr.io/cloud-builders/gcloud:latest" dir: llm_demo diff --git a/llm_demo/run_app.py b/llm_demo/run_app.py index f66d3777..7be2d849 100644 --- a/llm_demo/run_app.py +++ b/llm_demo/run_app.py @@ -16,10 +16,6 @@ import asyncio import os import uvicorn -<<<<<<< HEAD -======= -from piny import StrictMatcher, YamlLoader # type: ignore ->>>>>>> 94c687f (update requirement) from app import init_app From 2ec70532c0b7c0fb48e49267ad81a077dc8e2da0 Mon Sep 17 00:00:00 2001 From: Yuan Teoh Date: Mon, 12 Feb 2024 22:11:52 -0800 Subject: [PATCH 09/18] add orchestration interface --- .../orchestration/__init__.py | 18 + .../orchestration/langchain_tools/__init__.py | 17 + .../orchestration/langchain_tools/agent.py | 169 +++++++++ .../orchestration/langchain_tools/tools.py | 357 ++++++++++++++++++ .../orchestration/orchestration.py | 82 ++++ llm_demo/app.py | 8 +- 6 files changed, 649 insertions(+), 2 deletions(-) create mode 100644 langchain_tools_demo/orchestration/__init__.py create mode 100644 langchain_tools_demo/orchestration/langchain_tools/__init__.py create mode 100644 langchain_tools_demo/orchestration/langchain_tools/agent.py create mode 100644 langchain_tools_demo/orchestration/langchain_tools/tools.py create mode 100644 langchain_tools_demo/orchestration/orchestration.py diff --git a/langchain_tools_demo/orchestration/__init__.py b/langchain_tools_demo/orchestration/__init__.py new file mode 100644 index 00000000..256e7d36 --- /dev/null +++ b/langchain_tools_demo/orchestration/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import langchain_tools +from .orchestration import Orchestration, ais, create + +__ALL__ = [Orchestration, create, ais, langchain_tools] diff --git a/langchain_tools_demo/orchestration/langchain_tools/__init__.py b/langchain_tools_demo/orchestration/langchain_tools/__init__.py new file mode 100644 index 00000000..5b965e2b --- /dev/null +++ b/langchain_tools_demo/orchestration/langchain_tools/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import agent + +__ALL__ = [agent] diff --git a/langchain_tools_demo/orchestration/langchain_tools/agent.py b/langchain_tools_demo/orchestration/langchain_tools/agent.py new file mode 100644 index 00000000..336fed2f --- /dev/null +++ b/langchain_tools_demo/orchestration/langchain_tools/agent.py @@ -0,0 +1,169 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from datetime import date +from typing import Any, Dict, List + +import aiohttp +from fastapi import HTTPException +from langchain.agents import AgentType, initialize_agent +from langchain.agents.agent import AgentExecutor +from langchain.globals import set_verbose # type: ignore +from langchain.memory import ChatMessageHistory, ConversationBufferMemory +from langchain.prompts.chat import ChatPromptTemplate +from langchain_core import messages +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage +from langchain_google_vertexai import VertexAI + +from .. import orchestration +from .tools import initialize_tools + +set_verbose(bool(os.getenv("DEBUG", default=False))) +BASE_URL = os.getenv("BASE_URL", default="http://127.0.0.1:8080") + +CLOUD_RUN_AUTHORIZATION_TOKEN = None + + +class Orchestration(orchestration.Orchestration): + client: aiohttp.ClientSession + agent: AgentExecutor + + @orchestration.classproperty + def kind(cls): + return "langchain-tools" + + def __init__(self, client: aiohttp.ClientSession, agent: AgentExecutor): + self.client = client + self.agent = agent + + @classmethod + async def create(cls, base_history: list[Any]) -> "Orchestration": + """Load an agent executor with tools and LLM""" + print("Initializing agent..") + history = Orchestration.parse_messages(base_history) + llm = VertexAI(max_output_tokens=512, model_name="gemini-pro") + memory = ConversationBufferMemory( + chat_memory=ChatMessageHistory(messages=history), + memory_key="chat_history", + input_key="input", + output_key="output", + ) + client = await Orchestration.create_client_session() + tools = await initialize_tools(client) + agent = initialize_agent( + tools, + llm, + agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION, + memory=memory, + handle_parsing_errors=True, + max_iterations=3, + early_stopping_method="generate", + return_intermediate_steps=True, + ) + # Create new prompt template + tool_strings = "\n".join( + [f"> {tool.name}: {tool.description}" for tool in tools] + ) + tool_names = ", ".join([tool.name for tool in tools]) + format_instructions = FORMAT_INSTRUCTIONS.format( + tool_names=tool_names, + ) + today_date = date.today().strftime("%Y-%m-%d") + today = f"Today is {today_date}." + template = "\n\n".join( + [PREFIX, tool_strings, format_instructions, SUFFIX, today] + ) + human_message_template = "{input}\n\n{agent_scratchpad}" + prompt = ChatPromptTemplate.from_messages( + [("system", template), ("human", human_message_template)] + ) + agent.agent.llm_chain.prompt = prompt # type: ignore + + return Orchestration(client, agent) + + async def invoke(self, prompt: str) -> Dict[str, Any]: + response = await self.agent.ainvoke({"input": prompt}) + return response + + @staticmethod + def parse_messages(datas: List[Any]) -> List[BaseMessage]: + messages: List[BaseMessage] = [] + for data in datas: + if data["type"] == "human": + messages.append(HumanMessage(content=data["data"]["content"])) + if data["type"] == "ai": + messages.append(AIMessage(content=data["data"]["content"])) + return messages + + def close(self): + self.client.close() + + +PREFIX = """SFO Airport Assistant helps travelers find their way at the airport. + +Assistant is designed to be able to assist with a wide range of tasks, from answering simple questions to +complex multi-query questions that require passing results from one query to another. As a language model, Assistant is +able to generate human-like text based on the input it receives, allowing it to engage in natural-sounding +conversations and provide responses that are coherent and relevant to the topic at hand. + +Overall, Assistant is a powerful tool that can help answer a wide range of questions pertaining to the San +Francisco Airport. SFO Airport Assistant is here to assist. It currently does not have access to user info. + +TOOLS: +------ + +Assistant has access to the following tools:""" + +FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) +and an action_input key (tool input). + +Valid "action" values: "Final Answer" or {tool_names} + +Provide only ONE action per $JSON_BLOB, as shown: + +``` +{{{{ + "action": $TOOL_NAME, + "action_input": $INPUT +}}}} +``` + +Follow this format: + +Question: input question to answer +Thought: consider previous and subsequent steps +Action: +``` +$JSON_BLOB +``` +Observation: action result +... (repeat Thought/Action/Observation N times) +Thought: I know what to respond +Action: +``` +{{{{ + "action": "Final Answer", + "action_input": "Final response to human" +}}}} +```""" + +SUFFIX = """Begin! Use tools if necessary. Respond directly if appropriate. +If using a tool, reminder to ALWAYS respond with a valid json blob of a single action. +Format is Action:```$JSON_BLOB```then Observation:. +Thought: + +Previous conversation history: +{chat_history} +""" diff --git a/langchain_tools_demo/orchestration/langchain_tools/tools.py b/langchain_tools_demo/orchestration/langchain_tools/tools.py new file mode 100644 index 00000000..52295065 --- /dev/null +++ b/langchain_tools_demo/orchestration/langchain_tools/tools.py @@ -0,0 +1,357 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from datetime import datetime +from typing import Optional + +import aiohttp +import google.oauth2.id_token # type: ignore +from google.auth import compute_engine # type: ignore +from google.auth.transport.requests import Request # type: ignore +from langchain.agents.agent import ExceptionTool # type: ignore +from langchain.tools import StructuredTool +from pydantic.v1 import BaseModel, Field + +BASE_URL = os.getenv("BASE_URL", default="http://127.0.0.1:8080") +CREDENTIALS = None + + +def filter_none_values(params: dict) -> dict: + return {key: value for key, value in params.items() if value is not None} + + +def get_id_token(): + global CREDENTIALS + if CREDENTIALS is None: + CREDENTIALS, _ = google.auth.default() + if not hasattr(CREDENTIALS, "id_token"): + # Use Compute Engine default credential + CREDENTIALS = compute_engine.IDTokenCredentials( + request=Request(), + target_audience=BASE_URL, + use_metadata_identity_endpoint=True, + ) + if not CREDENTIALS.valid: + CREDENTIALS.refresh(Request()) + if hasattr(CREDENTIALS, "id_token"): + return CREDENTIALS.id_token + else: + return CREDENTIALS.token + + +def get_headers(client: aiohttp.ClientSession): + """Helper method to generate ID tokens for authenticated requests""" + headers = client.headers + if not "http://" in BASE_URL: + # Append ID Token to make authenticated requests to Cloud Run services + headers["Authorization"] = f"Bearer {get_id_token()}" + return headers + + +# Tools +class AirportSearchInput(BaseModel): + country: Optional[str] = Field(description="Country") + city: Optional[str] = Field(description="City") + name: Optional[str] = Field(description="Airport name") + + +def generate_search_airports(client: aiohttp.ClientSession): + async def search_airports(country: str, city: str, name: str): + params = { + "country": country, + "city": city, + "name": name, + } + response = await client.get( + url=f"{BASE_URL}/airports/search", + params=filter_none_values(params), + headers=get_headers(client), + ) + + num = 2 + response_json = await response.json() + if len(response_json) < 1: + return "There are no airports matching that query. Let the user know there are no results." + elif len(response_json) > num: + return ( + f"There are {len(response_json)} airports matching that query. Here are the first {num} results:\n" + + " ".join([f"{response_json[i]}" for i in range(num)]) + ) + else: + return "\n".join([f"{r}" for r in response_json]) + + return search_airports + + +class FlightNumberInput(BaseModel): + airline: str = Field(description="Airline unique 2 letter identifier") + flight_number: str = Field(description="1 to 4 digit number") + + +def generate_search_flights_by_number(client: aiohttp.ClientSession): + async def search_flights_by_number(airline: str, flight_number: str): + response = await client.get( + url=f"{BASE_URL}/flights/search", + params={"airline": airline, "flight_number": flight_number}, + headers=get_headers(client), + ) + + return await response.json() + + return search_flights_by_number + + +class ListFlights(BaseModel): + departure_airport: Optional[str] = Field( + description="Departure airport 3-letter code", + ) + arrival_airport: Optional[str] = Field(description="Arrival airport 3-letter code") + date: Optional[str] = Field(description="Date of flight departure") + + +def generate_list_flights(client: aiohttp.ClientSession): + async def list_flights( + departure_airport: str, + arrival_airport: str, + date: str, + ): + params = { + "departure_airport": departure_airport, + "arrival_airport": arrival_airport, + "date": date, + } + response = await client.get( + url=f"{BASE_URL}/flights/search", + params=filter_none_values(params), + headers=get_headers(client), + ) + + num = 2 + response_json = await response.json() + if len(response_json) < 1: + return "There are no flights matching that query. Let the user know there are no results." + elif len(response_json) > num: + return ( + f"There are {len(response_json)} flights matching that query. Here are the first {num} results:\n" + + " ".join([f"{response_json[i]}" for i in range(num)]) + ) + else: + return "\n".join([f"{r}" for r in response_json]) + + return list_flights + + +class QueryInput(BaseModel): + query: str = Field(description="Search query") + + +def generate_search_amenities(client: aiohttp.ClientSession): + async def search_amenities(query: str): + response = await client.get( + url=f"{BASE_URL}/amenities/search", + params={"top_k": "5", "query": query}, + headers=get_headers(client), + ) + + response = await response.json() + return response + + return search_amenities + + +class TicketInput(BaseModel): + airline: str = Field(description="Airline unique 2 letter identifier") + flight_number: str = Field(description="1 to 4 digit number") + departure_airport: str = Field( + description="Departure airport 3-letter code", + ) + arrival_airport: str = Field(description="Arrival airport 3-letter code") + departure_time: datetime = Field(description="Flight departure datetime") + arrival_time: datetime = Field(description="Flight arrival datetime") + + +def generate_insert_ticket(client: aiohttp.ClientSession): + async def insert_ticket( + airline: str, + flight_number: str, + departure_airport: str, + arrival_airport: str, + departure_time: datetime, + arrival_time: datetime, + ): + response = await client.post( + url=f"{BASE_URL}/tickets/insert", + params={ + "airline": airline, + "flight_number": flight_number, + "departure_airport": departure_airport, + "arrival_airport": arrival_airport, + "departure_time": departure_time.strftime("%Y-%m-%d %H:%M:%S"), + "arrival_time": arrival_time.strftime("%Y-%m-%d %H:%M:%S"), + }, + headers=get_headers(client), + ) + + response = await response.json() + return response + + return insert_ticket + + +def generate_list_tickets(client: aiohttp.ClientSession): + async def list_tickets(): + response = await client.get( + url=f"{BASE_URL}/tickets/list", + headers=get_headers(client), + ) + + response = await response.json() + return response + + return list_tickets + + +# Tools for agent +async def initialize_tools(client: aiohttp.ClientSession): + return [ + StructuredTool.from_function( + coroutine=generate_search_airports(client), + name="Search Airport", + description=""" + Use this tool to list all airports matching search criteria. + Takes at least one of country, city, name, or all and returns all matching airports. + The agent can decide to return the results directly to the user. + Input of this tool must be in JSON format and include all three inputs - country, city, name. + Example: + {{ + "country": "United States", + "city": "San Francisco", + "name": null + }} + Example: + {{ + "country": null, + "city": "Goroka", + "name": "Goroka" + }} + Example: + {{ + "country": "Mexico", + "city": null, + "name": null + }} + """, + args_schema=AirportSearchInput, + ), + StructuredTool.from_function( + coroutine=generate_search_flights_by_number(client), + name="Search Flights By Flight Number", + description=""" + Use this tool to get info for a specific flight. Do NOT use this tool with a flight id. + Takes an airline and flight number and returns info on the flight. + Do NOT guess an airline or flight number. + A flight number is a code for an airline service consisting of two-character + airline designator and a 1 to 4 digit number ex. OO123, DL 1234, BA 405, AS 3452. + If the tool returns more than one option choose the date closes to today. + """, + args_schema=FlightNumberInput, + ), + StructuredTool.from_function( + coroutine=generate_list_flights(client), + name="List Flights", + description=""" + Use this tool to list all flights matching search criteria. + Takes an arrival airport, a departure airport, or both, filters by date and returns all matching flights. + The agent can decide to return the results directly to the user. + Input of this tool must be in JSON format and include all three inputs - arrival_airport, departure_airport, and date. + Example: + {{ + "departure_airport": "SFO", + "arrival_airport": null, + "date": null + }} + Example: + {{ + "departure_airport": "SFO", + "arrival_airport": "SEA", + "date": "2023-11-01" + }} + Example: + {{ + "departure_airport": null, + "arrival_airport": "SFO", + "date": "2023-01-01" + }} + """, + args_schema=ListFlights, + ), + StructuredTool.from_function( + coroutine=generate_search_amenities(client), + name="Search Amenities", + description=""" + Use this tool to search amenities by name or to recommended airport amenities at SFO. + If user provides flight info, use 'Get Flight' and 'Get Flights by Number' + first to get gate info and location. + Only recommend amenities that are returned by this query. + Find amenities close to the user by matching the terminal and then comparing + the gate numbers. Gate number iterate by letter and number, example A1 A2 A3 + B1 B2 B3 C1 C2 C3. Gate A3 is close to A2 and B1. + """, + args_schema=QueryInput, + ), + StructuredTool.from_function( + coroutine=generate_insert_ticket(client), + name="Insert Ticket", + description=""" + Use this tool to book a flight ticket for the user. + Example: + {{ + "airline": "AA", + "flight_number": "452", + "departure_airport": "LAX", + "arrival_airport": "SFO", + "departure_time": "2024-01-01 05:50:00", + "arrival_time": "2024-01-01 09:23:00" + }} + Example: + {{ + "airline": "UA", + "flight_number": "1532", + "departure_airport": "SFO", + "arrival_airport": "DEN", + "departure_time": "2024-01-08 05:50:00", + "arrival_time": "2024-01-08 09:23:00" + }} + Example: + {{ + "airline": "OO", + "flight_number": "6307", + "departure_airport": "SFO", + "arrival_airport": "MSP", + "departure_time": "2024-10-28 20:13:00", + "arrival_time": "2024-10-28 21:07:00" + }} + """, + args_schema=TicketInput, + ), + StructuredTool.from_function( + coroutine=generate_list_tickets(client), + name="List Tickets", + description=""" + Use this tool to list a user's flight tickets. + Takes no input and returns a list of current user's flight tickets. + """, + ), + ] diff --git a/langchain_tools_demo/orchestration/orchestration.py b/langchain_tools_demo/orchestration/orchestration.py new file mode 100644 index 00000000..87d77cad --- /dev/null +++ b/langchain_tools_demo/orchestration/orchestration.py @@ -0,0 +1,82 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import Any, Dict, Generic, List, TypeVar + +import aiohttp + +# aiohttp context +connector = None + + +class classproperty: + def __init__(self, func): + self.fget = func + + def __get__(self, instance, owner): + return self.fget(owner) + + +class Orchestration(ABC): + client: aiohttp.ClientSession + + @classproperty + @abstractmethod + def kind(cls): + pass + + @classmethod + @abstractmethod + async def create(cls, history: List[Any]) -> "Orchestration": + pass + + async def invoke(self, prompt: str): + raise NotImplementedError("Subclass should implement this!") + + @staticmethod + async def get_connector(): + global connector + if connector is None: + connector = aiohttp.TCPConnector(limit=100) + return connector + + @staticmethod + async def handle_error_response(response): + if response.status != 200: + return f"Error sending {response.method} request to {str(response.url)}): {await response.text()}" + + @staticmethod + async def create_client_session() -> aiohttp.ClientSession: + return aiohttp.ClientSession( + connector=await Orchestration.get_connector(), + connector_owner=False, + headers={}, + raise_for_status=True, + ) + + def close(self): + raise NotImplementedError("Subclass should implement this!") + + +ais: Dict[str, Orchestration] = {} + + +async def create(config: Dict[str, str], history: List[Any]) -> Orchestration: + for cls in Orchestration.__subclasses__(): + s = f"{config['kind']} == {cls.kind}" + if config["kind"] == cls.kind: + return await cls.create(history) # type: ignore + raise TypeError(f"No orchestration of kind {config['kind']}") + return diff --git a/llm_demo/app.py b/llm_demo/app.py index b311ff24..dc6339fe 100644 --- a/llm_demo/app.py +++ b/llm_demo/app.py @@ -24,6 +24,7 @@ from google.auth.transport import requests # type:ignore from google.oauth2 import id_token # type:ignore from markdown import markdown +from piny import StrictMatcher, YamlLoader # type: ignore from starlette.middleware.sessions import SessionMiddleware from orchestrator import BaseOrchestrator, createOrchestrator @@ -31,8 +32,11 @@ routes = APIRouter() templates = Jinja2Templates(directory="templates") -BASE_HISTORY: list[BaseMessage] = [ - AIMessage(content="I am an SFO Airport Assistant, ready to assist you.") +BASE_HISTORY = [ + { + "type": "ai", + "data": {"content": "I am an SFO Airport Assistant, ready to assist you."}, + } ] CLIENT_ID = os.getenv("CLIENT_ID") routes = APIRouter() From 4e060e2ba5c0e3822e3f52e496ba5d75669a9599 Mon Sep 17 00:00:00 2001 From: Yuan Teoh Date: Tue, 13 Feb 2024 09:21:09 -0800 Subject: [PATCH 10/18] add function calling --- .../vertexai_function_calling/__init__.py | 17 ++ .../vertexai_function_calling/functions.py | 99 ++++++++++++ .../vertexai_function_calling/llm.py | 152 ++++++++++++++++++ 3 files changed, 268 insertions(+) create mode 100644 langchain_tools_demo/orchestration/vertexai_function_calling/__init__.py create mode 100644 langchain_tools_demo/orchestration/vertexai_function_calling/functions.py create mode 100644 langchain_tools_demo/orchestration/vertexai_function_calling/llm.py diff --git a/langchain_tools_demo/orchestration/vertexai_function_calling/__init__.py b/langchain_tools_demo/orchestration/vertexai_function_calling/__init__.py new file mode 100644 index 00000000..26322742 --- /dev/null +++ b/langchain_tools_demo/orchestration/vertexai_function_calling/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# from . import agent + +# __ALL__ = [agent] diff --git a/langchain_tools_demo/orchestration/vertexai_function_calling/functions.py b/langchain_tools_demo/orchestration/vertexai_function_calling/functions.py new file mode 100644 index 00000000..237f400a --- /dev/null +++ b/langchain_tools_demo/orchestration/vertexai_function_calling/functions.py @@ -0,0 +1,99 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from vertexai.preview import generative_models + +search_airports_func = generative_models.FunctionDeclaration( + name="airports_search", + description="Use this tool to list all airports matching search criteria. Takes at least one of country, city, name, or all of the above criteria. This function could also be used to search for airport information such as iata code.", + parameters={ + "type": "object", + "properties": { + "country": {"type": "string", "description": "country"}, + "city": {"type": "string", "description": "city"}, + "name": { + "type": "string", + "description": "Full or partial name of an airport", + }, + }, + }, +) + +search_amenities_func = generative_models.FunctionDeclaration( + name="amenities_search", + description="Use this tool to search amenities by name or to recommend airport amenities at SFO. If top_k is not specified, default to 5", + parameters={ + "type": "object", + "properties": { + "query": {"type": "string", "description": "Search query"}, + "top_k": { + "type": "integer", + "description": "Number of matching amenities to return. Default this value to 5.", + }, + }, + }, +) + +search_flights_by_number_func = generative_models.FunctionDeclaration( + name="flights_search", + description="Use this tool to get info for a specific flight. This function takes an airline and flight number and returns info on the flight.", + parameters={ + "type": "object", + "properties": { + "airline": { + "type": "string", + "description": "A code for an airline service consisting of two-character airline designator.", + }, + "flight_number": { + "type": "string", + "description": "A 1 to 4 digit number of the flight.", + }, + }, + }, +) + +list_flights_func = generative_models.FunctionDeclaration( + name="flights_search", + description="Use this tool to list all flights matching search criteria. This function takes an arrival airport, a departure airport, or both, filters by date and returns all matching flight. The format of date must be YYYY-MM-DD. Convert terms like today or yesterday to a valid date format.", + parameters={ + "type": "object", + "properties": { + "departure_airport": { + "type": "string", + "description": "The iata code for flight departure airport.", + }, + "arrival_airport": { + "type": "string", + "description": "The iata code for flight arrival airport.", + }, + "date": { + "type": "string", + "description": "The date of flight must be in the following format: YYYY-MM-DD.", + }, + }, + }, +) + + +def assistant_tool(): + return generative_models.Tool( + function_declarations=[ + search_airports_func, + search_amenities_func, + search_flights_by_number_func, + list_flights_func, + ], + ) diff --git a/langchain_tools_demo/orchestration/vertexai_function_calling/llm.py b/langchain_tools_demo/orchestration/vertexai_function_calling/llm.py new file mode 100644 index 00000000..25c51b13 --- /dev/null +++ b/langchain_tools_demo/orchestration/vertexai_function_calling/llm.py @@ -0,0 +1,152 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from datetime import date +from typing import Dict, Optional + +import aiohttp +import google.oauth2.id_token # type: ignore +from fastapi import HTTPException +from functions import assistant_tool +from google.auth.transport.requests import Request # type: ignore +from google.protobuf.json_format import MessageToDict +from vertexai.preview.generative_models import ( + ChatSession, + GenerationResponse, + GenerativeModel, + Part, +) + +MODEL = "gemini-pro" +BASE_URL = os.getenv("BASE_URL", default="http://127.0.0.1:8080") +CREDENTIALS = None + +# aiohttp context +connector = None + +CLOUD_RUN_AUTHORIZATION_TOKEN = None + +func_url = { + "airports_search": "airports/search", + "flights_search": "flights/search", + "list_flights": "flights/search", + "amenities_search": "amenities/search", +} + + +class ChatAssistant: + client: aiohttp.ClientSession + chat: ChatSession + + def __init__(self, client, chat: ChatSession) -> None: + self.client = client + self.chat = chat + + async def invoke(self, prompt: str): + model_response = self.request_chat_model(prompt) + print(f"function call response:\n{model_response}") + part_response = model_response.candidates[0].content.parts[0] + while "function_call" in part_response._raw_part: + function_call = MessageToDict(part_response.function_call._pb) + function_response = await self.request_function(function_call) + print(f"function response:\n{function_response}") + part = Part.from_function_response( + name=function_call["name"], + response={ + "content": function_response, + }, + ) + model_response = self.request_chat_model(part) + part_response = model_response.candidates[0].content.parts[0] + if "text" in part_response._raw_part: + content = part_response.text + print(f"output content: {content}") + return {"output": content} + else: + raise HTTPException( + status_code=500, detail="Error: Chat model response unknown" + ) + + def request_chat_model(self, prompt: str): + model_response = self.chat.send_message(prompt) + return model_response + + async def request_function(self, function_call): + function_name = func_url[function_call["name"]] + params = function_call["args"] + print(f"function name is {function_name}") + print(f"params is {params}") + response = await self.client.get( + url=f"{BASE_URL}/{function_name}", + params=params, + headers=get_headers(self.client), + ) + response = await response.json() + return response + + +chat_assistants: Dict[str, ChatAssistant] = {} + + +def get_headers(client: aiohttp.ClientSession): + """Helper method to generate ID tokens for authenticated requests""" + headers = client.headers + if not "http://" in BASE_URL: + # Append ID Token to make authenticated requests to Cloud Run services + headers["Authorization"] = f"Bearer {get_id_token()}" + return headers + + +def get_id_token(): + global CREDENTIALS + if CREDENTIALS is None: + CREDENTIALS, _ = google.auth.default() + if not CREDENTIALS.valid: + CREDENTIALS.refresh(Request()) + return CREDENTIALS.id_token + + +async def get_connector(): + global connector + if connector is None: + connector = aiohttp.TCPConnector(limit=100) + return connector + + +async def handle_error_response(response): + if response.status != 200: + return f"Error sending {response.method} request to {str(response.url)}): {await response.text()}" + + +async def create_client_session(user_id_token: Optional[str]) -> aiohttp.ClientSession: + headers = {} + if user_id_token is not None: + # user-specific query authentication + headers["User-Id-Token"] = user_id_token + + return aiohttp.ClientSession( + connector=await get_connector(), + connector_owner=False, + headers=headers, + raise_for_status=True, + ) + + +async def init_chat_assistant(user_id_token) -> ChatAssistant: + print("Initializing agent..") + client = await create_client_session(user_id_token) + model = GenerativeModel(MODEL, tools=[assistant_tool()]) + func_calling_chat = model.start_chat() + return ChatAssistant(client, func_calling_chat) From 3a179d19c4e9f7278d9856483280877a8702cd92 Mon Sep 17 00:00:00 2001 From: Yuan Teoh Date: Wed, 21 Feb 2024 14:08:42 -0800 Subject: [PATCH 11/18] rebase --- .../orchestration/__init__.py | 18 - .../orchestration/langchain_tools/__init__.py | 17 - .../orchestration/langchain_tools/agent.py | 169 --------- .../orchestration/langchain_tools/tools.py | 357 ------------------ .../orchestration/orchestration.py | 82 ---- .../vertexai_function_calling/__init__.py | 0 .../vertexai_function_calling/functions.py | 0 .../vertexai_function_calling/llm.py | 0 llm_demo/requirements.txt | 4 - llm_demo/run_app.py | 1 + 10 files changed, 1 insertion(+), 647 deletions(-) delete mode 100644 langchain_tools_demo/orchestration/__init__.py delete mode 100644 langchain_tools_demo/orchestration/langchain_tools/__init__.py delete mode 100644 langchain_tools_demo/orchestration/langchain_tools/agent.py delete mode 100644 langchain_tools_demo/orchestration/langchain_tools/tools.py delete mode 100644 langchain_tools_demo/orchestration/orchestration.py rename langchain_tools_demo/{orchestration => orchestrator}/vertexai_function_calling/__init__.py (100%) rename langchain_tools_demo/{orchestration => orchestrator}/vertexai_function_calling/functions.py (100%) rename langchain_tools_demo/{orchestration => orchestrator}/vertexai_function_calling/llm.py (100%) diff --git a/langchain_tools_demo/orchestration/__init__.py b/langchain_tools_demo/orchestration/__init__.py deleted file mode 100644 index 256e7d36..00000000 --- a/langchain_tools_demo/orchestration/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from . import langchain_tools -from .orchestration import Orchestration, ais, create - -__ALL__ = [Orchestration, create, ais, langchain_tools] diff --git a/langchain_tools_demo/orchestration/langchain_tools/__init__.py b/langchain_tools_demo/orchestration/langchain_tools/__init__.py deleted file mode 100644 index 5b965e2b..00000000 --- a/langchain_tools_demo/orchestration/langchain_tools/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from . import agent - -__ALL__ = [agent] diff --git a/langchain_tools_demo/orchestration/langchain_tools/agent.py b/langchain_tools_demo/orchestration/langchain_tools/agent.py deleted file mode 100644 index 336fed2f..00000000 --- a/langchain_tools_demo/orchestration/langchain_tools/agent.py +++ /dev/null @@ -1,169 +0,0 @@ -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -from datetime import date -from typing import Any, Dict, List - -import aiohttp -from fastapi import HTTPException -from langchain.agents import AgentType, initialize_agent -from langchain.agents.agent import AgentExecutor -from langchain.globals import set_verbose # type: ignore -from langchain.memory import ChatMessageHistory, ConversationBufferMemory -from langchain.prompts.chat import ChatPromptTemplate -from langchain_core import messages -from langchain_core.messages import AIMessage, BaseMessage, HumanMessage -from langchain_google_vertexai import VertexAI - -from .. import orchestration -from .tools import initialize_tools - -set_verbose(bool(os.getenv("DEBUG", default=False))) -BASE_URL = os.getenv("BASE_URL", default="http://127.0.0.1:8080") - -CLOUD_RUN_AUTHORIZATION_TOKEN = None - - -class Orchestration(orchestration.Orchestration): - client: aiohttp.ClientSession - agent: AgentExecutor - - @orchestration.classproperty - def kind(cls): - return "langchain-tools" - - def __init__(self, client: aiohttp.ClientSession, agent: AgentExecutor): - self.client = client - self.agent = agent - - @classmethod - async def create(cls, base_history: list[Any]) -> "Orchestration": - """Load an agent executor with tools and LLM""" - print("Initializing agent..") - history = Orchestration.parse_messages(base_history) - llm = VertexAI(max_output_tokens=512, model_name="gemini-pro") - memory = ConversationBufferMemory( - chat_memory=ChatMessageHistory(messages=history), - memory_key="chat_history", - input_key="input", - output_key="output", - ) - client = await Orchestration.create_client_session() - tools = await initialize_tools(client) - agent = initialize_agent( - tools, - llm, - agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION, - memory=memory, - handle_parsing_errors=True, - max_iterations=3, - early_stopping_method="generate", - return_intermediate_steps=True, - ) - # Create new prompt template - tool_strings = "\n".join( - [f"> {tool.name}: {tool.description}" for tool in tools] - ) - tool_names = ", ".join([tool.name for tool in tools]) - format_instructions = FORMAT_INSTRUCTIONS.format( - tool_names=tool_names, - ) - today_date = date.today().strftime("%Y-%m-%d") - today = f"Today is {today_date}." - template = "\n\n".join( - [PREFIX, tool_strings, format_instructions, SUFFIX, today] - ) - human_message_template = "{input}\n\n{agent_scratchpad}" - prompt = ChatPromptTemplate.from_messages( - [("system", template), ("human", human_message_template)] - ) - agent.agent.llm_chain.prompt = prompt # type: ignore - - return Orchestration(client, agent) - - async def invoke(self, prompt: str) -> Dict[str, Any]: - response = await self.agent.ainvoke({"input": prompt}) - return response - - @staticmethod - def parse_messages(datas: List[Any]) -> List[BaseMessage]: - messages: List[BaseMessage] = [] - for data in datas: - if data["type"] == "human": - messages.append(HumanMessage(content=data["data"]["content"])) - if data["type"] == "ai": - messages.append(AIMessage(content=data["data"]["content"])) - return messages - - def close(self): - self.client.close() - - -PREFIX = """SFO Airport Assistant helps travelers find their way at the airport. - -Assistant is designed to be able to assist with a wide range of tasks, from answering simple questions to -complex multi-query questions that require passing results from one query to another. As a language model, Assistant is -able to generate human-like text based on the input it receives, allowing it to engage in natural-sounding -conversations and provide responses that are coherent and relevant to the topic at hand. - -Overall, Assistant is a powerful tool that can help answer a wide range of questions pertaining to the San -Francisco Airport. SFO Airport Assistant is here to assist. It currently does not have access to user info. - -TOOLS: ------- - -Assistant has access to the following tools:""" - -FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) -and an action_input key (tool input). - -Valid "action" values: "Final Answer" or {tool_names} - -Provide only ONE action per $JSON_BLOB, as shown: - -``` -{{{{ - "action": $TOOL_NAME, - "action_input": $INPUT -}}}} -``` - -Follow this format: - -Question: input question to answer -Thought: consider previous and subsequent steps -Action: -``` -$JSON_BLOB -``` -Observation: action result -... (repeat Thought/Action/Observation N times) -Thought: I know what to respond -Action: -``` -{{{{ - "action": "Final Answer", - "action_input": "Final response to human" -}}}} -```""" - -SUFFIX = """Begin! Use tools if necessary. Respond directly if appropriate. -If using a tool, reminder to ALWAYS respond with a valid json blob of a single action. -Format is Action:```$JSON_BLOB```then Observation:. -Thought: - -Previous conversation history: -{chat_history} -""" diff --git a/langchain_tools_demo/orchestration/langchain_tools/tools.py b/langchain_tools_demo/orchestration/langchain_tools/tools.py deleted file mode 100644 index 52295065..00000000 --- a/langchain_tools_demo/orchestration/langchain_tools/tools.py +++ /dev/null @@ -1,357 +0,0 @@ -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -from datetime import datetime -from typing import Optional - -import aiohttp -import google.oauth2.id_token # type: ignore -from google.auth import compute_engine # type: ignore -from google.auth.transport.requests import Request # type: ignore -from langchain.agents.agent import ExceptionTool # type: ignore -from langchain.tools import StructuredTool -from pydantic.v1 import BaseModel, Field - -BASE_URL = os.getenv("BASE_URL", default="http://127.0.0.1:8080") -CREDENTIALS = None - - -def filter_none_values(params: dict) -> dict: - return {key: value for key, value in params.items() if value is not None} - - -def get_id_token(): - global CREDENTIALS - if CREDENTIALS is None: - CREDENTIALS, _ = google.auth.default() - if not hasattr(CREDENTIALS, "id_token"): - # Use Compute Engine default credential - CREDENTIALS = compute_engine.IDTokenCredentials( - request=Request(), - target_audience=BASE_URL, - use_metadata_identity_endpoint=True, - ) - if not CREDENTIALS.valid: - CREDENTIALS.refresh(Request()) - if hasattr(CREDENTIALS, "id_token"): - return CREDENTIALS.id_token - else: - return CREDENTIALS.token - - -def get_headers(client: aiohttp.ClientSession): - """Helper method to generate ID tokens for authenticated requests""" - headers = client.headers - if not "http://" in BASE_URL: - # Append ID Token to make authenticated requests to Cloud Run services - headers["Authorization"] = f"Bearer {get_id_token()}" - return headers - - -# Tools -class AirportSearchInput(BaseModel): - country: Optional[str] = Field(description="Country") - city: Optional[str] = Field(description="City") - name: Optional[str] = Field(description="Airport name") - - -def generate_search_airports(client: aiohttp.ClientSession): - async def search_airports(country: str, city: str, name: str): - params = { - "country": country, - "city": city, - "name": name, - } - response = await client.get( - url=f"{BASE_URL}/airports/search", - params=filter_none_values(params), - headers=get_headers(client), - ) - - num = 2 - response_json = await response.json() - if len(response_json) < 1: - return "There are no airports matching that query. Let the user know there are no results." - elif len(response_json) > num: - return ( - f"There are {len(response_json)} airports matching that query. Here are the first {num} results:\n" - + " ".join([f"{response_json[i]}" for i in range(num)]) - ) - else: - return "\n".join([f"{r}" for r in response_json]) - - return search_airports - - -class FlightNumberInput(BaseModel): - airline: str = Field(description="Airline unique 2 letter identifier") - flight_number: str = Field(description="1 to 4 digit number") - - -def generate_search_flights_by_number(client: aiohttp.ClientSession): - async def search_flights_by_number(airline: str, flight_number: str): - response = await client.get( - url=f"{BASE_URL}/flights/search", - params={"airline": airline, "flight_number": flight_number}, - headers=get_headers(client), - ) - - return await response.json() - - return search_flights_by_number - - -class ListFlights(BaseModel): - departure_airport: Optional[str] = Field( - description="Departure airport 3-letter code", - ) - arrival_airport: Optional[str] = Field(description="Arrival airport 3-letter code") - date: Optional[str] = Field(description="Date of flight departure") - - -def generate_list_flights(client: aiohttp.ClientSession): - async def list_flights( - departure_airport: str, - arrival_airport: str, - date: str, - ): - params = { - "departure_airport": departure_airport, - "arrival_airport": arrival_airport, - "date": date, - } - response = await client.get( - url=f"{BASE_URL}/flights/search", - params=filter_none_values(params), - headers=get_headers(client), - ) - - num = 2 - response_json = await response.json() - if len(response_json) < 1: - return "There are no flights matching that query. Let the user know there are no results." - elif len(response_json) > num: - return ( - f"There are {len(response_json)} flights matching that query. Here are the first {num} results:\n" - + " ".join([f"{response_json[i]}" for i in range(num)]) - ) - else: - return "\n".join([f"{r}" for r in response_json]) - - return list_flights - - -class QueryInput(BaseModel): - query: str = Field(description="Search query") - - -def generate_search_amenities(client: aiohttp.ClientSession): - async def search_amenities(query: str): - response = await client.get( - url=f"{BASE_URL}/amenities/search", - params={"top_k": "5", "query": query}, - headers=get_headers(client), - ) - - response = await response.json() - return response - - return search_amenities - - -class TicketInput(BaseModel): - airline: str = Field(description="Airline unique 2 letter identifier") - flight_number: str = Field(description="1 to 4 digit number") - departure_airport: str = Field( - description="Departure airport 3-letter code", - ) - arrival_airport: str = Field(description="Arrival airport 3-letter code") - departure_time: datetime = Field(description="Flight departure datetime") - arrival_time: datetime = Field(description="Flight arrival datetime") - - -def generate_insert_ticket(client: aiohttp.ClientSession): - async def insert_ticket( - airline: str, - flight_number: str, - departure_airport: str, - arrival_airport: str, - departure_time: datetime, - arrival_time: datetime, - ): - response = await client.post( - url=f"{BASE_URL}/tickets/insert", - params={ - "airline": airline, - "flight_number": flight_number, - "departure_airport": departure_airport, - "arrival_airport": arrival_airport, - "departure_time": departure_time.strftime("%Y-%m-%d %H:%M:%S"), - "arrival_time": arrival_time.strftime("%Y-%m-%d %H:%M:%S"), - }, - headers=get_headers(client), - ) - - response = await response.json() - return response - - return insert_ticket - - -def generate_list_tickets(client: aiohttp.ClientSession): - async def list_tickets(): - response = await client.get( - url=f"{BASE_URL}/tickets/list", - headers=get_headers(client), - ) - - response = await response.json() - return response - - return list_tickets - - -# Tools for agent -async def initialize_tools(client: aiohttp.ClientSession): - return [ - StructuredTool.from_function( - coroutine=generate_search_airports(client), - name="Search Airport", - description=""" - Use this tool to list all airports matching search criteria. - Takes at least one of country, city, name, or all and returns all matching airports. - The agent can decide to return the results directly to the user. - Input of this tool must be in JSON format and include all three inputs - country, city, name. - Example: - {{ - "country": "United States", - "city": "San Francisco", - "name": null - }} - Example: - {{ - "country": null, - "city": "Goroka", - "name": "Goroka" - }} - Example: - {{ - "country": "Mexico", - "city": null, - "name": null - }} - """, - args_schema=AirportSearchInput, - ), - StructuredTool.from_function( - coroutine=generate_search_flights_by_number(client), - name="Search Flights By Flight Number", - description=""" - Use this tool to get info for a specific flight. Do NOT use this tool with a flight id. - Takes an airline and flight number and returns info on the flight. - Do NOT guess an airline or flight number. - A flight number is a code for an airline service consisting of two-character - airline designator and a 1 to 4 digit number ex. OO123, DL 1234, BA 405, AS 3452. - If the tool returns more than one option choose the date closes to today. - """, - args_schema=FlightNumberInput, - ), - StructuredTool.from_function( - coroutine=generate_list_flights(client), - name="List Flights", - description=""" - Use this tool to list all flights matching search criteria. - Takes an arrival airport, a departure airport, or both, filters by date and returns all matching flights. - The agent can decide to return the results directly to the user. - Input of this tool must be in JSON format and include all three inputs - arrival_airport, departure_airport, and date. - Example: - {{ - "departure_airport": "SFO", - "arrival_airport": null, - "date": null - }} - Example: - {{ - "departure_airport": "SFO", - "arrival_airport": "SEA", - "date": "2023-11-01" - }} - Example: - {{ - "departure_airport": null, - "arrival_airport": "SFO", - "date": "2023-01-01" - }} - """, - args_schema=ListFlights, - ), - StructuredTool.from_function( - coroutine=generate_search_amenities(client), - name="Search Amenities", - description=""" - Use this tool to search amenities by name or to recommended airport amenities at SFO. - If user provides flight info, use 'Get Flight' and 'Get Flights by Number' - first to get gate info and location. - Only recommend amenities that are returned by this query. - Find amenities close to the user by matching the terminal and then comparing - the gate numbers. Gate number iterate by letter and number, example A1 A2 A3 - B1 B2 B3 C1 C2 C3. Gate A3 is close to A2 and B1. - """, - args_schema=QueryInput, - ), - StructuredTool.from_function( - coroutine=generate_insert_ticket(client), - name="Insert Ticket", - description=""" - Use this tool to book a flight ticket for the user. - Example: - {{ - "airline": "AA", - "flight_number": "452", - "departure_airport": "LAX", - "arrival_airport": "SFO", - "departure_time": "2024-01-01 05:50:00", - "arrival_time": "2024-01-01 09:23:00" - }} - Example: - {{ - "airline": "UA", - "flight_number": "1532", - "departure_airport": "SFO", - "arrival_airport": "DEN", - "departure_time": "2024-01-08 05:50:00", - "arrival_time": "2024-01-08 09:23:00" - }} - Example: - {{ - "airline": "OO", - "flight_number": "6307", - "departure_airport": "SFO", - "arrival_airport": "MSP", - "departure_time": "2024-10-28 20:13:00", - "arrival_time": "2024-10-28 21:07:00" - }} - """, - args_schema=TicketInput, - ), - StructuredTool.from_function( - coroutine=generate_list_tickets(client), - name="List Tickets", - description=""" - Use this tool to list a user's flight tickets. - Takes no input and returns a list of current user's flight tickets. - """, - ), - ] diff --git a/langchain_tools_demo/orchestration/orchestration.py b/langchain_tools_demo/orchestration/orchestration.py deleted file mode 100644 index 87d77cad..00000000 --- a/langchain_tools_demo/orchestration/orchestration.py +++ /dev/null @@ -1,82 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from abc import ABC, abstractmethod -from typing import Any, Dict, Generic, List, TypeVar - -import aiohttp - -# aiohttp context -connector = None - - -class classproperty: - def __init__(self, func): - self.fget = func - - def __get__(self, instance, owner): - return self.fget(owner) - - -class Orchestration(ABC): - client: aiohttp.ClientSession - - @classproperty - @abstractmethod - def kind(cls): - pass - - @classmethod - @abstractmethod - async def create(cls, history: List[Any]) -> "Orchestration": - pass - - async def invoke(self, prompt: str): - raise NotImplementedError("Subclass should implement this!") - - @staticmethod - async def get_connector(): - global connector - if connector is None: - connector = aiohttp.TCPConnector(limit=100) - return connector - - @staticmethod - async def handle_error_response(response): - if response.status != 200: - return f"Error sending {response.method} request to {str(response.url)}): {await response.text()}" - - @staticmethod - async def create_client_session() -> aiohttp.ClientSession: - return aiohttp.ClientSession( - connector=await Orchestration.get_connector(), - connector_owner=False, - headers={}, - raise_for_status=True, - ) - - def close(self): - raise NotImplementedError("Subclass should implement this!") - - -ais: Dict[str, Orchestration] = {} - - -async def create(config: Dict[str, str], history: List[Any]) -> Orchestration: - for cls in Orchestration.__subclasses__(): - s = f"{config['kind']} == {cls.kind}" - if config["kind"] == cls.kind: - return await cls.create(history) # type: ignore - raise TypeError(f"No orchestration of kind {config['kind']}") - return diff --git a/langchain_tools_demo/orchestration/vertexai_function_calling/__init__.py b/langchain_tools_demo/orchestrator/vertexai_function_calling/__init__.py similarity index 100% rename from langchain_tools_demo/orchestration/vertexai_function_calling/__init__.py rename to langchain_tools_demo/orchestrator/vertexai_function_calling/__init__.py diff --git a/langchain_tools_demo/orchestration/vertexai_function_calling/functions.py b/langchain_tools_demo/orchestrator/vertexai_function_calling/functions.py similarity index 100% rename from langchain_tools_demo/orchestration/vertexai_function_calling/functions.py rename to langchain_tools_demo/orchestrator/vertexai_function_calling/functions.py diff --git a/langchain_tools_demo/orchestration/vertexai_function_calling/llm.py b/langchain_tools_demo/orchestrator/vertexai_function_calling/llm.py similarity index 100% rename from langchain_tools_demo/orchestration/vertexai_function_calling/llm.py rename to langchain_tools_demo/orchestrator/vertexai_function_calling/llm.py diff --git a/llm_demo/requirements.txt b/llm_demo/requirements.txt index 05095b87..2dcb24d1 100644 --- a/llm_demo/requirements.txt +++ b/llm_demo/requirements.txt @@ -9,7 +9,3 @@ markdown==3.5.2 types-Markdown==3.5.0.20240129 uvicorn[standard]==0.27.0.post1 python-multipart==0.0.7 -<<<<<<< HEAD -======= -piny==1.1.0 ->>>>>>> 94c687f (update requirement) diff --git a/llm_demo/run_app.py b/llm_demo/run_app.py index 7be2d849..7fe1e5c9 100644 --- a/llm_demo/run_app.py +++ b/llm_demo/run_app.py @@ -15,6 +15,7 @@ import asyncio import os + import uvicorn from app import init_app From 622d4e5ddd56e735110acf1bbba0a15c8244db00 Mon Sep 17 00:00:00 2001 From: Yuan Teoh Date: Wed, 21 Feb 2024 14:18:38 -0800 Subject: [PATCH 12/18] update function calling orchestrator --- .../vertexai_function_calling/__init__.py | 4 +- .../function_calling_orchestrator.py | 158 ++++++++++++++++++ .../vertexai_function_calling/functions.py | 12 +- .../vertexai_function_calling/llm.py | 152 ----------------- 4 files changed, 171 insertions(+), 155 deletions(-) create mode 100644 langchain_tools_demo/orchestrator/vertexai_function_calling/function_calling_orchestrator.py delete mode 100644 langchain_tools_demo/orchestrator/vertexai_function_calling/llm.py diff --git a/langchain_tools_demo/orchestrator/vertexai_function_calling/__init__.py b/langchain_tools_demo/orchestrator/vertexai_function_calling/__init__.py index 26322742..5b743df4 100644 --- a/langchain_tools_demo/orchestrator/vertexai_function_calling/__init__.py +++ b/langchain_tools_demo/orchestrator/vertexai_function_calling/__init__.py @@ -12,6 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -# from . import agent +from .function_calling_orchestrator import FunctionCallingOrchestrator -# __ALL__ = [agent] +__ALL__ = ["FunctionCallingOrchestrator"] diff --git a/langchain_tools_demo/orchestrator/vertexai_function_calling/function_calling_orchestrator.py b/langchain_tools_demo/orchestrator/vertexai_function_calling/function_calling_orchestrator.py new file mode 100644 index 00000000..30c5e8f1 --- /dev/null +++ b/langchain_tools_demo/orchestrator/vertexai_function_calling/function_calling_orchestrator.py @@ -0,0 +1,158 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import uuid +from datetime import date +from typing import Any, Dict + +from aiohttp import ClientSession, TCPConnector +from fastapi import HTTPException +from google.auth.transport.requests import Request # type: ignore +from google.protobuf.json_format import MessageToDict +from vertexai.preview.generative_models import ( # type: ignore + ChatSession, + GenerationResponse, + GenerativeModel, + Part, +) + +from ..orchestrator import BaseOrchestrator, classproperty +from .functions import assistant_tool, function_request + +MODEL = "gemini-pro" +BASE_HISTORY = { + "type": "ai", + "data": {"content": "I am an SFO Airport Assistant, ready to assist you."}, +} + + +class UserChatModel: + client: ClientSession + chat: ChatSession + + def __init__(self, client: ClientSession, chat: ChatSession): + self.client = client + self.chat = chat + + @classmethod + def initialize_chat_model(cls, client: ClientSession) -> "UserChatModel": + model = GenerativeModel(MODEL, tools=[assistant_tool()]) + function_calling_session = model.start_chat() + return UserChatModel(client, function_calling_session) + + async def close(self): + await self.client.close() + + async def invoke(self, prompt: str) -> Dict[str, Any]: + model_response = self.request_chat_model(prompt) + print(f"function call response:\n{model_response}") + part_response = model_response.candidates[0].content.parts[0] + while "function_call" in part_response._raw_part: + function_call = MessageToDict(part_response.function_call._pb) + function_response = await self.request_function(function_call) + print(f"function response:\n{function_response}") + part = Part.from_function_response( + name=function_call["name"], + response={ + "content": function_response, + }, + ) + model_response = self.request_chat_model(part) + part_response = model_response.candidates[0].content.parts[0] + if "text" in part_response._raw_part: + content = part_response.text + print(f"output content: {content}") + return {"output": content} + else: + raise HTTPException( + status_code=500, detail="Error: Chat model response unknown" + ) + + def request_chat_model(self, prompt: str): + try: + model_response = self.chat.send_message(prompt) + except Exception as err: + raise HTTPException(status_code=500, detail=f"Error invoking agent: {err}") + return model_response + + async def request_function(self, function_call): + url = function_request(function_call["name"]) + params = function_call["args"] + print(f"function url is {url}") + print(f"params is {params}") + response = await self.client.get( + url=f"{BASE_URL}/{url}", + params=params, + headers=get_headers(self.client), + ) + response = await response.json() + return response + + +class FunctionCallingOrchestrator(BaseOrchestrator): + _user_sessions: Dict[str, UserChatModel] = {} + # aiohttp context + connector = None + + @classproperty + def kind(cls): + return "vertexai-function-calling" + + def user_session_exist(self, uuid: str) -> bool: + return uuid in self._user_sessions + + async def user_session_create(self, session: dict[str, Any]): + """Create and load an agent executor with tools and LLM.""" + print("Initializing agent..") + if "uuid" not in session: + session["uuid"] = str(uuid.uuid4()) + id = session["uuid"] + if "history" not in session: + session["history"] = [BASE_HISTORY] + client = await self.create_client_session() + chat = UserChatModel.initialize_chat_model(client) + self._user_sessions[id] = chat + + async def user_session_invoke(self, uuid: str, prompt: str) -> str: + user_session = self.get_user_session(uuid) + # Send prompt to LLM + response = await user_session.invoke(prompt) + return response["output"] + + async def user_session_reset(self, uuid: str): + user_session = self.get_user_session(uuid) + await user_session.close() + del user_session + + def get_user_session(self, uuid: str) -> UserChatModel: + return self._user_sessions[uuid] + + async def get_connector(self) -> TCPConnector: + if self.connector is None: + self.connector = TCPConnector(limit=100) + return self.connector + + async def create_client_session(self) -> ClientSession: + return ClientSession( + connector=await self.get_connector(), + connector_owner=False, + headers={}, + raise_for_status=True, + ) + + def close_clients(self): + close_client_tasks = [ + asyncio.create_task(a.close()) for a in self._user_sessions.values() + ] + asyncio.gather(*close_client_tasks) diff --git a/langchain_tools_demo/orchestrator/vertexai_function_calling/functions.py b/langchain_tools_demo/orchestrator/vertexai_function_calling/functions.py index 237f400a..8e6c499f 100644 --- a/langchain_tools_demo/orchestrator/vertexai_function_calling/functions.py +++ b/langchain_tools_demo/orchestrator/vertexai_function_calling/functions.py @@ -14,7 +14,7 @@ import os -from vertexai.preview import generative_models +from vertexai.preview import generative_models # type: ignore search_airports_func = generative_models.FunctionDeclaration( name="airports_search", @@ -88,6 +88,16 @@ ) +def function_request(function_call_name: str) -> str: + functions_url = { + "airports_search": "airports/search", + "flights_search": "flights/search", + "list_flights": "flights/search", + "amenities_search": "amenities/search", + } + return functions_url[function_call_name] + + def assistant_tool(): return generative_models.Tool( function_declarations=[ diff --git a/langchain_tools_demo/orchestrator/vertexai_function_calling/llm.py b/langchain_tools_demo/orchestrator/vertexai_function_calling/llm.py deleted file mode 100644 index 25c51b13..00000000 --- a/langchain_tools_demo/orchestrator/vertexai_function_calling/llm.py +++ /dev/null @@ -1,152 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -from datetime import date -from typing import Dict, Optional - -import aiohttp -import google.oauth2.id_token # type: ignore -from fastapi import HTTPException -from functions import assistant_tool -from google.auth.transport.requests import Request # type: ignore -from google.protobuf.json_format import MessageToDict -from vertexai.preview.generative_models import ( - ChatSession, - GenerationResponse, - GenerativeModel, - Part, -) - -MODEL = "gemini-pro" -BASE_URL = os.getenv("BASE_URL", default="http://127.0.0.1:8080") -CREDENTIALS = None - -# aiohttp context -connector = None - -CLOUD_RUN_AUTHORIZATION_TOKEN = None - -func_url = { - "airports_search": "airports/search", - "flights_search": "flights/search", - "list_flights": "flights/search", - "amenities_search": "amenities/search", -} - - -class ChatAssistant: - client: aiohttp.ClientSession - chat: ChatSession - - def __init__(self, client, chat: ChatSession) -> None: - self.client = client - self.chat = chat - - async def invoke(self, prompt: str): - model_response = self.request_chat_model(prompt) - print(f"function call response:\n{model_response}") - part_response = model_response.candidates[0].content.parts[0] - while "function_call" in part_response._raw_part: - function_call = MessageToDict(part_response.function_call._pb) - function_response = await self.request_function(function_call) - print(f"function response:\n{function_response}") - part = Part.from_function_response( - name=function_call["name"], - response={ - "content": function_response, - }, - ) - model_response = self.request_chat_model(part) - part_response = model_response.candidates[0].content.parts[0] - if "text" in part_response._raw_part: - content = part_response.text - print(f"output content: {content}") - return {"output": content} - else: - raise HTTPException( - status_code=500, detail="Error: Chat model response unknown" - ) - - def request_chat_model(self, prompt: str): - model_response = self.chat.send_message(prompt) - return model_response - - async def request_function(self, function_call): - function_name = func_url[function_call["name"]] - params = function_call["args"] - print(f"function name is {function_name}") - print(f"params is {params}") - response = await self.client.get( - url=f"{BASE_URL}/{function_name}", - params=params, - headers=get_headers(self.client), - ) - response = await response.json() - return response - - -chat_assistants: Dict[str, ChatAssistant] = {} - - -def get_headers(client: aiohttp.ClientSession): - """Helper method to generate ID tokens for authenticated requests""" - headers = client.headers - if not "http://" in BASE_URL: - # Append ID Token to make authenticated requests to Cloud Run services - headers["Authorization"] = f"Bearer {get_id_token()}" - return headers - - -def get_id_token(): - global CREDENTIALS - if CREDENTIALS is None: - CREDENTIALS, _ = google.auth.default() - if not CREDENTIALS.valid: - CREDENTIALS.refresh(Request()) - return CREDENTIALS.id_token - - -async def get_connector(): - global connector - if connector is None: - connector = aiohttp.TCPConnector(limit=100) - return connector - - -async def handle_error_response(response): - if response.status != 200: - return f"Error sending {response.method} request to {str(response.url)}): {await response.text()}" - - -async def create_client_session(user_id_token: Optional[str]) -> aiohttp.ClientSession: - headers = {} - if user_id_token is not None: - # user-specific query authentication - headers["User-Id-Token"] = user_id_token - - return aiohttp.ClientSession( - connector=await get_connector(), - connector_owner=False, - headers=headers, - raise_for_status=True, - ) - - -async def init_chat_assistant(user_id_token) -> ChatAssistant: - print("Initializing agent..") - client = await create_client_session(user_id_token) - model = GenerativeModel(MODEL, tools=[assistant_tool()]) - func_calling_chat = model.start_chat() - return ChatAssistant(client, func_calling_chat) From 71714b892fbdccd5f3d42c181f54a7de9ce2d251 Mon Sep 17 00:00:00 2001 From: Yuan Teoh Date: Fri, 23 Feb 2024 12:27:04 -0800 Subject: [PATCH 13/18] update function calling --- .../function_calling_orchestrator.py | 35 ++++++++++++++++++- llm_demo/orchestrator/__init__.py | 1 + 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/langchain_tools_demo/orchestrator/vertexai_function_calling/function_calling_orchestrator.py b/langchain_tools_demo/orchestrator/vertexai_function_calling/function_calling_orchestrator.py index 30c5e8f1..03f25a40 100644 --- a/langchain_tools_demo/orchestrator/vertexai_function_calling/function_calling_orchestrator.py +++ b/langchain_tools_demo/orchestrator/vertexai_function_calling/function_calling_orchestrator.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import uuid from datetime import date from typing import Any, Dict @@ -31,10 +32,40 @@ from .functions import assistant_tool, function_request MODEL = "gemini-pro" +BASE_URL = os.getenv("BASE_URL", default="http://127.0.0.1:8080") BASE_HISTORY = { "type": "ai", "data": {"content": "I am an SFO Airport Assistant, ready to assist you."}, } +CREDENTIALS = None + + +def get_id_token(): + global CREDENTIALS + if CREDENTIALS is None: + CREDENTIALS, _ = google.auth.default() + if not hasattr(CREDENTIALS, "id_token"): + # Use Compute Engine default credential + CREDENTIALS = compute_engine.IDTokenCredentials( + request=Request(), + target_audience=BASE_URL, + use_metadata_identity_endpoint=True, + ) + if not CREDENTIALS.valid: + CREDENTIALS.refresh(Request()) + if hasattr(CREDENTIALS, "id_token"): + return CREDENTIALS.id_token + else: + return CREDENTIALS.token + + +def get_headers(client: ClientSession): + """Helper method to generate ID tokens for authenticated requests""" + headers = client.headers + if not "http://" in BASE_URL: + # Append ID Token to make authenticated requests to Cloud Run services + headers["Authorization"] = f"Bearer {get_id_token()}" + return headers class UserChatModel: @@ -55,7 +86,9 @@ async def close(self): await self.client.close() async def invoke(self, prompt: str) -> Dict[str, Any]: - model_response = self.request_chat_model(prompt) + today_date = date.today().strftime("%Y-%m-%d") + today = f"Today is {today_date}." + model_response = self.request_chat_model(prompt + today) print(f"function call response:\n{model_response}") part_response = model_response.candidates[0].content.parts[0] while "function_call" in part_response._raw_part: diff --git a/llm_demo/orchestrator/__init__.py b/llm_demo/orchestrator/__init__.py index 35b97a74..c338a248 100644 --- a/llm_demo/orchestrator/__init__.py +++ b/llm_demo/orchestrator/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. from . import langchain_tools +from . import vertexai_function_calling from .orchestrator import BaseOrchestrator, createOrchestrator __ALL__ = ["BaseOrchestrator", "createOrchestrator", "langchain_tools"] From 2cfc26db3348539d13c67896c1ae0301451680ae Mon Sep 17 00:00:00 2001 From: Yuan Teoh Date: Tue, 27 Feb 2024 17:18:26 -0800 Subject: [PATCH 14/18] update merge --- llm_demo/app.py | 44 ++------------- llm_demo/example-config.yml | 4 -- llm_demo/int.tests.cloudbuild.yaml | 10 ---- llm_demo/orchestrator/__init__.py | 2 +- .../orchestrator/langchain_tools/__init__.py | 4 -- .../vertexai_function_calling/__init__.py | 0 .../function_calling_orchestrator.py | 54 +++++-------------- .../vertexai_function_calling/functions.py | 0 8 files changed, 17 insertions(+), 101 deletions(-) delete mode 100644 llm_demo/example-config.yml rename {langchain_tools_demo => llm_demo}/orchestrator/vertexai_function_calling/__init__.py (100%) rename {langchain_tools_demo => llm_demo}/orchestrator/vertexai_function_calling/function_calling_orchestrator.py (78%) rename {langchain_tools_demo => llm_demo}/orchestrator/vertexai_function_calling/functions.py (100%) diff --git a/llm_demo/app.py b/llm_demo/app.py index dc6339fe..6382ac8b 100644 --- a/llm_demo/app.py +++ b/llm_demo/app.py @@ -24,7 +24,6 @@ from google.auth.transport import requests # type:ignore from google.oauth2 import id_token # type:ignore from markdown import markdown -from piny import StrictMatcher, YamlLoader # type: ignore from starlette.middleware.sessions import SessionMiddleware from orchestrator import BaseOrchestrator, createOrchestrator @@ -32,15 +31,6 @@ routes = APIRouter() templates = Jinja2Templates(directory="templates") -BASE_HISTORY = [ - { - "type": "ai", - "data": {"content": "I am an SFO Airport Assistant, ready to assist you."}, - } -] -CLIENT_ID = os.getenv("CLIENT_ID") -routes = APIRouter() - @asynccontextmanager async def lifespan(app: FastAPI): @@ -55,18 +45,11 @@ async def lifespan(app: FastAPI): @routes.post("/") async def index(request: Request): """Render the default template.""" -<<<<<<< HEAD:llm_demo/app.py # User session setup - orchestrator = request.app.state.orchestration_type + orchestrator = request.app.state.orchestrator session = request.session if "uuid" not in session or not orchestrator.user_session_exist(session["uuid"]): await orchestrator.user_session_create(session) -======= - # Agent setup - orchestrator = request.app.state.orchestrator - await orchestrator.user_session_create(request.session) - templates = Jinja2Templates(directory="templates") ->>>>>>> 493e7c5 (update interface):langchain_tools_demo/main.py return templates.TemplateResponse( "index.html", { @@ -92,11 +75,7 @@ async def login_google( user_name = get_user_name(str(user_id_token), client_id) # create new request session -<<<<<<< HEAD:llm_demo/app.py - orchestrator = request.app.state.orchestration_type -======= orchestrator = request.app.state.orchestrator ->>>>>>> 493e7c5 (update interface):langchain_tools_demo/main.py orchestrator.set_user_session_header(request.session["uuid"], str(user_id_token)) print("Logged in to Google.") @@ -129,11 +108,7 @@ async def chat_handler(request: Request, prompt: str = Body(embed=True)): # Add user message to chat history request.session["history"].append({"type": "human", "data": {"content": prompt}}) -<<<<<<< HEAD:llm_demo/app.py - orchestrator = request.app.state.orchestration_type -======= orchestrator = request.app.state.orchestrator ->>>>>>> 493e7c5 (update interface):langchain_tools_demo/main.py output = await orchestrator.user_session_invoke(request.session["uuid"], prompt) # Return assistant response request.session["history"].append({"type": "ai", "data": {"content": output}}) @@ -148,15 +123,9 @@ async def reset(request: Request): raise HTTPException(status_code=400, detail=f"No session to reset.") uuid = request.session["uuid"] -<<<<<<< HEAD:llm_demo/app.py - orchestrator = request.app.state.orchestration_type - if not orchestrator.user_session_exist(uuid): - raise HTTPException(status_code=500, detail=f"Current user session not found") -======= orchestrator = request.app.state.orchestrator if not orchestrator.user_session_exist(uuid): - raise HTTPException(status_code=500, detail=f"Current agent not found") ->>>>>>> 493e7c5 (update interface):langchain_tools_demo/main.py + raise HTTPException(status_code=500, detail=f"Current user session not found") await orchestrator.user_session_reset(uuid) request.session.clear() @@ -170,23 +139,16 @@ def get_user_name(user_token_id: str, client_id: str) -> str: def init_app( -<<<<<<< HEAD:llm_demo/app.py orchestration_type: Optional[str], client_id: Optional[str], secret_key: Optional[str], ) -> FastAPI: # FastAPI setup if orchestration_type is None: -======= - orchestrator: Optional[str], client_id: Optional[str], secret_key: Optional[str] -) -> FastAPI: - # FastAPI setup - if orchestrator is None: ->>>>>>> 493e7c5 (update interface):langchain_tools_demo/main.py raise HTTPException(status_code=500, detail="Orchestrator not found") app = FastAPI(lifespan=lifespan) app.state.client_id = client_id - app.state.orchestrator = createOrchestrator(orchestrator) + app.state.orchestrator = createOrchestrator(orchestration_type) app.include_router(routes) app.mount("/static", StaticFiles(directory="static"), name="static") app.add_middleware(SessionMiddleware, secret_key=secret_key) diff --git a/llm_demo/example-config.yml b/llm_demo/example-config.yml deleted file mode 100644 index 24465a79..00000000 --- a/llm_demo/example-config.yml +++ /dev/null @@ -1,4 +0,0 @@ -host: 0.0.0.0 -port: 8081 -orchestration: - kind: ${LLM_FRAMEWORK} diff --git a/llm_demo/int.tests.cloudbuild.yaml b/llm_demo/int.tests.cloudbuild.yaml index d6d162e3..7af70c2a 100644 --- a/llm_demo/int.tests.cloudbuild.yaml +++ b/llm_demo/int.tests.cloudbuild.yaml @@ -12,16 +12,6 @@ # limitations under the License. steps: - - id: "Update config" - name: python:3.11 - dir: langchain_tools_demo - entrypoint: /bin/bash - args: - - "-c" - - | - # Create config - cp example-config.yml config.yml - - id: "Deploy to Cloud Run" name: "gcr.io/cloud-builders/gcloud:latest" dir: llm_demo diff --git a/llm_demo/orchestrator/__init__.py b/llm_demo/orchestrator/__init__.py index c338a248..dd54d48c 100644 --- a/llm_demo/orchestrator/__init__.py +++ b/llm_demo/orchestrator/__init__.py @@ -16,4 +16,4 @@ from . import vertexai_function_calling from .orchestrator import BaseOrchestrator, createOrchestrator -__ALL__ = ["BaseOrchestrator", "createOrchestrator", "langchain_tools"] +__ALL__ = ["BaseOrchestrator", "createOrchestrator", "langchain_tools", "vertexai_function_calling"] diff --git a/llm_demo/orchestrator/langchain_tools/__init__.py b/llm_demo/orchestrator/langchain_tools/__init__.py index cfefd3fc..0633529c 100644 --- a/llm_demo/orchestrator/langchain_tools/__init__.py +++ b/llm_demo/orchestrator/langchain_tools/__init__.py @@ -14,8 +14,4 @@ from .langchain_tools_orchestrator import LangChainToolsOrchestrator -<<<<<<<< HEAD:llm_demo/orchestrator/langchain_tools/__init__.py __ALL__ = ["LangChainToolsOrchestrator"] -======== -__ALL__ = [LangChainToolsOrchestrator] ->>>>>>>> 6daec6e (update interface and resolve comments):langchain_tools_demo/orchestrator/langchain_tools/__init__.py diff --git a/langchain_tools_demo/orchestrator/vertexai_function_calling/__init__.py b/llm_demo/orchestrator/vertexai_function_calling/__init__.py similarity index 100% rename from langchain_tools_demo/orchestrator/vertexai_function_calling/__init__.py rename to llm_demo/orchestrator/vertexai_function_calling/__init__.py diff --git a/langchain_tools_demo/orchestrator/vertexai_function_calling/function_calling_orchestrator.py b/llm_demo/orchestrator/vertexai_function_calling/function_calling_orchestrator.py similarity index 78% rename from langchain_tools_demo/orchestrator/vertexai_function_calling/function_calling_orchestrator.py rename to llm_demo/orchestrator/vertexai_function_calling/function_calling_orchestrator.py index 03f25a40..97715e4a 100644 --- a/langchain_tools_demo/orchestrator/vertexai_function_calling/function_calling_orchestrator.py +++ b/llm_demo/orchestrator/vertexai_function_calling/function_calling_orchestrator.py @@ -12,60 +12,32 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import os import uuid from datetime import date -from typing import Any, Dict +from typing import Any, Dict, List, Optional from aiohttp import ClientSession, TCPConnector from fastapi import HTTPException -from google.auth.transport.requests import Request # type: ignore -from google.protobuf.json_format import MessageToDict -from vertexai.preview.generative_models import ( # type: ignore - ChatSession, - GenerationResponse, - GenerativeModel, - Part, -) +from langchain.agents import AgentType, initialize_agent +from langchain.agents.agent import AgentExecutor +from langchain.globals import set_verbose # type: ignore +from langchain.memory import ChatMessageHistory, ConversationBufferMemory +from langchain.prompts.chat import ChatPromptTemplate +from langchain.tools import StructuredTool +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage +from langchain_google_vertexai import VertexAI from ..orchestrator import BaseOrchestrator, classproperty -from .functions import assistant_tool, function_request +from .tools import initialize_tools +set_verbose(bool(os.getenv("DEBUG", default=False))) MODEL = "gemini-pro" -BASE_URL = os.getenv("BASE_URL", default="http://127.0.0.1:8080") BASE_HISTORY = { "type": "ai", - "data": {"content": "I am an SFO Airport Assistant, ready to assist you."}, + "data": {"content": "Welcome to Cymbal Air! How may I assist you?"}, } -CREDENTIALS = None - - -def get_id_token(): - global CREDENTIALS - if CREDENTIALS is None: - CREDENTIALS, _ = google.auth.default() - if not hasattr(CREDENTIALS, "id_token"): - # Use Compute Engine default credential - CREDENTIALS = compute_engine.IDTokenCredentials( - request=Request(), - target_audience=BASE_URL, - use_metadata_identity_endpoint=True, - ) - if not CREDENTIALS.valid: - CREDENTIALS.refresh(Request()) - if hasattr(CREDENTIALS, "id_token"): - return CREDENTIALS.id_token - else: - return CREDENTIALS.token - - -def get_headers(client: ClientSession): - """Helper method to generate ID tokens for authenticated requests""" - headers = client.headers - if not "http://" in BASE_URL: - # Append ID Token to make authenticated requests to Cloud Run services - headers["Authorization"] = f"Bearer {get_id_token()}" - return headers class UserChatModel: diff --git a/langchain_tools_demo/orchestrator/vertexai_function_calling/functions.py b/llm_demo/orchestrator/vertexai_function_calling/functions.py similarity index 100% rename from langchain_tools_demo/orchestrator/vertexai_function_calling/functions.py rename to llm_demo/orchestrator/vertexai_function_calling/functions.py From 234d698fa119bfd8ec394d79a0c2335f5067ed51 Mon Sep 17 00:00:00 2001 From: Yuan Teoh Date: Mon, 4 Mar 2024 10:59:15 -0800 Subject: [PATCH 15/18] resolve comments --- .../langchain_tools_orchestrator.py | 11 ++- llm_demo/orchestrator/orchestrator.py | 2 + .../function_calling_orchestrator.py | 81 +++++++++++++------ 3 files changed, 64 insertions(+), 30 deletions(-) diff --git a/llm_demo/orchestrator/langchain_tools/langchain_tools_orchestrator.py b/llm_demo/orchestrator/langchain_tools/langchain_tools_orchestrator.py index be147600..4b020caa 100644 --- a/llm_demo/orchestrator/langchain_tools/langchain_tools_orchestrator.py +++ b/llm_demo/orchestrator/langchain_tools/langchain_tools_orchestrator.py @@ -33,7 +33,6 @@ from .tools import initialize_tools set_verbose(bool(os.getenv("DEBUG", default=False))) -MODEL = "gemini-pro" BASE_HISTORY = { "type": "ai", "data": {"content": "Welcome to Cymbal Air! How may I assist you?"}, @@ -55,8 +54,9 @@ def initialize_agent( tools: List[StructuredTool], history: List[BaseMessage], prompt: ChatPromptTemplate, + model: str ) -> "UserAgent": - llm = VertexAI(max_output_tokens=512, model_name=MODEL) + llm = VertexAI(max_output_tokens=512, model_name=model) memory = ConversationBufferMemory( chat_memory=ChatMessageHistory(messages=history), memory_key="chat_history", @@ -88,10 +88,13 @@ async def invoke(self, prompt: str) -> Dict[str, Any]: class LangChainToolsOrchestrator(BaseOrchestrator): - _user_sessions: Dict[str, UserAgent] = {} + _user_sessions: Dict[str, UserAgent] # aiohttp context connector = None + def __init__(): + self._user_sessions = {} + @classproperty def kind(cls): return "langchain-tools" @@ -111,7 +114,7 @@ async def user_session_create(self, session: dict[str, Any]): client = await self.create_client_session() tools = await initialize_tools(client) prompt = self.create_prompt_template(tools) - agent = UserAgent.initialize_agent(client, tools, history, prompt) + agent = UserAgent.initialize_agent(client, tools, history, prompt, self.MODEL) self._user_sessions[id] = agent async def user_session_invoke(self, uuid: str, prompt: str) -> str: diff --git a/llm_demo/orchestrator/orchestrator.py b/llm_demo/orchestrator/orchestrator.py index b0490450..c7e7d0d4 100644 --- a/llm_demo/orchestrator/orchestrator.py +++ b/llm_demo/orchestrator/orchestrator.py @@ -26,6 +26,8 @@ def __get__(self, instance, owner): class BaseOrchestrator(ABC): + MODEL = "gemini-pro" + @classproperty @abstractmethod def kind(cls): diff --git a/llm_demo/orchestrator/vertexai_function_calling/function_calling_orchestrator.py b/llm_demo/orchestrator/vertexai_function_calling/function_calling_orchestrator.py index 97715e4a..62010088 100644 --- a/llm_demo/orchestrator/vertexai_function_calling/function_calling_orchestrator.py +++ b/llm_demo/orchestrator/vertexai_function_calling/function_calling_orchestrator.py @@ -16,24 +16,23 @@ import os import uuid from datetime import date -from typing import Any, Dict, List, Optional +from typing import Any, Dict from aiohttp import ClientSession, TCPConnector from fastapi import HTTPException -from langchain.agents import AgentType, initialize_agent -from langchain.agents.agent import AgentExecutor -from langchain.globals import set_verbose # type: ignore -from langchain.memory import ChatMessageHistory, ConversationBufferMemory -from langchain.prompts.chat import ChatPromptTemplate -from langchain.tools import StructuredTool -from langchain_core.messages import AIMessage, BaseMessage, HumanMessage -from langchain_google_vertexai import VertexAI +from google.auth.transport.requests import Request # type: ignore +from google.protobuf.json_format import MessageToDict +from vertexai.preview.generative_models import ( + ChatSession, + GenerationResponse, + GenerativeModel, + Part, +) from ..orchestrator import BaseOrchestrator, classproperty -from .tools import initialize_tools +from .functions import assistant_tool -set_verbose(bool(os.getenv("DEBUG", default=False))) -MODEL = "gemini-pro" +DEBUG = os.getenv("DEBUG", default=False) BASE_HISTORY = { "type": "ai", "data": {"content": "Welcome to Cymbal Air! How may I assist you?"}, @@ -49,24 +48,26 @@ def __init__(self, client: ClientSession, chat: ChatSession): self.chat = chat @classmethod - def initialize_chat_model(cls, client: ClientSession) -> "UserChatModel": - model = GenerativeModel(MODEL, tools=[assistant_tool()]) - function_calling_session = model.start_chat() + def initialize_chat_model( + cls, client: ClientSession, model: str + ) -> "UserChatModel": + chat_model = GenerativeModel(model, tools=[assistant_tool()]) + function_calling_session = chat_model.start_chat() return UserChatModel(client, function_calling_session) async def close(self): await self.client.close() - async def invoke(self, prompt: str) -> Dict[str, Any]: - today_date = date.today().strftime("%Y-%m-%d") - today = f"Today is {today_date}." - model_response = self.request_chat_model(prompt + today) - print(f"function call response:\n{model_response}") + async def invoke(self, input_prompt: str) -> Dict[str, Any]: + prompt = self.get_prompt() + model_response = self.request_chat_model(prompt + input_prompt) + self.debug_log(f"Prompt:\n{prompt}.\nQuestion: {input_prompt}.") + self.debug_log(f"Function call response:\n{model_response}") part_response = model_response.candidates[0].content.parts[0] while "function_call" in part_response._raw_part: function_call = MessageToDict(part_response.function_call._pb) function_response = await self.request_function(function_call) - print(f"function response:\n{function_response}") + self.debug_log(f"Function response:\n{function_response}") part = Part.from_function_response( name=function_call["name"], response={ @@ -77,13 +78,22 @@ async def invoke(self, prompt: str) -> Dict[str, Any]: part_response = model_response.candidates[0].content.parts[0] if "text" in part_response._raw_part: content = part_response.text - print(f"output content: {content}") + self.debug_log(f"Output content: {content}") return {"output": content} else: raise HTTPException( status_code=500, detail="Error: Chat model response unknown" ) + def get_prompt(self) -> str: + today_date = date.today().strftime("%Y-%m-%d") + prompt = f"{PREFIX}. Today is {today_date}." + return prompt + + def debug_log(self, output: str) -> None: + if DEBUG: + print(output) + def request_chat_model(self, prompt: str): try: model_response = self.chat.send_message(prompt) @@ -94,8 +104,7 @@ def request_chat_model(self, prompt: str): async def request_function(self, function_call): url = function_request(function_call["name"]) params = function_call["args"] - print(f"function url is {url}") - print(f"params is {params}") + self.debug_log(f"Function url is {url}.\nParams is {params}.") response = await self.client.get( url=f"{BASE_URL}/{url}", params=params, @@ -106,10 +115,13 @@ async def request_function(self, function_call): class FunctionCallingOrchestrator(BaseOrchestrator): - _user_sessions: Dict[str, UserChatModel] = {} + _user_sessions: Dict[str, UserChatModel] # aiohttp context connector = None + def __init__(): + self._user_sessions = {} + @classproperty def kind(cls): return "vertexai-function-calling" @@ -126,7 +138,7 @@ async def user_session_create(self, session: dict[str, Any]): if "history" not in session: session["history"] = [BASE_HISTORY] client = await self.create_client_session() - chat = UserChatModel.initialize_chat_model(client) + chat = UserChatModel.initialize_chat_model(client, self.MODEL) self._user_sessions[id] = chat async def user_session_invoke(self, uuid: str, prompt: str) -> str: @@ -161,3 +173,20 @@ def close_clients(self): asyncio.create_task(a.close()) for a in self._user_sessions.values() ] asyncio.gather(*close_client_tasks) + + +PREFIX = """The Cymbal Air Customer Service Assistant helps customers of Cymbal Air with their travel needs. + +Cymbal Air is a passenger airline offering convenient flights to many cities around the world from its +hub in San Francisco. Cymbal Air takes pride in using the latest technology to offer the best customer +service! + +Cymbal Air Customer Service Assistant (or just "Assistant" for short) is designed to assist +with a wide range of tasks, from answering simple questions to complex multi-query questions that +require passing results from one query to another. Using the latest AI models, Assistant is able to +generate human-like text based on the input it receives, allowing it to engage in natural-sounding +conversations and provide responses that are coherent and relevant to the topic at hand. + +Assistant is a powerful tool that can help answer a wide range of questions pertaining to travel on Cymbal Air +as well as ammenities of San Francisco Airport. +""" From 2a5896d1c0daf62fe1afff4e44aced668499a581 Mon Sep 17 00:00:00 2001 From: Yuan Teoh Date: Tue, 5 Mar 2024 08:28:48 -0800 Subject: [PATCH 16/18] update lint --- llm_demo/orchestrator/__init__.py | 10 +++++++--- .../langchain_tools/langchain_tools_orchestrator.py | 4 ++-- .../function_calling_orchestrator.py | 9 ++++++--- 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/llm_demo/orchestrator/__init__.py b/llm_demo/orchestrator/__init__.py index dd54d48c..ac65f8ff 100644 --- a/llm_demo/orchestrator/__init__.py +++ b/llm_demo/orchestrator/__init__.py @@ -12,8 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from . import langchain_tools -from . import vertexai_function_calling +from . import langchain_tools, vertexai_function_calling from .orchestrator import BaseOrchestrator, createOrchestrator -__ALL__ = ["BaseOrchestrator", "createOrchestrator", "langchain_tools", "vertexai_function_calling"] +__ALL__ = [ + "BaseOrchestrator", + "createOrchestrator", + "langchain_tools", + "vertexai_function_calling", +] diff --git a/llm_demo/orchestrator/langchain_tools/langchain_tools_orchestrator.py b/llm_demo/orchestrator/langchain_tools/langchain_tools_orchestrator.py index 4b020caa..9d73820e 100644 --- a/llm_demo/orchestrator/langchain_tools/langchain_tools_orchestrator.py +++ b/llm_demo/orchestrator/langchain_tools/langchain_tools_orchestrator.py @@ -54,7 +54,7 @@ def initialize_agent( tools: List[StructuredTool], history: List[BaseMessage], prompt: ChatPromptTemplate, - model: str + model: str, ) -> "UserAgent": llm = VertexAI(max_output_tokens=512, model_name=model) memory = ConversationBufferMemory( @@ -92,7 +92,7 @@ class LangChainToolsOrchestrator(BaseOrchestrator): # aiohttp context connector = None - def __init__(): + def __init__(self): self._user_sessions = {} @classproperty diff --git a/llm_demo/orchestrator/vertexai_function_calling/function_calling_orchestrator.py b/llm_demo/orchestrator/vertexai_function_calling/function_calling_orchestrator.py index 62010088..ad964194 100644 --- a/llm_demo/orchestrator/vertexai_function_calling/function_calling_orchestrator.py +++ b/llm_demo/orchestrator/vertexai_function_calling/function_calling_orchestrator.py @@ -22,7 +22,7 @@ from fastapi import HTTPException from google.auth.transport.requests import Request # type: ignore from google.protobuf.json_format import MessageToDict -from vertexai.preview.generative_models import ( +from vertexai.preview.generative_models import ( # type: ignore ChatSession, GenerationResponse, GenerativeModel, @@ -30,8 +30,9 @@ ) from ..orchestrator import BaseOrchestrator, classproperty -from .functions import assistant_tool +from .functions import assistant_tool, function_request +BASE_URL = os.getenv("BASE_URL", default="http://127.0.0.1:8080") DEBUG = os.getenv("DEBUG", default=False) BASE_HISTORY = { "type": "ai", @@ -64,6 +65,8 @@ async def invoke(self, input_prompt: str) -> Dict[str, Any]: self.debug_log(f"Prompt:\n{prompt}.\nQuestion: {input_prompt}.") self.debug_log(f"Function call response:\n{model_response}") part_response = model_response.candidates[0].content.parts[0] + + # implement multi turn chat with while loop while "function_call" in part_response._raw_part: function_call = MessageToDict(part_response.function_call._pb) function_response = await self.request_function(function_call) @@ -119,7 +122,7 @@ class FunctionCallingOrchestrator(BaseOrchestrator): # aiohttp context connector = None - def __init__(): + def __init__(self): self._user_sessions = {} @classproperty From 99850f16c9ffbe66fd9566513d866c61018399be Mon Sep 17 00:00:00 2001 From: Yuan Teoh Date: Tue, 5 Mar 2024 09:48:06 -0800 Subject: [PATCH 17/18] update functions --- .../function_calling_orchestrator.py | 4 +-- .../vertexai_function_calling/functions.py | 32 +++++++++++++++++++ 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/llm_demo/orchestrator/vertexai_function_calling/function_calling_orchestrator.py b/llm_demo/orchestrator/vertexai_function_calling/function_calling_orchestrator.py index ad964194..5c6b0aa4 100644 --- a/llm_demo/orchestrator/vertexai_function_calling/function_calling_orchestrator.py +++ b/llm_demo/orchestrator/vertexai_function_calling/function_calling_orchestrator.py @@ -30,9 +30,8 @@ ) from ..orchestrator import BaseOrchestrator, classproperty -from .functions import assistant_tool, function_request +from .functions import BASE_URL, assistant_tool, function_request, get_headers -BASE_URL = os.getenv("BASE_URL", default="http://127.0.0.1:8080") DEBUG = os.getenv("DEBUG", default=False) BASE_HISTORY = { "type": "ai", @@ -79,6 +78,7 @@ async def invoke(self, input_prompt: str) -> Dict[str, Any]: ) model_response = self.request_chat_model(part) part_response = model_response.candidates[0].content.parts[0] + if "text" in part_response._raw_part: content = part_response.text self.debug_log(f"Output content: {content}") diff --git a/llm_demo/orchestrator/vertexai_function_calling/functions.py b/llm_demo/orchestrator/vertexai_function_calling/functions.py index 8e6c499f..61dcefcc 100644 --- a/llm_demo/orchestrator/vertexai_function_calling/functions.py +++ b/llm_demo/orchestrator/vertexai_function_calling/functions.py @@ -14,8 +14,12 @@ import os +import aiohttp from vertexai.preview import generative_models # type: ignore +BASE_URL = os.getenv("BASE_URL", default="http://127.0.0.1:8080") +CREDENTIALS = None + search_airports_func = generative_models.FunctionDeclaration( name="airports_search", description="Use this tool to list all airports matching search criteria. Takes at least one of country, city, name, or all of the above criteria. This function could also be used to search for airport information such as iata code.", @@ -88,6 +92,34 @@ ) +def get_id_token(): + global CREDENTIALS + if CREDENTIALS is None: + CREDENTIALS, _ = google.auth.default() + if not hasattr(CREDENTIALS, "id_token"): + # Use Compute Engine default credential + CREDENTIALS = compute_engine.IDTokenCredentials( + request=Request(), + target_audience=BASE_URL, + use_metadata_identity_endpoint=True, + ) + if not CREDENTIALS.valid: + CREDENTIALS.refresh(Request()) + if hasattr(CREDENTIALS, "id_token"): + return CREDENTIALS.id_token + else: + return CREDENTIALS.token + + +def get_headers(client: aiohttp.ClientSession): + """Helper method to generate ID tokens for authenticated requests""" + headers = client.headers + if not "http://" in BASE_URL: + # Append ID Token to make authenticated requests to Cloud Run services + headers["Authorization"] = f"Bearer {get_id_token()}" + return headers + + def function_request(function_call_name: str) -> str: functions_url = { "airports_search": "airports/search", From 356f3bad90be5e1018184549d5ff8e41d95443d1 Mon Sep 17 00:00:00 2001 From: Yuan Teoh Date: Tue, 5 Mar 2024 14:55:28 -0800 Subject: [PATCH 18/18] add ticket tools --- .../function_calling_orchestrator.py | 2 +- .../vertexai_function_calling/functions.py | 46 +++++++++++++++++++ 2 files changed, 47 insertions(+), 1 deletion(-) diff --git a/llm_demo/orchestrator/vertexai_function_calling/function_calling_orchestrator.py b/llm_demo/orchestrator/vertexai_function_calling/function_calling_orchestrator.py index 5c6b0aa4..edd82835 100644 --- a/llm_demo/orchestrator/vertexai_function_calling/function_calling_orchestrator.py +++ b/llm_demo/orchestrator/vertexai_function_calling/function_calling_orchestrator.py @@ -78,7 +78,7 @@ async def invoke(self, input_prompt: str) -> Dict[str, Any]: ) model_response = self.request_chat_model(part) part_response = model_response.candidates[0].content.parts[0] - + if "text" in part_response._raw_part: content = part_response.text self.debug_log(f"Output content: {content}") diff --git a/llm_demo/orchestrator/vertexai_function_calling/functions.py b/llm_demo/orchestrator/vertexai_function_calling/functions.py index 61dcefcc..45461788 100644 --- a/llm_demo/orchestrator/vertexai_function_calling/functions.py +++ b/llm_demo/orchestrator/vertexai_function_calling/functions.py @@ -91,6 +91,48 @@ }, ) +insert_ticket_func = generative_models.FunctionDeclaration( + name="insert_ticket", + description="Use this tool to book a flight ticket for the user.", + parameters={ + "type": "object", + "properties": { + "airline": { + "type": "string", + "description": "A code for an airline service consisting of two-character airline designator.", + }, + "flight_number": { + "type": "string", + "description": "A 1 to 4 digit number of the flight.", + }, + "departure_airport": { + "type": "string", + "description": "The iata code for flight departure airport.", + }, + "arrival_airport": { + "type": "string", + "description": "The iata code for flight arrival airport.", + }, + "departure_time": { + "type": "string", + "description": "The departure time for flight.", + }, + "arrival_time": { + "type": "string", + "description": "The arrival time for flight.", + }, + }, + }, +) + +""" TODO: Remove this comment once the issue is solved (/~https://github.com/googleapis/python-aiplatform/issues/3405) +list_tickets_func = generative_models.FunctionDeclaration( + name="list_tickets", + description="Use this tool to list a user's flight tickets. This tool takes no input parameters and returns a list of current user's flight tickets.", + parameters=None, +) +""" + def get_id_token(): global CREDENTIALS @@ -126,6 +168,8 @@ def function_request(function_call_name: str) -> str: "flights_search": "flights/search", "list_flights": "flights/search", "amenities_search": "amenities/search", + "insert_ticket": "tickets/insert", + # "list_tickets": "tickets/list", } return functions_url[function_call_name] @@ -137,5 +181,7 @@ def assistant_tool(): search_amenities_func, search_flights_by_number_func, list_flights_func, + insert_ticket_func, + # list_tickets_func, ], )