diff --git a/CLAUDE.md b/CLAUDE.md index ad057c6673..27f523d2a7 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -125,7 +125,7 @@ Key points: - Add `Serialize, Deserialize` derives for errors with metadata fields - Always return anyhow errors from failable functions - For example: `fn foo() -> Result { /* ... */ }` -- Import anyhow using `use anyhow::*` instead of importing individual types +- Do not glob import (`::*`) from anyhow. Instead, import individual types and traits **Dependency Management** - When adding a dependency, check for a workspace dependency in Cargo.toml diff --git a/engine/artifacts/errors/guard.websocket_service_hibernate.json b/engine/artifacts/errors/guard.websocket_service_hibernate.json new file mode 100644 index 0000000000..a9647dfd97 --- /dev/null +++ b/engine/artifacts/errors/guard.websocket_service_hibernate.json @@ -0,0 +1,5 @@ +{ + "code": "websocket_service_hibernate", + "group": "guard", + "message": "Initiate WebSocket service hibernation." +} \ No newline at end of file diff --git a/engine/artifacts/errors/guard.websocket_service_retry.json b/engine/artifacts/errors/guard.websocket_service_retry.json deleted file mode 100644 index e73bbbc507..0000000000 --- a/engine/artifacts/errors/guard.websocket_service_retry.json +++ /dev/null @@ -1,5 +0,0 @@ -{ - "code": "websocket_service_retry", - "group": "guard", - "message": "WebSocket service retry." -} \ No newline at end of file diff --git a/engine/packages/gasoline/src/builder/workflow/sub_workflow.rs b/engine/packages/gasoline/src/builder/workflow/sub_workflow.rs index 6048648b34..6c1180ae2f 100644 --- a/engine/packages/gasoline/src/builder/workflow/sub_workflow.rs +++ b/engine/packages/gasoline/src/builder/workflow/sub_workflow.rs @@ -266,7 +266,7 @@ where tracing::debug!("waiting for sub workflow"); - let mut wake_sub = self.ctx.db().wake_sub().await?; + let mut bump_sub = self.ctx.db().bump_sub().await?; let mut retries = self.ctx.db().max_sub_workflow_poll_retries(); let mut interval = tokio::time::interval(self.ctx.db().sub_workflow_poll_interval()); interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); @@ -294,7 +294,7 @@ where // Poll and wait for a wake at the same time tokio::select! { - _ = wake_sub.next() => {}, + _ = bump_sub.next() => {}, _ = interval.tick() => {}, } } diff --git a/engine/packages/gasoline/src/ctx/common.rs b/engine/packages/gasoline/src/ctx/common.rs index e1f470315f..49a3a34efe 100644 --- a/engine/packages/gasoline/src/ctx/common.rs +++ b/engine/packages/gasoline/src/ctx/common.rs @@ -26,7 +26,7 @@ pub async fn wait_for_workflow_output( ) -> Result { tracing::debug!(?workflow_id, "waiting for workflow"); - let mut wake_sub = db.wake_sub().await?; + let mut bump_sub = db.bump_sub().await?; let mut interval = tokio::time::interval(db.sub_workflow_poll_interval()); // Skip first tick, we wait after the db call instead of before @@ -47,7 +47,7 @@ pub async fn wait_for_workflow_output( // Poll and wait for a wake at the same time tokio::select! { - _ = wake_sub.next() => {}, + _ = bump_sub.next() => {}, _ = interval.tick() => {}, } } diff --git a/engine/packages/gasoline/src/ctx/workflow.rs b/engine/packages/gasoline/src/ctx/workflow.rs index 086104c34c..8d184036a4 100644 --- a/engine/packages/gasoline/src/ctx/workflow.rs +++ b/engine/packages/gasoline/src/ctx/workflow.rs @@ -716,7 +716,7 @@ impl WorkflowCtx { else { tracing::debug!("listening for signal"); - let mut wake_sub = self.db.wake_sub().await?; + let mut bump_sub = self.db.bump_sub().await?; let mut retries = self.db.max_signal_poll_retries(); let mut interval = tokio::time::interval(self.db.signal_poll_interval()); interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); @@ -742,7 +742,7 @@ impl WorkflowCtx { // Poll and wait for a wake at the same time tokio::select! { - _ = wake_sub.next() => {}, + _ = bump_sub.next() => {}, _ = interval.tick() => {}, res = self.wait_stop() => res?, } @@ -779,7 +779,7 @@ impl WorkflowCtx { else { tracing::debug!("listening for signal"); - let mut wake_sub = self.db.wake_sub().await?; + let mut bump_sub = self.db.bump_sub().await?; let mut retries = self.db.max_signal_poll_retries(); let mut interval = tokio::time::interval(self.db.signal_poll_interval()); interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); @@ -805,7 +805,7 @@ impl WorkflowCtx { // Poll and wait for a wake at the same time tokio::select! { - _ = wake_sub.next() => {}, + _ = bump_sub.next() => {}, _ = interval.tick() => {}, res = self.wait_stop() => res?, } @@ -1186,7 +1186,7 @@ impl WorkflowCtx { (async { tracing::debug!("listening for signal with timeout"); - let mut wake_sub = self.db.wake_sub().await?; + let mut bump_sub = self.db.bump_sub().await?; let mut interval = tokio::time::interval(self.db.signal_poll_interval()); interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); @@ -1206,7 +1206,7 @@ impl WorkflowCtx { // Poll and wait for a wake at the same time tokio::select! { - _ = wake_sub.next() => {}, + _ = bump_sub.next() => {}, _ = interval.tick() => {}, res = self.wait_stop() => res?, } @@ -1229,7 +1229,7 @@ impl WorkflowCtx { else { tracing::debug!("listening for signal with timeout"); - let mut wake_sub = self.db.wake_sub().await?; + let mut bump_sub = self.db.bump_sub().await?; let mut retries = self.db.max_signal_poll_retries(); let mut interval = tokio::time::interval(self.db.signal_poll_interval()); interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); @@ -1257,7 +1257,7 @@ impl WorkflowCtx { // Poll and wait for a wake at the same time tokio::select! { - _ = wake_sub.next() => {}, + _ = bump_sub.next() => {}, _ = interval.tick() => {}, res = self.wait_stop() => res?, } diff --git a/engine/packages/gasoline/src/db/kv/debug.rs b/engine/packages/gasoline/src/db/kv/debug.rs index d68601483c..aff9967940 100644 --- a/engine/packages/gasoline/src/db/kv/debug.rs +++ b/engine/packages/gasoline/src/db/kv/debug.rs @@ -718,7 +718,7 @@ impl DatabaseDebug for DatabaseKv { .instrument(tracing::info_span!("wake_workflows_tx")) .await?; - self.wake_worker(); + self.bump_workers(); Ok(()) } diff --git a/engine/packages/gasoline/src/db/kv/mod.rs b/engine/packages/gasoline/src/db/kv/mod.rs index debb197da8..a53fdb85ad 100644 --- a/engine/packages/gasoline/src/db/kv/mod.rs +++ b/engine/packages/gasoline/src/db/kv/mod.rs @@ -43,8 +43,8 @@ mod keys; const WORKER_INSTANCE_LOST_THRESHOLD_MS: i64 = rivet_util::duration::seconds(30); /// How long before overwriting an existing metrics lock. const METRICS_LOCK_TIMEOUT_MS: i64 = rivet_util::duration::seconds(30); -/// For pubsub wake mechanism. -const WORKER_WAKE_SUBJECT: &str = "gasoline.worker.wake"; +/// For pubsub bump mechanism. +const WORKER_BUMP_SUBJECT: &str = "gasoline.worker.bump"; pub struct DatabaseKv { pools: rivet_pools::Pools, @@ -52,31 +52,31 @@ pub struct DatabaseKv { } impl DatabaseKv { - /// Spawns a new thread and publishes a worker wake message to pubsub. - fn wake_worker(&self) { + /// Spawns a new thread and publishes a worker bump message to pubsub. + fn bump_workers(&self) { let Ok(pubsub) = self.pools.ups() else { tracing::debug!("failed to acquire pubsub pool"); return; }; - let spawn_res = tokio::task::Builder::new().name("wake").spawn( + let spawn_res = tokio::task::Builder::new().name("bump").spawn( async move { // Fail gracefully if let Err(err) = pubsub .publish( - WORKER_WAKE_SUBJECT, + WORKER_BUMP_SUBJECT, &Vec::new(), universalpubsub::PublishOpts::broadcast(), ) .await { - tracing::warn!(?err, "failed to publish wake message"); + tracing::warn!(?err, "failed to publish bump message"); } } - .instrument(tracing::info_span!("wake_worker_publish")), + .instrument(tracing::info_span!("bump_worker_publish")), ); if let Err(err) = spawn_res { - tracing::error!(?err, "failed to spawn wake task"); + tracing::error!(?err, "failed to spawn bump task"); } } } @@ -424,12 +424,12 @@ impl Database for DatabaseKv { } #[tracing::instrument(skip_all)] - async fn wake_sub<'a, 'b>(&'a self) -> WorkflowResult> { + async fn bump_sub<'a, 'b>(&'a self) -> WorkflowResult> { let mut subscriber = self .pools .ups() .map_err(WorkflowError::PoolsGeneric)? - .subscribe(WORKER_WAKE_SUBJECT) + .subscribe(WORKER_BUMP_SUBJECT) .await .map_err(|x| WorkflowError::CreateSubscription(x.into()))?; @@ -586,7 +586,7 @@ impl Database for DatabaseKv { "handled failover", ); - self.wake_worker(); + self.bump_workers(); } Ok(()) @@ -815,7 +815,7 @@ impl Database for DatabaseKv { .await .map_err(WorkflowError::Udb)?; - self.wake_worker(); + self.bump_workers(); Ok(workflow_id) } @@ -1028,7 +1028,7 @@ impl Database for DatabaseKv { { let wake_deadline_ts = key.condition.deadline_ts(); - // Update wake deadline ts + // Update wake deadline ts if earlier if last_wake_deadline_ts.is_none() || wake_deadline_ts < *last_wake_deadline_ts { @@ -1633,7 +1633,7 @@ impl Database for DatabaseKv { // Wake worker again in case some other workflow was waiting for this one to complete if wrote_to_wake_idx { - self.wake_worker(); + self.bump_workers(); } let dt = start_instant.elapsed().as_secs_f64(); @@ -1794,7 +1794,7 @@ impl Database for DatabaseKv { // // This will result in the workflow sleeping instead of immediately running again. // - // Adding this wake_worker call ensures that if the workflow has a valid wake condition before commit + // Adding this bump_workers call ensures that if the workflow has a valid wake condition before commit // then it will immediately wake up again. // // This is simpler than having this commit_workflow fn read wake conditions because: @@ -1802,7 +1802,7 @@ impl Database for DatabaseKv { // - would involve informing the worker to restart the workflow in memory instead of the usual // workflow lifecycle // - the worker is already designed to pull wake conditions frequently - self.wake_worker(); + self.bump_workers(); let dt = start_instant.elapsed().as_secs_f64(); metrics::COMMIT_WORKFLOW_DURATION.record( @@ -2111,7 +2111,7 @@ impl Database for DatabaseKv { .await .map_err(WorkflowError::Udb)?; - self.wake_worker(); + self.bump_workers(); Ok(()) } @@ -2163,7 +2163,7 @@ impl Database for DatabaseKv { .await .map_err(WorkflowError::Udb)?; - self.wake_worker(); + self.bump_workers(); Ok(()) } @@ -2219,7 +2219,7 @@ impl Database for DatabaseKv { .await .map_err(WorkflowError::Udb)?; - self.wake_worker(); + self.bump_workers(); Ok(sub_workflow_id) } diff --git a/engine/packages/gasoline/src/db/mod.rs b/engine/packages/gasoline/src/db/mod.rs index 81af4f6b24..7b5d2a11cb 100644 --- a/engine/packages/gasoline/src/db/mod.rs +++ b/engine/packages/gasoline/src/db/mod.rs @@ -30,12 +30,12 @@ pub trait Database: Send { // MARK: Const fns - /// How often to pull workflows when polling. This runs alongside a wake sub. + /// How often to pull workflows when polling. This runs alongside a bump sub. fn worker_poll_interval(&self) -> Duration { Duration::from_secs(90) } - /// Poll interval when polling for signals in-process. This runs alongside a wake sub. + /// Poll interval when polling for signals in-process. This runs alongside a bump sub. fn signal_poll_interval(&self) -> Duration { Duration::from_millis(500) } @@ -45,7 +45,7 @@ pub trait Database: Send { 4 } - /// Poll interval when polling for a sub workflow in-process. This runs alongside a wake sub. + /// Poll interval when polling for a sub workflow in-process. This runs alongside a bump sub. fn sub_workflow_poll_interval(&self) -> Duration { Duration::from_millis(500) } @@ -59,7 +59,7 @@ pub trait Database: Send { /// This function returns a subscription which should resolve once the worker should fetch the database /// again. - async fn wake_sub<'a, 'b>(&'a self) -> WorkflowResult>; + async fn bump_sub<'a, 'b>(&'a self) -> WorkflowResult>; /// Updates the last ping ts for this worker. async fn update_worker_ping(&self, worker_instance_id: Id) -> WorkflowResult<()>; diff --git a/engine/packages/gasoline/src/worker.rs b/engine/packages/gasoline/src/worker.rs index 43b1680040..5321656a28 100644 --- a/engine/packages/gasoline/src/worker.rs +++ b/engine/packages/gasoline/src/worker.rs @@ -14,7 +14,7 @@ use tracing_opentelemetry::OpenTelemetrySpanExt; use crate::{ctx::WorkflowCtx, db::DatabaseHandle, error::WorkflowError, registry::RegistryHandle}; /// How often to run gc and update ping. -const PING_INTERVAL: Duration = Duration::from_secs(20); +const PING_INTERVAL: Duration = Duration::from_secs(10); /// How often to publish metrics. const METRICS_INTERVAL: Duration = Duration::from_secs(20); /// Time to allow running workflows to shutdown after receiving a SIGINT or SIGTERM. @@ -57,7 +57,7 @@ impl Worker { } } - /// Polls the database periodically or wakes immediately when `Database::wake` finishes + /// Polls the database periodically or wakes immediately when `Database::bump_sub` finishes #[tracing::instrument(skip_all, fields(worker_instance_id=%self.worker_instance_id))] pub async fn start(mut self, mut shutdown_rx: Option>) -> Result<()> { tracing::debug!( @@ -67,7 +67,7 @@ impl Worker { let cache = rivet_cache::CacheInner::from_env(&self.config, self.pools.clone())?; - let mut wake_sub = { self.db.wake_sub().await? }; + let mut bump_sub = { self.db.bump_sub().await? }; let mut tick_interval = tokio::time::interval(self.db.worker_poll_interval()); tick_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); @@ -89,7 +89,7 @@ impl Worker { tokio::select! { _ = tick_interval.tick() => {}, - res = wake_sub.next() => { + res = bump_sub.next() => { if res.is_none() { break Err(WorkflowError::SubscriptionUnsubscribed.into()); } diff --git a/engine/packages/guard-core/src/custom_serve.rs b/engine/packages/guard-core/src/custom_serve.rs index 351747e96d..7a7f4a2e2c 100644 --- a/engine/packages/guard-core/src/custom_serve.rs +++ b/engine/packages/guard-core/src/custom_serve.rs @@ -1,4 +1,4 @@ -use anyhow::*; +use anyhow::{Result, bail}; use async_trait::async_trait; use bytes::Bytes; use http_body_util::Full; @@ -10,6 +10,11 @@ use crate::WebSocketHandle; use crate::proxy_service::ResponseBody; use crate::request_context::RequestContext; +pub enum HibernationResult { + Continue, + Close, +} + /// Trait for custom request serving logic that can handle both HTTP and WebSocket requests #[async_trait] pub trait CustomServeTrait: Send + Sync { @@ -23,11 +28,21 @@ pub trait CustomServeTrait: Send + Sync { /// Handle a WebSocket connection after upgrade. Supports connection retries. async fn handle_websocket( &self, - websocket: WebSocketHandle, - headers: &hyper::HeaderMap, - path: &str, - request_context: &mut RequestContext, + _websocket: WebSocketHandle, + _headers: &hyper::HeaderMap, + _path: &str, + _request_context: &mut RequestContext, // Identifies the websocket across retries. - unique_request_id: Uuid, - ) -> Result>; + _unique_request_id: Uuid, + ) -> Result> { + bail!("service does not support websockets"); + } + + /// Returns true if the websocket should close. + async fn handle_websocket_hibernation( + &self, + _websocket: WebSocketHandle, + ) -> Result { + bail!("service does not support websocket hibernation"); + } } diff --git a/engine/packages/guard-core/src/errors.rs b/engine/packages/guard-core/src/errors.rs index ac88f0a439..ebfa809da5 100644 --- a/engine/packages/guard-core/src/errors.rs +++ b/engine/packages/guard-core/src/errors.rs @@ -93,8 +93,12 @@ pub struct ServiceUnavailable; pub struct WebSocketServiceUnavailable; #[derive(RivetError, Serialize, Deserialize)] -#[error("guard", "websocket_service_retry", "WebSocket service retry.")] -pub struct WebSocketServiceRetry; +#[error( + "guard", + "websocket_service_hibernate", + "Initiate WebSocket service hibernation." +)] +pub struct WebSocketServiceHibernate; #[derive(RivetError, Serialize, Deserialize)] #[error("guard", "websocket_service_timeout", "WebSocket service timed out.")] diff --git a/engine/packages/guard-core/src/proxy_service.rs b/engine/packages/guard-core/src/proxy_service.rs index eac786680c..6a03749033 100644 --- a/engine/packages/guard-core/src/proxy_service.rs +++ b/engine/packages/guard-core/src/proxy_service.rs @@ -1,4 +1,4 @@ -use anyhow::{Context, Result, bail}; +use anyhow::{Context, Result, bail, ensure}; use bytes::Bytes; use futures_util::{SinkExt, StreamExt}; use http_body_util::{BodyExt, Full}; @@ -31,7 +31,9 @@ use url::Url; use uuid::Uuid; use crate::{ - WebSocketHandle, custom_serve::CustomServeTrait, errors, metrics, + WebSocketHandle, + custom_serve::{CustomServeTrait, HibernationResult}, + errors, metrics, request_context::RequestContext, }; @@ -1828,7 +1830,7 @@ impl ProxyService { ); } ResolveRouteOutput::Response(_) => unreachable!(), - ResolveRouteOutput::CustomServe(mut handlers) => { + ResolveRouteOutput::CustomServe(mut handler) => { tracing::debug!(%req_path, "Spawning task to handle WebSocket communication"); let mut request_context = request_context.clone(); let req_headers = req_headers.clone(); @@ -1838,6 +1840,7 @@ impl ProxyService { tokio::spawn( async move { let request_id = Uuid::new_v4(); + let mut ws_hibernation_close = false; let mut attempts = 0u32; let ws_handle = WebSocketHandle::new(client_ws) @@ -1845,7 +1848,7 @@ impl ProxyService { .context("failed initiating websocket handle")?; loop { - match handlers + match handler .handle_websocket( ws_handle.clone(), &req_headers, @@ -1895,18 +1898,43 @@ impl ProxyService { Err(err) => { tracing::debug!(?err, "websocket handler error"); - // Denotes that the connection did not fail, but needs to be retried to - // resole a new target - let ws_retry = is_ws_retry(&err); + // Denotes that the connection did not fail, but the downstream has closed + let ws_hibernate = is_ws_hibernate(&err); - if ws_retry { + if ws_hibernate { attempts = 0; } else { attempts += 1; } - if attempts > max_attempts - || (!is_retryable_ws_error(&err) && !ws_retry) + if ws_hibernate { + // This should be unreachable because as soon as the actor is + // reconnected to after hibernation the gateway will consume the close + // frame from the client ws stream + ensure!( + !ws_hibernation_close, + "should not be hibernating again after receiving a close frame during hibernation" + ); + + // After this function returns: + // - the route will be resolved again + // - the websocket will connect to the new downstream target + // - the gateway will continue reading messages from the client ws + // (starting with the message that caused the hibernation to end) + let res = handler + .handle_websocket_hibernation(ws_handle.clone()) + .await?; + + // Despite receiving a close frame from the client during hibernation + // we are going to reconnect to the actor so that it knows the + // connection has closed + if let HibernationResult::Close = res { + tracing::debug!("starting hibernating websocket close"); + + ws_hibernation_close = true; + } + } else if attempts > max_attempts + || !is_retryable_ws_error(&err) { tracing::debug!( ?attempts, @@ -1929,79 +1957,79 @@ impl ProxyService { break; } else { - if !ws_retry { - let backoff = ProxyService::calculate_backoff( - attempts, - initial_interval, - ); + let backoff = ProxyService::calculate_backoff( + attempts, + initial_interval, + ); - tracing::debug!( - ?backoff, - "WebSocket attempt {attempts} failed (service unavailable)" - ); + tracing::debug!( + ?backoff, + "WebSocket attempt {attempts} failed (service unavailable)" + ); - tokio::time::sleep(backoff).await; - } + // Apply backoff for retryable error + tokio::time::sleep(backoff).await; + } - match state - .resolve_route( - &req_host, - &req_path, - &req_method, - state.port_type.clone(), - &req_headers, - true, - ) - .await - { - Ok(ResolveRouteOutput::CustomServe(new_handlers)) => { - handlers = new_handlers; - continue; - } - Ok(ResolveRouteOutput::Response(response)) => { - ws_handle - .send(to_hyper_close(Some(str_to_close_frame( - response.message.as_ref(), - )))) - .await?; - - // Flush to ensure close frame is sent - ws_handle.flush().await?; - - // Keep TCP connection open briefly to allow client to process close - tokio::time::sleep(WEBSOCKET_CLOSE_LINGER).await; - } - Ok(ResolveRouteOutput::Target(_)) => { - ws_handle - .send(to_hyper_close(Some(err_to_close_frame( - errors::WebSocketTargetChanged.build(), - ray_id, - )))) - .await?; + // Retry route resolution + match state + .resolve_route( + &req_host, + &req_path, + &req_method, + state.port_type.clone(), + &req_headers, + true, + ) + .await + { + Ok(ResolveRouteOutput::CustomServe(new_handler)) => { + handler = new_handler; + continue; + } + Ok(ResolveRouteOutput::Response(response)) => { + ws_handle + .send(to_hyper_close(Some(str_to_close_frame( + response.message.as_ref(), + )))) + .await?; + + // Flush to ensure close frame is sent + ws_handle.flush().await?; + + // Keep TCP connection open briefly to allow client to process close + tokio::time::sleep(WEBSOCKET_CLOSE_LINGER).await; + } + Ok(ResolveRouteOutput::Target(_)) => { + ws_handle + .send(to_hyper_close(Some(err_to_close_frame( + errors::WebSocketTargetChanged.build(), + ray_id, + )))) + .await?; - // Flush to ensure close frame is sent - ws_handle.flush().await?; + // Flush to ensure close frame is sent + ws_handle.flush().await?; - // Keep TCP connection open briefly to allow client to process close - tokio::time::sleep(WEBSOCKET_CLOSE_LINGER).await; + // Keep TCP connection open briefly to allow client to process close + tokio::time::sleep(WEBSOCKET_CLOSE_LINGER).await; - break; - } - Err(err) => { - ws_handle - .send(to_hyper_close(Some(err_to_close_frame( - err, ray_id, - )))) - .await?; + break; + } + Err(err) => { + ws_handle + .send(to_hyper_close(Some(err_to_close_frame( + err, ray_id, + )))) + .await?; - // Flush to ensure close frame is sent - ws_handle.flush().await?; + // Flush to ensure close frame is sent + ws_handle.flush().await?; - // Keep TCP connection open briefly to allow client to process close - tokio::time::sleep(WEBSOCKET_CLOSE_LINGER).await; + // Keep TCP connection open briefly to allow client to process close + tokio::time::sleep(WEBSOCKET_CLOSE_LINGER).await; - break; - } + break; } } } @@ -2509,9 +2537,9 @@ fn is_retryable_ws_error(err: &anyhow::Error) -> bool { } } -fn is_ws_retry(err: &anyhow::Error) -> bool { +fn is_ws_hibernate(err: &anyhow::Error) -> bool { if let Some(rivet_err) = err.chain().find_map(|x| x.downcast_ref::()) { - rivet_err.group() == "guard" && rivet_err.code() == "websocket_service_retry" + rivet_err.group() == "guard" && rivet_err.code() == "websocket_service_hibernate" } else { false } diff --git a/engine/packages/guard-core/src/websocket_handle.rs b/engine/packages/guard-core/src/websocket_handle.rs index 763f337b20..2a3c50a4b3 100644 --- a/engine/packages/guard-core/src/websocket_handle.rs +++ b/engine/packages/guard-core/src/websocket_handle.rs @@ -1,5 +1,5 @@ use anyhow::*; -use futures_util::{SinkExt, StreamExt}; +use futures_util::{SinkExt, StreamExt, stream::Peekable}; use hyper::upgrade::Upgraded; use hyper_tungstenite::HyperWebsocket; use hyper_tungstenite::tungstenite::Message as WsMessage; @@ -8,7 +8,8 @@ use std::sync::Arc; use tokio::sync::Mutex; use tokio_tungstenite::WebSocketStream; -pub type WebSocketReceiver = futures_util::stream::SplitStream>>; +pub type WebSocketReceiver = + Peekable>>>; pub type WebSocketSender = futures_util::stream::SplitSink>, WsMessage>; @@ -26,7 +27,7 @@ impl WebSocketHandle { Ok(Self { ws_tx: Arc::new(Mutex::new(ws_tx)), - ws_rx: Arc::new(Mutex::new(ws_rx)), + ws_rx: Arc::new(Mutex::new(ws_rx.peekable())), }) } diff --git a/engine/packages/guard-core/tests/custom_serve.rs b/engine/packages/guard-core/tests/custom_serve.rs index a6ce98a49c..0d8196373d 100644 --- a/engine/packages/guard-core/tests/custom_serve.rs +++ b/engine/packages/guard-core/tests/custom_serve.rs @@ -12,17 +12,21 @@ use tokio_tungstenite::{connect_async, tungstenite::protocol::Message}; use common::{create_test_config, init_tracing, start_guard}; use rivet_guard_core::WebSocketHandle; -use rivet_guard_core::custom_serve::CustomServeTrait; +use rivet_guard_core::custom_serve::{CustomServeTrait, HibernationResult}; +use rivet_guard_core::errors::WebSocketServiceHibernate; use rivet_guard_core::proxy_service::{ResponseBody, RoutingFn, RoutingOutput}; use rivet_guard_core::request_context::RequestContext; use tokio_tungstenite::tungstenite::protocol::frame::CloseFrame; use uuid::Uuid; +const HIBERNATION_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(2); + // Track what was called for testing #[derive(Clone, Debug, Default)] struct CallTracker { http_calls: Arc>>, websocket_calls: Arc>>, + websocket_hibernation_calls: Arc>>, } // Test implementation of CustomServeTrait @@ -83,6 +87,11 @@ impl CustomServeTrait for TestCustomServe { match msg_result { std::result::Result::Ok(msg) if msg.is_text() => { let text = msg.to_text().unwrap_or(""); + + if text == "hibernate" { + return Err(WebSocketServiceHibernate.build()); + } + let response = format!("Custom: {}", text); if let std::result::Result::Err(e) = websocket.send(response.into()).await { eprintln!("Failed to send WebSocket message: {}", e); @@ -102,6 +111,22 @@ impl CustomServeTrait for TestCustomServe { Ok(None) } + + async fn handle_websocket_hibernation( + &self, + _websocket: WebSocketHandle, + ) -> Result { + // Track this WebSocket call + self.tracker + .websocket_hibernation_calls + .lock() + .unwrap() + .push("hibernation".to_string()); + + tokio::time::sleep(HIBERNATION_TIMEOUT).await; + + Ok(HibernationResult::Continue) + } } // Create routing function that returns CustomServe @@ -232,6 +257,68 @@ async fn test_custom_serve_websocket() { assert_eq!(http_calls.len(), 0); } +#[tokio::test] +async fn test_custom_serve_websocket_hibernation() { + init_tracing(); + + // Create tracker to verify calls + let tracker = CallTracker::default(); + + // Create routing function that returns CustomServe + let routing_fn = create_custom_serve_routing_fn(tracker.clone()); + + // Start guard with custom routing + let config = create_test_config(|_| {}); + let (guard_addr, _shutdown) = start_guard(config, routing_fn).await; + + // Connect to WebSocket through guard + let ws_url = format!("ws://{}/ws/custom", guard_addr); + let (mut ws_stream, response) = connect_async(&ws_url) + .await + .expect("Failed to connect to WebSocket"); + + // Verify upgrade was successful + assert_eq!(response.status(), StatusCode::SWITCHING_PROTOCOLS); + + // Send hibernation + ws_stream + .send(Message::Text("hibernate".to_string().into())) + .await + .expect("Failed to send WebSocket message"); + + // Send a test message + let test_message = "Hello Custom Hibernating WebSocket"; + ws_stream + .send(Message::Text(test_message.to_string().into())) + .await + .expect("Failed to send WebSocket message"); + + // Give some time for async operations to complete + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Verify the WebSocket handler hibernated + let ws_hibernation_calls = tracker.websocket_hibernation_calls.lock().unwrap(); + assert_eq!(ws_hibernation_calls.len(), 1); + assert_eq!(ws_hibernation_calls[0], "hibernation"); + + // Receive the echoed message with custom prefix + let response = tokio::time::timeout(HIBERNATION_TIMEOUT * 2, ws_stream.next()) + .await + .expect("timed out waiting for message from hibernating WebSocket"); + match response { + Some(Result::Ok(Message::Text(text))) => { + assert_eq!(text, format!("Custom: {}", test_message)); + } + other => panic!("Expected text message, got: {:?}", other), + } + + // Close the connection + ws_stream + .close(None) + .await + .expect("Failed to close WebSocket"); +} + #[tokio::test] async fn test_custom_serve_multiple_requests() { init_tracing(); diff --git a/engine/packages/guard/src/routing/api_public.rs b/engine/packages/guard/src/routing/api_public.rs index 143db070da..f6fee1840c 100644 --- a/engine/packages/guard/src/routing/api_public.rs +++ b/engine/packages/guard/src/routing/api_public.rs @@ -1,15 +1,13 @@ use std::sync::Arc; -use anyhow::*; +use anyhow::{Context, Result}; use async_trait::async_trait; use bytes::Bytes; use gas::prelude::*; use http_body_util::{BodyExt, Full}; use hyper::{Request, Response}; -use rivet_guard_core::WebSocketHandle; use rivet_guard_core::proxy_service::{ResponseBody, RoutingOutput}; use rivet_guard_core::{CustomServeTrait, request_context::RequestContext}; -use tokio_tungstenite::tungstenite::protocol::frame::CloseFrame; use tower::Service; struct ApiPublicService { @@ -30,31 +28,20 @@ impl CustomServeTrait for ApiPublicService { let response = service .call(req) .await - .map_err(|e| anyhow::anyhow!("Failed to call api-public service: {}", e))?; + .context("failed to call api-public service")?; // Collect the body and convert to ResponseBody let (parts, body) = response.into_parts(); let collected = body .collect() .await - .map_err(|e| anyhow::anyhow!("Failed to collect response body: {}", e))?; + .context("failed to collect response body")?; let bytes = collected.to_bytes(); let response_body = ResponseBody::Full(Full::new(bytes)); let response = Response::from_parts(parts, response_body); Ok(response) } - - async fn handle_websocket( - &self, - _client_ws: WebSocketHandle, - _headers: &hyper::HeaderMap, - _path: &str, - _request_context: &mut RequestContext, - _unique_request_id: Uuid, - ) -> Result> { - bail!("api-public does not support WebSocket connections") - } } /// Route requests to the api-public service diff --git a/engine/packages/guard/src/routing/pegboard_gateway.rs b/engine/packages/guard/src/routing/pegboard_gateway.rs index f8c961f34c..4d460fa0c1 100644 --- a/engine/packages/guard/src/routing/pegboard_gateway.rs +++ b/engine/packages/guard/src/routing/pegboard_gateway.rs @@ -235,6 +235,7 @@ async fn route_request_inner( // Return pegboard-gateway instance with path let gateway = pegboard_gateway::PegboardGateway::new( + ctx.clone(), shared_state.pegboard_gateway.clone(), runner_id, actor_id, diff --git a/engine/packages/pegboard-gateway/src/lib.rs b/engine/packages/pegboard-gateway/src/lib.rs index 7ea4685fb1..81b7b9dd09 100644 --- a/engine/packages/pegboard-gateway/src/lib.rs +++ b/engine/packages/pegboard-gateway/src/lib.rs @@ -8,18 +8,19 @@ use hyper::{Request, Response, StatusCode}; use rivet_error::*; use rivet_guard_core::{ WebSocketHandle, - custom_serve::CustomServeTrait, + custom_serve::{CustomServeTrait, HibernationResult}, errors::{ - ServiceUnavailable, WebSocketServiceRetry, WebSocketServiceTimeout, + ServiceUnavailable, WebSocketServiceHibernate, WebSocketServiceTimeout, WebSocketServiceUnavailable, }, proxy_service::ResponseBody, request_context::RequestContext, + websocket_handle::WebSocketReceiver, }; use rivet_runner_protocol as protocol; use rivet_util::serde::HashableMap; -use std::time::Duration; -use tokio::sync::watch; +use std::{sync::Arc, time::Duration}; +use tokio::sync::{Mutex, watch}; use tokio_tungstenite::tungstenite::{ Message, protocol::frame::{CloseFrame, coding::CloseCode}, @@ -47,6 +48,7 @@ enum LifecycleResult { } pub struct PegboardGateway { + ctx: StandaloneCtx, shared_state: SharedState, runner_id: Id, actor_id: Id, @@ -55,8 +57,15 @@ pub struct PegboardGateway { impl PegboardGateway { #[tracing::instrument(skip_all, fields(?actor_id, ?runner_id, ?path))] - pub fn new(shared_state: SharedState, runner_id: Id, actor_id: Id, path: String) -> Self { + pub fn new( + ctx: StandaloneCtx, + shared_state: SharedState, + runner_id: Id, + actor_id: Id, + path: String, + ) -> Self { Self { + ctx, shared_state, runner_id, actor_id, @@ -372,7 +381,7 @@ impl CustomServeTrait for PegboardGateway { if open_msg.can_hibernate && close.retry { // Successful closure - return Err(WebSocketServiceRetry.build()); + return Err(WebSocketServiceHibernate.build()); } else { return Ok(LifecycleResult::ServerClose(close)); } @@ -385,7 +394,7 @@ impl CustomServeTrait for PegboardGateway { } } else { tracing::debug!("tunnel sub closed"); - return Err(WebSocketServiceRetry.build()); + return Err(WebSocketServiceHibernate.build()); } } _ = tunnel_to_ws_abort_rx.changed() => { @@ -541,4 +550,63 @@ impl CustomServeTrait for PegboardGateway { Err(err) => Err(err), } } + + #[tracing::instrument(skip_all, fields(actor_id=?self.actor_id))] + async fn handle_websocket_hibernation( + &self, + client_ws: WebSocketHandle, + ) -> Result { + let mut ready_sub = self + .ctx + .subscribe::(("actor_id", self.actor_id)) + .await?; + + let close = tokio::select! { + _ = ready_sub.next() => { + tracing::debug!("actor became ready during hibernation"); + + HibernationResult::Continue + } + hibernation_res = hibernate_ws(client_ws.recv()) => { + let res = hibernation_res?; + + match &res { + HibernationResult::Continue => { + tracing::debug!("received message during hibernation"); + } + HibernationResult::Close => { + tracing::debug!("websocket stream closed during hibernation"); + } + } + + res + } + }; + + Ok(close) + } +} + +async fn hibernate_ws(ws_rx: Arc>) -> Result { + let mut guard = ws_rx.lock().await; + let mut pinned = std::pin::Pin::new(&mut *guard); + + loop { + if let Some(msg) = pinned.as_mut().peek().await { + match msg { + Ok(Message::Binary(_)) | Ok(Message::Text(_)) => { + return Ok(HibernationResult::Continue); + } + // We don't care about the close frame because we're currently hibernating; there is no + // downstream to send the close frame to. + Ok(Message::Close(_)) => return Ok(HibernationResult::Close), + // Ignore rest + _ => { + pinned.try_next().await?; + } + } + } else { + return Ok(HibernationResult::Close); + } + } } diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/config.ts b/rivetkit-typescript/packages/rivetkit/src/actor/config.ts index a8df7b4057..1d1eab1401 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/config.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/config.ts @@ -78,7 +78,7 @@ export const ActorConfigSchema = z noSleep: z.boolean().default(false), sleepTimeout: z.number().positive().default(30_000), /** @experimental */ - canHibernatWebSocket: z + canHibernateWebSocket: z .union([ z.boolean(), z diff --git a/rivetkit-typescript/packages/rivetkit/src/drivers/engine/actor-driver.ts b/rivetkit-typescript/packages/rivetkit/src/drivers/engine/actor-driver.ts index f97148e901..d18e55da79 100644 --- a/rivetkit-typescript/packages/rivetkit/src/drivers/engine/actor-driver.ts +++ b/rivetkit-typescript/packages/rivetkit/src/drivers/engine/actor-driver.ts @@ -219,14 +219,16 @@ export class EngineActorDriver implements ActorDriver { ); // Check if can hibernate - const canHibernatWebSocket = - definition.config.options?.canHibernatWebSocket; - if (canHibernatWebSocket === true) { + const canHibernateWebSocket = + definition.config.options?.canHibernateWebSocket; + if (canHibernateWebSocket === true) { hibernationConfig = { enabled: true, lastMsgIndex: undefined, }; - } else if (typeof canHibernatWebSocket === "function") { + } else if ( + typeof canHibernateWebSocket === "function" + ) { try { // Truncate the path to match the behavior on onRawWebSocket const newPath = truncateRawWebSocketPathPrefix( @@ -238,14 +240,14 @@ export class EngineActorDriver implements ActorDriver { ); const canHibernate = - canHibernatWebSocket(truncatedRequest); + canHibernateWebSocket(truncatedRequest); hibernationConfig = { enabled: canHibernate, lastMsgIndex: undefined, }; } catch (error) { logger().error({ - msg: "error calling canHibernatWebSocket", + msg: "error calling canHibernateWebSocket", error, }); hibernationConfig = {