@@ -79,30 +79,61 @@ void RemoveId(std::vector<int>& vec, int id) {
7979 vec.erase (std::remove (vec.begin (), vec.end (), id), vec.end ());
8080}
8181
82- bool HandleMatch (std::string const & rew_text, std::shared_ptr<InferenceState> infer_state, bool is_openhermes) {
82+ bool HandleMatch (std::string const & rew_text,
83+ std::shared_ptr<InferenceState> infer_state,
84+ std::function<void (Json::Value&&, Json::Value&&)> cb,
85+ bool is_openhermes) {
8386 if (infer_state->IsComplete (is_openhermes)) {
87+ infer_state->rewind_strs .clear ();
8488 return false ;
8589 }
8690 if (infer_state->stop_word_match_len == 0 ) {
8791 if ((is_openhermes && rew_text.find (' <' ) != std::string::npos) ||
8892 (!is_openhermes && rew_text.find (' [' ) != std::string::npos)) {
8993 infer_state->stop_word_match_len ++; // Move to next state
90- infer_state->prev_text = rew_text;
94+ infer_state->rewind_strs . push_back ( rew_text) ;
9195 return true ;
9296 }
93- }
94- else if (rew_text == infer_state->GetSequence (is_openhermes, infer_state->stop_word_match_len )) {
97+ } else if (rew_text == infer_state->GetSequence (is_openhermes, infer_state->stop_word_match_len )) {
9598 infer_state->stop_word_match_len ++; // Move to next state
96- infer_state->prev_text = rew_text;
99+ infer_state->rewind_strs . push_back ( rew_text) ;
97100 return true ;
98- }
99- else if (infer_state->stop_word_match_len > 0 && rew_text == infer_state->GetSequence (is_openhermes, 0u )) {
101+ } else if (infer_state->stop_word_match_len > 0 && rew_text == infer_state->GetSequence (is_openhermes, 0u )) {
100102 infer_state->stop_word_match_len = 1 ; // Restart from first match if sequence breaks but matches start
101- infer_state->prev_text = rew_text;
103+ // response cache data
104+ for (auto const & s: infer_state->rewind_strs ) {
105+ // std::cout << s;
106+ const std::string text_to_stream
107+ = " data: " + tensorrtllm_utils::CreateReturnJson (tensorrtllm_utils::GenerateRandomString (20 ), " _" , s) + " \n\n " ;
108+ Json::Value resp_data;
109+ resp_data[" data" ] = text_to_stream;
110+ Json::Value status;
111+ status[" is_done" ] = false ;
112+ status[" has_error" ] = false ;
113+ status[" is_stream" ] = true ;
114+ status[" status_code" ] = k200OK;
115+ cb (std::move (status), std::move (resp_data));
116+ }
117+ infer_state->rewind_strs .clear ();
118+ infer_state->rewind_strs .push_back (rew_text);
102119 return true ;
103- }
104- else {
120+ } else {
105121 infer_state->Reset ();
122+ // response cache data
123+ for (auto const & s: infer_state->rewind_strs ) {
124+ // std::cout << s;
125+ const std::string text_to_stream
126+ = " data: " + tensorrtllm_utils::CreateReturnJson (tensorrtllm_utils::GenerateRandomString (20 ), " _" , s) + " \n\n " ;
127+ Json::Value resp_data;
128+ resp_data[" data" ] = text_to_stream;
129+ Json::Value status;
130+ status[" is_done" ] = false ;
131+ status[" has_error" ] = false ;
132+ status[" is_stream" ] = true ;
133+ status[" status_code" ] = k200OK;
134+ cb (std::move (status), std::move (resp_data));
135+ }
136+ infer_state->rewind_strs .clear ();
106137 return false ; // Reset to start if sequence breaks
107138 }
108139 return false ;
@@ -313,7 +344,7 @@ void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr<Json::Value> json_b
313344 std::string rew_text = infer_state->texts_to_stream .front ();
314345 // res_str += rew_text;
315346 infer_state->texts_to_stream .pop ();
316- if (HandleMatch (rew_text, infer_state, is_openhermes_) && rew_text != " [DONE]" ) {
347+ if (HandleMatch (rew_text, infer_state, cb, is_openhermes_) && rew_text != " [DONE]" ) {
317348 continue ;
318349 };
319350
@@ -338,7 +369,7 @@ void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr<Json::Value> json_b
338369 = " data: " + tensorrtllm_utils::CreateReturnJson (tensorrtllm_utils::GenerateRandomString (20 ), model_id_, rew_text) + " \n\n " ;
339370
340371 lock.unlock (); // Unlock as soon as possible
341- infer_state-> prev_text = rew_text;
372+ // std::cout << rew_text;
342373
343374 Json::Value resp_data;
344375 resp_data[" data" ] = text_to_stream;
0 commit comments