Skip to content

Commit 6c094eb

Browse files
committed
rebase and change logic
Signed-off-by: John Calderon <jcalderon@nvidia.com>
1 parent 617ca5b commit 6c094eb

File tree

2 files changed

+37
-14
lines changed

2 files changed

+37
-14
lines changed

tensorrt_llm/_torch/models/modeling_qwen2vl.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from transformers import (AutoProcessor, AutoTokenizer, PretrainedConfig,
88
PreTrainedModel, Qwen2_5_VLForConditionalGeneration,
99
Qwen2VLForConditionalGeneration)
10+
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
1011

1112
from tensorrt_llm.inputs.multimodal import MultimodalParams
1213

@@ -222,6 +223,27 @@ def get_rope_index(
222223
mrope_position_deltas, device=input_ids.device).unsqueeze(1)
223224
return position_ids, mrope_position_deltas
224225

226+
def get_prompt_for_profiling(self):
227+
"Send prompt with largest image resolution for profiling the worst case"
228+
max_width = 9999999
229+
max_height = 9999999
230+
resized_height, resized_width = smart_resize(
231+
height=max_height,
232+
width=max_width,
233+
factor=self.model_config.vision_config.patch_size *
234+
self.model_config.vision_config.spatial_merge_size,
235+
min_pixels=self.processor.image_processor.min_pixels,
236+
max_pixels=self.processor.image_processor.max_pixels,
237+
)
238+
img_tensor = torch.rand([3, resized_width, resized_height],
239+
device="cpu")
240+
mm_data = {"image": [img_tensor]}
241+
242+
text_prompt = TextPrompt(
243+
prompt="<|vision_start|><|image_pad|><|vision_end|>",
244+
multi_modal_data=mm_data)
245+
return text_prompt
246+
225247
def _preprocess(self, text: dict[str, any], mm_data: dict[str, any],
226248
mm_processor_kwargs: Dict[str, Any]):
227249
images = mm_data.get("image")
@@ -438,7 +460,7 @@ def __init__(
438460
) -> None:
439461
model_config.pretrained_config.rope_scaling['type'] = 'mrope'
440462
config = model_config.pretrained_config
441-
463+
self.original_arch = model_config.pretrained_config.architectures[0]
442464
assert model_config.attn_backend == 'TRTLLM', "Qwen2/2.5-VL only supports TRTLLM backend now"
443465
super().__init__(config)
444466

@@ -643,7 +665,6 @@ class Qwen2VLModel(Qwen2VLModelBase):
643665

644666
def __init__(self, model_config: ModelConfig[PretrainedConfig], *args,
645667
**kwargs):
646-
self.is_multimodal = True # variable used during profiling
647668
super().__init__(model_config, *args, **kwargs)
648669
if not DISAGG:
649670
self.mm_encoder = Qwen2VisionModelBase(
@@ -665,7 +686,6 @@ class Qwen2_5_VLModel(Qwen2VLModelBase):
665686

666687
def __init__(self, model_config: ModelConfig[PretrainedConfig], *args,
667688
**kwargs):
668-
self.is_multimodal = True # variable used during profiling
669689
super().__init__(model_config, *args, **kwargs)
670690
if not DISAGG:
671691
self.mm_encoder = Qwen2VisionModelBase(

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
import tensorrt_llm
1010
import tensorrt_llm.bindings.executor as trtllm
1111
from tensorrt_llm._torch.model_config import ModelConfig
12+
from tensorrt_llm._torch.models.modeling_utils import \
13+
MODEL_CLASS_VISION_ENCODER_MAPPING
1214
from tensorrt_llm._utils import str_dtype_to_binding, torch_dtype_to_str
1315
from tensorrt_llm.bindings.executor import DecodingMode, ExecutorConfig
1416
from tensorrt_llm.inputs.registry import (create_input_processor,
@@ -56,14 +58,8 @@ def __init__(self, *, executor_config: ExecutorConfig,
5658
self._draft_model_engine = draft_model_engine
5759
self._mapping = mapping
5860
self._max_kv_tokens_in = self._executor_config.kv_cache_config.max_tokens
59-
self._is_multimodal = getattr(self._model_engine.model, "is_multimodal",
60-
False)
61-
if self._is_multimodal:
62-
self._dummy_reqs = self._create_dummy_mm_context_request(
63-
net_max_seq_len - 1)
64-
else:
65-
self._dummy_reqs = self._create_dummy_context_requests(
66-
net_max_seq_len - 1)
61+
self._dummy_reqs = self._create_dummy_context_requests(net_max_seq_len -
62+
1)
6763
self._kv_connector_manager = kv_connector_manager
6864

6965
@staticmethod
@@ -142,6 +138,7 @@ def _cal_max_memory(self, peak_memory, total_gpu_memory, fraction,
142138

143139
def _create_dummy_mm_context_request(
144140
self, input_seq_len: int) -> List[trtllm.Request]:
141+
requests = []
145142
self._model_name_or_path = getattr(self._model_engine.model,
146143
"name_or_path", None)
147144
self._tokenizer = AutoTokenizer.from_pretrained(
@@ -152,7 +149,7 @@ def _create_dummy_mm_context_request(
152149
logger.warning("The input processor of the model does not have the method [get_prompt_for_profiling] implemented." \
153150
"Profiling with the default input dummy context request. This may not take into account the memory consumption of " \
154151
"ViT's encoder")
155-
return self._create_dummy_context_requests(input_seq_len)
152+
return requests
156153
text_prompt = input_processor.get_prompt_for_profiling()
157154
max_beam_width = self._executor_config.max_beam_width
158155
input_processor_with_hash = create_input_processor_with_hash(
@@ -162,7 +159,6 @@ def _create_dummy_mm_context_request(
162159
multimodal_input = extra_processed_inputs.get('multimodal_input')
163160
multimodal_data = extra_processed_inputs.get('multimodal_data')
164161

165-
requests = []
166162
max_num_tokens = len(prompt_token_ids)
167163
remaining_tokens = max(max_num_tokens, input_seq_len)
168164
# add +1 to max_num_tokens to avoid assert in line 772 of tensorrt_llm/_torch/attention_backend/trtllm.py
@@ -195,11 +191,18 @@ def _create_dummy_mm_context_request(
195191

196192
def _create_dummy_context_requests(
197193
self, input_seq_len: int) -> List[trtllm.Request]:
194+
requests = []
195+
if MODEL_CLASS_VISION_ENCODER_MAPPING.get(
196+
self._model_engine.model.original_arch, None):
197+
requests = self._create_dummy_mm_context_request(input_seq_len)
198+
# if succeed profiling with multimodal requests then return, otherwise profile
199+
# with default case
200+
if requests:
201+
return requests
198202
vocab_size = self._model_engine.model.model_config.pretrained_config.vocab_size
199203
max_num_tokens = self._executor_config.max_num_tokens
200204
max_beam_width = self._executor_config.max_beam_width
201205

202-
requests = []
203206
input_seq_len = min(max_num_tokens, input_seq_len)
204207
remaining_tokens = max_num_tokens
205208
while remaining_tokens > 0:

0 commit comments

Comments
 (0)