Skip to content

Commit a46cf2e

Browse files
committed
Use the right protocol when cloning the query runner for remote projections
(cherry picked from commit 3285d74)
1 parent 6b6af9a commit a46cf2e

File tree

7 files changed

+33
-12
lines changed

7 files changed

+33
-12
lines changed

graphdatascience/query_runner/arrow_query_runner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,10 +211,10 @@ def close(self) -> None:
211211
self._fallback_query_runner.close()
212212
self._gds_arrow_client.close()
213213

214-
def clone(self, endpoint: str) -> "QueryRunner":
214+
def clone(self, host: str, port: int) -> "QueryRunner":
215215
return ArrowQueryRunner(
216216
self._gds_arrow_client,
217-
self._fallback_query_runner.clone(endpoint),
217+
self._fallback_query_runner.clone(host, port),
218218
self._server_version,
219219
)
220220

graphdatascience/query_runner/neo4j_query_runner.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def create_for_db(
4646

4747
query_runner = Neo4jQueryRunner(
4848
driver,
49+
Neo4jQueryRunner.parse_protocol(endpoint),
4950
auth,
5051
auto_close=True,
5152
bookmarks=bookmarks,
@@ -55,8 +56,14 @@ def create_for_db(
5556
)
5657

5758
elif isinstance(endpoint, neo4j.Driver):
59+
protocol = "neo4j+s" if endpoint.encrypted else "bolt"
5860
query_runner = Neo4jQueryRunner(
59-
endpoint, auto_close=False, bookmarks=bookmarks, database=database, show_progress=show_progress
61+
endpoint,
62+
protocol,
63+
auto_close=False,
64+
bookmarks=bookmarks,
65+
database=database,
66+
show_progress=show_progress,
6067
)
6168
else:
6269
raise ValueError(f"Invalid endpoint type: {type(endpoint)}")
@@ -77,6 +84,7 @@ def create_for_session(
7784

7885
query_runner = Neo4jQueryRunner(
7986
driver,
87+
Neo4jQueryRunner.parse_protocol(endpoint),
8088
auth,
8189
auto_close=True,
8290
show_progress=show_progress,
@@ -96,9 +104,17 @@ def _configure_aura(config: dict[str, Any]) -> None:
96104
config["keep_alive"] = True
97105
config["max_connection_pool_size"] = 50
98106

107+
@staticmethod
108+
def parse_protocol(endpoint: str) -> str:
109+
protocol_match = re.match(r"^([^:]+)://", endpoint)
110+
if not protocol_match:
111+
raise ValueError(f"Invalid endpoint URI format: {endpoint}")
112+
return protocol_match.group(1)
113+
99114
def __init__(
100115
self,
101116
driver: neo4j.Driver,
117+
protocol: str,
102118
auth: Optional[tuple[str, str]] = None,
103119
config: dict[str, Any] = {},
104120
database: Optional[str] = neo4j.DEFAULT_DATABASE,
@@ -108,6 +124,7 @@ def __init__(
108124
instance_description: str = "Neo4j DBMS",
109125
):
110126
self._driver = driver
127+
self._protocol = protocol
111128
self._auth = auth
112129
self._config = config
113130
self._auto_close = auto_close
@@ -283,10 +300,13 @@ def create_graph_constructor(
283300
def set_show_progress(self, show_progress: bool) -> None:
284301
self._show_progress = show_progress
285302

286-
def clone(self, endpoint: str) -> QueryRunner:
303+
def clone(self, host: str, port: int) -> QueryRunner:
304+
endpoint = "{}://{}:{}".format(self._protocol, host, port)
287305
driver = neo4j.GraphDatabase.driver(endpoint, auth=self._auth, **self.driver_config())
306+
288307
return Neo4jQueryRunner(
289308
driver,
309+
self._protocol,
290310
self._auth,
291311
self._config,
292312
self._database,

graphdatascience/query_runner/protocol/project_protocols.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,8 @@ def is_not_done(result: DataFrame) -> bool:
136136
ProtocolVersion.V3.versioned_procedure_name(endpoint), params, yields, database, logging, False
137137
).squeeze()["host"]
138138

139-
projection_query_runner = query_runner.clone(f"neo4j+s://{member_address}:7687")
139+
# TODO: retrieve the port from the server
140+
projection_query_runner = query_runner.clone(member_address, 7687)
140141

141142
@retry(
142143
reraise=True,

graphdatascience/query_runner/query_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def set_show_progress(self, show_progress: bool) -> None:
8181
pass
8282

8383
@abstractmethod
84-
def clone(self, endpoint: str) -> "QueryRunner":
84+
def clone(self, host: str, port: int) -> "QueryRunner":
8585
pass
8686

8787
def set_server_version(self, _: ServerVersion) -> None:

graphdatascience/query_runner/session_query_runner.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,10 +120,10 @@ def set_show_progress(self, show_progress: bool) -> None:
120120
self._show_progress = show_progress
121121
self._gds_query_runner.set_show_progress(show_progress)
122122

123-
def clone(self, endpoint: str) -> QueryRunner:
123+
def clone(self, host: str, port: int) -> QueryRunner:
124124
return SessionQueryRunner(
125-
self._gds_query_runner.clone(endpoint),
126-
self._db_query_runner.clone(endpoint),
125+
self._gds_query_runner,
126+
self._db_query_runner.clone(host, port),
127127
self._gds_arrow_client,
128128
self._show_progress,
129129
)

graphdatascience/query_runner/standalone_session_query_runner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,5 +75,5 @@ def last_bookmarks(self) -> Optional[Any]:
7575
def set_server_version(self, _: ServerVersion) -> None:
7676
super().set_server_version(_)
7777

78-
def clone(self, endpoint: str) -> "QueryRunner":
79-
return StandaloneSessionQueryRunner(self._query_runner.clone(endpoint))
78+
def clone(self, host: str, port: int) -> "QueryRunner":
79+
return self

graphdatascience/tests/unit/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def create_graph_constructor(
126126
self, graph_name, concurrency, undirected_relationship_types, self._server_version
127127
)
128128

129-
def clone(self, endpoint: str) -> "QueryRunner":
129+
def clone(self, host: str, port: int) -> "QueryRunner":
130130
return self
131131

132132
def set__mock_result(self, result: DataFrame) -> None:

0 commit comments

Comments
 (0)