Skip to content

Commit d8f39f4

Browse files
Fix deepseek issue of attention and fused MOE in the frontend (#3604)
* Fix deepseek issue in attention * Fix memory usage issue of fused MOE in the frontend * add minor fix for non-unify-expert path --------- Co-authored-by: jianan-gu <jianan.gu@intel.com>
1 parent 8e97ea1 commit d8f39f4

File tree

3 files changed

+109
-69
lines changed

3 files changed

+109
-69
lines changed

intel_extension_for_pytorch/transformers/models/cpu/modules/decoder.py

Lines changed: 69 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -25,44 +25,47 @@ def woq_quant_and_pack(weight, group_size, dtype, lowp_mode, sym_quant_weight):
2525
quantize_per_block,
2626
)
2727

28-
if group_size == -1:
29-
qweight, scales, zero_points = quantize_per_channel(
30-
weight, dtype, None, None, sym_quant_weight
31-
)
32-
else:
33-
qweight, scales, zero_points = quantize_per_block(
34-
weight, dtype, group_size, None, None, sym_quant_weight
35-
)
28+
with torch.no_grad():
29+
if group_size == -1:
30+
qweight, scales, zero_points = quantize_per_channel(
31+
weight, dtype, None, None, sym_quant_weight
32+
)
33+
else:
34+
qweight, scales, zero_points = quantize_per_block(
35+
weight, dtype, group_size, None, None, sym_quant_weight
36+
)
3637

37-
_op_context = torch.ops.ipex_prepack.weight_only_qlinear_prepack(
38-
qweight,
39-
dtype,
40-
[weight.shape[0], weight.shape[1]],
41-
scales,
42-
zero_points,
43-
None, # bias
44-
None, # g_idx
45-
None, # batch size
46-
group_size,
47-
lowp_mode,
48-
WoqActQuantMode.NONE, # act_quant_mode
49-
False, # cache_weight_for_large_batch
50-
)
51-
# qweight: {N/block_n, K/block_k, block_k, block_n}
52-
if (
53-
dtype == WoqWeightDtype.INT8
54-
and lowp_mode == WoqLowpMode.INT8
55-
and _op_context.get_weight().dim() == 4
56-
):
57-
n_blocks, k_blocks, block_k, block_n = _op_context.get_weight().shape
58-
weight_view = qweight.view([n_blocks, block_n, k_blocks, block_k])
59-
compensation = torch.sum(weight_view, dim=-1, keepdim=False, dtype=torch.int32)
60-
compensation = compensation.permute([0, 2, 1]).contiguous()
61-
else:
62-
compensation = None
63-
qweight = _op_context.get_weight()
64-
scale = _op_context.get_scales()
65-
zero_point = _op_context.get_zero_points()
38+
_op_context = torch.ops.ipex_prepack.weight_only_qlinear_prepack(
39+
qweight,
40+
dtype,
41+
[weight.shape[0], weight.shape[1]],
42+
scales,
43+
zero_points,
44+
None, # bias
45+
None, # g_idx
46+
None, # batch size
47+
group_size,
48+
lowp_mode,
49+
WoqActQuantMode.NONE, # act_quant_mode
50+
False, # cache_weight_for_large_batch
51+
)
52+
# qweight: {N/block_n, K/block_k, block_k, block_n}
53+
if (
54+
dtype == WoqWeightDtype.INT8
55+
and lowp_mode == WoqLowpMode.INT8
56+
and _op_context.get_weight().dim() == 4
57+
):
58+
n_blocks, k_blocks, block_k, block_n = _op_context.get_weight().shape
59+
weight_view = qweight.view([n_blocks, block_n, k_blocks, block_k])
60+
compensation = torch.sum(
61+
weight_view, dim=-1, keepdim=False, dtype=torch.int32
62+
)
63+
compensation = compensation.permute([0, 2, 1]).contiguous()
64+
else:
65+
compensation = None
66+
qweight = _op_context.get_weight()
67+
scale = _op_context.get_scales()
68+
zero_point = _op_context.get_zero_points()
6669
return (qweight, scale, zero_point, compensation)
6770

6871

@@ -73,31 +76,34 @@ def woq_pack(plain_qweight, plain_scales, plain_zp, group_size, dtype, lowp_mode
7376
WoqActQuantMode,
7477
)
7578

76-
_op_context = torch.ops.ipex_prepack.weight_only_qlinear_prepack(
77-
plain_qweight,
78-
dtype,
79-
[plain_qweight.shape[0], plain_qweight.shape[1]],
80-
plain_scales,
81-
plain_zp,
82-
None, # bias
83-
None, # g_idx
84-
None, # batch size
85-
group_size,
86-
lowp_mode,
87-
WoqActQuantMode.NONE, # act_quant_mode
88-
False, # cache_weight_for_large_batch
89-
)
90-
if (
91-
dtype == WoqWeightDtype.INT8
92-
and lowp_mode == WoqLowpMode.INT8
93-
and _op_context.get_weight().dim() == 4
94-
):
95-
n_blocks, k_blocks, block_k, block_n = _op_context.get_weight().shape
96-
weight_view = plain_qweight.view([n_blocks, block_n, k_blocks, block_k])
97-
compensation = torch.sum(weight_view, dim=-1, keepdim=False, dtype=torch.int32)
98-
compensation = compensation.permute([0, 2, 1]).contiguous()
99-
else:
100-
compensation = None
79+
with torch.no_grad():
80+
_op_context = torch.ops.ipex_prepack.weight_only_qlinear_prepack(
81+
plain_qweight,
82+
dtype,
83+
[plain_qweight.shape[0], plain_qweight.shape[1]],
84+
plain_scales,
85+
plain_zp,
86+
None, # bias
87+
None, # g_idx
88+
None, # batch size
89+
group_size,
90+
lowp_mode,
91+
WoqActQuantMode.NONE, # act_quant_mode
92+
False, # cache_weight_for_large_batch
93+
)
94+
if (
95+
dtype == WoqWeightDtype.INT8
96+
and lowp_mode == WoqLowpMode.INT8
97+
and _op_context.get_weight().dim() == 4
98+
):
99+
n_blocks, k_blocks, block_k, block_n = _op_context.get_weight().shape
100+
weight_view = plain_qweight.view([n_blocks, block_n, k_blocks, block_k])
101+
compensation = torch.sum(
102+
weight_view, dim=-1, keepdim=False, dtype=torch.int32
103+
)
104+
compensation = compensation.permute([0, 2, 1]).contiguous()
105+
else:
106+
compensation = None
101107
# pack_qweight: {N/block_n, K/block_k, block_k, block_n}
102108
return (
103109
_op_context.get_weight(),
@@ -245,11 +251,7 @@ def __init__(self, module, config, tpp=False, woq=False):
245251
"DeepseekV2ForCausalLM",
246252
"DeepseekV3ForCausalLM",
247253
]:
248-
self.unify_experts = False
249-
if hasattr(self.mlp, "shared_experts"):
250-
if config.n_shared_experts == 1:
251-
self.unify_experts = True
252-
self.unify_shared_expert_id = config.n_routed_experts + 1
254+
if hasattr(self.mlp, "shared_experts") and self.unify_experts:
253255
if (
254256
hasattr(self, "deepseek_lowbit_load")
255257
and self.deepseek_lowbit_load

intel_extension_for_pytorch/transformers/models/reference/modules/attentions.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3106,7 +3106,39 @@ def __init__(self, module, config, sdp_module_ref, distributed=False):
31063106
else:
31073107
self.norm_factor_value = 1 / self.norm_factor
31083108
if self.model_backbone in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
3109-
kv_b_proj_weight = self.kv_b_proj.weight.detach()
3109+
if (
3110+
hasattr(self.kv_b_proj, "_op_context")
3111+
and self.kv_b_proj._op_context is not None
3112+
):
3113+
kv_b_proj_weight = self.kv_b_proj._op_context.to_public(
3114+
self.kv_b_proj._op_context.get_weight()
3115+
)
3116+
if self.kv_b_proj._group_size > 1:
3117+
from intel_extension_for_pytorch.quantization._quantize_utils import (
3118+
dequantize_per_block,
3119+
)
3120+
3121+
kv_b_proj_weight = dequantize_per_block(
3122+
kv_b_proj_weight,
3123+
self.kv_b_proj._op_context.get_scales(),
3124+
self.kv_b_proj._op_context.get_zero_points(),
3125+
self.kv_b_proj.dtype,
3126+
self.kv_b_proj._group_size,
3127+
)
3128+
else:
3129+
from intel_extension_for_pytorch.quantization._quantize_utils import (
3130+
dequantize_per_channel,
3131+
)
3132+
3133+
kv_b_proj_weight = dequantize_per_channel(
3134+
kv_b_proj_weight,
3135+
self.kv_b_proj._op_context.get_scales(),
3136+
self.kv_b_proj._op_context.get_zero_points(),
3137+
self.kv_b_proj.dtype,
3138+
)
3139+
kv_b_proj_weight = kv_b_proj_weight.bfloat16()
3140+
else:
3141+
kv_b_proj_weight = self.kv_b_proj.weight.detach()
31103142
self.kv_b_proj_weight = kv_b_proj_weight.transpose(0, 1).contiguous()
31113143
w_kc, w_vc = kv_b_proj_weight.unflatten(
31123144
0, (-1, self.qk_nope_head_dim + self.v_head_dim)

intel_extension_for_pytorch/transformers/models/reference/modules/decoder.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2433,7 +2433,13 @@ def __init__(self, module, config, distributed=False):
24332433
if hasattr(module.mlp, "experts"): # DeepseekV2MoE
24342434
# shared_experts
24352435
if config.n_shared_experts is not None:
2436-
if self.use_fused_moe or self.use_fused_moe_woq:
2436+
self.unify_experts = False
2437+
if config.n_shared_experts == 1:
2438+
self.unify_experts = True
2439+
self.unify_shared_expert_id = config.n_routed_experts + 1
2440+
if self.unify_experts and (
2441+
self.use_fused_moe or self.use_fused_moe_woq
2442+
):
24372443
from intel_extension_for_pytorch.quantization import (
24382444
WoqWeightDtype,
24392445
WoqLowpMode,

0 commit comments

Comments
 (0)