From a012257766df8b468d498ed83fece20387e14bc8 Mon Sep 17 00:00:00 2001 From: fhl2000 <63384265+fhl2000@users.noreply.github.com> Date: Sat, 16 Aug 2025 05:31:31 +0000 Subject: [PATCH 01/21] renaming to piecewise_backend.py Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> --- vllm/compilation/backends.py | 2 +- .../{cuda_piecewise_backend.py => piecewise_backend.py} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename vllm/compilation/{cuda_piecewise_backend.py => piecewise_backend.py} (100%) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 059e7a3b2976..faac01fd1f2f 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -338,7 +338,7 @@ def call_module(self, target: torch.fx.node.Target, runtime_shape=None) # Lazy import here to avoid circular import from .cuda_graph import CUDAGraphOptions - from .cuda_piecewise_backend import PiecewiseBackend + from .piecewise_backend import PiecewiseBackend piecewise_backend = PiecewiseBackend( submod, self.vllm_config, index, diff --git a/vllm/compilation/cuda_piecewise_backend.py b/vllm/compilation/piecewise_backend.py similarity index 100% rename from vllm/compilation/cuda_piecewise_backend.py rename to vllm/compilation/piecewise_backend.py From c37583ca21c4eb64323acf7f1455200ea735fd4a Mon Sep 17 00:00:00 2001 From: fhl2000 <63384265+fhl2000@users.noreply.github.com> Date: Sun, 17 Aug 2025 04:47:19 +0000 Subject: [PATCH 02/21] dispatch cascade attention to NONE or PIECEWISE runtime mode;clean up comments; default empty splitting_ops when enable_attn_fusion Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> --- vllm/config/__init__.py | 18 ++++++++---- vllm/config/compilation.py | 33 +++++++++++++++++++--- vllm/forward_context.py | 3 +- vllm/v1/cudagraph_dispatcher.py | 44 ++++++++++++++++-------------- vllm/v1/worker/gpu_model_runner.py | 33 ++++++++++++++-------- 5 files changed, 88 insertions(+), 43 deletions(-) diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 280ae60c91ff..b211f86737c0 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -3619,13 +3619,21 @@ def __post_init__(self): # final check of cudagraph mode after platform-specific update if envs.VLLM_USE_V1: - if self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL \ + if self.compilation_config.cudagraph_mode.has_full_cudagraphs()\ and self.model_config is not None and \ not self.model_config.disable_cascade_attn: - logger.info("CUDAGraphMode.FULL is not supported with " - "cascade attention currently. Disabling cascade" - "attention.") - self.model_config.disable_cascade_attn = True + warn_msg = ("Cascade attention is not supported with full " + "cudagraphs currently. ") + if self.compilation_config.cudagraph_mode.\ + has_piecewise_cudagraphs(): + logger.warning_once( + warn_msg + "It will dispatched to " + "piecewise cudagraphs if a batch runs into cascade " + "attentions") + else: + logger.warning_once( + warn_msg + "It will fallback to eager execution if a " + "batch runs into cascade attentions") if self.compilation_config.cudagraph_mode\ .requires_piecewise_compilation(): diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 56a2183f8e2c..e7ca5a3e5d52 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -62,9 +62,17 @@ def max_cudagraph_mode(self) -> 'CUDAGraphMode': def has_full_cudagraphs(self) -> bool: return self.max_cudagraph_mode() == CUDAGraphMode.FULL + def has_piecewise_cudagraphs(self) -> bool: + return self.requires_piecewise_compilation() + def separate_routine(self) -> bool: return isinstance(self.value, tuple) + def vaild_runtime_modes(self) -> bool: + return self in [ + CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL + ] + @config @dataclass @@ -544,20 +552,37 @@ def set_splitting_ops_for_v1(self): # full cudagraph outside the fx graph. This reduces some cpu # overhead when the runtime batch_size is not cudagraph captured. # see https://github.com/vllm-project/vllm/pull/20059 for details. - self.splitting_ops = self._attention_ops + if self.pass_config.enable_attn_fusion: + self.splitting_ops = [] + if self.cudagraph_mode.has_piecewise_cudagraphs(): + logger.warning_once( + "When enable_attn_fusion, splitting_ops will be set " + "to empty list, and cudagraph_mode containing " + "PIECEWISE will be treated as FULL cudagraph_mode. " + "Please ensure you are using attention backends that " + "support cudagraph or set cudagraph_mode to NONE " + "explicitly if encountering any problems.") + self.cudagraph_mode = CUDAGraphMode.FULL + else: + self.splitting_ops = self._attention_ops elif len(self.splitting_ops) == 0: logger.warning_once("Using piecewise compilation with empty " "splitting_ops.") - if self.cudagraph_mode == CUDAGraphMode.PIECEWISE: + if self.cudagraph_mode.has_piecewise_cudagraphs(): logger.warning_once( "When compilation level is piecewise with empty " - "splitting_ops, PIECEWISE cudagraph_mode will be " - "treated as FULL cudagraph_mode. Please ensure you are " + "splitting_ops, cudagraph_mode containing PIECEWISE will " + "be treated as FULL cudagraph_mode. Please ensure you are " "using attention backends that support cudagraph or set " "cudagraph_mode to NONE explicitly if encountering " "any problems.") self.cudagraph_mode = CUDAGraphMode.FULL self.splitting_ops = [] + else: # len(self.splitting_ops) > 0: + assert not self.pass_config.enable_attn_fusion or \ + not self.splitting_ops_contain_attention(), ( + "attention ops should not be in splitting_ops " + "when enable_attn_fusion is True") def splitting_ops_contain_attention(self) -> bool: return self.splitting_ops is not None and all( diff --git a/vllm/forward_context.py b/vllm/forward_context.py index c57c51d289ac..bcc54fcdf687 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -179,8 +179,7 @@ class ForwardContext: batch_descriptor: Optional[BatchDescriptor] = None def __post_init__(self): - assert self.cudagraph_runtime_mode in [ - CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], \ + assert self.cudagraph_runtime_mode.vaild_runtime_modes(), \ f"Invalid cudagraph runtime mode: {self.cudagraph_runtime_mode}" diff --git a/vllm/v1/cudagraph_dispatcher.py b/vllm/v1/cudagraph_dispatcher.py index 02e65820b7c0..7affc1f3af58 100644 --- a/vllm/v1/cudagraph_dispatcher.py +++ b/vllm/v1/cudagraph_dispatcher.py @@ -11,7 +11,8 @@ class CudagraphDispatcher: """ - Runtime cudagraph dispatcher to dispach keys for multiple set of cudagraphs. + Runtime cudagraph dispatcher to dispatch keys for multiple sets of + cudagraphs. The dispatcher stores two sets of dispatch keys, one for PIECEWISE and one for FULL cudagraph runtime mode. The keys are initialized depending on @@ -21,10 +22,10 @@ class CudagraphDispatcher: At runtime, the dispatch method generates the runtime cudagraph mode (FULL, PIECEWISE, or NONE for no cudagraph) and the valid key (batch descriptor) - based on the input key. After dispatching (commuicate via forward context), - the cudagraph wrappers will trust the dispatch key to do either capturing - or replaying (if mode matched), or pass through to the underlying runnable - without cudagraph (if mode no match or mode is NONE). + based on the input key. After dispatching (communicated via forward + context), the cudagraph wrappers will trust the dispatch key to either + capture or replay (if the mode matches), or pass through to the underlying + runnable without cudagraph (if the mode does not match or mode is NONE). """ def __init__(self, vllm_config: VllmConfig): @@ -52,19 +53,15 @@ def __init__(self, vllm_config: VllmConfig): def add_cudagraph_key(self, runtime_mode: CUDAGraphMode, batch_descriptor: BatchDescriptor): assert runtime_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], \ - f"Invalid cudagraph runtime mode: {runtime_mode}" + f"Invalid cudagraph runtime mode for keys: {runtime_mode}" self.cudagraph_keys[runtime_mode].add(batch_descriptor) def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode, uniform_decode_query_len: int): # This should be called only after attention backend is initialized. - # Note: we create all valid keys possible for cudagraph but do not - # guarantee all keys would be used. For example, we create keys for - # piecewise cudagraphs when it is piecewise compilation, which is always - # valid, but for attention backend support unified routine, we may not - # trigger capturing/replaying the piecewise cudagraphs depending on - # CompilationConfig.cudagraph_mode. In addition, if we allow lazy + # Note: we create all valid keys for cudagraph here but do not + # guarantee all keys would be used. For example, if we allow lazy # capturing in future PR, some keys may never be triggered. if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE: for bs in self.compilation_config.cudagraph_capture_sizes: @@ -89,10 +86,13 @@ def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode, self.keys_initialized = True def dispatch( - self, batch_descriptor: BatchDescriptor + self, + batch_descriptor: BatchDescriptor, + use_cascade_attn: bool = False ) -> tuple[CUDAGraphMode, Optional[BatchDescriptor]]: """ - Given a batch descriptor, dispatch to a cudagraph mode. + Given conditions(e.g.,batch descriptor and if using cascade attention), + dispatch to a cudagraph runtime mode and the valid batch descriptor. A new batch descriptor is returned as we might dispatch a uniform batch to a graph that supports a more general batch (uniform to non-uniform). """ @@ -102,14 +102,16 @@ def dispatch( "initialized. No cudagraph will be used.") return CUDAGraphMode.NONE, None - # check if key exists for full cudagraph - if batch_descriptor in self.cudagraph_keys[CUDAGraphMode.FULL]: - return CUDAGraphMode.FULL, batch_descriptor - - # otherwise, check if non-uniform key exists non_uniform_key = batch_descriptor.non_uniform - if non_uniform_key in self.cudagraph_keys[CUDAGraphMode.FULL]: - return CUDAGraphMode.FULL, non_uniform_key + # if a batch use cascade attention, bypass checking full cudagraphs + if not use_cascade_attn: + # check if key exists for full cudagraph + if batch_descriptor in self.cudagraph_keys[CUDAGraphMode.FULL]: + return CUDAGraphMode.FULL, batch_descriptor + + # otherwise, check if non-uniform key exists + if non_uniform_key in self.cudagraph_keys[CUDAGraphMode.FULL]: + return CUDAGraphMode.FULL, non_uniform_key # also check if non-uniform key exists for more "general" # piecewise cudagraph diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4c919b392fbd..8b0c82f2a6fe 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -684,11 +684,13 @@ def _prepare_inputs( self, scheduler_output: "SchedulerOutput", ) -> tuple[dict[str, Any], torch.Tensor, Optional[SpecDecodeMetadata], - np.ndarray, Optional[CommonAttentionMetadata], int]: + np.ndarray, Optional[CommonAttentionMetadata], int, bool]: """ :return: tuple[ attn_metadata: layer-to-attention_metadata mapping, - logits_indices, spec_decode_metadata + logits_indices, spec_decode_metadata, + num_scheduled_tokens, spec_decode_common_attn_metadata, + max_num_scheduled_tokens, use_cascade_attn ] """ total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens @@ -835,6 +837,7 @@ def _prepare_inputs( ) attn_metadata: dict[str, Any] = {} + use_cascade_attn = False # Prepare encoder attention metadata separately # (encoder layers are not in KV cache groups) @@ -903,6 +906,8 @@ def _prepare_inputs( common_attn_metadata=common_attn_metadata, )) + use_cascade_attn |= attn_metadata_i.use_cascade + fast_prefill_metadata = attn_metadata_i if (self.cache_config.kv_sharing_fast_prefill and self.kv_sharing_fast_prefill_eligible_layers): @@ -933,7 +938,7 @@ def _prepare_inputs( return (attn_metadata, logits_indices, spec_decode_metadata, num_scheduled_tokens, spec_decode_common_attn_metadata, - max_num_scheduled_tokens) + max_num_scheduled_tokens, use_cascade_attn) def _compute_cascade_attn_prefix_len( self, @@ -1512,7 +1517,8 @@ def execute_model( # Prepare the decoder inputs. (attn_metadata, logits_indices, spec_decode_metadata, num_scheduled_tokens_np, spec_decode_common_attn_metadata, - max_query_len) = (self._prepare_inputs(scheduler_output)) + max_query_len, + use_cascade_attn) = (self._prepare_inputs(scheduler_output)) num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE @@ -1588,7 +1594,8 @@ def execute_model( batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens, uniform_decode=uniform_decode) cudagraph_runtime_mode, batch_descriptor = \ - self.cudagraph_dispatcher.dispatch(batch_descriptor) + self.cudagraph_dispatcher.dispatch(batch_descriptor, + use_cascade_attn) # Run the model. # Use persistent buffers for CUDA graphs. @@ -2248,9 +2255,7 @@ def _dummy_run( skip_eplb: If True, skip EPLB state update. is_profile: If True, this is a profile run. """ - assert cudagraph_runtime_mode in { - CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL - } + assert cudagraph_runtime_mode.vaild_runtime_modes() # Padding for DP num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens) @@ -2704,9 +2709,9 @@ def freeze_gc(): def _capture_cudagraphs(self, compilation_cases: list[int], cudagraph_runtime_mode: CUDAGraphMode, uniform_decode: bool): - assert cudagraph_runtime_mode != CUDAGraphMode.NONE and \ - cudagraph_runtime_mode in [CUDAGraphMode.FULL, - CUDAGraphMode.PIECEWISE] + assert cudagraph_runtime_mode in [CUDAGraphMode.FULL, + CUDAGraphMode.PIECEWISE],\ + f"Invalid cudagraph runtime mode: {cudagraph_runtime_mode}" # Only rank 0 should print progress bar during capture if is_global_first_rank(): @@ -2848,6 +2853,12 @@ def create_attn_groups( self.is_encoder_only_model = True def initialize_cudagraph_capture(self) -> None: + """ + Resolve the cudagraph_mode when there are multiple + attention backends with conflicting CUDA graph support. + Initialize the cudagraph_dispatcher based on the resolved + cudagraph_mode. + """ min_cg_support = AttentionCGSupport.ALWAYS min_cg_builder_name = None From 648fbb3b48717f9a6e1073e9488c4676756bf7e7 Mon Sep 17 00:00:00 2001 From: fhl2000 <63384265+fhl2000@users.noreply.github.com> Date: Sun, 17 Aug 2025 08:08:37 +0000 Subject: [PATCH 03/21] apply suggestion from bot Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> --- vllm/config/compilation.py | 2 +- vllm/forward_context.py | 2 +- vllm/v1/worker/gpu_model_runner.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index e7ca5a3e5d52..360c71477b51 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -68,7 +68,7 @@ def has_piecewise_cudagraphs(self) -> bool: def separate_routine(self) -> bool: return isinstance(self.value, tuple) - def vaild_runtime_modes(self) -> bool: + def valid_runtime_modes(self) -> bool: return self in [ CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL ] diff --git a/vllm/forward_context.py b/vllm/forward_context.py index bcc54fcdf687..6812e6460cb4 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -179,7 +179,7 @@ class ForwardContext: batch_descriptor: Optional[BatchDescriptor] = None def __post_init__(self): - assert self.cudagraph_runtime_mode.vaild_runtime_modes(), \ + assert self.cudagraph_runtime_mode.valid_runtime_modes(), \ f"Invalid cudagraph runtime mode: {self.cudagraph_runtime_mode}" diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index c1b0697de1e5..cb211971c3e6 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2260,7 +2260,7 @@ def _dummy_run( skip_eplb: If True, skip EPLB state update. is_profile: If True, this is a profile run. """ - assert cudagraph_runtime_mode.vaild_runtime_modes() + assert cudagraph_runtime_mode.valid_runtime_modes() # Padding for DP num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens) From 868c85fcb436e1693aad2bbf13479ed6a3819b01 Mon Sep 17 00:00:00 2001 From: fhl2000 <63384265+fhl2000@users.noreply.github.com> Date: Sun, 17 Aug 2025 11:18:26 +0000 Subject: [PATCH 04/21] fix bug when attn_metadata have no use_cascade Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> --- vllm/v1/worker/gpu_model_runner.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index cb211971c3e6..a8ad25e45c80 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -911,7 +911,8 @@ def _prepare_inputs( common_attn_metadata=common_attn_metadata, )) - use_cascade_attn |= attn_metadata_i.use_cascade + use_cascade_attn |= getattr(attn_metadata_i, "use_cascade", + False) fast_prefill_metadata = attn_metadata_i if (self.cache_config.kv_sharing_fast_prefill From 3bef9e4b7bef8705e3d81b5f3130f53621b2a749 Mon Sep 17 00:00:00 2001 From: fhl2000 <63384265+fhl2000@users.noreply.github.com> Date: Sun, 17 Aug 2025 12:39:34 +0000 Subject: [PATCH 05/21] simple dispatching test Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> --- tests/v1/cudagraph/test_cudagraph_dispatch.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/v1/cudagraph/test_cudagraph_dispatch.py b/tests/v1/cudagraph/test_cudagraph_dispatch.py index 64f2fa462802..baa4a3df3d12 100644 --- a/tests/v1/cudagraph/test_cudagraph_dispatch.py +++ b/tests/v1/cudagraph/test_cudagraph_dispatch.py @@ -131,6 +131,16 @@ def test_dispatcher(self, params): assert rt_mode == CUDAGraphMode.NONE assert key is None + # 4. Cascade attention should have a fall back mode + desc_full_exact = BatchDescriptor(num_tokens=8, uniform_decode=False) + rt_mode, key = dispatcher.dispatch(desc_full_exact, + use_cascade_attn=True) + if "PIECEWISE" in params["cudagraph_mode"]: # string contains check + assert rt_mode == CUDAGraphMode.PIECEWISE + assert key == desc_full_exact.non_uniform + else: + assert rt_mode == CUDAGraphMode.NONE + @pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda") class TestCUDAGraphWrapper: From 05cf0122d48029557b846674eb33a2eb216ae714 Mon Sep 17 00:00:00 2001 From: fhl2000 <63384265+fhl2000@users.noreply.github.com> Date: Sun, 17 Aug 2025 12:57:06 +0000 Subject: [PATCH 06/21] minor comment tweak Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> --- vllm/v1/worker/gpu_model_runner.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a8ad25e45c80..d4cce16f03f8 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2860,9 +2860,9 @@ def create_attn_groups( def initialize_cudagraph_capture(self) -> None: """ - Resolve the cudagraph_mode when there are multiple - attention backends with conflicting CUDA graph support. - Initialize the cudagraph_dispatcher based on the resolved + Resolve the cudagraph_mode when there are multiple attention + backends with potential conflicting CUDA graph support. + Then initialize the cudagraph_dispatcher based on the resolved cudagraph_mode. """ min_cg_support = AttentionCGSupport.ALWAYS From 20d8afb01087ab46b7dca158e6aba3a7acc774e3 Mon Sep 17 00:00:00 2001 From: fhl2000 <63384265+fhl2000@users.noreply.github.com> Date: Fri, 22 Aug 2025 16:51:01 +0000 Subject: [PATCH 07/21] address comments part1 Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> --- vllm/config/__init__.py | 32 +++++++++++++------------------- vllm/config/compilation.py | 24 +++++++++++++++--------- 2 files changed, 28 insertions(+), 28 deletions(-) diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 52c1545743e5..8f53d45f7c5d 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -3624,23 +3624,22 @@ def __post_init__(self): "to True to enable.") current_platform.check_and_update_config(self) - # final check of cudagraph mode after platform-specific update + # Do this after all the updates to compilation_config.level + if envs.VLLM_USE_V1 and \ + self.compilation_config.level == CompilationLevel.PIECEWISE: + self.compilation_config.set_splitting_ops_for_v1() + + # final check of cudagraph mode after all possible updates if envs.VLLM_USE_V1 and current_platform.is_cuda_alike(): if self.compilation_config.cudagraph_mode.has_full_cudagraphs()\ and self.model_config is not None and \ - not self.model_config.disable_cascade_attn: - warn_msg = ("Cascade attention is not supported with full " - "cudagraphs currently. ") - if self.compilation_config.cudagraph_mode.\ - has_piecewise_cudagraphs(): - logger.warning_once( - warn_msg + "It will dispatched to " - "piecewise cudagraphs if a batch runs into cascade " - "attentions") - else: - logger.warning_once( - warn_msg + "It will fallback to eager execution if a " - "batch runs into cascade attentions") + not self.model_config.disable_cascade_attn and\ + not self.compilation_config.cudagraph_mode.\ + has_piecewise_cudagraphs(): + logger.warning_once( + "No piecewise cudagraph for executing cascade attention." + " Will fall back to eager execution if a batch runs " + "into cascade attentions") if self.compilation_config.cudagraph_mode\ .requires_piecewise_compilation(): @@ -3653,11 +3652,6 @@ def __post_init__(self): if not self.instance_id: self.instance_id = random_uuid()[:5] - # Do this after all the updates to compilation_config.level - if envs.VLLM_USE_V1 and \ - self.compilation_config.level == CompilationLevel.PIECEWISE: - self.compilation_config.set_splitting_ops_for_v1() - if (envs.VLLM_USE_V1 and not self.scheduler_config.disable_hybrid_kv_cache_manager): # logger should only print warning message for hybrid models. As we diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 360c71477b51..c4889054abe9 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -566,16 +566,22 @@ def set_splitting_ops_for_v1(self): else: self.splitting_ops = self._attention_ops elif len(self.splitting_ops) == 0: - logger.warning_once("Using piecewise compilation with empty " - "splitting_ops.") - if self.cudagraph_mode.has_piecewise_cudagraphs(): + logger.info_once("Using piecewise compilation with empty " + "splitting_ops.") + if self.cudagraph_mode == CUDAGraphMode.PIECEWISE: logger.warning_once( - "When compilation level is piecewise with empty " - "splitting_ops, cudagraph_mode containing PIECEWISE will " - "be treated as FULL cudagraph_mode. Please ensure you are " - "using attention backends that support cudagraph or set " - "cudagraph_mode to NONE explicitly if encountering " - "any problems.") + "Piecewise compilation with empty splitting_ops do not" \ + "contains piecewise cudagraph. Setting cudagraph_" + "mode to NONE. Hint: If you are using attention backends " + "that support cudagraph, consider manually setting " + "cudagraph_mode to FULL or FULL_DECODE_ONLY to enable " + "full cudagraphs.") + self.cudagraph_mode = CUDAGraphMode.NONE + elif self.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE: + logger.warning_once( + "Piecewise compilation with empty splitting_ops do not " + "contains piecewise cudagraph. Setting cudagraph_mode " + "to FULL.") self.cudagraph_mode = CUDAGraphMode.FULL self.splitting_ops = [] else: # len(self.splitting_ops) > 0: From da7949400734b63b86868ed78d4ce0e57b233e81 Mon Sep 17 00:00:00 2001 From: fhl2000 <63384265+fhl2000@users.noreply.github.com> Date: Sat, 23 Aug 2025 05:00:17 +0000 Subject: [PATCH 08/21] fix comments part2 Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> --- .../compile/piecewise/test_full_cudagraph.py | 69 +---------------- tests/v1/attention/utils.py | 74 ++++++++++++++++++- tests/v1/cudagraph/test_cudagraph_dispatch.py | 53 +++++-------- tests/v1/cudagraph/test_cudagraph_mode.py | 67 +++-------------- vllm/config/__init__.py | 6 ++ vllm/config/compilation.py | 8 +- vllm/v1/worker/gpu_model_runner.py | 9 +-- 7 files changed, 115 insertions(+), 171 deletions(-) diff --git a/tests/compile/piecewise/test_full_cudagraph.py b/tests/compile/piecewise/test_full_cudagraph.py index 97140a9db7af..a08bf5a44482 100644 --- a/tests/compile/piecewise/test_full_cudagraph.py +++ b/tests/compile/piecewise/test_full_cudagraph.py @@ -3,12 +3,11 @@ import contextlib import os import weakref -from dataclasses import dataclass -from typing import Optional import pytest from tests.utils import wait_for_gpu_memory_to_clear +from tests.v1.attention.utils import full_cg_backend_configs as backend_configs from vllm import LLM, SamplingParams from vllm.config import CompilationConfig from vllm.platforms import current_platform @@ -33,72 +32,6 @@ def temporary_environ(env_vars): os.environ[k] = v -@dataclass -class BackendConfig: - name: str - env_vars: dict - comp_config: dict - specific_gpu_arch: Optional[tuple] = None - - -# Define all backend configurations of full cudagraph to be tested -backend_configs = { - # FA3 on Hopper - "FA3": - BackendConfig(name="FA3", - env_vars={"VLLM_FLASH_ATTN_VERSION": "3"}, - comp_config={ - "cudagraph_mode": "FULL", - }, - specific_gpu_arch=(9, 0)), - # FlashMLA on Hopper - "FlashMLA": - BackendConfig(name="FlashMLA", - env_vars={ - "VLLM_ATTENTION_BACKEND": "FLASHMLA", - }, - comp_config={ - "cudagraph_mode": "FULL_AND_PIECEWISE", - }, - specific_gpu_arch=(9, 0)), - # Cutlass MLA on Blackwell - "CutlassMLA": - BackendConfig( - name="CutlassMLA", - env_vars={ - "VLLM_USE_V1": "1", - "VLLM_ATTENTION_BACKEND": "CUTLASS_MLA", - "FORCE_NUM_KV_SPLITS": - "1", # TODO: remove this when hang issue is fixed - }, - comp_config={ - "cudagraph_mode": "FULL_AND_PIECEWISE", - "cudagraph_capture_sizes": [16, 32, 64, 128, 256, 512], - }, - specific_gpu_arch=(10, 0)), - # FA2 - "FA2": - BackendConfig(name="FA2", - env_vars={"VLLM_FLASH_ATTN_VERSION": "2"}, - comp_config={ - "cudagraph_mode": "FULL", - }), - # Triton Attention - "TritonAttn": - BackendConfig(name="TritonAttn", - env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN_VLLM_V1"}, - comp_config={ - "cudagraph_mode": "FULL", - }), - # FlashInfer - "FlashInfer": - BackendConfig(name="FlashInfer", - env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"}, - comp_config={ - "cudagraph_mode": "FULL_AND_PIECEWISE", - }), -} - test_params_full_cudagraph = [] # deepseek-ai/DeepSeek-V2-Lite with MLA diff --git a/tests/v1/attention/utils.py b/tests/v1/attention/utils.py index 6a08cdc56f73..6732828b3164 100644 --- a/tests/v1/attention/utils.py +++ b/tests/v1/attention/utils.py @@ -3,7 +3,7 @@ """Utility functions for attention-related v1 tests.""" from dataclasses import dataclass -from typing import Union +from typing import Optional, Union import pytest import torch @@ -256,3 +256,75 @@ def create_dummy_kv_cache(block_size: int, dtype=dtype, device=device) return kv_cache + + +@dataclass +class BackendConfig: + name: str + env_vars: dict + comp_config: dict # compilation config + specific_gpu_arch: Optional[tuple] = None + + +# Define all backend configurations of full cudagraph to be tested +full_cg_backend_configs = { + # FA3 on Hopper + "FA3": + BackendConfig(name="FA3", + env_vars={ + "VLLM_ATTENTION_BACKEND": "FLASH_ATTN", + "VLLM_FLASH_ATTN_VERSION": "3" + }, + comp_config={ + "cudagraph_mode": "FULL", + }, + specific_gpu_arch=(9, 0)), + # FlashMLA on Hopper + "FlashMLA": + BackendConfig(name="FlashMLA", + env_vars={ + "VLLM_ATTENTION_BACKEND": "FLASHMLA", + }, + comp_config={ + "cudagraph_mode": "FULL_AND_PIECEWISE", + }, + specific_gpu_arch=(9, 0)), + # Cutlass MLA on Blackwell + "CutlassMLA": + BackendConfig( + name="CutlassMLA", + env_vars={ + "VLLM_USE_V1": "1", + "VLLM_ATTENTION_BACKEND": "CUTLASS_MLA", + "FORCE_NUM_KV_SPLITS": + "1", # TODO: remove this when hang issue is fixed + }, + comp_config={ + "cudagraph_mode": "FULL_AND_PIECEWISE", + }, + specific_gpu_arch=(10, 0)), + # FA2 + "FA2": + BackendConfig(name="FA2", + env_vars={ + "VLLM_ATTENTION_BACKEND": "FLASH_ATTN", + "VLLM_FLASH_ATTN_VERSION": "2" + }, + comp_config={ + "cudagraph_mode": "FULL_AND_PIECEWISE", + }), + # Triton Attention + "TritonAttn": + BackendConfig(name="TritonAttn", + env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN_VLLM_V1"}, + comp_config={ + "cudagraph_mode": "FULL_AND_PIECEWISE", + }), + # FlashInfer + "FlashInfer": + BackendConfig(name="FlashInfer", + env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"}, + comp_config={ + "cudagraph_mode": "FULL_AND_PIECEWISE", + }), +} diff --git a/tests/v1/cudagraph/test_cudagraph_dispatch.py b/tests/v1/cudagraph/test_cudagraph_dispatch.py index baa4a3df3d12..b6b85e4440d0 100644 --- a/tests/v1/cudagraph/test_cudagraph_dispatch.py +++ b/tests/v1/cudagraph/test_cudagraph_dispatch.py @@ -45,39 +45,22 @@ def _create_vllm_config(compilation_config: CompilationConfig, class TestCudagraphDispatcher: @pytest.mark.parametrize( - "params", + "case_id,cudagraph_mode_str,compilation_level", [ # Test case 0: Full CG for mixed batches, no separate routine - { - "case_id": 0, - "cudagraph_mode": "FULL", - "compilation_level": CompilationLevel.NO_COMPILATION, - }, + (0, "FULL", CompilationLevel.NO_COMPILATION), # Test case 1: Full CG for uniform batches, piecewise for mixed - { - "case_id": 1, - "cudagraph_mode": "FULL_AND_PIECEWISE", - "compilation_level": CompilationLevel.PIECEWISE, - }, + (1, "FULL_AND_PIECEWISE", CompilationLevel.NO_COMPILATION), # Test case 2: Full CG for uniform batches, no CG for mixed - { - "case_id": 2, - "cudagraph_mode": "FULL_DECODE_ONLY", - "compilation_level": CompilationLevel.NO_COMPILATION, - }, + (2, "FULL_DECODE_ONLY", CompilationLevel.NO_COMPILATION), # Test case 3: Piecewise for all - { - "case_id": 3, - "cudagraph_mode": "PIECEWISE", - "compilation_level": CompilationLevel.PIECEWISE, - }, + (3, "PIECEWISE", CompilationLevel.PIECEWISE), ]) - def test_dispatcher(self, params): + def test_dispatcher(self, cudagraph_mode_str, compilation_level): # Setup dispatcher - comp_config = CompilationConfig( - cudagraph_mode=params["cudagraph_mode"], - level=params["compilation_level"], - cudagraph_capture_sizes=[1, 8]) + comp_config = CompilationConfig(cudagraph_mode=cudagraph_mode_str, + level=compilation_level, + cudagraph_capture_sizes=[1, 8]) config = _create_vllm_config(comp_config, max_num_seqs=8) dispatcher = CudagraphDispatcher(config) @@ -86,11 +69,11 @@ def test_dispatcher(self, params): uniform_decode_query_len=1) # Verify the key is initialized correctly - if params["cudagraph_mode"] in ["FULL_AND_PIECEWISE", "PIECEWISE"]: + if cudagraph_mode_str in ["FULL_AND_PIECEWISE", "PIECEWISE"]: assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 2 else: assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 0 - if params["cudagraph_mode"] not in ["NONE", "PIECEWISE"]: + if cudagraph_mode_str not in ["NONE", "PIECEWISE"]: assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 2 else: assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 0 @@ -99,10 +82,10 @@ def test_dispatcher(self, params): # 1. non-uniform batch, size in cudagraph size list desc_full_exact = BatchDescriptor(num_tokens=8, uniform_decode=False) rt_mode, key = dispatcher.dispatch(desc_full_exact) - if params["cudagraph_mode"] == "FULL": + if cudagraph_mode_str == "FULL": assert rt_mode == CUDAGraphMode.FULL assert key == desc_full_exact - elif params["cudagraph_mode"] in ["FULL_AND_PIECEWISE", "PIECEWISE"]: + elif cudagraph_mode_str in ["FULL_AND_PIECEWISE", "PIECEWISE"]: assert rt_mode == CUDAGraphMode.PIECEWISE assert key == desc_full_exact else: @@ -111,15 +94,13 @@ def test_dispatcher(self, params): # 2. uniform decode batch, size in cudagraph size list desc_uniform_exact = BatchDescriptor(num_tokens=8, uniform_decode=True) rt_mode, key = dispatcher.dispatch(desc_uniform_exact) - if params["cudagraph_mode"] == "FULL": + if cudagraph_mode_str == "FULL": assert rt_mode == CUDAGraphMode.FULL assert key == desc_uniform_exact.non_uniform - elif params["cudagraph_mode"] in [ - "FULL_DECODE_ONLY", "FULL_AND_PIECEWISE" - ]: + elif cudagraph_mode_str in ["FULL_DECODE_ONLY", "FULL_AND_PIECEWISE"]: assert rt_mode == CUDAGraphMode.FULL assert key == desc_uniform_exact - elif params["cudagraph_mode"] == "PIECEWISE": + elif cudagraph_mode_str == "PIECEWISE": assert rt_mode == CUDAGraphMode.PIECEWISE assert key == desc_uniform_exact.non_uniform else: @@ -135,7 +116,7 @@ def test_dispatcher(self, params): desc_full_exact = BatchDescriptor(num_tokens=8, uniform_decode=False) rt_mode, key = dispatcher.dispatch(desc_full_exact, use_cascade_attn=True) - if "PIECEWISE" in params["cudagraph_mode"]: # string contains check + if "PIECEWISE" in cudagraph_mode_str: # string contains check assert rt_mode == CUDAGraphMode.PIECEWISE assert key == desc_full_exact.non_uniform else: diff --git a/tests/v1/cudagraph/test_cudagraph_mode.py b/tests/v1/cudagraph/test_cudagraph_mode.py index 81655e417500..c4116247bb7c 100644 --- a/tests/v1/cudagraph/test_cudagraph_mode.py +++ b/tests/v1/cudagraph/test_cudagraph_mode.py @@ -4,12 +4,11 @@ import os import weakref from contextlib import ExitStack -from dataclasses import dataclass -from typing import Optional import pytest from tests.utils import wait_for_gpu_memory_to_clear +from tests.v1.attention.utils import full_cg_backend_configs as backend_configs from vllm import LLM from vllm.config import CompilationConfig from vllm.platforms import current_platform @@ -34,57 +33,6 @@ def temporary_environ(env_vars): os.environ[k] = v -@dataclass -class BackendConfig: - name: str - env_vars: dict - comp_config: dict - specific_gpu_arch: Optional[tuple] = None - - -# Define all backend configurations of full cudagraph to be tested -backend_configs = { - # FA3 on Hopper - "FA3": - BackendConfig(name="FA3", - env_vars={"VLLM_FLASH_ATTN_VERSION": "3"}, - comp_config={ - "cudagraph_mode": "FULL", - }, - specific_gpu_arch=(9, 0)), - # FlashMLA on Hopper - "FlashMLA": - BackendConfig(name="FlashMLA", - env_vars={ - "VLLM_ATTENTION_BACKEND": "FLASHMLA", - }, - comp_config={ - "cudagraph_mode": "FULL_AND_PIECEWISE", - }, - specific_gpu_arch=(9, 0)), - # FA2 - "FA2": - BackendConfig(name="FA2", - env_vars={"VLLM_FLASH_ATTN_VERSION": "2"}, - comp_config={ - "cudagraph_mode": "FULL_AND_PIECEWISE", - }), - # Triton Attention - "TritonAttn": - BackendConfig(name="TritonAttn", - env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN_VLLM_V1"}, - comp_config={ - "cudagraph_mode": "FULL_AND_PIECEWISE", - }), - # FlashInfer - "FlashInfer": - BackendConfig(name="FlashInfer", - env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"}, - comp_config={ - "cudagraph_mode": "FULL_AND_PIECEWISE", - }), -} - # test attention backend and cudagraph_mode combo # (backend_name, cudagraph_mode, supported) combo_cases_1 = [ @@ -97,9 +45,10 @@ class BackendConfig: ] -@pytest.mark.parametrize("combo_case", combo_cases_1) -def test_backend_and_cudagraph_mode_combo(combo_case): - backend_name, cudagraph_mode, supported = combo_case +@pytest.mark.parametrize("backend_name, cudagraph_mode, supported", + combo_cases_1) +def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode, + supported): if backend_name == "FlashInfer": try: import flashinfer # noqa: F401 @@ -125,7 +74,7 @@ def test_backend_and_cudagraph_mode_combo(combo_case): compilation_config=CompilationConfig( level=3, cudagraph_mode=cudagraph_mode)) llm.generate(["Hello, my name is"] * 10) - + # when above code raises, `llm` may be undefined, so we need to catch that try: llm = weakref.proxy(llm) del llm @@ -156,7 +105,8 @@ def test_backend_and_cudagraph_mode_combo(combo_case): ] -@pytest.mark.parametrize("combo_case", combo_cases_2) +@pytest.mark.parametrize("backend_name,cudagraph_mode,compilation_level,"\ + "supported", combo_cases_2) def test_cudagraph_compilation_combo(combo_case): backend_name, cudagraph_mode, compilation_level, supported\ = combo_case @@ -175,6 +125,7 @@ def test_cudagraph_compilation_combo(combo_case): compilation_config=CompilationConfig( level=compilation_level, cudagraph_mode=cudagraph_mode)) llm.generate(["Hello, my name is"] * 10) + # when above code raises, `llm` may be undefined, so we need to catch that try: llm = weakref.proxy(llm) del llm diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 39d7aa9e1291..45d8916628cc 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -3721,6 +3721,12 @@ def __post_init__(self): "when cudagraph_mode piecewise cudagraphs is used, "\ f"cudagraph_mode={self.compilation_config.cudagraph_mode}" + # final migrate the deprecated flags + self.compilation_config.use_cudagraph = self.compilation_config.\ + cudagraph_mode!= CUDAGraphMode.NONE + self.compilation_config.full_cuda_graph = self.compilation_config.\ + cudagraph_mode.has_full_cudagraphs() + if not self.instance_id: self.instance_id = random_uuid()[:5] diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 9f9bac1cb437..462ce0315683 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -256,7 +256,7 @@ class CompilationConfig: FULL_AND_PIECEWISE mode: Capture full cudagraph for decode batches and piecewise cudagraph for prefill and mixed prefill-decode batches. - This is like the most performant mode for most models. + This is generally the most performant mode for most models. Currently, the cudagraph mode is only used for the v1 engine. Note that the cudagraph logic is generally orthogonal to the @@ -278,7 +278,8 @@ class CompilationConfig: Note that this is orthogonal to the cudagraph capture logic outside of compilation. Warning: This flag is deprecated and will be removed in the next major or - minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode instead. + minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode=PIECEWISE + instead. """ cudagraph_num_of_warmups: int = 0 """Number of warmup runs for cudagraph. @@ -303,7 +304,8 @@ class CompilationConfig: flag cannot be used together with splitting_ops. This may provide performance benefits for smaller models. Warning: This flag is deprecated and will be removed in the next major or - minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode instead. + minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode=FULL + instead. """ pass_config: PassConfig = field(default_factory=PassConfig) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 88684a3bb69f..57c9ac8a7c94 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -829,7 +829,7 @@ def _prepare_inputs( logits_indices[-1].item()) if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE and num_logits <= self.cudagraph_batch_sizes[-1]): - # Use piecewise CUDA graphs. + # Use CUDA graphs. # Add padding to the batch size. num_logits_padded = self.vllm_config.pad_for_cudagraph( num_logits) @@ -2066,16 +2066,15 @@ def reload_weights(self) -> None: "Cannot reload weights before model is loaded." model_loader = get_model_loader(self.load_config) logger.info("Reloading weights inplace...") - model = self.get_model() - model_loader.load_weights(model, model_config=self.model_config) + model_loader.load_weights(self.get_model(), + model_config=self.model_config) def save_tensorized_model( self, tensorizer_config: "TensorizerConfig", ) -> None: - model = self.get_model() TensorizerLoader.save_model( - model, + self.get_model(), tensorizer_config=tensorizer_config, model_config=self.model_config, ) From 01083c393359dd104c0d1502a2acdef0d53d9976 Mon Sep 17 00:00:00 2001 From: fhl2000 <63384265+fhl2000@users.noreply.github.com> Date: Wed, 27 Aug 2025 09:51:19 +0000 Subject: [PATCH 09/21] fix double validation of deprecated flag Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> --- vllm/config/compilation.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 462ce0315683..747b1adac4cb 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -455,7 +455,8 @@ def __post_init__(self) -> None: if not self.use_cudagraph: logger.warning("use_cudagraph is deprecated, use " "cudagraph_mode=NONE instead.") - if self.cudagraph_mode is not None: + if self.cudagraph_mode is not None and \ + self.cudagraph_mode != CUDAGraphMode.NONE: raise ValueError( "use_cudagraph and cudagraph_mode are mutually" " exclusive, prefer cudagraph_mode since " @@ -464,7 +465,8 @@ def __post_init__(self) -> None: if self.full_cuda_graph: logger.warning("full_cuda_graph is deprecated, use " "cudagraph_mode=FULL instead.") - if self.cudagraph_mode is not None: + if self.cudagraph_mode is not None and \ + not self.cudagraph_mode.has_full_cudagraphs(): raise ValueError("full_cuda_graph and cudagraph_mode are " "mutually exclusive, prefer cudagraph_mode " "since full_cuda_graph is deprecated.") From 2bf556956548ab3dc60c498b6782da3d4f08db7b Mon Sep 17 00:00:00 2001 From: fhl2000 <63384265+fhl2000@users.noreply.github.com> Date: Thu, 4 Sep 2025 06:38:12 +0000 Subject: [PATCH 10/21] fix pre-commit Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> --- vllm/v1/worker/gpu_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 816ebbcfcead..6474af445b74 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -862,7 +862,7 @@ def _prepare_inputs( )) use_cascade_attn |= getattr(attn_metadata_i, "use_cascade", False) - + for layer_name in attn_group.layer_names: attn_metadata[layer_name] = attn_metadata_i From 27eecc29c199000a736a66ea3e507670bd61b544 Mon Sep 17 00:00:00 2001 From: fhl2000 <63384265+fhl2000@users.noreply.github.com> Date: Fri, 19 Sep 2025 17:22:25 +0000 Subject: [PATCH 11/21] disable cascade_attn when DBO Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> --- vllm/config/__init__.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 28d7067db6dc..e828af62862e 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -660,6 +660,11 @@ def __post_init__(self): "the VLLM_ALL2ALL_BACKEND environment variable to "\ "deepep_low_latency and install the DeepEP kerenls." + if not self.model_config.disable_cascade_attn: + self.model_config.disable_cascade_attn = True + logger.warning_once( + "Disabling cascade attention when DBO is enabled.") + if not self.instance_id: self.instance_id = random_uuid()[:5] From 48a8c7f2cd3a5c6aeca3af52162a9e0c2564073c Mon Sep 17 00:00:00 2001 From: fhl2000 <63384265+fhl2000@users.noreply.github.com> Date: Sat, 20 Sep 2025 16:37:20 +0000 Subject: [PATCH 12/21] pre-commit Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> --- vllm/config/compilation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index ccd733020d7c..1a64e53e0d37 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -603,7 +603,7 @@ def set_splitting_ops_for_v1(self): "support cudagraph or set cudagraph_mode to NONE " "explicitly if encountering any problems.") self.cudagraph_mode = CUDAGraphMode.FULL - else: + else: # NOTE: When using full cudagraph, instead of setting an empty # list and capture the full cudagraph inside the flattened fx # graph, we keep the piecewise fx graph structure but capture From f3e08f347fca44497c33ec26bc4425112de40df1 Mon Sep 17 00:00:00 2001 From: fhl2000 <63384265+fhl2000@users.noreply.github.com> Date: Tue, 23 Sep 2025 16:19:55 +0000 Subject: [PATCH 13/21] remove piecewise cudagraph wrapper when no needed Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> --- vllm/compilation/backends.py | 4 ++-- vllm/compilation/decorators.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 97fa0b230c8d..335bbda5e4eb 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -347,8 +347,8 @@ def call_module(self, target: torch.fx.node.Target, len(self.compile_submod_names), sym_shape_indices, compiled_graph_for_dynamic_shape, self.vllm_backend) - if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE - and + if (self.compilation_config.cudagraph_mode.\ + has_piecewise_cudagraphs() and not self.compilation_config.use_inductor_graph_partition): # We're using Dynamo-based piecewise splitting, so we wrap # the whole subgraph with a static graph wrapper. diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index b7a6e23c1aa7..af99b78fbeb0 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -335,7 +335,7 @@ def maybe_use_cudagraph_partition_wrapper(vllm_config: VllmConfig): from vllm.config import CUDAGraphMode compilation_config = vllm_config.compilation_config - if (compilation_config.cudagraph_mode != CUDAGraphMode.NONE + if (compilation_config.cudagraph_mode.has_piecewise_cudagraphs() and compilation_config.use_inductor_graph_partition): from torch._inductor.utils import CUDAGraphWrapperMetadata @@ -364,6 +364,6 @@ def customized_cudagraph_wrapper(f, yield - if (compilation_config.cudagraph_mode != CUDAGraphMode.NONE + if (compilation_config.cudagraph_mode.has_piecewise_cudagraphs() and compilation_config.use_inductor_graph_partition): torch._inductor.utils.set_customized_partition_wrappers(None) From 3faff9752deece558826ed0d27e19b232d9976e9 Mon Sep 17 00:00:00 2001 From: fhl2000 <63384265+fhl2000@users.noreply.github.com> Date: Tue, 23 Sep 2025 16:58:58 +0000 Subject: [PATCH 14/21] simplify set_splitting_ops_for_v1 Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> --- vllm/config/compilation.py | 95 +++++++++++++++--------------- vllm/v1/worker/gpu_model_runner.py | 4 +- 2 files changed, 51 insertions(+), 48 deletions(-) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 1a64e53e0d37..a7a5b524cba1 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -578,48 +578,29 @@ def set_splitting_ops_for_v1(self): "set_splitting_ops_for_v1 should only be called when " "level is CompilationLevel.PIECEWISE") - use_inductor_graph_partition_msg = ( - "When use_inductor_graph_partition=True, splitting_ops " - "are ignored and set to an empty list. Instead, " - "\"tags=(torch._C.Tag.cudagraph_unsafe, ),\" is " - "used to annotate custom ops for graph partition.") + if self.use_inductor_graph_partition: + self.set_splitting_ops_for_inductor_graph_partition() + return + + if self.pass_config.enable_attn_fusion: + # here use_inductor_graph_partition is False + self.set_splitting_ops_for_attn_fusion() + return if self.splitting_ops is None: - if self.use_inductor_graph_partition: - # When using inductor graph partition, we set splitting_ops - # to be empty and rely on torch._C.Tag.cudagraph_unsafe to - # annotate custom ops as splitting ops. - logger.warning_once(use_inductor_graph_partition_msg) - self.splitting_ops = [] - elif self.pass_config.enable_attn_fusion: - self.splitting_ops = [] - if self.cudagraph_mode.has_piecewise_cudagraphs(): - logger.warning_once( - "enable_attn_fusion is incompatible with piecewise " - "cudagraph when use_inductor_graph_partition is off." - "In this case, splitting_ops will be set to empty " - "list, and cudagraph_mode will be set to FULL. " - "Please ensure you are using attention backends that " - "support cudagraph or set cudagraph_mode to NONE " - "explicitly if encountering any problems.") - self.cudagraph_mode = CUDAGraphMode.FULL - else: - # NOTE: When using full cudagraph, instead of setting an empty - # list and capture the full cudagraph inside the flattened fx - # graph, we keep the piecewise fx graph structure but capture - # the full cudagraph outside the fx graph. This reduces some - # cpu overhead when the runtime batch_size is not cudagraph - # captured. see https://github.com/vllm-project/vllm/pull/20059 - # for details. Make a copy to avoid mutating the class-level - # list via reference. - self.splitting_ops = list(self._attention_ops) + # NOTE: When using full cudagraph, instead of setting an empty + # list and capture the full cudagraph inside the flattened fx + # graph, we keep the piecewise fx graph structure but capture + # the full cudagraph outside the fx graph. This reduces some + # cpu overhead when the runtime batch_size is not cudagraph + # captured. see https://github.com/vllm-project/vllm/pull/20059 + # for details. Make a copy to avoid mutating the class-level + # list via reference. + self.splitting_ops = list(self._attention_ops) elif len(self.splitting_ops) == 0: logger.warning_once( - "Using piecewise compilation with empty " - "splitting_ops and use_inductor_graph_partition" - f"={self.use_inductor_graph_partition}.") - if (self.cudagraph_mode == CUDAGraphMode.PIECEWISE - and not self.use_inductor_graph_partition): + "Using piecewise compilation with empty splitting_ops") + if self.cudagraph_mode == CUDAGraphMode.PIECEWISE: logger.warning_once( "Piecewise compilation with empty splitting_ops do not" \ "contains piecewise cudagraph. Setting cudagraph_" @@ -635,14 +616,36 @@ def set_splitting_ops_for_v1(self): "to FULL.") self.cudagraph_mode = CUDAGraphMode.FULL self.splitting_ops = [] - else: # len(self.splitting_ops) > 0: - if self.use_inductor_graph_partition: - logger.warning_once(use_inductor_graph_partition_msg) - self.splitting_ops = [] - assert not self.pass_config.enable_attn_fusion or \ - not self.splitting_ops_contain_attention(), ( - "attention ops should not be in splitting_ops " - "when enable_attn_fusion is True") + + def set_splitting_ops_for_inductor_graph_partition(self): + use_inductor_graph_partition_msg = ( + "When use_inductor_graph_partition=True, splitting_ops " + "are ignored and set to an empty list. Instead, " + "\"tags=(torch._C.Tag.cudagraph_unsafe, ),\" is " + "used to annotate custom ops for graph partition.") + if self.splitting_ops is not None and \ + len(self.splitting_ops) > 0: + logger.warning_once(use_inductor_graph_partition_msg) + self.splitting_ops = [] + + def set_splitting_ops_for_attn_fusion(self): + assert self.pass_config.enable_attn_fusion + if self.splitting_ops is None: + self.splitting_ops = [] + if self.cudagraph_mode.has_piecewise_cudagraphs(): + logger.warning_once( + "enable_attn_fusion is incompatible with piecewise " + "cudagraph when use_inductor_graph_partition is off." + "In this case, splitting_ops will be set to empty " + "list, and cudagraph_mode will be set to FULL. " + "Please ensure you are using attention backends that " + "support cudagraph or set cudagraph_mode to NONE " + "explicitly if encountering any problems.") + self.cudagraph_mode = CUDAGraphMode.FULL + + assert not self.splitting_ops_contain_attention(), ( + "attention ops should not be in splitting_ops " + "when enable_attn_fusion is True") def splitting_ops_contain_attention(self) -> bool: return self.splitting_ops is not None and all( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 84e110820e67..9ddaf456703a 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3390,8 +3390,8 @@ def freeze_gc(): def _capture_cudagraphs(self, compilation_cases: list[int], cudagraph_runtime_mode: CUDAGraphMode, uniform_decode: bool): - assert cudagraph_runtime_mode in [CUDAGraphMode.FULL, - CUDAGraphMode.PIECEWISE],\ + assert cudagraph_runtime_mode != CUDAGraphMode.NONE and \ + cudagraph_runtime_mode.valid_runtime_modes(), \ f"Invalid cudagraph runtime mode: {cudagraph_runtime_mode}" # Only rank 0 should print progress bar during capture From f09e47f5653399fcb46aa8d300c81f2552da026a Mon Sep 17 00:00:00 2001 From: fhl2000 <63384265+fhl2000@users.noreply.github.com> Date: Tue, 23 Sep 2025 17:42:44 +0000 Subject: [PATCH 15/21] resolve merged from main Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> --- tests/v1/attention/utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/v1/attention/utils.py b/tests/v1/attention/utils.py index e4af58830196..f7eda1e7467e 100644 --- a/tests/v1/attention/utils.py +++ b/tests/v1/attention/utils.py @@ -277,7 +277,8 @@ class BackendConfig: BackendConfig(name="FA3", env_vars={ "VLLM_ATTENTION_BACKEND": "FLASH_ATTN", - "VLLM_FLASH_ATTN_VERSION": "3" + "VLLM_FLASH_ATTN_VERSION": "3", + "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16", }, comp_config={ "cudagraph_mode": "FULL", @@ -312,6 +313,7 @@ class BackendConfig: BackendConfig(name="FlashAttentionMLA", env_vars={ "VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA", + "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16", }, comp_config={ "cudagraph_mode": "FULL_DECODE_ONLY", @@ -322,7 +324,8 @@ class BackendConfig: BackendConfig(name="FA2", env_vars={ "VLLM_ATTENTION_BACKEND": "FLASH_ATTN", - "VLLM_FLASH_ATTN_VERSION": "2" + "VLLM_FLASH_ATTN_VERSION": "2", + "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16", }, comp_config={ "cudagraph_mode": "FULL_AND_PIECEWISE", From 92cbd4ffe2dc7c512a4ca07e5b07dc13f57afaf4 Mon Sep 17 00:00:00 2001 From: fhl2000 <63384265+fhl2000@users.noreply.github.com> Date: Wed, 24 Sep 2025 01:26:49 +0000 Subject: [PATCH 16/21] pre-commit Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> --- vllm/v1/worker/gpu_model_runner.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 1df07891b93d..81a57fd428ee 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2868,7 +2868,8 @@ def _dummy_run( (1 token) and prefill (multiple tokens) requests. remove_lora: If False, dummy LoRAs are not destroyed after the run """ - assert cudagraph_runtime_mode.valid_runtime_modes() + assert cudagraph_runtime_mode is None or \ + cudagraph_runtime_mode.valid_runtime_modes() # If cudagraph_mode.decode_mode() == FULL and # cudagraph_mode.separate_routine(). This means that we are using From df905769546a18e635a7fe5d98f326dcc85ab2d4 Mon Sep 17 00:00:00 2001 From: fhl2000 <63384265+fhl2000@users.noreply.github.com> Date: Wed, 24 Sep 2025 01:32:49 +0000 Subject: [PATCH 17/21] modify comments for full_cuda_graph Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> --- vllm/config/compilation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index da8fa2164ea2..6e53975d9570 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -304,8 +304,8 @@ class CompilationConfig: flag cannot be used together with splitting_ops. This may provide performance benefits for smaller models. Warning: This flag is deprecated and will be removed in the next major or - minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode=FULL - instead. + minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode= + FULL_AND_PIECEWISE instead. """ use_inductor_graph_partition: bool = False From 4679802e45bb512e97c6b11278df007781ae86fd Mon Sep 17 00:00:00 2001 From: fhl2000 <63384265+fhl2000@users.noreply.github.com> Date: Wed, 24 Sep 2025 15:56:46 +0000 Subject: [PATCH 18/21] address comment;add test for splitting_ops Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> --- tests/compile/test_config.py | 67 ++++++++++++++++++++++++++++++++++-- vllm/config/compilation.py | 1 + 2 files changed, 66 insertions(+), 2 deletions(-) diff --git a/tests/compile/test_config.py b/tests/compile/test_config.py index 7afd6251bbbd..17d3f0b37768 100644 --- a/tests/compile/test_config.py +++ b/tests/compile/test_config.py @@ -4,7 +4,7 @@ import vllm from vllm.compilation.counter import compilation_counter -from vllm.config import CompilationConfig, VllmConfig +from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig from vllm.utils import _is_torch_equal_or_newer @@ -106,7 +106,6 @@ def test_dynamo_as_is(vllm_runner, monkeypatch): def test_no_compilation(vllm_runner, monkeypatch): # Disable multiprocessing so that the counter is in the same process monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0') - with ( compilation_counter.expect(num_graphs_seen=0, dynamo_as_is_count=0), @@ -131,3 +130,67 @@ def test_enforce_eager(vllm_runner, monkeypatch): enforce_eager=True, gpu_memory_utilization=0.4) as _): pass + + +def test_splitting_ops_dynamic(): + # Default config + config = VllmConfig() + assert config.compilation_config.cudagraph_mode == \ + CUDAGraphMode.FULL_AND_PIECEWISE + assert config.compilation_config.splitting_ops_contain_attention() + + # When use_inductor_graph_partition=True + if _is_torch_equal_or_newer('2.9.0.dev'): + # inductor graph partition is only available in PyTorch 2.9+. + # this is a fast config check so we are not using pytest.skip. + config = VllmConfig(compilation_config=CompilationConfig( + use_inductor_graph_partition=True, + splitting_ops=["silly_attention"])) + # should ignore splitting_ops + assert config.compilation_config.splitting_ops == [] + + # When attn_fusion pass enabled. + config = VllmConfig(compilation_config=CompilationConfig( + pass_config={ + "enable_attn_fusion": True, + "enable_noop": True + }, + custom_ops=["+quant_fp8"], + cudagraph_mode=CUDAGraphMode.PIECEWISE, + )) + assert config.compilation_config.splitting_ops == [] + # cudagraph mode also fall back to FULL + assert config.compilation_config.cudagraph_mode == \ + CUDAGraphMode.FULL + + # splitting_ops can not contain attention ops when attn_fusion + # pass enabled. + with pytest.raises(AssertionError): + config = VllmConfig(compilation_config=CompilationConfig( + pass_config={ + "enable_attn_fusion": True, + "enable_noop": True + }, + custom_ops=["+quant_fp8"], + cudagraph_mode=CUDAGraphMode.PIECEWISE, + # work around for accessing all attntion ops + splitting_ops=CompilationConfig()._attention_ops, + )) + + # When both use_inductor_graph_partition and attn_fusion pass enabled. + if _is_torch_equal_or_newer('2.9.0.dev'): + config = VllmConfig(compilation_config=CompilationConfig( + use_inductor_graph_partition=True, + pass_config={ + "enable_attn_fusion": True, + "enable_noop": True + }, + custom_ops=["+quant_fp8"], + cudagraph_mode=CUDAGraphMode.PIECEWISE, + )) + assert config.compilation_config.splitting_ops == [] + # enable_attn_fusion is directly support under + # use_inductor_graph_partition=True, and cudagraph_mode + # is unchanged. + assert config.compilation_config.cudagraph_mode == \ + CUDAGraphMode.PIECEWISE diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 6e53975d9570..cc099a72b4d4 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -623,6 +623,7 @@ def set_splitting_ops_for_v1(self): self.splitting_ops = [] def set_splitting_ops_for_inductor_graph_partition(self): + assert self.use_inductor_graph_partition use_inductor_graph_partition_msg = ( "When use_inductor_graph_partition=True, splitting_ops " "are ignored and set to an empty list. Instead, " From 5475e9e5dee8aca4db62bebc538f194f4e380798 Mon Sep 17 00:00:00 2001 From: fhl2000 <63384265+fhl2000@users.noreply.github.com> Date: Thu, 25 Sep 2025 05:16:11 +0000 Subject: [PATCH 19/21] fix profile_run log Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> --- vllm/v1/worker/gpu_model_runner.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e9e5f1d57ea6..38ecc088742e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3053,7 +3053,8 @@ def _dummy_run( # filter out the valid batch descriptor _cg_mode, batch_descriptor = self.cudagraph_dispatcher.dispatch( BatchDescriptor(num_tokens=num_tokens, - uniform_decode=uniform_decode)) + uniform_decode=uniform_decode)) \ + if not is_profile else (CUDAGraphMode.NONE, None) if cudagraph_runtime_mode is not None: # we allow forcing NONE when the dispatcher disagrees to support # warm ups for cudagraph capture From 413079bb09f62a0aa9669de788a38243198646e8 Mon Sep 17 00:00:00 2001 From: fhl2000 <63384265+fhl2000@users.noreply.github.com> Date: Thu, 25 Sep 2025 05:18:06 +0000 Subject: [PATCH 20/21] temporary disable cascade attention Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> --- vllm/config/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index df11f7541b20..1d662f4198aa 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -493,6 +493,10 @@ def __post_init__(self): " Will fall back to eager execution if a batch runs " "into cascade attentions") + # Temporarily disable cascade attention to eval CI + # TODO: remove this line later. + self.model_config.disable_cascade_attn = True + if self.compilation_config.cudagraph_mode\ .requires_piecewise_compilation(): assert self.compilation_config.level == \ From d8a1ad7139d560260e9218cd537f477194e1d6c6 Mon Sep 17 00:00:00 2001 From: fhl2000 <63384265+fhl2000@users.noreply.github.com> Date: Thu, 25 Sep 2025 10:23:48 +0000 Subject: [PATCH 21/21] recover Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> --- vllm/config/__init__.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 1d662f4198aa..df11f7541b20 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -493,10 +493,6 @@ def __post_init__(self): " Will fall back to eager execution if a batch runs " "into cascade attentions") - # Temporarily disable cascade attention to eval CI - # TODO: remove this line later. - self.model_config.disable_cascade_attn = True - if self.compilation_config.cudagraph_mode\ .requires_piecewise_compilation(): assert self.compilation_config.level == \