3939from graphdatascience .retry_utils .retry_utils import before_log
4040
4141from ..semantic_version .semantic_version import SemanticVersion
42+ from ..session .arrow_authentication import ArrowAuthentication
4243from ..version import __version__
4344from .arrow_endpoint_version import ArrowEndpointVersion
4445from .arrow_info import ArrowInfo
@@ -48,7 +49,7 @@ class GdsArrowClient:
4849 @staticmethod
4950 def create (
5051 arrow_info : ArrowInfo ,
51- auth : Optional [tuple [ str , str ] ] = None ,
52+ arrow_authentication : Optional [ArrowAuthentication ] = None ,
5253 encrypted : bool = False ,
5354 disable_server_verification : bool = False ,
5455 tls_root_certs : Optional [bytes ] = None ,
@@ -80,7 +81,7 @@ def create(
8081 host ,
8182 retry_config ,
8283 int (port ),
83- auth ,
84+ arrow_authentication ,
8485 encrypted ,
8586 disable_server_verification ,
8687 tls_root_certs ,
@@ -92,7 +93,7 @@ def __init__(
9293 host : str ,
9394 retry_config : RetryConfig ,
9495 port : int = 8491 ,
95- auth : Optional [tuple [ str , str ] ] = None ,
96+ auth : Optional [ArrowAuthentication ] = None ,
9697 encrypted : bool = False ,
9798 disable_server_verification : bool = False ,
9899 tls_root_certs : Optional [bytes ] = None ,
@@ -107,8 +108,8 @@ def __init__(
107108 The host address of the GDS Arrow server
108109 port: int
109110 The host port of the GDS Arrow server (default is 8491)
110- auth: Optional[tuple[str, str] ]
111- A tuple containing the username and password for authentication
111+ auth: Optional[ArrowAuthentication ]
112+ An implementation of ArrowAuthentication providing a pair to be used for basic authentication
112113 encrypted: bool
113114 A flag that indicates whether the connection should be encrypted (default is False)
114115 disable_server_verification: bool
@@ -189,7 +190,8 @@ def request_token(self) -> Optional[str]:
189190 def auth_with_retry () -> None :
190191 client = self ._client ()
191192 if self ._auth :
192- client .authenticate_basic_token (self ._auth [0 ], self ._auth [1 ])
193+ auth_pair = self ._auth .auth_pair ()
194+ client .authenticate_basic_token (auth_pair [0 ], auth_pair [1 ])
193195
194196 if self ._auth :
195197 auth_with_retry ()
@@ -884,7 +886,7 @@ def start_call(self, info: Any) -> AuthMiddleware:
884886
885887
886888class AuthMiddleware (ClientMiddleware ): # type: ignore
887- def __init__ (self , auth : tuple [ str , str ] , * args : Any , ** kwargs : Any ) -> None :
889+ def __init__ (self , auth : ArrowAuthentication , * args : Any , ** kwargs : Any ) -> None :
888890 super ().__init__ (* args , ** kwargs )
889891 self ._auth = auth
890892 self ._token : Optional [str ] = None
@@ -918,15 +920,15 @@ def received_headers(self, headers: dict[str, Any]) -> None:
918920
919921 def sending_headers (self ) -> dict [str , str ]:
920922 token = self .token ()
921- if not token :
922- username , password = self ._auth
923- auth_token = f"{ username } :{ password } "
924- auth_token = "Basic " + base64 .b64encode (auth_token .encode ("utf-8" )).decode ("ASCII" )
925- # There seems to be a bug, `authorization` must be lower key
926- return {"authorization" : auth_token }
927- else :
923+ if token is not None :
928924 return {"authorization" : "Bearer " + token }
929925
926+ auth_pair = self ._auth .auth_pair ()
927+ auth_token = f"{ auth_pair [0 ]} :{ auth_pair [1 ]} "
928+ auth_token = "Basic " + base64 .b64encode (auth_token .encode ("utf-8" )).decode ("ASCII" )
929+ # There seems to be a bug, `authorization` must be lower key
930+ return {"authorization" : auth_token }
931+
930932
931933@dataclass (repr = True , frozen = True )
932934class NodeLoadDoneResult :
0 commit comments