Skip to content

Commit 6b77476

Browse files
soerenreichardtFlorentinD
authored andcommitted
Allow arrow authentication via aura api token
Always pass api token from DedicatedSessions
1 parent eb95b15 commit 6b77476

File tree

10 files changed

+92
-39
lines changed

10 files changed

+92
-39
lines changed

graphdatascience/graph_data_science.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from .query_runner.neo4j_query_runner import Neo4jQueryRunner
1717
from .query_runner.query_runner import QueryRunner
1818
from .server_version.server_version import ServerVersion
19+
from .session.arrow_authentication import UsernamePasswordAuthentication
1920
from .utils.util_proc_runner import UtilProcRunner
2021
from .version import __min_server_version__
2122

@@ -93,10 +94,15 @@ def __init__(
9394

9495
arrow_info = ArrowInfo.create(self._query_runner)
9596
if arrow and arrow_info.enabled and self._server_version >= ServerVersion(2, 1, 0):
97+
arrow_auth = None
98+
if auth is not None:
99+
username, password = auth
100+
arrow_auth = UsernamePasswordAuthentication(username, password)
101+
96102
self._query_runner = ArrowQueryRunner.create(
97103
self._query_runner,
98104
arrow_info,
99-
auth,
105+
arrow_auth,
100106
self._query_runner.encrypted(),
101107
arrow_disable_server_verification,
102108
arrow_tls_root_certs,

graphdatascience/query_runner/arrow_query_runner.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,15 @@
1414
from .gds_arrow_client import GdsArrowClient
1515
from .graph_constructor import GraphConstructor
1616
from .query_runner import QueryRunner
17+
from ..session.arrow_authentication import ArrowAuthentication
1718

1819

1920
class ArrowQueryRunner(QueryRunner):
2021
@staticmethod
2122
def create(
2223
fallback_query_runner: QueryRunner,
2324
arrow_info: ArrowInfo,
24-
auth: Optional[tuple[str, str]] = None,
25+
arrow_authentication: Optional[ArrowAuthentication] = None,
2526
encrypted: bool = False,
2627
disable_server_verification: bool = False,
2728
tls_root_certs: Optional[bytes] = None,
@@ -33,7 +34,7 @@ def create(
3334

3435
gds_arrow_client = GdsArrowClient.create(
3536
arrow_info,
36-
auth,
37+
arrow_authentication,
3738
encrypted,
3839
disable_server_verification,
3940
tls_root_certs,

graphdatascience/query_runner/gds_arrow_client.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from graphdatascience.retry_utils.retry_utils import before_log
4040

4141
from ..semantic_version.semantic_version import SemanticVersion
42+
from ..session.arrow_authentication import ArrowAuthentication
4243
from ..version import __version__
4344
from .arrow_endpoint_version import ArrowEndpointVersion
4445
from .arrow_info import ArrowInfo
@@ -48,7 +49,7 @@ class GdsArrowClient:
4849
@staticmethod
4950
def create(
5051
arrow_info: ArrowInfo,
51-
auth: Optional[tuple[str, str]] = None,
52+
arrow_authentication: Optional[ArrowAuthentication] = None,
5253
encrypted: bool = False,
5354
disable_server_verification: bool = False,
5455
tls_root_certs: Optional[bytes] = None,
@@ -80,7 +81,7 @@ def create(
8081
host,
8182
retry_config,
8283
int(port),
83-
auth,
84+
arrow_authentication,
8485
encrypted,
8586
disable_server_verification,
8687
tls_root_certs,
@@ -92,7 +93,7 @@ def __init__(
9293
host: str,
9394
retry_config: RetryConfig,
9495
port: int = 8491,
95-
auth: Optional[tuple[str, str]] = None,
96+
auth: Optional[ArrowAuthentication] = None,
9697
encrypted: bool = False,
9798
disable_server_verification: bool = False,
9899
tls_root_certs: Optional[bytes] = None,
@@ -107,8 +108,8 @@ def __init__(
107108
The host address of the GDS Arrow server
108109
port: int
109110
The host port of the GDS Arrow server (default is 8491)
110-
auth: Optional[tuple[str, str]]
111-
A tuple containing the username and password for authentication
111+
auth: Optional[ArrowAuthentication]
112+
An implementation of ArrowAuthentication providing a pair to be used for basic authentication
112113
encrypted: bool
113114
A flag that indicates whether the connection should be encrypted (default is False)
114115
disable_server_verification: bool
@@ -189,7 +190,8 @@ def request_token(self) -> Optional[str]:
189190
def auth_with_retry() -> None:
190191
client = self._client()
191192
if self._auth:
192-
client.authenticate_basic_token(self._auth[0], self._auth[1])
193+
auth_pair = self._auth.auth_pair()
194+
client.authenticate_basic_token(auth_pair[0], auth_pair[1])
193195

194196
if self._auth:
195197
auth_with_retry()
@@ -884,7 +886,7 @@ def start_call(self, info: Any) -> AuthMiddleware:
884886

885887

886888
class AuthMiddleware(ClientMiddleware): # type: ignore
887-
def __init__(self, auth: tuple[str, str], *args: Any, **kwargs: Any) -> None:
889+
def __init__(self, auth: ArrowAuthentication, *args: Any, **kwargs: Any) -> None:
888890
super().__init__(*args, **kwargs)
889891
self._auth = auth
890892
self._token: Optional[str] = None
@@ -918,15 +920,15 @@ def received_headers(self, headers: dict[str, Any]) -> None:
918920

919921
def sending_headers(self) -> dict[str, str]:
920922
token = self.token()
921-
if not token:
922-
username, password = self._auth
923-
auth_token = f"{username}:{password}"
924-
auth_token = "Basic " + base64.b64encode(auth_token.encode("utf-8")).decode("ASCII")
925-
# There seems to be a bug, `authorization` must be lower key
926-
return {"authorization": auth_token}
927-
else:
923+
if token is not None:
928924
return {"authorization": "Bearer " + token}
929925

926+
auth_pair = self._auth.auth_pair()
927+
auth_token = f"{auth_pair[0]}:{auth_pair[1]}"
928+
auth_token = "Basic " + base64.b64encode(auth_token.encode("utf-8")).decode("ASCII")
929+
# There seems to be a bug, `authorization` must be lower key
930+
return {"authorization": auth_token}
931+
930932

931933
@dataclass(repr=True, frozen=True)
932934
class NodeLoadDoneResult:
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Callable
3+
4+
from graphdatascience.session.aura_api import AuraApi
5+
6+
7+
class ArrowAuthentication(ABC):
8+
type AuthTokenFn = Callable[[], str]
9+
10+
@abstractmethod
11+
def auth_pair(self) -> tuple[str, str]:
12+
"""Returns the auth pair used for authentication."""
13+
pass
14+
15+
16+
class UsernamePasswordAuthentication(ArrowAuthentication):
17+
def __init__(self, username: str, password: str):
18+
self._username = username
19+
self._password = password
20+
21+
def auth_pair(self) -> tuple[str, str]:
22+
return self._username, self._password
23+
24+
25+
class AuraApiTokenAuthentication(ArrowAuthentication):
26+
def __init__(self, aura_api: AuraApi):
27+
self._aura_api = aura_api
28+
29+
def auth_pair(self) -> tuple[str, str]:
30+
return "", self._aura_api._auth._auth_token()

graphdatascience/session/aura_graph_data_science.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from graphdatascience.query_runner.neo4j_query_runner import Neo4jQueryRunner
2020
from graphdatascience.query_runner.session_query_runner import SessionQueryRunner
2121
from graphdatascience.query_runner.standalone_session_query_runner import StandaloneSessionQueryRunner
22+
from graphdatascience.session.arrow_authentication import ArrowAuthentication
2223
from graphdatascience.session.dbms_connection_info import DbmsConnectionInfo
2324
from graphdatascience.utils.util_remote_proc_runner import UtilRemoteProcRunner
2425

@@ -32,7 +33,8 @@ class AuraGraphDataScience(DirectEndpoints, UncallableNamespace):
3233
@classmethod
3334
def create(
3435
cls,
35-
gds_session_connection_info: DbmsConnectionInfo,
36+
session_bolt_connection_info: DbmsConnectionInfo,
37+
arrow_authentication: Optional[ArrowAuthentication],
3638
db_endpoint: Optional[Union[Neo4jQueryRunner, DbmsConnectionInfo]],
3739
delete_fn: Callable[[], bool],
3840
arrow_disable_server_verification: bool = False,
@@ -41,16 +43,16 @@ def create(
4143
show_progress: bool = True,
4244
) -> AuraGraphDataScience:
4345
session_bolt_query_runner = Neo4jQueryRunner.create_for_session(
44-
endpoint=gds_session_connection_info.uri,
45-
auth=gds_session_connection_info.auth(),
46+
endpoint=session_bolt_connection_info.uri,
47+
auth=session_bolt_connection_info.auth(),
4648
show_progress=show_progress,
4749
)
4850

4951
arrow_info = ArrowInfo.create(session_bolt_query_runner)
5052
session_arrow_query_runner = ArrowQueryRunner.create(
5153
fallback_query_runner=session_bolt_query_runner,
5254
arrow_info=arrow_info,
53-
auth=gds_session_connection_info.auth(),
55+
arrow_authentication=arrow_authentication,
5456
encrypted=session_bolt_query_runner.encrypted(),
5557
disable_server_verification=arrow_disable_server_verification,
5658
tls_root_certs=arrow_tls_root_certs,
@@ -59,7 +61,7 @@ def create(
5961
# TODO: merge with the gds_arrow_client created inside ArrowQueryRunner
6062
session_arrow_client = GdsArrowClient.create(
6163
arrow_info,
62-
gds_session_connection_info.auth(),
64+
arrow_authentication,
6365
session_bolt_query_runner.encrypted(),
6466
arrow_disable_server_verification,
6567
arrow_tls_root_certs,

graphdatascience/session/dedicated_sessions.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from graphdatascience.query_runner.neo4j_query_runner import Neo4jQueryRunner
99
from graphdatascience.session.algorithm_category import AlgorithmCategory
10+
from graphdatascience.session.arrow_authentication import AuraApiTokenAuthentication, ArrowAuthentication
1011
from graphdatascience.session.aura_api import AuraApi
1112
from graphdatascience.session.aura_api_responses import SessionDetails
1213
from graphdatascience.session.aura_graph_data_science import AuraGraphDataScience
@@ -88,15 +89,18 @@ def get_or_create(
8889

8990
self._await_session_running(session_details, timeout)
9091

91-
session_connection = DbmsConnectionInfo(
92+
session_bolt_connection_info = DbmsConnectionInfo(
9293
uri=session_details.bolt_connection_url(),
93-
username="",
94-
password=self._aura_api._auth._auth_token(),
94+
username=self._aura_api._credentials[0],
95+
password=self._aura_api._credentials[1],
9596
)
9697

98+
arrow_authentication = AuraApiTokenAuthentication(self._aura_api)
99+
97100
return self._construct_client(
98101
session_id=session_details.id,
99-
session_connection=session_connection,
102+
session_bolt_connection_info=session_bolt_connection_info,
103+
arrow_authentication=arrow_authentication,
100104
db_runner=db_runner,
101105
)
102106

@@ -197,11 +201,13 @@ def _get_or_create_self_managed_session(
197201
def _construct_client(
198202
self,
199203
session_id: str,
200-
session_connection: DbmsConnectionInfo,
204+
session_bolt_connection_info: DbmsConnectionInfo,
205+
arrow_authentication: ArrowAuthentication,
201206
db_runner: Optional[Neo4jQueryRunner],
202207
) -> AuraGraphDataScience:
203208
return AuraGraphDataScience.create(
204-
gds_session_connection_info=session_connection,
209+
session_bolt_connection_info=session_bolt_connection_info,
210+
arrow_authentication=arrow_authentication,
205211
db_endpoint=db_runner,
206212
delete_fn=lambda: self._aura_api.delete_session(session_id=session_id),
207213
)

graphdatascience/tests/integration/conftest.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from graphdatascience.graph_data_science import GraphDataScience
1010
from graphdatascience.query_runner.neo4j_query_runner import Neo4jQueryRunner
1111
from graphdatascience.server_version.server_version import ServerVersion
12+
from graphdatascience.session.arrow_authentication import UsernamePasswordAuthentication
1213
from graphdatascience.session.aura_graph_data_science import AuraGraphDataScience
1314
from graphdatascience.session.dbms_connection_info import DbmsConnectionInfo
1415

@@ -92,7 +93,8 @@ def gds_without_arrow() -> Generator[GraphDataScience, None, None]:
9293
@pytest.fixture(scope="package", autouse=False)
9394
def gds_with_cloud_setup(request: pytest.FixtureRequest) -> Generator[AuraGraphDataScience, None, None]:
9495
_gds = AuraGraphDataScience.create(
95-
gds_session_connection_info=DbmsConnectionInfo(URI, AUTH[0], AUTH[1]),
96+
session_bolt_connection_info=DbmsConnectionInfo(URI, AUTH[0], AUTH[1]),
97+
arrow_authentication=UsernamePasswordAuthentication(AUTH[0], AUTH[1]),
9698
db_endpoint=DbmsConnectionInfo(AURA_DB_URI, AURA_DB_AUTH[0], AURA_DB_AUTH[1]),
9799
delete_fn=lambda: True,
98100
)
@@ -106,7 +108,8 @@ def gds_with_cloud_setup(request: pytest.FixtureRequest) -> Generator[AuraGraphD
106108
@pytest.fixture(scope="package", autouse=False)
107109
def standalone_aura_gds() -> Generator[AuraGraphDataScience, None, None]:
108110
_gds = AuraGraphDataScience.create(
109-
gds_session_connection_info=DbmsConnectionInfo(URI, AUTH[0], AUTH[1]),
111+
session_bolt_connection_info=DbmsConnectionInfo(URI, AUTH[0], AUTH[1]),
112+
arrow_authentication=UsernamePasswordAuthentication(AUTH[0], AUTH[1]),
110113
db_endpoint=None,
111114
delete_fn=lambda: True,
112115
)

graphdatascience/tests/unit/conftest.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
)
1717
from graphdatascience.query_runner.graph_constructor import GraphConstructor
1818
from graphdatascience.server_version.server_version import ServerVersion
19+
from graphdatascience.session.arrow_authentication import UsernamePasswordAuthentication
1920
from graphdatascience.session.aura_graph_data_science import AuraGraphDataScience
2021
from graphdatascience.session.dbms_connection_info import DbmsConnectionInfo
2122

@@ -179,7 +180,8 @@ def aura_gds(runner: CollectingQueryRunner, mocker: MockerFixture) -> Generator[
179180
mocker.patch("graphdatascience.query_runner.gds_arrow_client.GdsArrowClient.create", return_value=None)
180181

181182
aura_gds = AuraGraphDataScience.create(
182-
gds_session_connection_info=DbmsConnectionInfo("address", "some", "auth"),
183+
session_bolt_connection_info=DbmsConnectionInfo("address", "some", "auth"),
184+
arrow_authentication=UsernamePasswordAuthentication("some", "auth"),
183185
db_endpoint=DbmsConnectionInfo("address", "some", "auth"),
184186
delete_fn=lambda: True,
185187
)

graphdatascience/tests/unit/test_aura_api.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -807,14 +807,14 @@ def test_auth_token(requests_mock: Mocker) -> None:
807807
json={"access_token": "very_short_token", "expires_in": 0, "token_type": "Bearer"},
808808
)
809809

810-
assert api._request_session.auth._auth_token() == "very_short_token" # type: ignore
810+
assert api._request_session.auth_pair._auth_token() == "very_short_token" # type: ignore
811811

812812
requests_mock.post(
813813
"https://api.neo4j.io/oauth/token",
814814
json={"access_token": "longer_token", "expires_in": 3600, "token_type": "Bearer"},
815815
)
816816

817-
assert api._request_session.auth._auth_token() == "longer_token" # type: ignore
817+
assert api._request_session.auth_pair._auth_token() == "longer_token" # type: ignore
818818

819819

820820
def test_auth_token_reused(requests_mock: Mocker) -> None:
@@ -825,15 +825,15 @@ def test_auth_token_reused(requests_mock: Mocker) -> None:
825825
json={"access_token": "one_token", "expires_in": 3600, "token_type": "Bearer"},
826826
)
827827

828-
assert api._request_session.auth._auth_token() == "one_token" # type: ignore
828+
assert api._request_session.auth_pair._auth_token() == "one_token" # type: ignore
829829

830830
requests_mock.post(
831831
"https://api.neo4j.io/oauth/token",
832832
json={"access_token": "new_token", "expires_in": 3600, "token_type": "Bearer"},
833833
)
834834

835835
# no new token requested
836-
assert api._request_session.auth._auth_token() == "one_token" # type: ignore
836+
assert api._request_session.auth_pair._auth_token() == "one_token" # type: ignore
837837

838838

839839
def test_auth_token_use_short_token(requests_mock: Mocker) -> None:
@@ -844,8 +844,8 @@ def test_auth_token_use_short_token(requests_mock: Mocker) -> None:
844844
json={"access_token": "one_token", "expires_in": 10, "token_type": "Bearer"},
845845
)
846846

847-
assert api._request_session.auth._auth_token() == "one_token" # type: ignore
848-
assert api._request_session.auth._auth_token() == "one_token" # type: ignore
847+
assert api._request_session.auth_pair._auth_token() == "one_token" # type: ignore
848+
assert api._request_session.auth_pair._auth_token() == "one_token" # type: ignore
849849

850850

851851
def test_derive_tenant(requests_mock: Mocker) -> None:

graphdatascience/tests/unit/test_gds_arrow_client.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from graphdatascience.query_runner.arrow_info import ArrowInfo
1818
from graphdatascience.query_runner.gds_arrow_client import AuthMiddleware, GdsArrowClient
19+
from graphdatascience.session.session_connection_info import UsernamePasswordAuthentication
1920

2021
ActionParam = Union[str, tuple[str, Any], Action]
2122

@@ -382,7 +383,7 @@ def test_get_relationship_topologys(flight_server: FlightServer, flight_client:
382383

383384

384385
def test_auth_middleware() -> None:
385-
middleware = AuthMiddleware(("user", "password"))
386+
middleware = AuthMiddleware(UsernamePasswordAuthentication("user", "password"))
386387

387388
first_header = middleware.sending_headers()
388389
assert first_header == {"authorization": "Basic dXNlcjpwYXNzd29yZA=="}
@@ -401,7 +402,7 @@ def test_auth_middleware() -> None:
401402

402403

403404
def test_auth_middleware_bad_headers() -> None:
404-
middleware = AuthMiddleware(("user", "password"))
405+
middleware = AuthMiddleware(UsernamePasswordAuthentication("user", "password"))
405406

406407
with pytest.raises(ValueError, match="Incompatible header value received from server: `12342`"):
407408
middleware.received_headers({"authorization": [12342]})

0 commit comments

Comments
 (0)