diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index 2dea9a25537..94167f61f0f 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -1559,17 +1559,26 @@ def _are_end_id(self, requests: list[LlmRequest], tokens: torch.Tensor) -> torch return tokens == end_ids_tensor def _are_max_length(self, requests: list[LlmRequest]) -> torch.Tensor: + # Check if meeting the below conditions: + # either num_tokens - orig_prompt_len >= max_new_tokens + # or num_tokens >= max_seq_len lengths_tensor = torch.tensor( [ [ - ((req.get_num_tokens(BEAM) + num_tokens) - req.py_orig_prompt_len) + (req.get_num_tokens(BEAM) + num_tokens) for num_tokens in range(1, self.max_tokens + 1) ] for req in requests ] ) max_lengths_tensor = torch.tensor( - [([min(req.py_max_new_tokens, self.max_seq_len)] * self.max_tokens) for req in requests] + [ + ( + [min(req.py_max_new_tokens + req.py_orig_prompt_len, self.max_seq_len)] + * self.max_tokens + ) + for req in requests + ] ) return ( (lengths_tensor >= max_lengths_tensor)