@@ -46,6 +46,7 @@ def create_for_db(
4646
4747 query_runner = Neo4jQueryRunner (
4848 driver ,
49+ Neo4jQueryRunner .parse_protocol (endpoint ),
4950 auth ,
5051 auto_close = True ,
5152 bookmarks = bookmarks ,
@@ -55,8 +56,14 @@ def create_for_db(
5556 )
5657
5758 elif isinstance (endpoint , neo4j .Driver ):
59+ protocol = "neo4j+s" if endpoint .encrypted else "bolt"
5860 query_runner = Neo4jQueryRunner (
59- endpoint , auto_close = False , bookmarks = bookmarks , database = database , show_progress = show_progress
61+ endpoint ,
62+ protocol ,
63+ auto_close = False ,
64+ bookmarks = bookmarks ,
65+ database = database ,
66+ show_progress = show_progress ,
6067 )
6168 else :
6269 raise ValueError (f"Invalid endpoint type: { type (endpoint )} " )
@@ -77,6 +84,7 @@ def create_for_session(
7784
7885 query_runner = Neo4jQueryRunner (
7986 driver ,
87+ Neo4jQueryRunner .parse_protocol (endpoint ),
8088 auth ,
8189 auto_close = True ,
8290 show_progress = show_progress ,
@@ -96,9 +104,17 @@ def _configure_aura(config: dict[str, Any]) -> None:
96104 config ["keep_alive" ] = True
97105 config ["max_connection_pool_size" ] = 50
98106
107+ @staticmethod
108+ def parse_protocol (endpoint : str ) -> str :
109+ protocol_match = re .match (r"^([^:]+)://" , endpoint )
110+ if not protocol_match :
111+ raise ValueError (f"Invalid endpoint URI format: { endpoint } " )
112+ return protocol_match .group (1 )
113+
99114 def __init__ (
100115 self ,
101116 driver : neo4j .Driver ,
117+ protocol : str ,
102118 auth : Optional [tuple [str , str ]] = None ,
103119 config : dict [str , Any ] = {},
104120 database : Optional [str ] = neo4j .DEFAULT_DATABASE ,
@@ -108,6 +124,7 @@ def __init__(
108124 instance_description : str = "Neo4j DBMS" ,
109125 ):
110126 self ._driver = driver
127+ self ._protocol = protocol
111128 self ._auth = auth
112129 self ._config = config
113130 self ._auto_close = auto_close
@@ -283,10 +300,13 @@ def create_graph_constructor(
283300 def set_show_progress (self , show_progress : bool ) -> None :
284301 self ._show_progress = show_progress
285302
286- def clone (self , endpoint : str ) -> QueryRunner :
303+ def clone (self , host : str , port : int ) -> QueryRunner :
304+ endpoint = "{}://{}:{}" .format (self ._protocol , host , port )
287305 driver = neo4j .GraphDatabase .driver (endpoint , auth = self ._auth , ** self .driver_config ())
306+
288307 return Neo4jQueryRunner (
289308 driver ,
309+ self ._protocol ,
290310 self ._auth ,
291311 self ._config ,
292312 self ._database ,
0 commit comments