Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

Commit 764b41b

Browse files
automaticcatjan-service-account
authored andcommitted
fix im_end words filtering
1 parent b8b9500 commit 764b41b

File tree

1 file changed

+14
-6
lines changed

1 file changed

+14
-6
lines changed

cpp/tensorrt_llm/nitro/controllers/tensorrtllm.cc

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,18 @@ bool handleMatch(const std::string& rawText, std::shared_ptr<inferenceState> inf
5050
{
5151
if (inferState->isComplete())
5252
{
53-
return true;
53+
return false;
5454
}
55-
56-
if (rawText == inferState->sequence[inferState->stopWordMatchLen])
55+
if (inferState->stopWordMatchLen == 0)
56+
{
57+
if (rawText.find('<') != std::string::npos) // Found "<" anywhere in the text
58+
{
59+
inferState->stopWordMatchLen++; // Move to next state
60+
inferState->prevText = rawText;
61+
return true;
62+
}
63+
}
64+
else if (rawText == inferState->sequence[inferState->stopWordMatchLen])
5765
{
5866
inferState->stopWordMatchLen++; // Move to next state
5967
inferState->prevText = rawText;
@@ -110,9 +118,9 @@ GenerationInput::TensorPtr tensorrtllm::getTensorSingleStopWordList(int stopToke
110118

111119
GenerationInput::TensorPtr tensorrtllm::getTensorChatMLStopWordList()
112120
{
113-
std::vector<int32_t> stopWordsTokens = {28789, 28766, 321, 28730, 416, 28766, 28767, 32000, 6, 8, -1, -1, -1, -1,
114-
-1, -1}; // Extend with -1 for increased length
115-
return gptSession->getBufferManager().copyFrom(stopWordsTokens, ITensor::makeShape({1, 2, 8}), MemoryType::kGPU);
121+
std::vector<int32_t> stopWordsTokens = {28789, 28766, 321, 28730, 416, 28766, 28767, 2, 32000, 7, 8, 9, -1, -1, -1,
122+
-1, -1, -1}; // Extend with -1 for increased length
123+
return gptSession->getBufferManager().copyFrom(stopWordsTokens, ITensor::makeShape({1, 2, 9}), MemoryType::kGPU);
116124
}
117125

118126
GenerationInput tensorrtllm::createGenerationInput(std::vector<int32_t> inputIdsHost)

0 commit comments

Comments
 (0)