@@ -248,3 +248,62 @@ async def test_subscribe(self, mock_connect):
248248 call ({"data" : {"messageAdded" : "two" }}),
249249 ]
250250 )
251+
252+ @patch ("logging.info" )
253+ @patch ("websockets.connect" )
254+ async def test_does_not_crash_with_keep_alive (self , mock_connect , mock_info ):
255+ """Subsribe a GraphQL subscription."""
256+ mock_websocket = mock_connect .return_value .__aenter__ .return_value
257+ mock_websocket .send = AsyncMock ()
258+ mock_websocket .__aiter__ .return_value = [
259+ '{"type": "ka"}' ,
260+ ]
261+
262+ client = GraphqlClient (endpoint = "ws://www.test-api.com/graphql" )
263+ query = """
264+ subscription onMessageAdded {
265+ messageAdded
266+ }
267+ """
268+
269+ await client .subscribe (query = query , handle = MagicMock ())
270+
271+ mock_info .assert_has_calls ([call ("the server sent a keep alive message" )])
272+
273+ @patch ("websockets.connect" )
274+ async def test_headers_passed_to_websocket_connect (self , mock_connect ):
275+ """Subsribe a GraphQL subscription."""
276+ mock_websocket = mock_connect .return_value .__aenter__ .return_value
277+ mock_websocket .send = AsyncMock ()
278+ mock_websocket .__aiter__ .return_value = [
279+ '{"type": "data", "id": "1", "payload": {"data": {"messageAdded": "one"}}}' ,
280+ ]
281+
282+ expected_endpoint = "ws://www.test-api.com/graphql"
283+ client = GraphqlClient (endpoint = expected_endpoint )
284+
285+ query = """
286+ subscription onMessageAdded {
287+ messageAdded
288+ }
289+ """
290+
291+ mock_handle = MagicMock ()
292+
293+ expected_headers = {"some" : "header" }
294+
295+ await client .subscribe (
296+ query = query , handle = mock_handle , headers = expected_headers
297+ )
298+
299+ mock_connect .assert_called_with (
300+ expected_endpoint ,
301+ subprotocols = ["graphql-ws" ],
302+ extra_headers = expected_headers ,
303+ )
304+
305+ mock_handle .assert_has_calls (
306+ [
307+ call ({"data" : {"messageAdded" : "one" }}),
308+ ]
309+ )
0 commit comments