From 8cefed07c5e4cc4357d08fc3a29920dc2cfabd6a Mon Sep 17 00:00:00 2001 From: Yuan <45984206+Yuan325@users.noreply.github.com> Date: Tue, 30 Jul 2024 13:51:34 -0700 Subject: [PATCH] feat: Add langgraph orchestration (#447) Langgraph is introduced in langchain v0.2.0. The legacy langchain agent (implemented in the `langchain-tools` orchestration) is deprecated in langchain v0.2.0 and will soon be removed in the newer version of langchain. This PR adds a new orchestration - langgraph --- .github/sync-repo-settings.yaml | 1 + llm_demo/app.py | 2 + llm_demo/langgraph.int.tests.cloudbuild.yaml | 80 ++++ llm_demo/orchestrator/__init__.py | 3 +- .../langchain_tools_orchestrator.py | 9 +- llm_demo/orchestrator/langgraph/__init__.py | 17 + .../langgraph/langgraph_orchestrator.py | 287 ++++++++++++ .../orchestrator/langgraph/react_graph.py | 242 ++++++++++ llm_demo/orchestrator/langgraph/tool_node.py | 122 +++++ llm_demo/orchestrator/langgraph/tools.py | 437 ++++++++++++++++++ llm_demo/orchestrator/orchestrator.py | 4 + .../function_calling_orchestrator.py | 9 +- llm_demo/requirements.txt | 2 + 13 files changed, 1212 insertions(+), 3 deletions(-) create mode 100644 llm_demo/langgraph.int.tests.cloudbuild.yaml create mode 100644 llm_demo/orchestrator/langgraph/__init__.py create mode 100644 llm_demo/orchestrator/langgraph/langgraph_orchestrator.py create mode 100644 llm_demo/orchestrator/langgraph/react_graph.py create mode 100644 llm_demo/orchestrator/langgraph/tool_node.py create mode 100644 llm_demo/orchestrator/langgraph/tools.py diff --git a/.github/sync-repo-settings.yaml b/.github/sync-repo-settings.yaml index 81bacdff..2263b666 100644 --- a/.github/sync-repo-settings.yaml +++ b/.github/sync-repo-settings.yaml @@ -36,6 +36,7 @@ branchProtectionRules: - "retrieval-service-alloydb-pr (retrieval-app-testing)" - "retrieval-service-cloudsql-pg-pr (retrieval-app-testing)" - "llm-demo-langchain-tools-pr (retrieval-app-testing)" + - "llm-demo-langgraph-pr (retrieval-app-testing)" - "llm-demo-vertexai-fc-pr (retrieval-app-testing)" # Set team access permissionRules: diff --git a/llm_demo/app.py b/llm_demo/app.py index 6fb7d9ae..1e1468aa 100644 --- a/llm_demo/app.py +++ b/llm_demo/app.py @@ -188,6 +188,8 @@ async def decline_flight(request: Request): """Handler for LangChain chat requests""" # Note in the history, that the ticket was not booked # This is helpful in case of reloads so there doesn't seem to be a break in communication. + orchestrator = request.app.state.orchestrator + response = await orchestrator.user_session_decline_ticket(request.session["uuid"]) request.session["history"].append( {"type": "ai", "data": {"content": "Please confirm if you would like to book."}} ) diff --git a/llm_demo/langgraph.int.tests.cloudbuild.yaml b/llm_demo/langgraph.int.tests.cloudbuild.yaml new file mode 100644 index 00000000..65c81762 --- /dev/null +++ b/llm_demo/langgraph.int.tests.cloudbuild.yaml @@ -0,0 +1,80 @@ +# Copyright 2024 Google LLC +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +steps: + - id: "Deploy to Cloud Run" + name: "gcr.io/cloud-builders/gcloud:latest" + dir: llm_demo + script: | + #!/usr/bin/env bash + gcloud run deploy ${_SERVICE} \ + --source . \ + --region ${_REGION} \ + --no-allow-unauthenticated \ + --update-env-vars ORCHESTRATION_TYPE=${_ORCHESTRATION_TYPE} + + - id: "Test Frontend" + name: "gcr.io/cloud-builders/gcloud:latest" + entrypoint: /bin/bash + env: # Set env var expected by app + args: + - "-c" + - | + export URL=$(gcloud run services describe ${_SERVICE} --region ${_REGION} --format 'value(status.url)') + export ID_TOKEN=$(gcloud auth print-identity-token --audiences $$URL) + export ORCHESTRATION_TYPE=${_ORCHESTRATION_TYPE} + + # Test `/` route + curl -c cookies.txt -si --fail --show-error -H "Authorization: Bearer $$ID_TOKEN" $$URL + + # Test `/chat` route should fail + msg=$(curl -si --show-error \ + -X POST \ + -H "Authorization: Bearer $$ID_TOKEN" \ + -H 'Content-Type: application/json' \ + -d '{"prompt":"How can you help me?"}' \ + $$URL/chat) + + if grep -q "400" <<< "$msg"; then + echo "Chat Handler Test: PASSED" + else + echo "Chat Handler Test: FAILED" + echo $msg && exit 1 + fi + + # Test `/chat` route + curl -b cookies.txt -si --fail --show-error \ + -X POST \ + -H "Authorization: Bearer $$ID_TOKEN" \ + -H 'Content-Type: application/json' \ + -d '{"prompt":"How can you help me?"}' \ + $$URL/chat + + - id: "Delete image and service" + name: "gcr.io/cloud-builders/gcloud" + script: | + #!/usr/bin/env bash + gcloud artifacts docker images delete $_GCR_HOSTNAME/$PROJECT_ID/cloud-run-source-deploy/$_SERVICE --quiet + gcloud run services delete ${_SERVICE} --region ${_REGION} --quiet + +serviceAccount: "projects/$PROJECT_ID/serviceAccounts/548341735270-compute@developer.gserviceaccount.com" # Necessary for ID token creation +options: + automapSubstitutions: true + logging: CLOUD_LOGGING_ONLY # Necessary for custom service account + dynamic_substitutions: true + +substitutions: + _GCR_HOSTNAME: ${_REGION}-docker.pkg.dev + _SERVICE: demo-service-${BUILD_ID} + _REGION: us-central1 + _ORCHESTRATION_TYPE: langgraph diff --git a/llm_demo/orchestrator/__init__.py b/llm_demo/orchestrator/__init__.py index ac65f8ff..cc4edd67 100644 --- a/llm_demo/orchestrator/__init__.py +++ b/llm_demo/orchestrator/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from . import langchain_tools, vertexai_function_calling +from . import langchain_tools, langgraph, vertexai_function_calling from .orchestrator import BaseOrchestrator, createOrchestrator __ALL__ = [ @@ -20,4 +20,5 @@ "createOrchestrator", "langchain_tools", "vertexai_function_calling", + "langgraph", ] diff --git a/llm_demo/orchestrator/langchain_tools/langchain_tools_orchestrator.py b/llm_demo/orchestrator/langchain_tools/langchain_tools_orchestrator.py index 62a2b1c6..b3af94c8 100644 --- a/llm_demo/orchestrator/langchain_tools/langchain_tools_orchestrator.py +++ b/llm_demo/orchestrator/langchain_tools/langchain_tools_orchestrator.py @@ -16,7 +16,7 @@ import os import uuid from datetime import datetime -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from aiohttp import ClientSession, TCPConnector from fastapi import HTTPException @@ -127,6 +127,13 @@ async def user_session_insert_ticket(self, uuid: str, params: str) -> Any: response = await user_session.insert_ticket(params) return response + async def user_session_decline_ticket(self, uuid: str) -> Optional[dict[str, Any]]: + """ + Used if there's a process to be done after user decline ticket. + Return None is nothing is needed to be done. + """ + return None + async def check_and_add_confirmations(self, response: Dict[str, Any]): for step in response.get("intermediate_steps") or []: if len(step) > 0: diff --git a/llm_demo/orchestrator/langgraph/__init__.py b/llm_demo/orchestrator/langgraph/__init__.py new file mode 100644 index 00000000..f7e7d700 --- /dev/null +++ b/llm_demo/orchestrator/langgraph/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .langgraph_orchestrator import LangGraphOrchestrator + +__ALL__ = ["LangGraphOrchestrator"] diff --git a/llm_demo/orchestrator/langgraph/langgraph_orchestrator.py b/llm_demo/orchestrator/langgraph/langgraph_orchestrator.py new file mode 100644 index 00000000..3c4f1914 --- /dev/null +++ b/llm_demo/orchestrator/langgraph/langgraph_orchestrator.py @@ -0,0 +1,287 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import os +import uuid +from datetime import datetime +from typing import Annotated, Any, Dict, List, Literal, Optional, Sequence, TypedDict + +from aiohttp import ClientSession, TCPConnector +from fastapi import HTTPException +from langchain.globals import set_verbose # type: ignore +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +from langchain_core.runnables import RunnableConfig, RunnableLambda +from langchain_core.tools import StructuredTool +from langgraph.checkpoint import MemorySaver +from langgraph.checkpoint.base import empty_checkpoint +from pytz import timezone + +from ..orchestrator import BaseOrchestrator, classproperty +from .react_graph import create_graph +from .tools import initialize_tools + +DEBUG = bool(os.getenv("DEBUG", default=False)) +set_verbose(DEBUG) +BASE_HISTORY = { + "type": "ai", + "data": {"content": "Welcome to Cymbal Air! How may I assist you?"}, +} + + +class LangGraphOrchestrator(BaseOrchestrator): + _user_sessions: Dict[str, str] + # aiohttp context + connector = None + + def __init__(self): + self._user_sessions = {} + self._langgraph_app = None + self._checkpointer = None + + @classproperty + def kind(cls): + return "langgraph" + + def user_session_exist(self, uuid: str) -> bool: + return uuid in self._user_sessions + + async def user_session_insert_ticket(self, uuid: str, params: str) -> Any: + response = await self.user_session_invoke(uuid, None) + return "ticket booking success" + + async def user_session_decline_ticket(self, uuid: str) -> dict[str, Any]: + config = self.get_config(uuid) + human_message = HumanMessage( + content="I changed my mind. Decline ticket booking." + ) + self._langgraph_app.update_state(config, {"messages": [human_message]}) + response = await self.user_session_invoke(uuid, None) + return response + + async def user_session_create(self, session: dict[str, Any]): + """Create and load an agent executor with tools and LLM.""" + client = await self.create_client_session() + if self._langgraph_app is None: + print("Initializing graph..") + tools = await initialize_tools(client) + prompt = self.create_prompt_template(tools) + checkpointer = MemorySaver() + langgraph_app = await create_graph( + tools, checkpointer, prompt, self.MODEL, client, DEBUG + ) + self._checkpointer = checkpointer + self._langgraph_app = langgraph_app + + print("Initializing session") + if "uuid" not in session: + session["uuid"] = str(uuid.uuid4()) + session_id = session["uuid"] + if "history" not in session: + session["history"] = [BASE_HISTORY] + history = self.parse_messages(session["history"]) + + config = self.get_config(session_id) + self._langgraph_app.update_state(config, {"messages": history}) + self._user_sessions[session_id] = "" + self.client = client + + async def user_session_invoke( + self, uuid: str, user_prompt: Optional[str] + ) -> dict[str, Any]: + config = self.get_config(uuid) + if user_prompt: + user_query = [HumanMessage(content=user_prompt)] + app_input = { + "messages": user_query, + "user_id_token": self.get_user_id_token(uuid), + } + else: + app_input = None + final_state = await self._langgraph_app.ainvoke( + app_input, + config=config, + ) + last_message = final_state["messages"][-1] + output = last_message.content + # Build final response + response = {} + response["output"] = output + # If needs ticket verification + has_add_kwargs = hasattr(last_message, "additional_kwargs") + if has_add_kwargs and last_message.additional_kwargs.get("confirmation"): + tool_call = last_message.tool_calls[0] + response["confirmation"] = { + "tool": tool_call.get("name"), + "params": tool_call.get("args"), + } + return response + response["state"] = final_state + return response + + def user_session_reset(self, session: dict[str, Any], uuid: str): + del session["history"] + base_history = self.get_base_history(session) + session["history"] = [base_history] + history = self.parse_messages(session["history"]) + + # Reset graph checkpointer + checkpoint = empty_checkpoint() + config = self.get_config(uuid) + self._checkpointer.put(config=config, checkpoint=checkpoint, metadata={}) + + # Update state with message history + self._langgraph_app.update_state(config, {"messages": history}) + + def get_user_session(self, uuid: str): + raise NotImplementedError("Irrelevant to LangGraph.") + + def set_user_session_header(self, uuid: str, user_id_token: str): + self._user_sessions[uuid] = user_id_token + + def get_user_id_token(self, uuid: str) -> Optional[str]: + return self._user_sessions.get(uuid) + + async def get_connector(self) -> TCPConnector: + if self.connector is None: + self.connector = TCPConnector(limit=100) + return self.connector + + async def create_client_session(self) -> ClientSession: + return ClientSession( + connector=await self.get_connector(), + connector_owner=False, + headers={}, + raise_for_status=True, + ) + + def create_prompt_template(self, tools: List[StructuredTool]) -> ChatPromptTemplate: + # Create new prompt template + tool_strings = "\n".join( + [f"> {tool.name}: {tool.description}" for tool in tools] + ) + tool_names = ", ".join([tool.name for tool in tools]) + format_instructions = FORMAT_INSTRUCTIONS.format( + tool_names=tool_names, + ) + current_datetime = "Today's date and current time is {cur_datetime}." + template = "\n\n".join( + [ + PREFIX, + current_datetime, + TOOLS_PREFIX, + tool_strings, + format_instructions, + SUFFIX, + ] + ) + + prompt = ChatPromptTemplate.from_messages( + [("system", template), ("placeholder", "{messages}")] + ) + prompt = prompt.partial(cur_datetime=self.get_datetime) + return prompt + + def get_datetime(self): + formatter = "%A, %m/%d/%Y, %H:%M:%S" + now = datetime.now(timezone("US/Pacific")) + return now.strftime(formatter) + + def parse_messages(self, datas: List[Any]) -> List[BaseMessage]: + messages: List[BaseMessage] = [] + for data in datas: + if data["type"] == "human": + messages.append(HumanMessage(content=data["data"]["content"])) + elif data["type"] == "ai": + messages.append(AIMessage(content=data["data"]["content"])) + else: + raise Exception("Message type not found.") + return messages + + def get_base_history(self, session: dict[str, Any]): + if "user_info" in session: + base_history = { + "type": "ai", + "data": { + "content": f"Welcome to Cymbal Air, {session['user_info']['name']}! How may I assist you?" + }, + } + return base_history + return BASE_HISTORY + + def get_config(self, uuid: str): + return {"configurable": {"thread_id": uuid}} + + async def user_session_signout(self, uuid: str): + checkpoint = empty_checkpoint() + config = self.get_config(uuid) + self._checkpointer.put(config=config, checkpoint=checkpoint, metadata={}) + del self._user_sessions[uuid] + + def close_clients(self): + self.client.close() + + +PREFIX = """The Cymbal Air Customer Service Assistant helps customers of Cymbal Air with their travel needs. + +Cymbal Air (airline unique two letter identifier as CY) is a passenger airline offering convenient flights to many cities around the world from its +hub in San Francisco. Cymbal Air takes pride in using the latest technology to offer the best customer +service! + +Cymbal Air Customer Service Assistant (or just "Assistant" for short) is designed to assist +with a wide range of tasks, from answering simple questions to complex multi-query questions that +require passing results from one query to another. Using the latest AI models, Assistant is able to +generate human-like text based on the input it receives, allowing it to engage in natural-sounding +conversations and provide responses that are coherent and relevant to the topic at hand. The assistant should +not answer questions about other peoples information for privacy reasons. + +Assistant is a powerful tool that can help answer a wide range of questions pertaining to travel on Cymbal Air +as well as ammenities of San Francisco Airport.""" + +TOOLS_PREFIX = """ +TOOLS: +------ +Assistant can ask the user to use tools to look up information that may be helpful in answering the users original question. The tools the human can use are: + +""" + +FORMAT_INSTRUCTIONS = """ +When responding, please output a response in one of two formats: + +** Option 1:** +Use this is you want to use a tool. +Markdown code snippet formatted in the following schema: +```json +{{{{ + "action": string, \ The action to take. Must be one of {tool_names} + "action_input": string \ The input to the action +}}}} +``` + +**Option 2:** +Use this if you want to respond directly to the human. +Markdown code snippet formatted following schema: +```json +{{{{ + "action": "Final Answer", + "action_input": string \ You should put what you want to return to user here +}}}} +``` +""" + +SUFFIX = """Begin! Use tools if necessary. Respond directly if appropriate. + +Remember to respond with a markdown code snippet of a json a single action, and NOTHING else. +""" diff --git a/llm_demo/orchestrator/langgraph/react_graph.py b/llm_demo/orchestrator/langgraph/react_graph.py new file mode 100644 index 00000000..0c39832e --- /dev/null +++ b/llm_demo/orchestrator/langgraph/react_graph.py @@ -0,0 +1,242 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import uuid +from typing import Annotated, Literal, Sequence, TypedDict + +from aiohttp import ClientSession +from langchain_core.messages import ( + AIMessage, + BaseMessage, + HumanMessage, + ToolCall, + ToolMessage, +) +from langchain_core.prompts.chat import ChatPromptTemplate +from langchain_core.runnables import RunnableConfig, RunnableLambda +from langchain_google_vertexai import VertexAI +from langgraph.checkpoint import MemorySaver +from langgraph.graph import END, StateGraph +from langgraph.graph.message import add_messages +from langgraph.managed import IsLastStep + +from .tool_node import ToolNode +from .tools import get_confirmation_needing_tools, insert_ticket, validate_ticket + + +class UserState(TypedDict): + """ + State with messages and ClientSession for each session/user. + """ + + messages: Annotated[Sequence[BaseMessage], add_messages] + user_id_token: str + is_last_step: IsLastStep + + +async def create_graph( + tools, + checkpointer: MemorySaver, + prompt: ChatPromptTemplate, + model_name: str, + client: ClientSession, + debug: bool, +): + """ + Creates a graph that works with a chat model that utilizes tool calling. + + Args: + tools: A list of StructuredTools that will bind with the chat model. + checkpointer: The checkpoint saver object. This is useful for persisting + the state of the graph (e.g., as chat memory). + prompt: Initial prompt for the model. This applies to messages before they + are passed into the LLM. + model_name: The chat model name. + + Returns: + A compilled LangChain runnable that can be used for chat interactions. + + The resulting graph looks like this: + [*] --> Start + Start --> Agent + Agent --> Tools : continue + Tools --> Agent + Agent --> End : end + End --> [*] + """ + # tool node + tool_node = ToolNode(tools) + + # model node + model = VertexAI(max_output_tokens=512, model_name=model_name, temperature=0.0) + + # Add the prompt to the model to create a model runnable + model_runnable = prompt | model + + async def acall_model(state: UserState, config: RunnableConfig): + """ + The node representing async function that calls the model. + After invoking model, it will return AIMessage back to the user. + """ + messages = state["messages"] + res = await model_runnable.ainvoke({"messages": messages}, config) + response = res.replace("```json", "").replace("```", "") + try: + json_response = json.loads(response) + action = json_response.get("action") + action_input = json_response.get("action_input") + if action == "Final Answer": + new_message = AIMessage(content=action_input) + else: + new_message = AIMessage( + content="suggesting a tool call", + tool_calls=[ + ToolCall(id=str(uuid.uuid4()), name=action, args=action_input) + ], + ) + except Exception as e: + json_response = response + new_message = AIMessage( + content="Sorry, failed to generate the right format for response" + ) + # if model exceed the number of steps and has not yet return a final answer + if state["is_last_step"] and hasattr(new_message, "tool_calls"): + return { + "messages": [ + AIMessage( + content="Sorry, need more steps to process this request.", + ) + ] + } + return {"messages": [new_message]} + + def agent_should_continue( + state: UserState, + ) -> Literal["booking_validation", "continue", "end"]: + """ + Function to determine which node is called after the agent node. + """ + messages = state["messages"] + last_message = messages[-1] + # If the LLM makes a tool call, then we route to the "tools" node + if hasattr(last_message, "tool_calls") and len(last_message.tool_calls) > 0: + confirmation_needing_tools = get_confirmation_needing_tools() + for tool_call in last_message.tool_calls: + tool_name = tool_call["name"] + if tool_name in confirmation_needing_tools: + if tool_name == "Insert Ticket": + return "booking_validation" + return "continue" + # Otherwise, we stop (reply to the user) + return "end" + + async def booking_validation_node(state: UserState, config: RunnableConfig): + """ + The node representing async function that validate the ticket. + After ticket validation, it will return AIMessage with updated ticket args. + """ + messages = state["messages"] + last_message = messages[-1] + user_id_token = state["user_id_token"] + if hasattr(last_message, "tool_calls") and len(last_message.tool_calls) > 0: + tool_call = last_message.tool_calls[0] + # Run ticket validation and return the correct ticket information + flight_info = await validate_ticket( + client, tool_call.get("args"), user_id_token + ) + + new_message = AIMessage( + content="Please confirm if you would like to book the ticket.", + tool_calls=[ + ToolCall( + id=str(uuid.uuid4()), + name=tool_call.get("name"), + args=flight_info, + ) + ], + additional_kwargs={"confirmation": True}, + ) + return {"messages": [new_message]} + + def booking_should_continue(state: UserState) -> Literal["continue", "agent"]: + """ + Function to determine which node is called after human response on ticket booking. + """ + messages = state["messages"] + last_message = messages[-1] + # If last message makes a tool call, then we route to the "tools" node to proceed with booking + if hasattr(last_message, "tool_calls") and len(last_message.tool_calls) > 0: + return "continue" + # Otherwise, send response back to agent + return "agent" + + async def insert_ticket_node(state: UserState, config: RunnableConfig): + """ + Node to update human response to prevent + """ + messages = state["messages"] + last_message = messages[-1] + user_id_token = state["user_id_token"] + # Run insert ticket + if hasattr(last_message, "tool_calls") and len(last_message.tool_calls) > 0: + tool_call = last_message.tool_calls[0] + output = await insert_ticket(client, tool_call.get("args"), user_id_token) + tool_call_id = tool_call.get("id") + tool_message = ToolMessage( + content=output, name="Insert Ticket", tool_call_id=tool_call_id + ) + human_message = HumanMessage(content="Looks good to me.") + ai_message = AIMessage(content=output) + return {"messages": [human_message, tool_message, ai_message]} + + # Define constant node strings + AGENT_NODE = "agent" + TOOL_NODE = "tools" + BOOKING_VALIDATION_NODE = "booking_validation" + INSERT_TICKET_NODE = "insert_ticket" + + # Define a new graph + llm_graph = StateGraph(UserState) + llm_graph.add_node(AGENT_NODE, RunnableLambda(acall_model)) + llm_graph.add_node(TOOL_NODE, tool_node) + llm_graph.add_node(BOOKING_VALIDATION_NODE, RunnableLambda(booking_validation_node)) + llm_graph.add_node(INSERT_TICKET_NODE, RunnableLambda(insert_ticket_node)) + + # Set agent node as the first node to call + llm_graph.set_entry_point(AGENT_NODE) + + # Add edges + llm_graph.add_conditional_edges( + AGENT_NODE, + agent_should_continue, + { + "continue": TOOL_NODE, + "booking_validation": BOOKING_VALIDATION_NODE, + "end": END, + }, + ) + llm_graph.add_edge(TOOL_NODE, AGENT_NODE) + llm_graph.add_conditional_edges( + BOOKING_VALIDATION_NODE, + booking_should_continue, + {"continue": INSERT_TICKET_NODE, "agent": AGENT_NODE}, + ) + llm_graph.add_edge(INSERT_TICKET_NODE, END) + + # Compile graph into a LangChain Runnable + langgraph_app = llm_graph.compile( + checkpointer=checkpointer, debug=debug, interrupt_after=["booking_validation"] + ) + return langgraph_app diff --git a/llm_demo/orchestrator/langgraph/tool_node.py b/llm_demo/orchestrator/langgraph/tool_node.py new file mode 100644 index 00000000..3093cc79 --- /dev/null +++ b/llm_demo/orchestrator/langgraph/tool_node.py @@ -0,0 +1,122 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import copy +import json +import uuid +from itertools import repeat +from typing import Any, Callable, Dict, Optional, Sequence, Union + +from langchain_core.messages import AIMessage, AnyMessage, ToolCall, ToolMessage +from langchain_core.runnables import RunnableConfig +from langchain_core.runnables.config import get_executor_for_config +from langchain_core.tools import BaseTool +from langchain_core.tools import tool as create_tool +from langgraph.utils import RunnableCallable + + +def str_output(output: Any) -> str: + if isinstance(output, str): + return output + else: + try: + return json.dumps(output) + except Exception: + return str(output) + + +class ToolNode(RunnableCallable): + """ + A node that runs the tools requested in the last AIMessage. It can be used + either in StateGraph with a "messages" key or in MessageGraph. If multiple + tool calls are requested, they will be run in parallel. The output will be + a list of ToolMessages, one for each tool call. + """ + + def __init__( + self, + tools: Sequence[Union[BaseTool, Callable]], + *, + name: str = "tools", + tags: Optional[list[str]] = None, + ) -> None: + super().__init__(self._func, self._afunc, name=name, tags=tags, trace=False) + self.tools_by_name: Dict[str, BaseTool] = {} + for tool_ in tools: + if not isinstance(tool_, BaseTool): + tool_ = create_tool(tool_) + else: + base_tool_ = tool_ + if hasattr(tool_, "name"): + self.tools_by_name[tool_.name] = base_tool_ + + def _func(self, input: dict[str, Any], config: RunnableConfig) -> Any: + if messages := input.get("messages", []): + output_type = "dict" + message = messages[-1] + else: + raise ValueError("No message found in input") + + if not isinstance(message, AIMessage): + raise ValueError("Last message is not an AIMessage") + + user_id_token = input.get("user_id_token") + + def run_one(call: ToolCall, user_id_token: Optional[str]): + args = copy.copy(call["args"]) or {} + args["user_id_token"] = user_id_token + output = self.tools_by_name[call["name"]].invoke(args, config) + tool_call_id = call.get("id") or str(uuid.uuid4()) + return ToolMessage( + content=str_output(output), name=call["name"], tool_call_id=tool_call_id + ) + + with get_executor_for_config(config) as executor: + outputs = [ + *executor.map(run_one, message.tool_calls, repeat(user_id_token)) + ] + if output_type == "list": + return outputs + else: + return {"messages": outputs} + + async def _afunc(self, input: dict[str, Any], config: RunnableConfig) -> Any: + if messages := input.get("messages", []): + output_type = "dict" + message = messages[-1] + else: + raise ValueError("No message found in input") + + if not isinstance(message, AIMessage): + raise ValueError("Last message is not an AIMessage") + + user_id_token = input.get("user_id_token") + + async def run_one(call: ToolCall, user_id_token: Optional[str]): + args = copy.copy(call["args"]) or {} + args["user_id_token"] = user_id_token + output = await self.tools_by_name[call["name"]].ainvoke(args, config) + tool_call_id = call.get("id") or str(uuid.uuid4()) + return ToolMessage( + content=str_output(output), name=call["name"], tool_call_id=tool_call_id + ) + + outputs = await asyncio.gather( + *(run_one(call, user_id_token) for call in message.tool_calls) + ) + if output_type == "list": + return outputs + else: + return {"messages": outputs} diff --git a/llm_demo/orchestrator/langgraph/tools.py b/llm_demo/orchestrator/langgraph/tools.py new file mode 100644 index 00000000..5661f919 --- /dev/null +++ b/llm_demo/orchestrator/langgraph/tools.py @@ -0,0 +1,437 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +from datetime import date, datetime +from typing import Any, Dict, Optional + +import aiohttp +import google.oauth2.id_token # type: ignore +from google.auth import compute_engine # type: ignore +from google.auth.transport.requests import Request # type: ignore +from langchain_core.tools import StructuredTool +from pydantic.v1 import BaseModel, Field + +BASE_URL = os.getenv("BASE_URL", default="http://127.0.0.1:8080") +CREDENTIALS = None + + +def filter_none_values(params: Dict) -> Dict: + return {key: value for key, value in params.items() if value is not None} + + +def get_id_token(): + global CREDENTIALS + if CREDENTIALS is None: + CREDENTIALS, _ = google.auth.default() + if not hasattr(CREDENTIALS, "id_token"): + # Use Compute Engine default credential + CREDENTIALS = compute_engine.IDTokenCredentials( + request=Request(), + target_audience=BASE_URL, + use_metadata_identity_endpoint=True, + ) + if not CREDENTIALS.valid: + CREDENTIALS.refresh(Request()) + if hasattr(CREDENTIALS, "id_token"): + return CREDENTIALS.id_token + else: + return CREDENTIALS.token + + +def get_headers(client: aiohttp.ClientSession, user_id_token: str): + """Helper method to generate ID tokens for authenticated requests""" + headers = client.headers + headers["User-Id-Token"] = f"Bearer {user_id_token}" + if not "http://" in BASE_URL: + # Append ID Token to make authenticated requests to Cloud Run services + headers["Authorization"] = f"Bearer {get_id_token()}" + return headers + + +# Tools +class AirportSearchInput(BaseModel): + country: Optional[str] = Field(description="Country") + city: Optional[str] = Field(description="City") + name: Optional[str] = Field(description="Airport name") + user_id_token: Optional[str] + + +def generate_search_airports(client: aiohttp.ClientSession): + async def search_airports(country: str, city: str, name: str, user_id_token: str): + params = { + "country": country, + "city": city, + "name": name, + } + response = await client.get( + url=f"{BASE_URL}/airports/search", + params=filter_none_values(params), + headers=get_headers(client, user_id_token), + ) + + response_json = await response.json() + if len(response_json) < 1: + return "There are no airports matching that query. Let the user know there are no results." + else: + return response_json + + return search_airports + + +class FlightNumberInput(BaseModel): + airline: str = Field(description="Airline unique 2 letter identifier") + flight_number: str = Field(description="1 to 4 digit number") + user_id_token: Optional[str] + + +def generate_search_flights_by_number(client: aiohttp.ClientSession): + async def search_flights_by_number( + airline: str, flight_number: str, user_id_token: str + ): + response = await client.get( + url=f"{BASE_URL}/flights/search", + params={"airline": airline, "flight_number": flight_number}, + headers=get_headers(client, user_id_token), + ) + + return await response.json() + + return search_flights_by_number + + +class ListFlightsInput(BaseModel): + departure_airport: Optional[str] = Field( + description="Departure airport 3-letter code", + ) + arrival_airport: Optional[str] = Field(description="Arrival airport 3-letter code") + date: str = Field(description="Date of flight departure") + user_id_token: Optional[str] + + +def generate_list_flights(client: aiohttp.ClientSession): + async def list_flights( + departure_airport: str, + arrival_airport: str, + date: str, + user_id_token: str, + ): + params = { + "departure_airport": departure_airport, + "arrival_airport": arrival_airport, + "date": date, + } + response = await client.get( + url=f"{BASE_URL}/flights/search", + params=filter_none_values(params), + headers=get_headers(client, user_id_token), + ) + + response_json = await response.json() + if len(response_json) < 1: + return "There are no flights matching that query. Let the user know there are no results." + else: + return response_json + + return list_flights + + +class QueryInput(BaseModel): + query: str = Field(description="Search query") + user_id_token: Optional[str] + + +def generate_search_amenities(client: aiohttp.ClientSession): + async def search_amenities(query: str, user_id_token: str): + response = await client.get( + url=f"{BASE_URL}/amenities/search", + params={"top_k": "5", "query": query}, + headers=get_headers(client, user_id_token), + ) + + response = await response.json() + return response + + return search_amenities + + +def generate_search_policies(client: aiohttp.ClientSession): + async def search_policies(query: str, user_id_token: str): + response = await client.get( + url=f"{BASE_URL}/policies/search", + params={"top_k": "5", "query": query}, + headers=get_headers(client, user_id_token), + ) + + response = await response.json() + return response + + return search_policies + + +class TicketInput(BaseModel): + airline: str = Field(description="Airline unique 2 letter identifier") + flight_number: str = Field(description="1 to 4 digit number") + departure_airport: str = Field( + description="Departure airport 3-letter code", + ) + departure_time: datetime = Field(description="Flight departure datetime") + arrival_airport: Optional[str] = Field(description="Arrival airport 3-letter code") + arrival_time: Optional[datetime] = Field(description="Flight arrival datetime") + + +def generate_insert_ticket(client: aiohttp.ClientSession): + async def insert_ticket( + airline: str, + flight_number: str, + departure_airport: str, + arrival_airport: str, + departure_time: datetime, + arrival_time: datetime, + ): + return f"Booking ticket on {airline} {flight_number}" + + return insert_ticket + + +async def insert_ticket( + client: aiohttp.ClientSession, params: dict[str, str], user_id_token: str +): + response = await client.post( + url=f"{BASE_URL}/tickets/insert", + params={ + "airline": params.get("airline"), + "flight_number": params.get("flight_number"), + "departure_airport": params.get("departure_airport"), + "arrival_airport": params.get("arrival_airport"), + "departure_time": params.get("departure_time"), + "arrival_time": params.get("arrival_time"), + }, + headers=get_headers(client, user_id_token), + ) + response = await response.json() + return "Flight booking successful." + + +async def validate_ticket( + client: aiohttp.ClientSession, ticket_info: Dict[Any, Any], user_id_token: str +): + response = await client.get( + url=f"{BASE_URL}/tickets/validate", + params=filter_none_values( + { + "airline": ticket_info.get("airline"), + "flight_number": ticket_info.get("flight_number"), + "departure_airport": ticket_info.get("departure_airport"), + "departure_time": ticket_info.get("departure_time", "").replace( + "T", " " + ), + } + ), + headers=get_headers(client, user_id_token), + ) + response_json = await response.json() + + flight_info = { + "airline": response_json.get("airline"), + "flight_number": response_json.get("flight_number"), + "departure_airport": response_json.get("departure_airport"), + "arrival_airport": response_json.get("arrival_airport"), + "departure_time": response_json.get("departure_time").replace("T", " "), + "arrival_time": response_json.get("arrival_time").replace("T", " "), + } + return flight_info + + +def generate_list_tickets(client: aiohttp.ClientSession): + async def list_tickets(user_id_token: str): + response = await client.get( + url=f"{BASE_URL}/tickets/list", + headers=get_headers(client, user_id_token), + ) + + response_json = await response.json() + return { + "number of tickets booked": len(response_json), + "user's ticket": response_json, + } + + return list_tickets + + +# Tools for agent +async def initialize_tools(client: aiohttp.ClientSession): + return [ + StructuredTool.from_function( + coroutine=generate_search_airports(client), + name="Search Airport", + description=""" + Use this tool to list all airports matching search criteria. + Takes at least one of country, city, name, or all and returns all matching airports. + The agent can decide to return the results directly to the user. + Input of this tool must be in JSON format and include all three inputs - country, city, name. + Example: + {{ + "country": "United States", + "city": "San Francisco", + "name": null + }} + Example: + {{ + "country": null, + "city": "Goroka", + "name": "Goroka" + }} + Example: + {{ + "country": "Mexico", + "city": null, + "name": null + }} + """, + args_schema=AirportSearchInput, + ), + StructuredTool.from_function( + coroutine=generate_search_flights_by_number(client), + name="Search Flights By Flight Number", + description=""" + Use this tool to get information for a specific flight. + Takes an airline code and flight number and returns info on the flight. + Do NOT use this tool with a flight id. Do NOT guess an airline code or flight number. + A airline code is a code for an airline service consisting of two-character + airline designator and followed by flight number, which is 1 to 4 digit number. + For example, if given CY 0123, the airline is "CY", and flight_number is "123". + Another example for this is DL 1234, the airline is "DL", and flight_number is "1234". + If the tool returns more than one option choose the date closes to today. + Example: + {{ + "airline": "CY", + "flight_number": "888", + }} + Example: + {{ + "airline": "DL", + "flight_number": "1234", + }} + """, + args_schema=FlightNumberInput, + ), + StructuredTool.from_function( + coroutine=generate_list_flights(client), + name="List Flights", + description=""" + Use this tool to list flights information matching search criteria. + Takes an arrival airport, a departure airport, or both, filters by date and returns all matching flights. + If 3-letter iata code is not provided for departure_airport or arrival_airport, use search airport tools to get iata code information. + Do NOT guess a date, ask user for date input if it is not given. Date must be in the following format: YYYY-MM-DD. + The agent can decide to return the results directly to the user. + Input of this tool must be in JSON format and include all three inputs - arrival_airport, departure_airport, and date. + Example: + {{ + "departure_airport": "SFO", + "arrival_airport": null, + "date": 2023-10-30" + }} + Example: + {{ + "departure_airport": "SFO", + "arrival_airport": "SEA", + "date": "2023-11-01" + }} + Example: + {{ + "departure_airport": null, + "arrival_airport": "SFO", + "date": "2023-01-01" + }} + """, + args_schema=ListFlightsInput, + ), + StructuredTool.from_function( + coroutine=generate_search_amenities(client), + name="Search Amenities", + description=""" + Use this tool to search amenities by name or to recommended airport amenities at SFO. + If user provides flight info, use 'Search Flights by Flight Number' + first to get gate info and location. + Only recommend amenities that are returned by this query. + Find amenities close to the user by matching the terminal and then comparing + the gate numbers. Gate number iterate by letter and number, example A1 A2 A3 + B1 B2 B3 C1 C2 C3. Gate A3 is close to A2 and B1. + Input of this tool must be in JSON format and include one `query` input. + """, + args_schema=QueryInput, + ), + StructuredTool.from_function( + coroutine=generate_search_policies(client), + name="Search Policies", + description=""" + Use this tool to search for cymbal air passenger policy. + Policy that are listed is unchangeable. + You will not answer any questions outside of the policy given. + Policy includes information on ticket purchase and changes, baggage, check-in and boarding, special assistance, overbooking, flight delays and cancellations. + Input of this tool must be in JSON format and include one `query` input. + """, + args_schema=QueryInput, + ), + StructuredTool.from_function( + coroutine=generate_insert_ticket(client), + name="Insert Ticket", + description=""" + Use this tool to book a flight ticket for the user. + Example: + {{ + "airline": "AA", + "flight_number": "452", + "departure_airport": "LAX", + "arrival_airport": "SFO", + "departure_time": "2024-01-01 05:50:00", + "arrival_time": "2024-01-01 09:23:00" + }} + Example: + {{ + "airline": "UA", + "flight_number": "1532", + "departure_airport": "SFO", + "arrival_airport": "DEN", + "departure_time": "2024-01-08 05:50:00", + "arrival_time": "2024-01-08 09:23:00" + }} + Example: + {{ + "airline": "OO", + "flight_number": "6307", + "departure_airport": "SFO", + "arrival_airport": "MSP", + "departure_time": "2024-10-28 20:13:00", + "arrival_time": "2024-10-28 21:07:00" + }} + """, + args_schema=TicketInput, + ), + StructuredTool.from_function( + coroutine=generate_list_tickets(client), + name="List Tickets", + description=""" + Use this tool to list a user's flight tickets. + Takes no input and returns a list of current user's flight tickets. + Input is always empty JSON blob. Example: {{}} + """, + ), + ] + + +def get_confirmation_needing_tools(): + return ["Insert Ticket"] diff --git a/llm_demo/orchestrator/orchestrator.py b/llm_demo/orchestrator/orchestrator.py index 5f7bd397..244cff75 100644 --- a/llm_demo/orchestrator/orchestrator.py +++ b/llm_demo/orchestrator/orchestrator.py @@ -60,6 +60,10 @@ def get_user_session(self, uuid: str) -> Any: async def user_session_insert_ticket(self, uuid: str, params: str) -> Any: raise NotImplementedError("Subclass should implement this!") + @abstractmethod + async def user_session_decline_ticket(self, uuid: str) -> Optional[dict[str, Any]]: + raise NotImplementedError("Subclass should implement this!") + @abstractmethod async def user_session_signout(self, uuid: str): """Sign out from user session. Clear and restart session.""" diff --git a/llm_demo/orchestrator/vertexai_function_calling/function_calling_orchestrator.py b/llm_demo/orchestrator/vertexai_function_calling/function_calling_orchestrator.py index 6c921c8b..102f42b6 100644 --- a/llm_demo/orchestrator/vertexai_function_calling/function_calling_orchestrator.py +++ b/llm_demo/orchestrator/vertexai_function_calling/function_calling_orchestrator.py @@ -16,7 +16,7 @@ import os import uuid from datetime import datetime -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from aiohttp import ClientSession, TCPConnector from fastapi import HTTPException @@ -192,6 +192,13 @@ async def user_session_insert_ticket(self, uuid: str, params: str) -> Any: response = await user_session.insert_ticket(params) return response + async def user_session_decline_ticket(self, uuid: str) -> Optional[dict[str, Any]]: + """ + Used if there's a process to be done after user decline ticket. + Return None is nothing is needed to be done. + """ + return None + async def user_session_create(self, session: dict[str, Any]): """Create and load an agent executor with tools and LLM.""" print("Initializing agent..") diff --git a/llm_demo/requirements.txt b/llm_demo/requirements.txt index c37b0514..e3063511 100644 --- a/llm_demo/requirements.txt +++ b/llm_demo/requirements.txt @@ -12,3 +12,5 @@ uvicorn[standard]==0.27.0.post1 python-multipart==0.0.7 pytz==2024.1 types-pytz==2024.1.0.20240417 +langgraph==0.1.5 +httpx==0.27.0