Skip to content

Commit

Permalink
tau-bench rewrite
Browse files Browse the repository at this point in the history
- Uses [litellm](/~https://github.com/BerriAI/litellm) to provide a
  standard llm interface
- Improves reliability of `react` and `act` agent strategies
- Improved typing
  • Loading branch information
noahshinn committed Sep 3, 2024
1 parent da9678c commit 043b544
Show file tree
Hide file tree
Showing 56 changed files with 14,516 additions and 9,107 deletions.
273 changes: 116 additions & 157 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,39 @@
import json
import random
import argparse
import multiprocessing
import traceback
from math import comb
from typing import Any
import multiprocessing
from typing import List, Dict, Any
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor

from tau_bench.agents.base import BaseAgent
from tau_bench.envs import get_env
from tau_bench.agents.base import Agent
from tau_bench.types import EnvRunResult
from litellm import provider_list


def run(
args: argparse.Namespace,
ckpt_path,
):
ckpt_path: str,
) -> List[EnvRunResult]:
env = get_env(
args.env,
user_mode="naive",
user_strategy="llm",
user_model=args.user_model,
user_provider=args.user_model_provider,
task_split=args.task_split,
)
agent = agent_factory(
tools_info=env.tools_info,
wiki=env.wiki,
args=args,
)
end_index = (
len(env.tasks) if args.end_index == -1 else min(args.end_index, len(env.tasks))
)
results = []
results: List[EnvRunResult] = []
lock = multiprocessing.Lock()
print(
f"Running tasks {args.start_index} to {end_index} (checkpoint path: {ckpt_path})"
Expand All @@ -37,46 +46,40 @@ def run(
if args.shuffle:
random.shuffle(idxs)

def _run(idx: int) -> dict:
def _run(idx: int) -> EnvRunResult:
isolated_env = get_env(
args.env,
user_mode="naive",
user_strategy="llm",
user_model=args.user_model,
task_split=args.task_split,
)
isolated_agent = agent_factory(
tools_info=env.tools_info,
wiki=env.wiki,
args=args,
user_provider=args.user_model_provider,
)

print(f"Running task {idx}")
try:
reward, info = isolated_agent.act(
res = agent.solve(
isolated_env,
idx,
verbose=args.verbose,
temperature=args.temperature,
)
result = {
"task_id": idx,
"reward": reward,
"info": info,
"traj": isolated_agent.get_messages(),
"trial": i,
}
result = EnvRunResult(
task_id=idx,
reward=res.reward,
info=res.info,
traj=[msg.model_dump() for msg in res.messages],
trial=i,
)
except Exception as e:
result = {
"task_id": idx,
"reward": 0,
"info": "Error: " + str(e),
"traj": isolated_agent.get_messages(),
"trial": i,
}
result = EnvRunResult(
task_id=idx,
reward=0,
info={"error": str(e), "traceback": traceback.format_exc()},
traj=[],
trial=i,
)
print(
"✅" if result["reward"] == 1 else "❌",
"✅" if result.reward == 1 else "❌",
f"task_id={idx}",
result["info"],
result.info,
)
print("-----")
with lock:
Expand All @@ -85,7 +88,7 @@ def _run(idx: int) -> dict:
with open(ckpt_path, "r") as f:
data = json.load(f)
with open(ckpt_path, "w") as f:
json.dump(data + [result], f, indent=2)
json.dump(data + [result.model_dump()], f, indent=2)
return result

with ThreadPoolExecutor(max_workers=args.max_concurrency) as executor:
Expand All @@ -95,92 +98,62 @@ def _run(idx: int) -> dict:
return results


def agent_factory(tools_info, wiki, args: argparse.Namespace) -> BaseAgent:
# only add think as a tool for function calling
if not (args.agent_strategy == "function_calling" and args.think):
tools_info = [
tool for tool in tools_info if tool["function"]["name"] != "think"
]

def agent_factory(
tools_info: List[Dict[str, Any]], wiki, args: argparse.Namespace
) -> Agent:
if args.agent_strategy == "function_calling":
if (
"gpt" in args.model
or "mistralai/Mi" in args.model
or "meta-llama/Meta-Llama-3-" in args.model
):
from tau_bench.agents.gpt_function_calling_agent import (
GPTFunctionCallingAgent,
initialize_client,
)

if "gpt" in args.model:
initialize_client()
elif (
"mistralai/Mi" in args.model or "meta-llama/Meta-Llama-3-" in args.model
):
initialize_client(
api_key=os.getenv("ANYSCALE_API_KEY"),
base_url="https://api.endpoints.anyscale.com/v1",
)

return GPTFunctionCallingAgent(tools_info, wiki, model=args.model)
elif "claude" in args.model:
from tau_bench.agents.claude_function_calling_agent import (
ClaudeFunctionCallingAgent,
)

return ClaudeFunctionCallingAgent(tools_info, wiki, model=args.model)
elif "mistral" in args.model or "mixtral" in args.model:
from tau_bench.agents.mistral_function_calling_agent import (
MistralFunctionCallingAgent,
)

return MistralFunctionCallingAgent(tools_info, wiki, model=args.model)
elif "gemini" in args.model:
from tau_bench.agents.gemini_function_calling_agent import (
GeminiFunctionCallingAgent,
)

return GeminiFunctionCallingAgent(tools_info, wiki, model=args.model)
else:
from tau_bench.agents.custom_function_calling_agent import (
CustomFunctionCallingAgent,
)

return CustomFunctionCallingAgent(
tools_info, wiki, model_name_or_path=args.model, num_gpus=args.num_gpus
)
# native function calling
from tau_bench.agents.function_calling_agent import FunctionCallingAgent

return FunctionCallingAgent(
tools_info=tools_info,
wiki=wiki,
model=args.model,
provider=args.model_provider,
temperature=args.temperature,
)
elif args.agent_strategy == "act":
# `act` from https://arxiv.org/abs/2210.03629
from tau_bench.agents.chat_react_agent import ChatReActAgent

return ChatReActAgent(
tools_info=tools_info,
wiki=wiki,
model=args.model,
provider=args.model_provider,
use_reasoning=False,
temperature=args.temperature,
)
elif args.agent_strategy == "react":
from tau_bench.agents.chat_react_agent import ChatReActAgent, initialize_create

if "gpt" in args.model:
initialize_create(mode="openai")
elif "claude" in args.model:
initialize_create(mode="anthropic")
elif "gemini" in args.model:
initialize_create(mode="google")
else: # anyscale
initialize_create(
mode="openai",
api_key=os.getenv("ANYSCALE_API_KEY"),
base_url="https://api.endpoints.anyscale.com/v1",
)
return ChatReActAgent(tools_info, wiki, model=args.model, reason=args.think)
# `react` from https://arxiv.org/abs/2210.03629
from tau_bench.agents.chat_react_agent import ChatReActAgent

return ChatReActAgent(
tools_info=tools_info,
wiki=wiki,
model=args.model,
provider=args.model_provider,
use_reasoning=True,
temperature=args.temperature,
)
else:
raise ValueError(f"Unknown agent strategy: {args.agent_strategy}")


def display_metrics(results: dict[str, Any]) -> None:
num_trials = len(set([r["trial"] for r in results]))
rewards = [r["reward"] for r in results]
def display_metrics(results: List[EnvRunResult]) -> None:
def is_successful(reward: float) -> bool:
return (1 - 1e-6) <= reward <= (1 + 1e-6)

num_trials = len(set([r.trial for r in results]))
rewards = [r.reward for r in results]
avg_reward = sum(rewards) / len(rewards)
# c from https://arxiv.org/pdf/2406.12045
c_per_task_id: dict[int, int] = {}
for r in results:
if r["task_id"] not in c_per_task_id:
c_per_task_id[r["task_id"]] = r["reward"]
for result in results:
if result.task_id not in c_per_task_id:
c_per_task_id[result.task_id] = 1 if is_successful(result.reward) else 0
else:
c_per_task_id[r["task_id"]] += r["reward"]
c_per_task_id[result.task_id] += 1 if is_successful(result.reward) else 0
pass_hat_ks: dict[int, float] = {}
for k in range(1, num_trials + 1):
sum_task_pass_hat_k = 0
Expand All @@ -195,71 +168,57 @@ def display_metrics(results: dict[str, Any]) -> None:

def main():
parser = argparse.ArgumentParser()
parser.add_argument("--num_trials", type=int, default=1)
parser.add_argument("--num-trials", type=int, default=1)
parser.add_argument(
"--env", type=str, choices=["retail", "airline"], default="retail"
)
parser.add_argument(
"--model",
type=str,
help="The model to use for the agent",
)
parser.add_argument(
"--model-provider",
type=str,
choices=provider_list,
help="The model provider for the agent",
)
parser.add_argument(
"--user-model",
type=str,
default="gpt-4o",
choices=[
# openai api models
"gpt-4-turbo",
"gpt-4-0125-preview",
"gpt-4-1106-preview",
"gpt-4-32k-0613",
"gpt-3.5-turbo",
"gpt-3.5-turbo-1106",
"gpt-3.5-turbo-0125",
"gpt-4o",
"gpt-4o-mini",
# anthropic api models
"claude-3-5-sonnet-20240620",
"claude-3-opus-20240229",
"claude-3-sonnet-20240229",
"claude-3-haiku-20240307",
"claude-3-5-sonnet-20240620",
# google api models
"gemini-1.5-pro-latest",
"gemini-1.5-flash-latest",
"gemini-1.0-pro",
# mistral api models,
"open-mixtral-8x22b",
"mistral-large-latest",
# anyscale api models
"meta-llama/Meta-Llama-3-8B-Instruct",
"meta-llama/Meta-Llama-3-70B-Instruct",
"mistralai/Mistral-7B-Instruct-v0.1",
"mistralai/Mixtral-8x7B-Instruct-v0.1",
"mistralai/Mixtral-8x22B-Instruct-v0.1",
],
help="The model to use for the user simulator",
)
parser.add_argument(
"--user_model",
"--user-model-provider",
type=str,
default="gpt-4",
choices=provider_list,
help="The model provider for the user simulator",
)
parser.add_argument(
"--agent_strategy",
"--agent-strategy",
type=str,
default="function_calling",
choices=["function_calling", "react"],
choices=["function_calling", "act", "react"],
)
parser.add_argument("--temperature", type=float, default=0.0)
parser.add_argument(
"--task_split", type=str, default="test", choices=["train", "test", "dev"]
"--temperature",
type=float,
default=0.0,
help="The sampling temperature for the action model",
)
parser.add_argument(
"--think", type=int, default=0, help="Add think for function calling"
"--task-split",
type=str,
default="test",
choices=["train", "test", "dev"],
help="The split of tasks to run (only applies to the retail domain for now",
)
parser.add_argument("--start_index", type=int, default=0)
parser.add_argument("--end_index", type=int, default=-1, help="Run all tasks if -1")
parser.add_argument("--verbose", action="store_true", default=False)
parser.add_argument("--log_dir", type=str, default="results")
parser.add_argument("--num_gpus", type=int, default=None)
parser.add_argument("--start-index", type=int, default=0)
parser.add_argument("--end-index", type=int, default=-1, help="Run all tasks if -1")
parser.add_argument("--log-dir", type=str, default="results")
parser.add_argument(
"--max_concurrency",
"--max-concurrency",
type=int,
default=1,
help="Number of tasks to run in parallel",
Expand All @@ -272,7 +231,7 @@ def main():
random.seed(args.seed)

time_str = datetime.now().strftime("%m%d%H%M%S")
file_str = f"{args.log_dir}/{args.agent_strategy}{args.think}-{args.model.split('/')[-1]}-{args.temperature}_range_{args.start_index}-{args.end_index}_user{args.user_model}_{time_str}.json"
file_str = f"{args.log_dir}/{args.agent_strategy}-{args.model.split('/')[-1]}-{args.temperature}_range_{args.start_index}-{args.end_index}_user{args.user_model}_{time_str}.json"

if not os.path.exists(args.log_dir):
os.makedirs(args.log_dir)
Expand All @@ -285,7 +244,7 @@ def main():
display_metrics(results)

with open(file_str, "w") as f:
json.dump(results, f, indent=2)
json.dump([result.model_dump() for result in results], f, indent=2)
print(f"\n📄 Results saved to {file_str}\n")


Expand Down
3 changes: 3 additions & 0 deletions tau_bench/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
# Copyright Sierra

from tau_bench.envs.base import Env as Env
from tau_bench.agents.base import Agent as Agent
Loading

0 comments on commit 043b544

Please sign in to comment.