Skip to content

Commit 03c62f7

Browse files
fix(mcp): auto cleanup on exceptions occurring in __enter__ (#833)
1 parent 8805021 commit 03c62f7

File tree

3 files changed

+125
-18
lines changed

3 files changed

+125
-18
lines changed

src/strands/tools/mcp/mcp_client.py

Lines changed: 39 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from concurrent import futures
1717
from datetime import timedelta
1818
from types import TracebackType
19-
from typing import Any, Callable, Coroutine, Dict, Optional, TypeVar, Union
19+
from typing import Any, Callable, Coroutine, Dict, Optional, TypeVar, Union, cast
2020

2121
from mcp import ClientSession, ListToolsResult
2222
from mcp.types import CallToolResult as MCPCallToolResult
@@ -83,11 +83,15 @@ def __init__(self, transport_callable: Callable[[], MCPTransport], *, startup_ti
8383
self._transport_callable = transport_callable
8484

8585
self._background_thread: threading.Thread | None = None
86-
self._background_thread_session: ClientSession
87-
self._background_thread_event_loop: AbstractEventLoop
86+
self._background_thread_session: ClientSession | None = None
87+
self._background_thread_event_loop: AbstractEventLoop | None = None
8888

8989
def __enter__(self) -> "MCPClient":
90-
"""Context manager entry point which initializes the MCP server connection."""
90+
"""Context manager entry point which initializes the MCP server connection.
91+
92+
TODO: Refactor to lazy initialization pattern following idiomatic Python.
93+
Heavy work in __enter__ is non-idiomatic - should move connection logic to first method call instead.
94+
"""
9195
return self.start()
9296

9397
def __exit__(self, exc_type: BaseException, exc_val: BaseException, exc_tb: TracebackType) -> None:
@@ -118,9 +122,16 @@ def start(self) -> "MCPClient":
118122
self._init_future.result(timeout=self._startup_timeout)
119123
self._log_debug_with_thread("the client initialization was successful")
120124
except futures.TimeoutError as e:
121-
raise MCPClientInitializationError("background thread did not start in 30 seconds") from e
125+
logger.exception("client initialization timed out")
126+
# Pass None for exc_type, exc_val, exc_tb since this isn't a context manager exit
127+
self.stop(None, None, None)
128+
raise MCPClientInitializationError(
129+
f"background thread did not start in {self._startup_timeout} seconds"
130+
) from e
122131
except Exception as e:
123132
logger.exception("client failed to initialize")
133+
# Pass None for exc_type, exc_val, exc_tb since this isn't a context manager exit
134+
self.stop(None, None, None)
124135
raise MCPClientInitializationError("the client initialization failed") from e
125136
return self
126137

@@ -129,21 +140,29 @@ def stop(
129140
) -> None:
130141
"""Signals the background thread to stop and waits for it to complete, ensuring proper cleanup of all resources.
131142
143+
This method is defensive and can handle partial initialization states that may occur
144+
if start() fails partway through initialization.
145+
132146
Args:
133147
exc_type: Exception type if an exception was raised in the context
134148
exc_val: Exception value if an exception was raised in the context
135149
exc_tb: Exception traceback if an exception was raised in the context
136150
"""
137151
self._log_debug_with_thread("exiting MCPClient context")
138152

139-
async def _set_close_event() -> None:
140-
self._close_event.set()
141-
142-
self._invoke_on_background_thread(_set_close_event()).result()
143-
self._log_debug_with_thread("waiting for background thread to join")
153+
# Only try to signal close event if we have a background thread
144154
if self._background_thread is not None:
155+
# Signal close event if event loop exists
156+
if self._background_thread_event_loop is not None:
157+
158+
async def _set_close_event() -> None:
159+
self._close_event.set()
160+
161+
self._invoke_on_background_thread(_set_close_event()).result()
162+
163+
self._log_debug_with_thread("waiting for background thread to join")
145164
self._background_thread.join()
146-
self._log_debug_with_thread("background thread joined, MCPClient context exited")
165+
self._log_debug_with_thread("background thread is closed, MCPClient context exited")
147166

148167
# Reset fields to allow instance reuse
149168
self._init_future = futures.Future()
@@ -165,7 +184,7 @@ def list_tools_sync(self, pagination_token: Optional[str] = None) -> PaginatedLi
165184
raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE)
166185

167186
async def _list_tools_async() -> ListToolsResult:
168-
return await self._background_thread_session.list_tools(cursor=pagination_token)
187+
return await cast(ClientSession, self._background_thread_session).list_tools(cursor=pagination_token)
169188

170189
list_tools_response: ListToolsResult = self._invoke_on_background_thread(_list_tools_async()).result()
171190
self._log_debug_with_thread("received %d tools from MCP server", len(list_tools_response.tools))
@@ -191,7 +210,7 @@ def list_prompts_sync(self, pagination_token: Optional[str] = None) -> ListPromp
191210
raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE)
192211

193212
async def _list_prompts_async() -> ListPromptsResult:
194-
return await self._background_thread_session.list_prompts(cursor=pagination_token)
213+
return await cast(ClientSession, self._background_thread_session).list_prompts(cursor=pagination_token)
195214

196215
list_prompts_result: ListPromptsResult = self._invoke_on_background_thread(_list_prompts_async()).result()
197216
self._log_debug_with_thread("received %d prompts from MCP server", len(list_prompts_result.prompts))
@@ -215,7 +234,7 @@ def get_prompt_sync(self, prompt_id: str, args: dict[str, Any]) -> GetPromptResu
215234
raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE)
216235

217236
async def _get_prompt_async() -> GetPromptResult:
218-
return await self._background_thread_session.get_prompt(prompt_id, arguments=args)
237+
return await cast(ClientSession, self._background_thread_session).get_prompt(prompt_id, arguments=args)
219238

220239
get_prompt_result: GetPromptResult = self._invoke_on_background_thread(_get_prompt_async()).result()
221240
self._log_debug_with_thread("received prompt from MCP server")
@@ -250,7 +269,9 @@ def call_tool_sync(
250269
raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE)
251270

252271
async def _call_tool_async() -> MCPCallToolResult:
253-
return await self._background_thread_session.call_tool(name, arguments, read_timeout_seconds)
272+
return await cast(ClientSession, self._background_thread_session).call_tool(
273+
name, arguments, read_timeout_seconds
274+
)
254275

255276
try:
256277
call_tool_result: MCPCallToolResult = self._invoke_on_background_thread(_call_tool_async()).result()
@@ -285,7 +306,9 @@ async def call_tool_async(
285306
raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE)
286307

287308
async def _call_tool_async() -> MCPCallToolResult:
288-
return await self._background_thread_session.call_tool(name, arguments, read_timeout_seconds)
309+
return await cast(ClientSession, self._background_thread_session).call_tool(
310+
name, arguments, read_timeout_seconds
311+
)
289312

290313
try:
291314
future = self._invoke_on_background_thread(_call_tool_async())

tests/strands/tools/mcp/test_mcp_client.py

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -337,8 +337,12 @@ def test_enter_with_initialization_exception(mock_transport):
337337

338338
client = MCPClient(mock_transport["transport_callable"])
339339

340-
with pytest.raises(MCPClientInitializationError, match="the client initialization failed"):
341-
client.start()
340+
with patch.object(client, "stop") as mock_stop:
341+
with pytest.raises(MCPClientInitializationError, match="the client initialization failed"):
342+
client.start()
343+
344+
# Verify stop() was called for cleanup
345+
mock_stop.assert_called_once_with(None, None, None)
342346

343347

344348
def test_mcp_tool_result_type():
@@ -466,3 +470,54 @@ def test_get_prompt_sync_session_not_active():
466470

467471
with pytest.raises(MCPClientInitializationError, match="client session is not running"):
468472
client.get_prompt_sync("test_prompt_id", {})
473+
474+
475+
def test_timeout_initialization_cleanup():
476+
"""Test that timeout during initialization properly cleans up."""
477+
478+
def slow_transport():
479+
time.sleep(5)
480+
return MagicMock()
481+
482+
client = MCPClient(slow_transport, startup_timeout=1)
483+
484+
with patch.object(client, "stop") as mock_stop:
485+
with pytest.raises(MCPClientInitializationError, match="background thread did not start in 1 seconds"):
486+
client.start()
487+
mock_stop.assert_called_once_with(None, None, None)
488+
489+
490+
def test_stop_with_no_background_thread():
491+
"""Test that stop() handles the case when no background thread exists."""
492+
client = MCPClient(MagicMock())
493+
494+
# Ensure no background thread exists
495+
assert client._background_thread is None
496+
497+
# Mock join to verify it's not called
498+
with patch("threading.Thread.join") as mock_join:
499+
client.stop(None, None, None)
500+
mock_join.assert_not_called()
501+
502+
# Verify cleanup occurred
503+
assert client._background_thread is None
504+
505+
506+
def test_stop_with_background_thread_but_no_event_loop():
507+
"""Test that stop() handles the case when background thread exists but event loop is None."""
508+
client = MCPClient(MagicMock())
509+
510+
# Mock a background thread without event loop
511+
mock_thread = MagicMock()
512+
mock_thread.join = MagicMock()
513+
client._background_thread = mock_thread
514+
client._background_thread_event_loop = None
515+
516+
# Should not raise any exceptions and should join the thread
517+
client.stop(None, None, None)
518+
519+
# Verify thread was joined
520+
mock_thread.join.assert_called_once()
521+
522+
# Verify cleanup occurred
523+
assert client._background_thread is None

tests_integ/test_mcp_client.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from strands.tools.mcp.mcp_client import MCPClient
1616
from strands.tools.mcp.mcp_types import MCPTransport
1717
from strands.types.content import Message
18+
from strands.types.exceptions import MCPClientInitializationError
1819
from strands.types.tools import ToolUse
1920

2021

@@ -268,3 +269,31 @@ def transport_callback() -> MCPTransport:
268269

269270
def _messages_to_content_blocks(messages: List[Message]) -> List[ToolUse]:
270271
return [block["toolUse"] for message in messages for block in message["content"] if "toolUse" in block]
272+
273+
274+
def test_mcp_client_timeout_integration():
275+
"""Integration test for timeout scenario that caused hanging."""
276+
import threading
277+
278+
from mcp import StdioServerParameters, stdio_client
279+
280+
def slow_transport():
281+
time.sleep(4) # Longer than timeout
282+
return stdio_client(StdioServerParameters(command="python", args=["tests_integ/echo_server.py"]))
283+
284+
client = MCPClient(slow_transport, startup_timeout=2)
285+
initial_threads = threading.active_count()
286+
287+
# First attempt should timeout
288+
with pytest.raises(MCPClientInitializationError, match="background thread did not start in 2 seconds"):
289+
with client:
290+
pass
291+
292+
time.sleep(1) # Allow cleanup
293+
assert threading.active_count() == initial_threads # No thread leak
294+
295+
# Should be able to recover by increasing timeout
296+
client._startup_timeout = 60
297+
with client:
298+
tools = client.list_tools_sync()
299+
assert len(tools) >= 0 # Should work now

0 commit comments

Comments
 (0)