Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

Commit 9303456

Browse files
committed
fix: template issue for tokenizer v3
1 parent c7c8516 commit 9303456

File tree

2 files changed

+41
-67
lines changed

2 files changed

+41
-67
lines changed

cpp/tensorrt_llm/cortex.tensorrt-llm/src/tensorrt-llm_engine.cc

Lines changed: 41 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,13 @@ namespace {
3939
// '[', 'INST', ']', '[INST]', ''[, '/' , 'INST',']', '[/INST]', '</s>'
4040
const std::vector<int32_t> kMistral_V0_3_StopWords
4141
= {29560, 17057, 29561, 3, 29560, 29516, 17057, 29561, 4, 2, 3, 4, 8, 9, 10, -1, -1, -1, -1, -1};
42-
const std::string kMistralUserPrompt = "[INST] ";
43-
const std::string kMistralAiPrompt = "[/INST] ";
44-
const std::string kMistralSystemPrompt = "<s>";
45-
const std::unordered_map<std::string, int> kMistralTemplate = {{"[INST]", 3} , {"[/INST]", 4}};
42+
43+
enum class MistralTemplate: int32_t {
44+
kBos = 1,
45+
kEos = 2,
46+
kBeginInst = 3,
47+
kEndInst = 4
48+
};
4649

4750
// TODO(sang) This is fragile, just a temporary solution. Maybe can use a config file or model architect, etc...
4851
bool IsOpenhermes(const std::string& s) {
@@ -51,27 +54,6 @@ namespace {
5154
}
5255
return true;
5356
}
54-
55-
std::string GetUserPrompt(bool is_openhermes) {
56-
if(is_openhermes) {
57-
return kOhUserPrompt;
58-
}
59-
return kMistralUserPrompt;
60-
}
61-
62-
std::string GetAiPrompt(bool is_openhermes) {
63-
if(is_openhermes) {
64-
return kOhAiPrompt;
65-
}
66-
return kMistralAiPrompt;
67-
}
68-
69-
std::string GetSystemPrompt(bool is_openhermes) {
70-
if(is_openhermes) {
71-
return kOhSystemPrompt;
72-
}
73-
return kMistralSystemPrompt;
74-
}
7557
}
7658
TensorrtllmEngine::~TensorrtllmEngine() {}
7759

@@ -84,56 +66,22 @@ bool HandleMatch(std::string const& rew_text,
8466
std::function<void(Json::Value&&, Json::Value&&)> cb,
8567
bool is_openhermes) {
8668
if (infer_state->IsComplete(is_openhermes)) {
87-
infer_state->rewind_strs.clear();
8869
return false;
8970
}
9071
if (infer_state->stop_word_match_len == 0) {
9172
if ((is_openhermes && rew_text.find('<') != std::string::npos) ||
9273
(!is_openhermes && rew_text.find('[') != std::string::npos)) {
9374
infer_state->stop_word_match_len++; // Move to next state
94-
infer_state->rewind_strs.push_back(rew_text);
9575
return true;
9676
}
9777
} else if (rew_text == infer_state->GetSequence(is_openhermes, infer_state->stop_word_match_len)) {
9878
infer_state->stop_word_match_len++; // Move to next state
99-
infer_state->rewind_strs.push_back(rew_text);
10079
return true;
10180
} else if (infer_state->stop_word_match_len > 0 && rew_text == infer_state->GetSequence(is_openhermes, 0u)) {
10281
infer_state->stop_word_match_len = 1; // Restart from first match if sequence breaks but matches start
103-
// response cache data
104-
for(auto const& s: infer_state->rewind_strs) {
105-
// std::cout << s;
106-
const std::string text_to_stream
107-
= "data: " + tensorrtllm_utils::CreateReturnJson(tensorrtllm_utils::GenerateRandomString(20), "_", s) + "\n\n";
108-
Json::Value resp_data;
109-
resp_data["data"] = text_to_stream;
110-
Json::Value status;
111-
status["is_done"] = false;
112-
status["has_error"] = false;
113-
status["is_stream"] = true;
114-
status["status_code"] = k200OK;
115-
cb(std::move(status), std::move(resp_data));
116-
}
117-
infer_state->rewind_strs.clear();
118-
infer_state->rewind_strs.push_back(rew_text);
11982
return true;
12083
} else {
12184
infer_state->Reset();
122-
// response cache data
123-
for(auto const& s: infer_state->rewind_strs) {
124-
// std::cout << s;
125-
const std::string text_to_stream
126-
= "data: " + tensorrtllm_utils::CreateReturnJson(tensorrtllm_utils::GenerateRandomString(20), "_", s) + "\n\n";
127-
Json::Value resp_data;
128-
resp_data["data"] = text_to_stream;
129-
Json::Value status;
130-
status["is_done"] = false;
131-
status["has_error"] = false;
132-
status["is_stream"] = true;
133-
status["status_code"] = k200OK;
134-
cb(std::move(status), std::move(resp_data));
135-
}
136-
infer_state->rewind_strs.clear();
13785
return false; // Reset to start if sequence breaks
13886
}
13987
return false;
@@ -207,9 +155,8 @@ void InferenceThread(
207155
RemoveId(output_idsHostDecode, v);
208156
}
209157
} else {
210-
for(auto const& [_, v]: kMistralTemplate) {
211-
RemoveId(output_idsHostDecode, v);
212-
}
158+
RemoveId(output_idsHostDecode, static_cast<int32_t>(MistralTemplate::kBeginInst));
159+
RemoveId(output_idsHostDecode, static_cast<int32_t>(MistralTemplate::kEndInst));
213160
}
214161
std::string text = self->cortex_tokenizer->Decode(output_idsHostDecode);
215162

@@ -287,19 +234,37 @@ void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr<Json::Value> json_b
287234
data["presence_penalty"] = request.presence_penalty;
288235
Json::Value const& messages = request.messages;
289236

237+
// tokens for Mistral v0.3
238+
// TODO(sang): too much hard code here, need to refactor it soon
239+
std::vector<int32_t> tokens = {static_cast<int32_t>(MistralTemplate::kBos)};
240+
290241
// Format the input from user
242+
int msg_count = 0;
291243
for (auto const& message : messages) {
292244
std::string input_role = message["role"].asString();
293245
std::string role;
294246
if (input_role == "user") {
295247
role = user_prompt_;
296248
std::string content = message["content"].asString();
297249
formatted_input += role + content;
250+
if(!is_openhermes_) {
251+
auto new_tokens = cortex_tokenizer->Encode(content);
252+
new_tokens.insert(new_tokens.begin(), static_cast<int32_t>(MistralTemplate::kBeginInst));
253+
new_tokens.push_back(static_cast<int32_t>(MistralTemplate::kEndInst));
254+
tokens.insert(tokens.end(), new_tokens.begin(), new_tokens.end());
255+
}
298256
}
299257
else if (input_role == "assistant") {
300258
role = ai_prompt_;
301259
std::string content = message["content"].asString();
302260
formatted_input += role + content;
261+
if(!is_openhermes_) {
262+
auto new_tokens = cortex_tokenizer->Encode(content);
263+
if(msg_count == messages.size() - 1) {
264+
new_tokens.push_back(static_cast<int32_t>(MistralTemplate::kEos));
265+
}
266+
tokens.insert(tokens.end(), new_tokens.begin(), new_tokens.end());
267+
}
303268
}
304269
else if (input_role == "system") {
305270
role = system_prompt_;
@@ -311,14 +276,21 @@ void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr<Json::Value> json_b
311276
std::string content = message["content"].asString();
312277
formatted_input += role + content;
313278
}
279+
msg_count++;
314280
}
315281
formatted_input += ai_prompt_;
316282
// LOG_INFO << formatted_input;
317283
// Format the input from user
318284

319285
std::shared_ptr<InferenceState> infer_state = std::make_shared<InferenceState>();
320286

321-
std::vector<int32_t> input_ids_host = cortex_tokenizer->Encode(formatted_input);
287+
std::vector<int32_t> input_ids_host;
288+
if(is_openhermes_) {
289+
input_ids_host = cortex_tokenizer->Encode(formatted_input);
290+
} else {
291+
input_ids_host = tokens;
292+
}
293+
322294
int const input_len = input_ids_host.size();
323295
int const outputLen = request.max_tokens - input_len;
324296

@@ -397,9 +369,12 @@ void TensorrtllmEngine::LoadModel(std::shared_ptr<Json::Value> json_body, std::f
397369
is_openhermes_ = IsOpenhermes(request.model_path);
398370

399371
int ctx_len = request.ctx_len;
400-
user_prompt_ = request.user_prompt.empty() ? GetUserPrompt(is_openhermes_) : request.user_prompt;
401-
ai_prompt_ = request.ai_prompt.empty() ? GetAiPrompt(is_openhermes_) : request.ai_prompt;
402-
system_prompt_ = request.system_prompt.empty() ? GetSystemPrompt(is_openhermes_) : request.system_prompt;
372+
// We only support 2 models for now, it is ugly but it works :(
373+
if(is_openhermes_) {
374+
user_prompt_ = request.user_prompt.empty() ? kOhUserPrompt : request.user_prompt;
375+
ai_prompt_ = request.ai_prompt.empty() ? kOhAiPrompt : request.ai_prompt;
376+
system_prompt_ = request.system_prompt.empty() ? kOhSystemPrompt : request.system_prompt;
377+
}
403378
model_id_ = GetModelId(*json_body);
404379

405380
logger_ = std::make_shared<TllmLogger>();

cpp/tensorrt_llm/cortex.tensorrt-llm/src/tensorrt-llm_engine.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ struct InferenceState {
7575
std::vector<std::string> sequence_openhermes = {"<", "|", "im", "_", "end", "|", ">"};
7676
std::vector<std::string> sequence_mistral = {"[", "INST", "]"};
7777
int token_gen_count = 0;
78-
std::vector<std::string> rewind_strs;
7978

8079
void Reset() {
8180
stop_word_match_len = 0;

0 commit comments

Comments
 (0)