diff --git a/llm_demo/evaluation.cloudbuild.yaml b/llm_demo/evaluation.cloudbuild.yaml new file mode 100644 index 00000000..7a5136a1 --- /dev/null +++ b/llm_demo/evaluation.cloudbuild.yaml @@ -0,0 +1,49 @@ +# Copyright 2024 Google LLC +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +steps: + - id: Install dependencies + name: python:3.11 + dir: llm_demo + script: pip install -r requirements.txt -r requirements-test.txt --user + + - id: "Run evaluation service" + name: python:3.11 + dir: llm_demo + env: # Set env var expected by tests + - "ORCHESTRATION_TYPE=${_ORCHESTRATION_TYPE}" + - "RETRIEVAL_EXPERIMENT_NAME=${_RETRIEVAL_EXPERIMENT_NAME}" + - "RESPONSE_EXPERIMENT_NAME=${_RESPONSE_EXPERIMENT_NAME}" + secretEnv: + - CLIENT_ID + - BASE_URL + script: | + #!/usr/bin/env bash + python run_evaluation.py + +serviceAccount: "projects/$PROJECT_ID/serviceAccounts/evaluation-testing@retrieval-app-testing.iam.gserviceaccount.com" # Necessary for ID token creation +options: + logging: CLOUD_LOGGING_ONLY # Necessary for custom service account + dynamic_substitutions: true + +substitutions: + _ORCHESTRATION_TYPE: "langchain-tools" + _RETRIEVAL_EXPERIMENT_NAME: "retrieval-phase-eval-${_PR_NUMBER}" + _RESPONSE_EXPERIMENT_NAME: "response-phase-eval-${_PR_NUMBER}" + +availableSecrets: + secretManager: + - versionName: projects/$PROJECT_ID/secrets/client_id/versions/latest + env: CLIENT_ID + - versionName: projects/$PROJECT_ID/secrets/retrieval_url/versions/latest + env: BASE_URL diff --git a/llm_demo/evaluation/__init__.py b/llm_demo/evaluation/__init__.py new file mode 100644 index 00000000..cb34a895 --- /dev/null +++ b/llm_demo/evaluation/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .eval_golden import goldens +from .evaluation import ( + evaluate_response_phase, + evaluate_retrieval_phase, + run_llm_for_eval, +) + +__ALL__ = [ + "run_llm_for_eval", + "goldens", + "evaluate_retrieval_phase", + "evaluate_response_phase", +] diff --git a/llm_demo/evaluation/eval_golden.py b/llm_demo/evaluation/eval_golden.py new file mode 100644 index 00000000..bd5c445b --- /dev/null +++ b/llm_demo/evaluation/eval_golden.py @@ -0,0 +1,315 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import datetime, timedelta +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, Field +from pytz import timezone + + +class ToolCall(BaseModel): + """ + Represents tool call by orchestration. + """ + + name: str + arguments: Dict[str, Any] = Field( + default={}, description="Query arguments for tool call" + ) + + +class EvalData(BaseModel): + """ + Evaluation data model. + This model represents the information needed for running rapid evaluation with Vertex AI. + """ + + category: Optional[str] = Field(default=None, description="Evaluation category") + query: Optional[str] = Field(default=None, description="User query") + instruction: Optional[str] = Field( + default=None, description="Instruction to llm system" + ) + content: Optional[str] = Field( + default=None, + description="Used in tool call evaluation. Content value is the text output from the model.", + ) + tool_calls: List[ToolCall] = Field( + default=[], description="Golden tool call for evaluation" + ) + context: Optional[List[Dict[str, Any] | List[Dict[str, Any]]]] = Field( + default=None, description="Context given to llm in order to answer user query" + ) + output: Optional[str] = Field( + default=None, description="Golden output for evaluation" + ) + prediction_tool_calls: List[ToolCall] = Field( + default=[], description="Tool call output from LLM" + ) + prediction_output: str = Field(default="", description="Final output from LLM") + reset: bool = Field( + default=True, description="Determine to reset the chat after invoke" + ) + + +def get_date(day_delta: int): + DATE_FORMATTER = "%Y-%m-%d" + retrieved_date = datetime.now(timezone("US/Pacific")) + timedelta(days=day_delta) + return retrieved_date.strftime(DATE_FORMATTER) + + +goldens = [ + EvalData( + category="Search Airport Tool", + query="What is the airport located in San Francisco?", + tool_calls=[ + ToolCall( + name="Search Airport", + arguments={"country": "United States", "city": "San Francisco"}, + ), + ], + ), + EvalData( + category="Search Airport Tool", + query="Tell me more about Denver International Airport?", + tool_calls=[ + ToolCall( + name="Search Airport", + arguments={ + "country": "United States", + "city": "Denver", + "name": "Denver International Airport", + }, + ), + ], + ), + EvalData( + category="Search Flights By Flight Number Tool", + query="What is the departure gate for flight CY 922?", + tool_calls=[ + ToolCall( + name="Search Flights By Flight Number", + arguments={ + "airline": "CY", + "flight_number": "922", + }, + ), + ], + ), + EvalData( + category="Search Flights By Flight Number Tool", + query="What is flight CY 888 flying to?", + tool_calls=[ + ToolCall( + name="Search Flights By Flight Number", + arguments={ + "airline": "CY", + "flight_number": "888", + }, + ), + ], + ), + EvalData( + category="List Flights Tool", + query="What flights are headed to JFK tomorrow?", + tool_calls=[ + ToolCall( + name="List Flights", + arguments={ + "arrival_airport": "JFK", + "date": f"{get_date(1)}", + }, + ), + ], + ), + EvalData( + category="List Flights Tool", + query="Is there any flight from SFO to DEN?", + output="I will need the date to retrieve relevant flights.", + ), + EvalData( + category="Search Amenities Tool", + query="Are there any luxury shops?", + tool_calls=[ + ToolCall( + name="Search Amenities", + arguments={ + "query": "luxury shops", + }, + ), + ], + ), + EvalData( + category="Search Amenities Tool", + query="Where can I get coffee near gate A6?", + tool_calls=[ + ToolCall( + name="Search Amenities", + arguments={ + "query": "coffee near gate A6", + }, + ), + ], + ), + EvalData( + category="Search Policies Tool", + query="What is the flight cancellation policy?", + tool_calls=[ + ToolCall( + name="Search Policies", + arguments={ + "query": "flight cancellation policy", + }, + ), + ], + ), + EvalData( + category="Search Policies Tool", + query="How many checked bags can I bring?", + tool_calls=[ + ToolCall( + name="Search Policies", + arguments={ + "query": "checked baggage allowance", + }, + ), + ], + ), + EvalData( + category="Insert Ticket", + query="I would like to book flight CY 922 departing from SFO on 2024-01-01 at 6:38am.", + tool_calls=[ + ToolCall( + name="Insert Ticket", + arguments={ + "airline": "CY", + "flight_number": "922", + "departure_airport": "SFO", + "departure_time": "2024-01-01 06:38:00", + }, + ), + ], + ), + EvalData( + category="Insert Ticket", + query="What flights are headed from SFO to DEN on January 1 2024?", + tool_calls=[ + ToolCall( + name="List Flights", + arguments={ + "departure_airport": "SFO", + "arrival_airport": "DEN", + "date": "2024-01-01", + }, + ), + ], + reset=False, + ), + EvalData( + category="Insert Ticket", + query="I would like to book the first flight.", + tool_calls=[ + ToolCall( + name="Insert Ticket", + arguments={ + "airline": "UA", + "flight_number": "1532", + "departure_airport": "SFO", + "arrival_airport": "DEN", + "departure_time": "2024-01-01 05:50:00", + "arrival_time": "2024-01-01 09:23:00", + }, + ), + ], + ), + EvalData( + category="List Tickets", + query="Do I have any tickets?", + tool_calls=[ToolCall(name="List Tickets")], + ), + EvalData( + category="List Tickets", + query="When is my next flight?", + tool_calls=[ToolCall(name="List Tickets")], + ), + EvalData( + category="Airline Related Question", + query="What is Cymbal Air?", + output="Cymbal Air is a passenger airline offering convenient flights to many cities around the world from its hub in San Francisco.", + ), + EvalData( + category="Airline Related Question", + query="Where is the hub of cymbal air?", + output="The hub of Cymbal Air is in San Francisco.", + ), + EvalData( + category="Assistant Related Question", + query="What can you help me with?", + output="I can help to book flights and answer a wide range of questions pertaining to travel on Cymbal Air, as well as amenities of San Francisco Airport.", + ), + EvalData( + category="Assistant Related Question", + query="Can you help me book tickets?", + output="Yes, I can help with several tools such as search airports, list tickets, book tickets.", + ), + EvalData( + category="Out-Of-Context Question", + query="Can you help me solve math problems?", + output="Sorry, I am not given the tools for this.", + ), + EvalData( + category="Out-Of-Context Question", + query="Who is the CEO of Google?", + output="Sorry, I am not given the tools for this.", + ), + EvalData( + category="Multitool Selections", + query="Where can I get a snack near the gate for flight CY 352?", + tool_calls=[ + ToolCall( + name="Search Flights By Flight Number", + arguments={ + "airline": "CY", + "flight_number": "352", + }, + ), + ToolCall( + name="Search Amenities", + arguments={ + "query": "snack near gate A2.", + }, + ), + ], + ), + EvalData( + category="Multitool Selections", + query="What are some flights from SFO to Chicago tomorrow?", + tool_calls=[ + ToolCall( + name="Search Airport", + arguments={ + "city": "Chicago", + }, + ), + ToolCall( + name="List Flights", + arguments={ + "departure_airport": "SFO", + "arrival_airport": "ORD", + "date": f"{get_date(1)}", + }, + ), + ], + ), +] diff --git a/llm_demo/evaluation/evaluation.py b/llm_demo/evaluation/evaluation.py new file mode 100644 index 00000000..7772fc9d --- /dev/null +++ b/llm_demo/evaluation/evaluation.py @@ -0,0 +1,166 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import json +from typing import Dict, List + +import pandas as pd +from pydantic import BaseModel, Field +from vertexai.preview.evaluation import EvalTask # type: ignore +from vertexai.preview.evaluation import _base as evaluation_base + +from orchestrator import BaseOrchestrator + +from .eval_golden import EvalData, ToolCall + + +async def run_llm_for_eval( + eval_list: List[EvalData], orc: BaseOrchestrator, session: Dict, session_id: str +) -> List[EvalData]: + """ + Generate prediction_tool_calls and prediction_output for golden dataset query. + """ + agent = orc.get_user_session(session_id) + for eval_data in eval_list: + try: + query_response = await agent.invoke(eval_data.query) + except Exception as e: + print(f"error invoking agent: {e}") + else: + eval_data.prediction_output = query_response.get("output") + + # Retrieve prediction_tool_calls from query response + prediction_tool_calls = [] + contexts = [] + for step in query_response.get("intermediate_steps"): + called_tool = step[0] + tool_call = ToolCall( + name=called_tool.tool, + arguments=called_tool.tool_input, + ) + prediction_tool_calls.append(tool_call) + context = step[-1] + contexts.append(context) + + eval_data.prediction_tool_calls = prediction_tool_calls + eval_data.context = contexts + + if eval_data.reset: + orc.user_session_reset(session, session_id) + return eval_list + + +def evaluate_retrieval_phase( + eval_datas: List[EvalData], experiment_name: str +) -> evaluation_base.EvalResult: + """ + Run evaluation for the ability of a model to select the right tool and arguments (retrieval phase). + """ + metrics = ["tool_call_quality"] + # Prepare evaluation task input + responses = [] + references = [] + for e in eval_datas: + references.append( + json.dumps( + { + "content": e.content, + "tool_calls": [t.model_dump() for t in e.tool_calls], + } + ) + ) + responses.append( + json.dumps( + { + "content": e.content, + "tool_calls": [t.model_dump() for t in e.prediction_tool_calls], + } + ) + ) + eval_dataset = pd.DataFrame( + { + "response": responses, + "reference": references, + } + ) + # Run evaluation + eval_result = EvalTask( + dataset=eval_dataset, + metrics=metrics, + experiment=experiment_name, + ).evaluate() + return eval_result + + +def evaluate_response_phase( + eval_datas: List[EvalData], experiment_name: str +) -> evaluation_base.EvalResult: + """ + Run evaluation for the ability of a model to generate a response based on the context given (response phase). + """ + metrics = [ + "text_generation_quality", + "text_generation_factuality", + "summarization_pointwise_reference_free", + "qa_pointwise_reference_free", + ] + # Prepare evaluation task input + instructions = [] + contexts = [] + responses = [] + + for e in eval_datas: + instructions.append( + f"Answer user query based on context given. User query is {e.query}." + ) + context_str = ( + [json.dumps(c) for c in e.context] if e.context else ["no data retrieved"] + ) + contexts.append(PROMPT + ", " + ", ".join(context_str)) + responses.append(e.prediction_output or "") + eval_dataset = pd.DataFrame( + { + "instruction": instructions, + "context": contexts, + "response": responses, + } + ) + # Run evaluation + eval_result = EvalTask( + dataset=eval_dataset, + metrics=metrics, + experiment=experiment_name, + ).evaluate() + return eval_result + + +PROMPT = """The Cymbal Air Customer Service Assistant helps customers of Cymbal Air with their travel needs. + +Cymbal Air (airline unique two letter identifier as CY) is a passenger airline offering convenient flights to many cities around the world from its +hub in San Francisco. Cymbal Air takes pride in using the latest technology to offer the best customer +service! + +Cymbal Air Customer Service Assistant (or just "Assistant" for short) is designed to assist +with a wide range of tasks, from answering simple questions to complex multi-query questions that +require passing results from one query to another. Using the latest AI models, Assistant is able to +generate human-like text based on the input it receives, allowing it to engage in natural-sounding +conversations and provide responses that are coherent and relevant to the topic at hand. The assistant should +not answer questions about other peoples information for privacy reasons. + +Assistant is a powerful tool that can help answer a wide range of questions pertaining to travel on Cymbal Air +as well as ammenities of San Francisco Airport. + +Answer user query based on context or information given. +""" diff --git a/llm_demo/requirements.txt b/llm_demo/requirements.txt index 25e99ad7..1446de08 100644 --- a/llm_demo/requirements.txt +++ b/llm_demo/requirements.txt @@ -1,5 +1,5 @@ fastapi==0.109.2 -google-cloud-aiplatform==1.60.0 +google-cloud-aiplatform[rapid_evaluation]==1.62.0 google-auth==2.32.0 itsdangerous==2.2.0 jinja2==3.1.4 @@ -14,3 +14,5 @@ pytz==2024.1 types-pytz==2024.1.0.20240417 langgraph==0.1.16 httpx==0.27.0 +pandas-stubs==2.2.2.240603 +pandas==2.1.4 diff --git a/llm_demo/run_evaluation.py b/llm_demo/run_evaluation.py new file mode 100644 index 00000000..ddde259d --- /dev/null +++ b/llm_demo/run_evaluation.py @@ -0,0 +1,91 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import os +import uuid + +import pandas as pd +from google.auth.transport.requests import Request +from google.oauth2.id_token import fetch_id_token + +from evaluation import ( + evaluate_response_phase, + evaluate_retrieval_phase, + goldens, + run_llm_for_eval, +) +from orchestrator import createOrchestrator + + +def export_metrics_table_csv(retrieval: pd.DataFrame, response: pd.DataFrame): + """ + Export detailed metrics table to csv file + """ + retrieval.to_csv("retrieval_eval.csv") + response.to_csv("response_eval.csv") + + +def fetch_user_id_token(client_id: str): + request = Request() + user_id_token = fetch_id_token(request, client_id) + return user_id_token + + +async def main(): + # allow user to set USER_ID_TOKEN directly on env var + USER_ID_TOKEN = os.getenv("USER_ID_TOKEN", default=None) + + CLIENT_ID = os.getenv("CLIENT_ID", default="") + ORCHESTRATION_TYPE = os.getenv("ORCHESTRATION_TYPE", default="langchain-tools") + EXPORT_CSV = bool(os.getenv("EXPORT_CSV", default=False)) + RETRIEVAL_EXPERIMENT_NAME = os.getenv( + "RETRIEVAL_EXPERIMENT_NAME", default="retrieval-phase-eval" + ) + RESPONSE_EXPERIMENT_NAME = os.getenv( + "RESPONSE_EXPERIMENT_NAME", default="response-phase-eval" + ) + + # Prepare orchestrator and session + orc = createOrchestrator(ORCHESTRATION_TYPE) + session_id = str(uuid.uuid4()) + session = {"uuid": session_id} + await orc.user_session_create(session) + + # Retrieve and set user id token for auth + if USER_ID_TOKEN: + user_id_token = USER_ID_TOKEN + else: + user_id_token = fetch_user_id_token(CLIENT_ID) + orc.set_user_session_header(session_id, user_id_token) + + # Run evaluation + eval_lists = await run_llm_for_eval(goldens, orc, session, session_id) + retrieval_eval_results = evaluate_retrieval_phase( + eval_lists, RETRIEVAL_EXPERIMENT_NAME + ) + response_eval_results = evaluate_response_phase( + eval_lists, RESPONSE_EXPERIMENT_NAME + ) + print(f"Retrieval phase eval results: {retrieval_eval_results.summary_metrics}") + print(f"Response phase eval results: {response_eval_results.summary_metrics}") + + if EXPORT_CSV: + export_metrics_table_csv( + retrieval_eval_results.metrics_table, response_eval_results.metrics_table + ) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/retrieval_service/app/routes.py b/retrieval_service/app/routes.py index f11199c0..7d62470b 100644 --- a/retrieval_service/app/routes.py +++ b/retrieval_service/app/routes.py @@ -47,9 +47,9 @@ async def get_user_info(request): ) return { - "user_id": id_info["sub"], - "user_name": id_info["name"], - "user_email": id_info["email"], + "user_id": id_info.get("sub"), + "user_name": id_info.get("name"), + "user_email": id_info.get("email"), } except Exception as e: # pylint: disable=broad-except