diff --git a/packages/core/examples/agent_stream_example.ts b/packages/core/examples/agent_stream_example.ts new file mode 100644 index 000000000..3b99c1477 --- /dev/null +++ b/packages/core/examples/agent_stream_example.ts @@ -0,0 +1,47 @@ +import { Stagehand } from "../lib/v3"; +import dotenv from "dotenv"; +import chalk from "chalk"; + +// Load environment variables +dotenv.config(); +async function main() { + console.log(`\n${chalk.bold("Stagehand 🤘 Agent Streaming Example")}\n`); + // Initialize Stagehand + const stagehand = new Stagehand({ + env: "LOCAL", + verbose: 2, + cacheDir: "stagehand-agent-cache", + logInferenceToFile: false, + experimental: true, + }); + + await stagehand.init(); + + try { + const page = stagehand.context.pages()[0]; + await page.goto("https://amazon.com"); + const agent = stagehand.agent({ + model: "anthropic/claude-sonnet-4-5-20250929", + executionModel: "google/gemini-2.5-flash", + }); + + const result = await agent.stream({ + instruction: "go to amazon, and seach for shampoo, stop after searching", + maxSteps: 20, + }); + // stream the text + for await (const delta of result.textStream) { + process.stdout.write(delta); + } + // stream everything ( toolcalls, messages, etc.) + // for await (const delta of result.fullStream) { + // console.log(delta); + // } + + const finalResult = await result.result; + console.log("Final Result:", finalResult); + } catch (error) { + console.log(`${chalk.red("✗")} Error: ${error}`); + } +} +main(); diff --git a/packages/core/lib/v3/handlers/v3AgentHandler.ts b/packages/core/lib/v3/handlers/v3AgentHandler.ts index 503fc78e1..47906ca84 100644 --- a/packages/core/lib/v3/handlers/v3AgentHandler.ts +++ b/packages/core/lib/v3/handlers/v3AgentHandler.ts @@ -1,13 +1,22 @@ import { createAgentTools } from "../agent/tools"; import { LogLine } from "../types/public/logs"; import { V3 } from "../v3"; -import { ModelMessage, ToolSet, wrapLanguageModel, stepCountIs } from "ai"; +import { + ModelMessage, + ToolSet, + wrapLanguageModel, + stepCountIs, + type LanguageModelUsage, + type StepResult, +} from "ai"; import { processMessages } from "../agent/utils/messageProcessing"; import { LLMClient } from "../llm/LLMClient"; import { - AgentAction, AgentExecuteOptions, AgentResult, + AgentContext, + AgentState, + AgentStreamResult, } from "../types/public/agent"; import { V3FunctionName } from "../types/public/methods"; import { mapToolResultToActions } from "../agent/utils/actionMapping"; @@ -37,139 +46,138 @@ export class V3AgentHandler { this.mcpTools = mcpTools; } - public async execute( + private async prepareAgent( instructionOrOptions: string | AgentExecuteOptions, - ): Promise { - const startTime = Date.now(); + ): Promise { const options = typeof instructionOrOptions === "string" ? { instruction: instructionOrOptions } : instructionOrOptions; - const maxSteps = options.maxSteps || 10; - const actions: AgentAction[] = []; - let finalMessage = ""; - let completed = false; - const collectedReasoning: string[] = []; + const maxSteps = options.maxSteps || 20; - let currentPageUrl = (await this.v3.context.awaitActivePage()).url(); + const systemPrompt = this.buildSystemPrompt( + options.instruction, + this.systemInstructions, + ); + const tools = this.createTools(); + const allTools: ToolSet = { ...tools, ...this.mcpTools }; + const messages: ModelMessage[] = [ + { role: "user", content: options.instruction }, + ]; - try { - const systemPrompt = this.buildSystemPrompt( - options.instruction, - this.systemInstructions, - ); - const tools = this.createTools(); - const allTools = { ...tools, ...this.mcpTools }; - const messages: ModelMessage[] = [ - { role: "user", content: options.instruction }, - ]; - - if (!this.llmClient?.getLanguageModel) { - throw new MissingLLMConfigurationError(); - } - const baseModel = this.llmClient.getLanguageModel(); - const wrappedModel = wrapLanguageModel({ - model: baseModel, - middleware: { - transformParams: async ({ params }) => { - const { processedPrompt } = processMessages(params); - return { ...params, prompt: processedPrompt } as typeof params; - }, + if (!this.llmClient?.getLanguageModel) { + throw new MissingLLMConfigurationError(); + } + const baseModel = this.llmClient.getLanguageModel(); + const wrappedModel = wrapLanguageModel({ + model: baseModel, + middleware: { + transformParams: async ({ params }) => { + const { processedPrompt } = processMessages(params); + return { ...params, prompt: processedPrompt } as typeof params; }, + }, + }); + + const initialPageUrl = (await this.v3.context.awaitActivePage()).url(); + + return { + options, + maxSteps, + systemPrompt, + allTools, + messages, + wrappedModel, + initialPageUrl, + }; + } + + private createStepHandler(state: AgentState) { + return async (event: StepResult) => { + this.logger({ + category: "agent", + message: `Step finished: ${event.finishReason}`, + level: 2, }); + if (event.toolCalls && event.toolCalls.length > 0) { + for (let i = 0; i < event.toolCalls.length; i++) { + const toolCall = event.toolCalls[i]; + const args = toolCall.input; + const toolResult = event.toolResults?.[i]; + + if (event.text && event.text.length > 0) { + state.collectedReasoning.push(event.text); + this.logger({ + category: "agent", + message: `reasoning: ${event.text}`, + level: 1, + }); + } + + if (toolCall.toolName === "close") { + state.completed = true; + if (args?.taskComplete) { + const closeReasoning = args.reasoning; + const allReasoning = state.collectedReasoning.join(" "); + state.finalMessage = closeReasoning + ? `${allReasoning} ${closeReasoning}`.trim() + : allReasoning || "Task completed successfully"; + } + } + const mappedActions = mapToolResultToActions({ + toolCallName: toolCall.toolName, + toolResult, + args, + reasoning: event.text || undefined, + }); + + for (const action of mappedActions) { + action.pageUrl = state.currentPageUrl; + action.timestamp = Date.now(); + state.actions.push(action); + } + } + state.currentPageUrl = (await this.v3.context.awaitActivePage()).url(); + } + }; + } + + public async execute( + instructionOrOptions: string | AgentExecuteOptions, + ): Promise { + const startTime = Date.now(); + const { + maxSteps, + systemPrompt, + allTools, + messages, + wrappedModel, + initialPageUrl, + } = await this.prepareAgent(instructionOrOptions); + + const state: AgentState = { + collectedReasoning: [], + actions: [], + finalMessage: "", + completed: false, + currentPageUrl: initialPageUrl, + }; + + try { const result = await this.llmClient.generateText({ model: wrappedModel, system: systemPrompt, messages, tools: allTools, - stopWhen: stepCountIs(maxSteps), + stopWhen: (result) => this.handleStop(result, maxSteps), temperature: 1, toolChoice: "auto", - onStepFinish: async (event) => { - this.logger({ - category: "agent", - message: `Step finished: ${event.finishReason}`, - level: 2, - }); - - if (event.toolCalls && event.toolCalls.length > 0) { - for (let i = 0; i < event.toolCalls.length; i++) { - const toolCall = event.toolCalls[i]; - const args = toolCall.input as Record; - const toolResult = event.toolResults?.[i]; - - if (event.text.length > 0) { - collectedReasoning.push(event.text); - this.logger({ - category: "agent", - message: `reasoning: ${event.text}`, - level: 1, - }); - } - - if (toolCall.toolName === "close") { - completed = true; - if (args?.taskComplete) { - const closeReasoning = args.reasoning; - const allReasoning = collectedReasoning.join(" "); - finalMessage = closeReasoning - ? `${allReasoning} ${closeReasoning}`.trim() - : allReasoning || "Task completed successfully"; - } - } - const mappedActions = mapToolResultToActions({ - toolCallName: toolCall.toolName, - toolResult, - args, - reasoning: event.text || undefined, - }); - - for (const action of mappedActions) { - action.pageUrl = currentPageUrl; - action.timestamp = Date.now(); - actions.push(action); - } - } - currentPageUrl = (await this.v3.context.awaitActivePage()).url(); - } - }, + onStepFinish: this.createStepHandler(state), }); - if (!finalMessage) { - const allReasoning = collectedReasoning.join(" ").trim(); - finalMessage = allReasoning || result.text; - } - - const endTime = Date.now(); - const inferenceTimeMs = endTime - startTime; - if (result.usage) { - this.v3.updateMetrics( - V3FunctionName.AGENT, - result.usage.inputTokens || 0, - result.usage.outputTokens || 0, - result.usage.reasoningTokens || 0, - result.usage.cachedInputTokens || 0, - inferenceTimeMs, - ); - } - - return { - success: completed, - message: finalMessage || "Task execution completed", - actions, - completed, - usage: result.usage - ? { - input_tokens: result.usage.inputTokens || 0, - output_tokens: result.usage.outputTokens || 0, - reasoning_tokens: result.usage.reasoningTokens || 0, - cached_input_tokens: result.usage.cachedInputTokens || 0, - inference_time_ms: inferenceTimeMs, - } - : undefined, - }; + return this.consolidateMetricsAndResult(startTime, state, result); } catch (error) { const errorMessage = error?.message ?? String(error); this.logger({ @@ -179,13 +187,109 @@ export class V3AgentHandler { }); return { success: false, - actions, + actions: state.actions, message: `Failed to execute task: ${errorMessage}`, completed: false, }; } } + public async stream( + instructionOrOptions: string | AgentExecuteOptions, + ): Promise { + const { + maxSteps, + systemPrompt, + allTools, + messages, + wrappedModel, + initialPageUrl, + } = await this.prepareAgent(instructionOrOptions); + + const state: AgentState = { + collectedReasoning: [], + actions: [], + finalMessage: "", + completed: false, + currentPageUrl: initialPageUrl, + }; + const startTime = Date.now(); + + let resolveResult: (value: AgentResult | PromiseLike) => void; + let rejectResult: (reason?: string) => void; + const resultPromise = new Promise((resolve, reject) => { + resolveResult = resolve; + rejectResult = reject; + }); + + const streamResult = this.llmClient.streamText({ + model: wrappedModel, + system: systemPrompt, + messages, + tools: allTools, + stopWhen: (result) => this.handleStop(result, maxSteps), + temperature: 1, + toolChoice: "auto", + onStepFinish: this.createStepHandler(state), + onFinish: (event) => { + try { + const result = this.consolidateMetricsAndResult( + startTime, + state, + event, + ); + resolveResult(result); + } catch (error) { + rejectResult(error); + } + }, + }); + + const agentStreamResult = streamResult as AgentStreamResult; + agentStreamResult.result = resultPromise; + return agentStreamResult; + } + + private consolidateMetricsAndResult( + startTime: number, + state: AgentState, + result: { text?: string; usage?: LanguageModelUsage }, + ): AgentResult { + if (!state.finalMessage) { + const allReasoning = state.collectedReasoning.join(" ").trim(); + state.finalMessage = allReasoning || result.text || ""; + } + + const endTime = Date.now(); + const inferenceTimeMs = endTime - startTime; + if (result.usage) { + this.v3.updateMetrics( + V3FunctionName.AGENT, + result.usage.inputTokens || 0, + result.usage.outputTokens || 0, + result.usage.reasoningTokens || 0, + result.usage.cachedInputTokens || 0, + inferenceTimeMs, + ); + } + + return { + success: state.completed, + message: state.finalMessage || "Task execution completed", + actions: state.actions, + completed: state.completed, + usage: result.usage + ? { + input_tokens: result.usage.inputTokens || 0, + output_tokens: result.usage.outputTokens || 0, + reasoning_tokens: result.usage.reasoningTokens || 0, + cached_input_tokens: result.usage.cachedInputTokens || 0, + inference_time_ms: inferenceTimeMs, + } + : undefined, + }; + } + private buildSystemPrompt( executionInstruction: string, systemInstructions?: string, @@ -202,4 +306,15 @@ export class V3AgentHandler { logger: this.logger, }); } + + private handleStop( + result: Parameters>[0], + maxSteps: number, + ): boolean | PromiseLike { + const lastStep = result.steps[result.steps.length - 1]; + if (lastStep?.toolCalls?.some((tc) => tc.toolName === "close")) { + return true; + } + return stepCountIs(maxSteps)(result); + } } diff --git a/packages/core/lib/v3/types/public/agent.ts b/packages/core/lib/v3/types/public/agent.ts index 9ce9bcb21..8ed405cc5 100644 --- a/packages/core/lib/v3/types/public/agent.ts +++ b/packages/core/lib/v3/types/public/agent.ts @@ -1,11 +1,29 @@ import type { Client } from "@modelcontextprotocol/sdk/client/index.js"; -import { ToolSet } from "ai"; +import { ToolSet, ModelMessage, wrapLanguageModel, StreamTextResult } from "ai"; import { LogLine } from "./logs"; import { Page as PlaywrightPage } from "playwright-core"; import { Page as PuppeteerPage } from "puppeteer-core"; import { Page as PatchrightPage } from "patchright-core"; import { Page } from "../../understudy/page"; +export interface AgentContext { + options: AgentExecuteOptions; + maxSteps: number; + systemPrompt: string; + allTools: ToolSet; + messages: ModelMessage[]; + wrappedModel: ReturnType; + initialPageUrl: string; +} + +export interface AgentState { + collectedReasoning: string[]; + actions: AgentAction[]; + finalMessage: string; + completed: boolean; + currentPageUrl: string; +} + export interface AgentAction { type: string; reasoning?: string; @@ -34,6 +52,10 @@ export interface AgentResult { }; } +export type AgentStreamResult = StreamTextResult & { + result: Promise; +}; + export interface AgentExecuteOptions { instruction: string; maxSteps?: number; diff --git a/packages/core/lib/v3/v3.ts b/packages/core/lib/v3/v3.ts index 3c795d8ac..da25efeed 100644 --- a/packages/core/lib/v3/v3.ts +++ b/packages/core/lib/v3/v3.ts @@ -65,6 +65,7 @@ import { StagehandNotInitializedError, MissingEnvironmentVariableError, StagehandInitError, + AgentStreamResult, } from "./types/public"; import { V3Context } from "./understudy/context"; import { Page } from "./understudy/page"; @@ -1498,6 +1499,9 @@ export class V3 { execute: ( instructionOrOptions: string | AgentExecuteOptions, ) => Promise; + stream?: ( + instructionOrOptions: string | AgentExecuteOptions, + ) => Promise; } { this.logger({ category: "agent", @@ -1735,6 +1739,38 @@ export class V3 { } } }), + stream: async (instructionOrOptions: string | AgentExecuteOptions) => + withInstanceLogContext(this.instanceId, async () => { + if (!this.experimental) { + throw new ExperimentalNotConfiguredError("Agent streaming"); + } + if ((options?.integrations || options?.tools) && !this.experimental) { + throw new ExperimentalNotConfiguredError( + "MCP integrations and custom tools", + ); + } + + const tools = options?.integrations + ? await resolveTools(options.integrations, options.tools) + : (options?.tools ?? {}); + + const agentLlmClient = options?.model + ? this.resolveLlmClient(options.model) + : this.llmClient; + + const handler = new V3AgentHandler( + this, + this.logger, + agentLlmClient, + typeof options?.executionModel === "string" + ? options.executionModel + : options?.executionModel?.modelName, + options?.systemPrompt, + tools, + ); + + return handler.stream(instructionOrOptions); + }), }; } }