@@ -39,10 +39,13 @@ namespace {
3939 // '[', 'INST', ']', '[INST]', ''[, '/' , 'INST',']', '[/INST]', '</s>'
4040 const std::vector<int32_t > kMistral_V0_3_StopWords
4141 = {29560 , 17057 , 29561 , 3 , 29560 , 29516 , 17057 , 29561 , 4 , 2 , 3 , 4 , 8 , 9 , 10 , -1 , -1 , -1 , -1 , -1 };
42- const std::string kMistralUserPrompt = " [INST] " ;
43- const std::string kMistralAiPrompt = " [/INST] " ;
44- const std::string kMistralSystemPrompt = " <s>" ;
45- const std::unordered_map<std::string, int > kMistralTemplate = {{" [INST]" , 3 } , {" [/INST]" , 4 }};
42+
43+ enum class MistralTemplate : int32_t {
44+ kBos = 1 ,
45+ kEos = 2 ,
46+ kBeginInst = 3 ,
47+ kEndInst = 4
48+ };
4649
4750 // TODO(sang) This is fragile, just a temporary solution. Maybe can use a config file or model architect, etc...
4851 bool IsOpenhermes (const std::string& s) {
@@ -51,27 +54,6 @@ namespace {
5154 }
5255 return true ;
5356 }
54-
55- std::string GetUserPrompt (bool is_openhermes) {
56- if (is_openhermes) {
57- return kOhUserPrompt ;
58- }
59- return kMistralUserPrompt ;
60- }
61-
62- std::string GetAiPrompt (bool is_openhermes) {
63- if (is_openhermes) {
64- return kOhAiPrompt ;
65- }
66- return kMistralAiPrompt ;
67- }
68-
69- std::string GetSystemPrompt (bool is_openhermes) {
70- if (is_openhermes) {
71- return kOhSystemPrompt ;
72- }
73- return kMistralSystemPrompt ;
74- }
7557}
7658TensorrtllmEngine::~TensorrtllmEngine () {}
7759
@@ -84,56 +66,22 @@ bool HandleMatch(std::string const& rew_text,
8466 std::function<void (Json::Value&&, Json::Value&&)> cb,
8567 bool is_openhermes) {
8668 if (infer_state->IsComplete (is_openhermes)) {
87- infer_state->rewind_strs .clear ();
8869 return false ;
8970 }
9071 if (infer_state->stop_word_match_len == 0 ) {
9172 if ((is_openhermes && rew_text.find (' <' ) != std::string::npos) ||
9273 (!is_openhermes && rew_text.find (' [' ) != std::string::npos)) {
9374 infer_state->stop_word_match_len ++; // Move to next state
94- infer_state->rewind_strs .push_back (rew_text);
9575 return true ;
9676 }
9777 } else if (rew_text == infer_state->GetSequence (is_openhermes, infer_state->stop_word_match_len )) {
9878 infer_state->stop_word_match_len ++; // Move to next state
99- infer_state->rewind_strs .push_back (rew_text);
10079 return true ;
10180 } else if (infer_state->stop_word_match_len > 0 && rew_text == infer_state->GetSequence (is_openhermes, 0u )) {
10281 infer_state->stop_word_match_len = 1 ; // Restart from first match if sequence breaks but matches start
103- // response cache data
104- for (auto const & s: infer_state->rewind_strs ) {
105- // std::cout << s;
106- const std::string text_to_stream
107- = " data: " + tensorrtllm_utils::CreateReturnJson (tensorrtllm_utils::GenerateRandomString (20 ), " _" , s) + " \n\n " ;
108- Json::Value resp_data;
109- resp_data[" data" ] = text_to_stream;
110- Json::Value status;
111- status[" is_done" ] = false ;
112- status[" has_error" ] = false ;
113- status[" is_stream" ] = true ;
114- status[" status_code" ] = k200OK;
115- cb (std::move (status), std::move (resp_data));
116- }
117- infer_state->rewind_strs .clear ();
118- infer_state->rewind_strs .push_back (rew_text);
11982 return true ;
12083 } else {
12184 infer_state->Reset ();
122- // response cache data
123- for (auto const & s: infer_state->rewind_strs ) {
124- // std::cout << s;
125- const std::string text_to_stream
126- = " data: " + tensorrtllm_utils::CreateReturnJson (tensorrtllm_utils::GenerateRandomString (20 ), " _" , s) + " \n\n " ;
127- Json::Value resp_data;
128- resp_data[" data" ] = text_to_stream;
129- Json::Value status;
130- status[" is_done" ] = false ;
131- status[" has_error" ] = false ;
132- status[" is_stream" ] = true ;
133- status[" status_code" ] = k200OK;
134- cb (std::move (status), std::move (resp_data));
135- }
136- infer_state->rewind_strs .clear ();
13785 return false ; // Reset to start if sequence breaks
13886 }
13987 return false ;
@@ -207,9 +155,8 @@ void InferenceThread(
207155 RemoveId (output_idsHostDecode, v);
208156 }
209157 } else {
210- for (auto const & [_, v]: kMistralTemplate ) {
211- RemoveId (output_idsHostDecode, v);
212- }
158+ RemoveId (output_idsHostDecode, static_cast <int32_t >(MistralTemplate::kBeginInst ));
159+ RemoveId (output_idsHostDecode, static_cast <int32_t >(MistralTemplate::kEndInst ));
213160 }
214161 std::string text = self->cortex_tokenizer ->Decode (output_idsHostDecode);
215162
@@ -287,19 +234,37 @@ void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr<Json::Value> json_b
287234 data[" presence_penalty" ] = request.presence_penalty ;
288235 Json::Value const & messages = request.messages ;
289236
237+ // tokens for Mistral v0.3
238+ // TODO(sang): too much hard code here, need to refactor it soon
239+ std::vector<int32_t > tokens = {static_cast <int32_t >(MistralTemplate::kBos )};
240+
290241 // Format the input from user
242+ int msg_count = 0 ;
291243 for (auto const & message : messages) {
292244 std::string input_role = message[" role" ].asString ();
293245 std::string role;
294246 if (input_role == " user" ) {
295247 role = user_prompt_;
296248 std::string content = message[" content" ].asString ();
297249 formatted_input += role + content;
250+ if (!is_openhermes_) {
251+ auto new_tokens = cortex_tokenizer->Encode (content);
252+ new_tokens.insert (new_tokens.begin (), static_cast <int32_t >(MistralTemplate::kBeginInst ));
253+ new_tokens.push_back (static_cast <int32_t >(MistralTemplate::kEndInst ));
254+ tokens.insert (tokens.end (), new_tokens.begin (), new_tokens.end ());
255+ }
298256 }
299257 else if (input_role == " assistant" ) {
300258 role = ai_prompt_;
301259 std::string content = message[" content" ].asString ();
302260 formatted_input += role + content;
261+ if (!is_openhermes_) {
262+ auto new_tokens = cortex_tokenizer->Encode (content);
263+ if (msg_count == messages.size () - 1 ) {
264+ new_tokens.push_back (static_cast <int32_t >(MistralTemplate::kEos ));
265+ }
266+ tokens.insert (tokens.end (), new_tokens.begin (), new_tokens.end ());
267+ }
303268 }
304269 else if (input_role == " system" ) {
305270 role = system_prompt_;
@@ -311,14 +276,21 @@ void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr<Json::Value> json_b
311276 std::string content = message[" content" ].asString ();
312277 formatted_input += role + content;
313278 }
279+ msg_count++;
314280 }
315281 formatted_input += ai_prompt_;
316282 // LOG_INFO << formatted_input;
317283 // Format the input from user
318284
319285 std::shared_ptr<InferenceState> infer_state = std::make_shared<InferenceState>();
320286
321- std::vector<int32_t > input_ids_host = cortex_tokenizer->Encode (formatted_input);
287+ std::vector<int32_t > input_ids_host;
288+ if (is_openhermes_) {
289+ input_ids_host = cortex_tokenizer->Encode (formatted_input);
290+ } else {
291+ input_ids_host = tokens;
292+ }
293+
322294 int const input_len = input_ids_host.size ();
323295 int const outputLen = request.max_tokens - input_len;
324296
@@ -397,9 +369,12 @@ void TensorrtllmEngine::LoadModel(std::shared_ptr<Json::Value> json_body, std::f
397369 is_openhermes_ = IsOpenhermes (request.model_path );
398370
399371 int ctx_len = request.ctx_len ;
400- user_prompt_ = request.user_prompt .empty () ? GetUserPrompt (is_openhermes_) : request.user_prompt ;
401- ai_prompt_ = request.ai_prompt .empty () ? GetAiPrompt (is_openhermes_) : request.ai_prompt ;
402- system_prompt_ = request.system_prompt .empty () ? GetSystemPrompt (is_openhermes_) : request.system_prompt ;
372+ // We only support 2 models for now, it is ugly but it works :(
373+ if (is_openhermes_) {
374+ user_prompt_ = request.user_prompt .empty () ? kOhUserPrompt : request.user_prompt ;
375+ ai_prompt_ = request.ai_prompt .empty () ? kOhAiPrompt : request.ai_prompt ;
376+ system_prompt_ = request.system_prompt .empty () ? kOhSystemPrompt : request.system_prompt ;
377+ }
403378 model_id_ = GetModelId (*json_body);
404379
405380 logger_ = std::make_shared<TllmLogger>();
0 commit comments