Skip to content

Commit f60cf91

Browse files
committed
address comments: change unit test, and add more asserts
Signed-off-by: John Calderon <jcalderon@nvidia.com>
1 parent a887420 commit f60cf91

File tree

3 files changed

+28
-26
lines changed

3 files changed

+28
-26
lines changed

tensorrt_llm/_torch/models/modeling_qwen2vl.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import copy
22
import os
3-
import random
43
from typing import Any, Dict, List, Optional, Tuple, Union
54

5+
import numpy as np
66
import torch
77
import torch.nn as nn
88
from PIL import Image
@@ -229,10 +229,12 @@ def get_rope_index(
229229
return position_ids, mrope_position_deltas
230230

231231
def get_dummy_text(self, input_seq_len: int):
232-
return self.tokenizer.decode([
233-
random.randint(0, self.model_config.vocab_size - 1)
234-
for _ in range(input_seq_len)
235-
])
232+
return self.tokenizer.decode(
233+
np.random.randint(
234+
low=0,
235+
high=self.model_config.
236+
vocab_size, # Note: high is exclusive in NumPy
237+
size=input_seq_len))
236238

237239
def get_dummy_images(self, max_width: int, max_height: int,
238240
num_images: int):

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -66,14 +66,14 @@ def __init__(self, *, model_engine: PyTorchModelEngine,
6666
self._max_num_tokens = max_num_tokens
6767
self._max_beam_width = max_beam_width
6868
self._max_seq_len = max_seq_len
69+
self._profiling_stage_data = profiling_stage_data
6970
self._dummy_reqs = self._create_dummy_context_requests(net_max_seq_len -
7071
1)
7172
self._kv_connector_manager = kv_connector_manager
7273
self._pytorch_backend_config = pytorch_backend_config
7374
self._speculative_config = speculative_config
7475
self._tokens_per_block = tokens_per_block
7576
self._max_batch_size = max_batch_size
76-
self._profiling_stage_data = profiling_stage_data
7777

7878
@staticmethod
7979
def _get_cache_size_per_token(model_config: ModelConfig,
@@ -152,16 +152,20 @@ def _cal_max_memory(self, peak_memory, total_gpu_memory, fraction,
152152
def _create_dummy_mm_context_request(
153153
self, input_seq_len: int) -> List[trtllm.Request]:
154154
requests = []
155-
self._model_name_or_path = getattr(self._model_engine.model,
156-
"name_or_path", None)
157-
self._tokenizer = AutoTokenizer.from_pretrained(
158-
self._model_name_or_path)
159-
input_processor = create_input_processor(self._model_name_or_path,
160-
self._tokenizer)
155+
if isinstance(
156+
self._profiling_stage_data,
157+
dict) and not self._profiling_stage_data.get("enable_mm_reqs"):
158+
return requests
159+
160+
model_name_or_path = getattr(self._model_engine.model, "name_or_path",
161+
None)
162+
assert model_name_or_path is not None, "Could not determine model name or path"
163+
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
164+
input_processor = create_input_processor(model_name_or_path, tokenizer)
161165
if not (hasattr(input_processor, "get_dummy_prompt")):
162-
logger.warning("The input processor of the model does not have the method [get_prompt_for_profiling] implemented." \
166+
logger.warning("The input processor of the model does not have the method [get_dummy_prompt] implemented." \
163167
"Profiling with the default input dummy context request. This may not take into account the memory consumption of " \
164-
"ViT's encoder")
168+
"the image encoder")
165169
return requests
166170
text_prompt = input_processor.get_dummy_prompt(input_seq_len,
167171
{'image': 1})
@@ -174,6 +178,7 @@ def _create_dummy_mm_context_request(
174178
multimodal_data = extra_processed_inputs.get('multimodal_data')
175179

176180
max_num_tokens = len(prompt_token_ids)
181+
assert max_num_tokens > 0, "the length of the prompt of the dummy mm req is less than or equal to 0"
177182
remaining_tokens = max(max_num_tokens, input_seq_len)
178183
if remaining_tokens > input_seq_len:
179184
logger.warning(f"Profiling with multimedia prompt which contains more tokens than the allowed input_seq_len. " \

tests/unittest/llmapi/test_memory_profiling.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,10 @@ def test_profile_kvcache():
2222

2323
VLM_MODEL = "Qwen2.5-VL-7B-Instruct"
2424
VLM_MODEL_PATH = get_model_path(VLM_MODEL)
25-
LLM_MODEL = "Qwen2.5-7B-Instruct"
26-
LLM_MODEL_PATH = get_model_path(LLM_MODEL)
2725

2826
build_config = BuildConfig(max_batch_size=2048,
29-
max_num_tokens=8192,
3027
max_beam_width=1,
31-
max_seq_len=None)
28+
max_seq_len=8192)
3229
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9, )
3330

3431
dynamic_batch_config = DynamicBatchConfig(
@@ -66,21 +63,19 @@ def test_profile_kvcache():
6663

6764
torchllm_args = TorchLlmArgs(**llm_args)
6865

69-
profiling_data = dict()
66+
profiling_data = {"enable_mm_reqs": True}
7067
py_executor = create_py_executor(llm_args=torchllm_args,
7168
checkpoint_dir=VLM_MODEL_PATH,
7269
profiling_stage_data=profiling_data)
73-
vlm_max_gpu_total_bytes = profiling_data["max_gpu_total_bytes"]
70+
vlm_max_gpu_total_bytes_with_mm_reqs = profiling_data["max_gpu_total_bytes"]
7471
py_executor.shutdown()
7572
torch.cuda.empty_cache()
7673

77-
profiling_data = dict()
78-
llm_args["model"] = LLM_MODEL
79-
llm_args["postprocess_tokenizer_dir"] = LLM_MODEL
74+
profiling_data = {"enable_mm_reqs": False}
8075
torchllm_args = TorchLlmArgs(**llm_args)
8176
create_py_executor(llm_args=torchllm_args,
82-
checkpoint_dir=LLM_MODEL_PATH,
77+
checkpoint_dir=VLM_MODEL_PATH,
8378
profiling_stage_data=profiling_data)
84-
llm_max_gpu_total_bytes = profiling_data["max_gpu_total_bytes"]
79+
vlm_max_gpu_total_bytes_no_mm_reqs = profiling_data["max_gpu_total_bytes"]
8580

86-
assert vlm_max_gpu_total_bytes < llm_max_gpu_total_bytes, f"available KVCache for VLMs is expected to be less than LLMs, but got {vlm_max_gpu_total_bytes} for VLM and {llm_max_gpu_total_bytes} for LLM"
81+
assert vlm_max_gpu_total_bytes_with_mm_reqs < vlm_max_gpu_total_bytes_no_mm_reqs, f"available KVCache for VLMs is expected to be less when profiling with mm reqs, but got {vlm_max_gpu_total_bytes_with_mm_reqs} for mm reqs and {vlm_max_gpu_total_bytes_no_mm_reqs} without mm reqs"

0 commit comments

Comments
 (0)