Skip to content
22 changes: 21 additions & 1 deletion pydantic_ai_slim/pydantic_ai/_tool_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from dataclasses import dataclass, field, replace
from typing import Any, Generic

import anyio
from opentelemetry.trace import Tracer
from pydantic import ValidationError
from typing_extensions import assert_never
Expand Down Expand Up @@ -35,6 +36,8 @@ class ToolManager(Generic[AgentDepsT]):
"""The cached tools for this run step."""
failed_tools: set[str] = field(default_factory=set)
"""Names of tools that failed in this run step."""
default_timeout: float | None = None
"""Default timeout in seconds for tool execution. None means no timeout."""

@classmethod
@contextmanager
Expand Down Expand Up @@ -62,6 +65,7 @@ async def for_run_step(self, ctx: RunContext[AgentDepsT]) -> ToolManager[AgentDe
toolset=self.toolset,
ctx=ctx,
tools=await self.toolset.get_tools(ctx),
default_timeout=self.default_timeout,
)

@property
Expand Down Expand Up @@ -172,7 +176,23 @@ async def _call_tool(
call.args or {}, allow_partial=pyd_allow_partial, context=ctx.validation_context
)

result = await self.toolset.call_tool(name, args_dict, ctx, tool)
# Determine effective timeout: per-tool timeout takes precedence over default
effective_timeout = tool.timeout if tool.timeout is not None else self.default_timeout

if effective_timeout is not None:
try:
with anyio.fail_after(effective_timeout):
result = await self.toolset.call_tool(name, args_dict, ctx, tool)
except TimeoutError:
m = _messages.RetryPromptPart(
tool_name=name,
content=f"Tool '{name}' timed out after {effective_timeout} seconds. Please try a different approach.",
tool_call_id=call.tool_call_id,
)
self.failed_tools.add(name)
raise ToolRetryError(m) from None
else:
result = await self.toolset.call_tool(name, args_dict, ctx, tool)

return result
except (ValidationError, ModelRetry) as e:
Expand Down
20 changes: 19 additions & 1 deletion pydantic_ai_slim/pydantic_ai/agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
_prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False)
_max_result_retries: int = dataclasses.field(repr=False)
_max_tool_retries: int = dataclasses.field(repr=False)
_tool_timeout: float | None = dataclasses.field(repr=False)
_validation_context: Any | Callable[[RunContext[AgentDepsT]], Any] = dataclasses.field(repr=False)

_event_stream_handler: EventStreamHandler[AgentDepsT] | None = dataclasses.field(repr=False)
Expand Down Expand Up @@ -179,6 +180,7 @@ def __init__(
instrument: InstrumentationSettings | bool | None = None,
history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None,
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
tool_timeout: float | None = None,
) -> None: ...

@overload
Expand Down Expand Up @@ -206,6 +208,7 @@ def __init__(
instrument: InstrumentationSettings | bool | None = None,
history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None,
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
tool_timeout: float | None = None,
) -> None: ...

def __init__(
Expand All @@ -231,6 +234,7 @@ def __init__(
instrument: InstrumentationSettings | bool | None = None,
history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None,
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
tool_timeout: float | None = None,
**_deprecated_kwargs: Any,
):
"""Create an agent.
Expand Down Expand Up @@ -285,6 +289,9 @@ def __init__(
Each processor takes a list of messages and returns a modified list of messages.
Processors can be sync or async and are applied in sequence.
event_stream_handler: Optional handler for events from the model's streaming response and the agent's execution of tools.
tool_timeout: Default timeout in seconds for tool execution. If a tool takes longer than this,
a retry prompt is returned to the model. Individual tools can override this with their own timeout.
Defaults to None (no timeout).
"""
if model is None or defer_model_check:
self._model = model
Expand Down Expand Up @@ -318,6 +325,7 @@ def __init__(

self._max_result_retries = output_retries if output_retries is not None else retries
self._max_tool_retries = retries
self._tool_timeout = tool_timeout

self._validation_context = validation_context

Expand Down Expand Up @@ -569,7 +577,7 @@ async def main():
output_toolset.max_retries = self._max_result_retries
output_toolset.output_validators = output_validators
toolset = self._get_toolset(output_toolset=output_toolset, additional_toolsets=toolsets)
tool_manager = ToolManager[AgentDepsT](toolset)
tool_manager = ToolManager[AgentDepsT](toolset, default_timeout=self._tool_timeout)

# Build the graph
graph = _agent_graph.build_agent_graph(self.name, self._deps_type, output_type_)
Expand Down Expand Up @@ -1031,6 +1039,7 @@ def tool(
sequential: bool = False,
requires_approval: bool = False,
metadata: dict[str, Any] | None = None,
timeout: float | None = None,
) -> Callable[[ToolFuncContext[AgentDepsT, ToolParams]], ToolFuncContext[AgentDepsT, ToolParams]]: ...

def tool(
Expand All @@ -1049,6 +1058,7 @@ def tool(
sequential: bool = False,
requires_approval: bool = False,
metadata: dict[str, Any] | None = None,
timeout: float | None = None,
) -> Any:
"""Decorator to register a tool function which takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument.

Expand Down Expand Up @@ -1098,6 +1108,8 @@ async def spam(ctx: RunContext[str], y: float) -> float:
requires_approval: Whether this tool requires human-in-the-loop approval. Defaults to False.
See the [tools documentation](../deferred-tools.md#human-in-the-loop-tool-approval) for more info.
metadata: Optional metadata for the tool. This is not sent to the model but can be used for filtering and tool behavior customization.
timeout: Timeout in seconds for tool execution. If the tool takes longer, a retry prompt is returned to the model.
Defaults to None (no timeout). Overrides the agent-level `tool_timeout` if set.
"""

def tool_decorator(
Expand All @@ -1118,6 +1130,7 @@ def tool_decorator(
sequential=sequential,
requires_approval=requires_approval,
metadata=metadata,
timeout=timeout,
)
return func_

Expand All @@ -1142,6 +1155,7 @@ def tool_plain(
sequential: bool = False,
requires_approval: bool = False,
metadata: dict[str, Any] | None = None,
timeout: float | None = None,
) -> Callable[[ToolFuncPlain[ToolParams]], ToolFuncPlain[ToolParams]]: ...

def tool_plain(
Expand All @@ -1160,6 +1174,7 @@ def tool_plain(
sequential: bool = False,
requires_approval: bool = False,
metadata: dict[str, Any] | None = None,
timeout: float | None = None,
) -> Any:
"""Decorator to register a tool function which DOES NOT take `RunContext` as an argument.

Expand Down Expand Up @@ -1209,6 +1224,8 @@ async def spam(ctx: RunContext[str]) -> float:
requires_approval: Whether this tool requires human-in-the-loop approval. Defaults to False.
See the [tools documentation](../deferred-tools.md#human-in-the-loop-tool-approval) for more info.
metadata: Optional metadata for the tool. This is not sent to the model but can be used for filtering and tool behavior customization.
timeout: Timeout in seconds for tool execution. If the tool takes longer, a retry prompt is returned to the model.
Defaults to None (no timeout). Overrides the agent-level `tool_timeout` if set.
"""

def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams]:
Expand All @@ -1227,6 +1244,7 @@ def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams
sequential=sequential,
requires_approval=requires_approval,
metadata=metadata,
timeout=timeout,
)
return func_

Expand Down
13 changes: 13 additions & 0 deletions pydantic_ai_slim/pydantic_ai/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ class Tool(Generic[ToolAgentDepsT]):
sequential: bool
requires_approval: bool
metadata: dict[str, Any] | None
timeout: float | None
function_schema: _function_schema.FunctionSchema
"""
The base JSON schema for the tool's parameters.
Expand All @@ -285,6 +286,7 @@ def __init__(
sequential: bool = False,
requires_approval: bool = False,
metadata: dict[str, Any] | None = None,
timeout: float | None = None,
function_schema: _function_schema.FunctionSchema | None = None,
):
"""Create a new tool instance.
Expand Down Expand Up @@ -341,6 +343,8 @@ async def prep_my_tool(
requires_approval: Whether this tool requires human-in-the-loop approval. Defaults to False.
See the [tools documentation](../deferred-tools.md#human-in-the-loop-tool-approval) for more info.
metadata: Optional metadata for the tool. This is not sent to the model but can be used for filtering and tool behavior customization.
timeout: Timeout in seconds for tool execution. If the tool takes longer, a retry prompt is returned to the model.
Defaults to None (no timeout). Overrides the agent-level `tool_timeout` if set.
function_schema: The function schema to use for the tool. If not provided, it will be generated.
"""
self.function = function
Expand All @@ -362,6 +366,7 @@ async def prep_my_tool(
self.sequential = sequential
self.requires_approval = requires_approval
self.metadata = metadata
self.timeout = timeout

@classmethod
def from_schema(
Expand Down Expand Up @@ -417,6 +422,7 @@ def tool_def(self):
strict=self.strict,
sequential=self.sequential,
metadata=self.metadata,
timeout=self.timeout,
kind='unapproved' if self.requires_approval else 'function',
)

Expand Down Expand Up @@ -503,6 +509,13 @@ class ToolDefinition:
For MCP tools, this contains the `meta`, `annotations`, and `output_schema` fields from the tool definition.
"""

timeout: float | None = None
"""Timeout in seconds for tool execution.

If the tool takes longer than this, a retry prompt is returned to the model.
Defaults to None (no timeout). Overrides the agent-level `tool_timeout` if set.
"""

@property
def defer(self) -> bool:
"""Whether calls to this tool will be deferred.
Expand Down
6 changes: 6 additions & 0 deletions pydantic_ai_slim/pydantic_ai/toolsets/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ class ToolsetTool(Generic[AgentDepsT]):

For example, a [`pydantic.TypeAdapter(...).validator`](https://docs.pydantic.dev/latest/concepts/type_adapter/) or [`pydantic_core.SchemaValidator`](https://docs.pydantic.dev/latest/api/pydantic_core/#pydantic_core.SchemaValidator).
"""
timeout: float | None = None
"""Timeout in seconds for tool execution.

If the tool takes longer than this, a retry prompt is returned to the model.
Defaults to None (no timeout). Overrides the agent-level `tool_timeout` if set.
"""


class AbstractToolset(ABC, Generic[AgentDepsT]):
Expand Down
1 change: 1 addition & 0 deletions pydantic_ai_slim/pydantic_ai/toolsets/combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[
args_validator=tool.args_validator,
source_toolset=toolset,
source_tool=tool,
timeout=tool.timeout,
)
return all_tools

Expand Down
10 changes: 10 additions & 0 deletions pydantic_ai_slim/pydantic_ai/toolsets/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def tool(
sequential: bool | None = None,
requires_approval: bool | None = None,
metadata: dict[str, Any] | None = None,
timeout: float | None = None,
) -> Callable[[ToolFuncEither[AgentDepsT, ToolParams]], ToolFuncEither[AgentDepsT, ToolParams]]: ...

def tool(
Expand All @@ -137,6 +138,7 @@ def tool(
sequential: bool | None = None,
requires_approval: bool | None = None,
metadata: dict[str, Any] | None = None,
timeout: float | None = None,
) -> Any:
"""Decorator to register a tool function which takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument.

Expand Down Expand Up @@ -193,6 +195,8 @@ async def spam(ctx: RunContext[str], y: float) -> float:
If `None`, the default value is determined by the toolset.
metadata: Optional metadata for the tool. This is not sent to the model but can be used for filtering and tool behavior customization.
If `None`, the default value is determined by the toolset. If provided, it will be merged with the toolset's metadata.
timeout: Timeout in seconds for tool execution. If the tool takes longer, a retry prompt is returned to the model.
Defaults to None (no timeout). Overrides the agent-level `tool_timeout` if set.
"""

def tool_decorator(
Expand All @@ -213,6 +217,7 @@ def tool_decorator(
sequential=sequential,
requires_approval=requires_approval,
metadata=metadata,
timeout=timeout,
)
return func_

Expand All @@ -233,6 +238,7 @@ def add_function(
sequential: bool | None = None,
requires_approval: bool | None = None,
metadata: dict[str, Any] | None = None,
timeout: float | None = None,
) -> None:
"""Add a function as a tool to the toolset.

Expand Down Expand Up @@ -267,6 +273,8 @@ def add_function(
If `None`, the default value is determined by the toolset.
metadata: Optional metadata for the tool. This is not sent to the model but can be used for filtering and tool behavior customization.
If `None`, the default value is determined by the toolset. If provided, it will be merged with the toolset's metadata.
timeout: Timeout in seconds for tool execution. If the tool takes longer, a retry prompt is returned to the model.
Defaults to None (no timeout). Overrides the agent-level `tool_timeout` if set.
"""
if docstring_format is None:
docstring_format = self.docstring_format
Expand Down Expand Up @@ -295,6 +303,7 @@ def add_function(
sequential=sequential,
requires_approval=requires_approval,
metadata=metadata,
timeout=timeout,
)
self.add_tool(tool)

Expand Down Expand Up @@ -340,6 +349,7 @@ async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[
args_validator=tool.function_schema.validator,
call_func=tool.function_schema.call,
is_async=tool.function_schema.is_async,
timeout=tool_def.timeout,
)
return tools

Expand Down
2 changes: 2 additions & 0 deletions tests/models/test_model_request_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def test_model_request_parameters_are_serializable():
'sequential': False,
'kind': 'function',
'metadata': None,
'timeout': None,
}
],
'builtin_tools': [
Expand Down Expand Up @@ -131,6 +132,7 @@ def test_model_request_parameters_are_serializable():
'sequential': False,
'kind': 'function',
'metadata': None,
'timeout': None,
}
],
'prompted_output_template': None,
Expand Down
2 changes: 2 additions & 0 deletions tests/test_logfire.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,7 @@ async def my_ret(x: int) -> str:
'sequential': False,
'kind': 'function',
'metadata': None,
'timeout': None,
}
],
'builtin_tools': [],
Expand Down Expand Up @@ -994,6 +995,7 @@ class MyOutput:
'sequential': False,
'kind': 'output',
'metadata': None,
'timeout': None,
}
],
'prompted_output_template': None,
Expand Down
Loading