Skip to content

Commit cbb27fe

Browse files
yiz-liulilinsiman
andauthored
[Test] Add ACL graph capture/replay DP test (#4259)
### What this PR does / why we need it? Add ACL graph capture/replay DP test, this is a imprved version of #3886 Restructures the multi-card ACL graph test for improved clarity, robustness, and accuracy. Key improvements include: - Replaces fragile `sys.settrace` and manual patching with a clean, reusable spy installer using `unittest.mock.patch`. - Introduces more precise metrics by tracking `NPUModelRunner.execute_model` and `_dummy_run` calls directly. - Rewrites assertions to be more accurate and provides clear explanations for the expected counts of graph captures, replays, model executions, and dummy runs. - Simplifies the overall test structure by separating the worker logic into a dedicated function. - Removes a long, unnecessary sleep at the end of the test. - Expands test coverage by adding a larger `max_tokens` parameter. ### Does this PR introduce _any_ user-facing change? None. ### How was this patch tested? None. - vLLM version: v0.11.0 - vLLM main: vllm-project/vllm@2918c1b --------- Signed-off-by: lilinsiman <lilinsiman@gmail.com> Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com> Co-authored-by: lilinsiman <lilinsiman@gmail.com>
1 parent d96d5fa commit cbb27fe

File tree

2 files changed

+238
-0
lines changed

2 files changed

+238
-0
lines changed

.github/workflows/_e2e_test.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ jobs:
182182
VLLM_USE_MODELSCOPE: True
183183
if: ${{ inputs.type == 'full' }}
184184
run: |
185+
pytest -sv tests/e2e/multicard/test_aclgraph_capture_replay.py
185186
pytest -sv tests/e2e/multicard/test_torchair_graph_mode.py
186187
pytest -sv tests/e2e/multicard/test_full_graph_mode.py
187188
pytest -sv tests/e2e/multicard/test_data_parallel.py
Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
2+
# Copyright 2023 The vLLM team.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import contextlib
17+
import gc
18+
import math
19+
import multiprocessing
20+
import os
21+
from typing import Any
22+
from unittest.mock import patch
23+
24+
import pytest
25+
import torch
26+
27+
from vllm_ascend.utils import vllm_version_is
28+
29+
if vllm_version_is("0.11.0"):
30+
from vllm.utils import get_open_port
31+
else:
32+
from vllm.utils.network_utils import get_open_port
33+
34+
MODELS = [
35+
"Qwen/Qwen3-0.6B",
36+
"vllm-ascend/DeepSeek-V2-Lite-W8A8",
37+
]
38+
39+
40+
def _install_spies(counters: dict[str, Any]) -> contextlib.ExitStack:
41+
"""Installs thread-safe spies on NPU methods to track invocation counts."""
42+
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
43+
44+
def make_spy(cls, method_name, counter):
45+
original = getattr(cls, method_name)
46+
47+
def spy(self, *args, **kwargs):
48+
with counter.get_lock():
49+
counter.value += 1
50+
return original(self, *args, **kwargs)
51+
52+
return spy
53+
54+
stack = contextlib.ExitStack()
55+
hooks = [
56+
(torch.npu.NPUGraph, "replay", counters["replay"]),
57+
(torch.npu.NPUGraph, "__init__", counters["capture"]),
58+
(NPUModelRunner, "execute_model", counters["exec_model"]),
59+
(NPUModelRunner, "_dummy_run", counters["dummy_run"]),
60+
]
61+
62+
for cls, method, counter in hooks:
63+
stack.enter_context(
64+
patch.object(cls, method, make_spy(cls, method, counter)))
65+
66+
return stack
67+
68+
69+
def _run_worker_process(
70+
rank: int,
71+
local_rank: int,
72+
world_size: int,
73+
master_ip: str,
74+
master_port: int,
75+
counters: dict[str, Any],
76+
model_path: str,
77+
max_tokens: int,
78+
):
79+
"""Main entry point for the worker process."""
80+
os.environ.update({
81+
"VLLM_DP_RANK": str(rank),
82+
"VLLM_DP_RANK_LOCAL": str(local_rank),
83+
"VLLM_DP_SIZE": str(world_size),
84+
"VLLM_DP_MASTER_IP": master_ip,
85+
"VLLM_DP_MASTER_PORT": str(master_port),
86+
})
87+
88+
# Import vLLM only after environment setup
89+
from vllm import LLM, SamplingParams
90+
from vllm.distributed.parallel_state import (
91+
destroy_distributed_environment, destroy_model_parallel)
92+
93+
# Apply hooks and run inference
94+
with _install_spies(counters):
95+
prompts = [
96+
"Hello, my name is",
97+
"The president of the United States is",
98+
"The capital of France is",
99+
"The future of AI is",
100+
]
101+
102+
# Simple data sharding
103+
chunk_size = len(prompts) // world_size
104+
start_idx = rank * chunk_size
105+
end_idx = start_idx + chunk_size if rank < world_size - 1 else len(
106+
prompts)
107+
local_prompts = prompts[start_idx:end_idx]
108+
109+
llm = LLM(
110+
model=model_path,
111+
quantization="ascend" if "W8A8" in model_path else None,
112+
# enable_expert_parallel=True if "DeepSeek" in model_path else False,
113+
trust_remote_code=True,
114+
)
115+
116+
# Expose model config to the main test process
117+
counters["hidden_layers"].value = (
118+
llm.llm_engine.model_config.hf_config.num_hidden_layers)
119+
120+
llm.generate(local_prompts,
121+
SamplingParams(max_tokens=max_tokens, temperature=0.0))
122+
123+
# Explicit cleanup is mandatory in multi-process vLLM tests
124+
del llm
125+
126+
destroy_model_parallel()
127+
destroy_distributed_environment()
128+
129+
with contextlib.suppress(AssertionError):
130+
torch.distributed.destroy_process_group()
131+
132+
gc.collect()
133+
torch.npu.empty_cache()
134+
torch.npu.reset_peak_memory_stats()
135+
136+
137+
# @patch.dict(os.environ, clear=["HCCL_OP_EXPANSION_MODE","VLLM_WORKER_MULTIPROC_METHOD"])
138+
@pytest.mark.parametrize("model", MODELS)
139+
@pytest.mark.parametrize("max_tokens", [4, 36])
140+
@patch.dict(os.environ, {"ASCEND_RT_VISIBLE_DEVICES": "0,1"})
141+
def test_aclgraph_capture_replay_dp2(
142+
model: str,
143+
max_tokens: int,
144+
monkeypatch: pytest.MonkeyPatch,
145+
) -> None:
146+
# Counter doesn't work in default "spawn" mode
147+
monkeypatch.delenv("VLLM_WORKER_MULTIPROC_METHOD", raising=False)
148+
149+
# Shared counters for cross-process assertion
150+
counters = {
151+
"replay": multiprocessing.Value("i", 0),
152+
"capture": multiprocessing.Value("i", 0),
153+
"exec_model": multiprocessing.Value("i", 0),
154+
"dummy_run": multiprocessing.Value("i", 0),
155+
"hidden_layers": multiprocessing.Value("i", -1),
156+
}
157+
158+
dp_size = 2
159+
port = get_open_port()
160+
161+
# Launch workers
162+
workers = []
163+
for rank in range(dp_size):
164+
p = multiprocessing.Process(
165+
target=_run_worker_process,
166+
args=(rank, rank, dp_size, "127.0.0.1", port, counters, model,
167+
max_tokens),
168+
)
169+
p.start()
170+
workers.append(p)
171+
172+
# Supervision loop
173+
for p in workers:
174+
p.join(timeout=900)
175+
if p.exitcode != 0:
176+
for k in workers:
177+
if k.is_alive():
178+
k.kill()
179+
raise RuntimeError(
180+
f"Worker {p.pid} failed with exit code {p.exitcode}")
181+
182+
actual_capture = counters["capture"].value
183+
actual_replay = counters["replay"].value
184+
num_execute_model = counters["exec_model"].value
185+
num_dummy_run = counters["dummy_run"].value
186+
num_layers = counters["hidden_layers"].value
187+
188+
num_acl_graphs = num_layers + 1
189+
num_comm_groups = sum(1 for s in [dp_size, 1]
190+
if s > 1) # dp_size=2, tp_size=1
191+
192+
# Metric 1: Graph Capture (ACL Graph Construction)
193+
# Ref: vllm_ascend.utils.update_aclgraph_sizes
194+
max_batch_sizes = math.floor((1800 - num_comm_groups * 40) /
195+
num_acl_graphs / (1 + num_comm_groups * 2))
196+
197+
expected_capture = max_batch_sizes * num_acl_graphs * dp_size
198+
assert (
199+
actual_capture == expected_capture
200+
), f"Capture count mismatch. Expected: {expected_capture}, Got: {actual_capture}"
201+
202+
# Metric 2: Model Execution (NPUModelRunner.execute_model)
203+
# vLLM Step Breakdown:
204+
# 1. First step (prefill, 1 prompt)
205+
# 2. Generation steps (max_tokens)
206+
# 3. Final step (likely EOS/idle step), no replay here
207+
total_steps = max_tokens + 1 # this includes the 1 and 2 above
208+
expected_exec_model = (total_steps + 1) * dp_size
209+
210+
assert (
211+
num_execute_model == expected_exec_model
212+
), f"Model execution count mismatch. Expected: {expected_exec_model}, Got: {num_execute_model}"
213+
214+
# Metric 3: Dummy Runs (Warmup & Alignment)
215+
# vLLM synchronizes globally every 32 steps.
216+
# Ref: vllm.v1.engine.core.DPEngineCoreProc._has_global_unfinished_reqs
217+
aligned_steps = (total_steps + 31) // 32 * 32
218+
219+
# Part A: Warmup runs (Profile run + 2 runs per captured graph)
220+
warmup_runs = 1 + (2 * max_batch_sizes)
221+
222+
# Part B: Alignment padding (Empty runs to hit the 32-step boundary)
223+
padding_runs = aligned_steps - total_steps
224+
225+
expected_dummy_run = (warmup_runs + padding_runs) * dp_size
226+
227+
assert (
228+
num_dummy_run == expected_dummy_run
229+
), f"Dummy run count mismatch. Expected: {expected_dummy_run}, Got: {num_dummy_run}"
230+
231+
# Metric 4: Graph Replay (Inference Execution)
232+
# Replays happen for every aligned step across all graphs.
233+
expected_replay = num_acl_graphs * aligned_steps * dp_size
234+
235+
assert (
236+
actual_replay == expected_replay
237+
), f"Replay count mismatch. Expected: {expected_replay}, Got: {actual_replay}"

0 commit comments

Comments
 (0)