1515"""MONGODB-OIDC Authentication helpers."""
1616from __future__ import annotations
1717
18+ import asyncio
1819import threading
1920import time
2021from dataclasses import dataclass , field
3637)
3738from pymongo .errors import ConfigurationError , OperationFailure
3839from pymongo .helpers_shared import _AUTHENTICATION_FAILURE_CODE
40+ from pymongo .lock import Lock , _async_create_lock
3941
4042if TYPE_CHECKING :
4143 from pymongo .asynchronous .pool import AsyncConnection
@@ -81,7 +83,11 @@ class _OIDCAuthenticator:
8183 access_token : Optional [str ] = field (default = None )
8284 idp_info : Optional [OIDCIdPInfo ] = field (default = None )
8385 token_gen_id : int = field (default = 0 )
84- lock : threading .Lock = field (default_factory = threading .Lock )
86+ if not _IS_SYNC :
87+ lock : Lock = field (default_factory = _async_create_lock ) # type: ignore[assignment]
88+ else :
89+ lock : threading .Lock = field (default_factory = _async_create_lock ) # type: ignore[assignment, no-redef]
90+
8591 last_call_time : float = field (default = 0 )
8692
8793 async def reauthenticate (self , conn : AsyncConnection ) -> Optional [Mapping [str , Any ]]:
@@ -164,7 +170,7 @@ async def _authenticate_human(self, conn: AsyncConnection) -> Optional[Mapping[s
164170 # Attempt to authenticate with a JwtStepRequest.
165171 return await self ._sasl_continue_jwt (conn , start_resp )
166172
167- def _get_access_token (self ) -> Optional [str ]:
173+ async def _get_access_token (self ) -> Optional [str ]:
168174 properties = self .properties
169175 cb : Union [None , OIDCCallback ]
170176 resp : OIDCCallbackResult
@@ -186,7 +192,7 @@ def _get_access_token(self) -> Optional[str]:
186192 return None
187193
188194 if not prev_token and cb is not None :
189- with self .lock :
195+ async with self .lock : # type: ignore[attr-defined]
190196 # See if the token was changed while we were waiting for the
191197 # lock.
192198 new_token = self .access_token
@@ -196,7 +202,7 @@ def _get_access_token(self) -> Optional[str]:
196202 # Ensure that we are waiting a min time between callback invocations.
197203 delta = time .time () - self .last_call_time
198204 if delta < TIME_BETWEEN_CALLS_SECONDS :
199- time .sleep (TIME_BETWEEN_CALLS_SECONDS - delta )
205+ await asyncio .sleep (TIME_BETWEEN_CALLS_SECONDS - delta )
200206 self .last_call_time = time .time ()
201207
202208 if is_human :
@@ -211,7 +217,10 @@ def _get_access_token(self) -> Optional[str]:
211217 idp_info = self .idp_info ,
212218 username = self .properties .username ,
213219 )
214- resp = cb .fetch (context )
220+ if not _IS_SYNC :
221+ resp = await asyncio .get_running_loop ().run_in_executor (None , cb .fetch , context ) # type: ignore[assignment]
222+ else :
223+ resp = cb .fetch (context )
215224 if not isinstance (resp , OIDCCallbackResult ):
216225 raise ValueError (
217226 f"Callback result must be of type OIDCCallbackResult, not { type (resp )} "
@@ -253,13 +262,13 @@ async def _sasl_continue_jwt(
253262 start_payload : dict = bson .decode (start_resp ["payload" ])
254263 if "issuer" in start_payload :
255264 self .idp_info = OIDCIdPInfo (** start_payload )
256- access_token = self ._get_access_token ()
265+ access_token = await self ._get_access_token ()
257266 conn .oidc_token_gen_id = self .token_gen_id
258267 cmd = self ._get_continue_command ({"jwt" : access_token }, start_resp )
259268 return await self ._run_command (conn , cmd )
260269
261270 async def _sasl_start_jwt (self , conn : AsyncConnection ) -> Mapping [str , Any ]:
262- access_token = self ._get_access_token ()
271+ access_token = await self ._get_access_token ()
263272 conn .oidc_token_gen_id = self .token_gen_id
264273 cmd = self ._get_start_command ({"jwt" : access_token })
265274 return await self ._run_command (conn , cmd )
0 commit comments