Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to end agent runs as a result of a tool call #142

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions pydantic_ai_slim/pydantic_ai/_exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from __future__ import annotations as _annotations

from typing import Any


class StopAgentRun(Exception):
"""Exception raised to stop the agent run and use the provided result as the output."""

result: Any
"""The value to use as the result of the agent run."""
tool_name: str | None
"""The name of the tool call, if available."""

def __init__(self, result: Any, tool_name: str | None) -> None:
self.result = result
self.tool_name = tool_name
21 changes: 15 additions & 6 deletions pydantic_ai_slim/pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing_extensions import assert_never

from . import (
_exceptions,
_result,
_system_prompt,
_utils,
Expand Down Expand Up @@ -787,9 +788,13 @@ async def _handle_model_response(
else:
messages.append(self._unknown_tool(call.tool_name))

with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
task_results: Sequence[_messages.Message] = await asyncio.gather(*tasks)
messages.extend(task_results)
with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]) as span:
try:
task_results: Sequence[_messages.Message] = await asyncio.gather(*tasks)
messages.extend(task_results)
except _exceptions.StopAgentRun as e:
span.set_attribute('stop_agent_run', e.tool_name)
return _MarkFinalResult(data=e.result), []
return None, messages
else:
assert_never(model_response)
Expand Down Expand Up @@ -854,9 +859,13 @@ async def _handle_streamed_model_response(
else:
messages.append(self._unknown_tool(call.tool_name))

with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
task_results: Sequence[_messages.Message] = await asyncio.gather(*tasks)
messages.extend(task_results)
with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]) as span:
try:
task_results: Sequence[_messages.Message] = await asyncio.gather(*tasks)
messages.extend(task_results)
except _exceptions.StopAgentRun as e:
span.set_attribute('stop_agent_run', e.tool_name)
return _MarkFinalResult(data=e.result), []
Copy link
Contributor Author

@dmontagu dmontagu Dec 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not correct, and basically should be a type error, but the problem is that e.result has type Any (it is the value passed to ctx.stop_run). However, I am not sure what's going on here so not sure how to tweak things to fix this. Maybe we can discuss tomorrow.

return None, messages

async def _validate_result(
Expand Down
19 changes: 11 additions & 8 deletions pydantic_ai_slim/pydantic_ai/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,17 @@
import inspect
from collections.abc import Awaitable
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union, cast
from typing import Any, Callable, Generic, NoReturn, Union, cast

from pydantic import ValidationError
from pydantic_core import SchemaValidator
from typing_extensions import Concatenate, ParamSpec, TypeAlias
from typing_extensions import Concatenate, ParamSpec, TypeAlias, TypeVar

from . import _pydantic, _utils, messages
from . import _exceptions, _pydantic, _utils, messages
from .exceptions import ModelRetry, UnexpectedModelBehavior

if TYPE_CHECKING:
from .result import ResultData
else:
ResultData = Any
ResultData = TypeVar('ResultData', default=Any)
"""Type variable for the result data of a run."""


__all__ = (
Expand All @@ -38,7 +36,7 @@


@dataclass
class RunContext(Generic[AgentDeps]):
class RunContext(Generic[AgentDeps, ResultData]):
"""Information about the current call."""

deps: AgentDeps
Expand All @@ -48,6 +46,11 @@ class RunContext(Generic[AgentDeps]):
tool_name: str | None = None
"""Name of the tool being called."""

def stop_run(self, result: ResultData) -> NoReturn:
"""Stop the call to `agent.run` as soon as possible, using the provided value as the result."""
# NOTE: this means we ignore any other tools called concurrently
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mention in docs what happens about cancelling other tools. Do the return values of functions that finish get added to messages?

Copy link
Contributor Author

@dmontagu dmontagu Dec 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mention in docs what happens about cancelling other tools.

I added this comment to match the comments near the result schema things:

        elif model_response.role == 'model-structured-response':
            if self._result_schema is not None:
                # if there's a result schema, and any of the calls match one of its tools, return the result
                # NOTE: this means we ignore any other tools called here
                if match := self._result_schema.find_tool(model_response):

Is there a discussion of this in the docs? I didn't see one in a quick search, but if so I can adapt that; if not, I guess we could add a note about that too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would propose that we wait to update the docs about this until after we merge Jeremiah's PR that adds the "eager" vs. "correct" mode, and support reflect those modes here too

raise _exceptions.StopAgentRun(result, tool_name=self.tool_name)


ToolParams = ParamSpec('ToolParams')
"""Retrieval function param spec."""
Expand Down
9 changes: 9 additions & 0 deletions tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,3 +466,12 @@ def ctx_tool(ctx: RunContext[int], x: int) -> int:
""")
result = mod.agent.run_sync('foobar', deps=5)
assert result.data == snapshot('{"ctx_tool":5}')

def test_ctx_stop_run():
"""Ensure the ctx_stop_run_tool is used to complete the agent run."""
def ctx_stop_run_tool(ctx: RunContext[int]):
ctx.stop_run(ctx.deps * 'abc')

agent = Agent('test', result_type=str, tools=[Tool(ctx_stop_run_tool, takes_ctx=True)], deps_type=int)
result = agent.run_sync('foobar', deps=2)
assert result.data == snapshot('abcabc')
Loading