-
Notifications
You must be signed in to change notification settings - Fork 535
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ability to end agent runs as a result of a tool call #142
base: main
Are you sure you want to change the base?
Changes from 5 commits
8621d28
79b5845
ea16a9f
4e4ab04
d33a463
85526d0
489f25a
88d66e7
c8cb197
d20034c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
from __future__ import annotations as _annotations | ||
|
||
from typing import Any | ||
|
||
|
||
class StopAgentRun(Exception): | ||
"""Exception raised to stop the agent run and use the provided result as the output.""" | ||
|
||
result: Any | ||
"""The value to use as the result of the agent run.""" | ||
tool_name: str | None | ||
"""The name of the tool call, if available.""" | ||
|
||
def __init__(self, result: Any, tool_name: str | None) -> None: | ||
self.result = result | ||
self.tool_name = tool_name |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,19 +3,17 @@ | |
import inspect | ||
from collections.abc import Awaitable | ||
from dataclasses import dataclass, field | ||
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union, cast | ||
from typing import Any, Callable, Generic, NoReturn, Union, cast | ||
|
||
from pydantic import ValidationError | ||
from pydantic_core import SchemaValidator | ||
from typing_extensions import Concatenate, ParamSpec, TypeAlias | ||
from typing_extensions import Concatenate, ParamSpec, TypeAlias, TypeVar | ||
|
||
from . import _pydantic, _utils, messages | ||
from . import _exceptions, _pydantic, _utils, messages | ||
from .exceptions import ModelRetry, UnexpectedModelBehavior | ||
|
||
if TYPE_CHECKING: | ||
from .result import ResultData | ||
else: | ||
ResultData = Any | ||
ResultData = TypeVar('ResultData', default=Any) | ||
"""Type variable for the result data of a run.""" | ||
|
||
|
||
__all__ = ( | ||
|
@@ -38,7 +36,7 @@ | |
|
||
|
||
@dataclass | ||
class RunContext(Generic[AgentDeps]): | ||
class RunContext(Generic[AgentDeps, ResultData]): | ||
"""Information about the current call.""" | ||
|
||
deps: AgentDeps | ||
|
@@ -48,6 +46,11 @@ class RunContext(Generic[AgentDeps]): | |
tool_name: str | None = None | ||
"""Name of the tool being called.""" | ||
|
||
def stop_run(self, result: ResultData) -> NoReturn: | ||
"""Stop the call to `agent.run` as soon as possible, using the provided value as the result.""" | ||
# NOTE: this means we ignore any other tools called concurrently | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. mention in docs what happens about cancelling other tools. Do the return values of functions that finish get added to messages? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I added this comment to match the comments near the result schema things: elif model_response.role == 'model-structured-response':
if self._result_schema is not None:
# if there's a result schema, and any of the calls match one of its tools, return the result
# NOTE: this means we ignore any other tools called here
if match := self._result_schema.find_tool(model_response): Is there a discussion of this in the docs? I didn't see one in a quick search, but if so I can adapt that; if not, I guess we could add a note about that too? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would propose that we wait to update the docs about this until after we merge Jeremiah's PR that adds the "eager" vs. "correct" mode, and support reflect those modes here too |
||
raise _exceptions.StopAgentRun(result, tool_name=self.tool_name) | ||
|
||
|
||
ToolParams = ParamSpec('ToolParams') | ||
"""Retrieval function param spec.""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not correct, and basically should be a type error, but the problem is that
e.result
has typeAny
(it is the value passed toctx.stop_run
). However, I am not sure what's going on here so not sure how to tweak things to fix this. Maybe we can discuss tomorrow.