|
| 1 | +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. |
| 2 | +# Copyright 2023 The vLLM team. |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + |
| 16 | +import contextlib |
| 17 | +import gc |
| 18 | +import math |
| 19 | +import multiprocessing |
| 20 | +import os |
| 21 | +from typing import Any |
| 22 | +from unittest.mock import patch |
| 23 | + |
| 24 | +import pytest |
| 25 | +import torch |
| 26 | + |
| 27 | +from vllm_ascend.utils import vllm_version_is |
| 28 | + |
| 29 | +if vllm_version_is("0.11.0"): |
| 30 | + from vllm.utils import get_open_port |
| 31 | +else: |
| 32 | + from vllm.utils.network_utils import get_open_port |
| 33 | + |
| 34 | +MODELS = [ |
| 35 | + "Qwen/Qwen3-0.6B", |
| 36 | + "vllm-ascend/DeepSeek-V2-Lite-W8A8", |
| 37 | +] |
| 38 | + |
| 39 | + |
| 40 | +def _install_spies(counters: dict[str, Any]) -> contextlib.ExitStack: |
| 41 | + """Installs thread-safe spies on NPU methods to track invocation counts.""" |
| 42 | + from vllm_ascend.worker.model_runner_v1 import NPUModelRunner |
| 43 | + |
| 44 | + def make_spy(cls, method_name, counter): |
| 45 | + original = getattr(cls, method_name) |
| 46 | + |
| 47 | + def spy(self, *args, **kwargs): |
| 48 | + with counter.get_lock(): |
| 49 | + counter.value += 1 |
| 50 | + return original(self, *args, **kwargs) |
| 51 | + |
| 52 | + return spy |
| 53 | + |
| 54 | + stack = contextlib.ExitStack() |
| 55 | + hooks = [ |
| 56 | + (torch.npu.NPUGraph, "replay", counters["replay"]), |
| 57 | + (torch.npu.NPUGraph, "__init__", counters["capture"]), |
| 58 | + (NPUModelRunner, "execute_model", counters["exec_model"]), |
| 59 | + (NPUModelRunner, "_dummy_run", counters["dummy_run"]), |
| 60 | + ] |
| 61 | + |
| 62 | + for cls, method, counter in hooks: |
| 63 | + stack.enter_context( |
| 64 | + patch.object(cls, method, make_spy(cls, method, counter))) |
| 65 | + |
| 66 | + return stack |
| 67 | + |
| 68 | + |
| 69 | +def _run_worker_process( |
| 70 | + rank: int, |
| 71 | + local_rank: int, |
| 72 | + world_size: int, |
| 73 | + master_ip: str, |
| 74 | + master_port: int, |
| 75 | + counters: dict[str, Any], |
| 76 | + model_path: str, |
| 77 | + max_tokens: int, |
| 78 | +): |
| 79 | + """Main entry point for the worker process.""" |
| 80 | + os.environ.update({ |
| 81 | + "VLLM_DP_RANK": str(rank), |
| 82 | + "VLLM_DP_RANK_LOCAL": str(local_rank), |
| 83 | + "VLLM_DP_SIZE": str(world_size), |
| 84 | + "VLLM_DP_MASTER_IP": master_ip, |
| 85 | + "VLLM_DP_MASTER_PORT": str(master_port), |
| 86 | + }) |
| 87 | + |
| 88 | + # Import vLLM only after environment setup |
| 89 | + from vllm import LLM, SamplingParams |
| 90 | + from vllm.distributed.parallel_state import ( |
| 91 | + destroy_distributed_environment, destroy_model_parallel) |
| 92 | + |
| 93 | + # Apply hooks and run inference |
| 94 | + with _install_spies(counters): |
| 95 | + prompts = [ |
| 96 | + "Hello, my name is", |
| 97 | + "The president of the United States is", |
| 98 | + "The capital of France is", |
| 99 | + "The future of AI is", |
| 100 | + ] |
| 101 | + |
| 102 | + # Simple data sharding |
| 103 | + chunk_size = len(prompts) // world_size |
| 104 | + start_idx = rank * chunk_size |
| 105 | + end_idx = start_idx + chunk_size if rank < world_size - 1 else len( |
| 106 | + prompts) |
| 107 | + local_prompts = prompts[start_idx:end_idx] |
| 108 | + |
| 109 | + llm = LLM( |
| 110 | + model=model_path, |
| 111 | + quantization="ascend" if "W8A8" in model_path else None, |
| 112 | + # enable_expert_parallel=True if "DeepSeek" in model_path else False, |
| 113 | + trust_remote_code=True, |
| 114 | + ) |
| 115 | + |
| 116 | + # Expose model config to the main test process |
| 117 | + counters["hidden_layers"].value = ( |
| 118 | + llm.llm_engine.model_config.hf_config.num_hidden_layers) |
| 119 | + |
| 120 | + llm.generate(local_prompts, |
| 121 | + SamplingParams(max_tokens=max_tokens, temperature=0.0)) |
| 122 | + |
| 123 | + # Explicit cleanup is mandatory in multi-process vLLM tests |
| 124 | + del llm |
| 125 | + |
| 126 | + destroy_model_parallel() |
| 127 | + destroy_distributed_environment() |
| 128 | + |
| 129 | + with contextlib.suppress(AssertionError): |
| 130 | + torch.distributed.destroy_process_group() |
| 131 | + |
| 132 | + gc.collect() |
| 133 | + torch.npu.empty_cache() |
| 134 | + torch.npu.reset_peak_memory_stats() |
| 135 | + |
| 136 | + |
| 137 | +# @patch.dict(os.environ, clear=["HCCL_OP_EXPANSION_MODE","VLLM_WORKER_MULTIPROC_METHOD"]) |
| 138 | +@pytest.mark.parametrize("model", MODELS) |
| 139 | +@pytest.mark.parametrize("max_tokens", [4, 36]) |
| 140 | +@patch.dict(os.environ, {"ASCEND_RT_VISIBLE_DEVICES": "0,1"}) |
| 141 | +def test_aclgraph_capture_replay_dp2( |
| 142 | + model: str, |
| 143 | + max_tokens: int, |
| 144 | + monkeypatch: pytest.MonkeyPatch, |
| 145 | +) -> None: |
| 146 | + # Counter doesn't work in default "spawn" mode |
| 147 | + monkeypatch.delenv("VLLM_WORKER_MULTIPROC_METHOD", raising=False) |
| 148 | + |
| 149 | + # Shared counters for cross-process assertion |
| 150 | + counters = { |
| 151 | + "replay": multiprocessing.Value("i", 0), |
| 152 | + "capture": multiprocessing.Value("i", 0), |
| 153 | + "exec_model": multiprocessing.Value("i", 0), |
| 154 | + "dummy_run": multiprocessing.Value("i", 0), |
| 155 | + "hidden_layers": multiprocessing.Value("i", -1), |
| 156 | + } |
| 157 | + |
| 158 | + dp_size = 2 |
| 159 | + port = get_open_port() |
| 160 | + |
| 161 | + # Launch workers |
| 162 | + workers = [] |
| 163 | + for rank in range(dp_size): |
| 164 | + p = multiprocessing.Process( |
| 165 | + target=_run_worker_process, |
| 166 | + args=(rank, rank, dp_size, "127.0.0.1", port, counters, model, |
| 167 | + max_tokens), |
| 168 | + ) |
| 169 | + p.start() |
| 170 | + workers.append(p) |
| 171 | + |
| 172 | + # Supervision loop |
| 173 | + for p in workers: |
| 174 | + p.join(timeout=900) |
| 175 | + if p.exitcode != 0: |
| 176 | + for k in workers: |
| 177 | + if k.is_alive(): |
| 178 | + k.kill() |
| 179 | + raise RuntimeError( |
| 180 | + f"Worker {p.pid} failed with exit code {p.exitcode}") |
| 181 | + |
| 182 | + actual_capture = counters["capture"].value |
| 183 | + actual_replay = counters["replay"].value |
| 184 | + num_execute_model = counters["exec_model"].value |
| 185 | + num_dummy_run = counters["dummy_run"].value |
| 186 | + num_layers = counters["hidden_layers"].value |
| 187 | + |
| 188 | + num_acl_graphs = num_layers + 1 |
| 189 | + num_comm_groups = sum(1 for s in [dp_size, 1] |
| 190 | + if s > 1) # dp_size=2, tp_size=1 |
| 191 | + |
| 192 | + # Metric 1: Graph Capture (ACL Graph Construction) |
| 193 | + # Ref: vllm_ascend.utils.update_aclgraph_sizes |
| 194 | + max_batch_sizes = math.floor((1800 - num_comm_groups * 40) / |
| 195 | + num_acl_graphs / (1 + num_comm_groups * 2)) |
| 196 | + |
| 197 | + expected_capture = max_batch_sizes * num_acl_graphs * dp_size |
| 198 | + assert ( |
| 199 | + actual_capture == expected_capture |
| 200 | + ), f"Capture count mismatch. Expected: {expected_capture}, Got: {actual_capture}" |
| 201 | + |
| 202 | + # Metric 2: Model Execution (NPUModelRunner.execute_model) |
| 203 | + # vLLM Step Breakdown: |
| 204 | + # 1. First step (prefill, 1 prompt) |
| 205 | + # 2. Generation steps (max_tokens) |
| 206 | + # 3. Final step (likely EOS/idle step), no replay here |
| 207 | + total_steps = max_tokens + 1 # this includes the 1 and 2 above |
| 208 | + expected_exec_model = (total_steps + 1) * dp_size |
| 209 | + |
| 210 | + assert ( |
| 211 | + num_execute_model == expected_exec_model |
| 212 | + ), f"Model execution count mismatch. Expected: {expected_exec_model}, Got: {num_execute_model}" |
| 213 | + |
| 214 | + # Metric 3: Dummy Runs (Warmup & Alignment) |
| 215 | + # vLLM synchronizes globally every 32 steps. |
| 216 | + # Ref: vllm.v1.engine.core.DPEngineCoreProc._has_global_unfinished_reqs |
| 217 | + aligned_steps = (total_steps + 31) // 32 * 32 |
| 218 | + |
| 219 | + # Part A: Warmup runs (Profile run + 2 runs per captured graph) |
| 220 | + warmup_runs = 1 + (2 * max_batch_sizes) |
| 221 | + |
| 222 | + # Part B: Alignment padding (Empty runs to hit the 32-step boundary) |
| 223 | + padding_runs = aligned_steps - total_steps |
| 224 | + |
| 225 | + expected_dummy_run = (warmup_runs + padding_runs) * dp_size |
| 226 | + |
| 227 | + assert ( |
| 228 | + num_dummy_run == expected_dummy_run |
| 229 | + ), f"Dummy run count mismatch. Expected: {expected_dummy_run}, Got: {num_dummy_run}" |
| 230 | + |
| 231 | + # Metric 4: Graph Replay (Inference Execution) |
| 232 | + # Replays happen for every aligned step across all graphs. |
| 233 | + expected_replay = num_acl_graphs * aligned_steps * dp_size |
| 234 | + |
| 235 | + assert ( |
| 236 | + actual_replay == expected_replay |
| 237 | + ), f"Replay count mismatch. Expected: {expected_replay}, Got: {actual_replay}" |
0 commit comments