2424import java .util .concurrent .atomic .AtomicBoolean ;
2525
2626import graphql .schema .DataFetcher ;
27+ import jakarta .servlet .AsyncEvent ;
28+ import jakarta .servlet .AsyncListener ;
2729import jakarta .servlet .ServletException ;
2830import jakarta .servlet .ServletOutputStream ;
2931import jakarta .servlet .http .HttpServletResponse ;
3537import org .springframework .http .MediaType ;
3638import org .springframework .http .converter .HttpMessageConverter ;
3739import org .springframework .http .converter .json .MappingJackson2HttpMessageConverter ;
40+ import org .springframework .mock .web .MockAsyncContext ;
3841import org .springframework .mock .web .MockHttpServletRequest ;
3942import org .springframework .mock .web .MockHttpServletResponse ;
4043import org .springframework .web .servlet .function .AsyncServerResponse ;
@@ -72,7 +75,7 @@ class GraphQlSseHandlerTests {
7275 void shouldRejectQueryOperations () throws Exception {
7376 GraphQlSseHandler handler = createSseHandler (SEARCH_DATA_FETCHER );
7477 MockHttpServletRequest request = createServletRequest ("{ \" query\" : \" { bookById(id: 42) {name} }\" }" );
75- MockHttpServletResponse response = handleRequest (request , handler );
78+ MockHttpServletResponse response = handleAndAwait (request , handler );
7679
7780 assertThat (response .getContentType ()).isEqualTo (MediaType .TEXT_EVENT_STREAM_VALUE );
7881 assertThat (response .getContentAsString ()).isEqualTo ("""
@@ -91,7 +94,7 @@ void shouldWriteMultipleEventsForSubscription() throws Exception {
9194 MockHttpServletRequest request = createServletRequest ("""
9295 { "query": "subscription TestSubscription { bookSearch(author:\\ \" Orwell\\ \" ) { id name } }" }
9396 """ );
94- MockHttpServletResponse response = handleRequest (request , handler );
97+ MockHttpServletResponse response = handleAndAwait (request , handler );
9598
9699 assertThat (response .getContentType ()).isEqualTo (MediaType .TEXT_EVENT_STREAM_VALUE );
97100 assertThat (response .getContentAsString ()).isEqualTo ("""
@@ -117,7 +120,7 @@ void shouldWriteEventsAndTerminalError() throws Exception {
117120 MockHttpServletRequest request = createServletRequest ("""
118121 { "query": "subscription TestSubscription { bookSearch(author:\\ \" Orwell\\ \" ) { id name } }" }
119122 """ );
120- MockHttpServletResponse response = handleRequest (request , handler );
123+ MockHttpServletResponse response = handleAndAwait (request , handler );
121124
122125 assertThat (response .getContentType ()).isEqualTo (MediaType .TEXT_EVENT_STREAM_VALUE );
123126 assertThat (response .getContentAsString ()).isEqualTo ("""
@@ -153,7 +156,26 @@ void shouldCancelDataFetcherPublisherWhenWritingFails() throws Exception {
153156
154157 response .writeTo (servletRequest , servletResponse , new DefaultContext ());
155158 await ().atMost (Duration .ofMillis (500 )).until (DATA_FETCHER_CANCELLED ::get );
159+ }
160+
161+ @ Test
162+ void shouldCancelDataFetcherWhenAsyncTimeout () throws Exception {
163+ DataFetcher <?> errorDataFetcher = env -> Flux .just (BookSource .getBook (1L ))
164+ .delayElements (Duration .ofMillis (500 )).doOnCancel (() -> DATA_FETCHER_CANCELLED .set (true ));
165+
166+ GraphQlSseHandler handler = createSseHandler (errorDataFetcher );
167+ MockHttpServletRequest servletRequest = createServletRequest ("""
168+ { "query": "subscription TestSubscription { bookSearch(author:\\ \" Orwell\\ \" ) { id name } }" }
169+ """ );
156170
171+ MockHttpServletResponse servletResponse = handleRequest (servletRequest , handler );
172+ for (AsyncListener listener : ((MockAsyncContext ) servletRequest .getAsyncContext ()).getListeners ()) {
173+ listener .onTimeout (new AsyncEvent (servletRequest .getAsyncContext ()));
174+ }
175+
176+ assertThat (DATA_FETCHER_CANCELLED .get ()).isTrue ();
177+ assertThat (servletResponse .getContentType ()).isEqualTo (MediaType .TEXT_EVENT_STREAM_VALUE );
178+ assertThat (servletResponse .getContentAsString ()).isEmpty ();
157179 }
158180
159181 private GraphQlSseHandler createSseHandler (DataFetcher <?> dataFetcher ) {
@@ -174,15 +196,19 @@ private MockHttpServletRequest createServletRequest(String query) {
174196
175197 private MockHttpServletResponse handleRequest (
176198 MockHttpServletRequest servletRequest , GraphQlSseHandler handler ) throws ServletException , IOException {
177-
178199 ServerRequest request = ServerRequest .create (servletRequest , MESSAGE_READERS );
179200 ServerResponse response = handler .handleRequest (request );
180201 if (response instanceof AsyncServerResponse asyncResponse ) {
181202 asyncResponse .block ();
182203 }
183-
184204 MockHttpServletResponse servletResponse = new MockHttpServletResponse ();
185205 response .writeTo (servletRequest , servletResponse , new DefaultContext ());
206+ return servletResponse ;
207+ }
208+
209+ private MockHttpServletResponse handleAndAwait (
210+ MockHttpServletRequest servletRequest , GraphQlSseHandler handler ) throws ServletException , IOException {
211+ MockHttpServletResponse servletResponse = handleRequest (servletRequest , handler );
186212 await ().atMost (Duration .ofMillis (500 )).until (() -> servletResponse .getContentAsString ().contains ("complete" ));
187213 return servletResponse ;
188214 }
0 commit comments