From 517b67262ca179a6dbe81e2d03d8b96564fec428 Mon Sep 17 00:00:00 2001 From: fhl <2410591650@qq.com> Date: Tue, 26 Aug 2025 00:37:18 +0800 Subject: [PATCH 01/19] fixes and refactors spec-decode cudagraph Signed-off-by: fhl <2410591650@qq.com> --- vllm/config/__init__.py | 82 +++++++++++++--- vllm/config/compilation.py | 46 ++++++++- vllm/v1/cudagraph_dispatcher.py | 88 ++++++++++++++++- vllm/v1/spec_decode/eagle.py | 149 ++++++++++++++++++++++------- vllm/v1/worker/gpu_model_runner.py | 133 +++++++++++++++---------- 5 files changed, 393 insertions(+), 105 deletions(-) diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index cd0e17977ede..e592afda04a1 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -3501,12 +3501,26 @@ def compute_hash(self) -> str: usedforsecurity=False).hexdigest()[:10] return hash_str - def pad_for_cudagraph(self, batch_size: int) -> int: - # if batch_size > self.compilation_config.max_capture_size, + def pad_for_cudagraph(self, + batch_size: int, + uniform_aligned: bool = False) -> int: + """ Get the padded graph size for the batch size. + uniform_aligned: if True, means the padding batch size would be + divisible by the uniform_decode_len for the main model. + For drafter, caller should make sure uniform_aligned is False because + drafter's uniform_decode_len is 1. + """ + + # if batch_size > self.compilation_config.max_capture_size when + # uniform_aligned is False, or batch_size > self.compilation_config. + # max_uniform_capture_size when uniform_aligned is True, # it should raise an IndexError. - # the caller should make sure the batch_size is within the range, - # i.e., batch_size <= self.compilation_config.max_capture_size - return self.compilation_config.bs_to_padded_graph_size[batch_size] + # the caller should make sure the batch_size is within the range + if not uniform_aligned: + return self.compilation_config.bs_to_padded_graph_size[batch_size] + else: + return self.compilation_config.\ + bs_to_padded_graph_size_uniform[batch_size] @staticmethod def _get_quantization_config( @@ -3756,14 +3770,24 @@ def __post_init__(self): # local attention. self.scheduler_config.disable_hybrid_kv_cache_manager = True - def update_sizes_for_sequence_parallelism(self, - possible_sizes: list) -> list: + def update_sizes_for_sequence_parallelism( + self, + possible_sizes: list, + uniform_possible_sizes: Optional[list] = None + ) -> tuple[list, Optional[list]]: # remove the sizes that not multiple of tp_size when # enable sequence parallelism removed_sizes = [ size for size in possible_sizes if size % self.parallel_config.tensor_parallel_size != 0 ] + removed_uniform_sizes = [] + if uniform_possible_sizes is not None: + removed_uniform_sizes = [ + size for size in uniform_possible_sizes + if size % self.parallel_config.tensor_parallel_size != 0 + ] + removed_sizes = list(set(removed_sizes + removed_uniform_sizes)) if removed_sizes: logger.warning( "Batch sizes %s are removed because they are not " @@ -3774,7 +3798,10 @@ def update_sizes_for_sequence_parallelism(self, return [ size for size in possible_sizes if size % self.parallel_config.tensor_parallel_size == 0 - ] + ], [ + size for size in uniform_possible_sizes + if size % self.parallel_config.tensor_parallel_size == 0 + ] if uniform_possible_sizes else None def _set_cudagraph_sizes(self): """ @@ -3805,8 +3832,11 @@ def _set_cudagraph_sizes(self): """ # calculate the default `batch_size_capture_list` + batch_size_capture_list = [] + uniform_batch_size_capture_list = [] + uniform_decode_len = 1 if not self.speculative_config else \ + 1 + self.speculative_config.num_speculative_tokens if not envs.VLLM_USE_V1: - batch_size_capture_list = [] if self.scheduler_config is not None and \ self.model_config is not None and \ not self.model_config.enforce_eager: @@ -3814,8 +3844,8 @@ def _set_cudagraph_sizes(self): possible_sizes = [1, 2, 4] + [8 * i for i in range(1, 1025)] if self.parallel_config.tensor_parallel_size > 1 and \ self.compilation_config.pass_config.enable_sequence_parallelism: - possible_sizes = self.update_sizes_for_sequence_parallelism( - possible_sizes) + possible_sizes, _ = \ + self.update_sizes_for_sequence_parallelism(possible_sizes) # find the minimum size that is larger than max_num_seqs, # which then becomes the max_batchsize_to_capture @@ -3835,7 +3865,6 @@ def _set_cudagraph_sizes(self): if size <= max_batchsize_to_capture ] else: - batch_size_capture_list = [] if self.model_config is not None and \ not self.model_config.enforce_eager: cuda_graph_sizes = self.scheduler_config.cuda_graph_sizes @@ -3847,18 +3876,41 @@ def _set_cudagraph_sizes(self): batch_size_capture_list = sorted(cuda_graph_sizes) else: raise TypeError(f"Invalid value for {cuda_graph_sizes=}.") + + # we maintain a separate list of uniform-decode capture sizes, + # since for spec-decode, we may need capture sizes being + # divisible by uniform_decode_len(>1). + + # Derive uniform-decode capture sizes via projection: for each + # non-uniform capture size i, take the max multiple of + # uniform_decode_len that is not greater than i. + projected_sizes: set[int] = set() + for size in batch_size_capture_list: + proj = (size // uniform_decode_len) * uniform_decode_len + if proj >= uniform_decode_len: + projected_sizes.add(proj) + uniform_batch_size_capture_list = sorted(projected_sizes) if self.parallel_config.tensor_parallel_size > 1 and \ self.compilation_config.pass_config.enable_sequence_parallelism: - batch_size_capture_list = \ - self.update_sizes_for_sequence_parallelism(batch_size_capture_list) + batch_size_capture_list, uniform_batch_size_capture_list = \ + self.update_sizes_for_sequence_parallelism( + batch_size_capture_list, + uniform_batch_size_capture_list) max_num_tokens = self.scheduler_config.max_num_batched_tokens batch_size_capture_list = [ size for size in batch_size_capture_list if size <= max_num_tokens ] + max_num_decode_tokens = self.scheduler_config.max_num_seqs * \ + uniform_decode_len + uniform_batch_size_capture_list = [ + size for size in uniform_batch_size_capture_list + if size <= max_num_decode_tokens + ] self.compilation_config.init_with_cudagraph_sizes( - batch_size_capture_list) + batch_size_capture_list, uniform_batch_size_capture_list, + uniform_decode_len) def recalculate_max_model_len(self, max_model_len: int): # Can only be called in try_verify_and_update_config diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index e2785e7602e4..1b77f9496643 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -303,6 +303,16 @@ class CompilationConfig: max_capture_size: int = field(default=None, init=False) # type: ignore """not configurable, computed after init""" + uniform_cudagraph_capture_sizes: Optional[list[int]] = None + """ + List for capture sizes for uniform decode for the main model. Its elements + should be multiples of uniform_decode_len(1 for common pure decode, or + 1+num_speculative_tokens for speculative decode). + Not configurable, computed after init + """ + max_uniform_capture_size: int = field(default=None, + init=False) # type: ignore + """not configurable, computed after init""" local_cache_dir: str = field(default=None, init=False) # type: ignore """local cache dir for each rank""" bs_to_padded_graph_size: list[int] = field( @@ -312,6 +322,10 @@ class CompilationConfig: Intuitively, bs_to_padded_graph_size should be dict[int, int]. since we know all keys are in a range [0, max_capture_size], we can optimize it to list[int] for better lookup performance.""" + bs_to_padded_graph_size_uniform: list[int] = field( + default=None, # type: ignore + init=False) + """same as bs_to_padded_graph_size, but for uniform capture sizes""" # keep track of enabled and disabled custom ops enabled_custom_ops: Counter[str] = field(default_factory=Counter, @@ -370,6 +384,7 @@ def __repr__(self) -> str: "disabled_custom_ops": True, "compilation_time": True, "bs_to_padded_graph_size": True, + "bs_to_padded_graph_size_uniform": True, "traced_files": True, "inductor_compile_config": { "post_grad_custom_post_pass": True, @@ -482,8 +497,9 @@ def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]: from vllm.compilation.backends import VllmBackend return VllmBackend(vllm_config) - def init_with_cudagraph_sizes(self, - cudagraph_capture_sizes: list[int]) -> None: + def init_with_cudagraph_sizes(self, cudagraph_capture_sizes: list[int], + uniform_cudagraph_capture_sizes: list[int], + uniform_decode_len: int) -> None: """To complete the initialization of config, we need to know the cudagraph sizes.""" @@ -497,6 +513,12 @@ def init_with_cudagraph_sizes(self, " %s is overridden by config %s"), cudagraph_capture_sizes, dedup_sizes) self.cudagraph_capture_sizes = dedup_sizes + if envs.VLLM_USE_V1: + # recompute uniform_cudagraph_capture_sizes based on the + # dedup_sizes(computed from config) and uniform_decode_len + uniform_cudagraph_capture_sizes = sorted( + set((size // uniform_decode_len) * uniform_decode_len + for size in dedup_sizes if size >= uniform_decode_len)) computed_compile_sizes = [] if self.compile_sizes is not None: @@ -518,6 +540,11 @@ def init_with_cudagraph_sizes(self, self.max_capture_size = self.cudagraph_capture_sizes[ 0] if self.cudagraph_capture_sizes else 0 + self.uniform_cudagraph_capture_sizes = sorted( + uniform_cudagraph_capture_sizes, reverse=True) + self.max_uniform_capture_size = self.uniform_cudagraph_capture_sizes[ + 0] if self.uniform_cudagraph_capture_sizes else 0 + # pre-compute the mapping from batch size to padded graph size self.bs_to_padded_graph_size = [ 0 for i in range(self.max_capture_size + 1) @@ -532,6 +559,21 @@ def init_with_cudagraph_sizes(self, self.bs_to_padded_graph_size[ self.max_capture_size] = self.max_capture_size + # pre-compute the mapping for uniform decode padding. + self.bs_to_padded_graph_size_uniform = [ + 0 for i in range(self.max_uniform_capture_size + 1) + ] + + for end, start in zip(self.uniform_cudagraph_capture_sizes, + self.uniform_cudagraph_capture_sizes[1:] + [0]): + for bs in range(start, end): + if bs == start: + self.bs_to_padded_graph_size_uniform[bs] = start + else: + self.bs_to_padded_graph_size_uniform[bs] = end + self.bs_to_padded_graph_size_uniform[self.max_uniform_capture_size] =\ + self.max_uniform_capture_size + def set_splitting_ops_for_v1(self): # NOTE: this function needs to be called only when level is # CompilationLevel.PIECEWISE diff --git a/vllm/v1/cudagraph_dispatcher.py b/vllm/v1/cudagraph_dispatcher.py index 02e65820b7c0..048c32d71af3 100644 --- a/vllm/v1/cudagraph_dispatcher.py +++ b/vllm/v1/cudagraph_dispatcher.py @@ -5,6 +5,7 @@ from vllm.config import CompilationLevel, CUDAGraphMode, VllmConfig from vllm.forward_context import BatchDescriptor from vllm.logger import init_logger +from vllm.utils import round_up logger = init_logger(__name__) @@ -27,16 +28,22 @@ class CudagraphDispatcher: without cudagraph (if mode no match or mode is NONE). """ - def __init__(self, vllm_config: VllmConfig): + def __init__(self, vllm_config: VllmConfig, is_drafter: bool = False): self.vllm_config = vllm_config self.compilation_config = vllm_config.compilation_config self.cudagraph_mode = self.compilation_config.cudagraph_mode + self.is_drafter = is_drafter # Dict to store valid cudagraph dispatching keys. self.cudagraph_keys: dict[CUDAGraphMode, set[BatchDescriptor]] = { CUDAGraphMode.PIECEWISE: set(), CUDAGraphMode.FULL: set(), } + # Placeholder for capture sizes. Should be initialized in + # self.initialize_cudagraph_keys. + self.cudagraph_capture_sizes: list[int] = [] + self.uniform_cudagraph_capture_sizes: list[int] = [] + self.uniform_decode_query_len: int = 0 assert not self.cudagraph_mode.requires_piecewise_compilation() or \ (self.compilation_config.level == CompilationLevel.PIECEWISE and @@ -71,6 +78,8 @@ def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode, self.add_cudagraph_key( cudagraph_mode.mixed_mode(), BatchDescriptor(num_tokens=bs, uniform_decode=False)) + self.cudagraph_capture_sizes = \ + self.compilation_config.cudagraph_capture_sizes # if decode cudagraph mode is FULL, and we don't already have mixed # mode full cudagraphs then add them here. @@ -78,16 +87,91 @@ def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode, and cudagraph_mode.separate_routine(): max_num_tokens = uniform_decode_query_len * \ self.vllm_config.scheduler_config.max_num_seqs + # for uniform_decode_query_len==1, we use the non-uniform + # capture sizes, this can be for main model without spec-decode or + # for the drafter. Otherwise, we use the uniform-projected sizes. + candidate_sizes = self.compilation_config.cudagraph_capture_sizes\ + if uniform_decode_query_len == 1 else \ + self.compilation_config.cudagraph_capture_sizes_uniform cudagraph_capture_sizes_for_decode = [ - x for x in self.compilation_config.cudagraph_capture_sizes + x for x in candidate_sizes if x <= max_num_tokens and x >= uniform_decode_query_len ] for bs in cudagraph_capture_sizes_for_decode: self.add_cudagraph_key( CUDAGraphMode.FULL, BatchDescriptor(num_tokens=bs, uniform_decode=True)) + self.uniform_cudagraph_capture_sizes = \ + cudagraph_capture_sizes_for_decode + + self.uniform_decode_query_len = uniform_decode_query_len self.keys_initialized = True + def get_capture_cases(self, uniform_decode: bool) -> list[int]: + """Return capture sizes for a given whether it is uniform-decode.""" + if not uniform_decode: + return list(self.cudagraph_capture_sizes) + else: + return list(self.uniform_cudagraph_capture_sizes) + + def padded_num_tokens(self, num_tokens: int, + uniform_decode: bool) -> tuple[int, bool]: + """Return num_tokens after padded and whether it is cudagraph padded. + """ + if self.uniform_decode_query_len == 1 and num_tokens <= \ + self.compilation_config.max_capture_size: + # common situation within the range of max_capture_size for main + # model without spec-decode or it is for a drafter. + # we ignore whether it is uniform-decode since it is always safe + # to pad. + return self.vllm_config.pad_for_cudagraph( + num_tokens, uniform_aligned=False), True + + if self.uniform_decode_query_len > 1 and uniform_decode and \ + num_tokens <= self.compilation_config.max_capture_size_uniform: + # this is particular for uniform-decode alignment for vaildation + # phase of spec-decode. + return self.vllm_config.pad_for_cudagraph( + num_tokens, uniform_aligned=True), True + + return num_tokens, False + + def plan( + self, + num_scheduled_tokens: int, + num_reqs: int, + max_query_len: int, + ) -> tuple[CUDAGraphMode, Optional[BatchDescriptor], int]: + """Plan cudagraph execution in a single call. + + Returns (runtime_mode, batch_descriptor, num_input_tokens_padded). + """ + uniform_decode = (max_query_len == self.uniform_decode_query_len) and ( + num_scheduled_tokens == num_reqs * max_query_len) + + # Compute padded tokens + cudagraph_padded = False + if self.cudagraph_mode != CUDAGraphMode.NONE: + num_input_tokens, cudagraph_padded = self.padded_num_tokens( + num_scheduled_tokens, uniform_decode) + else: + num_input_tokens = num_scheduled_tokens + + if not cudagraph_padded and not self.is_drafter: + # Eager mode + # Pad tokens to multiple of tensor_parallel_size when + # enabled collective fusion for SP + tp_size = self.vllm_config.parallel_config.tensor_parallel_size + if self.compilation_config.pass_config. \ + enable_sequence_parallelism and tp_size > 1: + num_input_tokens = round_up(num_scheduled_tokens, tp_size) + + # Build initial descriptor and dispatch + descriptor = BatchDescriptor(num_tokens=num_input_tokens, + uniform_decode=uniform_decode) + runtime_mode, descriptor = self.dispatch(descriptor) + return runtime_mode, descriptor, num_input_tokens + def dispatch( self, batch_descriptor: BatchDescriptor ) -> tuple[CUDAGraphMode, Optional[BatchDescriptor]]: diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 0a0e9fed725c..4dbabf359367 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -3,17 +3,16 @@ import ast from dataclasses import replace from importlib.util import find_spec -from typing import Optional, Protocol +from typing import Any, Optional, Protocol import numpy as np import torch import torch.nn as nn from vllm.attention.layer import Attention -from vllm.config import (CompilationLevel, VllmConfig, - get_layers_from_vllm_config) +from vllm.config import CUDAGraphMode, VllmConfig, get_layers_from_vllm_config from vllm.distributed.parallel_state import get_pp_group -from vllm.forward_context import set_forward_context +from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model from vllm.model_executor.models import supports_multimodal @@ -25,6 +24,7 @@ TreeAttentionMetadataBuilder) from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata from vllm.v1.attention.backends.utils import CommonAttentionMetadata +from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.sample.metadata import SamplingMetadata @@ -74,12 +74,9 @@ def __init__( self.is_multimodal_model = vllm_config.model_config \ .is_multimodal_model - self.use_cuda_graph = (self.vllm_config.compilation_config.level - == CompilationLevel.PIECEWISE and - not self.vllm_config.model_config.enforce_eager) - self.cudagraph_batch_sizes = list( - reversed( - self.vllm_config.compilation_config.cudagraph_capture_sizes)) + self.use_cuda_graph = ( + self.vllm_config.compilation_config.cudagraph_mode + != CUDAGraphMode.NONE) # persistent buffers for cuda graph self.input_ids = torch.zeros(self.max_num_tokens, @@ -145,6 +142,11 @@ def __init__( dtype=torch.int32, ).repeat(max_batch_size, 1) + # Cudagraph dispatcher for runtime cudagraph dispatching of drafter, + # which is independent of the dispatcher of the model runner. + self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config, + is_drafter=True) + def propose( self, # [num_tokens] @@ -162,6 +164,7 @@ def propose( num_tokens = target_token_ids.shape[0] batch_size = next_token_ids.shape[0] last_token_indices = common_attn_metadata.query_start_loc[1:] - 1 + max_query_len = common_attn_metadata.max_query_len if self.method == "eagle3": assert isinstance(self.model, Eagle3LlamaForCausalLM) @@ -188,11 +191,12 @@ def propose( per_layer_attn_metadata = {} for layer_name in self.attn_layer_names: per_layer_attn_metadata[layer_name] = attn_metadata - if self.use_cuda_graph and \ - num_tokens <= self.cudagraph_batch_sizes[-1]: - num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) - else: - num_input_tokens = num_tokens + # dispatcher planning for drafter + cudagraph_runtime_mode, batch_descriptor, num_input_tokens = \ + self.cudagraph_dispatcher.plan( + num_scheduled_tokens=num_tokens, + num_reqs=batch_size, + max_query_len=max_query_len) # copy inputs to buffer for cudagraph self.positions[:num_tokens] = target_positions self.hidden_states[:num_tokens] = target_hidden_states @@ -211,7 +215,9 @@ def propose( with set_forward_context(per_layer_attn_metadata, self.vllm_config, - num_tokens=num_input_tokens): + num_tokens=num_input_tokens, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=batch_descriptor): ret_hidden_states = self.model( input_ids=input_ids, positions=self.positions[:num_input_tokens], @@ -254,11 +260,13 @@ def propose( # Generate the remaining draft tokens. draft_token_ids_list = [draft_token_ids] - if self.use_cuda_graph and \ - batch_size <= self.cudagraph_batch_sizes[-1]: - input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size) - else: - input_batch_size = batch_size + # dispatcher plans only once for the remaining loop + cudagraph_runtime_mode, batch_descriptor, input_batch_size = \ + self.cudagraph_dispatcher.plan( + num_scheduled_tokens=batch_size, + num_reqs=batch_size, + max_query_len=1) + attn_metadata.num_actual_tokens = batch_size attn_metadata.max_query_len = 1 attn_metadata.query_start_loc = self.arange[:batch_size + 1] @@ -318,9 +326,12 @@ def propose( input_ids = self.input_ids[:input_batch_size] # Run the model. - with set_forward_context(per_layer_attn_metadata, - self.vllm_config, - num_tokens=input_batch_size): + with set_forward_context( + per_layer_attn_metadata, + self.vllm_config, + num_tokens=input_batch_size, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=batch_descriptor): last_hidden_states, hidden_states = self.model( input_ids=input_ids, positions=self.positions[:input_batch_size], @@ -460,16 +471,27 @@ def propose_tree( self.hidden_states[:num_tokens] = tree_hidden_states.view( num_tokens, -1) - if self.use_cuda_graph and \ - num_tokens <= self.cudagraph_batch_sizes[-1]: - num_input_tokens = self.vllm_config.pad_for_cudagraph( - num_tokens) - else: - num_input_tokens = num_tokens + # As decode phase of TreeAttentionBackend does not have a uniform + # decode query length (1 for the root level and total_num_drafts + # for subsequent levels), so here we expect runtime mode should be + # not FULL, until we find ways to support full cudagraph of this + # case. + cudagraph_runtime_mode, batch_descriptor, num_input_tokens = \ + self.cudagraph_dispatcher.plan( + num_scheduled_tokens=num_tokens, + num_reqs=batch_size, + max_query_len=attn_metadata.max_query_len) + assert cudagraph_runtime_mode != CUDAGraphMode.FULL, \ + "TreeAttentionBackend does not support full cudagraphs at " \ + "this moment" + # Run the model. - with set_forward_context(per_layer_attn_metadata, - self.vllm_config, - num_tokens=num_input_tokens): + with set_forward_context( + per_layer_attn_metadata, + self.vllm_config, + num_tokens=num_input_tokens, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=batch_descriptor): last_hidden_states, hidden_states = self.model( input_ids=self.input_ids[:num_input_tokens], positions=self.positions[:num_input_tokens], @@ -653,9 +675,66 @@ def load_model(self, target_model: nn.Module) -> None: def dummy_run( self, num_tokens: int, + cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, + force_attention: bool = False, + uniform_decode: bool = False, ) -> None: - with set_forward_context(None, self.vllm_config, - num_tokens=num_tokens): + assert cudagraph_runtime_mode != CUDAGraphMode.FULL, \ + "Eagle drafter doesn't support full cudagraphs at this moment" + + max_query_len = 1 if uniform_decode else num_tokens + max_num_reqs = self.vllm_config.scheduler_config.max_num_seqs + num_reqs = min(num_tokens, max_num_reqs) + + per_layer_attn_metadata: Optional[dict[str, Any]] = None + if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL: + per_layer_attn_metadata = {} + + common_attn_metadata = CommonAttentionMetadata( + query_start_loc=self.runner.query_start_loc[:num_reqs + 1], + query_start_loc_cpu=self.runner.query_start_loc_cpu[:num_reqs + + 1], + seq_lens=self.runner.seq_lens[:num_reqs], + seq_lens_cpu=self.runner.seq_lens_cpu[:num_reqs], + num_computed_tokens_cpu=self.runner.input_batch. + num_computed_tokens_cpu_tensor[:num_reqs], + num_reqs=num_reqs, + num_actual_tokens=num_tokens, + max_query_len=max_query_len, + max_seq_len=self.max_model_len, + block_table_tensor=self.runner.input_batch.block_table[0]. + get_device_tensor()[:num_reqs], + slot_mapping=self.runner.input_batch.block_table[0]. + slot_mapping[:num_tokens], + causal=True) + # FIXME: need to consider multiple kv_cache_groups + attn_metadata = self.runner.attn_groups[0][0].metadata_builder\ + .build_for_drafting(common_attn_metadata=common_attn_metadata, + draft_index=0) + # At this moment, we assume all eagle layers belong to the same KV + # cache group, thus using the same attention metadata. + + for layer_name in self.attn_layer_names: + per_layer_attn_metadata[layer_name] = attn_metadata + + if cudagraph_runtime_mode == CUDAGraphMode.NONE: + batch_descriptor = None + else: + # filter out the valid batch descriptor + _cg_mode, batch_descriptor = \ + self.cudagraph_dispatcher.dispatch( + BatchDescriptor(num_tokens=num_tokens, + uniform_decode=uniform_decode)) + # sanity check + assert cudagraph_runtime_mode == _cg_mode, ( + f"Cudagraph runtime mode mismatch at dummy_run. " + f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}.") + + with set_forward_context(per_layer_attn_metadata, + self.vllm_config, + num_tokens=num_tokens, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=batch_descriptor): if self.is_multimodal_model: input_ids = None inputs_embeds = self.inputs_embeds[:num_tokens] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d634cf280f7f..57cafbf7c8f7 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -9,6 +9,7 @@ from collections.abc import Iterator from contextlib import contextmanager from copy import deepcopy +from functools import partial from typing import TYPE_CHECKING, Any, Optional, Union, cast import numpy as np @@ -53,7 +54,7 @@ from vllm.tasks import GenerationTask, PoolingTask, SupportedTask from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, GiB_bytes, LazyLoader, cdiv, check_use_alibi, - get_dtype_size, is_pin_memory_available, round_up, + get_dtype_size, is_pin_memory_available, supports_dynamo) from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend from vllm.v1.attention.backends.utils import ( @@ -1515,22 +1516,12 @@ def execute_model( max_query_len) = self._prepare_inputs(scheduler_output) num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE - and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): - # Use CUDA graphs. - # Add padding to the batch size. - num_input_tokens = self.vllm_config.pad_for_cudagraph( - num_scheduled_tokens) - else: - # Eager mode. - # Pad tokens to multiple of tensor_parallel_size when - # enabled collective fusion for SP - tp_size = self.vllm_config.parallel_config.tensor_parallel_size - if self.compilation_config.pass_config. \ - enable_sequence_parallelism and tp_size > 1: - num_input_tokens = round_up(num_scheduled_tokens, tp_size) - else: - num_input_tokens = num_scheduled_tokens + # dispatcher planning + cudagraph_runtime_mode, batch_descriptor, num_input_tokens = \ + self.cudagraph_dispatcher.plan( + num_scheduled_tokens=num_scheduled_tokens, + num_reqs=self.input_batch.num_reqs, + max_query_len=max_query_len) # Padding for DP num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens) @@ -1583,13 +1574,6 @@ def execute_model( intermediate_tensors = self.sync_and_slice_intermediate_tensors( num_input_tokens, intermediate_tensors, True) - uniform_decode = (max_query_len == self.uniform_decode_query_len) and ( - num_scheduled_tokens == self.input_batch.num_reqs * max_query_len) - batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens, - uniform_decode=uniform_decode) - cudagraph_runtime_mode, batch_descriptor = \ - self.cudagraph_dispatcher.dispatch(batch_descriptor) - # Run the model. # Use persistent buffers for CUDA graphs. with set_forward_context( @@ -2407,7 +2391,11 @@ def _dummy_run( else: hidden_states = outputs - if self.speculative_config and self.speculative_config.use_eagle(): + # Only trigger drafter's dummy run for profile run. Otherwise, the + # dummy run logic of drafter for warmup run and capturing cudagraph + # is separated from the main model's dummy run. + if is_profile and self.speculative_config and \ + self.speculative_config.use_eagle(): assert isinstance(self.drafter, EagleProposer) self.drafter.dummy_run(num_tokens) @@ -2665,31 +2653,65 @@ def freeze_gc(): set_cudagraph_capturing_enabled(True) with freeze_gc(), graph_capture(device=self.device): cudagraph_mode = self.compilation_config.cudagraph_mode + logger.info("Start capturing cudagraphs for main model...") if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE: cudagraph_runtime_mode = cudagraph_mode.mixed_mode() - - compilation_cases = list(reversed(self.cudagraph_batch_sizes)) - self._capture_cudagraphs( + mixed_cases = self.cudagraph_dispatcher.get_capture_cases( + uniform_decode=False) + compilation_cases = list(reversed(mixed_cases)) + self._capture_cudagraphs_with_callable( compilation_cases, cudagraph_runtime_mode=cudagraph_runtime_mode, - uniform_decode=False) + uniform_decode=False, + dummy_run_callable=partial(self._dummy_run, + skip_eplb=True)) # Capture full cudagraph for uniform decode batches if we have # dont already have full mixed prefill-decode cudagraphs if cudagraph_mode.decode_mode() == CUDAGraphMode.FULL and \ cudagraph_mode.separate_routine(): - max_num_tokens = self.scheduler_config.max_num_seqs * \ - self.uniform_decode_query_len - decode_cudagraph_batch_sizes = [ - x for x in self.cudagraph_batch_sizes if - x <= max_num_tokens and x >= self.uniform_decode_query_len - ] - compilation_cases_decode = list( - reversed(decode_cudagraph_batch_sizes)) - self._capture_cudagraphs( + uniform_cases = self.cudagraph_dispatcher.get_capture_cases( + uniform_decode=True) + compilation_cases_decode = list(reversed(uniform_cases)) + self._capture_cudagraphs_with_callable( compilation_cases=compilation_cases_decode, cudagraph_runtime_mode=CUDAGraphMode.FULL, - uniform_decode=True) + uniform_decode=True, + dummy_run_callable=partial(self._dummy_run, + skip_eplb=True)) + + # Capture drafter cudagraphs. + # Note: Currently only PIECEWISE mode is supported for eagle + # drafter. + # TODO: add full cudagraph support for drafter. + if self.speculative_config and self.speculative_config.use_eagle(): + assert isinstance(self.drafter, EagleProposer) + logger.info("Start capturing cudagraphs for drafter...") + if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE: + mixed_mode_drafter = cudagraph_mode.mixed_mode() + drafter_cases = self.drafter.cudagraph_dispatcher.\ + get_capture_cases(uniform_decode=False) + drafter_cases = list(reversed(drafter_cases)) + self._capture_cudagraphs_with_callable( + compilation_cases=drafter_cases, + cudagraph_runtime_mode=mixed_mode_drafter, + uniform_decode=False, + dummy_run_callable=self.drafter.dummy_run, + ) + # the following code would not be triggered at present since + # only PIECEWISE mode is supported. But it is kept and prepared + # for full cudagraphs. + if cudagraph_mode.decode_mode() == CUDAGraphMode.FULL and \ + cudagraph_mode.separate_routine(): + uniform_cases = self.drafter.cudagraph_dispatcher.\ + get_capture_cases(uniform_decode=True) + drafter_uniform_cases = list(reversed(uniform_cases)) + self._capture_cudagraphs_with_callable( + compilation_cases=drafter_uniform_cases, + cudagraph_runtime_mode=CUDAGraphMode.FULL, + uniform_decode=True, + dummy_run_callable=self.drafter.dummy_run, + ) # Disable cudagraph capturing globally, so any unexpected cudagraph # capturing will be detected and raise an error after here. @@ -2706,9 +2728,10 @@ def freeze_gc(): logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", elapsed_time, cuda_graph_size / (1 << 30)) - def _capture_cudagraphs(self, compilation_cases: list[int], - cudagraph_runtime_mode: CUDAGraphMode, - uniform_decode: bool): + def _capture_cudagraphs_with_callable( + self, compilation_cases: list[int], + cudagraph_runtime_mode: CUDAGraphMode, uniform_decode: bool, + dummy_run_callable: Any): assert cudagraph_runtime_mode != CUDAGraphMode.NONE and \ cudagraph_runtime_mode in [CUDAGraphMode.FULL, CUDAGraphMode.PIECEWISE] @@ -2731,15 +2754,13 @@ def _capture_cudagraphs(self, compilation_cases: list[int], # attention while `PIECEWISE` implies no attention. force_attention = ( cudagraph_runtime_mode == CUDAGraphMode.FULL) - self._dummy_run(num_tokens, - cudagraph_runtime_mode=CUDAGraphMode.NONE, - force_attention=force_attention, - uniform_decode=uniform_decode, - skip_eplb=True) - self._dummy_run(num_tokens, - cudagraph_runtime_mode=cudagraph_runtime_mode, - uniform_decode=uniform_decode, - skip_eplb=True) + dummy_run_callable(num_tokens, + cudagraph_runtime_mode=CUDAGraphMode.NONE, + force_attention=force_attention, + uniform_decode=uniform_decode) + dummy_run_callable(num_tokens, + cudagraph_runtime_mode=cudagraph_runtime_mode, + uniform_decode=uniform_decode) def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: """ @@ -2878,6 +2899,16 @@ def initialize_cudagraph_capture(self) -> None: self.compilation_config.cudagraph_mode, self.uniform_decode_query_len) + # At this moment, we assume the drafter and main model shares the + # same cudagraph_mode + if self.speculative_config and self.speculative_config.use_eagle(): + assert isinstance(self.drafter, EagleProposer) + assert not cudagraph_mode.has_full_cudagraphs(), \ + "Eagle drafter does not support full cudagraphs yet" + # decode_query_len is 1 for drafter + self.drafter.cudagraph_dispatcher.initialize_cudagraph_keys( + self.compilation_config.cudagraph_mode, 1) + def calculate_reorder_batch_threshold(self) -> None: """ Check that if any backends reorder batches; that the reordering From 40e1ccb35d03798c5faa6e3dc90e1763ea9d1a65 Mon Sep 17 00:00:00 2001 From: fhl <2410591650@qq.com> Date: Tue, 26 Aug 2025 12:30:18 +0800 Subject: [PATCH 02/19] remove build_for_cudagraph_capture Signed-off-by: fhl <2410591650@qq.com> --- vllm/v1/attention/backends/flashinfer.py | 16 ---------------- vllm/v1/attention/backends/mamba_attn.py | 19 +------------------ vllm/v1/attention/backends/mla/common.py | 15 --------------- vllm/v1/attention/backends/rocm_aiter_fa.py | 9 --------- vllm/v1/attention/backends/triton_attn.py | 10 ---------- vllm/v1/attention/backends/utils.py | 10 ---------- vllm/v1/worker/gpu_model_runner.py | 2 +- 7 files changed, 2 insertions(+), 79 deletions(-) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 50819bb2bb94..aa806ccacd85 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -577,22 +577,6 @@ def build(self, return attn_metadata - def build_for_cudagraph_capture( - self, common_attn_metadata: CommonAttentionMetadata): - """ - This method builds the metadata for full cudagraph capture. - Currently, only decode is supported for full cudagraphs with FlashInfer. - """ - m = common_attn_metadata - - assert m.num_reqs == m.num_actual_tokens, \ - "FlashInfer only supports decode-only full CUDAGraph capture. " \ - "Make sure all cudagraph capture sizes <= max_num_seq." - - m.max_query_len = 1 # decode-only - - return self.build(0, m) - def use_cascade_attention(self, *args, **kwargs) -> bool: if self.kv_cache_spec.dtype != self.vllm_config.model_config.dtype: # TODO: The cascade wrapper currently does not support setting diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 07ef7cb69a16..50940252b4a7 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -8,8 +8,7 @@ from vllm.config import VllmConfig from vllm.v1.attention.backends.utils import (AttentionCGSupport, - AttentionMetadataBuilder, - CommonAttentionMetadata) + AttentionMetadataBuilder) from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec M = TypeVar("M") @@ -37,19 +36,3 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], dtype=torch.int32, device=device, ) - - def build_for_cudagraph_capture( - self, common_attn_metadata: CommonAttentionMetadata) -> M: - """ - This method builds the metadata for full cudagraph capture. - Currently, only decode is supported for full cudagraphs with Mamba. - """ - m = common_attn_metadata - - assert m.num_reqs == m.num_actual_tokens, \ - "Mamba only supports decode-only full CUDAGraph capture. " \ - "Make sure all cudagraph capture sizes <= max_num_seq." - - m.max_query_len = 1 # decode-only - - return self.build(0, m) \ No newline at end of file diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index ce45b34f6435..b3b6ac3f22cd 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -561,21 +561,6 @@ def _build_decode(self, block_table_tensor: torch.Tensor, seq_lens=seq_lens, ) - def build_for_cudagraph_capture( - self, common_attn_metadata: CommonAttentionMetadata) -> M: - """ - This method builds the metadata for full cudagraph capture. - Currently, only decode is supported for full cudagraphs with MLA. - """ - m = common_attn_metadata - assert m.num_reqs == m.num_actual_tokens, \ - "MLA only supports decode-only full CUDAGraph capture. " \ - "Make sure all cudagraph capture sizes <= max_num_seq." - - assert m.max_query_len == 1 # decode-only - - return self.build(0, m) - def build(self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index 403ad8e88a95..8f2b9d51d452 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -254,15 +254,6 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], self.aot_sliding_window: Optional[tuple[int, int]] = None self.total_tokens: int = 0 - def build_for_cudagraph_capture( - self, common_attn_metadata: CommonAttentionMetadata): - self.total_tokens = self.model_config.max_model_len \ - * self.vllm_config.scheduler_config.max_num_partial_prefills - res = self.build(common_prefix_len=0, - common_attn_metadata=common_attn_metadata) - self.total_tokens = 0 - return res - def build(self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index b12036c59979..44e7cabc911b 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -73,16 +73,6 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config.parallel_config) self.headdim = model_config.get_head_size() - def build_for_cudagraph_capture( - self, common_attn_metadata: CommonAttentionMetadata - ) -> TritonAttentionMetadata: - attn_metadata = self.build(0, common_attn_metadata) - # When doing full graph capture, setting seq_lens to - # max_model_len will cause graph capture to be extremely - # slow, so here we set it to 1. - attn_metadata.seq_lens.fill_(1) - return attn_metadata - def build(self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 39bdbe125635..00134bab27fb 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -205,16 +205,6 @@ def build(self, """ raise NotImplementedError - def build_for_cudagraph_capture( - self, common_attn_metadata: CommonAttentionMetadata) -> M: - """ - Build attention metadata for CUDA graph capture. Uses build by default. - Subclasses that override this method should call self.build or - super().build_for_cudagraph_capture. - """ - return self.build(common_prefix_len=0, - common_attn_metadata=common_attn_metadata) - def build_for_drafting( self, common_attn_metadata: CommonAttentionMetadata, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 57cafbf7c8f7..84943e12fb94 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2323,7 +2323,7 @@ def _dummy_run( for attn_group in self.attn_groups[kv_cache_group_id]: attn_metadata_i = attn_group.metadata_builder\ - .build_for_cudagraph_capture(common_attn_metadata) + .build(0, common_attn_metadata) for layer_name in kv_cache_group_spec.layer_names: attn_metadata[layer_name] = attn_metadata_i From 35507171bb8d13755fe29b6af7bf2a6ba256246b Mon Sep 17 00:00:00 2001 From: fhl <2410591650@qq.com> Date: Tue, 26 Aug 2025 20:37:30 +0800 Subject: [PATCH 03/19] support capturing mutiple uniform_query_len Signed-off-by: fhl <2410591650@qq.com> --- vllm/forward_context.py | 9 +- vllm/v1/cudagraph_dispatcher.py | 128 ++++++++++++++++++++--------- vllm/v1/spec_decode/eagle.py | 41 +++++---- vllm/v1/worker/gpu_model_runner.py | 114 ++++++++++++++----------- 4 files changed, 190 insertions(+), 102 deletions(-) diff --git a/vllm/forward_context.py b/vllm/forward_context.py index c57c51d289ac..e39a770d0b92 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -38,13 +38,20 @@ class BatchDescriptor(NamedTuple): False can also be used for an uniform decode batch to dispatch to the cudagraph supporting non-uniform batches. """ + uniform_query_len: int = 0 + """ + For non-uniform batches, should set to 0 for uniquely identifying the batch. + For uniform batches, it is the max_query_len of a uniform batch. + """ @property def non_uniform(self) -> "BatchDescriptor": """ Return a non-uniform version of current batch descriptor. """ - return BatchDescriptor(self.num_tokens, uniform_decode=False) + return BatchDescriptor(self.num_tokens, + uniform_decode=False, + uniform_query_len=0) def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: list[int], diff --git a/vllm/v1/cudagraph_dispatcher.py b/vllm/v1/cudagraph_dispatcher.py index 048c32d71af3..2d88a08843cc 100644 --- a/vllm/v1/cudagraph_dispatcher.py +++ b/vllm/v1/cudagraph_dispatcher.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional +from typing import Optional, Union from vllm.config import CompilationLevel, CUDAGraphMode, VllmConfig from vllm.forward_context import BatchDescriptor @@ -42,8 +42,9 @@ def __init__(self, vllm_config: VllmConfig, is_drafter: bool = False): # Placeholder for capture sizes. Should be initialized in # self.initialize_cudagraph_keys. self.cudagraph_capture_sizes: list[int] = [] - self.uniform_cudagraph_capture_sizes: list[int] = [] - self.uniform_decode_query_len: int = 0 + # map uniform_query_len to capture sizes + self.uniform_cudagraph_capture_sizes: dict[int, list[int]] = {} + self.uniform_query_lens: list[int] = [] assert not self.cudagraph_mode.requires_piecewise_compilation() or \ (self.compilation_config.level == CompilationLevel.PIECEWISE and @@ -63,7 +64,8 @@ def add_cudagraph_key(self, runtime_mode: CUDAGraphMode, self.cudagraph_keys[runtime_mode].add(batch_descriptor) def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode, - uniform_decode_query_len: int): + uniform_query_lens: Union[int, list[int]]): + # This should be called only after attention backend is initialized. # Note: we create all valid keys possible for cudagraph but do not @@ -73,6 +75,26 @@ def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode, # trigger capturing/replaying the piecewise cudagraphs depending on # CompilationConfig.cudagraph_mode. In addition, if we allow lazy # capturing in future PR, some keys may never be triggered. + + # support multiple uniform_decode_query_lens for spec-decode + if isinstance(uniform_query_lens, int): + uniform_query_lens = [uniform_query_lens] + assert len(uniform_query_lens) > 0 and all( + isinstance(x, int) and x > 0 for x in uniform_query_lens), \ + f"Invalid uniform_query_lens: {uniform_query_lens}" + self.uniform_query_lens = uniform_query_lens + + # we only have compilation_config.cudagraph_capture_sizes_uniform + # being aligned with one uniform_query_len that greater than 1, not + # multiple of them. Should verify this here. + for uniform_query_len in self.uniform_query_lens: + if uniform_query_len > 1 and \ + self.compilation_config.cudagraph_capture_sizes_uniform: + assert all(x % uniform_query_len == 0 for x in + self.compilation_config.\ + cudagraph_capture_sizes_uniform), \ + f"Invalid uniform_query_lens: {uniform_query_len}" + if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE: for bs in self.compilation_config.cudagraph_capture_sizes: self.add_cudagraph_key( @@ -83,57 +105,85 @@ def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode, # if decode cudagraph mode is FULL, and we don't already have mixed # mode full cudagraphs then add them here. - if cudagraph_mode.decode_mode() == CUDAGraphMode.FULL \ - and cudagraph_mode.separate_routine(): - max_num_tokens = uniform_decode_query_len * \ - self.vllm_config.scheduler_config.max_num_seqs - # for uniform_decode_query_len==1, we use the non-uniform - # capture sizes, this can be for main model without spec-decode or - # for the drafter. Otherwise, we use the uniform-projected sizes. - candidate_sizes = self.compilation_config.cudagraph_capture_sizes\ - if uniform_decode_query_len == 1 else \ - self.compilation_config.cudagraph_capture_sizes_uniform - cudagraph_capture_sizes_for_decode = [ - x for x in candidate_sizes - if x <= max_num_tokens and x >= uniform_decode_query_len - ] - for bs in cudagraph_capture_sizes_for_decode: - self.add_cudagraph_key( - CUDAGraphMode.FULL, - BatchDescriptor(num_tokens=bs, uniform_decode=True)) - self.uniform_cudagraph_capture_sizes = \ - cudagraph_capture_sizes_for_decode + for uniform_query_len in self.uniform_query_lens: + if cudagraph_mode.decode_mode() == CUDAGraphMode.FULL \ + and cudagraph_mode.separate_routine(): + max_num_tokens = uniform_query_len * \ + self.vllm_config.scheduler_config.max_num_seqs + # for uniform_query_len==1, we use the non-uniform + # capture sizes, this can be for main model without spec-decode + # or for the drafter. Otherwise, we use the uniform-aligned + # sizes. + candidate_sizes = self.compilation_config.\ + cudagraph_capture_sizes \ + if uniform_query_len == 1 else \ + self.compilation_config.cudagraph_capture_sizes_uniform + cudagraph_capture_sizes_for_decode = [ + x for x in candidate_sizes + if x <= max_num_tokens and x >= uniform_query_len + ] + for bs in cudagraph_capture_sizes_for_decode: + self.add_cudagraph_key( + CUDAGraphMode.FULL, + BatchDescriptor(num_tokens=bs, + uniform_decode=True, + uniform_query_len=uniform_query_len)) + self.uniform_cudagraph_capture_sizes[uniform_query_len] = \ + cudagraph_capture_sizes_for_decode - self.uniform_decode_query_len = uniform_decode_query_len + # update the cudagraph mode resolved from runner + self.cudagraph_mode = cudagraph_mode self.keys_initialized = True - def get_capture_cases(self, uniform_decode: bool) -> list[int]: - """Return capture sizes for a given whether it is uniform-decode.""" + def get_capture_cases( + self, uniform_decode: bool, uniform_query_len: int + ) -> tuple[CUDAGraphMode, list[BatchDescriptor], list[int]]: + """Return capture sizes, keys, and runtime mode for a given case. + The capture sizes and keys are sorted in descending order. + """ if not uniform_decode: - return list(self.cudagraph_capture_sizes) + runtime_mode = self.cudagraph_mode.mixed_mode() + uniform_query_len = 0 + capture_sizes = sorted(self.cudagraph_capture_sizes, reverse=True) else: - return list(self.uniform_cudagraph_capture_sizes) + runtime_mode = self.cudagraph_mode.decode_mode() + assert uniform_query_len in self.uniform_cudagraph_capture_sizes + capture_sizes = sorted( + self.uniform_cudagraph_capture_sizes[uniform_query_len], + reverse=True) + keys = [ + BatchDescriptor(num_tokens=x, + uniform_decode=uniform_decode, + uniform_query_len=uniform_query_len) + for x in capture_sizes + ] + return capture_sizes, keys, runtime_mode - def padded_num_tokens(self, num_tokens: int, - uniform_decode: bool) -> tuple[int, bool]: + def padded_num_tokens(self, num_tokens: int, uniform_decode: bool, + uniform_query_len: int) -> tuple[int, bool]: """Return num_tokens after padded and whether it is cudagraph padded. """ - if self.uniform_decode_query_len == 1 and num_tokens <= \ + assert uniform_query_len == 0 or uniform_query_len in \ + self.uniform_query_lens, \ + f"Invalid uniform_query_len: {uniform_query_len}" + if uniform_query_len <= 1 and num_tokens <= \ self.compilation_config.max_capture_size: # common situation within the range of max_capture_size for main - # model without spec-decode or it is for a drafter. + # model or for a drafter. # we ignore whether it is uniform-decode since it is always safe # to pad. return self.vllm_config.pad_for_cudagraph( num_tokens, uniform_aligned=False), True - if self.uniform_decode_query_len > 1 and uniform_decode and \ + if uniform_decode and uniform_query_len > 1 and \ num_tokens <= self.compilation_config.max_capture_size_uniform: # this is particular for uniform-decode alignment for vaildation - # phase of spec-decode. + # phase of spec-decode, or for the first iteration of drafter when + # support padded speculation return self.vllm_config.pad_for_cudagraph( num_tokens, uniform_aligned=True), True + # otherwise, it is not cudagraph padded return num_tokens, False def plan( @@ -146,14 +196,15 @@ def plan( Returns (runtime_mode, batch_descriptor, num_input_tokens_padded). """ - uniform_decode = (max_query_len == self.uniform_decode_query_len) and ( + uniform_decode = (max_query_len in self.uniform_query_lens) and ( num_scheduled_tokens == num_reqs * max_query_len) + uniform_query_len = max_query_len if uniform_decode else 0 # Compute padded tokens cudagraph_padded = False if self.cudagraph_mode != CUDAGraphMode.NONE: num_input_tokens, cudagraph_padded = self.padded_num_tokens( - num_scheduled_tokens, uniform_decode) + num_scheduled_tokens, uniform_decode, uniform_query_len) else: num_input_tokens = num_scheduled_tokens @@ -168,7 +219,8 @@ def plan( # Build initial descriptor and dispatch descriptor = BatchDescriptor(num_tokens=num_input_tokens, - uniform_decode=uniform_decode) + uniform_decode=uniform_decode, + uniform_query_len=uniform_query_len) runtime_mode, descriptor = self.dispatch(descriptor) return runtime_mode, descriptor, num_input_tokens diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 4dbabf359367..5a514bab9555 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -471,19 +471,15 @@ def propose_tree( self.hidden_states[:num_tokens] = tree_hidden_states.view( num_tokens, -1) - # As decode phase of TreeAttentionBackend does not have a uniform - # decode query length (1 for the root level and total_num_drafts - # for subsequent levels), so here we expect runtime mode should be - # not FULL, until we find ways to support full cudagraph of this - # case. + # Note: decode phase of TreeAttentionBackend does not have an + # unique uniform decode query length (1 for the root level and + # total_num_drafts for subsequent levels). Here we may support + # this situation once full cudagraph of TreeAttention is supported. cudagraph_runtime_mode, batch_descriptor, num_input_tokens = \ self.cudagraph_dispatcher.plan( num_scheduled_tokens=num_tokens, num_reqs=batch_size, max_query_len=attn_metadata.max_query_len) - assert cudagraph_runtime_mode != CUDAGraphMode.FULL, \ - "TreeAttentionBackend does not support full cudagraphs at " \ - "this moment" # Run the model. with set_forward_context( @@ -677,14 +673,28 @@ def dummy_run( num_tokens: int, cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, force_attention: bool = False, - uniform_decode: bool = False, + batch_descriptor: Optional[BatchDescriptor] = None, ) -> None: assert cudagraph_runtime_mode != CUDAGraphMode.FULL, \ "Eagle drafter doesn't support full cudagraphs at this moment" - max_query_len = 1 if uniform_decode else num_tokens + uniform_decode = False + if batch_descriptor is not None and batch_descriptor.uniform_decode: + assert batch_descriptor.num_tokens == num_tokens, \ + "num_tokens mismatch" + uniform_decode = True + max_query_len = batch_descriptor.uniform_query_len + else: + max_query_len = num_tokens + max_num_reqs = self.vllm_config.scheduler_config.max_num_seqs - num_reqs = min(num_tokens, max_num_reqs) + if uniform_decode: + assert num_tokens % max_query_len == 0, \ + "num_tokens must be divisible by max_query_len for uniform " \ + "decode" + num_reqs = min(num_tokens // max_query_len, max_num_reqs) + else: + num_reqs = min(num_tokens, max_num_reqs) per_layer_attn_metadata: Optional[dict[str, Any]] = None if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL: @@ -720,12 +730,11 @@ def dummy_run( if cudagraph_runtime_mode == CUDAGraphMode.NONE: batch_descriptor = None else: - # filter out the valid batch descriptor - _cg_mode, batch_descriptor = \ - self.cudagraph_dispatcher.dispatch( - BatchDescriptor(num_tokens=num_tokens, - uniform_decode=uniform_decode)) + assert batch_descriptor is not None, \ + "batch_descriptor should be provided for cudagraph capture" # sanity check + _cg_mode, batch_descriptor = \ + self.cudagraph_dispatcher.dispatch(batch_descriptor) assert cudagraph_runtime_mode == _cg_mode, ( f"Cudagraph runtime mode mismatch at dummy_run. " f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}.") diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 84943e12fb94..c34178a24207 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2221,7 +2221,7 @@ def _dummy_run( num_tokens: int, cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, force_attention: bool = False, - uniform_decode: bool = False, + batch_descriptor: Optional[BatchDescriptor] = None, skip_eplb: bool = False, is_profile: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: @@ -2238,13 +2238,16 @@ def _dummy_run( needed. force_attention: If True, always create attention metadata. Used to warm up attention backend when mode is NONE. - uniform_decode: If True, the batch is a uniform decode batch. + batch_descriptor: Batch descriptor for the cudagraph capture. + If None, it is a mixed prefill-decode batch of eager mode. skip_eplb: If True, skip EPLB state update. is_profile: If True, this is a profile run. """ assert cudagraph_runtime_mode in { CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL } + uniform_decode = batch_descriptor is not None and \ + batch_descriptor.uniform_decode # Padding for DP num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens) @@ -2265,6 +2268,10 @@ def _dummy_run( # for GQA/MQA. max_query_len = self.uniform_decode_query_len if uniform_decode else \ num_tokens + if batch_descriptor is not None: + assert batch_descriptor.num_tokens == num_tokens + if batch_descriptor.uniform_decode: + assert max_query_len == batch_descriptor.uniform_query_len # Set num_scheduled_tokens based on num_tokens and max_num_seqs # for dummy run with LoRA so that the num_reqs collectively @@ -2361,12 +2368,11 @@ def _dummy_run( if cudagraph_runtime_mode == CUDAGraphMode.NONE: batch_descriptor = None else: - # filter out the valid batch descriptor - _cg_mode, batch_descriptor = \ - self.cudagraph_dispatcher.dispatch( - BatchDescriptor(num_tokens=num_tokens, - uniform_decode=uniform_decode)) + assert batch_descriptor is not None, \ + "batch_descriptor should be provided for cudagraph capture" # sanity check + _cg_mode, batch_descriptor = \ + self.cudagraph_dispatcher.dispatch(batch_descriptor) assert cudagraph_runtime_mode == _cg_mode, ( f"Cudagraph runtime mode mismatch at dummy_run. " f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}.") @@ -2392,8 +2398,8 @@ def _dummy_run( hidden_states = outputs # Only trigger drafter's dummy run for profile run. Otherwise, the - # dummy run logic of drafter for warmup run and capturing cudagraph - # is separated from the main model's dummy run. + # dummy run logic of drafter is separated from the main model's + # dummy run. if is_profile and self.speculative_config and \ self.speculative_config.use_eagle(): assert isinstance(self.drafter, EagleProposer) @@ -2655,14 +2661,13 @@ def freeze_gc(): cudagraph_mode = self.compilation_config.cudagraph_mode logger.info("Start capturing cudagraphs for main model...") if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE: - cudagraph_runtime_mode = cudagraph_mode.mixed_mode() - mixed_cases = self.cudagraph_dispatcher.get_capture_cases( - uniform_decode=False) - compilation_cases = list(reversed(mixed_cases)) + capture_sizes, keys, runtime_mode = \ + self.cudagraph_dispatcher.get_capture_cases( + uniform_decode=False, uniform_query_len=0) self._capture_cudagraphs_with_callable( - compilation_cases, - cudagraph_runtime_mode=cudagraph_runtime_mode, - uniform_decode=False, + capture_sizes=capture_sizes, + keys=keys, + cudagraph_runtime_mode=runtime_mode, dummy_run_callable=partial(self._dummy_run, skip_eplb=True)) @@ -2670,13 +2675,15 @@ def freeze_gc(): # dont already have full mixed prefill-decode cudagraphs if cudagraph_mode.decode_mode() == CUDAGraphMode.FULL and \ cudagraph_mode.separate_routine(): - uniform_cases = self.cudagraph_dispatcher.get_capture_cases( - uniform_decode=True) - compilation_cases_decode = list(reversed(uniform_cases)) + capture_sizes, keys, runtime_mode = \ + self.cudagraph_dispatcher.get_capture_cases( + uniform_decode=True, + uniform_query_len=self.uniform_decode_query_len) + self._capture_cudagraphs_with_callable( - compilation_cases=compilation_cases_decode, - cudagraph_runtime_mode=CUDAGraphMode.FULL, - uniform_decode=True, + capture_sizes=capture_sizes, + keys=keys, + cudagraph_runtime_mode=runtime_mode, dummy_run_callable=partial(self._dummy_run, skip_eplb=True)) @@ -2688,28 +2695,33 @@ def freeze_gc(): assert isinstance(self.drafter, EagleProposer) logger.info("Start capturing cudagraphs for drafter...") if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE: - mixed_mode_drafter = cudagraph_mode.mixed_mode() - drafter_cases = self.drafter.cudagraph_dispatcher.\ - get_capture_cases(uniform_decode=False) - drafter_cases = list(reversed(drafter_cases)) + capture_sizes, keys, runtime_mode = \ + self.drafter.cudagraph_dispatcher.\ + get_capture_cases(uniform_decode=False, + uniform_query_len=0) self._capture_cudagraphs_with_callable( - compilation_cases=drafter_cases, - cudagraph_runtime_mode=mixed_mode_drafter, - uniform_decode=False, + capture_sizes=capture_sizes, + keys=keys, + cudagraph_runtime_mode=runtime_mode, dummy_run_callable=self.drafter.dummy_run, ) # the following code would not be triggered at present since # only PIECEWISE mode is supported. But it is kept and prepared # for full cudagraphs. + + #TODO: support multiple uniform_query_lens for drafter once + # Padded speculation is supported. i.e., + # [1, self.uniform_decode_query_len] for drafter. if cudagraph_mode.decode_mode() == CUDAGraphMode.FULL and \ cudagraph_mode.separate_routine(): - uniform_cases = self.drafter.cudagraph_dispatcher.\ - get_capture_cases(uniform_decode=True) - drafter_uniform_cases = list(reversed(uniform_cases)) + capture_sizes, keys, runtime_mode = \ + self.drafter.cudagraph_dispatcher.\ + get_capture_cases(uniform_decode=True, + uniform_query_len=1) self._capture_cudagraphs_with_callable( - compilation_cases=drafter_uniform_cases, - cudagraph_runtime_mode=CUDAGraphMode.FULL, - uniform_decode=True, + capture_sizes=capture_sizes, + keys=keys, + cudagraph_runtime_mode=runtime_mode, dummy_run_callable=self.drafter.dummy_run, ) @@ -2729,23 +2741,27 @@ def freeze_gc(): elapsed_time, cuda_graph_size / (1 << 30)) def _capture_cudagraphs_with_callable( - self, compilation_cases: list[int], - cudagraph_runtime_mode: CUDAGraphMode, uniform_decode: bool, - dummy_run_callable: Any): + self, capture_sizes: list[int], keys: list[BatchDescriptor], + cudagraph_runtime_mode: CUDAGraphMode, dummy_run_callable: Any): assert cudagraph_runtime_mode != CUDAGraphMode.NONE and \ cudagraph_runtime_mode in [CUDAGraphMode.FULL, CUDAGraphMode.PIECEWISE] + assert len(keys) > 0, "keys must be non-empty" + assert len(capture_sizes) == len(keys), \ + "capture_sizes and keys must have the same length" + uniform_decode = keys[0].uniform_decode + uniform_query_len = keys[0].uniform_query_len # Only rank 0 should print progress bar during capture if is_global_first_rank(): - compilation_cases = tqdm( - compilation_cases, + capture_sizes = tqdm( + capture_sizes, disable=not self.load_config.use_tqdm_on_load, desc="Capturing CUDA graphs ({}, {})".format( - "decode" if uniform_decode else "mixed prefill-decode", - cudagraph_runtime_mode.name)) + f"decode(query_len={uniform_query_len})" if uniform_decode + else "mixed prefill-decode", cudagraph_runtime_mode.name)) # We skip EPLB here since we don't want to record dummy metrics - for num_tokens in compilation_cases: + for num_tokens, key in zip(capture_sizes, keys): for _ in range(self.compilation_config.cudagraph_num_of_warmups): # Use CUDAGraphRuntimeStyle.NONE (default) for warmup. # But be careful, warm up with `NONE`is orthogonal to @@ -2757,10 +2773,10 @@ def _capture_cudagraphs_with_callable( dummy_run_callable(num_tokens, cudagraph_runtime_mode=CUDAGraphMode.NONE, force_attention=force_attention, - uniform_decode=uniform_decode) + batch_descriptor=key) dummy_run_callable(num_tokens, cudagraph_runtime_mode=cudagraph_runtime_mode, - uniform_decode=uniform_decode) + batch_descriptor=key) def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: """ @@ -2905,9 +2921,13 @@ def initialize_cudagraph_capture(self) -> None: assert isinstance(self.drafter, EagleProposer) assert not cudagraph_mode.has_full_cudagraphs(), \ "Eagle drafter does not support full cudagraphs yet" - # decode_query_len is 1 for drafter + # uniform_query_len is 1 for drafter + # TODO: let uniform_query_lens = [1, self.uniform_decode_query_len] + # for drafter once Padded speculation is supported. See: + # https://github.com/vllm-project/vllm/issues/21984 for details + # and an implementation in https://github.com/vllm-project/vllm/pull/22684 # noqa: E501 self.drafter.cudagraph_dispatcher.initialize_cudagraph_keys( - self.compilation_config.cudagraph_mode, 1) + self.compilation_config.cudagraph_mode, uniform_query_lens=1) def calculate_reorder_batch_threshold(self) -> None: """ From a142f140053ebca69201fb9243ed6f41c18e4684 Mon Sep 17 00:00:00 2001 From: fhl <2410591650@qq.com> Date: Tue, 26 Aug 2025 21:46:15 +0800 Subject: [PATCH 04/19] fix typo Signed-off-by: fhl <2410591650@qq.com> --- vllm/v1/cudagraph_dispatcher.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/v1/cudagraph_dispatcher.py b/vllm/v1/cudagraph_dispatcher.py index 2d88a08843cc..d736ca42e14b 100644 --- a/vllm/v1/cudagraph_dispatcher.py +++ b/vllm/v1/cudagraph_dispatcher.py @@ -84,15 +84,15 @@ def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode, f"Invalid uniform_query_lens: {uniform_query_lens}" self.uniform_query_lens = uniform_query_lens - # we only have compilation_config.cudagraph_capture_sizes_uniform + # we only have compilation_config.uniform_cudagraph_capture_sizes # being aligned with one uniform_query_len that greater than 1, not # multiple of them. Should verify this here. for uniform_query_len in self.uniform_query_lens: if uniform_query_len > 1 and \ - self.compilation_config.cudagraph_capture_sizes_uniform: + self.compilation_config.uniform_cudagraph_capture_sizes: assert all(x % uniform_query_len == 0 for x in self.compilation_config.\ - cudagraph_capture_sizes_uniform), \ + uniform_cudagraph_capture_sizes), \ f"Invalid uniform_query_lens: {uniform_query_len}" if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE: @@ -117,7 +117,7 @@ def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode, candidate_sizes = self.compilation_config.\ cudagraph_capture_sizes \ if uniform_query_len == 1 else \ - self.compilation_config.cudagraph_capture_sizes_uniform + self.compilation_config.uniform_cudagraph_capture_sizes cudagraph_capture_sizes_for_decode = [ x for x in candidate_sizes if x <= max_num_tokens and x >= uniform_query_len From f7d73f823bf3621a665166f74fea736b4c0c2388 Mon Sep 17 00:00:00 2001 From: fhl <2410591650@qq.com> Date: Tue, 26 Aug 2025 22:41:56 +0800 Subject: [PATCH 05/19] fix typo Signed-off-by: fhl <2410591650@qq.com> --- vllm/v1/cudagraph_dispatcher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/cudagraph_dispatcher.py b/vllm/v1/cudagraph_dispatcher.py index d736ca42e14b..5a8ea41885c1 100644 --- a/vllm/v1/cudagraph_dispatcher.py +++ b/vllm/v1/cudagraph_dispatcher.py @@ -176,7 +176,7 @@ def padded_num_tokens(self, num_tokens: int, uniform_decode: bool, num_tokens, uniform_aligned=False), True if uniform_decode and uniform_query_len > 1 and \ - num_tokens <= self.compilation_config.max_capture_size_uniform: + num_tokens <= self.compilation_config.max_uniform_capture_size: # this is particular for uniform-decode alignment for vaildation # phase of spec-decode, or for the first iteration of drafter when # support padded speculation From 02390fc1ad06e57ed2a2f21006520c45baaa8963 Mon Sep 17 00:00:00 2001 From: fhl <2410591650@qq.com> Date: Tue, 26 Aug 2025 23:02:46 +0800 Subject: [PATCH 06/19] fix broken examples/offline_inference/spec_decode.py Signed-off-by: fhl <2410591650@qq.com> --- examples/offline_inference/spec_decode.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/examples/offline_inference/spec_decode.py b/examples/offline_inference/spec_decode.py index c4972f02d0f8..8df2512a94be 100644 --- a/examples/offline_inference/spec_decode.py +++ b/examples/offline_inference/spec_decode.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import json + from transformers import AutoTokenizer from vllm import LLM, SamplingParams @@ -69,6 +71,8 @@ def parse_args(): parser.add_argument("--model-dir", type=str, default=None) parser.add_argument("--eagle-dir", type=str, default=None) parser.add_argument("--custom-mm-prompts", action="store_true") + parser.add_argument("--request-id-prefix", type=str, default="") + parser.add_argument("--compilation-config", type=str, default="") return parser.parse_args() @@ -133,12 +137,15 @@ def main(): max_model_len=16384, limit_mm_per_prompt={"image": 5}, disable_chunked_mm_input=True, + compilation_config=( + json.loads(args.compilation_config) if args.compilation_config else None + ), ) sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len) if not args.custom_mm_prompts: outputs = llm.generate( - TokensPrompt(prompt_token_ids=prompt_ids), + [TokensPrompt(prompt_token_ids=_prompt_ids) for _prompt_ids in prompt_ids], sampling_params=sampling_params, ) else: From ec027786d2021d390403678a78ebe37f71fcc80e Mon Sep 17 00:00:00 2001 From: fhl2000 <63384265+fhl2000@users.noreply.github.com> Date: Sun, 7 Sep 2025 06:42:02 +0000 Subject: [PATCH 07/19] fix pre-commit Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> --- vllm/v1/worker/gpu_model_runner.py | 40 +++++++++++++----------------- 1 file changed, 17 insertions(+), 23 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a33d723d4ba7..d31b5ba7e2a9 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1590,23 +1590,13 @@ def _pool( def _preprocess( self, + num_input_tokens: int, scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> tuple[int, int, Optional[torch.Tensor], Optional[torch.Tensor], - Optional[torch.Tensor], torch.Tensor, + ) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor], torch.Tensor, Optional[IntermediateTensors], dict[str, Any]]: num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - # dispatcher planning - cudagraph_runtime_mode, batch_descriptor, num_input_tokens = \ - self.cudagraph_dispatcher.plan( - num_scheduled_tokens=num_scheduled_tokens, - num_reqs=self.input_batch.num_reqs, - max_query_len=max_query_len) - - # Padding for DP - num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens) - num_input_tokens += num_pad # _prepare_inputs may reorder the batch, so we must gather multi # modal outputs after that to ensure the correct order @@ -1656,11 +1646,6 @@ def _preprocess( num_input_tokens, intermediate_tensors, True) return ( - cudagraph_runtime_mode, - batch_descriptor, - num_scheduled_tokens, - num_input_tokens, - num_tokens_across_dp, input_ids, inputs_embeds, positions, @@ -1863,18 +1848,27 @@ def execute_model( num_scheduled_tokens_np, spec_decode_common_attn_metadata, max_query_len) = self._prepare_inputs(scheduler_output) + # cudagraph dispatcher planing + padding + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + cudagraph_runtime_mode, batch_descriptor, num_input_tokens = \ + self.cudagraph_dispatcher.plan( + num_scheduled_tokens=num_scheduled_tokens, + num_reqs=self.input_batch.num_reqs, + max_query_len=max_query_len) + + # Padding for DP + num_pad, num_tokens_across_dp = self.get_dp_padding( + num_input_tokens) + num_input_tokens += num_pad + ( - cudagraph_runtime_mode, - batch_descriptor, - num_scheduled_tokens, - num_input_tokens, - num_tokens_across_dp, input_ids, inputs_embeds, positions, intermediate_tensors, model_kwargs, - ) = self._preprocess(scheduler_output, intermediate_tensors) + ) = self._preprocess(num_input_tokens, scheduler_output, + intermediate_tensors) # Run the model. # Use persistent buffers for CUDA graphs. From 286677fb5b32f98c60f8987aa4a91fa53c8fb9f9 Mon Sep 17 00:00:00 2001 From: fhl2000 <63384265+fhl2000@users.noreply.github.com> Date: Wed, 10 Sep 2025 17:13:43 +0000 Subject: [PATCH 08/19] revert spec_decode.py Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> --- examples/offline_inference/spec_decode.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/offline_inference/spec_decode.py b/examples/offline_inference/spec_decode.py index e24182e3d26c..d3f42d63f217 100644 --- a/examples/offline_inference/spec_decode.py +++ b/examples/offline_inference/spec_decode.py @@ -71,7 +71,6 @@ def parse_args(): parser.add_argument("--model-dir", type=str, default=None) parser.add_argument("--eagle-dir", type=str, default=None) parser.add_argument("--custom-mm-prompts", action="store_true") - parser.add_argument("--request-id-prefix", type=str, default="") parser.add_argument("--compilation-config", type=str, default="") return parser.parse_args() From 0eda1116df7e656d495c120fa9023766db256bab Mon Sep 17 00:00:00 2001 From: fhl2000 <63384265+fhl2000@users.noreply.github.com> Date: Mon, 15 Sep 2025 03:26:30 +0000 Subject: [PATCH 09/19] address comments Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> --- vllm/config/__init__.py | 14 ++++---------- vllm/config/compilation.py | 4 ++-- vllm/v1/worker/gpu_model_runner.py | 7 +++---- 3 files changed, 9 insertions(+), 16 deletions(-) diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 2e4982e8bb04..2d64ecfb84cd 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -3716,16 +3716,10 @@ def _set_cudagraph_sizes(self): # we maintain a separate list of uniform-decode capture sizes, # since for spec-decode, we may need capture sizes being # divisible by uniform_decode_len(>1). - - # Derive uniform-decode capture sizes via projection: for each - # non-uniform capture size i, take the max multiple of - # uniform_decode_len that is not greater than i. - projected_sizes: set[int] = set() - for size in batch_size_capture_list: - proj = (size // uniform_decode_len) * uniform_decode_len - if proj >= uniform_decode_len: - projected_sizes.add(proj) - uniform_batch_size_capture_list = sorted(projected_sizes) + uniform_batch_size_capture_list = sorted( + set(size * uniform_decode_len + for size in batch_size_capture_list + if size >= uniform_decode_len)) if self.parallel_config.tensor_parallel_size > 1 and \ self.compilation_config.pass_config.enable_sequence_parallelism: batch_size_capture_list, uniform_batch_size_capture_list = \ diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 2c1b2ccd0c95..6e659bec0ac8 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -518,8 +518,8 @@ def init_with_cudagraph_sizes(self, cudagraph_capture_sizes: list[int], # recompute uniform_cudagraph_capture_sizes based on the # dedup_sizes(computed from config) and uniform_decode_len uniform_cudagraph_capture_sizes = sorted( - set((size // uniform_decode_len) * uniform_decode_len - for size in dedup_sizes if size >= uniform_decode_len)) + set(size * uniform_decode_len for size in dedup_sizes + if size >= uniform_decode_len)) computed_compile_sizes = [] if self.compile_sizes is not None: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 73eeb35668b3..5636062d8ef9 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -54,9 +54,8 @@ from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.tasks import GenerationTask, PoolingTask, SupportedTask from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, - GiB_bytes, LazyLoader, cdiv, check_use_alibi, - get_dtype_size, is_pin_memory_available, round_up, - supports_dynamo) + GiB_bytes, LazyLoader, check_use_alibi, get_dtype_size, + is_pin_memory_available, supports_dynamo) from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder from vllm.v1.attention.backends.utils import ( AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, @@ -3350,7 +3349,7 @@ def initialize_cudagraph_capture(self) -> None: # TODO: let uniform_query_lens = [1, self.uniform_decode_query_len] # for drafter once Padded speculation is supported. See: # https://github.com/vllm-project/vllm/issues/21984 for details - # and an implementation in https://github.com/vllm-project/vllm/pull/22684 # noqa: E501 + # and an implementation in https://github.com/vllm-project/vllm/pull/24539 # noqa: E501 self.drafter.cudagraph_dispatcher.initialize_cudagraph_keys( self.compilation_config.cudagraph_mode, uniform_query_lens=1) From 9c50e6eac37b8781b76a15c7d613f9d3f0519e64 Mon Sep 17 00:00:00 2001 From: fhl2000 <63384265+fhl2000@users.noreply.github.com> Date: Mon, 15 Sep 2025 04:35:27 +0000 Subject: [PATCH 10/19] revert build_for_cudagraph_capturing Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> --- vllm/v1/attention/backends/rocm_aiter_fa.py | 9 +++++++++ vllm/v1/attention/backends/utils.py | 10 ++++++++++ vllm/v1/worker/gpu_model_runner.py | 2 +- 3 files changed, 20 insertions(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index 768fbdea3231..a4e2758bd311 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -253,6 +253,15 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], self.aot_sliding_window: Optional[tuple[int, int]] = None self.total_tokens: int = 0 + def build_for_cudagraph_capture( + self, common_attn_metadata: CommonAttentionMetadata): + self.total_tokens = self.model_config.max_model_len \ + * self.vllm_config.scheduler_config.max_num_partial_prefills + res = self.build(common_prefix_len=0, + common_attn_metadata=common_attn_metadata) + self.total_tokens = 0 + return res + def build(self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 7ad0cf253dfb..009943fa743d 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -235,6 +235,16 @@ def reorder_batch(self, input_batch: "InputBatch", """ raise NotImplementedError + def build_for_cudagraph_capture( + self, common_attn_metadata: CommonAttentionMetadata) -> M: + """ + Build attention metadata for CUDA graph capture. Uses build by default. + Subclasses that override this method should call self.build or + super().build_for_cudagraph_capture. + """ + return self.build(common_prefix_len=0, + common_attn_metadata=common_attn_metadata) + def build_for_drafting( self, common_attn_metadata: CommonAttentionMetadata, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 5636062d8ef9..4c21fe0f4240 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2752,7 +2752,7 @@ def _dummy_run( for attn_group in self.attn_groups[kv_cache_group_id]: attn_metadata_i = attn_group.metadata_builder\ - .build(0, common_attn_metadata) + .build_for_cudagraph_capture(common_attn_metadata) for layer_name in kv_cache_group_spec.layer_names: attn_metadata[layer_name] = attn_metadata_i From e4a1a783fe7c3ec7b9c910a14a4aa6f945be457b Mon Sep 17 00:00:00 2001 From: fhl2000 <63384265+fhl2000@users.noreply.github.com> Date: Mon, 15 Sep 2025 04:48:04 +0000 Subject: [PATCH 11/19] remove unnecessary assertion Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> --- vllm/v1/attention/backends/gdn_attn.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/vllm/v1/attention/backends/gdn_attn.py b/vllm/v1/attention/backends/gdn_attn.py index 74eb9ae9d325..1634986d57b2 100644 --- a/vllm/v1/attention/backends/gdn_attn.py +++ b/vllm/v1/attention/backends/gdn_attn.py @@ -293,12 +293,6 @@ def build_for_cudagraph_capture( """ m = common_attn_metadata - assert (m.num_reqs * (self.num_spec + 1) <= m.num_actual_tokens - and ((m.num_reqs + 1) * (self.num_spec + 1) - >= m.num_actual_tokens)), \ - "GDN only supports decode-only full CUDAGraph capture. " \ - "Make sure all cudagraph capture sizes <= max_num_seq." - num_accepted_tokens = torch.full((m.num_reqs, ), m.max_query_len, dtype=torch.int32, From 691c21e92d06993a99878e83e16871feab0070a2 Mon Sep 17 00:00:00 2001 From: fhl <2410591650@qq.com> Date: Sun, 21 Sep 2025 18:06:57 +0800 Subject: [PATCH 12/19] fixes for ubatching Signed-off-by: fhl <2410591650@qq.com> --- vllm/v1/cudagraph_dispatcher.py | 12 ++-- vllm/v1/worker/gpu_model_runner.py | 97 +++++++++++------------------- 2 files changed, 44 insertions(+), 65 deletions(-) diff --git a/vllm/v1/cudagraph_dispatcher.py b/vllm/v1/cudagraph_dispatcher.py index cc689170bb4c..5b1ea6f61718 100644 --- a/vllm/v1/cudagraph_dispatcher.py +++ b/vllm/v1/cudagraph_dispatcher.py @@ -209,8 +209,11 @@ def caculate_uniform_decode(self, num_scheduled_tokens: int, num_reqs: int, uniform_query_len = max_query_len if uniform_decode else 0 return uniform_decode, uniform_query_len - def get_num_input_tokens(self, num_scheduled_tokens: int, num_reqs: int, - max_query_len: int) -> int: + def get_num_input_tokens_local(self, num_scheduled_tokens: int, + num_reqs: int, max_query_len: int) -> int: + """ return num_input_tokens, acounting for cudagraph padding and + tp padding locally, but not across dp. + """ uniform_decode, uniform_query_len = self.caculate_uniform_decode( num_scheduled_tokens, num_reqs, max_query_len) @@ -240,6 +243,7 @@ def get_num_input_tokens(self, num_scheduled_tokens: int, num_reqs: int, def maybe_pad_for_dp( self, num_input_tokens: int) -> tuple[int, Optional[torch.Tensor]]: if self.runner and hasattr(self.runner, 'get_dp_padding'): + assert not self.is_drafter return self.runner.get_dp_padding(num_input_tokens) return 0, None @@ -255,8 +259,8 @@ def plan( Returns (runtime_mode, batch_descriptor, num_input_tokens, num_tokens_across_dp). """ - num_input_tokens = self.get_num_input_tokens(num_scheduled_tokens, - num_reqs, max_query_len) + num_input_tokens = self.get_num_input_tokens_local( + num_scheduled_tokens, num_reqs, max_query_len) # maybe pad for dp num_pad, num_tokens_across_dp = self.maybe_pad_for_dp(num_input_tokens) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4649d5151abc..38bef80728a0 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -58,7 +58,7 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, GiB_bytes, check_use_alibi, get_dtype_size, is_pin_memory_available, - length_from_prompt_token_ids_or_embeds, round_up, + length_from_prompt_token_ids_or_embeds, supports_dynamo) from vllm.v1.attention.backends.flash_attn import AttentionMetadata from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder @@ -1032,8 +1032,8 @@ def _prepare_inputs( query_start_loc = self.query_start_loc.gpu[:num_reqs + 1] num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens - num_tokens_padded = num_tokens_unpadded + self.get_local_padding( - num_tokens_unpadded) + num_tokens_padded = self._get_num_input_tokens( + num_tokens_unpadded, num_reqs, max_num_scheduled_tokens) ubatch_slices, num_tokens_after_padding = \ ubatch_split(max_num_scheduled_tokens, num_tokens_unpadded, @@ -1771,28 +1771,6 @@ def get_dp_padding(self, dtype=torch.int32) return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding - def get_local_padding(self, num_tokens_unpadded: int) -> int: - - num_tokens_padded = num_tokens_unpadded - - if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE - and num_tokens_unpadded <= self.cudagraph_batch_sizes[-1]): - # Use piecewise CUDA graphs. - # Add padding to the batch size. - num_tokens_padded = self.vllm_config.pad_for_cudagraph( - num_tokens_unpadded) - else: - # Eager mode. - # Pad tokens to multiple of tensor_parallel_size when - # enabled collective fusion for SP - tp_size = self.vllm_config.parallel_config.tensor_parallel_size - if self.vllm_config.compilation_config.pass_config. \ - enable_sequence_parallelism and tp_size > 1: - num_tokens_padded = round_up(num_tokens_unpadded, tp_size) - - num_pad_tokens = num_tokens_padded - num_tokens_unpadded - return num_pad_tokens - # This is where the second ubatch is adjusted to account for the padding. # Should be called after attention metadata creation. This just pads # the second ubatch slice out to the total number of tokens @@ -1843,7 +1821,7 @@ def _pool( def _get_num_input_tokens(self, num_scheduled_tokens: int, num_reqs: int, max_query_len: int) -> int: - return self.cudagraph_dispatcher.get_num_input_tokens( + return self.cudagraph_dispatcher.get_num_input_tokens_local( num_scheduled_tokens, num_reqs, max_query_len) def _preprocess( @@ -1851,25 +1829,9 @@ def _preprocess( num_input_tokens: int, scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, - ubatch_slices: Optional[UBatchSlices] = None, - num_tokens_after_padding: Optional[torch.Tensor] = None, ) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor], torch.Tensor, Optional[IntermediateTensors], dict[str, Any]]: - ) -> tuple[int, int, Optional[torch.Tensor], Optional[torch.Tensor], - Optional[torch.Tensor], torch.Tensor, - Optional[IntermediateTensors], dict[str, Any]]: - num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - # TODO: refactor this. The padding is very confusing now. - if ubatch_slices: - assert num_tokens_after_padding is not None - num_input_tokens = int(num_tokens_after_padding[0].item() * 2) - self.pad_out_ubatch_slice(ubatch_slices, num_input_tokens) - elif ubatch_slices is None: - num_input_tokens = self._get_num_input_tokens(num_scheduled_tokens) - num_pad, num_tokens_after_padding = self.get_dp_padding( - num_input_tokens) - num_input_tokens += num_pad # _prepare_inputs may reorder the batch, so we must gather multi # modal outputs after that to ensure the correct order @@ -2159,20 +2121,26 @@ def execute_model( num_scheduled_tokens_np, spec_decode_common_attn_metadata, max_query_len, ubatch_slices, num_tokens_after_padding ) = self._prepare_inputs(scheduler_output) - - #TODO: refactor this into plan - if ubatch_slices is not None: - num_input_tokens = num_input_tokens // 2 - # cudagraph dispatcher planing + padding + # Cudagraph dispatcher planing + padding + # For ubatching, we still use original total num_scheduled_tokens + # for planing, as gpu_ubatch_wrapper awares only the + # cudagraph_runtime_mode but not batch_descriptor. num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens (cudagraph_runtime_mode, batch_descriptor, num_input_tokens, num_tokens_across_dp)= \ self.cudagraph_dispatcher.plan( num_scheduled_tokens=num_scheduled_tokens, num_reqs=self.input_batch.num_reqs, - max_query_len=max_query_len, - ubatch_slices=ubatch_slices) + max_query_len=max_query_len) + + # may overwrite num_input_tokens and num_tokens_across_dp when + # ubatching enabled in this batch. + if ubatch_slices: + assert num_tokens_after_padding is not None + num_tokens_across_dp = num_tokens_after_padding + num_input_tokens = int(num_tokens_after_padding[0].item() * 2) + self.pad_out_ubatch_slice(ubatch_slices, num_input_tokens) ( input_ids, @@ -2181,8 +2149,10 @@ def execute_model( intermediate_tensors, model_kwargs, ) = self._preprocess(num_input_tokens, scheduler_output, - intermediate_tensors, ubatch_slices, - num_tokens_after_padding) + intermediate_tensors) + + if ubatch_slices is not None: + num_input_tokens = num_input_tokens // 2 # Run the model. # Use persistent buffers for CUDA graphs. @@ -2841,7 +2811,6 @@ def _dummy_run( """ ubatch_enabled = self.parallel_config.enable_dbo num_tokens_across_dp = None - num_pad = 0 should_ubatch = False if ubatch_enabled: should_ubatch = num_tokens >= \ @@ -2865,7 +2834,7 @@ def _dummy_run( if not should_ubatch: num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens) - num_tokens += num_pad + num_tokens += num_pad # If cudagraph_mode.decode_mode() == FULL and # cudagraph_mode.separate_routine(). This means that we are using @@ -3053,6 +3022,8 @@ def _dummy_run( f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}.") if ubatch_slices is not None: + # we ignore the inconsistency of num_tokens in batch descriptor + # for ubatching since gpu_ubatch_wrapper is not aware of it. num_tokens = num_tokens // 2 with self.maybe_randomize_inputs(input_ids), set_forward_context( attn_metadata, @@ -3427,8 +3398,11 @@ def freeze_gc(): return cuda_graph_size def _capture_cudagraphs_with_callable( - self, capture_sizes: list[int], keys: list[BatchDescriptor], - cudagraph_runtime_mode: CUDAGraphMode, dummy_run_callable: Any, + self, + capture_sizes: list[int], + keys: list[BatchDescriptor], + cudagraph_runtime_mode: CUDAGraphMode, + dummy_run_callable: Any, enable_dbo: bool = False): assert cudagraph_runtime_mode != CUDAGraphMode.NONE and \ cudagraph_runtime_mode in [CUDAGraphMode.FULL, @@ -3447,7 +3421,7 @@ def _capture_cudagraphs_with_callable( desc="Capturing CUDA graphs ({}, {})".format( f"decode(query_len={uniform_query_len})" if uniform_decode else "mixed prefill-decode", cudagraph_runtime_mode.name)) - + # DBO Only supports running Full cudagraphs with uniform # decode lengths if enable_dbo and uniform_decode: @@ -3463,11 +3437,12 @@ def _capture_cudagraphs_with_callable( self.compilation_config.cudagraph_num_of_warmups): force_attention = ( cudagraph_runtime_mode == CUDAGraphMode.FULL) - dummy_run_callable(num_tokens, - cudagraph_runtime_mode=CUDAGraphMode.NONE, - force_attention=force_attention, - batch_descriptor=key, - allow_microbatching=True) + dummy_run_callable( + num_tokens, + cudagraph_runtime_mode=CUDAGraphMode.NONE, + force_attention=force_attention, + batch_descriptor=key, + allow_microbatching=True) # Graph Capture dummy_run_callable(num_tokens, From 43b2753a0a3306f8458bf53b107ceb19f3a5b4fe Mon Sep 17 00:00:00 2001 From: fhl2000 <63384265+fhl2000@users.noreply.github.com> Date: Tue, 23 Sep 2025 16:10:47 +0000 Subject: [PATCH 13/19] fix CI Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> --- vllm/v1/worker/cpu_worker.py | 3 ++- vllm/v1/worker/gpu_worker.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/v1/worker/cpu_worker.py b/vllm/v1/worker/cpu_worker.py index ac471845cdeb..80783f9fe847 100644 --- a/vllm/v1/worker/cpu_worker.py +++ b/vllm/v1/worker/cpu_worker.py @@ -110,7 +110,8 @@ def execute_model( intermediate_tensors = None num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens num_reqs = len(scheduler_output.num_scheduled_tokens) - max_query_len = max(scheduler_output.num_scheduled_tokens.values()) + max_query_len = max(scheduler_output.num_scheduled_tokens.values())\ + if num_scheduled_tokens > 0 else 0 num_input_tokens = self.model_runner._get_num_input_tokens( num_scheduled_tokens, num_reqs, max_query_len) all_gather_tensors = { diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 715e53a649c0..e6347096e44f 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -431,7 +431,8 @@ def execute_model( forward_pass = scheduler_output.total_num_scheduled_tokens > 0 num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens num_reqs = len(scheduler_output.num_scheduled_tokens) - max_query_len = max(scheduler_output.num_scheduled_tokens.values()) + max_query_len = max(scheduler_output.num_scheduled_tokens.values())\ + if num_scheduled_tokens > 0 else 0 num_input_tokens = self.model_runner._get_num_input_tokens( num_scheduled_tokens, num_reqs, max_query_len) all_gather_tensors = { From 0a3fe050218f09a3562b9c948ccbbaf8e40b9230 Mon Sep 17 00:00:00 2001 From: fhl2000 <63384265+fhl2000@users.noreply.github.com> Date: Sun, 28 Sep 2025 16:05:17 +0000 Subject: [PATCH 14/19] fix Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> --- vllm/v1/worker/gpu_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 38be802c5ee0..b8005de7596a 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2986,7 +2986,7 @@ def _dummy_run( # for GQA/MQA. max_query_len = self.uniform_decode_query_len if uniform_decode else \ num_tokens - if uniform_decode is not None: + if uniform_decode: assert max_query_len == uniform_query_len # Set num_scheduled_tokens based on num_tokens and max_num_seqs From 0ee4aef78c483c1108ed073da877fa44c97e492f Mon Sep 17 00:00:00 2001 From: fhl2000 <63384265+fhl2000@users.noreply.github.com> Date: Tue, 14 Oct 2025 19:11:25 +0000 Subject: [PATCH 15/19] WIP:address dp padding issue Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> --- vllm/v1/cudagraph_dispatcher.py | 67 +++++++------------- vllm/v1/spec_decode/eagle.py | 12 ++-- vllm/v1/worker/dp_utils.py | 29 ++++++++- vllm/v1/worker/gpu_model_runner.py | 99 ++++++++++++++++++------------ vllm/v1/worker/gpu_worker.py | 2 +- 5 files changed, 113 insertions(+), 96 deletions(-) diff --git a/vllm/v1/cudagraph_dispatcher.py b/vllm/v1/cudagraph_dispatcher.py index ccacb95866a8..f45234eadbb1 100644 --- a/vllm/v1/cudagraph_dispatcher.py +++ b/vllm/v1/cudagraph_dispatcher.py @@ -1,10 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any -import torch - -import vllm.envs as envs from vllm.config import CUDAGraphMode, VllmConfig from vllm.forward_context import BatchDescriptor from vllm.utils import round_up @@ -29,14 +25,11 @@ class CudagraphDispatcher: runnable without cudagraph (if the mode does not match or mode is NONE). """ - def __init__( - self, vllm_config: VllmConfig, is_drafter: bool = False, runner: Any = None - ): + def __init__(self, vllm_config: VllmConfig, is_drafter: bool = False): self.vllm_config = vllm_config self.compilation_config = vllm_config.compilation_config self.cudagraph_mode = self.compilation_config.cudagraph_mode self.is_drafter = is_drafter - self.runner = runner # Dict to store valid cudagraph dispatching keys. self.cudagraph_keys: dict[CUDAGraphMode, set[BatchDescriptor]] = { @@ -188,10 +181,10 @@ def get_capture_cases( ] return capture_sizes, keys, runtime_mode - def padded_num_tokens( + def cudagraph_padded_num_tokens( self, num_tokens: int, uniform_decode: bool, uniform_query_len: int ) -> tuple[int, bool]: - """Return num_tokens after padded and whether it is cudagraph padded.""" + """Return Tuple[num_tokens_after_padded, is_cudagraph_padded].""" assert uniform_query_len == 0 or uniform_query_len in self.uniform_query_lens, ( f"Invalid uniform_query_len: {uniform_query_len}" ) @@ -231,27 +224,20 @@ def caculate_uniform_decode( uniform_query_len = max_query_len if uniform_decode else 0 return uniform_decode, uniform_query_len - def get_num_input_tokens_local( + def get_local_batch_description( self, num_scheduled_tokens: int, num_reqs: int, max_query_len: int - ) -> int: - """return num_input_tokens, acounting for cudagraph padding and - tp padding locally, but not across dp. + ) -> tuple[int, bool, int]: + """ + return Tuple[num_tokens_after_padding, uniform_decode, uniform_query_len] """ uniform_decode, uniform_query_len = self.caculate_uniform_decode( num_scheduled_tokens, num_reqs, max_query_len ) - # store for later use in plan - self._uniform_decode = uniform_decode - self._uniform_query_len = uniform_query_len - # Compute padded tokens cudagraph_padded = False - if ( - self.cudagraph_mode != CUDAGraphMode.NONE - and not envs.VLLM_DISABLE_PAD_FOR_CUDAGRAPH - ): - num_input_tokens, cudagraph_padded = self.padded_num_tokens( + if self.cudagraph_mode != CUDAGraphMode.NONE: + num_input_tokens, cudagraph_padded = self.cudagraph_padded_num_tokens( num_scheduled_tokens, uniform_decode, uniform_query_len ) else: @@ -267,44 +253,33 @@ def get_num_input_tokens_local( and tp_size > 1 ): num_input_tokens = round_up(num_scheduled_tokens, tp_size) - return num_input_tokens - - def maybe_pad_for_dp( - self, num_input_tokens: int - ) -> tuple[int, torch.Tensor | None]: - if self.runner and hasattr(self.runner, "get_dp_padding"): - assert not self.is_drafter - return self.runner.get_dp_padding(num_input_tokens) - return 0, None + return num_input_tokens, uniform_decode, uniform_query_len - def plan( + def fast_plan( self, num_scheduled_tokens: int, num_reqs: int, max_query_len: int, use_cascade_attn: bool = False, - ) -> tuple[CUDAGraphMode, BatchDescriptor | None, int, torch.Tensor | None]: - """Plan cudagraph execution in a single call. + ) -> tuple[CUDAGraphMode, BatchDescriptor | None, int]: + """Plan cudagraph execution in a single call, without considering dp. - Returns (runtime_mode, batch_descriptor, num_input_tokens, - num_tokens_across_dp). + Returns (runtime_mode, batch_descriptor, num_input_tokens). """ - num_input_tokens = self.get_num_input_tokens_local( - num_scheduled_tokens, num_reqs, max_query_len + num_input_tokens, uniform_decode, uniform_query_len = ( + self.get_local_batch_description( + num_scheduled_tokens, num_reqs, max_query_len + ) ) - # maybe pad for dp - num_pad, num_tokens_across_dp = self.maybe_pad_for_dp(num_input_tokens) - num_input_tokens += num_pad - # Build initial descriptor and then dispatch descriptor = BatchDescriptor( num_tokens=num_input_tokens, - uniform_decode=self._uniform_decode, - uniform_query_len=self._uniform_query_len, + uniform_decode=uniform_decode, + uniform_query_len=uniform_query_len, ) runtime_mode, descriptor = self.dispatch(descriptor, use_cascade_attn) - return runtime_mode, descriptor, num_input_tokens, num_tokens_across_dp + return runtime_mode, descriptor, num_input_tokens def dispatch( self, batch_descriptor: BatchDescriptor, use_cascade_attn: bool = False diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 45db8e40ce50..6d0708cb2cf8 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -245,8 +245,8 @@ def propose( per_layer_attn_metadata[layer_name] = draft_indexer_metadata # dispatcher planning for drafter - cudagraph_runtime_mode, batch_descriptor, num_input_tokens, _ = ( - self.cudagraph_dispatcher.plan( + cudagraph_runtime_mode, batch_descriptor, num_input_tokens = ( + self.cudagraph_dispatcher.fast_plan( num_scheduled_tokens=num_tokens, num_reqs=batch_size, max_query_len=max_query_len, @@ -334,8 +334,8 @@ def propose( draft_token_ids_list = [draft_token_ids] # dispatcher plans only once for the remaining loop - cudagraph_runtime_mode, batch_descriptor, input_batch_size, _ = ( - self.cudagraph_dispatcher.plan( + cudagraph_runtime_mode, batch_descriptor, input_batch_size = ( + self.cudagraph_dispatcher.fast_plan( num_scheduled_tokens=batch_size, num_reqs=batch_size, max_query_len=1 ) ) @@ -748,8 +748,8 @@ def propose_tree( # unique uniform decode query length (1 for the root level and # total_num_drafts for subsequent levels). Here we may support # this situation once full cudagraph of TreeAttention is supported. - cudagraph_runtime_mode, batch_descriptor, num_input_tokens, _ = ( - self.cudagraph_dispatcher.plan( + cudagraph_runtime_mode, batch_descriptor, num_input_tokens = ( + self.cudagraph_dispatcher.fast_plan( num_scheduled_tokens=num_tokens, num_reqs=batch_size, max_query_len=attn_metadata.max_query_len, diff --git a/vllm/v1/worker/dp_utils.py b/vllm/v1/worker/dp_utils.py index 3f24ff0a09de..034f324b027d 100644 --- a/vllm/v1/worker/dp_utils.py +++ b/vllm/v1/worker/dp_utils.py @@ -39,16 +39,20 @@ def _run_ar( should_dp_pad: bool, orig_num_tokens_per_ubatch: int, padded_num_tokens_per_ubatch: int, + disable_padding_extend: bool, + num_tokens_padded_extended: int, parallel_config: ParallelConfig, ) -> torch.Tensor: dp_size = parallel_config.data_parallel_size dp_rank = parallel_config.data_parallel_rank device, group = _get_device_and_group(parallel_config) - tensor = torch.zeros(4, dp_size, device=device, dtype=torch.int32) + tensor = torch.zeros(6, dp_size, device=device, dtype=torch.int32) tensor[0][dp_rank] = orig_num_tokens_per_ubatch tensor[1][dp_rank] = padded_num_tokens_per_ubatch tensor[2][dp_rank] = 1 if should_ubatch else 0 tensor[3][dp_rank] = 1 if should_dp_pad else 0 + tensor[4][dp_rank] = 1 if disable_padding_extend else 0 + tensor[5][dp_rank] = num_tokens_padded_extended dist.all_reduce(tensor, group=group) return tensor @@ -76,6 +80,11 @@ def _post_process_ubatch(tensor: torch.Tensor) -> bool: def _post_process_dp_padding(tensor: torch.Tensor, should_dp_pad: bool) -> torch.Tensor: num_tokens_across_dp = tensor[1, :] if should_dp_pad: + # replace num_tokens_across_dp with the extended version when exists one dp rank + # do not disable it. + disable_padding_extend = bool(torch.all(tensor[4] == 1).item()) + if not disable_padding_extend: + num_tokens_across_dp = tensor[5, :] # If DP padding is enabled, ensure that each rank is processing the same number # of tokens max_num_tokens = int(num_tokens_across_dp.max().item()) @@ -93,6 +102,8 @@ def _synchronize_dp_ranks( num_tokens_padded: int, should_attempt_ubatching: bool, should_attempt_dp_padding: bool, + disable_padding_extend: bool, + num_tokens_padded_extended: int, parallel_config: ParallelConfig, ) -> tuple[bool, torch.Tensor | None]: """ @@ -120,6 +131,8 @@ def _synchronize_dp_ranks( should_dp_pad=should_attempt_dp_padding, orig_num_tokens_per_ubatch=num_tokens_unpadded, padded_num_tokens_per_ubatch=num_tokens_padded, + disable_padding_extend=disable_padding_extend, + num_tokens_padded_extended=num_tokens_padded_extended, parallel_config=parallel_config, ) @@ -157,6 +170,8 @@ def coordinate_batch_across_dp( parallel_config: ParallelConfig, num_tokens_padded: int | None = None, uniform_decode: bool | None = None, + disable_padding_extend: bool = True, + num_tokens_padded_extended: int | None = None, num_scheduled_tokens_per_request: np.ndarray | None = None, ) -> tuple[UBatchSlices | None, torch.Tensor | None]: """ @@ -170,8 +185,11 @@ def coordinate_batch_across_dp( parallel_config: The parallel config num_tokens_padded: Number of tokens including any non-DP padding (CUDA graphs, TP, etc) - uniform_decode: Only used if allow_microbatching is True. True if the batch - only contains single token decodes + uniform_decode: Used when allow_microbatching is True and/or when it is uniform + decoding for spec-decode. + disable_padding_extend: If it is True across all dp rank, we do not extend the + padding from uniform-decode batch to non-uniform batch. + num_tokens_padded_extended: the number of tokens after extending the padding. num_scheduled_tokens_per_request: Only used if allow_microbatching is True. The number of tokens per request. @@ -203,11 +221,16 @@ def coordinate_batch_across_dp( if num_tokens_padded is None: num_tokens_padded = num_tokens_unpadded + if num_tokens_padded_extended is None: + num_tokens_padded_extended = num_tokens_padded + (should_ubatch, num_tokens_after_padding) = _synchronize_dp_ranks( num_tokens_unpadded, num_tokens_padded, should_attempt_ubatching, allow_dp_padding, + disable_padding_extend, + num_tokens_padded_extended, parallel_config, ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index eaf9de3dae30..91439f1e809f 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -483,7 +483,7 @@ def __init__( # Cudagraph dispatcher for runtime cudagraph dispatching. self.cudagraph_dispatcher = CudagraphDispatcher( - self.vllm_config, is_drafter=False, runner=self + self.vllm_config, is_drafter=False ) self.mm_budget = ( @@ -1046,7 +1046,7 @@ def _prepare_inputs( SpecDecodeMetadata | None, np.ndarray, CommonAttentionMetadata | None, - int, + BatchDescriptor, UBatchSlices | None, torch.Tensor | None, bool, @@ -1173,12 +1173,11 @@ def _prepare_inputs( query_start_loc = self.query_start_loc.gpu[: num_reqs + 1] num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens - num_tokens_padded = self._get_num_input_tokens( - num_tokens_unpadded, num_reqs, max_num_scheduled_tokens + num_tokens_padded, uniform_decode, uniform_query_len = ( + self._get_local_batch_description( + num_tokens_unpadded, num_reqs, max_num_scheduled_tokens + ) ) - uniform_decode = ( - max_num_scheduled_tokens == self.uniform_decode_query_len - ) and (total_num_scheduled_tokens == num_reqs * max_num_scheduled_tokens) # Disable DP padding when running eager to avoid excessive padding when # running prefills. This lets us set enforce_eager on the prefiller in @@ -1186,6 +1185,21 @@ def _prepare_inputs( # decoder. allow_dp_padding = self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + # For uniform decode batch with query length > 1, we may extend to non-uniform + # padding if there exists one dp rank that is non-uniform batch (i.e. can run + # into piecewise cudagraph), to resolve the conflicts of we may no have proper + # cudagraph for uniform batch after dp-padding. + disable_padding_extend = not uniform_decode or uniform_query_len <= 1 + if ( + not disable_padding_extend + and num_tokens_padded < self.compilation_config.max_capture_size + ): + num_tokens_padded_extended = self.vllm_config.pad_for_cudagraph( + num_tokens_padded, uniform_aligned=False + ) + else: + num_tokens_padded_extended = num_tokens_padded + ubatch_slices, num_tokens_across_dp = coordinate_batch_across_dp( num_tokens_unpadded=num_tokens_unpadded, parallel_config=self.parallel_config, @@ -1193,6 +1207,8 @@ def _prepare_inputs( allow_dp_padding=allow_dp_padding, num_tokens_padded=num_tokens_padded, uniform_decode=uniform_decode, + disable_padding_extend=disable_padding_extend, + num_tokens_padded_extended=num_tokens_padded_extended, num_scheduled_tokens_per_request=num_scheduled_tokens, ) @@ -1421,13 +1437,29 @@ def _prepare_inputs( if self.lora_config: self.set_active_loras(self.input_batch, num_scheduled_tokens) + dp_rank = self.parallel_config.data_parallel_rank + if ubatch_slices: + assert num_tokens_across_dp is not None + num_input_tokens = int(num_tokens_across_dp[dp_rank].item()) + self.pad_out_ubatch_slice(ubatch_slices, num_input_tokens) + elif num_tokens_across_dp is not None: + num_input_tokens = int(num_tokens_across_dp[dp_rank].item()) + else: + num_input_tokens = num_tokens_padded + + batch_descriptor = BatchDescriptor( + num_tokens=num_input_tokens, + uniform_decode=uniform_decode, + uniform_query_len=uniform_query_len, + ) + return ( attn_metadata, logits_indices, spec_decode_metadata, num_scheduled_tokens, spec_decode_common_attn_metadata, - max_num_scheduled_tokens, + batch_descriptor, ubatch_slices, num_tokens_across_dp, use_cascade_attn, @@ -2068,10 +2100,14 @@ def _pool( pooler_output=pooler_output, ) - def _get_num_input_tokens( + def _get_local_batch_description( self, num_scheduled_tokens: int, num_reqs: int, max_query_len: int - ) -> int: - return self.cudagraph_dispatcher.get_num_input_tokens_local( + ) -> tuple[int, bool, int]: + """ + Get local batch descriptions for before DP sync. + returns (num_tokens_after_padding, uniform_decode, uniform_query_len) + """ + return self.cudagraph_dispatcher.get_local_batch_description( num_scheduled_tokens, num_reqs, max_query_len ) @@ -2081,6 +2117,7 @@ def _preprocess( num_input_tokens: int, # Padded intermediate_tensors: IntermediateTensors | None = None, ) -> tuple[ + int, torch.Tensor | None, torch.Tensor | None, torch.Tensor, @@ -2174,6 +2211,7 @@ def _preprocess( model_kwargs.update(encoder_inputs) return ( + num_scheduled_tokens, input_ids, inputs_embeds, positions, @@ -2425,44 +2463,23 @@ def execute_model( spec_decode_metadata, num_scheduled_tokens_np, spec_decode_common_attn_metadata, - max_query_len, + batch_descriptor, ubatch_slices, num_tokens_across_dp, use_cascade_attn, ) = self._prepare_inputs(scheduler_output) - dp_rank = self.parallel_config.data_parallel_rank - if ubatch_slices: - assert num_tokens_across_dp is not None - num_input_tokens = int(num_tokens_across_dp[dp_rank].item()) - self.pad_out_ubatch_slice(ubatch_slices, num_input_tokens) - elif num_tokens_across_dp is not None: - num_input_tokens = int(num_tokens_across_dp[dp_rank].item()) - else: - num_input_tokens = self._get_num_input_tokens( - scheduler_output.total_num_scheduled_tokens, - self.input_batch.num_reqs, - max_query_len, - ) - - # Cudagraph dispatcher planing + padding - # For ubatching, we still use original total num_scheduled_tokens - # for planing, as gpu_ubatch_wrapper awares only the - # cudagraph_runtime_mode but not batch_descriptor. - num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + # cudagraph dispatching ( cudagraph_runtime_mode, batch_descriptor, - num_input_tokens, - num_tokens_across_dp, - ) = self.cudagraph_dispatcher.plan( - num_scheduled_tokens=num_scheduled_tokens, - num_reqs=self.input_batch.num_reqs, - max_query_len=max_query_len, - use_cascade_attn=use_cascade_attn, + ) = self.cudagraph_dispatcher.dispatch( + batch_descriptor, + use_cascade_attn, ) - + num_input_tokens = batch_descriptor.num_tokens ( + num_scheduled_tokens, input_ids, inputs_embeds, positions, @@ -3286,6 +3303,8 @@ def _dummy_run( allow_dp_padding=allow_dp_padding, num_tokens_padded=total_num_scheduled_tokens, uniform_decode=uniform_decode, + disable_padding_extend=True, + num_tokens_padded_extended=total_num_scheduled_tokens, num_scheduled_tokens_per_request=num_scheduled_tokens, ) num_tokens_after_padding = num_tokens @@ -3413,7 +3432,7 @@ def _dummy_run( _cg_mode, batch_descriptor = ( self.cudagraph_dispatcher.dispatch( BatchDescriptor( - num_tokens=num_tokens, + num_tokens=num_tokens_after_padding, uniform_decode=uniform_decode, uniform_query_len=uniform_query_len, ) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 715abb81fefb..d6d3735bbe18 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -466,7 +466,7 @@ def execute_model( if num_scheduled_tokens > 0 else 0 ) - num_input_tokens = self.model_runner._get_num_input_tokens( + num_input_tokens, _, _ = self.model_runner._get_local_batch_description( num_scheduled_tokens, num_reqs, max_query_len ) all_gather_tensors = { From a4872bc6c7d8b7982f0fc1992f4c6d8673a11519 Mon Sep 17 00:00:00 2001 From: fhl2000 <63384265+fhl2000@users.noreply.github.com> Date: Wed, 15 Oct 2025 05:38:37 +0000 Subject: [PATCH 16/19] clean up Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> --- vllm/config/compilation.py | 48 +++++++++++++----------------- vllm/config/vllm.py | 6 +++- vllm/v1/cudagraph_dispatcher.py | 28 ++++++++--------- vllm/v1/worker/dp_utils.py | 12 ++++---- vllm/v1/worker/gpu_model_runner.py | 34 +++++++++++++++------ 5 files changed, 71 insertions(+), 57 deletions(-) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index a5e5c420caca..195ba38b1f90 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -371,6 +371,27 @@ class CompilationConfig: max_capture_size: int = field(default=None, init=False) # type: ignore """not configurable, computed after init""" + disable_cudagraph_uniform_alignment: bool = False + """Whether to disable uniformly alignment of cudagraph capture sizes for + uniform decode batch with query length>1 (i.e., for spec-decode). This flag + only takes effective when cudagraph_mode is FULL_DECODE_ONLY or + FULL_AND_PIECEWISE. + + Uniform alignment make sure all capture sizes for uniform-decode batch + are multiples of 1+num_speculative_tokens. This aligmnment is typically + useful for padded speculation (see #21984 for details), and is needed by + some attention backends to achieve their sota performance, which support + uniform-decode but no in a varible-length fashion. However, we should + realize here is a trade-off that while it is good for attention layer, + it may introduce slight regressions to other layers if these sizes after + alignment don't hit the multiple of 8. + + Note: for DP_size>1, the uniformity of sizes may be broken after dp_padding + sync. Therefore, we only ensure running full cudagraph of uniform-decode batch + of current rank if all dp ranks are uniform-decode batch. Otherwise, it would + fall back to piecewise cudagraphs, where the uniformity batch before padded + should still be utilized by attention layers under eager exectution. + """ uniform_cudagraph_capture_sizes: list[int] | None = None """ List for capture sizes for uniform decode for the main model. Its elements @@ -709,15 +730,6 @@ def init_with_cudagraph_sizes( else 0 ) - self.uniform_cudagraph_capture_sizes = sorted( - uniform_cudagraph_capture_sizes, reverse=True - ) - self.max_uniform_capture_size = ( - self.uniform_cudagraph_capture_sizes[0] - if self.uniform_cudagraph_capture_sizes - else 0 - ) - # pre-compute the mapping from batch size to padded graph size self.bs_to_padded_graph_size = [0 for i in range(self.max_capture_size + 1)] for end, start in zip( @@ -748,24 +760,6 @@ def init_with_cudagraph_sizes( self.max_uniform_capture_size ) - # pre-compute the mapping for uniform decode padding. - self.bs_to_padded_graph_size_uniform = [ - 0 for i in range(self.max_uniform_capture_size + 1) - ] - - for end, start in zip( - self.uniform_cudagraph_capture_sizes, - self.uniform_cudagraph_capture_sizes[1:] + [0], - ): - for bs in range(start, end): - if bs == start: - self.bs_to_padded_graph_size_uniform[bs] = start - else: - self.bs_to_padded_graph_size_uniform[bs] = end - self.bs_to_padded_graph_size_uniform[self.max_uniform_capture_size] = ( - self.max_uniform_capture_size - ) - def set_splitting_ops_for_v1(self): # NOTE: this function needs to be called only when level is # CompilationLevel.PIECEWISE diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 5bff4644cbae..bab5b006fd7c 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -209,7 +209,8 @@ def pad_for_cudagraph(self, batch_size: int, uniform_aligned: bool = False) -> i For drafter, caller should make sure uniform_aligned is False because drafter's uniform_decode_len is 1. """ - + if self.compilation_config.disable_cudagraph_uniform_alignment: + uniform_aligned = False # if batch_size > self.compilation_config.max_capture_size when # uniform_aligned is False, or batch_size > self.compilation_config. # max_uniform_capture_size when uniform_aligned is True, @@ -714,6 +715,9 @@ def _set_cudagraph_sizes(self): if size >= uniform_decode_len ) ) + if self.compilation_config.disable_cudagraph_uniform_alignment: + uniform_batch_size_capture_list = batch_size_capture_list + if ( self.parallel_config.tensor_parallel_size > 1 and self.compilation_config.pass_config.enable_sequence_parallelism diff --git a/vllm/v1/cudagraph_dispatcher.py b/vllm/v1/cudagraph_dispatcher.py index f45234eadbb1..ed6fe01ec4c9 100644 --- a/vllm/v1/cudagraph_dispatcher.py +++ b/vllm/v1/cudagraph_dispatcher.py @@ -62,10 +62,6 @@ def __init__(self, vllm_config: VllmConfig, is_drafter: bool = False): self.keys_initialized = False - # For storing temp variables - self._uniform_decode: bool = False - self._uniform_query_len: int = 0 - def add_cudagraph_key( self, runtime_mode: CUDAGraphMode, batch_descriptor: BatchDescriptor ): @@ -94,15 +90,16 @@ def initialize_cudagraph_keys( # we only have compilation_config.uniform_cudagraph_capture_sizes # being aligned with one uniform_query_len that greater than 1, not # multiple of them. Should verify this here. - for uniform_query_len in self.uniform_query_lens: - if ( - uniform_query_len > 1 - and self.compilation_config.uniform_cudagraph_capture_sizes - ): - assert all( - x % uniform_query_len == 0 - for x in self.compilation_config.uniform_cudagraph_capture_sizes - ), f"Invalid uniform_query_lens: {uniform_query_len}" + if not self.compilation_config.disable_cudagraph_uniform_alignment: + for uniform_query_len in self.uniform_query_lens: + if ( + uniform_query_len > 1 + and self.compilation_config.uniform_cudagraph_capture_sizes + ): + assert all( + x % uniform_query_len == 0 + for x in self.compilation_config.uniform_cudagraph_capture_sizes + ), f"Invalid uniform_query_lens: {uniform_query_len}" if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE: for bs in self.compilation_config.cudagraph_capture_sizes: @@ -130,7 +127,10 @@ def initialize_cudagraph_keys( # sizes. candidate_sizes = ( self.compilation_config.cudagraph_capture_sizes - if uniform_query_len == 1 + if ( + uniform_query_len == 1 + or self.compilation_config.disable_cudagraph_uniform_alignment + ) else self.compilation_config.uniform_cudagraph_capture_sizes ) cudagraph_capture_sizes_for_decode = [ diff --git a/vllm/v1/worker/dp_utils.py b/vllm/v1/worker/dp_utils.py index 034f324b027d..cee0ef648968 100644 --- a/vllm/v1/worker/dp_utils.py +++ b/vllm/v1/worker/dp_utils.py @@ -102,7 +102,7 @@ def _synchronize_dp_ranks( num_tokens_padded: int, should_attempt_ubatching: bool, should_attempt_dp_padding: bool, - disable_padding_extend: bool, + try_disable_padding_extend: bool, num_tokens_padded_extended: int, parallel_config: ParallelConfig, ) -> tuple[bool, torch.Tensor | None]: @@ -131,7 +131,7 @@ def _synchronize_dp_ranks( should_dp_pad=should_attempt_dp_padding, orig_num_tokens_per_ubatch=num_tokens_unpadded, padded_num_tokens_per_ubatch=num_tokens_padded, - disable_padding_extend=disable_padding_extend, + disable_padding_extend=try_disable_padding_extend, num_tokens_padded_extended=num_tokens_padded_extended, parallel_config=parallel_config, ) @@ -170,7 +170,7 @@ def coordinate_batch_across_dp( parallel_config: ParallelConfig, num_tokens_padded: int | None = None, uniform_decode: bool | None = None, - disable_padding_extend: bool = True, + try_disable_padding_extend: bool = True, num_tokens_padded_extended: int | None = None, num_scheduled_tokens_per_request: np.ndarray | None = None, ) -> tuple[UBatchSlices | None, torch.Tensor | None]: @@ -187,8 +187,8 @@ def coordinate_batch_across_dp( TP, etc) uniform_decode: Used when allow_microbatching is True and/or when it is uniform decoding for spec-decode. - disable_padding_extend: If it is True across all dp rank, we do not extend the - padding from uniform-decode batch to non-uniform batch. + try_disable_padding_extend: If it is True across all dp rank, we do not extend + the padding to the max value of num_tokens_padded_extended across dp ranks. num_tokens_padded_extended: the number of tokens after extending the padding. num_scheduled_tokens_per_request: Only used if allow_microbatching is True. The number of tokens per request. @@ -229,7 +229,7 @@ def coordinate_batch_across_dp( num_tokens_padded, should_attempt_ubatching, allow_dp_padding, - disable_padding_extend, + try_disable_padding_extend, num_tokens_padded_extended, parallel_config, ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 91439f1e809f..e1f4cbb4dc45 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1187,9 +1187,15 @@ def _prepare_inputs( # For uniform decode batch with query length > 1, we may extend to non-uniform # padding if there exists one dp rank that is non-uniform batch (i.e. can run - # into piecewise cudagraph), to resolve the conflicts of we may no have proper - # cudagraph for uniform batch after dp-padding. - disable_padding_extend = not uniform_decode or uniform_query_len <= 1 + # into piecewise cudagraph), to resolve the conflicts where we may no have + # cudagraph for uniform-batch after dp-padding. + num_tokens_padded_extended = num_tokens_padded + disable_padding_extend = ( + self.compilation_config.disable_cudagraph_uniform_alignment + or not self.compilation_config.cudagraph_mode.separate_routine() + or not uniform_decode + or uniform_query_len <= 1 + ) if ( not disable_padding_extend and num_tokens_padded < self.compilation_config.max_capture_size @@ -1197,8 +1203,6 @@ def _prepare_inputs( num_tokens_padded_extended = self.vllm_config.pad_for_cudagraph( num_tokens_padded, uniform_aligned=False ) - else: - num_tokens_padded_extended = num_tokens_padded ubatch_slices, num_tokens_across_dp = coordinate_batch_across_dp( num_tokens_unpadded=num_tokens_unpadded, @@ -1207,7 +1211,7 @@ def _prepare_inputs( allow_dp_padding=allow_dp_padding, num_tokens_padded=num_tokens_padded, uniform_decode=uniform_decode, - disable_padding_extend=disable_padding_extend, + try_disable_padding_extend=disable_padding_extend, num_tokens_padded_extended=num_tokens_padded_extended, num_scheduled_tokens_per_request=num_scheduled_tokens, ) @@ -3254,7 +3258,8 @@ def _dummy_run( # routine of FA2 for pure decode, i.e., Flashdecode + an optimization # for GQA/MQA. max_query_len = self.uniform_decode_query_len if uniform_decode else num_tokens - if uniform_decode: + if uniform_decode and uniform_query_len: + # allow skip this assertion when it is on a dummy execution on DP setup assert max_query_len == uniform_query_len # Set num_scheduled_tokens based on num_tokens and max_num_seqs @@ -3294,6 +3299,17 @@ def _dummy_run( # Disable DP padding when running eager allow_dp_padding = self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + # make sure uniform-decode batch can safely hit a cudagraph when it is + # on a dummy execution for DP size>1. + num_tokens_padded_extended = total_num_scheduled_tokens + if ( + total_num_scheduled_tokens + < self.vllm_config.compilation_config.max_capture_size + ): + num_tokens_padded_extended = self.vllm_config.pad_for_cudagraph( + total_num_scheduled_tokens, uniform_aligned=False + ) + # We currently only microbatch if the number of tokens is # over a certain threshold. ubatch_slices, num_tokens_across_dp = coordinate_batch_across_dp( @@ -3303,8 +3319,8 @@ def _dummy_run( allow_dp_padding=allow_dp_padding, num_tokens_padded=total_num_scheduled_tokens, uniform_decode=uniform_decode, - disable_padding_extend=True, - num_tokens_padded_extended=total_num_scheduled_tokens, + try_disable_padding_extend=True, + num_tokens_padded_extended=num_tokens_padded_extended, num_scheduled_tokens_per_request=num_scheduled_tokens, ) num_tokens_after_padding = num_tokens From c18486a5d43b219a47ff6237f2b5e14e01cd8df5 Mon Sep 17 00:00:00 2001 From: fhl2000 <63384265+fhl2000@users.noreply.github.com> Date: Sat, 1 Nov 2025 12:57:45 +0000 Subject: [PATCH 17/19] refactor eagle dummy run Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> --- vllm/v1/spec_decode/eagle.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 331a916ab3b7..c2a9d868fde9 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -1047,7 +1047,7 @@ def dummy_run( force_attention: bool = False, uniform_decode: bool = False, uniform_query_len: int = 0, - allow_microbatching: bool = False, # unused for drafter + **other_kwargs, # unused but may get passed from caller ) -> None: assert cudagraph_runtime_mode != CUDAGraphMode.FULL, ( "Eagle drafter doesn't support full cudagraphs at this moment" @@ -1108,6 +1108,7 @@ def dummy_run( num_tokens=num_tokens, uniform_decode=uniform_decode, uniform_query_len=uniform_query_len, + has_lora=False, ) # sanity check _cg_mode, batch_descriptor = self.cudagraph_dispatcher.dispatch( From 299ce7dace094b24b6bb52ed4de29372dcd796d7 Mon Sep 17 00:00:00 2001 From: fhl2000 <63384265+fhl2000@users.noreply.github.com> Date: Tue, 4 Nov 2025 03:15:26 +0000 Subject: [PATCH 18/19] fix drafter when enforce_eager Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> --- vllm/v1/spec_decode/eagle.py | 4 ++- vllm/v1/worker/gpu_model_runner.py | 53 ++++++++---------------------- 2 files changed, 17 insertions(+), 40 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index dec677e711fb..55343bf3b748 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -86,7 +86,6 @@ def __init__( and not self.speculative_config.enforce_eager ) - self.use_cuda_graph = self.use_cuda_graph and bool(self.cudagraph_batch_sizes) # persistent buffers for cuda graph self.input_ids = torch.zeros( self.max_num_tokens, dtype=torch.int32, device=device @@ -1053,6 +1052,9 @@ def dummy_run( assert cudagraph_runtime_mode != CUDAGraphMode.FULL, ( "Eagle drafter doesn't support full cudagraphs at this moment" ) + # overwrite runtime mode to NONE if enforce_eager + if self.speculative_config.enforce_eager: + cudagraph_runtime_mode = CUDAGraphMode.NONE max_query_len = uniform_query_len if uniform_decode else num_tokens diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 665ab91c33c6..7002dd54ddbd 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3601,11 +3601,8 @@ def _dummy_run( and self.speculative_config.use_eagle() ): assert isinstance(self.drafter, EagleProposer) - use_cudagraphs = ( - cudagraph_runtime_mode == CUDAGraphMode.PIECEWISE - and not self.speculative_config.enforce_eager - ) - self.drafter.dummy_run(num_tokens, use_cudagraphs=use_cudagraphs) + # no cudagraph for profile run + self.drafter.dummy_run(num_tokens) # This is necessary to avoid blocking DP. # For dummy runs, we typically skip EPLB since we don't have any real @@ -3950,9 +3947,14 @@ def freeze_gc(): # Note: Currently only PIECEWISE mode is supported for eagle # drafter. # TODO: add full cudagraph support for drafter. - if self.speculative_config and self.speculative_config.use_eagle(): + if ( + self.speculative_config + and self.speculative_config.use_eagle() + and not self.speculative_config.enforce_eager + ): assert isinstance(self.drafter, EagleProposer) logger.info("Start capturing cudagraphs for drafter...") + # when not enforce_eager, eagle drafter share the same cudagraph_mode if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE: capture_sizes, keys, runtime_mode = ( self.drafter.cudagraph_dispatcher.get_capture_cases( @@ -4312,43 +4314,16 @@ def _check_and_update_cudagraph_mode( ) # At this moment, we assume the drafter and main model shares the - # same cudagraph_mode + # same cudagraph_mode if not speculative_config.enforce_eager if self.speculative_config and self.speculative_config.use_eagle(): assert isinstance(self.drafter, EagleProposer) assert not cudagraph_mode.has_full_cudagraphs(), ( "Eagle drafter does not support full cudagraphs yet" ) - # uniform_query_len is 1 for drafter - # TODO: let uniform_query_lens = [1, self.uniform_decode_query_len] - # for drafter once Padded speculation is supported. See: - # https://github.com/vllm-project/vllm/issues/21984 for details - # and an implementation in https://github.com/vllm-project/vllm/pull/24539 # noqa: E501 - self.drafter.cudagraph_dispatcher.initialize_cudagraph_keys( - self.compilation_config.cudagraph_mode, uniform_query_lens=1 - ) - - # At this moment, we assume the drafter and main model shares the - # same cudagraph_mode - if self.speculative_config and self.speculative_config.use_eagle(): - assert isinstance(self.drafter, EagleProposer) - assert not cudagraph_mode.has_full_cudagraphs(), ( - "Eagle drafter does not support full cudagraphs yet" - ) - # uniform_query_len is 1 for drafter - # TODO: let uniform_query_lens = [1, self.uniform_decode_query_len] - # for drafter once Padded speculation is supported. See: - # https://github.com/vllm-project/vllm/issues/21984 for details - # and an implementation in https://github.com/vllm-project/vllm/pull/24539 # noqa: E501 - self.drafter.cudagraph_dispatcher.initialize_cudagraph_keys( - self.compilation_config.cudagraph_mode, uniform_query_lens=1 - ) - - # At this moment, we assume the drafter and main model shares the - # same cudagraph_mode - if self.speculative_config and self.speculative_config.use_eagle(): - assert isinstance(self.drafter, EagleProposer) - assert not cudagraph_mode.has_full_cudagraphs(), ( - "Eagle drafter does not support full cudagraphs yet" + cudagraph_mode = ( + self.compilation_config.cudagraph_mode + if not self.speculative_config.enforce_eager + else CUDAGraphMode.NONE ) # uniform_query_len is 1 for drafter # TODO: let uniform_query_lens = [1, self.uniform_decode_query_len] @@ -4356,7 +4331,7 @@ def _check_and_update_cudagraph_mode( # https://github.com/vllm-project/vllm/issues/21984 for details # and an implementation in https://github.com/vllm-project/vllm/pull/24539 # noqa: E501 self.drafter.cudagraph_dispatcher.initialize_cudagraph_keys( - self.compilation_config.cudagraph_mode, uniform_query_lens=1 + cudagraph_mode, uniform_query_lens=1 ) def calculate_reorder_batch_threshold(self) -> None: From b5c315a7addf17a5bc9cd2360f6511a88b867fd0 Mon Sep 17 00:00:00 2001 From: fhl2000 <63384265+fhl2000@users.noreply.github.com> Date: Tue, 4 Nov 2025 03:38:35 +0000 Subject: [PATCH 19/19] fix pre-commit Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> --- vllm/v1/attention/backends/mamba_attn.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index c92f1d0ab288..49d7d6c31b9a 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -11,6 +11,7 @@ from vllm.v1.attention.backends.utils import ( AttentionCGSupport, AttentionMetadataBuilder, + CommonAttentionMetadata, ) from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec