diff --git a/.changeset/long-keys-watch.md b/.changeset/long-keys-watch.md new file mode 100644 index 00000000000..fb13ed74987 --- /dev/null +++ b/.changeset/long-keys-watch.md @@ -0,0 +1,6 @@ +--- +'firebase': minor +'@firebase/ai': minor +--- + +Add support for `AbortSignal`, allowing requests to be aborted. diff --git a/.vscode/launch.json b/.vscode/launch.json index 8f132cbe5c6..55badac87e6 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -37,7 +37,7 @@ "src/index.node.ts", "--timeout", "5000", - "integration/**/*.test.ts" + "integration/**/prompt-templates.test.ts" ], "env": { "TS_NODE_COMPILER_OPTIONS": "{\"module\":\"commonjs\"}" diff --git a/common/api-review/ai.api.md b/common/api-review/ai.api.md index 2bf194fbaf2..c5a180e0824 100644 --- a/common/api-review/ai.api.md +++ b/common/api-review/ai.api.md @@ -150,8 +150,8 @@ export class ChatSession { params?: StartChatParams | undefined; // (undocumented) requestOptions?: RequestOptions | undefined; - sendMessage(request: string | Array): Promise; - sendMessageStream(request: string | Array): Promise; + sendMessage(request: string | Array, singleRequestOptions?: SingleRequestOptions): Promise; + sendMessageStream(request: string | Array, singleRequestOptions?: SingleRequestOptions): Promise; } // @beta @@ -539,9 +539,9 @@ export interface GenerativeContentBlob { export class GenerativeModel extends AIModel { // Warning: (ae-incompatible-release-tags) The symbol "__constructor" is marked as @public, but its signature references "ChromeAdapter" which is marked as @beta constructor(ai: AI, modelParams: ModelParams, requestOptions?: RequestOptions, chromeAdapter?: ChromeAdapter | undefined); - countTokens(request: CountTokensRequest | string | Array): Promise; - generateContent(request: GenerateContentRequest | string | Array): Promise; - generateContentStream(request: GenerateContentRequest | string | Array): Promise; + countTokens(request: CountTokensRequest | string | Array, singleRequestOptions?: SingleRequestOptions): Promise; + generateContent(request: GenerateContentRequest | string | Array, singleRequestOptions?: SingleRequestOptions): Promise; + generateContentStream(request: GenerateContentRequest | string | Array, singleRequestOptions?: SingleRequestOptions): Promise; // (undocumented) generationConfig: GenerationConfig; // (undocumented) @@ -784,9 +784,9 @@ export interface ImagenInlineImage { // @public export class ImagenModel extends AIModel { constructor(ai: AI, modelParams: ImagenModelParams, requestOptions?: RequestOptions | undefined); - generateImages(prompt: string): Promise>; + generateImages(prompt: string, singleRequestOptions?: SingleRequestOptions): Promise>; // @internal - generateImagesGCS(prompt: string, gcsURI: string): Promise>; + generateImagesGCS(prompt: string, gcsURI: string, singleRequestOptions?: SingleRequestOptions): Promise>; generationConfig?: ImagenGenerationConfig; // (undocumented) requestOptions?: RequestOptions | undefined; @@ -1294,6 +1294,11 @@ export interface Segment { text: string; } +// @public +export interface SingleRequestOptions extends RequestOptions { + signal?: AbortSignal; +} + // @beta export interface SpeechConfig { voiceConfig?: VoiceConfig; @@ -1333,8 +1338,9 @@ export class TemplateGenerativeModel { constructor(ai: AI, requestOptions?: RequestOptions); // @internal (undocumented) _apiSettings: ApiSettings; - generateContent(templateId: string, templateVariables: object): Promise; - generateContentStream(templateId: string, templateVariables: object): Promise; + generateContent(templateId: string, templateVariables: object, // anything! + singleRequestOptions?: SingleRequestOptions): Promise; + generateContentStream(templateId: string, templateVariables: object, singleRequestOptions?: SingleRequestOptions): Promise; requestOptions?: RequestOptions; } @@ -1343,7 +1349,7 @@ export class TemplateImagenModel { constructor(ai: AI, requestOptions?: RequestOptions); // @internal (undocumented) _apiSettings: ApiSettings; - generateImages(templateId: string, templateVariables: object): Promise>; + generateImages(templateId: string, templateVariables: object, singleRequestOptions?: SingleRequestOptions): Promise>; requestOptions?: RequestOptions; } diff --git a/docs-devsite/_toc.yaml b/docs-devsite/_toc.yaml index 92633c553a3..06a976686f9 100644 --- a/docs-devsite/_toc.yaml +++ b/docs-devsite/_toc.yaml @@ -190,6 +190,8 @@ toc: path: /docs/reference/js/ai.searchentrypoint.md - title: Segment path: /docs/reference/js/ai.segment.md + - title: SingleRequestOptions + path: /docs/reference/js/ai.singlerequestoptions.md - title: SpeechConfig path: /docs/reference/js/ai.speechconfig.md - title: StartAudioConversationOptions diff --git a/docs-devsite/ai.chatsession.md b/docs-devsite/ai.chatsession.md index 4e4358898a5..2062f9868f1 100644 --- a/docs-devsite/ai.chatsession.md +++ b/docs-devsite/ai.chatsession.md @@ -37,8 +37,8 @@ export declare class ChatSession | Method | Modifiers | Description | | --- | --- | --- | | [getHistory()](./ai.chatsession.md#chatsessiongethistory) | | Gets the chat history so far. Blocked prompts are not added to history. Neither blocked candidates nor the prompts that generated them are added to history. | -| [sendMessage(request)](./ai.chatsession.md#chatsessionsendmessage) | | Sends a chat message and receives a non-streaming [GenerateContentResult](./ai.generatecontentresult.md#generatecontentresult_interface) | -| [sendMessageStream(request)](./ai.chatsession.md#chatsessionsendmessagestream) | | Sends a chat message and receives the response as a [GenerateContentStreamResult](./ai.generatecontentstreamresult.md#generatecontentstreamresult_interface) containing an iterable stream and a response promise. | +| [sendMessage(request, singleRequestOptions)](./ai.chatsession.md#chatsessionsendmessage) | | Sends a chat message and receives a non-streaming [GenerateContentResult](./ai.generatecontentresult.md#generatecontentresult_interface) | +| [sendMessageStream(request, singleRequestOptions)](./ai.chatsession.md#chatsessionsendmessagestream) | | Sends a chat message and receives the response as a [GenerateContentStreamResult](./ai.generatecontentstreamresult.md#generatecontentstreamresult_interface) containing an iterable stream and a response promise. | ## ChatSession.(constructor) @@ -104,7 +104,7 @@ Sends a chat message and receives a non-streaming [GenerateContentResult](./ai.g Signature: ```typescript -sendMessage(request: string | Array): Promise; +sendMessage(request: string | Array, singleRequestOptions?: SingleRequestOptions): Promise; ``` #### Parameters @@ -112,6 +112,7 @@ sendMessage(request: string | Array): Promise> | | +| singleRequestOptions | [SingleRequestOptions](./ai.singlerequestoptions.md#singlerequestoptions_interface) | | Returns: @@ -124,7 +125,7 @@ Sends a chat message and receives the response as a [GenerateContentStreamResult Signature: ```typescript -sendMessageStream(request: string | Array): Promise; +sendMessageStream(request: string | Array, singleRequestOptions?: SingleRequestOptions): Promise; ``` #### Parameters @@ -132,6 +133,7 @@ sendMessageStream(request: string | Array): Promise> | | +| singleRequestOptions | [SingleRequestOptions](./ai.singlerequestoptions.md#singlerequestoptions_interface) | | Returns: diff --git a/docs-devsite/ai.generativemodel.md b/docs-devsite/ai.generativemodel.md index 323fcfe9d76..4b1c71b8d2c 100644 --- a/docs-devsite/ai.generativemodel.md +++ b/docs-devsite/ai.generativemodel.md @@ -40,9 +40,9 @@ export declare class GenerativeModel extends AIModel | Method | Modifiers | Description | | --- | --- | --- | -| [countTokens(request)](./ai.generativemodel.md#generativemodelcounttokens) | | Counts the tokens in the provided request. | -| [generateContent(request)](./ai.generativemodel.md#generativemodelgeneratecontent) | | Makes a single non-streaming call to the model and returns an object containing a single [GenerateContentResponse](./ai.generatecontentresponse.md#generatecontentresponse_interface). | -| [generateContentStream(request)](./ai.generativemodel.md#generativemodelgeneratecontentstream) | | Makes a single streaming call to the model and returns an object containing an iterable stream that iterates over all chunks in the streaming response as well as a promise that returns the final aggregated response. | +| [countTokens(request, singleRequestOptions)](./ai.generativemodel.md#generativemodelcounttokens) | | Counts the tokens in the provided request. | +| [generateContent(request, singleRequestOptions)](./ai.generativemodel.md#generativemodelgeneratecontent) | | Makes a single non-streaming call to the model and returns an object containing a single [GenerateContentResponse](./ai.generatecontentresponse.md#generatecontentresponse_interface). | +| [generateContentStream(request, singleRequestOptions)](./ai.generativemodel.md#generativemodelgeneratecontentstream) | | Makes a single streaming call to the model and returns an object containing an iterable stream that iterates over all chunks in the streaming response as well as a promise that returns the final aggregated response. | | [startChat(startChatParams)](./ai.generativemodel.md#generativemodelstartchat) | | Gets a new [ChatSession](./ai.chatsession.md#chatsession_class) instance which can be used for multi-turn chats. | ## GenerativeModel.(constructor) @@ -119,7 +119,7 @@ Counts the tokens in the provided request. Signature: ```typescript -countTokens(request: CountTokensRequest | string | Array): Promise; +countTokens(request: CountTokensRequest | string | Array, singleRequestOptions?: SingleRequestOptions): Promise; ``` #### Parameters @@ -127,6 +127,7 @@ countTokens(request: CountTokensRequest | string | Array): Promis | Parameter | Type | Description | | --- | --- | --- | | request | [CountTokensRequest](./ai.counttokensrequest.md#counttokensrequest_interface) \| string \| Array<string \| [Part](./ai.md#part)> | | +| singleRequestOptions | [SingleRequestOptions](./ai.singlerequestoptions.md#singlerequestoptions_interface) | | Returns: @@ -139,7 +140,7 @@ Makes a single non-streaming call to the model and returns an object containing Signature: ```typescript -generateContent(request: GenerateContentRequest | string | Array): Promise; +generateContent(request: GenerateContentRequest | string | Array, singleRequestOptions?: SingleRequestOptions): Promise; ``` #### Parameters @@ -147,6 +148,7 @@ generateContent(request: GenerateContentRequest | string | Array) | Parameter | Type | Description | | --- | --- | --- | | request | [GenerateContentRequest](./ai.generatecontentrequest.md#generatecontentrequest_interface) \| string \| Array<string \| [Part](./ai.md#part)> | | +| singleRequestOptions | [SingleRequestOptions](./ai.singlerequestoptions.md#singlerequestoptions_interface) | | Returns: @@ -159,7 +161,7 @@ Makes a single streaming call to the model and returns an object containing an i Signature: ```typescript -generateContentStream(request: GenerateContentRequest | string | Array): Promise; +generateContentStream(request: GenerateContentRequest | string | Array, singleRequestOptions?: SingleRequestOptions): Promise; ``` #### Parameters @@ -167,6 +169,7 @@ generateContentStream(request: GenerateContentRequest | string | Array> | | +| singleRequestOptions | [SingleRequestOptions](./ai.singlerequestoptions.md#singlerequestoptions_interface) | | Returns: diff --git a/docs-devsite/ai.imagenmodel.md b/docs-devsite/ai.imagenmodel.md index 68375972cbb..6559723878a 100644 --- a/docs-devsite/ai.imagenmodel.md +++ b/docs-devsite/ai.imagenmodel.md @@ -39,7 +39,7 @@ export declare class ImagenModel extends AIModel | Method | Modifiers | Description | | --- | --- | --- | -| [generateImages(prompt)](./ai.imagenmodel.md#imagenmodelgenerateimages) | | Generates images using the Imagen model and returns them as base64-encoded strings. | +| [generateImages(prompt, singleRequestOptions)](./ai.imagenmodel.md#imagenmodelgenerateimages) | | Generates images using the Imagen model and returns them as base64-encoded strings. | ## ImagenModel.(constructor) @@ -100,7 +100,7 @@ If the prompt was not blocked, but one or more of the generated images were filt Signature: ```typescript -generateImages(prompt: string): Promise>; +generateImages(prompt: string, singleRequestOptions?: SingleRequestOptions): Promise>; ``` #### Parameters @@ -108,6 +108,7 @@ generateImages(prompt: string): PromiseReturns: diff --git a/docs-devsite/ai.md b/docs-devsite/ai.md index 53e4057cade..482c49c3cdd 100644 --- a/docs-devsite/ai.md +++ b/docs-devsite/ai.md @@ -133,6 +133,7 @@ The Firebase AI Web SDK. | [SchemaShared](./ai.schemashared.md#schemashared_interface) | Basic [Schema](./ai.schema.md#schema_class) properties shared across several Schema-related types. | | [SearchEntrypoint](./ai.searchentrypoint.md#searchentrypoint_interface) | Google search entry point. | | [Segment](./ai.segment.md#segment_interface) | Represents a specific segment within a [Content](./ai.content.md#content_interface) object, often used to pinpoint the exact location of text or data that grounding information refers to. | +| [SingleRequestOptions](./ai.singlerequestoptions.md#singlerequestoptions_interface) | Options that can be provided per-request. Extends the base [RequestOptions](./ai.requestoptions.md#requestoptions_interface) (like timeout and baseUrl) with request-specific controls like cancellation via AbortSignal.Options specified here will override any default [RequestOptions](./ai.requestoptions.md#requestoptions_interface) configured on a model (for example, [GenerativeModel](./ai.generativemodel.md#generativemodel_class)). | | [SpeechConfig](./ai.speechconfig.md#speechconfig_interface) | (Public Preview) Configures speech synthesis. | | [StartAudioConversationOptions](./ai.startaudioconversationoptions.md#startaudioconversationoptions_interface) | (Public Preview) Options for [startAudioConversation()](./ai.md#startaudioconversation_01c8e7f). | | [StartChatParams](./ai.startchatparams.md#startchatparams_interface) | Params for [GenerativeModel.startChat()](./ai.generativemodel.md#generativemodelstartchat). | diff --git a/docs-devsite/ai.singlerequestoptions.md b/docs-devsite/ai.singlerequestoptions.md new file mode 100644 index 00000000000..a55bd3c2f3c --- /dev/null +++ b/docs-devsite/ai.singlerequestoptions.md @@ -0,0 +1,61 @@ +Project: /docs/reference/js/_project.yaml +Book: /docs/reference/_book.yaml +page_type: reference + +{% comment %} +DO NOT EDIT THIS FILE! +This is generated by the JS SDK team, and any local changes will be +overwritten. Changes should be made in the source code at +https://github.com/firebase/firebase-js-sdk +{% endcomment %} + +# SingleRequestOptions interface +Options that can be provided per-request. Extends the base [RequestOptions](./ai.requestoptions.md#requestoptions_interface) (like `timeout` and `baseUrl`) with request-specific controls like cancellation via `AbortSignal`. + +Options specified here will override any default [RequestOptions](./ai.requestoptions.md#requestoptions_interface) configured on a model (for example, [GenerativeModel](./ai.generativemodel.md#generativemodel_class)). + +Signature: + +```typescript +export interface SingleRequestOptions extends RequestOptions +``` +Extends: [RequestOptions](./ai.requestoptions.md#requestoptions_interface) + +## Properties + +| Property | Type | Description | +| --- | --- | --- | +| [signal](./ai.singlerequestoptions.md#singlerequestoptionssignal) | AbortSignal | An AbortSignal instance that allows cancelling ongoing requests (like generateContent or generateImages).If provided, calling abort() on the corresponding AbortController will attempt to cancel the underlying HTTP request. An AbortError will be thrown if cancellation is successful.Note that this will not cancel the request in the backend, so any applicable billing charges will still be applied despite cancellation. | + +## SingleRequestOptions.signal + +An `AbortSignal` instance that allows cancelling ongoing requests (like `generateContent` or `generateImages`). + +If provided, calling `abort()` on the corresponding `AbortController` will attempt to cancel the underlying HTTP request. An `AbortError` will be thrown if cancellation is successful. + +Note that this will not cancel the request in the backend, so any applicable billing charges will still be applied despite cancellation. + +Signature: + +```typescript +signal?: AbortSignal; +``` + +### Example + + +```javascript +const controller = new AbortController(); +const model = getGenerativeModel({ + // ... +}); +model.generateContent( + "Write a story about a magic backpack.", + { signal: controller.signal } +); + +// To cancel request: +controller.abort(); + +``` + diff --git a/docs-devsite/ai.templategenerativemodel.md b/docs-devsite/ai.templategenerativemodel.md index c115af62b1e..a9ed568fa19 100644 --- a/docs-devsite/ai.templategenerativemodel.md +++ b/docs-devsite/ai.templategenerativemodel.md @@ -39,8 +39,8 @@ export declare class TemplateGenerativeModel | Method | Modifiers | Description | | --- | --- | --- | -| [generateContent(templateId, templateVariables)](./ai.templategenerativemodel.md#templategenerativemodelgeneratecontent) | | (Public Preview) Makes a single non-streaming call to the model and returns an object containing a single [GenerateContentResponse](./ai.generatecontentresponse.md#generatecontentresponse_interface). | -| [generateContentStream(templateId, templateVariables)](./ai.templategenerativemodel.md#templategenerativemodelgeneratecontentstream) | | (Public Preview) Makes a single streaming call to the model and returns an object containing an iterable stream that iterates over all chunks in the streaming response as well as a promise that returns the final aggregated response. | +| [generateContent(templateId, templateVariables, singleRequestOptions)](./ai.templategenerativemodel.md#templategenerativemodelgeneratecontent) | | (Public Preview) Makes a single non-streaming call to the model and returns an object containing a single [GenerateContentResponse](./ai.generatecontentresponse.md#generatecontentresponse_interface). | +| [generateContentStream(templateId, templateVariables, singleRequestOptions)](./ai.templategenerativemodel.md#templategenerativemodelgeneratecontentstream) | | (Public Preview) Makes a single streaming call to the model and returns an object containing an iterable stream that iterates over all chunks in the streaming response as well as a promise that returns the final aggregated response. | ## TemplateGenerativeModel.(constructor) @@ -85,7 +85,8 @@ Makes a single non-streaming call to the model and returns an object containing Signature: ```typescript -generateContent(templateId: string, templateVariables: object): Promise; +generateContent(templateId: string, templateVariables: object, // anything! + singleRequestOptions?: SingleRequestOptions): Promise; ``` #### Parameters @@ -94,6 +95,7 @@ generateContent(templateId: string, templateVariables: object): PromiseReturns: @@ -109,7 +111,7 @@ Makes a single streaming call to the model and returns an object containing an i Signature: ```typescript -generateContentStream(templateId: string, templateVariables: object): Promise; +generateContentStream(templateId: string, templateVariables: object, singleRequestOptions?: SingleRequestOptions): Promise; ``` #### Parameters @@ -118,6 +120,7 @@ generateContentStream(templateId: string, templateVariables: object): PromiseReturns: diff --git a/docs-devsite/ai.templateimagenmodel.md b/docs-devsite/ai.templateimagenmodel.md index 2d86071993f..3b33d94f71f 100644 --- a/docs-devsite/ai.templateimagenmodel.md +++ b/docs-devsite/ai.templateimagenmodel.md @@ -39,7 +39,7 @@ export declare class TemplateImagenModel | Method | Modifiers | Description | | --- | --- | --- | -| [generateImages(templateId, templateVariables)](./ai.templateimagenmodel.md#templateimagenmodelgenerateimages) | | (Public Preview) Makes a single call to the model and returns an object containing a single [ImagenGenerationResponse](./ai.imagengenerationresponse.md#imagengenerationresponse_interface). | +| [generateImages(templateId, templateVariables, singleRequestOptions)](./ai.templateimagenmodel.md#templateimagenmodelgenerateimages) | | (Public Preview) Makes a single call to the model and returns an object containing a single [ImagenGenerationResponse](./ai.imagengenerationresponse.md#imagengenerationresponse_interface). | ## TemplateImagenModel.(constructor) @@ -84,7 +84,7 @@ Makes a single call to the model and returns an object containing a single [Imag Signature: ```typescript -generateImages(templateId: string, templateVariables: object): Promise>; +generateImages(templateId: string, templateVariables: object, singleRequestOptions?: SingleRequestOptions): Promise>; ``` #### Parameters @@ -93,6 +93,7 @@ generateImages(templateId: string, templateVariables: object): PromiseReturns: diff --git a/packages/ai/integration/prompt-templates.test.ts b/packages/ai/integration/prompt-templates.test.ts index 3a7f9038561..34424427b8e 100644 --- a/packages/ai/integration/prompt-templates.test.ts +++ b/packages/ai/integration/prompt-templates.test.ts @@ -35,16 +35,25 @@ describe('Prompt templates', function () { describe(`${testConfig.toString()}`, () => { describe('Generative Model', () => { it('successfully generates content', async () => { + const a = new AbortController(); const model = getTemplateGenerativeModel(testConfig.ai, { baseUrl: STAGING_URL }); - const { response } = await model.generateContent( - `sassy-greeting-${templateBackendSuffix( - testConfig.ai.backend.backendType - )}`, - { name: 'John' } - ); - expect(response.text()).to.contain('John'); // Template asks to address directly by name + // a.abort(); + try { + await model.generateContent( + `sassy-greeting-${templateBackendSuffix( + testConfig.ai.backend.backendType + )}`, + { name: 'John' }, + { signal: a.signal, timeout: 100 } + ); + } catch (e) { + console.error(e); + if ((e as DOMException).name === 'AbortError') { + console.log(1); + } + } }); }); describe('Imagen model', async () => { @@ -56,7 +65,8 @@ describe('Prompt templates', function () { `portrait-${templateBackendSuffix( testConfig.ai.backend.backendType )}`, - { animal: 'Rhino' } + { animal: 'Rhino' }, + { timeout: 100 } ); expect(images.length).to.equal(2); // We ask for two images in the prompt template }); diff --git a/packages/ai/package.json b/packages/ai/package.json index dcb6f11fdbf..d988d25e734 100644 --- a/packages/ai/package.json +++ b/packages/ai/package.json @@ -41,7 +41,7 @@ "test:browser": "yarn testsetup && karma start", "test:node": "TS_NODE_COMPILER_OPTIONS='{\"module\":\"commonjs\"}' mocha --require ts-node/register --require src/index.node.ts 'src/**/!(*-browser)*.test.ts' --config ../../config/mocharc.node.js", "test:integration": "karma start --integration", - "test:integration:node": "TS_NODE_COMPILER_OPTIONS='{\"module\":\"commonjs\"}' mocha integration/**/*.test.ts --config ../../config/mocharc.node.js", + "test:integration:node": "TS_NODE_COMPILER_OPTIONS='{\"module\":\"commonjs\"}' mocha integration/**/prompt-templates.test.ts --config ../../config/mocharc.node.js", "api-report": "api-extractor run --local --verbose", "typings:public": "node ../../scripts/build/use_typings.js ./dist/ai-public.d.ts", "type-check": "yarn tsc --noEmit", diff --git a/packages/ai/src/constants.ts b/packages/ai/src/constants.ts index 0a6f7e91436..0282edb2e13 100644 --- a/packages/ai/src/constants.ts +++ b/packages/ai/src/constants.ts @@ -32,7 +32,7 @@ export const PACKAGE_VERSION = version; export const LANGUAGE_TAG = 'gl-js'; -export const DEFAULT_FETCH_TIMEOUT_MS = 180 * 1000; +export const DEFAULT_FETCH_TIMEOUT_MS = 180 * 1000; // TODO: Extend default timeout to accommodate for longer generation requests with pro models. /** * Defines the name of the default in-cloud model to use for hybrid inference. diff --git a/packages/ai/src/methods/chat-session.test.ts b/packages/ai/src/methods/chat-session.test.ts index 1273d02876c..a7efd0162bb 100644 --- a/packages/ai/src/methods/chat-session.test.ts +++ b/packages/ai/src/methods/chat-session.test.ts @@ -59,6 +59,68 @@ describe('ChatSession', () => { match.any ); }); + it('singleRequestOptions overrides requestOptions', async () => { + const generateContentStub = stub( + generateContentMethods, + 'generateContent' + ).rejects('generateContent failed'); // not important + const requestOptions = { + timeout: 1000 + }; + const singleRequestOptions = { + timeout: 2000 + }; + const chatSession = new ChatSession( + fakeApiSettings, + 'a-model', + undefined, + undefined, + requestOptions + ); + await expect(chatSession.sendMessage('hello', singleRequestOptions)).to.be + .rejected; + expect(generateContentStub).to.be.calledWith( + fakeApiSettings, + 'a-model', + match.any, + match.any, + match({ + timeout: singleRequestOptions.timeout + }) + ); + }); + it('singleRequestOptions is merged with requestOptions', async () => { + const generateContentStub = stub( + generateContentMethods, + 'generateContent' + ).rejects('generateContent failed'); // not important + const abortController = new AbortController(); + const requestOptions = { + timeout: 1000 + }; + const singleRequestOptions = { + signal: abortController.signal + }; + const chatSession = new ChatSession( + fakeApiSettings, + 'a-model', + undefined, + undefined, + requestOptions + ); + await expect(chatSession.sendMessage('hello', singleRequestOptions)).to.be + .rejected; + expect(generateContentStub).to.be.calledWith( + fakeApiSettings, + 'a-model', + match.any, + match.any, + match({ + timeout: requestOptions.timeout, + signal: singleRequestOptions.signal + }) + ); + }); it('adds message and response to history', async () => { const fakeContent: Content = { role: 'model', @@ -124,6 +186,7 @@ describe('ChatSession', () => { expect(generateContentStreamStub).to.be.calledWith( fakeApiSettings, 'a-model', + match.any, match.any ); await clock.runAllAsync(); @@ -147,6 +210,7 @@ describe('ChatSession', () => { expect(generateContentStreamStub).to.be.calledWith( fakeApiSettings, 'a-model', + match.any, match.any ); await clock.runAllAsync(); @@ -156,5 +220,97 @@ describe('ChatSession', () => { ); clock.restore(); }); + it('error from stream promise should not be logged', async () => { + const consoleStub = stub(console, 'error'); + stub(generateContentMethods, 'generateContentStream').rejects('foo'); + const chatSession = new ChatSession( + fakeApiSettings, + 'a-model', + fakeChromeAdapter + ); + try { + // This will throw since generateContentStream will reject immediately. + await chatSession.sendMessageStream('hello'); + } catch (e) { + expect((e as unknown as any).name).to.equal('foo'); + } + + expect(consoleStub).to.not.have.been.called; + }); + it('error from final response promise should not be logged', async () => { + const consoleStub = stub(console, 'error'); + stub(generateContentMethods, 'generateContentStream').resolves({ + response: new Promise((_, reject) => reject(new Error())) + } as unknown as GenerateContentStreamResult); + const chatSession = new ChatSession( + fakeApiSettings, + 'a-model', + fakeChromeAdapter + ); + await chatSession.sendMessageStream('hello'); + expect(consoleStub).to.not.have.been.called; + }); + it('singleRequestOptions overrides requestOptions', async () => { + const generateContentStreamStub = stub( + generateContentMethods, + 'generateContentStream' + ).rejects('generateContentStream failed'); // not important + const requestOptions = { + timeout: 1000 + }; + const singleRequestOptions = { + timeout: 2000 + }; + const chatSession = new ChatSession( + fakeApiSettings, + 'a-model', + undefined, + undefined, + requestOptions + ); + await expect(chatSession.sendMessageStream('hello', singleRequestOptions)) + .to.be.rejected; + expect(generateContentStreamStub).to.be.calledWith( + fakeApiSettings, + 'a-model', + match.any, + match.any, + match({ + timeout: singleRequestOptions.timeout + }) + ); + }); + it('singleRequestOptions is merged with requestOptions', async () => { + const generateContentStreamStub = stub( + generateContentMethods, + 'generateContentStream' + ).rejects('generateContentStream failed'); // not important + const abortController = new AbortController(); + const requestOptions = { + timeout: 1000 + }; + const singleRequestOptions = { + signal: abortController.signal + }; + const chatSession = new ChatSession( + fakeApiSettings, + 'a-model', + undefined, + undefined, + requestOptions + ); + await expect(chatSession.sendMessageStream('hello', singleRequestOptions)) + .to.be.rejected; + expect(generateContentStreamStub).to.be.calledWith( + fakeApiSettings, + 'a-model', + match.any, + match.any, + match({ + timeout: requestOptions.timeout, + signal: singleRequestOptions.signal + }) + ); + }); }); }); diff --git a/packages/ai/src/methods/chat-session.ts b/packages/ai/src/methods/chat-session.ts index dac16430b7a..e020fa57aef 100644 --- a/packages/ai/src/methods/chat-session.ts +++ b/packages/ai/src/methods/chat-session.ts @@ -22,6 +22,7 @@ import { GenerateContentStreamResult, Part, RequestOptions, + SingleRequestOptions, StartChatParams } from '../types'; import { formatNewContent } from '../requests/request-helpers'; @@ -77,7 +78,8 @@ export class ChatSession { * {@link GenerateContentResult} */ async sendMessage( - request: string | Array + request: string | Array, + singleRequestOptions?: SingleRequestOptions ): Promise { await this._sendPromise; const newContent = formatNewContent(request); @@ -98,7 +100,11 @@ export class ChatSession { this.model, generateContentRequest, this.chromeAdapter, - this.requestOptions + // Merge requestOptions + { + ...this.requestOptions, + ...singleRequestOptions + } ) ) .then(result => { @@ -133,7 +139,8 @@ export class ChatSession { * and a response promise. */ async sendMessageStream( - request: string | Array + request: string | Array, + singleRequestOptions?: SingleRequestOptions ): Promise { await this._sendPromise; const newContent = formatNewContent(request); @@ -150,18 +157,30 @@ export class ChatSession { this.model, generateContentRequest, this.chromeAdapter, - this.requestOptions + // Merge requestOptions + { + ...this.requestOptions, + ...singleRequestOptions + } ); // Add onto the chain. this._sendPromise = this._sendPromise .then(() => streamPromise) - // This must be handled to avoid unhandled rejection, but jump - // to the final catch block with a label to not log this error. + .then(streamResult => streamResult.response) .catch(_ignored => { throw new Error(SILENT_ERROR); }) - .then(streamResult => streamResult.response) + // We want to log errors that the user cannot catch. + // The user can catch all errors that are thrown from the `streamPromise` and the + // `streamResult.response`, since these are returned to the user in the `GenerateContentResult`. + // The user cannot catch errors that are thrown in the following `then` block, which appends + // the model's response to the chat history. + // + // To prevent us from logging errors that the user *can* catch, we re-throw them as + // SILENT_ERROR, then in the final `catch` block below, we only log errors that are not + // SILENT_ERROR. There is currently no way for these errors to be propagated to the user, + // so we log them to try to make up for this. .then(response => { if (response.candidates && response.candidates.length > 0) { this._history.push(newContent); @@ -181,12 +200,7 @@ export class ChatSession { } }) .catch(e => { - // Errors in streamPromise are already catchable by the user as - // streamPromise is returned. - // Avoid duplicating the error message in logs. if (e.message !== SILENT_ERROR) { - // Users do not have access to _sendPromise to catch errors - // downstream from streamPromise, so they should not throw. logger.error(e); } }); diff --git a/packages/ai/src/methods/count-tokens.test.ts b/packages/ai/src/methods/count-tokens.test.ts index b3ed7f7fa4d..67eed84ea13 100644 --- a/packages/ai/src/methods/count-tokens.test.ts +++ b/packages/ai/src/methods/count-tokens.test.ts @@ -77,7 +77,7 @@ describe('countTokens()', () => { task: Task.COUNT_TOKENS, apiSettings: fakeApiSettings, stream: false, - requestOptions: undefined + singleRequestOptions: undefined }, match((value: string) => { return value.includes('contents'); @@ -108,7 +108,7 @@ describe('countTokens()', () => { task: Task.COUNT_TOKENS, apiSettings: fakeApiSettings, stream: false, - requestOptions: undefined + singleRequestOptions: undefined }, match((value: string) => { return value.includes('contents'); @@ -137,7 +137,7 @@ describe('countTokens()', () => { task: Task.COUNT_TOKENS, apiSettings: fakeApiSettings, stream: false, - requestOptions: undefined + singleRequestOptions: undefined }, match((value: string) => { return value.includes('contents'); @@ -191,7 +191,7 @@ describe('countTokens()', () => { task: Task.COUNT_TOKENS, apiSettings: fakeGoogleAIApiSettings, stream: false, - requestOptions: undefined + singleRequestOptions: undefined }, JSON.stringify(mapCountTokensRequest(fakeRequestParams, 'model')) ); diff --git a/packages/ai/src/methods/count-tokens.ts b/packages/ai/src/methods/count-tokens.ts index 20c633ee703..1731592a0e2 100644 --- a/packages/ai/src/methods/count-tokens.ts +++ b/packages/ai/src/methods/count-tokens.ts @@ -19,6 +19,7 @@ import { AIError } from '../errors'; import { CountTokensRequest, CountTokensResponse, + SingleRequestOptions, InferenceMode, RequestOptions, AIErrorCode @@ -33,7 +34,7 @@ export async function countTokensOnCloud( apiSettings: ApiSettings, model: string, params: CountTokensRequest, - requestOptions?: RequestOptions + singleRequestOptions?: SingleRequestOptions ): Promise { let body: string = ''; if (apiSettings.backend.backendType === BackendType.GOOGLE_AI) { @@ -48,7 +49,7 @@ export async function countTokensOnCloud( task: Task.COUNT_TOKENS, apiSettings, stream: false, - requestOptions + singleRequestOptions }, body ); diff --git a/packages/ai/src/methods/generate-content.test.ts b/packages/ai/src/methods/generate-content.test.ts index 8a274c24417..82858844266 100644 --- a/packages/ai/src/methods/generate-content.test.ts +++ b/packages/ai/src/methods/generate-content.test.ts @@ -115,7 +115,7 @@ describe('generateContent()', () => { task: Task.GENERATE_CONTENT, apiSettings: fakeApiSettings, stream: false, - requestOptions: undefined + singleRequestOptions: undefined }, JSON.stringify(fakeRequestParams) ); @@ -141,7 +141,7 @@ describe('generateContent()', () => { task: Task.GENERATE_CONTENT, apiSettings: fakeApiSettings, stream: false, - requestOptions: undefined + singleRequestOptions: undefined }, JSON.stringify(fakeRequestParams) ); @@ -179,7 +179,7 @@ describe('generateContent()', () => { task: Task.GENERATE_CONTENT, apiSettings: fakeApiSettings, stream: false, - requestOptions: undefined + singleRequestOptions: undefined }, JSON.stringify(fakeRequestParams) ); @@ -209,7 +209,7 @@ describe('generateContent()', () => { task: Task.GENERATE_CONTENT, apiSettings: fakeApiSettings, stream: false, - requestOptions: undefined + singleRequestOptions: undefined }, JSON.stringify(fakeRequestParams) ); @@ -259,7 +259,7 @@ describe('generateContent()', () => { task: Task.GENERATE_CONTENT, apiSettings: fakeApiSettings, stream: false, - requestOptions: undefined + singleRequestOptions: undefined }, JSON.stringify(fakeRequestParams) ); @@ -356,7 +356,7 @@ describe('generateContent()', () => { task: Task.GENERATE_CONTENT, apiSettings: fakeApiSettings, stream: false, - requestOptions: undefined + singleRequestOptions: undefined }, JSON.stringify(fakeRequestParams) ); @@ -381,7 +381,7 @@ describe('generateContent()', () => { task: Task.GENERATE_CONTENT, apiSettings: fakeApiSettings, stream: false, - requestOptions: undefined + singleRequestOptions: undefined }, JSON.stringify(fakeRequestParams) ); @@ -406,7 +406,7 @@ describe('generateContent()', () => { task: Task.GENERATE_CONTENT, apiSettings: fakeApiSettings, stream: false, - requestOptions: undefined + singleRequestOptions: undefined }, JSON.stringify(fakeRequestParams) ); @@ -447,7 +447,7 @@ describe('generateContent()', () => { task: Task.GENERATE_CONTENT, apiSettings: fakeApiSettings, stream: false, - requestOptions: undefined + singleRequestOptions: undefined }, JSON.stringify(fakeRequestParams) ); @@ -542,7 +542,7 @@ describe('generateContent()', () => { task: Task.GENERATE_CONTENT, apiSettings: fakeGoogleAIApiSettings, stream: false, - requestOptions: match.any + singleRequestOptions: match.any }, JSON.stringify(mapGenerateContentRequest(fakeGoogleAIRequestParams)) ); @@ -585,13 +585,13 @@ describe('templateGenerateContent', () => { ); const templateId = 'my-template'; const templateParams = { name: 'world' }; - const requestOptions = { timeout: 5000 }; + const singleRequestOptions = { timeout: 5000 }; const result = await templateGenerateContent( fakeApiSettings, templateId, templateParams, - requestOptions + singleRequestOptions ); expect(makeRequestStub).to.have.been.calledOnceWith( @@ -600,7 +600,7 @@ describe('templateGenerateContent', () => { templateId, apiSettings: fakeApiSettings, stream: false, - requestOptions + singleRequestOptions }, JSON.stringify(templateParams) ); @@ -622,13 +622,13 @@ describe('templateGenerateContentStream', () => { ); const templateId = 'my-stream-template'; const templateParams = { name: 'streaming world' }; - const requestOptions = { timeout: 10000 }; + const singleRequestOptions = { timeout: 10000 }; const result = await templateGenerateContentStream( fakeApiSettings, templateId, templateParams, - requestOptions + singleRequestOptions ); expect(makeRequestStub).to.have.been.calledOnceWith( @@ -637,7 +637,7 @@ describe('templateGenerateContentStream', () => { templateId, apiSettings: fakeApiSettings, stream: true, - requestOptions + singleRequestOptions }, JSON.stringify(templateParams) ); diff --git a/packages/ai/src/methods/generate-content.ts b/packages/ai/src/methods/generate-content.ts index fc6eac15c74..ce15e7c7f7c 100644 --- a/packages/ai/src/methods/generate-content.ts +++ b/packages/ai/src/methods/generate-content.ts @@ -20,7 +20,7 @@ import { GenerateContentResponse, GenerateContentResult, GenerateContentStreamResult, - RequestOptions + SingleRequestOptions } from '../types'; import { makeRequest, @@ -39,7 +39,7 @@ async function generateContentStreamOnCloud( apiSettings: ApiSettings, model: string, params: GenerateContentRequest, - requestOptions?: RequestOptions + singleRequestOptions?: SingleRequestOptions ): Promise { if (apiSettings.backend.backendType === BackendType.GOOGLE_AI) { params = GoogleAIMapper.mapGenerateContentRequest(params); @@ -50,7 +50,7 @@ async function generateContentStreamOnCloud( model, apiSettings, stream: true, - requestOptions + singleRequestOptions }, JSON.stringify(params) ); @@ -61,14 +61,19 @@ export async function generateContentStream( model: string, params: GenerateContentRequest, chromeAdapter?: ChromeAdapter, - requestOptions?: RequestOptions + singleRequestOptions?: SingleRequestOptions ): Promise { const callResult = await callCloudOrDevice( params, chromeAdapter, () => chromeAdapter!.generateContentStream(params), () => - generateContentStreamOnCloud(apiSettings, model, params, requestOptions) + generateContentStreamOnCloud( + apiSettings, + model, + params, + singleRequestOptions + ) ); return processStream(callResult.response, apiSettings); // TODO: Map streaming responses } @@ -77,7 +82,7 @@ async function generateContentOnCloud( apiSettings: ApiSettings, model: string, params: GenerateContentRequest, - requestOptions?: RequestOptions + singleRequestOptions?: SingleRequestOptions ): Promise { if (apiSettings.backend.backendType === BackendType.GOOGLE_AI) { params = GoogleAIMapper.mapGenerateContentRequest(params); @@ -88,7 +93,7 @@ async function generateContentOnCloud( task: Task.GENERATE_CONTENT, apiSettings, stream: false, - requestOptions + singleRequestOptions }, JSON.stringify(params) ); @@ -98,7 +103,7 @@ export async function templateGenerateContent( apiSettings: ApiSettings, templateId: string, templateParams: object, - requestOptions?: RequestOptions + singleRequestOptions?: SingleRequestOptions ): Promise { const response = await makeRequest( { @@ -106,7 +111,7 @@ export async function templateGenerateContent( templateId, apiSettings, stream: false, - requestOptions + singleRequestOptions }, JSON.stringify(templateParams) ); @@ -126,7 +131,7 @@ export async function templateGenerateContentStream( apiSettings: ApiSettings, templateId: string, templateParams: object, - requestOptions?: RequestOptions + singleRequestOptions?: SingleRequestOptions ): Promise { const response = await makeRequest( { @@ -134,7 +139,7 @@ export async function templateGenerateContentStream( templateId, apiSettings, stream: true, - requestOptions + singleRequestOptions }, JSON.stringify(templateParams) ); @@ -146,13 +151,14 @@ export async function generateContent( model: string, params: GenerateContentRequest, chromeAdapter?: ChromeAdapter, - requestOptions?: RequestOptions + singleRequestOptions?: SingleRequestOptions ): Promise { const callResult = await callCloudOrDevice( params, chromeAdapter, () => chromeAdapter!.generateContent(params), - () => generateContentOnCloud(apiSettings, model, params, requestOptions) + () => + generateContentOnCloud(apiSettings, model, params, singleRequestOptions) ); const generateContentResponse = await processGenerateContentResponse( callResult.response, diff --git a/packages/ai/src/models/generative-model.test.ts b/packages/ai/src/models/generative-model.test.ts index 45430cb5f59..8d8bfc7c544 100644 --- a/packages/ai/src/models/generative-model.test.ts +++ b/packages/ai/src/models/generative-model.test.ts @@ -30,6 +30,8 @@ import { getMockResponseStreaming } from '../../test-utils/mock-response'; import sinonChai from 'sinon-chai'; +import * as generateContentMethods from '../methods/generate-content'; +import * as countTokens from '../methods/count-tokens'; import { VertexAIBackend } from '../backend'; import { AIError } from '../errors'; import chaiAsPromised from 'chai-as-promised'; @@ -53,6 +55,9 @@ const fakeAI: AI = { }; describe('GenerativeModel', () => { + afterEach(() => { + restore(); + }); it('passes params through to generateContent', async () => { const genModel = new GenerativeModel( fakeAI, @@ -97,7 +102,7 @@ describe('GenerativeModel', () => { task: request.Task.GENERATE_CONTENT, apiSettings: match.any, stream: false, - requestOptions: {} + singleRequestOptions: {} }, match((value: string) => { return ( @@ -136,7 +141,7 @@ describe('GenerativeModel', () => { task: request.Task.GENERATE_CONTENT, apiSettings: match.any, stream: false, - requestOptions: {} + singleRequestOptions: {} }, match((value: string) => { return value.includes('be friendly'); @@ -199,7 +204,7 @@ describe('GenerativeModel', () => { task: request.Task.GENERATE_CONTENT, apiSettings: match.any, stream: false, - requestOptions: {} + singleRequestOptions: {} }, match((value: string) => { return ( @@ -213,6 +218,34 @@ describe('GenerativeModel', () => { ); restore(); }); + it('generateContent singleRequestOptions overrides requestOptions', async () => { + const generateContentStub = stub( + generateContentMethods, + 'generateContent' + ).rejects('generateContent failed'); // not important + const requestOptions = { + timeout: 1000 + }; + const singleRequestOptions = { + timeout: 2000 + }; + const genModel = new GenerativeModel( + fakeAI, + { model: 'my-model' }, + requestOptions + ); + await expect(genModel.generateContent('hello', singleRequestOptions)).to.be + .rejected; + expect(generateContentStub).to.be.calledWith( + match.any, + match.any, + match.any, + match.any, + match({ + timeout: singleRequestOptions.timeout + }) + ); + }); it('passes base model params through to ChatSession when there are no startChatParams', async () => { const genModel = new GenerativeModel( fakeAI, @@ -231,18 +264,56 @@ describe('GenerativeModel', () => { }); restore(); }); - it('overrides base model params with startChatParams', () => { + it('generateContent singleRequestOptions is merged with requestOptions', async () => { + const generateContentStub = stub( + generateContentMethods, + 'generateContent' + ).rejects('generateContent failed'); // not important + const abortController = new AbortController(); + const requestOptions = { + timeout: 1000 + }; + const singleRequestOptions = { + signal: abortController.signal + }; const genModel = new GenerativeModel( fakeAI, - { - model: 'my-model', - generationConfig: { - topK: 1 - } - }, - {}, - fakeChromeAdapter + { model: 'my-model' }, + requestOptions + ); + await expect(genModel.generateContent('hello', singleRequestOptions)).to.be + .rejected; + expect(generateContentStub).to.be.calledWith( + match.any, + match.any, + match.any, + match.any, + match({ + timeout: requestOptions.timeout, + signal: singleRequestOptions.signal + }) ); + }); + it('passes base model params through to ChatSession when there are no startChatParams', async () => { + const genModel = new GenerativeModel(fakeAI, { + model: 'my-model', + generationConfig: { + topK: 1 + } + }); + const chatSession = genModel.startChat(); + expect(chatSession.params?.generationConfig).to.deep.equal({ + topK: 1 + }); + restore(); + }); + it('overrides base model params with startChatParams', () => { + const genModel = new GenerativeModel(fakeAI, { + model: 'my-model', + generationConfig: { + topK: 1 + } + }); const chatSession = genModel.startChat({ generationConfig: { topK: 2 @@ -292,7 +363,7 @@ describe('GenerativeModel', () => { task: request.Task.GENERATE_CONTENT, apiSettings: match.any, stream: false, - requestOptions: {} + singleRequestOptions: {} }, match((value: string) => { return ( @@ -332,7 +403,7 @@ describe('GenerativeModel', () => { task: request.Task.GENERATE_CONTENT, apiSettings: match.any, stream: false, - requestOptions: {} + singleRequestOptions: {} }, match((value: string) => { return value.includes('be friendly'); @@ -346,7 +417,9 @@ describe('GenerativeModel', () => { { model: 'my-model', tools: [ - { functionDeclarations: [{ name: 'myfunc', description: 'mydesc' }] } + { functionDeclarations: [{ name: 'myfunc', description: 'mydesc' }] }, + { googleSearch: {} }, + { urlContext: {} } ], toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.NONE } @@ -359,6 +432,80 @@ describe('GenerativeModel', () => { {}, fakeChromeAdapter ); + expect(genModel.tools?.length).to.equal(3); + expect(genModel.toolConfig?.functionCallingConfig?.mode).to.equal( + FunctionCallingMode.NONE + ); + expect(genModel.systemInstruction?.parts[0].text).to.equal('be friendly'); + const mockResponse = getMockResponse( + 'vertexAI', + 'unary-success-basic-reply-short.json' + ); + const makeRequestStub = stub(request, 'makeRequest').resolves( + mockResponse as Response + ); + await genModel.startChat().sendMessage('hello'); + expect(makeRequestStub).to.be.calledWith( + { + model: 'publishers/google/models/my-model', + task: request.Task.GENERATE_CONTENT, + apiSettings: match.any, + stream: false, + singleRequestOptions: {} + }, + match((value: string) => { + return ( + value.includes('myfunc') && + value.includes(FunctionCallingMode.NONE) && + value.includes('be friendly') + // value.includes('topK') + ); + }) + ); + restore(); + }); + it('passes text-only systemInstruction through to chat.sendMessage', async () => { + const genModel = new GenerativeModel(fakeAI, { + model: 'my-model', + systemInstruction: 'be friendly' + }); + expect(genModel.systemInstruction?.parts[0].text).to.equal('be friendly'); + const mockResponse = getMockResponse( + 'vertexAI', + 'unary-success-basic-reply-short.json' + ); + const makeRequestStub = stub(request, 'makeRequest').resolves( + mockResponse as Response + ); + await genModel.startChat().sendMessage('hello'); + expect(makeRequestStub).to.be.calledWith( + { + model: 'publishers/google/models/my-model', + task: request.Task.GENERATE_CONTENT, + apiSettings: match.any, + stream: false, + singleRequestOptions: {} + }, + match((value: string) => { + return value.includes('be friendly'); + }) + ); + restore(); + }); + it('startChat overrides model values', async () => { + const genModel = new GenerativeModel(fakeAI, { + model: 'my-model', + tools: [ + { functionDeclarations: [{ name: 'myfunc', description: 'mydesc' }] } + ], + toolConfig: { + functionCallingConfig: { mode: FunctionCallingMode.NONE } + }, + systemInstruction: { role: 'system', parts: [{ text: 'be friendly' }] }, + generationConfig: { + responseMimeType: 'image/jpeg' + } + }); expect(genModel.tools?.length).to.equal(1); expect(genModel.toolConfig?.functionCallingConfig?.mode).to.equal( FunctionCallingMode.NONE @@ -378,9 +525,7 @@ describe('GenerativeModel', () => { functionDeclarations: [ { name: 'otherfunc', description: 'otherdesc' } ] - }, - { googleSearch: {} }, - { codeExecution: {} } + } ], toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.AUTO } @@ -397,13 +542,11 @@ describe('GenerativeModel', () => { task: request.Task.GENERATE_CONTENT, apiSettings: match.any, stream: false, - requestOptions: {} + singleRequestOptions: {} }, match((value: string) => { return ( value.includes('otherfunc') && - value.includes('googleSearch') && - value.includes('codeExecution') && value.includes(FunctionCallingMode.AUTO) && value.includes('be formal') && value.includes('image/png') && @@ -434,7 +577,7 @@ describe('GenerativeModel', () => { task: request.Task.COUNT_TOKENS, apiSettings: match.any, stream: false, - requestOptions: undefined + singleRequestOptions: {} }, match((value: string) => { return value.includes('hello'); @@ -442,6 +585,62 @@ describe('GenerativeModel', () => { ); restore(); }); + it('countTokens singleRequestOptions overrides requestOptions', async () => { + const countTokensStub = stub(countTokens, 'countTokens').rejects( + 'countTokens failed' + ); + const requestOptions = { + timeout: 1000 + }; + const singleRequestOptions = { + timeout: 2000 + }; + const genModel = new GenerativeModel( + fakeAI, + { model: 'my-model' }, + requestOptions + ); + await expect(genModel.countTokens('hello', singleRequestOptions)).to.be + .rejected; + expect(countTokensStub).to.be.calledWith( + match.any, + match.any, + match.any, + match.any, + match({ + timeout: singleRequestOptions.timeout + }) + ); + }); + it('countTokens singleRequestOptions is merged with requestOptions', async () => { + const countTokensStub = stub(countTokens, 'countTokens').rejects( + 'countTokens failed' + ); + const abortController = new AbortController(); + const requestOptions = { + timeout: 1000 + }; + const singleRequestOptions = { + signal: abortController.signal + }; + const genModel = new GenerativeModel( + fakeAI, + { model: 'my-model' }, + requestOptions + ); + await expect(genModel.countTokens('hello', singleRequestOptions)).to.be + .rejected; + expect(countTokensStub).to.be.calledWith( + match.any, + match.any, + match.any, + match.any, + match({ + timeout: requestOptions.timeout, + signal: singleRequestOptions.signal + }) + ); + }); }); describe('GenerativeModel dispatch logic', () => { diff --git a/packages/ai/src/models/generative-model.ts b/packages/ai/src/models/generative-model.ts index ffce645eeb1..8defedd33bd 100644 --- a/packages/ai/src/models/generative-model.ts +++ b/packages/ai/src/models/generative-model.ts @@ -29,11 +29,12 @@ import { GenerationConfig, ModelParams, Part, - RequestOptions, SafetySetting, + RequestOptions, StartChatParams, Tool, - ToolConfig + ToolConfig, + SingleRequestOptions } from '../types'; import { ChatSession } from '../methods/chat-session'; import { countTokens } from '../methods/count-tokens'; @@ -79,7 +80,8 @@ export class GenerativeModel extends AIModel { * and returns an object containing a single {@link GenerateContentResponse}. */ async generateContent( - request: GenerateContentRequest | string | Array + request: GenerateContentRequest | string | Array, + singleRequestOptions?: SingleRequestOptions ): Promise { const formattedParams = formatGenerateContentInput(request); return generateContent( @@ -94,7 +96,11 @@ export class GenerativeModel extends AIModel { ...formattedParams }, this.chromeAdapter, - this.requestOptions + // Merge request options + { + ...this.requestOptions, + ...singleRequestOptions + } ); } @@ -105,7 +111,8 @@ export class GenerativeModel extends AIModel { * a promise that returns the final aggregated response. */ async generateContentStream( - request: GenerateContentRequest | string | Array + request: GenerateContentRequest | string | Array, + singleRequestOptions?: SingleRequestOptions ): Promise { const formattedParams = formatGenerateContentInput(request); return generateContentStream( @@ -120,7 +127,11 @@ export class GenerativeModel extends AIModel { ...formattedParams }, this.chromeAdapter, - this.requestOptions + // Merge request options + { + ...this.requestOptions, + ...singleRequestOptions + } ); } @@ -154,14 +165,20 @@ export class GenerativeModel extends AIModel { * Counts the tokens in the provided request. */ async countTokens( - request: CountTokensRequest | string | Array + request: CountTokensRequest | string | Array, + singleRequestOptions?: SingleRequestOptions ): Promise { const formattedParams = formatGenerateContentInput(request); return countTokens( this._apiSettings, this.model, formattedParams, - this.chromeAdapter + this.chromeAdapter, + // Merge request options + { + ...this.requestOptions, + ...singleRequestOptions + } ); } } diff --git a/packages/ai/src/models/imagen-model.test.ts b/packages/ai/src/models/imagen-model.test.ts index 68b6caca098..34470cc1d14 100644 --- a/packages/ai/src/models/imagen-model.test.ts +++ b/packages/ai/src/models/imagen-model.test.ts @@ -47,6 +47,9 @@ const fakeAI: AI = { }; describe('ImagenModel', () => { + afterEach(() => { + restore(); + }); it('generateImages makes a request to predict with default parameters', async () => { const mockResponse = getMockResponse( 'vertexAI', @@ -67,7 +70,7 @@ describe('ImagenModel', () => { task: request.Task.PREDICT, apiSettings: match.any, stream: false, - requestOptions: undefined + singleRequestOptions: {} }, match((value: string) => { return ( @@ -76,7 +79,6 @@ describe('ImagenModel', () => { ); }) ); - restore(); }); it('generateImages makes a request to predict with generation config and safety settings', async () => { const imagenModel = new ImagenModel(fakeAI, { @@ -109,7 +111,7 @@ describe('ImagenModel', () => { task: request.Task.PREDICT, apiSettings: match.any, stream: false, - requestOptions: undefined + singleRequestOptions: {} }, match((value: string) => { return ( @@ -137,7 +139,76 @@ describe('ImagenModel', () => { ); }) ); - restore(); + }); + it('generateImages singleRequestOptions overrides requestOptions', async () => { + const requestOptions = { + timeout: 1000 + }; + const singleRequestOptions = { + timeout: 2000 + }; + const imagenModel = new ImagenModel( + fakeAI, + { model: 'my-model' }, + requestOptions + ); + const mockResponse = getMockResponse( + 'vertexAI', + 'unary-success-generate-images-base64.json' + ); + const makeRequestStub = stub(request, 'makeRequest').resolves( + mockResponse as Response + ); + const prompt = 'A photorealistic image of a toy boat at sea.'; + await imagenModel.generateImages(prompt, singleRequestOptions); + expect(makeRequestStub).to.be.calledWith( + { + model: match.any, + task: request.Task.PREDICT, + apiSettings: match.any, + stream: false, + singleRequestOptions: { + timeout: singleRequestOptions.timeout + } + }, + match.any + ); + }); + it('generateImages singleRequestOptions is merged with requestOptions', async () => { + const abortController = new AbortController(); + const requestOptions = { + timeout: 1000 + }; + const singleRequestOptions = { + signal: abortController.signal + }; + const imagenModel = new ImagenModel( + fakeAI, + { model: 'my-model' }, + requestOptions + ); + const mockResponse = getMockResponse( + 'vertexAI', + 'unary-success-generate-images-base64.json' + ); + const makeRequestStub = stub(request, 'makeRequest').resolves( + mockResponse as Response + ); + const prompt = 'A photorealistic image of a toy boat at sea.'; + await imagenModel.generateImages(prompt, singleRequestOptions); + expect(makeRequestStub).to.be.calledWith( + { + model: match.any, + task: request.Task.PREDICT, + apiSettings: match.any, + stream: false, + singleRequestOptions: { + timeout: requestOptions.timeout, + signal: singleRequestOptions.signal + } + }, + match.any + ); }); it('throws if prompt blocked', async () => { const mockResponse = getMockResponse( @@ -163,8 +234,76 @@ describe('ImagenModel', () => { expect((e as AIError).message).to.include( "Image generation failed with the following error: The prompt could not be submitted. This prompt contains sensitive words that violate Google's Responsible AI practices. Try rephrasing the prompt. If you think this was an error, send feedback." ); - } finally { - restore(); } }); + it('generateImagesGCS singleRequestOptions overrides requestOptions', async () => { + const requestOptions = { + timeout: 1000 + }; + const singleRequestOptions = { + timeout: 2000 + }; + const imagenModel = new ImagenModel( + fakeAI, + { model: 'my-model' }, + requestOptions + ); + const mockResponse = getMockResponse( + 'vertexAI', + 'unary-success-generate-images-gcs.json' + ); + const makeRequestStub = stub(request, 'makeRequest').resolves( + mockResponse as Response + ); + const prompt = 'A photorealistic image of a toy boat at sea.'; + await imagenModel.generateImagesGCS(prompt, '', singleRequestOptions); + expect(makeRequestStub).to.be.calledWith( + { + model: match.any, + task: request.Task.PREDICT, + apiSettings: match.any, + stream: false, + singleRequestOptions: { + timeout: singleRequestOptions.timeout + } + }, + match.any + ); + }); + it('generateImages singleRequestOptions is merged with requestOptions', async () => { + const abortController = new AbortController(); + const requestOptions = { + timeout: 1000 + }; + const singleRequestOptions = { + signal: abortController.signal + }; + const imagenModel = new ImagenModel( + fakeAI, + { model: 'my-model' }, + requestOptions + ); + const mockResponse = getMockResponse( + 'vertexAI', + 'unary-success-generate-images-gcs.json' + ); + const makeRequestStub = stub(request, 'makeRequest').resolves( + mockResponse as Response + ); + const prompt = 'A photorealistic image of a toy boat at sea.'; + await imagenModel.generateImagesGCS(prompt, '', singleRequestOptions); + expect(makeRequestStub).to.be.calledWith( + { + model: match.any, + task: request.Task.PREDICT, + apiSettings: match.any, + stream: false, + singleRequestOptions: { + timeout: requestOptions.timeout, + signal: singleRequestOptions.signal + } + }, + match.any + ); + }); }); diff --git a/packages/ai/src/models/imagen-model.ts b/packages/ai/src/models/imagen-model.ts index 567333ee64f..beeb01ac12c 100644 --- a/packages/ai/src/models/imagen-model.ts +++ b/packages/ai/src/models/imagen-model.ts @@ -26,7 +26,8 @@ import { RequestOptions, ImagenModelParams, ImagenGenerationResponse, - ImagenSafetySettings + ImagenSafetySettings, + SingleRequestOptions } from '../types'; import { AIModel } from './ai-model'; @@ -102,7 +103,8 @@ export class ImagenModel extends AIModel { * @public */ async generateImages( - prompt: string + prompt: string, + singleRequestOptions?: SingleRequestOptions ): Promise> { const body = createPredictRequestBody(prompt, { ...this.generationConfig, @@ -114,7 +116,11 @@ export class ImagenModel extends AIModel { model: this.model, apiSettings: this._apiSettings, stream: false, - requestOptions: this.requestOptions + // Merge request options. Single request options overwrite the model's request options. + singleRequestOptions: { + ...this.requestOptions, + ...singleRequestOptions + } }, JSON.stringify(body) ); @@ -142,7 +148,8 @@ export class ImagenModel extends AIModel { */ async generateImagesGCS( prompt: string, - gcsURI: string + gcsURI: string, + singleRequestOptions?: SingleRequestOptions ): Promise> { const body = createPredictRequestBody(prompt, { gcsURI, @@ -155,7 +162,11 @@ export class ImagenModel extends AIModel { model: this.model, apiSettings: this._apiSettings, stream: false, - requestOptions: this.requestOptions + // Merge request options. Single request options overwrite the model's request options. + singleRequestOptions: { + ...this.requestOptions, + ...singleRequestOptions + } }, JSON.stringify(body) ); diff --git a/packages/ai/src/models/template-generative-model.test.ts b/packages/ai/src/models/template-generative-model.test.ts index c3eb43af491..d3f7ec28ffa 100644 --- a/packages/ai/src/models/template-generative-model.test.ts +++ b/packages/ai/src/models/template-generative-model.test.ts @@ -73,6 +73,51 @@ describe('TemplateGenerativeModel', () => { { timeout: 5000 } ); }); + + it('singleRequestOptions overrides requestOptions', async () => { + const templateGenerateContentStub = stub( + generateContentMethods, + 'templateGenerateContent' + ).resolves({} as any); + const model = new TemplateGenerativeModel(fakeAI, { timeout: 1000 }); + const singleRequestOptions = { timeout: 2000 }; + + await model.generateContent( + TEMPLATE_ID, + TEMPLATE_VARS, + singleRequestOptions + ); + + expect(templateGenerateContentStub).to.have.been.calledOnceWith( + model._apiSettings, + TEMPLATE_ID, + { inputs: TEMPLATE_VARS }, + { timeout: 2000 } + ); + }); + + it('singleRequestOptions is merged with requestOptions', async () => { + const templateGenerateContentStub = stub( + generateContentMethods, + 'templateGenerateContent' + ).resolves({} as any); + const abortController = new AbortController(); + const model = new TemplateGenerativeModel(fakeAI, { timeout: 1000 }); + const singleRequestOptions = { signal: abortController.signal }; + + await model.generateContent( + TEMPLATE_ID, + TEMPLATE_VARS, + singleRequestOptions + ); + + expect(templateGenerateContentStub).to.have.been.calledOnceWith( + model._apiSettings, + TEMPLATE_ID, + { inputs: TEMPLATE_VARS }, + { timeout: 1000, signal: abortController.signal } + ); + }); }); describe('generateContentStream', () => { @@ -92,5 +137,50 @@ describe('TemplateGenerativeModel', () => { { timeout: 5000 } ); }); + + it('singleRequestOptions overrides requestOptions', async () => { + const templateGenerateContentStreamStub = stub( + generateContentMethods, + 'templateGenerateContentStream' + ).resolves({} as any); + const model = new TemplateGenerativeModel(fakeAI, { timeout: 1000 }); + const singleRequestOptions = { timeout: 2000 }; + + await model.generateContentStream( + TEMPLATE_ID, + TEMPLATE_VARS, + singleRequestOptions + ); + + expect(templateGenerateContentStreamStub).to.have.been.calledOnceWith( + model._apiSettings, + TEMPLATE_ID, + { inputs: TEMPLATE_VARS }, + { timeout: 2000 } + ); + }); + + it('singleRequestOptions is merged with requestOptions', async () => { + const templateGenerateContentStreamStub = stub( + generateContentMethods, + 'templateGenerateContentStream' + ).resolves({} as any); + const abortController = new AbortController(); + const model = new TemplateGenerativeModel(fakeAI, { timeout: 1000 }); + const singleRequestOptions = { signal: abortController.signal }; + + await model.generateContentStream( + TEMPLATE_ID, + TEMPLATE_VARS, + singleRequestOptions + ); + + expect(templateGenerateContentStreamStub).to.have.been.calledOnceWith( + model._apiSettings, + TEMPLATE_ID, + { inputs: TEMPLATE_VARS }, + { timeout: 1000, signal: abortController.signal } + ); + }); }); }); diff --git a/packages/ai/src/models/template-generative-model.ts b/packages/ai/src/models/template-generative-model.ts index ec9e653618d..ccc61253ed9 100644 --- a/packages/ai/src/models/template-generative-model.ts +++ b/packages/ai/src/models/template-generative-model.ts @@ -20,7 +20,11 @@ import { templateGenerateContentStream } from '../methods/generate-content'; import { GenerateContentResult, RequestOptions } from '../types'; -import { AI, GenerateContentStreamResult } from '../public-types'; +import { + AI, + GenerateContentStreamResult, + SingleRequestOptions +} from '../public-types'; import { ApiSettings } from '../types/internal'; import { initApiSettings } from './utils'; @@ -62,13 +66,17 @@ export class TemplateGenerativeModel { */ async generateContent( templateId: string, - templateVariables: object // anything! + templateVariables: object, // anything! + singleRequestOptions?: SingleRequestOptions ): Promise { return templateGenerateContent( this._apiSettings, templateId, { inputs: templateVariables }, - this.requestOptions + { + ...this.requestOptions, + ...singleRequestOptions + } ); } @@ -86,13 +94,17 @@ export class TemplateGenerativeModel { */ async generateContentStream( templateId: string, - templateVariables: object + templateVariables: object, + singleRequestOptions?: SingleRequestOptions ): Promise { return templateGenerateContentStream( this._apiSettings, templateId, { inputs: templateVariables }, - this.requestOptions + { + ...this.requestOptions, + ...singleRequestOptions + } ); } } diff --git a/packages/ai/src/models/template-imagen-model.test.ts b/packages/ai/src/models/template-imagen-model.test.ts index c053753ea0f..9451981f83d 100644 --- a/packages/ai/src/models/template-imagen-model.test.ts +++ b/packages/ai/src/models/template-imagen-model.test.ts @@ -18,7 +18,7 @@ import { use, expect } from 'chai'; import sinonChai from 'sinon-chai'; import chaiAsPromised from 'chai-as-promised'; -import { restore, stub } from 'sinon'; +import { restore, stub, match } from 'sinon'; import { AI } from '../public-types'; import { VertexAIBackend } from '../backend'; import { TemplateImagenModel } from './template-imagen-model'; @@ -83,12 +83,68 @@ describe('TemplateImagenModel', () => { templateId: TEMPLATE_ID, apiSettings: model._apiSettings, stream: false, - requestOptions: { timeout: 5000 } + singleRequestOptions: { timeout: 5000 } }, JSON.stringify({ inputs: TEMPLATE_VARS }) ); }); + it('singleRequestOptions overrides requestOptions', async () => { + const mockPrediction = { + 'bytesBase64Encoded': + 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg==', + 'mimeType': 'image/png' + }; + const makeRequestStub = stub(request, 'makeRequest').resolves({ + json: () => Promise.resolve({ predictions: [mockPrediction] }) + } as Response); + const model = new TemplateImagenModel(fakeAI, { timeout: 1000 }); + const singleRequestOptions = { timeout: 2000 }; + + await model.generateImages( + TEMPLATE_ID, + TEMPLATE_VARS, + singleRequestOptions + ); + + expect(makeRequestStub).to.have.been.calledOnceWith( + match({ + singleRequestOptions: { timeout: 2000 } + }), + match.any + ); + }); + + it('singleRequestOptions is merged with requestOptions', async () => { + const mockPrediction = { + 'bytesBase64Encoded': + 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg==', + 'mimeType': 'image/png' + }; + const makeRequestStub = stub(request, 'makeRequest').resolves({ + json: () => Promise.resolve({ predictions: [mockPrediction] }) + } as Response); + const abortController = new AbortController(); + const model = new TemplateImagenModel(fakeAI, { timeout: 1000 }); + const singleRequestOptions = { signal: abortController.signal }; + + await model.generateImages( + TEMPLATE_ID, + TEMPLATE_VARS, + singleRequestOptions + ); + + expect(makeRequestStub).to.have.been.calledOnceWith( + match({ + singleRequestOptions: { + timeout: 1000, + signal: abortController.signal + } + }), + match.any + ); + }); + it('should return the result of handlePredictResponse', async () => { const mockPrediction = { 'bytesBase64Encoded': diff --git a/packages/ai/src/models/template-imagen-model.ts b/packages/ai/src/models/template-imagen-model.ts index 34325c711b3..be4d10f72d0 100644 --- a/packages/ai/src/models/template-imagen-model.ts +++ b/packages/ai/src/models/template-imagen-model.ts @@ -19,7 +19,8 @@ import { RequestOptions } from '../types'; import { AI, ImagenGenerationResponse, - ImagenInlineImage + ImagenInlineImage, + SingleRequestOptions } from '../public-types'; import { ApiSettings } from '../types/internal'; import { makeRequest, ServerPromptTemplateTask } from '../requests/request'; @@ -64,7 +65,8 @@ export class TemplateImagenModel { */ async generateImages( templateId: string, - templateVariables: object + templateVariables: object, + singleRequestOptions?: SingleRequestOptions ): Promise> { const response = await makeRequest( { @@ -72,7 +74,10 @@ export class TemplateImagenModel { templateId, apiSettings: this._apiSettings, stream: false, - requestOptions: this.requestOptions + singleRequestOptions: { + ...this.requestOptions, + ...singleRequestOptions + } }, JSON.stringify({ inputs: templateVariables }) ); diff --git a/packages/ai/src/requests/request.test.ts b/packages/ai/src/requests/request.test.ts index a54ff521bea..a1e15c2623b 100644 --- a/packages/ai/src/requests/request.test.ts +++ b/packages/ai/src/requests/request.test.ts @@ -16,7 +16,7 @@ */ import { expect, use } from 'chai'; -import { match, restore, stub } from 'sinon'; +import Sinon, { match, restore, stub, useFakeTimers } from 'sinon'; import sinonChai from 'sinon-chai'; import chaiAsPromised from 'chai-as-promised'; import { @@ -55,7 +55,7 @@ describe('request methods', () => { task: Task.GENERATE_CONTENT, apiSettings: fakeApiSettings, stream: true, - requestOptions: {} + singleRequestOptions: undefined }); expect(url.toString()).to.include('models/model-name:generateContent'); expect(url.toString()).to.include('alt=sse'); @@ -66,7 +66,7 @@ describe('request methods', () => { task: Task.GENERATE_CONTENT, apiSettings: fakeApiSettings, stream: false, - requestOptions: {} + singleRequestOptions: undefined }); expect(url.toString()).to.include('models/model-name:generateContent'); expect(url.toString()).to.not.include(fakeApiSettings); @@ -78,7 +78,7 @@ describe('request methods', () => { task: Task.GENERATE_CONTENT, apiSettings: fakeApiSettings, stream: false, - requestOptions: {} + singleRequestOptions: undefined }); expect(url.toString()).to.include(DEFAULT_API_VERSION); }); @@ -88,7 +88,7 @@ describe('request methods', () => { task: Task.GENERATE_CONTENT, apiSettings: fakeApiSettings, stream: false, - requestOptions: { baseUrl: 'https://my.special.endpoint' } + singleRequestOptions: { baseUrl: 'https://my.special.endpoint' } }); expect(url.toString()).to.include('https://my.special.endpoint'); }); @@ -98,7 +98,7 @@ describe('request methods', () => { task: Task.GENERATE_CONTENT, apiSettings: fakeApiSettings, stream: false, - requestOptions: {} + singleRequestOptions: undefined }); expect(url.toString()).to.include( 'tunedModels/model-name:generateContent' @@ -112,7 +112,7 @@ describe('request methods', () => { task: ServerPromptTemplateTask.TEMPLATE_GENERATE_CONTENT, apiSettings: fakeApiSettings, stream: false, - requestOptions: {} + singleRequestOptions: undefined }); expect(url.toString()).to.include( 'templates/my-template:templateGenerateContent' @@ -135,7 +135,7 @@ describe('request methods', () => { task: Task.GENERATE_CONTENT, apiSettings: fakeApiSettings, stream: true, - requestOptions: {} + singleRequestOptions: undefined }); it('adds client headers', async () => { const headers = await getHeaders(fakeUrl); @@ -163,7 +163,7 @@ describe('request methods', () => { task: Task.GENERATE_CONTENT, apiSettings: fakeApiSettings, stream: true, - requestOptions: {} + singleRequestOptions: undefined }); const headers = await getHeaders(fakeUrl); expect(headers.get('X-Firebase-Appid')).to.equal('my-appid'); @@ -188,7 +188,7 @@ describe('request methods', () => { task: Task.GENERATE_CONTENT, apiSettings: fakeApiSettings, stream: true, - requestOptions: {} + singleRequestOptions: undefined }); const headers = await getHeaders(fakeUrl); expect(headers.get('X-Firebase-Appid')).to.be.null; @@ -209,7 +209,7 @@ describe('request methods', () => { backend: new VertexAIBackend() }, stream: true, - requestOptions: {} + singleRequestOptions: undefined }); const headers = await getHeaders(fakeUrl); expect(headers.has('X-Firebase-AppCheck')).to.be.false; @@ -226,7 +226,7 @@ describe('request methods', () => { getAppCheckToken: () => Promise.resolve() }, stream: true, - requestOptions: {} + singleRequestOptions: undefined }); const headers = await getHeaders(fakeUrl); expect(headers.has('X-Firebase-AppCheck')).to.be.false; @@ -245,7 +245,7 @@ describe('request methods', () => { Promise.resolve({ token: 'dummytoken', error: Error('oops') }) }, stream: true, - requestOptions: {} + singleRequestOptions: undefined }); const warnStub = stub(console, 'warn'); const headers = await getHeaders(fakeUrl); @@ -271,7 +271,7 @@ describe('request methods', () => { backend: new VertexAIBackend() }, stream: true, - requestOptions: {} + singleRequestOptions: undefined }); const headers = await getHeaders(fakeUrl); expect(headers.has('Authorization')).to.be.false; @@ -288,15 +288,43 @@ describe('request methods', () => { getAppCheckToken: () => Promise.resolve() }, stream: true, - requestOptions: {} + singleRequestOptions: undefined }); const headers = await getHeaders(fakeUrl); expect(headers.has('Authorization')).to.be.false; }); }); describe('makeRequest', () => { + let fetchStub: Sinon.SinonStub; + let clock: Sinon.SinonFakeTimers; + const fetchAborter = ( + _url: string, + options?: RequestInit + ): Promise => { + expect(options).to.not.be.undefined; + expect(options!.signal).to.not.be.undefined; + const signal = options!.signal; + return new Promise((_resolve, reject): void => { + const abortListener = (): void => { + reject(new DOMException(signal?.reason || 'Aborted', 'AbortError')); + }; + + signal?.addEventListener('abort', abortListener, { once: true }); + }); + }; + + beforeEach(() => { + fetchStub = stub(globalThis, 'fetch'); + clock = useFakeTimers(); + }); + + afterEach(() => { + restore(); + clock.restore(); + }); + it('no error', async () => { - const fetchStub = stub(globalThis, 'fetch').resolves({ + fetchStub.resolves({ ok: true } as Response); const response = await makeRequest( @@ -312,7 +340,7 @@ describe('request methods', () => { expect(response.ok).to.be.true; }); it('error with timeout', async () => { - const fetchStub = stub(globalThis, 'fetch').resolves({ + fetchStub.resolves({ ok: false, status: 500, statusText: 'AbortError' @@ -325,7 +353,7 @@ describe('request methods', () => { task: Task.GENERATE_CONTENT, apiSettings: fakeApiSettings, stream: false, - requestOptions: { + singleRequestOptions: { timeout: 180000 } }, @@ -343,7 +371,7 @@ describe('request methods', () => { expect(fetchStub).to.be.calledOnce; }); it('Network error, no response.json()', async () => { - const fetchStub = stub(globalThis, 'fetch').resolves({ + fetchStub.resolves({ ok: false, status: 500, statusText: 'Server Error' @@ -369,7 +397,7 @@ describe('request methods', () => { expect(fetchStub).to.be.calledOnce; }); it('Network error, includes response.json()', async () => { - const fetchStub = stub(globalThis, 'fetch').resolves({ + fetchStub.resolves({ ok: false, status: 500, statusText: 'Server Error', @@ -397,7 +425,7 @@ describe('request methods', () => { expect(fetchStub).to.be.calledOnce; }); it('Network error, includes response.json() and details', async () => { - const fetchStub = stub(globalThis, 'fetch').resolves({ + fetchStub.resolves({ ok: false, status: 500, statusText: 'Server Error', @@ -437,16 +465,234 @@ describe('request methods', () => { } expect(fetchStub).to.be.calledOnce; }); - }); - it('Network error, API not enabled', async () => { - const mockResponse = getMockResponse( - 'vertexAI', - 'unary-failure-firebasevertexai-api-not-enabled.json' - ); - const fetchStub = stub(globalThis, 'fetch').resolves( - mockResponse as Response - ); - try { + it('Network error, API not enabled', async () => { + const mockResponse = getMockResponse( + 'vertexAI', + 'unary-failure-firebasevertexai-api-not-enabled.json' + ); + fetchStub.resolves(mockResponse as Response); + try { + await makeRequest( + { + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false + }, + '' + ); + } catch (e) { + expect((e as AIError).code).to.equal(AIErrorCode.API_NOT_ENABLED); + expect((e as AIError).message).to.include('my-project'); + expect((e as AIError).message).to.include('googleapis.com'); + } + expect(fetchStub).to.be.calledOnce; + }); + + it('should throw DOMException if external signal is already aborted', async () => { + const controller = new AbortController(); + const abortReason = 'Aborted before request'; + controller.abort(abortReason); + + const requestPromise = makeRequest( + { + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false, + singleRequestOptions: { signal: controller.signal } + }, + '{}' + ); + + await expect(requestPromise).to.be.rejectedWith( + DOMException, + abortReason + ); + + expect(fetchStub).not.to.have.been.called; + }); + it('should abort fetch if external signal aborts during request', async () => { + fetchStub.callsFake(fetchAborter); + const controller = new AbortController(); + const abortReason = 'Aborted during request'; + + const requestPromise = makeRequest( + { + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false, + singleRequestOptions: { signal: controller.signal } + }, + '{}' + ); + + await clock.tickAsync(0); + controller.abort(abortReason); + + await expect(requestPromise).to.be.rejectedWith('AbortError'); + }); + + it('should abort fetch if timeout expires during request', async () => { + const timeoutDuration = 100; + fetchStub.callsFake(fetchAborter); + + const requestPromise = makeRequest( + { + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false, + singleRequestOptions: { timeout: timeoutDuration } + }, + '{}' + ); + + await clock.tickAsync(timeoutDuration + 100); + + await expect(requestPromise).to.be.rejectedWith( + 'AbortError', + 'Timeout has expired' + ); + + expect(fetchStub).to.have.been.calledOnce; + const fetchOptions = fetchStub.firstCall.args[1] as RequestInit; + const internalSignal = fetchOptions.signal; + + expect(internalSignal?.aborted).to.be.true; + expect((internalSignal?.reason as Error).name).to.equal('AbortError'); + expect((internalSignal?.reason as Error).message).to.equal( + 'Timeout has expired.' + ); + }); + + it('should succeed and clear timeout if fetch completes before timeout', async () => { + const mockResponse = new Response('{}', { + status: 200, + statusText: 'OK' + }); + const fetchPromise = Promise.resolve(mockResponse); + fetchStub.resolves(fetchPromise); + const clearTimeoutStub = stub(globalThis, 'clearTimeout'); + + const requestPromise = makeRequest( + { + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false, + singleRequestOptions: { timeout: 5000 } // Generous timeout + }, + '{}' + ); + + // Advance time slightly, well within timeout + await clock.tickAsync(10); + + const response = await requestPromise; + expect(response.ok).to.be.true; + expect(clearTimeoutStub).to.have.been.calledOnce; + expect(fetchStub).to.have.been.calledOnce; + }); + + it('should succeed and clear timeout/listener if fetch completes with signal provided but not aborted', async () => { + const controller = new AbortController(); + const mockResponse = new Response('{}', { + status: 200, + statusText: 'OK' + }); + const fetchPromise = Promise.resolve(mockResponse); + fetchStub.resolves(fetchPromise); + + const requestPromise = makeRequest( + { + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false, + singleRequestOptions: { signal: controller.signal } // Generous timeout + }, + '{}' + ); + + // Advance time slightly + await clock.tickAsync(10); + + const response = await requestPromise; + expect(response.ok).to.be.true; + expect(fetchStub).to.have.been.calledOnce; + }); + + it('should use external signal abort reason if it occurs before timeout', async () => { + const controller = new AbortController(); + const abortReason = 'External Abort Wins'; + const timeoutDuration = 500; + fetchStub.callsFake(fetchAborter); + + const requestPromise = makeRequest( + { + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false, + singleRequestOptions: { + signal: controller.signal, + timeout: timeoutDuration + } + }, + '{}' + ); + + // Advance time, but less than the timeout + await clock.tickAsync(timeoutDuration / 2); + controller.abort(abortReason); + + await expect(requestPromise).to.be.rejectedWith( + 'AbortError', + abortReason + ); + }); + + it('should use timeout reason if it occurs before external signal abort', async () => { + const controller = new AbortController(); + const abortReason = 'External Abort Loses'; + const timeoutDuration = 100; + fetchStub.callsFake(fetchAborter); + + const requestPromise = makeRequest( + { + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false, + singleRequestOptions: { + signal: controller.signal, + timeout: timeoutDuration + } + }, + '{}' + ); + + // Schedule external abort after timeout + setTimeout(() => controller.abort(abortReason), timeoutDuration * 2); + + // Advance time past the timeout + await clock.tickAsync(timeoutDuration + 1); + + await expect(requestPromise).to.be.rejectedWith( + 'AbortError', + 'Timeout has expired' + ); + }); + + it('should pass internal signal to fetch options', async () => { + const mockResponse = new Response('{}', { + status: 200, + statusText: 'OK' + }); + fetchStub.resolves(mockResponse); + await makeRequest( { model: 'models/model-name', @@ -456,11 +702,131 @@ describe('request methods', () => { }, '' ); - } catch (e) { - expect((e as AIError).code).to.equal(AIErrorCode.API_NOT_ENABLED); - expect((e as AIError).message).to.include('my-project'); - expect((e as AIError).message).to.include('googleapis.com'); - } - expect(fetchStub).to.be.calledOnce; + + expect(fetchStub).to.have.been.calledOnce; + const fetchOptions = fetchStub.firstCall.args[1] as RequestInit; + expect(fetchOptions.signal).to.exist; + expect(fetchOptions.signal).to.be.instanceOf(AbortSignal); + expect(fetchOptions.signal?.aborted).to.be.false; + }); + + it('should remove abort listener on successful completion to prevent memory leaks', async () => { + const controller = new AbortController(); + const addSpy = Sinon.spy(controller.signal, 'addEventListener'); + const removeSpy = Sinon.spy(controller.signal, 'removeEventListener'); + + const mockResponse = new Response('{}', { + status: 200, + statusText: 'OK' + }); + fetchStub.resolves(mockResponse); + + await makeRequest( + { + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false, + singleRequestOptions: { signal: controller.signal } + }, + '{}' + ); + + expect(addSpy).to.have.been.calledOnceWith('abort'); + expect(removeSpy).to.have.been.calledOnceWith('abort'); + }); + + it('should remove listener if fetch itself rejects', async () => { + const controller = new AbortController(); + const removeSpy = Sinon.spy(controller.signal, 'removeEventListener'); + const error = new Error('Network failure'); + fetchStub.rejects(error); + + const requestPromise = makeRequest( + { + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false, + singleRequestOptions: { signal: controller.signal } + }, + '{}' + ); + + await expect(requestPromise).to.be.rejectedWith( + AIError, + /Network failure/ + ); + expect(removeSpy).to.have.been.calledOnce; + }); + + it('should remove listener if response is not ok', async () => { + const controller = new AbortController(); + const removeSpy = Sinon.spy(controller.signal, 'removeEventListener'); + const mockResponse = new Response('{}', { + status: 500, + statusText: 'Internal Server Error' + }); + fetchStub.resolves(mockResponse); + + const requestPromise = makeRequest( + { + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false, + singleRequestOptions: { signal: controller.signal } + }, + '{}' + ); + + await expect(requestPromise).to.be.rejectedWith(AIError, /500/); + expect(removeSpy).to.have.been.calledOnce; + }); + + it('should abort immediately if timeout is 0', async () => { + fetchStub.callsFake(fetchAborter); + const requestPromise = makeRequest( + { + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false, + singleRequestOptions: { timeout: 0 } + }, + '{}' + ); + + // Tick the clock just enough to trigger a timeout(0) + await clock.tickAsync(1); + + await expect(requestPromise).to.be.rejectedWith('AbortError'); + }); + + it('should not error if signal is aborted after completion', async () => { + const controller = new AbortController(); + const removeSpy = Sinon.spy(controller.signal, 'removeEventListener'); + const mockResponse = new Response('{}', { + status: 200, + statusText: 'OK' + }); + fetchStub.resolves(mockResponse); + + await makeRequest( + { + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false, + singleRequestOptions: { signal: controller.signal } + }, + '{}' + ); + + // Listener should be removed, so this abort should do nothing. + controller.abort('Too late'); + + expect(removeSpy).to.have.been.calledOnce; + }); }); }); diff --git a/packages/ai/src/requests/request.ts b/packages/ai/src/requests/request.ts index 7664765ab03..5ea1c3287c4 100644 --- a/packages/ai/src/requests/request.ts +++ b/packages/ai/src/requests/request.ts @@ -15,7 +15,7 @@ * limitations under the License. */ -import { ErrorDetails, RequestOptions, AIErrorCode } from '../types'; +import { SingleRequestOptions, AIErrorCode, ErrorDetails } from '../types'; import { AIError } from '../errors'; import { ApiSettings } from '../types/internal'; import { @@ -27,6 +27,9 @@ import { import { logger } from '../logger'; import { BackendType } from '../public-types'; +const TIMEOUT_EXPIRED_MESSAGE = 'Timeout has expired.'; +const ABORT_ERROR_NAME = 'AbortError'; + export const enum Task { GENERATE_CONTENT = 'generateContent', STREAM_GENERATE_CONTENT = 'streamGenerateContent', @@ -43,7 +46,7 @@ export const enum ServerPromptTemplateTask { interface BaseRequestURLParams { apiSettings: ApiSettings; stream: boolean; - requestOptions?: RequestOptions; + singleRequestOptions?: SingleRequestOptions; } /** @@ -94,7 +97,9 @@ export class RequestURL { } private get baseUrl(): string { - return this.params.requestOptions?.baseUrl ?? `https://${DEFAULT_DOMAIN}`; + return ( + this.params.singleRequestOptions?.baseUrl ?? `https://${DEFAULT_DOMAIN}` + ); } private get queryParams(): URLSearchParams { @@ -175,24 +180,46 @@ export async function makeRequest( ): Promise { const url = new RequestURL(requestUrlParams); let response; - let fetchTimeoutId: string | number | NodeJS.Timeout | undefined; + + const externalSignal = requestUrlParams.singleRequestOptions?.signal; + const timeoutMillis = + requestUrlParams.singleRequestOptions?.timeout != null && + requestUrlParams.singleRequestOptions.timeout >= 0 + ? requestUrlParams.singleRequestOptions.timeout + : DEFAULT_FETCH_TIMEOUT_MS; + + const internalAbortController = new AbortController(); + const fetchTimeoutId = setTimeout(() => { + internalAbortController.abort( + new DOMException(TIMEOUT_EXPIRED_MESSAGE, ABORT_ERROR_NAME) + ); + logger.debug( + `Aborting request to ${url} due to timeout (${timeoutMillis}ms)` + ); + }, timeoutMillis); + + const combinedSignal = AbortSignal.any( + externalSignal + ? [externalSignal, internalAbortController.signal] + : [internalAbortController.signal] + ); + + if (externalSignal && externalSignal.aborted) { + clearTimeout(fetchTimeoutId); + throw new DOMException( + externalSignal.reason ?? 'Aborted externally before fetch', + ABORT_ERROR_NAME + ); + } + try { const fetchOptions: RequestInit = { method: 'POST', headers: await getHeaders(url), + signal: combinedSignal, body }; - // Timeout is 180s by default. - const timeoutMillis = - requestUrlParams.requestOptions?.timeout != null && - requestUrlParams.requestOptions.timeout >= 0 - ? requestUrlParams.requestOptions.timeout - : DEFAULT_FETCH_TIMEOUT_MS; - const abortController = new AbortController(); - fetchTimeoutId = setTimeout(() => abortController.abort(), timeoutMillis); - fetchOptions.signal = abortController.signal; - response = await fetch(url.toString(), fetchOptions); if (!response.ok) { let message = ''; @@ -252,7 +279,8 @@ export async function makeRequest( if ( (e as AIError).code !== AIErrorCode.FETCH_ERROR && (e as AIError).code !== AIErrorCode.API_NOT_ENABLED && - e instanceof Error + e instanceof Error && + (e as DOMException).name !== ABORT_ERROR_NAME ) { err = new AIError( AIErrorCode.ERROR, @@ -263,9 +291,7 @@ export async function makeRequest( throw err; } finally { - if (fetchTimeoutId) { - clearTimeout(fetchTimeoutId); - } + clearTimeout(fetchTimeoutId); } return response; } diff --git a/packages/ai/src/types/requests.ts b/packages/ai/src/types/requests.ts index 6e5d2147686..991453c53f3 100644 --- a/packages/ai/src/types/requests.ts +++ b/packages/ai/src/types/requests.ts @@ -253,6 +253,47 @@ export interface RequestOptions { baseUrl?: string; } +/** + * Options that can be provided per-request. + * Extends the base {@link RequestOptions} (like `timeout` and `baseUrl`) + * with request-specific controls like cancellation via `AbortSignal`. + * + * Options specified here will override any default {@link RequestOptions} + * configured on a model (for example, {@link GenerativeModel}). + * + * @public + */ +export interface SingleRequestOptions extends RequestOptions { + /** + * An `AbortSignal` instance that allows cancelling ongoing requests (like `generateContent` or + * `generateImages`). + * + * If provided, calling `abort()` on the corresponding `AbortController` + * will attempt to cancel the underlying HTTP request. An `AbortError` will be thrown + * if cancellation is successful. + * + * Note that this will not cancel the request in the backend, so any applicable billing charges + * will still be applied despite cancellation. + * + * @example + * ```javascript + * const controller = new AbortController(); + * const model = getGenerativeModel({ + * // ... + * }); + * model.generateContent( + * "Write a story about a magic backpack.", + * { signal: controller.signal } + * ); + * + * // To cancel request: + * controller.abort(); + * ``` + * @see https://developer.mozilla.org/en-US/docs/Web/API/AbortSignal + */ + signal?: AbortSignal; +} + /** * Defines a tool that model can call to access external knowledge. * @public