|
3 | 3 | import pytest |
4 | 4 |
|
5 | 5 | from strands.agent import Agent, AgentResult |
| 6 | +from strands.hooks import AgentInitializedEvent |
| 7 | +from strands.hooks.registry import HookProvider, HookRegistry |
6 | 8 | from strands.multiagent.base import MultiAgentBase, MultiAgentResult, NodeResult |
7 | | -from strands.multiagent.graph import GraphBuilder, GraphEdge, GraphNode, GraphResult, GraphState, Status |
| 9 | +from strands.multiagent.graph import Graph, GraphBuilder, GraphEdge, GraphNode, GraphResult, GraphState, Status |
| 10 | +from strands.session.session_manager import SessionManager |
8 | 11 |
|
9 | 12 |
|
10 | 13 | def create_mock_agent(name, response_text="Default response", metrics=None, agent_id=None): |
11 | 14 | """Create a mock Agent with specified properties.""" |
12 | 15 | agent = Mock(spec=Agent) |
13 | 16 | agent.name = name |
14 | 17 | agent.id = agent_id or f"{name}_id" |
| 18 | + agent._session_manager = None |
| 19 | + agent.hooks = HookRegistry() |
15 | 20 |
|
16 | 21 | if metrics is None: |
17 | 22 | metrics = Mock( |
@@ -261,6 +266,10 @@ async def test_graph_execution_with_failures(mock_strands_tracer, mock_use_span) |
261 | 266 | failing_agent.id = "fail_node" |
262 | 267 | failing_agent.__call__ = Mock(side_effect=Exception("Simulated failure")) |
263 | 268 |
|
| 269 | + # Add required attributes for validation |
| 270 | + failing_agent._session_manager = None |
| 271 | + failing_agent.hooks = HookRegistry() |
| 272 | + |
264 | 273 | async def mock_invoke_failure(*args, **kwargs): |
265 | 274 | raise Exception("Simulated failure") |
266 | 275 |
|
@@ -489,3 +498,51 @@ def test_graph_synchronous_execution(mock_strands_tracer, mock_use_span, mock_ag |
489 | 498 |
|
490 | 499 | mock_strands_tracer.start_multiagent_span.assert_called() |
491 | 500 | mock_use_span.assert_called_once() |
| 501 | + |
| 502 | + |
| 503 | +def test_graph_validate_unsupported_features(): |
| 504 | + """Test Graph validation for session persistence and callbacks.""" |
| 505 | + # Test with normal agent (should work) |
| 506 | + normal_agent = create_mock_agent("normal_agent") |
| 507 | + normal_agent._session_manager = None |
| 508 | + normal_agent.hooks = HookRegistry() |
| 509 | + |
| 510 | + builder = GraphBuilder() |
| 511 | + builder.add_node(normal_agent) |
| 512 | + graph = builder.build() |
| 513 | + assert len(graph.nodes) == 1 |
| 514 | + |
| 515 | + # Test with session manager (should fail in GraphBuilder.add_node) |
| 516 | + mock_session_manager = Mock(spec=SessionManager) |
| 517 | + agent_with_session = create_mock_agent("agent_with_session") |
| 518 | + agent_with_session._session_manager = mock_session_manager |
| 519 | + agent_with_session.hooks = HookRegistry() |
| 520 | + |
| 521 | + builder = GraphBuilder() |
| 522 | + with pytest.raises(ValueError, match="Session persistence is not supported for Graph agents yet"): |
| 523 | + builder.add_node(agent_with_session) |
| 524 | + |
| 525 | + # Test with callbacks (should fail in GraphBuilder.add_node) |
| 526 | + class TestHookProvider(HookProvider): |
| 527 | + def register_hooks(self, registry, **kwargs): |
| 528 | + registry.add_callback(AgentInitializedEvent, lambda e: None) |
| 529 | + |
| 530 | + agent_with_hooks = create_mock_agent("agent_with_hooks") |
| 531 | + agent_with_hooks._session_manager = None |
| 532 | + agent_with_hooks.hooks = HookRegistry() |
| 533 | + agent_with_hooks.hooks.add_hook(TestHookProvider()) |
| 534 | + |
| 535 | + builder = GraphBuilder() |
| 536 | + with pytest.raises(ValueError, match="Agent callbacks are not supported for Graph agents yet"): |
| 537 | + builder.add_node(agent_with_hooks) |
| 538 | + |
| 539 | + # Test validation in Graph constructor (when nodes are passed directly) |
| 540 | + # Test with session manager in Graph constructor |
| 541 | + node_with_session = GraphNode("node_with_session", agent_with_session) |
| 542 | + with pytest.raises(ValueError, match="Session persistence is not supported for Graph agents yet"): |
| 543 | + Graph(nodes={"node_with_session": node_with_session}, edges=set(), entry_points=set()) |
| 544 | + |
| 545 | + # Test with callbacks in Graph constructor |
| 546 | + node_with_hooks = GraphNode("node_with_hooks", agent_with_hooks) |
| 547 | + with pytest.raises(ValueError, match="Agent callbacks are not supported for Graph agents yet"): |
| 548 | + Graph(nodes={"node_with_hooks": node_with_hooks}, edges=set(), entry_points=set()) |
0 commit comments