2121using json = nlohmann::json;
2222using namespace tensorrtllm ;
2323
24+ namespace {
25+ constexpr const int k200OK = 200 ;
26+ constexpr const int k400BadRequest = 400 ;
27+ constexpr const int k409Conflict = 409 ;
28+ constexpr const int k500InternalServerError = 500 ;
29+
30+ // https://nvidia.github.io/TensorRT-LLM/_cpp_gen/runtime.html#generationinput-h
31+ // stopWordsList
32+ // 'im', '_' , 'end', '</s>', '<|im_end|>'
33+ const std::vector<int32_t > kOpenhermesStopWords = {321 , 28730 , 416 , 2 , 32000 , 3 , 4 , 5 , -1 , -1 };
34+ const std::string kOhUserPrompt = " <|im_end|>\n <|im_start|>user\n " ;
35+ const std::string kOhAiPrompt = " <|im_end|>\n <|im_start|>assistant\n " ;
36+ const std::string kOhSystemPrompt = " <|im_start|>system\n " ;
37+ const std::unordered_map<std::string, int > kOpenhermesTemplate = {{" <|im_end|>" , 32000 } , {" <|im_start|>" , 32001 }};
38+
39+ // '[', 'INST', ']', '[INST]', ''[, '/' , 'INST',']', '[/INST]', '</s>'
40+ const std::vector<int32_t > kMistral_V0_3_StopWords
41+ = {29560 , 17057 , 29561 , 3 , 29560 , 29516 , 17057 , 29561 , 4 , 2 , 3 , 4 , 8 , 9 , 10 , -1 , -1 , -1 , -1 , -1 };
42+
43+ enum class MistralTemplate : int32_t {
44+ kBos = 1 ,
45+ kEos = 2 ,
46+ kBeginInst = 3 ,
47+ kEndInst = 4
48+ };
2449
25- constexpr const int k200OK = 200 ;
26- constexpr const int k400BadRequest = 400 ;
27- constexpr const int k409Conflict = 409 ;
28- constexpr const int k500InternalServerError = 500 ;
29-
50+ // TODO(sang) This is fragile, just a temporary solution. Maybe can use a config file or model architect, etc...
51+ bool IsOpenhermes (const std::string& s) {
52+ if (s.find (" mistral" ) != std::string::npos || s.find (" Mistral" ) != std::string::npos) {
53+ return false ;
54+ }
55+ return true ;
56+ }
57+ }
3058TensorrtllmEngine::~TensorrtllmEngine () {}
3159
3260void RemoveId (std::vector<int >& vec, int id) {
3361 vec.erase (std::remove (vec.begin (), vec.end (), id), vec.end ());
3462}
3563
36- bool HandleMatch (std::string const & rew_text, std::shared_ptr<InferenceState> infer_state) {
37- if (infer_state->IsComplete ()) {
64+ bool HandleMatch (std::string const & rew_text,
65+ std::shared_ptr<InferenceState> infer_state,
66+ std::function<void (Json::Value&&, Json::Value&&)> cb,
67+ bool is_openhermes) {
68+ if (infer_state->IsComplete (is_openhermes)) {
3869 return false ;
3970 }
4071 if (infer_state->stop_word_match_len == 0 ) {
41- if (rew_text.find (' <' ) != std::string::npos) { // Found "<" anywhere in the text
72+ if ((is_openhermes && rew_text.find (' <' ) != std::string::npos) ||
73+ (!is_openhermes && rew_text.find (' [' ) != std::string::npos)) {
4274 infer_state->stop_word_match_len ++; // Move to next state
43- infer_state->prev_text = rew_text;
4475 return true ;
4576 }
46- }
47- else if (rew_text == infer_state->sequence [infer_state->stop_word_match_len ]) {
77+ } else if (rew_text == infer_state->GetSequence (is_openhermes, infer_state->stop_word_match_len )) {
4878 infer_state->stop_word_match_len ++; // Move to next state
49- infer_state->prev_text = rew_text;
5079 return true ;
51- }
52- else if (infer_state->stop_word_match_len > 0 && rew_text == infer_state->sequence [0 ]) {
80+ } else if (infer_state->stop_word_match_len > 0 && rew_text == infer_state->GetSequence (is_openhermes, 0u )) {
5381 infer_state->stop_word_match_len = 1 ; // Restart from first match if sequence breaks but matches start
54- infer_state->prev_text = rew_text;
5582 return true ;
56- }
57- else {
83+ } else {
5884 infer_state->Reset ();
5985 return false ; // Reset to start if sequence breaks
6086 }
@@ -67,19 +93,21 @@ GenerationInput::TensorPtr TensorrtllmEngine::GetTensorSingleStopWordList(int st
6793}
6894
6995GenerationInput::TensorPtr TensorrtllmEngine::GetTensorChatMLStopWordList () {
70- std::vector<int32_t > stop_words_tokens
71- = {321 , 28730 , 416 , 2 , 32000 , 3 , 4 , 5 , -1 , -1 }; // Extend with -1 for increased length
72- return gpt_session->getBufferManager ().copyFrom (stop_words_tokens, ITensor::makeShape ({1 , 2 , 5 }), MemoryType::kGPU );
96+ if (is_openhermes_) {
97+ return gpt_session->getBufferManager ().copyFrom (kOpenhermesStopWords , ITensor::makeShape ({1 , 2 , static_cast <int >(kOpenhermesStopWords .size ()/2 )}), MemoryType::kGPU );
98+ } else {
99+ return gpt_session->getBufferManager ().copyFrom (kMistral_V0_3_StopWords , ITensor::makeShape ({1 , 2 , static_cast <int >(kMistral_V0_3_StopWords .size ()/2 )}), MemoryType::kGPU );
100+ }
73101}
74102
75103GenerationInput TensorrtllmEngine::CreateGenerationInput (std::vector<int32_t > input_ids_host) {
76104 int input_len = input_ids_host.size ();
77- std::vector<int32_t > input_lengths_host (batchSize , input_len);
105+ std::vector<int32_t > input_lengths_host (batch_size_ , input_len);
78106 GenerationInput::TensorPtr input_lengths
79- = gpt_session->getBufferManager ().copyFrom (input_lengths_host, ITensor::makeShape ({batchSize }), MemoryType::kGPU );
107+ = gpt_session->getBufferManager ().copyFrom (input_lengths_host, ITensor::makeShape ({batch_size_ }), MemoryType::kGPU );
80108 GenerationInput::TensorPtr input_ids = gpt_session->getBufferManager ().copyFrom (
81- input_ids_host, ITensor::makeShape ({batchSize , input_len}), MemoryType::kGPU );
82- GenerationInput generation_input{0 , 0 , input_ids, input_lengths, model_config ->usePackedInput ()};
109+ input_ids_host, ITensor::makeShape ({batch_size_ , input_len}), MemoryType::kGPU );
110+ GenerationInput generation_input{0 , 0 , input_ids, input_lengths, model_config_ ->usePackedInput ()};
83111 generation_input.stopWordsList = GetTensorChatMLStopWordList ();
84112
85113 LOG_INFO << " Create generation input successfully" ;
@@ -102,27 +130,34 @@ void InferenceThread(
102130 TensorrtllmEngine* self,
103131 SamplingConfig sampling_config,
104132 int input_len,
105- int outputLen) {
133+ int outputLen, bool is_openhermes ) {
106134
107135 // Input preparation
108136 LOG_INFO << " Inference thread started" ;
109137 GenerationInput generation_input = self->CreateGenerationInput (input_ids_host);
110138 GenerationOutput generation_output = self->CreateGenerationOutput ();
111139
112140 // Define the callback to stream each generated token
113- generation_output.onTokenGenerated = [&infer_state, input_len, outputLen, self, &generation_output](
141+ generation_output.onTokenGenerated = [&infer_state, input_len, outputLen, self, &generation_output, is_openhermes ](
114142 GenerationOutput::TensorPtr const & output_ids, SizeType step, bool finished) {
115- LOG_INFO << " Generating tokenizer in thread" ;
143+ // LOG_INFO << "Generating tokenizer in thread";
116144 // Assuming the shape of output_ids tensor is (1, 1, 160), where 160 is the number of tokens
117145 int output_length = output_ids->getShape ().d [2 ]; // Get the length of output IDs based on the tensor shape
118146 // Copy output IDs from GPU to host for printing
119147 std::vector<int32_t > output_idsHost (output_length);
120148 self->gpt_session ->getBufferManager ().copy (*output_ids, output_idsHost.data (), MemoryType::kCPU );
121149 // Find the last non-zero value in the output IDs starting from the end of the input sequence
122150 std::vector<int > output_idsHostDecode (output_idsHost.begin () + input_len, output_idsHost.end ());
151+
123152 RemoveId (output_idsHostDecode, 0 );
124- RemoveId (output_idsHostDecode, 32000 );
125- RemoveId (output_idsHostDecode, 32001 );
153+ if (is_openhermes) {
154+ for (auto const & [_, v]: kOpenhermesTemplate ) {
155+ RemoveId (output_idsHostDecode, v);
156+ }
157+ } else {
158+ RemoveId (output_idsHostDecode, static_cast <int32_t >(MistralTemplate::kBeginInst ));
159+ RemoveId (output_idsHostDecode, static_cast <int32_t >(MistralTemplate::kEndInst ));
160+ }
126161 std::string text = self->cortex_tokenizer ->Decode (output_idsHostDecode);
127162
128163 if (infer_state->prev_pos >= 0 && infer_state->prev_pos < text.size ()) {
@@ -192,29 +227,47 @@ bool TensorrtllmEngine::CheckModelLoaded(std::function<void(Json::Value&&, Json:
192227
193228void TensorrtllmEngine::HandleChatCompletion (std::shared_ptr<Json::Value> json_body, std::function<void (Json::Value&&, Json::Value&&)>&& callback) {
194229 inferences::ChatCompletionRequest request = inferences::fromJson (json_body);
195- std::string formatted_input = pre_prompt ;
230+ std::string formatted_input = pre_prompt_ ;
196231 nlohmann::json data;
197232 // data["stream"] = completion.stream;
198233 // data["n_predict"] = completion.max_tokens;
199234 data[" presence_penalty" ] = request.presence_penalty ;
200235 Json::Value const & messages = request.messages ;
201236
237+ // tokens for Mistral v0.3
238+ // TODO(sang): too much hard code here, need to refactor it soon
239+ std::vector<int32_t > tokens = {static_cast <int32_t >(MistralTemplate::kBos )};
240+
202241 // Format the input from user
242+ int msg_count = 0 ;
203243 for (auto const & message : messages) {
204244 std::string input_role = message[" role" ].asString ();
205245 std::string role;
206246 if (input_role == " user" ) {
207- role = user_prompt ;
247+ role = user_prompt_ ;
208248 std::string content = message[" content" ].asString ();
209249 formatted_input += role + content;
250+ if (!is_openhermes_) {
251+ auto new_tokens = cortex_tokenizer->Encode (content);
252+ new_tokens.insert (new_tokens.begin (), static_cast <int32_t >(MistralTemplate::kBeginInst ));
253+ new_tokens.push_back (static_cast <int32_t >(MistralTemplate::kEndInst ));
254+ tokens.insert (tokens.end (), new_tokens.begin (), new_tokens.end ());
255+ }
210256 }
211257 else if (input_role == " assistant" ) {
212- role = ai_prompt ;
258+ role = ai_prompt_ ;
213259 std::string content = message[" content" ].asString ();
214260 formatted_input += role + content;
261+ if (!is_openhermes_) {
262+ auto new_tokens = cortex_tokenizer->Encode (content);
263+ if (msg_count == messages.size () - 1 ) {
264+ new_tokens.push_back (static_cast <int32_t >(MistralTemplate::kEos ));
265+ }
266+ tokens.insert (tokens.end (), new_tokens.begin (), new_tokens.end ());
267+ }
215268 }
216269 else if (input_role == " system" ) {
217- role = system_prompt ;
270+ role = system_prompt_ ;
218271 std::string content = message[" content" ].asString ();
219272 formatted_input = role + content + formatted_input;
220273 }
@@ -223,13 +276,21 @@ void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr<Json::Value> json_b
223276 std::string content = message[" content" ].asString ();
224277 formatted_input += role + content;
225278 }
279+ msg_count++;
226280 }
227- formatted_input += ai_prompt;
281+ formatted_input += ai_prompt_;
282+ // LOG_INFO << formatted_input;
228283 // Format the input from user
229284
230285 std::shared_ptr<InferenceState> infer_state = std::make_shared<InferenceState>();
231286
232- std::vector<int32_t > input_ids_host = cortex_tokenizer->Encode (formatted_input);
287+ std::vector<int32_t > input_ids_host;
288+ if (is_openhermes_) {
289+ input_ids_host = cortex_tokenizer->Encode (formatted_input);
290+ } else {
291+ input_ids_host = tokens;
292+ }
293+
233294 int const input_len = input_ids_host.size ();
234295 int const outputLen = request.max_tokens - input_len;
235296
@@ -243,23 +304,25 @@ void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr<Json::Value> json_b
243304 sampling_config.repetitionPenalty = std::vector{request.frequency_penalty };
244305 // Input preparation
245306
246- std::thread inference_thread (InferenceThread, infer_state, input_ids_host, callback, this , sampling_config, input_len, outputLen);
307+ std::thread inference_thread (InferenceThread, infer_state, input_ids_host, callback, this , sampling_config, input_len, outputLen, is_openhermes_ );
247308 inference_thread.detach (); // Detach the thread to allow it to run independently
248309
249- q_->runTaskInQueue ([cb = std::move (callback), infer_state]() {
310+ q_->runTaskInQueue ([this , cb = std::move (callback), infer_state]() {
311+ // std::string res_str;
250312 LOG_INFO << " Preparing to run inference task queue..." ;
251313 while (true ) { // Continuously check if the queue is not empty
252314 std::unique_lock<std::mutex> lock (infer_state->queue_mutex ); // Lock the queue for exclusive access
253315 if (!infer_state->texts_to_stream .empty ()) {
254316 std::string rew_text = infer_state->texts_to_stream .front ();
317+ // res_str += rew_text;
255318 infer_state->texts_to_stream .pop ();
256- if (HandleMatch (rew_text, infer_state) && rew_text != " [DONE]" ) {
319+ if (HandleMatch (rew_text, infer_state, cb, is_openhermes_ ) && rew_text != " [DONE]" ) {
257320 continue ;
258321 };
259322
260323 if (rew_text == " [DONE]" ) {
261324 const std::string str
262- = " data: " + tensorrtllm_utils::CreateReturnJson (tensorrtllm_utils::GenerateRandomString (20 ), " _ " , " " , " stop" )
325+ = " data: " + tensorrtllm_utils::CreateReturnJson (tensorrtllm_utils::GenerateRandomString (20 ), model_id_ , " " , " stop" )
263326 + " \n\n " + " data: [DONE]" + " \n\n " ;
264327
265328 infer_state->is_finished = true ;
@@ -275,10 +338,10 @@ void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr<Json::Value> json_b
275338 break ;
276339 }
277340 const std::string text_to_stream
278- = " data: " + tensorrtllm_utils::CreateReturnJson (tensorrtllm_utils::GenerateRandomString (20 ), " _ " , rew_text) + " \n\n " ;
341+ = " data: " + tensorrtllm_utils::CreateReturnJson (tensorrtllm_utils::GenerateRandomString (20 ), model_id_ , rew_text) + " \n\n " ;
279342
280343 lock.unlock (); // Unlock as soon as possible
281- infer_state-> prev_text = rew_text;
344+ // std::cout << rew_text;
282345
283346 Json::Value resp_data;
284347 resp_data[" data" ] = text_to_stream;
@@ -293,6 +356,7 @@ void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr<Json::Value> json_b
293356 lock.unlock ();
294357 }
295358 }
359+ // LOG_INFO << res_str;
296360 });
297361
298362 LOG_INFO << " Inference completed" ;
@@ -302,16 +366,20 @@ void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr<Json::Value> json_b
302366void TensorrtllmEngine::LoadModel (std::shared_ptr<Json::Value> json_body, std::function<void (Json::Value&&, Json::Value&&)>&& callback) {
303367 model::LoadModelRequest request = model::fromJson (json_body);
304368 std::filesystem::path model_dir = request.model_path ;
369+ is_openhermes_ = IsOpenhermes (request.model_path );
305370
306371 int ctx_len = request.ctx_len ;
307- this ->user_prompt = request.user_prompt ;
308- this ->ai_prompt = request.ai_prompt ;
309- this ->system_prompt = request.system_prompt ;
310- this ->model_id_ = GetModelId (*json_body);
372+ // We only support 2 models for now, it is ugly but it works :(
373+ if (is_openhermes_) {
374+ user_prompt_ = request.user_prompt .empty () ? kOhUserPrompt : request.user_prompt ;
375+ ai_prompt_ = request.ai_prompt .empty () ? kOhAiPrompt : request.ai_prompt ;
376+ system_prompt_ = request.system_prompt .empty () ? kOhSystemPrompt : request.system_prompt ;
377+ }
378+ model_id_ = GetModelId (*json_body);
311379
312- logger = std::make_shared<TllmLogger>();
313- logger ->setLevel (nvinfer1::ILogger::Severity::kINFO );
314- initTrtLlmPlugins (logger .get ());
380+ logger_ = std::make_shared<TllmLogger>();
381+ logger_ ->setLevel (nvinfer1::ILogger::Severity::kINFO );
382+ initTrtLlmPlugins (logger_ .get ());
315383
316384 std::filesystem::path tokenizer_model_name = model_dir / " tokenizer.model" ;
317385 cortex_tokenizer = std::make_unique<Tokenizer>(tokenizer_model_name.string ());
@@ -320,20 +388,20 @@ void TensorrtllmEngine::LoadModel(std::shared_ptr<Json::Value> json_body, std::f
320388 std::filesystem::path json_file_name = model_dir / " config.json" ;
321389 auto json = GptJsonConfig::parse (json_file_name);
322390 auto config = json.getModelConfig ();
323- model_config = std::make_unique<GptModelConfig>(config);
391+ model_config_ = std::make_unique<GptModelConfig>(config);
324392 auto world_config = WorldConfig::mpi (1 , json.getTensorParallelism (), json.getPipelineParallelism ());
325393 LOG_INFO << " Loaded config from " << json_file_name.string ();
326394 // auto dtype = model_config->getDataType();
327395
328396 // Currently doing fixed session config
329- session_config .maxBatchSize = batchSize ;
330- session_config .maxBeamWidth = 1 ; // Fixed for simplicity
331- session_config .maxSequenceLength = ctx_len;
332- session_config .cudaGraphMode = true ; // Fixed for simplicity
397+ session_config_ .maxBatchSize = batch_size_ ;
398+ session_config_ .maxBeamWidth = 1 ; // Fixed for simplicity
399+ session_config_ .maxSequenceLength = ctx_len;
400+ session_config_ .cudaGraphMode = true ; // Fixed for simplicity
333401
334402 // Init gpt_session
335403 auto model_path = model_dir / json.engineFilename (world_config, model_id_);
336- gpt_session = std::make_unique<GptSession>(session_config , *model_config , world_config, model_path.string (), logger );
404+ gpt_session = std::make_unique<GptSession>(session_config_ , *model_config_ , world_config, model_path.string (), logger_ );
337405
338406 model_loaded_ = true ;
339407 if (q_ == nullptr ) {
@@ -365,8 +433,8 @@ void TensorrtllmEngine::UnloadModel(std::shared_ptr<Json::Value> json_body, std:
365433 gpt_session.reset ();
366434 cortex_tokenizer.reset ();
367435 q_.reset ();
368- model_config .reset ();
369- logger .reset ();
436+ model_config_ .reset ();
437+ logger_ .reset ();
370438 model_loaded_ = false ;
371439
372440 Json::Value json_resp;
0 commit comments