Skip to content

Commit 43b09b2

Browse files
committed
fix(rivetkit): properly handle msgIndex for hibernatable websocket reconnection (#3401)
1 parent 5b69257 commit 43b09b2

File tree

6 files changed

+272
-54
lines changed

6 files changed

+272
-54
lines changed

rivetkit-typescript/packages/rivetkit/schemas/actor-persist/v2.bare

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,5 +51,5 @@ type PersistedActor struct {
5151
state: data
5252
connections: list<PersistedConnection>
5353
scheduledEvents: list<PersistedScheduleEvent>
54-
hibernatableWebSocket: list<PersistedHibernatableWebSocket>
54+
hibernatableWebSockets: list<PersistedHibernatableWebSocket>
5555
}
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import type { AnyConn } from "@/actor/conn/mod";
2+
import type { AnyActorInstance } from "@/actor/instance/mod";
3+
import type { UniversalWebSocket } from "@/common/websocket-interface";
4+
import type { ConnDriver, DriverReadyState } from "../driver";
5+
6+
/**
7+
* Creates a raw WebSocket connection driver.
8+
*
9+
* This driver is used for raw WebSocket connections that don't use the RivetKit protocol.
10+
* Unlike the standard WebSocket driver, this doesn't have sendMessage since raw WebSockets
11+
* don't handle messages from the RivetKit protocol - they handle messages directly in the
12+
* actor's onWebSocket handler.
13+
*/
14+
export function createRawWebSocketSocket(
15+
requestId: string,
16+
requestIdBuf: ArrayBuffer | undefined,
17+
hibernatable: boolean,
18+
websocket: UniversalWebSocket,
19+
closePromise: Promise<void>,
20+
): ConnDriver {
21+
return {
22+
requestId,
23+
requestIdBuf,
24+
hibernatable,
25+
26+
// No sendMessage implementation since this is a raw WebSocket that doesn't
27+
// handle messages from the RivetKit protocol
28+
29+
disconnect: async (
30+
_actor: AnyActorInstance,
31+
_conn: AnyConn,
32+
reason?: string,
33+
) => {
34+
// Close socket
35+
websocket.close(1000, reason);
36+
37+
// Wait for socket to close gracefully
38+
await closePromise;
39+
},
40+
41+
terminate: () => {
42+
(websocket as any).terminate?.();
43+
},
44+
45+
getConnectionReadyState: (
46+
_actor: AnyActorInstance,
47+
_conn: AnyConn,
48+
): DriverReadyState | undefined => {
49+
return websocket.readyState;
50+
},
51+
};
52+
}

rivetkit-typescript/packages/rivetkit/src/actor/instance/mod.ts

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ enum CanSleep {
4141
NotReady,
4242
ActiveConns,
4343
ActiveHonoHttpRequests,
44-
ActiveRawWebSockets,
4544
}
4645

4746
/** Actor type alias with all `any` types. Used for `extends` in classes referencing this actor. */
@@ -100,7 +99,6 @@ export class ActorInstance<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
10099

101100
// MARK: - HTTP/WebSocket Tracking
102101
#activeHonoHttpRequests = 0;
103-
#activeRawWebSockets = new Set<UniversalWebSocket>();
104102

105103
// MARK: - Deprecated (kept for compatibility)
106104
#schedule!: Schedule;
@@ -673,13 +671,9 @@ export class ActorInstance<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
673671
try {
674672
const stateBeforeHandler = this.#stateManager.persistChanged;
675673

676-
// Track active websocket
677-
this.#activeRawWebSockets.add(websocket);
674+
// Reset sleep timer when handling WebSocket
678675
this.#resetSleepTimer();
679676

680-
// Setup WebSocket event handlers (simplified for brevity)
681-
this.#setupWebSocketHandlers(websocket);
682-
683677
// Handle WebSocket
684678
await this.#config.onWebSocket(this.actorContext, websocket, opts);
685679

@@ -958,18 +952,6 @@ export class ActorInstance<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
958952
}
959953
}
960954

961-
#setupWebSocketHandlers(websocket: UniversalWebSocket) {
962-
// Simplified WebSocket handler setup
963-
// Full implementation would track hibernatable websockets
964-
const onSocketClosed = () => {
965-
this.#activeRawWebSockets.delete(websocket);
966-
this.#resetSleepTimer();
967-
};
968-
969-
websocket.addEventListener("close", onSocketClosed);
970-
websocket.addEventListener("error", onSocketClosed);
971-
}
972-
973955
#resetSleepTimer() {
974956
if (this.#config.options.noSleep || !this.#sleepingSupported) return;
975957
if (this.#stopCalled) return;
@@ -1001,8 +983,6 @@ export class ActorInstance<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
1001983
if (!this.#ready) return CanSleep.NotReady;
1002984
if (this.#activeHonoHttpRequests > 0)
1003985
return CanSleep.ActiveHonoHttpRequests;
1004-
if (this.#activeRawWebSockets.size > 0)
1005-
return CanSleep.ActiveRawWebSockets;
1006986

1007987
for (const _conn of this.#connectionManager.connections.values()) {
1008988
return CanSleep.ActiveConns;

rivetkit-typescript/packages/rivetkit/src/actor/router-endpoints.ts

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ import {
3636
promiseWithResolvers,
3737
} from "@/utils";
3838
import { createHttpSocket } from "./conn/drivers/http";
39+
import { createRawWebSocketSocket } from "./conn/drivers/raw-websocket";
3940
import { createWebSocketSocket } from "./conn/drivers/websocket";
4041
import type { ActorDriver } from "./driver";
4142
import { loggerWithoutContext } from "./log";
@@ -383,12 +384,19 @@ export async function handleRawWebSocketHandler(
383384
): Promise<UpgradeWebSocketArgs> {
384385
const actor = await actorDriver.loadActor(actorId);
385386

387+
// Promise used to wait for the websocket close in `disconnect`
388+
const closePromiseResolvers = promiseWithResolvers<void>();
389+
390+
// Track connection outside of scope for cleanup
391+
let createdConn: AnyConn | undefined;
392+
386393
// Return WebSocket event handlers
387394
return {
388-
onOpen: (evt: any, ws: any) => {
395+
onOpen: async (evt: any, ws: any) => {
389396
// Extract rivetRequestId provided by engine runner
390397
const rivetRequestId = evt?.rivetRequestId;
391398
const isHibernatable =
399+
!!rivetRequestId &&
392400
actor[
393401
ACTOR_INSTANCE_PERSIST_SYMBOL
394402
].hibernatableConns.findIndex((conn) =>
@@ -424,10 +432,36 @@ export async function handleRawWebSocketHandler(
424432
toUrl: newRequest.url,
425433
});
426434

427-
// Call the actor's onWebSocket handler with the adapted WebSocket
428-
actor.handleWebSocket(adapter, {
429-
request: newRequest,
430-
});
435+
try {
436+
// Create connection using actor.createConn - this handles deduplication for hibernatable connections
437+
const requestId = rivetRequestId
438+
? String(rivetRequestId)
439+
: crypto.randomUUID();
440+
const conn = await actor.createConn(
441+
createRawWebSocketSocket(
442+
requestId,
443+
rivetRequestId,
444+
isHibernatable,
445+
adapter,
446+
closePromiseResolvers.promise,
447+
),
448+
{}, // No parameters for raw WebSocket
449+
newRequest,
450+
);
451+
452+
createdConn = conn;
453+
454+
// Call the actor's onWebSocket handler with the adapted WebSocket
455+
actor.handleWebSocket(adapter, {
456+
request: newRequest,
457+
});
458+
} catch (error) {
459+
actor.rLog.error({
460+
msg: "failed to create raw WebSocket connection",
461+
error: String(error),
462+
});
463+
ws.close(1011, "Failed to create connection");
464+
}
431465
},
432466
onMessage: (event: any, ws: any) => {
433467
// Find the adapter for this WebSocket
@@ -442,6 +476,15 @@ export async function handleRawWebSocketHandler(
442476
if (adapter) {
443477
adapter._handleClose(evt?.code || 1006, evt?.reason || "");
444478
}
479+
480+
// Resolve the close promise
481+
closePromiseResolvers.resolve();
482+
483+
// Clean up the connection
484+
if (createdConn) {
485+
const wasClean = evt?.wasClean || evt?.code === 1000;
486+
actor.connDisconnected(createdConn, wasClean);
487+
}
445488
},
446489
onError: (error: any, ws: any) => {
447490
// Find the adapter for this WebSocket

0 commit comments

Comments
 (0)