Skip to content

Commit 5a4e8cd

Browse files
authored
[Feat][BugFix]Support the Qwen3-Next-80B-A3B-Instruct quantization model&Fix the NZ issue (#4245)
### What this PR does / why we need it? Support the Qwen3-Next-80B-A3B-Instruct quantization model and Fix the NZ issue. Triton kernel doesn't support data format nz, thus we skip converting weight to nz on layer `conv1d` - vLLM version: v0.11.0 - vLLM main: vllm-project/vllm@2918c1b --------- Signed-off-by: IncSec <1790766300@qq.com>
1 parent cbb27fe commit 5a4e8cd

File tree

10 files changed

+39
-30
lines changed

10 files changed

+39
-30
lines changed

tests/e2e/multicard/test_qwen3_next.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,12 @@
2020
2121
Run `pytest tests/e2e/multicard/test_qwen3_next.py`.
2222
"""
23+
24+
import os
25+
from unittest.mock import patch
26+
27+
from modelscope import snapshot_download # type: ignore
28+
2329
from tests.e2e.conftest import VllmRunner
2430

2531

@@ -106,3 +112,23 @@ def test_models_distributed_Qwen3_NEXT_MTP_TP4_SIMILARITY():
106112
print(f"spec_output: {spec_output[1]}")
107113

108114
assert matches > int(0.66 * len(ref_outputs))
115+
116+
117+
# TODO: will conduct accuracy verification after the subsequent version becomes stable
118+
@patch.dict(os.environ, {"HCCL_BUFFSIZE": "1024"})
119+
def test_models_distributed_Qwen3_NEXT_W8A8DYNAMIC_WITH_EP():
120+
example_prompts = [
121+
"Hello, my name is",
122+
]
123+
max_tokens = 5
124+
with VllmRunner(
125+
snapshot_download(
126+
"vllm-ascend/Qwen3-Next-80B-A3B-Instruct-W8A8-Pruning"),
127+
max_model_len=4096,
128+
tensor_parallel_size=2,
129+
gpu_memory_utilization=0.4,
130+
max_num_seqs=1,
131+
enable_expert_parallel=True,
132+
quantization="ascend",
133+
) as vllm_model:
134+
vllm_model.generate_greedy(example_prompts, max_tokens)

tests/ut/attention/test_mla_v1.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -797,7 +797,6 @@ def test_q_proj_and_k_up_proj(self):
797797
self.assertEqual(q_pe.shape[1], self.impl.num_heads)
798798
self.assertEqual(q_pe.shape[2], self.impl.qk_rope_head_dim)
799799

800-
@patch('vllm_ascend.utils._ENABLE_NZ', True)
801800
@patch('torch_npu.npu_format_cast')
802801
def test_process_weights_after_loading(self, mock_format_cast):
803802
layer = MagicMock(spec=LinearBase)

tests/ut/models/test_qwen2_5_vl.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from unittest.mock import patch
2-
31
import pytest
42
import torch
53
import torch.nn.functional as F
@@ -367,7 +365,6 @@ def test_pad_qkv_bias(self, mocker: MockerFixture):
367365
res = attention.pad_qkv_bias(torch.rand((300)))
368366
assert res.shape[0] == 384
369367

370-
@patch('vllm_ascend.utils._ENABLE_NZ', True)
371368
def test_pad_qkv_weight(self, mocker: MockerFixture):
372369
attention = self.init_vision_transformer(mocker)
373370
mocker.patch("torch.nn.Module.__setattr__")
@@ -380,7 +377,6 @@ def test_pad_qkv_weight(self, mocker: MockerFixture):
380377
res = attention.pad_qkv_weight(torch.rand((300, 300)))
381378
assert res.shape == (384, 300)
382379

383-
@patch('vllm_ascend.utils._ENABLE_NZ', True)
384380
def test_pad_proj_weight(self, mocker: MockerFixture):
385381
attention = self.init_vision_transformer(mocker)
386382
mocker.patch("torch.nn.Module.__setattr__")

tests/ut/quantization/test_w4a8_dynamic.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,6 @@ def build_layer(self,
260260
requires_grad=False)
261261
return layer
262262

263-
@patch('vllm_ascend.utils._ENABLE_NZ', False)
264263
@patch('torch_npu.npu_format_cast')
265264
@patch('torch_npu.npu_quantize')
266265
@patch('torch.Tensor.npu')

tests/ut/test_utils.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,18 +46,12 @@ def test_is_310p(self):
4646
self.assertFalse(utils.is_310p())
4747

4848
def test_is_enable_nz(self):
49-
# Case when _ENABLE_NZ is already set
50-
utils._ENABLE_NZ = True
51-
self.assertTrue(utils.is_enable_nz())
52-
53-
utils._ENABLE_NZ = False
54-
self.assertFalse(utils.is_enable_nz())
55-
56-
# Case when _ENABLE_NZ is None and vllm_config is not provided
57-
utils._ENABLE_NZ = None
58-
with self.assertRaises(ValueError) as context:
59-
utils.is_enable_nz()
60-
self.assertIn("vllm_config must be provided", str(context.exception))
49+
with mock.patch("vllm_ascend.utils.envs_ascend.VLLM_ASCEND_ENABLE_NZ",
50+
1):
51+
self.assertTrue(utils.is_enable_nz())
52+
with mock.patch("vllm_ascend.utils.envs_ascend.VLLM_ASCEND_ENABLE_NZ",
53+
0):
54+
self.assertFalse(utils.is_enable_nz())
6155

6256
def test_sleep_mode_enabled(self):
6357
utils._SLEEP_MODE_ENABLED = None

tests/ut/worker/test_worker_v1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,9 +281,9 @@ def test_sleep_mode_disabled_raises_error(self, mock_sleep_mode_enabled):
281281

282282
self.assertIn("Sleep mode is not enabled", str(cm.exception))
283283

284-
@patch('vllm_ascend.utils._ENABLE_NZ', False)
285284
@patch("vllm_ascend.worker.worker_v1.sleep_mode_enabled")
286285
@patch("vllm_ascend.worker.worker_v1.CaMemAllocator")
286+
@patch.dict("os.environ", {"VLLM_ASCEND_ENABLE_NZ": "0"})
287287
def test_wake_up_mode_enabled(self, mock_allocator_class,
288288
mock_sleep_mode_enabled):
289289
"""Test wake_up method when sleep mode is enabled"""

vllm_ascend/ops/linear.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ class AscendUnquantizedLinearMethod(UnquantizedLinearMethod):
4545

4646
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
4747
super().process_weights_after_loading(layer)
48-
if (is_enable_nz() and layer.weight.data.dtype
48+
if "conv1d" not in layer.prefix and (
49+
is_enable_nz() and layer.weight.data.dtype
4950
in [torch.float16, torch.bfloat16]):
5051
layer.weight.data = torch_npu.npu_format_cast(
5152
layer.weight.data, ACL_FORMAT_FRACTAL_NZ)

vllm_ascend/quantization/quant_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,8 @@ def get_scaled_act_names(self) -> List[str]:
222222
],
223223
"gate_up_proj": ["gate_proj", "up_proj"],
224224
"in_proj": ["in_proj_qkvz", "in_proj_ba"],
225+
"experts":
226+
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
225227
},
226228
"qwen2_5_vl": {
227229
"qkv_proj": [

vllm_ascend/utils.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@
5959
_IS_MOE_MODEL = None
6060
_ENABLE_SP = None
6161
_HAS_LAYER_IDX = None
62-
_ENABLE_NZ = None
6362
_SUBSCRIBED_COMPUTE_STREAMS = set()
6463
_GRAPH_PRINT_STREAM = None
6564
_GRAPH_PRINT_STREAM_LOCK = Lock()
@@ -129,14 +128,8 @@ def is_310p():
129128
return _IS_310P
130129

131130

132-
def is_enable_nz(vllm_config: Optional[VllmConfig] = None) -> bool:
133-
global _ENABLE_NZ
134-
if _ENABLE_NZ is None:
135-
if not vllm_config:
136-
raise ValueError(
137-
"vllm_config must be provided when _ENABLE_NZ is None")
138-
_ENABLE_NZ = envs_ascend.VLLM_ASCEND_ENABLE_NZ and vllm_config.model_config.hf_config.model_type != "qwen3_next"
139-
return _ENABLE_NZ
131+
def is_enable_nz():
132+
return envs_ascend.VLLM_ASCEND_ENABLE_NZ
140133

141134

142135
def sleep_mode_enabled():

vllm_ascend/worker/worker_v1.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,6 @@ def __init__(
8787
# register patch for vllm
8888
from vllm_ascend.utils import adapt_patch
8989
adapt_patch()
90-
is_enable_nz(vllm_config)
9190
# Register ops when worker init.
9291
from vllm_ascend import ops
9392
ops.register_dummy_fusion_op()

0 commit comments

Comments
 (0)