55from torch .nn .modules import Module
66import torch_npu
77from vllm .config import get_current_vllm_config
8- from vllm .distributed import get_tensor_model_parallel_rank
8+ from vllm .distributed import get_tensor_model_parallel_rank , get_tp_group
99from vllm .model_executor .layers .fused_moe import (FusedMoE , FusedMoEMethodBase ,
1010 FusedMoeWeightScaleSupported )
11- from vllm .model_executor .layers .fused_moe .config import FusedMoEConfig
11+ from vllm .model_executor .layers .fused_moe .config import (FusedMoEConfig , FusedMoEQuantConfig ,
12+ int4_w4a16_moe_quant_config ,
13+ int8_w8a16_moe_quant_config ,)
1214from vllm .model_executor .layers .linear import (LinearBase , LinearMethodBase ,
1315 RowParallelLinear , UnquantizedLinearMethod )
1416from vllm .model_executor .layers .quantization import \
@@ -76,7 +78,6 @@ def npu_fused_experts(
7678 )
7779 expert_tokens = expert_tokens .to (torch .int64 )
7880 # gmm1: gate_up_proj
79- hidden_states , pertoken_scale = torch_npu .npu_dynamic_quant (hidden_states )
8081 if not use_wna16 :
8182 hidden_states , pertoken_scale = torch_npu .npu_dynamic_quant (hidden_states )
8283 scale_args13 = {
@@ -92,8 +93,6 @@ def npu_fused_experts(
9293 hidden_states = torch_npu .npu_grouped_matmul (
9394 x = [hidden_states ],
9495 weight = [w13 ],
95- scale = [w13_scale .to (scale_dtype )],
96- per_token_scale = [pertoken_scale ],
9796 ** scale_args13 ,
9897 split_item = 2 ,
9998 group_list_type = 0 ,
@@ -103,7 +102,6 @@ def npu_fused_experts(
103102 )[0 ]
104103 # act_fn: swiglu
105104 hidden_states = torch_npu .npu_swiglu (hidden_states )
106- hidden_states , pertoken_scale = torch_npu .npu_dynamic_quant (hidden_states )
107105 if not use_wna16 :
108106 hidden_states , pertoken_scale = torch_npu .npu_dynamic_quant (hidden_states )
109107
@@ -117,15 +115,14 @@ def npu_fused_experts(
117115 hidden_states = torch_npu .npu_grouped_matmul (
118116 x = [hidden_states ],
119117 weight = [w2 ],
120- scale = [w2_scale .to (scale_dtype )],
121- per_token_scale = [pertoken_scale ],
122118 ** scale_args2 ,
123119 split_item = 2 ,
124120 group_list_type = 0 ,
125121 group_type = 0 ,
126122 group_list = expert_tokens ,
127123 output_dtype = original_dtype ,
128124 )[0 ]
125+
129126 final_hidden_states = torch_npu .npu_moe_finalize_routing (
130127 hidden_states ,
131128 skip1 = None ,
@@ -270,91 +267,86 @@ def __init__(self, quant_config: AWQQuantConfig):
270267 def create_weights (self , layer : torch .nn .Module , num_experts : int ,
271268 hidden_size : int , intermediate_size_per_partition : int ,
272269 params_dtype : torch .dtype , ** extra_weight_attrs ):
273- self .moe = layer
274- layer .quant_config = self .quant_config
275- bit8_pack_factor = self .quant_config .pack_factor
276- group_size = self .quant_config .group_size
277- group_size_div_factor = 1
278-
279- # make intermediate_size and hidden_size divisible by group_size
280- # we reduce the group size to ensure that
281- # and we would repeat the loaded_weight later
282- while intermediate_size_per_partition % group_size or \
283- hidden_size % group_size :
284- group_size = group_size // 2
285- group_size_div_factor *= 2
286- assert group_size >= 32
287- layer .group_size = group_size
288- layer .group_size_div_factor = group_size_div_factor
289-
290- strategy = FusedMoeWeightScaleSupported .GROUP .value
291- extra_weight_attrs .update ({
292- "quant_method" : strategy ,
293- "is_transposed" : False
294- })
295-
296- assert 'weight_loader' in extra_weight_attrs
297- weight_loader = extra_weight_attrs ['weight_loader' ]
298- wrapped_weight_loader = MoeWNA16Method .get_weight_loader (
299- layer , weight_loader )
300- extra_weight_attrs ['weight_loader' ] = wrapped_weight_loader
301-
302- # Fused gate_up_proj (column parallel)
303- w13_qweight = torch .nn .Parameter (torch .empty (
304- num_experts ,
305- 2 * intermediate_size_per_partition ,
306- hidden_size // bit8_pack_factor ,
307- dtype = torch .uint8 ),
308- requires_grad = False )
270+ extra_weight_attrs .update (
271+ {
272+ "is_transposed" : True ,
273+ "quant_method" : FusedMoeWeightScaleSupported .GROUP .value ,
274+ }
275+ )
276+
277+ w13_qweight = torch .nn .Parameter (
278+ torch .empty (
279+ num_experts ,
280+ hidden_size ,
281+ 2 * intermediate_size_per_partition // self .quant_config .pack_factor ,
282+ dtype = torch .int32 ,
283+ ),
284+ requires_grad = False ,
285+ )
309286 layer .register_parameter ("w13_qweight" , w13_qweight )
310287 set_weight_attrs (w13_qweight , extra_weight_attrs )
311288
312- # down_proj (row parallel)
313- w2_qweight = torch .nn .Parameter (torch .empty (
314- num_experts ,
315- hidden_size ,
316- intermediate_size_per_partition // bit8_pack_factor ,
317- dtype = torch .uint8 ),
318- requires_grad = False )
289+ w2_qweight = torch .nn .Parameter (
290+ torch .empty (
291+ num_experts ,
292+ intermediate_size_per_partition ,
293+ hidden_size // self .quant_config .pack_factor ,
294+ dtype = torch .int32 ,
295+ ),
296+ requires_grad = False ,
297+ )
319298 layer .register_parameter ("w2_qweight" , w2_qweight )
320299 set_weight_attrs (w2_qweight , extra_weight_attrs )
321300
322- w13_scales = torch .nn .Parameter (torch .zeros (
323- num_experts ,
324- 2 * intermediate_size_per_partition ,
325- hidden_size // group_size ,
326- dtype = params_dtype ),
327- requires_grad = False )
301+ num_groups_w13 = hidden_size // self .quant_config .group_size
302+ num_groups_w2 = intermediate_size_per_partition // self .quant_config .group_size
303+
304+ # WEIGHT_SCALES
305+ # Allocate 2 scales for w1 and w3 respectively.
306+ w13_scales = torch .nn .Parameter (
307+ torch .empty (
308+ num_experts ,
309+ num_groups_w13 ,
310+ intermediate_size_per_partition * 2 ,
311+ dtype = params_dtype ,
312+ ),
313+ requires_grad = False ,
314+ )
328315 layer .register_parameter ("w13_scales" , w13_scales )
329316 set_weight_attrs (w13_scales , extra_weight_attrs )
330317
331- w2_scales = torch .nn .Parameter (torch .zeros (
332- num_experts ,
333- hidden_size ,
334- intermediate_size_per_partition // group_size ,
335- dtype = params_dtype ),
336- requires_grad = False )
318+ w2_scales = torch .nn .Parameter (
319+ torch .empty (num_experts , num_groups_w2 , hidden_size , dtype = params_dtype ),
320+ requires_grad = False ,
321+ )
337322 layer .register_parameter ("w2_scales" , w2_scales )
338323 set_weight_attrs (w2_scales , extra_weight_attrs )
339324
340- if self .quant_config .zero_point :
341- w13_qzeros = torch .nn .Parameter (torch .zeros (
325+ # WEIGHT_ZERO_POINT
326+ # Allocate 2 zero points for w1 and w3 respectively.
327+ w13_qzeros = torch .nn .Parameter (
328+ torch .empty (
342329 num_experts ,
343- 2 * intermediate_size_per_partition // bit8_pack_factor ,
344- hidden_size // group_size ,
345- dtype = torch .uint8 ),
346- requires_grad = False )
347- layer .register_parameter ("w13_qzeros" , w13_qzeros )
348- set_weight_attrs (w13_qzeros , extra_weight_attrs )
349-
350- w2_qzeros = torch .nn .Parameter (torch .zeros (
330+ num_groups_w13 ,
331+ 2 * intermediate_size_per_partition // self .quant_config .pack_factor ,
332+ dtype = torch .int32 ,
333+ ),
334+ requires_grad = False ,
335+ )
336+ layer .register_parameter ("w13_qzeros" , w13_qzeros )
337+ set_weight_attrs (w13_qzeros , extra_weight_attrs )
338+
339+ w2_qzeros = torch .nn .Parameter (
340+ torch .empty (
351341 num_experts ,
352- hidden_size // bit8_pack_factor ,
353- intermediate_size_per_partition // group_size ,
354- dtype = torch .uint8 ),
355- requires_grad = False )
356- layer .register_parameter ("w2_qzeros" , w2_qzeros )
357- set_weight_attrs (w2_qzeros , extra_weight_attrs )
342+ num_groups_w2 ,
343+ hidden_size // self .quant_config .pack_factor ,
344+ dtype = torch .int32 ,
345+ ),
346+ requires_grad = False ,
347+ )
348+ layer .register_parameter ("w2_qzeros" , w2_qzeros )
349+ set_weight_attrs (w2_qzeros , extra_weight_attrs )
358350
359351 def process_weights_after_loading (self , layer : torch .nn .Module ) -> None :
360352 w13_qweight_tmp = torch .zeros_like (layer .w13_qweight .data )
@@ -406,6 +398,11 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
406398 "w2_qweight" , torch .nn .Parameter (w2_qweight_tmp , requires_grad = False )
407399 )
408400
401+ def get_fused_moe_quant_config (
402+ self , layer : torch .nn .Module
403+ ) -> FusedMoEQuantConfig | None :
404+ return None
405+
409406 def apply (
410407 self ,
411408 layer : torch .nn .Module ,
@@ -428,6 +425,7 @@ def apply(
428425 expert_load_view : Optional [torch .Tensor ] = None ,
429426 logical_to_physical_map : Optional [torch .Tensor ] = None ,
430427 logical_replica_count : Optional [torch .Tensor ] = None ,
428+ ** kwargs ,
431429 ) -> torch .Tensor :
432430 assert self .fused_experts is None
433431
0 commit comments