1616from concurrent import futures
1717from datetime import timedelta
1818from types import TracebackType
19- from typing import Any , Callable , Coroutine , Dict , Optional , TypeVar , Union
19+ from typing import Any , Callable , Coroutine , Dict , Optional , TypeVar , Union , cast
2020
2121from mcp import ClientSession , ListToolsResult
2222from mcp .types import CallToolResult as MCPCallToolResult
@@ -83,11 +83,15 @@ def __init__(self, transport_callable: Callable[[], MCPTransport], *, startup_ti
8383 self ._transport_callable = transport_callable
8484
8585 self ._background_thread : threading .Thread | None = None
86- self ._background_thread_session : ClientSession
87- self ._background_thread_event_loop : AbstractEventLoop
86+ self ._background_thread_session : ClientSession | None = None
87+ self ._background_thread_event_loop : AbstractEventLoop | None = None
8888
8989 def __enter__ (self ) -> "MCPClient" :
90- """Context manager entry point which initializes the MCP server connection."""
90+ """Context manager entry point which initializes the MCP server connection.
91+
92+ TODO: Refactor to lazy initialization pattern following idiomatic Python.
93+ Heavy work in __enter__ is non-idiomatic - should move connection logic to first method call instead.
94+ """
9195 return self .start ()
9296
9397 def __exit__ (self , exc_type : BaseException , exc_val : BaseException , exc_tb : TracebackType ) -> None :
@@ -118,9 +122,16 @@ def start(self) -> "MCPClient":
118122 self ._init_future .result (timeout = self ._startup_timeout )
119123 self ._log_debug_with_thread ("the client initialization was successful" )
120124 except futures .TimeoutError as e :
121- raise MCPClientInitializationError ("background thread did not start in 30 seconds" ) from e
125+ logger .exception ("client initialization timed out" )
126+ # Pass None for exc_type, exc_val, exc_tb since this isn't a context manager exit
127+ self .stop (None , None , None )
128+ raise MCPClientInitializationError (
129+ f"background thread did not start in { self ._startup_timeout } seconds"
130+ ) from e
122131 except Exception as e :
123132 logger .exception ("client failed to initialize" )
133+ # Pass None for exc_type, exc_val, exc_tb since this isn't a context manager exit
134+ self .stop (None , None , None )
124135 raise MCPClientInitializationError ("the client initialization failed" ) from e
125136 return self
126137
@@ -129,21 +140,29 @@ def stop(
129140 ) -> None :
130141 """Signals the background thread to stop and waits for it to complete, ensuring proper cleanup of all resources.
131142
143+ This method is defensive and can handle partial initialization states that may occur
144+ if start() fails partway through initialization.
145+
132146 Args:
133147 exc_type: Exception type if an exception was raised in the context
134148 exc_val: Exception value if an exception was raised in the context
135149 exc_tb: Exception traceback if an exception was raised in the context
136150 """
137151 self ._log_debug_with_thread ("exiting MCPClient context" )
138152
139- async def _set_close_event () -> None :
140- self ._close_event .set ()
141-
142- self ._invoke_on_background_thread (_set_close_event ()).result ()
143- self ._log_debug_with_thread ("waiting for background thread to join" )
153+ # Only try to signal close event if we have a background thread
144154 if self ._background_thread is not None :
155+ # Signal close event if event loop exists
156+ if self ._background_thread_event_loop is not None :
157+
158+ async def _set_close_event () -> None :
159+ self ._close_event .set ()
160+
161+ self ._invoke_on_background_thread (_set_close_event ()).result ()
162+
163+ self ._log_debug_with_thread ("waiting for background thread to join" )
145164 self ._background_thread .join ()
146- self ._log_debug_with_thread ("background thread joined , MCPClient context exited" )
165+ self ._log_debug_with_thread ("background thread is closed , MCPClient context exited" )
147166
148167 # Reset fields to allow instance reuse
149168 self ._init_future = futures .Future ()
@@ -165,7 +184,7 @@ def list_tools_sync(self, pagination_token: Optional[str] = None) -> PaginatedLi
165184 raise MCPClientInitializationError (CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE )
166185
167186 async def _list_tools_async () -> ListToolsResult :
168- return await self ._background_thread_session .list_tools (cursor = pagination_token )
187+ return await cast ( ClientSession , self ._background_thread_session ) .list_tools (cursor = pagination_token )
169188
170189 list_tools_response : ListToolsResult = self ._invoke_on_background_thread (_list_tools_async ()).result ()
171190 self ._log_debug_with_thread ("received %d tools from MCP server" , len (list_tools_response .tools ))
@@ -191,7 +210,7 @@ def list_prompts_sync(self, pagination_token: Optional[str] = None) -> ListPromp
191210 raise MCPClientInitializationError (CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE )
192211
193212 async def _list_prompts_async () -> ListPromptsResult :
194- return await self ._background_thread_session .list_prompts (cursor = pagination_token )
213+ return await cast ( ClientSession , self ._background_thread_session ) .list_prompts (cursor = pagination_token )
195214
196215 list_prompts_result : ListPromptsResult = self ._invoke_on_background_thread (_list_prompts_async ()).result ()
197216 self ._log_debug_with_thread ("received %d prompts from MCP server" , len (list_prompts_result .prompts ))
@@ -215,7 +234,7 @@ def get_prompt_sync(self, prompt_id: str, args: dict[str, Any]) -> GetPromptResu
215234 raise MCPClientInitializationError (CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE )
216235
217236 async def _get_prompt_async () -> GetPromptResult :
218- return await self ._background_thread_session .get_prompt (prompt_id , arguments = args )
237+ return await cast ( ClientSession , self ._background_thread_session ) .get_prompt (prompt_id , arguments = args )
219238
220239 get_prompt_result : GetPromptResult = self ._invoke_on_background_thread (_get_prompt_async ()).result ()
221240 self ._log_debug_with_thread ("received prompt from MCP server" )
@@ -250,7 +269,9 @@ def call_tool_sync(
250269 raise MCPClientInitializationError (CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE )
251270
252271 async def _call_tool_async () -> MCPCallToolResult :
253- return await self ._background_thread_session .call_tool (name , arguments , read_timeout_seconds )
272+ return await cast (ClientSession , self ._background_thread_session ).call_tool (
273+ name , arguments , read_timeout_seconds
274+ )
254275
255276 try :
256277 call_tool_result : MCPCallToolResult = self ._invoke_on_background_thread (_call_tool_async ()).result ()
@@ -285,7 +306,9 @@ async def call_tool_async(
285306 raise MCPClientInitializationError (CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE )
286307
287308 async def _call_tool_async () -> MCPCallToolResult :
288- return await self ._background_thread_session .call_tool (name , arguments , read_timeout_seconds )
309+ return await cast (ClientSession , self ._background_thread_session ).call_tool (
310+ name , arguments , read_timeout_seconds
311+ )
289312
290313 try :
291314 future = self ._invoke_on_background_thread (_call_tool_async ())
0 commit comments