Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions packages/inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ Currently, we support the following providers:
- [Cohere](https://cohere.com)
- [Cerebras](https://cerebras.ai/)
- [Groq](https://groq.com)
- [ZAI](https://z.ai/)

To send requests to a third-party provider, you have to pass the `provider` parameter to the inference function. The default value of the `provider` parameter is "auto", which will select the first of the providers available for the model, sorted by your preferred order in https://hf.co/settings/inference-providers.

Expand Down Expand Up @@ -100,6 +101,7 @@ Only a subset of models are supported when requesting third-party providers. You
- [Cerebras supported models](https://huggingface.co/api/partners/cerebras/models)
- [Groq supported models](https://console.groq.com/docs/models)
- [Novita AI supported models](https://huggingface.co/api/partners/novita/models)
- [ZAI supported models](https://huggingface.co/api/partners/zai/models)

❗**Important note:** To be compatible, the third-party API must adhere to the "standard" shape API we expect on HF model pages for each pipeline task type.
This is not an issue for LLMs as everyone converged on the OpenAI API anyways, but can be more tricky for other tasks like "text-to-image" or "automatic-speech-recognition" where there exists no standard API. Let us know if any help is needed or if we can make things easier for you!
Expand Down
4 changes: 4 additions & 0 deletions packages/inference/src/lib/getProviderHelper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ import * as Replicate from "../providers/replicate.js";
import * as Sambanova from "../providers/sambanova.js";
import * as Scaleway from "../providers/scaleway.js";
import * as Together from "../providers/together.js";
import * as Zai from "../providers/zai.js";
import type { InferenceProvider, InferenceProviderOrPolicy, InferenceTask } from "../types.js";
import { InferenceClientInputError } from "../errors.js";

Expand Down Expand Up @@ -164,6 +165,9 @@ export const PROVIDERS: Record<InferenceProvider, Partial<Record<InferenceTask,
conversational: new Together.TogetherConversationalTask(),
"text-generation": new Together.TogetherTextGenerationTask(),
},
zai: {
conversational: new Zai.ZaiConversationalTask(),
},
};

/**
Expand Down
1 change: 1 addition & 0 deletions packages/inference/src/providers/consts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,5 @@ export const HARDCODED_MODEL_INFERENCE_MAPPING: Record<
sambanova: {},
scaleway: {},
together: {},
zai: {},
};
36 changes: 36 additions & 0 deletions packages/inference/src/providers/zai.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/**
* See the registered mapping of HF model ID => ZAI model ID here:
*
* https://huggingface.co/api/partners/zai/models
*
* This is a publicly available mapping.
*
* If you want to try to run inference for a new model locally before it's registered on huggingface.co,
* you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes.
*
* - If you work at zai and want to update this mapping, please use the model mapping API we provide on huggingface.co
* - If you're a community member and want to add a new supported HF model to zai, please open an issue on the present repo
* and we will tag zai team members.
*
* Thanks!
*/
import { BaseConversationalTask } from "./providerHelper.js";
import type { HeaderParams } from "../types.js";

const ZAI_API_BASE_URL = "https://api.z.ai/api/paas/v4";

export class ZaiConversationalTask extends BaseConversationalTask {
constructor() {
super("zai", ZAI_API_BASE_URL);
}

override prepareHeaders(params: HeaderParams, binary: boolean): Record<string, string> {
const headers = super.prepareHeaders(params, binary);
headers["x-source-channel"] = "hugging_face";
return headers;
}

override makeRoute(): string {
return "/chat/completions";
}
}
1 change: 1 addition & 0 deletions packages/inference/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ export const INFERENCE_PROVIDERS = [
"sambanova",
"scaleway",
"together",
"zai",
] as const;

export const PROVIDERS_OR_POLICIES = [...INFERENCE_PROVIDERS, "auto"] as const;
Expand Down
52 changes: 52 additions & 0 deletions packages/inference/test/InferenceClient.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2134,6 +2134,58 @@ describe.skip("InferenceClient", () => {
},
TIMEOUT
);
describe.concurrent(
"ZAI",
() => {
const client = new InferenceClient(env.HF_ZAI_KEY ?? "dummy");

HARDCODED_MODEL_INFERENCE_MAPPING["zai"] = {
"zai-org/GLM-4.5": {
provider: "zai",
hfModelId: "zai-org/GLM-4.5",
providerId: "glm-4.5",
status: "live",
task: "conversational",
},
};

it("chatCompletion", async () => {
const res = await client.chatCompletion({
model: "zai-org/GLM-4.5",
provider: "zai",
messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }],
});
if (res.choices && res.choices.length > 0) {
const completion = res.choices[0].message?.content;
expect(completion).toContain("two");
}
});

it("chatCompletion stream", async () => {
const stream = client.chatCompletionStream({
model: "zai-org/GLM-4.5",
provider: "zai",
messages: [{ role: "user", content: "Say 'this is a test'" }],
stream: true,
}) as AsyncGenerator<ChatCompletionStreamOutput>;

let fullResponse = "";
for await (const chunk of stream) {
if (chunk.choices && chunk.choices.length > 0) {
const content = chunk.choices[0].delta?.content;
if (content) {
fullResponse += content;
}
}
}

// Verify we got a meaningful response
expect(fullResponse).toBeTruthy();
expect(fullResponse.length).toBeGreaterThan(0);
});
},
TIMEOUT
);
describe.concurrent(
"OVHcloud",
() => {
Expand Down