diff --git a/engine/sdks/typescript/runner-protocol/src/index.ts b/engine/sdks/typescript/runner-protocol/src/index.ts index 69bc8d5e0b..ff759d7b4e 100644 --- a/engine/sdks/typescript/runner-protocol/src/index.ts +++ b/engine/sdks/typescript/runner-protocol/src/index.ts @@ -1,4 +1,4 @@ - +import assert from "node:assert" import * as bare from "@bare-ts/lib" const DEFAULT_CONFIG = /* @__PURE__ */ bare.Config({}) @@ -1925,9 +1925,3 @@ export function decodeToServerlessServer(bytes: Uint8Array): ToServerlessServer } return result } - - -function assert(condition: boolean, message?: string): asserts condition { - if (!condition) throw new Error(message ?? "Assertion failed") -} - diff --git a/engine/sdks/typescript/runner/src/mod.ts b/engine/sdks/typescript/runner/src/mod.ts index 9d30b830d1..c972d04828 100644 --- a/engine/sdks/typescript/runner/src/mod.ts +++ b/engine/sdks/typescript/runner/src/mod.ts @@ -3,14 +3,15 @@ import type { Logger } from "pino"; import type WebSocket from "ws"; import { logger, setLogger } from "./log.js"; import { stringifyCommandWrapper, stringifyEvent } from "./stringify"; -import { Tunnel } from "./tunnel"; +import { HibernatingWebSocketMetadata, Tunnel } from "./tunnel"; import { calculateBackoff, parseWebSocketCloseReason, unreachable, } from "./utils"; import { importWebSocket } from "./websocket.js"; -import type { WebSocketTunnelAdapter } from "./websocket-tunnel-adapter"; + +export { HibernatingWebSocketMetadata } from "./tunnel"; const KV_EXPIRE: number = 30_000; const PROTOCOL_VERSION: number = 3; @@ -51,38 +52,151 @@ export interface RunnerConfig { onConnected: () => void; onDisconnected: (code: number, reason: string) => void; onShutdown: () => void; + + /** Called when receiving a network request. */ fetch: ( runner: Runner, actorId: string, requestId: protocol.RequestId, request: Request, ) => Promise; - websocket?: ( + + /** + * Called when receiving a WebSocket connection. + * + * All event listeners must be added synchronously inside this function or + * else events may be missed. The open event will fire immediately after + * this function finishes. + * + * Any errors thrown here will disconnect the WebSocket immediately. + * + * ## Hibernating Web Sockets + * + * ### Implementation Requirements + * + * **Requirement 1: Persist HWS Immediately** + * + * This is responsible for persisting hibernatable WebSockets immediately + * (do not wait for open event). It is not time sensitive to flush the + * connection state. If this fails to persist the HWS, the client's + * WebSocket will be disconnected on next wake in + * `Tunnel::restoreHibernatingRequests` since the connection entry will not + * exist. + * + * **Requirement 2: Persist Message Index On `message`** + * + * In the `message` event listener, this handler must persist the message + * index from the event. The request ID is available at + * `event.rivetRequestId` and message index at `event.rivetMessageIndex`. + * + * The message index should not be flushed immediately. Instead, this + * should: + * + * - Debounce calls to persist the message index + * - After each persist, call + * `Runner::sendHibernatableWebSocketMessageAck` to acknowledge the + * message + * + * This mechanism allows us to buffer messages on the gateway so we can + * batch-persist events on our end on a given interval. + * + * If this fails to persist, then the gateway will replay unacked + * messages when the actor starts again. + * + * **Requirement 3: Remove HWS From Storage On `close`** + * + * This handler should add an event listener for `close` to remove the + * connection from storage. + * + * If the connection remove fails to persist, the close event will be + * called again on the next actor start in + * `Tunnel::restoreHibernatingRequests` since there will be no request for + * the given connection. + * + * ### Restoring Connections + * + * `loadAll` will be called from `Tunnel::restoreHibernatingRequests` to + * restore this connection on the next actor wake. + * + * `restoreHibernatingRequests` is responsible for both making sure that + * new connections are registered with the actor and zombie connections are + * appropriately cleaned up. + * + * ### No Open Event On Restoration + * + * When restoring a HWS, the open event will not be called again. It will + * go straight to the message or close event. + */ + websocket: ( runner: Runner, actorId: string, ws: any, requestId: protocol.RequestId, request: Request, - ) => Promise; + ) => void; + + hibernatableWebSocket: { + /** + * Determines if a WebSocket can continue to live while an actor goes to + * sleep. + */ + canHibernate: ( + actorId: string, + requestId: ArrayBuffer, + request: Request, + ) => boolean; + + /** + * Returns all hibernatable WebSockets that are stored for this actor. + * + * This is called on actor start. + * + * This list will be diffed with the list of hibernating requests in + * the ActorStart message. + * + * This that are connected but not loaded (i.e. were not successfully + * persisted to this actor) will be disconnected. + * + * This that are not connected but were loaded (i.e. disconnected but + * this actor has not received the event yet) will also be + * disconnected. + */ + loadAll(actorId: string): HibernatingWebSocketMetadata[]; + + /** + * Notify the HWS message index needs to be persisted in the background. + * + * The message index should not be flushed immediately. Instead, this + * should: + * + * - Debounce calls to persist the message index + * - After each persist, call + * `Runner::sendHibernatableWebSocketMessageAck` to acknowledge the + * message + * + * This mechanism allows us to buffer messages on the gateway so we can + * batch-persist events on our end on a given interval. + * + * If this fails to persist, then the gateway will replay unacked + * messages when the actor starts again. + */ + persistMessageIndex: ( + actorId: string, + requestId: protocol.RequestId, + messageIndex: number, + ) => void; + }; + onActorStart: ( actorId: string, generation: number, config: ActorConfig, ) => Promise; + onActorStop: (actorId: string, generation: number) => Promise; - getActorHibernationConfig: ( - actorId: string, - requestId: ArrayBuffer, - request: Request, - ) => HibernationConfig; noAutoShutdown?: boolean; } -export interface HibernationConfig { - enabled: boolean; - lastMsgIndex: number | undefined; -} - export interface KvListOptions { reverse?: boolean; limit?: number; @@ -105,7 +219,6 @@ export class Runner { } #actors: Map = new Map(); - #actorWebSockets: Map> = new Map(); // WebSocket #pegboardWebSocket?: WebSocket; @@ -809,6 +922,8 @@ export class Runner { } #handleCommandStartActor(commandWrapper: protocol.CommandWrapper) { + if (!this.#tunnel) throw new Error("missing tunnel on actor start"); + const startCommand = commandWrapper.inner .val as protocol.CommandStartActor; @@ -850,6 +965,11 @@ export class Runner { // Send stopped state update if start failed this.forceStopActor(actorId, generation); }); + + this.#tunnel.restoreHibernatingRequests( + actorId, + startCommand.hibernatingRequestIds, + ); } #handleCommandStopActor(commandWrapper: protocol.CommandWrapper) { @@ -1427,8 +1547,10 @@ export class Runner { } } - sendWebsocketMessageAck(requestId: ArrayBuffer, index: number) { - this.#tunnel?.__ackWebsocketMessage(requestId, index); + sendHibernatableWebSocketMessageAck(requestId: ArrayBuffer, index: number) { + if (!this.#tunnel) + throw new Error("missing tunnel to send message ack"); + this.#tunnel.sendHibernatableWebSocketMessageAck(requestId, index); } getServerlessInitPacket(): string | undefined { diff --git a/engine/sdks/typescript/runner/src/stringify.ts b/engine/sdks/typescript/runner/src/stringify.ts index 699b2745c1..b8c87d0f2c 100644 --- a/engine/sdks/typescript/runner/src/stringify.ts +++ b/engine/sdks/typescript/runner/src/stringify.ts @@ -46,8 +46,8 @@ export function stringifyToServerTunnelMessageKind( case "ToServerResponseAbort": return "ToServerResponseAbort"; case "ToServerWebSocketOpen": { - const { canHibernate, lastMsgIndex } = kind.val; - return `ToServerWebSocketOpen{canHibernate: ${canHibernate}, lastMsgIndex: ${stringifyBigInt(lastMsgIndex)}}`; + const { canHibernate } = kind.val; + return `ToServerWebSocketOpen{canHibernate: ${canHibernate}}`; } case "ToServerWebSocketMessage": { const { data, binary } = kind.val; diff --git a/engine/sdks/typescript/runner/src/tunnel.ts b/engine/sdks/typescript/runner/src/tunnel.ts index d0fa360e84..479f08b793 100644 --- a/engine/sdks/typescript/runner/src/tunnel.ts +++ b/engine/sdks/typescript/runner/src/tunnel.ts @@ -9,11 +9,13 @@ import { stringifyToServerTunnelMessageKind, } from "./stringify"; import { unreachable } from "./utils"; -import { WebSocketTunnelAdapter } from "./websocket-tunnel-adapter"; +import { + WebSocketTunnelAdapter, + HIBERNATABLE_SYMBOL, +} from "./websocket-tunnel-adapter"; const GC_INTERVAL = 60000; // 60 seconds const MESSAGE_ACK_TIMEOUT = 5000; // 5 seconds -const WEBSOCKET_STATE_PERSIST_TIMEOUT = 30000; // 30 seconds interface PendingRequest { resolve: (response: Response) => void; @@ -22,6 +24,13 @@ interface PendingRequest { actorId?: string; } +export interface HibernatingWebSocketMetadata { + requestId: RequestId; + path: string; + headers: Record; + messageIndex: number; +} + interface PendingTunnelMessage { sentAt: number; requestIdStr: string; @@ -84,13 +93,202 @@ export class Tunnel { for (const [_, ws] of this.#actorWebSockets) { // Only close non-hibernatable websockets to prevent sending // unnecessary close messages for websockets that will be hibernated - if (!ws.canHibernate) { - ws.__closeWithoutCallback(1000, "ws.tunnel_shutdown"); + if (!ws[HIBERNATABLE_SYMBOL]) { + ws._closeWithoutCallback(1000, "ws.tunnel_shutdown"); } } this.#actorWebSockets.clear(); } + restoreHibernatingRequests( + actorId: string, + requestIds: readonly RequestId[], + ) { + // Load all persisted metadata + const metaEntries = + this.#runner.config.hibernatableWebSocket.loadAll(actorId); + + // Create maps for efficient lookup + const requestIdMap = new Map(); + for (const requestId of requestIds) { + requestIdMap.set(idToStr(requestId), requestId); + } + + const metaMap = new Map(); + for (const meta of metaEntries) { + metaMap.set(idToStr(meta.requestId), meta); + } + + // Process connected WebSockets + let connectedButNotLoadedCount = 0; + let restoredCount = 0; + for (const [requestIdStr, requestId] of requestIdMap) { + const meta = metaMap.get(requestIdStr); + + if (!meta) { + // Connected but not loaded (not persisted) - close it + // + // This may happen if + this.log?.warn({ + msg: "closing websocket that is not persisted", + requestId: requestIdStr, + }); + + this.#sendMessage(requestId, { + tag: "ToServerWebSocketClose", + val: { + code: 1000, + reason: "ws.meta_not_found_during_restore", + retry: false, + }, + }); + + connectedButNotLoadedCount++; + } else { + // Both connected and persisted - restore it + const request = buildRequestForWebSocket( + meta.path, + meta.headers, + ); + + // This will call `runner.config.websocket` under the hood to + // attach the event listeners to the WebSocket. + this.#createWebSocket( + actorId, + requestId, + requestIdStr, + true, + meta.messageIndex, + request, + false, + ); + + restoredCount++; + } + } + + // Process loaded but not connected (stale) - remove them + let loadedButNotConnectedCount = 0; + for (const [requestIdStr, meta] of metaMap) { + if (!requestIdMap.has(requestIdStr)) { + this.log?.warn({ + msg: "removing stale persisted websocket", + requestId: requestIdStr, + }); + + const request = buildRequestForWebSocket(meta.path, meta.headers); + + // Create adapter to register user's event listeners. + // Pass engineAlreadyClosed=true so close callback won't send tunnel message. + const adapter = this.#createWebSocket( + actorId, + meta.requestId, + requestIdStr, + true, + meta.messageIndex, + request, + true, + ); + + // Close the adapter normally - this will fire user's close event handler + // (which should clean up persistence) and trigger the close callback + // (which will clean up maps but skip sending tunnel message) + adapter.close(1000, "ws.stale_metadata"); + + loadedButNotConnectedCount++; + } + } + + this.log?.info({ + msg: "restored hibernatable websockets", + actorId, + restoredCount, + connectedButNotLoadedCount, + loadedButNotConnectedCount, + }); + } + + /** + * Called from WebSocketOpen message and when restoring hibernatable WebSockets. + */ + #createWebSocket( + actorId: string, + requestId: RequestId, + requestIdStr: string, + hibernatable: boolean, + messageIndex: number, + request: Request, + engineAlreadyClosed: boolean, + ): WebSocketTunnelAdapter { + // Create WebSocket adapter + const adapter = new WebSocketTunnelAdapter( + this, + actorId, + requestIdStr, + hibernatable, + messageIndex, + request, + (messageIndex: number) => { + this.#runner.config.hibernatableWebSocket.persistMessageIndex( + actorId, + requestId, + messageIndex, + ); + }, + (data: ArrayBuffer | string, isBinary: boolean) => { + // Send message through tunnel + const dataBuffer = + typeof data === "string" + ? (new TextEncoder().encode(data).buffer as ArrayBuffer) + : data; + + this.#sendMessage(requestId, { + tag: "ToServerWebSocketMessage", + val: { + data: dataBuffer, + binary: isBinary, + }, + }); + }, + (code?: number, reason?: string, retry: boolean = false) => { + // Send close through tunnel if engine doesn't already know it's closed + if (!engineAlreadyClosed) { + this.#sendMessage(requestId, { + tag: "ToServerWebSocketClose", + val: { + code: code || null, + reason: reason || null, + retry, + }, + }); + } + + // Remove from map + this.#actorWebSockets.delete(requestIdStr); + + // Clean up actor tracking + const actor = this.#runner.getActor(actorId); + if (actor) { + actor.webSockets.delete(requestIdStr); + } + }, + ); + + this.#actorWebSockets.set(requestIdStr, adapter); + + // Call WebSocket handler. This handler will add event listeners + // for `open`, etc. + this.#runner.config.websocket( + this.#runner, + actorId, + adapter, + requestId, + request, + ); + + return adapter; + } + #sendMessage( requestId: RequestId, messageKind: protocol.ToServerTunnelMessageKind, @@ -202,7 +400,7 @@ export class Tunnel { const webSocket = this.#actorWebSockets.get(requestIdStr); if (webSocket) { // Close the WebSocket connection - webSocket.__closeWithRetry( + webSocket._closeWithRetry( 1000, "Message acknowledgment timeout", ); @@ -242,7 +440,7 @@ export class Tunnel { for (const requestIdStr of actor.webSockets) { const ws = this.#actorWebSockets.get(requestIdStr); if (ws) { - ws.__closeWithRetry(1000, "Actor stopped"); + ws._closeWithRetry(1000, "Actor stopped"); this.#actorWebSockets.delete(requestIdStr); } } @@ -343,7 +541,7 @@ export class Tunnel { case "ToClientWebSocketMessage": { this.#sendAck(message.requestId, message.messageId); - const _unhandled = await this.#handleWebSocketMessage( + this.#handleWebSocketMessage( message.requestId, message.messageKind.val, ); @@ -532,6 +730,12 @@ export class Tunnel { requestId: protocol.RequestId, open: protocol.ToClientWebSocketOpen, ) { + // NOTE: This method is safe to be async since we will not receive any + // further WebSocket events until we send a ToServerWebSocketOpen + // tunnel message. We can do any async logic we need to between thoes two events. + // + // Sedning a ToServerWebSocketClose will terminate the WebSocket early. + const requestIdStr = idToStr(requestId); // Validate actor exists @@ -558,31 +762,9 @@ export class Tunnel { return; } - const websocketHandler = this.#runner.config.websocket; - - if (!websocketHandler) { - this.log?.error({ - msg: "no websocket handler configured for tunnel", - }); - // Send close immediately - this.#sendMessage(requestId, { - tag: "ToServerWebSocketClose", - val: { - code: 1011, - reason: "Not Implemented", - retry: false, - }, - }); - return; - } - // Close existing WebSocket if one already exists for this request ID. - // There should always be a close message sent before another open - // message for the same message ID. - // - // This should never occur if all is functioning correctly, but this - // prevents any edge case that would result in duplicate WebSockets for - // the same request. + // This should never happen, but prevents any potential duplicate + // WebSockets from retransmits. const existingAdapter = this.#actorWebSockets.get(requestIdStr); if (existingAdapter) { this.log?.warn({ @@ -591,109 +773,56 @@ export class Tunnel { }); // Close without sending a message through the tunnel since the server // already knows about the new connection - existingAdapter.__closeWithoutCallback(1000, "ws.duplicate_open"); - } - - // Track this WebSocket for the actor - if (actor) { - actor.webSockets.add(requestIdStr); + existingAdapter._closeWithoutCallback(1000, "ws.duplicate_open"); } + // Create WebSocket try { - // Create WebSocket adapter - const adapter = new WebSocketTunnelAdapter( - requestIdStr, - (data: ArrayBuffer | string, isBinary: boolean) => { - // Send message through tunnel - const dataBuffer = - typeof data === "string" - ? (new TextEncoder().encode(data) - .buffer as ArrayBuffer) - : data; - - this.#sendMessage(requestId, { - tag: "ToServerWebSocketMessage", - val: { - data: dataBuffer, - binary: isBinary, - }, - }); - }, - (code?: number, reason?: string, retry: boolean = false) => { - // Send close through tunnel - this.#sendMessage(requestId, { - tag: "ToServerWebSocketClose", - val: { - code: code || null, - reason: reason || null, - retry, - }, - }); - - // Remove from map - this.#actorWebSockets.delete(requestIdStr); + actor.webSockets.add(requestIdStr); - // Clean up actor tracking - if (actor) { - actor.webSockets.delete(requestIdStr); - } - }, + const request = buildRequestForWebSocket( + open.path, + Object.fromEntries(open.headers), ); - // Store adapter - this.#actorWebSockets.set(requestIdStr, adapter); - - // Convert headers to map - // - // We need to manually ensure the original Upgrade/Connection WS - // headers are present - const headerInit: Record = {}; - if (open.headers) { - for (const [k, v] of open.headers as ReadonlyMap< - string, - string - >) { - headerInit[k] = v; - } - } - headerInit["Upgrade"] = "websocket"; - headerInit["Connection"] = "Upgrade"; - - const request = new Request(`http://localhost${open.path}`, { - method: "GET", - headers: headerInit, - }); - - // Send open confirmation - const hibernationConfig = - this.#runner.config.getActorHibernationConfig( + const canHibernate = + this.#runner.config.hibernatableWebSocket.canHibernate( actor.actorId, requestId, request, ); - adapter.canHibernate = hibernationConfig.enabled; + // #createWebSocket will call `runner.config.websocket` under the + // hood to add the event listeners for open, etc. If this handler + // throws, then the WebSocket will be closed before sending the + // open event. + const adapter = this.#createWebSocket( + actor.actorId, + requestId, + requestIdStr, + canHibernate, + -1, + request, + false, + ); + + // Open the WebSocket after `config.socket` so (a) the event + // handlers can be added and (b) any errors in `config.websocket` + // will cause the WebSocket to terminate before the open event. this.#sendMessage(requestId, { tag: "ToServerWebSocketOpen", val: { - canHibernate: hibernationConfig.enabled, - lastMsgIndex: BigInt(hibernationConfig.lastMsgIndex ?? -1), + canHibernate, }, }); - // Notify adapter that connection is open + // Dispatch open event adapter._handleOpen(requestId); - - // Call websocket handler - await websocketHandler( - this.#runner, - open.actorId, - adapter, - requestId, - request, - ); } catch (error) { this.log?.error({ msg: "error handling websocket open", error }); + + // TODO: Call close event on adapter if needed + // Send close on error this.#sendMessage(requestId, { tag: "ToServerWebSocketClose", @@ -713,11 +842,13 @@ export class Tunnel { } } - /// Returns false if the message was sent off - async #handleWebSocketMessage( + #handleWebSocketMessage( requestId: ArrayBuffer, msg: protocol.ToClientWebSocketMessage, - ): Promise { + ) { + // NOTE: This method cannot be async in order to ensure in-order + // message processing. + const requestIdStr = idToStr(requestId); const adapter = this.#actorWebSockets.get(requestIdStr); if (adapter) { @@ -725,18 +856,16 @@ export class Tunnel { ? new Uint8Array(msg.data) : new TextDecoder().decode(new Uint8Array(msg.data)); - return adapter._handleMessage( - requestId, - data, - msg.index, - msg.binary, - ); + adapter._handleMessage(requestId, data, msg.index, msg.binary); } else { - return true; + this.log?.warn({ + msg: "missing websocket for incoming websocket message", + requestId, + }); } } - __ackWebsocketMessage(requestId: ArrayBuffer, index: number) { + sendHibernatableWebSocketMessageAck(requestId: ArrayBuffer, index: number) { this.log?.debug({ msg: "ack ws msg", requestId: idToStr(requestId), @@ -782,3 +911,32 @@ function generateUuidBuffer(): ArrayBuffer { function idToStr(id: ArrayBuffer): string { return uuidstringify(new Uint8Array(id)); } + +/** + * Builds a request that represents the incoming request for a given WebSocket. + * + * This request is not a real request and will never be sent. It's used to be passed to the actor to behave like a real incoming request. + */ +function buildRequestForWebSocket( + path: string, + headers: Record, +): Request { + // We need to manually ensure the original Upgrade/Connection WS + // headers are present + const fullHeaders = { + ...headers, + Upgrade: "websocket", + Connection: "Upgrade", + }; + + if (!path.startsWith("/")) { + throw new Error("path must start with leading slash"); + } + + const request = new Request(`http://actor${path}`, { + method: "GET", + headers: fullHeaders, + }); + + return request; +} diff --git a/engine/sdks/typescript/runner/src/utils.ts b/engine/sdks/typescript/runner/src/utils.ts index 4bf6693d26..c6a9c5e7b3 100644 --- a/engine/sdks/typescript/runner/src/utils.ts +++ b/engine/sdks/typescript/runner/src/utils.ts @@ -64,3 +64,60 @@ export function parseWebSocketCloseReason( rayId, }; } + +const U16_MAX = 65535; + +/** + * Wrapping greater than comparison for u16 values. + * Based on shared_state.rs wrapping_gt implementation. + */ +export function wrappingGtU16(a: number, b: number): boolean { + return a !== b && wrappingSub(a, b, U16_MAX) < U16_MAX / 2; +} + +/** + * Wrapping less than comparison for u16 values. + * Based on shared_state.rs wrapping_lt implementation. + */ +export function wrappingLtU16(a: number, b: number): boolean { + return a !== b && wrappingSub(b, a, U16_MAX) < U16_MAX / 2; +} + +/** + * Wrapping greater than or equal comparison for u16 values. + */ +export function wrappingGteU16(a: number, b: number): boolean { + return a === b || wrappingGtU16(a, b); +} + +/** + * Wrapping less than or equal comparison for u16 values. + */ +export function wrappingLteU16(a: number, b: number): boolean { + return a === b || wrappingLtU16(a, b); +} + +/** + * Performs wrapping addition for u16 values. + */ +export function wrappingAddU16(a: number, b: number): number { + return (a + b) % (U16_MAX + 1); +} + +/** + * Performs wrapping subtraction for u16 values. + */ +export function wrappingSubU16(a: number, b: number): number { + return wrappingSub(a, b, U16_MAX); +} + +/** + * Performs wrapping subtraction for unsigned integers. + */ +function wrappingSub(a: number, b: number, max: number): number { + const result = a - b; + if (result < 0) { + return result + max + 1; + } + return result; +} diff --git a/engine/sdks/typescript/runner/src/websocket-tunnel-adapter.ts b/engine/sdks/typescript/runner/src/websocket-tunnel-adapter.ts index ddebb72bc8..13a1480732 100644 --- a/engine/sdks/typescript/runner/src/websocket-tunnel-adapter.ts +++ b/engine/sdks/typescript/runner/src/websocket-tunnel-adapter.ts @@ -1,10 +1,15 @@ // WebSocket-like adapter for tunneled connections // Implements a subset of the WebSocket interface for compatibility with runner code +import type { Logger } from "pino"; import { logger } from "./log"; +import type { Tunnel } from "./tunnel"; +import { wrappingLteU16, wrappingAddU16, wrappingSubU16 } from "./utils"; + +export const HIBERNATABLE_SYMBOL = Symbol("hibernatable"); export class WebSocketTunnelAdapter { - #webSocketId: string; + // MARK: - WebSocket Compat Variables #readyState: number = 0; // CONNECTING #eventListeners: Map void>> = new Map(); #onopen: ((this: any, ev: any) => any) | null = null; @@ -16,18 +21,55 @@ export class WebSocketTunnelAdapter { #extensions = ""; #protocol = ""; #url = ""; + + // mARK: - Internal State + #tunnel: Tunnel; + #actorId: string; + #requestId: string; + #hibernatable: boolean; + #messageIndex: number; + + get [HIBERNATABLE_SYMBOL](): boolean { + return this.#hibernatable; + } + + /** + * Called when a new message index is received for a HWS. This should + * persist the message index somewhere in order to restore the WebSocket. + * + * The receiver of this is in charge of sending Runner::sendHibernatableWebSocketMessageAck. + */ + #persistHibernatableWebSocketMessageIndex: (messageIndex: number) => void; + + /** + * Called when sending a message from this WebSocket. + * + * Used to send a tunnel message from Tunnel. + */ #sendCallback: (data: ArrayBuffer | string, isBinary: boolean) => void; + + /** + * Called when closing this WebSocket. + * + * Used to send a tunnel message from Tunnel + */ #closeCallback: (code?: number, reason?: string, retry?: boolean) => void; - #canHibernate: boolean = false; - // Event buffering for events fired before listeners are attached - #bufferedEvents: Array<{ - type: string; - event: any; - }> = []; + get #log(): Logger | undefined { + return this.#tunnel.log; + } constructor( - webSocketId: string, + tunnel: Tunnel, + actorId: string, + requestId: string, + hibernatable: boolean, + messageIndex: number, + /** @experimental */ + public readonly request: Request, + persistHibernatableWebSocketMessageIndex: ( + messageIndex: number, + ) => void, sendCallback: (data: ArrayBuffer | string, isBinary: boolean) => void, closeCallback: ( code?: number, @@ -35,19 +77,292 @@ export class WebSocketTunnelAdapter { retry?: boolean, ) => void, ) { - this.#webSocketId = webSocketId; + this.#tunnel = tunnel; + this.#actorId = actorId; + this.#requestId = requestId; + this.#hibernatable = hibernatable; + this.#messageIndex = messageIndex; + this.#persistHibernatableWebSocketMessageIndex = + persistHibernatableWebSocketMessageIndex; this.#sendCallback = sendCallback; this.#closeCallback = closeCallback; } - get readyState(): number { - return this.#readyState; - } - + // MARK: - Lifecycle get bufferedAmount(): number { return this.#bufferedAmount; } + _handleOpen(requestId: ArrayBuffer): void { + if (this.#readyState !== 0) { + // CONNECTING + return; + } + + this.#readyState = 1; // OPEN + + const event = { + type: "open", + rivetRequestId: requestId, + target: this, + }; + + this.#fireEvent("open", event); + } + + _handleMessage( + requestId: ArrayBuffer, + data: string | Uint8Array, + messageIndex: number, + isBinary: boolean, + ): boolean { + if (this.#readyState !== 1) { + return true; + } + + const previousIndex = this.#messageIndex; + + // Ignore duplicate old messages + // + // This should only happen if something goes wrong + // between persisting the previous index and acking the + // message index to the gateway. If the ack is never + // received by the gateway (due to a crash or network + // issue), the gateway will resend all messages from + // the last ack on reconnect. + if (wrappingLteU16(messageIndex, previousIndex)) { + this.#log?.info({ + msg: "received duplicate hibernating websocket message, this indicates the actor failed to ack the message index before restarting", + requestId, + actorId: this.#actorId, + previousIndex, + expectedIndex: wrappingAddU16(previousIndex, 1), + receivedIndex: messageIndex, + }); + + return true; + } + + // Close message if skipped message in sequence + // + // There is no scenario where this should ever happen + const expectedIndex = wrappingAddU16(previousIndex, 1); + if (messageIndex !== expectedIndex) { + const closeReason = "ws.message_index_skip"; + + this.#log?.warn({ + msg: "hibernatable websocket message index out of sequence, closing connection", + requestId, + actorId: this.#actorId, + previousIndex, + expectedIndex, + receivedIndex: messageIndex, + closeReason, + gap: wrappingSubU16( + wrappingSubU16(messageIndex, previousIndex), + 1, + ), + }); + + // Close the WebSocket and skip processing + this.close(1008, closeReason); + + return true; + } + + // Update to the next index + this.#messageIndex = messageIndex; + if (this.#hibernatable) { + this.#persistHibernatableWebSocketMessageIndex(messageIndex); + } + + // Dispatch event + let messageData: any; + if (isBinary) { + // Handle binary data based on binaryType + if (this.#binaryType === "nodebuffer") { + // Convert to Buffer for Node.js compatibility + messageData = Buffer.from(data as Uint8Array); + } else if (this.#binaryType === "arraybuffer") { + // Convert to ArrayBuffer + if (data instanceof Uint8Array) { + messageData = data.buffer.slice( + data.byteOffset, + data.byteOffset + data.byteLength, + ); + } else { + messageData = data; + } + } else { + // Blob type - not commonly used in Node.js + throw new Error( + "Blob binaryType not supported in tunnel adapter", + ); + } + } else { + messageData = data; + } + + const event = { + type: "message", + data: messageData, + rivetRequestId: requestId, + rivetMessageIndex: messageIndex, + target: this, + }; + + this.#fireEvent("message", event); + + return false; + } + + _handleClose(requestId: ArrayBuffer, code?: number, reason?: string): void { + if (this.#readyState === 3) { + return; + } + + this.#readyState = 3; + + const event = { + type: "close", + wasClean: true, + code: code || 1000, + reason: reason || "", + rivetRequestId: requestId, + target: this, + }; + + this.#fireEvent("close", event); + } + + _handleError(error: Error): void { + const event = { + type: "error", + target: this, + error, + }; + + this.#fireEvent("error", event); + } + + _closeWithRetry(code?: number, reason?: string): void { + this.#closeInner(code, reason, true, true); + } + + _closeWithoutCallback(code?: number, reason?: string): void { + this.#closeInner(code, reason, false, false); + } + + #fireEvent(type: string, event: any): void { + // Call all registered event listeners + const listeners = this.#eventListeners.get(type); + + if (listeners && listeners.size > 0) { + for (const listener of listeners) { + try { + listener.call(this, event); + } catch (error) { + logger()?.error({ + msg: "error in websocket event listener", + error, + type, + }); + } + } + } + + // Call the onX property if set + switch (type) { + case "open": + if (this.#onopen) { + try { + this.#onopen.call(this, event); + } catch (error) { + logger()?.error({ + msg: "error in onopen handler", + error, + }); + } + } + break; + case "close": + if (this.#onclose) { + try { + this.#onclose.call(this, event); + } catch (error) { + logger()?.error({ + msg: "error in onclose handler", + error, + }); + } + } + break; + case "error": + if (this.#onerror) { + try { + this.#onerror.call(this, event); + } catch (error) { + logger()?.error({ + msg: "error in onerror handler", + error, + }); + } + } + break; + case "message": + if (this.#onmessage) { + try { + this.#onmessage.call(this, event); + } catch (error) { + logger()?.error({ + msg: "error in onmessage handler", + error, + }); + } + } + break; + } + } + + #closeInner( + code: number | undefined, + reason: string | undefined, + retry: boolean, + callback: boolean, + ): void { + if ( + this.#readyState === 2 || // CLOSING + this.#readyState === 3 // CLOSED + ) { + return; + } + + this.#readyState = 2; // CLOSING + + // Send close through tunnel + if (callback) { + this.#closeCallback(code, reason, retry); + } + + // Update state and fire event + this.#readyState = 3; // CLOSED + + const closeEvent = { + wasClean: true, + code: code || 1000, + reason: reason || "", + type: "close", + target: this, + }; + + this.#fireEvent("close", closeEvent); + } + + // MARK: - WebSocket Compatible API + get readyState(): number { + return this.#readyState; + } + get binaryType(): string { return this.#binaryType; } @@ -74,26 +389,12 @@ export class WebSocketTunnelAdapter { return this.#url; } - /** @experimental */ - get canHibernate(): boolean { - return this.#canHibernate; - } - - /** @experimental */ - set canHibernate(value: boolean) { - this.#canHibernate = value; - } - get onopen(): ((this: any, ev: any) => any) | null { return this.#onopen; } set onopen(value: ((this: any, ev: any) => any) | null) { this.#onopen = value; - // Flush any buffered open events when onopen is set - if (value) { - this.#flushBufferedEvents("open"); - } } get onclose(): ((this: any, ev: any) => any) | null { @@ -102,10 +403,6 @@ export class WebSocketTunnelAdapter { set onclose(value: ((this: any, ev: any) => any) | null) { this.#onclose = value; - // Flush any buffered close events when onclose is set - if (value) { - this.#flushBufferedEvents("close"); - } } get onerror(): ((this: any, ev: any) => any) | null { @@ -114,10 +411,6 @@ export class WebSocketTunnelAdapter { set onerror(value: ((this: any, ev: any) => any) | null) { this.#onerror = value; - // Flush any buffered error events when onerror is set - if (value) { - this.#flushBufferedEvents("error"); - } } get onmessage(): ((this: any, ev: any) => any) | null { @@ -126,10 +419,6 @@ export class WebSocketTunnelAdapter { set onmessage(value: ((this: any, ev: any) => any) | null) { this.#onmessage = value; - // Flush any buffered message events when onmessage is set - if (value) { - this.#flushBufferedEvents("message"); - } } send(data: string | ArrayBuffer | ArrayBufferView | Blob | Buffer): void { @@ -201,49 +490,7 @@ export class WebSocketTunnelAdapter { } close(code?: number, reason?: string): void { - this.closeInner(code, reason, false, true); - } - - __closeWithRetry(code?: number, reason?: string): void { - this.closeInner(code, reason, true, true); - } - - __closeWithoutCallback(code?: number, reason?: string): void { - this.closeInner(code, reason, false, false); - } - - closeInner( - code: number | undefined, - reason: string | undefined, - retry: boolean, - callback: boolean, - ): void { - if ( - this.#readyState === 2 || // CLOSING - this.#readyState === 3 // CLOSED - ) { - return; - } - - this.#readyState = 2; // CLOSING - - // Send close through tunnel - if (callback) { - this.#closeCallback(code, reason, retry); - } - - // Update state and fire event - this.#readyState = 3; // CLOSED - - const closeEvent = { - wasClean: true, - code: code || 1000, - reason: reason || "", - type: "close", - target: this, - }; - - this.#fireEvent("close", closeEvent); + this.#closeInner(code, reason, false, true); } addEventListener( @@ -258,9 +505,6 @@ export class WebSocketTunnelAdapter { this.#eventListeners.set(type, listeners); } listeners.add(listener); - - // Flush any buffered events for this type - this.#flushBufferedEvents(type); } } @@ -278,278 +522,15 @@ export class WebSocketTunnelAdapter { } dispatchEvent(event: any): boolean { - // Simple implementation + // TODO: return true; } - #fireEvent(type: string, event: any): void { - // Call all registered event listeners - const listeners = this.#eventListeners.get(type); - let hasListeners = false; - - if (listeners && listeners.size > 0) { - hasListeners = true; - for (const listener of listeners) { - try { - listener.call(this, event); - } catch (error) { - logger()?.error({ - msg: "error in websocket event listener", - error, - type, - }); - } - } - } - - // Call the onX property if set - switch (type) { - case "open": - if (this.#onopen) { - hasListeners = true; - try { - this.#onopen.call(this, event); - } catch (error) { - logger()?.error({ - msg: "error in onopen handler", - error, - }); - } - } - break; - case "close": - if (this.#onclose) { - hasListeners = true; - try { - this.#onclose.call(this, event); - } catch (error) { - logger()?.error({ - msg: "error in onclose handler", - error, - }); - } - } - break; - case "error": - if (this.#onerror) { - hasListeners = true; - try { - this.#onerror.call(this, event); - } catch (error) { - logger()?.error({ - msg: "error in onerror handler", - error, - }); - } - } - break; - case "message": - if (this.#onmessage) { - hasListeners = true; - try { - this.#onmessage.call(this, event); - } catch (error) { - logger()?.error({ - msg: "error in onmessage handler", - error, - }); - } - } - break; - } - - // Buffer the event if no listeners are registered - if (!hasListeners) { - this.#bufferedEvents.push({ type, event }); - } - } - - #flushBufferedEvents(type: string): void { - const eventsToFlush = this.#bufferedEvents.filter( - (buffered) => buffered.type === type, - ); - this.#bufferedEvents = this.#bufferedEvents.filter( - (buffered) => buffered.type !== type, - ); - - for (const { event } of eventsToFlush) { - // Re-fire the event, which will now have listeners - const listeners = this.#eventListeners.get(type); - if (listeners) { - for (const listener of listeners) { - try { - listener.call(this, event); - } catch (error) { - logger()?.error({ - msg: "error in websocket event listener", - error, - type, - }); - } - } - } - - // Also call the onX handler if it exists - switch (type) { - case "open": - if (this.#onopen) { - try { - this.#onopen.call(this, event); - } catch (error) { - logger()?.error({ - msg: "error in onopen handler", - error, - }); - } - } - break; - case "close": - if (this.#onclose) { - try { - this.#onclose.call(this, event); - } catch (error) { - logger()?.error({ - msg: "error in onclose handler", - error, - }); - } - } - break; - case "error": - if (this.#onerror) { - try { - this.#onerror.call(this, event); - } catch (error) { - logger()?.error({ - msg: "error in onerror handler", - error, - }); - } - } - break; - case "message": - if (this.#onmessage) { - try { - this.#onmessage.call(this, event); - } catch (error) { - logger()?.error({ - msg: "error in onmessage handler", - error, - }); - } - } - break; - } - } - } - - // Internal methods called by the Tunnel class - _handleOpen(requestId: ArrayBuffer): void { - if (this.#readyState !== 0) { - // CONNECTING - return; - } - - this.#readyState = 1; // OPEN - - const event = { - type: "open", - rivetRequestId: requestId, - target: this, - }; - - this.#fireEvent("open", event); - } - - /// Returns false if the message was sent off. - _handleMessage( - requestId: ArrayBuffer, - data: string | Uint8Array, - index: number, - isBinary: boolean, - ): boolean { - if (this.#readyState !== 1) { - // OPEN - return true; - } - - let messageData: any; - - if (isBinary) { - // Handle binary data based on binaryType - if (this.#binaryType === "nodebuffer") { - // Convert to Buffer for Node.js compatibility - messageData = Buffer.from(data as Uint8Array); - } else if (this.#binaryType === "arraybuffer") { - // Convert to ArrayBuffer - if (data instanceof Uint8Array) { - messageData = data.buffer.slice( - data.byteOffset, - data.byteOffset + data.byteLength, - ); - } else { - messageData = data; - } - } else { - // Blob type - not commonly used in Node.js - throw new Error( - "Blob binaryType not supported in tunnel adapter", - ); - } - } else { - messageData = data; - } - - const event = { - type: "message", - data: messageData, - rivetRequestId: requestId, - rivetMessageIndex: index, - target: this, - }; - - this.#fireEvent("message", event); - - return false; - } - - _handleClose(requestId: ArrayBuffer, code?: number, reason?: string): void { - if (this.#readyState === 3) { - // CLOSED - return; - } - - this.#readyState = 3; // CLOSED - - const event = { - type: "close", - wasClean: true, - code: code || 1000, - reason: reason || "", - rivetRequestId: requestId, - target: this, - }; - - this.#fireEvent("close", event); - } - - _handleError(error: Error): void { - const event = { - type: "error", - target: this, - error, - }; - - this.#fireEvent("error", event); - } - - // WebSocket constants for compatibility static readonly CONNECTING = 0; static readonly OPEN = 1; static readonly CLOSING = 2; static readonly CLOSED = 3; - // Instance constants readonly CONNECTING = 0; readonly OPEN = 1; readonly CLOSING = 2; @@ -566,6 +547,7 @@ export class WebSocketTunnelAdapter { if (cb) cb(new Error("Pong not supported in tunnel adapter")); } + /** @experimental */ terminate(): void { // Immediate close without close frame this.#readyState = 3; // CLOSED diff --git a/engine/sdks/typescript/runner/tests/utils.test.ts b/engine/sdks/typescript/runner/tests/utils.test.ts new file mode 100644 index 0000000000..6259921683 --- /dev/null +++ b/engine/sdks/typescript/runner/tests/utils.test.ts @@ -0,0 +1,194 @@ +import { describe, expect, it } from "vitest"; +import { + wrappingGteU16, + wrappingGtU16, + wrappingLteU16, + wrappingLtU16, +} from "../src/utils"; + +describe("wrappingGtU16", () => { + it("should return true when a > b in normal case", () => { + expect(wrappingGtU16(100, 50)).toBe(true); + expect(wrappingGtU16(1000, 999)).toBe(true); + }); + + it("should return false when a < b in normal case", () => { + expect(wrappingGtU16(50, 100)).toBe(false); + expect(wrappingGtU16(999, 1000)).toBe(false); + }); + + it("should return false when a == b", () => { + expect(wrappingGtU16(100, 100)).toBe(false); + expect(wrappingGtU16(0, 0)).toBe(false); + expect(wrappingGtU16(65535, 65535)).toBe(false); + }); + + it("should handle wrapping around u16 max", () => { + // When values wrap around, 1 is "greater than" 65535 + expect(wrappingGtU16(1, 65535)).toBe(true); + expect(wrappingGtU16(100, 65500)).toBe(true); + }); + + it("should handle edge cases near u16 boundaries", () => { + // 65535 is not greater than 0 (wrapped) + expect(wrappingGtU16(65535, 0)).toBe(false); + // But 0 is greater than 65535 if we consider wrapping + expect(wrappingGtU16(0, 65535)).toBe(true); + }); + + it("should handle values at exactly half the range", () => { + // U16_MAX / 2 = 32767.5, so values with distance <= 32767 return true + const lessThanHalf = 32766; + expect(wrappingGtU16(lessThanHalf, 0)).toBe(true); + expect(wrappingGtU16(0, lessThanHalf)).toBe(false); + + // At distance 32767, still less than 32767.5, so comparison returns true + const atHalfDistance = 32767; + expect(wrappingGtU16(atHalfDistance, 0)).toBe(true); + expect(wrappingGtU16(0, atHalfDistance)).toBe(false); + + // At distance 32768, greater than 32767.5, so comparison returns false + const overHalfDistance = 32768; + expect(wrappingGtU16(overHalfDistance, 0)).toBe(false); + expect(wrappingGtU16(0, overHalfDistance)).toBe(false); + }); +}); + +describe("wrappingLtU16", () => { + it("should return true when a < b in normal case", () => { + expect(wrappingLtU16(50, 100)).toBe(true); + expect(wrappingLtU16(999, 1000)).toBe(true); + }); + + it("should return false when a > b in normal case", () => { + expect(wrappingLtU16(100, 50)).toBe(false); + expect(wrappingLtU16(1000, 999)).toBe(false); + }); + + it("should return false when a == b", () => { + expect(wrappingLtU16(100, 100)).toBe(false); + expect(wrappingLtU16(0, 0)).toBe(false); + expect(wrappingLtU16(65535, 65535)).toBe(false); + }); + + it("should handle wrapping around u16 max", () => { + // When values wrap around, 65535 is "less than" 1 + expect(wrappingLtU16(65535, 1)).toBe(true); + expect(wrappingLtU16(65500, 100)).toBe(true); + }); + + it("should handle edge cases near u16 boundaries", () => { + // 0 is not less than 65535 (wrapped) + expect(wrappingLtU16(0, 65535)).toBe(false); + // But 65535 is less than 0 if we consider wrapping + expect(wrappingLtU16(65535, 0)).toBe(true); + }); + + it("should handle values at exactly half the range", () => { + // U16_MAX / 2 = 32767.5, so values with distance <= 32767 return true + const lessThanHalf = 32766; + expect(wrappingLtU16(0, lessThanHalf)).toBe(true); + expect(wrappingLtU16(lessThanHalf, 0)).toBe(false); + + // At distance 32767, still less than 32767.5, so comparison returns true + const atHalfDistance = 32767; + expect(wrappingLtU16(0, atHalfDistance)).toBe(true); + expect(wrappingLtU16(atHalfDistance, 0)).toBe(false); + + // At distance 32768, greater than 32767.5, so comparison returns false + const overHalfDistance = 32768; + expect(wrappingLtU16(0, overHalfDistance)).toBe(false); + expect(wrappingLtU16(overHalfDistance, 0)).toBe(false); + }); +}); + +describe("wrappingGtU16 and wrappingLtU16 consistency", () => { + it("should be inverse of each other for different values", () => { + const testCases: [number, number][] = [ + [100, 200], + [200, 100], + [0, 65535], + [65535, 0], + [1, 65534], + [32767, 32768], + ]; + + for (const [a, b] of testCases) { + const gt = wrappingGtU16(a, b); + const lt = wrappingLtU16(a, b); + const eq = a === b; + + // For any pair, exactly one of gt, lt, or eq should be true + expect(Number(gt) + Number(lt) + Number(eq)).toBe(1); + } + }); + + it("should satisfy transitivity for sequential values", () => { + // If we have sequential indices, a < b < c should hold + const a = 100; + const b = 101; + const c = 102; + + expect(wrappingLtU16(a, b)).toBe(true); + expect(wrappingLtU16(b, c)).toBe(true); + expect(wrappingLtU16(a, c)).toBe(true); + }); + + it("should handle sequence across wrap boundary", () => { + // Test a sequence that wraps: 65534, 65535, 0, 1 + const values = [65534, 65535, 0, 1]; + + for (let i = 0; i < values.length - 1; i++) { + expect(wrappingLtU16(values[i], values[i + 1])).toBe(true); + expect(wrappingGtU16(values[i + 1], values[i])).toBe(true); + } + }); +}); + +describe("wrappingGteU16", () => { + it("should return true when a > b", () => { + expect(wrappingGteU16(100, 50)).toBe(true); + expect(wrappingGteU16(1000, 999)).toBe(true); + }); + + it("should return true when a == b", () => { + expect(wrappingGteU16(100, 100)).toBe(true); + expect(wrappingGteU16(0, 0)).toBe(true); + expect(wrappingGteU16(65535, 65535)).toBe(true); + }); + + it("should return false when a < b", () => { + expect(wrappingGteU16(50, 100)).toBe(false); + expect(wrappingGteU16(999, 1000)).toBe(false); + }); + + it("should handle wrapping around u16 max", () => { + expect(wrappingGteU16(1, 65535)).toBe(true); + expect(wrappingGteU16(100, 65500)).toBe(true); + expect(wrappingGteU16(0, 65535)).toBe(true); + }); +}); + +describe("wrappingLteU16", () => { + it("should return true when a < b", () => { + expect(wrappingLteU16(50, 100)).toBe(true); + expect(wrappingLteU16(999, 1000)).toBe(true); + }); + + it("should return true when a == b", () => { + expect(wrappingLteU16(100, 100)).toBe(true); + expect(wrappingLteU16(0, 0)).toBe(true); + expect(wrappingLteU16(65535, 65535)).toBe(true); + }); + + it("should return false when a > b", () => { + expect(wrappingLteU16(100, 50)).toBe(false); + expect(wrappingLteU16(1000, 999)).toBe(false); + }); + + it("should handle wrapping around u16 max", () => { + expect(wrappingLteU16(65535, 1)).toBe(true); + expect(wrappingLteU16(65500, 100)).toBe(true); + expect(wrappingLteU16(65535, 0)).toBe(true); + }); +}); diff --git a/rivetkit-asyncapi/asyncapi.json b/rivetkit-asyncapi/asyncapi.json index 0d9c7c7f98..e6074c7754 100644 --- a/rivetkit-asyncapi/asyncapi.json +++ b/rivetkit-asyncapi/asyncapi.json @@ -1,436 +1,489 @@ { - "asyncapi": "3.0.0", - "info": { - "title": "RivetKit WebSocket Protocol", - "version": "2.0.24-rc.1", - "description": "WebSocket protocol for bidirectional communication between RivetKit clients and actors" - }, - "channels": { - "/gateway/{actorId}/connect": { - "address": "/gateway/{actorId}/connect", - "parameters": { - "actorId": { - "description": "The unique identifier for the actor instance" - } - }, - "messages": { - "toClient": { - "$ref": "#/components/messages/ToClient" - }, - "toServer": { - "$ref": "#/components/messages/ToServer" - } - } - } - }, - "operations": { - "sendToClient": { - "action": "send", - "channel": { - "$ref": "#/channels/~1gateway~1{actorId}~1connect" - }, - "messages": [ - { - "$ref": "#/channels/~1gateway~1{actorId}~1connect/messages/toClient" - } - ], - "summary": "Send messages from server to client", - "description": "Messages sent from the RivetKit actor to connected clients" - }, - "receiveFromClient": { - "action": "receive", - "channel": { - "$ref": "#/channels/~1gateway~1{actorId}~1connect" - }, - "messages": [ - { - "$ref": "#/channels/~1gateway~1{actorId}~1connect/messages/toServer" - } - ], - "summary": "Receive messages from client", - "description": "Messages received by the RivetKit actor from connected clients" - } - }, - "components": { - "messages": { - "ToClient": { - "name": "ToClient", - "title": "Message To Client", - "summary": "A message sent from the server to the client", - "contentType": "application/json", - "payload": { - "type": "object", - "properties": { - "body": { - "anyOf": [ - { - "type": "object", - "properties": { - "tag": { - "type": "string", - "const": "Init" - }, - "val": { - "type": "object", - "properties": { - "actorId": { - "type": "string" - }, - "connectionId": { - "type": "string" - } - }, - "required": [ - "actorId", - "connectionId" - ], - "additionalProperties": false - } - }, - "required": ["tag", "val"], - "additionalProperties": false - }, - { - "type": "object", - "properties": { - "tag": { - "type": "string", - "const": "Error" - }, - "val": { - "type": "object", - "properties": { - "group": { - "type": "string" - }, - "code": { - "type": "string" - }, - "message": { - "type": "string" - }, - "metadata": {}, - "actionId": { - "type": ["integer", "null"] - } - }, - "required": [ - "group", - "code", - "message", - "actionId" - ], - "additionalProperties": false - } - }, - "required": ["tag", "val"], - "additionalProperties": false - }, - { - "type": "object", - "properties": { - "tag": { - "type": "string", - "const": "ActionResponse" - }, - "val": { - "type": "object", - "properties": { - "id": { - "type": "integer", - "format": "int64" - }, - "output": {} - }, - "required": ["id"], - "additionalProperties": false - } - }, - "required": ["tag", "val"], - "additionalProperties": false - }, - { - "type": "object", - "properties": { - "tag": { - "type": "string", - "const": "Event" - }, - "val": { - "type": "object", - "properties": { - "name": { - "type": "string" - }, - "args": {} - }, - "required": ["name"], - "additionalProperties": false - } - }, - "required": ["tag", "val"], - "additionalProperties": false - } - ] - } - }, - "required": ["body"], - "additionalProperties": false - }, - "examples": [ - { - "name": "Init message", - "summary": "Initial connection message", - "payload": { - "body": { - "tag": "Init", - "val": { - "actorId": "actor_123", - "connectionId": "conn_456" - } - } - } - }, - { - "name": "Error message", - "summary": "Error response", - "payload": { - "body": { - "tag": "Error", - "val": { - "group": "auth", - "code": "unauthorized", - "message": "Authentication failed", - "actionId": null - } - } - } - }, - { - "name": "Action response", - "summary": "Response to an action request", - "payload": { - "body": { - "tag": "ActionResponse", - "val": { - "id": "123", - "output": { - "result": "success" - } - } - } - } - }, - { - "name": "Event", - "summary": "Event broadcast to subscribed clients", - "payload": { - "body": { - "tag": "Event", - "val": { - "name": "stateChanged", - "args": { - "newState": "active" - } - } - } - } - } - ] - }, - "ToServer": { - "name": "ToServer", - "title": "Message To Server", - "summary": "A message sent from the client to the server", - "contentType": "application/json", - "payload": { - "type": "object", - "properties": { - "body": { - "anyOf": [ - { - "type": "object", - "properties": { - "tag": { - "type": "string", - "const": "ActionRequest" - }, - "val": { - "type": "object", - "properties": { - "id": { - "type": "integer", - "format": "int64" - }, - "name": { - "type": "string" - }, - "args": {} - }, - "required": ["id", "name"], - "additionalProperties": false - } - }, - "required": ["tag", "val"], - "additionalProperties": false - }, - { - "type": "object", - "properties": { - "tag": { - "type": "string", - "const": "SubscriptionRequest" - }, - "val": { - "type": "object", - "properties": { - "eventName": { - "type": "string" - }, - "subscribe": { - "type": "boolean" - } - }, - "required": [ - "eventName", - "subscribe" - ], - "additionalProperties": false - } - }, - "required": ["tag", "val"], - "additionalProperties": false - } - ] - } - }, - "required": ["body"], - "additionalProperties": false - }, - "examples": [ - { - "name": "Action request", - "summary": "Request to execute an action", - "payload": { - "body": { - "tag": "ActionRequest", - "val": { - "id": "123", - "name": "updateState", - "args": { - "key": "value" - } - } - } - } - }, - { - "name": "Subscription request", - "summary": "Request to subscribe/unsubscribe from an event", - "payload": { - "body": { - "tag": "SubscriptionRequest", - "val": { - "eventName": "stateChanged", - "subscribe": true - } - } - } - } - ] - } - }, - "schemas": { - "Init": { - "type": "object", - "properties": { - "actorId": { - "type": "string" - }, - "connectionId": { - "type": "string" - } - }, - "required": ["actorId", "connectionId"], - "additionalProperties": false, - "description": "Initial connection message sent from server to client" - }, - "Error": { - "type": "object", - "properties": { - "group": { - "type": "string" - }, - "code": { - "type": "string" - }, - "message": { - "type": "string" - }, - "metadata": {}, - "actionId": { - "type": ["integer", "null"] - } - }, - "required": ["group", "code", "message", "actionId"], - "additionalProperties": false, - "description": "Error message sent from server to client" - }, - "ActionResponse": { - "type": "object", - "properties": { - "id": { - "type": "integer", - "format": "int64" - }, - "output": {} - }, - "required": ["id"], - "additionalProperties": false, - "description": "Response to an action request" - }, - "Event": { - "type": "object", - "properties": { - "name": { - "type": "string" - }, - "args": {} - }, - "required": ["name"], - "additionalProperties": false, - "description": "Event broadcast to subscribed clients" - }, - "ActionRequest": { - "type": "object", - "properties": { - "id": { - "type": "integer", - "format": "int64" - }, - "name": { - "type": "string" - }, - "args": {} - }, - "required": ["id", "name"], - "additionalProperties": false, - "description": "Request to execute an action on the actor" - }, - "SubscriptionRequest": { - "type": "object", - "properties": { - "eventName": { - "type": "string" - }, - "subscribe": { - "type": "boolean" - } - }, - "required": ["eventName", "subscribe"], - "additionalProperties": false, - "description": "Request to subscribe or unsubscribe from an event" - } - } - } -} + "asyncapi": "3.0.0", + "info": { + "title": "RivetKit WebSocket Protocol", + "version": "2.0.24-rc.1", + "description": "WebSocket protocol for bidirectional communication between RivetKit clients and actors" + }, + "channels": { + "/gateway/{actorId}/connect": { + "address": "/gateway/{actorId}/connect", + "parameters": { + "actorId": { + "description": "The unique identifier for the actor instance" + } + }, + "messages": { + "toClient": { + "$ref": "#/components/messages/ToClient" + }, + "toServer": { + "$ref": "#/components/messages/ToServer" + } + } + } + }, + "operations": { + "sendToClient": { + "action": "send", + "channel": { + "$ref": "#/channels/~1gateway~1{actorId}~1connect" + }, + "messages": [ + { + "$ref": "#/channels/~1gateway~1{actorId}~1connect/messages/toClient" + } + ], + "summary": "Send messages from server to client", + "description": "Messages sent from the RivetKit actor to connected clients" + }, + "receiveFromClient": { + "action": "receive", + "channel": { + "$ref": "#/channels/~1gateway~1{actorId}~1connect" + }, + "messages": [ + { + "$ref": "#/channels/~1gateway~1{actorId}~1connect/messages/toServer" + } + ], + "summary": "Receive messages from client", + "description": "Messages received by the RivetKit actor from connected clients" + } + }, + "components": { + "messages": { + "ToClient": { + "name": "ToClient", + "title": "Message To Client", + "summary": "A message sent from the server to the client", + "contentType": "application/json", + "payload": { + "type": "object", + "properties": { + "body": { + "anyOf": [ + { + "type": "object", + "properties": { + "tag": { + "type": "string", + "const": "Init" + }, + "val": { + "type": "object", + "properties": { + "actorId": { + "type": "string" + }, + "connectionId": { + "type": "string" + } + }, + "required": [ + "actorId", + "connectionId" + ], + "additionalProperties": false + } + }, + "required": [ + "tag", + "val" + ], + "additionalProperties": false + }, + { + "type": "object", + "properties": { + "tag": { + "type": "string", + "const": "Error" + }, + "val": { + "type": "object", + "properties": { + "group": { + "type": "string" + }, + "code": { + "type": "string" + }, + "message": { + "type": "string" + }, + "metadata": {}, + "actionId": { + "type": [ + "integer", + "null" + ] + } + }, + "required": [ + "group", + "code", + "message", + "actionId" + ], + "additionalProperties": false + } + }, + "required": [ + "tag", + "val" + ], + "additionalProperties": false + }, + { + "type": "object", + "properties": { + "tag": { + "type": "string", + "const": "ActionResponse" + }, + "val": { + "type": "object", + "properties": { + "id": { + "type": "integer", + "format": "int64" + }, + "output": {} + }, + "required": [ + "id" + ], + "additionalProperties": false + } + }, + "required": [ + "tag", + "val" + ], + "additionalProperties": false + }, + { + "type": "object", + "properties": { + "tag": { + "type": "string", + "const": "Event" + }, + "val": { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "args": {} + }, + "required": [ + "name" + ], + "additionalProperties": false + } + }, + "required": [ + "tag", + "val" + ], + "additionalProperties": false + } + ] + } + }, + "required": [ + "body" + ], + "additionalProperties": false + }, + "examples": [ + { + "name": "Init message", + "summary": "Initial connection message", + "payload": { + "body": { + "tag": "Init", + "val": { + "actorId": "actor_123", + "connectionId": "conn_456" + } + } + } + }, + { + "name": "Error message", + "summary": "Error response", + "payload": { + "body": { + "tag": "Error", + "val": { + "group": "auth", + "code": "unauthorized", + "message": "Authentication failed", + "actionId": null + } + } + } + }, + { + "name": "Action response", + "summary": "Response to an action request", + "payload": { + "body": { + "tag": "ActionResponse", + "val": { + "id": "123", + "output": { + "result": "success" + } + } + } + } + }, + { + "name": "Event", + "summary": "Event broadcast to subscribed clients", + "payload": { + "body": { + "tag": "Event", + "val": { + "name": "stateChanged", + "args": { + "newState": "active" + } + } + } + } + } + ] + }, + "ToServer": { + "name": "ToServer", + "title": "Message To Server", + "summary": "A message sent from the client to the server", + "contentType": "application/json", + "payload": { + "type": "object", + "properties": { + "body": { + "anyOf": [ + { + "type": "object", + "properties": { + "tag": { + "type": "string", + "const": "ActionRequest" + }, + "val": { + "type": "object", + "properties": { + "id": { + "type": "integer", + "format": "int64" + }, + "name": { + "type": "string" + }, + "args": {} + }, + "required": [ + "id", + "name" + ], + "additionalProperties": false + } + }, + "required": [ + "tag", + "val" + ], + "additionalProperties": false + }, + { + "type": "object", + "properties": { + "tag": { + "type": "string", + "const": "SubscriptionRequest" + }, + "val": { + "type": "object", + "properties": { + "eventName": { + "type": "string" + }, + "subscribe": { + "type": "boolean" + } + }, + "required": [ + "eventName", + "subscribe" + ], + "additionalProperties": false + } + }, + "required": [ + "tag", + "val" + ], + "additionalProperties": false + } + ] + } + }, + "required": [ + "body" + ], + "additionalProperties": false + }, + "examples": [ + { + "name": "Action request", + "summary": "Request to execute an action", + "payload": { + "body": { + "tag": "ActionRequest", + "val": { + "id": "123", + "name": "updateState", + "args": { + "key": "value" + } + } + } + } + }, + { + "name": "Subscription request", + "summary": "Request to subscribe/unsubscribe from an event", + "payload": { + "body": { + "tag": "SubscriptionRequest", + "val": { + "eventName": "stateChanged", + "subscribe": true + } + } + } + } + ] + } + }, + "schemas": { + "Init": { + "type": "object", + "properties": { + "actorId": { + "type": "string" + }, + "connectionId": { + "type": "string" + } + }, + "required": [ + "actorId", + "connectionId" + ], + "additionalProperties": false, + "description": "Initial connection message sent from server to client" + }, + "Error": { + "type": "object", + "properties": { + "group": { + "type": "string" + }, + "code": { + "type": "string" + }, + "message": { + "type": "string" + }, + "metadata": {}, + "actionId": { + "type": [ + "integer", + "null" + ] + } + }, + "required": [ + "group", + "code", + "message", + "actionId" + ], + "additionalProperties": false, + "description": "Error message sent from server to client" + }, + "ActionResponse": { + "type": "object", + "properties": { + "id": { + "type": "integer", + "format": "int64" + }, + "output": {} + }, + "required": [ + "id" + ], + "additionalProperties": false, + "description": "Response to an action request" + }, + "Event": { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "args": {} + }, + "required": [ + "name" + ], + "additionalProperties": false, + "description": "Event broadcast to subscribed clients" + }, + "ActionRequest": { + "type": "object", + "properties": { + "id": { + "type": "integer", + "format": "int64" + }, + "name": { + "type": "string" + }, + "args": {} + }, + "required": [ + "id", + "name" + ], + "additionalProperties": false, + "description": "Request to execute an action on the actor" + }, + "SubscriptionRequest": { + "type": "object", + "properties": { + "eventName": { + "type": "string" + }, + "subscribe": { + "type": "boolean" + } + }, + "required": [ + "eventName", + "subscribe" + ], + "additionalProperties": false, + "description": "Request to subscribe or unsubscribe from an event" + } + } + } +} \ No newline at end of file diff --git a/rivetkit-openapi/openapi.json b/rivetkit-openapi/openapi.json index 90803b4757..4ab454fa07 100644 --- a/rivetkit-openapi/openapi.json +++ b/rivetkit-openapi/openapi.json @@ -113,6 +113,7 @@ }, "put": { "requestBody": { + "required": true, "content": { "application/json": { "schema": { @@ -225,6 +226,7 @@ }, "post": { "requestBody": { + "required": true, "content": { "application/json": { "schema": { @@ -385,283 +387,6 @@ } } } - }, - "/gateway/{actorId}/health": { - "get": { - "parameters": [ - { - "name": "actorId", - "in": "path", - "required": true, - "schema": { - "type": "string" - }, - "description": "The ID of the actor to target" - } - ], - "responses": { - "200": { - "description": "Health check", - "content": { - "text/plain": { - "schema": { - "type": "string" - } - } - } - } - } - } - }, - "/gateway/{actorId}/action/{action}": { - "post": { - "parameters": [ - { - "name": "actorId", - "in": "path", - "required": true, - "schema": { - "type": "string" - }, - "description": "The ID of the actor to target" - }, - { - "name": "action", - "in": "path", - "required": true, - "schema": { - "type": "string" - }, - "description": "The name of the action to execute" - } - ], - "requestBody": { - "content": { - "application/json": { - "schema": { - "type": "object", - "properties": { - "args": {} - }, - "additionalProperties": false - } - } - } - }, - "responses": { - "200": { - "description": "Action executed successfully", - "content": { - "application/json": { - "schema": { - "type": "object", - "properties": { - "output": {} - }, - "additionalProperties": false - } - } - } - }, - "400": { - "description": "Invalid action" - }, - "500": { - "description": "Internal error" - } - } - } - }, - "/gateway/{actorId}/request/{path}": { - "get": { - "parameters": [ - { - "name": "actorId", - "in": "path", - "required": true, - "schema": { - "type": "string" - }, - "description": "The ID of the actor to target" - }, - { - "name": "path", - "in": "path", - "required": true, - "schema": { - "type": "string" - }, - "description": "The HTTP path to forward to the actor" - } - ], - "responses": { - "200": { - "description": "Response from actor's raw HTTP handler" - } - } - }, - "post": { - "parameters": [ - { - "name": "actorId", - "in": "path", - "required": true, - "schema": { - "type": "string" - }, - "description": "The ID of the actor to target" - }, - { - "name": "path", - "in": "path", - "required": true, - "schema": { - "type": "string" - }, - "description": "The HTTP path to forward to the actor" - } - ], - "responses": { - "200": { - "description": "Response from actor's raw HTTP handler" - } - } - }, - "put": { - "parameters": [ - { - "name": "actorId", - "in": "path", - "required": true, - "schema": { - "type": "string" - }, - "description": "The ID of the actor to target" - }, - { - "name": "path", - "in": "path", - "required": true, - "schema": { - "type": "string" - }, - "description": "The HTTP path to forward to the actor" - } - ], - "responses": { - "200": { - "description": "Response from actor's raw HTTP handler" - } - } - }, - "delete": { - "parameters": [ - { - "name": "actorId", - "in": "path", - "required": true, - "schema": { - "type": "string" - }, - "description": "The ID of the actor to target" - }, - { - "name": "path", - "in": "path", - "required": true, - "schema": { - "type": "string" - }, - "description": "The HTTP path to forward to the actor" - } - ], - "responses": { - "200": { - "description": "Response from actor's raw HTTP handler" - } - } - }, - "patch": { - "parameters": [ - { - "name": "actorId", - "in": "path", - "required": true, - "schema": { - "type": "string" - }, - "description": "The ID of the actor to target" - }, - { - "name": "path", - "in": "path", - "required": true, - "schema": { - "type": "string" - }, - "description": "The HTTP path to forward to the actor" - } - ], - "responses": { - "200": { - "description": "Response from actor's raw HTTP handler" - } - } - }, - "head": { - "parameters": [ - { - "name": "actorId", - "in": "path", - "required": true, - "schema": { - "type": "string" - }, - "description": "The ID of the actor to target" - }, - { - "name": "path", - "in": "path", - "required": true, - "schema": { - "type": "string" - }, - "description": "The HTTP path to forward to the actor" - } - ], - "responses": { - "200": { - "description": "Response from actor's raw HTTP handler" - } - } - }, - "options": { - "parameters": [ - { - "name": "actorId", - "in": "path", - "required": true, - "schema": { - "type": "string" - }, - "description": "The ID of the actor to target" - }, - { - "name": "path", - "in": "path", - "required": true, - "schema": { - "type": "string" - }, - "description": "The HTTP path to forward to the actor" - } - ], - "responses": { - "200": { - "description": "Response from actor's raw HTTP handler" - } - } - } } } } \ No newline at end of file diff --git a/rivetkit-typescript/contrib-docs/HIBERNATABLE_CONNECTIONS.md b/rivetkit-typescript/contrib-docs/HIBERNATABLE_CONNECTIONS.md new file mode 100644 index 0000000000..7ee0757de5 --- /dev/null +++ b/rivetkit-typescript/contrib-docs/HIBERNATABLE_CONNECTIONS.md @@ -0,0 +1,86 @@ +# Hibernatable Connections + +## Lifecycle + +### New Connection + +```mermaid +sequenceDiagram + participant P as Pegboard + participant R as Runner + participant A as ActorDriver + participant I as Instance + participant D as ActorDefinition + + P->>R: ToClientWebSocketOpen + R->>A: Runner.config.websocket + A->>A: handleWebSocketConnection + A->>I: ConnectionManager.prepareConn + A->>D: ActorDefinition.onBeforeConnect + A->>I: ActorDefinition.createConnState + R->>R: WebSocketAdapter._handleOpen + R->>A: open event + A->>A: ConnectionManager.connectConn + A->>I: ActorDefinition.onConnect + note over A: TODO: persist + R->>P: ToServerWebSocketOpen +``` + +### Restore Connection + + +```mermaid +sequenceDiagram + participant P as Pegboard + participant R as Runner + participant A as ActorDriver + participant I as Instance + + note over P,I: Actor start + P->>R: ToClientCommands (CommandStartActor) + R->>A: Runner.config.restoreHibernatingRequests + note over R,A: TODO: This may be problematic + R->>P: ToServerEvents (ActorStateRunning) + + note over P,I: Actor Start + R->>A: Runner.config.onActorStart + A->>I: Instance.#restoreExistingActor + A->>A: ConnectionManager.restoreConnections + note over A: Restores connections in to memory + + note over P,I: Conn Restoration + R->>R: Tunnel.restoreHibernatingRequests + note over R: Returns existing connections from actor state + R->>A: Runner.config.websocket + A->>A: handleWebSocketConnection + A->>A: ConnectionManager.prepareConn + A->>A: ConnectionManager.#reconnectHibernatableConn +``` + +TODO: Disconnecting stale conns +TODO: Disconnecting zombie conns + +### Persisting Message Index + +```mermaid +sequenceDiagram + participant P as Pegboard + participant R as Runner + participant A as ActorDriver + + R->>R: _handleMessage + R->>A: message event + A->>A: update storage + A->>A: saveAfter(TODO) + note over A: ...after persist... + A->>A: persist + A->>A: afterPersist + A->>R: TODO: ack callback +``` + +### Close Connection + +``` +TODO +``` + diff --git a/rivetkit-typescript/packages/rivetkit/schemas/actor-persist/v3.bare b/rivetkit-typescript/packages/rivetkit/schemas/actor-persist/v3.bare index 9bbd047387..b339745baf 100644 --- a/rivetkit-typescript/packages/rivetkit/schemas/actor-persist/v3.bare +++ b/rivetkit-typescript/packages/rivetkit/schemas/actor-persist/v3.bare @@ -1,22 +1,28 @@ +type RequestId data +type Cbor data + # MARK: Connection type Subscription struct { eventName: str } # Connection associated with hibernatable WebSocket that should persist across lifecycles. -type HibernatableConn struct { +type Conn struct { # Connection ID generated by RivetKit id: str - parameters: data - state: data + parameters: Cbor + state: Cbor subscriptions: list # Request ID of the hibernatable WebSocket - hibernatableRequestId: data + hibernatableRequestId: RequestId # Last seen message from this WebSocket lastSeenTimestamp: i64 # Last seem message index for this WebSocket msgIndex: i64 + + requestPath: str + requestHeaders: map } # MARK: Schedule Event @@ -24,15 +30,14 @@ type ScheduleEvent struct { eventId: str timestamp: i64 action: str - args: optional + args: optional } # MARK: Actor type Actor struct { # Input data passed to the actor on initialization - input: optional + input: optional hasInitialized: bool - state: data - hibernatableConns: list + state: Cbor scheduledEvents: list } diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/conn/drivers/websocket.ts b/rivetkit-typescript/packages/rivetkit/src/actor/conn/drivers/websocket.ts index 14dc4dbf62..63354ad802 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/conn/drivers/websocket.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/conn/drivers/websocket.ts @@ -2,7 +2,6 @@ import type { WSContext } from "hono/ws"; import type { AnyConn } from "@/actor/conn/mod"; import type { AnyActorInstance } from "@/actor/instance/mod"; import type { CachedSerializer, Encoding } from "@/actor/protocol/serde"; -import type * as protocol from "@/schemas/client-protocol/mod"; import { loggerWithoutContext } from "../../log"; import { type ConnDriver, DriverReadyState } from "../driver"; diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/conn/mod.ts b/rivetkit-typescript/packages/rivetkit/src/actor/conn/mod.ts index f81483bb51..6214a23871 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/conn/mod.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/conn/mod.ts @@ -9,10 +9,11 @@ import { arrayBuffersEqual, bufferToArrayBuffer } from "@/utils"; import type { AnyDatabaseProvider } from "../database"; import { InternalError } from "../errors"; import type { ActorInstance } from "../instance/mod"; -import type { PersistedConn } from "../instance/persisted"; +import type { PersistedConn, RequestId } from "./persisted"; import { CachedSerializer } from "../protocol/serde"; import type { ConnDriver } from "./driver"; -import { StateManager } from "./state-manager"; +import { ConnData, ConnDataInput, StateManager } from "./state-manager"; +import { HibernatingWebSocketMetadata } from "@rivetkit/engine-runner"; export function generateConnRequestId(): string { return crypto.randomUUID(); @@ -24,13 +25,9 @@ export type AnyConn = Conn; export const CONN_CONNECTED_SYMBOL = Symbol("connected"); export const CONN_SPEAKS_RIVETKIT_SYMBOL = Symbol("speaksRivetKit"); -export const CONN_PERSIST_SYMBOL = Symbol("persist"); export const CONN_DRIVER_SYMBOL = Symbol("driver"); export const CONN_ACTOR_SYMBOL = Symbol("actor"); -export const CONN_STATE_ENABLED_SYMBOL = Symbol("stateEnabled"); -export const CONN_PERSIST_RAW_SYMBOL = Symbol("persistRaw"); -export const CONN_HAS_CHANGES_SYMBOL = Symbol("hasChanges"); -export const CONN_MARK_SAVED_SYMBOL = Symbol("markSaved"); +export const CONN_STATE_MANAGER_SYMBOL = Symbol("stateManager"); export const CONN_SEND_MESSAGE_SYMBOL = Symbol("sendMessage"); /** @@ -41,32 +38,39 @@ export const CONN_SEND_MESSAGE_SYMBOL = Symbol("sendMessage"); * @see {@link https://rivet.dev/docs/connections|Connection Documentation} */ export class Conn { - subscriptions: Set = new Set(); - - // TODO: Remove this cyclical reference #actor: ActorInstance; - // MARK: - Managers - #stateManager!: StateManager; - - /** - * If undefined, then nothing is connected to this. - */ - [CONN_DRIVER_SYMBOL]?: ConnDriver; - - // MARK: - Public Getters - get [CONN_ACTOR_SYMBOL](): ActorInstance { return this.#actor; } - /** Connections exist before being connected to an actor. If true, this connection has been connected. */ + #stateManager!: StateManager; + + get [CONN_STATE_MANAGER_SYMBOL]() { + return this.#stateManager; + } + + /** + * Connections exist before being connected to an actor. If true, this + * connection has been connected. + **/ [CONN_CONNECTED_SYMBOL] = false; + /** + * If undefined, then no socket is connected to this conn + */ + [CONN_DRIVER_SYMBOL]?: ConnDriver; + + /** + * If this connection is speaking the RivetKit protocol. If false, this is + * a raw connection for WebSocket or fetch or inspector. + **/ [CONN_SPEAKS_RIVETKIT_SYMBOL](): boolean { return this[CONN_DRIVER_SYMBOL]?.rivetKitProtocol !== undefined; } + subscriptions: Set = new Set(); + #assertConnected() { if (!this[CONN_CONNECTED_SYMBOL]) throw new InternalError( @@ -74,16 +78,9 @@ export class Conn { ); } - get [CONN_PERSIST_SYMBOL](): PersistedConn { - return this.#stateManager.persist; - } - + // MARK: - Public Getters get params(): CP { - return this.#stateManager.params; - } - - get [CONN_STATE_ENABLED_SYMBOL](): boolean { - return this.#stateManager.stateEnabled; + return this.#stateManager.ephemeralData.parameters; } /** @@ -108,7 +105,7 @@ export class Conn { * Unique identifier for the connection. */ get id(): ConnId { - return this.#stateManager.persist.connId; + return this.#stateManager.ephemeralData.id; } /** @@ -117,26 +114,7 @@ export class Conn { * If the underlying connection can hibernate. */ get isHibernatable(): boolean { - const hibernatableRequestId = - this.#stateManager.persist.hibernatableRequestId; - if (!hibernatableRequestId) { - return false; - } - return ( - this.#actor.persist.hibernatableConns.findIndex((conn: any) => - arrayBuffersEqual( - conn.hibernatableRequestId, - hibernatableRequestId, - ), - ) > -1 - ); - } - - /** - * Timestamp of the last time the connection was seen, i.e. the last time the connection was active and checked for liveness. - */ - get lastSeen(): number { - return this.#stateManager.persist.lastSeen; + return this[CONN_DRIVER_SYMBOL]?.hibernatable ?? false; } /** @@ -148,34 +126,15 @@ export class Conn { */ constructor( actor: ActorInstance, - persist: PersistedConn, + data: ConnDataInput, ) { this.#actor = actor; - this.#stateManager = new StateManager(this); - this.#stateManager.initPersistProxy(persist); + this.#stateManager = new StateManager(this, data); } /** - * Returns whether this connection has unsaved changes + * Sends a raw message to the underlying connection. */ - [CONN_HAS_CHANGES_SYMBOL](): boolean { - return this.#stateManager.hasChanges(); - } - - /** - * Marks changes as saved - */ - [CONN_MARK_SAVED_SYMBOL]() { - this.#stateManager.markSaved(); - } - - /** - * Gets the raw persist data for serialization - */ - get [CONN_PERSIST_RAW_SYMBOL](): PersistedConn { - return this.#stateManager.persistRaw; - } - [CONN_SEND_MESSAGE_SYMBOL](message: CachedSerializer) { if (this[CONN_DRIVER_SYMBOL]) { const driver = this[CONN_DRIVER_SYMBOL]; diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/conn/persisted.ts b/rivetkit-typescript/packages/rivetkit/src/actor/conn/persisted.ts new file mode 100644 index 0000000000..4a94a95720 --- /dev/null +++ b/rivetkit-typescript/packages/rivetkit/src/actor/conn/persisted.ts @@ -0,0 +1,80 @@ +/** + * Persisted data structures for connections. + * + * Keep this file in sync with the Connection section of rivetkit-typescript/packages/rivetkit/schemas/actor-persist/ + */ + +import type * as persistSchema from "@/schemas/actor-persist/mod"; +import { bufferToArrayBuffer } from "@/utils"; +import * as cbor from "cbor-x"; + +export type RequestId = ArrayBuffer; + +export type Cbor = ArrayBuffer; + +// MARK: Connection +/** Event subscription for connection */ +export interface PersistedSubscription { + eventName: string; +} + +/** Connection associated with hibernatable WebSocket that should persist across lifecycles */ +export interface PersistedConn { + /** Connection ID generated by RivetKit */ + id: string; + parameters: CP; + state: CS; + subscriptions: PersistedSubscription[]; + /** Request ID of the hibernatable WebSocket */ + hibernatableRequestId: RequestId; + /** Last seen message from this WebSocket */ + lastSeenTimestamp: number; + /** Last seen message index for this WebSocket */ + msgIndex: number; + requestPath: string; + requestHeaders: Record; +} + +/** + * Converts persisted connection data to BARE schema format for serialization. + * @throws {Error} If the connection is ephemeral (not hibernatable) + */ +export function convertConnToBarePersistedConn( + persist: PersistedConn, +): persistSchema.Conn { + return { + id: persist.id, + parameters: bufferToArrayBuffer(cbor.encode(persist.parameters)), + state: bufferToArrayBuffer(cbor.encode(persist.state)), + subscriptions: persist.subscriptions.map((sub) => ({ + eventName: sub.eventName, + })), + hibernatableRequestId: persist.hibernatableRequestId, + lastSeenTimestamp: BigInt(persist.lastSeenTimestamp), + msgIndex: BigInt(persist.msgIndex), + requestPath: persist.requestPath, + requestHeaders: new Map(Object.entries(persist.requestHeaders)), + }; +} + +/** + * Converts BARE schema format to persisted connection data. + * @throws {Error} If the connection is ephemeral (not hibernatable) + */ +export function convertConnFromBarePersistedConn( + bareData: persistSchema.Conn, +): PersistedConn { + return { + id: bareData.id, + parameters: cbor.decode(new Uint8Array(bareData.parameters)), + state: cbor.decode(new Uint8Array(bareData.state)), + subscriptions: bareData.subscriptions.map((sub) => ({ + eventName: sub.eventName, + })), + hibernatableRequestId: bareData.hibernatableRequestId, + lastSeenTimestamp: Number(bareData.lastSeenTimestamp), + msgIndex: Number(bareData.msgIndex), + requestPath: bareData.requestPath, + requestHeaders: Object.fromEntries(bareData.requestHeaders), + }; +} diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/conn/state-manager.ts b/rivetkit-typescript/packages/rivetkit/src/actor/conn/state-manager.ts index a79895dea4..07f1aa91b4 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/conn/state-manager.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/conn/state-manager.ts @@ -1,8 +1,40 @@ import onChange from "on-change"; +import * as cbor from "cbor-x"; import { isCborSerializable } from "@/common/utils"; import * as errors from "../errors"; -import type { PersistedConn } from "../instance/persisted"; -import { CONN_ACTOR_SYMBOL, CONN_STATE_ENABLED_SYMBOL, type Conn } from "./mod"; +import { CONN_ACTOR_SYMBOL, type Conn } from "./mod"; +import type { PersistedConn } from "./persisted"; +import { assertUnreachable } from "../utils"; +import { HibernatingWebSocketMetadata } from "@rivetkit/engine-runner"; +import invariant from "invariant"; +import type * as persistSchema from "@/schemas/actor-persist/mod"; +import { bufferToArrayBuffer } from "@/utils"; + +/** Pick a subset of persisted data used to represent ephemeral connections */ +export type EphemeralConn = Pick< + PersistedConn, + "id" | "parameters" | "state" +>; + +export type ConnDataInput = + | { ephemeral: EphemeralConn } + | { hibernatable: PersistedConn }; + +export type ConnData = + | { + ephemeral: { + /** In-memory data representing this connection */ + data: EphemeralConn; + }; + } + | { + hibernatable: { + /** Persisted data with on-change proxy */ + data: PersistedConn; + /** Raw persisted data without proxy */ + dataRaw: PersistedConn; + }; + }; /** * Manages connection state persistence, proxying, and change tracking. @@ -11,27 +43,84 @@ import { CONN_ACTOR_SYMBOL, CONN_STATE_ENABLED_SYMBOL, type Conn } from "./mod"; export class StateManager { #conn: Conn; - // State tracking - #persist!: PersistedConn; - #persistRaw!: PersistedConn; - #changed = false; + /** + * Data representing this connection. + * + * This is stored as a struct for both ephemeral and hibernatable conns in + * order to keep the separation clear between the two. + */ + #data!: ConnData; + + /** Flagged by on-change if persisted data changes */ + hibernatableDataChanged = false; - constructor(conn: Conn) { + constructor( + conn: Conn, + data: ConnDataInput, + ) { this.#conn = conn; + + if ("ephemeral" in data) { + this.#data = { ephemeral: { data: data.ephemeral } }; + } else if ("hibernatable" in data) { + // Listen for changes to the object + const persistRaw = data.hibernatable; + const persist = onChange( + persistRaw, + ( + path: string, + value: any, + _previousValue: any, + _applyData: any, + ) => { + this.#handleChange(path, value); + }, + { ignoreDetached: true }, + ); + this.#data = { + hibernatable: { data: persist, dataRaw: persistRaw }, + }; + } else { + assertUnreachable(data); + } } - // MARK: - Public API + /** + * Returns the ephemeral or persisted data for this connectioned. + * + * This property is used to be able to treat both memory & persist conns + * identical by looking up the correct underlying data structure. + */ + get ephemeralData(): EphemeralConn { + if ("hibernatable" in this.#data) { + return this.#data.hibernatable.data; + } else if ("ephemeral" in this.#data) { + return this.#data.ephemeral.data; + } else { + return assertUnreachable(this.#data); + } + } - get persist(): PersistedConn { - return this.#persist; + get hibernatableData(): PersistedConn | undefined { + if ("hibernatable" in this.#data) { + return this.#data.hibernatable.data; + } else { + return undefined; + } } - get persistRaw(): PersistedConn { - return this.#persistRaw; + hibernatableDataOrError(): PersistedConn { + const hibernatable = this.hibernatableData; + invariant(hibernatable, "missing hibernatable data"); + return hibernatable; } - get changed(): boolean { - return this.#changed; + get hibernatableDataRaw(): PersistedConn | undefined { + if ("hibernatable" in this.#data) { + return this.#data.hibernatable.dataRaw; + } else { + return undefined; + } } get stateEnabled(): boolean { @@ -40,69 +129,22 @@ export class StateManager { get state(): CS { this.#validateStateEnabled(); - if (!this.#persist.state) throw new Error("state should exists"); - return this.#persist.state; + const state = this.ephemeralData.state; + if (!state) throw new Error("state should exists"); + return state; } set state(value: CS) { this.#validateStateEnabled(); - this.#persist.state = value; - } - - get params(): CP { - return this.#persist.params; - } - - // MARK: - Initialization - - /** - * Creates proxy for persist object that handles automatic state change detection. - */ - initPersistProxy(target: PersistedConn) { - // Set raw persist object - this.#persistRaw = target; - - // If this can't be proxied, return raw value - if (target === null || typeof target !== "object") { - this.#persist = target; - return; - } - - // Listen for changes to the object - this.#persist = onChange( - target, - ( - path: string, - value: any, - _previousValue: any, - _applyData: any, - ) => { - this.#handleChange(path, value); - }, - { ignoreDetached: true }, - ); - } - - // MARK: - Change Management - - /** - * Returns whether this connection has unsaved changes - */ - hasChanges(): boolean { - return this.#changed; + this.ephemeralData.state = value; } - /** - * Marks changes as saved - */ markSaved() { - this.#changed = false; + this.hibernatableDataChanged = false; } - // MARK: - Private Helpers - #validateStateEnabled() { - if (!this.stateEnabled) { + if (!this.#conn[CONN_ACTOR_SYMBOL].connStateEnabled) { throw new errors.ConnStateNotEnabled(); } } @@ -126,7 +168,7 @@ export class StateManager { } } - this.#changed = true; + this.hibernatableDataChanged = true; this.#conn[CONN_ACTOR_SYMBOL].rLog.debug({ msg: "conn onChange triggered", connId: this.#conn.id, @@ -138,4 +180,34 @@ export class StateManager { this.#conn, ); } + + addSubscription({ eventName }: { eventName: string }) { + const hibernatable = this.hibernatableData; + if (!hibernatable) return; + hibernatable.subscriptions.push({ + eventName, + }); + } + + removeSubscription({ eventName }: { eventName: string }) { + const hibernatable = this.hibernatableData; + if (!hibernatable) return; + const subIdx = hibernatable.subscriptions.findIndex( + (s) => s.eventName === eventName, + ); + if (subIdx !== -1) { + hibernatable.subscriptions.splice(subIdx, 1); + } + return subIdx !== -1; + } + + buildHwsMeta(): HibernatingWebSocketMetadata { + const hibernatable = this.hibernatableDataOrError(); + return { + requestId: hibernatable.hibernatableRequestId, + path: hibernatable.requestPath, + headers: hibernatable.requestHeaders, + messageIndex: hibernatable.msgIndex, + }; + } } diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/driver.ts b/rivetkit-typescript/packages/rivetkit/src/actor/driver.ts index c2fa43726b..c134e03797 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/driver.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/driver.ts @@ -4,6 +4,7 @@ import type { ManagerDriver } from "@/manager/driver"; import type { RegistryConfig } from "@/registry/config"; import type { RunnerConfig } from "@/registry/run-config"; import type { AnyActorInstance } from "./instance/mod"; +import { AnyConn, Conn } from "./conn/mod"; export type ActorDriverBuilder = ( registryConfig: RegistryConfig, @@ -77,4 +78,7 @@ export interface ActorDriver { /** Extra properties to add to logs for each actor. */ getExtraActorLogParams?(): Record; + + onAfterPersistActor?(actor: AnyActorInstance): void; + onAfterPersistConn?(actor: AnyConn): void; } diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/instance/connection-manager.ts b/rivetkit-typescript/packages/rivetkit/src/actor/instance/connection-manager.ts index d6aa418a86..599cdca16d 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/instance/connection-manager.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/instance/connection-manager.ts @@ -1,17 +1,16 @@ import * as cbor from "cbor-x"; import invariant from "invariant"; import { TO_CLIENT_VERSIONED } from "@/schemas/client-protocol/versioned"; +import { CONN_VERSIONED } from "@/schemas/actor-persist/versioned"; import { ToClientSchema } from "@/schemas/client-protocol-zod/mod"; import { arrayBuffersEqual, stringifyError } from "@/utils"; import type { ConnDriver } from "../conn/driver"; import { CONN_CONNECTED_SYMBOL, CONN_DRIVER_SYMBOL, - CONN_MARK_SAVED_SYMBOL, - CONN_PERSIST_RAW_SYMBOL, - CONN_PERSIST_SYMBOL, CONN_SEND_MESSAGE_SYMBOL, CONN_SPEAKS_RIVETKIT_SYMBOL, + CONN_STATE_MANAGER_SYMBOL, Conn, type ConnId, } from "../conn/mod"; @@ -23,8 +22,10 @@ import { CachedSerializer } from "../protocol/serde"; import { deadline } from "../utils"; import { makeConnKey } from "./kv"; import type { ActorInstance } from "./mod"; -import type { PersistedConn } from "./persisted"; - +import { + convertConnToBarePersistedConn, + PersistedConn, +} from "../conn/persisted"; /** * Manages all connection-related operations for an actor instance. * Handles connection creation, tracking, hibernation, and cleanup. @@ -80,22 +81,26 @@ export class ConnectionManager< driver: ConnDriver, params: CP, request: Request | undefined, + isRestoringHibernatable: boolean, ): Promise> { this.#actor.assertReady(); + invariant( + request?.url.startsWith("http://actor/") ?? true, + "request must start with `http://actor/`", + ); + // Check for hibernatable websocket reconnection - if (driver.requestIdBuf && driver.hibernatable) { - const existingConn = this.#findHibernatableConn( - driver.requestIdBuf, + if (isRestoringHibernatable) { + const existingConn = this.findHibernatableConn(driver.requestIdBuf); + invariant( + existingConn, + "cannot find connection for restoring connection", ); - - if (existingConn) { - return this.#reconnectHibernatableConn(existingConn, driver); - } + return this.#reconnectHibernatableConn(existingConn, driver); } // Create new connection - const persist = this.#actor.persist; if (this.#actor.config.onBeforeConnect) { const ctx = new OnBeforeConnectContext(this.#actor, request); await this.#actor.config.onBeforeConnect(ctx, params); @@ -108,29 +113,34 @@ export class ConnectionManager< } // Create connection persist data - const connPersist: PersistedConn = { - connId: crypto.randomUUID(), - params: params, - state: connState as CS, - lastSeen: Date.now(), - subscriptions: [], - }; - - // Check if hibernatable - if (driver.requestIdBuf) { - const isHibernatable = this.#isHibernatableRequest( - driver.requestIdBuf, - ); - if (isHibernatable) { - connPersist.hibernatableRequestId = driver.requestIdBuf; - } - } - - // Create connection instance - const conn = new Conn(this.#actor, connPersist); - conn[CONN_DRIVER_SYMBOL] = driver; - - return conn; + const hibernatable = driver.hibernatable; + invariant( + hibernatable && driver.requestIdBuf, + "must have requestIdBuf if hibernatable", + ); + throw "TODO"; + // TODO: + // const connPersist: PersistedConn = { + // id: crypto.randomUUID(), + // parameters: params, + // state: connState as CS, + // subscriptions: [], + // // Fallback to empty buf if not provided since we don't use this value + // hibernatableRequestId: driver.hibernatable + // ? driver.requestIdBuf + // : new ArrayBuffer(), + // lastSeenTimestamp: Date.now(), + // // First message index will be 1, so we start at 0 + // msgIndex: 0, + // requestPath: "", + // requestHeaders: undefined + // }; + + // // Create connection instance + // const conn = new Conn(this.#actor, connPersist); + // conn[CONN_DRIVER_SYMBOL] = driver; + // + // return conn; } /** @@ -183,6 +193,52 @@ export class ConnectionManager< } } + #reconnectHibernatableConn( + existingConn: Conn, + driver: ConnDriver, + ): Conn { + this.#actor.rLog.debug({ + msg: "reconnecting hibernatable websocket connection", + connectionId: existingConn.id, + requestId: driver.requestId, + }); + + // Clean up existing driver state if present + if (existingConn[CONN_DRIVER_SYMBOL]) { + this.#disconnectExistingDriver(existingConn); + } + + // Update connection with new socket + existingConn[CONN_DRIVER_SYMBOL] = driver; + existingConn[ + CONN_STATE_MANAGER_SYMBOL + ].hibernatableDataOrError().lastSeenTimestamp = Date.now(); + + // Mark as changed for persistence + this.#changedConnections.add(existingConn.id); + + // Reset sleep timer since we have an active connection + this.#actor.resetSleepTimer(); + + // Mark connection as connected + existingConn[CONN_CONNECTED_SYMBOL] = true; + + this.#actor.inspector.emitter.emit("connectionUpdated"); + + return existingConn; + } + + #disconnectExistingDriver(conn: Conn) { + const driver = conn[CONN_DRIVER_SYMBOL]; + if (driver?.disconnect) { + driver.disconnect( + this.#actor, + conn, + "Reconnecting hibernatable websocket with new driver state", + ); + } + } + /** * Handle connection disconnection. * @@ -242,7 +298,7 @@ export class ConnectionManager< } /** - * Utilify funtion for call sites that don't need a separate prepare and connect phase. + * Utilify function for call sites that don't need a separate prepare and connect phase. */ async prepareAndConnectConn( driver: ConnDriver, @@ -262,10 +318,9 @@ export class ConnectionManager< restoreConnections(connections: PersistedConn[]) { for (const connPersist of connections) { // Create connection instance - const conn = new Conn( - this.#actor, - connPersist, - ); + const conn = new Conn(this.#actor, { + hibernatable: connPersist, + }); this.#connections.set(conn.id, conn); // Restore subscriptions @@ -282,15 +337,29 @@ export class ConnectionManager< /** * Gets persistence data for all changed connections. */ - getChangedConnectionsData(): Array<[Uint8Array, Uint8Array]> { + getChangedConnectionsKvEntries(): Array<[Uint8Array, Uint8Array]> { const entries: Array<[Uint8Array, Uint8Array]> = []; for (const connId of this.#changedConnections) { const conn = this.#connections.get(connId); if (conn) { - const connData = cbor.encode(conn[CONN_PERSIST_RAW_SYMBOL]); - entries.push([makeConnKey(connId), connData]); - conn[CONN_MARK_SAVED_SYMBOL](); + const connStateManager = conn[CONN_STATE_MANAGER_SYMBOL]; + const hibernatableDataRaw = + connStateManager.hibernatableDataRaw; + if (hibernatableDataRaw) { + const bareData = convertConnToBarePersistedConn( + hibernatableDataRaw, + ); + const connData = + CONN_VERSIONED.serializeWithEmbeddedVersion(bareData); + entries.push([makeConnKey(connId), connData]); + connStateManager.markSaved(); + } else { + this.#actor.log.warn({ + msg: "missing raw hibernatable data for conn in getChangedConnectionsData", + connId: conn.id, + }); + } } } @@ -299,53 +368,19 @@ export class ConnectionManager< // MARK: - Private Helpers - #findHibernatableConn( + findHibernatableConn( requestIdBuf: ArrayBuffer, ): Conn | undefined { - return Array.from(this.#connections.values()).find( - (conn) => - conn[CONN_PERSIST_SYMBOL].hibernatableRequestId && - arrayBuffersEqual( - conn[CONN_PERSIST_SYMBOL].hibernatableRequestId, - requestIdBuf, - ), - ); - } - - #reconnectHibernatableConn( - existingConn: Conn, - driver: ConnDriver, - ): Conn { - this.#actor.rLog.debug({ - msg: "reconnecting hibernatable websocket connection", - connectionId: existingConn.id, - requestId: driver.requestId, + return Array.from(this.#connections.values()).find((conn) => { + const connStateManager = conn[CONN_STATE_MANAGER_SYMBOL]; + const connRequestId = + connStateManager.hibernatableDataRaw?.hibernatableRequestId; + return ( + connRequestId && arrayBuffersEqual(connRequestId, requestIdBuf) + ); }); - - // Clean up existing driver state if present - if (existingConn[CONN_DRIVER_SYMBOL]) { - this.#cleanupDriverState(existingConn); - } - - // Update connection with new socket - existingConn[CONN_DRIVER_SYMBOL] = driver; - existingConn[CONN_PERSIST_SYMBOL].lastSeen = Date.now(); - - this.#actor.inspector.emitter.emit("connectionUpdated"); - - return existingConn; } - #cleanupDriverState(conn: Conn) { - const driver = conn[CONN_DRIVER_SYMBOL]; - if (driver?.disconnect) { - driver.disconnect( - this.#actor, - conn, - "Reconnecting hibernatable websocket with new driver state", - ); - } - } async #createConnState( params: CP, @@ -373,14 +408,6 @@ export class ConnectionManager< ); } - #isHibernatableRequest(requestIdBuf: ArrayBuffer): boolean { - return ( - this.#actor.persist.hibernatableConns.findIndex((conn) => - arrayBuffersEqual(conn.hibernatableRequestId, requestIdBuf), - ) !== -1 - ); - } - #callOnConnect(conn: Conn) { if (this.#actor.config.onConnect) { try { diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/instance/event-manager.ts b/rivetkit-typescript/packages/rivetkit/src/actor/instance/event-manager.ts index 344314ab79..26f92af1e3 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/instance/event-manager.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/instance/event-manager.ts @@ -7,9 +7,9 @@ import { } from "@/schemas/client-protocol-zod/mod"; import { bufferToArrayBuffer } from "@/utils"; import { - CONN_PERSIST_SYMBOL, CONN_SEND_MESSAGE_SYMBOL, CONN_SPEAKS_RIVETKIT_SYMBOL, + CONN_STATE_MANAGER_SYMBOL, type Conn, } from "../conn/mod"; import type { AnyDatabaseProvider } from "../database"; @@ -65,7 +65,9 @@ export class EventManager { // Persist subscription if not restoring from persistence if (!fromPersist) { - connection[CONN_PERSIST_SYMBOL].subscriptions.push({ eventName }); + connection[CONN_STATE_MANAGER_SYMBOL].addSubscription({ + eventName, + }); // Mark connection as changed for persistence const connectionManager = (this.#actor as any).connectionManager; @@ -125,12 +127,10 @@ export class EventManager { // Update persistence if not part of connection removal if (!fromRemoveConn) { // Remove from persisted subscriptions - const subIdx = connection[ - CONN_PERSIST_SYMBOL - ].subscriptions.findIndex((s) => s.eventName === eventName); - if (subIdx !== -1) { - connection[CONN_PERSIST_SYMBOL].subscriptions.splice(subIdx, 1); - } else { + const removed = connection[ + CONN_STATE_MANAGER_SYMBOL + ].removeSubscription({ eventName }); + if (!removed) { this.#actor.rLog.warn({ msg: "subscription does not exist in persist", eventName, diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/instance/mod.ts b/rivetkit-typescript/packages/rivetkit/src/actor/instance/mod.ts index 5f825075d0..e94fb83211 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/instance/mod.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/instance/mod.ts @@ -7,7 +7,10 @@ import { stringifyError } from "@/common/utils"; import type { UniversalWebSocket } from "@/common/websocket-interface"; import { ActorInspector } from "@/inspector/actor"; import type { Registry } from "@/mod"; -import { ACTOR_VERSIONED } from "@/schemas/actor-persist/versioned"; +import { + ACTOR_VERSIONED, + CONN_VERSIONED, +} from "@/schemas/actor-persist/versioned"; import type * as protocol from "@/schemas/client-protocol/mod"; import { TO_CLIENT_VERSIONED } from "@/schemas/client-protocol/versioned"; import { ToClientSchema } from "@/schemas/client-protocol-zod/mod"; @@ -17,9 +20,7 @@ import type { ConnDriver } from "../conn/driver"; import { createHttpSocket } from "../conn/drivers/http"; import { CONN_DRIVER_SYMBOL, - CONN_PERSIST_SYMBOL, - CONN_SEND_MESSAGE_SYMBOL, - CONN_STATE_ENABLED_SYMBOL, + CONN_STATE_MANAGER_SYMBOL, type Conn, type ConnId, } from "../conn/mod"; @@ -43,9 +44,16 @@ import { import { ConnectionManager } from "./connection-manager"; import { EventManager } from "./event-manager"; import { KEYS } from "./kv"; -import type { PersistedActor, PersistedConn } from "./persisted"; +import { + convertActorFromBarePersisted, + type PersistedActor, +} from "./persisted"; import { ScheduleManager } from "./schedule-manager"; import { type SaveStateOptions, StateManager } from "./state-manager"; +import { + convertConnFromBarePersistedConn, + PersistedConn, +} from "../conn/persisted"; export type { SaveStateOptions }; @@ -148,24 +156,31 @@ export class ActorInstance { getConnections: async () => { return Array.from( this.connectionManager.connections.entries(), - ).map(([id, conn]) => ({ - type: conn[CONN_DRIVER_SYMBOL]?.type, - id, - params: conn.params as any, - state: conn[CONN_STATE_ENABLED_SYMBOL] - ? conn.state - : undefined, - subscriptions: conn.subscriptions.size, - lastSeen: conn.lastSeen, - stateEnabled: conn[CONN_STATE_ENABLED_SYMBOL], - isHibernatable: conn.isHibernatable, - hibernatableRequestId: conn[CONN_PERSIST_SYMBOL] - .hibernatableRequestId - ? idToStr( - conn[CONN_PERSIST_SYMBOL].hibernatableRequestId, - ) - : undefined, - })); + ).map(([id, conn]) => { + const connStateManager = conn[CONN_STATE_MANAGER_SYMBOL]; + return { + type: conn[CONN_DRIVER_SYMBOL]?.type, + id, + params: conn.params as any, + stateEnabled: connStateManager.stateEnabled, + state: connStateManager.stateEnabled + ? connStateManager.state + : undefined, + subscriptions: conn.subscriptions.size, + lastSeen: + connStateManager.hibernatableDataRaw + ?.lastSeenTimestamp, + isHibernatable: conn.isHibernatable, + hibernatableRequestId: connStateManager + .hibernatableDataRaw?.hibernatableRequestId + ? idToStr( + connStateManager.hibernatableDataRaw + .hibernatableRequestId, + ) + : undefined, + // TODO: Include the underlying request for path & headers? + }; + }); }, setState: async (state: unknown) => { if (!this.stateEnabled) { @@ -265,7 +280,7 @@ export class ActorInstance { } // MARK: - State Access - get persist(): PersistedActor { + get persist(): PersistedActor { return this.#stateManager.persist; } @@ -506,7 +521,8 @@ export class ActorInstance { // Save connection changes if (this.connectionManager.changedConnections.size > 0) { - const entries = this.connectionManager.getChangedConnectionsData(); + const entries = + this.connectionManager.getChangedConnectionsKvEntries(); if (entries.length > 0) { await this.driver.kvBatchPut(this.#actorId, entries); } @@ -780,8 +796,7 @@ export class ActorInstance { const bareData = ACTOR_VERSIONED.deserializeWithEmbeddedVersion(persistDataBuffer); - const persistData = - this.#stateManager.convertFromBarePersisted(bareData); + const persistData = convertActorFromBarePersisted(bareData); if (persistData.hasInitialized) { // Restore existing actor @@ -795,7 +810,7 @@ export class ActorInstance { this.#scheduleManager.setPersist(this.#stateManager.persist); } - async #restoreExistingActor(persistData: PersistedActor) { + async #restoreExistingActor(persistData: PersistedActor) { // List all connection keys const connEntries = await this.driver.kvListPrefix( this.#actorId, @@ -806,7 +821,10 @@ export class ActorInstance { const connections: PersistedConn[] = []; for (const [_key, value] of connEntries) { try { - const conn = cbor.decode(value) as PersistedConn; + const bareData = CONN_VERSIONED.deserializeWithEmbeddedVersion( + new Uint8Array(value), + ); + const conn = convertConnFromBarePersistedConn(bareData); connections.push(conn); } catch (error) { this.#rLog.error({ @@ -819,7 +837,6 @@ export class ActorInstance { this.#rLog.info({ msg: "actor restoring", connections: connections.length, - hibernatableWebSockets: persistData.hibernatableConns.length, }); // Initialize state @@ -829,7 +846,7 @@ export class ActorInstance { this.connectionManager.restoreConnections(connections); } - async #createNewActor(persistData: PersistedActor) { + async #createNewActor(persistData: PersistedActor) { this.#rLog.info({ msg: "actor creating" }); // Initialize state diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/instance/persisted.ts b/rivetkit-typescript/packages/rivetkit/src/actor/instance/persisted.ts index fee27efda2..d431d2bb55 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/instance/persisted.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/instance/persisted.ts @@ -1,54 +1,67 @@ /** - * Persisted data structures matching actor-persist/v3.bare schema + * Persisted data structures for actors. + * + * Keep this file in sync with the Connection section of rivetkit-typescript/packages/rivetkit/schemas/actor-persist/ */ +import type * as persistSchema from "@/schemas/actor-persist/mod"; +import { bufferToArrayBuffer } from "@/utils"; +import * as cbor from "cbor-x"; + +export type Cbor = ArrayBuffer; + +// MARK: Schedule Event /** Scheduled event to be executed at a specific timestamp */ export interface PersistedScheduleEvent { eventId: string; timestamp: number; action: string; - args?: ArrayBuffer; -} - -/** Connection associated with hibernatable WebSocket that should persist across lifecycles */ -export interface PersistedHibernatableConn { - /** Connection ID generated by RivetKit */ - id: string; - parameters: CP; - state: CS; - subscriptions: PersistedSubscription[]; - /** Request ID of the hibernatable WebSocket */ - hibernatableRequestId: ArrayBuffer; - /** Last seen message from this WebSocket */ - lastSeenTimestamp: number; - /** Last seen message index for this WebSocket */ - msgIndex: number; + args?: Cbor; } -/** State object that gets automatically persisted to storage. */ -export interface PersistedActor { +// MARK: Actor +/** State object that gets automatically persisted to storage */ +export interface PersistedActor { /** Input data passed to the actor on initialization */ input?: I; hasInitialized: boolean; state: S; - hibernatableConns: PersistedHibernatableConn[]; scheduledEvents: PersistedScheduleEvent[]; } -/** Object representing connection that gets persisted to storage separately via KV. */ -export interface PersistedConn { - connId: string; - params: CP; - state: CS; - subscriptions: PersistedSubscription[]; - - /** Last time the socket was seen. This is set when disconnected so we can determine when we need to clean this up. */ - lastSeen: number; - - /** Request ID of the hibernatable WebSocket. See PersistedActor.hibernatableConns */ - hibernatableRequestId?: ArrayBuffer; +export function convertActorToBarePersisted( + persist: PersistedActor, +): persistSchema.Actor { + return { + input: + persist.input !== undefined + ? bufferToArrayBuffer(cbor.encode(persist.input)) + : null, + hasInitialized: persist.hasInitialized, + state: bufferToArrayBuffer(cbor.encode(persist.state)), + scheduledEvents: persist.scheduledEvents.map((event) => ({ + eventId: event.eventId, + timestamp: BigInt(event.timestamp), + action: event.action, + args: event.args ?? null, + })), + }; } -export interface PersistedSubscription { - eventName: string; +export function convertActorFromBarePersisted( + bareData: persistSchema.Actor, +): PersistedActor { + return { + input: bareData.input + ? cbor.decode(new Uint8Array(bareData.input)) + : undefined, + hasInitialized: bareData.hasInitialized, + state: cbor.decode(new Uint8Array(bareData.state)), + scheduledEvents: bareData.scheduledEvents.map((event) => ({ + eventId: event.eventId, + timestamp: Number(event.timestamp), + action: event.action, + args: event.args ?? undefined, + })), + }; } diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/instance/state-manager.ts b/rivetkit-typescript/packages/rivetkit/src/actor/instance/state-manager.ts index 91fcdf9c83..2141fd78a8 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/instance/state-manager.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/instance/state-manager.ts @@ -13,7 +13,7 @@ import * as errors from "../errors"; import { isConnStatePath, isStatePath } from "../utils"; import { KEYS } from "./kv"; import type { ActorInstance } from "./mod"; -import type { PersistedActor } from "./persisted"; +import { convertActorToBarePersisted, type PersistedActor } from "./persisted"; export interface SaveStateOptions { /** @@ -33,8 +33,8 @@ export class StateManager { #actorDriver: ActorDriver; // State tracking - #persist!: PersistedActor; - #persistRaw!: PersistedActor; + #persist!: PersistedActor; + #persistRaw!: PersistedActor; #persistChanged = false; #isInOnStateChange = false; @@ -61,11 +61,11 @@ export class StateManager { // MARK: - Public API - get persist(): PersistedActor { + get persist(): PersistedActor { return this.#persist; } - get persistRaw(): PersistedActor { + get persistRaw(): PersistedActor { return this.#persistRaw; } @@ -93,7 +93,7 @@ export class StateManager { * Initializes state from persisted data or creates new state. */ async initializeState( - persistData: PersistedActor, + persistData: PersistedActor, ): Promise { if (!persistData.hasInitialized) { // Create initial state @@ -132,7 +132,7 @@ export class StateManager { /** * Creates proxy for persist object that handles automatic state change detection. */ - initPersistProxy(target: PersistedActor) { + initPersistProxy(target: PersistedActor) { // Set raw persist object this.#persistRaw = target; @@ -252,81 +252,13 @@ export class StateManager { this.#persistChanged = false; - const bareData = this.convertToBarePersisted(this.#persistRaw); + const bareData = convertActorToBarePersisted(this.#persistRaw); return [ KEYS.PERSIST_DATA, ACTOR_VERSIONED.serializeWithEmbeddedVersion(bareData), ]; } - // MARK: - BARE Conversion - - convertToBarePersisted( - persist: PersistedActor, - ): persistSchema.Actor { - const hibernatableConns: persistSchema.HibernatableConn[] = - persist.hibernatableConns.map((conn) => ({ - id: conn.id, - parameters: bufferToArrayBuffer( - cbor.encode(conn.parameters || {}), - ), - state: bufferToArrayBuffer(cbor.encode(conn.state || {})), - subscriptions: conn.subscriptions.map((sub) => ({ - eventName: sub.eventName, - })), - hibernatableRequestId: conn.hibernatableRequestId, - lastSeenTimestamp: BigInt(conn.lastSeenTimestamp), - msgIndex: BigInt(conn.msgIndex), - })); - - return { - input: - persist.input !== undefined - ? bufferToArrayBuffer(cbor.encode(persist.input)) - : null, - hasInitialized: persist.hasInitialized, - state: bufferToArrayBuffer(cbor.encode(persist.state)), - hibernatableConns, - scheduledEvents: persist.scheduledEvents.map((event) => ({ - eventId: event.eventId, - timestamp: BigInt(event.timestamp), - action: event.action, - args: event.args ?? null, - })), - }; - } - - convertFromBarePersisted( - bareData: persistSchema.Actor, - ): PersistedActor { - const hibernatableConns = bareData.hibernatableConns.map((conn) => ({ - id: conn.id, - parameters: cbor.decode(new Uint8Array(conn.parameters)), - state: cbor.decode(new Uint8Array(conn.state)), - subscriptions: conn.subscriptions.map((sub) => ({ - eventName: sub.eventName, - })), - hibernatableRequestId: conn.hibernatableRequestId, - lastSeenTimestamp: Number(conn.lastSeenTimestamp), - msgIndex: Number(conn.msgIndex), - })); - - return { - input: bareData.input - ? cbor.decode(new Uint8Array(bareData.input)) - : undefined, - hasInitialized: bareData.hasInitialized, - state: cbor.decode(new Uint8Array(bareData.state)), - hibernatableConns, - scheduledEvents: bareData.scheduledEvents.map((event) => ({ - eventId: event.eventId, - timestamp: Number(event.timestamp), - action: event.action, - args: event.args ?? undefined, - })), - }; - } - // MARK: - Private Helpers #validateStateEnabled() { @@ -428,8 +360,8 @@ export class StateManager { } } - async #writePersistedDataDirect(persistData: PersistedActor) { - const bareData = this.convertToBarePersisted(persistData); + async #writePersistedDataDirect(persistData: PersistedActor) { + const bareData = convertActorToBarePersisted(persistData); await this.#actorDriver.kvBatchPut(this.#actor.id, [ [ KEYS.PERSIST_DATA, diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/mod.ts b/rivetkit-typescript/packages/rivetkit/src/actor/mod.ts index 686a1a275e..dbfc945db8 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/mod.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/mod.ts @@ -73,7 +73,7 @@ export type { } from "@/common/websocket-interface"; export type { ActorKey } from "@/manager/protocol/query"; export type * from "./config"; -export type { Conn } from "./conn/mod"; +export type { Conn, AnyConn } from "./conn/mod"; export type { ActionContext } from "./contexts/action"; export type { ActorContext } from "./contexts/actor"; export type { ConnInitContext } from "./contexts/conn-init"; diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/router-endpoints.ts b/rivetkit-typescript/packages/rivetkit/src/actor/router-endpoints.ts index 89abe6d4fc..9a029f6f56 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/router-endpoints.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/router-endpoints.ts @@ -1,7 +1,7 @@ import * as cbor from "cbor-x"; import type { Context as HonoContext, HonoRequest } from "hono"; import type { WSContext } from "hono/ws"; -import type { AnyConn } from "@/actor/conn/mod"; +import { type AnyConn } from "@/actor/conn/mod"; import { ActionContext } from "@/actor/contexts/action"; import * as errors from "@/actor/errors"; import type { AnyActorInstance } from "@/actor/instance/mod"; @@ -104,6 +104,8 @@ export async function handleWebSocketConnect( parameters: unknown, requestId: string, requestIdBuf: ArrayBuffer | undefined, + isHibernatable: boolean, + isRestoringHibernatable: boolean, ): Promise { const exposeInternalError = req ? getRequestExposeInternalError(req) @@ -122,12 +124,6 @@ export async function handleWebSocketConnect( }); // Check if this is a hibernatable websocket - const isHibernatable = - !!requestIdBuf && - actor.persist.hibernatableConns.findIndex((conn) => - arrayBuffersEqual(conn.hibernatableRequestId, requestIdBuf), - ) !== -1; - const { driver, setWebSocket } = createWebSocketSocket( requestId, requestIdBuf, @@ -139,6 +135,7 @@ export async function handleWebSocketConnect( driver, parameters, req, + isRestoringHibernatable, ); createdConn = conn; @@ -149,6 +146,8 @@ export async function handleWebSocketConnect( setWebSocket(ws); + // This will not be called by restoring hibernatable + // connections. All restoratino is done in prepareConn. actor.connectionManager.connectConn(conn); }, onMessage: (evt: { data: any }, ws: WSContext) => { @@ -328,7 +327,7 @@ export async function handleAction( }), ); - // TODO: Remvoe any, Hono is being a dumbass + // TODO: Remove any, Hono is being a dumbass return c.body(serialized as Uint8Array as any, 200, { "Content-Type": contentTypeForEncoding(encoding), }); @@ -383,12 +382,12 @@ export async function handleRawWebSocket( // Promise used to wait for the websocket close in `disconnect` const closePromiseResolvers = promiseWithResolvers(); + // TODO: Is there a better way to determine this? // Extract rivetRequestId provided by engine runner const isHibernatable = !!requestIdBuf && - actor.persist.hibernatableConns.findIndex((conn) => - arrayBuffersEqual(conn.hibernatableRequestId, requestIdBuf), - ) !== -1; + actor.connectionManager.findHibernatableConn(requestIdBuf) !== + undefined; const newPath = truncateRawWebSocketPathPrefix(path); let newRequest: Request; diff --git a/rivetkit-typescript/packages/rivetkit/src/driver-helpers/utils.ts b/rivetkit-typescript/packages/rivetkit/src/driver-helpers/utils.ts index de71ea202e..4d48b86949 100644 --- a/rivetkit-typescript/packages/rivetkit/src/driver-helpers/utils.ts +++ b/rivetkit-typescript/packages/rivetkit/src/driver-helpers/utils.ts @@ -13,7 +13,6 @@ function serializeEmptyPersistData(input: unknown | undefined): Uint8Array { : null, hasInitialized: false, state: bufferToArrayBuffer(cbor.encode(undefined)), - hibernatableConns: [], scheduledEvents: [], }; return ACTOR_VERSIONED.serializeWithEmbeddedVersion(persistData); diff --git a/rivetkit-typescript/packages/rivetkit/src/drivers/engine/actor-driver.ts b/rivetkit-typescript/packages/rivetkit/src/drivers/engine/actor-driver.ts index d18e55da79..848b849d0f 100644 --- a/rivetkit-typescript/packages/rivetkit/src/drivers/engine/actor-driver.ts +++ b/rivetkit-typescript/packages/rivetkit/src/drivers/engine/actor-driver.ts @@ -1,7 +1,7 @@ import type { ActorConfig as EngineActorConfig, RunnerConfig as EngineRunnerConfig, - HibernationConfig, + HibernatingWebSocketMetadata, } from "@rivetkit/engine-runner"; import { Runner } from "@rivetkit/engine-runner"; import * as cbor from "cbor-x"; @@ -36,6 +36,7 @@ import { getInitialActorKvState, type ManagerDriver, } from "@/driver-helpers/mod"; +import { CONN_STATE_MANAGER_SYMBOL, type AnyConn } from "@/actor/conn/mod"; import { buildActorNames, type RegistryConfig } from "@/registry/config"; import type { RunnerConfig } from "@/registry/run-config"; import { getEndpoint } from "@/remote-manager-driver/api-utils"; @@ -48,6 +49,7 @@ import { stringifyError, } from "@/utils"; import { logger } from "./log"; +import { RequestId } from "@/schemas/actor-persist/mod"; const RUNNER_SSE_PING_INTERVAL = 1000; @@ -78,12 +80,14 @@ export class EngineActorDriver implements ActorDriver { // protocol is updated to send the intent directly (see RVT-5284) #actorStopIntent: Map = new Map(); - // WebSocket message acknowledgment debouncing for hibernatable websockets - #hibernatableWebSocketAckQueue: Map< + // Request IDs that are waiting to be acknowledged after the next persist + // + // We store the RequestId since it's the array buffer version we need to + // pass back to the runner. + #hibernatableWebSocketAckQueue = new Map< string, - { requestIdBuf: ArrayBuffer; messageIndex: number } - > = new Map(); - #wsAckFlushInterval?: NodeJS.Timeout; + { actorId: string; requestId: RequestId; messageIndex: number } + >(); constructor( registryConfig: RegistryConfig, @@ -132,168 +136,15 @@ export class EngineActorDriver implements ActorDriver { }, fetch: this.#runnerFetch.bind(this), websocket: this.#runnerWebSocket.bind(this), + hibernatableWebSocket: { + canHibernate: this.#hwsCanHibernate.bind(this), + loadAll: this.#hwsLoadAll.bind(this), + persistMessageIndex: this.#hwsPersistMessageIndex.bind(this), + removePersisted: this.#hwsRemovePersisted.bind(this), + }, onActorStart: this.#runnerOnActorStart.bind(this), onActorStop: this.#runnerOnActorStop.bind(this), logger: getLogger("engine-runner"), - getActorHibernationConfig: ( - actorId: string, - requestId: ArrayBuffer, - request: Request, - ): HibernationConfig => { - const url = new URL(request.url); - const path = url.pathname; - - // Get actor instance from runner to access actor name - const actorInstance = this.#runner.getActor(actorId); - if (!actorInstance) { - logger().warn({ - msg: "actor not found in getActorHibernationConfig", - actorId, - }); - return { enabled: false, lastMsgIndex: undefined }; - } - - // Load actor handler to access persisted data - const handler = this.#actors.get(actorId); - if (!handler) { - logger().warn({ - msg: "actor handler not found in getActorHibernationConfig", - actorId, - }); - return { enabled: false, lastMsgIndex: undefined }; - } - if (!handler.actor) { - logger().warn({ - msg: "actor not found in getActorHibernationConfig", - actorId, - }); - return { enabled: false, lastMsgIndex: undefined }; - } - - // Check for existing WS - const hibernatableArray = - handler.actor.persist.hibernatableConns; - logger().debug({ - msg: "checking hibernatable websockets", - requestId: idToStr(requestId), - existingHibernatableWebSockets: hibernatableArray.length, - actorId, - }); - - const existingWs = hibernatableArray.find((conn) => - arrayBuffersEqual(conn.hibernatableRequestId, requestId), - ); - - // Determine configuration for new WS - let hibernationConfig: HibernationConfig; - if (existingWs) { - // Convert msgIndex to number, treating -1 as undefined (no messages processed yet) - const lastMsgIndex = - existingWs.msgIndex >= 0n - ? Number(existingWs.msgIndex) - : undefined; - logger().debug({ - msg: "found existing hibernatable websocket", - requestId: idToStr(requestId), - lastMsgIndex: lastMsgIndex ?? -1, - }); - hibernationConfig = { - enabled: true, - lastMsgIndex, - }; - } else { - logger().debug({ - msg: "no existing hibernatable websocket found", - requestId: idToStr(requestId), - }); - if (path === PATH_CONNECT) { - hibernationConfig = { - enabled: true, - lastMsgIndex: undefined, - }; - } else if (path.startsWith(PATH_WEBSOCKET_PREFIX)) { - // Find actor config - const definition = lookupInRegistry( - this.#registryConfig, - actorInstance.config.name, - ); - - // Check if can hibernate - const canHibernateWebSocket = - definition.config.options?.canHibernateWebSocket; - if (canHibernateWebSocket === true) { - hibernationConfig = { - enabled: true, - lastMsgIndex: undefined, - }; - } else if ( - typeof canHibernateWebSocket === "function" - ) { - try { - // Truncate the path to match the behavior on onRawWebSocket - const newPath = truncateRawWebSocketPathPrefix( - url.pathname, - ); - const truncatedRequest = new Request( - `http://actor${newPath}`, - request, - ); - - const canHibernate = - canHibernateWebSocket(truncatedRequest); - hibernationConfig = { - enabled: canHibernate, - lastMsgIndex: undefined, - }; - } catch (error) { - logger().error({ - msg: "error calling canHibernateWebSocket", - error, - }); - hibernationConfig = { - enabled: false, - lastMsgIndex: undefined, - }; - } - } else { - hibernationConfig = { - enabled: false, - lastMsgIndex: undefined, - }; - } - } else { - logger().warn({ - msg: "unexpected path for getActorHibernationConfig", - path, - }); - hibernationConfig = { - enabled: false, - lastMsgIndex: undefined, - }; - } - } - - // Save or update hibernatable WebSocket - if (existingWs) { - logger().debug({ - msg: "updated existing hibernatable websocket timestamp", - requestId: idToStr(requestId), - currentMsgIndex: existingWs.msgIndex, - }); - existingWs.lastSeenTimestamp = Date.now(); - } else if (path === PATH_CONNECT) { - // For new hibernatable connections, we'll create a placeholder entry - // The actual connection data will be populated when the connection is created - logger().debug({ - msg: "will create hibernatable conn when connection is created", - requestId: idToStr(requestId), - }); - // Note: The actual hibernatable connection is created in connection-manager.ts - // when createConn is called with a hibernatable requestId - } - - return hibernationConfig; - }, }; // Create and start runner @@ -305,18 +156,10 @@ export class EngineActorDriver implements ActorDriver { namespace: runConfig.namespace, runnerName: runConfig.runnerName, }); + } - // Start WebSocket ack flush interval - // - // Decreasing this reduces the amount of buffered messages on the - // gateway - // - // Gateway timeout configured to 30s - // https://github.com/rivet-dev/rivet/blob/222dae87e3efccaffa2b503de40ecf8afd4e31eb/engine/packages/pegboard-gateway/src/shared_state.rs#L17 - this.#wsAckFlushInterval = setInterval( - () => this.#flushHibernatableWebSocketAcks(), - 1000, - ); + getExtraActorLogParams(): Record { + return { runnerId: this.#runner.runnerId ?? "-" }; } async #loadActorHandler(actorId: string): Promise { @@ -329,25 +172,6 @@ export class EngineActorDriver implements ActorDriver { return handler; } - async loadActor(actorId: string): Promise { - const handler = await this.#loadActorHandler(actorId); - if (!handler.actor) throw new Error(`Actor ${actorId} failed to load`); - return handler.actor; - } - - #flushHibernatableWebSocketAcks(): void { - if (this.#hibernatableWebSocketAckQueue.size === 0) return; - - for (const { - requestIdBuf: requestId, - messageIndex: index, - } of this.#hibernatableWebSocketAckQueue.values()) { - this.#runner.sendWebsocketMessageAck(requestId, index); - } - - this.#hibernatableWebSocketAckQueue.clear(); - } - getContext(actorId: string): DriverContext { return {}; } @@ -382,17 +206,11 @@ export class EngineActorDriver implements ActorDriver { return undefined; } - // Batch KV operations + // MARK: - Batch KV operations async kvBatchPut( actorId: string, entries: [Uint8Array, Uint8Array][], ): Promise { - logger().debug({ - msg: "batch writing KV entries", - actorId, - entryCount: entries.length, - }); - await this.#runner.kvPut(actorId, entries); } @@ -400,22 +218,10 @@ export class EngineActorDriver implements ActorDriver { actorId: string, keys: Uint8Array[], ): Promise<(Uint8Array | null)[]> { - logger().debug({ - msg: "batch reading KV entries", - actorId, - keyCount: keys.length, - }); - return await this.#runner.kvGet(actorId, keys); } async kvBatchDelete(actorId: string, keys: Uint8Array[]): Promise { - logger().debug({ - msg: "batch deleting KV entries", - actorId, - keyCount: keys.length, - }); - await this.#runner.kvDelete(actorId, keys); } @@ -423,16 +229,126 @@ export class EngineActorDriver implements ActorDriver { actorId: string, prefix: Uint8Array, ): Promise<[Uint8Array, Uint8Array][]> { + return await this.#runner.kvListPrefix(actorId, prefix); + } + + // MARK: - Actor Lifecycle + async loadActor(actorId: string): Promise { + const handler = await this.#loadActorHandler(actorId); + if (!handler.actor) throw new Error(`Actor ${actorId} failed to load`); + return handler.actor; + } + + startSleep(actorId: string) { + // HACK: Track intent for onActorStop (see RVT-5284) + this.#actorStopIntent.set(actorId, "sleep"); + this.#runner.sleepActor(actorId); + } + + startDestroy(actorId: string) { + // HACK: Track intent for onActorStop (see RVT-5284) + this.#actorStopIntent.set(actorId, "destroy"); + this.#runner.stopActor(actorId); + } + + async shutdownRunner(immediate: boolean): Promise { + logger().info({ msg: "stopping engine actor driver", immediate }); + + // TODO: We need to update the runner to have a draining state so: + // 1. Send ToServerDraining + // - This causes Pegboard to stop allocating actors to this runner + // 2. Pegboard sends ToClientStopActor for all actors on this runner which handles the graceful migration of each actor independently + // 3. Send ToServerStopping once all actors have successfully stopped + // + // What's happening right now is: + // 1. All actors enter stopped state + // 2. Actors still respond to requests because only RivetKit knows it's + // stopping, this causes all requests to issue errors that the actor is + // stopping. (This will NOT return a 503 bc the runner has no idea the + // actors are stopping.) + // 3. Once the last actor stops, then the runner finally stops + actors + // reschedule + // + // This means that: + // - All actors on this runner are bricked until the slowest onStop finishes + // - Guard will not gracefully handle requests bc it's not receiving a 503 + // - Actors can still be scheduled to this runner while the other + // actors are stopping, meaning that those actors will NOT get onStop + // and will potentiall corrupt their state + // + // HACK: Stop all actors to allow state to be saved + // NOTE: onStop is only supposed to be called by the runner, we're + // abusing it here logger().debug({ - msg: "listing KV entries with prefix", - actorId, - prefixLength: prefix.length, + msg: "stopping all actors before shutdown", + actorCount: this.#actors.size, }); + const stopPromises: Promise[] = []; + for (const [_actorId, handler] of this.#actors.entries()) { + if (handler.actor) { + stopPromises.push( + handler.actor.onStop("sleep").catch((err) => { + handler.actor?.rLog.error({ + msg: "onStop errored", + error: stringifyError(err), + }); + }), + ); + } + } + await Promise.all(stopPromises); + logger().debug({ msg: "all actors stopped" }); - return await this.#runner.kvListPrefix(actorId, prefix); + await this.#runner.shutdown(immediate); + } + + async serverlessHandleStart(c: HonoContext): Promise { + return streamSSE(c, async (stream) => { + // NOTE: onAbort does not work reliably + stream.onAbort(() => {}); + c.req.raw.signal.addEventListener("abort", () => { + logger().debug("SSE aborted, shutting down runner"); + + // We cannot assume that the request will always be closed gracefully by Rivet. We always proceed with a graceful shutdown in case the request was terminated for any other reason. + // + // If we did not use a graceful shutdown, the runner would + this.shutdownRunner(false); + }); + + await this.#runnerStarted.promise; + + // Runner id should be set if the runner started + const payload = this.#runner.getServerlessInitPacket(); + invariant(payload, "runnerId not set"); + await stream.writeSSE({ data: payload }); + + // Send ping every second to keep the connection alive + while (true) { + if (this.#isRunnerStopped) { + logger().debug({ + msg: "runner is stopped", + }); + break; + } + + if (stream.closed || stream.aborted) { + logger().debug({ + msg: "runner sse stream closed", + closed: stream.closed, + aborted: stream.aborted, + }); + break; + } + + await stream.writeSSE({ event: "ping", data: "" }); + await stream.sleep(RUNNER_SSE_PING_INTERVAL); + } + + // Wait for the runner to stop if the SSE stream aborted early for any reason + await this.#runnerStopped.promise; + }); } - // Runner lifecycle callbacks async #runnerOnActorStart( actorId: string, generation: number, @@ -543,6 +459,7 @@ export class EngineActorDriver implements ActorDriver { logger().debug({ msg: "runner actor stopped", actorId, reason }); } + // MARK: - Runner Networking async #runnerFetch( _runner: Runner, actorId: string, @@ -558,13 +475,13 @@ export class EngineActorDriver implements ActorDriver { return await this.#actorRouter.fetch(request, { actorId }); } - async #runnerWebSocket( + #runnerWebSocket( _runner: Runner, actorId: string, websocketRaw: any, requestIdBuf: ArrayBuffer, request: Request, - ): Promise { + ): void { const websocket = websocketRaw as UniversalWebSocket; const requestId = idToStr(requestIdBuf); @@ -604,8 +521,6 @@ export class EngineActorDriver implements ActorDriver { throw new Error(`Unreachable path: ${url.pathname}`); } - // TODO: Add close - // Connect the Hono WS hook to the adapter const wsContext = new WSContext(websocket); @@ -625,313 +540,192 @@ export class EngineActorDriver implements ActorDriver { } websocket.addEventListener("message", (event: RivetMessageEvent) => { - invariant(event.rivetRequestId, "missing rivetRequestId"); - invariant(event.rivetMessageIndex, "missing rivetMessageIndex"); - - // Handle hibernatable WebSockets: - // - Check for out of sequence messages - // - Save msgIndex for WS restoration - // - Queue WS acks - const actorHandler = this.#actors.get(actorId); - if (actorHandler?.actor) { - const hibernatableWs = - actorHandler.actor.persist.hibernatableConns.find( - (conn: any) => - arrayBuffersEqual( - conn.hibernatableRequestId, - requestIdBuf, - ), - ); - - if (hibernatableWs) { - // Track msgIndex for sending acks - const currentEntry = - this.#hibernatableWebSocketAckQueue.get(requestId); - if (currentEntry) { - const previousIndex = currentEntry.messageIndex; - - // Check for out-of-sequence messages - if (event.rivetMessageIndex !== previousIndex + 1) { - let closeReason: string; - let sequenceType: string; - - if (event.rivetMessageIndex < previousIndex) { - closeReason = "ws.message_index_regressed"; - sequenceType = "regressed"; - } else if ( - event.rivetMessageIndex === previousIndex - ) { - closeReason = "ws.message_index_duplicate"; - sequenceType = "duplicate"; - } else { - closeReason = "ws.message_index_skip"; - sequenceType = "gap/skipped"; - } - - logger().warn({ - msg: "hibernatable websocket message index out of sequence, closing connection", - requestId, - actorId, - previousIndex, - expectedIndex: previousIndex + 1, - receivedIndex: event.rivetMessageIndex, - sequenceType, - closeReason, - gap: - event.rivetMessageIndex > previousIndex - ? event.rivetMessageIndex - - previousIndex - - 1 - : 0, - }); - - // Close the WebSocket and skip processing - wsContext.close(1008, closeReason); - return; - } - - // Update to the next index - currentEntry.messageIndex = event.rivetMessageIndex; - } else { - this.#hibernatableWebSocketAckQueue.set(requestId, { - requestIdBuf, - messageIndex: event.rivetMessageIndex, - }); - } - - // Update msgIndex for next WebSocket open msgIndex restoration - const oldMsgIndex = hibernatableWs.msgIndex; - hibernatableWs.msgIndex = event.rivetMessageIndex; - hibernatableWs.lastSeenTimestamp = Date.now(); - - logger().debug({ - msg: "updated hibernatable websocket msgIndex in engine driver", - requestId, - oldMsgIndex: oldMsgIndex.toString(), - newMsgIndex: event.rivetMessageIndex, - actorId, - }); - } - } else { - // Warn if we receive a message for a hibernatable websocket but can't find the actor - logger().warn({ - msg: "received websocket message but actor not found for hibernatable tracking", - actorId, - requestId, - messageIndex: event.rivetMessageIndex, - hasHandler: !!actorHandler, - hasActor: !!actorHandler?.actor, - }); - } - // Process the message after all hibernation logic and validation in case the message is out of order wsHandlerPromise.then((x) => x.onMessage?.(event, wsContext)); }); websocket.addEventListener("close", (event) => { - // Flush any pending acks before closing - this.#flushHibernatableWebSocketAcks(); - - // Clean up hibernatable WebSocket - this.#cleanupHibernatableWebSocket( - actorId, - requestIdBuf, - requestId, - "close", - event, - ); - wsHandlerPromise.then((x) => x.onClose?.(event, wsContext)); }); websocket.addEventListener("error", (event) => { - // Clean up hibernatable WebSocket on error - this.#cleanupHibernatableWebSocket( - actorId, - requestIdBuf, - requestId, - "error", - event, - ); - wsHandlerPromise.then((x) => x.onError?.(event, wsContext)); }); } - /** - * Helper method to clean up hibernatable WebSocket entries - * Eliminates duplication between close and error handlers - */ - #cleanupHibernatableWebSocket( + // MARK: - Hibernating WebSockets + #hwsCanHibernate( actorId: string, - requestIdBuf: ArrayBuffer, - requestId: string, - eventType: "close" | "error", - event?: any, - ) { - const actorHandler = this.#actors.get(actorId); - if (actorHandler?.actor) { - const hibernatableArray = - actorHandler.actor.persist.hibernatableConns; - const wsIndex = hibernatableArray.findIndex((conn: any) => - arrayBuffersEqual(conn.hibernatableRequestId, requestIdBuf), - ); + requestId: ArrayBuffer, + request: Request, + ): boolean { + const url = new URL(request.url); + const path = url.pathname; - if (wsIndex !== -1) { - const removed = hibernatableArray.splice(wsIndex, 1); - const logData: any = { - msg: `removed hibernatable websocket on ${eventType}`, - requestId, - actorId, - removedMsgIndex: - removed[0]?.msgIndex?.toString() ?? "unknown", - }; - // Add error context if this is an error event - if (eventType === "error" && event) { - logData.error = event; - } - logger().debug(logData); - } - } else { - // Warn if actor not found during cleanup - const warnData: any = { - msg: `websocket ${eventType === "close" ? "closed" : "error"} but actor not found for hibernatable cleanup`, + // Get actor instance from runner to access actor name + const actorInstance = this.#runner.getActor(actorId); + if (!actorInstance) { + logger().warn({ + msg: "actor not found in #hwsCanHibernate", actorId, - requestId, - hasHandler: !!actorHandler, - hasActor: !!actorHandler?.actor, - }; - // Add error context if this is an error event - if (eventType === "error" && event) { - warnData.error = event; - } - logger().warn(warnData); + }); + return false; } - // Also remove from ack queue - this.#hibernatableWebSocketAckQueue.delete(requestId); - } - - startSleep(actorId: string) { - // HACK: Track intent for onActorStop (see RVT-5284) - this.#actorStopIntent.set(actorId, "sleep"); - this.#runner.sleepActor(actorId); - } - - startDestroy(actorId: string) { - // HACK: Track intent for onActorStop (see RVT-5284) - this.#actorStopIntent.set(actorId, "destroy"); - this.#runner.stopActor(actorId); - } - - async shutdownRunner(immediate: boolean): Promise { - logger().info({ msg: "stopping engine actor driver", immediate }); + // Load actor handler to access persisted data + const handler = this.#actors.get(actorId); + if (!handler) { + logger().warn({ + msg: "actor handler not found in #hwsCanHibernate", + actorId, + }); + return false; + } + if (!handler.actor) { + logger().warn({ + msg: "actor not found in #hwsCanHibernate", + actorId, + }); + return false; + } - // TODO: We need to update the runner to have a draining state so: - // 1. Send ToServerDraining - // - This causes Pegboard to stop allocating actors to this runner - // 2. Pegboard sends ToClientStopActor for all actors on this runner which handles the graceful migration of each actor independently - // 3. Send ToServerStopping once all actors have successfully stopped - // - // What's happening right now is: - // 1. All actors enter stopped state - // 2. Actors still respond to requests because only RivetKit knows it's - // stopping, this causes all requests to issue errors that the actor is - // stopping. (This will NOT return a 503 bc the runner has no idea the - // actors are stopping.) - // 3. Once the last actor stops, then the runner finally stops + actors - // reschedule - // - // This means that: - // - All actors on this runner are bricked until the slowest onStop finishes - // - Guard will not gracefully handle requests bc it's not receiving a 503 - // - Actors can still be scheduled to this runner while the other - // actors are stopping, meaning that those actors will NOT get onStop - // and will potentiall corrupt their state - // - // HACK: Stop all actors to allow state to be saved - // NOTE: onStop is only supposed to be called by the runner, we're - // abusing it here + // Determine configuration for new WS logger().debug({ - msg: "stopping all actors before shutdown", - actorCount: this.#actors.size, + msg: "no existing hibernatable websocket found", + requestId: idToStr(requestId), }); - const stopPromises: Promise[] = []; - for (const [_actorId, handler] of this.#actors.entries()) { - if (handler.actor) { - stopPromises.push( - handler.actor.onStop("sleep").catch((err) => { - handler.actor?.rLog.error({ - msg: "onStop errored", - error: stringifyError(err), - }); - }), - ); - } - } - await Promise.all(stopPromises); - logger().debug({ msg: "all actors stopped" }); - - // Clear the ack flush interval - if (this.#wsAckFlushInterval) { - clearInterval(this.#wsAckFlushInterval); - this.#wsAckFlushInterval = undefined; - } + if (path === PATH_CONNECT) { + return true; + } else if (path.startsWith(PATH_WEBSOCKET_PREFIX)) { + // Find actor config + const definition = lookupInRegistry( + this.#registryConfig, + actorInstance.config.name, + ); - // Flush any remaining acks - this.#flushHibernatableWebSocketAcks(); + // Check if can hibernate + const canHibernateWebSocket = + definition.config.options?.canHibernateWebSocket; + if (canHibernateWebSocket === true) { + return true; + } else if (typeof canHibernateWebSocket === "function") { + try { + // Truncate the path to match the behavior on onRawWebSocket + const newPath = truncateRawWebSocketPathPrefix( + url.pathname, + ); + const truncatedRequest = new Request( + `http://actor${newPath}`, + request, + ); - await this.#runner.shutdown(immediate); + const canHibernate = + canHibernateWebSocket(truncatedRequest); + return canHibernate; + } catch (error) { + logger().error({ + msg: "error calling canHibernateWebSocket", + error, + }); + return false; + } + } else { + return false; + } + } else { + logger().warn({ + msg: "unexpected path for getActorHibernationConfig", + path, + }); + return false; + } } - async serverlessHandleStart(c: HonoContext): Promise { - return streamSSE(c, async (stream) => { - // NOTE: onAbort does not work reliably - stream.onAbort(() => {}); - c.req.raw.signal.addEventListener("abort", () => { - logger().debug("SSE aborted, shutting down runner"); - - // We cannot assume that the request will always be closed gracefully by Rivet. We always proceed with a graceful shutdown in case the request was terminated for any other reason. - // - // If we did not use a graceful shutdown, the runner would - this.shutdownRunner(false); - }); + #hwsLoadAll(actorId: string): HibernatingWebSocketMetadata[] { + // TODO: Load actor in a better way + const actor = this.#actors.get(actorId); + invariant(actor?.actor, "actor not loaded"); + + return actor.actor.conns + .values() + .map((conn) => { + const connStateManager = conn[CONN_STATE_MANAGER_SYMBOL]; + const hibernatable = connStateManager.hibernatableData; + if (!hibernatable) return undefined; + return { + requestId: hibernatable.hibernatableRequestId, + path: hibernatable.requestPath, + headers: hibernatable.requestHeaders, + messageIndex: hibernatable.msgIndex, + } satisfies HibernatingWebSocketMetadata; + }) + .filter((x) => x !== undefined) + .toArray(); + } - await this.#runnerStarted.promise; + #hwsPersistMessageIndex(actorId: string, requestId: RequestId) { + // TODO: is this the right way of getting the actor - // Runner id should be set if the runner started - const payload = this.#runner.getServerlessInitPacket(); - invariant(payload, "runnerId not set"); - await stream.writeSSE({ data: payload }); + const actor = this.#actors.get(actorId); + const conn = actor?.actor?.connectionManager.findHibernatableConn(requestId); - // Send ping every second to keep the connection alive - while (true) { - if (this.#isRunnerStopped) { - logger().debug({ - msg: "runner is stopped", - }); - break; - } + if (!conn) { + logger().warn({ + msg: "cannot find conn to persist message index to", + actorId, + requestId: idToStr(requestId), + }); + return; + } - if (stream.closed || stream.aborted) { - logger().debug({ - msg: "runner sse stream closed", - closed: stream.closed, - aborted: stream.aborted, - }); - break; - } + this.#hibernatableWebSocketAckQueue.set(conn.id, {}); - await stream.writeSSE({ event: "ping", data: "" }); - await stream.sleep(RUNNER_SSE_PING_INTERVAL); - } + // TODO: Find conn with request ID + // TODO: Add conn to persist queue + // TODO: Start timer to force save + } - // Wait for the runner to stop if the SSE stream aborted early for any reason - await this.#runnerStopped.promise; - }); + #hwsRemovePersisted(actorId: string, requestId: RequestId) { + // // TODO: persist immediately + // const actorHandler = this.#actors.get(actorId); + // if (actorHandler?.actor) { + // const hibernatableArray = + // actorHandler.actor.persist.hibernatableConns; + // const wsIndex = hibernatableArray.findIndex((conn: any) => + // arrayBuffersEqual(conn.hibernatableRequestId, requestIdBuf), + // ); + // + // if (wsIndex !== -1) { + // const removed = hibernatableArray.splice(wsIndex, 1); + // logger().debug({ + // msg: "removed hibernatable websocket", + // requestId, + // actorId, + // removedMsgIndex: + // removed[0]?.msgIndex?.toString() ?? "unknown", + // }); + // } + // } else { + // // Warn if actor not found during cleanup + // logger().warn({ + // msg: "websocket but actor not found for hibernatable cleanup", + // actorId, + // requestId, + // hasHandler: !!actorHandler, + // hasActor: !!actorHandler?.actor, + // }); + // } + // + // // Also remove from ack queue + // this.#hibernatableWebSocketAckQueue.delete(requestId); } - getExtraActorLogParams(): Record { - return { runnerId: this.#runner.runnerId ?? "-" }; + onAfterPersistConn(conn: AnyConn) { + // TODO: + // this.#runner.sendHibernatableWebSocketMessageAck( + // requestId, + // messageIndex, + // ); + // this.#hibernatableWebSocketAckQueue.delete(conn.id); } } diff --git a/rivetkit-typescript/packages/rivetkit/src/schemas/actor-persist/versioned.ts b/rivetkit-typescript/packages/rivetkit/src/schemas/actor-persist/versioned.ts index d6fa27cefb..57ad841e30 100644 --- a/rivetkit-typescript/packages/rivetkit/src/schemas/actor-persist/versioned.ts +++ b/rivetkit-typescript/packages/rivetkit/src/schemas/actor-persist/versioned.ts @@ -23,35 +23,6 @@ const migrations = new Map>([ [ 2, (v2Data: v2.PersistedActor): v3.Actor => { - // Merge connections and hibernatableWebSocket into hibernatableConns - const hibernatableConns: v3.HibernatableConn[] = []; - - // Convert connections with hibernatable request IDs to hibernatable conns - for (const conn of v2Data.connections) { - if (conn.hibernatableRequestId) { - // Find the matching hibernatable WebSocket - const ws = v2Data.hibernatableWebSockets.find((ws) => - Buffer.from(ws.requestId).equals( - Buffer.from(conn.hibernatableRequestId!), - ), - ); - - if (ws) { - hibernatableConns.push({ - id: conn.id, - parameters: conn.parameters, - state: conn.state, - subscriptions: conn.subscriptions.map((sub) => ({ - eventName: sub.eventName, - })), - hibernatableRequestId: conn.hibernatableRequestId, - lastSeenTimestamp: ws.lastSeenTimestamp, - msgIndex: ws.msgIndex, - }); - } - } - } - // Transform scheduled events from nested structure to flat structure const scheduledEvents: v3.ScheduleEvent[] = v2Data.scheduledEvents.map((event) => { @@ -74,7 +45,6 @@ const migrations = new Map>([ input: v2Data.input, hasInitialized: v2Data.hasInitialized, state: v2Data.state, - hibernatableConns, scheduledEvents, }; }, @@ -87,3 +57,10 @@ export const ACTOR_VERSIONED = createVersionedDataHandler({ serializeVersion: (data) => v3.encodeActor(data), deserializeVersion: (bytes) => v3.decodeActor(bytes), }); + +export const CONN_VERSIONED = createVersionedDataHandler({ + currentVersion: CURRENT_VERSION, + migrations: new Map(), + serializeVersion: (data) => v3.encodeConn(data), + deserializeVersion: (bytes) => v3.decodeConn(bytes), +});