5858 WaitQueueTimeoutError ,
5959)
6060from pymongo .hello import Hello , HelloCompat
61+ from pymongo .helpers_shared import _get_timeout_details , format_timeout_details
6162from pymongo .lock import (
6263 _async_cond_wait ,
6364 _async_create_condition ,
7980 SSLErrors ,
8081 _CancellationContext ,
8182 _configured_protocol_interface ,
82- _get_timeout_details ,
8383 _raise_connection_failure ,
84- format_timeout_details ,
8584)
8685from pymongo .read_preferences import ReadPreference
8786from pymongo .server_api import _add_to_command
@@ -124,7 +123,89 @@ def _set_non_inheritable_non_atomic(fd: int) -> None: # noqa: ARG001
124123_IS_SYNC = False
125124
126125
127- class AsyncConnection :
126+ class AsyncBaseConnection :
127+ """A base connection object for server and kms connections."""
128+
129+ def __init__ (self , conn : AsyncNetworkingInterface , opts : PoolOptions ):
130+ self .conn = conn
131+ self .socket_checker : SocketChecker = SocketChecker ()
132+ self .cancel_context : _CancellationContext = _CancellationContext ()
133+ self .is_sdam = False
134+ self .closed = False
135+ self .last_timeout : float | None = None
136+ self .more_to_come = False
137+ self .opts = opts
138+ self .max_wire_version = - 1
139+
140+ def set_conn_timeout (self , timeout : Optional [float ]) -> None :
141+ """Cache last timeout to avoid duplicate calls to conn.settimeout."""
142+ if timeout == self .last_timeout :
143+ return
144+ self .last_timeout = timeout
145+ self .conn .get_conn .settimeout (timeout )
146+
147+ def apply_timeout (
148+ self , client : AsyncMongoClient [Any ], cmd : Optional [MutableMapping [str , Any ]]
149+ ) -> Optional [float ]:
150+ # CSOT: use remaining timeout when set.
151+ timeout = _csot .remaining ()
152+ if timeout is None :
153+ # Reset the socket timeout unless we're performing a streaming monitor check.
154+ if not self .more_to_come :
155+ self .set_conn_timeout (self .opts .socket_timeout )
156+ return None
157+ # RTT validation.
158+ rtt = _csot .get_rtt ()
159+ if rtt is None :
160+ rtt = self .connect_rtt
161+ max_time_ms = timeout - rtt
162+ if max_time_ms < 0 :
163+ timeout_details = _get_timeout_details (self .opts )
164+ formatted = format_timeout_details (timeout_details )
165+ # CSOT: raise an error without running the command since we know it will time out.
166+ errmsg = f"operation would exceed time limit, remaining timeout:{ timeout :.5f} <= network round trip time:{ rtt :.5f} { formatted } "
167+ if self .max_wire_version != - 1 :
168+ raise ExecutionTimeout (
169+ errmsg ,
170+ 50 ,
171+ {"ok" : 0 , "errmsg" : errmsg , "code" : 50 },
172+ self .max_wire_version ,
173+ )
174+ else :
175+ raise TimeoutError (errmsg )
176+ if cmd is not None :
177+ cmd ["maxTimeMS" ] = int (max_time_ms * 1000 )
178+ self .set_conn_timeout (timeout )
179+ return timeout
180+
181+ async def close_conn (self , reason : Optional [str ]) -> None :
182+ """Close this connection with a reason."""
183+ if self .closed :
184+ return
185+ await self ._close_conn ()
186+
187+ async def _close_conn (self ) -> None :
188+ """Close this connection."""
189+ if self .closed :
190+ return
191+ self .closed = True
192+ self .cancel_context .cancel ()
193+ # Note: We catch exceptions to avoid spurious errors on interpreter
194+ # shutdown.
195+ try :
196+ await self .conn .close ()
197+ except Exception : # noqa: S110
198+ pass
199+
200+ def conn_closed (self ) -> bool :
201+ """Return True if we know socket has been closed, False otherwise."""
202+ if _IS_SYNC :
203+ return self .socket_checker .socket_closed (self .conn .get_conn )
204+ else :
205+ return self .conn .is_closing ()
206+
207+
208+ class AsyncConnection (AsyncBaseConnection ):
128209 """Store a connection with some metadata.
129210
130211 :param conn: a raw connection object
@@ -142,29 +223,27 @@ def __init__(
142223 id : int ,
143224 is_sdam : bool ,
144225 ):
226+ super ().__init__ (conn , pool .opts )
145227 self .pool_ref = weakref .ref (pool )
146- self .conn = conn
147- self .address = address
148- self .id = id
228+ self .address : tuple [str , int ] = address
229+ self .id : int = id
149230 self .is_sdam = is_sdam
150- self .closed = False
151231 self .last_checkin_time = time .monotonic ()
152232 self .performed_handshake = False
153233 self .is_writable : bool = False
154234 self .max_wire_version = MAX_WIRE_VERSION
155- self .max_bson_size = MAX_BSON_SIZE
156- self .max_message_size = MAX_MESSAGE_SIZE
157- self .max_write_batch_size = MAX_WRITE_BATCH_SIZE
235+ self .max_bson_size : int = MAX_BSON_SIZE
236+ self .max_message_size : int = MAX_MESSAGE_SIZE
237+ self .max_write_batch_size : int = MAX_WRITE_BATCH_SIZE
158238 self .supports_sessions = False
159239 self .hello_ok : bool = False
160- self .is_mongos = False
240+ self .is_mongos : bool = False
161241 self .op_msg_enabled = False
162242 self .listeners = pool .opts ._event_listeners
163243 self .enabled_for_cmap = pool .enabled_for_cmap
164244 self .enabled_for_logging = pool .enabled_for_logging
165245 self .compression_settings = pool .opts ._compression_settings
166246 self .compression_context : Union [SnappyContext , ZlibContext , ZstdContext , None ] = None
167- self .socket_checker : SocketChecker = SocketChecker ()
168247 self .oidc_token_gen_id : Optional [int ] = None
169248 # Support for mechanism negotiation on the initial handshake.
170249 self .negotiated_mechs : Optional [list [str ]] = None
@@ -175,9 +254,6 @@ def __init__(
175254 self .pool_gen = pool .gen
176255 self .generation = self .pool_gen .get_overall ()
177256 self .ready = False
178- self .cancel_context : _CancellationContext = _CancellationContext ()
179- self .opts = pool .opts
180- self .more_to_come : bool = False
181257 # For load balancer support.
182258 self .service_id : Optional [ObjectId ] = None
183259 self .server_connection_id : Optional [int ] = None
@@ -193,44 +269,6 @@ def __init__(
193269 # For gossiping $clusterTime from the connection handshake to the client.
194270 self ._cluster_time = None
195271
196- def set_conn_timeout (self , timeout : Optional [float ]) -> None :
197- """Cache last timeout to avoid duplicate calls to conn.settimeout."""
198- if timeout == self .last_timeout :
199- return
200- self .last_timeout = timeout
201- self .conn .get_conn .settimeout (timeout )
202-
203- def apply_timeout (
204- self , client : AsyncMongoClient [Any ], cmd : Optional [MutableMapping [str , Any ]]
205- ) -> Optional [float ]:
206- # CSOT: use remaining timeout when set.
207- timeout = _csot .remaining ()
208- if timeout is None :
209- # Reset the socket timeout unless we're performing a streaming monitor check.
210- if not self .more_to_come :
211- self .set_conn_timeout (self .opts .socket_timeout )
212- return None
213- # RTT validation.
214- rtt = _csot .get_rtt ()
215- if rtt is None :
216- rtt = self .connect_rtt
217- max_time_ms = timeout - rtt
218- if max_time_ms < 0 :
219- timeout_details = _get_timeout_details (self .opts )
220- formatted = format_timeout_details (timeout_details )
221- # CSOT: raise an error without running the command since we know it will time out.
222- errmsg = f"operation would exceed time limit, remaining timeout:{ timeout :.5f} <= network round trip time:{ rtt :.5f} { formatted } "
223- raise ExecutionTimeout (
224- errmsg ,
225- 50 ,
226- {"ok" : 0 , "errmsg" : errmsg , "code" : 50 },
227- self .max_wire_version ,
228- )
229- if cmd is not None :
230- cmd ["maxTimeMS" ] = int (max_time_ms * 1000 )
231- self .set_conn_timeout (timeout )
232- return timeout
233-
234272 def pin_txn (self ) -> None :
235273 self .pinned_txn = True
236274 assert not self .pinned_cursor
@@ -574,26 +612,6 @@ async def close_conn(self, reason: Optional[str]) -> None:
574612 error = reason ,
575613 )
576614
577- async def _close_conn (self ) -> None :
578- """Close this connection."""
579- if self .closed :
580- return
581- self .closed = True
582- self .cancel_context .cancel ()
583- # Note: We catch exceptions to avoid spurious errors on interpreter
584- # shutdown.
585- try :
586- await self .conn .close ()
587- except Exception : # noqa: S110
588- pass
589-
590- def conn_closed (self ) -> bool :
591- """Return True if we know socket has been closed, False otherwise."""
592- if _IS_SYNC :
593- return self .socket_checker .socket_closed (self .conn .get_conn )
594- else :
595- return self .conn .is_closing ()
596-
597615 def send_cluster_time (
598616 self ,
599617 command : MutableMapping [str , Any ],
0 commit comments