Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
8ad3a36
add multimodal dummy request to profiling - drafting stage
johncalesp Sep 4, 2025
7f8e169
guard against multimodal_input being None
johncalesp Sep 4, 2025
d59fb38
rebase and change logic
johncalesp Sep 9, 2025
22b8d37
check attribute original_arch is present
johncalesp Sep 9, 2025
1101862
rebase from main and add fix for only-text request for qwen2.5 and ad…
johncalesp Sep 15, 2025
5449a63
fix max_seq_len is less than max_num_tokens during profiling
johncalesp Sep 16, 2025
f20c8fb
address yechang comments - p1
johncalesp Sep 18, 2025
0a1fd4a
fix rebase to 80dd8fe1973323eb8f01060788c0d5485a0ce0f8
johncalesp Sep 18, 2025
39e78ff
check for new function name and fix TextPrompt attribute for qwen
johncalesp Sep 18, 2025
38bb1d2
address comments - change design to use default_multimodal_input_loader
johncalesp Sep 19, 2025
62679a3
add additional arguments for mm data
johncalesp Sep 19, 2025
8b80b32
address comments: change unit test, and add more asserts
johncalesp Sep 22, 2025
1512044
fix rebase to commit b1738c3f189560a857ea1adcfdfb8e68c571c81d
johncalesp Sep 22, 2025
3658a44
address code rabbit comments && remove mrope_config.mrope_position_id…
johncalesp Sep 24, 2025
b7f0f7e
integrate latest feedback
johncalesp Oct 1, 2025
52afab2
add unit test to test-db
johncalesp Oct 1, 2025
187b80f
fix unit test by adding chunked prefill parameter
johncalesp Oct 3, 2025
0d470ec
fix test L40S-PyTorch-2.test_e2e.test_ptp_quickstart_multimodal[NVILA…
johncalesp Oct 3, 2025
bcc9004
fix tests A10-PyTorch-1.test_e2e.test_openai_chat_multimodal_example …
johncalesp Oct 4, 2025
345954c
include flag to check chunked prefill flag during profiling
johncalesp Oct 5, 2025
a546ae9
change logic to get initial input_seq_len
johncalesp Oct 9, 2025
f011861
fix latest rebase from 7291cdc42287297bf72015e7201fede7985edeae
johncalesp Oct 14, 2025
5410711
Fix rebase to 1cdb0b6
johncalesp Oct 14, 2025
9919b4c
fix format file _util.py
johncalesp Oct 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 90 additions & 8 deletions tensorrt_llm/_torch/models/modeling_qwen2vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
import os
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from torch.nn import functional as F
from transformers import (AutoProcessor, AutoTokenizer, PretrainedConfig,
PreTrainedModel)
Expand All @@ -25,9 +27,11 @@
from tensorrt_llm.inputs.multimodal import MultimodalParams

from ..._utils import nvtx_range, nvtx_range_debug
from ...inputs import (BaseMultimodalInputProcessor, ExtraProcessedInputs,
InputProcessor, MultimodalPlaceholderMetadata,
from ...inputs import (BaseDummyInputsBuilder, BaseMultimodalInputProcessor,
ExtraProcessedInputs, InputProcessor,
MultimodalPlaceholderMetadata,
MultimodalPlaceholderPlacement, TextPrompt,
default_multimodal_input_loader,
register_input_processor)
from ...logger import logger
from ...sampling_params import SamplingParams
Expand Down Expand Up @@ -83,16 +87,19 @@ def process_weights(weights: Dict,
return filtered_weights


class Qwen2VLInputProcessorBase(BaseMultimodalInputProcessor, InputProcessor):
class Qwen2VLInputProcessorBase(BaseDummyInputsBuilder,
BaseMultimodalInputProcessor, InputProcessor):

def __init__(self,
model_path: str,
model_config: PretrainedConfig,
tokenizer: AutoTokenizer,
trust_remote_code: bool = True):
self.model_config = model_config
self.tokenizer = tokenizer
self.tokenizer = tokenizer if tokenizer is not None else AutoTokenizer.from_pretrained(
model_path)
self.use_fast = True
self.model_path = model_path
self.processor = AutoProcessor.from_pretrained(
model_path,
use_fast=self.use_fast,
Expand Down Expand Up @@ -277,6 +284,81 @@ def get_rope_index(
mrope_position_deltas, device=input_ids.device).unsqueeze(1)
return position_ids, mrope_position_deltas

def get_dummy_text(self, input_seq_len: int) -> str:
ids = np.random.randint(
low=0,
high=int(
self.model_config.vocab_size), # high is exclusive in NumPy
size=input_seq_len,
).tolist()
return self.tokenizer.decode(ids, skip_special_tokens=True)

def get_dummy_image(self, max_width: int, max_height: int):
image = Image.new("RGB", (max_width, max_height), color=255)
return image

def get_dummy_prompt(self, input_seq_len: int):
text = ""
# we use the max resolution as starting point
img_max_dim = 3584
image = self.get_dummy_image(max_width=img_max_dim,
max_height=img_max_dim)

test_mm_prompt = default_multimodal_input_loader(
tokenizer=self.tokenizer,
model_dir=self.model_path,
model_type=self.model_config.model_type,
modality="image",
prompts=[text],
media=[[image]],
image_data_format="pt")[0]

prompt_token_ids_single_img, _ = self(test_mm_prompt, None)

# if the max img resolution results in a number of tokens greater then
# input_seq_len, we keep lowering the resolution such as to find the
# max resolution such as it does not exceed the input_seq_len
while len(prompt_token_ids_single_img) > input_seq_len:
# reduce img resolution
img_max_dim = img_max_dim >> 1

image = self.get_dummy_image(max_width=img_max_dim,
max_height=img_max_dim)

test_mm_prompt = default_multimodal_input_loader(
tokenizer=self.tokenizer,
model_dir=self.model_path,
model_type=self.model_config.model_type,
modality="image",
prompts=[text],
media=[[image]],
image_data_format="pt")[0]

prompt_token_ids_single_img, _ = self(test_mm_prompt, None)

len_prompt_tokens_ids = len(prompt_token_ids_single_img)
# There are corner cases where if we strictly try to generate a text based
# on how many tokens we need to complete the input_seq_len, the output of
# default_multimodal_input_loader may give more tokens then the input_seq_len and this
# can lead to errors.
# That is why we try to clip the variable text_token_left to a lower threshold
# but close enough to the actual input_seq_len
text_generation_perc_threshold = 0.95
text_token_left = int((input_seq_len - len_prompt_tokens_ids) *
text_generation_perc_threshold)

if text_token_left > 0:
text = self.get_dummy_text(text_token_left)

return default_multimodal_input_loader(
tokenizer=self.tokenizer,
model_dir=self.model_path,
model_type=self.model_config.model_type,
modality="image",
prompts=[text],
media=[[image]],
image_data_format="pt")[0]

def _preprocess(self, text: dict[str, any], mm_data: dict[str, any],
mm_processor_kwargs: Dict[str, Any]):
images = mm_data.get("image")
Expand Down Expand Up @@ -790,7 +872,7 @@ def __init__(
**kwargs,
) -> None:
model_config.pretrained_config.rope_scaling['type'] = 'mrope'

self.original_arch = model_config.pretrained_config.architectures[0]
# NOTE: Setting disable_fuse_rope to True to do mrope fusion in the model engine by pre-computing rotary_cos_sin in the model engine
disabble_fuse_rope = kwargs.get('disable_fuse_rope', False)
model_config.pretrained_config.text_config.disable_fuse_rope = disabble_fuse_rope
Expand Down Expand Up @@ -979,7 +1061,7 @@ def multimodal_data_device_paths(self) -> List[str]:
return [
"image.pixel_values", "image.image_grid_thw",
"video.pixel_values_videos", "video.video_grid_thw",
"multimodal_embedding", "mrope_config.mrope_position_ids"
"multimodal_embedding"
]

def load_weights(self, weights, weight_mapper: BaseWeightMapper):
Expand Down Expand Up @@ -1032,12 +1114,12 @@ def multimodal_data_device_paths(self) -> List[str]:
return [
"image.pixel_values", "video.pixel_values_videos",
"image.image_grid_thw", "video.video_grid_thw",
"multimodal_embedding", "mrope_config.mrope_position_ids"
"multimodal_embedding"
]
else:
return [
"image.pixel_values", "video.pixel_values_videos",
"multimodal_embedding", "mrope_config.mrope_position_ids"
"multimodal_embedding"
]

def load_weights(self, weights, weight_mapper: BaseWeightMapper):
Expand Down
71 changes: 70 additions & 1 deletion tensorrt_llm/_torch/pyexecutor/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import tensorrt_llm
import tensorrt_llm.bindings.executor as trtllm
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.models.modeling_utils import \
MODEL_CLASS_VISION_ENCODER_MAPPING
from tensorrt_llm._utils import str_dtype_to_binding, torch_dtype_to_str
from tensorrt_llm.bindings.executor import DecodingMode
from tensorrt_llm.llmapi.llm_args import (EagleDecodingConfig, KvCacheConfig,
Expand Down Expand Up @@ -76,6 +78,7 @@ def __init__(
pytorch_backend_config: PyTorchConfig,
speculative_config: SpeculativeConfig,
sparse_attention_config: SparseAttentionConfig,
profiling_stage_data: Optional[dict],
):
self._model_engine = model_engine
self._draft_model_engine = draft_model_engine
Expand All @@ -93,6 +96,7 @@ def __init__(
self._max_batch_size = max_batch_size
self._net_max_seq_len = net_max_seq_len
self._dummy_reqs = None
self._profiling_stage_data = profiling_stage_data
self._kv_cache_manager_cls = get_kv_cache_manager_cls(
model_engine.model.model_config)

Expand Down Expand Up @@ -133,13 +137,76 @@ def _cal_max_memory(self, peak_memory, total_gpu_memory, fraction,
f", tmp kv_mem { (allocated_bytes) / (GB):.2f} GiB")
return int(available_kv_mem)

def _create_dummy_mm_context_request(
self, input_seq_len: int) -> List[trtllm.Request]:
requests = []
if isinstance(
self._profiling_stage_data,
dict) and not self._profiling_stage_data.get("enable_mm_reqs"):
return requests

input_processor = self._model_engine.input_processor
if not (hasattr(input_processor, "get_dummy_prompt")):
logger.warning("The input processor of the model does not have the method [get_dummy_prompt] implemented." \
"Profiling with the default input dummy context request. This may not take into account the memory consumption of " \
"the image encoder")
return requests
prompt = input_processor.get_dummy_prompt(input_seq_len)

prompt_token_ids, extra_processed_inputs = self._model_engine.input_processor_with_hash(
prompt, None)

multimodal_input = extra_processed_inputs.get('multimodal_input')
multimodal_data = extra_processed_inputs.get('multimodal_data')

max_num_tokens = len(prompt_token_ids)
assert max_num_tokens > 0, "the length of the prompt of the dummy mm req is less than or equal to 0"
remaining_tokens = min(max_num_tokens, input_seq_len)
if remaining_tokens > input_seq_len:
logger.warning(f"Profiling with multimedia prompt which contains more tokens than the allowed input_seq_len. " \
f"Multimodal prompt has {remaining_tokens} while the input_seq_len is: {input_seq_len}")
while remaining_tokens > 0:
req_mm_input = trtllm.MultimodalInput(
multimodal_hashes=multimodal_input.multimodal_hashes,
multimodal_positions=multimodal_input.multimodal_positions,
multimodal_lengths=multimodal_input.multimodal_lengths
) if multimodal_input else None
request = trtllm.Request(prompt_token_ids,
max_tokens=1,
streaming=False,
sampling_config=trtllm.SamplingConfig(
beam_width=self._max_beam_width, ),
output_config=trtllm.OutputConfig(),
end_id=-1,
multimodal_input=req_mm_input)
# TODO:
# create_input_processor_with_hash shouldn’t be required during profiling,
# but is temporarily needed due to the multimodal input dependency for chunked prefill
request.py_multimodal_data = multimodal_data
remaining_tokens -= max_num_tokens
requests.append(request)

if self._mapping.enable_attention_dp:
requests = requests * self._mapping.tp_size

return requests

def _create_dummy_context_requests(
self, input_seq_len: int) -> List[trtllm.Request]:
requests = []
if hasattr(self._model_engine.model,
"original_arch") and MODEL_CLASS_VISION_ENCODER_MAPPING.get(
self._model_engine.model.original_arch, None):
input_seq_len = min(self._max_num_tokens, input_seq_len)
requests = self._create_dummy_mm_context_request(input_seq_len)
# if succeed profiling with multimodal requests then return, otherwise profile
# with default case
if requests:
return requests
vocab_size = self._model_engine.model.model_config.pretrained_config.vocab_size
max_num_tokens = self._max_num_tokens
max_beam_width = self._max_beam_width

requests = []
input_seq_len = min(max_num_tokens, input_seq_len)
remaining_tokens = max_num_tokens
while remaining_tokens > 0:
Expand Down Expand Up @@ -349,6 +416,8 @@ def configure_kv_cache_capacity(self, py_executor: PyExecutor) -> None:
)
# set max_gpu_total_bytes
self._kv_cache_config.max_gpu_total_bytes = kv_cache_max_memory
if isinstance(self._profiling_stage_data, dict):
self._profiling_stage_data["activation_bytes"] = activation_bytes
# ---------------------------handle max_gpu_total_bytes---------------------------------

def _create_kv_cache_manager(
Expand Down
6 changes: 5 additions & 1 deletion tensorrt_llm/_torch/pyexecutor/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
torch_dtype_to_str, trace_func)
from tensorrt_llm.inputs.multimodal import (MultimodalParams,
MultimodalRuntimeData)
from tensorrt_llm.inputs.registry import (create_input_processor,
create_input_processor_with_hash)
from tensorrt_llm.logger import logger
from tensorrt_llm.lora_helper import LoraConfig
from tensorrt_llm.lora_manager import LoraModelConfig
Expand Down Expand Up @@ -171,7 +173,9 @@ def __init__(

self.attn_runtime_features = attn_runtime_features or AttentionRuntimeFeatures(
)

self.input_processor = create_input_processor(model_path, None)
self.input_processor_with_hash = create_input_processor_with_hash(
self.input_processor)
if model is None:
loader = ModelLoader(
pytorch_backend_config=pytorch_backend_config,
Expand Down
2 changes: 2 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ def create_py_executor(
tokenizer: Optional[TokenizerBase] = None,
lora_config: Optional[LoraConfig] = None,
kv_connector_config: Optional[KvCacheConnectorConfig] = None,
profiling_stage_data: Optional[dict] = None,
) -> PyExecutor:

garbage_collection_gen0_threshold = llm_args.garbage_collection_gen0_threshold
Expand Down Expand Up @@ -570,6 +571,7 @@ def drafting_loop_wrapper(model):
kv_cache_config=kv_cache_config,
pytorch_backend_config=pytorch_backend_config,
speculative_config=spec_config,
profiling_stage_data=profiling_stage_data,
sparse_attention_config=sparse_attention_config,
)
estimating_kv_cache = kv_cache_creator.try_prepare_estimation()
Expand Down
Loading