Skip to content

Commit d1be882

Browse files
committed
[Quantization][Feature] Add AWQ quantization in vllm-ascend.
Signed-off-by: menogrey <1299267905@qq.com>
1 parent 5db33d3 commit d1be882

File tree

2 files changed

+78
-134
lines changed

2 files changed

+78
-134
lines changed

vllm_ascend/ops/fused_moe/fused_moe.py

Lines changed: 1 addition & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def process_weights_after_loading(self, layer):
100100
1, 2).contiguous()
101101
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
102102

103-
self.transpose = False
103+
#self.transpose = False
104104
else:
105105
w13_data = self._maybe_pad_weight(layer.w13_weight.data)
106106
layer.w13_weight = torch.nn.Parameter(w13_data,
@@ -402,60 +402,6 @@ def forward_impl(self, hidden_states: torch.Tensor,
402402

403403
return final_hidden_states
404404

405-
def transpose_weight(self, loaded_weight, expert_data, shard_dim):
406-
# Ensure training and inference weight shapes match during RL weight updates
407-
if (
408-
loaded_weight.shape[1] != expert_data.shape[1] and \
409-
loaded_weight.shape[0] != expert_data.shape[0]
410-
):
411-
shard_dim = int(not shard_dim)
412-
loaded_weight = loaded_weight.transpose(0, 1).contiguous()
413-
return loaded_weight, shard_dim
414-
415-
def _load_w13(self,
416-
expert_data: torch.Tensor,
417-
shard_dim: int,
418-
shard_id: str,
419-
loaded_weight: torch.Tensor,
420-
tp_rank: int,
421-
load_full: bool = False):
422-
# Index the loaded weight for tp sharding.
423-
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
424-
loaded_weight, shard_dim = self.transpose_weight(
425-
loaded_weight, expert_data, shard_dim)
426-
shard_size = expert_data.shape[shard_dim] // 2
427-
if not load_full:
428-
loaded_weight = loaded_weight.narrow(shard_dim,
429-
shard_size * tp_rank,
430-
shard_size)
431-
# Narrow parameter and load.
432-
# w1, gate_proj: Load into first logical weight of w13.
433-
if shard_id == "w1":
434-
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
435-
# w3, up_proj: Load into second logical weight of w13.
436-
else:
437-
assert shard_id == "w3"
438-
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
439-
expert_data.copy_(loaded_weight)
440-
441-
def _load_w2(self,
442-
expert_data: torch.Tensor,
443-
shard_dim: int,
444-
loaded_weight: torch.Tensor,
445-
tp_rank: int,
446-
load_full: bool = False):
447-
# Index the loaded weight for tp sharding.
448-
# down_proj: "RowParallel" so tp sharding on input_dim
449-
# Narrow parameter and load.
450-
loaded_weight, shard_dim = self.transpose_weight(
451-
loaded_weight, expert_data, shard_dim)
452-
shard_size = expert_data.shape[shard_dim]
453-
if not load_full:
454-
loaded_weight = loaded_weight.narrow(shard_dim,
455-
shard_size * tp_rank,
456-
shard_size)
457-
# w2, down_proj: Load into only logical weight of w2.
458-
expert_data.copy_(loaded_weight)
459405

460406

461407
class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):

vllm_ascend/quantization/awq/awq.py

Lines changed: 77 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55
from torch.nn.modules import Module
66
import torch_npu
77
from vllm.config import get_current_vllm_config
8-
from vllm.distributed import get_tensor_model_parallel_rank
8+
from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
99
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
1010
FusedMoeWeightScaleSupported)
11-
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
11+
from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig, FusedMoEQuantConfig,
12+
int4_w4a16_moe_quant_config,
13+
int8_w8a16_moe_quant_config,)
1214
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
1315
RowParallelLinear, UnquantizedLinearMethod)
1416
from vllm.model_executor.layers.quantization import \
@@ -76,7 +78,6 @@ def npu_fused_experts(
7678
)
7779
expert_tokens = expert_tokens.to(torch.int64)
7880
# gmm1: gate_up_proj
79-
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)
8081
if not use_wna16:
8182
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)
8283
scale_args13 = {
@@ -92,8 +93,6 @@ def npu_fused_experts(
9293
hidden_states = torch_npu.npu_grouped_matmul(
9394
x=[hidden_states],
9495
weight=[w13],
95-
scale=[w13_scale.to(scale_dtype)],
96-
per_token_scale=[pertoken_scale],
9796
**scale_args13,
9897
split_item=2,
9998
group_list_type=0,
@@ -103,7 +102,6 @@ def npu_fused_experts(
103102
)[0]
104103
# act_fn: swiglu
105104
hidden_states = torch_npu.npu_swiglu(hidden_states)
106-
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)
107105
if not use_wna16:
108106
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)
109107

@@ -117,15 +115,14 @@ def npu_fused_experts(
117115
hidden_states = torch_npu.npu_grouped_matmul(
118116
x=[hidden_states],
119117
weight=[w2],
120-
scale=[w2_scale.to(scale_dtype)],
121-
per_token_scale=[pertoken_scale],
122118
**scale_args2,
123119
split_item=2,
124120
group_list_type=0,
125121
group_type=0,
126122
group_list=expert_tokens,
127123
output_dtype=original_dtype,
128124
)[0]
125+
129126
final_hidden_states = torch_npu.npu_moe_finalize_routing(
130127
hidden_states,
131128
skip1=None,
@@ -270,91 +267,86 @@ def __init__(self, quant_config: AWQQuantConfig):
270267
def create_weights(self, layer: torch.nn.Module, num_experts: int,
271268
hidden_size: int, intermediate_size_per_partition: int,
272269
params_dtype: torch.dtype, **extra_weight_attrs):
273-
self.moe = layer
274-
layer.quant_config = self.quant_config
275-
bit8_pack_factor = self.quant_config.pack_factor
276-
group_size = self.quant_config.group_size
277-
group_size_div_factor = 1
278-
279-
# make intermediate_size and hidden_size divisible by group_size
280-
# we reduce the group size to ensure that
281-
# and we would repeat the loaded_weight later
282-
while intermediate_size_per_partition % group_size or \
283-
hidden_size % group_size:
284-
group_size = group_size // 2
285-
group_size_div_factor *= 2
286-
assert group_size >= 32
287-
layer.group_size = group_size
288-
layer.group_size_div_factor = group_size_div_factor
289-
290-
strategy = FusedMoeWeightScaleSupported.GROUP.value
291-
extra_weight_attrs.update({
292-
"quant_method": strategy,
293-
"is_transposed": False
294-
})
295-
296-
assert 'weight_loader' in extra_weight_attrs
297-
weight_loader = extra_weight_attrs['weight_loader']
298-
wrapped_weight_loader = MoeWNA16Method.get_weight_loader(
299-
layer, weight_loader)
300-
extra_weight_attrs['weight_loader'] = wrapped_weight_loader
301-
302-
# Fused gate_up_proj (column parallel)
303-
w13_qweight = torch.nn.Parameter(torch.empty(
304-
num_experts,
305-
2 * intermediate_size_per_partition,
306-
hidden_size // bit8_pack_factor,
307-
dtype=torch.uint8),
308-
requires_grad=False)
270+
extra_weight_attrs.update(
271+
{
272+
"is_transposed": True,
273+
"quant_method": FusedMoeWeightScaleSupported.GROUP.value,
274+
}
275+
)
276+
277+
w13_qweight = torch.nn.Parameter(
278+
torch.empty(
279+
num_experts,
280+
hidden_size,
281+
2 * intermediate_size_per_partition // self.quant_config.pack_factor,
282+
dtype=torch.int32,
283+
),
284+
requires_grad=False,
285+
)
309286
layer.register_parameter("w13_qweight", w13_qweight)
310287
set_weight_attrs(w13_qweight, extra_weight_attrs)
311288

312-
# down_proj (row parallel)
313-
w2_qweight = torch.nn.Parameter(torch.empty(
314-
num_experts,
315-
hidden_size,
316-
intermediate_size_per_partition // bit8_pack_factor,
317-
dtype=torch.uint8),
318-
requires_grad=False)
289+
w2_qweight = torch.nn.Parameter(
290+
torch.empty(
291+
num_experts,
292+
intermediate_size_per_partition,
293+
hidden_size // self.quant_config.pack_factor,
294+
dtype=torch.int32,
295+
),
296+
requires_grad=False,
297+
)
319298
layer.register_parameter("w2_qweight", w2_qweight)
320299
set_weight_attrs(w2_qweight, extra_weight_attrs)
321300

322-
w13_scales = torch.nn.Parameter(torch.zeros(
323-
num_experts,
324-
2 * intermediate_size_per_partition,
325-
hidden_size // group_size,
326-
dtype=params_dtype),
327-
requires_grad=False)
301+
num_groups_w13 = hidden_size // self.quant_config.group_size
302+
num_groups_w2 = intermediate_size_per_partition // self.quant_config.group_size
303+
304+
# WEIGHT_SCALES
305+
# Allocate 2 scales for w1 and w3 respectively.
306+
w13_scales = torch.nn.Parameter(
307+
torch.empty(
308+
num_experts,
309+
num_groups_w13,
310+
intermediate_size_per_partition * 2,
311+
dtype=params_dtype,
312+
),
313+
requires_grad=False,
314+
)
328315
layer.register_parameter("w13_scales", w13_scales)
329316
set_weight_attrs(w13_scales, extra_weight_attrs)
330317

331-
w2_scales = torch.nn.Parameter(torch.zeros(
332-
num_experts,
333-
hidden_size,
334-
intermediate_size_per_partition // group_size,
335-
dtype=params_dtype),
336-
requires_grad=False)
318+
w2_scales = torch.nn.Parameter(
319+
torch.empty(num_experts, num_groups_w2, hidden_size, dtype=params_dtype),
320+
requires_grad=False,
321+
)
337322
layer.register_parameter("w2_scales", w2_scales)
338323
set_weight_attrs(w2_scales, extra_weight_attrs)
339324

340-
if self.quant_config.zero_point:
341-
w13_qzeros = torch.nn.Parameter(torch.zeros(
325+
# WEIGHT_ZERO_POINT
326+
# Allocate 2 zero points for w1 and w3 respectively.
327+
w13_qzeros = torch.nn.Parameter(
328+
torch.empty(
342329
num_experts,
343-
2 * intermediate_size_per_partition // bit8_pack_factor,
344-
hidden_size // group_size,
345-
dtype=torch.uint8),
346-
requires_grad=False)
347-
layer.register_parameter("w13_qzeros", w13_qzeros)
348-
set_weight_attrs(w13_qzeros, extra_weight_attrs)
349-
350-
w2_qzeros = torch.nn.Parameter(torch.zeros(
330+
num_groups_w13,
331+
2 * intermediate_size_per_partition // self.quant_config.pack_factor,
332+
dtype=torch.int32,
333+
),
334+
requires_grad=False,
335+
)
336+
layer.register_parameter("w13_qzeros", w13_qzeros)
337+
set_weight_attrs(w13_qzeros, extra_weight_attrs)
338+
339+
w2_qzeros = torch.nn.Parameter(
340+
torch.empty(
351341
num_experts,
352-
hidden_size // bit8_pack_factor,
353-
intermediate_size_per_partition // group_size,
354-
dtype=torch.uint8),
355-
requires_grad=False)
356-
layer.register_parameter("w2_qzeros", w2_qzeros)
357-
set_weight_attrs(w2_qzeros, extra_weight_attrs)
342+
num_groups_w2,
343+
hidden_size // self.quant_config.pack_factor,
344+
dtype=torch.int32,
345+
),
346+
requires_grad=False,
347+
)
348+
layer.register_parameter("w2_qzeros", w2_qzeros)
349+
set_weight_attrs(w2_qzeros, extra_weight_attrs)
358350

359351
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
360352
w13_qweight_tmp = torch.zeros_like(layer.w13_qweight.data)
@@ -406,6 +398,11 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
406398
"w2_qweight", torch.nn.Parameter(w2_qweight_tmp, requires_grad=False)
407399
)
408400

401+
def get_fused_moe_quant_config(
402+
self, layer: torch.nn.Module
403+
) -> FusedMoEQuantConfig | None:
404+
return None
405+
409406
def apply(
410407
self,
411408
layer: torch.nn.Module,
@@ -428,6 +425,7 @@ def apply(
428425
expert_load_view: Optional[torch.Tensor] = None,
429426
logical_to_physical_map: Optional[torch.Tensor] = None,
430427
logical_replica_count: Optional[torch.Tensor] = None,
428+
**kwargs,
431429
) -> torch.Tensor:
432430
assert self.fused_experts is None
433431

0 commit comments

Comments
 (0)