@@ -50,10 +50,18 @@ bool handleMatch(const std::string& rawText, std::shared_ptr<inferenceState> inf
5050{
5151 if (inferState->isComplete ())
5252 {
53- return true ;
53+ return false ;
5454 }
55-
56- if (rawText == inferState->sequence [inferState->stopWordMatchLen ])
55+ if (inferState->stopWordMatchLen == 0 )
56+ {
57+ if (rawText.find (' <' ) != std::string::npos) // Found "<" anywhere in the text
58+ {
59+ inferState->stopWordMatchLen ++; // Move to next state
60+ inferState->prevText = rawText;
61+ return true ;
62+ }
63+ }
64+ else if (rawText == inferState->sequence [inferState->stopWordMatchLen ])
5765 {
5866 inferState->stopWordMatchLen ++; // Move to next state
5967 inferState->prevText = rawText;
@@ -110,9 +118,9 @@ GenerationInput::TensorPtr tensorrtllm::getTensorSingleStopWordList(int stopToke
110118
111119GenerationInput::TensorPtr tensorrtllm::getTensorChatMLStopWordList ()
112120{
113- std::vector<int32_t > stopWordsTokens = {28789 , 28766 , 321 , 28730 , 416 , 28766 , 28767 , 32000 , 6 , 8 , - 1 , -1 , -1 , -1 ,
114- -1 , -1 }; // Extend with -1 for increased length
115- return gptSession->getBufferManager ().copyFrom (stopWordsTokens, ITensor::makeShape ({1 , 2 , 8 }), MemoryType::kGPU );
121+ std::vector<int32_t > stopWordsTokens = {28789 , 28766 , 321 , 28730 , 416 , 28766 , 28767 , 2 , 32000 , 7 , 8 , 9 , -1 , -1 , -1 ,
122+ -1 , -1 , - 1 }; // Extend with -1 for increased length
123+ return gptSession->getBufferManager ().copyFrom (stopWordsTokens, ITensor::makeShape ({1 , 2 , 9 }), MemoryType::kGPU );
116124}
117125
118126GenerationInput tensorrtllm::createGenerationInput (std::vector<int32_t > inputIdsHost)
0 commit comments