Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions haystack/components/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def __init__(
exit_conditions: Optional[list[str]] = None,
state_schema: Optional[dict[str, Any]] = None,
max_agent_steps: int = 100,
final_answer_on_max_steps: bool = True,
streaming_callback: Optional[StreamingCallbackT] = None,
raise_on_tool_invocation_failure: bool = False,
tool_invoker_kwargs: Optional[dict[str, Any]] = None,
Expand All @@ -171,6 +172,10 @@ def __init__(
:param state_schema: The schema for the runtime state used by the tools.
:param max_agent_steps: Maximum number of steps the agent will run before stopping. Defaults to 100.
If the agent exceeds this number of steps, it will stop and return the current state.
:param final_answer_on_max_steps: If True, generates a final text response when max_agent_steps
is reached and the last message is a tool result. This ensures the agent always returns a
natural language response instead of raw tool output. Adds one additional LLM call that doesn't
count toward max_agent_steps. Defaults to True.
:param streaming_callback: A callback that will be invoked when a response is streamed from the LLM.
The same callback can be configured to emit tool results when a tool is called.
:param raise_on_tool_invocation_failure: Should the agent raise an exception when a tool invocation fails?
Expand Down Expand Up @@ -213,6 +218,7 @@ def __init__(
self.system_prompt = system_prompt
self.exit_conditions = exit_conditions
self.max_agent_steps = max_agent_steps
self.final_answer_on_max_steps = final_answer_on_max_steps
self.raise_on_tool_invocation_failure = raise_on_tool_invocation_failure
self.streaming_callback = streaming_callback

Expand Down Expand Up @@ -520,6 +526,90 @@ def _check_tool_invoker_breakpoint(
llm_messages=execution_context.state.data["messages"][-1:], pipeline_snapshot=pipeline_snapshot
)

def _generate_final_answer(self, exe_context: _ExecutionContext, span) -> None:
"""Generate a final text response when max steps is reached with a tool result as last message."""
if not self.final_answer_on_max_steps or not exe_context.state.data.get("messages"):
return

last_msg = exe_context.state.data["messages"][-1]
if not last_msg.tool_call_result:
return

try:
logger.info("Generating final text response after max steps reached.")

# Add system message for context
final_prompt = ChatMessage.from_system(
"You have reached the maximum number of reasoning steps. "
"Based on the information gathered so far, provide a final answer "
"to the user's question. Tools are no longer available."
)

# Make final call with tools disabled
final_inputs = {k: v for k, v in exe_context.chat_generator_inputs.items() if k != "tools"}
final_result = self.chat_generator.run(
messages=exe_context.state.data["messages"] + [final_prompt], tools=[], **final_inputs
)

# Append final response
if final_result and "replies" in final_result:
for msg in final_result["replies"]:
exe_context.state.data["messages"].append(msg)

span.set_tag("haystack.agent.final_answer_generated", True)

except Exception as e:
logger.warning(
"Failed to generate final answer: {error}. Returning with tool result as last message.", error=str(e)
)
span.set_tag("haystack.agent.final_answer_failed", True)

async def _generate_final_answer_async(self, exe_context: _ExecutionContext, span) -> None:
"""
Async version: Generate a final text response when max steps is reached with tool result as last message.
"""
if not self.final_answer_on_max_steps or not exe_context.state.data.get("messages"):
return

last_msg = exe_context.state.data["messages"][-1]
if not last_msg.tool_call_result:
return

try:
logger.info("Generating final text response after max steps reached.")

# Add system message for context
final_prompt = ChatMessage.from_system(
"You have reached the maximum number of reasoning steps. "
"Based on the information gathered so far, provide a final answer "
"to the user's question. Tools are no longer available."
)

# Make final call with tools disabled using AsyncPipeline
final_inputs = {k: v for k, v in exe_context.chat_generator_inputs.items() if k != "tools"}
final_inputs["tools"] = []

final_result = await AsyncPipeline._run_component_async(
component_name="chat_generator",
component={"instance": self.chat_generator},
component_inputs={"messages": exe_context.state.data["messages"] + [final_prompt], **final_inputs},
component_visits=exe_context.component_visits,
parent_span=span,
)

# Append final response
if final_result and "replies" in final_result:
for msg in final_result["replies"]:
exe_context.state.data["messages"].append(msg)

span.set_tag("haystack.agent.final_answer_generated", True)

except Exception as e:
logger.warning(
"Failed to generate final answer: {error}. Returning with tool result as last message.", error=str(e)
)
span.set_tag("haystack.agent.final_answer_failed", True)

def run( # noqa: PLR0915
self,
messages: list[ChatMessage],
Expand Down Expand Up @@ -677,6 +767,8 @@ def run( # noqa: PLR0915
"Agent reached maximum agent steps of {max_agent_steps}, stopping.",
max_agent_steps=self.max_agent_steps,
)
self._generate_final_answer(exe_context, span)

span.set_content_tag("haystack.agent.output", exe_context.state.data)
span.set_tag("haystack.agent.steps_taken", exe_context.counter)

Expand Down Expand Up @@ -820,6 +912,8 @@ async def run_async(
"Agent reached maximum agent steps of {max_agent_steps}, stopping.",
max_agent_steps=self.max_agent_steps,
)
await self._generate_final_answer_async(exe_context, span)

span.set_content_tag("haystack.agent.output", exe_context.state.data)
span.set_tag("haystack.agent.steps_taken", exe_context.counter)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
---
enhancements:
- |
Add `final_answer_on_max_steps` parameter to Agent component. When enabled (default: True),
the agent will generate a final natural language response if it reaches max_agent_steps with
a tool result as the last message. This ensures the agent always returns a user-friendly text
response instead of raw tool output, improving user experience when step limits are reached.
The feature adds one additional LLM call that doesn't count toward max_agent_steps.
66 changes: 66 additions & 0 deletions test/components/agents/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,72 @@ def test_exceed_max_steps(self, monkeypatch, weather_tool, caplog):
agent.run([ChatMessage.from_user("Hello")])
assert "Agent reached maximum agent steps" in caplog.text

def test_final_answer_on_max_steps_enabled(self, monkeypatch, weather_tool):
"""Test that final answer is generated when max steps is reached with tool result as last message."""
monkeypatch.setenv("OPENAI_API_KEY", "fake-key")
generator = OpenAIChatGenerator()

# Mock responses: first returns tool call, then after tools run, we hit max steps
agent = Agent(chat_generator=generator, tools=[weather_tool], max_agent_steps=1, final_answer_on_max_steps=True)
agent.warm_up()

call_count = 0

def mock_run(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
# First call: LLM wants to call tool
return {
"replies": [
ChatMessage.from_assistant(
tool_calls=[ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"})]
)
]
}
else:
# Final answer call (no tools available)
return {"replies": [ChatMessage.from_assistant("Based on the weather data, it's 20C in Berlin.")]}

agent.chat_generator.run = mock_run

result = agent.run([ChatMessage.from_user("What's the weather in Berlin?")])

# Last message should be text response, not tool result
assert result["last_message"].text
assert "Berlin" in result["last_message"].text

def test_final_answer_on_max_steps_disabled(self, monkeypatch, weather_tool):
"""Test that no final answer is generated when final_answer_on_max_steps=False."""
monkeypatch.setenv("OPENAI_API_KEY", "fake-key")
generator = OpenAIChatGenerator()

agent = Agent(
chat_generator=generator, tools=[weather_tool], max_agent_steps=1, final_answer_on_max_steps=False
)
agent.warm_up()

call_count = 0

def mock_run(*args, **kwargs):
nonlocal call_count
call_count += 1
# Always return tool call to ensure we'd end with tool result
return {
"replies": [
ChatMessage.from_assistant(
tool_calls=[ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"})]
)
]
}

agent.chat_generator.run = mock_run

agent.run([ChatMessage.from_user("What's the weather?")])

# Should have ended without final answer call (only 1 LLM call, not 2)
assert call_count == 1

def test_exit_conditions_checked_across_all_llm_messages(self, monkeypatch, weather_tool):
monkeypatch.setenv("OPENAI_API_KEY", "fake-key")
generator = OpenAIChatGenerator()
Expand Down