Skip to content

Commit

Permalink
update the code
Browse files Browse the repository at this point in the history
  • Loading branch information
anjieyang committed Oct 31, 2024
1 parent bfd778d commit 3d58477
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 13 deletions.
65 changes: 65 additions & 0 deletions gui/gui_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ def __init__(
log_dir: Path | None = None,
) -> None:
super().__init__(benchmark, task_id, agent_policy, log_dir)
self.display_callback = None

def set_display_callback(self, callback):
self.display_callback = callback

def get_prompt(self):
observation, ob_prompt = self.benchmark.observe_with_prompt()
Expand All @@ -47,3 +51,64 @@ def get_prompt(self):
(marked_screenshot, MessageType.IMAGE_JPG_BASE64),
]
return result_prompt

def step(self, it) -> bool:
if self.display_callback:
self.display_callback(f"Step {self.step_cnt}:", "ai")

prompt = self.get_prompt()
self.log_prompt(prompt)

try:
response = self.agent_policy.chat(prompt)
if self.display_callback:
self.display_callback(f"Planning next action...", "ai")
except Exception as e:
if self.display_callback:
self.display_callback(f"Error: {str(e)}", "ai")
self.write_main_csv_row("agent_exception")
return True

if self.display_callback:
self.display_callback(f"Executing: {response}", "ai")
return self.execute_action(response)

def execute_action(self, response: list[ActionOutput]) -> bool:
for action in response:
benchmark_result = self.benchmark.step(
action=action.name,
parameters=action.arguments,
env_name=action.env,
)
self.metrics = benchmark_result.evaluation_results

if benchmark_result.terminated:
if self.display_callback:
self.display_callback(
f"✓ Task completed! Results: {self.metrics}", "ai"
)
self.write_current_log_row(action)
self.write_current_log_row(benchmark_result.info["terminate_reason"])
return True

if self.display_callback:
self.display_callback(
f'Action "{action.name}" completed in {action.env}. '
f"Progress: {self.metrics}", "ai"
)
self.write_current_log_row(action)
self.step_cnt += 1
return False

def start_benchmark(self):
if self.display_callback:
self.display_callback("Starting benchmark...", "ai")
try:
super().start_benchmark()
except KeyboardInterrupt:
if self.display_callback:
self.display_callback("Experiment interrupted.", "ai")
self.write_main_csv_row("experiment_interrupted")
finally:
if self.display_callback:
self.display_callback("Experiment finished.", "ai")
40 changes: 27 additions & 13 deletions gui/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,19 +53,33 @@ def assign_task():
input_entry.delete(0, "end")
display_message(task_description)

model = get_model_instance(model_dropdown.get())
agent_policy = SingleAgentPolicy(model_backend=model)

task_id = str(uuid4())
benchmark = get_benchmark(task_id, task_description)
experiment = GuiExperiment(
benchmark=benchmark,
task_id=task_id,
agent_policy=agent_policy,
log_dir=log_dir,
)
# TODO: redirect the output to the GUI
experiment.start_benchmark()
try:
model = get_model_instance(model_dropdown.get())
agent_policy = SingleAgentPolicy(model_backend=model)

task_id = str(uuid4())
benchmark = get_benchmark(task_id, task_description)
experiment = GuiExperiment(
benchmark=benchmark,
task_id=task_id,
agent_policy=agent_policy,
log_dir=log_dir,
)

experiment.set_display_callback(display_message)

def run_experiment():
try:
experiment.start_benchmark()
except Exception as e:
display_message(f"Error: {str(e)}", "ai")

import threading
thread = threading.Thread(target=run_experiment, daemon=True)
thread.start()

except Exception as e:
display_message(f"Error: {str(e)}", "ai")


def display_message(message, sender="user"):
Expand Down

0 comments on commit 3d58477

Please sign in to comment.