Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

Commit c7c8516

Browse files
committed
fix: rewind text if does not match pattern
1 parent 3ac131a commit c7c8516

File tree

2 files changed

+45
-15
lines changed

2 files changed

+45
-15
lines changed

cpp/tensorrt_llm/cortex.tensorrt-llm/src/tensorrt-llm_engine.cc

Lines changed: 43 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -79,30 +79,61 @@ void RemoveId(std::vector<int>& vec, int id) {
7979
vec.erase(std::remove(vec.begin(), vec.end(), id), vec.end());
8080
}
8181

82-
bool HandleMatch(std::string const& rew_text, std::shared_ptr<InferenceState> infer_state, bool is_openhermes) {
82+
bool HandleMatch(std::string const& rew_text,
83+
std::shared_ptr<InferenceState> infer_state,
84+
std::function<void(Json::Value&&, Json::Value&&)> cb,
85+
bool is_openhermes) {
8386
if (infer_state->IsComplete(is_openhermes)) {
87+
infer_state->rewind_strs.clear();
8488
return false;
8589
}
8690
if (infer_state->stop_word_match_len == 0) {
8791
if ((is_openhermes && rew_text.find('<') != std::string::npos) ||
8892
(!is_openhermes && rew_text.find('[') != std::string::npos)) {
8993
infer_state->stop_word_match_len++; // Move to next state
90-
infer_state->prev_text = rew_text;
94+
infer_state->rewind_strs.push_back(rew_text);
9195
return true;
9296
}
93-
}
94-
else if (rew_text == infer_state->GetSequence(is_openhermes, infer_state->stop_word_match_len)) {
97+
} else if (rew_text == infer_state->GetSequence(is_openhermes, infer_state->stop_word_match_len)) {
9598
infer_state->stop_word_match_len++; // Move to next state
96-
infer_state->prev_text = rew_text;
99+
infer_state->rewind_strs.push_back(rew_text);
97100
return true;
98-
}
99-
else if (infer_state->stop_word_match_len > 0 && rew_text == infer_state->GetSequence(is_openhermes, 0u)) {
101+
} else if (infer_state->stop_word_match_len > 0 && rew_text == infer_state->GetSequence(is_openhermes, 0u)) {
100102
infer_state->stop_word_match_len = 1; // Restart from first match if sequence breaks but matches start
101-
infer_state->prev_text = rew_text;
103+
// response cache data
104+
for(auto const& s: infer_state->rewind_strs) {
105+
// std::cout << s;
106+
const std::string text_to_stream
107+
= "data: " + tensorrtllm_utils::CreateReturnJson(tensorrtllm_utils::GenerateRandomString(20), "_", s) + "\n\n";
108+
Json::Value resp_data;
109+
resp_data["data"] = text_to_stream;
110+
Json::Value status;
111+
status["is_done"] = false;
112+
status["has_error"] = false;
113+
status["is_stream"] = true;
114+
status["status_code"] = k200OK;
115+
cb(std::move(status), std::move(resp_data));
116+
}
117+
infer_state->rewind_strs.clear();
118+
infer_state->rewind_strs.push_back(rew_text);
102119
return true;
103-
}
104-
else {
120+
} else {
105121
infer_state->Reset();
122+
// response cache data
123+
for(auto const& s: infer_state->rewind_strs) {
124+
// std::cout << s;
125+
const std::string text_to_stream
126+
= "data: " + tensorrtllm_utils::CreateReturnJson(tensorrtllm_utils::GenerateRandomString(20), "_", s) + "\n\n";
127+
Json::Value resp_data;
128+
resp_data["data"] = text_to_stream;
129+
Json::Value status;
130+
status["is_done"] = false;
131+
status["has_error"] = false;
132+
status["is_stream"] = true;
133+
status["status_code"] = k200OK;
134+
cb(std::move(status), std::move(resp_data));
135+
}
136+
infer_state->rewind_strs.clear();
106137
return false; // Reset to start if sequence breaks
107138
}
108139
return false;
@@ -313,7 +344,7 @@ void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr<Json::Value> json_b
313344
std::string rew_text = infer_state->texts_to_stream.front();
314345
// res_str += rew_text;
315346
infer_state->texts_to_stream.pop();
316-
if (HandleMatch(rew_text, infer_state, is_openhermes_) && rew_text != "[DONE]") {
347+
if (HandleMatch(rew_text, infer_state, cb, is_openhermes_) && rew_text != "[DONE]") {
317348
continue;
318349
};
319350

@@ -338,7 +369,7 @@ void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr<Json::Value> json_b
338369
= "data: " + tensorrtllm_utils::CreateReturnJson(tensorrtllm_utils::GenerateRandomString(20), model_id_, rew_text) + "\n\n";
339370

340371
lock.unlock(); // Unlock as soon as possible
341-
infer_state->prev_text = rew_text;
372+
// std::cout << rew_text;
342373

343374
Json::Value resp_data;
344375
resp_data["data"] = text_to_stream;

cpp/tensorrt_llm/cortex.tensorrt-llm/src/tensorrt-llm_engine.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,18 +68,17 @@ class Tokenizer {
6868

6969
struct InferenceState {
7070
int prev_pos{0};
71-
std::string prev_text;
7271
bool is_finished;
7372
std::queue<std::string> texts_to_stream;
7473
std::mutex queue_mutex; // Mutex to protect access to textsToStream
7574
size_t stop_word_match_len = 0;
7675
std::vector<std::string> sequence_openhermes = {"<", "|", "im", "_", "end", "|", ">"};
7776
std::vector<std::string> sequence_mistral = {"[", "INST", "]"};
7877
int token_gen_count = 0;
78+
std::vector<std::string> rewind_strs;
7979

8080
void Reset() {
81-
stop_word_match_len = 0;
82-
prev_text = "";
81+
stop_word_match_len = 0;
8382
}
8483

8584
bool IsComplete(bool is_openhermes) const {

0 commit comments

Comments
 (0)