@@ -25,44 +25,47 @@ def woq_quant_and_pack(weight, group_size, dtype, lowp_mode, sym_quant_weight):
2525 quantize_per_block ,
2626 )
2727
28- if group_size == - 1 :
29- qweight , scales , zero_points = quantize_per_channel (
30- weight , dtype , None , None , sym_quant_weight
31- )
32- else :
33- qweight , scales , zero_points = quantize_per_block (
34- weight , dtype , group_size , None , None , sym_quant_weight
35- )
28+ with torch .no_grad ():
29+ if group_size == - 1 :
30+ qweight , scales , zero_points = quantize_per_channel (
31+ weight , dtype , None , None , sym_quant_weight
32+ )
33+ else :
34+ qweight , scales , zero_points = quantize_per_block (
35+ weight , dtype , group_size , None , None , sym_quant_weight
36+ )
3637
37- _op_context = torch .ops .ipex_prepack .weight_only_qlinear_prepack (
38- qweight ,
39- dtype ,
40- [weight .shape [0 ], weight .shape [1 ]],
41- scales ,
42- zero_points ,
43- None , # bias
44- None , # g_idx
45- None , # batch size
46- group_size ,
47- lowp_mode ,
48- WoqActQuantMode .NONE , # act_quant_mode
49- False , # cache_weight_for_large_batch
50- )
51- # qweight: {N/block_n, K/block_k, block_k, block_n}
52- if (
53- dtype == WoqWeightDtype .INT8
54- and lowp_mode == WoqLowpMode .INT8
55- and _op_context .get_weight ().dim () == 4
56- ):
57- n_blocks , k_blocks , block_k , block_n = _op_context .get_weight ().shape
58- weight_view = qweight .view ([n_blocks , block_n , k_blocks , block_k ])
59- compensation = torch .sum (weight_view , dim = - 1 , keepdim = False , dtype = torch .int32 )
60- compensation = compensation .permute ([0 , 2 , 1 ]).contiguous ()
61- else :
62- compensation = None
63- qweight = _op_context .get_weight ()
64- scale = _op_context .get_scales ()
65- zero_point = _op_context .get_zero_points ()
38+ _op_context = torch .ops .ipex_prepack .weight_only_qlinear_prepack (
39+ qweight ,
40+ dtype ,
41+ [weight .shape [0 ], weight .shape [1 ]],
42+ scales ,
43+ zero_points ,
44+ None , # bias
45+ None , # g_idx
46+ None , # batch size
47+ group_size ,
48+ lowp_mode ,
49+ WoqActQuantMode .NONE , # act_quant_mode
50+ False , # cache_weight_for_large_batch
51+ )
52+ # qweight: {N/block_n, K/block_k, block_k, block_n}
53+ if (
54+ dtype == WoqWeightDtype .INT8
55+ and lowp_mode == WoqLowpMode .INT8
56+ and _op_context .get_weight ().dim () == 4
57+ ):
58+ n_blocks , k_blocks , block_k , block_n = _op_context .get_weight ().shape
59+ weight_view = qweight .view ([n_blocks , block_n , k_blocks , block_k ])
60+ compensation = torch .sum (
61+ weight_view , dim = - 1 , keepdim = False , dtype = torch .int32
62+ )
63+ compensation = compensation .permute ([0 , 2 , 1 ]).contiguous ()
64+ else :
65+ compensation = None
66+ qweight = _op_context .get_weight ()
67+ scale = _op_context .get_scales ()
68+ zero_point = _op_context .get_zero_points ()
6669 return (qweight , scale , zero_point , compensation )
6770
6871
@@ -73,31 +76,34 @@ def woq_pack(plain_qweight, plain_scales, plain_zp, group_size, dtype, lowp_mode
7376 WoqActQuantMode ,
7477 )
7578
76- _op_context = torch .ops .ipex_prepack .weight_only_qlinear_prepack (
77- plain_qweight ,
78- dtype ,
79- [plain_qweight .shape [0 ], plain_qweight .shape [1 ]],
80- plain_scales ,
81- plain_zp ,
82- None , # bias
83- None , # g_idx
84- None , # batch size
85- group_size ,
86- lowp_mode ,
87- WoqActQuantMode .NONE , # act_quant_mode
88- False , # cache_weight_for_large_batch
89- )
90- if (
91- dtype == WoqWeightDtype .INT8
92- and lowp_mode == WoqLowpMode .INT8
93- and _op_context .get_weight ().dim () == 4
94- ):
95- n_blocks , k_blocks , block_k , block_n = _op_context .get_weight ().shape
96- weight_view = plain_qweight .view ([n_blocks , block_n , k_blocks , block_k ])
97- compensation = torch .sum (weight_view , dim = - 1 , keepdim = False , dtype = torch .int32 )
98- compensation = compensation .permute ([0 , 2 , 1 ]).contiguous ()
99- else :
100- compensation = None
79+ with torch .no_grad ():
80+ _op_context = torch .ops .ipex_prepack .weight_only_qlinear_prepack (
81+ plain_qweight ,
82+ dtype ,
83+ [plain_qweight .shape [0 ], plain_qweight .shape [1 ]],
84+ plain_scales ,
85+ plain_zp ,
86+ None , # bias
87+ None , # g_idx
88+ None , # batch size
89+ group_size ,
90+ lowp_mode ,
91+ WoqActQuantMode .NONE , # act_quant_mode
92+ False , # cache_weight_for_large_batch
93+ )
94+ if (
95+ dtype == WoqWeightDtype .INT8
96+ and lowp_mode == WoqLowpMode .INT8
97+ and _op_context .get_weight ().dim () == 4
98+ ):
99+ n_blocks , k_blocks , block_k , block_n = _op_context .get_weight ().shape
100+ weight_view = plain_qweight .view ([n_blocks , block_n , k_blocks , block_k ])
101+ compensation = torch .sum (
102+ weight_view , dim = - 1 , keepdim = False , dtype = torch .int32
103+ )
104+ compensation = compensation .permute ([0 , 2 , 1 ]).contiguous ()
105+ else :
106+ compensation = None
101107 # pack_qweight: {N/block_n, K/block_k, block_k, block_n}
102108 return (
103109 _op_context .get_weight (),
@@ -245,11 +251,7 @@ def __init__(self, module, config, tpp=False, woq=False):
245251 "DeepseekV2ForCausalLM" ,
246252 "DeepseekV3ForCausalLM" ,
247253 ]:
248- self .unify_experts = False
249- if hasattr (self .mlp , "shared_experts" ):
250- if config .n_shared_experts == 1 :
251- self .unify_experts = True
252- self .unify_shared_expert_id = config .n_routed_experts + 1
254+ if hasattr (self .mlp , "shared_experts" ) and self .unify_experts :
253255 if (
254256 hasattr (self , "deepseek_lowbit_load" )
255257 and self .deepseek_lowbit_load
0 commit comments