Skip to content
14 changes: 14 additions & 0 deletions tests/e2e/singlecard/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,17 @@ def test_quant_W8A8():
quantization="ascend",
) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)


def test_quant_awq():
max_tokens = 5
example_prompts = [
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs."
]
with VllmRunner(
snapshot_download("Qwen/Qwen2.5-0.5B-Instruct-AWQ"),
max_model_len=8192,
enforce_eager=False,
gpu_memory_utilization=0.7,
) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)
247 changes: 247 additions & 0 deletions tests/ut/quantization/test_awq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
from types import MappingProxyType
from unittest.mock import ANY, MagicMock, patch

import torch
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.linear import LinearBase

from tests.ut.base import TestBase
from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod
from vllm_ascend.quantization.awq.awq import (AWQLinearAscendMethod,
AWQMoEAscendMethod,
AWQQuantConfig)
from vllm_ascend.utils import AWQ_QUANTIZATION_METHOD


class TestAWQQuantization(TestBase):

def setUp(self):
super().setUp()
self.sample_config = {
"quant_method": AWQ_QUANTIZATION_METHOD,
"group_size": 128,
"bits": 4,
"zero_point": True,
"version": "gemm",
"modules_to_not_convert": ["visual"],
}

self.awq_quant_config = AWQQuantConfig.from_config(self.sample_config)
self.awq_quant_config.packed_modules_mapping = MappingProxyType({})

def test_init(self):
self.assertEqual(self.awq_quant_config.group_size, 128)
self.assertEqual(self.awq_quant_config.weight_bits, 4)
self.assertTrue(self.awq_quant_config.zero_point)
self.assertEqual(self.awq_quant_config.modules_to_not_convert,
["visual"])

def test_init_with_invalid_bits(self):
invalid_config = self.sample_config.copy()
invalid_config["bits"] = 8
with self.assertRaises(ValueError):
AWQQuantConfig.from_config(invalid_config)

def test_get_name(self):
self.assertEqual(self.awq_quant_config.get_name(),
AWQ_QUANTIZATION_METHOD)

def test_get_supported_act_dtypes(self):
supported_dtypes = self.awq_quant_config.get_supported_act_dtypes()
self.assertIn(torch.float16, supported_dtypes)
self.assertIn(torch.bfloat16, supported_dtypes)
self.assertEqual(len(supported_dtypes), 2)

def test_get_min_capability(self):
with self.assertRaises(NotImplementedError):
AWQQuantConfig.get_min_capability()

def test_get_config_filenames(self):
filenames = AWQQuantConfig.get_config_filenames()
self.assertIn("quant_config.json", filenames)
self.assertIn("quantize_config.json", filenames)
self.assertEqual(len(filenames), 2)

def test_from_config(self):
config = AWQQuantConfig.from_config(self.sample_config)
self.assertIsInstance(config, AWQQuantConfig)

def test_get_quant_method_for_linear(self):
linear_layer = MagicMock(spec=LinearBase)
# Test skipped layer
quant_method = self.awq_quant_config.get_quant_method(
linear_layer, "visual")
self.assertIsInstance(quant_method, AscendUnquantizedLinearMethod)

# Test quantized layer
quant_method = self.awq_quant_config.get_quant_method(
linear_layer, "attn")
self.assertIsInstance(quant_method, AWQLinearAscendMethod)

def test_get_quant_method_for_fused_moe(self):
fused_moe_layer = MagicMock(spec=FusedMoE)
fused_moe_config = MagicMock(spec=FusedMoEConfig)
fused_moe_layer.moe_config = fused_moe_config

# Test skipped layer
with patch(
'vllm_ascend.quantization.awq.awq.AscendUnquantizedFusedMoEMethod',
return_value=MagicMock()) as mock_ascend_moe:
quant_method = self.awq_quant_config.get_quant_method(
fused_moe_layer, "visual")
self.assertIs(quant_method, mock_ascend_moe.return_value)

# Test quantized layer
with patch('vllm_ascend.quantization.awq.awq.AWQMoEAscendMethod',
return_value=MagicMock()) as mock_ascend_moe:
quant_method = self.awq_quant_config.get_quant_method(
fused_moe_layer, "attn")
self.assertIs(quant_method, mock_ascend_moe.return_value)


class TestAWQLinearAscendMethod(TestBase):

def setUp(self):
super().setUp()
self.sample_config = {
"quant_method": AWQ_QUANTIZATION_METHOD,
"group_size": 128,
"bits": 4,
"zero_point": True,
"version": "gemm",
"modules_to_not_convert": ["visual"],
}

self.awq_quant_config = AWQQuantConfig.from_config(self.sample_config)
self.method = AWQLinearAscendMethod(self.awq_quant_config)

def test_create_weights(self):
with patch("vllm.model_executor.parameter.get_tensor_model_parallel_rank", return_value=0), \
patch("vllm.model_executor.parameter.get_tensor_model_parallel_world_size", return_value=1):

layer = MagicMock(spec=LinearBase)
self.method.create_weights(
layer=layer,
input_size_per_partition=128,
output_partition_sizes=[64],
input_size=128,
output_size=64,
params_dtype=torch.float16,
)
layer.register_parameter.assert_any_call("qweight", ANY)
layer.register_parameter.assert_any_call("qzeros", ANY)
layer.register_parameter.assert_any_call("scales", ANY)

def test_process_weights_after_loading(self):
layer = MagicMock(spec=LinearBase)
layer.qweight = torch.randint(10, (64, 128), dtype=torch.int32)
# AWQ pack order [0 2 4 6 1 3 5 7]
layer.qweight[0][0] = 0x75316420
layer.qzeros = torch.randint(
10, (1, 128 // self.awq_quant_config.group_size),
dtype=torch.int32)
# AWQ pack order [0 2 4 6 1 3 5 7]
layer.qzeros[0][0] = 0x75316420
layer.scales = torch.randn(1,
128 // self.awq_quant_config.group_size,
dtype=torch.float16)

self.method.process_weights_after_loading(layer)
# unpacked and signed number. eg: 0 -> 1000b(-8 in int4) -> 0x8 in uint32
self.assertEqual(layer.qweight[0][0].to(torch.uint32), 0xFEDCBA98)
self.assertTrue(
torch.equal(
layer.qzeros[0],
torch.Tensor([8., 7., 6., 5., 4., 3., 2.,
1.]).to(torch.float16)))

def test_apply(self):
with patch("torch_npu.npu_weight_quant_batchmatmul") as mock_func:
layer = MagicMock(spec=LinearBase)
layer.qweight = torch.randint(10, (64, 128), dtype=torch.int32)
layer.qzeros = torch.randint(
-8, 8,
(8, 128 // self.awq_quant_config.group_size)).to(torch.float16)
layer.scales = torch.randn(1,
128 // self.awq_quant_config.group_size,
dtype=torch.float16)

x = torch.randn(2, 16, 128, dtype=torch.float16)
self.method.apply(layer, x)
mock_func.assert_called_once()


class TestAWQMoEAscendMethod(TestBase):

def setUp(self):
super().setUp()
self.sample_config = {
"quant_method": AWQ_QUANTIZATION_METHOD,
"group_size": 128,
"bits": 4,
"zero_point": True,
"version": "gemm",
"modules_to_not_convert": ["visual"],
}

self.awq_quant_config = AWQQuantConfig.from_config(self.sample_config)
self.method = AWQMoEAscendMethod(self.awq_quant_config)

def test_create_weights(self):
layer = MagicMock(spec=FusedMoE)
self.method.create_weights(
layer,
num_experts=4,
hidden_size=256,
intermediate_size_per_partition=128,
params_dtype=torch.float16,
)

layer.register_parameter.assert_any_call("w13_qweight", ANY)
layer.register_parameter.assert_any_call("w2_qweight", ANY)
layer.register_parameter.assert_any_call("w13_scales", ANY)
layer.register_parameter.assert_any_call("w2_scales", ANY)
layer.register_parameter.assert_any_call("w13_qzeros", ANY)
layer.register_parameter.assert_any_call("w2_qzeros", ANY)

def test_process_weights_after_loading(self):
layer = MagicMock(spec=FusedMoE)
layer.register_parameter = lambda name, param: setattr(
layer, name, param)
layer.w13_qweight = torch.randint(10, (4, 128, 256), dtype=torch.int32)
# AWQ pack order [0 2 4 6 1 3 5 7]
layer.w13_qweight[0][0][0] = 0x75316420
layer.w13_qzeros = torch.randint(10, (4, 2), dtype=torch.int32)
# AWQ pack order [0 2 4 6 1 3 5 7]
layer.w13_qzeros[0][0] = 0x75316420
layer.w13_scales = torch.randn(4, 2, dtype=torch.float16)

layer.w2_qweight = torch.randint(10, (4, 256, 128), dtype=torch.int32)
# AWQ pack order [0 2 4 6 1 3 5 7]
layer.w2_qweight[0][0][0] = 0x75316420
layer.w2_qzeros = torch.randint(10, (4, 2), dtype=torch.int32)
# AWQ pack order [0 2 4 6 1 3 5 7]
layer.w2_qzeros[0][0] = 0x75316420
layer.w2_scales = torch.randn(4, 2, dtype=torch.float16)

self.method.process_weights_after_loading(layer)

# unpacked and signed number. eg: 0 -> 1000b(-8 in int4) -> 0x8 in uint32
self.assertEqual(layer.w13_qweight[0][0][0].to(torch.uint32),
0xFEDCBA98)
print(layer.w13_qzeros[0])
self.assertTrue(
torch.equal(
layer.w13_qzeros[0][0],
torch.Tensor([8., 7., 6., 5., 4., 3., 2.,
1.]).to(torch.float16)))

# unpacked and signed number. eg: 0 -> 1000b(-8 in int4) -> 0x8 in uint32
self.assertEqual(layer.w2_qweight[0][0][0].to(torch.uint32),
0xFEDCBA98)
self.assertTrue(
torch.equal(
layer.w2_qzeros[0][0],
torch.Tensor([8., 7., 6., 5., 4., 3., 2.,
1.]).to(torch.float16)))
15 changes: 13 additions & 2 deletions tests/ut/quantization/test_quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,25 @@ def test_from_config(self):

@patch('torch.npu.is_available')
def test_override_quantization_method(self, mock_is_available):
# Test when NPU is available
# Test when quant_method is None
mock_is_available.return_value = True
result = AscendQuantConfig.override_quantization_method(None, None)
self.assertIsNone(result)

# Test when NPU is available
mock_is_available.return_value = True
result = AscendQuantConfig.override_quantization_method({}, None)
self.assertEqual(result, ASCEND_QUANTIZATION_METHOD)

# Test when NPU is not available
mock_is_available.return_value = False
result = AscendQuantConfig.override_quantization_method(None, None)
result = AscendQuantConfig.override_quantization_method({}, None)
self.assertIsNone(result)

# Test when quant_method is specified
hf_quant_cfg = {"quant_method": "awq"}
result = AscendQuantConfig.override_quantization_method(
hf_quant_cfg, None)
self.assertIsNone(result)

def test_get_quant_method_for_linear(self):
Expand Down
8 changes: 5 additions & 3 deletions tests/ut/test_platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from tests.ut.base import TestBase
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD,
AWQ_QUANTIZATION_METHOD,
COMPRESSED_TENSORS_METHOD, AscendDeviceType)


Expand Down Expand Up @@ -48,9 +49,10 @@ def test_class_variables(self):
self.assertEqual(NPUPlatform.device_control_env_var,
"ASCEND_RT_VISIBLE_DEVICES")
self.assertEqual(NPUPlatform.dispatch_key, "PrivateUse1")
self.assertEqual(
NPUPlatform.supported_quantization,
[ASCEND_QUANTIZATION_METHOD, COMPRESSED_TENSORS_METHOD])
self.assertEqual(NPUPlatform.supported_quantization, [
ASCEND_QUANTIZATION_METHOD, COMPRESSED_TENSORS_METHOD,
AWQ_QUANTIZATION_METHOD
])

def test_is_sleep_mode_available(self):
self.assertTrue(self.platform.is_sleep_mode_available())
Expand Down
17 changes: 9 additions & 8 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -1514,9 +1514,10 @@ def _mla_preprocess(self, layer_name, hidden_states, kv_cache,
num_decode_tokens = attn_metadata.num_decode_tokens
num_actual_tokens = attn_metadata.num_actual_tokens
if self.fused_qkv_a_proj is not None:
maybe_npu_prefetch(inputs=self.fused_qkv_a_proj.weight,
dependency=hidden_states,
enabled=self.enable_prefetch)
if hasattr(self.fused_qkv_a_proj, 'weight'):
maybe_npu_prefetch(inputs=self.fused_qkv_a_proj.weight,
dependency=hidden_states,
enabled=self.enable_prefetch)
qkv_lora = self.fused_qkv_a_proj(hidden_states)[0]
q_c, kv_no_split = qkv_lora.split(
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
Expand Down Expand Up @@ -1718,11 +1719,11 @@ def forward(
o_proj_input[num_decode_tokens:num_actual_tokens] = output_prefill
# O proj
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024
maybe_npu_prefetch(inputs=self.o_proj.weight,
dependency=o_proj_input,
max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=self.enable_prefetch)

if hasattr(self.o_proj, 'weight'):
maybe_npu_prefetch(inputs=self.o_proj.weight,
dependency=o_proj_input,
max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=self.enable_prefetch)
output[...] = self.o_proj(o_proj_input,
is_prefill=prefill_preprocess_res
is not None)[0]
Expand Down
3 changes: 2 additions & 1 deletion vllm_ascend/ops/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ def __init__(
vllm_config = get_current_vllm_config()
self.bias = None
# quantization with anti_method m4 will generate none-zero norm bias
if vllm_config.quant_config is not None and \
if vllm_config.quant_config is not None and hasattr(vllm_config.quant_config, "quant_description") and \
vllm_config.quant_config.quant_description is not None and \
any("norm.bias" in name for name in vllm_config.quant_config.quant_description.keys()):
self.bias = torch.nn.Parameter(torch.zeros(hidden_size),
requires_grad=False)
Expand Down
16 changes: 11 additions & 5 deletions vllm_ascend/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,13 @@

# isort: off
from vllm_ascend.utils import (
ASCEND_QUANTIZATION_METHOD, COMPRESSED_TENSORS_METHOD, AscendDeviceType,
enable_sp, get_ascend_device_type, is_vl_model,
prefill_context_parallel_enable, update_aclgraph_sizes,
update_cudagraph_capture_sizes, update_default_aclgraph_sizes)
ASCEND_QUANTIZATION_METHOD, COMPRESSED_TENSORS_METHOD,
AWQ_QUANTIZATION_METHOD, AscendDeviceType, enable_sp,
get_ascend_device_type, is_vl_model, prefill_context_parallel_enable,
update_aclgraph_sizes, update_cudagraph_capture_sizes,
update_default_aclgraph_sizes)

# isort: on

# set custom ops path
CUR_DIR = os.path.dirname(os.path.realpath(__file__))
Expand Down Expand Up @@ -79,7 +82,8 @@ class NPUPlatform(Platform):
dispatch_key: str = "PrivateUse1"

supported_quantization: list[str] = [
ASCEND_QUANTIZATION_METHOD, COMPRESSED_TENSORS_METHOD
ASCEND_QUANTIZATION_METHOD, COMPRESSED_TENSORS_METHOD,
AWQ_QUANTIZATION_METHOD
]

def is_sleep_mode_available(self) -> bool:
Expand All @@ -103,6 +107,8 @@ def pre_register_and_update(cls,
if ASCEND_QUANTIZATION_METHOD not in quant_action.choices:
quant_action.choices.append(ASCEND_QUANTIZATION_METHOD)

from vllm_ascend.quantization.awq.awq import \
AWQQuantConfig # noqa: F401
from vllm_ascend.quantization.compressed_tensors.compressed_tensors import \
AscendCompressedTensorsConfig # noqa: F401
from vllm_ascend.quantization.quant_config import \
Expand Down
Empty file.
Loading
Loading