|
5 | 5 | from mcp.types import JSONRPCMessage, JSONRPCRequest |
6 | 6 | from opentelemetry import context, propagate |
7 | 7 |
|
| 8 | +from strands.tools.mcp.mcp_client import MCPClient |
8 | 9 | from strands.tools.mcp.mcp_instrumentation import ( |
9 | 10 | ItemWithContext, |
10 | 11 | SessionContextAttachingReader, |
|
14 | 15 | ) |
15 | 16 |
|
16 | 17 |
|
| 18 | +@pytest.fixture(autouse=True) |
| 19 | +def reset_mcp_instrumentation(): |
| 20 | + """Reset MCP instrumentation state before each test.""" |
| 21 | + import strands.tools.mcp.mcp_instrumentation as mcp_inst |
| 22 | + |
| 23 | + mcp_inst._instrumentation_applied = False |
| 24 | + yield |
| 25 | + # Reset after test too |
| 26 | + mcp_inst._instrumentation_applied = False |
| 27 | + |
| 28 | + |
17 | 29 | class TestItemWithContext: |
18 | 30 | def test_item_with_context_creation(self): |
19 | 31 | """Test that ItemWithContext correctly stores item and context.""" |
@@ -328,6 +340,27 @@ def __getattr__(self, name): |
328 | 340 |
|
329 | 341 |
|
330 | 342 | class TestMCPInstrumentation: |
| 343 | + def test_mcp_instrumentation_idempotent_with_multiple_clients(self): |
| 344 | + """Test that mcp_instrumentation is only called once even with multiple MCPClient instances.""" |
| 345 | + |
| 346 | + # Mock the wrap_function_wrapper to count calls |
| 347 | + with patch("strands.tools.mcp.mcp_instrumentation.wrap_function_wrapper") as mock_wrap: |
| 348 | + # Mock transport |
| 349 | + def mock_transport(): |
| 350 | + read_stream = AsyncMock() |
| 351 | + write_stream = AsyncMock() |
| 352 | + return read_stream, write_stream |
| 353 | + |
| 354 | + # Create first MCPClient instance - should apply instrumentation |
| 355 | + MCPClient(mock_transport) |
| 356 | + first_call_count = mock_wrap.call_count |
| 357 | + |
| 358 | + # Create second MCPClient instance - should NOT apply instrumentation again |
| 359 | + MCPClient(mock_transport) |
| 360 | + |
| 361 | + # wrap_function_wrapper should not be called again for the second client |
| 362 | + assert mock_wrap.call_count == first_call_count |
| 363 | + |
331 | 364 | def test_mcp_instrumentation_calls_wrap_function_wrapper(self): |
332 | 365 | """Test that mcp_instrumentation calls the expected wrapper functions.""" |
333 | 366 | with ( |
|
0 commit comments