diff --git a/tests/e2e/singlecard/test_quantization.py b/tests/e2e/singlecard/test_quantization.py index 95f26ee8c29..627ce6a7639 100644 --- a/tests/e2e/singlecard/test_quantization.py +++ b/tests/e2e/singlecard/test_quantization.py @@ -33,3 +33,17 @@ def test_quant_W8A8(): quantization="ascend", ) as vllm_model: vllm_model.generate_greedy(example_prompts, max_tokens) + + +def test_quant_awq(): + max_tokens = 5 + example_prompts = [ + "vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs." + ] + with VllmRunner( + snapshot_download("Qwen/Qwen2.5-0.5B-Instruct-AWQ"), + max_model_len=8192, + enforce_eager=False, + gpu_memory_utilization=0.7, + ) as vllm_model: + vllm_model.generate_greedy(example_prompts, max_tokens) diff --git a/tests/ut/quantization/test_awq.py b/tests/ut/quantization/test_awq.py new file mode 100644 index 00000000000..8cf771e7d5b --- /dev/null +++ b/tests/ut/quantization/test_awq.py @@ -0,0 +1,247 @@ +from types import MappingProxyType +from unittest.mock import ANY, MagicMock, patch + +import torch +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig +from vllm.model_executor.layers.linear import LinearBase + +from tests.ut.base import TestBase +from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod +from vllm_ascend.quantization.awq.awq import (AWQLinearAscendMethod, + AWQMoEAscendMethod, + AWQQuantConfig) +from vllm_ascend.utils import AWQ_QUANTIZATION_METHOD + + +class TestAWQQuantization(TestBase): + + def setUp(self): + super().setUp() + self.sample_config = { + "quant_method": AWQ_QUANTIZATION_METHOD, + "group_size": 128, + "bits": 4, + "zero_point": True, + "version": "gemm", + "modules_to_not_convert": ["visual"], + } + + self.awq_quant_config = AWQQuantConfig.from_config(self.sample_config) + self.awq_quant_config.packed_modules_mapping = MappingProxyType({}) + + def test_init(self): + self.assertEqual(self.awq_quant_config.group_size, 128) + self.assertEqual(self.awq_quant_config.weight_bits, 4) + self.assertTrue(self.awq_quant_config.zero_point) + self.assertEqual(self.awq_quant_config.modules_to_not_convert, + ["visual"]) + + def test_init_with_invalid_bits(self): + invalid_config = self.sample_config.copy() + invalid_config["bits"] = 8 + with self.assertRaises(ValueError): + AWQQuantConfig.from_config(invalid_config) + + def test_get_name(self): + self.assertEqual(self.awq_quant_config.get_name(), + AWQ_QUANTIZATION_METHOD) + + def test_get_supported_act_dtypes(self): + supported_dtypes = self.awq_quant_config.get_supported_act_dtypes() + self.assertIn(torch.float16, supported_dtypes) + self.assertIn(torch.bfloat16, supported_dtypes) + self.assertEqual(len(supported_dtypes), 2) + + def test_get_min_capability(self): + with self.assertRaises(NotImplementedError): + AWQQuantConfig.get_min_capability() + + def test_get_config_filenames(self): + filenames = AWQQuantConfig.get_config_filenames() + self.assertIn("quant_config.json", filenames) + self.assertIn("quantize_config.json", filenames) + self.assertEqual(len(filenames), 2) + + def test_from_config(self): + config = AWQQuantConfig.from_config(self.sample_config) + self.assertIsInstance(config, AWQQuantConfig) + + def test_get_quant_method_for_linear(self): + linear_layer = MagicMock(spec=LinearBase) + # Test skipped layer + quant_method = self.awq_quant_config.get_quant_method( + linear_layer, "visual") + self.assertIsInstance(quant_method, AscendUnquantizedLinearMethod) + + # Test quantized layer + quant_method = self.awq_quant_config.get_quant_method( + linear_layer, "attn") + self.assertIsInstance(quant_method, AWQLinearAscendMethod) + + def test_get_quant_method_for_fused_moe(self): + fused_moe_layer = MagicMock(spec=FusedMoE) + fused_moe_config = MagicMock(spec=FusedMoEConfig) + fused_moe_layer.moe_config = fused_moe_config + + # Test skipped layer + with patch( + 'vllm_ascend.quantization.awq.awq.AscendUnquantizedFusedMoEMethod', + return_value=MagicMock()) as mock_ascend_moe: + quant_method = self.awq_quant_config.get_quant_method( + fused_moe_layer, "visual") + self.assertIs(quant_method, mock_ascend_moe.return_value) + + # Test quantized layer + with patch('vllm_ascend.quantization.awq.awq.AWQMoEAscendMethod', + return_value=MagicMock()) as mock_ascend_moe: + quant_method = self.awq_quant_config.get_quant_method( + fused_moe_layer, "attn") + self.assertIs(quant_method, mock_ascend_moe.return_value) + + +class TestAWQLinearAscendMethod(TestBase): + + def setUp(self): + super().setUp() + self.sample_config = { + "quant_method": AWQ_QUANTIZATION_METHOD, + "group_size": 128, + "bits": 4, + "zero_point": True, + "version": "gemm", + "modules_to_not_convert": ["visual"], + } + + self.awq_quant_config = AWQQuantConfig.from_config(self.sample_config) + self.method = AWQLinearAscendMethod(self.awq_quant_config) + + def test_create_weights(self): + with patch("vllm.model_executor.parameter.get_tensor_model_parallel_rank", return_value=0), \ + patch("vllm.model_executor.parameter.get_tensor_model_parallel_world_size", return_value=1): + + layer = MagicMock(spec=LinearBase) + self.method.create_weights( + layer=layer, + input_size_per_partition=128, + output_partition_sizes=[64], + input_size=128, + output_size=64, + params_dtype=torch.float16, + ) + layer.register_parameter.assert_any_call("qweight", ANY) + layer.register_parameter.assert_any_call("qzeros", ANY) + layer.register_parameter.assert_any_call("scales", ANY) + + def test_process_weights_after_loading(self): + layer = MagicMock(spec=LinearBase) + layer.qweight = torch.randint(10, (64, 128), dtype=torch.int32) + # AWQ pack order [0 2 4 6 1 3 5 7] + layer.qweight[0][0] = 0x75316420 + layer.qzeros = torch.randint( + 10, (1, 128 // self.awq_quant_config.group_size), + dtype=torch.int32) + # AWQ pack order [0 2 4 6 1 3 5 7] + layer.qzeros[0][0] = 0x75316420 + layer.scales = torch.randn(1, + 128 // self.awq_quant_config.group_size, + dtype=torch.float16) + + self.method.process_weights_after_loading(layer) + # unpacked and signed number. eg: 0 -> 1000b(-8 in int4) -> 0x8 in uint32 + self.assertEqual(layer.qweight[0][0].to(torch.uint32), 0xFEDCBA98) + self.assertTrue( + torch.equal( + layer.qzeros[0], + torch.Tensor([8., 7., 6., 5., 4., 3., 2., + 1.]).to(torch.float16))) + + def test_apply(self): + with patch("torch_npu.npu_weight_quant_batchmatmul") as mock_func: + layer = MagicMock(spec=LinearBase) + layer.qweight = torch.randint(10, (64, 128), dtype=torch.int32) + layer.qzeros = torch.randint( + -8, 8, + (8, 128 // self.awq_quant_config.group_size)).to(torch.float16) + layer.scales = torch.randn(1, + 128 // self.awq_quant_config.group_size, + dtype=torch.float16) + + x = torch.randn(2, 16, 128, dtype=torch.float16) + self.method.apply(layer, x) + mock_func.assert_called_once() + + +class TestAWQMoEAscendMethod(TestBase): + + def setUp(self): + super().setUp() + self.sample_config = { + "quant_method": AWQ_QUANTIZATION_METHOD, + "group_size": 128, + "bits": 4, + "zero_point": True, + "version": "gemm", + "modules_to_not_convert": ["visual"], + } + + self.awq_quant_config = AWQQuantConfig.from_config(self.sample_config) + self.method = AWQMoEAscendMethod(self.awq_quant_config) + + def test_create_weights(self): + layer = MagicMock(spec=FusedMoE) + self.method.create_weights( + layer, + num_experts=4, + hidden_size=256, + intermediate_size_per_partition=128, + params_dtype=torch.float16, + ) + + layer.register_parameter.assert_any_call("w13_qweight", ANY) + layer.register_parameter.assert_any_call("w2_qweight", ANY) + layer.register_parameter.assert_any_call("w13_scales", ANY) + layer.register_parameter.assert_any_call("w2_scales", ANY) + layer.register_parameter.assert_any_call("w13_qzeros", ANY) + layer.register_parameter.assert_any_call("w2_qzeros", ANY) + + def test_process_weights_after_loading(self): + layer = MagicMock(spec=FusedMoE) + layer.register_parameter = lambda name, param: setattr( + layer, name, param) + layer.w13_qweight = torch.randint(10, (4, 128, 256), dtype=torch.int32) + # AWQ pack order [0 2 4 6 1 3 5 7] + layer.w13_qweight[0][0][0] = 0x75316420 + layer.w13_qzeros = torch.randint(10, (4, 2), dtype=torch.int32) + # AWQ pack order [0 2 4 6 1 3 5 7] + layer.w13_qzeros[0][0] = 0x75316420 + layer.w13_scales = torch.randn(4, 2, dtype=torch.float16) + + layer.w2_qweight = torch.randint(10, (4, 256, 128), dtype=torch.int32) + # AWQ pack order [0 2 4 6 1 3 5 7] + layer.w2_qweight[0][0][0] = 0x75316420 + layer.w2_qzeros = torch.randint(10, (4, 2), dtype=torch.int32) + # AWQ pack order [0 2 4 6 1 3 5 7] + layer.w2_qzeros[0][0] = 0x75316420 + layer.w2_scales = torch.randn(4, 2, dtype=torch.float16) + + self.method.process_weights_after_loading(layer) + + # unpacked and signed number. eg: 0 -> 1000b(-8 in int4) -> 0x8 in uint32 + self.assertEqual(layer.w13_qweight[0][0][0].to(torch.uint32), + 0xFEDCBA98) + print(layer.w13_qzeros[0]) + self.assertTrue( + torch.equal( + layer.w13_qzeros[0][0], + torch.Tensor([8., 7., 6., 5., 4., 3., 2., + 1.]).to(torch.float16))) + + # unpacked and signed number. eg: 0 -> 1000b(-8 in int4) -> 0x8 in uint32 + self.assertEqual(layer.w2_qweight[0][0][0].to(torch.uint32), + 0xFEDCBA98) + self.assertTrue( + torch.equal( + layer.w2_qzeros[0][0], + torch.Tensor([8., 7., 6., 5., 4., 3., 2., + 1.]).to(torch.float16))) diff --git a/tests/ut/quantization/test_quant_config.py b/tests/ut/quantization/test_quant_config.py index b667767ba79..5e7379fac63 100644 --- a/tests/ut/quantization/test_quant_config.py +++ b/tests/ut/quantization/test_quant_config.py @@ -62,14 +62,25 @@ def test_from_config(self): @patch('torch.npu.is_available') def test_override_quantization_method(self, mock_is_available): - # Test when NPU is available + # Test when quant_method is None mock_is_available.return_value = True result = AscendQuantConfig.override_quantization_method(None, None) self.assertIsNone(result) + # Test when NPU is available + mock_is_available.return_value = True + result = AscendQuantConfig.override_quantization_method({}, None) + self.assertEqual(result, ASCEND_QUANTIZATION_METHOD) + # Test when NPU is not available mock_is_available.return_value = False - result = AscendQuantConfig.override_quantization_method(None, None) + result = AscendQuantConfig.override_quantization_method({}, None) + self.assertIsNone(result) + + # Test when quant_method is specified + hf_quant_cfg = {"quant_method": "awq"} + result = AscendQuantConfig.override_quantization_method( + hf_quant_cfg, None) self.assertIsNone(result) def test_get_quant_method_for_linear(self): diff --git a/tests/ut/test_platform.py b/tests/ut/test_platform.py index 5fe5cde3e80..e690ad79cde 100644 --- a/tests/ut/test_platform.py +++ b/tests/ut/test_platform.py @@ -10,6 +10,7 @@ from tests.ut.base import TestBase from vllm_ascend.platform import NPUPlatform from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, + AWQ_QUANTIZATION_METHOD, COMPRESSED_TENSORS_METHOD, AscendDeviceType) @@ -48,9 +49,10 @@ def test_class_variables(self): self.assertEqual(NPUPlatform.device_control_env_var, "ASCEND_RT_VISIBLE_DEVICES") self.assertEqual(NPUPlatform.dispatch_key, "PrivateUse1") - self.assertEqual( - NPUPlatform.supported_quantization, - [ASCEND_QUANTIZATION_METHOD, COMPRESSED_TENSORS_METHOD]) + self.assertEqual(NPUPlatform.supported_quantization, [ + ASCEND_QUANTIZATION_METHOD, COMPRESSED_TENSORS_METHOD, + AWQ_QUANTIZATION_METHOD + ]) def test_is_sleep_mode_available(self): self.assertTrue(self.platform.is_sleep_mode_available()) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 188e66a5948..c43e628f9e0 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -1514,9 +1514,10 @@ def _mla_preprocess(self, layer_name, hidden_states, kv_cache, num_decode_tokens = attn_metadata.num_decode_tokens num_actual_tokens = attn_metadata.num_actual_tokens if self.fused_qkv_a_proj is not None: - maybe_npu_prefetch(inputs=self.fused_qkv_a_proj.weight, - dependency=hidden_states, - enabled=self.enable_prefetch) + if hasattr(self.fused_qkv_a_proj, 'weight'): + maybe_npu_prefetch(inputs=self.fused_qkv_a_proj.weight, + dependency=hidden_states, + enabled=self.enable_prefetch) qkv_lora = self.fused_qkv_a_proj(hidden_states)[0] q_c, kv_no_split = qkv_lora.split( [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], @@ -1718,11 +1719,11 @@ def forward( o_proj_input[num_decode_tokens:num_actual_tokens] = output_prefill # O proj MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 - maybe_npu_prefetch(inputs=self.o_proj.weight, - dependency=o_proj_input, - max_size=MAX_O_PROJ_PREFETCH_SIZE, - enabled=self.enable_prefetch) - + if hasattr(self.o_proj, 'weight'): + maybe_npu_prefetch(inputs=self.o_proj.weight, + dependency=o_proj_input, + max_size=MAX_O_PROJ_PREFETCH_SIZE, + enabled=self.enable_prefetch) output[...] = self.o_proj(o_proj_input, is_prefill=prefill_preprocess_res is not None)[0] diff --git a/vllm_ascend/ops/layernorm.py b/vllm_ascend/ops/layernorm.py index 8c395b54fd4..b16c0cbd2b6 100644 --- a/vllm_ascend/ops/layernorm.py +++ b/vllm_ascend/ops/layernorm.py @@ -97,7 +97,8 @@ def __init__( vllm_config = get_current_vllm_config() self.bias = None # quantization with anti_method m4 will generate none-zero norm bias - if vllm_config.quant_config is not None and \ + if vllm_config.quant_config is not None and hasattr(vllm_config.quant_config, "quant_description") and \ + vllm_config.quant_config.quant_description is not None and \ any("norm.bias" in name for name in vllm_config.quant_config.quant_description.keys()): self.bias = torch.nn.Parameter(torch.zeros(hidden_size), requires_grad=False) diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 7cc84fc6ae3..57d25c1ea0e 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -33,10 +33,13 @@ # isort: off from vllm_ascend.utils import ( - ASCEND_QUANTIZATION_METHOD, COMPRESSED_TENSORS_METHOD, AscendDeviceType, - enable_sp, get_ascend_device_type, is_vl_model, - prefill_context_parallel_enable, update_aclgraph_sizes, - update_cudagraph_capture_sizes, update_default_aclgraph_sizes) + ASCEND_QUANTIZATION_METHOD, COMPRESSED_TENSORS_METHOD, + AWQ_QUANTIZATION_METHOD, AscendDeviceType, enable_sp, + get_ascend_device_type, is_vl_model, prefill_context_parallel_enable, + update_aclgraph_sizes, update_cudagraph_capture_sizes, + update_default_aclgraph_sizes) + +# isort: on # set custom ops path CUR_DIR = os.path.dirname(os.path.realpath(__file__)) @@ -79,7 +82,8 @@ class NPUPlatform(Platform): dispatch_key: str = "PrivateUse1" supported_quantization: list[str] = [ - ASCEND_QUANTIZATION_METHOD, COMPRESSED_TENSORS_METHOD + ASCEND_QUANTIZATION_METHOD, COMPRESSED_TENSORS_METHOD, + AWQ_QUANTIZATION_METHOD ] def is_sleep_mode_available(self) -> bool: @@ -103,6 +107,8 @@ def pre_register_and_update(cls, if ASCEND_QUANTIZATION_METHOD not in quant_action.choices: quant_action.choices.append(ASCEND_QUANTIZATION_METHOD) + from vllm_ascend.quantization.awq.awq import \ + AWQQuantConfig # noqa: F401 from vllm_ascend.quantization.compressed_tensors.compressed_tensors import \ AscendCompressedTensorsConfig # noqa: F401 from vllm_ascend.quantization.quant_config import \ diff --git a/vllm_ascend/quantization/awq/__init__.py b/vllm_ascend/quantization/awq/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/vllm_ascend/quantization/awq/awq.py b/vllm_ascend/quantization/awq/awq.py new file mode 100644 index 00000000000..76c2ec8ff62 --- /dev/null +++ b/vllm_ascend/quantization/awq/awq.py @@ -0,0 +1,469 @@ +from typing import Any, Callable, List, Optional, Union + +import torch +import torch_npu +from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, + FusedMoeWeightScaleSupported) +from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig, + FusedMoEQuantConfig) +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization import ( + QUANTIZATION_METHODS, register_quantization_config) +from vllm.model_executor.layers.quantization.awq import AWQLinearMethod +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.quantization.utils.quant_utils import \ + is_layer_skipped +from vllm.model_executor.utils import set_weight_attrs + +from vllm_ascend.ops.fused_moe.experts_selector import select_experts +from vllm_ascend.ops.fused_moe.fused_moe import AscendUnquantizedFusedMoEMethod +from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod +from vllm_ascend.utils import AWQ_QUANTIZATION_METHOD + + +def remove_quantization_method(): + if AWQ_QUANTIZATION_METHOD in QUANTIZATION_METHODS: + QUANTIZATION_METHODS.remove(AWQ_QUANTIZATION_METHOD) + + +remove_quantization_method() + + +def npu_fused_experts( + hidden_states: torch.Tensor, + w13: torch.Tensor, + w13_scale: torch.Tensor, + w2: torch.Tensor, + w2_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + top_k: int, + **kwargs, +): + w13_offset = kwargs.get("w13_offset", None) + w2_offset = kwargs.get("w2_offset", None) + use_wna16 = kwargs.get("use_wna16", False) + + original_shape = hidden_states.shape + original_dtype = hidden_states.dtype + scale_dtype = original_dtype if original_dtype == torch.bfloat16 else torch.float32 + if len(original_shape) == 3: + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + num_tokens = hidden_states.shape[0] + num_experts = w13.shape[0] + row_idx_len = num_tokens * top_k + row_idx = (torch.arange(0, + row_idx_len, + dtype=torch.int32, + device=topk_weights.device).view( + top_k, -1).permute(1, 0).contiguous()) + hidden_states, expanded_row_idx, expanded_expert_idx = ( + torch_npu.npu_moe_init_routing(hidden_states, + row_idx=row_idx, + expert_idx=topk_ids, + active_num=num_tokens)) + expert_tokens = torch_npu.npu_moe_compute_expert_tokens( + expanded_expert_idx, num_experts) + expert_tokens = expert_tokens.to(torch.int64) + # gmm1: gate_up_proj + if not use_wna16: + hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant( + hidden_states) + scale_args13 = { + "scale": [w13_scale.to(scale_dtype)], + "per_token_scale": [pertoken_scale], + } + else: + scale_args13 = { + "antiquant_scale": [w13_scale], + "antiquant_offset": [w13_offset], + } + + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[w13], + **scale_args13, + split_item=2, + group_list_type=0, + group_type=0, + group_list=expert_tokens, + output_dtype=original_dtype, + )[0] + # act_fn: swiglu + hidden_states = torch_npu.npu_swiglu(hidden_states) + if not use_wna16: + hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant( + hidden_states) + + scale_args2 = { + "scale": [w2_scale.to(scale_dtype)], + "per_token_scale": [pertoken_scale], + } + else: + scale_args2 = { + "antiquant_scale": [w2_scale], + "antiquant_offset": [w2_offset] + } + # gmm2: down_proj + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[w2], + **scale_args2, + split_item=2, + group_list_type=0, + group_type=0, + group_list=expert_tokens, + output_dtype=original_dtype, + )[0] + + final_hidden_states = torch_npu.npu_moe_finalize_routing( + hidden_states, + skip1=None, + skip2=None, + bias=None, + scales=topk_weights, + expanded_src_to_dst_row=expanded_row_idx, + export_for_source_row=topk_ids, + ) + if len(original_shape) == 3: + final_hidden_states = final_hidden_states.view(original_shape) + return final_hidden_states + + +@register_quantization_config(AWQ_QUANTIZATION_METHOD) +class AWQQuantConfig(QuantizationConfig): + + def __init__( + self, + weight_bits: int, + group_size: int, + zero_point: bool, + modules_to_not_convert: list[str] | None = None, + ): + super().__init__() + self.weight_bits = weight_bits + self.group_size = group_size + self.zero_point = zero_point + self.modules_to_not_convert = modules_to_not_convert or [] + + if self.weight_bits != 4: + raise ValueError( + "Currently, only 4-bit weight quantization is supported for " + f"AWQ, but got {self.weight_bits} bits.") + self.pack_factor = 32 // self.weight_bits + + def get_name(self) -> str: + return AWQ_QUANTIZATION_METHOD + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.half, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + raise NotImplementedError( + "Ascend hardware dose not support \"get_min_capability\" feature.") + + @staticmethod + def get_config_filenames() -> List[str]: + return [ + "quant_config.json", # E.g., casperhansen/vicuna-7b-v1.5-awq + # E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq + "quantize_config.json", + ] + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "AWQQuantConfig": + weight_bits = cls.get_from_keys(config, ["w_bit", "bits"]) + group_size = cls.get_from_keys(config, ["q_group_size", "group_size"]) + zero_point = cls.get_from_keys(config, ["zero_point"]) + modules_to_not_convert = cls.get_from_keys_or( + config, ["modules_to_not_convert"], None) + return cls(weight_bits, group_size, zero_point, modules_to_not_convert) + + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Union["LinearMethodBase", "QuantizeMethodBase"] | None: + if isinstance(layer, LinearBase): + if is_layer_skipped( + prefix, + self.modules_to_not_convert, + self.packed_modules_mapping, + skip_with_substr=True, + ): + return AscendUnquantizedLinearMethod() + return AWQLinearAscendMethod(self) + elif isinstance(layer, FusedMoE): + if is_layer_skipped( + prefix, + self.modules_to_not_convert, + skip_with_substr=True, + ): + return AscendUnquantizedFusedMoEMethod(layer.moe_config) + return AWQMoEAscendMethod(self) + return None + + +class AWQLinearAscendMethod(AWQLinearMethod): + + def __init__(self, quant_config: AWQQuantConfig): + self.quant_config = quant_config + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + layer.scales = torch.nn.Parameter(layer.scales.data, + requires_grad=False) + qweight_tmp = torch.zeros_like(layer.qweight.data) + qzeros_tmp = layer.qzeros.data + qzeros_list = [] + shifts = [0, 4, 1, 5, 2, 6, 3, 7] + + for i in range(0, self.quant_config.pack_factor): + shift_num = shifts[i] * 4 + qzeros_list.append((qzeros_tmp.reshape(-1, 1) >> shift_num) & 0xF) + qweight_tmp.bitwise_or_(((layer.qweight.data >> shift_num) * + (2**(4 * i))) & (0xF << (4 * i))) + + qweight_tmp.bitwise_xor_(0x88888888) + + qzeros_tmp = torch.cat(qzeros_list, + dim=-1).reshape(qzeros_tmp.shape[0], -1) + qzeros_tmp = -(qzeros_tmp - 8) + qzeros_tmp = qzeros_tmp.to(layer.scales.data.dtype) + + layer.qzeros = torch.nn.Parameter(qzeros_tmp, requires_grad=False) + layer.qweight = torch.nn.Parameter(qweight_tmp, requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + qweight = layer.qweight + scales = layer.scales + qzeros = layer.qzeros + pack_factor = self.quant_config.pack_factor + out_shape = x.shape[:-1] + (qweight.shape[-1] * pack_factor, ) + reshaped_x = x.reshape(-1, x.shape[-1]) + + if bias is not None and bias.dtype == torch.bfloat16: + bias = bias.float() + + out = torch_npu.npu_weight_quant_batchmatmul( + reshaped_x, + qweight, + antiquant_scale=scales, + antiquant_offset=qzeros, + antiquant_group_size=self.quant_config.group_size, + bias=bias, + ) + + return out.reshape(out_shape) + + +class AWQMoEAscendMethod(FusedMoEMethodBase): + + def __init__(self, quant_config: AWQQuantConfig): + super().__init__(FusedMoEConfig) + self.quant_config = quant_config + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): + extra_weight_attrs.update({ + "is_transposed": + True, + "quant_method": + FusedMoeWeightScaleSupported.GROUP.value, + }) + + w13_qweight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + 2 * intermediate_size_per_partition // + self.quant_config.pack_factor, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_qweight", w13_qweight) + set_weight_attrs(w13_qweight, extra_weight_attrs) + + w2_qweight = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size_per_partition, + hidden_size // self.quant_config.pack_factor, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_qweight", w2_qweight) + set_weight_attrs(w2_qweight, extra_weight_attrs) + + num_groups_w13 = hidden_size // self.quant_config.group_size + num_groups_w2 = intermediate_size_per_partition // self.quant_config.group_size + + # WEIGHT_SCALES + # Allocate 2 scales for w1 and w3 respectively. + w13_scales = torch.nn.Parameter( + torch.empty( + num_experts, + num_groups_w13, + intermediate_size_per_partition * 2, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_scales", w13_scales) + set_weight_attrs(w13_scales, extra_weight_attrs) + + w2_scales = torch.nn.Parameter( + torch.empty(num_experts, + num_groups_w2, + hidden_size, + dtype=params_dtype), + requires_grad=False, + ) + layer.register_parameter("w2_scales", w2_scales) + set_weight_attrs(w2_scales, extra_weight_attrs) + + # WEIGHT_ZERO_POINT + # Allocate 2 zero points for w1 and w3 respectively. + w13_qzeros = torch.nn.Parameter( + torch.empty( + num_experts, + num_groups_w13, + 2 * intermediate_size_per_partition // + self.quant_config.pack_factor, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_qzeros", w13_qzeros) + set_weight_attrs(w13_qzeros, extra_weight_attrs) + + w2_qzeros = torch.nn.Parameter( + torch.empty( + num_experts, + num_groups_w2, + hidden_size // self.quant_config.pack_factor, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_qzeros", w2_qzeros) + set_weight_attrs(w2_qzeros, extra_weight_attrs) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + w13_qweight_tmp = torch.zeros_like(layer.w13_qweight.data) + w2_qweight_tmp = torch.zeros_like(layer.w2_qweight.data) + w13_qzeros_list = [] + w2_qzeros_list = [] + shifts = [0, 4, 1, 5, 2, 6, 3, 7] + for i in range(0, self.quant_config.pack_factor): + shift_num = shifts[i] * 4 + w13_qzeros_list.append( + (layer.w13_qzeros.data.reshape(-1, 1) >> shift_num) & 0xF) + w2_qzeros_list.append( + (layer.w2_qzeros.data.reshape(-1, 1) >> shift_num) & 0xF) + w13_qweight_tmp.bitwise_or_(( + (layer.w13_qweight.data >> shift_num) * (2**(4 * i))) + & (0xF << (4 * i))) + w2_qweight_tmp.bitwise_or_(((layer.w2_qweight.data >> shift_num) * + (2**(4 * i))) + & (0xF << (4 * i))) + + w13_qweight_tmp.bitwise_xor_(0x88888888) + w2_qweight_tmp.bitwise_xor_(0x88888888) + + w13_qzeros_tmp = torch.cat(w13_qzeros_list, + dim=-1).reshape(layer.w13_qzeros.shape[0], + layer.w13_qzeros.shape[1], + -1) + w13_qzeros_tmp = -(w13_qzeros_tmp - 8) + w13_qzeros_tmp = w13_qzeros_tmp.to(layer.w13_scales.data.dtype) + w2_qzeros_tmp = torch.cat(w2_qzeros_list, + dim=-1).reshape(layer.w2_qzeros.shape[0], + layer.w2_qzeros.shape[1], -1) + w2_qzeros_tmp = -(w2_qzeros_tmp - 8) + w2_qzeros_tmp = w2_qzeros_tmp.to(layer.w2_scales.data.dtype) + + layer.register_parameter( + "w13_qzeros", + torch.nn.Parameter(w13_qzeros_tmp, requires_grad=False)) + layer.register_parameter( + "w13_qweight", + torch.nn.Parameter(w13_qweight_tmp, requires_grad=False)) + layer.register_parameter( + "w2_qzeros", torch.nn.Parameter(w2_qzeros_tmp, + requires_grad=False)) + layer.register_parameter( + "w2_qweight", + torch.nn.Parameter(w2_qweight_tmp, requires_grad=False)) + + def get_fused_moe_quant_config( + self, layer: torch.nn.Module) -> FusedMoEQuantConfig | None: + return None + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for `AWQMoEAscendMethod` yet.") + + assert activation == "silu", "Only SiLU activation is supported." + + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype, + global_num_experts=global_num_experts) + + return npu_fused_experts( + hidden_states=x, + w13=layer.w13_qweight, + w13_scale=layer.w13_scales, + w13_offset=layer.w13_qzeros, + w2=layer.w2_qweight, + w2_scale=layer.w2_scales, + w2_offset=layer.w2_qzeros, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=topk_ids.shape[1], + use_wna16=True, + ) diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 0a74bcbfdcf..a9acfc9dfe8 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -42,6 +42,7 @@ ASCEND_QUANTIZATION_METHOD = "ascend" COMPRESSED_TENSORS_METHOD = "compressed-tensors" +AWQ_QUANTIZATION_METHOD = "awq" SOC_VERSION_INFERENCE_SERIES = ["Ascend310P3"] REGISTERED_ASCEND_OPS = {}