Skip to content

Commit 5620d29

Browse files
committed
refactor: wrap setRequestHandler for high-level servers
1 parent 98729f8 commit 5620d29

File tree

4 files changed

+572
-158
lines changed

4 files changed

+572
-158
lines changed

src/modules/tracingV2.ts

Lines changed: 135 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -233,10 +233,9 @@ function addMCPcatToolsToServer(server: HighLevelMCPServerLike): void {
233233
function addTracingToToolCallback(
234234
tool: RegisteredTool,
235235
toolName: string,
236-
server: HighLevelMCPServerLike,
236+
_server: HighLevelMCPServerLike,
237237
): RegisteredTool {
238238
const originalCallback = tool.callback;
239-
const lowLevelServer = server.server as MCPServerLike;
240239

241240
// Check if this callback has already been wrapped
242241
if (wrappedCallbacks.has(originalCallback)) {
@@ -277,156 +276,168 @@ function addTracingToToolCallback(
277276
return args;
278277
};
279278

280-
try {
281-
const data = getServerTrackingData(lowLevelServer);
282-
if (!data) {
283-
writeToLog(
284-
"Warning: MCPCat is unable to find server tracking data. Please ensure you have called track(server, options) before using tool calls.",
285-
);
279+
// Remove context from args before calling original callback
280+
// BUT keep it for get_more_tools since it's a required parameter
281+
const cleanedArgs =
282+
toolName === "get_more_tools" ? args : removeContextFromArgs(args);
283+
284+
// Call original callback with cleaned args
285+
if (cleanedArgs === undefined) {
286+
return await (
287+
originalCallback as (
288+
extra: CompatibleRequestHandlerExtra,
289+
) => Promise<CallToolResult>
290+
)(extra);
291+
} else {
292+
return await (
293+
originalCallback as (
294+
args: any,
295+
extra: CompatibleRequestHandlerExtra,
296+
) => Promise<CallToolResult>
297+
)(cleanedArgs, extra);
298+
}
299+
};
286300

287-
// Remove context from args before calling original callback
288-
// BUT keep it for get_more_tools since it's a required parameter
289-
const cleanedArgs =
290-
toolName === "get_more_tools" ? args : removeContextFromArgs(args);
291-
292-
// Call with original params
293-
return await (cleanedArgs === undefined
294-
? (
295-
originalCallback as (
296-
extra: CompatibleRequestHandlerExtra,
297-
) => Promise<CallToolResult>
298-
)(extra)
299-
: (
300-
originalCallback as (
301-
args: any,
302-
extra: CompatibleRequestHandlerExtra,
303-
) => Promise<CallToolResult>
304-
)(cleanedArgs, extra));
305-
}
301+
// Mark the original callback as wrapped
302+
wrappedCallbacks.set(originalCallback, true);
306303

307-
const sessionId = getServerSessionId(lowLevelServer, extra);
304+
// Mark the wrapped callback as well (in case it gets re-wrapped)
305+
wrappedCallbacks.set(wrappedCallback, true);
308306

309-
// Create a request-like object for compatibility with existing code
310-
const request = {
311-
params: {
312-
name: toolName,
313-
arguments: args,
314-
},
315-
};
307+
// Create a new tool object with the wrapped callback
308+
const wrappedTool = {
309+
...tool,
310+
callback: wrappedCallback as RegisteredTool["callback"],
311+
};
316312

317-
let event: UnredactedEvent = {
318-
sessionId: sessionId,
319-
resourceName: toolName,
320-
parameters: {
321-
request: request,
322-
extra: extra,
323-
},
324-
eventType: PublishEventRequestEventTypeEnum.mcpToolsCall,
325-
timestamp: new Date(),
326-
redactionFn: data.options.redactSensitiveInformation,
327-
};
313+
// Mark the tool as processed
314+
(wrappedTool as any)[MCPCAT_PROCESSED] = true;
328315

329-
try {
330-
// Try to identify the session if identify function is provided
331-
await handleIdentify(lowLevelServer, data, request, extra);
316+
return wrappedTool;
317+
}
318+
319+
function setupToolsCallHandlerWrapping(server: HighLevelMCPServerLike): void {
320+
const lowLevelServer = server.server as MCPServerLike;
321+
322+
// Check if tools/call handler already exists
323+
const existingHandler = lowLevelServer._requestHandlers.get("tools/call");
324+
if (existingHandler) {
325+
const wrappedHandler = createToolsCallWrapper(
326+
existingHandler,
327+
lowLevelServer,
328+
);
329+
lowLevelServer._requestHandlers.set("tools/call", wrappedHandler);
330+
}
331+
332+
// Intercept future calls to setRequestHandler for tools registered after track()
333+
const originalSetRequestHandler =
334+
lowLevelServer.setRequestHandler.bind(lowLevelServer);
335+
336+
lowLevelServer.setRequestHandler = function (
337+
requestSchema: any,
338+
handler: any,
339+
) {
340+
const method = requestSchema?.shape?.method?.value;
341+
342+
// Only wrap tools/call handler
343+
if (method === "tools/call") {
344+
const wrappedHandler = createToolsCallWrapper(handler, lowLevelServer);
345+
return originalSetRequestHandler(requestSchema, wrappedHandler);
346+
}
347+
348+
// Pass through all other handlers unchanged
349+
return originalSetRequestHandler(requestSchema, handler);
350+
} as any;
351+
}
332352

333-
// Update event sessionId in case handleIdentify reconnected to a different session
353+
function createToolsCallWrapper(
354+
originalHandler: any,
355+
server: MCPServerLike,
356+
): any {
357+
return async (request: any, extra: any) => {
358+
const startTime = new Date();
359+
let shouldPublishEvent = false;
360+
let event: UnredactedEvent | null = null;
361+
362+
try {
363+
const data = getServerTrackingData(server);
364+
365+
if (!data) {
366+
writeToLog(
367+
"Warning: MCPCat is unable to find server tracking data. Please ensure you have called track(server, options) before using tool calls.",
368+
);
369+
} else {
370+
shouldPublishEvent = true;
371+
372+
const sessionId = getServerSessionId(server, extra);
373+
374+
event = {
375+
sessionId,
376+
resourceName: request.params?.name || "Unknown Tool",
377+
parameters: { request, extra },
378+
eventType: PublishEventRequestEventTypeEnum.mcpToolsCall,
379+
timestamp: startTime,
380+
redactionFn: data.options.redactSensitiveInformation,
381+
};
382+
383+
// Identify user session
384+
await handleIdentify(server, data, request, extra);
334385
event.sessionId = data.sessionId;
335386

336-
// Extract context for userIntent if present
337-
if (args && typeof args === "object" && "context" in args) {
338-
event.userIntent = args.context;
387+
// Extract context for userIntent
388+
if (
389+
data.options.enableToolCallContext &&
390+
request.params?.arguments?.context
391+
) {
392+
event.userIntent = request.params.arguments.context;
339393
}
394+
}
395+
} catch (error) {
396+
// If tracing setup fails, log it but continue with tool execution
397+
writeToLog(
398+
`Warning: MCPCat tracing failed for tool ${request.params?.name}, falling back to original handler - ${error}`,
399+
);
400+
}
340401

341-
// Remove context from args before calling original callback
342-
// BUT keep it for get_more_tools since it's a required parameter
343-
const cleanedArgs =
344-
toolName === "get_more_tools" ? args : removeContextFromArgs(args);
345-
346-
let result = await (cleanedArgs === undefined
347-
? (
348-
originalCallback as (
349-
extra: CompatibleRequestHandlerExtra,
350-
) => Promise<CallToolResult>
351-
)(extra)
352-
: (
353-
originalCallback as (
354-
args: any,
355-
extra: CompatibleRequestHandlerExtra,
356-
) => Promise<CallToolResult>
357-
)(cleanedArgs, extra));
358-
359-
// Check if the result indicates an error
402+
// Execute the tool (this should always happen, even if tracing setup failed)
403+
try {
404+
const result = await originalHandler(request, extra);
405+
406+
if (event && shouldPublishEvent) {
407+
// Check for execution errors (SDK converts them to CallToolResult)
360408
if (isToolResultError(result)) {
361409
event.isError = true;
362410
event.error = captureException(result);
363411
}
364412

365413
event.response = result;
366-
event.duration =
367-
(event.timestamp &&
368-
new Date().getTime() - event.timestamp.getTime()) ||
369-
undefined;
370-
publishEvent(lowLevelServer, event);
371-
return result;
372-
} catch (error) {
414+
event.duration = new Date().getTime() - startTime.getTime();
415+
publishEvent(server, event);
416+
}
417+
418+
return result;
419+
} catch (error) {
420+
// Validation errors, unknown tool, disabled tool
421+
if (event && shouldPublishEvent) {
373422
event.isError = true;
374423
event.error = captureException(error);
375-
event.duration =
376-
(event.timestamp &&
377-
new Date().getTime() - event.timestamp.getTime()) ||
378-
undefined;
379-
publishEvent(lowLevelServer, event);
380-
throw error;
424+
event.duration = new Date().getTime() - startTime.getTime();
425+
publishEvent(server, event);
381426
}
382-
} catch (error) {
383-
// If any error occurs in our tracing code, log it and call the original callback
384-
writeToLog(
385-
`Warning: MCPCat tracing failed for tool ${toolName}, falling back to original callback - ${error}`,
386-
);
387427

388-
// Remove context from args before calling original callback
389-
// BUT keep it for get_more_tools since it's a required parameter
390-
const cleanedArgs =
391-
toolName === "get_more_tools" ? args : removeContextFromArgs(args);
392-
393-
return await (cleanedArgs === undefined
394-
? (
395-
originalCallback as (
396-
extra: CompatibleRequestHandlerExtra,
397-
) => Promise<CallToolResult>
398-
)(extra)
399-
: (
400-
originalCallback as (
401-
args: any,
402-
extra: CompatibleRequestHandlerExtra,
403-
) => Promise<CallToolResult>
404-
)(cleanedArgs, extra));
428+
// Re-throw so Protocol converts to JSONRPC error response
429+
throw error;
405430
}
406431
};
407-
408-
// Mark the original callback as wrapped
409-
wrappedCallbacks.set(originalCallback, true);
410-
411-
// Mark the wrapped callback as well (in case it gets re-wrapped)
412-
wrappedCallbacks.set(wrappedCallback, true);
413-
414-
// Create a new tool object with the wrapped callback
415-
const wrappedTool = {
416-
...tool,
417-
callback: wrappedCallback as RegisteredTool["callback"],
418-
};
419-
420-
// Mark the tool as processed
421-
(wrappedTool as any)[MCPCAT_PROCESSED] = true;
422-
423-
return wrappedTool;
424432
}
425433

426434
export function setupTracking(server: HighLevelMCPServerLike): void {
427435
try {
428436
const mcpcatData = getServerTrackingData(server.server);
429437

438+
// Setup handler wrapping before any tools are registered
439+
setupToolsCallHandlerWrapping(server);
440+
430441
setupInitializeTracing(server);
431442
// Modify existing tools to include context parameters in their inputSchemas
432443
if (mcpcatData?.options.enableToolCallContext) {

src/tests/context-parameters.test.ts

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,8 @@ describe("Context Parameters", () => {
315315
const listEvent = events.find(
316316
(e) =>
317317
e.eventType === PublishEventRequestEventTypeEnum.mcpToolsCall &&
318-
e.resourceName === "list_todos",
318+
e.resourceName === "list_todos" &&
319+
!e.isError, // Find the successful event, not the validation errors
319320
);
320321

321322
expect(listEvent).toBeDefined();
@@ -396,5 +397,69 @@ describe("Context Parameters", () => {
396397
);
397398
});
398399
});
400+
401+
it("should remove context parameter before calling tool callback", async () => {
402+
// Variable to capture what arguments the tool callback actually receives
403+
let capturedCallbackArguments: any = null;
404+
405+
// Register a test tool that captures its arguments
406+
const { z } = await import("zod");
407+
server.tool(
408+
"test_context_removal",
409+
"Test tool that captures callback arguments",
410+
{
411+
testParam: z.string().describe("A test parameter"),
412+
},
413+
async (args: any) => {
414+
// Capture exactly what arguments this callback receives
415+
capturedCallbackArguments = { ...args };
416+
return {
417+
content: [
418+
{
419+
type: "text",
420+
text: "Arguments captured",
421+
},
422+
],
423+
};
424+
},
425+
);
426+
427+
// Enable tracking with context parameters
428+
await track(server, {
429+
projectId: "test-project",
430+
enableTracing: true,
431+
});
432+
433+
// Call the test tool WITH context parameter
434+
const result = await client.request(
435+
{
436+
method: "tools/call",
437+
params: {
438+
name: "test_context_removal",
439+
arguments: {
440+
testParam: "test-value",
441+
context: "This context should be removed before callback",
442+
},
443+
},
444+
},
445+
CallToolResultSchema,
446+
);
447+
448+
// Wait for processing
449+
await new Promise((resolve) => setTimeout(resolve, 100));
450+
451+
// The tool call should succeed (successful calls have undefined isError)
452+
expect(result).toBeDefined();
453+
expect(result.isError).not.toBe(true);
454+
455+
// Verify that the callback received the testParam
456+
expect(capturedCallbackArguments).not.toBeNull();
457+
expect(capturedCallbackArguments).toHaveProperty("testParam");
458+
expect(capturedCallbackArguments.testParam).toBe("test-value");
459+
460+
// This is the key assertion: context should NOT be in the arguments
461+
// that the tool callback received (it should have been removed by the wrapper)
462+
expect(capturedCallbackArguments).not.toHaveProperty("context");
463+
});
399464
});
400465
});

0 commit comments

Comments
 (0)