Skip to content

Commit fa24744

Browse files
committed
change logic to get initial input_seq_len
Signed-off-by: John Calderon <jcalderon@nvidia.com>
1 parent de73b55 commit fa24744

File tree

2 files changed

+9
-12
lines changed

2 files changed

+9
-12
lines changed

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -160,13 +160,10 @@ def _create_dummy_mm_context_request(
160160

161161
max_num_tokens = len(prompt_token_ids)
162162
assert max_num_tokens > 0, "the length of the prompt of the dummy mm req is less than or equal to 0"
163-
remaining_tokens = max(max_num_tokens, input_seq_len)
163+
remaining_tokens = min(max_num_tokens, input_seq_len)
164164
if remaining_tokens > input_seq_len:
165165
logger.warning(f"Profiling with multimedia prompt which contains more tokens than the allowed input_seq_len. " \
166166
f"Multimodal prompt has {remaining_tokens} while the input_seq_len is: {input_seq_len}")
167-
## add + 1 to avoid error: RuntimeError: The max KV cache length of input sequences (X + 1) exceeds the KV cache manager's maximum supported length X.
168-
## at line "/code/tensorrt_llm/tensorrt_llm/_torch/attention_backend/trtllm.py", line 837
169-
self._max_seq_len = remaining_tokens + 1
170167
while remaining_tokens > 0:
171168
req_mm_input = trtllm.MultimodalInput(
172169
multimodal_hashes=multimodal_input.multimodal_hashes,
@@ -181,6 +178,9 @@ def _create_dummy_mm_context_request(
181178
output_config=trtllm.OutputConfig(),
182179
end_id=-1,
183180
multimodal_input=req_mm_input)
181+
# TODO:
182+
# create_input_processor_with_hash shouldn’t be required during profiling,
183+
# but is temporarily needed due to the multimodal input dependency for chunked prefill
184184
request.py_multimodal_data = multimodal_data
185185
remaining_tokens -= max_num_tokens
186186
requests.append(request)
@@ -193,11 +193,10 @@ def _create_dummy_mm_context_request(
193193
def _create_dummy_context_requests(
194194
self, input_seq_len: int) -> List[trtllm.Request]:
195195
requests = []
196-
if hasattr(
197-
self._model_engine.model,
198-
"original_arch") and MODEL_CLASS_VISION_ENCODER_MAPPING.get(
199-
self._model_engine.model.original_arch, None
200-
) and self._model_engine.attn_runtime_features.chunked_prefill:
196+
if hasattr(self._model_engine.model,
197+
"original_arch") and MODEL_CLASS_VISION_ENCODER_MAPPING.get(
198+
self._model_engine.model.original_arch, None):
199+
input_seq_len = min(self._max_num_tokens, input_seq_len)
201200
requests = self._create_dummy_mm_context_request(input_seq_len)
202201
# if succeed profiling with multimodal requests then return, otherwise profile
203202
# with default case

tests/unittest/llmapi/test_memory_profiling.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,7 @@ def test_profile_kvcache():
2222
VLM_MODEL = "Qwen2.5-VL-7B-Instruct"
2323
VLM_MODEL_PATH = get_model_path(VLM_MODEL)
2424

25-
build_config = BuildConfig(max_batch_size=2048,
26-
max_beam_width=1,
27-
max_seq_len=8192)
25+
build_config = BuildConfig(max_beam_width=1, max_num_tokens=16384)
2826
dynamic_batch_config = DynamicBatchConfig(
2927
enable_batch_size_tuning=True,
3028
enable_max_num_tokens_tuning=False,

0 commit comments

Comments
 (0)