1111from typing import Any
1212
1313import orjson
14+ from asgi_tools import ResponseWebSocket
1415from asgiref import typing as asgi_types
1516from asgiref .compatibility import guarantee_single_callable
1617from servestatic import ServeStaticASGI
2627 AsgiHttpApp ,
2728 AsgiLifespanApp ,
2829 AsgiWebsocketApp ,
30+ AsgiWebsocketReceive ,
31+ AsgiWebsocketSend ,
2932 Connection ,
3033 Location ,
3134 ReactPyConfig ,
@@ -153,41 +156,56 @@ async def __call__(
153156 send : asgi_types .ASGISendCallable ,
154157 ) -> None :
155158 """ASGI app for rendering ReactPy Python components."""
156- dispatcher : asyncio .Task [Any ] | None = None
157- recv_queue : asyncio .Queue [dict [str , Any ]] = asyncio .Queue ()
158-
159159 # Start a loop that handles ASGI websocket events
160- while True :
161- event = await receive ()
162- if event ["type" ] == "websocket.connect" :
163- await send (
164- {"type" : "websocket.accept" , "subprotocol" : None , "headers" : []}
165- )
166- dispatcher = asyncio .create_task (
167- self .run_dispatcher (scope , receive , send , recv_queue )
168- )
169-
170- elif event ["type" ] == "websocket.disconnect" :
171- if dispatcher :
172- dispatcher .cancel ()
173- break
174-
175- elif event ["type" ] == "websocket.receive" and event ["text" ]:
176- queue_put_func = recv_queue .put (orjson .loads (event ["text" ]))
177- await queue_put_func
178-
179- async def run_dispatcher (
160+ async with ReactPyWebsocket (scope , receive , send , parent = self .parent ) as ws : # type: ignore
161+ while True :
162+ # Wait for the webserver to notify us of a new event
163+ event : dict [str , Any ] = await ws .receive (raw = True ) # type: ignore
164+
165+ # If the event is a `receive` event, parse the message and send it to the rendering queue
166+ if event ["type" ] == "websocket.receive" :
167+ msg : dict [str , str ] = orjson .loads (event ["text" ])
168+ if msg .get ("type" ) == "layout-event" :
169+ await ws .rendering_queue .put (msg )
170+ else : # pragma: no cover
171+ await asyncio .to_thread (
172+ _logger .warning , f"Unknown message type: { msg .get ('type' )} "
173+ )
174+
175+ # If the event is a `disconnect` event, break the rendering loop and close the connection
176+ elif event ["type" ] == "websocket.disconnect" :
177+ break
178+
179+
180+ class ReactPyWebsocket (ResponseWebSocket ):
181+ def __init__ (
180182 self ,
181183 scope : asgi_types .WebSocketScope ,
182- receive : asgi_types . ASGIReceiveCallable ,
183- send : asgi_types . ASGISendCallable ,
184- recv_queue : asyncio . Queue [ dict [ str , Any ]] ,
184+ receive : AsgiWebsocketReceive ,
185+ send : AsgiWebsocketSend ,
186+ parent : ReactPyMiddleware ,
185187 ) -> None :
186- """Asyncio background task that renders and transmits layout updates of ReactPy components."""
188+ super ().__init__ (scope = scope , receive = receive , send = send ) # type: ignore
189+ self .scope = scope
190+ self .parent = parent
191+ self .rendering_queue : asyncio .Queue [dict [str , str ]] = asyncio .Queue ()
192+ self .dispatcher : asyncio .Task [Any ] | None = None
193+
194+ async def __aenter__ (self ) -> ReactPyWebsocket :
195+ self .dispatcher = asyncio .create_task (self .run_dispatcher ())
196+ return await super ().__aenter__ () # type: ignore
197+
198+ async def __aexit__ (self , * _ : Any ) -> None :
199+ if self .dispatcher :
200+ self .dispatcher .cancel ()
201+ await super ().__aexit__ () # type: ignore
202+
203+ async def run_dispatcher (self ) -> None :
204+ """Async background task that renders ReactPy components over a websocket."""
187205 try :
188206 # Determine component to serve by analyzing the URL and/or class parameters.
189207 if self .parent .multiple_root_components :
190- url_match = re .match (self .parent .dispatcher_pattern , scope ["path" ])
208+ url_match = re .match (self .parent .dispatcher_pattern , self . scope ["path" ])
191209 if not url_match : # pragma: no cover
192210 raise RuntimeError ("Could not find component in URL path." )
193211 dotted_path = url_match ["dotted_path" ]
@@ -203,10 +221,10 @@ async def run_dispatcher(
203221
204222 # Create a connection object by analyzing the websocket's query string.
205223 ws_query_string = urllib .parse .parse_qs (
206- scope ["query_string" ].decode (), strict_parsing = True
224+ self . scope ["query_string" ].decode (), strict_parsing = True
207225 )
208226 connection = Connection (
209- scope = scope ,
227+ scope = self . scope ,
210228 location = Location (
211229 path = ws_query_string .get ("http_pathname" , ["" ])[0 ],
212230 query_string = ws_query_string .get ("http_query_string" , ["" ])[0 ],
@@ -217,20 +235,19 @@ async def run_dispatcher(
217235 # Start the ReactPy component rendering loop
218236 await serve_layout (
219237 Layout (ConnectionContext (component (), value = connection )),
220- lambda msg : send (
221- {
222- "type" : "websocket.send" ,
223- "text" : orjson .dumps (msg ).decode (),
224- "bytes" : None ,
225- }
226- ),
227- recv_queue .get , # type: ignore
238+ self .send_json ,
239+ self .rendering_queue .get , # type: ignore
228240 )
229241
230242 # Manually log exceptions since this function is running in a separate asyncio task.
231243 except Exception as error :
232244 await asyncio .to_thread (_logger .error , f"{ error } \n { traceback .format_exc ()} " )
233245
246+ async def send_json (self , data : Any ) -> None :
247+ return await self ._send (
248+ {"type" : "websocket.send" , "text" : orjson .dumps (data ).decode ()}
249+ )
250+
234251
235252@dataclass
236253class StaticFileApp :
0 commit comments