Skip to content

Commit c1250f6

Browse files
zyongyeVictor49152
authored andcommitted
[DeepSeek] Fix DeepSeek V3.2 Rope Embedding (vllm-project#28968)
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
1 parent 14eb7a9 commit c1250f6

File tree

2 files changed

+17
-3
lines changed

2 files changed

+17
-3
lines changed

vllm/model_executor/layers/mla.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class MLAModules:
2424
q_b_proj: torch.nn.Module | None
2525
q_proj: torch.nn.Module | None
2626
indexer: torch.nn.Module | None
27+
indexer_rotary_emb: torch.nn.Module | None
2728
is_sparse: bool
2829
topk_indices_buffer: torch.Tensor | None
2930

@@ -80,6 +81,7 @@ def __init__(
8081
self.rotary_emb = mla_modules.rotary_emb
8182
self.o_proj = mla_modules.o_proj
8283
self.indexer = mla_modules.indexer
84+
self.indexer_rope_emb = mla_modules.indexer_rotary_emb
8385
self.is_sparse = mla_modules.is_sparse
8486

8587
if self.indexer is not None:
@@ -153,7 +155,9 @@ def forward_native(
153155
)
154156

155157
if self.indexer and self.is_sparse:
156-
_topk_indices = self.indexer(hidden_states, q_c, positions, self.rotary_emb)
158+
_topk_indices = self.indexer(
159+
hidden_states, q_c, positions, self.indexer_rope_emb
160+
)
157161

158162
attn_out = self.mla_attn(
159163
q,

vllm/model_executor/models/deepseek_v2.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -837,8 +837,8 @@ def forward(
837837
)
838838

839839
q_pe, k_pe = rotary_emb(positions, q_pe, k_pe.unsqueeze(1))
840-
q = torch.cat([q_pe, q_nope], dim=-1)
841-
k = torch.cat([k_pe.squeeze(1), k_nope], dim=-1)
840+
q = torch.cat([q_pe.squeeze(0), q_nope], dim=-1)
841+
k = torch.cat([k_pe.squeeze((0, 2)), k_nope], dim=-1)
842842

843843
# we only quant q here since k quant is fused with cache insertion
844844
q = q.view(-1, self.head_dim)
@@ -987,6 +987,14 @@ def __init__(
987987
self.is_v32 = hasattr(config, "index_topk")
988988

989989
if self.is_v32:
990+
self.indexer_rope_emb = get_rope(
991+
qk_rope_head_dim,
992+
rotary_dim=qk_rope_head_dim,
993+
max_position=max_position_embeddings,
994+
base=rope_theta,
995+
rope_scaling=rope_scaling,
996+
is_neox_style=True,
997+
)
990998
self.indexer = Indexer(
991999
vllm_config,
9921000
config,
@@ -998,6 +1006,7 @@ def __init__(
9981006
f"{prefix}.indexer",
9991007
)
10001008
else:
1009+
self.indexer_rope_emb = None
10011010
self.indexer = None
10021011

10031012
mla_modules = MLAModules(
@@ -1015,6 +1024,7 @@ def __init__(
10151024
q_b_proj=self.q_b_proj if self.q_lora_rank is not None else None,
10161025
q_proj=self.q_proj if self.q_lora_rank is None else None,
10171026
indexer=self.indexer,
1027+
indexer_rotary_emb=self.indexer_rope_emb,
10181028
is_sparse=self.is_v32,
10191029
topk_indices_buffer=topk_indices_buffer,
10201030
)

0 commit comments

Comments
 (0)