diff --git a/haystack/components/agents/agent.py b/haystack/components/agents/agent.py index b5e5b1a1e4..e0ccab28f5 100644 --- a/haystack/components/agents/agent.py +++ b/haystack/components/agents/agent.py @@ -155,6 +155,7 @@ def __init__( exit_conditions: Optional[list[str]] = None, state_schema: Optional[dict[str, Any]] = None, max_agent_steps: int = 100, + final_answer_on_max_steps: bool = True, streaming_callback: Optional[StreamingCallbackT] = None, raise_on_tool_invocation_failure: bool = False, tool_invoker_kwargs: Optional[dict[str, Any]] = None, @@ -171,6 +172,10 @@ def __init__( :param state_schema: The schema for the runtime state used by the tools. :param max_agent_steps: Maximum number of steps the agent will run before stopping. Defaults to 100. If the agent exceeds this number of steps, it will stop and return the current state. + :param final_answer_on_max_steps: If True, generates a final text response when max_agent_steps + is reached and the last message is a tool result. This ensures the agent always returns a + natural language response instead of raw tool output. Adds one additional LLM call that doesn't + count toward max_agent_steps. Defaults to True. :param streaming_callback: A callback that will be invoked when a response is streamed from the LLM. The same callback can be configured to emit tool results when a tool is called. :param raise_on_tool_invocation_failure: Should the agent raise an exception when a tool invocation fails? @@ -213,6 +218,7 @@ def __init__( self.system_prompt = system_prompt self.exit_conditions = exit_conditions self.max_agent_steps = max_agent_steps + self.final_answer_on_max_steps = final_answer_on_max_steps self.raise_on_tool_invocation_failure = raise_on_tool_invocation_failure self.streaming_callback = streaming_callback @@ -520,6 +526,90 @@ def _check_tool_invoker_breakpoint( llm_messages=execution_context.state.data["messages"][-1:], pipeline_snapshot=pipeline_snapshot ) + def _generate_final_answer(self, exe_context: _ExecutionContext, span) -> None: + """Generate a final text response when max steps is reached with a tool result as last message.""" + if not self.final_answer_on_max_steps or not exe_context.state.data.get("messages"): + return + + last_msg = exe_context.state.data["messages"][-1] + if not last_msg.tool_call_result: + return + + try: + logger.info("Generating final text response after max steps reached.") + + # Add system message for context + final_prompt = ChatMessage.from_system( + "You have reached the maximum number of reasoning steps. " + "Based on the information gathered so far, provide a final answer " + "to the user's question. Tools are no longer available." + ) + + # Make final call with tools disabled + final_inputs = {k: v for k, v in exe_context.chat_generator_inputs.items() if k != "tools"} + final_result = self.chat_generator.run( + messages=exe_context.state.data["messages"] + [final_prompt], tools=[], **final_inputs + ) + + # Append final response + if final_result and "replies" in final_result: + for msg in final_result["replies"]: + exe_context.state.data["messages"].append(msg) + + span.set_tag("haystack.agent.final_answer_generated", True) + + except Exception as e: + logger.warning( + "Failed to generate final answer: {error}. Returning with tool result as last message.", error=str(e) + ) + span.set_tag("haystack.agent.final_answer_failed", True) + + async def _generate_final_answer_async(self, exe_context: _ExecutionContext, span) -> None: + """ + Async version: Generate a final text response when max steps is reached with tool result as last message. + """ + if not self.final_answer_on_max_steps or not exe_context.state.data.get("messages"): + return + + last_msg = exe_context.state.data["messages"][-1] + if not last_msg.tool_call_result: + return + + try: + logger.info("Generating final text response after max steps reached.") + + # Add system message for context + final_prompt = ChatMessage.from_system( + "You have reached the maximum number of reasoning steps. " + "Based on the information gathered so far, provide a final answer " + "to the user's question. Tools are no longer available." + ) + + # Make final call with tools disabled using AsyncPipeline + final_inputs = {k: v for k, v in exe_context.chat_generator_inputs.items() if k != "tools"} + final_inputs["tools"] = [] + + final_result = await AsyncPipeline._run_component_async( + component_name="chat_generator", + component={"instance": self.chat_generator}, + component_inputs={"messages": exe_context.state.data["messages"] + [final_prompt], **final_inputs}, + component_visits=exe_context.component_visits, + parent_span=span, + ) + + # Append final response + if final_result and "replies" in final_result: + for msg in final_result["replies"]: + exe_context.state.data["messages"].append(msg) + + span.set_tag("haystack.agent.final_answer_generated", True) + + except Exception as e: + logger.warning( + "Failed to generate final answer: {error}. Returning with tool result as last message.", error=str(e) + ) + span.set_tag("haystack.agent.final_answer_failed", True) + def run( # noqa: PLR0915 self, messages: list[ChatMessage], @@ -677,6 +767,8 @@ def run( # noqa: PLR0915 "Agent reached maximum agent steps of {max_agent_steps}, stopping.", max_agent_steps=self.max_agent_steps, ) + self._generate_final_answer(exe_context, span) + span.set_content_tag("haystack.agent.output", exe_context.state.data) span.set_tag("haystack.agent.steps_taken", exe_context.counter) @@ -820,6 +912,8 @@ async def run_async( "Agent reached maximum agent steps of {max_agent_steps}, stopping.", max_agent_steps=self.max_agent_steps, ) + await self._generate_final_answer_async(exe_context, span) + span.set_content_tag("haystack.agent.output", exe_context.state.data) span.set_tag("haystack.agent.steps_taken", exe_context.counter) diff --git a/releasenotes/notes/agent-final-answer-on-max-steps-a1b2c3d4e5f6g7h8.yaml b/releasenotes/notes/agent-final-answer-on-max-steps-a1b2c3d4e5f6g7h8.yaml new file mode 100644 index 0000000000..fa3afa5edf --- /dev/null +++ b/releasenotes/notes/agent-final-answer-on-max-steps-a1b2c3d4e5f6g7h8.yaml @@ -0,0 +1,8 @@ +--- +enhancements: + - | + Add `final_answer_on_max_steps` parameter to Agent component. When enabled (default: True), + the agent will generate a final natural language response if it reaches max_agent_steps with + a tool result as the last message. This ensures the agent always returns a user-friendly text + response instead of raw tool output, improving user experience when step limits are reached. + The feature adds one additional LLM call that doesn't count toward max_agent_steps. diff --git a/test/components/agents/test_agent.py b/test/components/agents/test_agent.py index e896b715dd..a7cb715118 100644 --- a/test/components/agents/test_agent.py +++ b/test/components/agents/test_agent.py @@ -741,6 +741,72 @@ def test_exceed_max_steps(self, monkeypatch, weather_tool, caplog): agent.run([ChatMessage.from_user("Hello")]) assert "Agent reached maximum agent steps" in caplog.text + def test_final_answer_on_max_steps_enabled(self, monkeypatch, weather_tool): + """Test that final answer is generated when max steps is reached with tool result as last message.""" + monkeypatch.setenv("OPENAI_API_KEY", "fake-key") + generator = OpenAIChatGenerator() + + # Mock responses: first returns tool call, then after tools run, we hit max steps + agent = Agent(chat_generator=generator, tools=[weather_tool], max_agent_steps=1, final_answer_on_max_steps=True) + agent.warm_up() + + call_count = 0 + + def mock_run(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + # First call: LLM wants to call tool + return { + "replies": [ + ChatMessage.from_assistant( + tool_calls=[ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"})] + ) + ] + } + else: + # Final answer call (no tools available) + return {"replies": [ChatMessage.from_assistant("Based on the weather data, it's 20C in Berlin.")]} + + agent.chat_generator.run = mock_run + + result = agent.run([ChatMessage.from_user("What's the weather in Berlin?")]) + + # Last message should be text response, not tool result + assert result["last_message"].text + assert "Berlin" in result["last_message"].text + + def test_final_answer_on_max_steps_disabled(self, monkeypatch, weather_tool): + """Test that no final answer is generated when final_answer_on_max_steps=False.""" + monkeypatch.setenv("OPENAI_API_KEY", "fake-key") + generator = OpenAIChatGenerator() + + agent = Agent( + chat_generator=generator, tools=[weather_tool], max_agent_steps=1, final_answer_on_max_steps=False + ) + agent.warm_up() + + call_count = 0 + + def mock_run(*args, **kwargs): + nonlocal call_count + call_count += 1 + # Always return tool call to ensure we'd end with tool result + return { + "replies": [ + ChatMessage.from_assistant( + tool_calls=[ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"})] + ) + ] + } + + agent.chat_generator.run = mock_run + + agent.run([ChatMessage.from_user("What's the weather?")]) + + # Should have ended without final answer call (only 1 LLM call, not 2) + assert call_count == 1 + def test_exit_conditions_checked_across_all_llm_messages(self, monkeypatch, weather_tool): monkeypatch.setenv("OPENAI_API_KEY", "fake-key") generator = OpenAIChatGenerator()