Skip to content

Commit 8a7a0b4

Browse files
johncalespdominicshanshan
authored andcommitted
[TRTLLM-6780][fix] Add multimodal data to dummy requests during memory profiling (NVIDIA#7539)
Signed-off-by: John Calderon <johncalesp@gmail.com> Signed-off-by: John Calderon <jcalderon@nvidia.com> Signed-off-by: john calderon <jcalderon@nvidia.com> Signed-off-by: John Calderon <jcalderon@nvidia>
1 parent 6df6518 commit 8a7a0b4

File tree

9 files changed

+284
-47
lines changed

9 files changed

+284
-47
lines changed

tensorrt_llm/_torch/models/modeling_qwen2vl.py

Lines changed: 87 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
import os
33
from typing import Any, Dict, List, Optional, Tuple, Union
44

5+
import numpy as np
56
import torch
67
import torch.nn as nn
8+
from PIL import Image
79
from torch.nn import functional as F
810
from transformers import (AutoProcessor, AutoTokenizer, PretrainedConfig,
911
PreTrainedModel)
@@ -25,9 +27,11 @@
2527
from tensorrt_llm.inputs.multimodal import MultimodalParams
2628

2729
from ..._utils import nvtx_range, nvtx_range_debug
28-
from ...inputs import (BaseMultimodalInputProcessor, ExtraProcessedInputs,
29-
InputProcessor, MultimodalPlaceholderMetadata,
30+
from ...inputs import (BaseDummyInputsBuilder, BaseMultimodalInputProcessor,
31+
ExtraProcessedInputs, InputProcessor,
32+
MultimodalPlaceholderMetadata,
3033
MultimodalPlaceholderPlacement, TextPrompt,
34+
default_multimodal_input_loader,
3135
register_input_processor)
3236
from ...logger import logger
3337
from ...sampling_params import SamplingParams
@@ -83,16 +87,19 @@ def process_weights(weights: Dict,
8387
return filtered_weights
8488

8589

86-
class Qwen2VLInputProcessorBase(BaseMultimodalInputProcessor, InputProcessor):
90+
class Qwen2VLInputProcessorBase(BaseDummyInputsBuilder,
91+
BaseMultimodalInputProcessor, InputProcessor):
8792

8893
def __init__(self,
8994
model_path: str,
9095
model_config: PretrainedConfig,
9196
tokenizer: AutoTokenizer,
9297
trust_remote_code: bool = True):
9398
self.model_config = model_config
94-
self.tokenizer = tokenizer
99+
self.tokenizer = tokenizer if tokenizer is not None else AutoTokenizer.from_pretrained(
100+
model_path)
95101
self.use_fast = True
102+
self.model_path = model_path
96103
self.processor = AutoProcessor.from_pretrained(
97104
model_path,
98105
use_fast=self.use_fast,
@@ -277,6 +284,81 @@ def get_rope_index(
277284
mrope_position_deltas, device=input_ids.device).unsqueeze(1)
278285
return position_ids, mrope_position_deltas
279286

287+
def get_dummy_text(self, input_seq_len: int) -> str:
288+
ids = np.random.randint(
289+
low=0,
290+
high=int(
291+
self.model_config.vocab_size), # high is exclusive in NumPy
292+
size=input_seq_len,
293+
).tolist()
294+
return self.tokenizer.decode(ids, skip_special_tokens=True)
295+
296+
def get_dummy_image(self, max_width: int, max_height: int):
297+
image = Image.new("RGB", (max_width, max_height), color=255)
298+
return image
299+
300+
def get_dummy_prompt(self, input_seq_len: int):
301+
text = ""
302+
# we use the max resolution as starting point
303+
img_max_dim = 3584
304+
image = self.get_dummy_image(max_width=img_max_dim,
305+
max_height=img_max_dim)
306+
307+
test_mm_prompt = default_multimodal_input_loader(
308+
tokenizer=self.tokenizer,
309+
model_dir=self.model_path,
310+
model_type=self.model_config.model_type,
311+
modality="image",
312+
prompts=[text],
313+
media=[[image]],
314+
image_data_format="pt")[0]
315+
316+
prompt_token_ids_single_img, _ = self(test_mm_prompt, None)
317+
318+
# if the max img resolution results in a number of tokens greater then
319+
# input_seq_len, we keep lowering the resolution such as to find the
320+
# max resolution such as it does not exceed the input_seq_len
321+
while len(prompt_token_ids_single_img) > input_seq_len:
322+
# reduce img resolution
323+
img_max_dim = img_max_dim >> 1
324+
325+
image = self.get_dummy_image(max_width=img_max_dim,
326+
max_height=img_max_dim)
327+
328+
test_mm_prompt = default_multimodal_input_loader(
329+
tokenizer=self.tokenizer,
330+
model_dir=self.model_path,
331+
model_type=self.model_config.model_type,
332+
modality="image",
333+
prompts=[text],
334+
media=[[image]],
335+
image_data_format="pt")[0]
336+
337+
prompt_token_ids_single_img, _ = self(test_mm_prompt, None)
338+
339+
len_prompt_tokens_ids = len(prompt_token_ids_single_img)
340+
# There are corner cases where if we strictly try to generate a text based
341+
# on how many tokens we need to complete the input_seq_len, the output of
342+
# default_multimodal_input_loader may give more tokens then the input_seq_len and this
343+
# can lead to errors.
344+
# That is why we try to clip the variable text_token_left to a lower threshold
345+
# but close enough to the actual input_seq_len
346+
text_generation_perc_threshold = 0.95
347+
text_token_left = int((input_seq_len - len_prompt_tokens_ids) *
348+
text_generation_perc_threshold)
349+
350+
if text_token_left > 0:
351+
text = self.get_dummy_text(text_token_left)
352+
353+
return default_multimodal_input_loader(
354+
tokenizer=self.tokenizer,
355+
model_dir=self.model_path,
356+
model_type=self.model_config.model_type,
357+
modality="image",
358+
prompts=[text],
359+
media=[[image]],
360+
image_data_format="pt")[0]
361+
280362
def _preprocess(self, text: dict[str, any], mm_data: dict[str, any],
281363
mm_processor_kwargs: Dict[str, Any]):
282364
images = mm_data.get("image")
@@ -790,7 +872,7 @@ def __init__(
790872
**kwargs,
791873
) -> None:
792874
model_config.pretrained_config.rope_scaling['type'] = 'mrope'
793-
875+
self.original_arch = model_config.pretrained_config.architectures[0]
794876
# NOTE: Setting disable_fuse_rope to True to do mrope fusion in the model engine by pre-computing rotary_cos_sin in the model engine
795877
disabble_fuse_rope = kwargs.get('disable_fuse_rope', False)
796878
model_config.pretrained_config.text_config.disable_fuse_rope = disabble_fuse_rope

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import tensorrt_llm
88
import tensorrt_llm.bindings.executor as trtllm
99
from tensorrt_llm._torch.model_config import ModelConfig
10+
from tensorrt_llm._torch.models.modeling_utils import \
11+
MODEL_CLASS_VISION_ENCODER_MAPPING
1012
from tensorrt_llm._utils import str_dtype_to_binding, torch_dtype_to_str
1113
from tensorrt_llm.bindings.executor import DecodingMode
1214
from tensorrt_llm.llmapi.llm_args import (EagleDecodingConfig, KvCacheConfig,
@@ -76,6 +78,7 @@ def __init__(
7678
pytorch_backend_config: PyTorchConfig,
7779
speculative_config: SpeculativeConfig,
7880
sparse_attention_config: SparseAttentionConfig,
81+
profiling_stage_data: Optional[dict],
7982
):
8083
self._model_engine = model_engine
8184
self._draft_model_engine = draft_model_engine
@@ -93,6 +96,7 @@ def __init__(
9396
self._max_batch_size = max_batch_size
9497
self._net_max_seq_len = net_max_seq_len
9598
self._dummy_reqs = None
99+
self._profiling_stage_data = profiling_stage_data
96100
self._kv_cache_manager_cls = get_kv_cache_manager_cls(
97101
model_engine.model.model_config)
98102

@@ -133,13 +137,76 @@ def _cal_max_memory(self, peak_memory, total_gpu_memory, fraction,
133137
f", tmp kv_mem { (allocated_bytes) / (GB):.2f} GiB")
134138
return int(available_kv_mem)
135139

140+
def _create_dummy_mm_context_request(
141+
self, input_seq_len: int) -> List[trtllm.Request]:
142+
requests = []
143+
if isinstance(
144+
self._profiling_stage_data,
145+
dict) and not self._profiling_stage_data.get("enable_mm_reqs"):
146+
return requests
147+
148+
input_processor = self._model_engine.input_processor
149+
if not (hasattr(input_processor, "get_dummy_prompt")):
150+
logger.warning("The input processor of the model does not have the method [get_dummy_prompt] implemented." \
151+
"Profiling with the default input dummy context request. This may not take into account the memory consumption of " \
152+
"the image encoder")
153+
return requests
154+
prompt = input_processor.get_dummy_prompt(input_seq_len)
155+
156+
prompt_token_ids, extra_processed_inputs = self._model_engine.input_processor_with_hash(
157+
prompt, None)
158+
159+
multimodal_input = extra_processed_inputs.get('multimodal_input')
160+
multimodal_data = extra_processed_inputs.get('multimodal_data')
161+
162+
max_num_tokens = len(prompt_token_ids)
163+
assert max_num_tokens > 0, "the length of the prompt of the dummy mm req is less than or equal to 0"
164+
remaining_tokens = min(max_num_tokens, input_seq_len)
165+
if remaining_tokens > input_seq_len:
166+
logger.warning(f"Profiling with multimedia prompt which contains more tokens than the allowed input_seq_len. " \
167+
f"Multimodal prompt has {remaining_tokens} while the input_seq_len is: {input_seq_len}")
168+
while remaining_tokens > 0:
169+
req_mm_input = trtllm.MultimodalInput(
170+
multimodal_hashes=multimodal_input.multimodal_hashes,
171+
multimodal_positions=multimodal_input.multimodal_positions,
172+
multimodal_lengths=multimodal_input.multimodal_lengths
173+
) if multimodal_input else None
174+
request = trtllm.Request(prompt_token_ids,
175+
max_tokens=1,
176+
streaming=False,
177+
sampling_config=trtllm.SamplingConfig(
178+
beam_width=self._max_beam_width, ),
179+
output_config=trtllm.OutputConfig(),
180+
end_id=-1,
181+
multimodal_input=req_mm_input)
182+
# TODO:
183+
# create_input_processor_with_hash shouldn’t be required during profiling,
184+
# but is temporarily needed due to the multimodal input dependency for chunked prefill
185+
request.py_multimodal_data = multimodal_data
186+
remaining_tokens -= max_num_tokens
187+
requests.append(request)
188+
189+
if self._mapping.enable_attention_dp:
190+
requests = requests * self._mapping.tp_size
191+
192+
return requests
193+
136194
def _create_dummy_context_requests(
137195
self, input_seq_len: int) -> List[trtllm.Request]:
196+
requests = []
197+
if hasattr(self._model_engine.model,
198+
"original_arch") and MODEL_CLASS_VISION_ENCODER_MAPPING.get(
199+
self._model_engine.model.original_arch, None):
200+
input_seq_len = min(self._max_num_tokens, input_seq_len)
201+
requests = self._create_dummy_mm_context_request(input_seq_len)
202+
# if succeed profiling with multimodal requests then return, otherwise profile
203+
# with default case
204+
if requests:
205+
return requests
138206
vocab_size = self._model_engine.model.model_config.pretrained_config.vocab_size
139207
max_num_tokens = self._max_num_tokens
140208
max_beam_width = self._max_beam_width
141209

142-
requests = []
143210
input_seq_len = min(max_num_tokens, input_seq_len)
144211
remaining_tokens = max_num_tokens
145212
while remaining_tokens > 0:
@@ -362,6 +429,8 @@ def configure_kv_cache_capacity(self, py_executor: PyExecutor) -> None:
362429
)
363430
# set max_gpu_total_bytes
364431
self._kv_cache_config.max_gpu_total_bytes = kv_cache_max_memory
432+
if isinstance(self._profiling_stage_data, dict):
433+
self._profiling_stage_data["activation_bytes"] = activation_bytes
365434
# ---------------------------handle max_gpu_total_bytes---------------------------------
366435

367436
def _create_kv_cache_manager(

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
torch_dtype_to_str, trace_func)
1919
from tensorrt_llm.inputs.multimodal import (MultimodalParams,
2020
MultimodalRuntimeData)
21+
from tensorrt_llm.inputs.registry import (create_input_processor,
22+
create_input_processor_with_hash)
2123
from tensorrt_llm.logger import logger
2224
from tensorrt_llm.lora_helper import LoraConfig
2325
from tensorrt_llm.lora_manager import LoraModelConfig
@@ -172,7 +174,9 @@ def __init__(
172174

173175
self.attn_runtime_features = attn_runtime_features or AttentionRuntimeFeatures(
174176
)
175-
177+
self.input_processor = create_input_processor(model_path, None)
178+
self.input_processor_with_hash = create_input_processor_with_hash(
179+
self.input_processor)
176180
if model is None:
177181
loader = ModelLoader(
178182
pytorch_backend_config=pytorch_backend_config,

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ def create_py_executor(
207207
tokenizer: Optional[TokenizerBase] = None,
208208
lora_config: Optional[LoraConfig] = None,
209209
kv_connector_config: Optional[KvCacheConnectorConfig] = None,
210+
profiling_stage_data: Optional[dict] = None,
210211
) -> PyExecutor:
211212

212213
garbage_collection_gen0_threshold = llm_args.garbage_collection_gen0_threshold
@@ -570,6 +571,7 @@ def drafting_loop_wrapper(model):
570571
kv_cache_config=kv_cache_config,
571572
pytorch_backend_config=pytorch_backend_config,
572573
speculative_config=spec_config,
574+
profiling_stage_data=profiling_stage_data,
573575
sparse_attention_config=sparse_attention_config,
574576
)
575577
estimating_kv_cache = kv_cache_creator.try_prepare_estimation()

tensorrt_llm/commands/serve.py

Lines changed: 28 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def get_llm_args(model: str,
8787
trust_remote_code: bool = False,
8888
reasoning_parser: Optional[str] = None,
8989
fail_fast_on_attention_window_too_large: bool = False,
90+
enable_chunked_prefill: bool = False,
9091
**llm_args_extra_dict: Any):
9192

9293
if gpus_per_node is None:
@@ -109,44 +110,27 @@ def get_llm_args(model: str,
109110
dynamic_batch_config=dynamic_batch_config,
110111
)
111112
llm_args = {
112-
"model":
113-
model,
114-
"scheduler_config":
115-
scheduler_config,
116-
"tokenizer":
117-
tokenizer,
118-
"tensor_parallel_size":
119-
tensor_parallel_size,
120-
"pipeline_parallel_size":
121-
pipeline_parallel_size,
122-
"moe_expert_parallel_size":
123-
moe_expert_parallel_size,
124-
"gpus_per_node":
125-
gpus_per_node,
126-
"trust_remote_code":
127-
trust_remote_code,
128-
"build_config":
129-
build_config,
130-
"max_batch_size":
131-
max_batch_size,
132-
"max_num_tokens":
133-
max_num_tokens,
134-
"max_beam_width":
135-
max_beam_width,
136-
"max_seq_len":
137-
max_seq_len,
138-
"kv_cache_config":
139-
kv_cache_config,
140-
"backend":
141-
backend,
142-
"num_postprocess_workers":
143-
num_postprocess_workers,
144-
"postprocess_tokenizer_dir":
145-
tokenizer or model,
146-
"reasoning_parser":
147-
reasoning_parser,
113+
"model": model,
114+
"scheduler_config": scheduler_config,
115+
"tokenizer": tokenizer,
116+
"tensor_parallel_size": tensor_parallel_size,
117+
"pipeline_parallel_size": pipeline_parallel_size,
118+
"moe_expert_parallel_size": moe_expert_parallel_size,
119+
"gpus_per_node": gpus_per_node,
120+
"trust_remote_code": trust_remote_code,
121+
"build_config": build_config,
122+
"max_batch_size": max_batch_size,
123+
"max_num_tokens": max_num_tokens,
124+
"max_beam_width": max_beam_width,
125+
"max_seq_len": max_seq_len,
126+
"kv_cache_config": kv_cache_config,
127+
"backend": backend,
128+
"num_postprocess_workers": num_postprocess_workers,
129+
"postprocess_tokenizer_dir": tokenizer or model,
130+
"reasoning_parser": reasoning_parser,
148131
"fail_fast_on_attention_window_too_large":
149132
fail_fast_on_attention_window_too_large,
133+
"enable_chunked_prefill": enable_chunked_prefill,
150134
}
151135

152136
return llm_args, llm_args_extra_dict
@@ -329,6 +313,10 @@ def convert(self, value: Any, param: Optional["click.Parameter"],
329313
help=
330314
"Exit with runtime error when attention window is too large to fit even a single sequence in the KV cache."
331315
)
316+
@click.option("--enable_chunked_prefill",
317+
is_flag=True,
318+
default=False,
319+
help="Enable chunked prefill")
332320
def serve(
333321
model: str, tokenizer: Optional[str], host: str, port: int,
334322
log_level: str, backend: str, max_beam_width: int, max_batch_size: int,
@@ -338,7 +326,8 @@ def serve(
338326
num_postprocess_workers: int, trust_remote_code: bool,
339327
extra_llm_api_options: Optional[str], reasoning_parser: Optional[str],
340328
metadata_server_config_file: Optional[str], server_role: Optional[str],
341-
fail_fast_on_attention_window_too_large: bool):
329+
fail_fast_on_attention_window_too_large: bool,
330+
enable_chunked_prefill: bool):
342331
"""Running an OpenAI API compatible server
343332
344333
MODEL: model name | HF checkpoint path | TensorRT engine path
@@ -363,7 +352,8 @@ def serve(
363352
trust_remote_code=trust_remote_code,
364353
reasoning_parser=reasoning_parser,
365354
fail_fast_on_attention_window_too_large=
366-
fail_fast_on_attention_window_too_large)
355+
fail_fast_on_attention_window_too_large,
356+
enable_chunked_prefill=enable_chunked_prefill)
367357

368358
llm_args_extra_dict = {}
369359
if extra_llm_api_options is not None:

0 commit comments

Comments
 (0)