@@ -323,3 +323,40 @@ async def test_headers_passed_to_websocket_connect(self, mock_connect):
323323 )
324324
325325 mock_handle .assert_has_calls ([call ({"data" : {"messageAdded" : "one" }})])
326+
327+ @patch ("websockets.connect" )
328+ async def test_init_payload_passed_in_init_message (self , mock_connect ):
329+ """Subsribe a GraphQL subscription."""
330+ mock_websocket = mock_connect .return_value .__aenter__ .return_value
331+ mock_websocket .send = AsyncMock ()
332+ mock_websocket .__aiter__ .return_value = [
333+ '{"type": "connection_init", "payload": '
334+ '{"init": "this is the init_payload"}}' ,
335+ '{"type": "data", "id": "1", "payload": {"data": {"messageAdded": "one"}}}' ,
336+ ]
337+ expected_endpoint = "ws://www.test-api.com/graphql"
338+ client = GraphqlClient (endpoint = expected_endpoint )
339+
340+ query = """
341+ subscription onMessageAdded {
342+ messageAdded
343+ }
344+ """
345+ init_payload = '{"init": "this is the init_payload"}'
346+
347+ mock_handle = MagicMock ()
348+
349+ await client .subscribe (
350+ query = query , handle = mock_handle , init_payload = init_payload
351+ )
352+
353+ mock_connect .assert_called_with (
354+ expected_endpoint , subprotocols = ["graphql-ws" ], extra_headers = {}
355+ )
356+
357+ mock_handle .assert_has_calls (
358+ [
359+ call ({"init" : "this is the init_payload" }),
360+ call ({"data" : {"messageAdded" : "one" }}),
361+ ]
362+ )
0 commit comments