From 23f3c07773a22f2a4376c0f8e7f3737016a0c23c Mon Sep 17 00:00:00 2001 From: Markus Ecker Date: Tue, 23 Sep 2025 23:32:07 +0200 Subject: [PATCH 1/5] Add middleware API --- .../integrations/mastra/src/mastra.ts | 2 +- .../middleware-starter/src/index.ts | 2 +- .../integrations/vercel-ai-sdk/src/index.ts | 2 +- .../packages/client/src/agent/agent.ts | 46 +- .../__tests__/filter-tool-calls.test.ts | 531 ++++++++++++++++++ .../__tests__/function-middleware.test.ts | 294 ++++++++++ .../__tests__/middleware-live-events.test.ts | 323 +++++++++++ .../__tests__/middleware-usage-example.ts | 78 +++ .../middleware/__tests__/middleware.test.ts | 306 ++++++++++ .../src/middleware/filter-tool-calls.ts | 98 ++++ .../packages/client/src/middleware/index.ts | 3 + .../client/src/middleware/middleware.ts | 20 + 12 files changed, 1699 insertions(+), 6 deletions(-) create mode 100644 typescript-sdk/packages/client/src/middleware/__tests__/filter-tool-calls.test.ts create mode 100644 typescript-sdk/packages/client/src/middleware/__tests__/function-middleware.test.ts create mode 100644 typescript-sdk/packages/client/src/middleware/__tests__/middleware-live-events.test.ts create mode 100644 typescript-sdk/packages/client/src/middleware/__tests__/middleware-usage-example.ts create mode 100644 typescript-sdk/packages/client/src/middleware/__tests__/middleware.test.ts create mode 100644 typescript-sdk/packages/client/src/middleware/filter-tool-calls.ts create mode 100644 typescript-sdk/packages/client/src/middleware/index.ts create mode 100644 typescript-sdk/packages/client/src/middleware/middleware.ts diff --git a/typescript-sdk/integrations/mastra/src/mastra.ts b/typescript-sdk/integrations/mastra/src/mastra.ts index f8310767b..675c618d7 100644 --- a/typescript-sdk/integrations/mastra/src/mastra.ts +++ b/typescript-sdk/integrations/mastra/src/mastra.ts @@ -59,7 +59,7 @@ export class MastraAgent extends AbstractAgent { this.runtimeContext = runtimeContext ?? new RuntimeContext(); } - protected run(input: RunAgentInput): Observable { + public run(input: RunAgentInput): Observable { let messageId = randomUUID(); return new Observable((subscriber) => { diff --git a/typescript-sdk/integrations/middleware-starter/src/index.ts b/typescript-sdk/integrations/middleware-starter/src/index.ts index b4e43e696..b0d66539a 100644 --- a/typescript-sdk/integrations/middleware-starter/src/index.ts +++ b/typescript-sdk/integrations/middleware-starter/src/index.ts @@ -2,7 +2,7 @@ import { AbstractAgent, BaseEvent, EventType, RunAgentInput } from "@ag-ui/clien import { Observable } from "rxjs"; export class MiddlewareStarterAgent extends AbstractAgent { - protected run(input: RunAgentInput): Observable { + public run(input: RunAgentInput): Observable { const messageId = Date.now().toString(); return new Observable((observer) => { observer.next({ diff --git a/typescript-sdk/integrations/vercel-ai-sdk/src/index.ts b/typescript-sdk/integrations/vercel-ai-sdk/src/index.ts index b06664f34..9dc76a936 100644 --- a/typescript-sdk/integrations/vercel-ai-sdk/src/index.ts +++ b/typescript-sdk/integrations/vercel-ai-sdk/src/index.ts @@ -55,7 +55,7 @@ export class VercelAISDKAgent extends AbstractAgent { this.toolChoice = toolChoice ?? "auto"; } - protected run(input: RunAgentInput): Observable { + public run(input: RunAgentInput): Observable { const finalMessages: Message[] = input.messages; return new Observable((subscriber) => { diff --git a/typescript-sdk/packages/client/src/agent/agent.ts b/typescript-sdk/packages/client/src/agent/agent.ts index f596a625c..2a5cfbd75 100644 --- a/typescript-sdk/packages/client/src/agent/agent.ts +++ b/typescript-sdk/packages/client/src/agent/agent.ts @@ -13,6 +13,7 @@ import { LegacyRuntimeProtocolEvent } from "@/legacy/types"; import { lastValueFrom } from "rxjs"; import { transformChunks } from "@/chunks"; import { AgentStateMutation, AgentSubscriber, runSubscribersWithMutation } from "./subscriber"; +import { Middleware, MiddlewareFunction, FunctionMiddleware } from "@/middleware"; export interface RunAgentResult { result: any; @@ -27,6 +28,7 @@ export abstract class AbstractAgent { public state: State; public debug: boolean = false; public subscribers: AgentSubscriber[] = []; + private middlewares: Middleware[] = []; constructor({ agentId, @@ -53,7 +55,15 @@ export abstract class AbstractAgent { }; } - protected abstract run(input: RunAgentInput): Observable; + public use(...middlewares: (Middleware | MiddlewareFunction)[]): this { + const normalizedMiddlewares = middlewares.map(m => + typeof m === 'function' ? new FunctionMiddleware(m) : m + ); + this.middlewares.push(...normalizedMiddlewares); + return this; + } + + public abstract run(input: RunAgentInput): Observable; public async runAgent( parameters?: RunAgentParameters, @@ -77,7 +87,21 @@ export abstract class AbstractAgent { await this.onInitialize(input, subscribers); const pipeline = pipe( - () => this.run(input), + () => { + // Build middleware chain using reduceRight + if (this.middlewares.length === 0) { + return this.run(input); + } + + const chainedAgent = this.middlewares.reduceRight( + (nextAgent: AbstractAgent, middleware) => ({ + run: (i: RunAgentInput) => middleware.run(i, nextAgent) + } as AbstractAgent), + this // Original agent is the final 'next' + ); + + return chainedAgent.run(input); + }, transformChunks(this.debug), verifyEvents(this.debug), (source$) => this.apply(input, source$, subscribers), @@ -416,7 +440,23 @@ export abstract class AbstractAgent { this.agentId = this.agentId ?? uuidv4(); const input = this.prepareRunAgentInput(config); - return this.run(input).pipe( + // Build middleware chain for legacy bridge + const runObservable = (() => { + if (this.middlewares.length === 0) { + return this.run(input); + } + + const chainedAgent = this.middlewares.reduceRight( + (nextAgent: AbstractAgent, middleware) => ({ + run: (i: RunAgentInput) => middleware.run(i, nextAgent) + } as AbstractAgent), + this + ); + + return chainedAgent.run(input); + })(); + + return runObservable.pipe( transformChunks(this.debug), verifyEvents(this.debug), convertToLegacyEvents(this.threadId, input.runId, this.agentId), diff --git a/typescript-sdk/packages/client/src/middleware/__tests__/filter-tool-calls.test.ts b/typescript-sdk/packages/client/src/middleware/__tests__/filter-tool-calls.test.ts new file mode 100644 index 000000000..45da1f806 --- /dev/null +++ b/typescript-sdk/packages/client/src/middleware/__tests__/filter-tool-calls.test.ts @@ -0,0 +1,531 @@ +import { AbstractAgent } from "@/agent"; +import { FilterToolCallsMiddleware } from "@/middleware/filter-tool-calls"; +import { Middleware } from "@/middleware"; +import { + BaseEvent, + EventType, + RunAgentInput, + ToolCallStartEvent, + ToolCallArgsEvent, + ToolCallEndEvent, + ToolCallResultEvent, + ToolCallChunkEvent +} from "@ag-ui/core"; +import { Observable } from "rxjs"; + +describe("FilterToolCallsMiddleware", () => { + class ToolCallingAgent extends AbstractAgent { + public run(input: RunAgentInput): Observable { + return new Observable((subscriber) => { + // Emit RUN_STARTED + subscriber.next({ + type: EventType.RUN_STARTED, + threadId: input.threadId, + runId: input.runId, + }); + + // Emit first tool call (calculator) + const toolCall1Id = "tool-call-1"; + subscriber.next({ + type: EventType.TOOL_CALL_START, + toolCallId: toolCall1Id, + toolCallName: "calculator", + parentMessageId: "message-1", + } as ToolCallStartEvent); + + subscriber.next({ + type: EventType.TOOL_CALL_ARGS, + toolCallId: toolCall1Id, + delta: '{"operation": "add", "a": 5, "b": 3}', + } as ToolCallArgsEvent); + + subscriber.next({ + type: EventType.TOOL_CALL_END, + toolCallId: toolCall1Id, + } as ToolCallEndEvent); + + subscriber.next({ + type: EventType.TOOL_CALL_RESULT, + messageId: "tool-message-1", + toolCallId: toolCall1Id, + content: "8", + } as ToolCallResultEvent); + + // Emit second tool call (weather) + const toolCall2Id = "tool-call-2"; + subscriber.next({ + type: EventType.TOOL_CALL_START, + toolCallId: toolCall2Id, + toolCallName: "weather", + parentMessageId: "message-2", + } as ToolCallStartEvent); + + subscriber.next({ + type: EventType.TOOL_CALL_ARGS, + toolCallId: toolCall2Id, + delta: '{"city": "New York"}', + } as ToolCallArgsEvent); + + subscriber.next({ + type: EventType.TOOL_CALL_END, + toolCallId: toolCall2Id, + } as ToolCallEndEvent); + + subscriber.next({ + type: EventType.TOOL_CALL_RESULT, + messageId: "tool-message-2", + toolCallId: toolCall2Id, + content: "Sunny, 72°F", + } as ToolCallResultEvent); + + // Emit third tool call (search) + const toolCall3Id = "tool-call-3"; + subscriber.next({ + type: EventType.TOOL_CALL_START, + toolCallId: toolCall3Id, + toolCallName: "search", + parentMessageId: "message-3", + } as ToolCallStartEvent); + + subscriber.next({ + type: EventType.TOOL_CALL_ARGS, + toolCallId: toolCall3Id, + delta: '{"query": "TypeScript middleware"}', + } as ToolCallArgsEvent); + + subscriber.next({ + type: EventType.TOOL_CALL_END, + toolCallId: toolCall3Id, + } as ToolCallEndEvent); + + subscriber.next({ + type: EventType.TOOL_CALL_RESULT, + messageId: "tool-message-3", + toolCallId: toolCall3Id, + content: "Results found...", + } as ToolCallResultEvent); + + // Emit RUN_FINISHED + subscriber.next({ + type: EventType.RUN_FINISHED, + threadId: input.threadId, + runId: input.runId, + }); + + subscriber.complete(); + }); + } + } + + const input: RunAgentInput = { + threadId: "test-thread", + runId: "test-run", + tools: [], + context: [], + forwardedProps: {}, + state: {}, + messages: [], + }; + + it("should filter out disallowed tool calls", async () => { + const agent = new ToolCallingAgent(); + const middleware = new FilterToolCallsMiddleware({ + disallowedToolCalls: ["calculator", "search"], + }); + + const events: BaseEvent[] = []; + await new Promise((resolve) => { + middleware.run(input, agent).subscribe({ + next: (event) => events.push(event), + complete: () => resolve(), + }); + }); + + // Should have RUN_STARTED, weather tool events (4), and RUN_FINISHED + expect(events.length).toBe(6); + + // Check that we have RUN_STARTED + expect(events[0].type).toBe(EventType.RUN_STARTED); + + // Check that only weather tool calls are present + const toolCallStarts = events.filter(e => e.type === EventType.TOOL_CALL_START) as ToolCallStartEvent[]; + expect(toolCallStarts.length).toBe(1); + expect(toolCallStarts[0].toolCallName).toBe("weather"); + + // Check that calculator and search are filtered out + const allToolNames = toolCallStarts.map(e => e.toolCallName); + expect(allToolNames).not.toContain("calculator"); + expect(allToolNames).not.toContain("search"); + + // Check that we have RUN_FINISHED + expect(events[events.length - 1].type).toBe(EventType.RUN_FINISHED); + }); + + it("should only allow specified tool calls", async () => { + const agent = new ToolCallingAgent(); + const middleware = new FilterToolCallsMiddleware({ + allowedToolCalls: ["weather"], + }); + + const events: BaseEvent[] = []; + await new Promise((resolve) => { + middleware.run(input, agent).subscribe({ + next: (event) => events.push(event), + complete: () => resolve(), + }); + }); + + // Should have RUN_STARTED, weather tool events (4), and RUN_FINISHED + expect(events.length).toBe(6); + + // Check that only weather tool calls are present + const toolCallStarts = events.filter(e => e.type === EventType.TOOL_CALL_START) as ToolCallStartEvent[]; + expect(toolCallStarts.length).toBe(1); + expect(toolCallStarts[0].toolCallName).toBe("weather"); + + // Verify all weather-related events are present + const weatherToolCallId = toolCallStarts[0].toolCallId; + const weatherArgs = events.find(e => + e.type === EventType.TOOL_CALL_ARGS && + (e as ToolCallArgsEvent).toolCallId === weatherToolCallId + ); + expect(weatherArgs).toBeDefined(); + + const weatherEnd = events.find(e => + e.type === EventType.TOOL_CALL_END && + (e as ToolCallEndEvent).toolCallId === weatherToolCallId + ); + expect(weatherEnd).toBeDefined(); + + const weatherResult = events.find(e => + e.type === EventType.TOOL_CALL_RESULT && + (e as ToolCallResultEvent).toolCallId === weatherToolCallId + ); + expect(weatherResult).toBeDefined(); + }); + + it("should filter all events for a blocked tool call", async () => { + const agent = new ToolCallingAgent(); + const middleware = new FilterToolCallsMiddleware({ + disallowedToolCalls: ["calculator"], + }); + + const events: BaseEvent[] = []; + await new Promise((resolve) => { + middleware.run(input, agent).subscribe({ + next: (event) => events.push(event), + complete: () => resolve(), + }); + }); + + // Should not have any calculator-related events + const calculatorEvents = events.filter(e => { + if (e.type === EventType.TOOL_CALL_START) { + return (e as ToolCallStartEvent).toolCallName === "calculator"; + } + if (e.type === EventType.TOOL_CALL_ARGS || + e.type === EventType.TOOL_CALL_END || + e.type === EventType.TOOL_CALL_RESULT) { + return (e as any).toolCallId === "tool-call-1"; + } + return false; + }); + + expect(calculatorEvents.length).toBe(0); + + // But should have weather and search events + const weatherStart = events.find(e => + e.type === EventType.TOOL_CALL_START && + (e as ToolCallStartEvent).toolCallName === "weather" + ); + expect(weatherStart).toBeDefined(); + + const searchStart = events.find(e => + e.type === EventType.TOOL_CALL_START && + (e as ToolCallStartEvent).toolCallName === "search" + ); + expect(searchStart).toBeDefined(); + }); + + it("should allow all tool calls when allowed list is empty", async () => { + const agent = new ToolCallingAgent(); + const middleware = new FilterToolCallsMiddleware({ + allowedToolCalls: [], + }); + + const events: BaseEvent[] = []; + await new Promise((resolve) => { + middleware.run(input, agent).subscribe({ + next: (event) => events.push(event), + complete: () => resolve(), + }); + }); + + // No tool calls should pass through with empty allowed list + const toolCallStarts = events.filter(e => e.type === EventType.TOOL_CALL_START); + expect(toolCallStarts.length).toBe(0); + }); + + it("should allow all tool calls when disallowed list is empty", async () => { + const agent = new ToolCallingAgent(); + const middleware = new FilterToolCallsMiddleware({ + disallowedToolCalls: [], + }); + + const events: BaseEvent[] = []; + await new Promise((resolve) => { + middleware.run(input, agent).subscribe({ + next: (event) => events.push(event), + complete: () => resolve(), + }); + }); + + // All tool calls should pass through with empty disallowed list + const toolCallStarts = events.filter(e => e.type === EventType.TOOL_CALL_START); + expect(toolCallStarts.length).toBe(3); + }); + + it("should throw error when both allowed and disallowed are specified", () => { + expect(() => { + new FilterToolCallsMiddleware({ + allowedToolCalls: ["calculator"], + disallowedToolCalls: ["weather"], + } as any); + }).toThrow("Cannot specify both allowedToolCalls and disallowedToolCalls"); + }); + + it("should throw error when neither allowed nor disallowed are specified", () => { + expect(() => { + new FilterToolCallsMiddleware({} as any); + }).toThrow("Must specify either allowedToolCalls or disallowedToolCalls"); + }); + + // Test removed - middleware now requires next parameter + + it("should work in a middleware chain", async () => { + const agent = new ToolCallingAgent(); + + // First middleware filters out calculator + const filterMiddleware = new FilterToolCallsMiddleware({ + disallowedToolCalls: ["calculator"], + }); + + // Second middleware could be any other middleware + class EventCounterMiddleware extends Middleware { + public eventCount = 0; + + public run(input: RunAgentInput, next: AbstractAgent): Observable { + + return new Observable((subscriber) => { + const subscription = next.run(input).subscribe({ + next: (event) => { + this.eventCount++; + subscriber.next(event); + }, + error: (err) => subscriber.error(err), + complete: () => subscriber.complete(), + }); + + return () => subscription.unsubscribe(); + }); + } + } + + const counterMiddleware = new EventCounterMiddleware(); + + agent.use(counterMiddleware, filterMiddleware); + + const input: RunAgentInput = { + threadId: "test-thread", + runId: "test-run", + tools: [], + context: [], + forwardedProps: {}, + state: {}, + messages: [], + }; + + const chainedAgent = agent["middlewares"].reduceRight( + (nextAgent: AbstractAgent, middleware) => ({ + run: (i: RunAgentInput) => middleware.run(i, nextAgent), + } as AbstractAgent), + agent + ); + + const events: BaseEvent[] = []; + await new Promise((resolve) => { + chainedAgent.run(input).subscribe({ + next: (event) => events.push(event), + complete: () => resolve(), + }); + }); + + // Counter should have seen the filtered events + expect(counterMiddleware.eventCount).toBe(10); // 2 run events + 8 tool events (2 tools * 4 events) + + // Final output should not have calculator events + const toolCallStarts = events.filter(e => e.type === EventType.TOOL_CALL_START) as ToolCallStartEvent[]; + expect(toolCallStarts.map(e => e.toolCallName)).toEqual(["weather", "search"]); + }); + + it("should filter TOOL_CALL_CHUNK events that are disallowed", async () => { + class ChunkEmittingAgent extends AbstractAgent { + public run(input: RunAgentInput): Observable { + return new Observable((subscriber) => { + // Emit RUN_STARTED + subscriber.next({ + type: EventType.RUN_STARTED, + threadId: input.threadId, + runId: input.runId, + }); + + // Emit calculator tool as chunks (should be filtered) + subscriber.next({ + type: EventType.TOOL_CALL_CHUNK, + toolCallId: "tool-1", + toolCallName: "calculator", + parentMessageId: "msg-1", + delta: '{"operation": "add",', + } as ToolCallChunkEvent); + + subscriber.next({ + type: EventType.TOOL_CALL_CHUNK, + toolCallId: "tool-1", + delta: '"a": 5, "b": 3}', + } as ToolCallChunkEvent); + + // Emit weather tool as chunks (should pass through) + subscriber.next({ + type: EventType.TOOL_CALL_CHUNK, + toolCallId: "tool-2", + toolCallName: "weather", + parentMessageId: "msg-2", + delta: '{"city": "Paris"}', + } as ToolCallChunkEvent); + + // Emit a close event to trigger chunk transformation + subscriber.next({ + type: EventType.RUN_FINISHED, + threadId: input.threadId, + runId: input.runId, + }); + + subscriber.complete(); + }); + } + } + + const agent = new ChunkEmittingAgent(); + const middleware = new FilterToolCallsMiddleware({ + disallowedToolCalls: ["calculator"], + }); + + const events: BaseEvent[] = []; + await new Promise((resolve) => { + middleware.run(input, agent).subscribe({ + next: (event) => events.push(event), + complete: () => resolve(), + }); + }); + + // Should have RUN_STARTED, weather tool events (START, ARGS, END), and RUN_FINISHED + const toolCallStarts = events.filter(e => e.type === EventType.TOOL_CALL_START) as ToolCallStartEvent[]; + expect(toolCallStarts.length).toBe(1); + expect(toolCallStarts[0].toolCallName).toBe("weather"); + + // Calculator chunks should have been transformed and then filtered + const calculatorEvents = events.filter(e => { + if (e.type === EventType.TOOL_CALL_START) { + return (e as ToolCallStartEvent).toolCallName === "calculator"; + } + if (e.type === EventType.TOOL_CALL_ARGS || + e.type === EventType.TOOL_CALL_END) { + return (e as any).toolCallId === "tool-1"; + } + return false; + }); + expect(calculatorEvents.length).toBe(0); + + // No TOOL_CALL_CHUNK events should remain (all transformed) + const chunkEvents = events.filter(e => e.type === EventType.TOOL_CALL_CHUNK); + expect(chunkEvents.length).toBe(0); + }); + + it("should only allow specified tool calls from chunks", async () => { + class ChunkEmittingAgent extends AbstractAgent { + public run(input: RunAgentInput): Observable { + return new Observable((subscriber) => { + // Emit RUN_STARTED + subscriber.next({ + type: EventType.RUN_STARTED, + threadId: input.threadId, + runId: input.runId, + }); + + // Emit three different tools as chunks + subscriber.next({ + type: EventType.TOOL_CALL_CHUNK, + toolCallId: "tool-1", + toolCallName: "calculator", + parentMessageId: "msg-1", + delta: '{"test": "data"}', + } as ToolCallChunkEvent); + + subscriber.next({ + type: EventType.TOOL_CALL_CHUNK, + toolCallId: "tool-2", + toolCallName: "weather", + parentMessageId: "msg-2", + delta: '{"city": "London"}', + } as ToolCallChunkEvent); + + subscriber.next({ + type: EventType.TOOL_CALL_CHUNK, + toolCallId: "tool-3", + toolCallName: "search", + parentMessageId: "msg-3", + delta: '{"query": "test"}', + } as ToolCallChunkEvent); + + // Close event + subscriber.next({ + type: EventType.RUN_FINISHED, + threadId: input.threadId, + runId: input.runId, + }); + + subscriber.complete(); + }); + } + } + + const agent = new ChunkEmittingAgent(); + const middleware = new FilterToolCallsMiddleware({ + allowedToolCalls: ["weather"], + }); + + const events: BaseEvent[] = []; + await new Promise((resolve) => { + middleware.run(input, agent).subscribe({ + next: (event) => events.push(event), + complete: () => resolve(), + }); + }); + + // Should only have weather tool events + const toolCallStarts = events.filter(e => e.type === EventType.TOOL_CALL_START) as ToolCallStartEvent[]; + expect(toolCallStarts.length).toBe(1); + expect(toolCallStarts[0].toolCallName).toBe("weather"); + + // Verify weather tool has all its events + const weatherEvents = events.filter(e => { + if (e.type === EventType.TOOL_CALL_START || + e.type === EventType.TOOL_CALL_ARGS || + e.type === EventType.TOOL_CALL_END) { + return (e as any).toolCallId === "tool-2"; + } + return false; + }); + expect(weatherEvents.length).toBe(3); // START, ARGS, END + }); +}); \ No newline at end of file diff --git a/typescript-sdk/packages/client/src/middleware/__tests__/function-middleware.test.ts b/typescript-sdk/packages/client/src/middleware/__tests__/function-middleware.test.ts new file mode 100644 index 000000000..8be5f6298 --- /dev/null +++ b/typescript-sdk/packages/client/src/middleware/__tests__/function-middleware.test.ts @@ -0,0 +1,294 @@ +import { AbstractAgent } from "@/agent"; +import { MiddlewareFunction } from "@/middleware"; +import { BaseEvent, EventType, RunAgentInput, TextMessageChunkEvent } from "@ag-ui/core"; +import { Observable } from "rxjs"; +import { map, tap } from "rxjs/operators"; + +describe("Function-based Middleware", () => { + class SimpleAgent extends AbstractAgent { + public run(input: RunAgentInput): Observable { + return new Observable((subscriber) => { + subscriber.next({ + type: EventType.RUN_STARTED, + threadId: input.threadId, + runId: input.runId, + }); + + subscriber.next({ + type: EventType.TEXT_MESSAGE_CHUNK, + role: "assistant", + messageId: "msg-1", + delta: "Hello from agent", + } as TextMessageChunkEvent); + + subscriber.next({ + type: EventType.RUN_FINISHED, + threadId: input.threadId, + runId: input.runId, + }); + + subscriber.complete(); + }); + } + } + + it("should accept a function as middleware", async () => { + const agent = new SimpleAgent(); + + // Define a simple function middleware that adds a prefix to text chunks + const prefixMiddleware: MiddlewareFunction = (input, next) => { + return next.run(input).pipe( + map((event) => { + if (event.type === EventType.TEXT_MESSAGE_CHUNK) { + const textEvent = event as TextMessageChunkEvent; + return { + ...textEvent, + delta: `[PREFIX] ${textEvent.delta}`, + } as TextMessageChunkEvent; + } + return event; + }) + ); + }; + + agent.use(prefixMiddleware); + + const input: RunAgentInput = { + threadId: "test-thread", + runId: "test-run", + tools: [], + context: [], + forwardedProps: {}, + state: {}, + messages: [], + }; + + const events: BaseEvent[] = []; + const chainedAgent = agent["middlewares"].reduceRight( + (nextAgent: AbstractAgent, middleware) => ({ + run: (i: RunAgentInput) => middleware.run(i, nextAgent), + } as AbstractAgent), + agent + ); + + await new Promise((resolve) => { + chainedAgent.run(input).subscribe({ + next: (event) => events.push(event), + complete: () => resolve(), + }); + }); + + expect(events.length).toBe(3); + expect(events[0].type).toBe(EventType.RUN_STARTED); + + const textEvent = events[1] as TextMessageChunkEvent; + expect(textEvent.type).toBe(EventType.TEXT_MESSAGE_CHUNK); + expect(textEvent.delta).toBe("[PREFIX] Hello from agent"); + + expect(events[2].type).toBe(EventType.RUN_FINISHED); + }); + + it("should chain multiple function middlewares", async () => { + const agent = new SimpleAgent(); + + // First middleware adds a prefix + const prefixMiddleware: MiddlewareFunction = (input, next) => { + return next.run(input).pipe( + map((event) => { + if (event.type === EventType.TEXT_MESSAGE_CHUNK) { + const textEvent = event as TextMessageChunkEvent; + return { + ...textEvent, + delta: `[PREFIX] ${textEvent.delta}`, + } as TextMessageChunkEvent; + } + return event; + }) + ); + }; + + // Second middleware adds a suffix + const suffixMiddleware: MiddlewareFunction = (input, next) => { + return next.run(input).pipe( + map((event) => { + if (event.type === EventType.TEXT_MESSAGE_CHUNK) { + const textEvent = event as TextMessageChunkEvent; + return { + ...textEvent, + delta: `${textEvent.delta} [SUFFIX]`, + } as TextMessageChunkEvent; + } + return event; + }) + ); + }; + + agent.use(prefixMiddleware, suffixMiddleware); + + const input: RunAgentInput = { + threadId: "test-thread", + runId: "test-run", + tools: [], + context: [], + forwardedProps: {}, + state: {}, + messages: [], + }; + + const events: BaseEvent[] = []; + const chainedAgent = agent["middlewares"].reduceRight( + (nextAgent: AbstractAgent, middleware) => ({ + run: (i: RunAgentInput) => middleware.run(i, nextAgent), + } as AbstractAgent), + agent + ); + + await new Promise((resolve) => { + chainedAgent.run(input).subscribe({ + next: (event) => events.push(event), + complete: () => resolve(), + }); + }); + + const textEvent = events[1] as TextMessageChunkEvent; + expect(textEvent.delta).toBe("[PREFIX] Hello from agent [SUFFIX]"); + }); + + it("should mix function and class middleware", async () => { + const agent = new SimpleAgent(); + + // Function middleware that adds a counter + let counter = 0; + const countingMiddleware: MiddlewareFunction = (input, next) => { + return next.run(input).pipe( + tap(() => counter++) + ); + }; + + // Class middleware that adds a prefix + class PrefixMiddleware extends Middleware { + constructor(private prefix: string) { + super(); + } + + run(input: RunAgentInput, next: AbstractAgent): Observable { + return next.run(input).pipe( + map((event) => { + if (event.type === EventType.TEXT_MESSAGE_CHUNK) { + const textEvent = event as TextMessageChunkEvent; + return { + ...textEvent, + delta: `${this.prefix} ${textEvent.delta}`, + } as TextMessageChunkEvent; + } + return event; + }) + ); + } + } + + const prefixMiddleware = new PrefixMiddleware("[CLASS]"); + + // Mix both types + agent.use(countingMiddleware, prefixMiddleware); + + const input: RunAgentInput = { + threadId: "test-thread", + runId: "test-run", + tools: [], + context: [], + forwardedProps: {}, + state: {}, + messages: [], + }; + + const events: BaseEvent[] = []; + const chainedAgent = agent["middlewares"].reduceRight( + (nextAgent: AbstractAgent, middleware) => ({ + run: (i: RunAgentInput) => middleware.run(i, nextAgent), + } as AbstractAgent), + agent + ); + + await new Promise((resolve) => { + chainedAgent.run(input).subscribe({ + next: (event) => events.push(event), + complete: () => resolve(), + }); + }); + + // Check that counting middleware ran + expect(counter).toBe(3); // 3 events total + + // Check that class middleware transformed the text + const textEvent = events[1] as TextMessageChunkEvent; + expect(textEvent.delta).toBe("[CLASS] Hello from agent"); + }); + + it("should handle event transformation in function middleware", async () => { + const agent = new SimpleAgent(); + + // Function middleware that counts events + let eventCount = 0; + const countingMiddleware: MiddlewareFunction = (input, next) => { + return next.run(input).pipe( + tap(() => { + eventCount++; + }) + ); + }; + + // Function middleware that filters events + const filterMiddleware: MiddlewareFunction = (input, next) => { + return next.run(input).pipe( + map((event) => { + // Add metadata to all events + return { + ...event, + metadata: { processed: true } + } as BaseEvent & { metadata: { processed: boolean } }; + }) + ); + }; + + agent.use(countingMiddleware, filterMiddleware); + + const input: RunAgentInput = { + threadId: "test-thread", + runId: "test-run", + tools: [], + context: [], + forwardedProps: {}, + state: {}, + messages: [], + }; + + const events: BaseEvent[] = []; + + const chainedAgent = agent["middlewares"].reduceRight( + (nextAgent: AbstractAgent, middleware) => ({ + run: (i: RunAgentInput) => middleware.run(i, nextAgent), + } as AbstractAgent), + agent + ); + + await new Promise((resolve) => { + chainedAgent.run(input).subscribe({ + next: (event) => events.push(event), + complete: () => resolve(), + }); + }); + + // Check that counting middleware counted all events + expect(eventCount).toBe(3); + + // Check that filter middleware added metadata + expect(events.length).toBe(3); + events.forEach(event => { + expect((event as any).metadata?.processed).toBe(true); + }); + }); +}); + +// Import Middleware here to avoid circular dependency issues +import { Middleware } from "@/middleware"; \ No newline at end of file diff --git a/typescript-sdk/packages/client/src/middleware/__tests__/middleware-live-events.test.ts b/typescript-sdk/packages/client/src/middleware/__tests__/middleware-live-events.test.ts new file mode 100644 index 000000000..72c70a104 --- /dev/null +++ b/typescript-sdk/packages/client/src/middleware/__tests__/middleware-live-events.test.ts @@ -0,0 +1,323 @@ +import { AbstractAgent } from "@/agent"; +import { Middleware } from "@/middleware"; +import { BaseEvent, EventType, RunAgentInput, TextMessageChunkEvent } from "@ag-ui/core"; +import { Observable, interval } from "rxjs"; +import { map, take, concatMap } from "rxjs/operators"; + +describe("Middleware Live Event Streaming", () => { + class StreamingAgent extends AbstractAgent { + public run(input: RunAgentInput): Observable { + return new Observable((subscriber) => { + // Emit RUN_STARTED immediately + subscriber.next({ + type: EventType.RUN_STARTED, + threadId: input.threadId, + runId: input.runId, + }); + + // Simulate streaming text chunks over time + const streamingText = ["Hello", " ", "world", "!"]; + let index = 0; + + const intervalSub = interval(50).pipe(take(streamingText.length)).subscribe({ + next: () => { + const chunk: TextMessageChunkEvent = { + type: EventType.TEXT_MESSAGE_CHUNK, + role: "assistant", + messageId: "streaming-message", + delta: streamingText[index++], + }; + subscriber.next(chunk); + }, + complete: () => { + // Emit RUN_FINISHED after all chunks + subscriber.next({ + type: EventType.RUN_FINISHED, + threadId: input.threadId, + runId: input.runId, + }); + subscriber.complete(); + } + }); + + return () => intervalSub.unsubscribe(); + }); + } + } + + class TimestampMiddleware extends Middleware { + public timestamps: Map = new Map(); + + public run(input: RunAgentInput, next: AbstractAgent): Observable { + + return next.run(input).pipe( + map((event) => { + const timestamp = Date.now(); + this.timestamps.set(event.type, timestamp); + return { + ...event, + timestamp, + } as BaseEvent & { timestamp: number }; + }) + ); + } + } + + class BufferingMiddleware extends Middleware { + private buffer: string = ""; + + public run(input: RunAgentInput, next: AbstractAgent): Observable { + + return new Observable((subscriber) => { + const subscription = next.run(input).subscribe({ + next: (event) => { + if (event.type === EventType.TEXT_MESSAGE_CHUNK) { + const chunkEvent = event as TextMessageChunkEvent; + this.buffer += chunkEvent.delta; + + // Only emit when we have a complete word or punctuation + if (chunkEvent.delta === " " || chunkEvent.delta === "!") { + const bufferedEvent: TextMessageChunkEvent = { + ...chunkEvent, + delta: this.buffer, + }; + this.buffer = ""; + subscriber.next(bufferedEvent); + } + } else { + // Pass through non-text events immediately + subscriber.next(event); + } + }, + error: (err) => subscriber.error(err), + complete: () => subscriber.complete(), + }); + + return () => subscription.unsubscribe(); + }); + } + } + + class DelayMiddleware extends Middleware { + constructor(private delayMs: number) { + super(); + } + + public run(input: RunAgentInput, next: AbstractAgent): Observable { + + return next.run(input).pipe( + concatMap((event) => + new Observable((subscriber) => { + setTimeout(() => { + subscriber.next(event); + subscriber.complete(); + }, this.delayMs); + }) + ) + ); + } + } + + it("should stream events live through middleware chain", async () => { + const agent = new StreamingAgent(); + const timestampMiddleware = new TimestampMiddleware(); + + agent.use(timestampMiddleware); + + const input: RunAgentInput = { + threadId: "test-thread", + runId: "test-run", + tools: [], + context: [], + forwardedProps: {}, + state: {}, + messages: [], + }; + + const events: BaseEvent[] = []; + const eventTimes: number[] = []; + + const chainedAgent = agent["middlewares"].reduceRight( + (nextAgent: AbstractAgent, middleware) => ({ + run: (i: RunAgentInput) => middleware.run(i, nextAgent), + } as AbstractAgent), + agent + ); + + const startTime = Date.now(); + + await new Promise((resolve) => { + chainedAgent.run(input).subscribe({ + next: (event) => { + events.push(event); + eventTimes.push(Date.now() - startTime); + }, + complete: () => resolve(), + }); + }); + + // Should receive events over time, not all at once + expect(events.length).toBe(6); // RUN_STARTED, 4 chunks, RUN_FINISHED + expect(events[0].type).toBe(EventType.RUN_STARTED); + expect(events[5].type).toBe(EventType.RUN_FINISHED); + + // Check that chunks arrived over time (with ~50ms intervals) + expect(eventTimes[2] - eventTimes[1]).toBeGreaterThanOrEqual(40); + expect(eventTimes[3] - eventTimes[2]).toBeGreaterThanOrEqual(40); + }); + + it("should buffer and transform events in real-time", async () => { + const agent = new StreamingAgent(); + const bufferMiddleware = new BufferingMiddleware(); + + agent.use(bufferMiddleware); + + const input: RunAgentInput = { + threadId: "test-thread", + runId: "test-run", + tools: [], + context: [], + forwardedProps: {}, + state: {}, + messages: [], + }; + + const events: BaseEvent[] = []; + + await new Promise((resolve) => { + bufferMiddleware.run(input, agent).subscribe({ + next: (event) => events.push(event), + complete: () => resolve(), + }); + }); + + // BufferingMiddleware should have combined chunks + const textEvents = events.filter(e => e.type === EventType.TEXT_MESSAGE_CHUNK); + expect(textEvents.length).toBe(2); // "Hello " and "world!" + expect((textEvents[0] as TextMessageChunkEvent).delta).toBe("Hello "); + expect((textEvents[1] as TextMessageChunkEvent).delta).toBe("world!"); + }); + + it("should process events through multiple middleware in order", async () => { + const agent = new StreamingAgent(); + const timestampMiddleware = new TimestampMiddleware(); + const delayMiddleware = new DelayMiddleware(10); + + agent.use(timestampMiddleware, delayMiddleware); + + const input: RunAgentInput = { + threadId: "test-thread", + runId: "test-run", + tools: [], + context: [], + forwardedProps: {}, + state: {}, + messages: [], + }; + + const events: BaseEvent[] = []; + const startTime = Date.now(); + + const chainedAgent = agent["middlewares"].reduceRight( + (nextAgent: AbstractAgent, middleware) => ({ + run: (i: RunAgentInput) => middleware.run(i, nextAgent), + } as AbstractAgent), + agent + ); + + await new Promise((resolve) => { + chainedAgent.run(input).subscribe({ + next: (event) => { + events.push(event); + }, + complete: () => resolve(), + }); + }); + + const totalTime = Date.now() - startTime; + + // Each event should have a timestamp from the first middleware + events.forEach(event => { + expect((event as any).timestamp).toBeDefined(); + }); + + // The delay middleware should have added delay to each event + expect(totalTime).toBeGreaterThanOrEqual(60); // 6 events * 10ms delay + }); + + it("should handle backpressure correctly", async () => { + class FastProducerAgent extends AbstractAgent { + public run(input: RunAgentInput): Observable { + return new Observable((subscriber) => { + subscriber.next({ + type: EventType.RUN_STARTED, + threadId: input.threadId, + runId: input.runId, + }); + + // Emit many events quickly + for (let i = 0; i < 100; i++) { + subscriber.next({ + type: EventType.TEXT_MESSAGE_CHUNK, + role: "assistant", + messageId: "fast-message", + delta: i.toString(), + } as TextMessageChunkEvent); + } + + subscriber.next({ + type: EventType.RUN_FINISHED, + threadId: input.threadId, + runId: input.runId, + }); + + subscriber.complete(); + }); + } + } + + class SlowConsumerMiddleware extends Middleware { + public processedCount = 0; + + public run(input: RunAgentInput, next: AbstractAgent): Observable { + + return next.run(input).pipe( + concatMap((event) => + new Observable((subscriber) => { + // Simulate slow processing + setTimeout(() => { + this.processedCount++; + subscriber.next(event); + subscriber.complete(); + }, 1); + }) + ) + ); + } + } + + const agent = new FastProducerAgent(); + const slowMiddleware = new SlowConsumerMiddleware(); + + agent.use(slowMiddleware); + + const input: RunAgentInput = { + threadId: "test-thread", + runId: "test-run", + tools: [], + context: [], + forwardedProps: {}, + state: {}, + messages: [], + }; + + await new Promise((resolve) => { + slowMiddleware.run(input, agent).subscribe({ + complete: () => resolve(), + }); + }); + + // All events should be processed despite the speed difference + expect(slowMiddleware.processedCount).toBe(102); // RUN_STARTED + 100 chunks + RUN_FINISHED + }); +}); \ No newline at end of file diff --git a/typescript-sdk/packages/client/src/middleware/__tests__/middleware-usage-example.ts b/typescript-sdk/packages/client/src/middleware/__tests__/middleware-usage-example.ts new file mode 100644 index 000000000..5b7b577ab --- /dev/null +++ b/typescript-sdk/packages/client/src/middleware/__tests__/middleware-usage-example.ts @@ -0,0 +1,78 @@ +/** + * Example usage of middleware with AbstractAgent + * This file demonstrates both class-based and function-based middleware + */ + +import { AbstractAgent } from "@/agent"; +import { Middleware, MiddlewareFunction } from "@/middleware"; +import { RunAgentInput, BaseEvent, EventType } from "@ag-ui/core"; +import { Observable } from "rxjs"; +import { map, tap } from "rxjs/operators"; + +// Example agent +class MyAgent extends AbstractAgent { + run(input: RunAgentInput): Observable { + return new Observable(subscriber => { + subscriber.next({ type: EventType.RUN_STARTED, threadId: input.threadId, runId: input.runId }); + // ... agent logic ... + subscriber.next({ type: EventType.RUN_FINISHED, threadId: input.threadId, runId: input.runId }); + subscriber.complete(); + }); + } +} + +// 1. Function-based middleware (simple and concise) +const loggingMiddleware: MiddlewareFunction = (input, next) => { + console.log('Request:', input); + return next.run(input).pipe( + tap(event => console.log('Event:', event)) + ); +}; + +// 2. Another function middleware +const timingMiddleware: MiddlewareFunction = (input, next) => { + const start = Date.now(); + return next.run(input).pipe( + tap({ + complete: () => console.log(`Execution took ${Date.now() - start}ms`) + }) + ); +}; + +// 3. Class-based middleware (when you need state or complex logic) +class AuthMiddleware extends Middleware { + constructor(private apiKey: string) { + super(); + } + + run(input: RunAgentInput, next: AbstractAgent): Observable { + // Add auth to context + const authenticatedInput = { + ...input, + context: [...input.context, { apiKey: this.apiKey }] + }; + return next.run(authenticatedInput); + } +} + +// Usage +async function example() { + const agent = new MyAgent(); + + // Can use function middleware directly + agent.use(loggingMiddleware); + + // Can chain multiple middleware (functions and classes) + agent.use( + timingMiddleware, + new AuthMiddleware('my-api-key'), + (input, next) => { + // Inline function middleware + console.log('Processing request...'); + return next.run(input); + } + ); + + // Run the agent - middleware will be applied automatically + await agent.runAgent(); +} \ No newline at end of file diff --git a/typescript-sdk/packages/client/src/middleware/__tests__/middleware.test.ts b/typescript-sdk/packages/client/src/middleware/__tests__/middleware.test.ts new file mode 100644 index 000000000..b9422b416 --- /dev/null +++ b/typescript-sdk/packages/client/src/middleware/__tests__/middleware.test.ts @@ -0,0 +1,306 @@ +import { AbstractAgent } from "@/agent"; +import { Middleware } from "@/middleware"; +import { BaseEvent, EventType, RunAgentInput, TextMessageChunkEvent } from "@ag-ui/core"; +import { Observable, of } from "rxjs"; +import { map, tap } from "rxjs/operators"; + +describe("Middleware", () => { + class SimpleAgent extends AbstractAgent { + public run(input: RunAgentInput): Observable { + return new Observable((subscriber) => { + subscriber.next({ + type: EventType.RUN_STARTED, + threadId: input.threadId, + runId: input.runId, + }); + + subscriber.next({ + type: EventType.RUN_FINISHED, + threadId: input.threadId, + runId: input.runId, + }); + + subscriber.complete(); + }); + } + } + + class TextInjectionMiddleware extends Middleware { + constructor(private text: string) { + super(); + } + + public run(input: RunAgentInput, next: AbstractAgent): Observable { + + return new Observable((subscriber) => { + const subscription = next.run(input).subscribe({ + next: (event) => { + subscriber.next(event); + + // Inject text message chunk after RUN_STARTED + if (event.type === EventType.RUN_STARTED) { + const textEvent: TextMessageChunkEvent = { + type: EventType.TEXT_MESSAGE_CHUNK, + role: "assistant", + messageId: "test-message-id", + delta: this.text, + }; + subscriber.next(textEvent); + } + }, + error: (err) => subscriber.error(err), + complete: () => subscriber.complete(), + }); + + return () => subscription.unsubscribe(); + }); + } + } + + class EventCounterMiddleware extends Middleware { + public eventCount = 0; + public eventTypes: EventType[] = []; + + public run(input: RunAgentInput, next: AbstractAgent): Observable { + + return next.run(input).pipe( + tap((event) => { + this.eventCount++; + this.eventTypes.push(event.type); + }) + ); + } + } + + class EventTransformMiddleware extends Middleware { + public run(input: RunAgentInput, next: AbstractAgent): Observable { + + return next.run(input).pipe( + map((event) => { + // Add metadata to all events + return { + ...event, + metadata: { transformed: true }, + } as BaseEvent; + }) + ); + } + } + + it("should inject text message chunk between RUN_STARTED and RUN_FINISHED", async () => { + const agent = new SimpleAgent(); + const middleware = new TextInjectionMiddleware("Hello from middleware!"); + + agent.use(middleware); + + const events: BaseEvent[] = []; + const result = await agent.runAgent({}, (params) => { + if (params.onEvent) { + params.onEvent({ event: params as any, messages: [], state: {}, agent, input: {} as any }); + } + }); + + // Collect events through the pipeline + const input: RunAgentInput = { + threadId: "test-thread", + runId: "test-run", + tools: [], + context: [], + forwardedProps: {}, + state: {}, + messages: [], + }; + + const subscription = agent["middlewares"][0].run(input, agent).subscribe({ + next: (event) => events.push(event), + }); + + await new Promise((resolve) => setTimeout(resolve, 100)); + + expect(events.length).toBe(3); + expect(events[0].type).toBe(EventType.RUN_STARTED); + expect(events[1].type).toBe(EventType.TEXT_MESSAGE_CHUNK); + expect((events[1] as TextMessageChunkEvent).delta).toBe("Hello from middleware!"); + expect(events[2].type).toBe(EventType.RUN_FINISHED); + }); + + it("should chain multiple middleware correctly", async () => { + const agent = new SimpleAgent(); + const textMiddleware1 = new TextInjectionMiddleware("First"); + const textMiddleware2 = new TextInjectionMiddleware("Second"); + + agent.use(textMiddleware1, textMiddleware2); + + const events: BaseEvent[] = []; + const input: RunAgentInput = { + threadId: "test-thread", + runId: "test-run", + tools: [], + context: [], + forwardedProps: {}, + state: {}, + messages: [], + }; + + // Build the chain as the agent does + const chainedAgent = agent["middlewares"].reduceRight( + (nextAgent: AbstractAgent, middleware) => ({ + run: (i: RunAgentInput) => middleware.run(i, nextAgent), + } as AbstractAgent), + agent + ); + + const subscription = chainedAgent.run(input).subscribe({ + next: (event) => events.push(event), + }); + + await new Promise((resolve) => setTimeout(resolve, 100)); + + expect(events.length).toBe(4); + expect(events[0].type).toBe(EventType.RUN_STARTED); + expect(events[1].type).toBe(EventType.TEXT_MESSAGE_CHUNK); + expect((events[1] as TextMessageChunkEvent).delta).toBe("First"); + expect(events[2].type).toBe(EventType.TEXT_MESSAGE_CHUNK); + expect((events[2] as TextMessageChunkEvent).delta).toBe("Second"); + expect(events[3].type).toBe(EventType.RUN_FINISHED); + }); + + it("should allow middleware to observe events", async () => { + const agent = new SimpleAgent(); + const counterMiddleware = new EventCounterMiddleware(); + + agent.use(counterMiddleware); + + const input: RunAgentInput = { + threadId: "test-thread", + runId: "test-run", + tools: [], + context: [], + forwardedProps: {}, + state: {}, + messages: [], + }; + + const events: BaseEvent[] = []; + const subscription = counterMiddleware.run(input, agent).subscribe({ + next: (event) => events.push(event), + }); + + await new Promise((resolve) => setTimeout(resolve, 100)); + + expect(counterMiddleware.eventCount).toBe(2); + expect(counterMiddleware.eventTypes).toEqual([ + EventType.RUN_STARTED, + EventType.RUN_FINISHED, + ]); + }); + + it("should allow middleware to transform events", async () => { + const agent = new SimpleAgent(); + const transformMiddleware = new EventTransformMiddleware(); + + agent.use(transformMiddleware); + + const input: RunAgentInput = { + threadId: "test-thread", + runId: "test-run", + tools: [], + context: [], + forwardedProps: {}, + state: {}, + messages: [], + }; + + const events: BaseEvent[] = []; + const subscription = transformMiddleware.run(input, agent).subscribe({ + next: (event) => events.push(event), + }); + + await new Promise((resolve) => setTimeout(resolve, 100)); + + expect(events.length).toBe(2); + expect((events[0] as any).metadata?.transformed).toBe(true); + expect((events[1] as any).metadata?.transformed).toBe(true); + }); + + + it("should work with 2 middleware and 1 actual agent in a chain", async () => { + // The actual agent that sends RUN_STARTED and RUN_FINISHED + const agent = new SimpleAgent(); + + // First middleware: modifies any text message chunks to have fixed text + class TextModifierMiddleware extends Middleware { + constructor(private replacementText: string) { + super(); + } + + public run(input: RunAgentInput, next: AbstractAgent): Observable { + + return next.run(input).pipe( + map((event) => { + // If it's a text message chunk, replace the delta + if (event.type === EventType.TEXT_MESSAGE_CHUNK) { + const textEvent = event as TextMessageChunkEvent; + return { + ...textEvent, + delta: this.replacementText, + } as TextMessageChunkEvent; + } + // Pass through other events unchanged + return event; + }) + ); + } + } + + // Second middleware: injects a text message chunk after RUN_STARTED + const textInjectionMiddleware = new TextInjectionMiddleware("Original text from middleware"); + + const textModifierMiddleware = new TextModifierMiddleware("Modified text!"); + + // Add middleware in order: modifier first (outermost), then injection (innermost) + // This way: modifier -> injection -> agent + // And events flow back: agent -> injection (adds text) -> modifier (modifies text) + agent.use(textModifierMiddleware, textInjectionMiddleware); + + const input: RunAgentInput = { + threadId: "test-thread", + runId: "test-run", + tools: [], + context: [], + forwardedProps: {}, + state: {}, + messages: [], + }; + + // Build the chain as the agent does internally + const chainedAgent = agent["middlewares"].reduceRight( + (nextAgent: AbstractAgent, middleware) => ({ + run: (i: RunAgentInput) => middleware.run(i, nextAgent), + } as AbstractAgent), + agent + ); + + const events: BaseEvent[] = []; + await new Promise((resolve) => { + chainedAgent.run(input).subscribe({ + next: (event) => events.push(event), + complete: () => resolve(), + }); + }); + + // Verify the event sequence + expect(events.length).toBe(3); + + // First event: RUN_STARTED from the agent + expect(events[0].type).toBe(EventType.RUN_STARTED); + + // Second event: TEXT_MESSAGE_CHUNK injected by first middleware, + // but with text modified by second middleware + expect(events[1].type).toBe(EventType.TEXT_MESSAGE_CHUNK); + expect((events[1] as TextMessageChunkEvent).delta).toBe("Modified text!"); + + // Third event: RUN_FINISHED from the agent + expect(events[2].type).toBe(EventType.RUN_FINISHED); + }); +}); \ No newline at end of file diff --git a/typescript-sdk/packages/client/src/middleware/filter-tool-calls.ts b/typescript-sdk/packages/client/src/middleware/filter-tool-calls.ts new file mode 100644 index 000000000..86ade3bfd --- /dev/null +++ b/typescript-sdk/packages/client/src/middleware/filter-tool-calls.ts @@ -0,0 +1,98 @@ +import { Middleware } from "./middleware"; +import { AbstractAgent } from "@/agent"; +import { RunAgentInput, BaseEvent, EventType, ToolCallStartEvent, ToolCallArgsEvent, ToolCallEndEvent, ToolCallResultEvent } from "@ag-ui/core"; +import { Observable } from "rxjs"; +import { filter } from "rxjs/operators"; +import { transformChunks } from "@/chunks"; + +type FilterToolCallsConfig = + | { allowedToolCalls: string[]; disallowedToolCalls?: never } + | { disallowedToolCalls: string[]; allowedToolCalls?: never }; + +export class FilterToolCallsMiddleware extends Middleware { + private blockedToolCallIds = new Set(); + private readonly allowedTools?: Set; + private readonly disallowedTools?: Set; + + constructor(config: FilterToolCallsConfig) { + super(); + + // Runtime validation (belt and suspenders approach) + if (config.allowedToolCalls && config.disallowedToolCalls) { + throw new Error("Cannot specify both allowedToolCalls and disallowedToolCalls"); + } + + if (!config.allowedToolCalls && !config.disallowedToolCalls) { + throw new Error("Must specify either allowedToolCalls or disallowedToolCalls"); + } + + if (config.allowedToolCalls) { + this.allowedTools = new Set(config.allowedToolCalls); + } else if (config.disallowedToolCalls) { + this.disallowedTools = new Set(config.disallowedToolCalls); + } + } + + public run(input: RunAgentInput, next: AbstractAgent): Observable { + // Apply transformChunks first to convert TOOL_CALL_CHUNK events + return next.run(input).pipe( + transformChunks(false), + filter((event) => { + // Handle TOOL_CALL_START events + if (event.type === EventType.TOOL_CALL_START) { + const toolCallStartEvent = event as ToolCallStartEvent; + const shouldFilter = this.shouldFilterTool(toolCallStartEvent.toolCallName); + + if (shouldFilter) { + // Track this tool call ID as blocked + this.blockedToolCallIds.add(toolCallStartEvent.toolCallId); + return false; // Filter out this event + } + + return true; // Allow this event + } + + // Handle TOOL_CALL_ARGS events + if (event.type === EventType.TOOL_CALL_ARGS) { + const toolCallArgsEvent = event as ToolCallArgsEvent; + return !this.blockedToolCallIds.has(toolCallArgsEvent.toolCallId); + } + + // Handle TOOL_CALL_END events + if (event.type === EventType.TOOL_CALL_END) { + const toolCallEndEvent = event as ToolCallEndEvent; + return !this.blockedToolCallIds.has(toolCallEndEvent.toolCallId); + } + + // Handle TOOL_CALL_RESULT events + if (event.type === EventType.TOOL_CALL_RESULT) { + const toolCallResultEvent = event as ToolCallResultEvent; + const isBlocked = this.blockedToolCallIds.has(toolCallResultEvent.toolCallId); + + if (isBlocked) { + // Clean up the blocked ID after the last event + this.blockedToolCallIds.delete(toolCallResultEvent.toolCallId); + return false; + } + + return true; + } + + // Allow all other events through + return true; + }) + ); + } + + private shouldFilterTool(toolName: string): boolean { + if (this.allowedTools) { + // If using allowed list, filter out tools NOT in the list + return !this.allowedTools.has(toolName); + } else if (this.disallowedTools) { + // If using disallowed list, filter out tools IN the list + return this.disallowedTools.has(toolName); + } + + return false; + } +} \ No newline at end of file diff --git a/typescript-sdk/packages/client/src/middleware/index.ts b/typescript-sdk/packages/client/src/middleware/index.ts new file mode 100644 index 000000000..e24b580f7 --- /dev/null +++ b/typescript-sdk/packages/client/src/middleware/index.ts @@ -0,0 +1,3 @@ +export { Middleware, FunctionMiddleware } from "./middleware"; +export type { MiddlewareFunction } from "./middleware"; +export { FilterToolCallsMiddleware } from "./filter-tool-calls"; \ No newline at end of file diff --git a/typescript-sdk/packages/client/src/middleware/middleware.ts b/typescript-sdk/packages/client/src/middleware/middleware.ts new file mode 100644 index 000000000..0fa7f789f --- /dev/null +++ b/typescript-sdk/packages/client/src/middleware/middleware.ts @@ -0,0 +1,20 @@ +import { AbstractAgent } from "@/agent"; +import { RunAgentInput, BaseEvent } from "@ag-ui/core"; +import { Observable } from "rxjs"; + +export type MiddlewareFunction = (input: RunAgentInput, next: AbstractAgent) => Observable; + +export abstract class Middleware { + abstract run(input: RunAgentInput, next: AbstractAgent): Observable; +} + +// Wrapper class to convert a function into a Middleware instance +export class FunctionMiddleware extends Middleware { + constructor(private fn: MiddlewareFunction) { + super(); + } + + run(input: RunAgentInput, next: AbstractAgent): Observable { + return this.fn(input, next); + } +} From ce29ad85f284e11494aec0dc17f1b201e23bedda Mon Sep 17 00:00:00 2001 From: Markus Ecker Date: Tue, 23 Sep 2025 23:42:41 +0200 Subject: [PATCH 2/5] Middleware docs --- docs/concepts/middleware.mdx | 307 +++++++++++++++++ docs/docs.json | 2 + docs/sdk/js/client/abstract-agent.mdx | 30 ++ docs/sdk/js/client/middleware.mdx | 408 +++++++++++++++++++++++ docs/sdk/js/client/overview.mdx | 24 ++ typescript-sdk/packages/client/README.md | 27 ++ 6 files changed, 798 insertions(+) create mode 100644 docs/concepts/middleware.mdx create mode 100644 docs/sdk/js/client/middleware.mdx diff --git a/docs/concepts/middleware.mdx b/docs/concepts/middleware.mdx new file mode 100644 index 000000000..835d10566 --- /dev/null +++ b/docs/concepts/middleware.mdx @@ -0,0 +1,307 @@ +--- +title: "Middleware" +description: "Transform and intercept events in AG-UI agents" +--- + +# Middleware + +Middleware in AG-UI provides a powerful way to transform, filter, and augment the event streams that flow through agents. It enables you to add cross-cutting concerns like logging, authentication, rate limiting, and event filtering without modifying the core agent logic. + +## What is Middleware? + +Middleware sits between the agent execution and the event consumer, allowing you to: + +1. **Transform events** – Modify or enhance events as they flow through the pipeline +2. **Filter events** – Selectively allow or block certain events +3. **Add metadata** – Inject additional context or tracking information +4. **Handle errors** – Implement custom error recovery strategies +5. **Monitor execution** – Add logging, metrics, or debugging capabilities + +## How Middleware Works + +Middleware forms a chain where each middleware wraps the next, creating layers of functionality. When an agent runs, the event stream flows through each middleware in sequence. + +```typescript +import { AbstractAgent } from "@ag-ui/client" + +const agent = new MyAgent() + +// Middleware chain: logging -> auth -> filter -> agent +agent.use(loggingMiddleware, authMiddleware, filterMiddleware) + +// When agent runs, events flow through all middleware +await agent.runAgent() +``` + +## Function-Based Middleware + +For simple transformations, you can use function-based middleware. This is the most concise way to add middleware: + +```typescript +import { MiddlewareFunction } from "@ag-ui/client" +import { EventType } from "@ag-ui/core" + +const prefixMiddleware: MiddlewareFunction = (input, next) => { + return next.run(input).pipe( + map(event => { + if (event.type === EventType.TEXT_MESSAGE_CHUNK) { + return { + ...event, + delta: `[AI]: ${event.delta}` + } + } + return event + }) + ) +} + +agent.use(prefixMiddleware) +``` + +## Class-Based Middleware + +For more complex scenarios requiring state or configuration, use class-based middleware: + +```typescript +import { Middleware } from "@ag-ui/client" +import { Observable } from "rxjs" +import { tap } from "rxjs/operators" + +class MetricsMiddleware extends Middleware { + private eventCount = 0 + + constructor(private metricsService: MetricsService) { + super() + } + + run(input: RunAgentInput, next: AbstractAgent): Observable { + const startTime = Date.now() + + return next.run(input).pipe( + tap(event => { + this.eventCount++ + this.metricsService.recordEvent(event.type) + }), + finalize(() => { + const duration = Date.now() - startTime + this.metricsService.recordDuration(duration) + this.metricsService.recordEventCount(this.eventCount) + }) + ) + } +} + +agent.use(new MetricsMiddleware(metricsService)) +``` + +## Built-in Middleware + +AG-UI provides several built-in middleware components for common use cases: + +### FilterToolCallsMiddleware + +Filter tool calls based on allowed or disallowed lists: + +```typescript +import { FilterToolCallsMiddleware } from "@ag-ui/client" + +// Only allow specific tools +const allowedFilter = new FilterToolCallsMiddleware({ + allowedToolCalls: ["search", "calculate"] +}) + +// Or block specific tools +const blockedFilter = new FilterToolCallsMiddleware({ + disallowedToolCalls: ["delete", "modify"] +}) + +agent.use(allowedFilter) +``` + +## Middleware Patterns + +### Logging Middleware + +```typescript +const loggingMiddleware: MiddlewareFunction = (input, next) => { + console.log("Request:", input.messages) + + return next.run(input).pipe( + tap(event => console.log("Event:", event.type)), + catchError(error => { + console.error("Error:", error) + throw error + }) + ) +} +``` + +### Authentication Middleware + +```typescript +class AuthMiddleware extends Middleware { + constructor(private apiKey: string) { + super() + } + + run(input: RunAgentInput, next: AbstractAgent): Observable { + // Add authentication to the context + const authenticatedInput = { + ...input, + context: [ + ...input.context, + { type: "auth", apiKey: this.apiKey } + ] + } + + return next.run(authenticatedInput) + } +} +``` + +### Rate Limiting Middleware + +```typescript +class RateLimitMiddleware extends Middleware { + private lastCall = 0 + + constructor(private minInterval: number) { + super() + } + + run(input: RunAgentInput, next: AbstractAgent): Observable { + const now = Date.now() + const timeSinceLastCall = now - this.lastCall + + if (timeSinceLastCall < this.minInterval) { + const delay = this.minInterval - timeSinceLastCall + return timer(delay).pipe( + switchMap(() => { + this.lastCall = Date.now() + return next.run(input) + }) + ) + } + + this.lastCall = now + return next.run(input) + } +} +``` + +## Combining Middleware + +You can combine multiple middleware to create sophisticated processing pipelines: + +```typescript +// Function middleware for simple logging +const logMiddleware: MiddlewareFunction = (input, next) => { + console.log(`Starting run ${input.runId}`) + return next.run(input) +} + +// Class middleware for authentication +const authMiddleware = new AuthMiddleware(apiKey) + +// Built-in middleware for filtering +const filterMiddleware = new FilterToolCallsMiddleware({ + allowedToolCalls: ["search", "summarize"] +}) + +// Apply all middleware in order +agent.use( + logMiddleware, // First: log the request + authMiddleware, // Second: add authentication + filterMiddleware // Third: filter tool calls +) +``` + +## Execution Order + +Middleware executes in the order it's added, with each middleware wrapping the next: + +1. First middleware receives the original input +2. It can modify the input before passing to the next middleware +3. Each middleware processes events from the next in the chain +4. The final middleware calls the actual agent + +```typescript +agent.use(middleware1, middleware2, middleware3) + +// Execution flow: +// → middleware1 +// → middleware2 +// → middleware3 +// → agent.run() +// ← events flow back through middleware3 +// ← events flow back through middleware2 +// ← events flow back through middleware1 +``` + +## Best Practices + +1. **Keep middleware focused** – Each middleware should have a single responsibility +2. **Handle errors gracefully** – Use RxJS error handling operators +3. **Avoid blocking operations** – Use async patterns for I/O operations +4. **Document side effects** – Clearly indicate if middleware modifies state +5. **Test middleware independently** – Write unit tests for each middleware +6. **Consider performance** – Be mindful of processing overhead in the event stream + +## Advanced Use Cases + +### Conditional Middleware + +Apply middleware based on runtime conditions: + +```typescript +const conditionalMiddleware: MiddlewareFunction = (input, next) => { + if (input.context.some(c => c.type === "debug")) { + // Apply debug logging + return next.run(input).pipe( + tap(event => console.debug(event)) + ) + } + return next.run(input) +} +``` + +### Event Transformation + +Transform specific event types: + +```typescript +const transformMiddleware: MiddlewareFunction = (input, next) => { + return next.run(input).pipe( + map(event => { + if (event.type === EventType.TOOL_CALL_START) { + // Add timestamp to tool calls + return { + ...event, + metadata: { + ...event.metadata, + timestamp: Date.now() + } + } + } + return event + }) + ) +} +``` + +### Stream Control + +Control the flow of events: + +```typescript +const throttleMiddleware: MiddlewareFunction = (input, next) => { + return next.run(input).pipe( + // Throttle text message chunks to prevent overwhelming the UI + throttleTime(50, undefined, { leading: true, trailing: true }) + ) +} +``` + +## Conclusion + +Middleware provides a flexible and powerful way to extend AG-UI agents without modifying their core logic. Whether you need simple event transformation or complex stateful processing, the middleware system offers the tools to build robust, maintainable agent applications. \ No newline at end of file diff --git a/docs/docs.json b/docs/docs.json index e6c6a9c4d..fd721bb53 100644 --- a/docs/docs.json +++ b/docs/docs.json @@ -40,6 +40,7 @@ "concepts/architecture", "concepts/events", "concepts/agents", + "concepts/middleware", "concepts/messages", "concepts/state", "concepts/tools" @@ -85,6 +86,7 @@ "sdk/js/client/overview", "sdk/js/client/abstract-agent", "sdk/js/client/http-agent", + "sdk/js/client/middleware", "sdk/js/client/subscriber" ] }, diff --git a/docs/sdk/js/client/abstract-agent.mdx b/docs/sdk/js/client/abstract-agent.mdx index 3738b0aae..7c6acafb5 100644 --- a/docs/sdk/js/client/abstract-agent.mdx +++ b/docs/sdk/js/client/abstract-agent.mdx @@ -95,6 +95,36 @@ subscribe(subscriber: AgentSubscriber): { unsubscribe: () => void } Returns an object with an `unsubscribe()` method to remove the subscriber when no longer needed. +### use() + +Adds middleware to the agent's event processing pipeline. + +```typescript +use(...middlewares: (Middleware | MiddlewareFunction)[]): this +``` + +Middleware can be either: +- **Function middleware**: Simple functions that transform the event stream +- **Class middleware**: Instances of the `Middleware` class for stateful operations + +```typescript +// Function middleware +agent.use((input, next) => { + console.log("Processing:", input.runId); + return next.run(input); +}); + +// Class middleware +agent.use(new FilterToolCallsMiddleware({ + allowedToolCalls: ["search"] +})); + +// Chain multiple middleware +agent.use(loggingMiddleware, authMiddleware, filterMiddleware); +``` + +Middleware executes in the order added, with each wrapping the next. See the [Middleware documentation](/sdk/js/client/middleware) for more details. + ### abortRun() Cancels the current agent execution. diff --git a/docs/sdk/js/client/middleware.mdx b/docs/sdk/js/client/middleware.mdx new file mode 100644 index 000000000..31462257b --- /dev/null +++ b/docs/sdk/js/client/middleware.mdx @@ -0,0 +1,408 @@ +--- +title: "Middleware" +description: "Event stream transformation and filtering for AG-UI agents" +--- + +# Middleware + +The middleware system in `@ag-ui/client` provides a powerful way to transform, filter, and augment event streams flowing through agents. Middleware can intercept and modify events, add logging, implement authentication, filter tool calls, and more. + +```typescript +import { Middleware, MiddlewareFunction, FilterToolCallsMiddleware } from "@ag-ui/client" +``` + +## Types + +### MiddlewareFunction + +A function that transforms the event stream. + +```typescript +type MiddlewareFunction = ( + input: RunAgentInput, + next: AbstractAgent +) => Observable +``` + +### Middleware + +Abstract base class for creating middleware. + +```typescript +abstract class Middleware { + abstract run( + input: RunAgentInput, + next: AbstractAgent + ): Observable +} +``` + +## Function-Based Middleware + +The simplest way to create middleware is with a function. Function middleware is ideal for stateless transformations. + +### Basic Example + +```typescript +const loggingMiddleware: MiddlewareFunction = (input, next) => { + console.log(`[${new Date().toISOString()}] Starting run ${input.runId}`) + + return next.run(input).pipe( + tap(event => console.log(`Event: ${event.type}`)), + finalize(() => console.log(`Run ${input.runId} completed`)) + ) +} + +agent.use(loggingMiddleware) +``` + +### Transforming Events + +```typescript +const prefixMiddleware: MiddlewareFunction = (input, next) => { + return next.run(input).pipe( + map(event => { + if (event.type === EventType.TEXT_MESSAGE_CHUNK) { + return { + ...event, + delta: `[Assistant]: ${event.delta}` + } + } + return event + }) + ) +} +``` + +### Error Handling + +```typescript +const errorMiddleware: MiddlewareFunction = (input, next) => { + return next.run(input).pipe( + catchError(error => { + console.error("Agent error:", error) + + // Return error event + return of({ + type: EventType.RUN_ERROR, + message: error.message + } as BaseEvent) + }) + ) +} +``` + +## Class-Based Middleware + +For stateful operations or complex logic, extend the `Middleware` class. + +### Basic Implementation + +```typescript +class CounterMiddleware extends Middleware { + private totalEvents = 0 + + run(input: RunAgentInput, next: AbstractAgent): Observable { + let runEvents = 0 + + return next.run(input).pipe( + tap(() => { + runEvents++ + this.totalEvents++ + }), + finalize(() => { + console.log(`Run events: ${runEvents}, Total: ${this.totalEvents}`) + }) + ) + } +} + +agent.use(new CounterMiddleware()) +``` + +### Configuration-Based Middleware + +```typescript +class AuthMiddleware extends Middleware { + constructor( + private apiKey: string, + private headerName: string = "Authorization" + ) { + super() + } + + run(input: RunAgentInput, next: AbstractAgent): Observable { + // Add authentication to context + const authenticatedInput = { + ...input, + context: [ + ...input.context, + { + type: "auth", + [this.headerName]: `Bearer ${this.apiKey}` + } + ] + } + + return next.run(authenticatedInput) + } +} + +agent.use(new AuthMiddleware(process.env.API_KEY)) +``` + +## Built-in Middleware + +### FilterToolCallsMiddleware + +Filters tool calls based on allowed or disallowed lists. + +```typescript +import { FilterToolCallsMiddleware } from "@ag-ui/client" +``` + +#### Configuration + +```typescript +type FilterToolCallsConfig = + | { allowedToolCalls: string[]; disallowedToolCalls?: never } + | { disallowedToolCalls: string[]; allowedToolCalls?: never } +``` + +#### Allow Specific Tools + +```typescript +const allowFilter = new FilterToolCallsMiddleware({ + allowedToolCalls: ["search", "calculate", "summarize"] +}) + +agent.use(allowFilter) +``` + +#### Block Specific Tools + +```typescript +const blockFilter = new FilterToolCallsMiddleware({ + disallowedToolCalls: ["delete", "modify", "execute"] +}) + +agent.use(blockFilter) +``` + +## Middleware Patterns + +### Timing Middleware + +```typescript +const timingMiddleware: MiddlewareFunction = (input, next) => { + const startTime = performance.now() + + return next.run(input).pipe( + finalize(() => { + const duration = performance.now() - startTime + console.log(`Execution time: ${duration.toFixed(2)}ms`) + }) + ) +} +``` + +### Rate Limiting + +```typescript +class RateLimitMiddleware extends Middleware { + private lastCall = 0 + + constructor(private minInterval: number) { + super() + } + + run(input: RunAgentInput, next: AbstractAgent): Observable { + const now = Date.now() + const elapsed = now - this.lastCall + + if (elapsed < this.minInterval) { + // Delay the execution + return timer(this.minInterval - elapsed).pipe( + switchMap(() => { + this.lastCall = Date.now() + return next.run(input) + }) + ) + } + + this.lastCall = now + return next.run(input) + } +} + +// Limit to one request per second +agent.use(new RateLimitMiddleware(1000)) +``` + +### Retry Logic + +```typescript +const retryMiddleware: MiddlewareFunction = (input, next) => { + return next.run(input).pipe( + retry({ + count: 3, + delay: (error, retryCount) => { + console.log(`Retry attempt ${retryCount}`) + return timer(1000 * retryCount) // Exponential backoff + } + }) + ) +} +``` + +### Caching + +```typescript +class CacheMiddleware extends Middleware { + private cache = new Map() + + run(input: RunAgentInput, next: AbstractAgent): Observable { + const cacheKey = this.getCacheKey(input) + + if (this.cache.has(cacheKey)) { + console.log("Cache hit") + return from(this.cache.get(cacheKey)!) + } + + const events: BaseEvent[] = [] + + return next.run(input).pipe( + tap(event => events.push(event)), + finalize(() => { + this.cache.set(cacheKey, events) + }) + ) + } + + private getCacheKey(input: RunAgentInput): string { + // Create a cache key from the input + return JSON.stringify({ + messages: input.messages, + tools: input.tools.map(t => t.name) + }) + } +} +``` + +## Chaining Middleware + +Multiple middleware can be combined to create sophisticated processing pipelines. + +```typescript +// Create middleware instances +const logger = loggingMiddleware +const auth = new AuthMiddleware(apiKey) +const filter = new FilterToolCallsMiddleware({ + allowedToolCalls: ["search"] +}) +const rateLimit = new RateLimitMiddleware(1000) + +// Apply middleware in order +agent.use( + logger, // First: Log all events + auth, // Second: Add authentication + rateLimit, // Third: Apply rate limiting + filter // Fourth: Filter tool calls +) + +// Execution flow: +// logger → auth → rateLimit → filter → agent → filter → rateLimit → auth → logger +``` + +## Advanced Usage + +### Conditional Middleware + +```typescript +const debugMiddleware: MiddlewareFunction = (input, next) => { + const isDebug = input.context.some(c => c.type === "debug") + + if (!isDebug) { + return next.run(input) + } + + return next.run(input).pipe( + tap(event => { + console.debug("[DEBUG]", JSON.stringify(event, null, 2)) + }) + ) +} +``` + +### Event Filtering + +```typescript +const filterEventsMiddleware: MiddlewareFunction = (input, next) => { + return next.run(input).pipe( + filter(event => { + // Only allow specific event types + return [ + EventType.RUN_STARTED, + EventType.TEXT_MESSAGE_CHUNK, + EventType.RUN_FINISHED + ].includes(event.type) + }) + ) +} +``` + +### Stream Manipulation + +```typescript +const bufferMiddleware: MiddlewareFunction = (input, next) => { + return next.run(input).pipe( + // Buffer text chunks and emit them in batches + bufferWhen(() => + interval(100).pipe( + filter(() => true) + ) + ), + map(events => events.flat()) + ) +} +``` + +## Best Practices + +1. **Single Responsibility**: Each middleware should focus on one concern +2. **Error Handling**: Always handle errors gracefully and consider recovery strategies +3. **Performance**: Be mindful of processing overhead in high-throughput scenarios +4. **State Management**: Use class-based middleware when state is required +5. **Testing**: Write unit tests for each middleware independently +6. **Documentation**: Document middleware behavior and side effects + +## TypeScript Support + +The middleware system is fully typed for excellent IDE support: + +```typescript +import { + Middleware, + MiddlewareFunction, + FilterToolCallsMiddleware +} from "@ag-ui/client" +import { RunAgentInput, BaseEvent, EventType } from "@ag-ui/core" + +// Type-safe middleware function +const typedMiddleware: MiddlewareFunction = ( + input: RunAgentInput, + next: AbstractAgent +): Observable => { + return next.run(input) +} + +// Type-safe middleware class +class TypedMiddleware extends Middleware { + run( + input: RunAgentInput, + next: AbstractAgent + ): Observable { + return next.run(input) + } +} +``` \ No newline at end of file diff --git a/docs/sdk/js/client/overview.mdx b/docs/sdk/js/client/overview.mdx index 0bfcf085f..3902715c5 100644 --- a/docs/sdk/js/client/overview.mdx +++ b/docs/sdk/js/client/overview.mdx @@ -61,6 +61,30 @@ Concrete implementation for HTTP-based agent connectivity: efficient event encoding format +## Middleware + +Transform and intercept event streams flowing through agents with a flexible +middleware system: + +- [Function Middleware](/sdk/js/client/middleware#function-based-middleware) - Simple + transformations with plain functions +- [Class Middleware](/sdk/js/client/middleware#class-based-middleware) - Stateful + middleware with configuration +- [Built-in Middleware](/sdk/js/client/middleware#built-in-middleware) - + FilterToolCallsMiddleware and more +- [Middleware Patterns](/sdk/js/client/middleware#middleware-patterns) - Common + use cases and examples + + + Powerful event stream transformation and filtering for AG-UI agents + + ## AgentSubscriber Event-driven subscriber system for handling agent lifecycle events and state diff --git a/typescript-sdk/packages/client/README.md b/typescript-sdk/packages/client/README.md index 1be36135a..fbc9c41db 100644 --- a/typescript-sdk/packages/client/README.md +++ b/typescript-sdk/packages/client/README.md @@ -19,6 +19,7 @@ yarn add @ag-ui/client - 📡 **Event streaming** – Full AG-UI event processing with validation and transformation - 🔄 **State management** – Automatic message/state tracking with reactive updates - 🪝 **Subscriber system** – Middleware-style hooks for logging, persistence, and custom logic +- 🎯 **Middleware support** – Transform and filter events with function or class-based middleware ## Quick example @@ -37,6 +38,32 @@ const result = await agent.runAgent({ console.log(result.newMessages); ``` +## Using Middleware + +```ts +import { HttpAgent, FilterToolCallsMiddleware } from "@ag-ui/client"; + +const agent = new HttpAgent({ + url: "https://api.example.com/agent", +}); + +// Add middleware to transform or filter events +agent.use( + // Function middleware for logging + (input, next) => { + console.log("Starting run:", input.runId); + return next.run(input); + }, + + // Class middleware for filtering tool calls + new FilterToolCallsMiddleware({ + allowedToolCalls: ["search", "calculate"] + }) +); + +await agent.runAgent(); +``` + ## Documentation - Concepts & architecture: [`docs/concepts`](https://docs.ag-ui.com/concepts/architecture) From edb3a1f9e1947db552491a5c674e85a3b0e0462a Mon Sep 17 00:00:00 2001 From: Markus Ecker Date: Thu, 25 Sep 2025 23:54:05 +0200 Subject: [PATCH 3/5] wip --- .../__tests__/middleware-with-state.test.ts | 302 ++++++++++++++++++ .../client/src/middleware/middleware.ts | 68 +++- 2 files changed, 368 insertions(+), 2 deletions(-) create mode 100644 typescript-sdk/packages/client/src/middleware/__tests__/middleware-with-state.test.ts diff --git a/typescript-sdk/packages/client/src/middleware/__tests__/middleware-with-state.test.ts b/typescript-sdk/packages/client/src/middleware/__tests__/middleware-with-state.test.ts new file mode 100644 index 000000000..4aebc5da2 --- /dev/null +++ b/typescript-sdk/packages/client/src/middleware/__tests__/middleware-with-state.test.ts @@ -0,0 +1,302 @@ +import { Middleware, EventWithState } from "../middleware"; +import { AbstractAgent } from "@/agent"; +import { + RunAgentInput, + BaseEvent, + EventType, + TextMessageStartEvent, + TextMessageContentEvent, + StateSnapshotEvent, + StateDeltaEvent, + MessagesSnapshotEvent, + ToolCallStartEvent, + ToolCallArgsEvent, + ToolCallEndEvent, +} from "@ag-ui/core"; +import { Observable, from } from "rxjs"; +import { map, toArray } from "rxjs/operators"; + +// Mock agent for testing +class MockAgent extends AbstractAgent { + constructor(private events: BaseEvent[]) { + super(); + } + + run(input: RunAgentInput): Observable { + return from(this.events); + } +} + +// Test middleware that uses runNextWithState +class TestMiddleware extends Middleware { + run(input: RunAgentInput, next: AbstractAgent): Observable { + return this.runNextWithState(input, next).pipe( + map(({ event }) => event) + ); + } + + // Expose for testing + testRunNextWithState( + input: RunAgentInput, + next: AbstractAgent + ): Observable { + return this.runNextWithState(input, next); + } +} + +describe("Middleware.runNextWithState", () => { + it("should track messages as they are built", async () => { + const events: BaseEvent[] = [ + { + type: EventType.TEXT_MESSAGE_START, + messageId: "msg1", + role: "assistant", + } as TextMessageStartEvent, + { + type: EventType.TEXT_MESSAGE_CONTENT, + messageId: "msg1", + delta: "Hello", + } as TextMessageContentEvent, + { + type: EventType.TEXT_MESSAGE_CONTENT, + messageId: "msg1", + delta: " world", + } as TextMessageContentEvent, + ]; + + const agent = new MockAgent(events); + const middleware = new TestMiddleware(); + const input: RunAgentInput = { messages: [], state: {} }; + + const results = await middleware + .testRunNextWithState(input, agent) + .pipe(toArray()) + .toPromise(); + + expect(results).toHaveLength(3); + + // After TEXT_MESSAGE_START, should have one empty message + expect(results![0].messages).toHaveLength(1); + expect(results![0].messages[0].id).toBe("msg1"); + expect(results![0].messages[0].role).toBe("assistant"); + expect(results![0].messages[0].content).toBe(""); + + // After first content chunk + expect(results![1].messages).toHaveLength(1); + expect(results![1].messages[0].content).toBe("Hello"); + + // After second content chunk + expect(results![2].messages).toHaveLength(1); + expect(results![2].messages[0].content).toBe("Hello world"); + }); + + it("should track state changes", async () => { + const events: BaseEvent[] = [ + { + type: EventType.STATE_SNAPSHOT, + snapshot: { counter: 0, name: "test" }, + } as StateSnapshotEvent, + { + type: EventType.STATE_DELTA, + delta: [{ op: "replace", path: "/counter", value: 1 }], + } as StateDeltaEvent, + { + type: EventType.STATE_DELTA, + delta: [{ op: "add", path: "/newField", value: "added" }], + } as StateDeltaEvent, + ]; + + const agent = new MockAgent(events); + const middleware = new TestMiddleware(); + const input: RunAgentInput = { messages: [], state: {} }; + + const results = await middleware + .testRunNextWithState(input, agent) + .pipe(toArray()) + .toPromise(); + + expect(results).toHaveLength(3); + + // After STATE_SNAPSHOT + expect(results![0].state).toEqual({ counter: 0, name: "test" }); + + // After first STATE_DELTA + expect(results![1].state).toEqual({ counter: 1, name: "test" }); + + // After second STATE_DELTA + expect(results![2].state).toEqual({ + counter: 1, + name: "test", + newField: "added", + }); + }); + + it("should handle MESSAGES_SNAPSHOT", async () => { + const events: BaseEvent[] = [ + { + type: EventType.TEXT_MESSAGE_START, + messageId: "msg1", + role: "user", + } as TextMessageStartEvent, + { + type: EventType.TEXT_MESSAGE_CONTENT, + messageId: "msg1", + delta: "First", + } as TextMessageContentEvent, + { + type: EventType.MESSAGES_SNAPSHOT, + messages: [ + { id: "old1", role: "assistant", content: "Previous message" }, + { id: "old2", role: "user", content: "Another message" }, + ], + } as MessagesSnapshotEvent, + ]; + + const agent = new MockAgent(events); + const middleware = new TestMiddleware(); + const input: RunAgentInput = { messages: [], state: {} }; + + const results = await middleware + .testRunNextWithState(input, agent) + .pipe(toArray()) + .toPromise(); + + expect(results).toHaveLength(3); + + // After building a message + expect(results![1].messages).toHaveLength(1); + expect(results![1].messages[0].content).toBe("First"); + + // After MESSAGES_SNAPSHOT - replaces all messages + expect(results![2].messages).toHaveLength(2); + expect(results![2].messages[0].id).toBe("old1"); + expect(results![2].messages[1].id).toBe("old2"); + }); + + it("should track tool calls", async () => { + const events: BaseEvent[] = [ + { + type: EventType.TOOL_CALL_START, + toolCallId: "tool1", + toolCallName: "calculator", + parentMessageId: "msg1", + } as ToolCallStartEvent, + { + type: EventType.TOOL_CALL_ARGS, + toolCallId: "tool1", + delta: '{"operation": "add"', + } as ToolCallArgsEvent, + { + type: EventType.TOOL_CALL_ARGS, + toolCallId: "tool1", + delta: ', "values": [1, 2]}', + } as ToolCallArgsEvent, + { + type: EventType.TOOL_CALL_END, + toolCallId: "tool1", + } as ToolCallEndEvent, + ]; + + const agent = new MockAgent(events); + const middleware = new TestMiddleware(); + const input: RunAgentInput = { messages: [], state: {} }; + + const results = await middleware + .testRunNextWithState(input, agent) + .pipe(toArray()) + .toPromise(); + + expect(results).toHaveLength(4); + + // After TOOL_CALL_START + expect(results![0].messages).toHaveLength(1); + expect(results![0].messages[0].role).toBe("assistant"); + const msg1 = results![0].messages[0] as any; + expect(msg1.toolCalls).toHaveLength(1); + expect(msg1.toolCalls[0].id).toBe("tool1"); + expect(msg1.toolCalls[0].type).toBe("function"); + expect(msg1.toolCalls[0].function.name).toBe("calculator"); + + // After args accumulation + const msg3 = results![2].messages[0] as any; + expect(msg3.toolCalls[0].function.arguments).toBe('{"operation": "add", "values": [1, 2]}'); + + // After TOOL_CALL_END - args remain as string (defaultApplyEvents doesn't parse them) + const msg4 = results![3].messages[0] as any; + expect(msg4.toolCalls[0].function.arguments).toBe('{"operation": "add", "values": [1, 2]}'); + }); + + it("should preserve initial state and messages", async () => { + const events: BaseEvent[] = [ + { + type: EventType.TEXT_MESSAGE_START, + messageId: "new1", + role: "assistant", + } as TextMessageStartEvent, + { + type: EventType.STATE_DELTA, + delta: [{ op: "add", path: "/newField", value: 42 }], + } as StateDeltaEvent, + ]; + + const agent = new MockAgent(events); + const middleware = new TestMiddleware(); + + const input: RunAgentInput = { + messages: [ + { id: "existing1", role: "user", content: "Existing message" }, + ], + state: { existingField: "hello" }, + }; + + const results = await middleware + .testRunNextWithState(input, agent) + .pipe(toArray()) + .toPromise(); + + expect(results).toHaveLength(2); + + // Should preserve existing message and add new one + expect(results![0].messages).toHaveLength(2); + expect(results![0].messages[0].id).toBe("existing1"); + expect(results![0].messages[1].id).toBe("new1"); + + // Should preserve existing state and add new field + expect(results![1].state).toEqual({ + existingField: "hello", + newField: 42, + }); + }); + + it("should provide immutable snapshots", async () => { + const events: BaseEvent[] = [ + { + type: EventType.TEXT_MESSAGE_START, + messageId: "msg1", + role: "assistant", + } as TextMessageStartEvent, + { + type: EventType.STATE_SNAPSHOT, + snapshot: { value: 1 }, + } as StateSnapshotEvent, + ]; + + const agent = new MockAgent(events); + const middleware = new TestMiddleware(); + const input: RunAgentInput = { messages: [], state: {} }; + + const results = await middleware + .testRunNextWithState(input, agent) + .pipe(toArray()) + .toPromise(); + + // Modify returned state/messages - should not affect next results + results![0].messages[0].content = "MODIFIED"; + results![0].state.hacked = true; + + // Second result should not be affected + expect(results![1].messages[0].content).toBe(""); + expect(results![1].state).toEqual({ value: 1 }); + expect(results![1].state.hacked).toBeUndefined(); + }); +}); \ No newline at end of file diff --git a/typescript-sdk/packages/client/src/middleware/middleware.ts b/typescript-sdk/packages/client/src/middleware/middleware.ts index 0fa7f789f..bc713c54e 100644 --- a/typescript-sdk/packages/client/src/middleware/middleware.ts +++ b/typescript-sdk/packages/client/src/middleware/middleware.ts @@ -1,11 +1,75 @@ import { AbstractAgent } from "@/agent"; -import { RunAgentInput, BaseEvent } from "@ag-ui/core"; -import { Observable } from "rxjs"; +import { RunAgentInput, BaseEvent, Message } from "@ag-ui/core"; +import { Observable, ReplaySubject } from "rxjs"; +import { concatMap } from "rxjs/operators"; +import { transformChunks } from "@/chunks"; +import { defaultApplyEvents } from "@/apply"; +import { structuredClone_ } from "@/utils"; export type MiddlewareFunction = (input: RunAgentInput, next: AbstractAgent) => Observable; +export interface EventWithState { + event: BaseEvent; + messages: Message[]; + state: any; +} + export abstract class Middleware { abstract run(input: RunAgentInput, next: AbstractAgent): Observable; + + /** + * Runs the next agent in the chain with automatic chunk transformation. + */ + protected runNext(input: RunAgentInput, next: AbstractAgent): Observable { + return next.run(input).pipe( + transformChunks(false) // Always transform chunks to full events + ); + } + + /** + * Runs the next agent and tracks state, providing current messages and state with each event. + * The messages and state represent the state AFTER the event has been applied. + */ + protected runNextWithState( + input: RunAgentInput, + next: AbstractAgent + ): Observable { + let currentMessages = structuredClone_(input.messages || []); + let currentState = structuredClone_(input.state || {}); + + // Use a ReplaySubject to feed events one by one + const eventSubject = new ReplaySubject(); + + // Set up defaultApplyEvents to process events + const mutations$ = defaultApplyEvents(input, eventSubject, next, []); + + // Subscribe to track state changes + mutations$.subscribe(mutation => { + if (mutation.messages !== undefined) { + currentMessages = mutation.messages; + } + if (mutation.state !== undefined) { + currentState = mutation.state; + } + }); + + return this.runNext(input, next).pipe( + concatMap(async event => { + // Feed the event to defaultApplyEvents and wait for it to process + eventSubject.next(event); + + // Give defaultApplyEvents a chance to process + await new Promise(resolve => setTimeout(resolve, 0)); + + // Return event with current state + return { + event, + messages: structuredClone_(currentMessages), + state: structuredClone_(currentState) + }; + }) + ); + } } // Wrapper class to convert a function into a Middleware instance From e765961a468f7c8bb2f11e3be73f0ed545dbde97 Mon Sep 17 00:00:00 2001 From: Markus Ecker Date: Thu, 25 Sep 2025 23:56:58 +0200 Subject: [PATCH 4/5] update tests --- .../packages/client/src/middleware/filter-tool-calls.ts | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/typescript-sdk/packages/client/src/middleware/filter-tool-calls.ts b/typescript-sdk/packages/client/src/middleware/filter-tool-calls.ts index 86ade3bfd..cbb32786d 100644 --- a/typescript-sdk/packages/client/src/middleware/filter-tool-calls.ts +++ b/typescript-sdk/packages/client/src/middleware/filter-tool-calls.ts @@ -3,7 +3,6 @@ import { AbstractAgent } from "@/agent"; import { RunAgentInput, BaseEvent, EventType, ToolCallStartEvent, ToolCallArgsEvent, ToolCallEndEvent, ToolCallResultEvent } from "@ag-ui/core"; import { Observable } from "rxjs"; import { filter } from "rxjs/operators"; -import { transformChunks } from "@/chunks"; type FilterToolCallsConfig = | { allowedToolCalls: string[]; disallowedToolCalls?: never } @@ -34,9 +33,8 @@ export class FilterToolCallsMiddleware extends Middleware { } public run(input: RunAgentInput, next: AbstractAgent): Observable { - // Apply transformChunks first to convert TOOL_CALL_CHUNK events - return next.run(input).pipe( - transformChunks(false), + // Use runNext which already includes transformChunks + return this.runNext(input, next).pipe( filter((event) => { // Handle TOOL_CALL_START events if (event.type === EventType.TOOL_CALL_START) { From 1ed65d61e87cab9bb65f14f29304fe610a162ce3 Mon Sep 17 00:00:00 2001 From: Markus Ecker Date: Fri, 31 Oct 2025 15:28:25 +0100 Subject: [PATCH 5/5] fix unit tests --- .../packages/client/src/agent/agent.ts | 20 ++++++++----- .../__tests__/function-middleware.test.ts | 25 ++++++++++------ .../__tests__/middleware-live-events.test.ts | 30 ++++++++++++------- .../__tests__/middleware-with-state.test.ts | 8 +++-- .../middleware/__tests__/middleware.test.ts | 29 ++++++++++-------- 5 files changed, 70 insertions(+), 42 deletions(-) diff --git a/sdks/typescript/packages/client/src/agent/agent.ts b/sdks/typescript/packages/client/src/agent/agent.ts index 1198bf2d5..acfabc64c 100644 --- a/sdks/typescript/packages/client/src/agent/agent.ts +++ b/sdks/typescript/packages/client/src/agent/agent.ts @@ -30,6 +30,8 @@ export abstract class AbstractAgent { public subscribers: AgentSubscriber[] = []; private middlewares: Middleware[] = []; + public readonly maxVersion: string = "*"; + constructor({ agentId, description, @@ -94,10 +96,11 @@ export abstract class AbstractAgent { } const chainedAgent = this.middlewares.reduceRight( - (nextAgent: AbstractAgent, middleware) => ({ - run: (i: RunAgentInput) => middleware.run(i, nextAgent), - } as AbstractAgent), - this // Original agent is the final 'next' + (nextAgent: AbstractAgent, middleware) => + ({ + run: (i: RunAgentInput) => middleware.run(i, nextAgent), + }) as AbstractAgent, + this, // Original agent is the final 'next' ); return chainedAgent.run(input); @@ -447,10 +450,11 @@ export abstract class AbstractAgent { } const chainedAgent = this.middlewares.reduceRight( - (nextAgent: AbstractAgent, middleware) => ({ - run: (i: RunAgentInput) => middleware.run(i, nextAgent) - } as AbstractAgent), - this + (nextAgent: AbstractAgent, middleware) => + ({ + run: (i: RunAgentInput) => middleware.run(i, nextAgent), + }) as AbstractAgent, + this, ); return chainedAgent.run(input); diff --git a/sdks/typescript/packages/client/src/middleware/__tests__/function-middleware.test.ts b/sdks/typescript/packages/client/src/middleware/__tests__/function-middleware.test.ts index 9252030e9..11f94035e 100644 --- a/sdks/typescript/packages/client/src/middleware/__tests__/function-middleware.test.ts +++ b/sdks/typescript/packages/client/src/middleware/__tests__/function-middleware.test.ts @@ -39,25 +39,31 @@ describe("FunctionMiddleware", () => { const middlewareFn: MiddlewareFunction = (middlewareInput, next) => { return new Observable((subscriber) => { - subscriber.next({ - type: EventType.RUN_STARTED, - threadId: middlewareInput.threadId, - runId: middlewareInput.runId, - }); - - next.run(middlewareInput).subscribe({ + const subscription = next.run(middlewareInput).subscribe({ next: (event) => { + if (event.type === EventType.RUN_STARTED) { + subscriber.next({ + ...event, + metadata: { ...(event as any).metadata, fromMiddleware: true }, + }); + return; + } + if (event.type === EventType.RUN_FINISHED) { subscriber.next({ ...event, result: { success: true }, }); - } else { - subscriber.next(event); + return; } + + subscriber.next(event); }, + error: (error) => subscriber.error(error), complete: () => subscriber.complete(), }); + + return () => subscription.unsubscribe(); }); }; @@ -73,6 +79,7 @@ describe("FunctionMiddleware", () => { expect(events.length).toBe(2); expect(events[0].type).toBe(EventType.RUN_STARTED); + expect((events[0] as any).metadata).toEqual({ fromMiddleware: true }); expect(events[1].type).toBe(EventType.RUN_FINISHED); expect((events[1] as any).result).toEqual({ success: true }); }); diff --git a/sdks/typescript/packages/client/src/middleware/__tests__/middleware-live-events.test.ts b/sdks/typescript/packages/client/src/middleware/__tests__/middleware-live-events.test.ts index d37f16b9c..bedb8f8c3 100644 --- a/sdks/typescript/packages/client/src/middleware/__tests__/middleware-live-events.test.ts +++ b/sdks/typescript/packages/client/src/middleware/__tests__/middleware-live-events.test.ts @@ -42,17 +42,27 @@ describe("Middleware live events", () => { class CustomMiddleware extends Middleware { run(input: RunAgentInput, next: AbstractAgent): Observable { return new Observable((subscriber) => { - subscriber.next({ - type: EventType.RUN_STARTED, - threadId: input.threadId, - runId: input.runId, - metadata: { custom: true }, - } as RunStartedEvent); + const subscription = next.run(input).subscribe({ + next: (event) => { + if (event.type === EventType.RUN_STARTED) { + const started = event as RunStartedEvent; + subscriber.next({ + ...started, + metadata: { + ...(started.metadata ?? {}), + custom: true, + }, + }); + return; + } - next.run(input).subscribe({ - next: (event) => subscriber.next(event), + subscriber.next(event); + }, + error: (error) => subscriber.error(error), complete: () => subscriber.complete(), }); + + return () => subscription.unsubscribe(); }); } } @@ -82,7 +92,7 @@ describe("Middleware live events", () => { expect(events.length).toBe(3); expect(events[0].type).toBe(EventType.RUN_STARTED); expect((events[0] as RunStartedEvent).metadata).toEqual({ custom: true }); - expect(events[1].type).toBe(EventType.RUN_STARTED); - expect(events[2].type).toBe(EventType.TEXT_MESSAGE_CHUNK); + expect(events[1].type).toBe(EventType.TEXT_MESSAGE_CHUNK); + expect(events[2].type).toBe(EventType.RUN_FINISHED); }); }); diff --git a/sdks/typescript/packages/client/src/middleware/__tests__/middleware-with-state.test.ts b/sdks/typescript/packages/client/src/middleware/__tests__/middleware-with-state.test.ts index 03e2e782c..6ac9573c8 100644 --- a/sdks/typescript/packages/client/src/middleware/__tests__/middleware-with-state.test.ts +++ b/sdks/typescript/packages/client/src/middleware/__tests__/middleware-with-state.test.ts @@ -73,9 +73,11 @@ describe("Middleware runNextWithState", () => { }); }); - expect(events.length).toBe(3); + expect(events.length).toBe(5); expect(events[0].type).toBe(EventType.RUN_STARTED); - expect(events[1].type).toBe(EventType.TEXT_MESSAGE_CHUNK); - expect(events[2].type).toBe(EventType.RUN_FINISHED); + expect(events[1].type).toBe(EventType.TEXT_MESSAGE_START); + expect(events[2].type).toBe(EventType.TEXT_MESSAGE_CONTENT); + expect(events[3].type).toBe(EventType.TEXT_MESSAGE_END); + expect(events[4].type).toBe(EventType.RUN_FINISHED); }); }); diff --git a/sdks/typescript/packages/client/src/middleware/__tests__/middleware.test.ts b/sdks/typescript/packages/client/src/middleware/__tests__/middleware.test.ts index d25e217ae..a07cfb870 100644 --- a/sdks/typescript/packages/client/src/middleware/__tests__/middleware.test.ts +++ b/sdks/typescript/packages/client/src/middleware/__tests__/middleware.test.ts @@ -28,17 +28,23 @@ describe("Middleware", () => { class TestMiddleware extends Middleware { run(input: RunAgentInput, next: AbstractAgent): Observable { return new Observable((subscriber) => { - subscriber.next({ - type: EventType.RUN_STARTED, - threadId: input.threadId, - runId: input.runId, - metadata: { middleware: true }, - }); + const subscription = next.run(input).subscribe({ + next: (event) => { + if (event.type === EventType.RUN_STARTED) { + subscriber.next({ + ...event, + metadata: { ...(event as any).metadata, middleware: true }, + }); + return; + } - next.run(input).subscribe({ - next: (event) => subscriber.next(event), + subscriber.next(event); + }, + error: (error) => subscriber.error(error), complete: () => subscriber.complete(), }); + + return () => subscription.unsubscribe(); }); } } @@ -65,11 +71,10 @@ describe("Middleware", () => { }); }); - expect(events.length).toBe(3); + expect(events.length).toBe(2); expect(events[0].type).toBe(EventType.RUN_STARTED); expect((events[0] as any).metadata).toEqual({ middleware: true }); - expect(events[1].type).toBe(EventType.RUN_STARTED); - expect(events[2].type).toBe(EventType.RUN_FINISHED); - expect((events[2] as any).result).toEqual({ success: true }); + expect(events[1].type).toBe(EventType.RUN_FINISHED); + expect((events[1] as any).result).toEqual({ success: true }); }); });