From de216fed4a43378512a022c07927bbd5cd06c037 Mon Sep 17 00:00:00 2001 From: Luis Catacora Date: Mon, 20 Oct 2025 02:07:13 -0400 Subject: [PATCH 1/2] Remove temporary replicate gpt-5 mapping --- .../inference/src/lib/getProviderHelper.ts | 13 ++-- packages/inference/src/providers/consts.ts | 2 +- packages/inference/src/providers/replicate.ts | 72 +++++++++++++++++-- .../inference/test/InferenceClient.spec.ts | 26 +++++-- 4 files changed, 94 insertions(+), 19 deletions(-) diff --git a/packages/inference/src/lib/getProviderHelper.ts b/packages/inference/src/lib/getProviderHelper.ts index 5f5f16b044..9021079813 100644 --- a/packages/inference/src/lib/getProviderHelper.ts +++ b/packages/inference/src/lib/getProviderHelper.ts @@ -153,12 +153,13 @@ export const PROVIDERS: Record 0) { + return text; + } + } + return undefined; + } + if (typeof value === "object") { + const record = value as Record; + const directTextKeys = ["output_text", "generated_text", "text", "content"] as const; + for (const key of directTextKeys) { + const maybeText = record[key]; + if (typeof maybeText === "string" && maybeText.length > 0) { + return maybeText; + } + } + const nestedKeys = ["output", "choices", "message", "delta", "content", "data"] as const; + for (const key of nestedKeys) { + if (key in record) { + const text = extractTextFromReplicateResponse(record[key]); + if (typeof text === "string" && text.length > 0) { + return text; + } + } + } + } + return undefined; } abstract class ReplicateTask extends TaskProviderHelper { @@ -116,6 +155,25 @@ export class ReplicateTextToImageTask extends ReplicateTask implements TextToIma } } +export class ReplicateTextGenerationTask extends ReplicateTask implements TextGenerationTaskHelper { + override async getResponse(response: ReplicateOutput): Promise { + if (response instanceof Blob) { + throw new InferenceClientProviderOutputError( + "Received malformed response from Replicate text-generation API" + ); + } + + const text = extractTextFromReplicateResponse(response); + if (typeof text === "string") { + return { generated_text: text }; + } + + throw new InferenceClientProviderOutputError( + "Received malformed response from Replicate text-generation API" + ); + } +} + export class ReplicateTextToSpeechTask extends ReplicateTask { override preparePayload(params: BodyParams): Record { const payload = super.preparePayload(params); diff --git a/packages/inference/test/InferenceClient.spec.ts b/packages/inference/test/InferenceClient.spec.ts index fcb6e55cb3..029658db45 100644 --- a/packages/inference/test/InferenceClient.spec.ts +++ b/packages/inference/test/InferenceClient.spec.ts @@ -1193,11 +1193,27 @@ describe.skip("InferenceClient", () => { describe.concurrent( "Replicate", () => { - const client = new InferenceClient(env.HF_REPLICATE_KEY ?? "dummy"); - - it("textToImage canonical - black-forest-labs/FLUX.1-schnell", async () => { - const res = await client.textToImage({ - model: "black-forest-labs/FLUX.1-schnell", + const client = new InferenceClient(env.HF_REPLICATE_KEY ?? "dummy"); + + it("textGeneration - akhaliq/gpt-5", async () => { + const res = await client.textGeneration({ + model: "akhaliq/gpt-5", + provider: "replicate", + inputs: "The capital city of France is", + parameters: { + max_new_tokens: 20, + temperature: 0.2, + }, + }); + + expect(res).toBeDefined(); + expect(typeof res.generated_text).toBe("string"); + expect(res.generated_text.length).toBeGreaterThan(0); + }); + + it("textToImage canonical - black-forest-labs/FLUX.1-schnell", async () => { + const res = await client.textToImage({ + model: "black-forest-labs/FLUX.1-schnell", provider: "replicate", inputs: "black forest gateau cake spelling out the words FLUX SCHNELL, tasty, food photography, dynamic shot", }); From 7eb74812be64f385c49d03656ba8e0f24e35cab6 Mon Sep 17 00:00:00 2001 From: Luis Catacora Date: Mon, 20 Oct 2025 02:14:33 -0400 Subject: [PATCH 2/2] Fix format --- .../inference/src/lib/getProviderHelper.ts | 14 +-- packages/inference/src/providers/consts.ts | 2 +- packages/inference/src/providers/replicate.ts | 108 +++++++++--------- .../inference/test/InferenceClient.spec.ts | 42 +++---- 4 files changed, 81 insertions(+), 85 deletions(-) diff --git a/packages/inference/src/lib/getProviderHelper.ts b/packages/inference/src/lib/getProviderHelper.ts index 9021079813..efc205c9ff 100644 --- a/packages/inference/src/lib/getProviderHelper.ts +++ b/packages/inference/src/lib/getProviderHelper.ts @@ -153,13 +153,13 @@ export const PROVIDERS: Record 0) { - return text; - } - } - return undefined; - } - if (typeof value === "object") { - const record = value as Record; - const directTextKeys = ["output_text", "generated_text", "text", "content"] as const; - for (const key of directTextKeys) { - const maybeText = record[key]; - if (typeof maybeText === "string" && maybeText.length > 0) { - return maybeText; - } - } - const nestedKeys = ["output", "choices", "message", "delta", "content", "data"] as const; - for (const key of nestedKeys) { - if (key in record) { - const text = extractTextFromReplicateResponse(record[key]); - if (typeof text === "string" && text.length > 0) { - return text; - } - } - } - } - return undefined; + if (value == null) { + return undefined; + } + if (typeof value === "string") { + return value; + } + if (Array.isArray(value)) { + for (const item of value) { + const text = extractTextFromReplicateResponse(item); + if (typeof text === "string" && text.length > 0) { + return text; + } + } + return undefined; + } + if (typeof value === "object") { + const record = value as Record; + const directTextKeys = ["output_text", "generated_text", "text", "content"] as const; + for (const key of directTextKeys) { + const maybeText = record[key]; + if (typeof maybeText === "string" && maybeText.length > 0) { + return maybeText; + } + } + const nestedKeys = ["output", "choices", "message", "delta", "content", "data"] as const; + for (const key of nestedKeys) { + if (key in record) { + const text = extractTextFromReplicateResponse(record[key]); + if (typeof text === "string" && text.length > 0) { + return text; + } + } + } + } + return undefined; } abstract class ReplicateTask extends TaskProviderHelper { @@ -156,22 +156,18 @@ export class ReplicateTextToImageTask extends ReplicateTask implements TextToIma } export class ReplicateTextGenerationTask extends ReplicateTask implements TextGenerationTaskHelper { - override async getResponse(response: ReplicateOutput): Promise { - if (response instanceof Blob) { - throw new InferenceClientProviderOutputError( - "Received malformed response from Replicate text-generation API" - ); - } + override async getResponse(response: ReplicateOutput): Promise { + if (response instanceof Blob) { + throw new InferenceClientProviderOutputError("Received malformed response from Replicate text-generation API"); + } - const text = extractTextFromReplicateResponse(response); - if (typeof text === "string") { - return { generated_text: text }; - } + const text = extractTextFromReplicateResponse(response); + if (typeof text === "string") { + return { generated_text: text }; + } - throw new InferenceClientProviderOutputError( - "Received malformed response from Replicate text-generation API" - ); - } + throw new InferenceClientProviderOutputError("Received malformed response from Replicate text-generation API"); + } } export class ReplicateTextToSpeechTask extends ReplicateTask { diff --git a/packages/inference/test/InferenceClient.spec.ts b/packages/inference/test/InferenceClient.spec.ts index 029658db45..3fe4aa4b8f 100644 --- a/packages/inference/test/InferenceClient.spec.ts +++ b/packages/inference/test/InferenceClient.spec.ts @@ -1193,27 +1193,27 @@ describe.skip("InferenceClient", () => { describe.concurrent( "Replicate", () => { - const client = new InferenceClient(env.HF_REPLICATE_KEY ?? "dummy"); - - it("textGeneration - akhaliq/gpt-5", async () => { - const res = await client.textGeneration({ - model: "akhaliq/gpt-5", - provider: "replicate", - inputs: "The capital city of France is", - parameters: { - max_new_tokens: 20, - temperature: 0.2, - }, - }); - - expect(res).toBeDefined(); - expect(typeof res.generated_text).toBe("string"); - expect(res.generated_text.length).toBeGreaterThan(0); - }); - - it("textToImage canonical - black-forest-labs/FLUX.1-schnell", async () => { - const res = await client.textToImage({ - model: "black-forest-labs/FLUX.1-schnell", + const client = new InferenceClient(env.HF_REPLICATE_KEY ?? "dummy"); + + it("textGeneration - akhaliq/gpt-5", async () => { + const res = await client.textGeneration({ + model: "akhaliq/gpt-5", + provider: "replicate", + inputs: "The capital city of France is", + parameters: { + max_new_tokens: 20, + temperature: 0.2, + }, + }); + + expect(res).toBeDefined(); + expect(typeof res.generated_text).toBe("string"); + expect(res.generated_text.length).toBeGreaterThan(0); + }); + + it("textToImage canonical - black-forest-labs/FLUX.1-schnell", async () => { + const res = await client.textToImage({ + model: "black-forest-labs/FLUX.1-schnell", provider: "replicate", inputs: "black forest gateau cake spelling out the words FLUX SCHNELL, tasty, food photography, dynamic shot", });