99import tensorrt_llm
1010import tensorrt_llm .bindings .executor as trtllm
1111from tensorrt_llm ._torch .model_config import ModelConfig
12+ from tensorrt_llm ._torch .models .modeling_utils import \
13+ MODEL_CLASS_VISION_ENCODER_MAPPING
1214from tensorrt_llm ._utils import str_dtype_to_binding , torch_dtype_to_str
1315from tensorrt_llm .bindings .executor import DecodingMode , ExecutorConfig
1416from tensorrt_llm .inputs .registry import (create_input_processor ,
@@ -56,14 +58,8 @@ def __init__(self, *, executor_config: ExecutorConfig,
5658 self ._draft_model_engine = draft_model_engine
5759 self ._mapping = mapping
5860 self ._max_kv_tokens_in = self ._executor_config .kv_cache_config .max_tokens
59- self ._is_multimodal = getattr (self ._model_engine .model , "is_multimodal" ,
60- False )
61- if self ._is_multimodal :
62- self ._dummy_reqs = self ._create_dummy_mm_context_request (
63- net_max_seq_len - 1 )
64- else :
65- self ._dummy_reqs = self ._create_dummy_context_requests (
66- net_max_seq_len - 1 )
61+ self ._dummy_reqs = self ._create_dummy_context_requests (net_max_seq_len -
62+ 1 )
6763 self ._kv_connector_manager = kv_connector_manager
6864
6965 @staticmethod
@@ -142,6 +138,7 @@ def _cal_max_memory(self, peak_memory, total_gpu_memory, fraction,
142138
143139 def _create_dummy_mm_context_request (
144140 self , input_seq_len : int ) -> List [trtllm .Request ]:
141+ requests = []
145142 self ._model_name_or_path = getattr (self ._model_engine .model ,
146143 "name_or_path" , None )
147144 self ._tokenizer = AutoTokenizer .from_pretrained (
@@ -152,7 +149,7 @@ def _create_dummy_mm_context_request(
152149 logger .warning ("The input processor of the model does not have the method [get_prompt_for_profiling] implemented." \
153150 "Profiling with the default input dummy context request. This may not take into account the memory consumption of " \
154151 "ViT's encoder" )
155- return self . _create_dummy_context_requests ( input_seq_len )
152+ return requests
156153 text_prompt = input_processor .get_prompt_for_profiling ()
157154 max_beam_width = self ._executor_config .max_beam_width
158155 input_processor_with_hash = create_input_processor_with_hash (
@@ -162,7 +159,6 @@ def _create_dummy_mm_context_request(
162159 multimodal_input = extra_processed_inputs .get ('multimodal_input' )
163160 multimodal_data = extra_processed_inputs .get ('multimodal_data' )
164161
165- requests = []
166162 max_num_tokens = len (prompt_token_ids )
167163 remaining_tokens = max (max_num_tokens , input_seq_len )
168164 # add +1 to max_num_tokens to avoid assert in line 772 of tensorrt_llm/_torch/attention_backend/trtllm.py
@@ -195,11 +191,18 @@ def _create_dummy_mm_context_request(
195191
196192 def _create_dummy_context_requests (
197193 self , input_seq_len : int ) -> List [trtllm .Request ]:
194+ requests = []
195+ if MODEL_CLASS_VISION_ENCODER_MAPPING .get (
196+ self ._model_engine .model .original_arch , None ):
197+ requests = self ._create_dummy_mm_context_request (input_seq_len )
198+ # if succeed profiling with multimodal requests then return, otherwise profile
199+ # with default case
200+ if requests :
201+ return requests
198202 vocab_size = self ._model_engine .model .model_config .pretrained_config .vocab_size
199203 max_num_tokens = self ._executor_config .max_num_tokens
200204 max_beam_width = self ._executor_config .max_beam_width
201205
202- requests = []
203206 input_seq_len = min (max_num_tokens , input_seq_len )
204207 remaining_tokens = max_num_tokens
205208 while remaining_tokens > 0 :
0 commit comments