1+ use core:: future:: IntoFuture as _;
12use core:: net:: { Ipv6Addr , SocketAddr } ;
23use core:: pin:: pin;
34use core:: time:: Duration ;
@@ -19,7 +20,7 @@ use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, Server
1920use rustls:: pki_types:: { CertificateDer , ServerName , UnixTime } ;
2021use rustls:: version:: TLS13 ;
2122use rustls:: { DigitallySignedStruct , SignatureScheme } ;
22- use tokio:: sync:: RwLock ;
23+ use tokio:: sync:: { Notify , RwLock } ;
2324use tokio:: task:: JoinSet ;
2425use tokio:: { select, signal} ;
2526use tower_http:: trace:: TraceLayer ;
@@ -31,7 +32,7 @@ use uuid::Uuid;
3132use wrpc_transport:: { ResourceBorrow , ResourceOwn } ;
3233use wrpc_wasi_keyvalue:: exports:: wasi:: keyvalue:: store;
3334use wtransport:: tls:: { Sha256DigestFmt , WEBTRANSPORT_ALPN } ;
34- use wtransport:: { Endpoint , Identity , ServerConfig } ;
35+ use wtransport:: { Endpoint , Identity , ServerConfig , VarInt } ;
3536
3637pub type Result < T , E = store:: Error > = core:: result:: Result < T , E > ;
3738
@@ -695,45 +696,6 @@ export const PORT = "{port}"
695696 ) ;
696697
697698 let srv = Arc :: new ( wrpc_transport_web:: Server :: new ( ) ) ;
698- let webt = tokio:: spawn ( {
699- let mut tasks = JoinSet :: < anyhow:: Result < ( ) > > :: new ( ) ;
700- let srv = Arc :: clone ( & srv) ;
701- async move {
702- loop {
703- select ! {
704- conn = ep. accept( ) => {
705- let srv = Arc :: clone( & srv) ;
706- tasks. spawn( async move {
707- let req = conn
708- . await
709- . context( "failed to accept WebTransport connection" ) ?;
710- let conn = req
711- . accept( )
712- . await
713- . context( "failed to establish WebTransport connection" ) ?;
714- let wrpc = wrpc_transport_web:: Client :: from( conn) ;
715- loop {
716- srv. accept( & wrpc)
717- . await
718- . context( "failed to accept wRPC connection" ) ?;
719- }
720- } ) ;
721- }
722- Some ( res) = tasks. join_next( ) => {
723- match res {
724- Ok ( Ok ( ( ) ) ) => { }
725- Ok ( Err ( err) ) => {
726- warn!( ?err, "failed to serve connection" )
727- }
728- Err ( err) => {
729- error!( ?err, "failed to join task" )
730- }
731- }
732- }
733- }
734- }
735- }
736- } ) ;
737699
738700 let invocations = wrpc_wasi_keyvalue:: exports:: wasi:: keyvalue:: store:: serve_interface (
739701 srv. as_ref ( ) ,
@@ -746,7 +708,7 @@ export const PORT = "{port}"
746708 . into_iter ( )
747709 . map ( |( instance, name, invocations) | invocations. map ( move |res| ( instance, name, res) ) ) ,
748710 ) ;
749- let wrpc = tokio:: spawn ( async move {
711+ let mut wrpc = tokio:: spawn ( async move {
750712 let mut tasks = JoinSet :: new ( ) ;
751713 loop {
752714 select ! {
@@ -779,26 +741,75 @@ export const PORT = "{port}"
779741 }
780742 } ) ;
781743
744+ let http_shutdown = Arc :: new ( Notify :: new ( ) ) ;
745+ let http = http. with_graceful_shutdown ( {
746+ let http_shutdown = Arc :: clone ( & http_shutdown) ;
747+ async move { http_shutdown. notified ( ) . await }
748+ } ) ;
749+ let mut http = http. into_future ( ) ;
782750 let shutdown = signal:: ctrl_c ( ) ;
783751 let mut shutdown = pin ! ( shutdown) ;
752+ let mut tasks = JoinSet :: < anyhow:: Result < ( ) > > :: new ( ) ;
784753 info ! ( "serving HTTP on: http://{addr}" ) ;
785- select ! {
786- res = http => {
787- trace!( "HTTP serving stopped" ) ;
788- res. context( "failed to serve HTTP" ) ?;
789- }
790- res = webt => {
791- trace!( "WebTransport serving task stopped" ) ;
792- res. context( "failed to serve WebTransport" ) ?;
793- }
794- res = wrpc => {
795- trace!( "wRPC serving task stopped" ) ;
796- res. context( "failed to serve wRPC invocations" ) ?;
797- }
798- res = & mut shutdown => {
799- trace!( "^C received" ) ;
800- res. context( "failed to listen for ^C" ) ?;
754+ loop {
755+ select ! {
756+ res = & mut http => {
757+ trace!( "HTTP serving stopped" ) ;
758+ return res. context( "failed to serve HTTP" ) ;
759+ }
760+ res = & mut wrpc => {
761+ trace!( "wRPC serving task stopped" ) ;
762+ return res. context( "failed to serve wRPC invocations" ) ;
763+ }
764+ conn = ep. accept( ) => {
765+ let srv = Arc :: clone( & srv) ;
766+ tasks. spawn( async move {
767+ let req = conn
768+ . await
769+ . context( "failed to accept WebTransport connection" ) ?;
770+ let conn = req
771+ . accept( )
772+ . await
773+ . context( "failed to establish WebTransport connection" ) ?;
774+ let wrpc = wrpc_transport_web:: Client :: from( conn) ;
775+ loop {
776+ srv. accept( & wrpc)
777+ . await
778+ . context( "failed to accept wRPC connection" ) ?;
779+ }
780+ } ) ;
781+ }
782+ Some ( res) = tasks. join_next( ) => {
783+ match res {
784+ Ok ( Ok ( ( ) ) ) => { }
785+ Ok ( Err ( err) ) => {
786+ warn!( ?err, "failed to serve WebTransport invocation" )
787+ }
788+ Err ( err) => {
789+ error!( ?err, "failed to join WebTransport invocation task" )
790+ }
791+ }
792+ }
793+ res = & mut shutdown => {
794+ trace!( "^C received" ) ;
795+ http_shutdown. notify_waiters( ) ;
796+ ep. close( VarInt :: from_u32( 0 ) , b"shutdown" ) ;
797+ ep. wait_idle( ) . await ;
798+ // wait for all WebTransport invocations to complete
799+ while let Some ( res) = tasks. join_next( ) . await {
800+ if let Err ( err) = res {
801+ error!( ?err, "failed to WebTransport invocation task" )
802+ }
803+ }
804+ http. await . context( "HTTP server failed" ) ?;
805+ // Drop the last wRPC server reference to shutdown invocation handling task
806+ drop( srv) ;
807+ if let Err ( err) = wrpc. await {
808+ error!( ?err, "wRPC serving task failed" ) ;
809+ }
810+ res. context( "failed to listen for ^C" ) ?;
811+ return Ok ( ( ) )
812+ }
801813 }
802814 }
803- Ok ( ( ) )
804815}
0 commit comments